Skip to content

vulkan: support softmax/FA batch and broadcast #14449

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 28 additions & 23 deletions ggml/src/ggml-vulkan/ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -627,6 +627,7 @@ struct vk_flash_attn_push_constants {
uint32_t nev2;
uint32_t nev3;
uint32_t nem1;
uint32_t nem2;

uint32_t nb01;
uint32_t nb02;
Expand All @@ -637,7 +638,6 @@ struct vk_flash_attn_push_constants {
uint32_t nb21;
uint32_t nb22;
uint32_t nb23;
uint32_t nb31;

float scale;
float max_bias;
Expand All @@ -652,6 +652,7 @@ struct vk_flash_attn_push_constants {
uint32_t split_kv;
uint32_t k_num;
};
static_assert(sizeof(vk_flash_attn_push_constants) <= 128, "sizeof(vk_flash_attn_push_constants) must be <= 128");

struct vk_op_push_constants {
uint32_t KX;
Expand Down Expand Up @@ -743,6 +744,14 @@ struct vk_op_rope_push_constants {
struct vk_op_soft_max_push_constants {
uint32_t KX;
uint32_t KY;
uint32_t ne00;
uint32_t ne01;
uint32_t ne02;
uint32_t ne12;
uint32_t ne13;
uint32_t nb11;
uint32_t nb12;
uint32_t nb13;
float scale;
float max_bias;
float m0;
Expand Down Expand Up @@ -5977,7 +5986,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)

const uint32_t nem1 = mask ? mask->ne[1] : 0;
const uint32_t nbm1 = mask ? mask->nb[1] : 0;
const uint32_t nem2 = mask ? mask->ne[2] : 0;

const uint32_t D = neq0;
uint32_t N = neq1;
Expand Down Expand Up @@ -6140,7 +6149,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
// Try to use split_k when KV is large enough to be worth the overhead
if (workgroups_x == 1 && shader_core_count > 0 && KV >= 512) {
// Try to run two workgroups per SM.
split_k = ctx->device->shader_core_count * 2 / workgroups_y;
split_k = ctx->device->shader_core_count * 2 / (workgroups_y * workgroups_z);
if (split_k > 1) {
// Try to evenly split KV into split_k chunks, but it needs to be a multiple
// of "align", so recompute split_k based on that.
Expand All @@ -6150,9 +6159,9 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
}
}

// Reserve space for split_k temporaries. For each split, we need to store the O matrix (D x ne1)
// and the per-row m and L values (ne1 rows).
const uint64_t split_k_size = split_k > 1 ? (D * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k : 0;
// Reserve space for split_k temporaries. For each split x batch, we need to store the O matrix (D x ne1)
// and the per-row m and L values (ne1 rows). We store all the matrices first, followed by the rows.
const uint64_t split_k_size = split_k > 1 ? (D * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k * ne3 : 0;
if (split_k_size > ctx->device->max_memory_allocation_size) {
GGML_ABORT("Requested preallocation size is too large");
}
Expand Down Expand Up @@ -6244,11 +6253,10 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
(uint32_t)neq2, (uint32_t)neq3,
(uint32_t)nek2, (uint32_t)nek3,
(uint32_t)nev2, (uint32_t)nev3,
nem1,
nem1, nem2,
q_stride, (uint32_t)nbq2, (uint32_t)nbq3,
k_stride, (uint32_t)nbk2, (uint32_t)nbk3,
v_stride, (uint32_t)nbv2, (uint32_t)nbv3,
nbm1,
scale, max_bias, logit_softcap,
mask != nullptr, n_head_log2, m0, m1,
gqa_ratio, split_kv, split_k };
Expand All @@ -6271,13 +6279,13 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
pc, { workgroups_x * pipeline->wg_denoms[0], workgroups_y, workgroups_z });

ggml_vk_sync_buffers(subctx);
const std::array<uint32_t, 3> pc2 = { D, (uint32_t)ne1, split_k };
const std::array<uint32_t, 4> pc2 = { D, (uint32_t)ne1, (uint32_t)ne3, split_k };
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_flash_attn_split_k_reduce,
{
vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE},
vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
},
pc2, { (uint32_t)ne1, 1, 1 });
pc2, { (uint32_t)ne1, 1, (uint32_t)ne3 });
} else {
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
{
Expand Down Expand Up @@ -7562,7 +7570,13 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
const uint32_t nrows_x = (uint32_t)ggml_nrows(src0);
const uint32_t nrows_y = (uint32_t)src0->ne[1];

const uint32_t n_head_kv = nrows_x/nrows_y;
const uint32_t ne12 = src1 ? (uint32_t)(src1->ne[2]) : 0u;
const uint32_t ne13 = src1 ? (uint32_t)(src1->ne[3]) : 0u;
const uint32_t nb11 = src1 ? (uint32_t)(src1->nb[1] / src1->nb[0]) : 0u;
const uint32_t nb12 = src1 ? (uint32_t)(src1->nb[2] / src1->nb[0]) : 0u;
const uint32_t nb13 = src1 ? (uint32_t)(src1->nb[3] / src1->nb[0]) : 0u;

const uint32_t n_head_kv = src0->ne[2];
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));

