Skip to content

Commit 7100278

Browse files
authored
[nvidia] cutlass fp8 blockwise/groupwise gemm support (#1045)
1 parent 6c6f1a5 commit 7100278

File tree

10 files changed

+790
-12
lines changed

10 files changed

+790
-12
lines changed

3rdparty/cutlass

Submodule cutlass updated 187 files
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
"""
2+
Copyright (c) 2025 by FlashInfer team.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
import pytest
18+
import torch
19+
import triton
20+
import triton.language as tl
21+
from triton.testing import do_bench
22+
23+
import flashinfer
24+
from flashinfer.gemm import gemm_fp8_nt_blockscaled, gemm_fp8_nt_groupwise
25+
26+
27+
@triton.jit
28+
def _w8a8_block_fp8_matmul(
29+
# Pointers to inputs and output
30+
A,
31+
B,
32+
C,
33+
As,
34+
Bs,
35+
# Shape for matmul
36+
M,
37+
N,
38+
K,
39+
# Block size for block-wise quantization
40+
group_n,
41+
group_k,
42+
# Stride for inputs and output
43+
stride_am,
44+
stride_ak,
45+
stride_bk,
46+
stride_bn,
47+
stride_cm,
48+
stride_cn,
49+
stride_As_m,
50+
stride_As_k,
51+
stride_Bs_k,
52+
stride_Bs_n,
53+
# Meta-parameters
54+
BLOCK_SIZE_M: tl.constexpr,
55+
BLOCK_SIZE_N: tl.constexpr,
56+
BLOCK_SIZE_K: tl.constexpr,
57+
GROUP_SIZE_M: tl.constexpr,
58+
):
59+
"""Triton-accelerated function used to perform linear operations (dot
60+
product) on input tensors `A` and `B` with block-wise quantization, and store the result in output
61+
tensor `C`.
62+
"""
63+
64+
pid = tl.program_id(axis=0)
65+
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
66+
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
67+
num_pid_in_group = GROUP_SIZE_M * num_pid_n
68+
group_id = pid // num_pid_in_group
69+
first_pid_m = group_id * GROUP_SIZE_M
70+
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
71+
pid_m = first_pid_m + (pid % group_size_m)
72+
pid_n = (pid % num_pid_in_group) // group_size_m
73+
74+
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
75+
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
76+
offs_k = tl.arange(0, BLOCK_SIZE_K)
77+
a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
78+
b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
79+
80+
As_ptrs = As + offs_am * stride_As_m
81+
offs_bsn = offs_bn // group_n
82+
Bs_ptrs = Bs + offs_bsn * stride_Bs_n
83+
84+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
85+
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
86+
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
87+
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
88+
89+
k_start = k * BLOCK_SIZE_K
90+
offs_ks = k_start // group_k
91+
a_s = tl.load(As_ptrs + offs_ks * stride_As_k)
92+
b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k)
93+
94+
accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
95+
a_ptrs += BLOCK_SIZE_K * stride_ak
96+
b_ptrs += BLOCK_SIZE_K * stride_bk
97+
98+
if C.dtype.element_ty == tl.bfloat16:
99+
c = accumulator.to(tl.bfloat16)
100+
elif C.dtype.element_ty == tl.float16:
101+
c = accumulator.to(tl.float16)
102+
else:
103+
c = accumulator.to(tl.float32)
104+
105+
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
106+
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
107+
c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
108+
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
109+
tl.store(c_ptrs, c, mask=c_mask)
110+
111+
112+
def triton_w8a8_block_fp8_matmul(
113+
A: torch.Tensor,
114+
B: torch.Tensor,
115+
As: torch.Tensor,
116+
Bs: torch.Tensor,
117+
out: torch.Tensor,
118+
) -> torch.Tensor:
119+
M = A.shape[0]
120+
N, K = B.shape
121+
block_n, block_k = 128, 128
122+
123+
config = {
124+
"BLOCK_SIZE_M": 64,
125+
"BLOCK_SIZE_N": block_n,
126+
"BLOCK_SIZE_K": block_k,
127+
"GROUP_SIZE_M": 32,
128+
"num_warps": 4,
129+
"num_stages": 3,
130+
}
131+
132+
def grid(META):
133+
return (
134+
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
135+
)
136+
137+
_w8a8_block_fp8_matmul[grid](
138+
A,
139+
B,
140+
out,
141+
As,
142+
Bs,
143+
M,
144+
N,
145+
K,
146+
block_n,
147+
block_k,
148+
A.stride(-2),
149+
A.stride(-1),
150+
B.stride(1),
151+
B.stride(0),
152+
out.stride(-2),
153+
out.stride(-1),
154+
As.stride(-2),
155+
As.stride(-1),
156+
Bs.stride(1),
157+
Bs.stride(0),
158+
**config,
159+
)
160+
161+
return out
162+
163+
164+
def bench_groupwise_gemm_fp8_blackwell(m, n, k, in_dtype, out_dtype):
165+
a = torch.randn((m, k), device="cuda").to(in_dtype)
166+
b = torch.randn((n, k), device="cuda").to(in_dtype)
167+
a_scale = torch.rand((k // 128, m), dtype=torch.float32, device="cuda")
168+
b_scale = torch.rand((k // 128, n // 128), dtype=torch.float32, device="cuda")
169+
170+
out = torch.empty((m, n), dtype=out_dtype, device="cuda")
171+
gemm_fp8_nt_groupwise(a, b, a_scale, b_scale, out=out)
172+
173+
ms = do_bench(lambda: gemm_fp8_nt_groupwise(a, b, a_scale, b_scale, out=out))
174+
tflops_per_second = 2 * m * n * k * 1e-9 / ms
175+
print(
176+
f"gemm_fp8_nt_groupwise {m} {n} {k} {in_dtype} {out_dtype}: {tflops_per_second:.2f} TFLOPs/s"
177+
)
178+
179+
tl_out = torch.empty((m, n), dtype=out_dtype, device="cuda")
180+
a_scale = a_scale.transpose(0, 1).contiguous()
181+
b_scale = b_scale.transpose(0, 1).contiguous()
182+
ms = do_bench(lambda: triton_w8a8_block_fp8_matmul(a, b, a_scale, b_scale, tl_out))
183+
tflops_per_second = 2 * m * n * k * 1e-9 / ms
184+
print(
185+
f"triton_gemm_fp8_nt_groupwise {m} {n} {k} {in_dtype} {out_dtype}: {tflops_per_second:.2f} TFLOPs/s"
186+
)
187+
188+
189+
if __name__ == "__main__":
190+
for m in [1024, 2048, 4096, 8192]:
191+
for n in [1024, 2048, 4096, 8192]:
192+
for k in [1024, 2048, 4096, 8192]:
193+
bench_groupwise_gemm_fp8_blackwell(
194+
m, n, k, torch.float8_e5m2, torch.bfloat16
195+
)

csrc/gemm_groupwise_sm100.cu

+75
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
/*
2+
* Copyright (c) 2025 by FlashInfer team.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#include <flashinfer/cutlass_utils.cuh>
17+
#include <flashinfer/gemm/gemm_groupwise_sm100.cuh>
18+
19+
#include "pytorch_extension_utils.h"
20+
21+
using namespace flashinfer;
22+
23+
#define DISPATCH_PYTORCH_INPUT_OUTPUT_DTYPE(input_dtype, output_dtype, c_type_in, c_type_out, ...) \
24+
[&]() -> bool { \
25+
return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(output_dtype, c_type_out, [&] { \
26+
return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(input_dtype, c_type_in, \
27+
[&] { return __VA_ARGS__(); }); \
28+
}); \
29+
}()
30+
31+
#define DISPATCH_SCALE_GRANULARITY(scale_granularity_m, scale_granularity_n, scale_granularity_k, \
32+
SCALE_GRANULARITY_M, SCALE_GRANULARITY_N, SCALE_GRANULARITY_K, \
33+
...) \
34+
[&]() -> bool { \
35+
if (scale_granularity_m == 1 && scale_granularity_n == 128 && scale_granularity_k == 128) { \
36+
constexpr int SCALE_GRANULARITY_M = 1; \
37+
constexpr int SCALE_GRANULARITY_N = 128; \
38+
constexpr int SCALE_GRANULARITY_K = 128; \
39+
return __VA_ARGS__(); \
40+
} else if (scale_granularity_m == 128 && scale_granularity_n == 128 && \
41+
scale_granularity_k == 128) { \
42+
constexpr int SCALE_GRANULARITY_M = 128; \
43+
constexpr int SCALE_GRANULARITY_N = 128; \
44+
constexpr int SCALE_GRANULARITY_K = 128; \
45+
return __VA_ARGS__(); \
46+
} \
47+
TORCH_CHECK(false, "Unsupported scale granularity"); \
48+
return false; \
49+
}()
50+
51+
void CutlassGemmGroupwiseScaledSM100(at::Tensor float_workspace_buffer, at::Tensor A, at::Tensor B,
52+
at::Tensor SFA, at::Tensor SFB, at::Tensor C,
53+
int64_t scale_granularity_m, int64_t scale_granularity_n,
54+
int64_t scale_granularity_k) {
55+
unsigned int batch_size = A.size(0);
56+
const c10::cuda::OptionalCUDAGuard device_guard(float_workspace_buffer.device());
57+
auto stream = at::cuda::getCurrentCUDAStream();
58+
DISPATCH_PYTORCH_INPUT_OUTPUT_DTYPE(A.scalar_type(), C.scalar_type(), c_type_in, c_type_out, [&] {
59+
return DISPATCH_SCALE_GRANULARITY(
60+
scale_granularity_m, scale_granularity_n, scale_granularity_k, SCALE_GRANULARITY_M,
61+
SCALE_GRANULARITY_N, SCALE_GRANULARITY_K, [&] {
62+
using cutlass_t_in = cutlass_dtype_t<c_type_in>;
63+
using cutlass_t_out = cutlass_dtype_t<c_type_out>;
64+
auto status = flashinfer::gemm::CutlassGroupwiseScaledGEMMSM100<
65+
SCALE_GRANULARITY_M, SCALE_GRANULARITY_N, SCALE_GRANULARITY_K>(
66+
static_cast<float*>(float_workspace_buffer.data_ptr()),
67+
float_workspace_buffer.element_size() * float_workspace_buffer.size(0),
68+
static_cast<cutlass_t_in*>(A.data_ptr()), static_cast<cutlass_t_in*>(B.data_ptr()),
69+
static_cast<float*>(SFA.data_ptr()), static_cast<float*>(SFB.data_ptr()),
70+
static_cast<cutlass_t_out*>(C.data_ptr()), A.size(0), B.size(0), A.size(1), 1,
71+
stream);
72+
return true;
73+
});
74+
});
75+
}

csrc/gemm_sm100_pybind.cu

+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
/*
2+
* Copyright (c) 2025 by FlashInfer team.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#include "pytorch_extension_utils.h"
17+
18+
void CutlassGemmGroupwiseScaledSM100(at::Tensor float_workspace_buffer, at::Tensor A, at::Tensor B,
19+
at::Tensor SFA, at::Tensor SFB, at::Tensor C,
20+
int64_t scale_granularity_m, int64_t scale_granularity_n,
21+
int64_t scale_granularity_k);
22+
23+
TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) {
24+
m.def("gemm_fp8_nt_groupwise", CutlassGemmGroupwiseScaledSM100);
25+
}

docs/api/gemm.rst

+1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ FP8 Batch GEMM
1313
.. autosummary::
1414
:toctree: ../generated
1515

16+
gemm_fp8_nt_groupwise
1617
bmm_fp8
1718

1819
Grouped GEMM

0 commit comments

Comments
 (0)