Skip to content

Commit f0bb0a3

Browse files
authored
perf: split kv-cache for prefill/append kernels (#310)
Duplicate of #75, but re-based on the main branch. Note that to support CUDAGraph, we cannot make `kv_chunk_size` a function argument, which will be passed by value, and cannot change once captured by CUDAGraph. Instead, we pass `kv_chunk_size` through a `kv_chunk_size_ptr` which is a pointer to a global memory address that stores the `kv_chunk_size`, its value can be set in `BeginForward` fuctions.
1 parent cf77d96 commit f0bb0a3

17 files changed

+875
-378
lines changed

include/flashinfer/attention/handler.cuh

+349-51
Large diffs are not rendered by default.

include/flashinfer/attention/prefill.cuh

+292-152
Large diffs are not rendered by default.

include/flashinfer/decode_attention_decl.cuh

+1-1
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ cudaError_t BatchDecodeWithPagedKVCacheWrapperDispatched(
6464
paged_kv_t<PAGE_STORAGE, KV_LAYOUT, DTypeKV, IdType> new_paged_kv = paged_kv;
6565
kv_partition_info_t<IdType> kv_partition_info;
6666
DTypeOut* tmp_v = handler->GetTempV<DTypeOut>();
67-
float* tmp_s = handler->GetTempS<float>();
67+
float* tmp_s = handler->GetTempS();
6868

6969
if (handler->IsForwardStarted()) {
7070
if (tmp_v != nullptr) {

include/flashinfer/prefill_attention_decl.cuh

+63-33
Original file line numberDiff line numberDiff line change
@@ -38,44 +38,60 @@ cudaError_t SinglePrefillWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn*
3838
uint32_t kv_len, float sm_scale, float rope_scale,
3939
float rope_theta, cudaStream_t stream);
4040

41-
template <uint32_t NUM_FRAGS_X, uint32_t HEAD_DIM, LogitsPostHook LOGITS_POST_HOOK,
42-
QKVLayout KV_LAYOUT, PosEncodingMode POS_ENCODING_MODE, bool ALLOW_FP16_QK_REDUCTION,
41+
template <uint32_t num_frags_x, uint32_t HEAD_DIM, LogitsPostHook LOGITS_POST_HOOK,
42+
QKVLayout KV_LAYOUT, PosEncodingMode pos_encoding_mode, bool ALLOW_FP16_QK_REDUCTION,
4343
MaskMode MASK_MODE, typename DTypeIn, typename DTypeOut, typename IdType>
4444
cudaError_t BatchPrefillWithRaggedKVCacheDispatched(
45-
DTypeIn* q, IdType* request_indices, IdType* tile_indices, IdType* qo_indptr, DTypeIn* k,
46-
DTypeIn* v, IdType* kv_indptr, uint8_t* custom_mask, IdType* qk_indptr, IdType* q_offset,
47-
IdType* k_rope_pos_offset, DTypeOut* o, float* tmp, float* lse, uint32_t batch_size,
48-
uint32_t num_qo_tiles, uint32_t num_qo_heads, uint32_t num_kv_heads, float sm_scale,
49-
float rope_scale, float rope_theta, cudaStream_t stream = nullptr);
45+
DTypeIn* q, IdType* request_indices, IdType* q_tile_indices, IdType* kv_tile_indices,
46+
IdType* q_indptr, DTypeIn* k, DTypeIn* v, IdType* kv_indptr, uint8_t* custom_mask,
47+
IdType* qk_indptr, IdType* q_offset, IdType* k_rope_pos_offset, IdType* o_indptr, DTypeOut* o,
48+
DTypeOut* tmp_v, float* tmp_s, float* lse, IdType* merge_indptr, bool* block_valid_mask,
49+
IdType* kv_chunk_size_ptr, const uint32_t total_num_rows, const uint32_t num_qo_heads,
50+
const uint32_t padded_batch_size, const uint32_t num_kv_heads, const float sm_scale,
51+
const float rope_scale, const float rope_theta, cudaStream_t stream = nullptr);
5052

51-
template <PageStorage PAGE_STORAGE, uint32_t NUM_FRAGS_X, uint32_t HEAD_DIM,
52-
LogitsPostHook LOGITS_POST_HOOK, QKVLayout KV_LAYOUT, PosEncodingMode POS_ENCODING_MODE,
53+
template <PageStorage page_storage, uint32_t num_frags_x, uint32_t HEAD_DIM,
54+
LogitsPostHook LOGITS_POST_HOOK, QKVLayout kv_layout, PosEncodingMode pos_encoding_mode,
5355
bool ALLOW_FP16_QK_REDUCTION, MaskMode MASK_MODE, typename DTypeIn, typename DTypeOut,
5456
typename IdType>
5557
cudaError_t BatchPrefillWithPagedKVCacheDispatched(
56-
DTypeIn* q, IdType* request_indices, IdType* tile_indices, IdType* qo_indptr, IdType* q_offset,
57-
paged_kv_t<PAGE_STORAGE, KV_LAYOUT, DTypeIn, IdType> paged_kv, uint8_t* custom_mask,
58-
IdType* qk_indptr, DTypeOut* o, float* tmp, float* lse, uint32_t num_qo_tiles,
59-
uint32_t num_qo_heads, float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream);
58+
DTypeIn* q, IdType* request_indices, IdType* q_tile_indices, IdType* kv_tile_indices,
59+
IdType* q_indptr, IdType* q_offset,
60+
paged_kv_t<page_storage, kv_layout, DTypeIn, IdType> paged_kv, uint8_t* custom_mask,
61+
IdType* qk_indptr, IdType* o_indptr, DTypeOut* o, DTypeOut* tmp_v, float* tmp_s, float* lse,
62+
IdType* merge_indptr, bool* block_valid_mask, IdType* kv_chunk_size_ptr,
63+
uint32_t total_num_rows, uint32_t num_qo_heads, uint32_t padded_batch_size, float sm_scale,
64+
float rope_scale, float rope_theta, cudaStream_t stream);
6065

6166
template <PageStorage PAGE_STORAGE, uint32_t HEAD_DIM, LogitsPostHook LOGITS_POST_HOOK,
6267
QKVLayout KV_LAYOUT, PosEncodingMode POS_ENCODING_MODE, bool ALLOW_FP16_QK_REDUCTION,
6368
MaskMode MASK_MODE, typename DTypeIn, typename DTypeOut, typename IdType>
6469
cudaError_t BatchPrefillWithPagedKVCacheWrapperDispatched(
65-
BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr, IdType* q_offset,
70+
BatchPrefillHandler* handler, DTypeIn* q, IdType* q_indptr, IdType* q_offset,
6671
paged_kv_t<PAGE_STORAGE, KV_LAYOUT, DTypeIn, IdType> paged_kv, uint8_t* custom_mask,
6772
IdType* qk_indptr, DTypeOut* o, float* lse, uint32_t num_qo_heads, float sm_scale,
6873
float rope_scale, float rope_theta, cudaStream_t stream) {
69-
float* tmp = nullptr;
70-
IdType* request_indices = nullptr;
71-
IdType* tile_indices = nullptr;
74+
DTypeOut* tmp_v = nullptr;
75+
float* tmp_s = nullptr;
76+
IdType *request_indices = nullptr, *qo_tile_indices = nullptr, *kv_tile_indices = nullptr,
77+
*o_indptr = nullptr, *merge_indptr = nullptr, *kv_chunk_size_ptr = nullptr;
78+
bool* block_valid_mask = nullptr;
7279
uint32_t num_frags_x = 0U;
73-
uint32_t num_qo_tiles = 0U;
80+
uint32_t padded_batch_size = 0U;
81+
uint32_t total_num_rows = 0U;
7482
if (handler->IsForwardStarted()) {
83+
tmp_v = handler->GetTempV<DTypeOut>();
84+
tmp_s = handler->GetTempS();
7585
request_indices = handler->GetRequestIndices<IdType>();
76-
tile_indices = handler->GetTileIndices<IdType>();
86+
qo_tile_indices = handler->GetQOTileIndices<IdType>();
87+
kv_tile_indices = handler->GetKVTileIndices<IdType>();
88+
block_valid_mask = handler->GetBlockValidMask();
89+
o_indptr = handler->GetOIndptr<IdType>();
90+
merge_indptr = handler->GetMergeIndptr<IdType>();
91+
kv_chunk_size_ptr = handler->GetKVChunkSizePtr<IdType>();
7792
num_frags_x = handler->GetNumFragsX();
78-
num_qo_tiles = handler->GetNumQOTiles();
93+
padded_batch_size = handler->GetPaddedBatchSize();
94+
total_num_rows = handler->GetTotalNumRows();
7995
} else {
8096
std::ostringstream err_msg;
8197
err_msg << "Please call BatchPrefillHandler's BeginForward() before calling "
@@ -87,8 +103,10 @@ cudaError_t BatchPrefillWithPagedKVCacheWrapperDispatched(
87103
return BatchPrefillWithPagedKVCacheDispatched<
88104
PAGE_STORAGE, NUM_FRAGS_X, HEAD_DIM, LOGITS_POST_HOOK, KV_LAYOUT, POS_ENCODING_MODE,
89105
ALLOW_FP16_QK_REDUCTION, MASK_MODE, DTypeIn, DTypeOut, IdType>(
90-
q, request_indices, tile_indices, qo_indptr, q_offset, paged_kv, custom_mask, qk_indptr, o,
91-
tmp, lse, num_qo_heads, num_qo_tiles, sm_scale, rope_scale, rope_theta, stream);
106+
q, request_indices, qo_tile_indices, kv_tile_indices, q_indptr, q_offset, paged_kv,
107+
custom_mask, qk_indptr, o_indptr, o, tmp_v, tmp_s, lse, merge_indptr, block_valid_mask,
108+
kv_chunk_size_ptr, total_num_rows, num_qo_heads, padded_batch_size, sm_scale, rope_scale,
109+
rope_theta, stream);
92110
});
93111
return cudaSuccess;
94112
}
@@ -97,21 +115,32 @@ template <uint32_t HEAD_DIM, LogitsPostHook LOGITS_POST_HOOK, QKVLayout KV_LAYOU
97115
PosEncodingMode POS_ENCODING_MODE, bool ALLOW_FP16_QK_REDUCTION, MaskMode MASK_MODE,
98116
typename DTypeIn, typename DTypeOut, typename IdType>
99117
cudaError_t BatchPrefillWithRaggedKVCacheWrapperDispatched(
100-
BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr, DTypeIn* k, DTypeIn* v,
118+
BatchPrefillHandler* handler, DTypeIn* q, IdType* q_indptr, DTypeIn* k, DTypeIn* v,
101119
IdType* kv_indptr, uint8_t* custom_mask, IdType* qk_indptr, IdType* q_offset,
102-
IdType* k_rope_pos_offset, DTypeOut* o, float* lse, uint32_t batch_size, uint32_t num_qo_heads,
120+
IdType* k_rope_pos_offset, DTypeOut* o, float* lse, uint32_t num_qo_heads,
103121
uint32_t num_kv_heads, float sm_scale, float rope_scale, float rope_theta,
104122
cudaStream_t stream) {
105-
float* tmp = nullptr;
106-
IdType* request_indices = nullptr;
107-
IdType* tile_indices = nullptr;
123+
DTypeOut* tmp_v = nullptr;
124+
float* tmp_s = nullptr;
125+
IdType *request_indices = nullptr, *qo_tile_indices = nullptr, *kv_tile_indices = nullptr,
126+
*o_indptr = nullptr, *merge_indptr = nullptr, *kv_chunk_size_ptr = nullptr;
127+
bool* block_valid_mask = nullptr;
108128
uint32_t num_frags_x = 0U;
109-
uint32_t num_qo_tiles = 0U;
129+
uint32_t padded_batch_size = 0U;
130+
uint32_t total_num_rows = 0U;
110131
if (handler->IsForwardStarted()) {
132+
tmp_v = handler->GetTempV<DTypeOut>();
133+
tmp_s = handler->GetTempS();
111134
request_indices = handler->GetRequestIndices<IdType>();
112-
tile_indices = handler->GetTileIndices<IdType>();
135+
qo_tile_indices = handler->GetQOTileIndices<IdType>();
136+
kv_tile_indices = handler->GetKVTileIndices<IdType>();
137+
block_valid_mask = handler->GetBlockValidMask();
138+
o_indptr = handler->GetOIndptr<IdType>();
139+
merge_indptr = handler->GetMergeIndptr<IdType>();
140+
kv_chunk_size_ptr = handler->GetKVChunkSizePtr<IdType>();
113141
num_frags_x = handler->GetNumFragsX();
114-
num_qo_tiles = handler->GetNumQOTiles();
142+
padded_batch_size = handler->GetPaddedBatchSize();
143+
total_num_rows = handler->GetTotalNumRows();
115144
} else {
116145
std::ostringstream err_msg;
117146
err_msg << "Please call BatchPrefillHandler's BeginForward() before calling "
@@ -123,9 +152,10 @@ cudaError_t BatchPrefillWithRaggedKVCacheWrapperDispatched(
123152
return BatchPrefillWithRaggedKVCacheDispatched<
124153
NUM_FRAGS_X, HEAD_DIM, LOGITS_POST_HOOK, KV_LAYOUT, POS_ENCODING_MODE,
125154
ALLOW_FP16_QK_REDUCTION, MASK_MODE, DTypeIn, DTypeOut, IdType>(
126-
q, request_indices, tile_indices, qo_indptr, k, v, kv_indptr, custom_mask, qk_indptr,
127-
q_offset, k_rope_pos_offset, o, tmp, lse, batch_size, num_qo_heads, num_qo_tiles,
128-
num_kv_heads, sm_scale, rope_scale, rope_theta, stream);
155+
q, request_indices, qo_tile_indices, kv_tile_indices, q_indptr, k, v, kv_indptr,
156+
custom_mask, qk_indptr, q_offset, k_rope_pos_offset, o_indptr, o, tmp_v, tmp_s, lse,
157+
merge_indptr, block_valid_mask, kv_chunk_size_ptr, total_num_rows, num_qo_heads,
158+
padded_batch_size, num_kv_heads, sm_scale, rope_scale, rope_theta, stream);
129159
});
130160
return cudaSuccess;
131161
}

