Skip to content

Commit 1b52696

Browse files
authored
Change device scan default accumulator type (#724)
* Change device scan default accumulator type The default accumulator types are changing for ROCm 7.0 as described below. Note that this is a breaking change. * rocprim::inclusive_scan * previous default: class AccType = typename std::iterator_traits<InputIterator>::value_type> * new default: class AccType = rocprim::invoke_result_binary_op_t<typename std::iterator_traits<InputIterator>::value_type, BinaryFunction>` * `rocprim::deterministic_inclusive_scan * previous default: class AccType = typename std::iterator_traits<InputIterator>::value_type>` * new default: class AccType = rocprim::invoke_result_binary_op_t<typename std::iterator_traits<InputIterator>::value_type, BinaryFunction>` * `rocprim::exclusive_scan * previous default: class AccType = detail::input_type_t<InitValueType>> * new default: class AccType = rocprim::invoke_result_binary_op_t<rocprim::detail::input_type_t<InitValueType>, BinaryFunction> * `rocprim::deterministic_exclusive_scan * previous default: class AccType = detail::input_type_t<InitValueType>> * new default: class AccType = rocprim::invoke_result_binary_op_t<rocprim::detail::input_type_t<InitValueType>, BinaryFunction> * Grammatical corrections for 7.0 changelog items This fix corrects some grammatical issues in the changelog. * Update device_scan tests to use new defaults Update the inclusive and exclusive host scan routines so they use the new accumulator type defaults.
1 parent 97ddb51 commit 1b52696

File tree

3 files changed

+32
-10
lines changed

3 files changed

+32
-10
lines changed

CHANGELOG.md

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,28 @@
22

33
Full documentation for rocPRIM is available at [https://rocm.docs.amd.com/projects/rocPRIM/en/latest/](https://rocm.docs.amd.com/projects/rocPRIM/en/latest/).
44

5+
## rocPRIM 3.6.0 for ROCm 7.0
6+
7+
### Changed
8+
9+
* The default scan accumulator types for device-level scan algorithms have changed. This is a breaking change.
10+
The previous default accumulator types could lead to situations in which unexpected overflow occured, such as
11+
when the input or inital type was smaller than the output type.
12+
13+
This is a complete list of affected functions and how their default accumulator types are changing:
14+
* `rocprim::inclusive_scan`
15+
* past default: `class AccType = typename std::iterator_traits<InputIterator>::value_type>`
16+
* new default: `class AccType = rocprim::invoke_result_binary_op_t<typename std::iterator_traits<InputIterator>::value_type, BinaryFunction>`
17+
* `rocprim::deterministic_inclusive_scan`
18+
* past default: `class AccType = typename std::iterator_traits<InputIterator>::value_type>`
19+
* new default: `class AccType = rocprim::invoke_result_binary_op_t<typename std::iterator_traits<InputIterator>::value_type, BinaryFunction>`
20+
* `rocprim::exclusive_scan`
21+
* past default: `class AccType = detail::input_type_t<InitValueType>>`
22+
* new default: `class AccType = rocprim::invoke_result_binary_op_t<rocprim::detail::input_type_t<InitValueType>, BinaryFunction>`
23+
* `rocprim::deterministic_exclusive_scan`
24+
* past default: `class AccType = detail::input_type_t<InitValueType>>`
25+
* new default: `class AccType = rocprim::invoke_result_binary_op_t<rocprim::detail::input_type_t<InitValueType>, BinaryFunction>`
26+
527
## rocPRIM 3.5.0 for ROCm 6.5.0
628

729
### Removed

rocprim/include/rocprim/device/device_scan.hpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -403,8 +403,8 @@ inline auto scan_impl(void* temporary_storage,
403403
/// requirements of a C++ OutputIterator concept. It can be a simple pointer type.
404404
/// \tparam BinaryFunction type of binary function used for scan. Default type
405405
/// is \p rocprim::plus<T>, where \p T is a \p value_type of \p InputIterator.
406-
/// \tparam AccType accumulator type used to propagate the scanned values. Default type
407-
/// is value type of the input iterator.
406+
/// \tparam AccType accumulator type used to propagate the scanned values. The default is the type that
407+
/// is returned by a function of type BinaryFunction when it's is passed an InputIterator value.
408408
///
409409
/// \param [in] temporary_storage pointer to a device-accessible temporary storage. When
410410
/// a null pointer is passed, the required allocation size (in bytes) is written to
@@ -495,7 +495,7 @@ template<class Config = default_config,
495495
class OutputIterator,
496496
class BinaryFunction
497497
= ::rocprim::plus<typename std::iterator_traits<InputIterator>::value_type>,
498-
class AccType = typename std::iterator_traits<InputIterator>::value_type>
498+
class AccType = rocprim::invoke_result_binary_op_t<typename std::iterator_traits<InputIterator>::value_type, BinaryFunction>>
499499
inline hipError_t inclusive_scan(void* temporary_storage,
500500
size_t& storage_size,
501501
InputIterator input,
@@ -536,7 +536,7 @@ template<class Config = default_config,
536536
class OutputIterator,
537537
class BinaryFunction
538538
= ::rocprim::plus<typename std::iterator_traits<InputIterator>::value_type>,
539-
class AccType = typename std::iterator_traits<InputIterator>::value_type>
539+
class AccType = rocprim::invoke_result_binary_op_t<typename std::iterator_traits<InputIterator>::value_type, BinaryFunction>>
540540
inline hipError_t deterministic_inclusive_scan(void* temporary_storage,
541541
size_t& storage_size,
542542
InputIterator input,
@@ -590,8 +590,8 @@ inline hipError_t deterministic_inclusive_scan(void* temporary_stora
590590
/// \tparam InitValueType type of the initial value.
591591
/// \tparam BinaryFunction type of binary function used for scan. Default type
592592
/// is \p rocprim::plus<T>, where \p T is a \p value_type of \p InputIterator.
593-
/// \tparam AccType accumulator type used to propagate the scanned values. Default type
594-
/// is 'InitValueType', unless it's 'rocprim::future_value'. Then it will be the wrapped input type.
593+
/// \tparam AccType accumulator type used to propagate the scanned values. The default is the type that
594+
/// is returned by a function of type BinaryFunction when it's is passed a value of type InitValueType.
595595
///
596596
/// \param [in] temporary_storage pointer to a device-accessible temporary storage. When
597597
/// a null pointer is passed, the required allocation size (in bytes) is written to
@@ -661,7 +661,7 @@ template<class Config = default_config,
661661
class InitValueType,
662662
class BinaryFunction
663663
= ::rocprim::plus<typename std::iterator_traits<InputIterator>::value_type>,
664-
class AccType = detail::input_type_t<InitValueType>>
664+
class AccType = rocprim::invoke_result_binary_op_t<rocprim::detail::input_type_t<InitValueType>, BinaryFunction>>
665665
inline hipError_t exclusive_scan(void* temporary_storage,
666666
size_t& storage_size,
667667
InputIterator input,
@@ -703,7 +703,7 @@ template<class Config = default_config,
703703
class InitValueType,
704704
class BinaryFunction
705705
= ::rocprim::plus<typename std::iterator_traits<InputIterator>::value_type>,
706-
class AccType = detail::input_type_t<InitValueType>>
706+
class AccType = rocprim::invoke_result_binary_op_t<rocprim::detail::input_type_t<InitValueType>, BinaryFunction>>
707707
inline hipError_t deterministic_exclusive_scan(void* temporary_storage,
708708
size_t& storage_size,
709709
InputIterator input,

test/rocprim/test_utils.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ template<class InputIt, class OutputIt, class BinaryOperation>
273273
OutputIt host_inclusive_scan(InputIt first, InputIt last,
274274
OutputIt d_first, BinaryOperation op)
275275
{
276-
using acc_type = typename std::iterator_traits<InputIt>::value_type;
276+
using acc_type = rocprim::invoke_result_binary_op_t<typename std::iterator_traits<InputIt>::value_type, BinaryOperation>;
277277
return host_inclusive_scan_impl(first, last, d_first, op, acc_type{});
278278
}
279279

@@ -313,7 +313,7 @@ OutputIt host_exclusive_scan(InputIt first, InputIt last,
313313
T initial_value, OutputIt d_first,
314314
BinaryOperation op)
315315
{
316-
using acc_type = typename std::iterator_traits<InputIt>::value_type;
316+
using acc_type = rocprim::invoke_result_binary_op_t<rocprim::detail::input_type_t<T>, BinaryOperation>;
317317
return host_exclusive_scan_impl(first, last, initial_value, d_first, op, acc_type{});
318318
}
319319

0 commit comments

Comments
 (0)