Skip to content

Commit b66423b

Browse files
committed
Revert commit "CUDA: FA support for Deepseek (Ampere or newer) (ggml-org#13306)"
1 parent a8e1c88 commit b66423b

32 files changed

+520
-825
lines changed

ggml/src/ggml-cuda/common.cuh

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -295,25 +295,6 @@ static __device__ void no_device_code(
295295
#define NO_DEVICE_CODE //GGML_ABORT("NO_DEVICE_CODE not valid in host code.")
296296
#endif // __CUDA_ARCH__
297297

298-
// The compiler is always able to unroll loops if they contain continue expressions.
299-
// In such cases loop unrolling can still be achieved via recursion:
300-
template <int n>
301-
struct ggml_cuda_unroll {
302-
template <typename Func, typename... Args>
303-
__device__ void operator()(const Func & f, Args... args) const {
304-
f(n - 1, args...);
305-
ggml_cuda_unroll<n - 1>{}(f, args...);
306-
}
307-
};
308-
309-
template <>
310-
struct ggml_cuda_unroll<1> {
311-
template <typename Func, typename... Args>
312-
__device__ void operator()(const Func & f, Args... args) const {
313-
f(0, args...);
314-
}
315-
};
316-
317298
template<int width = WARP_SIZE>
318299
static __device__ __forceinline__ int warp_reduce_sum(int x) {
319300
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE

ggml/src/ggml-cuda/cp-async.cuh

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,6 @@
22

33
#include "common.cuh"
44

5-
6-
static __device__ __forceinline__ unsigned int ggml_cuda_cvta_generic_to_shared(void * generic_ptr) {
7-
#ifdef CP_ASYNC_AVAILABLE
8-
return __cvta_generic_to_shared(generic_ptr);
9-
#else
10-
GGML_UNUSED(generic_ptr);
11-
NO_DEVICE_CODE;
12-
return 0;
13-
#endif // CP_ASYNC_AVAILABLE
14-
}
15-
165
// Copies data from global to shared memory, cg == cache global.
176
// Both the src and dst pointers must be aligned to 16 bit.
187
// Shared memory uses 32 bit addressing, the pointer is passed as unsigned int.

ggml/src/ggml-cuda/fattn-common.cuh

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -653,7 +653,7 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) {
653653
nullptr;
654654
}
655655

656-
template<int D, int ncols1, int ncols2> // D == head size
656+
template<int D, int ncols1, int ncols2, int KQ_stride> // D == head size
657657
__launch_bounds__(D, 1)
658658
static __global__ void flash_attn_stream_k_fixup(
659659
float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) {
@@ -811,13 +811,13 @@ static void on_no_fattn_vec_case(const int D) {
811811
fprintf(stderr, "Compile with GGML_CUDA_FA_ALL_QUANTS for all combinations of q4_0, q4_1, iq4_nl, q5_0, q5_1, q6_0, q8_0, and f16.\n");
812812
GGML_ABORT("fatal error");
813813
} else {
814-
fprintf(stderr, "Unsupported KV type combination for head_size %d.\n", D);
814+
fprintf(stderr, "Unsupported KV type combination for head_size 256.\n");
815815
fprintf(stderr, "Only f16 is supported.\n");
816816
GGML_ABORT("fatal error");
817817
}
818818
}
819819

820-
template <int DV, int ncols1, int ncols2>
820+
template <int D, int ncols1, int ncols2, int KQ_stride>
821821
void launch_fattn(
822822
ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, const int nwarps, const size_t nbytes_shared,
823823
const int KQ_row_granularity, const bool need_f16_K, const bool need_f16_V, const bool stream_k, const int warp_size = WARP_SIZE
@@ -837,7 +837,7 @@ void launch_fattn(
837837

838838
GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
839839
GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) &&
840-
"the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");
840+
"the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");
841841

842842
GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding.");
843843

@@ -898,13 +898,10 @@ void launch_fattn(
898898
const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3];
899899

900900
const dim3 block_dim(warp_size, nwarps, 1);
901-
int max_blocks_per_sm = 1; // Max. number of active blocks limited by occupancy.
902-
CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z, nbytes_shared));
903-
904901
dim3 blocks_num;
905902
if (stream_k) {
906903
// For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
907-
const int max_blocks = max_blocks_per_sm*nsm;
904+
const int max_blocks = 2*nsm;
908905
const int tiles_nwaves = (ntiles_total + max_blocks - 1) / max_blocks;
909906
const int tiles_efficiency_percent = 100 * ntiles_total / (max_blocks*tiles_nwaves);
910907

@@ -916,11 +913,14 @@ void launch_fattn(
916913
blocks_num.y = 1;
917914
blocks_num.z = 1;
918915

919-
dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + DV) * sizeof(float));
916+
dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + D) * sizeof(float));
920917
} else {
921918
GGML_ASSERT(K->ne[1] % KQ_row_granularity == 0);
922919
const int ntiles_KQ = K->ne[1] / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size.
923920

921+
int max_blocks_per_sm = 1; // Max. number of active blocks limited by occupancy.
922+
CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z, nbytes_shared));
923+
924924
// parallel_blocks should be at least large enough to achieve max. occupancy for a single wave:
925925
parallel_blocks = std::max((nsm * max_blocks_per_sm) / ntiles_total, 1);
926926

@@ -997,19 +997,19 @@ void launch_fattn(
997997

998998
if (stream_k) {
999999
if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
1000-
const dim3 block_dim_combine(DV, 1, 1);
1000+
const dim3 block_dim_combine(D, 1, 1);
10011001
const dim3 blocks_num_combine = {blocks_num.x, ncols1, ncols2};
10021002

1003-
flash_attn_stream_k_fixup<DV, ncols1, ncols2>
1003+
flash_attn_stream_k_fixup<D, ncols1, ncols2, KQ_stride>
10041004
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
10051005
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]);
10061006
}
10071007
} else if (parallel_blocks > 1) {
1008-
const dim3 block_dim_combine(DV, 1, 1);
1008+
const dim3 block_dim_combine(D, 1, 1);
10091009
const dim3 blocks_num_combine(Q->ne[1], 1, blocks_num.z);
10101010
const size_t nbytes_shared_combine = parallel_blocks*sizeof(float2);
10111011

1012-
flash_attn_combine_results<DV>
1012+
flash_attn_combine_results<D>
10131013
<<<blocks_num_combine, block_dim_combine, nbytes_shared_combine, main_stream>>>
10141014
(dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data, parallel_blocks);
10151015
}

0 commit comments

Comments
 (0)