|
| 1 | +/* |
| 2 | + * Copyright (c) 2024 by FlashInfer team. |
| 3 | + * |
| 4 | + * Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | + * you may not use this file except in compliance with the License. |
| 6 | + * You may obtain a copy of the License at |
| 7 | + * |
| 8 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | + * |
| 10 | + * Unless required by applicable law or agreed to in writing, software |
| 11 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | + * See the License for the specific language governing permissions and |
| 14 | + * limitations under the License. |
| 15 | + */ |
| 16 | +#ifndef FLASHINFER_ATTENTION_CUTLASS_MLA_CUH_ |
| 17 | +#define FLASHINFER_ATTENTION_CUTLASS_MLA_CUH_ |
| 18 | +#include <sstream> |
| 19 | + |
| 20 | +#include "../cutlass_utils.cuh" |
| 21 | +#include "../exception.h" |
| 22 | +#include "cutlass/kernel_hardware_info.h" |
| 23 | + |
| 24 | +// From 3rdparty/cutlass/examples/77_blackwell_fmha |
| 25 | +#include "device/sm100_mla.hpp" |
| 26 | +#include "kernel/sm100_mla_tile_scheduler.hpp" |
| 27 | + |
| 28 | +namespace flashinfer { |
| 29 | + |
| 30 | +namespace attention { |
| 31 | + |
| 32 | +using namespace cute; |
| 33 | +using namespace cutlass::fmha::kernel; |
| 34 | + |
| 35 | +#define CUTLASS_CHECK(cmd) \ |
| 36 | + do { \ |
| 37 | + auto status = cmd; \ |
| 38 | + if (status != cutlass::Status::kSuccess) { \ |
| 39 | + std::ostringstream err_msg; \ |
| 40 | + err_msg << "cutlass " << #cmd << " failed: " << cutlassGetStatusString(status); \ |
| 41 | + FLASHINFER_ERROR(err_msg.str()); \ |
| 42 | + } \ |
| 43 | + } while (0) |
| 44 | + |
| 45 | +template <bool v> |
| 46 | +struct IsPersistent { |
| 47 | + static const bool value = v; |
| 48 | +}; |
| 49 | + |
| 50 | +template <typename T, typename PersistenceOption = IsPersistent<true>> |
| 51 | +struct MlaSm100 { |
| 52 | + using Element = T; |
| 53 | + using ElementAcc = float; |
| 54 | + using ElementOut = T; |
| 55 | + |
| 56 | + using TileShape = Shape<_128, _128, Shape<_512, _64>>; |
| 57 | + using TileShapeH = cute::tuple_element_t<0, TileShape>; |
| 58 | + using TileShapeD = cute::tuple_element_t<2, TileShape>; |
| 59 | + |
| 60 | + // H K (D_latent D_rope) B |
| 61 | + using ProblemShape = cute::tuple<TileShapeH, int, TileShapeD, int>; |
| 62 | + |
| 63 | + using StrideQ = cute::tuple<int64_t, _1, int64_t>; // H D B |
| 64 | + using StrideK = cute::tuple<int64_t, _1, int64_t>; // K D B |
| 65 | + using StrideO = StrideK; // H D B |
| 66 | + using StrideLSE = cute::tuple<_1, int>; // H B |
| 67 | + |
| 68 | + using TileScheduler = |
| 69 | + std::conditional_t<PersistenceOption::value, Sm100MlaPersistentTileScheduler, |
| 70 | + Sm100MlaIndividualTileScheduler>; |
| 71 | + |
| 72 | + using FmhaKernel = cutlass::fmha::kernel::Sm100FmhaMlaKernelTmaWarpspecialized< |
| 73 | + TileShape, Element, ElementAcc, ElementOut, ElementAcc, TileScheduler, /*kIsCpAsync=*/true>; |
| 74 | + using Fmha = cutlass::fmha::device::MLA<FmhaKernel>; |
| 75 | +}; |
| 76 | + |
| 77 | +template <typename T> |
| 78 | +typename T::Fmha::Arguments args_from_options(void* out_ptr, void* lse_ptr, void* q_absorbed_ptr, |
| 79 | + void* ckv_kpe_cache_ptr, void* seq_lens_ptr, |
| 80 | + void* page_table_ptr, int batches, |
| 81 | + int page_count_per_seq, int page_count_total, |
| 82 | + int page_size, int device_index) { |
| 83 | + cutlass::KernelHardwareInfo hw_info; |
| 84 | + hw_info.device_id = device_index; |
| 85 | + hw_info.sm_count = |
| 86 | + cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); |
| 87 | + |
| 88 | + int max_seq_len = page_size * page_count_per_seq; |
| 89 | + using TileShapeH = typename T::TileShapeH; |
| 90 | + using TileShapeD = typename T::TileShapeD; |
| 91 | + auto problem_shape = cute::make_tuple(TileShapeH{}, max_seq_len, TileShapeD{}, batches); |
| 92 | + |
| 93 | + auto [H, K, D, B] = problem_shape; |
| 94 | + auto [D_latent, D_rope] = D; |
| 95 | + |
| 96 | + // the scale is based on the non-absorbed sizes, change as appropriate |
| 97 | + // we can't determine this parameter from the info we have, it's an input |
| 98 | + int D_non_latent = 128; |
| 99 | + float scale = 1.0 / sqrt(1.0 * (D_non_latent + D_rope)); |
| 100 | + |
| 101 | + using StrideQ = typename T::StrideQ; |
| 102 | + using StrideK = typename T::StrideK; |
| 103 | + using StrideO = typename T::StrideO; |
| 104 | + using StrideLSE = typename T::StrideLSE; |
| 105 | + |
| 106 | + StrideQ stride_Q = cute::make_tuple(static_cast<int64_t>(0 + D_latent + D_rope), _1{}, |
| 107 | + static_cast<int64_t>(H * (0 + D_latent + D_rope))); |
| 108 | + StrideK stride_C = cute::make_tuple(static_cast<int64_t>(0 + D_latent + D_rope), _1{}, |
| 109 | + static_cast<int64_t>(page_size * (D_latent + D_rope))); |
| 110 | + StrideLSE stride_PT = cute::make_stride(_1{}, page_count_per_seq); |
| 111 | + StrideLSE stride_LSE = cute::make_tuple(_1{}, 0 + H); |
| 112 | + StrideO stride_O = cute::make_tuple(static_cast<int64_t>(0 + D_latent), _1{}, |
| 113 | + static_cast<int64_t>(0 + H * D_latent)); |
| 114 | + |
| 115 | + using Element = typename T::Element; |
| 116 | + using ElementOut = typename T::ElementOut; |
| 117 | + using ElementAcc = typename T::ElementAcc; |
| 118 | + auto Q_ptr = reinterpret_cast<Element*>(q_absorbed_ptr); |
| 119 | + auto C_ptr = reinterpret_cast<Element*>(ckv_kpe_cache_ptr); |
| 120 | + typename T::Fmha::Arguments arguments{ |
| 121 | + problem_shape, |
| 122 | + {scale, Q_ptr, stride_Q, Q_ptr + D_latent, stride_Q, C_ptr, stride_C, C_ptr + D_latent, |
| 123 | + stride_C, reinterpret_cast<int*>(seq_lens_ptr), reinterpret_cast<int*>(page_table_ptr), |
| 124 | + stride_PT, page_count_total, page_size}, |
| 125 | + {reinterpret_cast<ElementOut*>(out_ptr), stride_O, |
| 126 | + // static_cast<ElementAcc*>(lse.data_ptr()), stride_LSE}, |
| 127 | + static_cast<ElementAcc*>(nullptr), stride_LSE}, |
| 128 | + hw_info, |
| 129 | + -1, // split_kv |
| 130 | + nullptr, // is_var_split_kv=false |
| 131 | + }; |
| 132 | + // TODO(kaixih@nvidia): When split_kv=-1 and is_var_split_kv=false, we compute |
| 133 | + // split_kv automatically based on batch size and sequence length to balance |
| 134 | + // workload across available SMs. Consider using var_split_kv for manual |
| 135 | + // control if needed. |
| 136 | + T::Fmha::set_split_kv(arguments); |
| 137 | + return arguments; |
| 138 | +} |
| 139 | + |
| 140 | +template <typename Element> |
| 141 | +cudaError_t runMla(void* workspace_ptr, void* out_ptr, void* lse_ptr, void* q_absorbed_ptr, |
| 142 | + void* ckv_kpe_cache_ptr, void* seq_lens_ptr, void* page_table_ptr, int batches, |
| 143 | + int page_count_per_seq, int page_count_total, int page_size, int device_index, |
| 144 | + cudaStream_t stream) { |
| 145 | + using MlaSm100Type = MlaSm100<Element>; |
| 146 | + typename MlaSm100Type::Fmha fmha; |
| 147 | + auto arguments = args_from_options<MlaSm100Type>( |
| 148 | + out_ptr, lse_ptr, q_absorbed_ptr, ckv_kpe_cache_ptr, seq_lens_ptr, page_table_ptr, batches, |
| 149 | + page_count_per_seq, page_count_total, page_size, device_index); |
| 150 | + |
| 151 | + CUTLASS_CHECK(fmha.can_implement(arguments)); |
| 152 | + |
| 153 | + CUTLASS_CHECK(fmha.initialize(arguments, workspace_ptr, stream)); |
| 154 | + |
| 155 | + CUTLASS_CHECK(fmha.run(arguments, workspace_ptr, stream)); |
| 156 | + |
| 157 | + return cudaSuccess; |
| 158 | +} |
| 159 | + |
| 160 | +} // namespace attention |
| 161 | + |
| 162 | +} // namespace flashinfer |
| 163 | +#endif // FLASHINFER_ATTENTION_CUTLASS_MLA_CUH_ |
0 commit comments