@@ -83,9 +83,9 @@ std::vector<torch::Tensor> BatchPrefillWithPagedKVCachePyTorchWrapper::Forward(
83
83
num_kv_heads = paged_kv_data.size (3 );
84
84
}
85
85
CHECK_GQA_HEAD_DIVISIBLE (num_qo_heads, num_kv_heads);
86
- CHECK_EQ (qo_indptr.size (0 ), batch_size + 1 );
87
- CHECK_EQ (paged_kv_indptr.size (0 ), batch_size + 1 );
88
- CHECK_EQ (paged_kv_last_page_len.size (0 ), batch_size);
86
+ CHECK_GE (qo_indptr.size (0 ), batch_size + 1 );
87
+ CHECK_GE (paged_kv_indptr.size (0 ), batch_size + 1 );
88
+ CHECK_GE (paged_kv_last_page_len.size (0 ), batch_size);
89
89
CHECK_EQ (paged_kv_data.size (1 ), 2 );
90
90
CHECK_EQ (paged_kv_data.size (4 ), head_dim);
91
91
qo_indptr = qo_indptr.to (torch::kInt32 );
@@ -186,12 +186,12 @@ std::vector<torch::Tensor> BatchPrefillWithPagedKVCachePyTorchWrapper::ForwardCu
186
186
num_kv_heads = paged_kv_data.size (3 );
187
187
}
188
188
CHECK_GQA_HEAD_DIVISIBLE (num_qo_heads, num_kv_heads);
189
- CHECK_EQ (qo_indptr.size (0 ), batch_size + 1 );
190
- CHECK_EQ (paged_kv_indptr.size (0 ), batch_size + 1 );
191
- CHECK_EQ (paged_kv_last_page_len.size (0 ), batch_size);
189
+ CHECK_GE (qo_indptr.size (0 ), batch_size + 1 );
190
+ CHECK_GE (paged_kv_indptr.size (0 ), batch_size + 1 );
191
+ CHECK_GE (paged_kv_last_page_len.size (0 ), batch_size);
192
192
CHECK_EQ (paged_kv_data.size (1 ), 2 );
193
193
CHECK_EQ (paged_kv_data.size (4 ), head_dim);
194
- CHECK_EQ (qk_indptr.size (0 ), batch_size + 1 );
194
+ CHECK_GE (qk_indptr.size (0 ), batch_size + 1 );
195
195
qo_indptr = qo_indptr.to (torch::kInt32 );
196
196
paged_kv_indptr = paged_kv_indptr.to (torch::kInt32 );
197
197
paged_kv_indices = paged_kv_indices.to (torch::kInt32 );
@@ -303,7 +303,7 @@ std::vector<torch::Tensor> BatchPrefillWithRaggedKVCachePyTorchWrapper::Forward(
303
303
int64_t nnz_qo = q.size (0 );
304
304
int64_t num_qo_heads = q.size (1 );
305
305
int64_t head_dim = q.size (2 );
306
- CHECK_EQ (kv_indptr.size (0 ), batch_size + 1 );
306
+ CHECK_GE (kv_indptr.size (0 ), batch_size + 1 );
307
307
int64_t num_kv_heads = (kv_layout_ == QKVLayout::kNHD ) ? k.size (1 ) : k.size (0 );
308
308
CHECK_EQ (k.size (0 ), v.size (0 ));
309
309
CHECK_EQ (k.size (1 ), v.size (1 ));
@@ -366,8 +366,8 @@ std::vector<torch::Tensor> BatchPrefillWithRaggedKVCachePyTorchWrapper::Forward(
366
366
std::vector<torch::Tensor> BatchPrefillWithRaggedKVCachePyTorchWrapper::ForwardCustomMask (
367
367
torch::Tensor q, torch::Tensor qo_indptr, torch::Tensor k, torch::Tensor v,
368
368
torch::Tensor kv_indptr, torch::Tensor custom_mask, torch::Tensor qk_indptr,
369
- unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, float sm_scale, float rope_scale,
370
- float rope_theta, bool return_lse) {
369
+ unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction,
370
+ float sm_scale, float rope_scale, float rope_theta, bool return_lse) {
371
371
CHECK_INPUT (q);
372
372
CHECK_INPUT (qo_indptr);
373
373
CHECK_INPUT (k);
@@ -386,8 +386,8 @@ std::vector<torch::Tensor> BatchPrefillWithRaggedKVCachePyTorchWrapper::ForwardC
386
386
int64_t nnz_qo = q.size (0 );
387
387
int64_t num_qo_heads = q.size (1 );
388
388
int64_t head_dim = q.size (2 );
389
- CHECK_EQ (kv_indptr.size (0 ), batch_size + 1 );
390
- CHECK_EQ (qk_indptr.size (0 ), batch_size + 1 );
389
+ CHECK_GE (kv_indptr.size (0 ), batch_size + 1 );
390
+ CHECK_GE (qk_indptr.size (0 ), batch_size + 1 );
391
391
int64_t num_kv_heads = (kv_layout_ == QKVLayout::kNHD ) ? k.size (1 ) : k.size (0 );
392
392
CHECK_EQ (k.size (0 ), v.size (0 ));
393
393
CHECK_EQ (k.size (1 ), v.size (1 ));
0 commit comments