Skip to content

Add support for two-phase iterators to DenseConjunctionBulkScorer. #14359

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Mar 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -332,11 +332,9 @@ BulkScorer filteredOptionalBulkScorer() throws IOException {
assert scoreMode.needsScores() == false;
filters.add(new DisjunctionSumScorer(optionalScorers, scoreMode, cost));

if (filters.stream().map(Scorer::twoPhaseIterator).allMatch(Objects::isNull)
&& maxDoc >= DenseConjunctionBulkScorer.WINDOW_SIZE
if (maxDoc >= DenseConjunctionBulkScorer.WINDOW_SIZE
&& cost >= maxDoc / DenseConjunctionBulkScorer.DENSITY_THRESHOLD_INVERSE) {
return new DenseConjunctionBulkScorer(
filters.stream().map(Scorer::iterator).toList(), maxDoc, 0f);
return DenseConjunctionBulkScorer.of(filters, maxDoc, 0f);
}

return new DefaultBulkScorer(new ConjunctionScorer(filters, Collections.emptyList()));
Expand Down Expand Up @@ -392,14 +390,14 @@ private BulkScorer requiredBulkScorer() throws IOException {
}
if (scoreMode != ScoreMode.TOP_SCORES
&& requiredScoring.size() + requiredNoScoring.size() >= 2
&& requiredScoring.stream().map(Scorer::twoPhaseIterator).allMatch(Objects::isNull)
&& requiredNoScoring.stream().map(Scorer::twoPhaseIterator).allMatch(Objects::isNull)) {
&& requiredScoring.stream().map(Scorer::twoPhaseIterator).allMatch(Objects::isNull)) {
if (requiredScoring.isEmpty()
&& maxDoc >= DenseConjunctionBulkScorer.WINDOW_SIZE
&& leadCost >= maxDoc / DenseConjunctionBulkScorer.DENSITY_THRESHOLD_INVERSE) {
return new DenseConjunctionBulkScorer(
requiredNoScoring.stream().map(Scorer::iterator).toList(), maxDoc, 0f);
} else {
return DenseConjunctionBulkScorer.of(requiredNoScoring, maxDoc, 0f);
} else if (requiredNoScoring.stream()
.map(Scorer::twoPhaseIterator)
.allMatch(Objects::isNull)) {
return new ConjunctionBulkScorer(requiredScoring, requiredNoScoring);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import java.io.IOException;
import java.util.Collections;
import java.util.List;
import org.apache.lucene.search.Weight.DefaultBulkScorer;

/**
Expand Down Expand Up @@ -78,9 +79,18 @@ public final Scorer get(long leadCost) throws IOException {
public final BulkScorer bulkScorer() throws IOException {
DocIdSetIterator iterator = iterator(Long.MAX_VALUE);
if (maxDoc >= DenseConjunctionBulkScorer.WINDOW_SIZE / 2
&& iterator.cost() >= maxDoc / DenseConjunctionBulkScorer.DENSITY_THRESHOLD_INVERSE
&& TwoPhaseIterator.unwrap(iterator) == null) {
return new DenseConjunctionBulkScorer(Collections.singletonList(iterator), maxDoc, score);
&& iterator.cost() >= maxDoc / DenseConjunctionBulkScorer.DENSITY_THRESHOLD_INVERSE) {
TwoPhaseIterator twoPhase = TwoPhaseIterator.unwrap(iterator);
List<DocIdSetIterator> iterators;
List<TwoPhaseIterator> twoPhases;
if (twoPhase == null) {
iterators = Collections.singletonList(iterator);
twoPhases = Collections.emptyList();
} else {
iterators = Collections.emptyList();
twoPhases = Collections.singletonList(twoPhase);
}
return new DenseConjunctionBulkScorer(iterators, twoPhases, maxDoc, score);
} else {
return new DefaultBulkScorer(new ConstantScoreScorer(score, scoreMode, iterator));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,28 @@
*/
final class DenseConjunctionBulkScorer extends BulkScorer {

private record DisiWrapper(DocIdSetIterator approximation, TwoPhaseIterator twoPhase) {
DisiWrapper(DocIdSetIterator iterator) {
this(iterator, null);
}

DisiWrapper(TwoPhaseIterator twoPhase) {
this(twoPhase.approximation(), twoPhase);
}

int docID() {
return approximation().docID();
}

int docIDRunEnd() throws IOException {
if (twoPhase() == null) {
return approximation().docIDRunEnd();
} else {
return twoPhase().docIDRunEnd();
}
}
}

// Use a small-ish window size to make sure that we can take advantage of gaps in the postings of
// clauses that are not leading iteration.
static final int WINDOW_SIZE = 4096;
Expand All @@ -39,25 +61,49 @@ final class DenseConjunctionBulkScorer extends BulkScorer {
static final int DENSITY_THRESHOLD_INVERSE = Long.SIZE / 2;

private final int maxDoc;
private final List<DocIdSetIterator> iterators;
private final List<DisiWrapper> iterators;
private final SimpleScorable scorable;

private final FixedBitSet windowMatches = new FixedBitSet(WINDOW_SIZE);
private final FixedBitSet clauseWindowMatches = new FixedBitSet(WINDOW_SIZE);
private final List<DocIdSetIterator> windowIterators = new ArrayList<>();
private final List<DocIdSetIterator> windowApproximations = new ArrayList<>();
private final List<TwoPhaseIterator> windowTwoPhases = new ArrayList<>();
private final DocIdStreamView docIdStreamView = new DocIdStreamView();
private final RangeDocIdStream rangeDocIdStream = new RangeDocIdStream();
private final SingleIteratorDocIdStream singleIteratorDocIdStream =
new SingleIteratorDocIdStream();

DenseConjunctionBulkScorer(List<DocIdSetIterator> iterators, int maxDoc, float constantScore) {
if (iterators.isEmpty()) {
static DenseConjunctionBulkScorer of(List<Scorer> filters, int maxDoc, float constantScore) {
List<DocIdSetIterator> iterators = new ArrayList<>();
List<TwoPhaseIterator> twoPhases = new ArrayList<>();
for (Scorer filter : filters) {
TwoPhaseIterator twoPhase = filter.twoPhaseIterator();
if (twoPhase != null) {
twoPhases.add(twoPhase);
} else {
iterators.add(filter.iterator());
}
}
return new DenseConjunctionBulkScorer(iterators, twoPhases, maxDoc, constantScore);
}

DenseConjunctionBulkScorer(
List<DocIdSetIterator> iterators,
List<TwoPhaseIterator> twoPhases,
int maxDoc,
float constantScore) {
if (iterators.isEmpty() && twoPhases.isEmpty()) {
throw new IllegalArgumentException("Expected one or more iterators, got 0");
}
this.maxDoc = maxDoc;
iterators = new ArrayList<>(iterators);
iterators.sort(Comparator.comparingLong(DocIdSetIterator::cost));
this.iterators = iterators;
this.iterators = new ArrayList<>();
for (DocIdSetIterator iterator : iterators) {
this.iterators.add(new DisiWrapper(iterator));
}
for (TwoPhaseIterator twoPhase : twoPhases) {
this.iterators.add(new DisiWrapper(twoPhase));
}
this.iterators.sort(Comparator.comparing(w -> w.approximation().cost()));
this.scorable = new SimpleScorable();
scorable.score = constantScore;
}
Expand All @@ -66,21 +112,21 @@ final class DenseConjunctionBulkScorer extends BulkScorer {
public int score(LeafCollector collector, Bits acceptDocs, int min, int max) throws IOException {
collector.setScorer(scorable);

List<DocIdSetIterator> iterators = this.iterators;
List<DisiWrapper> iterators = this.iterators;
if (collector.competitiveIterator() != null) {
iterators = new ArrayList<>(iterators);
iterators.add(collector.competitiveIterator());
iterators.add(new DisiWrapper(collector.competitiveIterator()));
}

for (DocIdSetIterator it : iterators) {
min = Math.max(min, it.docID());
for (DisiWrapper w : iterators) {
min = Math.max(min, w.approximation().docID());
}

max = Math.min(max, maxDoc);

DocIdSetIterator lead = iterators.get(0);
DisiWrapper lead = iterators.get(0);
if (lead.docID() < min) {
min = lead.advance(min);
min = lead.approximation.advance(min);
}

while (min < max) {
Expand Down Expand Up @@ -108,17 +154,17 @@ private static int advance(FixedBitSet set, int i) {
}

private int scoreWindow(
LeafCollector collector, Bits acceptDocs, List<DocIdSetIterator> iterators, int min, int max)
LeafCollector collector, Bits acceptDocs, List<DisiWrapper> iterators, int min, int max)
throws IOException {

// Advance all iterators to the first doc that is greater than or equal to min. This is
// important as this is the only place where we can take advantage of a large gap between
// consecutive matches in any clause.
for (DocIdSetIterator iterator : iterators) {
if (iterator.docID() >= min) {
min = iterator.docID();
for (DisiWrapper w : iterators) {
if (w.docID() >= min) {
min = w.docID();
} else {
min = iterator.advance(min);
min = w.approximation().advance(min);
}
if (min >= max) {
return min;
Expand All @@ -127,12 +173,12 @@ private int scoreWindow(

if (acceptDocs == null) {
int minDocIDRunEnd = max;
for (DocIdSetIterator iterator : iterators) {
if (iterator.docID() > min) {
for (DisiWrapper w : iterators) {
if (w.docID() > min) {
minDocIDRunEnd = min;
break;
} else {
minDocIDRunEnd = Math.min(minDocIDRunEnd, iterator.docIDRunEnd());
minDocIDRunEnd = Math.min(minDocIDRunEnd, w.docIDRunEnd());
}
}

Expand All @@ -147,22 +193,34 @@ private int scoreWindow(

int bitsetWindowMax = (int) Math.min(max, (long) min + WINDOW_SIZE);

for (DocIdSetIterator it : iterators) {
if (it.docID() > min || it.docIDRunEnd() < bitsetWindowMax) {
windowIterators.add(it);
for (DisiWrapper w : iterators) {
if (w.docID() > min || w.docIDRunEnd() < bitsetWindowMax) {
windowApproximations.add(w.approximation());
if (w.twoPhase() != null) {
windowTwoPhases.add(w.twoPhase());
}
}
}

if (acceptDocs == null && windowIterators.size() == 1) {
// We have a range of doc IDs where all matches of an iterator are matches of the conjunction.
singleIteratorDocIdStream.iterator = windowIterators.get(0);
singleIteratorDocIdStream.from = min;
singleIteratorDocIdStream.to = bitsetWindowMax;
collector.collect(singleIteratorDocIdStream);
if (windowTwoPhases.isEmpty()) {
if (acceptDocs == null && windowApproximations.size() == 1) {
// We have a range of doc IDs where all matches of an iterator are matches of the
// conjunction.
singleIteratorDocIdStream.iterator = windowApproximations.get(0);
singleIteratorDocIdStream.from = min;
singleIteratorDocIdStream.to = bitsetWindowMax;
collector.collect(singleIteratorDocIdStream);
} else {
scoreWindowUsingBitSet(collector, acceptDocs, windowApproximations, min, bitsetWindowMax);
}
} else {
scoreWindowUsingBitSet(collector, acceptDocs, windowIterators, min, bitsetWindowMax);
windowTwoPhases.sort(Comparator.comparingDouble(TwoPhaseIterator::matchCost));
scoreWindowUsingLeapFrog(
collector, acceptDocs, windowApproximations, windowTwoPhases, min, bitsetWindowMax);
windowTwoPhases.clear();
}
windowIterators.clear();
windowApproximations.clear();

return bitsetWindowMax;
}

Expand Down Expand Up @@ -238,9 +296,79 @@ private void scoreWindowUsingBitSet(
windowMatches.clear();
}

private static void scoreWindowUsingLeapFrog(
LeafCollector collector,
Bits acceptDocs,
List<DocIdSetIterator> approximations,
List<TwoPhaseIterator> twoPhases,
int min,
int max)
throws IOException {
assert twoPhases.size() > 0;
assert approximations.size() >= twoPhases.size();

if (approximations.size() == 1) {
// scoreWindowUsingLeapFrog is only used if there is at least one two-phase iterator, so our
// single clause is a two-phase iterator
assert twoPhases.size() == 1;
DocIdSetIterator approximation = approximations.get(0);
TwoPhaseIterator twoPhase = twoPhases.get(0);
if (approximation.docID() < min) {
approximation.advance(min);
}
for (int doc = approximation.docID(); doc < max; doc = approximation.nextDoc()) {
if ((acceptDocs == null || acceptDocs.get(doc)) && twoPhase.matches()) {
collector.collect(doc);
}
}
} else {
DocIdSetIterator lead1 = approximations.get(0);
DocIdSetIterator lead2 = approximations.get(1);

if (lead1.docID() < min) {
lead1.advance(min);
}

advanceHead:
for (int doc = lead1.docID(); doc < max; ) {
if (acceptDocs != null && acceptDocs.get(doc) == false) {
doc = lead1.nextDoc();
continue;
}
int doc2 = lead2.docID();
if (doc2 < doc) {
doc2 = lead2.advance(doc);
}
if (doc != doc2) {
doc = lead1.advance(Math.min(doc2, max));
continue;
}
for (int i = 2; i < approximations.size(); ++i) {
DocIdSetIterator other = approximations.get(i);
int docN = other.docID();
if (docN < doc) {
docN = other.advance(doc);
}
if (doc != docN) {
doc = lead1.advance(Math.min(docN, max));
continue advanceHead;
}
}
for (TwoPhaseIterator twoPhase : twoPhases) {
if (twoPhase.matches() == false) {
doc = lead1.nextDoc();
continue advanceHead;
}
}
collector.collect(doc);
doc = lead1.nextDoc();
}
}
}

@Override
public long cost() {
return iterators.get(0).cost();
return iterators.get(0).approximation().cost();
}

final class DocIdStreamView extends DocIdStream {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,14 @@ public final boolean matches() throws IOException {
};
}

@Override
public int docIDRunEnd() throws IOException {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have/need any tests against this implementation?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have tests for the upTo variable, which is almost the same thing. I'll improve the test to cover #docIdRunEnd() at the same time.

if (approximation.match == Match.YES) {
return approximation.upTo + 1;
}
return super.docIDRunEnd();
}

@Override
public float matchCost() {
return innerTwoPhase.matchCost();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,4 +118,18 @@ public DocIdSetIterator approximation() {
* indexing an array. The returned value must be positive.
*/
public abstract float matchCost();

/**
* Returns the end of the run of consecutive doc IDs that match this {@link TwoPhaseIterator} and
* that contains the current doc ID of the approximation, that is: one plus the last doc ID of the
* run.
*
* <p><b>Note</b>: It is illegal to call this method when the approximation is exhausted or not
* positioned.
*
* <p>The default implementation returns the current doc ID of the approximation.
*/
public int docIDRunEnd() throws IOException {
return approximation().docID();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ public Scorer get(long leadCost) throws IOException {
public BulkScorer bulkScorer() throws IOException {
List<DocIdSetIterator> clauses =
Collections.singletonList(DocIdSetIterator.all(context.reader().maxDoc()));
return new DenseConjunctionBulkScorer(clauses, context.reader().maxDoc(), score());
return new DenseConjunctionBulkScorer(
clauses, Collections.emptyList(), context.reader().maxDoc(), score());
}

@Override
Expand Down
Loading