Skip to content

Commit 9dd4089

Browse files
committed
Update steel::AttnParams
1 parent 57f9817 commit 9dd4089

File tree

6 files changed

+83
-37
lines changed

6 files changed

+83
-37
lines changed

mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright © 2024 Apple Inc.
1+
// Copyright © 2024-25 Apple Inc.
22

33
using namespace mlx::steel;
44

@@ -9,6 +9,9 @@ using namespace mlx::steel;
99
constant bool align_Q [[function_constant(200)]];
1010
constant bool align_K [[function_constant(201)]];
1111

12+
constant bool has_mask [[function_constant(300)]];
13+
constant bool do_causal [[function_constant(301)]];
14+
1215
template <typename T>
1316
struct TransformScale {
1417
T scale;
@@ -69,13 +72,16 @@ template <
6972
int BD,
7073
int WM,
7174
int WN,
75+
typename MaskType = float,
7276
typename AccumType = float>
7377
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void attention(
7478
const device T* Q [[buffer(0)]],
7579
const device T* K [[buffer(1)]],
7680
const device T* V [[buffer(2)]],
7781
device T* O [[buffer(3)]],
7882
const constant AttnParams* params [[buffer(4)]],
83+
const constant AttnMaskParams* mask_params [[buffer(5), function_constant(has_mask)]],
84+
const device MaskType* mask [[buffer(6), function_constant(has_mask)]],
7985
uint simd_lane_id [[thread_index_in_simdgroup]],
8086
uint simd_group_id [[simdgroup_index_in_threadgroup]],
8187
uint3 tid [[threadgroup_position_in_grid]],
@@ -102,6 +108,12 @@ template <
102108
tidl.y * params->O_strides[1] + // Head
103109
tidl.x * BQ * params->O_strides[2]; // Seqeunce
104110

111+
if (has_mask) {
112+
mask += tidl.z * mask_params->M_strides[0] + // Batch
113+
tidl.y * mask_params->M_strides[1] + // Head
114+
tidl.x * BQ * mask_params->M_strides[2]; // Seqeunce
115+
}
116+
105117
// Prepare threadgroup memory
106118
constexpr short padQ = 16 / sizeof(T);
107119
constexpr short padK = 16 / sizeof(T);
@@ -203,7 +215,7 @@ template <
203215

204216
// Load Q blocks apply scale
205217
if (!align_Q && int(tid.x) == (params->NQ_aligned)) {
206-
loader_q.load_safe(short2(BD, params->qL - params->NQ_aligned * BQ));
218+
loader_q.load_safe(short2(BD, params->qL_rem));
207219
} else {
208220
loader_q.load_unsafe();
209221
}
@@ -226,7 +238,7 @@ template <
226238
// Load K block and apply scale
227239
threadgroup_barrier(mem_flags::mem_threadgroup);
228240
if (!align_K && kb == (params->NK_aligned)) {
229-
loader_k.load_safe(short2(BD, params->kL - params->NK_aligned * BK));
241+
loader_k.load_safe(short2(BD, params->kL_rem));
230242
} else {
231243
loader_k.load_unsafe();
232244
}
@@ -276,7 +288,7 @@ template <
276288

277289
// Load V blocks
278290
if (!align_K && kb == (params->NK_aligned)) {
279-
loader_v.load_safe(short2(BD, params->kL - params->NK_aligned * BK));
291+
loader_v.load_safe(short2(BD, params->kL_rem));
280292
} else {
281293
loader_v.load_unsafe();
282294
}
@@ -367,8 +379,7 @@ template <
367379
O += (tm + sm) * params->O_strides[2] + sn;
368380

369381
if (!align_Q && int(tid.x) == (params->NQ_aligned)) {
370-
auto dst_tile_dims =
371-
short2(BD - sn, params->qL - BQ * params->NQ_aligned - (tm + sm));
382+
auto dst_tile_dims = short2(BD - sn, params->qL_rem - (tm + sm));
372383

373384
if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0)
374385
return;
Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,28 @@
1-
// Copyright © 2024 Apple Inc.
1+
// Copyright © 2024-25 Apple Inc.
22

33
// clang-format off
44
#include "mlx/backend/metal/kernels/utils.h"
55

66
#include "mlx/backend/metal/kernels/steel/attn/attn.h"
77
#include "mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h"
88

9-
#define instantiate_attn(tname, dtype, bq, bk, bd, wm, wn) \
10-
template [[host_name("steel_attention_" #tname "_bq" #bq "_bk" #bk "_bd" #bd "_wm" #wm "_wn" #wn)]] \
11-
[[kernel]] void attention<dtype, bq, bk, bd, wm, wn, float>( \
12-
const device dtype* Q [[buffer(0)]], \
13-
const device dtype* K [[buffer(1)]], \
14-
const device dtype* V [[buffer(2)]], \
15-
device dtype* O [[buffer(3)]],\
16-
const constant AttnParams* params [[buffer(4)]], \
17-
uint simd_lane_id [[thread_index_in_simdgroup]], \
18-
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
19-
uint3 tid [[threadgroup_position_in_grid]], \
20-
uint3 lid [[thread_position_in_threadgroup]]);
9+
#define instantiate_attn(tname, dtype, bq, bk, bd, wm, wn, mname, mtype) \
10+
instantiate_kernel( \
11+
"steel_attention_" #tname "_bq" #bq "_bk" #bk "_bd" #bd \
12+
"_wm" #wm "_wn" #wn "_mask" #mname, \
13+
attention, dtype, bq, bk, bd, wm, wn, mtype, float)
2114

22-
#define instantiate_attn_shapes_helper(iname, itype) \
23-
instantiate_attn(iname, itype, 32, 16, 128, 4, 1) \
24-
instantiate_attn(iname, itype, 32, 32, 80, 4, 1) \
25-
instantiate_attn(iname, itype, 32, 32, 64, 4, 1)
15+
#define instantiate_attn_shapes_helper(iname, itype, mname, mtype) \
16+
instantiate_attn(iname, itype, 32, 16, 128, 4, 1, mname, mtype) \
17+
instantiate_attn(iname, itype, 32, 32, 80, 4, 1, mname, mtype) \
18+
instantiate_attn(iname, itype, 32, 32, 64, 4, 1, mname, mtype)
2619

27-
instantiate_attn_shapes_helper(float16, half);
28-
instantiate_attn_shapes_helper(bfloat16, bfloat16_t);
20+
#define instantiate_attn_mask_helper(iname, itype) \
21+
instantiate_attn_shapes_helper(iname, itype, iname, itype) \
22+
instantiate_attn_shapes_helper(iname, itype, bool_, bool)
2923

30-
instantiate_attn_shapes_helper(float32, float);
24+
instantiate_attn_mask_helper(float16, half);
25+
instantiate_attn_mask_helper(bfloat16, bfloat16_t);
26+
27+
instantiate_attn_mask_helper(float32, float);
3128
// clang-format on

mlx/backend/metal/kernels/steel/attn/params.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,19 @@ struct AttnParams {
2626
int NQ_aligned; ///< Number of full query blocks
2727
int NK_aligned; ///< Number of full key/value blocks
2828

29+
int qL_rem; ///< Remainder in last query block
30+
int kL_rem; ///< Remainder in last key/value block
31+
int qL_off; ///< Offset in query sequence start
32+
2933
int64_t Q_strides[3]; ///< Query strides (B, H, L, D = 1)
3034
int64_t K_strides[3]; ///< Key strides (B, H, L, D = 1)
3135
int64_t V_strides[3]; ///< Value strides (B, H, L, D = 1)
3236
int64_t O_strides[3]; ///< Output strides (B, H, L, D = 1)
3337
};
3438

39+
struct AttnMaskParams {
40+
int64_t M_strides[3]; ///< Mask strides (B, H, qL, kL = 1)
41+
};
42+
3543
} // namespace steel
3644
} // namespace mlx

mlx/backend/metal/scaled_dot_product_attention.cpp

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@ void sdpa_full_self_attention_metal(
2121
const array& k,
2222
const array& v,
2323
const float scale,
24-
array& o) {
24+
array& o,
25+
bool do_causal_ = false,
26+
const std::optional<array>& mask = std::nullopt) {
2527
using namespace mlx::steel;
2628

2729
int wm = 4;
@@ -41,11 +43,14 @@ void sdpa_full_self_attention_metal(
4143

4244
const bool align_Q = (qL % bq) == 0;
4345
const bool align_K = (kL % bk) == 0;
46+
const bool has_mask = !!mask;
47+
const bool do_causal = do_causal_;
4448

4549
metal::MTLFCList func_consts = {
4650
{&align_Q, MTL::DataType::DataTypeBool, 200},
4751
{&align_K, MTL::DataType::DataTypeBool, 201},
48-
};
52+
{&has_mask, MTL::DataType::DataTypeBool, 300},
53+
{&do_causal, MTL::DataType::DataTypeBool, 301}};
4954

5055
std::ostringstream kname;
5156
// clang-format off
@@ -54,13 +59,17 @@ void sdpa_full_self_attention_metal(
5459
<< "_bq" << bq
5560
<< "_bk" << bk
5661
<< "_bd" << bd
57-
<< "_wm" << wm << "_wn" << wn; // clang-format on
62+
<< "_wm" << wm
63+
<< "_wn" << wn
64+
<< "_mask" << (type_to_name(mask ? *mask : q)); // clang-format on
5865

5966
std::string base_name = kname.str();
6067

6168
// clang-format off
6269
kname << "_align_Q_" << (align_Q ? 't' : 'n')
63-
<< "_align_K_" << (align_K ? 't' : 'n'); // clang-format on
70+
<< "_align_K_" << (align_K ? 't' : 'n')
71+
<< "_has_mask_" << (has_mask ? 't' : 'n')
72+
<< "_do_causal_" << (do_causal ? 't' : 'n'); // clang-format on
6473

6574
std::string hash_name = kname.str();
6675

@@ -91,6 +100,10 @@ void sdpa_full_self_attention_metal(
91100
/* int NQ_aligned = */ NQ_aligned,
92101
/* int NK_aligned = */ NK_aligned,
93102

103+
/* int qL_rem = */ (qL - NQ_aligned * bq),
104+
/* int kL_rem = */ (kL - NK_aligned * bk),
105+
/* int qL_off = */ ((kL - qL) < 0 ? 0 : (kL - qL)),
106+
94107
/* int64_t Q_strides[3] = */ {q.strides(0), q.strides(1), q.strides(2)},
95108
/* int64_t K_strides[3] = */ {k.strides(0), k.strides(1), k.strides(2)},
96109
/* int64_t V_strides[3] = */ {v.strides(0), v.strides(1), v.strides(2)},
@@ -102,6 +115,15 @@ void sdpa_full_self_attention_metal(
102115
compute_encoder.set_output_array(o, 3);
103116
compute_encoder.set_bytes(params, 4);
104117

118+
if (mask) {
119+
auto m = *mask;
120+
AttnMaskParams mask_params{/* int64_t M_strides[3] = */ {
121+
m.strides(0), m.strides(1), m.strides(2)}};
122+
123+
compute_encoder.set_bytes(mask_params, 5);
124+
compute_encoder.set_input_array(m, 6);
125+
}
126+
105127
MTL::Size grid_dims = MTL::Size(NQ, H, B);
106128
MTL::Size group_dims = MTL::Size(32, wm, wn);
107129

@@ -324,7 +346,7 @@ void ScaledDotProductAttention::eval_gpu(
324346

325347
// Checks that the headdim dimension has stride 1.
326348
auto is_matrix_contiguous = [](const array& arr) {
327-
return arr.strides(3) == 1;
349+
return arr.strides(-1) == 1;
328350
};
329351

330352
// We are in vector mode ie single query
@@ -381,7 +403,11 @@ void ScaledDotProductAttention::eval_gpu(
381403
{str_oB, str_oH, str_oL, str_oD},
382404
flags);
383405

384-
sdpa_full_self_attention_metal(s, d, q, k, v, scale_, o);
406+
auto mask = inputs.size() > 3
407+
? std::optional<array>{copy_unless(is_matrix_contiguous, inputs[3])}
408+
: std::nullopt;
409+
410+
sdpa_full_self_attention_metal(s, d, q, k, v, scale_, o, do_causal_, mask);
385411
}
386412

387413
d.add_temporaries(std::move(copies), s.index);

mlx/fast.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -710,7 +710,7 @@ array scaled_dot_product_attention(
710710
auto k_idx = arange(0, kL, s);
711711
q_idx = expand_dims(q_idx, 1, s);
712712
k_idx = expand_dims(k_idx, 0, s);
713-
mask = less(q_idx, k_idx, s);
713+
mask = greater_equal(q_idx, k_idx, s);
714714
}
715715

716716
if (n_repeats > 1 && mask.ndim() >= 3) {
@@ -746,11 +746,15 @@ array scaled_dot_product_attention(
746746
const bool sdpa_full_supported_head_dim = query_head_dim == value_head_dim &&
747747
(query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128);
748748

749+
const bool sdpa_vector_supported_mask = (!has_mask || has_bool_mask);
750+
const bool sdpa_full_supported_mask = !has_mask;
751+
749752
const bool supports_sdpa_full = query_sequence_length >= threshold &&
750-
!has_mask && sdpa_full_supported_head_dim && stream.device == Device::gpu;
753+
sdpa_full_supported_mask && sdpa_full_supported_head_dim &&
754+
stream.device == Device::gpu;
751755

752-
const bool supports_sdpa_vector = query_sequence_length == 1 &&
753-
(!has_mask || has_bool_mask) && sdpa_vector_supported_head_dim &&
756+
const bool supports_sdpa_vector = query_sequence_length < 8 &&
757+
sdpa_vector_supported_mask && sdpa_vector_supported_head_dim &&
754758
stream.device == Device::gpu;
755759

756760
const bool implementation_supports_use_case =

python/tests/test_fast_sdpa.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def mlx_ref_attn(q, k, v, scale=1.0, mask=None):
3030
q_offset = max(0, kL - L)
3131
q_indices = mx.arange(q_offset, q_offset + L)
3232
k_indices = mx.arange(kL)
33-
mask = q_indices[:, None] < k_indices[None]
33+
mask = q_indices[:, None] >= k_indices[None]
3434

3535
if n_repeats > 1 and mask.ndim >= 3:
3636
if mask.shape[-3] == 1:

0 commit comments

Comments
 (0)