Skip to content

Commit 4292332

Browse files
committed
Fixing binary again
1 parent 39dd4b7 commit 4292332

File tree

1 file changed

+13
-12
lines changed

1 file changed

+13
-12
lines changed

cpp/src/preprocessing/quantize/detail/binary.cuh

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,8 @@ auto train(raft::resources const& res,
276276
const size_t dataset_size = dataset.extent(0);
277277
quantizer.threshold = raft::make_device_vector<T, int64_t>(res, dataset_dim);
278278

279-
std::vector<T> host_threshold_vec(dataset_dim);
279+
using compute_t = std::conditional_t<std::is_same_v<half, T>, float, T>;
280+
std::vector<compute_t> host_threshold_vec(dataset_dim);
280281
auto threshold_ptr = host_threshold_vec.data();
281282

282283
if (params.threshold == cuvs::preprocessing::quantize::binary::bit_threshold::mean) {
@@ -286,7 +287,7 @@ auto train(raft::resources const& res,
286287
#pragma omp parallel for reduction(+ : threshold_ptr[ : dataset_dim])
287288
for (size_t i = 0; i < dataset_size; i++) {
288289
for (uint32_t j = 0; j < dataset_dim; j++) {
289-
threshold_ptr[j] += static_cast<T>(dataset.data_handle()[i * dataset_dim + j]);
290+
threshold_ptr[j] += static_cast<compute_t>(dataset.data_handle()[i * dataset_dim + j]);
290291
}
291292
}
292293
for (uint32_t j = 0; j < dataset_dim; j++) {
@@ -314,13 +315,13 @@ auto train(raft::resources const& res,
314315
const auto stride = stride_prime_list[prime_i];
315316

316317
// Transposed
317-
auto sampled_dataset = raft::make_host_matrix<T, int64_t>(dataset_dim, num_samples);
318+
auto sampled_dataset = raft::make_host_matrix<compute_t, int64_t>(dataset_dim, num_samples);
318319
#pragma omp parallel for
319320
for (size_t out_i = 0; out_i < num_samples; out_i++) {
320321
const auto in_i = (out_i * stride) % dataset_size;
321322
for (uint32_t j = 0; j < dataset_dim; j++) {
322323
sampled_dataset.data_handle()[j * num_samples + out_i] =
323-
static_cast<T>(dataset.data_handle()[in_i * dataset_dim + j]);
324+
static_cast<compute_t>(dataset.data_handle()[in_i * dataset_dim + j]);
324325
}
325326
}
326327

@@ -332,15 +333,15 @@ auto train(raft::resources const& res,
332333
}
333334
}
334335

335-
if constexpr (std::is_same_v<T, T>) {
336+
if constexpr (std::is_same_v<T, compute_t>) {
336337
raft::copy(quantizer.threshold.data_handle(),
337338
host_threshold_vec.data(),
338339
dataset_dim,
339340
raft::resource::get_cuda_stream(res));
340341
} else {
341-
auto mr = raft::resource::get_workspace_resource(res);
342-
auto casted_vec =
343-
raft::make_device_mdarray<T, int64_t>(res, mr, raft::make_extents<int64_t>(dataset_dim));
342+
auto mr = raft::resource::get_workspace_resource(res);
343+
auto casted_vec = raft::make_device_mdarray<compute_t, int64_t>(
344+
res, mr, raft::make_extents<int64_t>(dataset_dim));
344345
raft::copy(casted_vec.data_handle(),
345346
host_threshold_vec.data(),
346347
dataset_dim,
@@ -422,10 +423,10 @@ void transform(raft::resources const& res,
422423
T* threshold_ptr = nullptr;
423424

424425
if (quantizer.threshold.size() != 0) {
425-
threshold_vec = raft::make_host_vector<T, int64_t>(dataset_dim);
426+
threshold_vec = raft::make_host_vector<compute_t, int64_t>(dataset_dim);
426427
threshold_ptr = threshold_vec.data_handle();
427428

428-
if constexpr (std::is_same_v<T, T>) {
429+
if constexpr (std::is_same_v<compute_t, T>) {
429430
raft::copy(threshold_ptr,
430431
quantizer.threshold.data_handle(),
431432
dataset_dim,
@@ -436,7 +437,7 @@ void transform(raft::resources const& res,
436437
res, mr, raft::make_extents<int64_t>(dataset_dim));
437438
raft::linalg::map(res,
438439
casted_vec.view(),
439-
raft::cast_op<T>{},
440+
raft::cast_op<compute_t>{},
440441
raft::make_const_mdspan(quantizer.threshold.view()));
441442
raft::copy(
442443
threshold_ptr, casted_vec.data_handle(), dataset_dim, raft::resource::get_cuda_stream(res));
@@ -453,7 +454,7 @@ void transform(raft::resources const& res,
453454
if (threshold_ptr == nullptr) {
454455
if (is_positive(dataset(i, in_j))) { pack |= (1u << pack_j); }
455456
} else {
456-
if (is_positive(static_cast<T>(dataset(i, in_j)) - threshold_ptr[in_j])) {
457+
if (is_positive(static_cast<compute_t>(dataset(i, in_j)) - threshold_ptr[in_j])) {
457458
pack |= (1u << pack_j);
458459
}
459460
}

0 commit comments

Comments
 (0)