We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 9569106 commit 23413e0Copy full SHA for 23413e0
flashinfer/jit/attention.py
@@ -180,6 +180,7 @@ def gen_batch_decode_mla_module(
180
dtype_o,
181
dtype_idx,
182
head_dim,
183
+ head_dim,
184
use_sliding_window,
185
use_logits_soft_cap,
186
)
include/flashinfer/attention/decode.cuh
@@ -857,6 +857,7 @@ __global__ void BatchDecodeWithPagedKVCacheKernelMLA(Params params) {
857
const float rope_rcp_scale = params.rope_rcp_scale;
858
const float rope_rcp_theta = params.rope_rcp_theta;
859
const bool partition_kv = params.partition_kv;
860
+ params.sm_scale *= math::log2e;
861
862
constexpr uint32_t head_dim_ckv = bdx * vec_size_ckv;
863
constexpr uint32_t head_dim_kpe = bdx * vec_size_kpe;
0 commit comments