|
| 1 | +#include <optional> |
| 2 | + |
| 3 | +#include "pytorch_extension_utils.h" |
| 4 | + |
| 5 | +#include "mla_config.inc" |
| 6 | + |
| 7 | +#include <flashinfer/attention/decode_mla_cute_sm80.cuh> |
| 8 | +#include <flashinfer/attention/scheduler.cuh> |
| 9 | + |
| 10 | +using namespace flashinfer; |
| 11 | + |
| 12 | +std::vector<int64_t> BatchDecodeWithPagedKVCachePlanMLA( |
| 13 | + at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, |
| 14 | + at::Tensor page_locked_int_workspace_buffer, at::Tensor indptr, unsigned int batch_size, |
| 15 | + unsigned int num_qo_heads, unsigned int page_size, bool enable_cuda_graph, |
| 16 | + int64_t cuda_stream) { |
| 17 | + size_t float_workspace_size_in_bytes = |
| 18 | + float_workspace_buffer.size(0) * float_workspace_buffer.element_size(); |
| 19 | + size_t int_workspace_size_in_bytes = |
| 20 | + int_workspace_buffer.size(0) * int_workspace_buffer.element_size(); |
| 21 | + |
| 22 | + DecodePlanInfo plan_info; |
| 23 | + cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream); |
| 24 | + |
| 25 | + auto work_estimation_func = |
| 26 | + BatchDecodeWithPagedKVCacheWorkEstimationDispatchedMlaCuteSM80<HEAD_DIM_CKV, HEAD_DIM_KPE, QO_TILE_LEN, |
| 27 | + AttentionVariant, Params>; |
| 28 | + cudaError_t status = |
| 29 | + DecodePlan<HEAD_DIM_CKV, flashinfer::PosEncodingMode::kNone, AttentionVariant, Params>( |
| 30 | + static_cast<void*>(float_workspace_buffer.data_ptr()), float_workspace_size_in_bytes, |
| 31 | + static_cast<void*>(int_workspace_buffer.data_ptr()), |
| 32 | + static_cast<void*>(page_locked_int_workspace_buffer.data_ptr()), |
| 33 | + int_workspace_size_in_bytes, plan_info, static_cast<IdType*>(indptr.data_ptr()), |
| 34 | + batch_size, num_qo_heads, page_size, enable_cuda_graph, /*stream=*/stream, |
| 35 | + work_estimation_func); |
| 36 | + |
| 37 | + TORCH_CHECK(status == cudaSuccess, "BatchDecodeWithPagedKVCachePlanMLA failed with error ", |
| 38 | + cudaGetErrorString(status)); |
| 39 | + |
| 40 | + return plan_info.ToVector(); |
| 41 | +} |
| 42 | + |
| 43 | + |
| 44 | +void BatchDecodeWithPagedKVCacheRunMLA( |
| 45 | + at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, |
| 46 | + std::vector<int64_t> plan_info_vec, at::Tensor q_nope, at::Tensor q_pe, |
| 47 | + at::Tensor paged_ckv_cache, at::Tensor paged_kpe_cache, at::Tensor paged_kv_indptr, |
| 48 | + at::Tensor paged_kv_indices, at::Tensor paged_kv_last_page_len, at::Tensor o, float sm_scale, |
| 49 | + int window_left, float logits_soft_cap, float rope_scale, float rope_theta, |
| 50 | + std::optional<at::Tensor> maybe_lse, int64_t cuda_stream) { |
| 51 | + DecodePlanInfo plan_info; |
| 52 | + plan_info.FromVector(plan_info_vec); |
| 53 | + |
| 54 | + auto device = q_nope.device(); |
| 55 | + int64_t batch_size = q_nope.size(0); |
| 56 | + int64_t num_qo_heads = q_nope.size(1); |
| 57 | + int64_t page_size = paged_ckv_cache.size(1); |
| 58 | + |
| 59 | + if (maybe_lse) { |
| 60 | + const auto& lse = *maybe_lse; |
| 61 | + TORCH_CHECK(lse.size(0) == batch_size, lse.size(0), q_nope.size(0)); |
| 62 | + TORCH_CHECK(lse.size(1) == num_qo_heads, lse.size(1), q_nope.size(1)); |
| 63 | + } |
| 64 | + |
| 65 | + TORCH_CHECK(logits_soft_cap >= 0.f, "logits_soft_cap must be non-negative"); |
| 66 | + |
| 67 | + void* float_buffer = static_cast<void*>(float_workspace_buffer.data_ptr()); |
| 68 | + void* int_buffer = static_cast<void*>(int_workspace_buffer.data_ptr()); |
| 69 | + |
| 70 | + paged_kv_mla_t<DTypeKV, IdType> paged_kv( |
| 71 | + page_size, HEAD_DIM_CKV, HEAD_DIM_KPE, batch_size, |
| 72 | + static_cast<DTypeKV*>(paged_ckv_cache.data_ptr()), paged_ckv_cache.strides().data(), |
| 73 | + static_cast<DTypeKV*>(paged_kpe_cache.data_ptr()), paged_kpe_cache.strides().data(), |
| 74 | + static_cast<IdType*>(paged_kv_indices.data_ptr()), |
| 75 | + static_cast<IdType*>(paged_kv_indptr.data_ptr()), |
| 76 | + static_cast<IdType*>(paged_kv_last_page_len.data_ptr())); |
| 77 | + Params params(static_cast<DTypeQ*>(q_nope.data_ptr()), static_cast<DTypeQ*>(q_pe.data_ptr()), |
| 78 | + /*q_offset=*/nullptr, paged_kv, static_cast<DTypeO*>(o.data_ptr()), |
| 79 | + /*lse=*/(maybe_lse ? static_cast<float*>(maybe_lse->data_ptr()) : nullptr), |
| 80 | + num_qo_heads, window_left, logits_soft_cap, sm_scale, rope_scale, rope_theta); |
| 81 | + |
| 82 | + DTypeO* tmp_v = nullptr; |
| 83 | + float* tmp_s = nullptr; |
| 84 | + params.request_indices = |
| 85 | + GetPtrFromBaseOffset<IdType>(int_buffer, plan_info.request_indices_offset); |
| 86 | + params.kv_tile_indices = |
| 87 | + GetPtrFromBaseOffset<IdType>(int_buffer, plan_info.kv_tile_indices_offset); |
| 88 | + params.o_indptr = GetPtrFromBaseOffset<IdType>(int_buffer, plan_info.o_indptr_offset); |
| 89 | + params.kv_chunk_size_ptr = |
| 90 | + GetPtrFromBaseOffset<IdType>(int_buffer, plan_info.kv_chunk_size_ptr_offset); |
| 91 | + if (plan_info.split_kv) { |
| 92 | + tmp_v = GetPtrFromBaseOffset<DTypeO>(float_buffer, plan_info.v_offset); |
| 93 | + tmp_s = GetPtrFromBaseOffset<float>(float_buffer, plan_info.s_offset); |
| 94 | + if (plan_info.enable_cuda_graph) { |
| 95 | + params.block_valid_mask = |
| 96 | + GetPtrFromBaseOffset<bool>(int_buffer, plan_info.block_valid_mask_offset); |
| 97 | + } |
| 98 | + } |
| 99 | + params.padded_batch_size = plan_info.padded_batch_size; |
| 100 | + |
| 101 | + cudaStream_t stream = reinterpret_cast<cudaStream_t>(cuda_stream); |
| 102 | + cudaError_t status = |
| 103 | + BatchDecodeWithPagedKVCacheDispatchedMlaCuteSM80<HEAD_DIM_CKV, HEAD_DIM_KPE, QO_TILE_LEN, |
| 104 | + Params>(params, tmp_v, tmp_s, /*stream=*/stream); |
| 105 | + TORCH_CHECK(status == cudaSuccess, "BatchDecodeWithPagedKVCache failed with error ", |
| 106 | + cudaGetErrorString(status)); |
| 107 | +} |
0 commit comments