Skip to content

Commit ae1a2ee

Browse files
committed
Add support for two-phase iterators to DenseConjunctionBulkScorer. (#14359)
While this could help "normal" conjunctions a bit by applying "normal" iterators more efficiently, the main motivation is to efficiently evaluate range queries on fields that have a doc-value index enabled. These range queries produce two-phase iterators that should match large contiguous range of doc IDs. Using `DenseConjunctionBulkScorer` helps skip these clauses from the conjunction on doc ID ranges that they fully match.
1 parent ffbbd7c commit ae1a2ee

File tree

9 files changed

+694
-58
lines changed

9 files changed

+694
-58
lines changed

lucene/core/src/java/org/apache/lucene/search/BooleanScorerSupplier.java

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -332,11 +332,9 @@ BulkScorer filteredOptionalBulkScorer() throws IOException {
332332
assert scoreMode.needsScores() == false;
333333
filters.add(new DisjunctionSumScorer(optionalScorers, scoreMode, cost));
334334

335-
if (filters.stream().map(Scorer::twoPhaseIterator).allMatch(Objects::isNull)
336-
&& maxDoc >= DenseConjunctionBulkScorer.WINDOW_SIZE
335+
if (maxDoc >= DenseConjunctionBulkScorer.WINDOW_SIZE
337336
&& cost >= maxDoc / DenseConjunctionBulkScorer.DENSITY_THRESHOLD_INVERSE) {
338-
return new DenseConjunctionBulkScorer(
339-
filters.stream().map(Scorer::iterator).toList(), maxDoc, 0f);
337+
return DenseConjunctionBulkScorer.of(filters, maxDoc, 0f);
340338
}
341339

342340
return new DefaultBulkScorer(new ConjunctionScorer(filters, Collections.emptyList()));
@@ -392,14 +390,14 @@ private BulkScorer requiredBulkScorer() throws IOException {
392390
}
393391
if (scoreMode != ScoreMode.TOP_SCORES
394392
&& requiredScoring.size() + requiredNoScoring.size() >= 2
395-
&& requiredScoring.stream().map(Scorer::twoPhaseIterator).allMatch(Objects::isNull)
396-
&& requiredNoScoring.stream().map(Scorer::twoPhaseIterator).allMatch(Objects::isNull)) {
393+
&& requiredScoring.stream().map(Scorer::twoPhaseIterator).allMatch(Objects::isNull)) {
397394
if (requiredScoring.isEmpty()
398395
&& maxDoc >= DenseConjunctionBulkScorer.WINDOW_SIZE
399396
&& leadCost >= maxDoc / DenseConjunctionBulkScorer.DENSITY_THRESHOLD_INVERSE) {
400-
return new DenseConjunctionBulkScorer(
401-
requiredNoScoring.stream().map(Scorer::iterator).toList(), maxDoc, 0f);
402-
} else {
397+
return DenseConjunctionBulkScorer.of(requiredNoScoring, maxDoc, 0f);
398+
} else if (requiredNoScoring.stream()
399+
.map(Scorer::twoPhaseIterator)
400+
.allMatch(Objects::isNull)) {
403401
return new ConjunctionBulkScorer(requiredScoring, requiredNoScoring);
404402
}
405403
}

lucene/core/src/java/org/apache/lucene/search/ConstantScoreScorerSupplier.java

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import java.io.IOException;
2020
import java.util.Collections;
21+
import java.util.List;
2122
import org.apache.lucene.search.Weight.DefaultBulkScorer;
2223

2324
/**
@@ -78,9 +79,18 @@ public final Scorer get(long leadCost) throws IOException {
7879
public final BulkScorer bulkScorer() throws IOException {
7980
DocIdSetIterator iterator = iterator(Long.MAX_VALUE);
8081
if (maxDoc >= DenseConjunctionBulkScorer.WINDOW_SIZE / 2
81-
&& iterator.cost() >= maxDoc / DenseConjunctionBulkScorer.DENSITY_THRESHOLD_INVERSE
82-
&& TwoPhaseIterator.unwrap(iterator) == null) {
83-
return new DenseConjunctionBulkScorer(Collections.singletonList(iterator), maxDoc, score);
82+
&& iterator.cost() >= maxDoc / DenseConjunctionBulkScorer.DENSITY_THRESHOLD_INVERSE) {
83+
TwoPhaseIterator twoPhase = TwoPhaseIterator.unwrap(iterator);
84+
List<DocIdSetIterator> iterators;
85+
List<TwoPhaseIterator> twoPhases;
86+
if (twoPhase == null) {
87+
iterators = Collections.singletonList(iterator);
88+
twoPhases = Collections.emptyList();
89+
} else {
90+
iterators = Collections.emptyList();
91+
twoPhases = Collections.singletonList(twoPhase);
92+
}
93+
return new DenseConjunctionBulkScorer(iterators, twoPhases, maxDoc, score);
8494
} else {
8595
return new DefaultBulkScorer(new ConstantScoreScorer(score, scoreMode, iterator));
8696
}

lucene/core/src/java/org/apache/lucene/search/DenseConjunctionBulkScorer.java

Lines changed: 161 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,28 @@
3030
*/
3131
final class DenseConjunctionBulkScorer extends BulkScorer {
3232

33+
private record DisiWrapper(DocIdSetIterator approximation, TwoPhaseIterator twoPhase) {
34+
DisiWrapper(DocIdSetIterator iterator) {
35+
this(iterator, null);
36+
}
37+
38+
DisiWrapper(TwoPhaseIterator twoPhase) {
39+
this(twoPhase.approximation(), twoPhase);
40+
}
41+
42+
int docID() {
43+
return approximation().docID();
44+
}
45+
46+
int docIDRunEnd() throws IOException {
47+
if (twoPhase() == null) {
48+
return approximation().docIDRunEnd();
49+
} else {
50+
return twoPhase().docIDRunEnd();
51+
}
52+
}
53+
}
54+
3355
// Use a small-ish window size to make sure that we can take advantage of gaps in the postings of
3456
// clauses that are not leading iteration.
3557
static final int WINDOW_SIZE = 4096;
@@ -39,25 +61,49 @@ final class DenseConjunctionBulkScorer extends BulkScorer {
3961
static final int DENSITY_THRESHOLD_INVERSE = Long.SIZE / 2;
4062

4163
private final int maxDoc;
42-
private final List<DocIdSetIterator> iterators;
64+
private final List<DisiWrapper> iterators;
4365
private final SimpleScorable scorable;
4466

4567
private final FixedBitSet windowMatches = new FixedBitSet(WINDOW_SIZE);
4668
private final FixedBitSet clauseWindowMatches = new FixedBitSet(WINDOW_SIZE);
47-
private final List<DocIdSetIterator> windowIterators = new ArrayList<>();
69+
private final List<DocIdSetIterator> windowApproximations = new ArrayList<>();
70+
private final List<TwoPhaseIterator> windowTwoPhases = new ArrayList<>();
4871
private final DocIdStreamView docIdStreamView = new DocIdStreamView();
4972
private final RangeDocIdStream rangeDocIdStream = new RangeDocIdStream();
5073
private final SingleIteratorDocIdStream singleIteratorDocIdStream =
5174
new SingleIteratorDocIdStream();
5275

53-
DenseConjunctionBulkScorer(List<DocIdSetIterator> iterators, int maxDoc, float constantScore) {
54-
if (iterators.isEmpty()) {
76+
static DenseConjunctionBulkScorer of(List<Scorer> filters, int maxDoc, float constantScore) {
77+
List<DocIdSetIterator> iterators = new ArrayList<>();
78+
List<TwoPhaseIterator> twoPhases = new ArrayList<>();
79+
for (Scorer filter : filters) {
80+
TwoPhaseIterator twoPhase = filter.twoPhaseIterator();
81+
if (twoPhase != null) {
82+
twoPhases.add(twoPhase);
83+
} else {
84+
iterators.add(filter.iterator());
85+
}
86+
}
87+
return new DenseConjunctionBulkScorer(iterators, twoPhases, maxDoc, constantScore);
88+
}
89+
90+
DenseConjunctionBulkScorer(
91+
List<DocIdSetIterator> iterators,
92+
List<TwoPhaseIterator> twoPhases,
93+
int maxDoc,
94+
float constantScore) {
95+
if (iterators.isEmpty() && twoPhases.isEmpty()) {
5596
throw new IllegalArgumentException("Expected one or more iterators, got 0");
5697
}
5798
this.maxDoc = maxDoc;
58-
iterators = new ArrayList<>(iterators);
59-
iterators.sort(Comparator.comparingLong(DocIdSetIterator::cost));
60-
this.iterators = iterators;
99+
this.iterators = new ArrayList<>();
100+
for (DocIdSetIterator iterator : iterators) {
101+
this.iterators.add(new DisiWrapper(iterator));
102+
}
103+
for (TwoPhaseIterator twoPhase : twoPhases) {
104+
this.iterators.add(new DisiWrapper(twoPhase));
105+
}
106+
this.iterators.sort(Comparator.comparing(w -> w.approximation().cost()));
61107
this.scorable = new SimpleScorable();
62108
scorable.score = constantScore;
63109
}
@@ -66,21 +112,21 @@ final class DenseConjunctionBulkScorer extends BulkScorer {
66112
public int score(LeafCollector collector, Bits acceptDocs, int min, int max) throws IOException {
67113
collector.setScorer(scorable);
68114

69-
List<DocIdSetIterator> iterators = this.iterators;
115+
List<DisiWrapper> iterators = this.iterators;
70116
if (collector.competitiveIterator() != null) {
71117
iterators = new ArrayList<>(iterators);
72-
iterators.add(collector.competitiveIterator());
118+
iterators.add(new DisiWrapper(collector.competitiveIterator()));
73119
}
74120

75-
for (DocIdSetIterator it : iterators) {
76-
min = Math.max(min, it.docID());
121+
for (DisiWrapper w : iterators) {
122+
min = Math.max(min, w.approximation().docID());
77123
}
78124

79125
max = Math.min(max, maxDoc);
80126

81-
DocIdSetIterator lead = iterators.get(0);
127+
DisiWrapper lead = iterators.get(0);
82128
if (lead.docID() < min) {
83-
min = lead.advance(min);
129+
min = lead.approximation.advance(min);
84130
}
85131

86132
while (min < max) {
@@ -108,17 +154,17 @@ private static int advance(FixedBitSet set, int i) {
108154
}
109155

110156
private int scoreWindow(
111-
LeafCollector collector, Bits acceptDocs, List<DocIdSetIterator> iterators, int min, int max)
157+
LeafCollector collector, Bits acceptDocs, List<DisiWrapper> iterators, int min, int max)
112158
throws IOException {
113159

114160
// Advance all iterators to the first doc that is greater than or equal to min. This is
115161
// important as this is the only place where we can take advantage of a large gap between
116162
// consecutive matches in any clause.
117-
for (DocIdSetIterator iterator : iterators) {
118-
if (iterator.docID() >= min) {
119-
min = iterator.docID();
163+
for (DisiWrapper w : iterators) {
164+
if (w.docID() >= min) {
165+
min = w.docID();
120166
} else {
121-
min = iterator.advance(min);
167+
min = w.approximation().advance(min);
122168
}
123169
if (min >= max) {
124170
return min;
@@ -127,12 +173,12 @@ private int scoreWindow(
127173

128174
if (acceptDocs == null) {
129175
int minDocIDRunEnd = max;
130-
for (DocIdSetIterator iterator : iterators) {
131-
if (iterator.docID() > min) {
176+
for (DisiWrapper w : iterators) {
177+
if (w.docID() > min) {
132178
minDocIDRunEnd = min;
133179
break;
134180
} else {
135-
minDocIDRunEnd = Math.min(minDocIDRunEnd, iterator.docIDRunEnd());
181+
minDocIDRunEnd = Math.min(minDocIDRunEnd, w.docIDRunEnd());
136182
}
137183
}
138184

@@ -147,22 +193,34 @@ private int scoreWindow(
147193

148194
int bitsetWindowMax = (int) Math.min(max, (long) min + WINDOW_SIZE);
149195

150-
for (DocIdSetIterator it : iterators) {
151-
if (it.docID() > min || it.docIDRunEnd() < bitsetWindowMax) {
152-
windowIterators.add(it);
196+
for (DisiWrapper w : iterators) {
197+
if (w.docID() > min || w.docIDRunEnd() < bitsetWindowMax) {
198+
windowApproximations.add(w.approximation());
199+
if (w.twoPhase() != null) {
200+
windowTwoPhases.add(w.twoPhase());
201+
}
153202
}
154203
}
155204

156-
if (acceptDocs == null && windowIterators.size() == 1) {
157-
// We have a range of doc IDs where all matches of an iterator are matches of the conjunction.
158-
singleIteratorDocIdStream.iterator = windowIterators.get(0);
159-
singleIteratorDocIdStream.from = min;
160-
singleIteratorDocIdStream.to = bitsetWindowMax;
161-
collector.collect(singleIteratorDocIdStream);
205+
if (windowTwoPhases.isEmpty()) {
206+
if (acceptDocs == null && windowApproximations.size() == 1) {
207+
// We have a range of doc IDs where all matches of an iterator are matches of the
208+
// conjunction.
209+
singleIteratorDocIdStream.iterator = windowApproximations.get(0);
210+
singleIteratorDocIdStream.from = min;
211+
singleIteratorDocIdStream.to = bitsetWindowMax;
212+
collector.collect(singleIteratorDocIdStream);
213+
} else {
214+
scoreWindowUsingBitSet(collector, acceptDocs, windowApproximations, min, bitsetWindowMax);
215+
}
162216
} else {
163-
scoreWindowUsingBitSet(collector, acceptDocs, windowIterators, min, bitsetWindowMax);
217+
windowTwoPhases.sort(Comparator.comparingDouble(TwoPhaseIterator::matchCost));
218+
scoreWindowUsingLeapFrog(
219+
collector, acceptDocs, windowApproximations, windowTwoPhases, min, bitsetWindowMax);
220+
windowTwoPhases.clear();
164221
}
165-
windowIterators.clear();
222+
windowApproximations.clear();
223+
166224
return bitsetWindowMax;
167225
}
168226

@@ -238,9 +296,79 @@ private void scoreWindowUsingBitSet(
238296
windowMatches.clear();
239297
}
240298

299+
private static void scoreWindowUsingLeapFrog(
300+
LeafCollector collector,
301+
Bits acceptDocs,
302+
List<DocIdSetIterator> approximations,
303+
List<TwoPhaseIterator> twoPhases,
304+
int min,
305+
int max)
306+
throws IOException {
307+
assert twoPhases.size() > 0;
308+
assert approximations.size() >= twoPhases.size();
309+
310+
if (approximations.size() == 1) {
311+
// scoreWindowUsingLeapFrog is only used if there is at least one two-phase iterator, so our
312+
// single clause is a two-phase iterator
313+
assert twoPhases.size() == 1;
314+
DocIdSetIterator approximation = approximations.get(0);
315+
TwoPhaseIterator twoPhase = twoPhases.get(0);
316+
if (approximation.docID() < min) {
317+
approximation.advance(min);
318+
}
319+
for (int doc = approximation.docID(); doc < max; doc = approximation.nextDoc()) {
320+
if ((acceptDocs == null || acceptDocs.get(doc)) && twoPhase.matches()) {
321+
collector.collect(doc);
322+
}
323+
}
324+
} else {
325+
DocIdSetIterator lead1 = approximations.get(0);
326+
DocIdSetIterator lead2 = approximations.get(1);
327+
328+
if (lead1.docID() < min) {
329+
lead1.advance(min);
330+
}
331+
332+
advanceHead:
333+
for (int doc = lead1.docID(); doc < max; ) {
334+
if (acceptDocs != null && acceptDocs.get(doc) == false) {
335+
doc = lead1.nextDoc();
336+
continue;
337+
}
338+
int doc2 = lead2.docID();
339+
if (doc2 < doc) {
340+
doc2 = lead2.advance(doc);
341+
}
342+
if (doc != doc2) {
343+
doc = lead1.advance(Math.min(doc2, max));
344+
continue;
345+
}
346+
for (int i = 2; i < approximations.size(); ++i) {
347+
DocIdSetIterator other = approximations.get(i);
348+
int docN = other.docID();
349+
if (docN < doc) {
350+
docN = other.advance(doc);
351+
}
352+
if (doc != docN) {
353+
doc = lead1.advance(Math.min(docN, max));
354+
continue advanceHead;
355+
}
356+
}
357+
for (TwoPhaseIterator twoPhase : twoPhases) {
358+
if (twoPhase.matches() == false) {
359+
doc = lead1.nextDoc();
360+
continue advanceHead;
361+
}
362+
}
363+
collector.collect(doc);
364+
doc = lead1.nextDoc();
365+
}
366+
}
367+
}
368+
241369
@Override
242370
public long cost() {
243-
return iterators.get(0).cost();
371+
return iterators.get(0).approximation().cost();
244372
}
245373

246374
final class DocIdStreamView extends DocIdStream {

lucene/core/src/java/org/apache/lucene/search/DocValuesRangeIterator.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,14 @@ public final boolean matches() throws IOException {
210210
};
211211
}
212212

213+
@Override
214+
public int docIDRunEnd() throws IOException {
215+
if (approximation.match == Match.YES) {
216+
return approximation.upTo + 1;
217+
}
218+
return super.docIDRunEnd();
219+
}
220+
213221
@Override
214222
public float matchCost() {
215223
return innerTwoPhase.matchCost();

lucene/core/src/java/org/apache/lucene/search/TwoPhaseIterator.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,4 +118,18 @@ public DocIdSetIterator approximation() {
118118
* indexing an array. The returned value must be positive.
119119
*/
120120
public abstract float matchCost();
121+
122+
/**
123+
* Returns the end of the run of consecutive doc IDs that match this {@link TwoPhaseIterator} and
124+
* that contains the current doc ID of the approximation, that is: one plus the last doc ID of the
125+
* run.
126+
*
127+
* <p><b>Note</b>: It is illegal to call this method when the approximation is exhausted or not
128+
* positioned.
129+
*
130+
* <p>The default implementation returns the current doc ID of the approximation.
131+
*/
132+
public int docIDRunEnd() throws IOException {
133+
return approximation().docID();
134+
}
121135
}

lucene/core/src/test/org/apache/lucene/search/ReadAheadMatchAllDocsQuery.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,8 @@ public Scorer get(long leadCost) throws IOException {
7777
public BulkScorer bulkScorer() throws IOException {
7878
List<DocIdSetIterator> clauses =
7979
Collections.singletonList(DocIdSetIterator.all(context.reader().maxDoc()));
80-
return new DenseConjunctionBulkScorer(clauses, context.reader().maxDoc(), score());
80+
return new DenseConjunctionBulkScorer(
81+
clauses, Collections.emptyList(), context.reader().maxDoc(), score());
8182
}
8283

8384
@Override

0 commit comments

Comments
 (0)