Skip to content

Commit 716301d

Browse files
musa: enable fp16 mma (all) and cublas on qy2 (#13842)
* musa: enable fp16 mma (all) and cublas on qy2 Signed-off-by: Xiaodong Ye <[email protected]> * Update ggml/src/ggml-cuda/ggml-cuda.cu Co-authored-by: Johannes Gäßler <[email protected]> * Address review comments Signed-off-by: Xiaodong Ye <[email protected]> * Address review comments Signed-off-by: Xiaodong Ye <[email protected]> * musa: disable MUL_MAT_ID (q2_k × f32) due to precision issues Signed-off-by: Xiaodong Ye <[email protected]> --------- Signed-off-by: Xiaodong Ye <[email protected]> Co-authored-by: Johannes Gäßler <[email protected]>
1 parent 60ef23d commit 716301d

File tree

4 files changed

+34
-24
lines changed

4 files changed

+34
-24
lines changed

ggml/src/ggml-cuda/common.cuh

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -76,11 +76,9 @@
7676
#define GGML_CUDA_CC_IS_CDNA(cc) (cc >= GGML_CUDA_CC_CDNA && cc < GGML_CUDA_CC_RDNA1)
7777

7878
// Moore Threads
79-
#define GGML_CUDA_MUSA_ARCH_IS_QY1 (__MUSA_ARCH__ <= 210)
80-
81-
#define GGML_CUDA_CC_QY1 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x210) // MTT S80, MTT S3000
82-
#define GGML_CUDA_CC_QY2 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x220) // MTT S4000
83-
#define GGML_CUDA_CC_NG (GGML_CUDA_CC_OFFSET_MTHREADS + 0x310) // TBD
79+
#define GGML_CUDA_CC_QY1 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x210) // MTT S80, MTT S3000
80+
#define GGML_CUDA_CC_QY2 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x220) // MTT S4000
81+
#define GGML_CUDA_CC_NG (GGML_CUDA_CC_OFFSET_MTHREADS + 0x310) // TBD
8482

8583
#define GGML_CUDA_CC_IS_MTHREADS(cc) (cc >= GGML_CUDA_CC_OFFSET_MTHREADS && cc < GGML_CUDA_CC_OFFSET_AMD)
8684
#define GGML_CUDA_CC_IS_QY1(cc) (cc >= GGML_CUDA_CC_QY1 && cc < GGML_CUDA_CC_QY2)
@@ -203,9 +201,9 @@ typedef float2 dfloat2;
203201
#define FAST_FP16_AVAILABLE
204202
#endif // defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610
205203

206-
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
204+
#if (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA)
207205
#define FP16_MMA_AVAILABLE
208-
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
206+
#endif // (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA)
209207

210208
#if defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || (defined(GGML_HIP_ROCWMMA_FATTN_GFX12) && defined(RDNA4)))
211209
#define FP16_MMA_AVAILABLE
@@ -219,9 +217,9 @@ typedef float2 dfloat2;
219217
#define CP_ASYNC_AVAILABLE
220218
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
221219

222-
#if !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && GGML_CUDA_MUSA_ARCH_IS_QY1)
220+
#if !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ < 220)
223221
#define FLASH_ATTN_AVAILABLE
224-
#endif // !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && GGML_CUDA_MUSA_ARCH_IS_QY1)
222+
#endif // !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ < 220)
225223

226224
static bool fp16_available(const int cc) {
227225
return ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_PASCAL;
@@ -233,7 +231,8 @@ static bool fast_fp16_available(const int cc) {
233231

234232
// To be used for feature selection of external libraries, e.g. cuBLAS.
235233
static bool fast_fp16_hardware_available(const int cc) {
236-
return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_PASCAL && cc != 610) || GGML_CUDA_CC_IS_AMD(cc);
234+
return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_PASCAL && cc != 610) || GGML_CUDA_CC_IS_AMD(cc) ||
235+
(GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2);
237236
}
238237

239238
// Any FP16 tensor core instructions are available for ggml code.
@@ -242,7 +241,8 @@ static bool fp16_mma_available(const int cc) {
242241
return false;
243242
#else
244243
if ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) ||
245-
GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc)) {
244+
GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) ||
245+
GGML_CUDA_CC_IS_MTHREADS(cc)) {
246246
return true;
247247
} else if (GGML_CUDA_CC_IS_RDNA4(cc)) {
248248
#if defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_HIP_ROCWMMA_FATTN_GFX12)
@@ -259,7 +259,8 @@ static bool fp16_mma_available(const int cc) {
259259
// To be used for feature selection of external libraries, e.g. cuBLAS.
260260
static bool fp16_mma_hardware_available(const int cc) {
261261
return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_VOLTA) ||
262-
GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc);
262+
GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc) ||
263+
(GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2);
263264
}
264265

