Skip to content

Commit f7a2958

Browse files
committed
Revert "musa: enable fp16 mma (all) and cublas on qy2 (ggml-org#13842)"
This reverts commit 716301d.
1 parent 485148b commit f7a2958

File tree

4 files changed

+24
-34
lines changed

4 files changed

+24
-34
lines changed

ggml/src/ggml-cuda/common.cuh

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,11 @@
7676
#define GGML_CUDA_CC_IS_CDNA(cc) (cc >= GGML_CUDA_CC_CDNA && cc < GGML_CUDA_CC_RDNA1)
7777

7878
// Moore Threads
79-
#define GGML_CUDA_CC_QY1 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x210) // MTT S80, MTT S3000
80-
#define GGML_CUDA_CC_QY2 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x220) // MTT S4000
81-
#define GGML_CUDA_CC_NG (GGML_CUDA_CC_OFFSET_MTHREADS + 0x310) // TBD
79+
#define GGML_CUDA_MUSA_ARCH_IS_QY1 (__MUSA_ARCH__ <= 210)
80+
81+
#define GGML_CUDA_CC_QY1 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x210) // MTT S80, MTT S3000
82+
#define GGML_CUDA_CC_QY2 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x220) // MTT S4000
83+
#define GGML_CUDA_CC_NG (GGML_CUDA_CC_OFFSET_MTHREADS + 0x310) // TBD
8284

8385
#define GGML_CUDA_CC_IS_MTHREADS(cc) (cc >= GGML_CUDA_CC_OFFSET_MTHREADS && cc < GGML_CUDA_CC_OFFSET_AMD)
8486
#define GGML_CUDA_CC_IS_QY1(cc) (cc >= GGML_CUDA_CC_QY1 && cc < GGML_CUDA_CC_QY2)
@@ -205,9 +207,9 @@ typedef float2 dfloat2;
205207
#define FAST_FP16_AVAILABLE
206208
#endif // defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610
207209

208-
#if (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA)
210+
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
209211
#define FP16_MMA_AVAILABLE
210-
#endif // (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA)
212+
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
211213

212214
#if defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || (defined(GGML_HIP_ROCWMMA_FATTN_GFX12) && defined(RDNA4)))
213215
#define FP16_MMA_AVAILABLE
@@ -221,9 +223,9 @@ typedef float2 dfloat2;
221223
#define CP_ASYNC_AVAILABLE
222224
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
223225

224-
#if !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ < 220)
226+
#if !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && GGML_CUDA_MUSA_ARCH_IS_QY1)
225227
#define FLASH_ATTN_AVAILABLE
226-
#endif // !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ < 220)
228+
#endif // !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && GGML_CUDA_MUSA_ARCH_IS_QY1)
227229

228230
static bool fp16_available(const int cc) {
229231
return ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_PASCAL;
@@ -235,8 +237,7 @@ static bool fast_fp16_available(const int cc) {
235237

236238
// To be used for feature selection of external libraries, e.g. cuBLAS.
237239
static bool fast_fp16_hardware_available(const int cc) {
238-
return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_PASCAL && cc != 610) || GGML_CUDA_CC_IS_AMD(cc) ||
239-
(GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2);
240+
return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_PASCAL && cc != 610) || GGML_CUDA_CC_IS_AMD(cc);
240241
}
241242

242243
// Any FP16 tensor core instructions are available for ggml code.
@@ -245,8 +246,7 @@ static bool fp16_mma_available(const int cc) {
245246
return false;
246247
#else
247248
if ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) ||
248-
GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) ||
249-
GGML_CUDA_CC_IS_MTHREADS(cc)) {
249+
GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc)) {
250250
return true;
251251
} else if (GGML_CUDA_CC_IS_RDNA4(cc)) {
252252
#if defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_HIP_ROCWMMA_FATTN_GFX12)
@@ -263,8 +263,7 @@ static bool fp16_mma_available(const int cc) {
263263
// To be used for feature selection of external libraries, e.g. cuBLAS.
264264
static bool fp16_mma_hardware_available(const int cc) {
265265
return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_VOLTA) ||
266-
GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc) ||
267-
(GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2);
266+
GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc);
268267
}
269268

