Skip to content

Commit 04c0d48

Browse files
committed
Move all HIP stuff to ggml-cuda.cu
1 parent d83cfba commit 04c0d48

File tree

3 files changed

+46
-54
lines changed

3 files changed

+46
-54
lines changed

CMakeLists.txt

+5-5
Original file line numberDiff line numberDiff line change
@@ -232,16 +232,16 @@ if (LLAMA_HIPBLAS)
232232
find_package(hipblas)
233233

234234
if (${hipblas_FOUND} AND ${hip_FOUND})
235-
message(STATUS "hipBLAS found")
236-
add_compile_definitions(GGML_USE_HIPBLAS)
237-
add_library(ggml-hip OBJECT ggml-cuda.cu ggml-cuda.h)
235+
message(STATUS "HIP and hipBLAS found")
236+
add_compile_definitions(GGML_USE_HIPBLAS GGML_USE_CUBLAS)
237+
add_library(ggml-rocm OBJECT ggml-cuda.cu ggml-cuda.h)
238238
set_source_files_properties(ggml-cuda.cu PROPERTIES LANGUAGE CXX)
239-
target_link_libraries(ggml-hip PRIVATE hip::device)
239+
target_link_libraries(ggml-rocm PRIVATE hip::device PUBLIC hip::host roc::hipblas)
240240

241241
if (LLAMA_STATIC)
242242
message(FATAL_ERROR "Static linking not supported for HIP/ROCm")
243243
endif()
244-
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} hip::host roc::hipblas ggml-hip)
244+
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} ggml-rocm)
245245
else()
246246
message(WARNING "hipBLAS or HIP not found. Try setting CMAKE_PREFIX_PATH=/opt/rocm")
247247
endif()

ggml-cuda.cu

+41-3
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,47 @@
55
#include <atomic>
66

77
#if defined(GGML_USE_HIPBLAS)
8-
#include "hip/hip_runtime.h"
9-
#include "hipblas/hipblas.h"
10-
#include "hip/hip_fp16.h"
8+
#include <hip/hip_runtime.h>
9+
#include <hipblas/hipblas.h>
10+
#include <hip/hip_fp16.h>
11+
#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F
12+
#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
13+
#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
14+
#define CUBLAS_OP_N HIPBLAS_OP_N
15+
#define CUBLAS_OP_T HIPBLAS_OP_T
16+
#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
17+
#define CUBLAS_TF32_TENSOR_OP_MATH 0
18+
#define CUDA_R_16F HIPBLAS_R_16F
19+
#define CUDA_R_32F HIPBLAS_R_32F
20+
#define cublasCreate hipblasCreate
21+
#define cublasGemmEx hipblasGemmEx
22+
#define cublasHandle_t hipblasHandle_t
23+
#define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS
24+
#define cublasSetStream hipblasSetStream
25+
#define cublasSgemm hipblasSgemm
26+
#define cublasStatus_t hipblasStatus_t
27+
#define cudaDeviceSynchronize hipDeviceSynchronize
28+
#define cudaError_t hipError_t
29+
#define cudaEventCreateWithFlags hipEventCreateWithFlags
30+
#define cudaEventDisableTiming hipEventDisableTiming
31+
#define cudaEventRecord hipEventRecord
32+
#define cudaEvent_t hipEvent_t
33+
#define cudaFree hipFree
34+
#define cudaFreeHost hipHostFree
35+
#define cudaGetErrorString hipGetErrorString
36+
#define cudaGetLastError hipGetLastError
37+
#define cudaMalloc hipMalloc
38+
#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocPortable)
39+
#define cudaMemcpy2DAsync hipMemcpy2DAsync
40+
#define cudaMemcpyAsync hipMemcpyAsync
41+
#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost
42+
#define cudaMemcpyHostToDevice hipMemcpyHostToDevice
43+
#define cudaStreamCreateWithFlags hipStreamCreateWithFlags
44+
#define cudaStreamNonBlocking hipStreamNonBlocking
45+
#define cudaStreamSynchronize hipStreamSynchronize
46+
#define cudaStreamWaitEvent hipStreamWaitEvent
47+
#define cudaStream_t hipStream_t
48+
#define cudaSuccess hipSuccess
1149
#else
1250
#include <cuda_runtime.h>
1351
#include <cublas_v2.h>

ggml-cuda.h

-46
Original file line numberDiff line numberDiff line change
@@ -1,49 +1,3 @@
1-
#if defined(GGML_USE_HIPBLAS)
2-
#include "hipblas/hipblas.h"
3-
#include "hip/hip_runtime.h"
4-
#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F
5-
#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
6-
#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
7-
#define CUBLAS_OP_N HIPBLAS_OP_N
8-
#define CUBLAS_OP_T HIPBLAS_OP_T
9-
#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
10-
#define CUBLAS_TF32_TENSOR_OP_MATH 0
11-
#define CUDA_R_16F HIPBLAS_R_16F
12-
#define CUDA_R_32F HIPBLAS_R_32F
13-
#define cublasCreate hipblasCreate
14-
#define cublasGemmEx hipblasGemmEx
15-
#define cublasHandle_t hipblasHandle_t
16-
#define cublasSetMathMode(h, m) HIPBLAS_STATUS_SUCCESS
17-
#define cublasSetStream hipblasSetStream
18-
#define cublasSgemm hipblasSgemm
19-
#define cublasStatus_t hipblasStatus_t
20-
#define cudaDeviceSynchronize hipDeviceSynchronize
21-
#define cudaError_t hipError_t
22-
#define cudaEventCreateWithFlags hipEventCreateWithFlags
23-
#define cudaEventDisableTiming hipEventDisableTiming
24-
#define cudaEventRecord hipEventRecord
25-
#define cudaEvent_t hipEvent_t
26-
#define cudaFree hipFree
27-
#define cudaFreeHost hipFreeHost
28-
#define cudaGetErrorString hipGetErrorString
29-
#define cudaGetLastError hipGetLastError
30-
#define cudaMalloc hipMalloc
31-
#define cudaMallocHost hipMallocHost
32-
#define cudaMemcpy2DAsync hipMemcpy2DAsync
33-
#define cudaMemcpyAsync hipMemcpyAsync
34-
#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost
35-
#define cudaMemcpyHostToDevice hipMemcpyHostToDevice
36-
#define cudaStreamCreateWithFlags hipStreamCreateWithFlags
37-
#define cudaStreamNonBlocking hipStreamNonBlocking
38-
#define cudaStreamSynchronize hipStreamSynchronize
39-
#define cudaStreamWaitEvent hipStreamWaitEvent
40-
#define cudaStream_t hipStream_t
41-
#define cudaSuccess hipSuccess
42-
#define GGML_USE_CUBLAS
43-
#else
44-
#include <cublas_v2.h>
45-
#include <cuda_runtime.h>
46-
#endif
471
#include "ggml.h"
482

493
#ifdef __cplusplus

0 commit comments

Comments
 (0)