Skip to content

Commit 6e2f1e6

Browse files
committed
Add a DoubleValuesSource for scoring full precision vector similarity (#14708)
1 parent c92e06c commit 6e2f1e6

File tree

5 files changed

+368
-19
lines changed

5 files changed

+368
-19
lines changed

lucene/CHANGES.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ New Features
2323
* GITHUB#14776: Add a Rescorer that uses values from provided DoubleValuesSource to re-score
2424
first pass hits. (Vigya Sharma)
2525

26+
* GITHUB#14708: Add a DoubleValuesSource for full precision vector similarity scores. (Vigya Sharma)
27+
2628
Improvements
2729
---------------------
2830
* GITHUB#14458: Add an IndexDeletion policy that retains the last N commits. (Owais Kazi)

lucene/core/src/java/org/apache/lucene/search/DoubleValuesSource.java

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
import org.apache.lucene.index.DocValues;
2525
import org.apache.lucene.index.LeafReaderContext;
2626
import org.apache.lucene.index.NumericDocValues;
27-
import org.apache.lucene.index.VectorEncoding;
2827
import org.apache.lucene.search.comparators.DoubleComparator;
2928
import org.apache.lucene.util.NumericUtils;
3029

@@ -250,14 +249,6 @@ public LongValuesSource rewrite(IndexSearcher searcher) throws IOException {
250249
*/
251250
public static DoubleValues similarityToQueryVector(
252251
LeafReaderContext ctx, byte[] queryVector, String vectorField) throws IOException {
253-
if (ctx.reader().getFieldInfos().fieldInfo(vectorField).getVectorEncoding()
254-
!= VectorEncoding.BYTE) {
255-
throw new IllegalArgumentException(
256-
"Field "
257-
+ vectorField
258-
+ " does not have the expected vector encoding: "
259-
+ VectorEncoding.BYTE);
260-
}
261252
return new ByteVectorSimilarityValuesSource(queryVector, vectorField).getValues(ctx, null);
262253
}
263254

@@ -273,14 +264,6 @@ public static DoubleValues similarityToQueryVector(
273264
*/
274265
public static DoubleValues similarityToQueryVector(
275266
LeafReaderContext ctx, float[] queryVector, String vectorField) throws IOException {
276-
if (ctx.reader().getFieldInfos().fieldInfo(vectorField).getVectorEncoding()
277-
!= VectorEncoding.FLOAT32) {
278-
throw new IllegalArgumentException(
279-
"Field "
280-
+ vectorField
281-
+ " does not have the expected vector encoding: "
282-
+ VectorEncoding.FLOAT32);
283-
}
284267
return new FloatVectorSimilarityValuesSource(queryVector, vectorField).getValues(ctx, null);
285268
}
286269

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.lucene.search;
19+
20+
import java.io.IOException;
21+
import java.util.Arrays;
22+
import java.util.Objects;
23+
import org.apache.lucene.index.FieldInfo;
24+
import org.apache.lucene.index.FloatVectorValues;
25+
import org.apache.lucene.index.KnnVectorValues;
26+
import org.apache.lucene.index.LeafReaderContext;
27+
import org.apache.lucene.index.VectorSimilarityFunction;
28+
29+
/**
30+
* A {@link DoubleValuesSource} that computes vector similarity between a query vector and raw full
31+
* precision vectors indexed in provided {@link org.apache.lucene.document.KnnFloatVectorField} in
32+
* documents.
33+
*/
34+
public class FullPrecisionFloatVectorSimilarityValuesSource extends DoubleValuesSource {
35+
36+
private final float[] queryVector;
37+
private final String fieldName;
38+
private VectorSimilarityFunction vectorSimilarityFunction;
39+
40+
/**
41+
* Creates a {@link DoubleValuesSource} that returns vector similarity score between provided
42+
* query vector and field for documents.
43+
*
44+
* @param vector the query vector
45+
* @param fieldName the field name of the {@link org.apache.lucene.document.KnnFloatVectorField}
46+
* @param vectorSimilarityFunction the vector similarity function to use
47+
*/
48+
public FullPrecisionFloatVectorSimilarityValuesSource(
49+
float[] vector, String fieldName, VectorSimilarityFunction vectorSimilarityFunction) {
50+
this.queryVector = vector;
51+
this.fieldName = fieldName;
52+
this.vectorSimilarityFunction = vectorSimilarityFunction;
53+
}
54+
55+
/**
56+
* Creates a {@link DoubleValuesSource} that returns vector similarity score between provided
57+
* query vector and field for documents. Uses the configured vector similarity function for the
58+
* field.
59+
*
60+
* @param vector the query vector
61+
* @param fieldName the field name of the {@link org.apache.lucene.document.KnnFloatVectorField}
62+
*/
63+
public FullPrecisionFloatVectorSimilarityValuesSource(float[] vector, String fieldName) {
64+
this(vector, fieldName, null);
65+
}
66+
67+
/** Sugar to fetch full precision similarity score values */
68+
public DoubleValues getSimilarityScores(LeafReaderContext ctx) throws IOException {
69+
return getValues(ctx, null);
70+
}
71+
72+
@Override
73+
public DoubleValues getValues(LeafReaderContext ctx, DoubleValues scores) throws IOException {
74+
final FloatVectorValues vectorValues = ctx.reader().getFloatVectorValues(fieldName);
75+
if (vectorValues == null) {
76+
FloatVectorValues.checkField(ctx.reader(), fieldName);
77+
return DoubleValues.EMPTY;
78+
}
79+
final FieldInfo fi = ctx.reader().getFieldInfos().fieldInfo(fieldName);
80+
if (fi.getVectorDimension() != queryVector.length) {
81+
throw new IllegalArgumentException(
82+
"Query vector dimension does not match field dimension: "
83+
+ queryVector.length
84+
+ " != "
85+
+ fi.getVectorDimension());
86+
}
87+
88+
if (vectorSimilarityFunction == null) {
89+
this.vectorSimilarityFunction = fi.getVectorSimilarityFunction();
90+
}
91+
final KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator();
92+
return new DoubleValues() {
93+
@Override
94+
public double doubleValue() throws IOException {
95+
return vectorSimilarityFunction.compare(
96+
queryVector, vectorValues.vectorValue(iterator.index()));
97+
}
98+
99+
@Override
100+
public boolean advanceExact(int doc) throws IOException {
101+
return doc >= iterator.docID() && (iterator.docID() == doc || iterator.advance(doc) == doc);
102+
}
103+
};
104+
}
105+
106+
@Override
107+
public boolean needsScores() {
108+
return false;
109+
}
110+
111+
@Override
112+
public DoubleValuesSource rewrite(IndexSearcher reader) throws IOException {
113+
return this;
114+
}
115+
116+
@Override
117+
public int hashCode() {
118+
return Objects.hash(fieldName, Arrays.hashCode(queryVector), vectorSimilarityFunction);
119+
}
120+
121+
@Override
122+
public boolean equals(Object obj) {
123+
if (this == obj) return true;
124+
if (obj == null || getClass() != obj.getClass()) return false;
125+
FullPrecisionFloatVectorSimilarityValuesSource other =
126+
(FullPrecisionFloatVectorSimilarityValuesSource) obj;
127+
return Objects.equals(fieldName, other.fieldName)
128+
&& Objects.equals(vectorSimilarityFunction, other.vectorSimilarityFunction)
129+
&& Arrays.equals(queryVector, other.queryVector);
130+
}
131+
132+
@Override
133+
public String toString() {
134+
return "FullPrecisionFloatVectorSimilarityValuesSource(fieldName="
135+
+ fieldName
136+
+ " vectorSimilarityFunction="
137+
+ vectorSimilarityFunction.name()
138+
+ " queryVector="
139+
+ Arrays.toString(queryVector)
140+
+ ")";
141+
}
142+
143+
@Override
144+
public boolean isCacheable(LeafReaderContext ctx) {
145+
return true;
146+
}
147+
}

0 commit comments

Comments
 (0)