Skip to content

Commit 7347b9d

Browse files
authored
Handle const and reference scan accumulator types (#736)
We recently updated the default accumulator types for rocprim::inclusive_scan, deterministic_inclusive_scan, exclusive_scan, and deterministic_exclusive_scan. The update set the default type to the return type of the scan operator. The scan operator can return a type that is const or a reference. This is problematic because internally, the scan algorithm need to create instances of the accumulator type and periodically update their value. It also needs to create pointers of the accumulator type. This change fixes the problem by stripping any const or reference from the accumulator type before it is passed to the internal scan implementation function.
1 parent 1b52696 commit 7347b9d

File tree

1 file changed

+54
-38
lines changed

1 file changed

+54
-38
lines changed

rocprim/include/rocprim/device/device_scan.hpp

Lines changed: 54 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -505,23 +505,27 @@ inline hipError_t inclusive_scan(void* temporary_storage,
505505
const hipStream_t stream = 0,
506506
bool debug_synchronous = false)
507507
{
508+
// AccType may be const or a reference. Get the non-const, non-reference type.
509+
// This is necessary because we may need to assign to instances of this type or create pointers to it.
510+
using safe_acc_type = typename std::remove_const<typename std::remove_reference<AccType>::type>::type;
511+
508512
// input_type() is a dummy initial value (not used)
509513
return detail::scan_impl<detail::lookback_scan_determinism::default_determinism,
510514
false,
511515
Config,
512516
InputIterator,
513517
OutputIterator,
514-
AccType,
518+
safe_acc_type,
515519
BinaryFunction,
516-
AccType>(temporary_storage,
517-
storage_size,
518-
input,
519-
output,
520-
AccType{},
521-
size,
522-
scan_op,
523-
stream,
524-
debug_synchronous);
520+
safe_acc_type>(temporary_storage,
521+
storage_size,
522+
input,
523+
output,
524+
safe_acc_type{},
525+
size,
526+
scan_op,
527+
stream,
528+
debug_synchronous);
525529
}
526530

527531
/// \brief Bitwise-reproducible parallel inclusive scan primitive for device level.
@@ -546,22 +550,26 @@ inline hipError_t deterministic_inclusive_scan(void* temporary_stora
546550
const hipStream_t stream = 0,
547551
bool debug_synchronous = false)
548552
{
553+
// AccType may be const or a reference. Get the non-const, non-reference type.
554+
// This is necessary because we may need to assign to instances of this type or create pointers to it.
555+
using safe_acc_type = typename std::remove_const<typename std::remove_reference<AccType>::type>::type;
556+
549557
return detail::scan_impl<detail::lookback_scan_determinism::deterministic,
550558
false,
551559
Config,
552560
InputIterator,
553561
OutputIterator,
554-
AccType,
562+
safe_acc_type,
555563
BinaryFunction,
556-
AccType>(temporary_storage,
557-
storage_size,
558-
input,
559-
output,
560-
AccType{},
561-
size,
562-
scan_op,
563-
stream,
564-
debug_synchronous);
564+
safe_acc_type>(temporary_storage,
565+
storage_size,
566+
input,
567+
output,
568+
safe_acc_type{},
569+
size,
570+
scan_op,
571+
stream,
572+
debug_synchronous);
565573
}
566574

567575
/// \brief Parallel exclusive scan primitive for device level.
@@ -672,22 +680,26 @@ inline hipError_t exclusive_scan(void* temporary_storage,
672680
const hipStream_t stream = 0,
673681
bool debug_synchronous = false)
674682
{
683+
// AccType may be const or a reference. Get the non-const, non-reference type.
684+
// This is necessary because we may need to assign to instances of this type or create pointers to it.
685+
using safe_acc_type = typename std::remove_const<typename std::remove_reference<AccType>::type>::type;
686+
675687
return detail::scan_impl<detail::lookback_scan_determinism::default_determinism,
676688
true,
677689
Config,
678690
InputIterator,
679691
OutputIterator,
680692
InitValueType,
681693
BinaryFunction,
682-
AccType>(temporary_storage,
683-
storage_size,
684-
input,
685-
output,
686-
initial_value,
687-
size,
688-
scan_op,
689-
stream,
690-
debug_synchronous);
694+
safe_acc_type>(temporary_storage,
695+
storage_size,
696+
input,
697+
output,
698+
initial_value,
699+
size,
700+
scan_op,
701+
stream,
702+
debug_synchronous);
691703
}
692704

693705
/// \brief Bitwise-reproducible parallel exclusive scan primitive for device level.
@@ -714,22 +726,26 @@ inline hipError_t deterministic_exclusive_scan(void* temporary_sto
714726
const hipStream_t stream = 0,
715727
bool debug_synchronous = false)
716728
{
729+
// AccType may be const or a reference. Get the non-const, non-reference type.
730+
// This is necessary because we may need to assign to instances of this type or create pointers to it.
731+
using safe_acc_type = typename std::remove_const<typename std::remove_reference<AccType>::type>::type;
732+
717733
return detail::scan_impl<detail::lookback_scan_determinism::deterministic,
718734
true,
719735
Config,
720736
InputIterator,
721737
OutputIterator,
722738
InitValueType,
723739
BinaryFunction,
724-
AccType>(temporary_storage,
725-
storage_size,
726-
input,
727-
output,
728-
initial_value,
729-
size,
730-
scan_op,
731-
stream,
732-
debug_synchronous);
740+
safe_acc_type>(temporary_storage,
741+
storage_size,
742+
input,
743+
output,
744+
initial_value,
745+
size,
746+
scan_op,
747+
stream,
748+
debug_synchronous);
733749
}
734750

735751
/// @}

0 commit comments

Comments
 (0)