|
| 1 | +""" |
| 2 | +Copyright (c) 2025 by FlashInfer team. |
| 3 | +
|
| 4 | +Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +you may not use this file except in compliance with the License. |
| 6 | +You may obtain a copy of the License at |
| 7 | +
|
| 8 | + http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +
|
| 10 | +Unless required by applicable law or agreed to in writing, software |
| 11 | +distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +See the License for the specific language governing permissions and |
| 14 | +limitations under the License. |
| 15 | +""" |
| 16 | + |
| 17 | +import pytest |
| 18 | +import torch |
| 19 | +import triton |
| 20 | +import triton.language as tl |
| 21 | +from triton.testing import do_bench |
| 22 | + |
| 23 | +import flashinfer |
| 24 | +from flashinfer.gemm import gemm_fp8_nt_blockscaled, gemm_fp8_nt_groupwise |
| 25 | + |
| 26 | + |
| 27 | +@triton.jit |
| 28 | +def _w8a8_block_fp8_matmul( |
| 29 | + # Pointers to inputs and output |
| 30 | + A, |
| 31 | + B, |
| 32 | + C, |
| 33 | + As, |
| 34 | + Bs, |
| 35 | + # Shape for matmul |
| 36 | + M, |
| 37 | + N, |
| 38 | + K, |
| 39 | + # Block size for block-wise quantization |
| 40 | + group_n, |
| 41 | + group_k, |
| 42 | + # Stride for inputs and output |
| 43 | + stride_am, |
| 44 | + stride_ak, |
| 45 | + stride_bk, |
| 46 | + stride_bn, |
| 47 | + stride_cm, |
| 48 | + stride_cn, |
| 49 | + stride_As_m, |
| 50 | + stride_As_k, |
| 51 | + stride_Bs_k, |
| 52 | + stride_Bs_n, |
| 53 | + # Meta-parameters |
| 54 | + BLOCK_SIZE_M: tl.constexpr, |
| 55 | + BLOCK_SIZE_N: tl.constexpr, |
| 56 | + BLOCK_SIZE_K: tl.constexpr, |
| 57 | + GROUP_SIZE_M: tl.constexpr, |
| 58 | +): |
| 59 | + """Triton-accelerated function used to perform linear operations (dot |
| 60 | + product) on input tensors `A` and `B` with block-wise quantization, and store the result in output |
| 61 | + tensor `C`. |
| 62 | + """ |
| 63 | + |
| 64 | + pid = tl.program_id(axis=0) |
| 65 | + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) |
| 66 | + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) |
| 67 | + num_pid_in_group = GROUP_SIZE_M * num_pid_n |
| 68 | + group_id = pid // num_pid_in_group |
| 69 | + first_pid_m = group_id * GROUP_SIZE_M |
| 70 | + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) |
| 71 | + pid_m = first_pid_m + (pid % group_size_m) |
| 72 | + pid_n = (pid % num_pid_in_group) // group_size_m |
| 73 | + |
| 74 | + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M |
| 75 | + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N |
| 76 | + offs_k = tl.arange(0, BLOCK_SIZE_K) |
| 77 | + a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) |
| 78 | + b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) |
| 79 | + |
| 80 | + As_ptrs = As + offs_am * stride_As_m |
| 81 | + offs_bsn = offs_bn // group_n |
| 82 | + Bs_ptrs = Bs + offs_bsn * stride_Bs_n |
| 83 | + |
| 84 | + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) |
| 85 | + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): |
| 86 | + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) |
| 87 | + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) |
| 88 | + |
| 89 | + k_start = k * BLOCK_SIZE_K |
| 90 | + offs_ks = k_start // group_k |
| 91 | + a_s = tl.load(As_ptrs + offs_ks * stride_As_k) |
| 92 | + b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k) |
| 93 | + |
| 94 | + accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :] |
| 95 | + a_ptrs += BLOCK_SIZE_K * stride_ak |
| 96 | + b_ptrs += BLOCK_SIZE_K * stride_bk |
| 97 | + |
| 98 | + if C.dtype.element_ty == tl.bfloat16: |
| 99 | + c = accumulator.to(tl.bfloat16) |
| 100 | + elif C.dtype.element_ty == tl.float16: |
| 101 | + c = accumulator.to(tl.float16) |
| 102 | + else: |
| 103 | + c = accumulator.to(tl.float32) |
| 104 | + |
| 105 | + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) |
| 106 | + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) |
| 107 | + c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] |
| 108 | + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) |
| 109 | + tl.store(c_ptrs, c, mask=c_mask) |
| 110 | + |
| 111 | + |
| 112 | +def triton_w8a8_block_fp8_matmul( |
| 113 | + A: torch.Tensor, |
| 114 | + B: torch.Tensor, |
| 115 | + As: torch.Tensor, |
| 116 | + Bs: torch.Tensor, |
| 117 | + out: torch.Tensor, |
| 118 | +) -> torch.Tensor: |
| 119 | + M = A.shape[0] |
| 120 | + N, K = B.shape |
| 121 | + block_n, block_k = 128, 128 |
| 122 | + |
| 123 | + config = { |
| 124 | + "BLOCK_SIZE_M": 64, |
| 125 | + "BLOCK_SIZE_N": block_n, |
| 126 | + "BLOCK_SIZE_K": block_k, |
| 127 | + "GROUP_SIZE_M": 32, |
| 128 | + "num_warps": 4, |
| 129 | + "num_stages": 3, |
| 130 | + } |
| 131 | + |
| 132 | + def grid(META): |
| 133 | + return ( |
| 134 | + triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), |
| 135 | + ) |
| 136 | + |
| 137 | + _w8a8_block_fp8_matmul[grid]( |
| 138 | + A, |
| 139 | + B, |
| 140 | + out, |
| 141 | + As, |
| 142 | + Bs, |
| 143 | + M, |
| 144 | + N, |
| 145 | + K, |
| 146 | + block_n, |
| 147 | + block_k, |
| 148 | + A.stride(-2), |
| 149 | + A.stride(-1), |
| 150 | + B.stride(1), |
| 151 | + B.stride(0), |
| 152 | + out.stride(-2), |
| 153 | + out.stride(-1), |
| 154 | + As.stride(-2), |
| 155 | + As.stride(-1), |
| 156 | + Bs.stride(1), |
| 157 | + Bs.stride(0), |
| 158 | + **config, |
| 159 | + ) |
| 160 | + |
| 161 | + return out |
| 162 | + |
| 163 | + |
| 164 | +def bench_groupwise_gemm_fp8_blackwell(m, n, k, in_dtype, out_dtype): |
| 165 | + a = torch.randn((m, k), device="cuda").to(in_dtype) |
| 166 | + b = torch.randn((n, k), device="cuda").to(in_dtype) |
| 167 | + a_scale = torch.rand((k // 128, m), dtype=torch.float32, device="cuda") |
| 168 | + b_scale = torch.rand((k // 128, n // 128), dtype=torch.float32, device="cuda") |
| 169 | + |
| 170 | + out = torch.empty((m, n), dtype=out_dtype, device="cuda") |
| 171 | + gemm_fp8_nt_groupwise(a, b, a_scale, b_scale, out=out) |
| 172 | + |
| 173 | + ms = do_bench(lambda: gemm_fp8_nt_groupwise(a, b, a_scale, b_scale, out=out)) |
| 174 | + tflops_per_second = 2 * m * n * k * 1e-9 / ms |
| 175 | + print( |
| 176 | + f"gemm_fp8_nt_groupwise {m} {n} {k} {in_dtype} {out_dtype}: {tflops_per_second:.2f} TFLOPs/s" |
| 177 | + ) |
| 178 | + |
| 179 | + tl_out = torch.empty((m, n), dtype=out_dtype, device="cuda") |
| 180 | + a_scale = a_scale.transpose(0, 1).contiguous() |
| 181 | + b_scale = b_scale.transpose(0, 1).contiguous() |
| 182 | + ms = do_bench(lambda: triton_w8a8_block_fp8_matmul(a, b, a_scale, b_scale, tl_out)) |
| 183 | + tflops_per_second = 2 * m * n * k * 1e-9 / ms |
| 184 | + print( |
| 185 | + f"triton_gemm_fp8_nt_groupwise {m} {n} {k} {in_dtype} {out_dtype}: {tflops_per_second:.2f} TFLOPs/s" |
| 186 | + ) |
| 187 | + |
| 188 | + |
| 189 | +if __name__ == "__main__": |
| 190 | + for m in [1024, 2048, 4096, 8192]: |
| 191 | + for n in [1024, 2048, 4096, 8192]: |
| 192 | + for k in [1024, 2048, 4096, 8192]: |
| 193 | + bench_groupwise_gemm_fp8_blackwell( |
| 194 | + m, n, k, torch.float8_e5m2, torch.bfloat16 |
| 195 | + ) |
0 commit comments