Skip to content

Commit 85ffce9

Browse files
committed
Merge remote-tracking branch 'rocm/main_perf'
2 parents cfa9f99 + f669d30 commit 85ffce9

File tree

1 file changed

+161
-16
lines changed

1 file changed

+161
-16
lines changed

python/perf-kernels/gemm.py

Lines changed: 161 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,17 @@
88

99
from utils.benchmark_utils import get_available_models, get_model_configs
1010

11+
# TODO: Make this an argument, Benchmarking, testing code and kernel helper need to change for it.
12+
SCALE_BLOCK_SIZE = 128
13+
1114

1215
@triton.autotune(
1316
configs=[
17+
triton.Config(
18+
{
19+
'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 4, 'waves_per_eu': 2,
20+
'kpack': 2, 'matrix_instr_nonkdim': 16
21+
}, num_warps=4, num_stages=2),
1422
triton.Config(
1523
{
1624
'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 4, 'waves_per_eu': 2,
@@ -60,7 +68,13 @@ def matmul_kernel(
6068
stride_cn,
6169
a_scale_ptr,
6270
b_scale_ptr,
71+
stride_ascale_m,
72+
stride_ascale_k,
73+
stride_bscale_k,
74+
stride_bscale_n,
6375
# Meta-parameters
76+
GROUP_K: tl.constexpr,
77+
GROUP_N: tl.constexpr,
6478
BLOCK_SIZE_M: tl.constexpr,
6579
BLOCK_SIZE_N: tl.constexpr,
6680
BLOCK_SIZE_K: tl.constexpr,
@@ -76,12 +90,19 @@ def matmul_kernel(
7690

7791
NUM_XCDS: tl.constexpr = 8
7892

93+
tl.static_assert(((APPLY_SCALE is None) or (APPLY_SCALE == 'tensor')) or (APPLY_SCALE == 'block'),
94+
f"Scaling mode {APPLY_SCALE} is not supported!!!")
95+
7996
tl.assume(stride_am > 0)
8097
tl.assume(stride_ak > 0)
8198
tl.assume(stride_bk > 0)
8299
tl.assume(stride_bn > 0)
83100
tl.assume(stride_cm > 0)
84101
tl.assume(stride_cn > 0)
102+
tl.assume(stride_ascale_m > 0)
103+
tl.assume(stride_ascale_k > 0)
104+
tl.assume(stride_bscale_k > 0)
105+
tl.assume(stride_bscale_n > 0)
85106

86107
# -----------------------------------------------------------
87108
# Map program ids `pid` to the block of C it should compute.
@@ -132,9 +153,16 @@ def matmul_kernel(
132153
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
133154
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
134155
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
135-
if APPLY_SCALE:
136-
a_scale = tl.load(a_scale_ptr) if (a_scale_ptr) else 1.0
156+
if APPLY_SCALE == 'tensor':
157+
a_scale = tl.load(a_scale_ptr) if a_scale_ptr else 1.0
137158
b_scale = tl.load(b_scale_ptr)
159+
elif APPLY_SCALE == 'block':
160+
k_start = 0
161+
offs_ks = k_start // GROUP_K
162+
a_scale_ptrs = None if a_scale_ptr is None else (a_scale_ptr + offs_am * stride_ascale_m +
163+
offs_ks * stride_ascale_k)
164+
offs_bsn = offs_bn // GROUP_N
165+
b_scale_ptrs = b_scale_ptr + offs_bsn * stride_bscale_n + offs_ks * stride_bscale_k
138166

139167
acc_dtype = tl.float32 if c_ptr.type.element_ty != tl.int8 else tl.int32
140168
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype)
@@ -148,15 +176,37 @@ def matmul_kernel(
148176
else:
149177
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
150178
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
179+
180+
if APPLY_SCALE == 'block':
181+
b_scale = tl.load(b_scale_ptrs)
182+
if a_scale_ptrs is not None:
183+
a_scale = tl.load(a_scale_ptrs)
184+
151185
# Type conversion to support mixed precision GEMMs where b is lower precision than a
152186
b = b.to(a_ptr.type.element_ty)
153-
accumulator += tl.dot(a, b, input_precision="ieee")
187+
188+
if APPLY_SCALE == 'block':
189+
if a_scale_ptrs is not None:
190+
accumulator += tl.dot(a, b, input_precision="ieee") * a_scale[:, None] * b_scale[None, :]
191+
else:
192+
accumulator += tl.dot(a, b, input_precision="ieee") * b_scale[None, :]
193+
else:
194+
accumulator += tl.dot(a, b, input_precision="ieee")
154195

155196
# Advance the ptrs to the next K block.
156197
a_ptrs += BLOCK_SIZE_K * stride_ak
157198
b_ptrs += BLOCK_SIZE_K * stride_bk
199+
200+
if APPLY_SCALE == 'block':
201+
k_cur = k * BLOCK_SIZE_K // GROUP_K
202+
k_nxt = (k + 1) * BLOCK_SIZE_K // GROUP_K
203+
offs_ks = k_nxt - k_cur
204+
b_scale_ptrs += offs_ks * stride_bscale_k
205+
if a_scale_ptrs is not None:
206+
a_scale_ptrs += offs_ks * stride_ascale_k
207+
158208
# Apply scale to recover dynamic range reduced due to lower precision inputs.
159-
if APPLY_SCALE:
209+
if APPLY_SCALE == 'tensor':
160210
accumulator = accumulator * a_scale * b_scale
161211
# Apply activation function, if specified.
162212
# TODO(vgokhale): Add different types of activations.
@@ -180,13 +230,14 @@ def leaky_relu(x):
180230

181231

182232
# Wrapper for gemm kernel.
183-
def matmul(a, b, c, a_scale, b_scale, scale_a8_b8=False, activation=""):
233+
def matmul(a, b, c, a_scale, b_scale, scale_a8_b8=None, activation=""):
184234
# Check constraints.
185235
assert a.shape[1] == b.shape[0], "Incompatible dimensions!!!"
186236
assert (a.element_size()
187237
>= b.element_size()), "Mixed dtype GEMMs are only supported when data type of a is bigger than b!!!"
188238
assert (a.is_floating_point() == b.is_floating_point()
189239
), "GEMMs between float and integer type tensors are not supported!!!"
240+
assert (scale_a8_b8 in [None, 'tensor', 'block']), f"Scaling mode {scale_a8_b8} is not supported!!!"
190241
M, K = a.shape
191242
K, N = b.shape
192243
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
@@ -205,6 +256,12 @@ def matmul(a, b, c, a_scale, b_scale, scale_a8_b8=False, activation=""):
205256
c.stride(1),
206257
a_scale,
207258
b_scale,
259+
a_scale.stride(0) if (a_scale is not None) and a_scale.ndim else 0,
260+
a_scale.stride(1) if (a_scale is not None) and a_scale.ndim else 0,
261+
b_scale.stride(0) if (b_scale is not None) and b_scale.ndim else 0,
262+
b_scale.stride(1) if (b_scale is not None) and b_scale.ndim else 0,
263+
GROUP_K=SCALE_BLOCK_SIZE,
264+
GROUP_N=SCALE_BLOCK_SIZE,
208265
APPLY_SCALE=scale_a8_b8,
209266
ACTIVATION=activation,
210267
)
@@ -243,7 +300,7 @@ def dtype_is_8_bit(dtype):
243300
(dtype is torch.int8)
244301

245302

246-
def gen_input(M, N, dtype, needTrans, seed, device='cuda'):
303+
def gen_input(M, N, dtype, needTrans, seed=0, fp8_scaling_mode='tensor', device='cuda'):
247304
torch.manual_seed(seed)
248305

249306
if needTrans:
@@ -252,9 +309,28 @@ def gen_input(M, N, dtype, needTrans, seed, device='cuda'):
252309
raw_data = torch.randn((M, N), dtype=torch.float32, device='cuda')
253310
scale = None
254311
if dtype_is_8_bit(dtype):
255-
max_val = torch.max(torch.abs(raw_data))
256-
scale = max_val / dtype_max[dtype]
257-
raw_data = raw_data / scale
312+
if fp8_scaling_mode == 'token':
313+
assert raw_data.size(1) % SCALE_BLOCK_SIZE == 0
314+
raw_data = raw_data.view(M, -1, SCALE_BLOCK_SIZE)
315+
max_val = raw_data.abs().float().amax(dim=2).view(M, -1).clamp(1e-4)
316+
scale = max_val.unsqueeze(2) / dtype_max[dtype]
317+
raw_data = (raw_data / scale).view(M, N)
318+
scale = scale.view(M, -1)
319+
scale = scale.T.contiguous().T
320+
elif fp8_scaling_mode == 'block':
321+
x_padded = torch.zeros((triton.cdiv(M, SCALE_BLOCK_SIZE) * SCALE_BLOCK_SIZE,
322+
triton.cdiv(N, SCALE_BLOCK_SIZE) * SCALE_BLOCK_SIZE), dtype=raw_data.dtype,
323+
device=raw_data.device)
324+
x_padded[:M, :N] = raw_data
325+
x_view = x_padded.view(-1, SCALE_BLOCK_SIZE, x_padded.size(1) // SCALE_BLOCK_SIZE, SCALE_BLOCK_SIZE)
326+
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
327+
x_scaled = x_view * (dtype_max[dtype] / x_amax)
328+
raw_data = x_scaled.view_as(x_padded)[:M, :N].T.contiguous().T
329+
scale = (x_amax / dtype_max[dtype]).view(x_view.size(0), x_view.size(2))
330+
elif fp8_scaling_mode == 'tensor':
331+
max_val = torch.max(torch.abs(raw_data))
332+
scale = max_val / dtype_max[dtype]
333+
raw_data = raw_data / scale
258334

259335
input = raw_data.to(dtype)
260336
input_f32 = input.to(torch.float32)
@@ -289,21 +365,21 @@ def get_x_vals():
289365
def test_correctness(M, N, K, col_a, col_b, in_dtype_a, in_dtype_b, out_dtype):
290366
torch_in_dtype_a = name_to_torch_types[in_dtype_a]
291367
torch_in_dtype_b = name_to_torch_types[in_dtype_b]
292-
a, a_fp32, a_scale = gen_input(M, K, torch_in_dtype_a, col_a, 1, device='cuda')
293-
b, b_fp32, b_scale = gen_input(K, N, torch_in_dtype_b, col_b, 2, device='cuda')
368+
a, a_fp32, a_scale = gen_input(M, K, torch_in_dtype_a, col_a, seed=1, device='cuda')
369+
b, b_fp32, b_scale = gen_input(K, N, torch_in_dtype_b, col_b, seed=2, device='cuda')
294370
torch_out_dtype = name_to_torch_types[out_dtype]
295371
c = torch.empty((M, N), device=a.device, dtype=torch_out_dtype)
296372
# For 8-bit, we have scaled to the dynamic range of the data type.
297373
# This requires us to compute in fp32 because for e5m2, the range is same as fp16 (e5m10).
298374
# If we use fp16 it is possible to return infs from the torch.matmul call.
299375
if dtype_is_8_bit(torch_in_dtype_a) or dtype_is_8_bit(torch_in_dtype_b):
300-
matmul(a, b, c, a_scale, b_scale, scale_a8_b8=True, activation="")
376+
matmul(a, b, c, a_scale, b_scale, scale_a8_b8='tensor', activation="")
301377
torch_output = torch.matmul(a_fp32, b_fp32)
302378
# Set a_scale to 1.0 if it is not set
303379
torch_output = torch_output * (a_scale or 1.0) * b_scale
304380
# For other dtypes, use the same torch matmul as the dtype.
305381
else:
306-
matmul(a, b, c, a_scale=None, b_scale=None, scale_a8_b8=False, activation="")
382+
matmul(a, b, c, a_scale=None, b_scale=None, scale_a8_b8=None, activation="")
307383
torch_output = torch.matmul(a.to(torch_in_dtype_a), b.to(torch_in_dtype_b))
308384
if out_dtype == 'int8':
309385
torch.testing.assert_close(c.to(torch.float32),
@@ -312,6 +388,61 @@ def test_correctness(M, N, K, col_a, col_b, in_dtype_a, in_dtype_b, out_dtype):
312388
torch.testing.assert_close(c, torch_output.to(torch_out_dtype), atol=5e-3, rtol=1e-2)
313389

314390

391+
# yapf: disable
392+
@pytest.mark.parametrize(
393+
"M, N, K, in_dtype_a, in_dtype_b, out_dtype, col_a, col_b",
394+
[(*shape, in_dtype_a, in_dtype_b, out_dtype, col_a, col_b)
395+
for shape in get_x_vals()
396+
for in_dtype_a, in_dtype_b, out_dtype in [
397+
('fp8e4', 'fp8e4', 'fp16'), ('fp8e5', 'fp8e5', 'fp16'), ('fp16', 'fp8e4', 'fp16'),
398+
('fp16', 'fp8e5', 'fp16'), ('bf16', 'fp8e4', 'bf16'), ('bf16', 'fp8e5', 'bf16')]
399+
# Defines if a matrix is row or column major.
400+
for col_a in [True, False]
401+
for col_b in [True, False]])
402+
# yapf: enable
403+
def test_correctness_block_scaling(M, N, K, col_a, col_b, in_dtype_a, in_dtype_b, out_dtype):
404+
if (N % SCALE_BLOCK_SIZE != 0) or (K % SCALE_BLOCK_SIZE != 0):
405+
pytest.skip("Skip N/K sizes not aligned to SCALE_BLOCK_SIZE")
406+
# Generate Inputs
407+
torch_in_dtype_a = name_to_torch_types[in_dtype_a]
408+
torch_in_dtype_b = name_to_torch_types[in_dtype_b]
409+
a, a_fp32, a_scale = gen_input(M, K, torch_in_dtype_a, col_a, seed=1, fp8_scaling_mode='token', device='cuda')
410+
b, b_fp32, b_scale = gen_input(K, N, torch_in_dtype_b, col_b, seed=2, fp8_scaling_mode='block', device='cuda')
411+
# Create output tensor
412+
torch_out_dtype = name_to_torch_types[out_dtype]
413+
c = torch.empty((M, N), device=a.device, dtype=torch_out_dtype)
414+
# For 8-bit, we have scaled to the dynamic range of the data type.
415+
# This requires us to compute in fp32 because for e5m2, the range is same as fp16 (e5m10).
416+
# If we use fp16 it is possible to return infs from the torch.matmul call.
417+
matmul(a, b, c, a_scale, b_scale, scale_a8_b8='block', activation="")
418+
# Reference Implementation
419+
block_k = SCALE_BLOCK_SIZE
420+
block_n = SCALE_BLOCK_SIZE
421+
k_tiles = triton.cdiv(K, block_k)
422+
n_tiles = triton.cdiv(N, block_n)
423+
c_ref = torch.zeros((M, N), device=a_fp32.device, dtype=torch.float32)
424+
425+
A_tiles = [a_fp32[:, i * block_k:min((i + 1) * block_k, K)] for i in range(k_tiles)]
426+
B_tiles = [[
427+
b_fp32[
428+
i * block_k:min((i + 1) * block_k, K),
429+
j * block_n:min((j + 1) * block_n, N),
430+
] for j in range(n_tiles)
431+
] for i in range(k_tiles)]
432+
C_tiles = [c_ref[:, j * block_n:min((j + 1) * block_n, N)] for j in range(n_tiles)]
433+
As_tiles = [a_scale[:, i:i + 1] for i in range(k_tiles)] if (a_scale is not None) else None
434+
435+
for i in range(k_tiles):
436+
for j in range(n_tiles):
437+
a_tile = A_tiles[i]
438+
b_tile = B_tiles[i][j]
439+
c_tile = C_tiles[j]
440+
s_tile = (As_tiles[i] * b_scale[i][j]) if dtype_is_8_bit(torch_in_dtype_a) else b_scale[i][j]
441+
c_tile[:, :] += torch.matmul(a_tile, b_tile) * s_tile
442+
443+
torch.testing.assert_close(c, c_ref.to(torch_out_dtype), atol=5e-3, rtol=1e-2)
444+
445+
315446
def get_type(provider):
316447
res = re.findall(r'\(.*?\)', provider)
317448
return res[0][1:-1].split('/', 1)
@@ -341,16 +472,28 @@ def benchmark(M, N, K, provider, model=None, args=None):
341472

342473
quantiles = [0.5, 0.2, 0.8]
343474
layout_tn = args.layout == 'tn'
344-
a, _, a_scale = gen_input(M, K, in_dtype_a, False, 1, device='cuda')
345-
b, _, b_scale = gen_input(K, N, in_dtype_b, layout_tn, 2, device='cuda')
475+
476+
if args.fp8_scaling_mode == 'tensor' or in_dtype_b == torch.int8:
477+
a, _, a_scale = gen_input(M, K, in_dtype_a, False, seed=1, device='cuda')
478+
b, _, b_scale = gen_input(K, N, in_dtype_b, layout_tn, seed=2, device='cuda')
479+
else:
480+
a, _, a_scale = gen_input(M, K, in_dtype_a, False, seed=1, fp8_scaling_mode='token', device='cuda')
481+
b, _, b_scale = gen_input(K, N, in_dtype_b, layout_tn, seed=2, fp8_scaling_mode='block', device='cuda')
482+
346483
if 'hipblaslt' in provider:
347484
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), quantiles=quantiles)
348485
else: # triton, different data types
349486
assert "triton" in provider
350487
# Allocates output.
351488
c = torch.empty((M, N), device=a.device, dtype=out_dtype)
352489

353-
scale_a8_b8 = dtype_is_8_bit(in_dtype_a) or dtype_is_8_bit(in_dtype_b)
490+
# If data type is 8 bit
491+
# Default to tensor scaling if scaling mode is tensor or dtype is int8
492+
# Use block scaling otherwise
493+
scale_a8_b8 = None
494+
if dtype_is_8_bit(in_dtype_a) or dtype_is_8_bit(in_dtype_b):
495+
scale_a8_b8 = 'tensor' if in_dtype_b == torch.int8 else args.fp8_scaling_mode
496+
354497
ms, min_ms, max_ms = triton.testing.do_bench(
355498
lambda: matmul(a, b, c, a_scale, b_scale, scale_a8_b8=scale_a8_b8, activation=""), quantiles=quantiles)
356499
if args.v:
@@ -381,6 +524,8 @@ def parse_args():
381524
parser.add_argument("-dtype", type=str, default=None, help="Data type of inputs and outputs")
382525
parser.add_argument("-b_dtype", type=str, default=None,
383526
help="Data type of B operand, if specified (else same as dtype)")
527+
parser.add_argument("-fp8_scaling_mode", type=str, default='tensor', choices=['tensor', 'block'],
528+
help="Type of scaling to apply when either or both inputs are fp8")
384529

385530
args = parser.parse_args()
386531

0 commit comments

Comments
 (0)