include/flashinfer/utils.cuh

-40
Original file line numberDiff line numberDiff line change
@@ -50,15 +50,6 @@
5050
}
5151
#endif
5252

53-
#define DISPATCH_SPLIT_QO_INDPTR(split_qo_indptr, SPLIT_QO_INDPTR, ...) \
54-
if (split_qo_indptr) { \
55-
constexpr bool SPLIT_QO_INDPTR = true; \
56-
__VA_ARGS__ \
57-
} else { \
58-
constexpr bool SPLIT_QO_INDPTR = false; \
59-
__VA_ARGS__ \
60-
}
61-
6253
#define DISPATCH_ALLOW_FP16_QK_REDUCTION(allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, ...) \
6354
if (allow_fp16_qk_reduction) { \
6455
throw std::runtime_error("FP16_QK_REDUCTION disabled at compile time"); \
@@ -265,37 +256,6 @@ __forceinline__ __device__ __host__ T1 ceil_div(const T1 x, const T2 y) {
265256
return (x + y - 1) / y;
266257
}
267258

268-
template <typename IdType>
269-
std::tuple<IdType, IdType, std::vector<IdType>, std::vector<IdType>> split_qo_indptr(
270-
IdType* qo_indptr, uint32_t batch_size, uint32_t gqa_group_size, uint32_t head_dim,
271-
cudaStream_t stream = nullptr) {
272-
constexpr uint32_t num_warps = 4;
273-
std::vector<IdType> qo_indptr_h(batch_size + 1), request_indices, tile_indices;
274-
if (is_device_ptr((void*)qo_indptr)) {
275-
cudaMemcpyAsync(qo_indptr_h.data(), qo_indptr, sizeof(IdType) * (batch_size + 1),
276-
cudaMemcpyDeviceToHost, stream);
277-
} else {
278-
qo_indptr_h.assign(qo_indptr, qo_indptr + batch_size + 1);
279-
}
280-
281-
const uint32_t total_q_len = qo_indptr_h[batch_size];
282-
const bool avg_len_greater_than_64 = total_q_len * gqa_group_size > 64 * batch_size;
283-
const uint32_t num_frags_x = (head_dim < 256 && avg_len_greater_than_64) ? 2 : 1;
284-
const uint32_t num_rows_per_cta = num_frags_x * num_warps * 16;
285-
uint32_t num_qo_tiles = 0;
286-
287-
for (uint32_t i = 0; i < batch_size; ++i) {
288-
for (uint32_t j = qo_indptr_h[i] * gqa_group_size; j < qo_indptr_h[i + 1] * gqa_group_size;
289-
j += num_rows_per_cta) {
290-
request_indices.push_back(i);
291-
tile_indices.push_back((j - qo_indptr_h[i] * gqa_group_size) / num_rows_per_cta);
292-
++num_qo_tiles;
293-
}
294-
}
295-
296-
return {num_frags_x, num_qo_tiles, std::move(request_indices), std::move(tile_indices)};
297-
}
298-
299259
template <typename T>
300260
inline void DebugPrintCUDAArray(T* device_ptr, size_t size, std::string prefix = "") {
301261
std::vector<T> host_array(size);

python/csrc/batch_prefill.cu

+33-21
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,10 @@
2121
using namespace flashinfer;
2222

2323
void BatchPrefillWithPagedKVCachePyTorchWrapper::BeginForward(
24-
torch::Tensor workspace_buffer, torch::Tensor qo_indptr, unsigned int batch_size,
25-
unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int head_dim) {
24+
torch::Tensor workspace_buffer, torch::Tensor qo_indptr, torch::Tensor paged_kv_indptr,
25+
torch::Tensor paged_kv_last_page_len, unsigned int batch_size, unsigned int num_qo_heads,
26+
unsigned int num_kv_heads, unsigned int head_dim, unsigned int page_size,
27+
torch::Tensor empty_q_data) {
2628
// NOTE(Zihao): not necessary to be a CUDA tensor
2729
CHECK_CONTIGUOUS(qo_indptr);
2830
CHECK_CONTIGUOUS(workspace_buffer);
@@ -31,16 +33,23 @@ void BatchPrefillWithPagedKVCachePyTorchWrapper::BeginForward(
3133
CHECK_DIM(1, workspace_buffer);
3234

3335
qo_indptr = qo_indptr.to(torch::kInt32);
36+
paged_kv_indptr = paged_kv_indptr.to(torch::kInt32);
37+
paged_kv_last_page_len = paged_kv_last_page_len.to(torch::kInt32);
3438
size_t workspace_size_in_bytes = workspace_buffer.size(0) * workspace_buffer.element_size();
3539
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();
3640
handler_->SetCUDAStream(torch_current_stream);
3741

38-
cudaError_t status =
39-
handler_->BeginForward(static_cast<void*>(workspace_buffer.data_ptr()),
40-
workspace_size_in_bytes, static_cast<int32_t*>(qo_indptr.data_ptr()),
41-
batch_size, num_qo_heads, num_kv_heads, head_dim);
42-
TORCH_CHECK(status == cudaSuccess, "BatchPrefillWithPagedKVCache failed with error ",
43-
cudaGetErrorString(status));
42+
DISPATCH_PYTORCH_DTYPE_TO_CTYPE(empty_q_data.scalar_type(), q_type, [&] {
43+
cudaError_t status = handler_->BeginForward<q_type, int32_t>(
44+
static_cast<void*>(workspace_buffer.data_ptr()), workspace_size_in_bytes,
45+
static_cast<int32_t*>(qo_indptr.data_ptr()),
46+
static_cast<int32_t*>(paged_kv_indptr.data_ptr()),
47+
static_cast<int32_t*>(paged_kv_last_page_len.data_ptr()), batch_size, num_qo_heads,
48+
num_kv_heads, head_dim, page_size);
49+
TORCH_CHECK(status == cudaSuccess, "BatchPrefillWithPagedKVCache failed with error ",
50+
cudaGetErrorString(status));
51+
return true;
52+
});
4453
}
4554

4655
void BatchPrefillWithPagedKVCachePyTorchWrapper::EndForward() { handler_->EndForward(); }
@@ -198,7 +207,6 @@ std::vector<torch::Tensor> BatchPrefillWithPagedKVCachePyTorchWrapper::ForwardCu
198207
paged_kv_indptr = paged_kv_indptr.to(torch::kInt32);
199208
paged_kv_indices = paged_kv_indices.to(torch::kInt32);
200209
paged_kv_last_page_len = paged_kv_last_page_len.to(torch::kInt32);
201-
custom_mask = custom_mask.to(torch::kFloat32);
202210
qk_indptr = qk_indptr.to(torch::kInt32);
203211

204212
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();
@@ -257,8 +265,9 @@ std::vector<torch::Tensor> BatchPrefillWithPagedKVCachePyTorchWrapper::ForwardCu
257265
}
258266

259267
void BatchPrefillWithRaggedKVCachePyTorchWrapper::BeginForward(
260-
torch::Tensor workspace_buffer, torch::Tensor qo_indptr, unsigned int batch_size,
261-
unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int head_dim) {
268+
torch::Tensor workspace_buffer, torch::Tensor qo_indptr, torch::Tensor kv_indptr,
269+
unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads,
270+
unsigned int head_dim, torch::Tensor empty_q_data) {
262271
// NOTE(Zihao): not necessary to be a CUDA tensor
263272
CHECK_CONTIGUOUS(qo_indptr);
264273
CHECK_CONTIGUOUS(workspace_buffer);
@@ -267,16 +276,21 @@ void BatchPrefillWithRaggedKVCachePyTorchWrapper::BeginForward(
267276
CHECK_DIM(1, workspace_buffer);
268277

269278
qo_indptr = qo_indptr.to(torch::kInt32);
279+
kv_indptr = kv_indptr.to(torch::kInt32);
270280
size_t workspace_size_in_bytes = workspace_buffer.size(0) * workspace_buffer.element_size();
271281
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();
272282
handler_->SetCUDAStream(torch_current_stream);
273283

274-
cudaError_t status =
275-
handler_->BeginForward(static_cast<void*>(workspace_buffer.data_ptr()),
276-
workspace_size_in_bytes, static_cast<int32_t*>(qo_indptr.data_ptr()),
277-
batch_size, num_qo_heads, num_kv_heads, head_dim);
278-
TORCH_CHECK(status == cudaSuccess, "BatchPrefillWithPagedKVCache failed with error ",
279-
cudaGetErrorString(status));
284+
DISPATCH_PYTORCH_DTYPE_TO_CTYPE(empty_q_data.scalar_type(), q_type, [&] {
285+
cudaError_t status = handler_->BeginForward<q_type, int32_t>(
286+
static_cast<void*>(workspace_buffer.data_ptr()), workspace_size_in_bytes,
287+
static_cast<int32_t*>(qo_indptr.data_ptr()), static_cast<int32_t*>(kv_indptr.data_ptr()),
288+
/*last_page_len=*/nullptr, batch_size, num_qo_heads, num_kv_heads, head_dim,
289+
/*page_size=*/1);
290+
TORCH_CHECK(status == cudaSuccess, "BatchPrefillWithPagedKVCache failed with error ",
291+
cudaGetErrorString(status));
292+
return true;
293+
});
280294
}
281295

282296
void BatchPrefillWithRaggedKVCachePyTorchWrapper::EndForward() { handler_->EndForward(); }
@@ -348,8 +362,7 @@ std::vector<torch::Tensor> BatchPrefillWithRaggedKVCachePyTorchWrapper::Forward(
348362
/*q_offset=*/nullptr, /*k_rope_pos_offset=*/nullptr,
349363
static_cast<c_type*>(o.data_ptr()),
350364
/*lse=*/return_lse ? static_cast<float*>(lse.data_ptr()) : nullptr,
351-
batch_size, num_qo_heads, num_kv_heads, sm_scale, rope_scale,
352-
rope_theta,
365+
num_qo_heads, num_kv_heads, sm_scale, rope_scale, rope_theta,
353366
/*stream=*/torch_current_stream);
354367
TORCH_CHECK(status == cudaSuccess,
355368
"BatchPrefillWithRaggedKVCache failed with error ",
@@ -406,7 +419,6 @@ std::vector<torch::Tensor> BatchPrefillWithRaggedKVCachePyTorchWrapper::ForwardC
406419
qo_indptr = qo_indptr.to(torch::kInt32);
407420
kv_indptr = kv_indptr.to(torch::kInt32);
408421
qk_indptr = qk_indptr.to(torch::kInt32);
409-
custom_mask = custom_mask.to(torch::kFloat32);
410422

411423
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();
412424
torch::Tensor o = torch::empty_like(q, q.options());
@@ -439,7 +451,7 @@ std::vector<torch::Tensor> BatchPrefillWithRaggedKVCachePyTorchWrapper::ForwardC
439451
/*q_offset=*/nullptr, /*k_rope_pos_offset=*/nullptr,
440452
static_cast<c_type*>(o.data_ptr()),
441453
/*lse=*/return_lse ? static_cast<float*>(lse.data_ptr()) : nullptr,
442-
batch_size, num_qo_heads, num_kv_heads, sm_scale, rope_scale, rope_theta,
454+
num_qo_heads, num_kv_heads, sm_scale, rope_scale, rope_theta,
443455
/*stream=*/torch_current_stream);
444456
TORCH_CHECK(status == cudaSuccess,
445457
"BatchPrefillWithRaggedKVCache failed with error ",

0 commit comments

Comments
 (0)