|
| 1 | +""" |
| 2 | +Copyright (c) 2024 by FlashInfer team. |
| 3 | +
|
| 4 | +Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +you may not use this file except in compliance with the License. |
| 6 | +You may obtain a copy of the License at |
| 7 | +
|
| 8 | + http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +
|
| 10 | +Unless required by applicable law or agreed to in writing, software |
| 11 | +distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +See the License for the specific language governing permissions and |
| 14 | +limitations under the License. |
| 15 | +""" |
| 16 | + |
| 17 | +import numpy as np |
| 18 | +import torch |
| 19 | +from triton.testing import do_bench |
| 20 | + |
| 21 | +import flashinfer |
| 22 | + |
| 23 | +page_block_size = 16 |
| 24 | +num_kv_heads = 4 |
| 25 | +num_qo_heads = 32 |
| 26 | +head_dim = 128 |
| 27 | + |
| 28 | + |
| 29 | +def bench_batch_decode( |
| 30 | + batch_size, |
| 31 | + seq_len, |
| 32 | + num_qo_heads, |
| 33 | + num_kv_heads, |
| 34 | + head_dim, |
| 35 | + page_block_size, |
| 36 | + q_dtype, |
| 37 | + kv_dtype, |
| 38 | +): |
| 39 | + np.random.seed(42) |
| 40 | + seq_lens = torch.full((batch_size,), seq_len) |
| 41 | + seq_lens_blocks = torch.ceil(seq_lens / page_block_size).int() |
| 42 | + kv_indptr = torch.cat([torch.tensor([0]), torch.cumsum(seq_lens_blocks, 0)], dim=0) |
| 43 | + kv_indptr = kv_indptr.int() |
| 44 | + last_page_len = seq_lens - (seq_lens_blocks - 1) * page_block_size |
| 45 | + last_page_len = last_page_len.int() |
| 46 | + num_blocks = kv_indptr[-1].item() |
| 47 | + |
| 48 | + q = torch.rand(batch_size, num_qo_heads, head_dim, dtype=q_dtype, device="cuda:0") |
| 49 | + kv_data = torch.randn( |
| 50 | + num_blocks, 2, page_block_size, num_kv_heads, head_dim, device="cuda:0" |
| 51 | + ).to(kv_dtype) |
| 52 | + workspace_buffer = torch.empty( |
| 53 | + 128 * 1024 * 1024, dtype=torch.uint8, device="cuda:0" |
| 54 | + ) |
| 55 | + wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( |
| 56 | + workspace_buffer, kv_layout="NHD", use_tensor_cores=True |
| 57 | + ) |
| 58 | + wrapper.plan( |
| 59 | + kv_indptr.to(0), |
| 60 | + torch.arange(num_blocks).int().to(0), |
| 61 | + last_page_len.to(0), |
| 62 | + num_qo_heads, |
| 63 | + num_kv_heads, |
| 64 | + head_dim, |
| 65 | + page_block_size, |
| 66 | + data_type=kv_dtype, |
| 67 | + q_data_type=q_dtype, |
| 68 | + ) |
| 69 | + |
| 70 | + ms = do_bench(lambda: wrapper.run(q, kv_data)) |
| 71 | + |
| 72 | + io = q.numel() * q.element_size() + kv_data.numel() * kv_data.element_size() |
| 73 | + print( |
| 74 | + f"batch_size={batch_size}, seq_len={seq_len}, num_qo_heads={num_qo_heads}, num_kv_heads={num_kv_heads}, head_dim={head_dim}, page_block_size={page_block_size}, q_dtype={q_dtype}, kv_dtype={kv_dtype}" |
| 75 | + ) |
| 76 | + print(f"execution time: {ms}ms") |
| 77 | + print(f"memory bandwidth: {io / ms / 1024 / 1024 :.2f} GB/s") |
| 78 | + |
| 79 | + |
| 80 | +if __name__ == "__main__": |
| 81 | + for q_dtype in [torch.bfloat16]: |
| 82 | + for kv_dtype in [torch.bfloat16, torch.float8_e4m3fn]: |
| 83 | + for batch_size in [1, 2, 4, 8, 16, 32, 64, 128, 256, 512]: |
| 84 | + for seq_len in [512, 1024, 2048, 4096, 8192, 16384]: |
| 85 | + bench_batch_decode( |
| 86 | + batch_size, |
| 87 | + seq_len, |
| 88 | + num_qo_heads, |
| 89 | + num_kv_heads, |
| 90 | + head_dim, |
| 91 | + page_block_size, |
| 92 | + q_dtype, |
| 93 | + kv_dtype, |
| 94 | + ) |
0 commit comments