Skip to content

[ROCm] add support for ROCm/HIP device #6086

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 36 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
01ff268
[ROCm] add support for ROCm/HIP
jeffdaily Sep 7, 2023
ac966a5
more rocm updates
jeffdaily Sep 7, 2023
6ae6432
more bug fixes
jeffdaily Sep 7, 2023
5b87bcd
warp 32 vs 64 updates
jeffdaily Sep 7, 2023
3ad89f8
lint fixes
jeffdaily Sep 8, 2023
7b8b6a0
missing device_index variable
jeffdaily Sep 8, 2023
62aa30b
accidental inclusion of hip headers
jeffdaily Sep 8, 2023
bb27c55
copyright notice compliance
jeffdaily Sep 11, 2023
58ace9c
Merge branch 'master' into rocm2
shiyu1994 Sep 12, 2023
0bc1cfb
Merge branch 'master' into rocm2
shiyu1994 Sep 13, 2023
a7c9653
Merge branch 'master' into rocm2
shiyu1994 Oct 8, 2023
96e3a52
Merge branch 'master' into rocm2
shiyu1994 Oct 12, 2023
cb7623a
Merge branch 'master' into rocm2
shiyu1994 Nov 1, 2023
9ba27bb
Merge branch 'master' into rocm2
shiyu1994 Nov 3, 2023
5ba59b8
Merge branch 'master' into rocm2
shiyu1994 Dec 1, 2023
c0abd17
fix conflicts
shiyu1994 Oct 28, 2024
e7129a0
Update CMakeLists.txt
shiyu1994 Oct 28, 2024
eb0036f
Merge branch 'master' into rocm2
shiyu1994 Dec 17, 2024
dbd972e
fix lint issue
shiyu1994 Dec 17, 2024
4cd0dea
Merge branch 'master' into rocm2
shiyu1994 Dec 18, 2024
3ad2482
Merge branch 'master' into rocm2
shiyu1994 Dec 19, 2024
8f6600e
clean up
shiyu1994 Dec 24, 2024
47fc353
Merge branch 'rocm2' of https://github.com/jeffdaily/LightGBM into HEAD
shiyu1994 Dec 24, 2024
2e8869c
Merge branch 'master' into rocm2
shiyu1994 Dec 24, 2024
b173124
Update CMakeLists.txt
shiyu1994 Feb 5, 2025
785b341
Update CMakeLists.txt
shiyu1994 Feb 5, 2025
f4c605d
merge master
shiyu1994 Feb 25, 2025
3cd34a2
clean up CMakeLists.txt
shiyu1994 Feb 26, 2025
89605b8
Merge branch 'master' into rocm2
shiyu1994 Mar 4, 2025
c66b8f3
Merge branch 'master' into rocm2
shiyu1994 Apr 16, 2025
8591248
Merge branch 'master' into rocm2
StrikerRUS Apr 19, 2025
8b1c95a
Merge branch 'master' into rocm2
shiyu1994 Jun 10, 2025
6732b79
use WARPSIZE
Jun 10, 2025
28d4648
Merge branch 'rocm2' of https://github.com/jeffdaily/LightGBM into HEAD
Jun 10, 2025
9a19de6
Merge branch 'master' into rocm2
shiyu1994 Jun 17, 2025
d4676d9
fix share buffer size
shiyu1994 Jun 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -463,3 +463,6 @@ dask-worker-space/
*.pub
*.rdp
*_rsa

