Skip to content

Commit dd38b7f

Browse files
committed
Move input block building outside the async portion.
1 parent c0f5e27 commit dd38b7f

File tree

10 files changed

+38
-39
lines changed

10 files changed

+38
-39
lines changed

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/InferenceOperator.java

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,10 +72,7 @@ public Page getOutput() {
7272
@Override
7373
protected void performAsync(Page input, ActionListener<OngoingInference> listener) {
7474
try {
75-
BulkInferenceRequestIterator requests = requests(input);
76-
listener = ActionListener.releaseBefore(requests, listener);
77-
78-
bulkInferenceExecutor.execute(requests, listener.map(responses -> new OngoingInference(input, responses)));
75+
bulkInferenceExecutor.execute(requests(input), listener.map(responses -> new OngoingInference(input, responses)));
7976
} catch (Exception e) {
8077
listener.onFailure(e);
8178
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/XContentRowEncoder.java

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
import java.util.Arrays;
2929
import java.util.List;
3030
import java.util.Map;
31-
import java.util.Objects;
3231
import java.util.stream.Collectors;
3332

3433
/**
@@ -109,7 +108,6 @@ public BytesRefBlock eval(Page page) {
109108
} catch (IOException e) {
110109
throw new UncheckedIOException(e);
111110
} finally {
112-
Arrays.stream(fieldValueBlocks).filter(Objects::nonNull).forEach(Block::allowPassingToDifferentDriver);
113111
Releasables.closeExpectNoException(fieldValueBlocks);
114112
}
115113
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/bulk/BulkInferenceRequestIterator.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,10 @@
77

88
package org.elasticsearch.xpack.esql.inference.bulk;
99

10-
import org.elasticsearch.core.Releasable;
1110
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
1211

1312
import java.util.Iterator;
1413

15-
public interface BulkInferenceRequestIterator extends Iterator<InferenceAction.Request>, Releasable {
14+
public interface BulkInferenceRequestIterator extends Iterator<InferenceAction.Request> {
1615

1716
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperator.java

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
package org.elasticsearch.xpack.esql.inference.completion;
99

10-
import org.elasticsearch.compute.data.BytesRefBlock;
1110
import org.elasticsearch.compute.data.Page;
1211
import org.elasticsearch.compute.operator.DriverContext;
1312
import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator;
@@ -54,6 +53,16 @@ public CompletionOperator(
5453
this.promptEvaluator = promptEvaluator;
5554
}
5655

56+
@Override
57+
public void addInput(Page input) {
58+
try {
59+
super.addInput(input.appendBlock(promptEvaluator.eval(input)));
60+
} catch (Exception e) {
61+
releasePageOnAnyThread(input);
62+
throw e;
63+
}
64+
}
65+
5766
@Override
5867
public Class<ChatCompletionResults> inferenceResultsClass() {
5968
return ChatCompletionResults.class;
@@ -71,7 +80,7 @@ public String toString() {
7180

7281
@Override
7382
protected BulkInferenceRequestIterator requests(Page inputPage) {
74-
return new CompletionOperatorRequestIterator((BytesRefBlock) promptEvaluator.eval(inputPage), inferenceId());
83+
return new CompletionOperatorRequestIterator(inputPage, inferenceId());
7584
}
7685

7786
@Override

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorOutputBuilder.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import org.elasticsearch.xpack.esql.inference.InferenceOperator;
1717

1818
import java.util.concurrent.atomic.AtomicBoolean;
19+
import java.util.stream.IntStream;
1920

2021
public class CompletionOperatorOutputBuilder implements InferenceOperator.OutputBuilder<ChatCompletionResults> {
2122
private final Page inputPage;
@@ -52,6 +53,6 @@ public void addInferenceResults(ChatCompletionResults completionResults) {
5253
public Page buildOutput() {
5354
Block outputBlock = outputBlockBuilder.build();
5455
assert outputBlock.getPositionCount() == inputPage.getPositionCount();
55-
return inputPage.shallowCopy().appendBlock(outputBlock);
56+
return inputPage.projectBlocks(IntStream.range(0, inputPage.getBlockCount() - 1).toArray()).appendBlock(outputBlock);
5657
}
5758
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/completion/CompletionOperatorRequestIterator.java

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,7 @@
99

1010
import org.apache.lucene.util.BytesRef;
1111
import org.elasticsearch.compute.data.BytesRefBlock;
12-
import org.elasticsearch.core.Releasable;
13-
import org.elasticsearch.core.Releasables;
12+
import org.elasticsearch.compute.data.Page;
1413
import org.elasticsearch.inference.TaskType;
1514
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
1615
import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRequestIterator;
@@ -25,17 +24,14 @@ public class CompletionOperatorRequestIterator implements BulkInferenceRequestIt
2524
private final int size;
2625
private int currentPos = 0;
2726

28-
public CompletionOperatorRequestIterator(BytesRefBlock promptBlock, String inferenceId) {
27+
public CompletionOperatorRequestIterator(Page inputPage, String inferenceId) {
28+
assert inputPage.getBlockCount() > 0;
29+
BytesRefBlock promptBlock = inputPage.getBlock(inputPage.getBlockCount() - 1);
2930
this.promptReader = new PromptReader(promptBlock);
3031
this.size = promptBlock.getPositionCount();
3132
this.inferenceId = inferenceId;
3233
}
3334

34-
@Override
35-
public void close() {
36-
Releasables.close(promptReader);
37-
}
38-
3935
@Override
4036
public boolean hasNext() {
4137
return currentPos < size;
@@ -53,7 +49,7 @@ private InferenceAction.Request inferenceRequest(String prompt) {
5349
return InferenceAction.Request.builder(inferenceId, TaskType.COMPLETION).setInput(List.of(prompt)).build();
5450
}
5551

56-
private static class PromptReader implements Releasable {
52+
private static class PromptReader {
5753
private final BytesRefBlock promptBlock;
5854
private BytesRef readBuffer = new BytesRef();
5955
private StringBuilder strBuilder = new StringBuilder();
@@ -76,11 +72,5 @@ public String readPrompt(int pos) {
7672

7773
return strBuilder.toString();
7874
}
79-
80-
@Override
81-
public void close() {
82-
promptBlock.allowPassingToDifferentDriver();
83-
Releasables.close(promptBlock);
84-
}
8575
}
8676
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperator.java

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
package org.elasticsearch.xpack.esql.inference.rerank;
99

10-
import org.elasticsearch.compute.data.BytesRefBlock;
1110
import org.elasticsearch.compute.data.Page;
1211
import org.elasticsearch.compute.operator.DriverContext;
1312
import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator;
@@ -69,6 +68,16 @@ public RerankOperator(
6968
this.scoreChannel = scoreChannel;
7069
}
7170

71+
@Override
72+
public void addInput(Page input) {
73+
try {
74+
super.addInput(input.appendBlock(rowEncoder.eval(input)));
75+
} catch (Exception e) {
76+
releasePageOnAnyThread(input);
77+
throw e;
78+
}
79+
}
80+
7281
@Override
7382
public Class<RankedDocsResults> inferenceResultsClass() {
7483
return RankedDocsResults.class;
@@ -86,7 +95,7 @@ public String toString() {
8695

8796
@Override
8897
protected RerankOperatorRequestIterator requests(Page inputPage) {
89-
return new RerankOperatorRequestIterator((BytesRefBlock) rowEncoder.eval(inputPage), inferenceId(), queryText, batchSize);
98+
return new RerankOperatorRequestIterator(inputPage, inferenceId(), queryText, batchSize);
9099
}
91100

92101
@Override

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperatorOutputBuilder.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ public void close() {
3535

3636
@Override
3737
public Page buildOutput() {
38-
int blockCount = Integer.max(inputPage.getBlockCount(), scoreChannel + 1);
38+
int blockCount = Integer.max(inputPage.getBlockCount() - 1, scoreChannel + 1);
3939
Block[] blocks = new Block[blockCount];
4040

4141
try {

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/inference/rerank/RerankOperatorRequestIterator.java

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import org.apache.lucene.util.BytesRef;
1111
import org.elasticsearch.common.lucene.BytesRefs;
1212
import org.elasticsearch.compute.data.BytesRefBlock;
13-
import org.elasticsearch.core.Releasables;
13+
import org.elasticsearch.compute.data.Page;
1414
import org.elasticsearch.inference.TaskType;
1515
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
1616
import org.elasticsearch.xpack.esql.inference.bulk.BulkInferenceRequestIterator;
@@ -26,8 +26,9 @@ public class RerankOperatorRequestIterator implements BulkInferenceRequestIterat
2626
private final int batchSize;
2727
private int remainingPositions;
2828

29-
public RerankOperatorRequestIterator(BytesRefBlock inputBlock, String inferenceId, String queryText, int batchSize) {
30-
this.inputBlock = inputBlock;
29+
public RerankOperatorRequestIterator(Page inputPage, String inferenceId, String queryText, int batchSize) {
30+
assert inputPage.getBlockCount() > 0;
31+
this.inputBlock = inputPage.getBlock(inputPage.getBlockCount() - 1);
3132
this.inferenceId = inferenceId;
3233
this.queryText = queryText;
3334
this.batchSize = batchSize;
@@ -64,12 +65,6 @@ public InferenceAction.Request next() {
6465
return inferenceRequest(inputs);
6566
}
6667

67-
@Override
68-
public void close() {
69-
inputBlock.allowPassingToDifferentDriver();
70-
Releasables.close(inputBlock);
71-
}
72-
7368
private InferenceAction.Request inferenceRequest(List<String> inputs) {
7469
return InferenceAction.Request.builder(inferenceId, TaskType.RERANK).setInput(inputs).setQuery(queryText).build();
7570
}

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/inference/InferenceOperatorTestCase.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,8 @@ protected ThreadPool threadPool() {
7979

8080
@Override
8181
protected SourceOperator simpleInput(BlockFactory blockFactory, int size) {
82-
return new AbstractBlockSourceOperator(blockFactory, 8 * 1024) {
82+
final int minPageSize = Math.max(1, size / 100);
83+
return new AbstractBlockSourceOperator(blockFactory, between(minPageSize, size)) {
8384
@Override
8485
protected int remaining() {
8586
return size - currentPosition;

0 commit comments

Comments
 (0)