From 51ec0e9b1c5ef9061d162d94ea57d989f5ed1730 Mon Sep 17 00:00:00 2001 From: hanbj Date: Fri, 21 Feb 2025 11:25:41 +0800 Subject: [PATCH 1/4] Reduce the number of comparisons when lowerPoint is equal to upperPoint --- .../apache/lucene/search/PointRangeQuery.java | 47 +++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/lucene/core/src/java/org/apache/lucene/search/PointRangeQuery.java b/lucene/core/src/java/org/apache/lucene/search/PointRangeQuery.java index fddafdcaef78..eeac84a94431 100644 --- a/lucene/core/src/java/org/apache/lucene/search/PointRangeQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/PointRangeQuery.java @@ -54,6 +54,7 @@ public abstract class PointRangeQuery extends Query { final int bytesPerDim; final byte[] lowerPoint; final byte[] upperPoint; + final boolean equalValues; /** * Expert: create a multidimensional range query for point values. @@ -89,6 +90,17 @@ protected PointRangeQuery(String field, byte[] lowerPoint, byte[] upperPoint, in this.lowerPoint = lowerPoint; this.upperPoint = upperPoint; + + ByteArrayComparator comparator = ArrayUtil.getUnsignedComparator(bytesPerDim); + boolean equalValues = true; + int offset = 0; + for (int dim = 0; dim < numDims; dim++, offset += bytesPerDim) { + if (comparator.compare(lowerPoint, offset, upperPoint, offset) != 0) { + equalValues = false; + break; + } + } + this.equalValues = equalValues; } /** @@ -129,6 +141,16 @@ public final Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, fl private boolean matches(byte[] packedValue) { int offset = 0; + + if (equalValues) { + for (int dim = 0; dim < numDims; dim++, offset += bytesPerDim) { + if (comparator.compare(packedValue, offset, lowerPoint, offset) != 0) { + return false; + } + } + return true; + } + for (int dim = 0; dim < numDims; dim++, offset += bytesPerDim) { if (comparator.compare(packedValue, offset, lowerPoint, offset) < 0) { // Doc's value is too low, in this dimension @@ -147,6 +169,31 @@ private Relation relate(byte[] minPackedValue, byte[] maxPackedValue) { boolean crosses = false; int offset = 0; + if (equalValues) { + for (int dim = 0; dim < numDims; dim++, offset += bytesPerDim) { + + int cmpMin = comparator.compare(minPackedValue, offset, lowerPoint, offset); + if (cmpMin > 0) { + return Relation.CELL_OUTSIDE_QUERY; + } + + int cmpMax = comparator.compare(maxPackedValue, offset, lowerPoint, offset); + if (cmpMax < 0) { + return Relation.CELL_OUTSIDE_QUERY; + } + + if (cmpMin != 0 || cmpMax != 0) { + crosses = true; + } + } + + if (crosses) { + return Relation.CELL_CROSSES_QUERY; + } else { + return Relation.CELL_INSIDE_QUERY; + } + } + for (int dim = 0; dim < numDims; dim++, offset += bytesPerDim) { if (comparator.compare(minPackedValue, offset, upperPoint, offset) > 0 From ee97cd335a8705ee562a8f58ef241c967866b7f8 Mon Sep 17 00:00:00 2001 From: hanbj Date: Tue, 1 Apr 2025 18:13:11 +0800 Subject: [PATCH 2/4] code format and add test --- .../apache/lucene/search/PointRangeQuery.java | 702 +++++++++--------- .../lucene/search/TestPointQueries.java | 43 ++ 2 files changed, 412 insertions(+), 333 deletions(-) diff --git a/lucene/core/src/java/org/apache/lucene/search/PointRangeQuery.java b/lucene/core/src/java/org/apache/lucene/search/PointRangeQuery.java index eeac84a94431..ceba2c71ad78 100644 --- a/lucene/core/src/java/org/apache/lucene/search/PointRangeQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/PointRangeQuery.java @@ -132,416 +132,447 @@ public void visit(QueryVisitor visitor) { public final Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException { + if (this.equalValues) { // lowerPoint==upperPoint + return new SinglePointConstantScoreWeight(this, scoreMode, boost); + } // We don't use RandomAccessWeight here: it's no good to approximate with "match all docs". // This is an inverted structure and should be used in the first pass: + return new MultiPointsConstantScoreWeight(this, scoreMode, boost); + } + + /** + * Essentially, it is to reduce the number of comparisons. This is an optimization, used for the + * case of lowerPoint==upperPoint. + */ + protected class SinglePointConstantScoreWeight extends MultiPointsConstantScoreWeight { + + public SinglePointConstantScoreWeight(Query query, ScoreMode scoreMode, float boost) { + super(query, scoreMode, boost); + } + + @Override + public boolean matches(byte[] packedValue) { + int offset = 0; + for (int dim = 0; dim < numDims; dim++, offset += bytesPerDim) { + if (comparator.compare(packedValue, offset, lowerPoint, offset) != 0) { + return false; + } + } + return true; + } - return new ConstantScoreWeight(this, boost) { + @Override + public Relation relate(byte[] minPackedValue, byte[] maxPackedValue) { + boolean crosses = false; + int offset = 0; - private final ByteArrayComparator comparator = ArrayUtil.getUnsignedComparator(bytesPerDim); + for (int dim = 0; dim < numDims; dim++, offset += bytesPerDim) { - private boolean matches(byte[] packedValue) { - int offset = 0; + int cmpMin = comparator.compare(minPackedValue, offset, lowerPoint, offset); + if (cmpMin > 0) { + return Relation.CELL_OUTSIDE_QUERY; + } - if (equalValues) { - for (int dim = 0; dim < numDims; dim++, offset += bytesPerDim) { - if (comparator.compare(packedValue, offset, lowerPoint, offset) != 0) { - return false; - } - } - return true; + int cmpMax = comparator.compare(maxPackedValue, offset, lowerPoint, offset); + if (cmpMax < 0) { + return Relation.CELL_OUTSIDE_QUERY; } - for (int dim = 0; dim < numDims; dim++, offset += bytesPerDim) { - if (comparator.compare(packedValue, offset, lowerPoint, offset) < 0) { - // Doc's value is too low, in this dimension - return false; - } - if (comparator.compare(packedValue, offset, upperPoint, offset) > 0) { - // Doc's value is too high, in this dimension - return false; - } + if (cmpMin != 0 || cmpMax != 0) { + crosses = true; } - return true; } - private Relation relate(byte[] minPackedValue, byte[] maxPackedValue) { + if (crosses) { + return Relation.CELL_CROSSES_QUERY; + } else { + return Relation.CELL_INSIDE_QUERY; + } + } + } - boolean crosses = false; - int offset = 0; + /** + * A weight that used for lowerPoint != upperPoint case, the query range may include multiple + * points. + */ + protected class MultiPointsConstantScoreWeight extends ConstantScoreWeight { - if (equalValues) { - for (int dim = 0; dim < numDims; dim++, offset += bytesPerDim) { + protected ScoreMode scoreMode; + protected ByteArrayComparator comparator; - int cmpMin = comparator.compare(minPackedValue, offset, lowerPoint, offset); - if (cmpMin > 0) { - return Relation.CELL_OUTSIDE_QUERY; - } + public MultiPointsConstantScoreWeight(Query query, ScoreMode scoreMode, float boost) { + super(query, boost); + this.scoreMode = scoreMode; + this.comparator = ArrayUtil.getUnsignedComparator(bytesPerDim); + } - int cmpMax = comparator.compare(maxPackedValue, offset, lowerPoint, offset); - if (cmpMax < 0) { - return Relation.CELL_OUTSIDE_QUERY; - } + public boolean matches(byte[] packedValue) { + int offset = 0; + for (int dim = 0; dim < numDims; dim++, offset += bytesPerDim) { + if (comparator.compare(packedValue, offset, lowerPoint, offset) < 0) { + // Doc's value is too low, in this dimension + return false; + } + if (comparator.compare(packedValue, offset, upperPoint, offset) > 0) { + // Doc's value is too high, in this dimension + return false; + } + } + return true; + } - if (cmpMin != 0 || cmpMax != 0) { - crosses = true; - } - } + public Relation relate(byte[] minPackedValue, byte[] maxPackedValue) { - if (crosses) { - return Relation.CELL_CROSSES_QUERY; - } else { - return Relation.CELL_INSIDE_QUERY; - } + boolean crosses = false; + int offset = 0; + + for (int dim = 0; dim < numDims; dim++, offset += bytesPerDim) { + + if (comparator.compare(minPackedValue, offset, upperPoint, offset) > 0 + || comparator.compare(maxPackedValue, offset, lowerPoint, offset) < 0) { + return Relation.CELL_OUTSIDE_QUERY; } - for (int dim = 0; dim < numDims; dim++, offset += bytesPerDim) { + crosses |= + comparator.compare(minPackedValue, offset, lowerPoint, offset) < 0 + || comparator.compare(maxPackedValue, offset, upperPoint, offset) > 0; + } - if (comparator.compare(minPackedValue, offset, upperPoint, offset) > 0 - || comparator.compare(maxPackedValue, offset, lowerPoint, offset) < 0) { - return Relation.CELL_OUTSIDE_QUERY; - } + if (crosses) { + return Relation.CELL_CROSSES_QUERY; + } else { + return Relation.CELL_INSIDE_QUERY; + } + } - crosses |= - comparator.compare(minPackedValue, offset, lowerPoint, offset) < 0 - || comparator.compare(maxPackedValue, offset, upperPoint, offset) > 0; - } + private IntersectVisitor getIntersectVisitor(DocIdSetBuilder result) { + return new IntersectVisitor() { - if (crosses) { - return Relation.CELL_CROSSES_QUERY; - } else { - return Relation.CELL_INSIDE_QUERY; + DocIdSetBuilder.BulkAdder adder; + + @Override + public void grow(int count) { + adder = result.grow(count); } - } - private IntersectVisitor getIntersectVisitor(DocIdSetBuilder result) { - return new IntersectVisitor() { + @Override + public void visit(int docID) { + adder.add(docID); + } - DocIdSetBuilder.BulkAdder adder; + @Override + public void visit(DocIdSetIterator iterator) throws IOException { + adder.add(iterator); + } - @Override - public void grow(int count) { - adder = result.grow(count); - } + @Override + public void visit(IntsRef ref) { + adder.add(ref); + } - @Override - public void visit(int docID) { - adder.add(docID); + @Override + public void visit(int docID, byte[] packedValue) { + if (matches(packedValue)) { + visit(docID); } + } - @Override - public void visit(DocIdSetIterator iterator) throws IOException { + @Override + public void visit(DocIdSetIterator iterator, byte[] packedValue) throws IOException { + if (matches(packedValue)) { adder.add(iterator); } + } - @Override - public void visit(IntsRef ref) { - adder.add(ref); + @Override + public Relation compare(byte[] minPackedValue, byte[] maxPackedValue) { + return relate(minPackedValue, maxPackedValue); + } + }; + } + + /** Create a visitor that sets documents that do NOT match the range. */ + private IntersectVisitor getInverseIntersectVisitor(FixedBitSet result, long[] cost) { + return new IntersectVisitor() { + + @Override + public void visit(int docID) { + result.set(docID); + cost[0]++; + } + + @Override + public void visit(DocIdSetIterator iterator) throws IOException { + result.or(iterator); + cost[0] += iterator.cost(); + } + + @Override + public void visit(IntsRef ref) { + for (int i = ref.offset; i < ref.offset + ref.length; i++) { + result.set(ref.ints[i]); } + cost[0] += ref.length; + } - @Override - public void visit(int docID, byte[] packedValue) { - if (matches(packedValue)) { - visit(docID); - } + @Override + public void visit(int docID, byte[] packedValue) { + if (matches(packedValue) == false) { + visit(docID); } + } - @Override - public void visit(DocIdSetIterator iterator, byte[] packedValue) throws IOException { - if (matches(packedValue)) { - adder.add(iterator); - } + @Override + public void visit(DocIdSetIterator iterator, byte[] packedValue) throws IOException { + if (matches(packedValue) == false) { + visit(iterator); } + } - @Override - public Relation compare(byte[] minPackedValue, byte[] maxPackedValue) { - return relate(minPackedValue, maxPackedValue); + @Override + public Relation compare(byte[] minPackedValue, byte[] maxPackedValue) { + Relation relation = relate(minPackedValue, maxPackedValue); + switch (relation) { + case CELL_INSIDE_QUERY: + // all points match, skip this subtree + return Relation.CELL_OUTSIDE_QUERY; + case CELL_OUTSIDE_QUERY: + // none of the points match, clear all documents + return Relation.CELL_INSIDE_QUERY; + case CELL_CROSSES_QUERY: + default: + return relation; } - }; + } + }; + } + + private boolean checkValidPointValues(PointValues values) throws IOException { + if (values == null) { + // No docs in this segment/field indexed any points + return false; } - /** Create a visitor that sets documents that do NOT match the range. */ - private IntersectVisitor getInverseIntersectVisitor(FixedBitSet result, long[] cost) { - return new IntersectVisitor() { + if (values.getNumIndexDimensions() != numDims) { + throw new IllegalArgumentException( + "field=\"" + + field + + "\" was indexed with numIndexDimensions=" + + values.getNumIndexDimensions() + + " but this query has numDims=" + + numDims); + } + if (bytesPerDim != values.getBytesPerDimension()) { + throw new IllegalArgumentException( + "field=\"" + + field + + "\" was indexed with bytesPerDim=" + + values.getBytesPerDimension() + + " but this query has bytesPerDim=" + + bytesPerDim); + } + return true; + } - @Override - public void visit(int docID) { - result.set(docID); - cost[0]++; - } + @Override + public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException { + LeafReader reader = context.reader(); - @Override - public void visit(DocIdSetIterator iterator) throws IOException { - result.or(iterator); - cost[0] += iterator.cost(); - } + PointValues values = reader.getPointValues(field); + if (checkValidPointValues(values) == false) { + return null; + } - @Override - public void visit(IntsRef ref) { - for (int i = ref.offset; i < ref.offset + ref.length; i++) { - result.set(ref.ints[i]); - } - cost[0] += ref.length; + if (values.getDocCount() == 0) { + return null; + } else { + final byte[] fieldPackedLower = values.getMinPackedValue(); + final byte[] fieldPackedUpper = values.getMaxPackedValue(); + for (int i = 0; i < numDims; ++i) { + int offset = i * bytesPerDim; + if (comparator.compare(lowerPoint, offset, fieldPackedUpper, offset) > 0 + || comparator.compare(upperPoint, offset, fieldPackedLower, offset) < 0) { + // If this query is a required clause of a boolean query, then returning null here + // will help make sure that we don't call ScorerSupplier#get on other required clauses + // of the same boolean query, which is an expensive operation for some queries (e.g. + // multi-term queries). + return null; } + } + } - @Override - public void visit(int docID, byte[] packedValue) { - if (matches(packedValue) == false) { - visit(docID); - } + boolean allDocsMatch; + if (values.getDocCount() == reader.maxDoc()) { + final byte[] fieldPackedLower = values.getMinPackedValue(); + final byte[] fieldPackedUpper = values.getMaxPackedValue(); + allDocsMatch = true; + for (int i = 0; i < numDims; ++i) { + int offset = i * bytesPerDim; + if (comparator.compare(lowerPoint, offset, fieldPackedLower, offset) > 0 + || comparator.compare(upperPoint, offset, fieldPackedUpper, offset) < 0) { + allDocsMatch = false; + break; } + } + } else { + allDocsMatch = false; + } + + if (allDocsMatch) { + // all docs have a value and all points are within bounds, so everything matches + return ConstantScoreScorerSupplier.matchAll(score(), scoreMode, reader.maxDoc()); + } else { + return new ConstantScoreScorerSupplier(score(), scoreMode, reader.maxDoc()) { + + final DocIdSetBuilder result = new DocIdSetBuilder(reader.maxDoc(), values); + final IntersectVisitor visitor = getIntersectVisitor(result); + long cost = -1; @Override - public void visit(DocIdSetIterator iterator, byte[] packedValue) throws IOException { - if (matches(packedValue) == false) { - visit(iterator); + public DocIdSetIterator iterator(long leadCost) throws IOException { + if (values.getDocCount() == reader.maxDoc() + && values.getDocCount() == values.size() + && cost() > reader.maxDoc() / 2) { + // If all docs have exactly one value and the cost is greater + // than half the leaf size then maybe we can make things faster + // by computing the set of documents that do NOT match the range + final FixedBitSet result = new FixedBitSet(reader.maxDoc()); + long[] cost = new long[1]; + values.intersect(getInverseIntersectVisitor(result, cost)); + // Flip the bit set and cost + result.flip(0, reader.maxDoc()); + cost[0] = Math.max(0, reader.maxDoc() - cost[0]); + return new BitSetIterator(result, cost[0]); } + + values.intersect(visitor); + return result.build().iterator(); } @Override - public Relation compare(byte[] minPackedValue, byte[] maxPackedValue) { - Relation relation = relate(minPackedValue, maxPackedValue); - switch (relation) { - case CELL_INSIDE_QUERY: - // all points match, skip this subtree - return Relation.CELL_OUTSIDE_QUERY; - case CELL_OUTSIDE_QUERY: - // none of the points match, clear all documents - return Relation.CELL_INSIDE_QUERY; - case CELL_CROSSES_QUERY: - default: - return relation; + public long cost() { + if (cost == -1) { + // Computing the cost may be expensive, so only do it if necessary + cost = values.estimateDocCount(visitor); + assert cost >= 0; } + return cost; } }; } + } - private boolean checkValidPointValues(PointValues values) throws IOException { - if (values == null) { - // No docs in this segment/field indexed any points - return false; - } + @Override + public int count(LeafReaderContext context) throws IOException { + LeafReader reader = context.reader(); - if (values.getNumIndexDimensions() != numDims) { - throw new IllegalArgumentException( - "field=\"" - + field - + "\" was indexed with numIndexDimensions=" - + values.getNumIndexDimensions() - + " but this query has numDims=" - + numDims); - } - if (bytesPerDim != values.getBytesPerDimension()) { - throw new IllegalArgumentException( - "field=\"" - + field - + "\" was indexed with bytesPerDim=" - + values.getBytesPerDimension() - + " but this query has bytesPerDim=" - + bytesPerDim); - } - return true; + PointValues values = reader.getPointValues(field); + if (checkValidPointValues(values) == false) { + return 0; } - @Override - public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException { - LeafReader reader = context.reader(); - - PointValues values = reader.getPointValues(field); - if (checkValidPointValues(values) == false) { - return null; + if (reader.hasDeletions() == false) { + if (relate(values.getMinPackedValue(), values.getMaxPackedValue()) + == Relation.CELL_INSIDE_QUERY) { + return values.getDocCount(); } - - if (values.getDocCount() == 0) { - return null; - } else { - final byte[] fieldPackedLower = values.getMinPackedValue(); - final byte[] fieldPackedUpper = values.getMaxPackedValue(); - for (int i = 0; i < numDims; ++i) { - int offset = i * bytesPerDim; - if (comparator.compare(lowerPoint, offset, fieldPackedUpper, offset) > 0 - || comparator.compare(upperPoint, offset, fieldPackedLower, offset) < 0) { - // If this query is a required clause of a boolean query, then returning null here - // will help make sure that we don't call ScorerSupplier#get on other required clauses - // of the same boolean query, which is an expensive operation for some queries (e.g. - // multi-term queries). - return null; - } - } + // only 1D: we have the guarantee that it will actually run fast since there are at most 2 + // crossing leaves. + // docCount == size : counting according number of points in leaf node, so must be + // single-valued. + if (numDims == 1 && values.getDocCount() == values.size()) { + return (int) pointCount(values.getPointTree(), this::relate, this::matches); } + } + return super.count(context); + } - boolean allDocsMatch; - if (values.getDocCount() == reader.maxDoc()) { - final byte[] fieldPackedLower = values.getMinPackedValue(); - final byte[] fieldPackedUpper = values.getMaxPackedValue(); - allDocsMatch = true; - for (int i = 0; i < numDims; ++i) { - int offset = i * bytesPerDim; - if (comparator.compare(lowerPoint, offset, fieldPackedLower, offset) > 0 - || comparator.compare(upperPoint, offset, fieldPackedUpper, offset) < 0) { - allDocsMatch = false; - break; + /** + * Finds the number of points matching the provided range conditions. Using this method is + * faster than calling {@link PointValues#intersect(IntersectVisitor)} to get the count of + * intersecting points. This method does not enforce live documents, therefore it should only be + * used when there are no deleted documents. + * + * @param pointTree start node of the count operation + * @param nodeComparator comparator to be used for checking whether the internal node is inside + * the range + * @param leafComparator comparator to be used for checking whether the leaf node is inside the + * range + * @return count of points that match the range + */ + private long pointCount( + PointValues.PointTree pointTree, + BiFunction nodeComparator, + Predicate leafComparator) + throws IOException { + final long[] matchingNodeCount = {0}; + // create a custom IntersectVisitor that records the number of leafNodes that matched + final IntersectVisitor visitor = + new IntersectVisitor() { + @Override + public void visit(int docID) { + // this branch should be unreachable + throw new UnsupportedOperationException( + "This IntersectVisitor does not perform any actions on a " + + "docID=" + + docID + + " node being visited"); } - } - } else { - allDocsMatch = false; - } - - if (allDocsMatch) { - // all docs have a value and all points are within bounds, so everything matches - return ConstantScoreScorerSupplier.matchAll(score(), scoreMode, reader.maxDoc()); - } else { - return new ConstantScoreScorerSupplier(score(), scoreMode, reader.maxDoc()) { - - final DocIdSetBuilder result = new DocIdSetBuilder(reader.maxDoc(), values); - final IntersectVisitor visitor = getIntersectVisitor(result); - long cost = -1; @Override - public DocIdSetIterator iterator(long leadCost) throws IOException { - if (values.getDocCount() == reader.maxDoc() - && values.getDocCount() == values.size() - && cost() > reader.maxDoc() / 2) { - // If all docs have exactly one value and the cost is greater - // than half the leaf size then maybe we can make things faster - // by computing the set of documents that do NOT match the range - final FixedBitSet result = new FixedBitSet(reader.maxDoc()); - long[] cost = new long[1]; - values.intersect(getInverseIntersectVisitor(result, cost)); - // Flip the bit set and cost - result.flip(0, reader.maxDoc()); - cost[0] = Math.max(0, reader.maxDoc() - cost[0]); - return new BitSetIterator(result, cost[0]); + public void visit(int docID, byte[] packedValue) { + if (leafComparator.test(packedValue)) { + matchingNodeCount[0]++; } - - values.intersect(visitor); - return result.build().iterator(); } @Override - public long cost() { - if (cost == -1) { - // Computing the cost may be expensive, so only do it if necessary - cost = values.estimateDocCount(visitor); - assert cost >= 0; - } - return cost; + public Relation compare(byte[] minPackedValue, byte[] maxPackedValue) { + return nodeComparator.apply(minPackedValue, maxPackedValue); } }; - } - } - - @Override - public int count(LeafReaderContext context) throws IOException { - LeafReader reader = context.reader(); - - PointValues values = reader.getPointValues(field); - if (checkValidPointValues(values) == false) { - return 0; - } + pointCount(visitor, pointTree, matchingNodeCount); + return matchingNodeCount[0]; + } - if (reader.hasDeletions() == false) { - if (relate(values.getMinPackedValue(), values.getMaxPackedValue()) - == Relation.CELL_INSIDE_QUERY) { - return values.getDocCount(); - } - // only 1D: we have the guarantee that it will actually run fast since there are at most 2 - // crossing leaves. - // docCount == size : counting according number of points in leaf node, so must be - // single-valued. - if (numDims == 1 && values.getDocCount() == values.size()) { - return (int) pointCount(values.getPointTree(), this::relate, this::matches); + private void pointCount( + IntersectVisitor visitor, PointValues.PointTree pointTree, long[] matchingNodeCount) + throws IOException { + Relation r = visitor.compare(pointTree.getMinPackedValue(), pointTree.getMaxPackedValue()); + switch (r) { + case CELL_OUTSIDE_QUERY: + // This cell is fully outside the query shape: return 0 as the count of its nodes + return; + case CELL_INSIDE_QUERY: + // This cell is fully inside the query shape: return the size of the entire node as the + // count + matchingNodeCount[0] += pointTree.size(); + return; + case CELL_CROSSES_QUERY: + /* + The cell crosses the shape boundary, or the cell fully contains the query, so we fall + through and do full counting. + */ + if (pointTree.moveToChild()) { + do { + pointCount(visitor, pointTree, matchingNodeCount); + } while (pointTree.moveToSibling()); + pointTree.moveToParent(); + } else { + // we have reached a leaf node here. + pointTree.visitDocValues(visitor); + // leaf node count is saved in the matchingNodeCount array by the visitor } - } - return super.count(context); - } - - /** - * Finds the number of points matching the provided range conditions. Using this method is - * faster than calling {@link PointValues#intersect(IntersectVisitor)} to get the count of - * intersecting points. This method does not enforce live documents, therefore it should only - * be used when there are no deleted documents. - * - * @param pointTree start node of the count operation - * @param nodeComparator comparator to be used for checking whether the internal node is - * inside the range - * @param leafComparator comparator to be used for checking whether the leaf node is inside - * the range - * @return count of points that match the range - */ - private long pointCount( - PointValues.PointTree pointTree, - BiFunction nodeComparator, - Predicate leafComparator) - throws IOException { - final long[] matchingNodeCount = {0}; - // create a custom IntersectVisitor that records the number of leafNodes that matched - final IntersectVisitor visitor = - new IntersectVisitor() { - @Override - public void visit(int docID) { - // this branch should be unreachable - throw new UnsupportedOperationException( - "This IntersectVisitor does not perform any actions on a " - + "docID=" - + docID - + " node being visited"); - } - - @Override - public void visit(int docID, byte[] packedValue) { - if (leafComparator.test(packedValue)) { - matchingNodeCount[0]++; - } - } - - @Override - public Relation compare(byte[] minPackedValue, byte[] maxPackedValue) { - return nodeComparator.apply(minPackedValue, maxPackedValue); - } - }; - pointCount(visitor, pointTree, matchingNodeCount); - return matchingNodeCount[0]; - } - - private void pointCount( - IntersectVisitor visitor, PointValues.PointTree pointTree, long[] matchingNodeCount) - throws IOException { - Relation r = visitor.compare(pointTree.getMinPackedValue(), pointTree.getMaxPackedValue()); - switch (r) { - case CELL_OUTSIDE_QUERY: - // This cell is fully outside the query shape: return 0 as the count of its nodes - return; - case CELL_INSIDE_QUERY: - // This cell is fully inside the query shape: return the size of the entire node as the - // count - matchingNodeCount[0] += pointTree.size(); - return; - case CELL_CROSSES_QUERY: - /* - The cell crosses the shape boundary, or the cell fully contains the query, so we fall - through and do full counting. - */ - if (pointTree.moveToChild()) { - do { - pointCount(visitor, pointTree, matchingNodeCount); - } while (pointTree.moveToSibling()); - pointTree.moveToParent(); - } else { - // we have reached a leaf node here. - pointTree.visitDocValues(visitor); - // leaf node count is saved in the matchingNodeCount array by the visitor - } - return; - default: - throw new IllegalArgumentException("Unreachable code"); - } + return; + default: + throw new IllegalArgumentException("Unreachable code"); } + } - @Override - public boolean isCacheable(LeafReaderContext ctx) { - return true; - } - }; + @Override + public boolean isCacheable(LeafReaderContext ctx) { + return true; + } } public String getField() { @@ -564,6 +595,11 @@ public byte[] getUpperPoint() { return upperPoint.clone(); } + // for test + public boolean isEqualValues() { + return equalValues; + } + @Override public final int hashCode() { int hash = classHash(); diff --git a/lucene/core/src/test/org/apache/lucene/search/TestPointQueries.java b/lucene/core/src/test/org/apache/lucene/search/TestPointQueries.java index caedc1aa2dc5..68fe70ba2fa2 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestPointQueries.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestPointQueries.java @@ -2538,4 +2538,47 @@ public void testPointInSetQuerySkipsNonMatchingSegments() throws IOException { w.close(); dir.close(); } + + public void testPointRangeQueryWithEqualValues() throws Exception { + Directory dir = newDirectory(); + IndexWriterConfig iwc = newIndexWriterConfig(); + iwc.setCodec(getCodec()); + IndexWriter w = new IndexWriter(dir, iwc); + + int cardinality = TestUtil.nextInt(random(), 2, 20); + + int zeroCount = 0; + int oneCount = 0; + for (int i = 0; i < 10000; i++) { + int x = random().nextInt(cardinality); + if (x == 0) { + zeroCount++; + } else if (x == 1) { + oneCount++; + } + Document doc = new Document(); + doc.add(new IntPoint("int", x)); + w.addDocument(doc); + } + + IndexReader r = DirectoryReader.open(w); + IndexSearcher s = newSearcher(r, false); + + PointRangeQuery query = (PointRangeQuery) IntPoint.newRangeQuery("int", 0, 1); + assertFalse(query.isEqualValues()); + Weight weight = query.createWeight(s, ScoreMode.COMPLETE_NO_SCORES, 1f); + assertTrue(weight instanceof PointRangeQuery.MultiPointsConstantScoreWeight); + + query = (PointRangeQuery) IntPoint.newRangeQuery("int", 0, 0); + assertTrue(query.isEqualValues()); + weight = query.createWeight(s, ScoreMode.COMPLETE_NO_SCORES, 1f); + assertTrue(weight instanceof PointRangeQuery.SinglePointConstantScoreWeight); + + assertEquals(zeroCount, s.count(IntPoint.newRangeQuery("int", 0, 0))); + assertEquals(oneCount, s.count(IntPoint.newRangeQuery("int", 1, 1))); + + w.close(); + r.close(); + dir.close(); + } } From ae3bfb6cffecf45a0f07ecbdd7eba2bc500ec66b Mon Sep 17 00:00:00 2001 From: hanbj Date: Tue, 1 Apr 2025 18:13:28 +0800 Subject: [PATCH 3/4] add change --- lucene/CHANGES.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index 74048906a71b..1311be7f60a1 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -69,6 +69,8 @@ Optimizations * GITHUB#14268: PointInSetQuery early exit on non-matching segments. (hanbj) +* GITHUB#14267: Reduce the number of comparisons when lowerPoint is equal to upperPoint. (hanbj) + Bug Fixes --------------------- (No changes) From 471179e351da992355574e2478b26d6c25e43eb1 Mon Sep 17 00:00:00 2001 From: hanbj Date: Wed, 2 Apr 2025 14:39:47 +0800 Subject: [PATCH 4/4] code refactoring --- .../apache/lucene/search/PointRangeQuery.java | 64 +++++++++++++------ .../lucene/search/TestPointQueries.java | 4 +- 2 files changed, 48 insertions(+), 20 deletions(-) diff --git a/lucene/core/src/java/org/apache/lucene/search/PointRangeQuery.java b/lucene/core/src/java/org/apache/lucene/search/PointRangeQuery.java index ceba2c71ad78..4552c8e78842 100644 --- a/lucene/core/src/java/org/apache/lucene/search/PointRangeQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/PointRangeQuery.java @@ -133,20 +133,23 @@ public final Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, fl throws IOException { if (this.equalValues) { // lowerPoint==upperPoint - return new SinglePointConstantScoreWeight(this, scoreMode, boost); + return new SinglePointRangeQueryWeight(this, scoreMode, boost); } // We don't use RandomAccessWeight here: it's no good to approximate with "match all docs". // This is an inverted structure and should be used in the first pass: - return new MultiPointsConstantScoreWeight(this, scoreMode, boost); + return new MultiPointRangeQueryWeight(this, scoreMode, boost); } /** - * Essentially, it is to reduce the number of comparisons. This is an optimization, used for the - * case of lowerPoint==upperPoint. + * Single-point range query weight implementation class, used to handle the special case where the + * lower and upper bounds are equal (i.e. single-point query). + * + *

Optimize query performance by reducing the number of comparisons between dimensions. This + * implementation is used when the upper and lower bounds of all dimensions are exactly the same. */ - protected class SinglePointConstantScoreWeight extends MultiPointsConstantScoreWeight { + protected class SinglePointRangeQueryWeight extends PointRangeQueryWeight { - public SinglePointConstantScoreWeight(Query query, ScoreMode scoreMode, float boost) { + protected SinglePointRangeQueryWeight(Query query, ScoreMode scoreMode, float boost) { super(query, scoreMode, boost); } @@ -192,21 +195,20 @@ public Relation relate(byte[] minPackedValue, byte[] maxPackedValue) { } /** - * A weight that used for lowerPoint != upperPoint case, the query range may include multiple - * points. + * Multiple-point range query weight implementation class, used to handle the situation where the + * query range contains multiple points. + * + *

When the lower bound (lowerPoint) of the query is not equal to the upper bound (upperPoint), + * this implementation is used to check whether each dimension is within the query range. */ - protected class MultiPointsConstantScoreWeight extends ConstantScoreWeight { + protected class MultiPointRangeQueryWeight extends PointRangeQueryWeight { - protected ScoreMode scoreMode; - protected ByteArrayComparator comparator; - - public MultiPointsConstantScoreWeight(Query query, ScoreMode scoreMode, float boost) { - super(query, boost); - this.scoreMode = scoreMode; - this.comparator = ArrayUtil.getUnsignedComparator(bytesPerDim); + protected MultiPointRangeQueryWeight(Query query, ScoreMode scoreMode, float boost) { + super(query, scoreMode, boost); } - public boolean matches(byte[] packedValue) { + @Override + protected boolean matches(byte[] packedValue) { int offset = 0; for (int dim = 0; dim < numDims; dim++, offset += bytesPerDim) { if (comparator.compare(packedValue, offset, lowerPoint, offset) < 0) { @@ -221,7 +223,8 @@ public boolean matches(byte[] packedValue) { return true; } - public Relation relate(byte[] minPackedValue, byte[] maxPackedValue) { + @Override + protected Relation relate(byte[] minPackedValue, byte[] maxPackedValue) { boolean crosses = false; int offset = 0; @@ -244,6 +247,31 @@ public Relation relate(byte[] minPackedValue, byte[] maxPackedValue) { return Relation.CELL_INSIDE_QUERY; } } + } + + /** + * Basic weight class, inherited from {@link ConstantScoreWeight}, subclasses need to implement + * specific point value matching logic and range relationship judgment. + * + * @see SinglePointRangeQueryWeight for the specific implementation of single-point range query. + * @see MultiPointRangeQueryWeight for the specific implementation of multi-point range query. + */ + protected abstract class PointRangeQueryWeight extends ConstantScoreWeight { + + protected ScoreMode scoreMode; + protected ByteArrayComparator comparator; + + protected PointRangeQueryWeight(Query query, ScoreMode scoreMode, float boost) { + super(query, boost); + this.scoreMode = scoreMode; + this.comparator = ArrayUtil.getUnsignedComparator(bytesPerDim); + } + + /** whether the point value matches the query range. */ + protected abstract boolean matches(byte[] packedValue); + + /** relation between the point value range and the query range. */ + protected abstract Relation relate(byte[] minPackedValue, byte[] maxPackedValue); private IntersectVisitor getIntersectVisitor(DocIdSetBuilder result) { return new IntersectVisitor() { diff --git a/lucene/core/src/test/org/apache/lucene/search/TestPointQueries.java b/lucene/core/src/test/org/apache/lucene/search/TestPointQueries.java index 68fe70ba2fa2..ab7ba6f3e46c 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestPointQueries.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestPointQueries.java @@ -2567,12 +2567,12 @@ public void testPointRangeQueryWithEqualValues() throws Exception { PointRangeQuery query = (PointRangeQuery) IntPoint.newRangeQuery("int", 0, 1); assertFalse(query.isEqualValues()); Weight weight = query.createWeight(s, ScoreMode.COMPLETE_NO_SCORES, 1f); - assertTrue(weight instanceof PointRangeQuery.MultiPointsConstantScoreWeight); + assertTrue(weight instanceof PointRangeQuery.MultiPointRangeQueryWeight); query = (PointRangeQuery) IntPoint.newRangeQuery("int", 0, 0); assertTrue(query.isEqualValues()); weight = query.createWeight(s, ScoreMode.COMPLETE_NO_SCORES, 1f); - assertTrue(weight instanceof PointRangeQuery.SinglePointConstantScoreWeight); + assertTrue(weight instanceof PointRangeQuery.SinglePointRangeQueryWeight); assertEquals(zeroCount, s.count(IntPoint.newRangeQuery("int", 0, 0))); assertEquals(oneCount, s.count(IntPoint.newRangeQuery("int", 1, 1)));