Skip to content

Commit b504110

Browse files
authored
refactor: replace old warp size api (#705)
* feat: add new api to query wavefront size * refactor: deprercate old warp size api * chore: update copyright dates * style: update formatting * test(test_warp_sort): fix kernels not compiling with c++14 * style: fix styling in tests * docs(arch.hpp): fix typo * fix: add previously relaxed static assertions back as run-time assertions in ctor * fix: fix too strict constraints on static asserts when defining block/warp algo types on host * fix(block_histogram_sort.hpp): force loop that couldn't be unrolled (due to hotfix) to be marked as no unroll to prevent compiler warnings
1 parent 99642d5 commit b504110

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+524
-364
lines changed

CHANGELOG.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ Full documentation for rocPRIM is available at [https://rocm.docs.amd.com/projec
3232

3333
### Upcoming changes
3434

35+
* The next major release may change the template parameters of warp and block algorithms.
36+
3537
* The default scan accumulator types for device-level scan algorithms will be changed. This is a breaking change.
3638

3739
Previously, the default accumulator type was set to the input type for inclusive scans and to the initial value type for exclusive scans. These default types could cause unexpected overflow in situations where the input or initial type is smaller than the output type when the user doesn't explicitly set an accumulator type using the `AccType` template parameter.
@@ -56,6 +58,13 @@ The following is the complete list of affected functions and how their default a
5658

5759
* `rocprim::load_cs` and `rocprim::store_cs` are deprecated. Use `rocprim::load_nontemporal` and `rocprim::store_nontemporal` now.
5860

61+
* Due to an upcoming compiler change the following warp size-related symbols will be removed in the next major release and are thus marked as deprecated:
62+
* `rocprim::device_warp_size()`
63+
* For compile-time constants, this is replaced with `rocprim::arch::wavefront::min_size()` and `rocprim::arch::wavefront::max_size()`. Use this when allocating global or shared memory.
64+
* For run-time constants, this is replaced with `rocprim::arch::wavefront::size().`
65+
* `rocprim::warp_size()`
66+
* `ROCPRIM_WAVEFRONT_SIZE`
67+
5968
## rocPRIM 3.4.0 for ROCm 6.4.0
6069

6170
### Added

benchmark/benchmark_device_memory.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ struct operation<atomics_inter_warp_collision, T, ItemsPerThread, BlockSize>
179179
(void)shared_storage;
180180
(void)shared_storage_size;
181181
(void)input;
182-
unsigned int index = (threadIdx.x % rocprim::device_warp_size()) * ItemsPerThread
182+
unsigned int index = (threadIdx.x % rocprim::arch::wavefront::min_size()) * ItemsPerThread
183183
+ blockIdx.x * blockDim.x * ItemsPerThread;
184184
ROCPRIM_UNROLL
185185
for(unsigned int i = 0; i < ItemsPerThread; ++i)

benchmark/benchmark_device_radix_sort_onesweep.parallel.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -398,7 +398,7 @@ struct device_radix_sort_onesweep_benchmark_generator
398398
template<unsigned int ItemsPerThread, rocprim::block_radix_rank_algorithm RadixRankAlgorithm>
399399
static constexpr bool is_buildable()
400400
{
401-
// Calculation uses `rocprim::device_warp_size()`, which is 64 on host side unless overridden.
401+
// Calculation uses `rocprim::arch::wavefront::min_size()`, which is 64 on host side unless overridden.
402402
// However, this does not affect the total size of shared memory for the current configuration space.
403403
// Were the implementation to change, causing retuning, this needs to be re-evaluated and possibly taken into account.
404404
using sharedmem_storage = typename rocprim::detail::onesweep_iteration_helper<

benchmark/benchmark_utils.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
#include <rocprim/config.hpp>
3535
#include <rocprim/device/config_types.hpp>
3636
#include <rocprim/device/detail/device_config_helper.hpp> // partition_config_params
37+
#include <rocprim/intrinsics/arch.hpp>
3738
#include <rocprim/intrinsics/thread.hpp>
3839
#include <rocprim/type_traits.hpp>
3940
#include <rocprim/type_traits_interface.hpp>

common/utils.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,10 @@
5454

5555
namespace common
5656
{
57-
5857
template<unsigned int LogicalWarpSize>
59-
__device__ constexpr bool device_test_enabled_for_warp_size_v
60-
= ::rocprim::device_warp_size() >= LogicalWarpSize;
58+
__device__
59+
constexpr bool device_test_enabled_for_warp_size_v
60+
= ::rocprim::arch::wavefront::min_size() >= LogicalWarpSize;
6161

6262
inline char* __get_env(const char* name)
6363
{

rocprim/include/rocprim/block/block_exchange.hpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
#include "../functional.hpp"
2828
#include "../intrinsics.hpp"
29+
#include "../intrinsics/arch.hpp"
2930
#include "../types.hpp"
3031

3132
#include "config.hpp"
@@ -88,8 +89,8 @@ class block_exchange
8889
{
8990
static constexpr unsigned int BlockSize = BlockSizeX * BlockSizeY * BlockSizeZ;
9091
// Select warp size
91-
static constexpr unsigned int warp_size =
92-
detail::get_min_warp_size(BlockSize, ::rocprim::device_warp_size());
92+
static constexpr unsigned int warp_size
93+
= detail::get_min_warp_size(BlockSize, ::rocprim::arch::wavefront::min_size());
9394
// Number of warps in block
9495
static constexpr unsigned int warps_no = ::rocprim::detail::ceiling_div(BlockSize, warp_size);
9596
static constexpr unsigned int banks_no = ::rocprim::detail::get_lds_banks_no();
@@ -656,16 +657,18 @@ class block_exchange
656657
/// ...
657658
/// }
658659
/// \endcode
659-
template<unsigned int WarpSize = device_warp_size(), class U, class Offset>
660+
template<unsigned int WarpSize = arch::wavefront::min_size(), class U, class Offset>
660661
ROCPRIM_DEVICE ROCPRIM_INLINE
661662
void scatter_to_warp_striped(const T (&input)[ItemsPerThread],
662663
U (&output)[ItemsPerThread],
663664
const Offset (&ranks)[ItemsPerThread],
664665
storage_type& storage)
665666
{
666-
static_assert(detail::is_power_of_two(WarpSize) && WarpSize <= device_warp_size(),
667+
static_assert(detail::is_power_of_two(WarpSize) && WarpSize <= arch::wavefront::max_size(),
667668
"WarpSize must be a power of two and equal or less"
668669
"than the size of hardware warp.");
670+
assert(WarpSize <= arch::wavefront::size());
671+
669672
const unsigned int flat_id
670673
= ::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>();
671674
const unsigned int thread_id = detail::logical_lane_id<WarpSize>();

rocprim/include/rocprim/block/block_load.hpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -770,9 +770,15 @@ class block_load<T, BlockSizeX, ItemsPerThread, block_load_method::block_load_wa
770770
using block_exchange_type = block_exchange<T, BlockSizeX, ItemsPerThread, BlockSizeY, BlockSizeZ>;
771771

772772
public:
773-
ROCPRIM_DETAIL_DEVICE_STATIC_ASSERT(BlockSize % ::rocprim::device_warp_size() == 0,
773+
ROCPRIM_DETAIL_DEVICE_STATIC_ASSERT(BlockSize % ::rocprim::arch::wavefront::min_size() == 0,
774774
"BlockSize must be a multiple of hardware warpsize");
775775

776+
ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
777+
block_load()
778+
{
779+
assert(BlockSize % ::rocprim::arch::wavefront::size() == 0);
780+
}
781+
776782
using storage_type = typename block_exchange_type::storage_type;
777783

778784
template<class InputIterator>

rocprim/include/rocprim/block/block_load_func.hpp

Lines changed: 34 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,10 @@
2424
#include "../config.hpp"
2525
#include "../detail/various.hpp"
2626

27-
#include "../intrinsics.hpp"
2827
#include "../functional.hpp"
28+
#include "../intrinsics.hpp"
2929
#include "../types.hpp"
30+
#include "rocprim/intrinsics/arch.hpp"
3031

3132
/// \addtogroup blockmodule
3233
/// @{
@@ -367,20 +368,20 @@ void block_load_direct_striped(unsigned int flat_id,
367368
/// \param flat_id a local flat 1D thread id in a block (tile) for the calling thread
368369
/// \param block_input the input iterator from the thread block to load from
369370
/// \param items array that data is loaded to
370-
template<
371-
unsigned int WarpSize = device_warp_size(),
372-
class InputIterator,
373-
class T,
374-
unsigned int ItemsPerThread
375-
>
371+
template<unsigned int WarpSize = arch::wavefront::min_size(),
372+
class InputIterator,
373+
class T,
374+
unsigned int ItemsPerThread>
376375
ROCPRIM_DEVICE ROCPRIM_INLINE
377-
void block_load_direct_warp_striped(unsigned int flat_id,
376+
void block_load_direct_warp_striped(unsigned int flat_id,
378377
InputIterator block_input,
379378
T (&items)[ItemsPerThread])
380379
{
381-
static_assert(detail::is_power_of_two(WarpSize) && WarpSize <= device_warp_size(),
382-
"WarpSize must be a power of two and equal or less"
383-
"than the size of hardware warp.");
380+
static_assert(detail::is_power_of_two(WarpSize) && WarpSize <= arch::wavefront::max_size(),
381+
"WarpSize must be a power of two and equal or less"
382+
"than the size of hardware warp.");
383+
assert(WarpSize <= arch::wavefront::size());
384+
384385
unsigned int thread_id = detail::logical_lane_id<WarpSize>();
385386
unsigned int warp_id = flat_id / WarpSize;
386387
unsigned int warp_offset = warp_id * WarpSize * ItemsPerThread;
@@ -419,21 +420,21 @@ void block_load_direct_warp_striped(unsigned int flat_id,
419420
/// \param block_input the input iterator from the thread block to load from
420421
/// \param items array that data is loaded to
421422
/// \param valid maximum range of valid numbers to load
422-
template<
423-
unsigned int WarpSize = device_warp_size(),
424-
class InputIterator,
425-
class T,
426-
unsigned int ItemsPerThread
427-
>
423+
template<unsigned int WarpSize = arch::wavefront::min_size(),
424+
class InputIterator,
425+
class T,
426+
unsigned int ItemsPerThread>
428427
ROCPRIM_DEVICE ROCPRIM_INLINE
429-
void block_load_direct_warp_striped(unsigned int flat_id,
428+
void block_load_direct_warp_striped(unsigned int flat_id,
430429
InputIterator block_input,
431430
T (&items)[ItemsPerThread],
432431
unsigned int valid)
433432
{
434-
static_assert(detail::is_power_of_two(WarpSize) && WarpSize <= device_warp_size(),
435-
"WarpSize must be a power of two and equal or less"
436-
"than the size of hardware warp.");
433+
static_assert(detail::is_power_of_two(WarpSize) && WarpSize <= arch::wavefront::max_size(),
434+
"WarpSize must be a power of two and equal or less"
435+
"than the size of hardware warp.");
436+
assert(WarpSize <= arch::wavefront::size());
437+
437438
unsigned int thread_id = detail::logical_lane_id<WarpSize>();
438439
unsigned int warp_id = flat_id / WarpSize;
439440
unsigned int warp_offset = warp_id * WarpSize * ItemsPerThread;
@@ -479,23 +480,23 @@ void block_load_direct_warp_striped(unsigned int flat_id,
479480
/// \param items array that data is loaded to
480481
/// \param valid maximum range of valid numbers to load
481482
/// \param out_of_bounds default value assigned to out-of-bound items
482-
template<
483-
unsigned int WarpSize = device_warp_size(),
484-
class InputIterator,
485-
class T,
486-
unsigned int ItemsPerThread,
487-
class Default
488-
>
483+
template<unsigned int WarpSize = arch::wavefront::min_size(),
484+
class InputIterator,
485+
class T,
486+
unsigned int ItemsPerThread,
487+
class Default>
489488
ROCPRIM_DEVICE ROCPRIM_INLINE
490-
void block_load_direct_warp_striped(unsigned int flat_id,
489+
void block_load_direct_warp_striped(unsigned int flat_id,
491490
InputIterator block_input,
492491
T (&items)[ItemsPerThread],
493492
unsigned int valid,
494-
Default out_of_bounds)
493+
Default out_of_bounds)
495494
{
496-
static_assert(detail::is_power_of_two(WarpSize) && WarpSize <= device_warp_size(),
497-
"WarpSize must be a power of two and equal or less"
498-
"than the size of hardware warp.");
495+
static_assert(detail::is_power_of_two(WarpSize) && WarpSize <= arch::wavefront::max_size(),
496+
"WarpSize must be a power of two and equal or less"
497+
"than the size of hardware warp.");
498+
assert(WarpSize <= arch::wavefront::size());
499+
499500
ROCPRIM_UNROLL
500501
for (unsigned int item = 0; item < ItemsPerThread; item++)
501502
{

rocprim/include/rocprim/block/block_radix_sort.hpp

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
#include "block_exchange.hpp"
3535
#include "block_radix_rank.hpp"
3636
#include "rocprim/block/config.hpp"
37+
#include "rocprim/intrinsics/arch.hpp"
3738

3839
/// \addtogroup blockmodule
3940
/// @{
@@ -102,10 +103,11 @@ template<class Key,
102103
unsigned int BlockSizeY = 1,
103104
unsigned int BlockSizeZ = 1,
104105
unsigned int RadixBitsPerPass
105-
= (BlockSizeX * BlockSizeY * BlockSizeZ) % device_warp_size() == 0 ? 8 /* match */
106-
: 4 /* basic_memoize */,
106+
= (BlockSizeX * BlockSizeY * BlockSizeZ) % arch::wavefront::min_size() == 0
107+
? 8 /* match */
108+
: 4 /* basic_memoize */,
107109
block_radix_rank_algorithm RadixRankAlgorithm
108-
= (BlockSizeX * BlockSizeY * BlockSizeZ) % device_warp_size() == 0
110+
= (BlockSizeX * BlockSizeY * BlockSizeZ) % arch::wavefront::min_size() == 0
109111
? block_radix_rank_algorithm::match
110112
: block_radix_rank_algorithm::basic_memoize,
111113
block_padding_hint PaddingHint = block_padding_hint::lds_occupancy_bound>
@@ -119,11 +121,10 @@ class block_radix_sort
119121
static constexpr bool with_values = !std::is_same<Value, empty_type>::value;
120122
static constexpr bool warp_striped = RadixRankAlgorithm == block_radix_rank_algorithm::match;
121123

122-
#if __HIP_DEVICE_COMPILE__
123-
static_assert(!warp_striped || (BlockSize % device_warp_size()) == 0,
124-
"When using 'block_radix_rank_algorithm::match', the block size should be a "
125-
"multiple of the warp size");
126-
#endif
124+
ROCPRIM_DETAIL_DEVICE_STATIC_ASSERT(
125+
!warp_striped || (BlockSize % ::rocprim::arch::wavefront::min_size()) == 0,
126+
"When using 'block_radix_rank_algorithm::match', the block size should be a "
127+
"multiple of the warp size");
127128

128129
static constexpr bool is_key_and_value_aligned
129130
= alignof(Key) == alignof(Value) && sizeof(Key) == sizeof(Value);
@@ -160,6 +161,12 @@ class block_radix_sort
160161
using storage_type = storage_type_; // only for Doxygen
161162
#endif
162163

164+
ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
165+
block_radix_sort()
166+
{
167+
assert(BlockSize % ::rocprim::arch::wavefront::size() == 0);
168+
}
169+
163170
/// \brief Performs ascending radix sort over keys partitioned across threads in a block.
164171
///
165172
/// \tparam Decomposer The type of the decomposer argument. Defaults to the identity decomposer.
@@ -1060,7 +1067,7 @@ class block_radix_sort
10601067

10611068
private:
10621069
static constexpr bool use_warp_exchange
1063-
= device_warp_size() % ItemsPerThread == 0 && ItemsPerThread <= 4;
1070+
= ::rocprim::arch::wavefront::min_size() % ItemsPerThread == 0 && ItemsPerThread <= 4;
10641071

10651072
template<class SortedValue>
10661073
ROCPRIM_DEVICE ROCPRIM_INLINE
@@ -1126,7 +1133,8 @@ class block_radix_sort
11261133
{
11271134
// This appears to be slower with high large items per thread.
11281135
constexpr bool use_warp_exchange
1129-
= device_warp_size() % ItemsPerThread == 0 && ItemsPerThread <= 4;
1136+
= ::rocprim::arch::wavefront::min_size() % ItemsPerThread == 0
1137+
&& ItemsPerThread <= 4;
11301138
blocked_to_warp_striped(keys,
11311139
values,
11321140
storage,

rocprim/include/rocprim/block/block_reduce.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ struct select_block_reduce_impl<block_reduce_algorithm::raking_reduce_commutativ
9898
/// * \p ItemsPerThread is greater than one,
9999
/// * \p T is an arithmetic type,
100100
/// * reduce operation is simple addition operator, and
101-
/// * the number of threads in the block is a multiple of the hardware warp size (see rocprim::device_warp_size()).
101+
/// * the number of threads in the block is a multiple of the hardware warp size (see \p rocprim::arch::wavefront::min_size() ).
102102
/// * block_reduce has three alternative implementations: \p block_reduce_algorithm::using_warp_reduce,
103103
/// \p block_reduce_algorithm::raking_reduce and \p block_reduce_algorithm::raking_reduce_commutative_only.
104104
/// * If the block sizes less than 64 only one warp reduction is used. The block reduction algorithm

rocprim/include/rocprim/block/block_scan.hpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,9 @@
2929
#include "../intrinsics.hpp"
3030
#include "../functional.hpp"
3131

32-
#include "detail/block_scan_warp_scan.hpp"
3332
#include "detail/block_scan_reduce_then_scan.hpp"
33+
#include "detail/block_scan_warp_scan.hpp"
34+
#include "rocprim/intrinsics/arch.hpp"
3435

3536
/// \addtogroup blockmodule
3637
/// @{
@@ -70,10 +71,9 @@ struct select_block_scan_impl<block_scan_algorithm::reduce_then_scan>
7071
// When BlockSize is less than hardware warp size block_scan_warp_scan performs better than
7172
// block_scan_reduce_then_scan by specializing for warps
7273
using type = typename std::conditional<
73-
(BlockSizeX * BlockSizeY * BlockSizeZ <= ::rocprim::device_warp_size()),
74-
block_scan_warp_scan<T, BlockSizeX, BlockSizeY, BlockSizeZ>,
75-
block_scan_reduce_then_scan<T, BlockSizeX, BlockSizeY, BlockSizeZ>
76-
>::type;
74+
(BlockSizeX * BlockSizeY * BlockSizeZ <= ::rocprim::arch::wavefront::min_size()),
75+
block_scan_warp_scan<T, BlockSizeX, BlockSizeY, BlockSizeZ>,
76+
block_scan_reduce_then_scan<T, BlockSizeX, BlockSizeY, BlockSizeZ>>::type;
7777
};
7878

7979
} // end namespace detail
@@ -96,7 +96,7 @@ struct select_block_scan_impl<block_scan_algorithm::reduce_then_scan>
9696
/// * \p ItemsPerThread is greater than one,
9797
/// * \p T is an arithmetic type,
9898
/// * scan operation is simple addition operator, and
99-
/// * the number of threads in the block is a multiple of the hardware warp size (see rocprim::device_warp_size()).
99+
/// * the number of threads in the block is a multiple of the hardware warp size (see \p rocprim::arch::wavefront::min_size() ).
100100
/// * block_scan has two alternative implementations: \p block_scan_algorithm::using_warp_scan
101101
/// and block_scan_algorithm::reduce_then_scan.
102102
///

rocprim/include/rocprim/block/block_store.hpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -498,11 +498,17 @@ class block_store<T, BlockSizeX, ItemsPerThread, block_store_method::block_store
498498
using block_exchange_type = block_exchange<T, BlockSize, ItemsPerThread>;
499499

500500
public:
501-
ROCPRIM_DETAIL_DEVICE_STATIC_ASSERT(BlockSize % ::rocprim::device_warp_size() == 0,
501+
ROCPRIM_DETAIL_DEVICE_STATIC_ASSERT(BlockSize % ::rocprim::arch::wavefront::min_size() == 0,
502502
"BlockSize must be a multiple of hardware warpsize");
503503

504504
using storage_type = typename block_exchange_type::storage_type;
505505

506+
ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
507+
block_store()
508+
{
509+
assert(BlockSize % ::rocprim::arch::wavefront::size() == 0);
510+
}
511+
506512
template<class OutputIterator>
507513
ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
508514
void store(OutputIterator block_output,

0 commit comments

Comments
 (0)