Skip to content

Commit e3e1873

Browse files
committed
Refactored bulk inference execution.
1 parent fd5d656 commit e3e1873

20 files changed

+627
-702
lines changed
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.inference;
9+
10+
import org.elasticsearch.action.ActionListener;
11+
import org.elasticsearch.compute.operator.FailureCollector;
12+
import org.elasticsearch.core.CheckedConsumer;
13+
import org.elasticsearch.index.seqno.LocalCheckpointTracker;
14+
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
15+
16+
import java.util.Iterator;
17+
import java.util.Map;
18+
import java.util.concurrent.ConcurrentHashMap;
19+
import java.util.concurrent.atomic.AtomicBoolean;
20+
21+
import static org.elasticsearch.index.seqno.SequenceNumbers.NO_OPS_PERFORMED;
22+
23+
public class BulkInferenceOperation {
24+
private final Iterator<InferenceAction.Request> requests;
25+
private final CheckedConsumer<InferenceAction.Response, ?> responseConsumer;
26+
private final LocalCheckpointTracker checkpoint = new LocalCheckpointTracker(NO_OPS_PERFORMED, NO_OPS_PERFORMED);
27+
private final FailureCollector failureCollector = new FailureCollector();
28+
private final Map<Long, InferenceAction.Response> bufferedResponses = new ConcurrentHashMap<>();
29+
private final AtomicBoolean responseSent = new AtomicBoolean(false);
30+
31+
public BulkInferenceOperation(
32+
Iterator<InferenceAction.Request> requests,
33+
CheckedConsumer<InferenceAction.Response, ?> responseConsumer
34+
) {
35+
this.requests = requests;
36+
this.responseConsumer = responseConsumer;
37+
}
38+
39+
public void execute(InferenceExecutionContext ctx, ActionListener<Void> completionListener) {
40+
int threadCount = 0;
41+
while (threadCount++ < ctx.maxConcurrentRequests()) {
42+
ctx.executorService().submit(() -> {
43+
while (true) {
44+
BulkInferenceRequestItem bulkItemRequest = nextRequest();
45+
if (bulkItemRequest == null) {
46+
break;
47+
}
48+
execute(bulkItemRequest, ctx, () -> onResponseProcessed(completionListener));
49+
}
50+
});
51+
}
52+
}
53+
54+
private void execute(BulkInferenceRequestItem bulkItemRequest, InferenceExecutionContext ctx, Runnable onResponseProcessed) {
55+
final ActionListener<InferenceAction.Response> responseListener = ActionListener.wrap(
56+
inferenceResponse -> onInferenceResponse(bulkItemRequest.seqNo(), inferenceResponse),
57+
exception -> onInferenceException(bulkItemRequest.seqNo(), exception)
58+
);
59+
60+
ctx.inferenceRunner().doInference(bulkItemRequest.request(), ActionListener.runAfter(responseListener, onResponseProcessed));
61+
}
62+
63+
private boolean isCompleted() {
64+
return requests.hasNext() == false && checkpoint.getMaxSeqNo() == checkpoint.getPersistedCheckpoint();
65+
}
66+
67+
private void onResponseProcessed(ActionListener<Void> completionListener) {
68+
if (isCompleted() && responseSent.compareAndSet(false, true)) {
69+
if (failureCollector.hasFailure()) {
70+
completionListener.onFailure(failureCollector.getFailure());
71+
return;
72+
}
73+
completionListener.onResponse(null);
74+
}
75+
}
76+
77+
private BulkInferenceRequestItem nextRequest() {
78+
synchronized (checkpoint) {
79+
if (requests.hasNext()) {
80+
return new BulkInferenceRequestItem(checkpoint.generateSeqNo(), requests.next());
81+
}
82+
83+
return null;
84+
}
85+
}
86+
87+
private void onInferenceResponse(long seqNo, InferenceAction.Response response) {
88+
if (failureCollector.hasFailure() == false) {
89+
bufferedResponses.put(seqNo, response);
90+
}
91+
checkpoint.markSeqNoAsProcessed(seqNo);
92+
93+
synchronized (checkpoint) {
94+
long persistSeqNo = checkpoint.getPersistedCheckpoint();
95+
while (persistSeqNo < checkpoint.getProcessedCheckpoint()) {
96+
persistSeqNo++;
97+
if (failureCollector.hasFailure() == false) {
98+
try {
99+
responseConsumer.accept(bufferedResponses.remove(persistSeqNo));
100+
} catch (Exception e) {
101+
failureCollector.unwrapAndCollect(e);
102+
}
103+
}
104+
checkpoint.markSeqNoAsPersisted(persistSeqNo);
105+
}
106+
}
107+
}
108+
109+
public void onInferenceException(long seqNo, Exception e) {
110+
failureCollector.unwrapAndCollect(e);
111+
checkpoint.markSeqNoAsProcessed(seqNo);
112+
113+
synchronized (checkpoint) {
114+
long persistSeqNo = checkpoint.getPersistedCheckpoint();
115+
while (persistSeqNo < checkpoint.getProcessedCheckpoint()) {
116+
persistSeqNo++;
117+
bufferedResponses.remove(persistSeqNo);
118+
checkpoint.markSeqNoAsPersisted(persistSeqNo);
119+
}
120+
}
121+
}
122+
123+
private record BulkInferenceRequestItem(long seqNo, InferenceAction.Request request) {}
124+
}
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.inference;
9+
10+
import org.apache.lucene.util.BytesRef;
11+
import org.apache.lucene.util.BytesRefBuilder;
12+
import org.elasticsearch.compute.data.Block;
13+
import org.elasticsearch.compute.data.BlockFactory;
14+
import org.elasticsearch.compute.data.BytesRefBlock;
15+
import org.elasticsearch.compute.data.Page;
16+
import org.elasticsearch.compute.operator.DriverContext;
17+
import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator;
18+
import org.elasticsearch.compute.operator.Operator;
19+
import org.elasticsearch.core.Releasables;
20+
import org.elasticsearch.inference.TaskType;
21+
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
22+
import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults;
23+
24+
import java.util.List;
25+
import java.util.NoSuchElementException;
26+
27+
public class CompletionOperator extends InferenceOperator<ChatCompletionResults> {
28+
29+
public record Factory(InferenceRunner inferenceRunner, String inferenceId, ExpressionEvaluator.Factory promptEvaluatorFactory)
30+
implements
31+
OperatorFactory {
32+
@Override
33+
public String describe() {
34+
return "Completion[inference_id=[" + inferenceId + "]]";
35+
}
36+
37+
@Override
38+
public Operator get(DriverContext driverContext) {
39+
return new CompletionOperator(driverContext, inferenceRunner, inferenceId, promptEvaluatorFactory.get(driverContext));
40+
}
41+
}
42+
43+
private final ExpressionEvaluator promptEvaluator;
44+
private final BlockFactory blockFactory;
45+
46+
public CompletionOperator(
47+
DriverContext driverContext,
48+
InferenceRunner inferenceRunner,
49+
String inferenceId,
50+
ExpressionEvaluator promptEvaluator
51+
) {
52+
super(driverContext, inferenceRunner, inferenceId);
53+
this.promptEvaluator = promptEvaluator;
54+
this.blockFactory = driverContext.blockFactory();
55+
}
56+
57+
@Override
58+
protected void doClose() {
59+
Releasables.closeExpectNoException(promptEvaluator);
60+
}
61+
62+
@Override
63+
public String toString() {
64+
return "CompletionOperator[inference_id=[" + inferenceId() + "]]";
65+
}
66+
67+
@Override
68+
protected RequestIterator requests(Page inputPage) {
69+
return new InferenceOperator.RequestIterator() {
70+
private final BytesRefBlock promptBlock = (BytesRefBlock) promptEvaluator.eval(inputPage);
71+
private BytesRef readBuffer = new BytesRef();
72+
private int currentPos = 0;
73+
74+
@Override
75+
public boolean hasNext() {
76+
return currentPos < promptBlock.getPositionCount();
77+
}
78+
79+
@Override
80+
public InferenceAction.Request next() {
81+
if (hasNext() == false) {
82+
throw new NoSuchElementException();
83+
}
84+
int pos = currentPos++;
85+
86+
if (promptBlock.isNull(pos)) {
87+
return null;
88+
}
89+
90+
StringBuilder promptBuilder = new StringBuilder();
91+
for (int valueIndex = 0; valueIndex < promptBlock.getValueCount(pos); valueIndex++) {
92+
readBuffer = promptBlock.getBytesRef(promptBlock.getFirstValueIndex(pos) + valueIndex, readBuffer);
93+
promptBuilder.append(readBuffer.utf8ToString()).append("\n");
94+
}
95+
96+
return inferenceRequest(promptBuilder.toString());
97+
}
98+
99+
@Override
100+
public void close() {
101+
promptBlock.allowPassingToDifferentDriver();
102+
Releasables.closeExpectNoException(promptBlock);
103+
}
104+
};
105+
}
106+
107+
@Override
108+
protected OutputBuilder<ChatCompletionResults> outputBuilder(Page inputPage) {
109+
return new InferenceOperator.OutputBuilder<>() {
110+
private final BytesRefBlock.Builder outputBlockBuilder = blockFactory.newBytesRefBlockBuilder(inputPage.getPositionCount());
111+
private final BytesRefBuilder bytesRefBuilder = new BytesRefBuilder();
112+
113+
@Override
114+
public void close() {
115+
Releasables.closeExpectNoException(outputBlockBuilder);
116+
}
117+
118+
@Override
119+
public void onInferenceResults(ChatCompletionResults completionResults) {
120+
if (completionResults == null || completionResults.getResults().isEmpty()) {
121+
outputBlockBuilder.appendNull();
122+
} else {
123+
outputBlockBuilder.beginPositionEntry();
124+
for (ChatCompletionResults.Result rankedDocsResult : completionResults.getResults()) {
125+
bytesRefBuilder.copyChars(rankedDocsResult.content());
126+
outputBlockBuilder.appendBytesRef(bytesRefBuilder.get());
127+
bytesRefBuilder.clear();
128+
}
129+
outputBlockBuilder.endPositionEntry();
130+
}
131+
}
132+
133+
@Override
134+
protected Class<ChatCompletionResults> inferenceResultsClass() {
135+
return ChatCompletionResults.class;
136+
}
137+
138+
@Override
139+
public Page buildOutput() {
140+
Block outputBlock = outputBlockBuilder.build();
141+
assert outputBlock.getPositionCount() == inputPage.getPositionCount();
142+
return inputPage.appendBlock(outputBlock);
143+
}
144+
};
145+
}
146+
147+
private InferenceAction.Request inferenceRequest(String prompt) {
148+
return InferenceAction.Request.builder(inferenceId(), TaskType.COMPLETION).setInput(List.of(prompt)).build();
149+
}
150+
}
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.inference;
9+
10+
import java.util.concurrent.ExecutorService;
11+
12+
public class InferenceExecutionContext {
13+
private static final int DEFAULT_MAX_CONCURRENT_REQUESTS = 10;
14+
private final InferenceRunner inferenceRunner;
15+
private final ExecutorService executorService;
16+
private final int maxConcurrentRequests;
17+
18+
private InferenceExecutionContext(InferenceRunner inferenceRunner, ExecutorService executorService, int maxConcurrentRequests) {
19+
this.inferenceRunner = inferenceRunner;
20+
this.executorService = executorService;
21+
this.maxConcurrentRequests = maxConcurrentRequests;
22+
}
23+
24+
public InferenceRunner inferenceRunner() {
25+
return inferenceRunner;
26+
}
27+
28+
public ExecutorService executorService() {
29+
return executorService;
30+
}
31+
32+
public int maxConcurrentRequests() {
33+
return maxConcurrentRequests;
34+
}
35+
36+
public static class Builder {
37+
private final InferenceRunner inferenceRunner;
38+
private final ExecutorService executorService;
39+
private int maxConcurrentRequests = DEFAULT_MAX_CONCURRENT_REQUESTS;
40+
41+
Builder(InferenceRunner inferenceRunner, ExecutorService executorService) {
42+
this.inferenceRunner = inferenceRunner;
43+
this.executorService = executorService;
44+
}
45+
46+
public InferenceExecutionContext build() {
47+
return new InferenceExecutionContext(inferenceRunner, executorService, maxConcurrentRequests);
48+
}
49+
50+
public Builder setMaxConcurrentRequests(int maxConcurrentRequests) {
51+
this.maxConcurrentRequests = maxConcurrentRequests;
52+
return this;
53+
}
54+
}
55+
}

0 commit comments

Comments
 (0)