Skip to content

Commit 4449fde

Browse files
Address Review Comments
Signed-off-by: Naveen Tatikonda <[email protected]>
1 parent 83e1754 commit 4449fde

File tree

5 files changed

+121
-133
lines changed

5 files changed

+121
-133
lines changed

jni/include/faiss_index_service.h

+5-2
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ class IndexService {
8787

8888
protected:
8989
virtual void allocIndex(faiss::Index * index, size_t dim, size_t numVectors);
90+
virtual jlong initAndAllocateIndex(std::unique_ptr<faiss::Index> &index, size_t threadCount, size_t dim, size_t numVectors);
9091

9192
std::unique_ptr<FaissMethods> faissMethods;
9293
}; // class IndexService
@@ -157,10 +158,11 @@ class BinaryIndexService final : public IndexService {
157158
* @param templateIndexJ template index
158159
* @return memory address of the native index object
159160
*/
160-
virtual jlong initIndexFromTemplate(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, int dim, int numVectors, int threadCount, jbyteArray templateIndexJ);
161+
jlong initIndexFromTemplate(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, int dim, int numVectors, int threadCount, jbyteArray templateIndexJ) final;
161162

162163
protected:
163164
void allocIndex(faiss::Index * index, size_t dim, size_t numVectors) final;
165+
jlong initAndAllocateIndex(std::unique_ptr<faiss::IndexBinary> &index, size_t threadCount, size_t dim, size_t numVectors);
164166
}; // class BinaryIndexService
165167

166168
/**
@@ -229,10 +231,11 @@ class ByteIndexService final : public IndexService {
229231
* @param templateIndexJ template index
230232
* @return memory address of the native index object
231233
*/
232-
virtual jlong initIndexFromTemplate(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, int dim, int numVectors, int threadCount, jbyteArray templateIndexJ);
234+
jlong initIndexFromTemplate(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, int dim, int numVectors, int threadCount, jbyteArray templateIndexJ) final;
233235

234236
protected:
235237
void allocIndex(faiss::Index * index, size_t dim, size_t numVectors) final;
238+
jlong initAndAllocateIndex(std::unique_ptr<faiss::Index> &index, size_t threadCount, size_t dim, size_t numVectors) final;
236239
}; // class ByteIndexService
237240

238241
}

jni/src/faiss_index_service.cpp

+58-87
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,25 @@ void IndexService::allocIndex(faiss::Index * index, size_t dim, size_t numVector
6868
}
6969
}
7070

