Skip to content

Commit 79a2125

Browse files
authored
bugfix: avoid potential illegal memory access (#267)
Followup of #266, add guard to mask array access.
1 parent 7304282 commit 79a2125

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

include/flashinfer/attention/prefill.cuh

+5-5
Original file line numberDiff line numberDiff line change
@@ -573,11 +573,11 @@ __device__ __forceinline__ void mask_s(const uint32_t qo_idx_base, const uint32_
573573
? (kv_idx > kv_len + q_idx - qo_len || (partition_kv && kv_idx >= chunk_end))
574574
: kv_idx >= chunk_end);
575575
s_frag[fx][fz][reg_id] =
576-
out_of_boundary
577-
? DTypeQKAccum(-5e4)
578-
: s_frag[fx][fz][reg_id] + DTypeQKAccum(mask_mode == MaskMode::kCustom
579-
? custom_mask[q_idx * kv_len + kv_idx]
580-
: 0.f);
576+
out_of_boundary ? DTypeQKAccum(-5e4)
577+
: s_frag[fx][fz][reg_id] +
578+
DTypeQKAccum((mask_mode == MaskMode::kCustom && q_idx < qo_len)
579+
? custom_mask[q_idx * kv_len + kv_idx]
580+
: 0.f);
581581
}
582582
}
583583
}

0 commit comments

Comments
 (0)