265266
static bool bf16_mma_hardware_available(const int cc) {

ggml/src/ggml-cuda/fattn-wmma-f16.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,11 @@
99
#ifdef FP16_MMA_AVAILABLE
1010
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
1111
#include <mma.h>
12+
#ifdef GGML_USE_MUSA
13+
namespace wmma = mtmusa::wmma;
14+
#else // GGML_USE_MUSA
1215
namespace wmma = nvcuda::wmma;
16+
#endif // GGML_USE_MUSA
1317
#elif defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)
1418
#undef HIP_ENABLE_WARP_SYNC_BUILTINS // conflicts with rocWMMA headers
1519
#include <rocwmma/rocwmma.hpp>

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1227,9 +1227,12 @@ static void ggml_cuda_op_mul_mat_cublas(
12271227

12281228
const int cc = ggml_cuda_info().devices[id].cc;
12291229

1230+
const bool supports_bf16 = GGML_CUDA_CC_IS_NVIDIA(cc) || GGML_CUDA_CC_IS_AMD(cc) ||
1231+
(GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2);
1232+
12301233
const bool use_fp16 = (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT;
12311234

1232-
if (src0->type == GGML_TYPE_BF16 && ggml_is_contiguous(src0) && row_diff == src0->ne[1]) {
1235+
if (supports_bf16 && src0->type == GGML_TYPE_BF16 && ggml_is_contiguous(src0) && row_diff == src0->ne[1]) {
12331236
ggml_cuda_pool_alloc<nv_bfloat16> src1_as_bf16(ctx.pool(id));
12341237
if (src1->type != GGML_TYPE_BF16) {
12351238
const to_bf16_cuda_t to_bf16_cuda = ggml_get_to_bf16_cuda(src1->type);
@@ -1257,7 +1260,7 @@ static void ggml_cuda_op_mul_mat_cublas(
12571260

12581261
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_BF16);
12591262
to_fp32_cuda(dst_bf16.get(), dst_dd_i, row_diff*src1_ncols, stream);
1260-
} else if (((GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_VOLTA) || GGML_CUDA_CC_IS_AMD(cc)) && use_fp16) {
1263+
} else if (fast_fp16_hardware_available(cc) && use_fp16) {
12611264
// convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
12621265
ggml_cuda_pool_alloc<half> src0_as_f16(ctx.pool(id));
12631266
if (src0->type != GGML_TYPE_F16) {
@@ -3061,9 +3064,16 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
30613064
return false;
30623065
}
30633066
#ifdef GGML_USE_MUSA
3064-
if (b->type == GGML_TYPE_F16 && b->ne[2]*b->ne[3] > 1 &&
3065-
!ggml_is_transposed(a) && !ggml_is_transposed(b)) {
3066-
return false;
3067+
const int cc = ggml_cuda_info().devices[dev_ctx->device].cc;
3068+
if (b->ne[2]*b->ne[3] > 1 && !ggml_is_transposed(a) && !ggml_is_transposed(b)) {
3069+
if (GGML_CUDA_CC_IS_QY1(cc) && op->op == GGML_OP_MUL_MAT &&
3070+
a->type == GGML_TYPE_F16 && b->type == GGML_TYPE_F16) {
3071+
return false;
3072+
}
3073+
if (GGML_CUDA_CC_IS_QY2(cc) && op->op == GGML_OP_MUL_MAT_ID &&
3074+
a->type == GGML_TYPE_Q2_K && b->type == GGML_TYPE_F32) {
3075+
return false;
3076+
}
30673077
}
30683078
#endif // GGML_USE_MUSA
30693079
switch (a->type) {
@@ -3090,11 +3100,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
30903100
case GGML_TYPE_IQ4_NL:
30913101
case GGML_TYPE_IQ4_XS:
30923102
case GGML_TYPE_BF16:
3093-
#ifdef GGML_USE_MUSA
3094-
if (a->type == GGML_TYPE_Q3_K) {
3095-
return false;
3096-
}
3097-
#endif // GGML_USE_MUSA
30983103
return true;
30993104
default:
31003105
return false;

ggml/src/ggml-musa/mudnn.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#pragma once
22

3-
#include "../include/ggml.h"
4-
#include "../ggml-cuda/common.cuh"
3+
#include "ggml-cuda/common.cuh"
4+
#include "ggml.h"
55

66
// Asynchronously copies data from src tensor to dst tensor using the provided context.
77
// Returns a musaError_t indicating success or failure.

0 commit comments

Comments
 (0)