Skip to content

Commit f16e2f3

Browse files
authored
Fix patience knn queries to work with seeded knn queries (#14688)
1 parent a5ab7e2 commit f16e2f3

File tree

4 files changed

+82
-13
lines changed

4 files changed

+82
-13
lines changed

lucene/core/src/java/org/apache/lucene/search/HnswQueueSaturationCollector.java

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,13 @@ public void nextCandidate() {
9494
@Override
9595
public KnnSearchStrategy getSearchStrategy() {
9696
KnnSearchStrategy delegateStrategy = delegate.getSearchStrategy();
97-
assert delegateStrategy instanceof KnnSearchStrategy.Hnsw;
98-
return new KnnSearchStrategy.Patience(
99-
this, ((KnnSearchStrategy.Hnsw) delegateStrategy).filteredSearchThreshold());
97+
if (delegateStrategy instanceof KnnSearchStrategy.Hnsw hnswStrategy) {
98+
return new KnnSearchStrategy.Patience(this, hnswStrategy.filteredSearchThreshold());
99+
} else if (delegateStrategy instanceof KnnSearchStrategy.Seeded seededStrategy) {
100+
if (seededStrategy.originalStrategy() instanceof KnnSearchStrategy.Hnsw hnswStrategy) {
101+
return new KnnSearchStrategy.Patience(this, hnswStrategy.filteredSearchThreshold());
102+
}
103+
}
104+
return delegateStrategy;
100105
}
101106
}

lucene/core/src/java/org/apache/lucene/search/PatienceKnnVectorQuery.java

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,7 @@ public class PatienceKnnVectorQuery extends AbstractKnnVectorQuery {
4343

4444
private final int patience;
4545
private final double saturationThreshold;
46-
47-
final AbstractKnnVectorQuery delegate;
46+
private AbstractKnnVectorQuery delegate;
4847

4948
/**
5049
* Construct a new PatienceKnnVectorQuery instance for a float vector field
@@ -234,4 +233,18 @@ public KnnCollector newCollector(
234233
patience);
235234
}
236235
}
236+
237+
@Override
238+
public Query rewrite(IndexSearcher indexSearcher) throws IOException {
239+
if (delegate instanceof SeededKnnVectorQuery seededKnnVectorQuery) {
240+
// this is required because SeededKnnVectorQuery now requires its own rewriting logic (to
241+
// create the seed Weight)
242+
delegate =
243+
new SeededKnnVectorQuery(
244+
seededKnnVectorQuery.delegate,
245+
seededKnnVectorQuery.seed,
246+
seededKnnVectorQuery.createSeedWeight(indexSearcher));
247+
}
248+
return super.rewrite(indexSearcher);
249+
}
237250
}

lucene/core/src/test/org/apache/lucene/search/TestPatienceByteVectorQuery.java

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,27 @@
3030
import org.apache.lucene.search.knn.KnnSearchStrategy;
3131
import org.apache.lucene.store.Directory;
3232
import org.apache.lucene.util.TestVectorUtil;
33+
import org.junit.Before;
3334

3435
public class TestPatienceByteVectorQuery extends BaseKnnVectorQueryTestCase {
3536

37+
private boolean wrapSeeded;
38+
39+
@Before
40+
@Override
41+
public void setUp() throws Exception {
42+
super.setUp();
43+
wrapSeeded = random().nextBoolean();
44+
}
45+
3646
@Override
3747
PatienceKnnVectorQuery getKnnVectorQuery(String field, float[] query, int k, Query queryFilter) {
38-
return PatienceKnnVectorQuery.fromByteQuery(
39-
new KnnByteVectorQuery(field, floatToBytes(query), k, queryFilter));
48+
KnnByteVectorQuery knnQuery =
49+
new KnnByteVectorQuery(field, floatToBytes(query), k, queryFilter);
50+
return wrapSeeded
51+
? PatienceKnnVectorQuery.fromSeededQuery(
52+
SeededKnnVectorQuery.fromByteQuery(knnQuery, new MatchNoDocsQuery()))
53+
: PatienceKnnVectorQuery.fromByteQuery(knnQuery);
4054
}
4155

4256
@Override
@@ -80,7 +94,13 @@ public void testToString() throws IOException {
8094
IndexReader reader = DirectoryReader.open(indexStore)) {
8195
AbstractKnnVectorQuery query = getKnnVectorQuery("field", new float[] {0.0f, 1.0f}, 10);
8296
assertEquals(
83-
"PatienceKnnVectorQuery{saturationThreshold=0.995, patience=7, delegate=KnnByteVectorQuery:field[0,...][10]}",
97+
"PatienceKnnVectorQuery{saturationThreshold=0.995, patience=7, delegate="
98+
+ (wrapSeeded
99+
? "SeededKnnVectorQuery{seed=MatchNoDocsQuery(\"\"), seedWeight=null, delegate="
100+
: "")
101+
+ "KnnByteVectorQuery:field[0,...][10]"
102+
+ (wrapSeeded ? "}" : "")
103+
+ "}",
84104
query.toString("ignored"));
85105

86106
assertDocScoreQueryToString(query.rewrite(newSearcher(reader)));
@@ -89,7 +109,13 @@ public void testToString() throws IOException {
89109
Query filter = new TermQuery(new Term("id", "text"));
90110
query = getKnnVectorQuery("field", new float[] {0.0f, 1.0f}, 10, filter);
91111
assertEquals(
92-
"PatienceKnnVectorQuery{saturationThreshold=0.995, patience=7, delegate=KnnByteVectorQuery:field[0,...][10][id:text]}",
112+
"PatienceKnnVectorQuery{saturationThreshold=0.995, patience=7, delegate="
113+
+ (wrapSeeded
114+
? "SeededKnnVectorQuery{seed=MatchNoDocsQuery(\"\"), seedWeight=null, delegate="
115+
: "")
116+
+ "KnnByteVectorQuery:field[0,...][10][id:text]"
117+
+ (wrapSeeded ? "}" : "")
118+
+ "}",
93119
query.toString("ignored"));
94120
}
95121
}

lucene/core/src/test/org/apache/lucene/search/TestPatienceFloatVectorQuery.java

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,26 @@
2828
import org.apache.lucene.search.knn.KnnSearchStrategy;
2929
import org.apache.lucene.store.Directory;
3030
import org.apache.lucene.util.TestVectorUtil;
31+
import org.junit.Before;
3132

3233
public class TestPatienceFloatVectorQuery extends BaseKnnVectorQueryTestCase {
3334

35+
private boolean wrapSeeded;
36+
37+
@Before
38+
@Override
39+
public void setUp() throws Exception {
40+
super.setUp();
41+
wrapSeeded = random().nextBoolean();
42+
}
43+
3444
@Override
3545
PatienceKnnVectorQuery getKnnVectorQuery(String field, float[] query, int k, Query queryFilter) {
36-
return PatienceKnnVectorQuery.fromFloatQuery(
37-
new KnnFloatVectorQuery(field, query, k, queryFilter));
46+
KnnFloatVectorQuery knnQuery = new KnnFloatVectorQuery(field, query, k, queryFilter);
47+
return wrapSeeded
48+
? PatienceKnnVectorQuery.fromSeededQuery(
49+
SeededKnnVectorQuery.fromFloatQuery(knnQuery, new MatchNoDocsQuery()))
50+
: PatienceKnnVectorQuery.fromFloatQuery(knnQuery);
3851
}
3952

4053
@Override
@@ -71,7 +84,13 @@ public void testToString() throws IOException {
7184
IndexReader reader = DirectoryReader.open(indexStore)) {
7285
AbstractKnnVectorQuery query = getKnnVectorQuery("field", new float[] {0.0f, 1.0f}, 10);
7386
assertEquals(
74-
"PatienceKnnVectorQuery{saturationThreshold=0.995, patience=7, delegate=KnnFloatVectorQuery:field[0.0,...][10]}",
87+
"PatienceKnnVectorQuery{saturationThreshold=0.995, patience=7, delegate="
88+
+ (wrapSeeded
89+
? "SeededKnnVectorQuery{seed=MatchNoDocsQuery(\"\"), seedWeight=null, delegate="
90+
: "")
91+
+ "KnnFloatVectorQuery:field[0.0,...][10]"
92+
+ (wrapSeeded ? "}" : "")
93+
+ "}",
7594
query.toString("ignored"));
7695

7796
assertDocScoreQueryToString(query.rewrite(newSearcher(reader)));
@@ -80,7 +99,13 @@ public void testToString() throws IOException {
8099
Query filter = new TermQuery(new Term("id", "text"));
81100
query = getKnnVectorQuery("field", new float[] {0.0f, 1.0f}, 10, filter);
82101
assertEquals(
83-
"PatienceKnnVectorQuery{saturationThreshold=0.995, patience=7, delegate=KnnFloatVectorQuery:field[0.0,...][10][id:text]}",
102+
"PatienceKnnVectorQuery{saturationThreshold=0.995, patience=7, delegate="
103+
+ (wrapSeeded
104+
? "SeededKnnVectorQuery{seed=MatchNoDocsQuery(\"\"), seedWeight=null, delegate="
105+
: "")
106+
+ "KnnFloatVectorQuery:field[0.0,...][10][id:text]"
107+
+ (wrapSeeded ? "}" : "")
108+
+ "}",
84109
query.toString("ignored"));
85110
}
86111
}

0 commit comments

Comments
 (0)