@@ -653,7 +653,7 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) {
653
653
nullptr ;
654
654
}
655
655
656
- template <int D, int ncols1, int ncols2> // D == head size
656
+ template <int D, int ncols1, int ncols2, int KQ_stride > // D == head size
657
657
__launch_bounds__ (D, 1 )
658
658
static __global__ void flash_attn_stream_k_fixup(
659
659
float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) {
@@ -811,13 +811,13 @@ static void on_no_fattn_vec_case(const int D) {
811
811
fprintf (stderr, " Compile with GGML_CUDA_FA_ALL_QUANTS for all combinations of q4_0, q4_1, iq4_nl, q5_0, q5_1, q6_0, q8_0, and f16.\n " );
812
812
GGML_ABORT (" fatal error" );
813
813
} else {
814
- fprintf (stderr, " Unsupported KV type combination for head_size %d .\n " , D );
814
+ fprintf (stderr, " Unsupported KV type combination for head_size 256 .\n " );
815
815
fprintf (stderr, " Only f16 is supported.\n " );
816
816
GGML_ABORT (" fatal error" );
817
817
}
818
818
}
819
819
820
- template <int DV , int ncols1, int ncols2>
820
+ template <int D , int ncols1, int ncols2, int KQ_stride >
821
821
void launch_fattn (
822
822
ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, const int nwarps, const size_t nbytes_shared,
823
823
const int KQ_row_granularity, const bool need_f16_K, const bool need_f16_V, const bool stream_k, const int warp_size = WARP_SIZE
@@ -837,7 +837,7 @@ void launch_fattn(
837
837
838
838
GGML_ASSERT (!mask || mask->type == GGML_TYPE_F16);
839
839
GGML_ASSERT (!mask || mask->ne [1 ] >= GGML_PAD (Q->ne [1 ], 16 ) &&
840
- " the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big" );
840
+ " the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big" );
841
841
842
842
GGML_ASSERT (K->ne [1 ] % FATTN_KQ_STRIDE == 0 && " Incorrect KV cache padding." );
843
843
@@ -898,13 +898,10 @@ void launch_fattn(
898
898
const int ntiles_total = ntiles_x * (Q->ne [2 ] / ncols2) * Q->ne [3 ];
899
899
900
900
const dim3 block_dim (warp_size, nwarps, 1 );
901
- int max_blocks_per_sm = 1 ; // Max. number of active blocks limited by occupancy.
902
- CUDA_CHECK (cudaOccupancyMaxActiveBlocksPerMultiprocessor (&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z , nbytes_shared));
903
-
904
901
dim3 blocks_num;
905
902
if (stream_k) {
906
903
// For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
907
- const int max_blocks = max_blocks_per_sm *nsm;
904
+ const int max_blocks = 2 *nsm;
908
905
const int tiles_nwaves = (ntiles_total + max_blocks - 1 ) / max_blocks;
909
906
const int tiles_efficiency_percent = 100 * ntiles_total / (max_blocks*tiles_nwaves);
910
907
@@ -916,11 +913,14 @@ void launch_fattn(
916
913
blocks_num.y = 1 ;
917
914
blocks_num.z = 1 ;
918
915
919
- dst_tmp_meta.alloc (blocks_num.x *ncols * (2 *2 + DV ) * sizeof (float ));
916
+ dst_tmp_meta.alloc (blocks_num.x *ncols * (2 *2 + D ) * sizeof (float ));
920
917
} else {
921
918
GGML_ASSERT (K->ne [1 ] % KQ_row_granularity == 0 );
922
919
const int ntiles_KQ = K->ne [1 ] / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size.
923
920
921
+ int max_blocks_per_sm = 1 ; // Max. number of active blocks limited by occupancy.
922
+ CUDA_CHECK (cudaOccupancyMaxActiveBlocksPerMultiprocessor (&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z , nbytes_shared));
923
+
924
924
// parallel_blocks should be at least large enough to achieve max. occupancy for a single wave:
925
925
parallel_blocks = std::max ((nsm * max_blocks_per_sm) / ntiles_total, 1 );
926
926
@@ -997,19 +997,19 @@ void launch_fattn(
997
997
998
998
if (stream_k) {
999
999
if (ntiles_total % blocks_num.x != 0 ) { // Fixup is only needed if the SMs work on fractional tiles.
1000
- const dim3 block_dim_combine (DV , 1 , 1 );
1000
+ const dim3 block_dim_combine (D , 1 , 1 );
1001
1001
const dim3 blocks_num_combine = {blocks_num.x , ncols1, ncols2};
1002
1002
1003
- flash_attn_stream_k_fixup<DV , ncols1, ncols2>
1003
+ flash_attn_stream_k_fixup<D , ncols1, ncols2, KQ_stride >
1004
1004
<<<blocks_num_combine, block_dim_combine, 0 , main_stream>>>
1005
1005
((float *) KQV->data , dst_tmp_meta.ptr , Q->ne [1 ], Q->ne [2 ], K->ne [1 ]);
1006
1006
}
1007
1007
} else if (parallel_blocks > 1 ) {
1008
- const dim3 block_dim_combine (DV , 1 , 1 );
1008
+ const dim3 block_dim_combine (D , 1 , 1 );
1009
1009
const dim3 blocks_num_combine (Q->ne [1 ], 1 , blocks_num.z );
1010
1010
const size_t nbytes_shared_combine = parallel_blocks*sizeof (float2 );
1011
1011
1012
- flash_attn_combine_results<DV >
1012
+ flash_attn_combine_results<D >
1013
1013
<<<blocks_num_combine, block_dim_combine, nbytes_shared_combine, main_stream>>>
1014
1014
(dst_tmp.ptr , dst_tmp_meta.ptr , (float *) KQV->data , parallel_blocks);
1015
1015
}
0 commit comments