Skip to content

Commit 015022b

Browse files
authored
vulkan: enable coopmat2 FA gqa and split_k optimizations more often (ggml-org#12931)
The grouped query attention optmization doesn't require a power of two ratio, the only thing relying on it was the modulo operation written as bitwise &. split_k need not depend on gqa_ratio - enable it any time there's only one workgroup in the X dimension. The shader gets the split index from the x coord, and multiple workgroups in the X dimension (pre-split) indicates a larger FA operation that wouldn't need splitting.
1 parent b43d89e commit 015022b

File tree

3 files changed

+7
-5
lines changed

3 files changed

+7
-5
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -5531,7 +5531,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
55315531
uint32_t workgroups_y = (uint32_t)neq2;
55325532
uint32_t workgroups_z = (uint32_t)neq3;
55335533

5534-
if (N == 1 && qk_ratio > 1 && is_pow2(qk_ratio) && gqa_ratio <= flash_attention_num_small_rows &&
5534+
if (N == 1 && qk_ratio > 1 && gqa_ratio <= flash_attention_num_small_rows &&
55355535
qk_ratio * nek2 == neq2 && nek2 == nev2 && neq3 == 1 && nek3 == 1 && nev3 == 1) {
55365536
// grouped query attention - make the N dimension equal to gqa_ratio, reduce
55375537
// workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1
@@ -5544,8 +5544,8 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
55445544
uint32_t split_kv = KV;
55455545
uint32_t split_k = 1;
55465546

5547-
if (gqa_ratio > 1 && ctx->device->shader_core_count > 0) {
5548-
GGML_ASSERT(workgroups_x == 1);
5547+
// Try to use split_k when KV is large enough to be worth the overhead
5548+
if (workgroups_x == 1 && ctx->device->shader_core_count > 0 && KV >= 512) {
55495549
// Try to run two workgroups per SM.
55505550
split_k = ctx->device->shader_core_count * 2 / workgroups_y;
55515551
if (split_k > 1) {

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp

+1-1
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in A
131131
// Load the slope matrix, indexed by Q's dimension 2.
132132
ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2)
133133
{
134-
const uint32_t h = iq2 + (r & (p.gqa_ratio - 1));
134+
const uint32_t h = iq2 + (r % p.gqa_ratio);
135135

136136
const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1);
137137
const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1);

tests/test-backend-ops.cpp

+3-1
Original file line numberDiff line numberDiff line change
@@ -4532,7 +4532,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
45324532

45334533
for (int kv : { 4096, 8192, 16384, }) {
45344534
for (int hs : { 64, 128, }) {
4535-
test_cases.emplace_back(new test_flash_attn_ext(hs, hs, 8, 4, kv, 1, true, 0, 0, GGML_PREC_F32, GGML_TYPE_F16));
4535+
for (int nr : { 1, 4, }) {
4536+
test_cases.emplace_back(new test_flash_attn_ext(hs, hs, 8, nr, kv, 1, true, 0, 0, GGML_PREC_F32, GGML_TYPE_F16));
4537+
}
45364538
}
45374539
}
45384540

0 commit comments

Comments
 (0)