Skip to content

Integrate MLX SDPA kernels with mask #2820

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion candle-flash-attn/cutlass
Submodule cutlass updated 582 files
300 changes: 164 additions & 136 deletions candle-metal-kernels/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1570,174 +1570,201 @@ pub fn call_sdpa_full(
kernels: &Kernels,
q_offset: usize,
q_shape: &[usize],
q_strides: &[usize],
q_buffer: &Buffer,
k_offset: usize,
k_shape: &[usize],
k_strides: &[usize],
k_buffer: &Buffer,
v_offset: usize,
v_buffer: &Buffer,
v_strides: &[usize],
mask_type: Option<SdpaDType>,
mask_buffer: Option<&Buffer>,
m_strides: Option<&[usize]>,
output: &Buffer,
alpha: f32,
softcapping: f32,
o_strides: &[usize],
scale: f32,
do_causal: bool,
itype: SdpaDType,
) -> Result<(), MetalKernelError> {
#[derive(Debug)]
#[repr(C)]
struct MLXFastAttentionParams {
m: i32,
n: i32,
k: i32,

ldq: i32, // ldq == ldo
ldk: i32,
ldv: i32,
lds: i32,
ldo: i32,

tiles_n: i32,
tiles_m: i32,

batch_stride_q: i32,
batch_stride_k: i32,
batch_stride_v: i32,
batch_stride_o: i32,

swizzle_log: i32,
gemm_n_iterations_aligned: i32,
gemm_k_iterations_aligned: i32,
gemm_sv_m_block_iterations: i32,

batch_ndim: i32,
alpha: f32,
softcapping: f32,
struct AttnParams {
b: i32,
h: i32,
d: i32,
ql: i32,
kl: i32,
gqa_factor: i32,
scale: f32,
nq: i32,
nk: i32,
nq_aligned: i32,
nk_aligned: i32,
ql_rem: i32,
kl_rem: i32,
ql_off: i32,
q_strides: [i64; 3],
k_strides: [i64; 3],
v_strides: [i64; 3],
o_strides: [i64; 3],
}

let bk = q_shape.last().unwrap();
#[derive(Debug)]
#[repr(C)]
struct AttnMaskParams {
m_strides: [i64; 3],
}

const BN: usize = 16;
const BM: usize = 16;
const WM: usize = 2;
const WN: usize = 2;
const WM: usize = 4;
const WN: usize = 1;

let name = match (bk, itype) {
(32, SdpaDType::F16) => "steel_gemm_attention_bm_16_bn_16_bk_32_itype_half",
(64, SdpaDType::F16) => "steel_gemm_attention_bm_16_bn_16_bk_64_itype_half",
(96, SdpaDType::F16) => "steel_gemm_attention_bm_16_bn_16_bk_96_itype_half",
(128, SdpaDType::F16) => "steel_gemm_attention_bm_16_bn_16_bk_128_itype_half",
(256, SdpaDType::F16) => "steel_gemm_attention_bm_16_bn_16_bk_256_itype_half",
(32, SdpaDType::F32) => "steel_gemm_attention_bm_16_bn_16_bk_32_itype_float",
(64, SdpaDType::F32) => "steel_gemm_attention_bm_16_bn_16_bk_64_itype_float",
(96, SdpaDType::F32) => "steel_gemm_attention_bm_16_bn_16_bk_96_itype_float",
(128, SdpaDType::F32) => "steel_gemm_attention_bm_16_bn_16_bk_128_itype_float",
(256, SdpaDType::F32) => "steel_gemm_attention_bm_16_bn_16_bk_256_itype_float",
(other, SdpaDType::F16 | SdpaDType::F32) => {
return Err(MetalKernelError::SdpaHeadSizeMismatch {
variation: "full",
got: *other,
expected: vec![32, 64, 96, 128, 256],
})
}
(_, SdpaDType::BF16) => {
return Err(MetalKernelError::SdpaHeadDTypeMismatch {
variation: "full",
got: SdpaDType::BF16,
})
}
const BQ: usize = 32;
let bd = q_shape[q_shape.len() - 1];
let bk = if bd < 128 { 32 } else { 16 };

let b = q_shape[0];
let h = q_shape[1];
let d = q_shape[3];
let gqa_factor = q_shape[1] / k_shape[1];

let ql = q_shape[2];
let kl = k_shape[2];

let align_q = (ql % BQ) == 0;
let align_k = (kl % bk) == 0;
let has_mask = mask_buffer.is_some();

let itype_repr = match itype {
SdpaDType::BF16 => "bfloat16",
SdpaDType::F16 => "float16",
SdpaDType::F32 => "float32",
};
let mask_repr = match mask_type {
Some(SdpaDType::BF16) => "bfloat16",
Some(SdpaDType::F16) => "float16",
Some(SdpaDType::F32) => "float32",
None => itype_repr,
};
let name =
format!("steel_attention_{itype_repr}_bq{BQ}_bk{bk}_bd{bd}_wm{WM}_wn{WN}_mask{mask_repr}");

let constants = Some(ConstantValues::new(vec![
(200, Value::Bool(/* align_Q */ align_q)),
(201, Value::Bool(/* align_K */ align_k)),
(300, Value::Bool(/* has_mask */ has_mask)),
(301, Value::Bool(/* do_causal */ do_causal)),
]));

let pipeline = kernels.load_pipeline(device, Source::Sdpa, name)?;
let pipeline = kernels.load_pipeline_with_constants(device, Source::Sdpa, name, constants)?;
let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline);

// q = (bs, qhead, seq, hidden)
// k/v = (bs, kv_head, seq, hidden)

let qseq = q_shape[q_shape.len() - 2];

let m = q_shape[q_shape.len() - 2];
let n = m;
let k = q_shape[q_shape.len() - 1];
let bs_out = q_shape[0] * q_shape[1];

let batch_shape = [q_shape[0] * q_shape[1]];
let dk = q_shape[q_shape.len() - 1];
let ldq = dk;
let ldk = dk;
let ldv = dk;
let lds = BN;
let ldo = dk;

let tn = 1;
let tm = m.div_ceil(BM);

let b_stride_q = dk * qseq;
let b_stride_k = dk * qseq;
let b_stride_v = dk * qseq;
let b_stride_o = dk * qseq;
let swizzle_log = 0;
let gemm_n_iterations_aligned = n.div_ceil(BN);
let gemm_k_iterations_aligned = k.div_ceil(*bk);
let gemm_sv_m_block_iterations = m.div_ceil(BM);
let batch_ndim = batch_shape.len();

let alpha = if softcapping != 1. {
alpha / softcapping
} else {
alpha
let nq = (ql + BQ - 1) / BQ;
let nk = (kl + bk - 1) / bk;

let nq_aligned = ql / BQ;
let nk_aligned = kl / bk;

let params = AttnParams {
b: b as i32,
h: h as i32,
d: d as i32,
ql: ql as i32,
kl: kl as i32,
gqa_factor: gqa_factor as i32,
scale,
nq: nq as i32,
nk: nk as i32,
nq_aligned: nq_aligned as i32,
nk_aligned: nk_aligned as i32,
ql_rem: (ql - nq_aligned * BQ) as i32,
kl_rem: (kl - nk_aligned * bk) as i32,
ql_off: (kl - ql) as i32,
q_strides: [
q_strides[0] as i64,
q_strides[1] as i64,
q_strides[2] as i64,
],
k_strides: [
k_strides[0] as i64,
k_strides[1] as i64,
k_strides[2] as i64,
],
v_strides: [
v_strides[0] as i64,
v_strides[1] as i64,
v_strides[2] as i64,
],
o_strides: [
o_strides[0] as i64,
o_strides[1] as i64,
o_strides[2] as i64,
],
};

let params = MLXFastAttentionParams {
m: m as i32,
n: n as i32,
k: k as i32,
ldq: ldq as i32,
ldk: ldk as i32,
ldv: ldv as i32,
lds: lds as i32,
ldo: ldo as i32,
tiles_n: tn,
tiles_m: tm as i32,
batch_stride_q: b_stride_q as i32,
batch_stride_k: b_stride_k as i32,
batch_stride_v: b_stride_v as i32,
batch_stride_o: b_stride_o as i32,
swizzle_log,
gemm_n_iterations_aligned: gemm_n_iterations_aligned as i32,
gemm_k_iterations_aligned: gemm_k_iterations_aligned as i32,
gemm_sv_m_block_iterations: gemm_sv_m_block_iterations as i32,
batch_ndim: batch_ndim as i32,
alpha,
softcapping,
};
let batch_strides = [b_stride_q, b_stride_k, b_stride_v, b_stride_o];
impl EncoderParam for AttnParams {
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
encoder.set_bytes(
position,
core::mem::size_of::<AttnParams>() as u64,
&data as *const AttnParams as *const c_void,
);
}
}

impl EncoderParam for MLXFastAttentionParams {
impl EncoderParam for AttnMaskParams {
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) {
encoder.set_bytes(
position,
core::mem::size_of::<MLXFastAttentionParams>() as u64,
&data as *const MLXFastAttentionParams as *const c_void,
core::mem::size_of::<AttnMaskParams>() as u64,
&data as *const AttnMaskParams as *const c_void,
);
}
}

set_params!(
encoder,
(
(q_buffer, q_offset),
(k_buffer, k_offset),
(v_buffer, v_offset),
output,
params,
&batch_shape[..],
&batch_strides[..]
)
);
if let Some(mask) = mask_buffer {
let mask_strides = m_strides.unwrap();
let mask_params = AttnMaskParams {
m_strides: [
mask_strides[0] as i64,
mask_strides[1] as i64,
mask_strides[2] as i64,
],
};
encoder.use_resource(mask, metal::MTLResourceUsage::Read);

set_params!(
encoder,
(
(q_buffer, q_offset),
(k_buffer, k_offset),
(v_buffer, v_offset),
output,
params,
mask_params,
mask
)
);
} else {
set_params!(
encoder,
(
(q_buffer, q_offset),
(k_buffer, k_offset),
(v_buffer, v_offset),
output,
params
)
);
}

let grid_dims = MTLSize {
width: 1,
height: tm as u64,
depth: bs_out as u64,
width: nq as u64,
height: h as u64,
depth: b as u64,
};
let group_dims = MTLSize {
width: 32,
Expand All @@ -1749,6 +1776,7 @@ pub fn call_sdpa_full(
encoder.use_resource(v_buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(grid_dims, group_dims);

Ok(())
}

Expand Down
Loading