@@ -276,7 +276,8 @@ auto train(raft::resources const& res,
276
276
const size_t dataset_size = dataset.extent (0 );
277
277
quantizer.threshold = raft::make_device_vector<T, int64_t >(res, dataset_dim);
278
278
279
- std::vector<T> host_threshold_vec (dataset_dim);
279
+ using compute_t = std::conditional_t <std::is_same_v<half, T>, float , T>;
280
+ std::vector<compute_t > host_threshold_vec (dataset_dim);
280
281
auto threshold_ptr = host_threshold_vec.data ();
281
282
282
283
if (params.threshold == cuvs::preprocessing::quantize::binary::bit_threshold::mean) {
@@ -286,7 +287,7 @@ auto train(raft::resources const& res,
286
287
#pragma omp parallel for reduction(+ : threshold_ptr[ : dataset_dim])
287
288
for (size_t i = 0 ; i < dataset_size; i++) {
288
289
for (uint32_t j = 0 ; j < dataset_dim; j++) {
289
- threshold_ptr[j] += static_cast <T >(dataset.data_handle ()[i * dataset_dim + j]);
290
+ threshold_ptr[j] += static_cast <compute_t >(dataset.data_handle ()[i * dataset_dim + j]);
290
291
}
291
292
}
292
293
for (uint32_t j = 0 ; j < dataset_dim; j++) {
@@ -314,13 +315,13 @@ auto train(raft::resources const& res,
314
315
const auto stride = stride_prime_list[prime_i];
315
316
316
317
// Transposed
317
- auto sampled_dataset = raft::make_host_matrix<T , int64_t >(dataset_dim, num_samples);
318
+ auto sampled_dataset = raft::make_host_matrix<compute_t , int64_t >(dataset_dim, num_samples);
318
319
#pragma omp parallel for
319
320
for (size_t out_i = 0 ; out_i < num_samples; out_i++) {
320
321
const auto in_i = (out_i * stride) % dataset_size;
321
322
for (uint32_t j = 0 ; j < dataset_dim; j++) {
322
323
sampled_dataset.data_handle ()[j * num_samples + out_i] =
323
- static_cast <T >(dataset.data_handle ()[in_i * dataset_dim + j]);
324
+ static_cast <compute_t >(dataset.data_handle ()[in_i * dataset_dim + j]);
324
325
}
325
326
}
326
327
@@ -332,15 +333,15 @@ auto train(raft::resources const& res,
332
333
}
333
334
}
334
335
335
- if constexpr (std::is_same_v<T, T >) {
336
+ if constexpr (std::is_same_v<T, compute_t >) {
336
337
raft::copy (quantizer.threshold .data_handle (),
337
338
host_threshold_vec.data (),
338
339
dataset_dim,
339
340
raft::resource::get_cuda_stream (res));
340
341
} else {
341
- auto mr = raft::resource::get_workspace_resource (res);
342
- auto casted_vec =
343
- raft::make_device_mdarray<T, int64_t >( res, mr, raft::make_extents<int64_t >(dataset_dim));
342
+ auto mr = raft::resource::get_workspace_resource (res);
343
+ auto casted_vec = raft::make_device_mdarray< compute_t , int64_t >(
344
+ res, mr, raft::make_extents<int64_t >(dataset_dim));
344
345
raft::copy (casted_vec.data_handle (),
345
346
host_threshold_vec.data (),
346
347
dataset_dim,
@@ -422,10 +423,10 @@ void transform(raft::resources const& res,
422
423
T* threshold_ptr = nullptr ;
423
424
424
425
if (quantizer.threshold .size () != 0 ) {
425
- threshold_vec = raft::make_host_vector<T , int64_t >(dataset_dim);
426
+ threshold_vec = raft::make_host_vector<compute_t , int64_t >(dataset_dim);
426
427
threshold_ptr = threshold_vec.data_handle ();
427
428
428
- if constexpr (std::is_same_v<T , T>) {
429
+ if constexpr (std::is_same_v<compute_t , T>) {
429
430
raft::copy (threshold_ptr,
430
431
quantizer.threshold .data_handle (),
431
432
dataset_dim,
@@ -436,7 +437,7 @@ void transform(raft::resources const& res,
436
437
res, mr, raft::make_extents<int64_t >(dataset_dim));
437
438
raft::linalg::map (res,
438
439
casted_vec.view (),
439
- raft::cast_op<T >{},
440
+ raft::cast_op<compute_t >{},
440
441
raft::make_const_mdspan (quantizer.threshold .view ()));
441
442
raft::copy (
442
443
threshold_ptr, casted_vec.data_handle (), dataset_dim, raft::resource::get_cuda_stream (res));
@@ -453,7 +454,7 @@ void transform(raft::resources const& res,
453
454
if (threshold_ptr == nullptr ) {
454
455
if (is_positive (dataset (i, in_j))) { pack |= (1u << pack_j); }
455
456
} else {
456
- if (is_positive (static_cast <T >(dataset (i, in_j)) - threshold_ptr[in_j])) {
457
+ if (is_positive (static_cast <compute_t >(dataset (i, in_j)) - threshold_ptr[in_j])) {
457
458
pack |= (1u << pack_j);
458
459
}
459
460
}
0 commit comments