File tree Expand file tree Collapse file tree 7 files changed +39
-11
lines changed Expand file tree Collapse file tree 7 files changed +39
-11
lines changed Original file line number Diff line number Diff line change @@ -267,6 +267,34 @@ if(USE_CUDA)
267
267
endif ()
268
268
endif ()
269
269
270
+ if (USE_ROCM )
271
+ find_package (HIP )
272
+ include_directories (${HIP_INCLUDE_DIRS} )
273
+ set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__HIP_PLATFORM_AMD__" )
274
+ set (CMAKE_HIP_FLAGS "${CMAKE_HIP_FLAGS} ${OpenMP_CXX_FLAGS} -fPIC -Wall" )
275
+
276
+ # avoid warning: unused variable 'mask' due to __shfl_down_sync work-around
277
+ set (DISABLED_WARNINGS "${DISABLED_WARNINGS} -Wno-unused-variable" )
278
+ # avoid warning: 'hipHostAlloc' is deprecated: use hipHostMalloc instead
279
+ set (DISABLED_WARNINGS "${DISABLED_WARNINGS} -Wno-deprecated-declarations" )
280
+ # avoid many warnings about missing overrides
281
+ set (DISABLED_WARNINGS "${DISABLED_WARNINGS} -Wno-inconsistent-missing-override" )
282
+ # avoid warning: shift count >= width of type in feature_histogram.hpp
283
+ set (DISABLED_WARNINGS "${DISABLED_WARNINGS} -Wno-shift-count-overflow" )
284
+
285
+ set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${DISABLED_WARNINGS} " )
286
+ set (CMAKE_HIP_FLAGS "${CMAKE_HIP_FLAGS} ${DISABLED_WARNINGS} " )
287
+
288
+ if (USE_DEBUG )
289
+ set (CMAKE_HIP_FLAGS "${CMAKE_HIP_FLAGS} -g -O0" )
290
+ else ()
291
+ set (CMAKE_HIP_FLAGS "${CMAKE_HIP_FLAGS} -O3" )
292
+ endif ()
293
+ message (STATUS "CMAKE_HIP_FLAGS: ${CMAKE_HIP_FLAGS} " )
294
+
295
+ add_definitions (-DUSE_ROCM )
296
+ endif ()
297
+
270
298
include (CheckCXXSourceCompiles )
271
299
check_cxx_source_compiles ("
272
300
#include <xmmintrin.h>
Original file line number Diff line number Diff line change @@ -409,7 +409,7 @@ __global__ void GetGradientsKernel_RankXENDCG_SharedMemory(
409
409
const data_size_t block_reduce_size = query_item_count >= 1024 ? 1024 : query_item_count;
410
410
__shared__ double shared_rho[SHARED_MEMORY_SIZE];
411
411
// assert that warpSize == 32
412
- __shared__ double shared_buffer[32 ];
412
+ __shared__ double shared_buffer[WARPSIZE ];
413
413
__shared__ double shared_params[SHARED_MEMORY_SIZE];
414
414
__shared__ score_t shared_lambdas[SHARED_MEMORY_SIZE];
415
415
__shared__ double reduce_result;
@@ -527,7 +527,7 @@ __global__ void GetGradientsKernel_RankXENDCG_GlobalMemory(
527
527
double * cuda_params_buffer_pointer = cuda_params_buffer + item_index_start;
528
528
const data_size_t block_reduce_size = query_item_count > 1024 ? 1024 : query_item_count;
529
529
// assert that warpSize == 32, so we use buffer size 1024 / 32 = 32
530
- __shared__ double shared_buffer[32 ];
530
+ __shared__ double shared_buffer[WARPSIZE ];
531
531
__shared__ double reduce_result;
532
532
if (query_item_count <= 1 ) {
533
533
for (data_size_t i = 0 ; i <= query_item_count; ++i) {
Original file line number Diff line number Diff line change @@ -364,9 +364,9 @@ __device__ void FindBestSplitsDiscretizedForLeafKernelInner(
364
364
}
365
365
}
366
366
__shared__ uint32_t best_thread_index;
367
- __shared__ double shared_double_buffer[32 ];
368
- __shared__ bool shared_bool_buffer[32 ];
369
- __shared__ uint32_t shared_int_buffer[64 ];
367
+ __shared__ double shared_double_buffer[WARPSIZE ];
368
+ __shared__ bool shared_bool_buffer[WARPSIZE ];
369
+ __shared__ uint32_t shared_int_buffer[2 * WARPSIZE]; // need 2 * WARPSIZE since the actual ACC_HIST_TYPE could be long int
370
370
const unsigned int threadIdx_x = threadIdx .x ;
371
371
const bool skip_sum = REVERSE ?
372
372
(task->skip_default_bin && (task->num_bin - 1 - threadIdx_x) == static_cast <int >(task->default_bin )) :
Original file line number Diff line number Diff line change @@ -1080,7 +1080,7 @@ __global__ void RenewDiscretizedTreeLeavesKernel(
1080
1080
double * leaf_grad_stat_buffer,
1081
1081
double * leaf_hess_stat_buffer,
1082
1082
double * leaf_values) {
1083
- __shared__ double shared_mem_buffer[32 ];
1083
+ __shared__ double shared_mem_buffer[WARPSIZE ];
1084
1084
const int leaf_index = static_cast <int >(blockIdx .x );
1085
1085
const data_size_t * data_indices_in_leaf = data_indices + leaf_data_start[leaf_index];
1086
1086
const data_size_t num_data_in_leaf = leaf_num_data[leaf_index];
Original file line number Diff line number Diff line change @@ -22,7 +22,7 @@ __global__ void ReduceMinMaxKernel(
22
22
score_t * grad_max_block_buffer,
23
23
score_t * hess_min_block_buffer,
24
24
score_t * hess_max_block_buffer) {
25
- __shared__ score_t shared_mem_buffer[32 ];
25
+ __shared__ score_t shared_mem_buffer[WARPSIZE ];
26
26
const data_size_t index = static_cast <data_size_t >(threadIdx .x + blockIdx .x * blockDim .x );
27
27
score_t grad_max_val = kMinScore ;
28
28
score_t grad_min_val = kMaxScore ;
@@ -56,7 +56,7 @@ __global__ void ReduceBlockMinMaxKernel(
56
56
score_t * grad_max_block_buffer,
57
57
score_t * hess_min_block_buffer,
58
58
score_t * hess_max_block_buffer) {
59
- __shared__ score_t shared_mem_buffer[32 ];
59
+ __shared__ score_t shared_mem_buffer[WARPSIZE ];
60
60
score_t grad_max_val = kMinScore ;
61
61
score_t grad_min_val = kMaxScore ;
62
62
score_t hess_max_val = kMinScore ;
Original file line number Diff line number Diff line change @@ -835,7 +835,7 @@ __global__ void FixHistogramDiscretizedKernel(
835
835
const int * cuda_need_fix_histogram_features,
836
836
const uint32_t * cuda_need_fix_histogram_features_num_bin_aligned,
837
837
const CUDALeafSplitsStruct* cuda_smaller_leaf_splits) {
838
- __shared__ int64_t shared_mem_buffer[32 ];
838
+ __shared__ int64_t shared_mem_buffer[WARPSIZE ];
839
839
const unsigned int blockIdx_x = blockIdx .x ;
840
840
const int feature_index = cuda_need_fix_histogram_features[blockIdx_x];
841
841
const uint32_t num_bin_aligned = cuda_need_fix_histogram_features_num_bin_aligned[blockIdx_x];
Original file line number Diff line number Diff line change @@ -90,7 +90,7 @@ __global__ void CUDAInitValuesKernel3(const int16_t* cuda_gradients_and_hessians
90
90
const score_t * grad_scale_pointer, const score_t * hess_scale_pointer) {
91
91
const score_t grad_scale = *grad_scale_pointer;
92
92
const score_t hess_scale = *hess_scale_pointer;
93
- __shared__ int64_t shared_mem_buffer[32 ];
93
+ __shared__ int64_t shared_mem_buffer[WARPSIZE ];
94
94
const data_size_t data_index = static_cast <data_size_t >(threadIdx .x + blockIdx .x * blockDim .x );
95
95
int64_t int_gradient = 0 ;
96
96
int64_t int_hessian = 0 ;
@@ -121,7 +121,7 @@ __global__ void CUDAInitValuesKernel4(
121
121
const data_size_t * cuda_data_indices_in_leaf,
122
122
hist_t * cuda_hist_in_leaf,
123
123
CUDALeafSplitsStruct* cuda_struct) {
124
- __shared__ double shared_mem_buffer[32 ];
124
+ __shared__ double shared_mem_buffer[WARPSIZE ];
125
125
double thread_sum_of_gradients = 0 .0f ;
126
126
double thread_sum_of_hessians = 0 .0f ;
127
127
int64_t thread_sum_of_gradients_hessians = 0 ;
You can’t perform that action at this time.
0 commit comments