diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index e58df600b694..74048906a71b 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -67,6 +67,8 @@ Optimizations --------------------- * GITHUB#14418: Quick exit on filter query matching no docs when rewriting knn query. (Pan Guixin) +* GITHUB#14268: PointInSetQuery early exit on non-matching segments. (hanbj) + Bug Fixes --------------------- (No changes) diff --git a/lucene/core/src/java/org/apache/lucene/search/PointInSetQuery.java b/lucene/core/src/java/org/apache/lucene/search/PointInSetQuery.java index d94fec11187c..e269f29616d0 100644 --- a/lucene/core/src/java/org/apache/lucene/search/PointInSetQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/PointInSetQuery.java @@ -62,6 +62,8 @@ public abstract class PointInSetQuery extends Query implements Accountable { final int numDims; final int bytesPerDim; final long ramBytesUsed; // cache + byte[] lowerPoint = null; + byte[] upperPoint = null; /** Iterator of encoded point values. */ // TODO: if we want to stream, maybe we should use jdk stream class? @@ -108,6 +110,9 @@ protected PointInSetQuery(String field, int numDims, int bytesPerDim, Stream pac } if (previous == null) { previous = new BytesRefBuilder(); + lowerPoint = new byte[bytesPerDim * numDims]; + assert lowerPoint.length == current.length; + System.arraycopy(current.bytes, current.offset, lowerPoint, 0, current.length); } else { int cmp = previous.get().compareTo(current); if (cmp == 0) { @@ -122,6 +127,12 @@ protected PointInSetQuery(String field, int numDims, int bytesPerDim, Stream pac } sortedPackedPoints = builder.finish(); sortedPackedPointsHashCode = sortedPackedPoints.hashCode(); + if (previous != null) { + BytesRef max = previous.get(); + upperPoint = new byte[bytesPerDim * numDims]; + assert upperPoint.length == max.length; + System.arraycopy(max.bytes, max.offset, upperPoint, 0, max.length); + } ramBytesUsed = BASE_RAM_BYTES + RamUsageEstimator.sizeOfObject(field) @@ -172,6 +183,22 @@ public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOExcepti + bytesPerDim); } + if (values.getDocCount() == 0) { + return null; + } else if (lowerPoint != null) { + assert upperPoint != null; + ByteArrayComparator comparator = ArrayUtil.getUnsignedComparator(bytesPerDim); + final byte[] fieldPackedLower = values.getMinPackedValue(); + final byte[] fieldPackedUpper = values.getMaxPackedValue(); + for (int i = 0; i < numDims; ++i) { + int offset = i * bytesPerDim; + if (comparator.compare(lowerPoint, offset, fieldPackedUpper, offset) > 0 + || comparator.compare(upperPoint, offset, fieldPackedLower, offset) < 0) { + return null; + } + } + } + if (numDims == 1) { // We optimize this common case, effectively doing a merge sort of the indexed values vs // the queried set: diff --git a/lucene/core/src/test/org/apache/lucene/search/TestPointQueries.java b/lucene/core/src/test/org/apache/lucene/search/TestPointQueries.java index 0ca2d177b09f..caedc1aa2dc5 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestPointQueries.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestPointQueries.java @@ -2497,4 +2497,45 @@ public void testRangeQuerySkipsNonMatchingSegments() throws IOException { w.close(); dir.close(); } + + public void testPointInSetQuerySkipsNonMatchingSegments() throws IOException { + Directory dir = newDirectory(); + IndexWriter w = new IndexWriter(dir, newIndexWriterConfig()); + Document doc = new Document(); + doc.add(new IntPoint("field", 10)); + doc.add(new IntPoint("field2d", 10, 10)); + w.addDocument(doc); + + DirectoryReader reader = DirectoryReader.open(w); + IndexSearcher searcher = newSearcher(reader); + + Query query = IntPoint.newSetQuery("field", 1, 3, 5); + Weight weight = + searcher.createWeight(searcher.rewrite(query), ScoreMode.COMPLETE_NO_SCORES, 1f); + assertNull(weight.scorerSupplier(reader.leaves().get(0))); + + query = IntPoint.newSetQuery("field", 11, 13, 15); + weight = searcher.createWeight(searcher.rewrite(query), ScoreMode.COMPLETE_NO_SCORES, 1f); + assertNull(weight.scorerSupplier(reader.leaves().get(0))); + + query = IntPoint.newSetQuery("field", 5, 10, 15); + weight = searcher.createWeight(searcher.rewrite(query), ScoreMode.COMPLETE_NO_SCORES, 1f); + assertNotNull(weight.scorerSupplier(reader.leaves().get(0))); + + query = newMultiDimIntSetQuery("field2d", 2, 5, 5); + weight = searcher.createWeight(searcher.rewrite(query), ScoreMode.COMPLETE_NO_SCORES, 1f); + assertNull(weight.scorerSupplier(reader.leaves().get(0))); + + query = newMultiDimIntSetQuery("field2d", 2, 15, 15); + weight = searcher.createWeight(searcher.rewrite(query), ScoreMode.COMPLETE_NO_SCORES, 1f); + assertNull(weight.scorerSupplier(reader.leaves().get(0))); + + query = newMultiDimIntSetQuery("field2d", 2, 10, 10); + weight = searcher.createWeight(searcher.rewrite(query), ScoreMode.COMPLETE_NO_SCORES, 1f); + assertNotNull(weight.scorerSupplier(reader.leaves().get(0))); + + reader.close(); + w.close(); + dir.close(); + } }