Skip to content

Commit 2847695

Browse files
authored
Unify precomputation of aggregations behind a common API (#16733)
* Unify precomputation of aggregations behind a common API We've had a series of aggregation speedups that use the same strategy: instead of iterating through documents that match the query one-by-one, we can look at a Lucene segment and compute the aggregation directly (if some particular conditions are met). In every case, we've hooked that into custom logic hijacks the getLeafCollector method and throws CollectionTerminatedException. This creates the illusion that we're implementing a custom LeafCollector, when really we're not collecting at all (which is the whole point). With this refactoring, the mechanism (hijacking getLeafCollector) is moved into AggregatorBase. Aggregators that have a strategy to precompute their answer can override tryPrecomputeAggregationForLeaf, which is expected to return true if they managed to precompute. This should also make it easier to keep track of which aggregations have precomputation approaches (since they override this method). Signed-off-by: Michael Froh <[email protected]> * Remove subaggregator check from CompositeAggregator Not sure why I added this, when the existing implementation didn't have it. That said, we *should* call finishLeaf() before precomputing the current leaf. Signed-off-by: Michael Froh <[email protected]> * Resolve conflicts with star-tree changes Signed-off-by: Michael Froh <[email protected]> * Skip precomputation when valuesSource is null Signed-off-by: Michael Froh <[email protected]> * Add comment as suggested by @bowenlan-amzn Signed-off-by: Michael Froh <[email protected]> --------- Signed-off-by: Michael Froh <[email protected]>
1 parent 679a08f commit 2847695

File tree

11 files changed

+168
-132
lines changed

11 files changed

+168
-132
lines changed

server/src/main/java/org/opensearch/search/aggregations/AggregatorBase.java

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
package org.opensearch.search.aggregations;
3333

3434
import org.apache.lucene.index.LeafReaderContext;
35+
import org.apache.lucene.search.CollectionTerminatedException;
3536
import org.apache.lucene.search.MatchAllDocsQuery;
3637
import org.apache.lucene.search.ScoreMode;
3738
import org.opensearch.core.common.breaker.CircuitBreaker;
@@ -200,6 +201,9 @@ public Map<String, Object> metadata() {
200201

201202
@Override
202203
public final LeafBucketCollector getLeafCollector(LeafReaderContext ctx) throws IOException {
204+
if (tryPrecomputeAggregationForLeaf(ctx)) {
205+
throw new CollectionTerminatedException();
206+
}
203207
preGetSubLeafCollectors(ctx);
204208
final LeafBucketCollector sub = collectableSubAggregators.getLeafCollector(ctx);
205209
return getLeafCollector(ctx, sub);
@@ -216,6 +220,21 @@ protected void preGetSubLeafCollectors(LeafReaderContext ctx) throws IOException
216220
*/
217221
protected void doPreCollection() throws IOException {}
218222

223+
/**
224+
* Subclasses may override this method if they have an efficient way of computing their aggregation for the given
225+
* segment (versus collecting matching documents). If this method returns true, collection for the given segment
226+
* will be terminated, rather than executing normally.
227+
* <p>
228+
* If this method returns true, the aggregator's state should be identical to what it would be if matching
229+
* documents from the segment were fully collected. If this method returns false, the aggregator's state should
230+
* be unchanged from before this method is called.
231+
* @param ctx the context for the given segment
232+
* @return true if and only if results for this segment have been precomputed
233+
*/
234+
protected boolean tryPrecomputeAggregationForLeaf(LeafReaderContext ctx) throws IOException {
235+
return false;
236+
}
237+
219238
@Override
220239
public final void preCollection() throws IOException {
221240
List<BucketCollector> collectors = Arrays.asList(subAggregators);
@@ -251,8 +270,8 @@ public Aggregator[] subAggregators() {
251270
public Aggregator subAggregator(String aggName) {
252271
if (subAggregatorbyName == null) {
253272
subAggregatorbyName = new HashMap<>(subAggregators.length);
254-
for (int i = 0; i < subAggregators.length; i++) {
255-
subAggregatorbyName.put(subAggregators[i].name(), subAggregators[i]);
273+
for (Aggregator subAggregator : subAggregators) {
274+
subAggregatorbyName.put(subAggregator.name(), subAggregator);
256275
}
257276
}
258277
return subAggregatorbyName.get(aggName);

server/src/main/java/org/opensearch/search/aggregations/bucket/composite/CompositeAggregator.java

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -564,10 +564,13 @@ private void processLeafFromQuery(LeafReaderContext ctx, Sort indexSortPrefix) t
564564
}
565565

566566
@Override
567-
protected LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException {
568-
boolean optimized = filterRewriteOptimizationContext.tryOptimize(ctx, this::incrementBucketDocCount, segmentMatchAll(context, ctx));
569-
if (optimized) throw new CollectionTerminatedException();
567+
protected boolean tryPrecomputeAggregationForLeaf(LeafReaderContext ctx) throws IOException {
568+
finishLeaf(); // May need to wrap up previous leaf if it could not be precomputed
569+
return filterRewriteOptimizationContext.tryOptimize(ctx, this::incrementBucketDocCount, segmentMatchAll(context, ctx));
570+
}
570571

572+
@Override
573+
protected LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException {
571574
finishLeaf();
572575

573576
boolean fillDocIdSet = deferredCollectors != NO_OP_COLLECTOR;

server/src/main/java/org/opensearch/search/aggregations/bucket/histogram/DateHistogramAggregator.java

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333

3434
import org.apache.lucene.index.LeafReaderContext;
3535
import org.apache.lucene.index.SortedNumericDocValues;
36-
import org.apache.lucene.search.CollectionTerminatedException;
3736
import org.apache.lucene.search.DocIdSetIterator;
3837
import org.apache.lucene.search.ScoreMode;
3938
import org.apache.lucene.util.CollectionUtil;
@@ -187,22 +186,23 @@ public ScoreMode scoreMode() {
187186
}
188187

189188
@Override
190-
public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException {
191-
if (valuesSource == null) {
192-
return LeafBucketCollector.NO_OP_COLLECTOR;
193-
}
194-
195-
boolean optimized = filterRewriteOptimizationContext.tryOptimize(ctx, this::incrementBucketDocCount, segmentMatchAll(context, ctx));
196-
if (optimized) throw new CollectionTerminatedException();
197-
198-
SortedNumericDocValues values = valuesSource.longValues(ctx);
189+
protected boolean tryPrecomputeAggregationForLeaf(LeafReaderContext ctx) throws IOException {
199190
CompositeIndexFieldInfo supportedStarTree = getSupportedStarTree(this.context.getQueryShardContext());
200191
if (supportedStarTree != null) {
201192
if (preComputeWithStarTree(ctx, supportedStarTree) == true) {
202-
throw new CollectionTerminatedException();
193+
return true;
203194
}
204195
}
196+
return filterRewriteOptimizationContext.tryOptimize(ctx, this::incrementBucketDocCount, segmentMatchAll(context, ctx));
197+
}
205198

199+
@Override
200+
public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException {
201+
if (valuesSource == null) {
202+
return LeafBucketCollector.NO_OP_COLLECTOR;
203+
}
204+
205+
SortedNumericDocValues values = valuesSource.longValues(ctx);
206206
return new LeafBucketCollectorBase(sub, values) {
207207
@Override
208208
public void collect(int doc, long owningBucketOrd) throws IOException {

server/src/main/java/org/opensearch/search/aggregations/bucket/range/RangeAggregator.java

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
package org.opensearch.search.aggregations.bucket.range;
3333

3434
import org.apache.lucene.index.LeafReaderContext;
35-
import org.apache.lucene.search.CollectionTerminatedException;
3635
import org.apache.lucene.search.ScoreMode;
3736
import org.opensearch.core.ParseField;
3837
import org.opensearch.core.common.io.stream.StreamInput;
@@ -310,10 +309,15 @@ public ScoreMode scoreMode() {
310309
}
311310

312311
@Override
313-
public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, final LeafBucketCollector sub) throws IOException {
314-
if (segmentMatchAll(context, ctx) && filterRewriteOptimizationContext.tryOptimize(ctx, this::incrementBucketDocCount, false)) {
315-
throw new CollectionTerminatedException();
312+
protected boolean tryPrecomputeAggregationForLeaf(LeafReaderContext ctx) throws IOException {
313+
if (segmentMatchAll(context, ctx)) {
314+
return filterRewriteOptimizationContext.tryOptimize(ctx, this::incrementBucketDocCount, false);
316315
}
316+
return false;
317+
}
318+
319+
@Override
320+
public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, final LeafBucketCollector sub) throws IOException {
317321

318322
final SortedNumericDoubleValues values = valuesSource.doubleValues(ctx);
319323
return new LeafBucketCollectorBase(sub, values) {

server/src/main/java/org/opensearch/search/aggregations/bucket/terms/GlobalOrdinalsStringTermsAggregator.java

Lines changed: 36 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@
4040
import org.apache.lucene.index.SortedSetDocValues;
4141
import org.apache.lucene.index.Terms;
4242
import org.apache.lucene.index.TermsEnum;
43-
import org.apache.lucene.search.CollectionTerminatedException;
4443
import org.apache.lucene.search.Weight;
4544
import org.apache.lucene.util.ArrayUtil;
4645
import org.apache.lucene.util.BytesRef;
@@ -165,35 +164,32 @@ public void setWeight(Weight weight) {
165164
@return A LeafBucketCollector implementation with collection termination, since collection is complete
166165
@throws IOException If an I/O error occurs during reading
167166
*/
168-
LeafBucketCollector termDocFreqCollector(
169-
LeafReaderContext ctx,
170-
SortedSetDocValues globalOrds,
171-
BiConsumer<Long, Integer> ordCountConsumer
172-
) throws IOException {
167+
boolean tryCollectFromTermFrequencies(LeafReaderContext ctx, SortedSetDocValues globalOrds, BiConsumer<Long, Integer> ordCountConsumer)
168+
throws IOException {
173169
if (weight == null) {
174170
// Weight not assigned - cannot use this optimization
175-
return null;
171+
return false;
176172
} else {
177173
if (weight.count(ctx) == 0) {
178174
// No documents matches top level query on this segment, we can skip the segment entirely
179-
return LeafBucketCollector.NO_OP_COLLECTOR;
175+
return true;
180176
} else if (weight.count(ctx) != ctx.reader().maxDoc()) {
181177
// weight.count(ctx) == ctx.reader().maxDoc() implies there are no deleted documents and
182178
// top-level query matches all docs in the segment
183-
return null;
179+
return false;
184180
}
185181
}
186182

187183
Terms segmentTerms = ctx.reader().terms(this.fieldName);
188184
if (segmentTerms == null) {
189185
// Field is not indexed.
190-
return null;
186+
return false;
191187
}
192188

193189
NumericDocValues docCountValues = DocValues.getNumeric(ctx.reader(), DocCountFieldMapper.NAME);
194190
if (docCountValues.nextDoc() != NO_MORE_DOCS) {
195191
// This segment has at least one document with the _doc_count field.
196-
return null;
192+
return false;
197193
}
198194

199195
TermsEnum indexTermsEnum = segmentTerms.iterator();
@@ -217,31 +213,28 @@ LeafBucketCollector termDocFreqCollector(
217213
ordinalTerm = globalOrdinalTermsEnum.next();
218214
}
219215
}
220-
return new LeafBucketCollector() {
221-
@Override
222-
public void collect(int doc, long owningBucketOrd) throws IOException {
223-
throw new CollectionTerminatedException();
224-
}
225-
};
216+
return true;
226217
}
227218

228219
@Override
229-
public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException {
220+
protected boolean tryPrecomputeAggregationForLeaf(LeafReaderContext ctx) throws IOException {
230221
SortedSetDocValues globalOrds = valuesSource.globalOrdinalsValues(ctx);
231-
collectionStrategy.globalOrdsReady(globalOrds);
232-
233222
if (collectionStrategy instanceof DenseGlobalOrds
234223
&& this.resultStrategy instanceof StandardTermsResults
235-
&& sub == LeafBucketCollector.NO_OP_COLLECTOR) {
236-
LeafBucketCollector termDocFreqCollector = termDocFreqCollector(
224+
&& subAggregators.length == 0) {
225+
return tryCollectFromTermFrequencies(
237226
ctx,
238227
globalOrds,
239228
(ord, docCount) -> incrementBucketDocCount(collectionStrategy.globalOrdToBucketOrd(0, ord), docCount)
240229
);
241-
if (termDocFreqCollector != null) {
242-
return termDocFreqCollector;
243-
}
244230
}
231+
return false;
232+
}
233+
234+
@Override
235+
public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException {
236+
SortedSetDocValues globalOrds = valuesSource.globalOrdinalsValues(ctx);
237+
collectionStrategy.globalOrdsReady(globalOrds);
245238

246239
SortedDocValues singleValues = DocValues.unwrapSingleton(globalOrds);
247240
if (singleValues != null) {
@@ -436,6 +429,24 @@ static class LowCardinality extends GlobalOrdinalsStringTermsAggregator {
436429
this.segmentDocCounts = context.bigArrays().newLongArray(1, true);
437430
}
438431

432+
@Override
433+
protected boolean tryPrecomputeAggregationForLeaf(LeafReaderContext ctx) throws IOException {
434+
if (subAggregators.length == 0) {
435+
if (mapping != null) {
436+
mapSegmentCountsToGlobalCounts(mapping);
437+
}
438+
final SortedSetDocValues segmentOrds = valuesSource.ordinalsValues(ctx);
439+
segmentDocCounts = context.bigArrays().grow(segmentDocCounts, 1 + segmentOrds.getValueCount());
440+
mapping = valuesSource.globalOrdinalsMapping(ctx);
441+
return tryCollectFromTermFrequencies(
442+
ctx,
443+
segmentOrds,
444+
(ord, docCount) -> incrementBucketDocCount(mapping.applyAsLong(ord), docCount)
445+
);
446+
}
447+
return false;
448+
}
449+
439450
@Override
440451
public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException {
441452
if (mapping != null) {
@@ -446,17 +457,6 @@ public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCol
446457
assert sub == LeafBucketCollector.NO_OP_COLLECTOR;
447458
mapping = valuesSource.globalOrdinalsMapping(ctx);
448459

449-
if (this.resultStrategy instanceof StandardTermsResults) {
450-
LeafBucketCollector termDocFreqCollector = this.termDocFreqCollector(
451-
ctx,
452-
segmentOrds,
453-
(ord, docCount) -> incrementBucketDocCount(mapping.applyAsLong(ord), docCount)
454-
);
455-
if (termDocFreqCollector != null) {
456-
return termDocFreqCollector;
457-
}
458-
}
459-
460460
final SortedDocValues singleValues = DocValues.unwrapSingleton(segmentOrds);
461461
if (singleValues != null) {
462462
segmentsWithSingleValuedOrds++;

server/src/main/java/org/opensearch/search/aggregations/metrics/AvgAggregator.java

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
package org.opensearch.search.aggregations.metrics;
3333

3434
import org.apache.lucene.index.LeafReaderContext;
35-
import org.apache.lucene.search.CollectionTerminatedException;
3635
import org.apache.lucene.search.DocIdSetIterator;
3736
import org.apache.lucene.search.ScoreMode;
3837
import org.apache.lucene.util.FixedBitSet;
@@ -104,23 +103,29 @@ public ScoreMode scoreMode() {
104103
}
105104

106105
@Override
107-
public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, final LeafBucketCollector sub) throws IOException {
106+
protected boolean tryPrecomputeAggregationForLeaf(LeafReaderContext ctx) throws IOException {
108107
if (valuesSource == null) {
109-
return LeafBucketCollector.NO_OP_COLLECTOR;
108+
return false;
110109
}
111110
CompositeIndexFieldInfo supportedStarTree = getSupportedStarTree(this.context.getQueryShardContext());
112111
if (supportedStarTree != null) {
113112
if (parent != null && subAggregators.length == 0) {
114113
// If this a child aggregator, then the parent will trigger star-tree pre-computation.
115114
// Returning NO_OP_COLLECTOR explicitly because the getLeafCollector() are invoked starting from innermost aggregators
116-
return LeafBucketCollector.NO_OP_COLLECTOR;
115+
return true;
117116
}
118-
return getStarTreeLeafCollector(ctx, sub, supportedStarTree);
117+
precomputeLeafUsingStarTree(ctx, supportedStarTree);
118+
return true;
119119
}
120-
return getDefaultLeafCollector(ctx, sub);
120+
return false;
121121
}
122122

123-
private LeafBucketCollector getDefaultLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException {
123+
@Override
124+
public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, final LeafBucketCollector sub) throws IOException {
125+
if (valuesSource == null) {
126+
return LeafBucketCollector.NO_OP_COLLECTOR;
127+
}
128+
124129
final BigArrays bigArrays = context.bigArrays();
125130
final SortedNumericDoubleValues values = valuesSource.doubleValues(ctx);
126131
final CompensatedSum kahanSummation = new CompensatedSum(0, 0);
@@ -154,8 +159,7 @@ public void collect(int doc, long bucket) throws IOException {
154159
};
155160
}
156161

157-
public LeafBucketCollector getStarTreeLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub, CompositeIndexFieldInfo starTree)
158-
throws IOException {
162+
private void precomputeLeafUsingStarTree(LeafReaderContext ctx, CompositeIndexFieldInfo starTree) throws IOException {
159163
StarTreeValues starTreeValues = StarTreeQueryHelper.getStarTreeValues(ctx, starTree);
160164
assert starTreeValues != null;
161165

@@ -200,12 +204,6 @@ public LeafBucketCollector getStarTreeLeafCollector(LeafReaderContext ctx, LeafB
200204

201205
sums.set(0, kahanSummation.value());
202206
compensations.set(0, kahanSummation.delta());
203-
return new LeafBucketCollectorBase(sub, valuesSource.doubleValues(ctx)) {
204-
@Override
205-
public void collect(int doc, long bucket) {
206-
throw new CollectionTerminatedException();
207-
}
208-
};
209207
}
210208

211209
@Override

server/src/main/java/org/opensearch/search/aggregations/metrics/MaxAggregator.java

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,24 @@ public ScoreMode scoreMode() {
104104
return valuesSource != null && valuesSource.needsScores() ? ScoreMode.COMPLETE : ScoreMode.COMPLETE_NO_SCORES;
105105
}
106106

107+
@Override
108+
protected boolean tryPrecomputeAggregationForLeaf(LeafReaderContext ctx) throws IOException {
109+
if (valuesSource == null) {
110+
return false;
111+
}
112+
CompositeIndexFieldInfo supportedStarTree = getSupportedStarTree(this.context.getQueryShardContext());
113+
if (supportedStarTree != null) {
114+
if (parent != null && subAggregators.length == 0) {
115+
// If this a child aggregator, then the parent will trigger star-tree pre-computation.
116+
// Returning NO_OP_COLLECTOR explicitly because the getLeafCollector() are invoked starting from innermost aggregators
117+
return true;
118+
}
119+
precomputeLeafUsingStarTree(ctx, supportedStarTree);
120+
return true;
121+
}
122+
return false;
123+
}
124+
107125
@Override
108126
public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, final LeafBucketCollector sub) throws IOException {
109127
if (valuesSource == null) {
@@ -130,20 +148,6 @@ public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, final LeafBuc
130148
}
131149
}
132150

133-
CompositeIndexFieldInfo supportedStarTree = getSupportedStarTree(this.context.getQueryShardContext());
134-
if (supportedStarTree != null) {
135-
if (parent != null && subAggregators.length == 0) {
136-
// If this a child aggregator, then the parent will trigger star-tree pre-computation.
137-
// Returning NO_OP_COLLECTOR explicitly because the getLeafCollector() are invoked starting from innermost aggregators
138-
return LeafBucketCollector.NO_OP_COLLECTOR;
139-
}
140-
getStarTreeCollector(ctx, sub, supportedStarTree);
141-
}
142-
return getDefaultLeafCollector(ctx, sub);
143-
}
144-
145-
private LeafBucketCollector getDefaultLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException {
146-
147151
final BigArrays bigArrays = context.bigArrays();
148152
final SortedNumericDoubleValues allValues = valuesSource.doubleValues(ctx);
149153
final NumericDoubleValues values = MultiValueMode.MAX.select(allValues);
@@ -167,9 +171,9 @@ public void collect(int doc, long bucket) throws IOException {
167171
};
168172
}
169173

170-
public void getStarTreeCollector(LeafReaderContext ctx, LeafBucketCollector sub, CompositeIndexFieldInfo starTree) throws IOException {
174+
private void precomputeLeafUsingStarTree(LeafReaderContext ctx, CompositeIndexFieldInfo starTree) throws IOException {
171175
AtomicReference<Double> max = new AtomicReference<>(maxes.get(0));
172-
StarTreeQueryHelper.getStarTreeLeafCollector(context, valuesSource, ctx, sub, starTree, MetricStat.MAX.getTypeName(), value -> {
176+
StarTreeQueryHelper.precomputeLeafUsingStarTree(context, valuesSource, ctx, starTree, MetricStat.MAX.getTypeName(), value -> {
173177
max.set(Math.max(max.get(), (NumericUtils.sortableLongToDouble(value))));
174178
}, () -> maxes.set(0, max.get()));
175179
}

0 commit comments

Comments
 (0)