@@ -137,3 +137,80 @@ void block_sparse_indices_to_vector_sparse_offsets(at::Tensor block_sparse_indic
137
137
TORCH_CHECK (status == cudaSuccess, " BlockSparseIndicesToVectorSparseOffset failed with error: " ,
138
138
cudaGetErrorString (status));
139
139
}
140
+
141
+ void append_paged_mla_kv_cache (at::Tensor append_ckv, at::Tensor append_kpe,
142
+ at::Tensor batch_indices, at::Tensor positions, at::Tensor ckv_cache,
143
+ at::Tensor kpe_cache, at::Tensor kv_indices, at::Tensor kv_indptr,
144
+ at::Tensor kv_last_page_len, int64_t cuda_stream) {
145
+ CHECK_LAST_DIM_CONTIGUOUS (append_ckv);
146
+ CHECK_LAST_DIM_CONTIGUOUS (append_kpe);
147
+ CHECK_INPUT (batch_indices);
148
+ CHECK_INPUT (positions);
149
+ // NOTE(Zihao): doesn't have to be contiguous
150
+ CHECK_LAST_DIM_CONTIGUOUS_INPUT (ckv_cache);
151
+ CHECK_LAST_DIM_CONTIGUOUS_INPUT (kpe_cache);
152
+ CHECK_INPUT (kv_indices);
153
+ CHECK_INPUT (kv_indptr);
154
+ CHECK_INPUT (kv_last_page_len);
155
+ CHECK_DIM (2 , append_ckv);
156
+ CHECK_DIM (2 , append_kpe);
157
+ CHECK_DIM (1 , batch_indices);
158
+ CHECK_DIM (1 , positions);
159
+ CHECK_DIM (3 , ckv_cache);
160
+ CHECK_DIM (3 , kpe_cache);
161
+ CHECK_DIM (1 , kv_indices);
162
+ CHECK_DIM (1 , kv_indptr);
163
+ CHECK_DIM (1 , kv_last_page_len);
164
+ unsigned int nnz = append_ckv.size (0 );
165
+ unsigned int batch_size = kv_last_page_len.size (0 );
166
+ CHECK_EQ (kv_indptr.size (0 ), batch_size + 1 );
167
+ CHECK_EQ (batch_indices.size (0 ), nnz);
168
+ CHECK_EQ (positions.size (0 ), nnz);
169
+ auto device = append_ckv.device ();
170
+ CHECK_EQ (append_ckv.device (), device);
171
+ CHECK_EQ (append_kpe.device (), device);
172
+ CHECK_EQ (ckv_cache.device (), device);
173
+
174
+ CHECK_EQ (kv_indices.device (), device);
175
+ CHECK_EQ (kv_indptr.device (), device);
176
+ CHECK_EQ (kv_last_page_len.device (), device);
177
+
178
+ unsigned int page_size, ckv_dim, kpe_dim;
179
+ page_size = ckv_cache.size (1 );
180
+ ckv_dim = ckv_cache.size (2 );
181
+ kpe_dim = kpe_cache.size (2 );
182
+
183
+ // get kv_cache_strides
184
+ const int64_t * ckv_strides = ckv_cache.strides ().data ();
185
+ const int64_t * kpe_strides = kpe_cache.strides ().data ();
186
+
187
+ auto append_ckv_strides = append_ckv.strides ();
188
+ auto append_ckv_stride_n = append_ckv_strides[0 ];
189
+ auto append_kpe_strides = append_kpe.strides ();
190
+ auto append_kpe_stride_n = append_kpe_strides[0 ];
191
+
192
+ CHECK_EQ (append_ckv.size (1 ), ckv_dim);
193
+ CHECK_EQ (append_kpe.size (1 ), kpe_dim);
194
+
195
+ auto kv_scalar_dtype = ckv_cache.scalar_type ();
196
+
197
+ cudaStream_t stream = reinterpret_cast <cudaStream_t>(cuda_stream);
198
+ bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE (kv_scalar_dtype, c_type, [&] {
199
+ paged_kv_mla_t <c_type, int32_t > paged_mla_kv (
200
+ page_size, ckv_dim, kpe_dim, batch_size, static_cast <c_type*>(ckv_cache.data_ptr ()),
201
+ ckv_strides, static_cast <c_type*>(kpe_cache.data_ptr ()), kpe_strides,
202
+ static_cast <int32_t *>(kv_indices.data_ptr ()), static_cast <int32_t *>(kv_indptr.data_ptr ()),
203
+ static_cast <int32_t *>(kv_last_page_len.data_ptr ()));
204
+ cudaError_t status =
205
+ AppendPagedKVMlaCache (paged_mla_kv, static_cast <c_type*>(append_ckv.data_ptr ()),
206
+ static_cast <c_type*>(append_kpe.data_ptr ()),
207
+ static_cast <int32_t *>(batch_indices.data_ptr ()),
208
+ static_cast <int32_t *>(positions.data_ptr ()), nnz, append_ckv_stride_n,
209
+ append_kpe_stride_n, stream);
210
+ TORCH_CHECK (status == cudaSuccess,
211
+ " AppendPagedKVMlaCache failed with error: " , cudaGetErrorString (status));
212
+ return true ;
213
+ });
214
+
215
+ TORCH_CHECK (success, " AppendPagedKVMlaCache failed to dispatch with dtype " , kv_scalar_dtype);
216
+ }
0 commit comments