@@ -38,44 +38,60 @@ cudaError_t SinglePrefillWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn*
38
38
uint32_t kv_len, float sm_scale, float rope_scale,
39
39
float rope_theta, cudaStream_t stream);
40
40
41
- template <uint32_t NUM_FRAGS_X , uint32_t HEAD_DIM, LogitsPostHook LOGITS_POST_HOOK,
42
- QKVLayout KV_LAYOUT, PosEncodingMode POS_ENCODING_MODE , bool ALLOW_FP16_QK_REDUCTION,
41
+ template <uint32_t num_frags_x , uint32_t HEAD_DIM, LogitsPostHook LOGITS_POST_HOOK,
42
+ QKVLayout KV_LAYOUT, PosEncodingMode pos_encoding_mode , bool ALLOW_FP16_QK_REDUCTION,
43
43
MaskMode MASK_MODE, typename DTypeIn, typename DTypeOut, typename IdType>
44
44
cudaError_t BatchPrefillWithRaggedKVCacheDispatched (
45
- DTypeIn* q, IdType* request_indices, IdType* tile_indices, IdType* qo_indptr, DTypeIn* k,
46
- DTypeIn* v, IdType* kv_indptr, uint8_t * custom_mask, IdType* qk_indptr, IdType* q_offset,
47
- IdType* k_rope_pos_offset, DTypeOut* o, float * tmp, float * lse, uint32_t batch_size,
48
- uint32_t num_qo_tiles, uint32_t num_qo_heads, uint32_t num_kv_heads, float sm_scale,
49
- float rope_scale, float rope_theta, cudaStream_t stream = nullptr );
45
+ DTypeIn* q, IdType* request_indices, IdType* q_tile_indices, IdType* kv_tile_indices,
46
+ IdType* q_indptr, DTypeIn* k, DTypeIn* v, IdType* kv_indptr, uint8_t * custom_mask,
47
+ IdType* qk_indptr, IdType* q_offset, IdType* k_rope_pos_offset, IdType* o_indptr, DTypeOut* o,
48
+ DTypeOut* tmp_v, float * tmp_s, float * lse, IdType* merge_indptr, bool * block_valid_mask,
49
+ IdType* kv_chunk_size_ptr, const uint32_t total_num_rows, const uint32_t num_qo_heads,
50
+ const uint32_t padded_batch_size, const uint32_t num_kv_heads, const float sm_scale,
51
+ const float rope_scale, const float rope_theta, cudaStream_t stream = nullptr );
50
52
51
- template <PageStorage PAGE_STORAGE , uint32_t NUM_FRAGS_X , uint32_t HEAD_DIM,
52
- LogitsPostHook LOGITS_POST_HOOK, QKVLayout KV_LAYOUT , PosEncodingMode POS_ENCODING_MODE ,
53
+ template <PageStorage page_storage , uint32_t num_frags_x , uint32_t HEAD_DIM,
54
+ LogitsPostHook LOGITS_POST_HOOK, QKVLayout kv_layout , PosEncodingMode pos_encoding_mode ,
53
55
bool ALLOW_FP16_QK_REDUCTION, MaskMode MASK_MODE, typename DTypeIn, typename DTypeOut,
54
56
typename IdType>
55
57
cudaError_t BatchPrefillWithPagedKVCacheDispatched (
56
- DTypeIn* q, IdType* request_indices, IdType* tile_indices, IdType* qo_indptr, IdType* q_offset,
57
- paged_kv_t <PAGE_STORAGE, KV_LAYOUT, DTypeIn, IdType> paged_kv, uint8_t * custom_mask,
58
- IdType* qk_indptr, DTypeOut* o, float * tmp, float * lse, uint32_t num_qo_tiles,
59
- uint32_t num_qo_heads, float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream);
58
+ DTypeIn* q, IdType* request_indices, IdType* q_tile_indices, IdType* kv_tile_indices,
59
+ IdType* q_indptr, IdType* q_offset,
60
+ paged_kv_t <page_storage, kv_layout, DTypeIn, IdType> paged_kv, uint8_t * custom_mask,
61
+ IdType* qk_indptr, IdType* o_indptr, DTypeOut* o, DTypeOut* tmp_v, float * tmp_s, float * lse,
62
+ IdType* merge_indptr, bool * block_valid_mask, IdType* kv_chunk_size_ptr,
63
+ uint32_t total_num_rows, uint32_t num_qo_heads, uint32_t padded_batch_size, float sm_scale,
64
+ float rope_scale, float rope_theta, cudaStream_t stream);
60
65
61
66
template <PageStorage PAGE_STORAGE, uint32_t HEAD_DIM, LogitsPostHook LOGITS_POST_HOOK,
62
67
QKVLayout KV_LAYOUT, PosEncodingMode POS_ENCODING_MODE, bool ALLOW_FP16_QK_REDUCTION,
63
68
MaskMode MASK_MODE, typename DTypeIn, typename DTypeOut, typename IdType>
64
69
cudaError_t BatchPrefillWithPagedKVCacheWrapperDispatched (
65
- BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr , IdType* q_offset,
70
+ BatchPrefillHandler* handler, DTypeIn* q, IdType* q_indptr , IdType* q_offset,
66
71
paged_kv_t <PAGE_STORAGE, KV_LAYOUT, DTypeIn, IdType> paged_kv, uint8_t * custom_mask,
67
72
IdType* qk_indptr, DTypeOut* o, float * lse, uint32_t num_qo_heads, float sm_scale,
68
73
float rope_scale, float rope_theta, cudaStream_t stream) {
69
- float * tmp = nullptr ;
70
- IdType* request_indices = nullptr ;
71
- IdType* tile_indices = nullptr ;
74
+ DTypeOut* tmp_v = nullptr ;
75
+ float * tmp_s = nullptr ;
76
+ IdType *request_indices = nullptr , *qo_tile_indices = nullptr , *kv_tile_indices = nullptr ,
77
+ *o_indptr = nullptr , *merge_indptr = nullptr , *kv_chunk_size_ptr = nullptr ;
78
+ bool * block_valid_mask = nullptr ;
72
79
uint32_t num_frags_x = 0U ;
73
- uint32_t num_qo_tiles = 0U ;
80
+ uint32_t padded_batch_size = 0U ;
81
+ uint32_t total_num_rows = 0U ;
74
82
if (handler->IsForwardStarted ()) {
83
+ tmp_v = handler->GetTempV <DTypeOut>();
84
+ tmp_s = handler->GetTempS ();
75
85
request_indices = handler->GetRequestIndices <IdType>();
76
- tile_indices = handler->GetTileIndices <IdType>();
86
+ qo_tile_indices = handler->GetQOTileIndices <IdType>();
87
+ kv_tile_indices = handler->GetKVTileIndices <IdType>();
88
+ block_valid_mask = handler->GetBlockValidMask ();
89
+ o_indptr = handler->GetOIndptr <IdType>();
90
+ merge_indptr = handler->GetMergeIndptr <IdType>();
91
+ kv_chunk_size_ptr = handler->GetKVChunkSizePtr <IdType>();
77
92
num_frags_x = handler->GetNumFragsX ();
78
- num_qo_tiles = handler->GetNumQOTiles ();
93
+ padded_batch_size = handler->GetPaddedBatchSize ();
94
+ total_num_rows = handler->GetTotalNumRows ();
79
95
} else {
80
96
std::ostringstream err_msg;
81
97
err_msg << " Please call BatchPrefillHandler's BeginForward() before calling "
@@ -87,8 +103,10 @@ cudaError_t BatchPrefillWithPagedKVCacheWrapperDispatched(
87
103
return BatchPrefillWithPagedKVCacheDispatched<
88
104
PAGE_STORAGE, NUM_FRAGS_X, HEAD_DIM, LOGITS_POST_HOOK, KV_LAYOUT, POS_ENCODING_MODE,
89
105
ALLOW_FP16_QK_REDUCTION, MASK_MODE, DTypeIn, DTypeOut, IdType>(
90
- q, request_indices, tile_indices, qo_indptr, q_offset, paged_kv, custom_mask, qk_indptr, o,
91
- tmp, lse, num_qo_heads, num_qo_tiles, sm_scale, rope_scale, rope_theta, stream);
106
+ q, request_indices, qo_tile_indices, kv_tile_indices, q_indptr, q_offset, paged_kv,
107
+ custom_mask, qk_indptr, o_indptr, o, tmp_v, tmp_s, lse, merge_indptr, block_valid_mask,
108
+ kv_chunk_size_ptr, total_num_rows, num_qo_heads, padded_batch_size, sm_scale, rope_scale,
109
+ rope_theta, stream);
92
110
});
93
111
return cudaSuccess;
94
112
}
@@ -97,21 +115,32 @@ template <uint32_t HEAD_DIM, LogitsPostHook LOGITS_POST_HOOK, QKVLayout KV_LAYOU
97
115
PosEncodingMode POS_ENCODING_MODE, bool ALLOW_FP16_QK_REDUCTION, MaskMode MASK_MODE,
98
116
typename DTypeIn, typename DTypeOut, typename IdType>
99
117
cudaError_t BatchPrefillWithRaggedKVCacheWrapperDispatched (
100
- BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr , DTypeIn* k, DTypeIn* v,
118
+ BatchPrefillHandler* handler, DTypeIn* q, IdType* q_indptr , DTypeIn* k, DTypeIn* v,
101
119
IdType* kv_indptr, uint8_t * custom_mask, IdType* qk_indptr, IdType* q_offset,
102
- IdType* k_rope_pos_offset, DTypeOut* o, float * lse, uint32_t batch_size, uint32_t num_qo_heads,
120
+ IdType* k_rope_pos_offset, DTypeOut* o, float * lse, uint32_t num_qo_heads,
103
121
uint32_t num_kv_heads, float sm_scale, float rope_scale, float rope_theta,
104
122
cudaStream_t stream) {
105
- float * tmp = nullptr ;
106
- IdType* request_indices = nullptr ;
107
- IdType* tile_indices = nullptr ;
123
+ DTypeOut* tmp_v = nullptr ;
124
+ float * tmp_s = nullptr ;
125
+ IdType *request_indices = nullptr , *qo_tile_indices = nullptr , *kv_tile_indices = nullptr ,
126
+ *o_indptr = nullptr , *merge_indptr = nullptr , *kv_chunk_size_ptr = nullptr ;
127
+ bool * block_valid_mask = nullptr ;
108
128
uint32_t num_frags_x = 0U ;
109
- uint32_t num_qo_tiles = 0U ;
129
+ uint32_t padded_batch_size = 0U ;
130
+ uint32_t total_num_rows = 0U ;
110
131
if (handler->IsForwardStarted ()) {
132
+ tmp_v = handler->GetTempV <DTypeOut>();
133
+ tmp_s = handler->GetTempS ();
111
134
request_indices = handler->GetRequestIndices <IdType>();
112
- tile_indices = handler->GetTileIndices <IdType>();
135
+ qo_tile_indices = handler->GetQOTileIndices <IdType>();
136
+ kv_tile_indices = handler->GetKVTileIndices <IdType>();
137
+ block_valid_mask = handler->GetBlockValidMask ();
138
+ o_indptr = handler->GetOIndptr <IdType>();
139
+ merge_indptr = handler->GetMergeIndptr <IdType>();
140
+ kv_chunk_size_ptr = handler->GetKVChunkSizePtr <IdType>();
113
141
num_frags_x = handler->GetNumFragsX ();
114
- num_qo_tiles = handler->GetNumQOTiles ();
142
+ padded_batch_size = handler->GetPaddedBatchSize ();
143
+ total_num_rows = handler->GetTotalNumRows ();
115
144
} else {
116
145
std::ostringstream err_msg;
117
146
err_msg << " Please call BatchPrefillHandler's BeginForward() before calling "
@@ -123,9 +152,10 @@ cudaError_t BatchPrefillWithRaggedKVCacheWrapperDispatched(
123
152
return BatchPrefillWithRaggedKVCacheDispatched<
124
153
NUM_FRAGS_X, HEAD_DIM, LOGITS_POST_HOOK, KV_LAYOUT, POS_ENCODING_MODE,
125
154
ALLOW_FP16_QK_REDUCTION, MASK_MODE, DTypeIn, DTypeOut, IdType>(
126
- q, request_indices, tile_indices, qo_indptr, k, v, kv_indptr, custom_mask, qk_indptr,
127
- q_offset, k_rope_pos_offset, o, tmp, lse, batch_size, num_qo_heads, num_qo_tiles,
128
- num_kv_heads, sm_scale, rope_scale, rope_theta, stream);
155
+ q, request_indices, qo_tile_indices, kv_tile_indices, q_indptr, k, v, kv_indptr,
156
+ custom_mask, qk_indptr, q_offset, k_rope_pos_offset, o_indptr, o, tmp_v, tmp_s, lse,
157
+ merge_indptr, block_valid_mask, kv_chunk_size_ptr, total_num_rows, num_qo_heads,
158
+ padded_batch_size, num_kv_heads, sm_scale, rope_scale, rope_theta, stream);
129
159
});
130
160
return cudaSuccess;
131
161
}
0 commit comments