Skip to content

Commit 71726cd

Browse files
authored
add shuffle backpressure (#461)
1 parent 744f4f1 commit 71726cd

File tree

33 files changed

+719
-124
lines changed

33 files changed

+719
-124
lines changed

geaflow/geaflow-common/src/main/java/com/antgroup/geaflow/common/config/keys/ExecutionConfigKeys.java

+15
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,11 @@ public class ExecutionConfigKeys implements Serializable {
400400
.defaultValue("snappy")
401401
.description("codec of shuffle compression");
402402

403+
public static final ConfigKey SHUFFLE_BACKPRESSURE_ENABLE = ConfigKeys
404+
.key("geaflow.shuffle.backpressure.enable")
405+
.defaultValue(false)
406+
.description("whether to enable shuffle backpressure");
407+
403408
/** Shuffle network config. */
404409

405410
public static final ConfigKey NETTY_SERVER_HOST = ConfigKeys
@@ -484,6 +489,11 @@ public class ExecutionConfigKeys implements Serializable {
484489
.defaultValue(1)
485490
.description("size of shuffle fetch queue");
486491

492+
public static final ConfigKey SHUFFLE_FETCH_CHANNEL_QUEUE_SIZE = ConfigKeys
493+
.key("geaflow.shuffle.fetch.channel.queue.size")
494+
.defaultValue(64)
495+
.description("buffer number per channel");
496+
487497
/** Shuffle write config. */
488498

489499
public static final ConfigKey SHUFFLE_SPILL_RECORDS = ConfigKeys
@@ -501,6 +511,11 @@ public class ExecutionConfigKeys implements Serializable {
501511
.defaultValue(128 * 1024)
502512
.description("size of shuffle write buffer");
503513

514+
public static final ConfigKey SHUFFLE_WRITER_BUFFER_SIZE = ConfigKeys
515+
.key("geaflow.shuffle.writer.buffer.size")
516+
.defaultValue(64 * 1024 * 1024)
517+
.description("max buffer size for the shuffle writer in bytes");
518+
504519
public static final ConfigKey SHUFFLE_EMIT_BUFFER_SIZE = ConfigKeys
505520
.key("geaflow.shuffle.emit.buffer.size")
506521
.defaultValue(1024)

geaflow/geaflow-core/geaflow-engine/geaflow-cluster/src/main/java/com/antgroup/geaflow/cluster/fetcher/PrefetchMessageBuffer.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -62,15 +62,15 @@ public void onMessage(PipelineMessage<T> message) {
6262
AbstractMessageIterator<T> iterator = (AbstractMessageIterator<T>) message.getMessageIterator();
6363
OutBuffer outBuffer = iterator.getOutBuffer();
6464
long windowId = message.getRecordArgs().getWindowId();
65-
this.slice.add(new PipeBuffer(outBuffer, windowId, true));
65+
this.slice.add(new PipeBuffer(outBuffer, windowId));
6666
}
6767

6868
@Override
6969
public void onBarrier(PipelineBarrier barrier) {
7070
if (barrier.getEdgeId() != this.edgeId) {
7171
return;
7272
}
73-
this.slice.add(new PipeBuffer(barrier.getWindowId(), (int) barrier.getCount(), false, true));
73+
this.slice.add(new PipeBuffer(barrier.getWindowId(), (int) barrier.getCount(), true));
7474
this.slice.flush();
7575
this.latch.countDown();
7676
}

geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/com/antgroup/geaflow/shuffle/api/writer/PipelineShardWriter.java

+42-2
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,14 @@
1717
import com.antgroup.geaflow.common.exception.GeaflowRuntimeException;
1818
import com.antgroup.geaflow.shuffle.message.Shard;
1919
import com.antgroup.geaflow.shuffle.message.SliceId;
20+
import com.antgroup.geaflow.shuffle.pipeline.buffer.OutBuffer.BufferBuilder;
21+
import com.antgroup.geaflow.shuffle.pipeline.slice.BlockingSlice;
22+
import com.antgroup.geaflow.shuffle.pipeline.slice.IPipelineSlice;
2023
import com.antgroup.geaflow.shuffle.pipeline.slice.PipelineSlice;
2124
import java.io.IOException;
2225
import java.util.List;
2326
import java.util.Optional;
27+
import java.util.concurrent.atomic.AtomicInteger;
2428
import java.util.concurrent.atomic.AtomicReference;
2529
import org.slf4j.Logger;
2630
import org.slf4j.LoggerFactory;
@@ -31,23 +35,31 @@ public class PipelineShardWriter<T> extends ShardWriter<T, Shard> {
3135

3236
private OutputFlusher outputFlusher;
3337
private final AtomicReference<Throwable> throwable;
38+
private final AtomicInteger curBufferBytes;
39+
private int maxWriteBufferSize;
3440

3541
public PipelineShardWriter() {
3642
this.throwable = new AtomicReference<>();
43+
this.curBufferBytes = new AtomicInteger(0);
3744
}
3845

3946
@Override
4047
public void init(IWriterContext writerContext) {
4148
super.init(writerContext);
4249
String threadName = String.format("flusher-%s", writerContext.getTaskName());
4350
int flushTimeout = this.shuffleConfig.getFlushBufferTimeoutMs();
51+
this.maxWriteBufferSize = shuffleConfig.getMaxWriteBufferSize();
4452
this.outputFlusher = new OutputFlusher(threadName, flushTimeout);
4553
this.outputFlusher.start();
4654
}
4755

4856
@Override
49-
protected PipelineSlice newSlice(String taskLogTag, SliceId sliceId) {
50-
return new PipelineSlice(taskLogTag, sliceId);
57+
protected IPipelineSlice newSlice(String taskLogTag, SliceId sliceId) {
58+
if (enableBackPressure) {
59+
return new BlockingSlice(taskLogTag, sliceId, this);
60+
} else {
61+
return new PipelineSlice(taskLogTag, sliceId);
62+
}
5163
}
5264

5365
@Override
@@ -68,6 +80,34 @@ public Optional<Shard> doFinish(long windowId) throws IOException {
6880
return Optional.empty();
6981
}
7082

83+
@Override
84+
protected void sendBuffer(int sliceIndex, BufferBuilder builder, long windowId) {
85+
if (enableBackPressure) {
86+
if (curBufferBytes.get() >= maxWriteBufferSize) {
87+
synchronized (this) {
88+
while (curBufferBytes.get() >= maxWriteBufferSize) {
89+
try {
90+
this.wait();
91+
} catch (InterruptedException e) {
92+
throw new GeaflowRuntimeException(e);
93+
}
94+
}
95+
}
96+
}
97+
curBufferBytes.addAndGet(builder.getBufferSize());
98+
}
99+
super.sendBuffer(sliceIndex, builder, windowId);
100+
}
101+
102+
public void notifyBufferConsumed(int bufferBytes) {
103+
int preBytes = curBufferBytes.getAndAdd(-bufferBytes);
104+
if (preBytes >= maxWriteBufferSize && curBufferBytes.get() < maxWriteBufferSize) {
105+
synchronized (this) {
106+
this.notifyAll();
107+
}
108+
}
109+
}
110+
71111
private void flushAll() {
72112
boolean flushed = this.flushSlices();
73113
if (!flushed) {

geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/com/antgroup/geaflow/shuffle/api/writer/ShardWriter.java

+6-4
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ public abstract class ShardWriter<T, R> {
4545
protected int edgeId;
4646
protected int taskIndex;
4747
protected int targetChannels;
48+
protected boolean enableBackPressure;
4849

4950
protected String taskLogTag;
5051
protected long[] recordCounter;
@@ -72,7 +73,8 @@ public void init(IWriterContext writerContext) {
7273
this.taskLogTag = writerContext.getTaskName();
7374
this.recordCounter = new long[this.targetChannels];
7475
this.bytesCounter = new long[this.targetChannels];
75-
this.maxBufferSize = this.shuffleConfig.getFlushBufferSizeBytes();
76+
this.maxBufferSize = this.shuffleConfig.getMaxBufferSizeBytes();
77+
this.enableBackPressure = this.shuffleConfig.isBackpressureEnabled();
7678

7779
this.buffers = this.buildBufferBuilder(this.targetChannels);
7880
this.resultSlices = this.buildResultSlices(this.targetChannels);
@@ -146,16 +148,16 @@ public Optional<R> finish(long windowId) throws IOException {
146148

147149
protected abstract Optional<R> doFinish(long windowId) throws IOException;
148150

149-
private void sendBuffer(int sliceIndex, BufferBuilder builder, long windowId) {
151+
protected void sendBuffer(int sliceIndex, BufferBuilder builder, long windowId) {
150152
this.recordCounter[sliceIndex] += builder.getRecordCount();
151153
this.bytesCounter[sliceIndex] += builder.getBufferSize();
152154
IPipelineSlice resultSlice = this.resultSlices[sliceIndex];
153-
resultSlice.add(new PipeBuffer(builder.build(), windowId, true));
155+
resultSlice.add(new PipeBuffer(builder.build(), windowId));
154156
}
155157

156158
private void sendBarrier(int sliceIndex, long windowId, int count, boolean isFinish) {
157159
IPipelineSlice resultSlice = this.resultSlices[sliceIndex];
158-
resultSlice.add(new PipeBuffer(windowId, count, false, isFinish));
160+
resultSlice.add(new PipeBuffer(windowId, count, isFinish));
159161
}
160162

161163
private void flushFloatingBuffers(long windowId) {

geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/com/antgroup/geaflow/shuffle/config/ShuffleConfig.java

+26-4
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,18 @@
2828
import static com.antgroup.geaflow.common.config.keys.ExecutionConfigKeys.NETTY_SERVER_PORT;
2929
import static com.antgroup.geaflow.common.config.keys.ExecutionConfigKeys.NETTY_SERVER_THREADS_NUM;
3030
import static com.antgroup.geaflow.common.config.keys.ExecutionConfigKeys.NETTY_THREAD_CACHE_ENABLE;
31+
import static com.antgroup.geaflow.common.config.keys.ExecutionConfigKeys.SHUFFLE_BACKPRESSURE_ENABLE;
3132
import static com.antgroup.geaflow.common.config.keys.ExecutionConfigKeys.SHUFFLE_COMPRESSION_ENABLE;
3233
import static com.antgroup.geaflow.common.config.keys.ExecutionConfigKeys.SHUFFLE_EMIT_BUFFER_SIZE;
3334
import static com.antgroup.geaflow.common.config.keys.ExecutionConfigKeys.SHUFFLE_EMIT_QUEUE_SIZE;
35+
import static com.antgroup.geaflow.common.config.keys.ExecutionConfigKeys.SHUFFLE_FETCH_CHANNEL_QUEUE_SIZE;
3436
import static com.antgroup.geaflow.common.config.keys.ExecutionConfigKeys.SHUFFLE_FETCH_QUEUE_SIZE;
3537
import static com.antgroup.geaflow.common.config.keys.ExecutionConfigKeys.SHUFFLE_FETCH_TIMEOUT_MS;
3638
import static com.antgroup.geaflow.common.config.keys.ExecutionConfigKeys.SHUFFLE_FLUSH_BUFFER_SIZE_BYTES;
3739
import static com.antgroup.geaflow.common.config.keys.ExecutionConfigKeys.SHUFFLE_FLUSH_BUFFER_TIMEOUT_MS;
3840
import static com.antgroup.geaflow.common.config.keys.ExecutionConfigKeys.SHUFFLE_MEMORY_POOL_ENABLE;
3941
import static com.antgroup.geaflow.common.config.keys.ExecutionConfigKeys.SHUFFLE_STORAGE_TYPE;
42+
import static com.antgroup.geaflow.common.config.keys.ExecutionConfigKeys.SHUFFLE_WRITER_BUFFER_SIZE;
4043

4144
import com.antgroup.geaflow.common.config.Configuration;
4245
import com.antgroup.geaflow.common.shuffle.StorageLevel;
@@ -69,6 +72,7 @@ public class ShuffleConfig {
6972
private final boolean threadCacheEnabled;
7073
private final boolean preferDirectBuffer;
7174
private final boolean customFrameDecoderEnable;
75+
private final boolean enableBackpressure;
7276

7377
//////////////////////////////
7478
// Read & Write
@@ -83,14 +87,17 @@ public class ShuffleConfig {
8387

8488
private final int fetchTimeoutMs;
8589
private final int fetchQueueSize;
90+
private final int channelQueueSize;
91+
8692

8793
//////////////////////////////
8894
// Write
8995
//////////////////////////////
9096

9197
private final int emitQueueSize;
9298
private final int emitBufferSize;
93-
private final int flushBufferSizeBytes;
99+
private final int maxBufferSizeBytes;
100+
private final int maxWriteBufferSize;
94101
private final int flushBufferTimeoutMs;
95102
private final StorageLevel storageLevel;
96103

@@ -116,15 +123,18 @@ public ShuffleConfig(Configuration config) {
116123
// read & write
117124
this.memoryPoolEnable = config.getBoolean(SHUFFLE_MEMORY_POOL_ENABLE);
118125
this.compressionEnabled = config.getBoolean(SHUFFLE_COMPRESSION_ENABLE);
126+
this.enableBackpressure = config.getBoolean(SHUFFLE_BACKPRESSURE_ENABLE);
119127

120128
// read
121129
this.fetchTimeoutMs = config.getInteger(SHUFFLE_FETCH_TIMEOUT_MS);
122130
this.fetchQueueSize = config.getInteger(SHUFFLE_FETCH_QUEUE_SIZE);
131+
this.channelQueueSize = config.getInteger(SHUFFLE_FETCH_CHANNEL_QUEUE_SIZE);
123132

124133
// write
125134
this.emitQueueSize = config.getInteger(SHUFFLE_EMIT_QUEUE_SIZE);
126135
this.emitBufferSize = config.getInteger(SHUFFLE_EMIT_BUFFER_SIZE);
127-
this.flushBufferSizeBytes = config.getInteger(SHUFFLE_FLUSH_BUFFER_SIZE_BYTES);
136+
this.maxBufferSizeBytes = config.getInteger(SHUFFLE_FLUSH_BUFFER_SIZE_BYTES);
137+
this.maxWriteBufferSize = config.getInteger(SHUFFLE_WRITER_BUFFER_SIZE);
128138
this.flushBufferTimeoutMs = config.getInteger(SHUFFLE_FLUSH_BUFFER_TIMEOUT_MS);
129139
this.storageLevel = StorageLevel.valueOf(config.getString(SHUFFLE_STORAGE_TYPE));
130140

@@ -221,6 +231,10 @@ public int getFetchQueueSize() {
221231
return this.fetchQueueSize;
222232
}
223233

234+
public int getChannelQueueSize() {
235+
return channelQueueSize;
236+
}
237+
224238
public int getEmitQueueSize() {
225239
return this.emitQueueSize;
226240
}
@@ -229,8 +243,16 @@ public int getEmitBufferSize() {
229243
return this.emitBufferSize;
230244
}
231245

232-
public int getFlushBufferSizeBytes() {
233-
return this.flushBufferSizeBytes;
246+
public int getMaxBufferSizeBytes() {
247+
return this.maxBufferSizeBytes;
248+
}
249+
250+
public boolean isBackpressureEnabled() {
251+
return enableBackpressure;
252+
}
253+
254+
public int getMaxWriteBufferSize() {
255+
return maxWriteBufferSize;
234256
}
235257

236258
public int getFlushBufferTimeoutMs() {

geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/com/antgroup/geaflow/shuffle/message/SliceId.java

+2
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
import java.util.Objects;
2020

2121
public class SliceId implements Serializable {
22+
private static final long serialVersionUID = 1L;
23+
public static final int SLICE_ID_BYTES = 20;
2224

2325
private final WriterId writerId;
2426
private final int sliceIndex;

geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/com/antgroup/geaflow/shuffle/network/netty/SliceOutputChannelHandler.java

+3-2
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,10 @@ public void notifyNonEmpty(final SequenceSliceReader reader) {
6969
* availability, so there is no race condition here.
7070
*/
7171
private void enqueueReader(final SequenceSliceReader reader) throws Exception {
72-
if (reader.isRegistered() || !reader.hasNext()) {
72+
if (reader.isRegistered() || !reader.isAvailable()) {
7373
return;
7474
}
75+
7576
// Queue an available reader for consumption. If the queue is empty,
7677
// we try trigger the actual write. Otherwise, this will be handled by
7778
// the writeAndFlushNextMessageIfPossible calls.
@@ -102,7 +103,7 @@ public void close() throws IOException {
102103
allReaders.clear();
103104
}
104105

105-
void updateRequestedBatchId(ChannelId receiverId, Consumer<SequenceSliceReader> operation)
106+
public void applyReaderOperation(ChannelId receiverId, Consumer<SequenceSliceReader> operation)
106107
throws Exception {
107108
if (fatalError) {
108109
return;

geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/com/antgroup/geaflow/shuffle/network/netty/SliceRequestClient.java

+17-2
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,10 @@
1616

1717
import com.antgroup.geaflow.shuffle.message.SliceId;
1818
import com.antgroup.geaflow.shuffle.network.ConnectionId;
19+
import com.antgroup.geaflow.shuffle.network.protocol.AddCreditRequest;
1920
import com.antgroup.geaflow.shuffle.network.protocol.BatchRequest;
2021
import com.antgroup.geaflow.shuffle.network.protocol.CloseRequest;
22+
import com.antgroup.geaflow.shuffle.network.protocol.NettyMessage;
2123
import com.antgroup.geaflow.shuffle.network.protocol.SliceRequest;
2224
import com.antgroup.geaflow.shuffle.pipeline.channel.RemoteInputChannel;
2325
import com.antgroup.geaflow.shuffle.util.AtomicReferenceCounter;
@@ -79,7 +81,7 @@ public void requestSlice(SliceId sliceId, final RemoteInputChannel inputChannel,
7981
clientHandler.addInputChannel(inputChannel);
8082

8183
final SliceRequest request = new SliceRequest(sliceId, startBatchId,
82-
inputChannel.getInputChannelId());
84+
inputChannel.getInputChannelId(), inputChannel.getInitialCredit());
8385

8486
final ChannelFutureListener listener = new ChannelFutureListener() {
8587
@Override
@@ -111,10 +113,23 @@ public void run() {
111113

112114
public void requestNextBatch(long batchId, final RemoteInputChannel inputChannel)
113115
throws IOException {
116+
checkNotClosed();
117+
final BatchRequest request = new BatchRequest(batchId, inputChannel.getInputChannelId());
118+
sendRequest(inputChannel, request);
119+
}
114120

121+
public void notifyCreditAvailable(RemoteInputChannel inputChannel) throws IOException {
115122
checkNotClosed();
116123

117-
final BatchRequest request = new BatchRequest(batchId, inputChannel.getInputChannelId());
124+
int credit = inputChannel.getAndResetAvailableCredit();
125+
Preconditions.checkArgument(credit > 0, "Credit must be greater than zero.");
126+
final AddCreditRequest request = new AddCreditRequest(credit,
127+
inputChannel.getInputChannelId());
128+
sendRequest(inputChannel, request);
129+
}
130+
131+
private void sendRequest(RemoteInputChannel inputChannel, NettyMessage request) throws IOException {
132+
checkNotClosed();
118133

119134
final ChannelFutureListener listener = new ChannelFutureListener() {
120135
@Override

geaflow/geaflow-core/geaflow-engine/geaflow-shuffle/src/main/java/com/antgroup/geaflow/shuffle/network/netty/SliceRequestServerHandler.java

+9-2
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
package com.antgroup.geaflow.shuffle.network.netty;
1616

17+
import com.antgroup.geaflow.shuffle.network.protocol.AddCreditRequest;
1718
import com.antgroup.geaflow.shuffle.network.protocol.BatchRequest;
1819
import com.antgroup.geaflow.shuffle.network.protocol.CancelRequest;
1920
import com.antgroup.geaflow.shuffle.network.protocol.CloseRequest;
@@ -56,7 +57,8 @@ protected void channelRead0(ChannelHandlerContext ctx, NettyMessage msg) throws
5657
try {
5758
SequenceSliceReader reader = new SequenceSliceReader(
5859
request.getReceiverId(), outboundQueue);
59-
reader.createSliceReader(request.getSliceId(), request.getStartBatchId());
60+
reader.createSliceReader(request.getSliceId(), request.getStartBatchId(),
61+
request.getInitialCredit());
6062

6163
outboundQueue.notifyReaderCreated(reader);
6264
} catch (Throwable notFound) {
@@ -72,8 +74,13 @@ protected void channelRead0(ChannelHandlerContext ctx, NettyMessage msg) throws
7274
} else if (msgClazz == BatchRequest.class) {
7375
BatchRequest request = (BatchRequest) msg;
7476

75-
outboundQueue.updateRequestedBatchId(request.receiverId(),
77+
outboundQueue.applyReaderOperation(request.receiverId(),
7678
reader -> reader.requestBatch(request.getNextBatchId()));
79+
} else if (msgClazz == AddCreditRequest.class) {
80+
AddCreditRequest request = (AddCreditRequest) msg;
81+
82+
outboundQueue.applyReaderOperation(request.receiverId(),
83+
reader -> reader.addCredit(request.getCredit()));
7784
} else {
7885
LOGGER.warn("Received unexpected client request: {}", msg);
7986
respondWithError(ctx, new IllegalArgumentException("unknown request:" + msg));

0 commit comments

Comments
 (0)