Skip to content

Commit 6df6b06

Browse files
committed
Refactoring of resolution logic
PR changes a lot of the resolution logic and does some renaming. Signed-off-by: John Mazanec <[email protected]>
1 parent 28e25b5 commit 6df6b06

File tree

71 files changed

+1545
-1922
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

71 files changed

+1545
-1922
lines changed

src/main/java/org/opensearch/knn/common/FieldInfoExtractor.java

Lines changed: 40 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import org.opensearch.knn.index.SpaceType;
1212
import org.opensearch.knn.index.VectorDataType;
1313
import org.opensearch.knn.index.engine.KNNEngine;
14+
import org.opensearch.knn.index.quantizationService.QuantizationService;
1415
import org.opensearch.knn.indices.ModelMetadata;
1516
import org.opensearch.knn.indices.ModelUtil;
1617

@@ -21,6 +22,7 @@
2122

2223
import static org.opensearch.knn.common.KNNConstants.QFRAMEWORK_CONFIG;
2324
import org.opensearch.knn.indices.ModelDao;
25+
import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams;
2426

2527
import java.util.Locale;
2628

@@ -47,20 +49,43 @@ public static KNNEngine extractKNNEngine(final FieldInfo field) {
4749
}
4850

4951
/**
50-
* Extracts VectorDataType from FieldInfo
52+
* Extracts VectorDataType from FieldInfo. This VectorDataType represents what vectors will be input to the
53+
* library layer. For the data type that is transfered to the native layer, see extractVectorDataTypeForTransfer (better comment)
54+
*
5155
* @param fieldInfo {@link FieldInfo}
5256
* @return {@link VectorDataType}
5357
*/
5458
public static VectorDataType extractVectorDataType(final FieldInfo fieldInfo) {
5559
String vectorDataTypeString = fieldInfo.getAttribute(KNNConstants.VECTOR_DATA_TYPE_FIELD);
56-
if (StringUtils.isEmpty(vectorDataTypeString)) {
57-
final ModelMetadata modelMetadata = ModelUtil.getModelMetadata(fieldInfo.getAttribute(KNNConstants.MODEL_ID));
58-
if (modelMetadata != null) {
59-
VectorDataType vectorDataType = modelMetadata.getVectorDataType();
60-
vectorDataTypeString = vectorDataType == null ? null : vectorDataType.getValue();
61-
}
60+
if (StringUtils.isNotEmpty(vectorDataTypeString)) {
61+
return VectorDataType.get(vectorDataTypeString);
62+
}
63+
64+
final ModelMetadata modelMetadata = ModelUtil.getModelMetadata(fieldInfo.getAttribute(KNNConstants.MODEL_ID));
65+
if (modelMetadata == null) {
66+
return VectorDataType.DEFAULT;
67+
}
68+
return modelMetadata.getVectorDataType();
69+
}
70+
71+
/**
72+
* Extracts VectorDataType for transfer from FieldInfo. This VectorDataType represents what vectors will be transfered
73+
* to the native layer. For the data type that is input to the library layer, see extractVectorDataType (better comment)
74+
*
75+
* @param fieldInfo {@link FieldInfo}
76+
* @param quantizationParams {@link QuantizationParams}
77+
* @return {@link VectorDataType}
78+
*/
79+
public static VectorDataType extractVectorDataTypeForTransfer(final FieldInfo fieldInfo, QuantizationParams quantizationParams) {
80+
if (quantizationParams != null) {
81+
return QuantizationService.getInstance().getVectorDataTypeForTransfer(fieldInfo);
6282
}
63-
return StringUtils.isNotEmpty(vectorDataTypeString) ? VectorDataType.get(vectorDataTypeString) : VectorDataType.DEFAULT;
83+
QuantizationConfig quantizationConfig = extractQuantizationConfig(fieldInfo);
84+
if (quantizationConfig != null && quantizationConfig != QuantizationConfig.EMPTY) {
85+
return VectorDataType.BINARY;
86+
}
87+
88+
return extractVectorDataType(fieldInfo);
6489
}
6590

6691
/**
@@ -71,10 +96,15 @@ public static VectorDataType extractVectorDataType(final FieldInfo fieldInfo) {
7196
*/
7297
public static QuantizationConfig extractQuantizationConfig(final FieldInfo fieldInfo) {
7398
String quantizationConfigString = fieldInfo.getAttribute(QFRAMEWORK_CONFIG);
74-
if (StringUtils.isEmpty(quantizationConfigString)) {
99+
if (StringUtils.isNotEmpty(quantizationConfigString)) {
100+
return QuantizationConfigParser.fromCsv(quantizationConfigString);
101+
}
102+
103+
final ModelMetadata modelMetadata = ModelUtil.getModelMetadata(fieldInfo.getAttribute(KNNConstants.MODEL_ID));
104+
if (modelMetadata == null || modelMetadata.getKNNLibraryIndex().isEmpty()) {
75105
return QuantizationConfig.EMPTY;
76106
}
77-
return QuantizationConfigParser.fromCsv(quantizationConfigString);
107+
return modelMetadata.getKNNLibraryIndex().get().getQuantizationConfig();
78108
}
79109

80110
/**

src/main/java/org/opensearch/knn/index/KNNIndexShard.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@
3434
import java.util.concurrent.ExecutionException;
3535
import java.util.stream.Collectors;
3636

37+
import static org.opensearch.knn.common.FieldInfoExtractor.extractVectorDataTypeForTransfer;
3738
import static org.opensearch.knn.common.KNNConstants.MODEL_ID;
3839
import static org.opensearch.knn.common.KNNConstants.SPACE_TYPE;
39-
import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD;
4040
import static org.opensearch.knn.index.util.IndexUtil.getParametersAtLoading;
4141
import static org.opensearch.knn.index.codec.util.KNNCodecUtil.buildEngineFilePrefix;
4242
import static org.opensearch.knn.index.codec.util.KNNCodecUtil.buildEngineFileSuffix;
@@ -182,7 +182,7 @@ List<EngineFileContext> getEngineFileContexts(IndexReader indexReader, KNNEngine
182182
shardPath,
183183
spaceType,
184184
modelId,
185-
VectorDataType.get(fieldInfo.attributes().getOrDefault(VECTOR_DATA_TYPE_FIELD, VectorDataType.FLOAT.getValue()))
185+
extractVectorDataTypeForTransfer(fieldInfo, null)
186186
)
187187
);
188188
}

src/main/java/org/opensearch/knn/index/codec/BasePerFieldKnnVectorsFormat.java

Lines changed: 39 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -77,56 +77,54 @@ public KnnVectorsFormat getKnnVectorsFormatForField(final String field) {
7777
).fieldType(field);
7878

7979
if (mappedFieldType.getModelId().isPresent()) {
80-
return getFormatForModelBasedIndices();
81-
}
82-
if (mappedFieldType.getKNNEngine() == null) {
83-
throw new IllegalStateException("Method config context cannot be empty");
80+
return getNativeEngines990KnnVectorsFormat();
8481
}
8582
return getFormatForMethodBasedIndices(mappedFieldType.getKNNEngine(), mappedFieldType.getLibraryParameters(), field);
8683
}
8784

88-
private KnnVectorsFormat getFormatForModelBasedIndices() {
89-
return new NativeEngines990KnnVectorsFormat(new Lucene99FlatVectorsFormat(FlatVectorScorerUtil.getLucene99FlatVectorsScorer()));
90-
}
91-
9285
private KnnVectorsFormat getFormatForMethodBasedIndices(KNNEngine knnEngine, Map<String, Object> params, String field) {
93-
if (knnEngine == KNNEngine.LUCENE) {
94-
if (params != null && params.containsKey(METHOD_ENCODER_PARAMETER)) {
95-
KNNScalarQuantizedVectorsFormatParams knnScalarQuantizedVectorsFormatParams = new KNNScalarQuantizedVectorsFormatParams(
96-
params,
97-
defaultMaxConnections,
98-
defaultBeamWidth
99-
);
100-
if (knnScalarQuantizedVectorsFormatParams.validate(params)) {
101-
log.debug(
102-
"Initialize KNN vector format for field [{}] with params [{}] = \"{}\", [{}] = \"{}\", [{}] = \"{}\", [{}] = \"{}\"",
103-
field,
104-
MAX_CONNECTIONS,
105-
knnScalarQuantizedVectorsFormatParams.getMaxConnections(),
106-
BEAM_WIDTH,
107-
knnScalarQuantizedVectorsFormatParams.getBeamWidth(),
108-
LUCENE_SQ_CONFIDENCE_INTERVAL,
109-
knnScalarQuantizedVectorsFormatParams.getConfidenceInterval(),
110-
LUCENE_SQ_BITS,
111-
knnScalarQuantizedVectorsFormatParams.getBits()
112-
);
113-
return scalarQuantizedVectorsFormatSupplier.apply(knnScalarQuantizedVectorsFormatParams);
114-
}
115-
}
86+
if (knnEngine != KNNEngine.LUCENE) {
87+
return getNativeEngines990KnnVectorsFormat();
88+
}
11689

117-
KNNVectorsFormatParams knnVectorsFormatParams = new KNNVectorsFormatParams(params, defaultMaxConnections, defaultBeamWidth);
118-
log.debug(
119-
"Initialize KNN vector format for field [{}] with params [{}] = \"{}\" and [{}] = \"{}\"",
120-
field,
121-
MAX_CONNECTIONS,
122-
knnVectorsFormatParams.getMaxConnections(),
123-
BEAM_WIDTH,
124-
knnVectorsFormatParams.getBeamWidth()
90+
// For Lucene, we need to properly configure the format because format initialization is when parameters are
91+
// set
92+
if (params != null && params.containsKey(METHOD_ENCODER_PARAMETER)) {
93+
KNNScalarQuantizedVectorsFormatParams knnScalarQuantizedVectorsFormatParams = new KNNScalarQuantizedVectorsFormatParams(
94+
params,
95+
defaultMaxConnections,
96+
defaultBeamWidth
12597
);
126-
return vectorsFormatSupplier.apply(knnVectorsFormatParams);
98+
if (knnScalarQuantizedVectorsFormatParams.validate(params)) {
99+
log.debug(
100+
"Initialize KNN vector format for field [{}] with params [{}] = \"{}\", [{}] = \"{}\", [{}] = \"{}\", [{}] = \"{}\"",
101+
field,
102+
MAX_CONNECTIONS,
103+
knnScalarQuantizedVectorsFormatParams.getMaxConnections(),
104+
BEAM_WIDTH,
105+
knnScalarQuantizedVectorsFormatParams.getBeamWidth(),
106+
LUCENE_SQ_CONFIDENCE_INTERVAL,
107+
knnScalarQuantizedVectorsFormatParams.getConfidenceInterval(),
108+
LUCENE_SQ_BITS,
109+
knnScalarQuantizedVectorsFormatParams.getBits()
110+
);
111+
return scalarQuantizedVectorsFormatSupplier.apply(knnScalarQuantizedVectorsFormatParams);
112+
}
127113
}
128114

129-
// All native engines to use NativeEngines990KnnVectorsFormat
115+
KNNVectorsFormatParams knnVectorsFormatParams = new KNNVectorsFormatParams(params, defaultMaxConnections, defaultBeamWidth);
116+
log.debug(
117+
"Initialize KNN vector format for field [{}] with params [{}] = \"{}\" and [{}] = \"{}\"",
118+
field,
119+
MAX_CONNECTIONS,
120+
knnVectorsFormatParams.getMaxConnections(),
121+
BEAM_WIDTH,
122+
knnVectorsFormatParams.getBeamWidth()
123+
);
124+
return vectorsFormatSupplier.apply(knnVectorsFormatParams);
125+
}
126+
127+
private NativeEngines990KnnVectorsFormat getNativeEngines990KnnVectorsFormat() {
130128
return new NativeEngines990KnnVectorsFormat(new Lucene99FlatVectorsFormat(FlatVectorScorerUtil.getLucene99FlatVectorsScorer()));
131129
}
132130

src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@
2525

2626
import java.io.IOException;
2727

28-
import static org.opensearch.knn.common.FieldInfoExtractor.extractKNNEngine;
2928
import static org.opensearch.knn.common.FieldInfoExtractor.extractVectorDataType;
29+
import static org.opensearch.knn.common.FieldInfoExtractor.extractKNNEngine;
3030

3131
/**
3232
* This class writes the KNN docvalues to the segments

src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriter.java

Lines changed: 40 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,8 @@
2424
import org.opensearch.knn.index.VectorDataType;
2525
import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams;
2626
import org.opensearch.knn.index.engine.KNNEngine;
27-
import org.opensearch.knn.index.engine.KNNIndexContext;
28-
import org.opensearch.knn.index.quantizationService.QuantizationService;
29-
import org.opensearch.knn.index.util.IndexUtil;
3027
import org.opensearch.knn.index.vectorvalues.KNNVectorValues;
3128
import org.opensearch.knn.indices.Model;
32-
import org.opensearch.knn.indices.ModelCache;
3329
import org.opensearch.knn.indices.ModelUtil;
3430
import org.opensearch.knn.plugin.stats.KNNGraphValue;
3531
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;
@@ -47,7 +43,7 @@
4743
import static org.apache.lucene.codecs.CodecUtil.FOOTER_MAGIC;
4844
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
4945
import static org.opensearch.knn.common.FieldInfoExtractor.extractKNNEngine;
50-
import static org.opensearch.knn.common.FieldInfoExtractor.extractVectorDataType;
46+
import static org.opensearch.knn.common.FieldInfoExtractor.extractVectorDataTypeForTransfer;
5147
import static org.opensearch.knn.common.KNNConstants.MODEL_ID;
5248
import static org.opensearch.knn.common.KNNConstants.PARAMETERS;
5349
import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD;
@@ -161,17 +157,14 @@ private void buildAndWriteIndex(final KNNVectorValues<?> knnVectorValues) throws
161157
// TODO: Refactor this so its scalable. Possibly move it out of this class
162158
private BuildIndexParams indexParams(FieldInfo fieldInfo, String indexPath, KNNEngine knnEngine) throws IOException {
163159
final Map<String, Object> parameters;
164-
VectorDataType vectorDataType;
165-
if (quantizationState != null) {
166-
vectorDataType = QuantizationService.getInstance().getVectorDataTypeForTransfer(fieldInfo);
167-
} else {
168-
vectorDataType = extractVectorDataType(fieldInfo);
169-
}
170-
if (fieldInfo.attributes().containsKey(MODEL_ID)) {
171-
Model model = getModel(fieldInfo);
172-
parameters = getTemplateParameters(fieldInfo, model);
173-
} else {
160+
VectorDataType vectorDataType = extractVectorDataTypeForTransfer(
161+
fieldInfo,
162+
quantizationState == null ? null : quantizationState.getQuantizationParams()
163+
);
164+
if (fieldInfo.attributes().containsKey(MODEL_ID) == false) {
174165
parameters = getParameters(fieldInfo, vectorDataType, knnEngine);
166+
} else {
167+
parameters = getTemplateParameters(fieldInfo, vectorDataType);
175168
}
176169

177170
return BuildIndexParams.builder()
@@ -215,7 +208,6 @@ private Map<String, Object> getParameters(FieldInfo fieldInfo, VectorDataType ve
215208
);
216209
}
217210

218-
parameters.put(KNNConstants.VECTOR_DATA_TYPE_FIELD, vectorDataType.getValue());
219211
// In OpenSearch 2.16, we added the prefix for binary indices in the index description in the codec logic.
220212
// After 2.16, we added the binary prefix in the faiss library code. However, to ensure backwards compatibility,
221213
// we need to ensure that if the description does not contain the prefix but the type is binary, we add the
@@ -228,60 +220,20 @@ private Map<String, Object> getParameters(FieldInfo fieldInfo, VectorDataType ve
228220
return parameters;
229221
}
230222

231-
private void maybeAddBinaryPrefixForFaissBWC(KNNEngine knnEngine, Map<String, Object> parameters, Map<String, String> fieldAttributes) {
232-
if (KNNEngine.FAISS != knnEngine) {
233-
return;
234-
}
235-
236-
if (!VectorDataType.BINARY.getValue()
237-
.equals(fieldAttributes.getOrDefault(KNNConstants.VECTOR_DATA_TYPE_FIELD, VectorDataType.DEFAULT.getValue()))) {
238-
return;
239-
}
240-
241-
if (parameters.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER) == null) {
242-
return;
243-
}
244-
245-
if (parameters.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER).toString().startsWith(FAISS_BINARY_INDEX_DESCRIPTION_PREFIX)) {
246-
return;
223+
private Map<String, Object> getTemplateParameters(FieldInfo fieldInfo, VectorDataType vectorDataTypeForTransfer) {
224+
Model model = ModelUtil.getModel(fieldInfo.getAttribute(MODEL_ID));
225+
if (model == null) {
226+
throw new IllegalStateException("Model not found for field " + fieldInfo.name);
247227
}
248228

249-
parameters.put(
250-
KNNConstants.INDEX_DESCRIPTION_PARAMETER,
251-
FAISS_BINARY_INDEX_DESCRIPTION_PREFIX + parameters.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER).toString()
252-
);
253-
IndexUtil.updateVectorDataTypeToParameters(parameters, VectorDataType.BINARY);
254-
}
255-
256-
private Map<String, Object> getTemplateParameters(FieldInfo fieldInfo, Model model) throws IOException {
257229
Map<String, Object> parameters = new HashMap<>();
258230
parameters.put(KNNConstants.INDEX_THREAD_QTY, KNNSettings.state().getSettingValue(KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY));
259-
parameters.put(KNNConstants.MODEL_ID, fieldInfo.attributes().get(MODEL_ID));
231+
parameters.put(KNNConstants.MODEL_ID, model.getModelID());
260232
parameters.put(KNNConstants.MODEL_BLOB_PARAMETER, model.getModelBlob());
261-
262-
// TODO: Is there any way we could avoid resolving it like this?
263-
KNNIndexContext knnIndexContext = ModelUtil.getKnnMethodContextFromModelMetadata(model.getModelID(), model.getModelMetadata());
264-
if (knnIndexContext != null && knnIndexContext.getLibraryParameters().containsKey(VECTOR_DATA_TYPE_FIELD)) {
265-
IndexUtil.updateVectorDataTypeToParameters(
266-
parameters,
267-
VectorDataType.get((String) knnIndexContext.getLibraryParameters().get(VECTOR_DATA_TYPE_FIELD))
268-
);
269-
} else {
270-
IndexUtil.updateVectorDataTypeToParameters(parameters, model.getModelMetadata().getVectorDataType());
271-
}
272-
233+
parameters.put(VECTOR_DATA_TYPE_FIELD, vectorDataTypeForTransfer.getValue());
273234
return parameters;
274235
}
275236

276-
private Model getModel(FieldInfo fieldInfo) {
277-
String modelId = fieldInfo.attributes().get(MODEL_ID);
278-
Model model = ModelCache.getInstance().get(modelId);
279-
if (model.getModelBlob() == null) {
280-
throw new RuntimeException(String.format("There is no trained model with id \"%s\"", modelId));
281-
}
282-
return model;
283-
}
284-
285237
private void startMergeStats(int numDocs, long bytesPerVector) {
286238
KNNGraphValue.MERGE_CURRENT_OPERATIONS.increment();
287239
KNNGraphValue.MERGE_CURRENT_DOCS.incrementBy(numDocs);
@@ -358,4 +310,30 @@ private static NativeIndexWriter createWriter(
358310
: DefaultIndexBuildStrategy.getInstance();
359311
return new NativeIndexWriter(state, fieldInfo, strategy, quantizationState);
360312
}
313+
314+
private void maybeAddBinaryPrefixForFaissBWC(KNNEngine knnEngine, Map<String, Object> parameters, Map<String, String> fieldAttributes) {
315+
if (KNNEngine.FAISS != knnEngine) {
316+
return;
317+
}
318+
319+
if (!VectorDataType.BINARY.getValue()
320+
.equals(fieldAttributes.getOrDefault(KNNConstants.VECTOR_DATA_TYPE_FIELD, VectorDataType.DEFAULT.getValue()))) {
321+
return;
322+
}
323+
324+
if (parameters.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER) == null) {
325+
return;
326+
}
327+
328+
if (parameters.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER).toString().startsWith(FAISS_BINARY_INDEX_DESCRIPTION_PREFIX)) {
329+
return;
330+
}
331+
332+
parameters.put(
333+
KNNConstants.INDEX_DESCRIPTION_PARAMETER,
334+
FAISS_BINARY_INDEX_DESCRIPTION_PREFIX + parameters.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER).toString()
335+
);
336+
337+
parameters.put(VECTOR_DATA_TYPE_FIELD, VectorDataType.BINARY.getValue());
338+
}
361339
}

src/main/java/org/opensearch/knn/index/codec/nativeindex/model/BuildIndexParams.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@ public class BuildIndexParams {
2222
String fieldName;
2323
KNNEngine knnEngine;
2424
String indexPath;
25+
/**
26+
* Vector data type represents the type used to build the library index. If something like binary quantization is
27+
* done, then this will be different from the vector data type the user provides
28+
*/
2529
VectorDataType vectorDataType;
2630
Map<String, Object> parameters;
2731
/**

0 commit comments

Comments
 (0)