const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
Expand All @@ -7571,6 +7585,9 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
ggml_vk_op_f32<vk_op_soft_max_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX, {
ncols,
src1 != nullptr ? nrows_y : (uint32_t)0,
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],
ne12, ne13,
nb11, nb12, nb13,
scale, max_bias,
m0, m1,
n_head_log2,
Expand Down Expand Up @@ -10066,11 +10083,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
if (op->src[3] && op->src[3]->type != GGML_TYPE_F16) {
return false;
}
// TODO: support broadcast
// ref: https://github.com/ggml-org/llama.cpp/pull/14435
if (op->src[0]->ne[3] != 1) {
return false;
}
// It's straightforward to support different K/V dequant, but would
// significantly increase the number of pipelines
if (op->src[1]->type != op->src[2]->type) {
Expand Down Expand Up @@ -10231,13 +10243,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_OP_DIAG_MASK_INF:
return true;
case GGML_OP_SOFT_MAX:
// TODO: support batching
if (op->src[0]->ne[3] != 1) {
return false;
}
// TODO: support broadcast
// ref: https://github.com/ggml-org/llama.cpp/pull/14435
return !op->src[1] || (op->src[1]->ne[2] == 1 && op->src[1]->ne[3] == 1);
case GGML_OP_SOFT_MAX_BACK:
case GGML_OP_ARGSORT:
case GGML_OP_SUM:
Expand Down
12 changes: 8 additions & 4 deletions ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,10 @@ void main() {
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2;
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
#endif
uint32_t m_offset = 0;
if (p.nem2 != 1) {
m_offset = (iq3 % p.nem2) * p.nem1 * KV;
}

[[dont_unroll]]
for (uint32_t j = start_j; j < end_j; ++j) {
Expand Down Expand Up @@ -150,7 +154,7 @@ void main() {
uint32_t c = (idx + tid) % Bc;
uint32_t r = (idx + tid) / Bc;
if (idx + tid < Bc * Br) {
masksh[c][r] = float(data_m[(i * Br + r) * m_stride + (j * Bc + c)]);
masksh[c][r] = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
}
}
barrier();
Expand Down Expand Up @@ -277,7 +281,7 @@ void main() {
// If there is split_k, then the split_k resolve shader does the final
// division by L. Store the intermediate O value and per-row m and L values.
if (p.k_num > 1) {
uint32_t o_offset = D * p.ne1 * split_k_index;
uint32_t o_offset = D * p.ne1 * (split_k_index + iq3 * p.k_num);

[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
if (r < N) {
Expand All @@ -289,7 +293,7 @@ void main() {
}
}

o_offset = D * p.ne1 * p.k_num + p.ne1 * split_k_index * 2;
o_offset = D * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
if (r < N) {
perElemOpStoreCol0(r, 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
Expand All @@ -311,7 +315,7 @@ void main() {
}
}

uint32_t o_offset = iq3*p.ne2*p.ne1;
uint32_t o_offset = iq3*p.ne2*p.ne1*D;

if (p.gqa_ratio > 1) {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
Expand Down
2 changes: 1 addition & 1 deletion ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ layout (push_constant) uniform parameter {
uint32_t nev2;
uint32_t nev3;
uint32_t nem1;
uint32_t nem2;

uint32_t nb01;
uint32_t nb02;
Expand All @@ -34,7 +35,6 @@ layout (push_constant) uniform parameter {
uint32_t nb21;
uint32_t nb22;
uint32_t nb23;
uint32_t nb31;

float scale;
float max_bias;
Expand Down
12 changes: 8 additions & 4 deletions ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,10 @@ void main() {
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2;
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
#endif
uint32_t m_offset = 0;
if (p.nem2 != 1) {
m_offset = (iq3 % p.nem2) * p.nem1 * KV;
}

[[dont_unroll]]
for (uint32_t j = start_j; j < end_j; ++j) {
Expand Down Expand Up @@ -181,7 +185,7 @@ void main() {
uint32_t c = (idx + tid) % Bc;
uint32_t r = (idx + tid) / Bc;
if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) {
sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * float(data_m[(i * Br + r) * m_stride + (j * Bc + c)]));
sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]));
}
}
barrier();
Expand Down Expand Up @@ -300,7 +304,7 @@ void main() {
// If there is split_k, then the split_k resolve shader does the final
// division by L. Store the intermediate O value and per-row m and L values.
if (p.k_num > 1) {
uint32_t o_offset = D * p.ne1 * split_k_index;
uint32_t o_offset = D * p.ne1 * (split_k_index + iq3 * p.k_num);

[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
if (tile_row(r) < N) {
Expand All @@ -312,7 +316,7 @@ void main() {
}
}

o_offset = D * p.ne1 * p.k_num + p.ne1 * split_k_index * 2;
o_offset = D * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
if (tile_row(r) < N) {
perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
Expand All @@ -334,7 +338,7 @@ void main() {
}
}

uint32_t o_offset = iq3*p.ne2*p.ne1;
uint32_t o_offset = iq3*p.ne2*p.ne1*D;

if (p.gqa_ratio > 1) {
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Expand Down
13 changes: 9 additions & 4 deletions ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,11 @@ void main() {
coopMatPerElementNV(slopeMat, slopeMat, perElemOpComputeSlope, iq2);
}

uint32_t m_offset = 0;
if (p.nem2 != 1) {
m_offset = (iq3 % p.nem2) * p.nem1 * KV * 2 /*sizeof(float16_t)*/;
}

[[dont_unroll]]
for (uint32_t j = start_j; j < end_j; ++j) {

Expand All @@ -155,7 +160,7 @@ void main() {

coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;

coopMatLoadTensorNV(mv, data_m, 0, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));

S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
}
Expand Down Expand Up @@ -229,10 +234,10 @@ void main() {
if (p.k_num > 1) {
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(O);

uint32_t o_offset = D * p.ne1 * split_k_index;
uint32_t o_offset = D * p.ne1 * (split_k_index + iq3 * p.k_num);
coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);

o_offset = D * p.ne1 * p.k_num + p.ne1 * split_k_index * 2;
o_offset = D * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
coopMatPerElementNV(L, L, perElemOpStoreCol0, o_offset, iq2, N);
coopMatPerElementNV(M, M, perElemOpStoreCol0, o_offset + p.ne1, iq2, N);
return;
Expand All @@ -250,7 +255,7 @@ void main() {

O = Ldiag*O;

uint32_t o_offset = iq3*p.ne2*p.ne1;
uint32_t o_offset = iq3*p.ne2*p.ne1*D;

coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(O);
if (p.gqa_ratio > 1) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,22 @@ layout (binding = 1) writeonly buffer D {float data_d[];};
layout (push_constant) uniform parameter {
uint D;
uint N;
uint ne3;
uint k_num;
} p;

void main() {
// Each workgroup handles a row
const uint n = gl_WorkGroupID.x;
const uint tid = gl_LocalInvocationID.x;
const uint iq3 = gl_WorkGroupID.z;

uint D = p.D;
uint N = p.N;
uint k_num = p.k_num;

uint l_offset = D * N * k_num + n;
uint m_offset = D * N * k_num + N + n;
uint l_offset = D * N * p.ne3 * k_num + N * iq3 * k_num * 2 + n;
uint m_offset = D * N * p.ne3 * k_num + N * iq3 * k_num * 2 + N + n;
uint lm_stride = N * 2;

// Compute the max m value for the row
Expand All @@ -49,11 +51,11 @@ void main() {
for (uint d = tid; d < D; d += BLOCK_SIZE) {
float O = 0.0;
[[unroll]] for (uint k = 0; k < k_num; ++k) {
uint o_offset = D * N * k + D * n + d;
uint o_offset = D * N * (k + iq3 * k_num) + D * n + d;
float m = data_a[m_offset + k * lm_stride];
O += exp(m - m_max) * data_a[o_offset];
}
O *= L;
data_d[D * n + d] = O;
data_d[iq3 * D * N + D * n + d] = O;
}
}
Loading
Loading