Skip to content

Commit 8c5f7fd

Browse files
committed
[feat][store]Add calculate hamming distance
1 parent 463fcee commit 8c5f7fd

File tree

5 files changed

+219
-21
lines changed

5 files changed

+219
-21
lines changed

src/engine/storage.cc

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
#include "engine/storage.h"
1616

17+
#include <climits>
1718
#include <cstdint>
1819
#include <limits>
1920
#include <string>
@@ -839,23 +840,50 @@ butil::Status Storage::VectorCalcDistance(const ::dingodb::pb::index::VectorCalc
839840
}
840841

841842
int64_t dimension = 0;
843+
auto value_type = op_left_vectors[0].value_type();
842844

843-
auto lambda_op_vector_check_function = [&dimension](const auto& op_vector, const std::string& name) {
845+
auto lambda_op_vector_check_function = [&dimension,&value_type](const auto& op_vector, const std::string& name) {
844846
if (!op_vector.empty()) {
845847
size_t i = 0;
846848
for (const auto& vector : op_vector) {
847-
int64_t current_dimension = static_cast<int64_t>(vector.float_values().size());
848-
if (0 == dimension) {
849-
dimension = current_dimension;
849+
if(vector.value_type() != value_type) {
850+
std::string s = fmt::format("{} index : {} value_type : {} unequal value_type : {}", name, i,
851+
::dingodb::pb::common::ValueType_Name(value_type),
852+
::dingodb::pb::common::ValueType_Name(vector.value_type()));
853+
LOG(ERROR) << s;
854+
return butil::Status(pb::error::EILLEGAL_PARAMTETERS, s);
850855
}
851-
852-
if (dimension != current_dimension) {
853-
std::string s = fmt::format("{} index : {} dimension : {} unequal current_dimension : {}", name, i,
854-
dimension, current_dimension);
856+
if (vector.value_type() == ::dingodb::pb::common::ValueType::FLOAT) {
857+
int64_t current_dimension = static_cast<int64_t>(vector.float_values().size());
858+
if (0 == dimension) {
859+
dimension = current_dimension;
860+
}
861+
862+
if (dimension != current_dimension) {
863+
std::string s = fmt::format("{} float index : {} dimension : {} unequal current_dimension : {}", name, i,
864+
dimension, current_dimension);
865+
LOG(ERROR) << s;
866+
return butil::Status(pb::error::EILLEGAL_PARAMTETERS, s);
867+
}
868+
i++;
869+
} else if (vector.value_type() == ::dingodb::pb::common::ValueType::UINT8) {
870+
int64_t current_dimension = static_cast<int64_t>(vector.binary_values().size());
871+
if (0 == dimension) {
872+
dimension = current_dimension;
873+
}
874+
875+
if (dimension != current_dimension) {
876+
std::string s = fmt::format("{} binary index : {} dimension : {} unequal current_dimension : {}", name, i,
877+
dimension * CHAR_BIT, current_dimension * CHAR_BIT);
878+
LOG(ERROR) << s;
879+
return butil::Status(pb::error::EILLEGAL_PARAMTETERS, s);
880+
}
881+
i++;
882+
} else {
883+
std::string s = fmt::format("{} index : {} value_type : VALUE_TYPE_NONE", name, i);
855884
LOG(ERROR) << s;
856885
return butil::Status(pb::error::EILLEGAL_PARAMTETERS, s);
857886
}
858-
i++;
859887
}
860888
}
861889

src/vector/vector_index_utils.cc

Lines changed: 82 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,10 @@ butil::Status VectorIndexUtils::CalcDistanceByFaiss(
139139
return CalcCosineDistanceByFaiss(op_left_vectors, op_right_vectors, is_return_normlize, distances,
140140
result_op_left_vectors, result_op_right_vectors);
141141
}
142+
case pb::common::METRIC_TYPE_HAMMING: {
143+
return CalcHammingDistanceByFaiss(op_left_vectors, op_right_vectors, is_return_normlize, distances,
144+
result_op_left_vectors, result_op_right_vectors);
145+
}
142146
case pb::common::METRIC_TYPE_NONE:
143147
case pb::common::MetricType_INT_MIN_SENTINEL_DO_NOT_USE_:
144148
case pb::common::MetricType_INT_MAX_SENTINEL_DO_NOT_USE_: {
@@ -213,6 +217,17 @@ butil::Status VectorIndexUtils::CalcCosineDistanceByFaiss(
213217
result_op_right_vectors, DoCalcCosineDistanceByFaiss);
214218
}
215219

220+
butil::Status VectorIndexUtils::CalcHammingDistanceByFaiss(
221+
const google::protobuf::RepeatedPtrField<::dingodb::pb::common::Vector>& op_left_vectors,
222+
const google::protobuf::RepeatedPtrField<::dingodb::pb::common::Vector>& op_right_vectors, bool is_return_normlize,
223+
std::vector<std::vector<float>>& distances, // NOLINT
224+
std::vector<::dingodb::pb::common::Vector>& result_op_left_vectors, // NOLINT
225+
std::vector<::dingodb::pb::common::Vector>& result_op_right_vectors) // NOLINT
226+
{ // NOLINT
227+
return CalcDistanceCore(op_left_vectors, op_right_vectors, is_return_normlize, distances, result_op_left_vectors,
228+
result_op_right_vectors, DoCalcHammingDistanceByFaiss);
229+
}
230+
216231
butil::Status VectorIndexUtils::CalcL2DistanceByHnswlib(
217232
const google::protobuf::RepeatedPtrField<::dingodb::pb::common::Vector>& op_left_vectors,
218233
const google::protobuf::RepeatedPtrField<::dingodb::pb::common::Vector>& op_right_vectors, bool is_return_normlize,
@@ -307,6 +322,33 @@ butil::Status VectorIndexUtils::DoCalcCosineDistanceByFaiss(
307322
return butil::Status();
308323
}
309324

325+
butil::Status VectorIndexUtils::DoCalcHammingDistanceByFaiss(
326+
const ::dingodb::pb::common::Vector& op_left_vectors, const ::dingodb::pb::common::Vector& op_right_vectors,
327+
bool is_return_normlize,
328+
float& distance, // NOLINT
329+
dingodb::pb::common::Vector& result_op_left_vectors, // NOLINT
330+
dingodb::pb::common::Vector& result_op_right_vectors) // NOLINT
331+
{ // NOLINT
332+
faiss::VectorDistance<faiss::MetricType::METRIC_HAMMING> vector_distance;
333+
vector_distance.d = op_left_vectors.binary_values().size();
334+
335+
std::vector<uint8_t> left_vectors = std::vector<uint8_t>(op_left_vectors.binary_values().size());
336+
for (int j = 0; j < op_left_vectors.binary_values().size(); j++) {
337+
left_vectors[j] = static_cast<uint8_t>(op_left_vectors.binary_values()[j][0]);
338+
}
339+
std::vector<uint8_t> right_vectors = std::vector<uint8_t>(op_right_vectors.binary_values().size());
340+
for (int j = 0; j < op_right_vectors.binary_values().size(); j++) {
341+
right_vectors[j] = static_cast<uint8_t>(op_right_vectors.binary_values()[j][0]);
342+
}
343+
344+
distance = vector_distance(left_vectors.data(), right_vectors.data());
345+
346+
ResultOpBinaryVectorAssignmentWrapper(op_left_vectors, op_right_vectors, is_return_normlize, result_op_left_vectors,
347+
result_op_right_vectors);
348+
349+
return butil::Status();
350+
}
351+
310352
butil::Status VectorIndexUtils::DoCalcL2DistanceByHnswlib(
311353
const ::dingodb::pb::common::Vector& op_left_vectors, const ::dingodb::pb::common::Vector& op_right_vectors,
312354
bool is_return_normlize,
@@ -386,6 +428,13 @@ void VectorIndexUtils::ResultOpVectorAssignment(dingodb::pb::common::Vector& res
386428
result_op_vectors.set_value_type(::dingodb::pb::common::ValueType::FLOAT);
387429
}
388430

431+
void VectorIndexUtils::ResultOpBinaryVectorAssignment(dingodb::pb::common::Vector& result_op_vectors,
432+
const ::dingodb::pb::common::Vector& op_vectors) {
433+
result_op_vectors = op_vectors;
434+
result_op_vectors.set_dimension(result_op_vectors.binary_values().size() * CHAR_BIT);
435+
result_op_vectors.set_value_type(::dingodb::pb::common::ValueType::UINT8);
436+
}
437+
389438
void VectorIndexUtils::ResultOpVectorAssignmentWrapper(const ::dingodb::pb::common::Vector& op_left_vectors,
390439
const ::dingodb::pb::common::Vector& op_right_vectors,
391440
bool is_return_normlize,
@@ -403,6 +452,23 @@ void VectorIndexUtils::ResultOpVectorAssignmentWrapper(const ::dingodb::pb::comm
403452
}
404453
}
405454

455+
void VectorIndexUtils::ResultOpBinaryVectorAssignmentWrapper(
456+
const ::dingodb::pb::common::Vector& op_left_vectors, const ::dingodb::pb::common::Vector& op_right_vectors,
457+
bool is_return_normlize,
458+
dingodb::pb::common::Vector& result_op_left_vectors, // NOLINT
459+
dingodb::pb::common::Vector& result_op_right_vectors) // NOLINT
460+
{ // NOLINT
461+
if (is_return_normlize) {
462+
if (result_op_left_vectors.binary_values().empty()) {
463+
ResultOpBinaryVectorAssignment(result_op_left_vectors, op_left_vectors);
464+
}
465+
466+
if (result_op_right_vectors.binary_values().empty()) {
467+
ResultOpBinaryVectorAssignment(result_op_right_vectors, op_right_vectors);
468+
}
469+
}
470+
}
471+
406472
void VectorIndexUtils::NormalizeVectorForFaiss(float* x, int32_t d) {
407473
static const float kFloatAccuracy = 0.00001;
408474

@@ -446,6 +512,10 @@ butil::Status VectorIndexUtils::CheckVectorDimension(const std::vector<pb::commo
446512
DINGO_LOG(ERROR) << s;
447513
return butil::Status(pb::error::Errno::EVECTOR_INVALID, s);
448514
}
515+
if (vector_with_id.vector().dimension() != dimension) {
516+
std::string s = fmt::format("vector dimension not match, {} {}", vector_with_id.vector().dimension(), dimension);
517+
return butil::Status(pb::error::Errno::EVECTOR_INVALID, s);
518+
}
449519
}
450520

451521
return butil::Status::OK();
@@ -486,9 +556,9 @@ template <typename T>
486556
std::unique_ptr<T[]> VectorIndexUtils::ExtractVectorValue(const std::vector<pb::common::VectorWithId>& vector_with_ids,
487557
faiss::idx_t dimension, bool normalize) {
488558
std::unique_ptr<T[]> vectors = nullptr;
489-
if (std::is_same<T, float>::value) {
559+
if constexpr (std::is_same<T, float>::value) {
490560
vectors = std::make_unique<T[]>(vector_with_ids.size() * dimension);
491-
} else if (std::is_same<T, uint8_t>::value) {
561+
} else if constexpr (std::is_same<T, uint8_t>::value) {
492562
vectors = std::make_unique<T[]>(vector_with_ids.size() * dimension / CHAR_BIT);
493563
} else {
494564
std::string s = fmt::format("invalid value typename type");
@@ -497,8 +567,8 @@ std::unique_ptr<T[]> VectorIndexUtils::ExtractVectorValue(const std::vector<pb::
497567
}
498568

499569
for (size_t i = 0; i < vector_with_ids.size(); ++i) {
500-
if (vector_with_ids[i].vector().value_type() == pb::common::ValueType::FLOAT) {
501-
if (!std::is_same<T, float>::value) {
570+
if constexpr (std::is_same<T, float>::value) {
571+
if (vector_with_ids[i].vector().value_type() != pb::common::ValueType::FLOAT) {
502572
std::string s = fmt::format("template not match vectors value_type : {}",
503573
pb::common::ValueType_Name(vector_with_ids[i].vector().value_type()));
504574
DINGO_LOG(ERROR) << s;
@@ -509,15 +579,17 @@ std::unique_ptr<T[]> VectorIndexUtils::ExtractVectorValue(const std::vector<pb::
509579
if (normalize) {
510580
VectorIndexUtils::NormalizeVectorForFaiss(reinterpret_cast<float*>(vectors.get()) + i * dimension, dimension);
511581
}
512-
} else if (vector_with_ids[i].vector().value_type() == pb::common::ValueType::UINT8) {
513-
if (!std::is_same<T, uint8_t>::value) {
582+
} else if constexpr (std::is_same<T, uint8_t>::value) {
583+
if (vector_with_ids[i].vector().value_type() != pb::common::ValueType::UINT8) {
514584
std::string s = fmt::format("template not match vectors value_type : {}",
515585
pb::common::ValueType_Name(vector_with_ids[i].vector().value_type()));
516586
DINGO_LOG(ERROR) << s;
517587
return nullptr;
518588
}
519589
const auto& vector_value = vector_with_ids[i].vector().binary_values();
520-
memcpy(vectors.get() + i * dimension / CHAR_BIT, vector_value.data(), dimension / CHAR_BIT);
590+
for (int j = 0; j < vector_value.size(); j++) {
591+
vectors.get()[i * dimension / CHAR_BIT + j] = static_cast<uint8_t>(vector_value[j][0]);
592+
}
521593
} else {
522594
std::string s =
523595
fmt::format("invalid value type : {}", pb::common::ValueType_Name(vector_with_ids[i].vector().value_type()));
@@ -855,8 +927,9 @@ butil::Status VectorIndexUtils::ValidateVectorIndexParameter(
855927
!(ivf_flat_parameter.metric_type() == pb::common::METRIC_TYPE_INNER_PRODUCT) &&
856928
!(ivf_flat_parameter.metric_type() == pb::common::METRIC_TYPE_L2)) {
857929
DINGO_LOG(ERROR) << "ivf_flat_parameter.metric_type is illegal " << ivf_flat_parameter.metric_type();
858-
return butil::Status(pb::error::Errno::EILLEGAL_PARAMTETERS,
859-
"ivf_flat_parameter.metric_type is illegal " + std::to_string(ivf_flat_parameter.metric_type()));
930+
return butil::Status(
931+
pb::error::Errno::EILLEGAL_PARAMTETERS,
932+
"ivf_flat_parameter.metric_type is illegal " + std::to_string(ivf_flat_parameter.metric_type()));
860933
}
861934

862935
// check ivf_flat_parameter.ncentroids

src/vector/vector_index_utils.h

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,13 @@ class VectorIndexUtils {
9292
std::vector<::dingodb::pb::common::Vector>& result_op_left_vectors,
9393
std::vector<::dingodb::pb::common::Vector>& result_op_right_vectors);
9494

95+
static butil::Status CalcHammingDistanceByFaiss(
96+
const google::protobuf::RepeatedPtrField<::dingodb::pb::common::Vector>& op_left_vectors,
97+
const google::protobuf::RepeatedPtrField<::dingodb::pb::common::Vector>& op_right_vectors,
98+
bool is_return_normlize, std::vector<std::vector<float>>& distances,
99+
std::vector<::dingodb::pb::common::Vector>& result_op_left_vectors,
100+
std::vector<::dingodb::pb::common::Vector>& result_op_right_vectors);
101+
95102
static butil::Status CalcL2DistanceByHnswlib(
96103
const google::protobuf::RepeatedPtrField<::dingodb::pb::common::Vector>& op_left_vectors,
97104
const google::protobuf::RepeatedPtrField<::dingodb::pb::common::Vector>& op_right_vectors,
@@ -132,6 +139,12 @@ class VectorIndexUtils {
132139
dingodb::pb::common::Vector& result_op_left_vectors,
133140
dingodb::pb::common::Vector& result_op_right_vectors);
134141

142+
static butil::Status DoCalcHammingDistanceByFaiss(const ::dingodb::pb::common::Vector& op_left_vectors,
143+
const ::dingodb::pb::common::Vector& op_right_vectors,
144+
bool is_return_normlize, float& distance,
145+
dingodb::pb::common::Vector& result_op_left_vectors,
146+
dingodb::pb::common::Vector& result_op_right_vectors);
147+
135148
static butil::Status DoCalcL2DistanceByHnswlib(const ::dingodb::pb::common::Vector& op_left_vectors,
136149
const ::dingodb::pb::common::Vector& op_right_vectors,
137150
bool is_return_normlize, float& distance,
@@ -152,18 +165,26 @@ class VectorIndexUtils {
152165

153166
static void ResultOpVectorAssignment(dingodb::pb::common::Vector& result_op_vectors,
154167
const ::dingodb::pb::common::Vector& op_vectors);
168+
static void ResultOpBinaryVectorAssignment(dingodb::pb::common::Vector& result_op_vectors,
169+
const ::dingodb::pb::common::Vector& op_vectors);
155170

156171
static void ResultOpVectorAssignmentWrapper(const ::dingodb::pb::common::Vector& op_left_vectors,
157172
const ::dingodb::pb::common::Vector& op_right_vectors,
158173
bool is_return_normlize,
159174
dingodb::pb::common::Vector& result_op_left_vectors,
160175
dingodb::pb::common::Vector& result_op_right_vectors);
161176

177+
static void ResultOpBinaryVectorAssignmentWrapper(const ::dingodb::pb::common::Vector& op_left_vectors,
178+
const ::dingodb::pb::common::Vector& op_right_vectors,
179+
bool is_return_normlize,
180+
dingodb::pb::common::Vector& result_op_left_vectors,
181+
dingodb::pb::common::Vector& result_op_right_vectors);
182+
162183
static void NormalizeVectorForFaiss(float* x, int32_t d);
163184
static void NormalizeVectorForHnsw(const float* data, uint32_t dimension, float* norm_array);
164185

165-
static butil::Status CheckVectorDimension(
166-
const std::vector<pb::common::VectorWithId>& vector_with_ids, int dimension);
186+
static butil::Status CheckVectorDimension(const std::vector<pb::common::VectorWithId>& vector_with_ids,
187+
int dimension);
167188

168189
static std::unique_ptr<faiss::idx_t[]> CastVectorId(const std::vector<int64_t>& delete_ids);
169190

test/unit_test/vector/test_vector_index_utils.cc

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,15 @@
1515
#include <gtest/gtest.h>
1616

1717
#include <array>
18+
#include <climits>
1819
#include <cstddef>
1920
#include <cstdint>
2021
#include <cstdio>
2122
#include <cstdlib>
2223
#include <iomanip>
2324
#include <iostream>
2425
#include <random>
26+
#include <string>
2527
#include <vector>
2628

2729
#include "butil/status.h"
@@ -4222,6 +4224,80 @@ TEST_F(VectorIndexUtilsTest, DoCalcCosineDistanceByFaiss) {
42224224
}
42234225
}
42244226

4227+
TEST_F(VectorIndexUtilsTest, DoCalcHammingDistanceByFaiss) {
4228+
// ok
4229+
{
4230+
constexpr uint32_t kDimension = 16;
4231+
std::array<uint8_t, kDimension / CHAR_BIT> data_left{};
4232+
4233+
std::mt19937 rng;
4234+
std::uniform_real_distribution<> distrib(0, 255);
4235+
for (auto& elem : data_left) {
4236+
elem = distrib(rng);
4237+
}
4238+
4239+
LOG(INFO) << "left_data : \t";
4240+
for (const auto elem : data_left) {
4241+
LOG(INFO) << std::setw(3) << static_cast<int32_t>(elem) << " ";
4242+
}
4243+
4244+
std::array<uint8_t, kDimension / CHAR_BIT> data_right{};
4245+
for (auto& elem : data_right) {
4246+
elem = distrib(rng);
4247+
}
4248+
4249+
LOG(INFO) << "right_data : \t";
4250+
for (const auto elem : data_right) {
4251+
LOG(INFO) << std::setw(3) << static_cast<int32_t>(elem) << " ";
4252+
}
4253+
4254+
::dingodb::pb::common::Vector result_op_left_vectors;
4255+
::dingodb::pb::common::Vector result_op_right_vectors;
4256+
::dingodb::pb::common::Vector op_left_vectors;
4257+
::dingodb::pb::common::Vector op_right_vectors;
4258+
bool is_return_normlize = true;
4259+
float distance = 0.0f;
4260+
4261+
op_left_vectors.set_value_type(::dingodb::pb::common::ValueType::UINT8);
4262+
op_right_vectors.set_value_type(::dingodb::pb::common::ValueType::UINT8);
4263+
4264+
for (const auto elem : data_left) {
4265+
std::string str = std::string(1, static_cast<char>(elem));
4266+
op_left_vectors.add_binary_values(str);
4267+
}
4268+
4269+
for (const auto elem : data_right) {
4270+
std::string str = std::string(1, static_cast<char>(elem));
4271+
op_right_vectors.add_binary_values(str);
4272+
}
4273+
4274+
butil::Status ok =
4275+
VectorIndexUtils::DoCalcHammingDistanceByFaiss(op_left_vectors, op_right_vectors, is_return_normlize, distance,
4276+
result_op_left_vectors, result_op_right_vectors);
4277+
4278+
EXPECT_EQ(ok.error_code(), pb::error::Errno::OK);
4279+
LOG(INFO) << "DoCalcHammingDistanceByFaiss:distance:" << distance;
4280+
4281+
EXPECT_EQ(result_op_left_vectors.value_type(), ::dingodb::pb::common::ValueType::UINT8);
4282+
LOG(INFO) << "DoCalcHammingDistanceByFaiss:left";
4283+
LOG(INFO) << "DoCalcHammingDistanceByFaiss:value_type : " << result_op_left_vectors.value_type();
4284+
LOG(INFO) << "DoCalcHammingDistanceByFaiss:dimension : " << result_op_left_vectors.dimension();
4285+
LOG(INFO) << "DoCalcHammingDistanceByFaiss:data : \t\t";
4286+
for (const auto& elem : result_op_left_vectors.binary_values()) {
4287+
LOG(INFO) << static_cast<int32_t>(elem[0]) << " ";
4288+
}
4289+
4290+
EXPECT_EQ(result_op_right_vectors.value_type(), ::dingodb::pb::common::ValueType::UINT8);
4291+
LOG(INFO) << "DoCalcHammingDistanceByFaiss:right";
4292+
LOG(INFO) << "DoCalcHammingDistanceByFaiss:value_type : " << result_op_right_vectors.value_type();
4293+
LOG(INFO) << "DoCalcHammingDistanceByFaiss:dimension : " << result_op_right_vectors.dimension();
4294+
LOG(INFO) << "DoCalcHammingDistanceByFaiss:data : \t\t";
4295+
for (const auto& elem : result_op_right_vectors.binary_values()) {
4296+
LOG(INFO) << static_cast<int32_t>(elem[0]) << " ";
4297+
}
4298+
}
4299+
}
4300+
42254301
TEST_F(VectorIndexUtilsTest, DoCalcL2DistanceByHnswlib) {
42264302
// ok
42274303
{

0 commit comments

Comments
 (0)