71+
jlong IndexService::initAndAllocateIndex(std::unique_ptr<faiss::Index> &index, size_t threadCount, size_t dim, size_t numVectors) {
72+
// Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread
73+
if (threadCount != 0) {
74+
omp_set_num_threads(threadCount);
75+
}
76+
77+
std::unique_ptr<faiss::IndexIDMap> idMap (faissMethods->indexIdMap(index.get()));
78+
//Makes sure the index is deleted when the destructor is called, this cannot be passed in the constructor
79+
idMap->own_fields = true;
80+
81+
// TODO: allocIndex for IVF
82+
allocIndex(dynamic_cast<faiss::Index *>(idMap->index), dim, numVectors);
83+
84+
//Release the ownership so as to make sure not delete the underlying index that is created. The index is needed later
85+
//in insert and write operations
86+
index.release();
87+
return reinterpret_cast<jlong>(idMap.release());
88+
}
89+
7190
jlong IndexService::initIndex(
7291
knn_jni::JNIUtilInterface * jniUtil,
7392
JNIEnv * env,
@@ -81,11 +100,6 @@ jlong IndexService::initIndex(
81100
// Create index using Faiss factory method
82101
std::unique_ptr<faiss::Index> index(faissMethods->indexFactory(dim, indexDescription.c_str(), metric));
83102

84-
// Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread
85-
if (threadCount != 0) {
86-
omp_set_num_threads(threadCount);
87-
}
88-
89103
// Add extra parameters that cant be configured with the index factory
90104
SetExtraParameters<faiss::Index, faiss::IndexIVF, faiss::IndexHNSW>(jniUtil, env, parameters, index.get());
91105

@@ -94,16 +108,7 @@ jlong IndexService::initIndex(
94108
throw std::runtime_error("Index is not trained");
95109
}
96110

97-
std::unique_ptr<faiss::IndexIDMap> idMap (faissMethods->indexIdMap(index.get()));
98-
//Makes sure the index is deleted when the destructor is called, this cannot be passed in the constructor
99-
idMap->own_fields = true;
100-
101-
allocIndex(dynamic_cast<faiss::Index *>(idMap->index), dim, numVectors);
102-
103-
//Release the ownership so as to make sure not delete the underlying index that is created. The index is needed later
104-
//in insert and write operations
105-
index.release();
106-
return reinterpret_cast<jlong>(idMap.release());
111+
return initAndAllocateIndex(index, threadCount, dim, numVectors);
107112
}
108113

109114
void IndexService::insertToIndex(
@@ -178,16 +183,31 @@ jlong IndexService::initIndexFromTemplate(
178183
std::unique_ptr<faiss::Index> index;
179184
index.reset(faiss::read_index(&vectorIoReader, 0));
180185

186+
return initAndAllocateIndex(index, threadCount, dim, numVectors);
187+
}
188+
189+
BinaryIndexService::BinaryIndexService(std::unique_ptr<FaissMethods> _faissMethods)
190+
: IndexService(std::move(_faissMethods)) {
191+
}
192+
193+
void BinaryIndexService::allocIndex(faiss::Index * index, size_t dim, size_t numVectors) {
194+
if (auto * indexBinaryHNSW = dynamic_cast<faiss::IndexBinaryHNSW *>(index)) {
195+
auto * indexBinaryFlat = dynamic_cast<faiss::IndexBinaryFlat *>(indexBinaryHNSW->storage);
196+
indexBinaryFlat->xb.reserve(dim * numVectors / 8);
197+
}
198+
}
199+
200+
jlong BinaryIndexService::initAndAllocateIndex(std::unique_ptr<faiss::IndexBinary> &index, size_t threadCount, size_t dim, size_t numVectors) {
181201
// Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread
182202
if (threadCount != 0) {
183203
omp_set_num_threads(threadCount);
184204
}
185205

186-
std::unique_ptr<faiss::IndexIDMap> idMap (faissMethods->indexIdMap(index.get()));
206+
std::unique_ptr<faiss::IndexBinaryIDMap> idMap(faissMethods->indexBinaryIdMap(index.get()));
187207
//Makes sure the index is deleted when the destructor is called, this cannot be passed in the constructor
188208
idMap->own_fields = true;
189209

190-
// TODO: allocIndex
210+
// TODO: allocIndex for IVF
191211
allocIndex(dynamic_cast<faiss::Index *>(idMap->index), dim, numVectors);
192212

193213
//Release the ownership so as to make sure not delete the underlying index that is created. The index is needed later
@@ -196,17 +216,6 @@ jlong IndexService::initIndexFromTemplate(
196216
return reinterpret_cast<jlong>(idMap.release());
197217
}
198218

199-
BinaryIndexService::BinaryIndexService(std::unique_ptr<FaissMethods> _faissMethods)
200-
: IndexService(std::move(_faissMethods)) {
201-
}
202-
203-
void BinaryIndexService::allocIndex(faiss::Index * index, size_t dim, size_t numVectors) {
204-
if (auto * indexBinaryHNSW = dynamic_cast<faiss::IndexBinaryHNSW *>(index)) {
205-
auto * indexBinaryFlat = dynamic_cast<faiss::IndexBinaryFlat *>(indexBinaryHNSW->storage);
206-
indexBinaryFlat->xb.reserve(dim * numVectors / 8);
207-
}
208-
}
209-
210219
jlong BinaryIndexService::initIndex(
211220
knn_jni::JNIUtilInterface * jniUtil,
212221
JNIEnv * env,
@@ -219,10 +228,6 @@ jlong BinaryIndexService::initIndex(
219228
) {
220229
// Create index using Faiss factory method
221230
std::unique_ptr<faiss::IndexBinary> index(faissMethods->indexBinaryFactory(dim, indexDescription.c_str()));
222-
// Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread
223-
if (threadCount != 0) {
224-
omp_set_num_threads(threadCount);
225-
}
226231

227232
// Add extra parameters that cant be configured with the index factory
228233
SetExtraParameters<faiss::IndexBinary, faiss::IndexBinaryIVF, faiss::IndexBinaryHNSW>(jniUtil, env, parameters, index.get());
@@ -232,16 +237,7 @@ jlong BinaryIndexService::initIndex(
232237
throw std::runtime_error("Index is not trained");
233238
}
234239

235-
std::unique_ptr<faiss::IndexBinaryIDMap> idMap(faissMethods->indexBinaryIdMap(index.get()));
236-
//Makes sure the index is deleted when the destructor is called
237-
idMap->own_fields = true;
238-
239-
allocIndex(dynamic_cast<faiss::Index *>(idMap->index), dim, numVectors);
240-
241-
//Release the ownership so as to make sure not delete the underlying index that is created. The index is needed later
242-
//in insert and write operations
243-
index.release();
244-
return reinterpret_cast<jlong>(idMap.release());
240+
return initAndAllocateIndex(index, threadCount, dim, numVectors);
245241
}
246242

247243
void BinaryIndexService::insertToIndex(
@@ -319,16 +315,32 @@ jlong BinaryIndexService::initIndexFromTemplate(
319315
std::unique_ptr<faiss::IndexBinary> index;
320316
index.reset(faiss::read_index_binary(&vectorIoReader, 0));
321317

318+
return initAndAllocateIndex(index, threadCount, dim, numVectors);
319+
}
320+
321+
ByteIndexService::ByteIndexService(std::unique_ptr<FaissMethods> _faissMethods)
322+
: IndexService(std::move(_faissMethods)) {
323+
}
324+
325+
void ByteIndexService::allocIndex(faiss::Index * index, size_t dim, size_t numVectors) {
326+
if (auto * indexHNSWSQ = dynamic_cast<faiss::IndexHNSWSQ *>(index)) {
327+
if(auto * indexScalarQuantizer = dynamic_cast<faiss::IndexScalarQuantizer *>(indexHNSWSQ->storage)) {
328+
indexScalarQuantizer->codes.reserve(indexScalarQuantizer->code_size * numVectors);
329+
}
330+
}
331+
}
332+
333+
jlong ByteIndexService::initAndAllocateIndex(std::unique_ptr<faiss::Index> &index, size_t threadCount, size_t dim, size_t numVectors) {
322334
// Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread
323335
if (threadCount != 0) {
324336
omp_set_num_threads(threadCount);
325337
}
326338

327-
std::unique_ptr<faiss::IndexBinaryIDMap> idMap (faissMethods->indexBinaryIdMap(index.get()));
339+
std::unique_ptr<faiss::IndexIDMap> idMap (faissMethods->indexIdMap(index.get()));
328340
//Makes sure the index is deleted when the destructor is called, this cannot be passed in the constructor
329341
idMap->own_fields = true;
330342

331-
// TODO: allocIndex
343+
// TODO: allocIndex for IVF
332344
allocIndex(dynamic_cast<faiss::Index *>(idMap->index), dim, numVectors);
333345

334346
//Release the ownership so as to make sure not delete the underlying index that is created. The index is needed later
@@ -337,18 +349,6 @@ jlong BinaryIndexService::initIndexFromTemplate(
337349
return reinterpret_cast<jlong>(idMap.release());
338350
}
339351

340-
ByteIndexService::ByteIndexService(std::unique_ptr<FaissMethods> _faissMethods)
341-
: IndexService(std::move(_faissMethods)) {
342-
}
343-
344-
void ByteIndexService::allocIndex(faiss::Index * index, size_t dim, size_t numVectors) {
345-
if (auto * indexHNSWSQ = dynamic_cast<faiss::IndexHNSWSQ *>(index)) {
346-
if(auto * indexScalarQuantizer = dynamic_cast<faiss::IndexScalarQuantizer *>(indexHNSWSQ->storage)) {
347-
indexScalarQuantizer->codes.reserve(indexScalarQuantizer->code_size * numVectors);
348-
}
349-
}
350-
}
351-
352352
jlong ByteIndexService::initIndex(
353353
knn_jni::JNIUtilInterface * jniUtil,
354354
JNIEnv * env,
@@ -362,11 +362,6 @@ jlong ByteIndexService::initIndex(
362362
// Create index using Faiss factory method
363363
std::unique_ptr<faiss::Index> index(faissMethods->indexFactory(dim, indexDescription.c_str(), metric));
364364

365-
// Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread
366-
if (threadCount != 0) {
367-
omp_set_num_threads(threadCount);
368-
}
369-
370365
// Add extra parameters that cant be configured with the index factory
371366
SetExtraParameters<faiss::Index, faiss::IndexIVF, faiss::IndexHNSW>(jniUtil, env, parameters, index.get());
372367

@@ -375,16 +370,7 @@ jlong ByteIndexService::initIndex(
375370
throw std::runtime_error("Index is not trained");
376371
}
377372

378-
std::unique_ptr<faiss::IndexIDMap> idMap (faissMethods->indexIdMap(index.get()));
379-
//Makes sure the index is deleted when the destructor is called, this cannot be passed in the constructor
380-
idMap->own_fields = true;
381-
382-
allocIndex(dynamic_cast<faiss::Index *>(idMap->index), dim, numVectors);
383-
384-
//Release the ownership so as to make sure not delete the underlying index that is created. The index is needed later
385-
//in insert and write operations
386-
index.release();
387-
return reinterpret_cast<jlong>(idMap.release());
373+
return initAndAllocateIndex(index, threadCount, dim, numVectors);
388374
}
389375

390376
void ByteIndexService::insertToIndex(
@@ -477,22 +463,7 @@ jlong ByteIndexService::initIndexFromTemplate(
477463
std::unique_ptr<faiss::Index> index;
478464
index.reset(faiss::read_index(&vectorIoReader, 0));
479465

480-
// Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread
481-
if (threadCount != 0) {
482-
omp_set_num_threads(threadCount);
483-
}
484-
485-
std::unique_ptr<faiss::IndexIDMap> idMap (faissMethods->indexIdMap(index.get()));
486-
//Makes sure the index is deleted when the destructor is called, this cannot be passed in the constructor
487-
idMap->own_fields = true;
488-
489-
// TODO: allocIndex
490-
allocIndex(dynamic_cast<faiss::Index *>(idMap->index), dim, numVectors);
491-
492-
//Release the ownership so as to make sure not delete the underlying index that is created. The index is needed later
493-
//in insert and write operations
494-
index.release();
495-
return reinterpret_cast<jlong>(idMap.release());
466+
return initAndAllocateIndex(index, threadCount, dim, numVectors);
496467
}
497468
} // namespace faiss_wrapper
498469
} // namesapce knn_jni

src/main/java/org/opensearch/knn/index/KNNIndexShard.java

+11-6
Original file line numberDiff line numberDiff line change
@@ -193,12 +193,7 @@ List<EngineFileContext> getEngineFileContexts(IndexReader indexReader, KNNEngine
193193
fileExtension,
194194
spaceType,
195195
modelId,
196-
quantizationConfig == QuantizationConfig.EMPTY
197-
? VectorDataType.get(
198-
fieldInfo.attributes().getOrDefault(VECTOR_DATA_TYPE_FIELD, VectorDataType.FLOAT.getValue())
199-
)
200-
: quantizationConfig.getQuantizationType() == ScalarQuantizationType.EIGHT_BIT ? VectorDataType.BYTE
201-
: VectorDataType.BINARY
196+
getVectorDataType(quantizationConfig, fieldInfo)
202197
)
203198
);
204199
}
@@ -228,6 +223,16 @@ List<EngineFileContext> getEngineFileContexts(
228223
.collect(Collectors.toList());
229224
}
230225

226+
private VectorDataType getVectorDataType(QuantizationConfig quantizationConfig, FieldInfo fieldInfo) {
227+
if (quantizationConfig == QuantizationConfig.EMPTY) {
228+
return VectorDataType.get(fieldInfo.attributes().getOrDefault(VECTOR_DATA_TYPE_FIELD, VectorDataType.FLOAT.getValue()));
229+
}
230+
if (quantizationConfig.getQuantizationType() == ScalarQuantizationType.EIGHT_BIT) {
231+
return VectorDataType.BYTE;
232+
}
233+
return VectorDataType.BINARY;
234+
}
235+
231236
@AllArgsConstructor
232237
@Getter
233238
@VisibleForTesting

src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java

+16-13
Original file line numberDiff line numberDiff line change
@@ -257,18 +257,8 @@ private QuantizationState train(
257257
if (quantizationParams != null && totalLiveDocs > 0) {
258258
KNNVectorValues<?> knnVectorValues = knnVectorValuesSupplier.get();
259259

260-
// We will not be writing quantization state for 8 bits into a segment file because we are not quantizing the query vectors and
261-
// we are storing the template index after training in the quantization state to use it later in the index build strategy for
262-
// ingesting data
263-
if ((quantizationParams.getTypeIdentifier()).equals(
264-
ScalarQuantizationParams.generateTypeIdentifier(ScalarQuantizationType.EIGHT_BIT)
265-
)) {
266-
quantizationState = quantizationService.train(quantizationParams, knnVectorValues, totalLiveDocs, fieldInfo);
267-
} else {
268-
initQuantizationStateWriterIfNecessary();
269-
quantizationState = quantizationService.train(quantizationParams, knnVectorValues, totalLiveDocs, fieldInfo);
270-
quantizationStateWriter.writeState(fieldInfo.getFieldNumber(), quantizationState);
271-
}
260+
quantizationState = quantizationService.train(quantizationParams, knnVectorValues, totalLiveDocs, fieldInfo);
261+
writeQuantizationState(quantizationParams, quantizationState, fieldInfo.getFieldNumber());
272262
}
273263

274264
return quantizationState;
@@ -289,11 +279,24 @@ private int getLiveDocs(KNNVectorValues<?> vectorValues) throws IOException {
289279
return liveDocs;
290280
}
291281

292-
private void initQuantizationStateWriterIfNecessary() throws IOException {
282+
private void writeQuantizationState(QuantizationParams quantizationParams, QuantizationState quantizationState, int fieldNumber)
283+
throws IOException {
284+
285+
// We will not be writing quantization state for 8 bits into a segment file because we are not quantizing the query vectors and
286+
// we are storing the template index after training in the quantization state to use it later in the index build strategy for
287+
// ingesting data
288+
if ((quantizationParams.getTypeIdentifier()).equals(
289+
ScalarQuantizationParams.generateTypeIdentifier(ScalarQuantizationType.EIGHT_BIT)
290+
)) {
291+
return;
292+
}
293+
294+
// Initialize quantization state writer if required
293295
if (quantizationStateWriter == null) {
294296
quantizationStateWriter = new KNN990QuantizationStateWriter(segmentWriteState);
295297
quantizationStateWriter.writeHeader(segmentWriteState);
296298
}
299+
quantizationStateWriter.writeState(fieldNumber, quantizationState);
297300
}
298301

299302
private boolean shouldSkipBuildingVectorDataStructure(final long docCount) {

0 commit comments

Comments
 (0)