Skip to content

Commit da83cf5

Browse files
authored
Bugfix: bugfix to #322 (#325)
Some last commits for bugfix are missing for #322.
1 parent 545b9ca commit da83cf5

File tree

10 files changed

+51
-65
lines changed

10 files changed

+51
-65
lines changed

include/flashinfer/attention/handler.cuh

+18-19
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ inline std::tuple<bool, uint32_t, uint32_t> PrefillBinarySearchKVChunkSize(
9292
const uint32_t qo_chunk_size, const uint32_t min_kv_chunk_size = 1) {
9393
int64_t low = min_kv_chunk_size, high = 0;
9494
int64_t batch_size = packed_qo_len_arr.size();
95-
int64_t max_kv_len;
95+
int64_t max_kv_len = 0;
9696
for (const int64_t& kv_len : kv_len_arr) {
9797
max_kv_len = std::max(max_kv_len, kv_len);
9898
}
@@ -174,9 +174,9 @@ cudaError_t BatchDecodeWithPagedKVCacheWorkEstimationDispatched(
174174
new_batch_size = batch_size;
175175
} else {
176176
// compute max_num_pages_per_batch and new_batch_size
177-
std::vector<IdType> page_indptr_h(batch_size + 1), num_pages(batch_size);
177+
std::vector<IdType> num_pages(batch_size);
178178
for (uint32_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) {
179-
num_pages[batch_idx] = page_indptr_h[batch_idx + 1] - page_indptr_h[batch_idx];
179+
num_pages[batch_idx] = kv_indptr_h[batch_idx + 1] - kv_indptr_h[batch_idx];
180180
}
181181
std::tie(max_num_pages_per_batch, new_batch_size) =
182182
PartitionPagedKVCacheBinarySearchMinNumPagePerBatch(max_grid_size, num_kv_heads, num_pages,
@@ -517,14 +517,16 @@ class BatchDecodeHandler {
517517
};
518518

519519
template <typename IdType>
520-
cudaError_t PrefillSplitQOKVIndptr(
521-
bool& split_kv, uint32_t& split_max_batch_size, uint32_t& total_num_tiles_q,
522-
uint32_t& new_batch_size, WarpLayout& warp_layout, uint32_t& kv_chunk_size,
523-
uint32_t& total_num_rows, std::vector<IdType>& request_indices,
524-
std::vector<IdType>& qo_tile_indices, std::vector<IdType>& kv_tile_indices,
525-
std::vector<IdType>& merge_indptr, std::vector<IdType>& o_indptr, IdType* qo_indptr_h,
526-
IdType* kv_indptr_h, IdType* kv_last_page_len_h, uint32_t batch_size, uint32_t num_qo_heads,
527-
uint32_t num_kv_heads, uint32_t head_dim, uint32_t page_size, cudaStream_t stream = nullptr) {
520+
cudaError_t PrefillSplitQOKVIndptr(bool& split_kv, uint32_t& split_max_batch_size,
521+
uint32_t& total_num_tiles_q, uint32_t& new_batch_size,
522+
WarpLayout& warp_layout, uint32_t& kv_chunk_size,
523+
uint32_t& total_num_rows, std::vector<IdType>& request_indices,
524+
std::vector<IdType>& qo_tile_indices,
525+
std::vector<IdType>& kv_tile_indices,
526+
std::vector<IdType>& merge_indptr, std::vector<IdType>& o_indptr,
527+
IdType* qo_indptr_h, IdType* kv_indptr_h, uint32_t batch_size,
528+
uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t head_dim,
529+
uint32_t page_size) {
528530
request_indices.clear();
529531
qo_tile_indices.clear();
530532
kv_tile_indices.clear();
@@ -536,8 +538,6 @@ cudaError_t PrefillSplitQOKVIndptr(
536538
const uint32_t gqa_group_size = num_qo_heads / num_kv_heads;
537539
total_num_rows = qo_indptr_h[batch_size];
538540

539-
bool has_kv_last_page_len = kv_last_page_len_h != nullptr;
540-
541541
// step 0: get the number of SMs
542542
int num_sm = 0;
543543
int dev_id = 0;
@@ -570,7 +570,7 @@ cudaError_t PrefillSplitQOKVIndptr(
570570
// step 2: determine kv_chunk_size
571571
std::tie(split_kv, kv_chunk_size, new_batch_size) =
572572
PrefillBinarySearchKVChunkSize(max_grid_size, num_kv_heads, packed_qo_len_arr, kv_len_arr,
573-
qo_chunk_size, /*min_kv_chunk_size=*/(128 / page_size));
573+
qo_chunk_size, /*min_kv_chunk_size=*/(512 / page_size));
574574

575575
// step 3: split qo_indptr and kv_indptr
576576
total_num_tiles_q = 0;
@@ -656,9 +656,8 @@ class BatchPrefillHandler {
656656

657657
template <typename DTypeOut, typename IdType>
658658
cudaError_t BeginForward(void* buffer, size_t workspace_size_in_bytes, IdType* qo_indptr_h,
659-
IdType* kv_indptr_h, IdType* kv_last_page_len_h, uint32_t batch_size,
660-
uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t head_dim,
661-
uint32_t page_size) {
659+
IdType* kv_indptr_h, uint32_t batch_size, uint32_t num_qo_heads,
660+
uint32_t num_kv_heads, uint32_t head_dim, uint32_t page_size) {
662661
if (num_qo_heads % num_kv_heads != 0) {
663662
std::ostringstream err_msg;
664663
err_msg << "num_qo_heads " << num_qo_heads << " should be divisible by num_kv_heads "
@@ -672,8 +671,8 @@ class BatchPrefillHandler {
672671
FLASHINFER_CUDA_CALL(PrefillSplitQOKVIndptr(
673672
split_kv, split_max_batch_size, total_num_tiles_q, new_batch_size, warp_layout_,
674673
kv_chunk_size, total_num_rows_, request_indices_vec, qo_tile_indices_vec,
675-
kv_tile_indices_vec, merge_indptr_vec, o_indptr_vec, qo_indptr_h, kv_indptr_h,
676-
kv_last_page_len_h, batch_size, num_qo_heads, num_kv_heads, head_dim, page_size, stream_));
674+
kv_tile_indices_vec, merge_indptr_vec, o_indptr_vec, qo_indptr_h, kv_indptr_h, batch_size,
675+
num_qo_heads, num_kv_heads, head_dim, page_size));
677676
const uint32_t qo_tile_size = get_num_rows_per_cta(warp_layout_);
678677

679678
if (IsCUDAGraphEnabled()) {

python/csrc/batch_prefill.cu

+5-8
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,8 @@ using namespace flashinfer;
2222

2323
void BatchPrefillWithPagedKVCachePyTorchWrapper::BeginForward(
2424
torch::Tensor workspace_buffer, torch::Tensor qo_indptr, torch::Tensor paged_kv_indptr,
25-
torch::Tensor paged_kv_last_page_len, unsigned int batch_size, unsigned int num_qo_heads,
26-
unsigned int num_kv_heads, unsigned int head_dim, unsigned int page_size,
27-
torch::Tensor empty_q_data) {
25+
unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads,
26+
unsigned int head_dim, unsigned int page_size, torch::Tensor empty_q_data) {
2827
// NOTE(Zihao): not necessary to be a CUDA tensor
2928
CHECK_CONTIGUOUS(qo_indptr);
3029
CHECK_CONTIGUOUS(workspace_buffer);
@@ -33,7 +32,6 @@ void BatchPrefillWithPagedKVCachePyTorchWrapper::BeginForward(
3332
CHECK_DIM(1, workspace_buffer);
3433
qo_indptr = qo_indptr.to(torch::kCPU).to(torch::kInt32);
3534
paged_kv_indptr = paged_kv_indptr.to(torch::kCPU).to(torch::kInt32);
36-
paged_kv_last_page_len = paged_kv_last_page_len.to(torch::kCPU).to(torch::kInt32);
3735

3836
size_t workspace_size_in_bytes = workspace_buffer.size(0) * workspace_buffer.element_size();
3937
cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream();
@@ -43,9 +41,8 @@ void BatchPrefillWithPagedKVCachePyTorchWrapper::BeginForward(
4341
cudaError_t status = handler_->BeginForward<q_type, int32_t>(
4442
static_cast<void*>(workspace_buffer.data_ptr()), workspace_size_in_bytes,
4543
static_cast<int32_t*>(qo_indptr.data_ptr()),
46-
static_cast<int32_t*>(paged_kv_indptr.data_ptr()),
47-
static_cast<int32_t*>(paged_kv_last_page_len.data_ptr()), batch_size, num_qo_heads,
48-
num_kv_heads, head_dim, page_size);
44+
static_cast<int32_t*>(paged_kv_indptr.data_ptr()), batch_size, num_qo_heads, num_kv_heads,
45+
head_dim, page_size);
4946
TORCH_CHECK(status == cudaSuccess, "BatchPrefillWithPagedKVCache failed with error ",
5047
cudaGetErrorString(status));
5148
return true;
@@ -285,7 +282,7 @@ void BatchPrefillWithRaggedKVCachePyTorchWrapper::BeginForward(
285282
cudaError_t status = handler_->BeginForward<q_type, int32_t>(
286283
static_cast<void*>(workspace_buffer.data_ptr()), workspace_size_in_bytes,
287284
static_cast<int32_t*>(qo_indptr.data_ptr()), static_cast<int32_t*>(kv_indptr.data_ptr()),
288-
/*last_page_len=*/nullptr, batch_size, num_qo_heads, num_kv_heads, head_dim,
285+
batch_size, num_qo_heads, num_kv_heads, head_dim,
289286
/*page_size=*/1);
290287
TORCH_CHECK(status == cudaSuccess, "BatchPrefillWithPagedKVCache failed with error ",
291288
cudaGetErrorString(status));

python/csrc/flashinfer_ops.h

+3-3
Original file line numberDiff line numberDiff line change
@@ -112,9 +112,9 @@ class BatchDecodeWithPagedKVCachePyTorchWrapper {
112112
class BatchPrefillWithPagedKVCachePyTorchWrapper {
113113
public:
114114
void BeginForward(torch::Tensor workspace_buffer, torch::Tensor qo_indptr,
115-
torch::Tensor page_kv_indptr, torch::Tensor page_kv_last_page_len,
116-
unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads,
117-
unsigned int head_dim, unsigned page_size, torch::Tensor empty_q_data);
115+
torch::Tensor page_kv_indptr, unsigned int batch_size,
116+
unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int head_dim,
117+
unsigned page_size, torch::Tensor empty_q_data);
118118
void EndForward();
119119
bool IsCUDAGraphEnabled() const { return handler_->IsCUDAGraphEnabled(); }
120120
void UpdatePageLockedBufferSize(uint32_t max_workspace_size_in_bytes);

python/flashinfer/decode.py

-1
Original file line numberDiff line numberDiff line change
@@ -730,7 +730,6 @@ def begin_forward(
730730
self._workspace_buffer,
731731
self._qo_indptr_buf,
732732
indptr,
733-
last_page_len,
734733
batch_size,
735734
num_qo_heads,
736735
num_kv_heads,

python/flashinfer/prefill.py

-1
Original file line numberDiff line numberDiff line change
@@ -773,7 +773,6 @@ def begin_forward(
773773
self._workspace_buffer,
774774
qo_indptr,
775775
paged_kv_indptr,
776-
paged_kv_last_page_len,
777776
batch_size,
778777
num_qo_heads,
779778
num_kv_heads,

src/bench_batch_decode.cu

+3-4
Original file line numberDiff line numberDiff line change
@@ -149,10 +149,9 @@ void bench_flashinfer_batch_decode_with_prefill(nvbench::state& state) {
149149
size_t workspace_size_in_bytes = 128 * 1024 * 1024;
150150
thrust::device_vector<char> buffer(workspace_size_in_bytes);
151151

152-
handler.BeginForward<T, int32_t>((void*)thrust::raw_pointer_cast(buffer.data()),
153-
workspace_size_in_bytes, qo_indptr_h.data(),
154-
kv_indptr_host.data(), kv_last_page_len_host.data(), batch_size,
155-
num_qo_heads, num_kv_heads, head_dim, page_size);
152+
handler.BeginForward<T, int32_t>(
153+
(void*)thrust::raw_pointer_cast(buffer.data()), workspace_size_in_bytes, qo_indptr_h.data(),
154+
kv_indptr_host.data(), batch_size, num_qo_heads, num_kv_heads, head_dim, page_size);
156155

157156
state.exec(nvbench::exec_tag::sync, [&](nvbench::launch&) {
158157
cudaError_t status =

src/bench_cascade.cu

+2-4
Original file line numberDiff line numberDiff line change
@@ -248,8 +248,7 @@ void bench_two_level_single_prefix_cascade_append(nvbench::state& state) {
248248
thrust::device_vector<char> buffer(workspace_size_in_bytes);
249249
cascade_handler.BeginForward<T, int32_t>(
250250
(void*)thrust::raw_pointer_cast(buffer.data()), workspace_size_in_bytes, qo_indptr_h.data(),
251-
kv_indptr_unique_h.data(), kv_last_page_len_unique_h.data(), batch_size, num_qo_heads,
252-
num_kv_heads, head_dim, page_size);
251+
kv_indptr_unique_h.data(), batch_size, num_qo_heads, num_kv_heads, head_dim, page_size);
253252
state.exec(nvbench::exec_tag::timer, [&](nvbench::launch& launch, auto& timer) {
254253
timer.start();
255254
cudaError_t status = SinglePrefillWithKVCache(
@@ -305,8 +304,7 @@ void bench_two_level_single_prefix_cascade_append(nvbench::state& state) {
305304
thrust::device_vector<char> buffer(workspace_size_in_bytes);
306305
baseline_handler.BeginForward<T, int32_t>(
307306
(void*)thrust::raw_pointer_cast(buffer.data()), workspace_size_in_bytes, qo_indptr_h.data(),
308-
kv_indptr_combined_h.data(), kv_last_page_len_combined_h.data(), batch_size, num_qo_heads,
309-
num_kv_heads, head_dim, page_size);
307+
kv_indptr_combined_h.data(), batch_size, num_qo_heads, num_kv_heads, head_dim, page_size);
310308
state.exec(nvbench::exec_tag::timer, [&](nvbench::launch& launch, auto& timer) {
311309
timer.start();
312310
cudaError_t status =

src/test_batch_prefill.cu

+9-12
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,7 @@ void _TestBatchPagedPrefillKernelOneHotCorrectness(size_t num_kv_heads, size_t n
104104

105105
handler.BeginForward<T, int32_t>((void*)thrust::raw_pointer_cast(buffer.data()),
106106
workspace_size_in_bytes, q_indptr.data(), kv_indptr.data(),
107-
kv_last_page_len.data(), batch_size, num_qo_heads,
108-
num_kv_heads, head_dim, page_size);
107+
batch_size, num_qo_heads, num_kv_heads, head_dim, page_size);
109108

110109
for (uint32_t num_runs = 0; num_runs < 10; ++num_runs) {
111110
auto status = flashinfer::BatchPrefillWithPagedKVCacheWrapper<PageStorage::kIndices,
@@ -190,10 +189,9 @@ void _TestBatchRaggedPrefillKernelCorrectness(size_t num_kv_heads, size_t num_qo
190189
thrust::device_vector<int32_t> append_indptr_device(append_indptr);
191190
thrust::device_vector<int32_t> kv_indptr_device(kv_indptr);
192191

193-
handler.BeginForward<T, int32_t>((void*)thrust::raw_pointer_cast(buffer.data()),
194-
workspace_size_in_bytes, append_indptr.data(), kv_indptr.data(),
195-
/*kv_last_page_len=*/nullptr, batch_size, num_qo_heads,
196-
num_kv_heads, head_dim, /*page_size=*/1);
192+
handler.BeginForward<T, int32_t>(
193+
(void*)thrust::raw_pointer_cast(buffer.data()), workspace_size_in_bytes, append_indptr.data(),
194+
kv_indptr.data(), batch_size, num_qo_heads, num_kv_heads, head_dim, /*page_size=*/1);
197195

198196
auto status = BatchPrefillWithRaggedKVCacheWrapper<T, T, int32_t>(
199197
&handler, thrust::raw_pointer_cast(queries_device.data()),
@@ -321,8 +319,7 @@ void _TestBatchPagedPrefillKernelShortContextCorrectness(size_t num_kv_heads, si
321319

322320
handler.BeginForward<T, int32_t>((void*)thrust::raw_pointer_cast(buffer.data()),
323321
workspace_size_in_bytes, append_indptr.data(), kv_indptr.data(),
324-
kv_last_page_len.data(), batch_size, num_qo_heads, num_kv_heads,
325-
head_dim, page_size);
322+
batch_size, num_qo_heads, num_kv_heads, head_dim, page_size);
326323

327324
auto status =
328325
BatchPrefillWithPagedKVCacheWrapper<PageStorage::kIndices, kv_layout, T, T, int32_t>(
@@ -416,10 +413,10 @@ void _TestBatchPagedPrefillKernelLongContextCorrectness(size_t num_kv_heads, siz
416413
size_t workspace_size_in_bytes = 32 * 1024 * 1024;
417414
thrust::device_vector<char> buffer(workspace_size_in_bytes);
418415

419-
handler.BeginForward<T, int32_t>(
420-
(void*)thrust::raw_pointer_cast(buffer.data()), workspace_size_in_bytes, append_indptr.data(),
421-
kv_indptr.data(), kv_last_page_len.data(),
422-
/*batch_size=*/1, num_qo_heads, num_kv_heads, head_dim, page_size);
416+
handler.BeginForward<T, int32_t>((void*)thrust::raw_pointer_cast(buffer.data()),
417+
workspace_size_in_bytes, append_indptr.data(), kv_indptr.data(),
418+
/*batch_size=*/1, num_qo_heads, num_kv_heads, head_dim,
419+
page_size);
423420

424421
auto status =
425422
BatchPrefillWithPagedKVCacheWrapper<PageStorage::kIndices, kv_layout, T, T, int32_t>(

src/test_cascade.cu

+8-8
Original file line numberDiff line numberDiff line change
@@ -409,14 +409,14 @@ void _TestTwoLevelSinglePrefixCascadeAppendCorrectness(size_t batch_size,
409409
thrust::device_vector<char> buffer_baseline(workspace_size_in_bytes),
410410
buffer_cascade(workspace_size_in_bytes);
411411

412-
baseline_handler.BeginForward<T, int32_t>(
413-
(void*)thrust::raw_pointer_cast(buffer_baseline.data()), workspace_size_in_bytes,
414-
qo_indptr_h.data(), kv_indptr_combined_h.data(), kv_last_page_len_combined_h.data(),
415-
batch_size, num_qo_heads, num_kv_heads, head_dim, page_size);
416-
cascade_handler.BeginForward<T, int32_t>(
417-
(void*)thrust::raw_pointer_cast(buffer_cascade.data()), workspace_size_in_bytes,
418-
qo_indptr_h.data(), kv_indptr_unique_h.data(), kv_last_page_len_unique_h.data(), batch_size,
419-
num_qo_heads, num_kv_heads, head_dim, page_size);
412+
baseline_handler.BeginForward<T, int32_t>((void*)thrust::raw_pointer_cast(buffer_baseline.data()),
413+
workspace_size_in_bytes, qo_indptr_h.data(),
414+
kv_indptr_combined_h.data(), batch_size, num_qo_heads,
415+
num_kv_heads, head_dim, page_size);
416+
cascade_handler.BeginForward<T, int32_t>((void*)thrust::raw_pointer_cast(buffer_cascade.data()),
417+
workspace_size_in_bytes, qo_indptr_h.data(),
418+
kv_indptr_unique_h.data(), batch_size, num_qo_heads,
419+
num_kv_heads, head_dim, page_size);
420420

421421
cudaError_t status = BatchPrefillWithPagedKVCacheWrapper<page_storage, kv_layout, T, T, int32_t>(
422422
&baseline_handler, thrust::raw_pointer_cast(q_d.data()),

src/tvm_wrapper.cu

+3-5
Original file line numberDiff line numberDiff line change
@@ -272,8 +272,8 @@ void _FlashInferAttentionPrefillWithPagedKVCache(int64_t handler_id, DLTensor* q
272272

273273
void _FlashInferAttentionPrefillWithPagedKVCacheBeginForward(
274274
int64_t handler_idx, DLTensor* workspace_buffer, DLTensor* qo_indptr, DLTensor* kv_indptr,
275-
DLTensor* kv_last_page_len, int64_t batch_size, int64_t num_qo_heads, int64_t num_kv_heads,
276-
int64_t head_dim, int64_t page_size, TVMStreamHandle copy_stream) {
275+
int64_t batch_size, int64_t num_qo_heads, int64_t num_kv_heads, int64_t head_dim,
276+
int64_t page_size, TVMStreamHandle copy_stream) {
277277
CHECK_EQ(workspace_buffer->ndim, 1) << "The workspace buffer must be a 1-D tensor";
278278
size_t workspace_size_in_bytes = workspace_buffer->shape[0] * workspace_buffer->dtype.bits / 8;
279279
CHECK(handler_idx < max_num_handlers) << "The handler id must be less than " << max_num_handlers;
@@ -290,8 +290,6 @@ void _FlashInferAttentionPrefillWithPagedKVCacheBeginForward(
290290
static_cast<void*>(workspace_buffer->data), workspace_size_in_bytes,
291291
static_cast<dtype_idx*>(qo_indptr->data) + qo_indptr->byte_offset / sizeof(dtype_idx),
292292
static_cast<dtype_idx*>(kv_indptr->data) + kv_indptr->byte_offset / sizeof(dtype_idx),
293-
static_cast<dtype_idx*>(kv_last_page_len->data) +
294-
kv_last_page_len->byte_offset / sizeof(dtype_idx),
295293
batch_size, num_qo_heads, num_kv_heads, head_dim, page_size);
296294
if (status != cudaSuccess) {
297295
LOG(FATAL) << "FlashInfer prefill BeginForward error " << cudaGetErrorString(status);
@@ -568,7 +566,7 @@ void _FlashInferAttentionPrefillWithRaggedKVCacheBeginForward(
568566
static_cast<void*>(workspace_buffer->data), workspace_size_in_bytes,
569567
static_cast<dtype_idx*>(qo_indptr->data) + qo_indptr->byte_offset / sizeof(dtype_idx),
570568
static_cast<dtype_idx*>(kv_indptr->data) + kv_indptr->byte_offset / sizeof(dtype_idx),
571-
/*kv_last_page_len=*/nullptr, batch_size, num_qo_heads, num_kv_heads, head_dim,
569+
batch_size, num_qo_heads, num_kv_heads, head_dim,
572570
/*page_size=*/1);
573571
if (status != cudaSuccess) {
574572
LOG(FATAL) << "FlashInfer PrefillWithRaggedKVCache BeginForward error "

0 commit comments

Comments
 (0)