270269
static bool bf16_mma_hardware_available(const int cc) {

ggml/src/ggml-cuda/fattn-wmma-f16.cu

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,7 @@
99
#ifdef FP16_MMA_AVAILABLE
1010
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
1111
#include <mma.h>
12-
#ifdef GGML_USE_MUSA
13-
namespace wmma = mtmusa::wmma;
14-
#else // GGML_USE_MUSA
1512
namespace wmma = nvcuda::wmma;
16-
#endif // GGML_USE_MUSA
1713
#elif defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)
1814
#undef HIP_ENABLE_WARP_SYNC_BUILTINS // conflicts with rocWMMA headers
1915
#include <rocwmma/rocwmma.hpp>

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1228,12 +1228,9 @@ static void ggml_cuda_op_mul_mat_cublas(
12281228

12291229
const int cc = ggml_cuda_info().devices[id].cc;
12301230

1231-
const bool supports_bf16 = GGML_CUDA_CC_IS_NVIDIA(cc) || GGML_CUDA_CC_IS_AMD(cc) ||
1232-
(GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2);
1233-
12341231
const bool use_fp16 = (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT;
12351232

1236-
if (supports_bf16 && src0->type == GGML_TYPE_BF16 && ggml_is_contiguous(src0) && row_diff == src0->ne[1]) {
1233+
if (src0->type == GGML_TYPE_BF16 && ggml_is_contiguous(src0) && row_diff == src0->ne[1]) {
12371234
ggml_cuda_pool_alloc<nv_bfloat16> src1_as_bf16(ctx.pool(id));
12381235
if (src1->type != GGML_TYPE_BF16) {
12391236
const to_bf16_cuda_t to_bf16_cuda = ggml_get_to_bf16_cuda(src1->type);
@@ -1261,7 +1258,7 @@ static void ggml_cuda_op_mul_mat_cublas(
12611258

12621259
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_BF16);
12631260
to_fp32_cuda(dst_bf16.get(), dst_dd_i, row_diff*src1_ncols, stream);
1264-
} else if (fast_fp16_hardware_available(cc) && use_fp16) {
1261+
} else if (((GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_VOLTA) || GGML_CUDA_CC_IS_AMD(cc)) && use_fp16) {
12651262
// convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
12661263
ggml_cuda_pool_alloc<half> src0_as_f16(ctx.pool(id));
12671264
if (src0->type != GGML_TYPE_F16) {
@@ -3069,16 +3066,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
30693066
return false;
30703067
}
30713068
#ifdef GGML_USE_MUSA
3072-
const int cc = ggml_cuda_info().devices[dev_ctx->device].cc;
3073-
if (b->ne[2]*b->ne[3] > 1 && !ggml_is_transposed(a) && !ggml_is_transposed(b)) {
3074-
if (GGML_CUDA_CC_IS_QY1(cc) && op->op == GGML_OP_MUL_MAT &&
3075-
a->type == GGML_TYPE_F16 && b->type == GGML_TYPE_F16) {
3076-
return false;
3077-
}
3078-
if (GGML_CUDA_CC_IS_QY2(cc) && op->op == GGML_OP_MUL_MAT_ID &&
3079-
a->type == GGML_TYPE_Q2_K && b->type == GGML_TYPE_F32) {
3080-
return false;
3081-
}
3069+
if (b->type == GGML_TYPE_F16 && b->ne[2]*b->ne[3] > 1 &&
3070+
!ggml_is_transposed(a) && !ggml_is_transposed(b)) {
3071+
return false;
30823072
}
30833073
#endif // GGML_USE_MUSA
30843074
switch (a->type) {
@@ -3105,6 +3095,11 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
31053095
case GGML_TYPE_IQ4_NL:
31063096
case GGML_TYPE_IQ4_XS:
31073097
case GGML_TYPE_BF16:
3098+
#ifdef GGML_USE_MUSA
3099+
if (a->type == GGML_TYPE_Q3_K) {
3100+
return false;
3101+
}
3102+
#endif // GGML_USE_MUSA
31083103
return true;
31093104
default:
31103105
return false;

ggml/src/ggml-musa/mudnn.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#pragma once
22

3-
#include "ggml-cuda/common.cuh"
4-
#include "ggml.h"
3+
#include "../include/ggml.h"
4+
#include "../ggml-cuda/common.cuh"
55

66
// Asynchronously copies data from src tensor to dst tensor using the provided context.
77
// Returns a musaError_t indicating success or failure.

0 commit comments

Comments
 (0)