@@ -139,6 +139,10 @@ butil::Status VectorIndexUtils::CalcDistanceByFaiss(
139
139
return CalcCosineDistanceByFaiss (op_left_vectors, op_right_vectors, is_return_normlize, distances,
140
140
result_op_left_vectors, result_op_right_vectors);
141
141
}
142
+ case pb::common::METRIC_TYPE_HAMMING: {
143
+ return CalcHammingDistanceByFaiss (op_left_vectors, op_right_vectors, is_return_normlize, distances,
144
+ result_op_left_vectors, result_op_right_vectors);
145
+ }
142
146
case pb::common::METRIC_TYPE_NONE:
143
147
case pb::common::MetricType_INT_MIN_SENTINEL_DO_NOT_USE_:
144
148
case pb::common::MetricType_INT_MAX_SENTINEL_DO_NOT_USE_: {
@@ -213,6 +217,17 @@ butil::Status VectorIndexUtils::CalcCosineDistanceByFaiss(
213
217
result_op_right_vectors, DoCalcCosineDistanceByFaiss);
214
218
}
215
219
220
+ butil::Status VectorIndexUtils::CalcHammingDistanceByFaiss (
221
+ const google::protobuf::RepeatedPtrField<::dingodb::pb::common::Vector>& op_left_vectors,
222
+ const google::protobuf::RepeatedPtrField<::dingodb::pb::common::Vector>& op_right_vectors, bool is_return_normlize,
223
+ std::vector<std::vector<float >>& distances, // NOLINT
224
+ std::vector<::dingodb::pb::common::Vector>& result_op_left_vectors, // NOLINT
225
+ std::vector<::dingodb::pb::common::Vector>& result_op_right_vectors) // NOLINT
226
+ { // NOLINT
227
+ return CalcDistanceCore (op_left_vectors, op_right_vectors, is_return_normlize, distances, result_op_left_vectors,
228
+ result_op_right_vectors, DoCalcHammingDistanceByFaiss);
229
+ }
230
+
216
231
butil::Status VectorIndexUtils::CalcL2DistanceByHnswlib (
217
232
const google::protobuf::RepeatedPtrField<::dingodb::pb::common::Vector>& op_left_vectors,
218
233
const google::protobuf::RepeatedPtrField<::dingodb::pb::common::Vector>& op_right_vectors, bool is_return_normlize,
@@ -307,6 +322,33 @@ butil::Status VectorIndexUtils::DoCalcCosineDistanceByFaiss(
307
322
return butil::Status ();
308
323
}
309
324
325
+ butil::Status VectorIndexUtils::DoCalcHammingDistanceByFaiss (
326
+ const ::dingodb::pb::common::Vector& op_left_vectors, const ::dingodb::pb::common::Vector& op_right_vectors,
327
+ bool is_return_normlize,
328
+ float & distance, // NOLINT
329
+ dingodb::pb::common::Vector& result_op_left_vectors, // NOLINT
330
+ dingodb::pb::common::Vector& result_op_right_vectors) // NOLINT
331
+ { // NOLINT
332
+ faiss::VectorDistance<faiss::MetricType::METRIC_HAMMING> vector_distance;
333
+ vector_distance.d = op_left_vectors.binary_values ().size ();
334
+
335
+ std::vector<uint8_t > left_vectors = std::vector<uint8_t >(op_left_vectors.binary_values ().size ());
336
+ for (int j = 0 ; j < op_left_vectors.binary_values ().size (); j++) {
337
+ left_vectors[j] = static_cast <uint8_t >(op_left_vectors.binary_values ()[j][0 ]);
338
+ }
339
+ std::vector<uint8_t > right_vectors = std::vector<uint8_t >(op_right_vectors.binary_values ().size ());
340
+ for (int j = 0 ; j < op_right_vectors.binary_values ().size (); j++) {
341
+ right_vectors[j] = static_cast <uint8_t >(op_right_vectors.binary_values ()[j][0 ]);
342
+ }
343
+
344
+ distance = vector_distance (left_vectors.data (), right_vectors.data ());
345
+
346
+ ResultOpBinaryVectorAssignmentWrapper (op_left_vectors, op_right_vectors, is_return_normlize, result_op_left_vectors,
347
+ result_op_right_vectors);
348
+
349
+ return butil::Status ();
350
+ }
351
+
310
352
butil::Status VectorIndexUtils::DoCalcL2DistanceByHnswlib (
311
353
const ::dingodb::pb::common::Vector& op_left_vectors, const ::dingodb::pb::common::Vector& op_right_vectors,
312
354
bool is_return_normlize,
@@ -386,6 +428,13 @@ void VectorIndexUtils::ResultOpVectorAssignment(dingodb::pb::common::Vector& res
386
428
result_op_vectors.set_value_type (::dingodb::pb::common::ValueType::FLOAT);
387
429
}
388
430
431
+ void VectorIndexUtils::ResultOpBinaryVectorAssignment (dingodb::pb::common::Vector& result_op_vectors,
432
+ const ::dingodb::pb::common::Vector& op_vectors) {
433
+ result_op_vectors = op_vectors;
434
+ result_op_vectors.set_dimension (result_op_vectors.binary_values ().size () * CHAR_BIT);
435
+ result_op_vectors.set_value_type (::dingodb::pb::common::ValueType::UINT8);
436
+ }
437
+
389
438
void VectorIndexUtils::ResultOpVectorAssignmentWrapper (const ::dingodb::pb::common::Vector& op_left_vectors,
390
439
const ::dingodb::pb::common::Vector& op_right_vectors,
391
440
bool is_return_normlize,
@@ -403,6 +452,23 @@ void VectorIndexUtils::ResultOpVectorAssignmentWrapper(const ::dingodb::pb::comm
403
452
}
404
453
}
405
454
455
+ void VectorIndexUtils::ResultOpBinaryVectorAssignmentWrapper (
456
+ const ::dingodb::pb::common::Vector& op_left_vectors, const ::dingodb::pb::common::Vector& op_right_vectors,
457
+ bool is_return_normlize,
458
+ dingodb::pb::common::Vector& result_op_left_vectors, // NOLINT
459
+ dingodb::pb::common::Vector& result_op_right_vectors) // NOLINT
460
+ { // NOLINT
461
+ if (is_return_normlize) {
462
+ if (result_op_left_vectors.binary_values ().empty ()) {
463
+ ResultOpBinaryVectorAssignment (result_op_left_vectors, op_left_vectors);
464
+ }
465
+
466
+ if (result_op_right_vectors.binary_values ().empty ()) {
467
+ ResultOpBinaryVectorAssignment (result_op_right_vectors, op_right_vectors);
468
+ }
469
+ }
470
+ }
471
+
406
472
void VectorIndexUtils::NormalizeVectorForFaiss (float * x, int32_t d) {
407
473
static const float kFloatAccuracy = 0.00001 ;
408
474
@@ -446,6 +512,10 @@ butil::Status VectorIndexUtils::CheckVectorDimension(const std::vector<pb::commo
446
512
DINGO_LOG (ERROR) << s;
447
513
return butil::Status (pb::error::Errno::EVECTOR_INVALID, s);
448
514
}
515
+ if (vector_with_id.vector ().dimension () != dimension) {
516
+ std::string s = fmt::format (" vector dimension not match, {} {}" , vector_with_id.vector ().dimension (), dimension);
517
+ return butil::Status (pb::error::Errno::EVECTOR_INVALID, s);
518
+ }
449
519
}
450
520
451
521
return butil::Status::OK ();
@@ -486,9 +556,9 @@ template <typename T>
486
556
std::unique_ptr<T[]> VectorIndexUtils::ExtractVectorValue (const std::vector<pb::common::VectorWithId>& vector_with_ids,
487
557
faiss::idx_t dimension, bool normalize) {
488
558
std::unique_ptr<T[]> vectors = nullptr ;
489
- if (std::is_same<T, float >::value) {
559
+ if constexpr (std::is_same<T, float >::value) {
490
560
vectors = std::make_unique<T[]>(vector_with_ids.size () * dimension);
491
- } else if (std::is_same<T, uint8_t >::value) {
561
+ } else if constexpr (std::is_same<T, uint8_t >::value) {
492
562
vectors = std::make_unique<T[]>(vector_with_ids.size () * dimension / CHAR_BIT);
493
563
} else {
494
564
std::string s = fmt::format (" invalid value typename type" );
@@ -497,8 +567,8 @@ std::unique_ptr<T[]> VectorIndexUtils::ExtractVectorValue(const std::vector<pb::
497
567
}
498
568
499
569
for (size_t i = 0 ; i < vector_with_ids.size (); ++i) {
500
- if (vector_with_ids[i]. vector (). value_type () == pb::common::ValueType::FLOAT ) {
501
- if (!std::is_same<T, float >::value ) {
570
+ if constexpr (std::is_same<T, float >::value ) {
571
+ if (vector_with_ids[i]. vector (). value_type () != pb::common::ValueType::FLOAT ) {
502
572
std::string s = fmt::format (" template not match vectors value_type : {}" ,
503
573
pb::common::ValueType_Name (vector_with_ids[i].vector ().value_type ()));
504
574
DINGO_LOG (ERROR) << s;
@@ -509,15 +579,17 @@ std::unique_ptr<T[]> VectorIndexUtils::ExtractVectorValue(const std::vector<pb::
509
579
if (normalize) {
510
580
VectorIndexUtils::NormalizeVectorForFaiss (reinterpret_cast <float *>(vectors.get ()) + i * dimension, dimension);
511
581
}
512
- } else if (vector_with_ids[i]. vector (). value_type () == pb::common::ValueType::UINT8 ) {
513
- if (!std::is_same<T, uint8_t >::value ) {
582
+ } else if constexpr (std::is_same<T, uint8_t >::value ) {
583
+ if (vector_with_ids[i]. vector (). value_type () != pb::common::ValueType::UINT8 ) {
514
584
std::string s = fmt::format (" template not match vectors value_type : {}" ,
515
585
pb::common::ValueType_Name (vector_with_ids[i].vector ().value_type ()));
516
586
DINGO_LOG (ERROR) << s;
517
587
return nullptr ;
518
588
}
519
589
const auto & vector_value = vector_with_ids[i].vector ().binary_values ();
520
- memcpy (vectors.get () + i * dimension / CHAR_BIT, vector_value.data (), dimension / CHAR_BIT);
590
+ for (int j = 0 ; j < vector_value.size (); j++) {
591
+ vectors.get ()[i * dimension / CHAR_BIT + j] = static_cast <uint8_t >(vector_value[j][0 ]);
592
+ }
521
593
} else {
522
594
std::string s =
523
595
fmt::format (" invalid value type : {}" , pb::common::ValueType_Name (vector_with_ids[i].vector ().value_type ()));
@@ -855,8 +927,9 @@ butil::Status VectorIndexUtils::ValidateVectorIndexParameter(
855
927
!(ivf_flat_parameter.metric_type () == pb::common::METRIC_TYPE_INNER_PRODUCT) &&
856
928
!(ivf_flat_parameter.metric_type () == pb::common::METRIC_TYPE_L2)) {
857
929
DINGO_LOG (ERROR) << " ivf_flat_parameter.metric_type is illegal " << ivf_flat_parameter.metric_type ();
858
- return butil::Status (pb::error::Errno::EILLEGAL_PARAMTETERS,
859
- " ivf_flat_parameter.metric_type is illegal " + std::to_string (ivf_flat_parameter.metric_type ()));
930
+ return butil::Status (
931
+ pb::error::Errno::EILLEGAL_PARAMTETERS,
932
+ " ivf_flat_parameter.metric_type is illegal " + std::to_string (ivf_flat_parameter.metric_type ()));
860
933
}
861
934
862
935
// check ivf_flat_parameter.ncentroids
0 commit comments