|
| 1 | +/* |
| 2 | + * Copyright (c) 2023 by FlashInfer team. |
| 3 | + * |
| 4 | + * Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | + * you may not use this file except in compliance with the License. |
| 6 | + * You may obtain a copy of the License at |
| 7 | + * |
| 8 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | + * |
| 10 | + * Unless required by applicable law or agreed to in writing, software |
| 11 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | + * See the License for the specific language governing permissions and |
| 14 | + * limitations under the License. |
| 15 | + */ |
| 16 | +#include <thrust/detail/raw_pointer_cast.h> |
| 17 | +#include <thrust/device_vector.h> |
| 18 | + |
| 19 | +#include <cstdint> |
| 20 | +#include <nvbench/nvbench.cuh> |
| 21 | +#include <optional> |
| 22 | + |
| 23 | +#include "flashinfer/attention/handler.cuh" |
| 24 | +#include "flashinfer/layout.cuh" |
| 25 | +#include "flashinfer/pos_enc.cuh" |
| 26 | +#include "flashinfer_ops.cuh" |
| 27 | + |
| 28 | +using namespace flashinfer; |
| 29 | + |
| 30 | +inline uint32_t ceil_div(uint32_t a, uint32_t b) { return (a + b - 1) / b; } |
| 31 | + |
| 32 | +template <typename dtype_in, typename dtype_out, bool append> |
| 33 | +void bench_flashinfer_batch_prefill_with_ragged_kv(nvbench::state& state) { |
| 34 | + size_t kv_len = state.get_int64("kv_len"); |
| 35 | + size_t qo_len = kv_len; |
| 36 | + size_t batch_size = state.get_int64("batch_size"); |
| 37 | + size_t num_qo_heads = state.get_int64("num_qo_heads"); |
| 38 | + size_t num_kv_heads = state.get_int64("num_kv_heads"); |
| 39 | + size_t head_dim = state.get_int64("head_dim"); |
| 40 | + size_t pos_encoding_mode = state.get_int64("pos_encoding_mode"); |
| 41 | + size_t kv_layout = state.get_int64("kv_layout"); |
| 42 | + bool causal = state.get_int64("causal"); |
| 43 | + bool cooperative = state.get_int64("cooperative"); |
| 44 | + bool allow_fp16_qk_reduction = state.get_int64("allow_fp16_qk_reduction"); |
| 45 | + |
| 46 | + // Allocate input data: |
| 47 | + thrust::device_vector<dtype_in> Q(batch_size * qo_len * num_qo_heads * head_dim); |
| 48 | + thrust::device_vector<dtype_in> K(batch_size * kv_len * num_kv_heads * head_dim); |
| 49 | + thrust::device_vector<dtype_in> V(batch_size * kv_len * num_kv_heads * head_dim); |
| 50 | + thrust::device_vector<dtype_out> O(batch_size * qo_len * num_qo_heads * head_dim); |
| 51 | + thrust::device_vector<uint8_t> workspace(128 * 1024 * 1024); |
| 52 | + |
| 53 | + // Provide throughput information: |
| 54 | + state.add_global_memory_reads<dtype_in>( |
| 55 | + (batch_size * qo_len * num_qo_heads + 2 * batch_size * kv_len * num_kv_heads) * head_dim, |
| 56 | + "Read"); |
| 57 | + state.add_global_memory_writes<dtype_out>(qo_len * batch_size * num_qo_heads * head_dim, "Write"); |
| 58 | + |
| 59 | + std::vector<int32_t> qo_indptr_h(batch_size + 1); |
| 60 | + std::vector<int32_t> kv_indptr_h(batch_size + 1); |
| 61 | + |
| 62 | + for (uint32_t i = 0; i <= batch_size; ++i) { |
| 63 | + qo_indptr_h[i] = i * qo_len; |
| 64 | + kv_indptr_h[i] = i * kv_len; |
| 65 | + } |
| 66 | + |
| 67 | + thrust::device_vector<int32_t> qo_indptr_d(qo_indptr_h); |
| 68 | + thrust::device_vector<int32_t> kv_indptr_d(kv_indptr_h); |
| 69 | + |
| 70 | + BatchPrefillHandler handler; |
| 71 | + |
| 72 | + handler.BeginForward<dtype_out>(thrust::raw_pointer_cast(workspace.data()), |
| 73 | + workspace.size() * sizeof(uint8_t), qo_indptr_h.data(), |
| 74 | + kv_indptr_h.data(), batch_size, num_qo_heads, num_kv_heads, |
| 75 | + head_dim, /*page_size=*/1); |
| 76 | + |
| 77 | + state.exec(nvbench::exec_tag::timer, [&](nvbench::launch& launch, auto& timer) { |
| 78 | + timer.start(); |
| 79 | + cudaError_t status; |
| 80 | + status = BatchPrefillWithRaggedKVCacheWrapper<dtype_in, dtype_out, int32_t>( |
| 81 | + &handler, thrust::raw_pointer_cast(Q.data()), thrust::raw_pointer_cast(qo_indptr_d.data()), |
| 82 | + thrust::raw_pointer_cast(K.data()), thrust::raw_pointer_cast(V.data()), |
| 83 | + thrust::raw_pointer_cast(kv_indptr_d.data()), |
| 84 | + /*q_offset=*/nullptr, /*k_rope_pos_offset=*/nullptr, thrust::raw_pointer_cast(O.data()), |
| 85 | + /*lse=*/nullptr, batch_size, num_qo_heads, num_kv_heads, head_dim, causal, |
| 86 | + QKVLayout(kv_layout), PosEncodingMode(pos_encoding_mode), allow_fp16_qk_reduction); |
| 87 | + if (status != cudaSuccess) { |
| 88 | + state.skip("CUDA error: " + std::string(cudaGetErrorString(status))); |
| 89 | + } |
| 90 | + timer.stop(); |
| 91 | + }); |
| 92 | + const auto measured_mean = static_cast<nvbench::float32_t>( |
| 93 | + state.get_summary("nv/cold/time/gpu/mean").get_float64("value")); |
| 94 | + auto& summ = state.add_summary("nv/tflops"); |
| 95 | + summ.set_string("description", "Achieved TFlops/s"); |
| 96 | + summ.set_string("name", "TFlops/s"); |
| 97 | + float tflops; |
| 98 | + if (causal) { |
| 99 | + tflops = (batch_size * (qo_len * (2 * kv_len - qo_len) * 2 * num_qo_heads * head_dim)) / |
| 100 | + measured_mean / 1e12; |
| 101 | + } else { |
| 102 | + tflops = (batch_size * qo_len * kv_len * 4 * num_qo_heads * head_dim) / measured_mean / 1e12; |
| 103 | + } |
| 104 | + summ.set_float64("value", tflops); |
| 105 | +} |
| 106 | + |
| 107 | +#define STR_HELPER(x) #x |
| 108 | +#define STR(x) STR_HELPER(x) |
| 109 | +#define BENCH_FLASHINFER_BATCH_PREFILL_WITH_RAGGED_KV(dtype_in, dtype_out) \ |
| 110 | + auto bench_flashinfer_batch_prefill_with_ragged_kv_##dtype_in##_##dtype_out##_ = \ |
| 111 | + bench_flashinfer_batch_prefill_with_ragged_kv<dtype_in, dtype_out, false>; \ |
| 112 | + NVBENCH_BENCH(bench_flashinfer_batch_prefill_with_ragged_kv_##dtype_in##_##dtype_out##_) \ |
| 113 | + .set_name( \ |
| 114 | + ("bench_flashinfer_batch_prefill_with_ragged_kv_" STR(dtype_in) "_" STR(dtype_out))) \ |
| 115 | + .add_int64_axis("kv_len", {32, 64, 128, 256, 512, 1024, 2048, 4096}) \ |
| 116 | + .add_int64_axis("batch_size", {4, 8, 32}) \ |
| 117 | + .add_int64_axis("num_qo_heads", {32}) \ |
| 118 | + .add_int64_axis("num_kv_heads", {32}) \ |
| 119 | + .add_int64_axis("head_dim", {128}) \ |
| 120 | + .add_int64_axis("causal", {0, 1}) \ |
| 121 | + .add_int64_axis("kv_layout", {0}) \ |
| 122 | + .add_int64_axis("pos_encoding_mode", {0}) \ |
| 123 | + .add_int64_axis("allow_fp16_qk_reduction", {0}) \ |
| 124 | + .add_int64_axis("cooperative", {1}) |
| 125 | + |
| 126 | +BENCH_FLASHINFER_BATCH_PREFILL_WITH_RAGGED_KV(half, half); |
0 commit comments