Skip to content

Commit 1092e7e

Browse files
authored
bugfix: fix cudagraph-compatible prefill/decode apis (#281)
The `indptr` array length should be a upper-bound of `batch_size + 1` in cuda graph mode.
1 parent 7def34e commit 1092e7e

File tree

4 files changed

+19
-18
lines changed

4 files changed

+19
-18
lines changed

include/flashinfer/attention/handler.cuh

+1-1
Original file line numberDiff line numberDiff line change
@@ -420,7 +420,7 @@ class BatchDecodeHandler {
420420
* \note (Zihao): when enable_cuda_graph is true, max_workspace_size_in_bytes will be ignored,
421421
* when enable_cuda_graph is false, max_batch_size will be ignored.
422422
*/
423-
BatchDecodeHandler(size_t max_workspace_size_in_bytes = 128 * 64 * 64,
423+
BatchDecodeHandler(size_t max_workspace_size_in_bytes = 128 * 1024 * 1024,
424424
size_t max_batch_size = 16384, bool enable_cuda_graph = false)
425425
: batch_size_after_partition_(0U),
426426
float_buffer_(nullptr),

python/csrc/batch_decode.cu

+2-2
Original file line numberDiff line numberDiff line change
@@ -222,8 +222,8 @@ std::vector<torch::Tensor> BatchDecodeWithPagedKVCachePyTorchWrapper::Forward(
222222
}
223223
CHECK_EQ(paged_kv_data.size(1), 2);
224224
CHECK_EQ(paged_kv_data.size(4), head_dim);
225-
CHECK_EQ(paged_kv_indptr.size(0), batch_size + 1);
226-
CHECK_EQ(paged_kv_last_page_len.size(0), batch_size);
225+
CHECK_GE(paged_kv_indptr.size(0), batch_size + 1);
226+
CHECK_GE(paged_kv_last_page_len.size(0), batch_size);
227227
// TODO(Zihao): support dispatching to different data types
228228
CHECK_EQ(paged_kv_indptr.scalar_type(), torch::kInt32);
229229
CHECK_EQ(paged_kv_indices.scalar_type(), torch::kInt32);

python/csrc/batch_prefill.cu

+12-12
Original file line numberDiff line numberDiff line change
@@ -83,9 +83,9 @@ std::vector<torch::Tensor> BatchPrefillWithPagedKVCachePyTorchWrapper::Forward(
8383
num_kv_heads = paged_kv_data.size(3);
8484
}
8585
CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads);
86-
CHECK_EQ(qo_indptr.size(0), batch_size + 1);
87-
CHECK_EQ(paged_kv_indptr.size(0), batch_size + 1);
88-
CHECK_EQ(paged_kv_last_page_len.size(0), batch_size);
86+
CHECK_GE(qo_indptr.size(0), batch_size + 1);
87+
CHECK_GE(paged_kv_indptr.size(0), batch_size + 1);
88+
CHECK_GE(paged_kv_last_page_len.size(0), batch_size);
8989
CHECK_EQ(paged_kv_data.size(1), 2);
9090
CHECK_EQ(paged_kv_data.size(4), head_dim);
9191
qo_indptr = qo_indptr.to(torch::kInt32);
@@ -186,12 +186,12 @@ std::vector<torch::Tensor> BatchPrefillWithPagedKVCachePyTorchWrapper::ForwardCu
186186
num_kv_heads = paged_kv_data.size(3);
187187
}
188188
CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads);
189-
CHECK_EQ(qo_indptr.size(0), batch_size + 1);
190-
CHECK_EQ(paged_kv_indptr.size(0), batch_size + 1);
191-
CHECK_EQ(paged_kv_last_page_len.size(0), batch_size);
189+
CHECK_GE(qo_indptr.size(0), batch_size + 1);
190+
CHECK_GE(paged_kv_indptr.size(0), batch_size + 1);
191+
CHECK_GE(paged_kv_last_page_len.size(0), batch_size);
192192
CHECK_EQ(paged_kv_data.size(1), 2);
193193
CHECK_EQ(paged_kv_data.size(4), head_dim);
194-
CHECK_EQ(qk_indptr.size(0), batch_size + 1);
194+
CHECK_GE(qk_indptr.size(0), batch_size + 1);
195195
qo_indptr = qo_indptr.to(torch::kInt32);
196196
paged_kv_indptr = paged_kv_indptr.to(torch::kInt32);
197197
paged_kv_indices = paged_kv_indices.to(torch::kInt32);
@@ -303,7 +303,7 @@ std::vector<torch::Tensor> BatchPrefillWithRaggedKVCachePyTorchWrapper::Forward(
303303
int64_t nnz_qo = q.size(0);
304304
int64_t num_qo_heads = q.size(1);
305305
int64_t head_dim = q.size(2);
306-
CHECK_EQ(kv_indptr.size(0), batch_size + 1);
306+
CHECK_GE(kv_indptr.size(0), batch_size + 1);
307307
int64_t num_kv_heads = (kv_layout_ == QKVLayout::kNHD) ? k.size(1) : k.size(0);
308308
CHECK_EQ(k.size(0), v.size(0));
309309
CHECK_EQ(k.size(1), v.size(1));
@@ -366,8 +366,8 @@ std::vector<torch::Tensor> BatchPrefillWithRaggedKVCachePyTorchWrapper::Forward(
366366
std::vector<torch::Tensor> BatchPrefillWithRaggedKVCachePyTorchWrapper::ForwardCustomMask(
367367
torch::Tensor q, torch::Tensor qo_indptr, torch::Tensor k, torch::Tensor v,
368368
torch::Tensor kv_indptr, torch::Tensor custom_mask, torch::Tensor qk_indptr,
369-
unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, float sm_scale, float rope_scale,
370-
float rope_theta, bool return_lse) {
369+
unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction,
370+
float sm_scale, float rope_scale, float rope_theta, bool return_lse) {
371371
CHECK_INPUT(q);
372372
CHECK_INPUT(qo_indptr);
373373
CHECK_INPUT(k);
@@ -386,8 +386,8 @@ std::vector<torch::Tensor> BatchPrefillWithRaggedKVCachePyTorchWrapper::ForwardC
386386
int64_t nnz_qo = q.size(0);
387387
int64_t num_qo_heads = q.size(1);
388388
int64_t head_dim = q.size(2);
389-
CHECK_EQ(kv_indptr.size(0), batch_size + 1);
390-
CHECK_EQ(qk_indptr.size(0), batch_size + 1);
389+
CHECK_GE(kv_indptr.size(0), batch_size + 1);
390+
CHECK_GE(qk_indptr.size(0), batch_size + 1);
391391
int64_t num_kv_heads = (kv_layout_ == QKVLayout::kNHD) ? k.size(1) : k.size(0);
392392
CHECK_EQ(k.size(0), v.size(0));
393393
CHECK_EQ(k.size(1), v.size(1));

python/tests/test_batch_decode_kernels.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -156,9 +156,10 @@ def test_cuda_graph_batch_decode_with_paged_kv_cache(
156156
(batch_size,), (kv_len - 1) % page_size + 1, dtype=torch.int32
157157
)
158158

159-
kv_indptr_device_buffer = torch.empty(batch_size + 1).int().to(0)
160-
kv_indices_device_buffer = torch.empty(total_num_pages).int().to(0)
161-
kv_last_page_device_buffer = torch.empty(batch_size).int().to(0)
159+
# NOTE(Zihao): allocate more space than needed for testing
160+
kv_indptr_device_buffer = torch.empty(batch_size + 11).int().to(0)
161+
kv_indices_device_buffer = torch.empty(total_num_pages + 10).int().to(0)
162+
kv_last_page_device_buffer = torch.empty(batch_size + 10).int().to(0)
162163

163164
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8).to(0)
164165
wrapper = flashinfer.CUDAGraphBatchDecodeWithPagedKVCacheWrapper(

0 commit comments

Comments
 (0)