Skip to content

Commit c0f5e27

Browse files
committed
Fix circuit breaker errors.
1 parent 7e0cdee commit c0f5e27

File tree

11 files changed

+151
-151
lines changed

11 files changed

+151
-151
lines changed

muted-tests.yml

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -381,9 +381,6 @@ tests:
381381
- class: org.elasticsearch.packaging.test.DockerTests
382382
method: test024InstallPluginFromArchiveUsingConfigFile
383383
issue: https://github.com/elastic/elasticsearch/issues/126936
384-
- class: org.elasticsearch.xpack.esql.qa.multi_node.EsqlSpecIT
385-
method: test {rerank.Reranker before a limit ASYNC}
386-
issue: https://github.com/elastic/elasticsearch/issues/127051
387384
- class: org.elasticsearch.packaging.test.DockerTests
388385
method: test026InstallBundledRepositoryPlugins
389386
issue: https://github.com/elastic/elasticsearch/issues/127081
@@ -402,9 +399,6 @@ tests:
402399
- class: org.elasticsearch.xpack.test.rest.XPackRestIT
403400
method: test {p0=ml/data_frame_analytics_cat_apis/Test cat data frame analytics all jobs with header}
404401
issue: https://github.com/elastic/elasticsearch/issues/127625
405-
- class: org.elasticsearch.xpack.esql.qa.multi_node.EsqlSpecIT
406-
method: test {rerank.Reranker using another sort order ASYNC}
407-
issue: https://github.com/elastic/elasticsearch/issues/127638
408402
- class: org.elasticsearch.xpack.search.CrossClusterAsyncSearchIT
409403
method: testCancellationViaTimeoutWithAllowPartialResultsSetToFalse
410404
issue: https://github.com/elastic/elasticsearch/issues/127096

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceOperator.java

Lines changed: 54 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -12,27 +12,30 @@
1212
import org.elasticsearch.compute.data.Page;
1313
import org.elasticsearch.compute.operator.AsyncOperator;
1414
import org.elasticsearch.compute.operator.DriverContext;
15+
import org.elasticsearch.core.Releasable;
1516
import org.elasticsearch.inference.InferenceServiceResults;
1617
import org.elasticsearch.threadpool.ThreadPool;
18+
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
1719
import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceExecutionConfig;
1820
import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceExecutor;
19-
import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceOutputBuilder;
2021
import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRequestIterator;
2122

