@@ -92,7 +92,7 @@ inline std::tuple<bool, uint32_t, uint32_t> PrefillBinarySearchKVChunkSize(
92
92
const uint32_t qo_chunk_size, const uint32_t min_kv_chunk_size = 1 ) {
93
93
int64_t low = min_kv_chunk_size, high = 0 ;
94
94
int64_t batch_size = packed_qo_len_arr.size ();
95
- int64_t max_kv_len;
95
+ int64_t max_kv_len = 0 ;
96
96
for (const int64_t & kv_len : kv_len_arr) {
97
97
max_kv_len = std::max (max_kv_len, kv_len);
98
98
}
@@ -174,9 +174,9 @@ cudaError_t BatchDecodeWithPagedKVCacheWorkEstimationDispatched(
174
174
new_batch_size = batch_size;
175
175
} else {
176
176
// compute max_num_pages_per_batch and new_batch_size
177
- std::vector<IdType> page_indptr_h (batch_size + 1 ), num_pages (batch_size);
177
+ std::vector<IdType> num_pages (batch_size);
178
178
for (uint32_t batch_idx = 0 ; batch_idx < batch_size; ++batch_idx) {
179
- num_pages[batch_idx] = page_indptr_h [batch_idx + 1 ] - page_indptr_h [batch_idx];
179
+ num_pages[batch_idx] = kv_indptr_h [batch_idx + 1 ] - kv_indptr_h [batch_idx];
180
180
}
181
181
std::tie (max_num_pages_per_batch, new_batch_size) =
182
182
PartitionPagedKVCacheBinarySearchMinNumPagePerBatch (max_grid_size, num_kv_heads, num_pages,
@@ -517,14 +517,16 @@ class BatchDecodeHandler {
517
517
};
518
518
519
519
template <typename IdType>
520
- cudaError_t PrefillSplitQOKVIndptr (
521
- bool & split_kv, uint32_t & split_max_batch_size, uint32_t & total_num_tiles_q,
522
- uint32_t & new_batch_size, WarpLayout& warp_layout, uint32_t & kv_chunk_size,
523
- uint32_t & total_num_rows, std::vector<IdType>& request_indices,
524
- std::vector<IdType>& qo_tile_indices, std::vector<IdType>& kv_tile_indices,
525
- std::vector<IdType>& merge_indptr, std::vector<IdType>& o_indptr, IdType* qo_indptr_h,
526
- IdType* kv_indptr_h, IdType* kv_last_page_len_h, uint32_t batch_size, uint32_t num_qo_heads,
527
- uint32_t num_kv_heads, uint32_t head_dim, uint32_t page_size, cudaStream_t stream = nullptr ) {
520
+ cudaError_t PrefillSplitQOKVIndptr (bool & split_kv, uint32_t & split_max_batch_size,
521
+ uint32_t & total_num_tiles_q, uint32_t & new_batch_size,
522
+ WarpLayout& warp_layout, uint32_t & kv_chunk_size,
523
+ uint32_t & total_num_rows, std::vector<IdType>& request_indices,
524
+ std::vector<IdType>& qo_tile_indices,
525
+ std::vector<IdType>& kv_tile_indices,
526
+ std::vector<IdType>& merge_indptr, std::vector<IdType>& o_indptr,
527
+ IdType* qo_indptr_h, IdType* kv_indptr_h, uint32_t batch_size,
528
+ uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t head_dim,
529
+ uint32_t page_size) {
528
530
request_indices.clear ();
529
531
qo_tile_indices.clear ();
530
532
kv_tile_indices.clear ();
@@ -536,8 +538,6 @@ cudaError_t PrefillSplitQOKVIndptr(
536
538
const uint32_t gqa_group_size = num_qo_heads / num_kv_heads;
537
539
total_num_rows = qo_indptr_h[batch_size];
538
540
539
- bool has_kv_last_page_len = kv_last_page_len_h != nullptr ;
540
-
541
541
// step 0: get the number of SMs
542
542
int num_sm = 0 ;
543
543
int dev_id = 0 ;
@@ -570,7 +570,7 @@ cudaError_t PrefillSplitQOKVIndptr(
570
570
// step 2: determine kv_chunk_size
571
571
std::tie (split_kv, kv_chunk_size, new_batch_size) =
572
572
PrefillBinarySearchKVChunkSize (max_grid_size, num_kv_heads, packed_qo_len_arr, kv_len_arr,
573
- qo_chunk_size, /* min_kv_chunk_size=*/ (128 / page_size));
573
+ qo_chunk_size, /* min_kv_chunk_size=*/ (512 / page_size));
574
574
575
575
// step 3: split qo_indptr and kv_indptr
576
576
total_num_tiles_q = 0 ;
@@ -656,9 +656,8 @@ class BatchPrefillHandler {
656
656
657
657
template <typename DTypeOut, typename IdType>
658
658
cudaError_t BeginForward (void * buffer, size_t workspace_size_in_bytes, IdType* qo_indptr_h,
659
- IdType* kv_indptr_h, IdType* kv_last_page_len_h, uint32_t batch_size,
660
- uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t head_dim,
661
- uint32_t page_size) {
659
+ IdType* kv_indptr_h, uint32_t batch_size, uint32_t num_qo_heads,
660
+ uint32_t num_kv_heads, uint32_t head_dim, uint32_t page_size) {
662
661
if (num_qo_heads % num_kv_heads != 0 ) {
663
662
std::ostringstream err_msg;
664
663
err_msg << " num_qo_heads " << num_qo_heads << " should be divisible by num_kv_heads "
@@ -672,8 +671,8 @@ class BatchPrefillHandler {
672
671
FLASHINFER_CUDA_CALL (PrefillSplitQOKVIndptr (
673
672
split_kv, split_max_batch_size, total_num_tiles_q, new_batch_size, warp_layout_,
674
673
kv_chunk_size, total_num_rows_, request_indices_vec, qo_tile_indices_vec,
675
- kv_tile_indices_vec, merge_indptr_vec, o_indptr_vec, qo_indptr_h, kv_indptr_h,
676
- kv_last_page_len_h, batch_size, num_qo_heads, num_kv_heads, head_dim, page_size, stream_ ));
674
+ kv_tile_indices_vec, merge_indptr_vec, o_indptr_vec, qo_indptr_h, kv_indptr_h, batch_size,
675
+ num_qo_heads, num_kv_heads, head_dim, page_size));
677
676
const uint32_t qo_tile_size = get_num_rows_per_cta (warp_layout_);
678
677
679
678
if (IsCUDAGraphEnabled ()) {
0 commit comments