# hipify-perl -inplace leaves behind *.prehip files
*.prehip
194 changes: 172 additions & 22 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ option(USE_SWIG "Enable SWIG to generate Java API" OFF)
option(USE_HDFS "Enable HDFS support (EXPERIMENTAL)" OFF)
option(USE_TIMETAG "Set to ON to output time costs" OFF)
option(USE_CUDA "Enable CUDA-accelerated training " OFF)
option(USE_ROCM "Enable ROCM-accelerated training " OFF)
option(USE_DEBUG "Set to ON for Debug mode" OFF)
option(USE_SANITIZER "Use santizer flags" OFF)
set(
Expand Down Expand Up @@ -35,6 +36,8 @@ elseif(USE_GPU OR APPLE)
cmake_minimum_required(VERSION 3.2)
elseif(USE_CUDA)
cmake_minimum_required(VERSION 3.16)
elseif(USE_ROCM)
cmake_minimum_required(VERSION 3.21)
else()
cmake_minimum_required(VERSION 3.0)
endif()
Expand Down Expand Up @@ -153,6 +156,11 @@ if(USE_CUDA)
set(USE_OPENMP ON CACHE BOOL "CUDA requires OpenMP" FORCE)
endif()

if(USE_ROCM)
enable_language(HIP)
set(USE_OPENMP ON CACHE BOOL "ROCM requires OpenMP" FORCE)
endif()

if(USE_OPENMP)
if(APPLE)
find_package(OpenMP)
Expand Down Expand Up @@ -282,6 +290,76 @@ if(USE_CUDA)
endforeach()
endif()

if(USE_ROCM)
find_package(HIP)
include_directories(${HIP_INCLUDE_DIRS})
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__HIP_PLATFORM_AMD__")
set(CMAKE_HIP_FLAGS "${CMAKE_HIP_FLAGS} ${OpenMP_CXX_FLAGS} -fPIC -Wall")

# avoid warning: unused variable 'mask' due to __shfl_down_sync work-around
set(DISABLED_WARNINGS "${DISABLED_WARNINGS} -Wno-unused-variable")
# avoid warning: 'hipHostAlloc' is deprecated: use hipHostMalloc instead
set(DISABLED_WARNINGS "${DISABLED_WARNINGS} -Wno-deprecated-declarations")
# avoid many warnings about missing overrides
set(DISABLED_WARNINGS "${DISABLED_WARNINGS} -Wno-inconsistent-missing-override")
# avoid warning: shift count >= width of type in feature_histogram.hpp
set(DISABLED_WARNINGS "${DISABLED_WARNINGS} -Wno-shift-count-overflow")

set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${DISABLED_WARNINGS}")
set(CMAKE_HIP_FLAGS "${CMAKE_HIP_FLAGS} ${DISABLED_WARNINGS}")

if(USE_DEBUG)
set(CMAKE_HIP_FLAGS "${CMAKE_HIP_FLAGS} -g -O0")
else()
set(CMAKE_HIP_FLAGS "${CMAKE_HIP_FLAGS} -O3")
endif()
message(STATUS "CMAKE_HIP_FLAGS: ${CMAKE_HIP_FLAGS}")

add_definitions(-DUSE_CUDA)

set(
BASE_DEFINES
-DPOWER_FEATURE_WORKGROUPS=12
-DUSE_CONSTANT_BUF=0
)
set(
ALLFEATS_DEFINES
${BASE_DEFINES}
-DENABLE_ALL_FEATURES
)
set(
FULLDATA_DEFINES
${ALLFEATS_DEFINES}
-DIGNORE_INDICES
)

message(STATUS "ALLFEATS_DEFINES: ${ALLFEATS_DEFINES}")
message(STATUS "FULLDATA_DEFINES: ${FULLDATA_DEFINES}")

function(add_histogram hsize hname hadd hconst hdir)
add_library(histo${hsize}${hname} OBJECT src/treelearner/kernels/histogram${hsize}.cu)
if(hadd)
list(APPEND histograms histo${hsize}${hname})
set(histograms ${histograms} PARENT_SCOPE)
endif()
target_compile_definitions(
histo${hsize}${hname}
PRIVATE
-DCONST_HESSIAN=${hconst}
${hdir}
)
endfunction()

foreach(hsize _16_64_256)
add_histogram("${hsize}" "_sp_const" "True" "1" "${BASE_DEFINES}")
add_histogram("${hsize}" "_sp" "True" "0" "${BASE_DEFINES}")
add_histogram("${hsize}" "-allfeats_sp_const" "False" "1" "${ALLFEATS_DEFINES}")
add_histogram("${hsize}" "-allfeats_sp" "False" "0" "${ALLFEATS_DEFINES}")
add_histogram("${hsize}" "-fulldata_sp_const" "True" "1" "${FULLDATA_DEFINES}")
add_histogram("${hsize}" "-fulldata_sp" "True" "0" "${FULLDATA_DEFINES}")
endforeach()
endif()

if(USE_HDFS)
find_package(JNI REQUIRED)
find_path(HDFS_INCLUDE_DIR hdfs.h REQUIRED)
Expand Down Expand Up @@ -404,31 +482,93 @@ if(USE_MPI)
include_directories(${MPI_CXX_INCLUDE_PATH})
endif()

file(
GLOB
list(
APPEND
SOURCES
src/boosting/*.cpp
src/io/*.cpp
src/metric/*.cpp
src/objective/*.cpp
src/network/*.cpp
src/treelearner/*.cpp
src/boosting/boosting.cpp
src/boosting/gbdt.cpp
src/boosting/gbdt_model_text.cpp
src/boosting/gbdt_prediction.cpp
src/boosting/prediction_early_stop.cpp
src/boosting/sample_strategy.cpp
src/io/bin.cpp
src/io/config.cpp
src/io/config_auto.cpp
src/io/dataset.cpp
src/io/dataset_loader.cpp
src/io/file_io.cpp
src/io/json11.cpp
src/io/metadata.cpp
src/io/parser.cpp
src/io/train_share_states.cpp
src/io/tree.cpp
src/metric/dcg_calculator.cpp
src/metric/metric.cpp
src/network/linker_topo.cpp
src/network/linkers_mpi.cpp
src/network/linkers_socket.cpp
src/network/network.cpp
src/objective/objective_function.cpp
src/treelearner/data_parallel_tree_learner.cpp
src/treelearner/feature_parallel_tree_learner.cpp
src/treelearner/gpu_tree_learner.cpp
src/treelearner/gradient_discretizer.cpp
src/treelearner/linear_tree_learner.cpp
src/treelearner/serial_tree_learner.cpp
src/treelearner/tree_learner.cpp
src/treelearner/voting_parallel_tree_learner.cpp
)

list(
APPEND
CUDA_SOURCES
src/boosting/cuda/cuda_score_updater.cpp
src/boosting/cuda/cuda_score_updater.cu
src/cuda/cuda_algorithms.cu
src/cuda/cuda_utils.cpp
src/io/cuda/cuda_column_data.cpp
src/io/cuda/cuda_column_data.cu
src/io/cuda/cuda_metadata.cpp
src/io/cuda/cuda_row_data.cpp
src/io/cuda/cuda_tree.cpp
src/io/cuda/cuda_tree.cu
src/metric/cuda/cuda_binary_metric.cpp
src/metric/cuda/cuda_pointwise_metric.cpp
src/metric/cuda/cuda_pointwise_metric.cu
src/metric/cuda/cuda_regression_metric.cpp
src/objective/cuda/cuda_binary_objective.cpp
src/objective/cuda/cuda_binary_objective.cu
src/objective/cuda/cuda_multiclass_objective.cpp
src/objective/cuda/cuda_multiclass_objective.cu
src/objective/cuda/cuda_rank_objective.cpp
src/objective/cuda/cuda_rank_objective.cu
src/objective/cuda/cuda_regression_objective.cpp
src/objective/cuda/cuda_regression_objective.cu
src/treelearner/cuda/cuda_best_split_finder.cpp
src/treelearner/cuda/cuda_best_split_finder.cu
src/treelearner/cuda/cuda_data_partition.cpp
src/treelearner/cuda/cuda_data_partition.cu
src/treelearner/cuda/cuda_histogram_constructor.cpp
src/treelearner/cuda/cuda_histogram_constructor.cu
src/treelearner/cuda/cuda_leaf_splits.cpp
src/treelearner/cuda/cuda_leaf_splits.cu
src/treelearner/cuda/cuda_single_gpu_tree_learner.cpp
src/treelearner/cuda/cuda_single_gpu_tree_learner.cu
)

if(USE_CUDA)
src/treelearner/*.cu
src/boosting/cuda/*.cpp
src/boosting/cuda/*.cu
src/metric/cuda/*.cpp
src/metric/cuda/*.cu
src/objective/cuda/*.cpp
src/objective/cuda/*.cu
src/treelearner/cuda/*.cpp
src/treelearner/cuda/*.cu
src/io/cuda/*.cu
src/io/cuda/*.cpp
src/cuda/*.cpp
src/cuda/*.cu
list(APPEND SOURCES ${CUDA_SOURCES})
endif()

set(ROCM_SOURCES ${CUDA_SOURCES})
if(USE_ROCM)
foreach(f ${CUDA_SOURCES})
if(f MATCHES ".*\\.cu$")
set_source_files_properties(${f} PROPERTIES LANGUAGE HIP)
endif()
endforeach()
list(APPEND SOURCES ${ROCM_SOURCES})
endif()
)

add_library(lightgbm_objs OBJECT ${SOURCES})

Expand Down Expand Up @@ -579,6 +719,16 @@ if(USE_CUDA)
target_link_libraries(_lightgbm PRIVATE ${histograms})
endif()

if(USE_ROCM)
# histograms are list of object libraries. Linking object library to other
# object libraries only gets usage requirements, the linked objects won't be
# used. Thus we have to call target_link_libraries on final targets here.
if(BUILD_CLI)
target_link_libraries(lightgbm PRIVATE ${histograms})
endif()
target_link_libraries(_lightgbm PRIVATE ${histograms})
endif()

if(USE_HDFS)
target_link_libraries(lightgbm_objs PUBLIC ${HDFS_CXX_LIBRARIES})
endif()
Expand Down
16 changes: 16 additions & 0 deletions helpers/hipify.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#!/bin/bash

for DIR in ./src ./include
do
for EXT in cpp h hpp cu
do
for FILE in $(find ${DIR} -name *.${EXT})
do
echo "hipifying $FILE in-place"
hipify-perl $FILE -inplace &
done
done
done

echo "waiting for all hipify-perl invocations to finish"
wait
49 changes: 45 additions & 4 deletions include/LightGBM/cuda/cuda_algorithms.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include <LightGBM/bin.h>
#include <LightGBM/cuda/cuda_utils.h>
#include <LightGBM/cuda/cuda_rocm_interop.h>
#include <LightGBM/utils/log.h>

#include <algorithm>
Expand Down Expand Up @@ -174,7 +175,7 @@ __device__ __forceinline__ void GlobalMemoryPrefixSum(T* array, const size_t len
for (size_t index = start; index < end; ++index) {
thread_sum += array[index];
}
__shared__ T shared_mem[32];
__shared__ T shared_mem[WARPSIZE];
const T thread_base = ShufflePrefixSumExclusive<T>(thread_sum, shared_mem);
if (start < end) {
array[start] += thread_base;
Expand Down Expand Up @@ -483,7 +484,7 @@ __device__ void ShuffleSortedPrefixSumDevice(const VAL_T* in_values,
const INDEX_T* sorted_indices,
REDUCE_VAL_T* out_values,
const INDEX_T num_data) {
__shared__ REDUCE_VAL_T shared_buffer[32];
__shared__ REDUCE_VAL_T shared_buffer[WARPSIZE];
const INDEX_T num_data_per_thread = (num_data + static_cast<INDEX_T>(blockDim.x) - 1) / static_cast<INDEX_T>(blockDim.x);
const INDEX_T start = num_data_per_thread * static_cast<INDEX_T>(threadIdx.x);
const INDEX_T end = min(start + num_data_per_thread, num_data);
Expand Down Expand Up @@ -572,8 +573,48 @@ __device__ VAL_T PercentileDevice(const VAL_T* values,
INDEX_T* indices,
REDUCE_WEIGHT_T* weights_prefix_sum,
const double alpha,
const INDEX_T len);

const INDEX_T len) {
if (len <= 1) {
return values[0];
}
if (!USE_WEIGHT) {
BitonicArgSortDevice<VAL_T, INDEX_T, ASCENDING, BITONIC_SORT_NUM_ELEMENTS / 2, 10>(values, indices, len);
const double float_pos = (1.0f - alpha) * len;
const INDEX_T pos = static_cast<INDEX_T>(float_pos);
if (pos < 1) {
return values[indices[0]];
} else if (pos >= len) {
return values[indices[len - 1]];
} else {
const double bias = float_pos - pos;
const VAL_T v1 = values[indices[pos - 1]];
const VAL_T v2 = values[indices[pos]];
return static_cast<VAL_T>(v1 - (v1 - v2) * bias);
}
} else {
BitonicArgSortDevice<VAL_T, INDEX_T, ASCENDING, BITONIC_SORT_NUM_ELEMENTS / 4, 9>(values, indices, len);
ShuffleSortedPrefixSumDevice<WEIGHT_T, REDUCE_WEIGHT_T, INDEX_T>(weights, indices, weights_prefix_sum, len);
const REDUCE_WEIGHT_T threshold = weights_prefix_sum[len - 1] * (1.0f - alpha);
__shared__ INDEX_T pos;
if (threadIdx.x == 0) {
pos = len;
}
__syncthreads();
for (INDEX_T index = static_cast<INDEX_T>(threadIdx.x); index < len; index += static_cast<INDEX_T>(blockDim.x)) {
if (weights_prefix_sum[index] > threshold && (index == 0 || weights_prefix_sum[index - 1] <= threshold)) {
pos = index;
}
}
__syncthreads();
pos = min(pos, len - 1);
if (pos == 0 || pos == len - 1) {
return values[pos];
}
const VAL_T v1 = values[indices[pos - 1]];
const VAL_T v2 = values[indices[pos]];
return static_cast<VAL_T>(v1 - (v1 - v2) * (threshold - weights_prefix_sum[pos - 1]) / (weights_prefix_sum[pos] - weights_prefix_sum[pos - 1]));
}
}

} // namespace LightGBM

Expand Down
22 changes: 22 additions & 0 deletions include/LightGBM/cuda/cuda_rocm_interop.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
/*!
* Copyright (c) 2021 Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See LICENSE file in the project root for
* license information.
*/
#ifdef USE_CUDA

#if defined(__HIP_PLATFORM_AMD__) || defined(__HIP__)
// ROCm doesn't have __shfl_down_sync, only __shfl_down without mask.
// Since mask is full 0xffffffff, we can use __shfl_down instead.
#define __shfl_down_sync(mask, val, offset) __shfl_down(val, offset)
#define __shfl_up_sync(mask, val, offset) __shfl_up(val, offset)
// ROCm warpSize is constexpr and is either 32 or 64 depending on gfx arch.
#define WARPSIZE warpSize
// ROCm doesn't have atomicAdd_block, but it should be semantically the same as atomicAdd
#define atomicAdd_block atomicAdd
#else
// CUDA warpSize is not a constexpr, but always 32
#define WARPSIZE 32
#endif

#endif
6 changes: 3 additions & 3 deletions include/LightGBM/cuda/cuda_split_info.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,13 @@ class CUDASplitInfo {
uint32_t* cat_threshold = nullptr;
int* cat_threshold_real = nullptr;

__device__ CUDASplitInfo() {
__host__ __device__ CUDASplitInfo() {
num_cat_threshold = 0;
cat_threshold = nullptr;
cat_threshold_real = nullptr;
}

__device__ ~CUDASplitInfo() {
__host__ __device__ ~CUDASplitInfo() {
if (num_cat_threshold > 0) {
if (cat_threshold != nullptr) {
cudaFree(cat_threshold);
Expand All @@ -55,7 +55,7 @@ class CUDASplitInfo {
}
}

__device__ CUDASplitInfo& operator=(const CUDASplitInfo& other) {
__host__ __device__ CUDASplitInfo& operator=(const CUDASplitInfo& other) {
is_valid = other.is_valid;
leaf_index = other.leaf_index;
gain = other.gain;
Expand Down
2 changes: 1 addition & 1 deletion include/LightGBM/cuda/vector_cudahost.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ struct CHAllocator {
n = SIZE_ALIGNED(n);
#ifdef USE_CUDA
if (LGBM_config_::current_device == lgbm_device_cuda) {
cudaError_t ret = cudaHostAlloc(&ptr, n*sizeof(T), cudaHostAllocPortable);
cudaError_t ret = cudaHostAlloc((void**)&ptr, n*sizeof(T), cudaHostAllocPortable);
if (ret != cudaSuccess) {
Log::Warning("Defaulting to malloc in CHAllocator!!!");
ptr = reinterpret_cast<T*>(_mm_malloc(n*sizeof(T), 16));
Expand Down
Loading