22-
public abstract class InferenceOperator<IR extends InferenceServiceResults> extends AsyncOperator<Page> {
23+
import static org.elasticsearch.common.logging.LoggerMessageFormat.format;
24+
25+
public abstract class InferenceOperator<IR extends InferenceServiceResults> extends AsyncOperator<InferenceOperator.OngoingInference> {
2326

2427
// Move to a setting.
2528
private static final int MAX_INFERENCE_WORKER = 10;
2629
private final String inferenceId;
2730
private final BlockFactory blockFactory;
2831

29-
private final BulkInferenceExecutor<IR, Page> bulkInferenceExecutor;
32+
private final BulkInferenceExecutor bulkInferenceExecutor;
3033

3134
@SuppressWarnings("this-escape")
3235
public InferenceOperator(DriverContext driverContext, InferenceRunner inferenceRunner, ThreadPool threadPool, String inferenceId) {
3336
super(driverContext, threadPool.getThreadContext(), MAX_INFERENCE_WORKER);
3437
this.blockFactory = driverContext.blockFactory();
35-
this.bulkInferenceExecutor = new BulkInferenceExecutor<IR, Page>(inferenceRunner, threadPool, bulkExecutionConfig());
38+
this.bulkInferenceExecutor = new BulkInferenceExecutor(inferenceRunner, threadPool, bulkExecutionConfig());
3639
this.inferenceId = inferenceId;
3740
}
3841

@@ -45,37 +48,71 @@ protected String inferenceId() {
4548
}
4649

4750
@Override
48-
protected void releaseFetchedOnAnyThread(Page fetched) {
49-
releasePageOnAnyThread(fetched);
51+
protected void releaseFetchedOnAnyThread(OngoingInference result) {
52+
releasePageOnAnyThread(result.inputPage);
5053
}
5154

5255
@Override
5356
public Page getOutput() {
54-
return fetchFromBuffer();
57+
OngoingInference ongoingInference = fetchFromBuffer();
58+
if (ongoingInference == null) {
59+
return null;
60+
}
61+
62+
try (OutputBuilder<IR> outputBuilder = outputBuilder(ongoingInference.inputPage)) {
63+
for (int i = 0; i < ongoingInference.responses.length; i++) {
64+
outputBuilder.addInferenceResults(inferenceResults(ongoingInference.responses[i]));
65+
}
66+
return outputBuilder.buildOutput();
67+
} finally {
68+
releaseFetchedOnAnyThread(ongoingInference);
69+
}
5570
}
5671

5772
@Override
58-
protected void performAsync(Page input, ActionListener<Page> listener) {
59-
ActionListener<Page> completionListener = ActionListener.runAfter(listener, () -> releasePageOnAnyThread(input));
60-
73+
protected void performAsync(Page input, ActionListener<OngoingInference> listener) {
6174
try {
6275
BulkInferenceRequestIterator requests = requests(input);
63-
completionListener = ActionListener.releaseBefore(requests, completionListener);
76+
listener = ActionListener.releaseBefore(requests, listener);
6477

65-
BulkInferenceOutputBuilder<IR, Page> outputBuilder = outputBuilder(input);
66-
completionListener = ActionListener.releaseBefore(outputBuilder, completionListener);
67-
68-
bulkInferenceExecutor.execute(requests, outputBuilder, completionListener);
78+
bulkInferenceExecutor.execute(requests, listener.map(responses -> new OngoingInference(input, responses)));
6979
} catch (Exception e) {
70-
completionListener.onFailure(e);
80+
listener.onFailure(e);
7181
}
7282
}
7383

7484
protected BulkInferenceExecutionConfig bulkExecutionConfig() {
7585
return BulkInferenceExecutionConfig.DEFAULT;
7686
}
7787

88+
private IR inferenceResults(InferenceAction.Response inferenceResponse) {
89+
InferenceServiceResults results = inferenceResponse.getResults();
90+
if (inferenceResultsClass().isInstance(results)) {
91+
return inferenceResultsClass().cast(results);
92+
}
93+
94+
throw new IllegalStateException(
95+
format(
96+
"Inference result has wrong type. Got [{}] while expecting [{}]",
97+
results.getClass().getName(),
98+
inferenceResultsClass().getName()
99+
)
100+
);
101+
}
102+
78103
protected abstract BulkInferenceRequestIterator requests(Page input);
79104

80-
protected abstract BulkInferenceOutputBuilder<IR, Page> outputBuilder(Page input);
105+
protected abstract Class<IR> inferenceResultsClass();
106+
107+
protected abstract OutputBuilder<IR> outputBuilder(Page input);
108+
109+
public record OngoingInference(Page inputPage, InferenceAction.Response[] responses) {
110+
111+
}
112+
113+
public interface OutputBuilder<IR extends InferenceServiceResults> extends Releasable {
114+
void addInferenceResults(IR inferenceResults);
115+
116+
Page buildOutput();
117+
}
81118
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceExecutor.java

Lines changed: 10 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -9,19 +9,18 @@
99

1010
import org.elasticsearch.action.ActionListener;
1111
import org.elasticsearch.common.util.concurrent.ThrottledTaskRunner;
12-
import org.elasticsearch.inference.InferenceServiceResults;
1312
import org.elasticsearch.threadpool.ThreadPool;
1413
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
1514
import org.elasticsearch.xpack.esql.inference.InferenceRunner;
1615
import org.elasticsearch.xpack.esql.plugin.EsqlPlugin;
1716

17+
import java.util.ArrayList;
18+
import java.util.List;
1819
import java.util.concurrent.ExecutorService;
1920
import java.util.concurrent.RejectedExecutionException;
2021
import java.util.concurrent.TimeoutException;
2122

22-
import static org.elasticsearch.common.logging.LoggerMessageFormat.format;
23-
24-
public class BulkInferenceExecutor<IR extends InferenceServiceResults, OutputType> {
23+
public class BulkInferenceExecutor {
2524
private static final String TASK_RUNNER_NAME = "bulk_inference_operation";
2625
private static final int INFERENCE_RESPONSE_TIMEOUT = 30; // TODO: should be in the config.
2726
private final ThrottledInferenceRunner throttledInferenceRunner;
@@ -32,12 +31,8 @@ public BulkInferenceExecutor(InferenceRunner inferenceRunner, ThreadPool threadP
3231
throttledInferenceRunner = ThrottledInferenceRunner.create(inferenceRunner, executorService, bulkExecutionConfig);
3332
}
3433

35-
public void execute(
36-
BulkInferenceRequestIterator requests,
37-
BulkInferenceOutputBuilder<IR, OutputType> outputBuilder,
38-
ActionListener<OutputType> listener
39-
) throws Exception {
40-
final ResponseHandler<IR, OutputType> responseHandler = new ResponseHandler<>(outputBuilder);
34+
public void execute(BulkInferenceRequestIterator requests, ActionListener<InferenceAction.Response[]> listener) throws Exception {
35+
final ResponseHandler responseHandler = new ResponseHandler();
4136
runInferenceRequests(requests, listener.delegateFailureAndWrap(responseHandler::handleResponses));
4237
}
4338

@@ -66,14 +61,10 @@ private void runInferenceRequests(BulkInferenceRequestIterator requests, ActionL
6661
}
6762
}
6863

69-
private static class ResponseHandler<IR extends InferenceServiceResults, OutputType> {
70-
private final BulkInferenceOutputBuilder<IR, OutputType> outputBuilder;
71-
72-
private ResponseHandler(BulkInferenceOutputBuilder<IR, OutputType> outputBuilder) {
73-
this.outputBuilder = outputBuilder;
74-
}
64+
private static class ResponseHandler {
65+
private final List<InferenceAction.Response> responses = new ArrayList<>();
7566

76-
private void handleResponses(ActionListener<OutputType> listener, BulkInferenceExecutionState bulkExecutionState) {
67+
private void handleResponses(ActionListener<InferenceAction.Response[]> listener, BulkInferenceExecutionState bulkExecutionState) {
7768

7869
try {
7970
persistsInferenceResponses(bulkExecutionState);
@@ -84,7 +75,7 @@ private void handleResponses(ActionListener<OutputType> listener, BulkInferenceE
8475

8576
if (bulkExecutionState.hasFailure() == false) {
8677
try {
87-
listener.onResponse(outputBuilder.buildOutput());
78+
listener.onResponse(responses.toArray(InferenceAction.Response[]::new));
8879
return;
8980
} catch (Exception e) {
9081
bulkExecutionState.addFailure(e);
@@ -105,7 +96,7 @@ private void persistsInferenceResponses(BulkInferenceExecutionState bulkExecutio
10596
assert response != null || bulkExecutionState.hasFailure();
10697
if (bulkExecutionState.hasFailure() == false) {
10798
try {
108-
persistsInferenceResponse(response);
99+
responses.add(response);
109100
} catch (Exception e) {
110101
bulkExecutionState.addFailure(e);
111102
}
@@ -114,22 +105,6 @@ private void persistsInferenceResponses(BulkInferenceExecutionState bulkExecutio
114105
}
115106
}
116107
}
117-
118-
private void persistsInferenceResponse(InferenceAction.Response inferenceResponse) {
119-
InferenceServiceResults results = inferenceResponse.getResults();
120-
if (outputBuilder.inferenceResultsClass().isInstance(results)) {
121-
outputBuilder.addInferenceResults(outputBuilder.inferenceResultsClass().cast(results));
122-
return;
123-
}
124-
125-
throw new IllegalStateException(
126-
format(
127-
"Inference result has wrong type. Got [{}] while expecting [{}]",
128-
results.getClass().getName(),
129-
outputBuilder.inferenceResultsClass().getName()
130-
)
131-
);
132-
}
133108
}
134109

135110
private static class ThrottledInferenceRunner extends ThrottledTaskRunner {

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceOutputBuilder.java

Lines changed: 0 additions & 19 deletions
This file was deleted.

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperator.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,11 @@ public CompletionOperator(
5454
this.promptEvaluator = promptEvaluator;
5555
}
5656

57+
@Override
58+
public Class<ChatCompletionResults> inferenceResultsClass() {
59+
return ChatCompletionResults.class;
60+
}
61+
5762
@Override
5863
protected void doClose() {
5964
Releasables.close(promptEvaluator);

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorOutputBuilder.java

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,11 @@
1313
import org.elasticsearch.compute.data.Page;
1414
import org.elasticsearch.core.Releasables;
1515
import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults;
16-
import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceOutputBuilder;
16+
import org.elasticsearch.xpack.esql.inference.InferenceOperator;
1717

1818
import java.util.concurrent.atomic.AtomicBoolean;
1919

20-
public class CompletionOperatorOutputBuilder implements BulkInferenceOutputBuilder<ChatCompletionResults, Page> {
20+
public class CompletionOperatorOutputBuilder implements InferenceOperator.OutputBuilder<ChatCompletionResults> {
2121
private final Page inputPage;
2222
private final BytesRefBlock.Builder outputBlockBuilder;
2323
private final BytesRefBuilder bytesRefBuilder = new BytesRefBuilder();
@@ -28,11 +28,6 @@ public CompletionOperatorOutputBuilder(BytesRefBlock.Builder outputBlockBuilder,
2828
this.outputBlockBuilder = outputBlockBuilder;
2929
}
3030

31-
@Override
32-
public Class<ChatCompletionResults> inferenceResultsClass() {
33-
return ChatCompletionResults.class;
34-
}
35-
3631
@Override
3732
public void close() {
3833
Releasables.close(outputBlockBuilder);

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperator.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,11 @@ public RerankOperator(
6969
this.scoreChannel = scoreChannel;
7070
}
7171

72+
@Override
73+
public Class<RankedDocsResults> inferenceResultsClass() {
74+
return RankedDocsResults.class;
75+
}
76+
7277
@Override
7378
protected void doClose() {
7479
Releasables.close(rowEncoder);

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperatorOutputBuilder.java

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,11 @@
1212
import org.elasticsearch.compute.data.Page;
1313
import org.elasticsearch.core.Releasables;
1414
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
15-
import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceOutputBuilder;
15+
import org.elasticsearch.xpack.esql.inference.InferenceOperator;
1616

1717
import java.util.Comparator;
1818

19-
public class RerankOperatorOutputBuilder implements BulkInferenceOutputBuilder<RankedDocsResults, Page> {
19+
public class RerankOperatorOutputBuilder implements InferenceOperator.OutputBuilder<RankedDocsResults> {
2020

2121
private final Page inputPage;
2222
private final DoubleBlock.Builder scoreBlockBuilder;
@@ -28,11 +28,6 @@ public RerankOperatorOutputBuilder(DoubleBlock.Builder scoreBlockBuilder, Page i
2828
this.scoreChannel = scoreChannel;
2929
}
3030

31-
@Override
32-
public Class<RankedDocsResults> inferenceResultsClass() {
33-
return RankedDocsResults.class;
34-
}
35-
3631
@Override
3732
public void close() {
3833
Releasables.close(scoreBlockBuilder);

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceOperatorTestCase.java

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import org.elasticsearch.compute.operator.SourceOperator;
2929
import org.elasticsearch.compute.test.AbstractBlockSourceOperator;
3030
import org.elasticsearch.compute.test.OperatorTestCase;
31+
import org.elasticsearch.core.TimeValue;
3132
import org.elasticsearch.inference.InferenceServiceResults;
3233
import org.elasticsearch.threadpool.FixedExecutorBuilder;
3334
import org.elasticsearch.threadpool.TestThreadPool;
@@ -114,12 +115,25 @@ protected InferenceRunner mockedSimpleInferenceRunner() {
114115
InferenceRunner inferenceRunner = mock(InferenceRunner.class);
115116
when(inferenceRunner.threadPool()).thenReturn(threadPool());
116117
doAnswer(i -> {
117-
@SuppressWarnings("unchecked")
118-
ActionListener<InferenceAction.Response> listener = i.getArgument(1, ActionListener.class);
119-
InferenceAction.Request request = i.getArgument(0, InferenceAction.Request.class);
120-
InferenceAction.Response inferenceResponse = mock(InferenceAction.Response.class);
121-
when(inferenceResponse.getResults()).thenReturn(mockInferenceResult(request));
122-
listener.onResponse(inferenceResponse);
118+
Runnable sendResponse = () -> {
119+
@SuppressWarnings("unchecked")
120+
ActionListener<InferenceAction.Response> listener = i.getArgument(1, ActionListener.class);
121+
InferenceAction.Request request = i.getArgument(0, InferenceAction.Request.class);
122+
InferenceAction.Response inferenceResponse = mock(InferenceAction.Response.class);
123+
when(inferenceResponse.getResults()).thenReturn(mockInferenceResult(request));
124+
listener.onResponse(inferenceResponse);
125+
};
126+
127+
if (randomBoolean()) {
128+
sendResponse.run();
129+
} else {
130+
threadPool.schedule(
131+
sendResponse,
132+
TimeValue.timeValueNanos(between(1, 1_000)),
133+
threadPool.executor(EsqlPlugin.ESQL_WORKER_THREAD_POOL_NAME)
134+
);
135+
}
136+
123137
return null;
124138
}).when(inferenceRunner).doInference(any(InferenceAction.Request.class), any());
125139
return inferenceRunner;

0 commit comments

Comments
 (0)