8
8
9
9
from utils .benchmark_utils import get_available_models , get_model_configs
10
10
11
+ # TODO: Make this an argument, Benchmarking, testing code and kernel helper need to change for it.
12
+ SCALE_BLOCK_SIZE = 128
13
+
11
14
12
15
@triton .autotune (
13
16
configs = [
17
+ triton .Config (
18
+ {
19
+ 'BLOCK_SIZE_M' : 128 , 'BLOCK_SIZE_N' : 128 , 'BLOCK_SIZE_K' : 128 , 'GROUP_SIZE_M' : 4 , 'waves_per_eu' : 2 ,
20
+ 'kpack' : 2 , 'matrix_instr_nonkdim' : 16
21
+ }, num_warps = 4 , num_stages = 2 ),
14
22
triton .Config (
15
23
{
16
24
'BLOCK_SIZE_M' : 256 , 'BLOCK_SIZE_N' : 128 , 'BLOCK_SIZE_K' : 64 , 'GROUP_SIZE_M' : 4 , 'waves_per_eu' : 2 ,
@@ -60,7 +68,13 @@ def matmul_kernel(
60
68
stride_cn ,
61
69
a_scale_ptr ,
62
70
b_scale_ptr ,
71
+ stride_ascale_m ,
72
+ stride_ascale_k ,
73
+ stride_bscale_k ,
74
+ stride_bscale_n ,
63
75
# Meta-parameters
76
+ GROUP_K : tl .constexpr ,
77
+ GROUP_N : tl .constexpr ,
64
78
BLOCK_SIZE_M : tl .constexpr ,
65
79
BLOCK_SIZE_N : tl .constexpr ,
66
80
BLOCK_SIZE_K : tl .constexpr ,
@@ -76,12 +90,19 @@ def matmul_kernel(
76
90
77
91
NUM_XCDS : tl .constexpr = 8
78
92
93
+ tl .static_assert (((APPLY_SCALE is None ) or (APPLY_SCALE == 'tensor' )) or (APPLY_SCALE == 'block' ),
94
+ f"Scaling mode { APPLY_SCALE } is not supported!!!" )
95
+
79
96
tl .assume (stride_am > 0 )
80
97
tl .assume (stride_ak > 0 )
81
98
tl .assume (stride_bk > 0 )
82
99
tl .assume (stride_bn > 0 )
83
100
tl .assume (stride_cm > 0 )
84
101
tl .assume (stride_cn > 0 )
102
+ tl .assume (stride_ascale_m > 0 )
103
+ tl .assume (stride_ascale_k > 0 )
104
+ tl .assume (stride_bscale_k > 0 )
105
+ tl .assume (stride_bscale_n > 0 )
85
106
86
107
# -----------------------------------------------------------
87
108
# Map program ids `pid` to the block of C it should compute.
@@ -132,9 +153,16 @@ def matmul_kernel(
132
153
offs_bn = (pid_n * BLOCK_SIZE_N + tl .arange (0 , BLOCK_SIZE_N )) % N
133
154
a_ptrs = a_ptr + (offs_am [:, None ] * stride_am + offs_k [None , :] * stride_ak )
134
155
b_ptrs = b_ptr + (offs_k [:, None ] * stride_bk + offs_bn [None , :] * stride_bn )
135
- if APPLY_SCALE :
136
- a_scale = tl .load (a_scale_ptr ) if ( a_scale_ptr ) else 1.0
156
+ if APPLY_SCALE == 'tensor' :
157
+ a_scale = tl .load (a_scale_ptr ) if a_scale_ptr else 1.0
137
158
b_scale = tl .load (b_scale_ptr )
159
+ elif APPLY_SCALE == 'block' :
160
+ k_start = 0
161
+ offs_ks = k_start // GROUP_K
162
+ a_scale_ptrs = None if a_scale_ptr is None else (a_scale_ptr + offs_am * stride_ascale_m +
163
+ offs_ks * stride_ascale_k )
164
+ offs_bsn = offs_bn // GROUP_N
165
+ b_scale_ptrs = b_scale_ptr + offs_bsn * stride_bscale_n + offs_ks * stride_bscale_k
138
166
139
167
acc_dtype = tl .float32 if c_ptr .type .element_ty != tl .int8 else tl .int32
140
168
accumulator = tl .zeros ((BLOCK_SIZE_M , BLOCK_SIZE_N ), dtype = acc_dtype )
@@ -148,15 +176,37 @@ def matmul_kernel(
148
176
else :
149
177
a = tl .load (a_ptrs , mask = offs_k [None , :] < K - k * BLOCK_SIZE_K , other = 0.0 )
150
178
b = tl .load (b_ptrs , mask = offs_k [:, None ] < K - k * BLOCK_SIZE_K , other = 0.0 )
179
+
180
+ if APPLY_SCALE == 'block' :
181
+ b_scale = tl .load (b_scale_ptrs )
182
+ if a_scale_ptrs is not None :
183
+ a_scale = tl .load (a_scale_ptrs )
184
+
151
185
# Type conversion to support mixed precision GEMMs where b is lower precision than a
152
186
b = b .to (a_ptr .type .element_ty )
153
- accumulator += tl .dot (a , b , input_precision = "ieee" )
187
+
188
+ if APPLY_SCALE == 'block' :
189
+ if a_scale_ptrs is not None :
190
+ accumulator += tl .dot (a , b , input_precision = "ieee" ) * a_scale [:, None ] * b_scale [None , :]
191
+ else :
192
+ accumulator += tl .dot (a , b , input_precision = "ieee" ) * b_scale [None , :]
193
+ else :
194
+ accumulator += tl .dot (a , b , input_precision = "ieee" )
154
195
155
196
# Advance the ptrs to the next K block.
156
197
a_ptrs += BLOCK_SIZE_K * stride_ak
157
198
b_ptrs += BLOCK_SIZE_K * stride_bk
199
+
200
+ if APPLY_SCALE == 'block' :
201
+ k_cur = k * BLOCK_SIZE_K // GROUP_K
202
+ k_nxt = (k + 1 ) * BLOCK_SIZE_K // GROUP_K
203
+ offs_ks = k_nxt - k_cur
204
+ b_scale_ptrs += offs_ks * stride_bscale_k
205
+ if a_scale_ptrs is not None :
206
+ a_scale_ptrs += offs_ks * stride_ascale_k
207
+
158
208
# Apply scale to recover dynamic range reduced due to lower precision inputs.
159
- if APPLY_SCALE :
209
+ if APPLY_SCALE == 'tensor' :
160
210
accumulator = accumulator * a_scale * b_scale
161
211
# Apply activation function, if specified.
162
212
# TODO(vgokhale): Add different types of activations.
@@ -180,13 +230,14 @@ def leaky_relu(x):
180
230
181
231
182
232
# Wrapper for gemm kernel.
183
- def matmul (a , b , c , a_scale , b_scale , scale_a8_b8 = False , activation = "" ):
233
+ def matmul (a , b , c , a_scale , b_scale , scale_a8_b8 = None , activation = "" ):
184
234
# Check constraints.
185
235
assert a .shape [1 ] == b .shape [0 ], "Incompatible dimensions!!!"
186
236
assert (a .element_size ()
187
237
>= b .element_size ()), "Mixed dtype GEMMs are only supported when data type of a is bigger than b!!!"
188
238
assert (a .is_floating_point () == b .is_floating_point ()
189
239
), "GEMMs between float and integer type tensors are not supported!!!"
240
+ assert (scale_a8_b8 in [None , 'tensor' , 'block' ]), f"Scaling mode { scale_a8_b8 } is not supported!!!"
190
241
M , K = a .shape
191
242
K , N = b .shape
192
243
grid = lambda META : (triton .cdiv (M , META ['BLOCK_SIZE_M' ]) * triton .cdiv (N , META ['BLOCK_SIZE_N' ]), )
@@ -205,6 +256,12 @@ def matmul(a, b, c, a_scale, b_scale, scale_a8_b8=False, activation=""):
205
256
c .stride (1 ),
206
257
a_scale ,
207
258
b_scale ,
259
+ a_scale .stride (0 ) if (a_scale is not None ) and a_scale .ndim else 0 ,
260
+ a_scale .stride (1 ) if (a_scale is not None ) and a_scale .ndim else 0 ,
261
+ b_scale .stride (0 ) if (b_scale is not None ) and b_scale .ndim else 0 ,
262
+ b_scale .stride (1 ) if (b_scale is not None ) and b_scale .ndim else 0 ,
263
+ GROUP_K = SCALE_BLOCK_SIZE ,
264
+ GROUP_N = SCALE_BLOCK_SIZE ,
208
265
APPLY_SCALE = scale_a8_b8 ,
209
266
ACTIVATION = activation ,
210
267
)
@@ -243,7 +300,7 @@ def dtype_is_8_bit(dtype):
243
300
(dtype is torch .int8 )
244
301
245
302
246
- def gen_input (M , N , dtype , needTrans , seed , device = 'cuda' ):
303
+ def gen_input (M , N , dtype , needTrans , seed = 0 , fp8_scaling_mode = 'tensor' , device = 'cuda' ):
247
304
torch .manual_seed (seed )
248
305
249
306
if needTrans :
@@ -252,9 +309,28 @@ def gen_input(M, N, dtype, needTrans, seed, device='cuda'):
252
309
raw_data = torch .randn ((M , N ), dtype = torch .float32 , device = 'cuda' )
253
310
scale = None
254
311
if dtype_is_8_bit (dtype ):
255
- max_val = torch .max (torch .abs (raw_data ))
256
- scale = max_val / dtype_max [dtype ]
257
- raw_data = raw_data / scale
312
+ if fp8_scaling_mode == 'token' :
313
+ assert raw_data .size (1 ) % SCALE_BLOCK_SIZE == 0
314
+ raw_data = raw_data .view (M , - 1 , SCALE_BLOCK_SIZE )
315
+ max_val = raw_data .abs ().float ().amax (dim = 2 ).view (M , - 1 ).clamp (1e-4 )
316
+ scale = max_val .unsqueeze (2 ) / dtype_max [dtype ]
317
+ raw_data = (raw_data / scale ).view (M , N )
318
+ scale = scale .view (M , - 1 )
319
+ scale = scale .T .contiguous ().T
320
+ elif fp8_scaling_mode == 'block' :
321
+ x_padded = torch .zeros ((triton .cdiv (M , SCALE_BLOCK_SIZE ) * SCALE_BLOCK_SIZE ,
322
+ triton .cdiv (N , SCALE_BLOCK_SIZE ) * SCALE_BLOCK_SIZE ), dtype = raw_data .dtype ,
323
+ device = raw_data .device )
324
+ x_padded [:M , :N ] = raw_data
325
+ x_view = x_padded .view (- 1 , SCALE_BLOCK_SIZE , x_padded .size (1 ) // SCALE_BLOCK_SIZE , SCALE_BLOCK_SIZE )
326
+ x_amax = x_view .abs ().float ().amax (dim = (1 , 3 ), keepdim = True ).clamp (1e-4 )
327
+ x_scaled = x_view * (dtype_max [dtype ] / x_amax )
328
+ raw_data = x_scaled .view_as (x_padded )[:M , :N ].T .contiguous ().T
329
+ scale = (x_amax / dtype_max [dtype ]).view (x_view .size (0 ), x_view .size (2 ))
330
+ elif fp8_scaling_mode == 'tensor' :
331
+ max_val = torch .max (torch .abs (raw_data ))
332
+ scale = max_val / dtype_max [dtype ]
333
+ raw_data = raw_data / scale
258
334
259
335
input = raw_data .to (dtype )
260
336
input_f32 = input .to (torch .float32 )
@@ -289,21 +365,21 @@ def get_x_vals():
289
365
def test_correctness (M , N , K , col_a , col_b , in_dtype_a , in_dtype_b , out_dtype ):
290
366
torch_in_dtype_a = name_to_torch_types [in_dtype_a ]
291
367
torch_in_dtype_b = name_to_torch_types [in_dtype_b ]
292
- a , a_fp32 , a_scale = gen_input (M , K , torch_in_dtype_a , col_a , 1 , device = 'cuda' )
293
- b , b_fp32 , b_scale = gen_input (K , N , torch_in_dtype_b , col_b , 2 , device = 'cuda' )
368
+ a , a_fp32 , a_scale = gen_input (M , K , torch_in_dtype_a , col_a , seed = 1 , device = 'cuda' )
369
+ b , b_fp32 , b_scale = gen_input (K , N , torch_in_dtype_b , col_b , seed = 2 , device = 'cuda' )
294
370
torch_out_dtype = name_to_torch_types [out_dtype ]
295
371
c = torch .empty ((M , N ), device = a .device , dtype = torch_out_dtype )
296
372
# For 8-bit, we have scaled to the dynamic range of the data type.
297
373
# This requires us to compute in fp32 because for e5m2, the range is same as fp16 (e5m10).
298
374
# If we use fp16 it is possible to return infs from the torch.matmul call.
299
375
if dtype_is_8_bit (torch_in_dtype_a ) or dtype_is_8_bit (torch_in_dtype_b ):
300
- matmul (a , b , c , a_scale , b_scale , scale_a8_b8 = True , activation = "" )
376
+ matmul (a , b , c , a_scale , b_scale , scale_a8_b8 = 'tensor' , activation = "" )
301
377
torch_output = torch .matmul (a_fp32 , b_fp32 )
302
378
# Set a_scale to 1.0 if it is not set
303
379
torch_output = torch_output * (a_scale or 1.0 ) * b_scale
304
380
# For other dtypes, use the same torch matmul as the dtype.
305
381
else :
306
- matmul (a , b , c , a_scale = None , b_scale = None , scale_a8_b8 = False , activation = "" )
382
+ matmul (a , b , c , a_scale = None , b_scale = None , scale_a8_b8 = None , activation = "" )
307
383
torch_output = torch .matmul (a .to (torch_in_dtype_a ), b .to (torch_in_dtype_b ))
308
384
if out_dtype == 'int8' :
309
385
torch .testing .assert_close (c .to (torch .float32 ),
@@ -312,6 +388,61 @@ def test_correctness(M, N, K, col_a, col_b, in_dtype_a, in_dtype_b, out_dtype):
312
388
torch .testing .assert_close (c , torch_output .to (torch_out_dtype ), atol = 5e-3 , rtol = 1e-2 )
313
389
314
390
391
+ # yapf: disable
392
+ @pytest .mark .parametrize (
393
+ "M, N, K, in_dtype_a, in_dtype_b, out_dtype, col_a, col_b" ,
394
+ [(* shape , in_dtype_a , in_dtype_b , out_dtype , col_a , col_b )
395
+ for shape in get_x_vals ()
396
+ for in_dtype_a , in_dtype_b , out_dtype in [
397
+ ('fp8e4' , 'fp8e4' , 'fp16' ), ('fp8e5' , 'fp8e5' , 'fp16' ), ('fp16' , 'fp8e4' , 'fp16' ),
398
+ ('fp16' , 'fp8e5' , 'fp16' ), ('bf16' , 'fp8e4' , 'bf16' ), ('bf16' , 'fp8e5' , 'bf16' )]
399
+ # Defines if a matrix is row or column major.
400
+ for col_a in [True , False ]
401
+ for col_b in [True , False ]])
402
+ # yapf: enable
403
+ def test_correctness_block_scaling (M , N , K , col_a , col_b , in_dtype_a , in_dtype_b , out_dtype ):
404
+ if (N % SCALE_BLOCK_SIZE != 0 ) or (K % SCALE_BLOCK_SIZE != 0 ):
405
+ pytest .skip ("Skip N/K sizes not aligned to SCALE_BLOCK_SIZE" )
406
+ # Generate Inputs
407
+ torch_in_dtype_a = name_to_torch_types [in_dtype_a ]
408
+ torch_in_dtype_b = name_to_torch_types [in_dtype_b ]
409
+ a , a_fp32 , a_scale = gen_input (M , K , torch_in_dtype_a , col_a , seed = 1 , fp8_scaling_mode = 'token' , device = 'cuda' )
410
+ b , b_fp32 , b_scale = gen_input (K , N , torch_in_dtype_b , col_b , seed = 2 , fp8_scaling_mode = 'block' , device = 'cuda' )
411
+ # Create output tensor
412
+ torch_out_dtype = name_to_torch_types [out_dtype ]
413
+ c = torch .empty ((M , N ), device = a .device , dtype = torch_out_dtype )
414
+ # For 8-bit, we have scaled to the dynamic range of the data type.
415
+ # This requires us to compute in fp32 because for e5m2, the range is same as fp16 (e5m10).
416
+ # If we use fp16 it is possible to return infs from the torch.matmul call.
417
+ matmul (a , b , c , a_scale , b_scale , scale_a8_b8 = 'block' , activation = "" )
418
+ # Reference Implementation
419
+ block_k = SCALE_BLOCK_SIZE
420
+ block_n = SCALE_BLOCK_SIZE
421
+ k_tiles = triton .cdiv (K , block_k )
422
+ n_tiles = triton .cdiv (N , block_n )
423
+ c_ref = torch .zeros ((M , N ), device = a_fp32 .device , dtype = torch .float32 )
424
+
425
+ A_tiles = [a_fp32 [:, i * block_k :min ((i + 1 ) * block_k , K )] for i in range (k_tiles )]
426
+ B_tiles = [[
427
+ b_fp32 [
428
+ i * block_k :min ((i + 1 ) * block_k , K ),
429
+ j * block_n :min ((j + 1 ) * block_n , N ),
430
+ ] for j in range (n_tiles )
431
+ ] for i in range (k_tiles )]
432
+ C_tiles = [c_ref [:, j * block_n :min ((j + 1 ) * block_n , N )] for j in range (n_tiles )]
433
+ As_tiles = [a_scale [:, i :i + 1 ] for i in range (k_tiles )] if (a_scale is not None ) else None
434
+
435
+ for i in range (k_tiles ):
436
+ for j in range (n_tiles ):
437
+ a_tile = A_tiles [i ]
438
+ b_tile = B_tiles [i ][j ]
439
+ c_tile = C_tiles [j ]
440
+ s_tile = (As_tiles [i ] * b_scale [i ][j ]) if dtype_is_8_bit (torch_in_dtype_a ) else b_scale [i ][j ]
441
+ c_tile [:, :] += torch .matmul (a_tile , b_tile ) * s_tile
442
+
443
+ torch .testing .assert_close (c , c_ref .to (torch_out_dtype ), atol = 5e-3 , rtol = 1e-2 )
444
+
445
+
315
446
def get_type (provider ):
316
447
res = re .findall (r'\(.*?\)' , provider )
317
448
return res [0 ][1 :- 1 ].split ('/' , 1 )
@@ -341,16 +472,28 @@ def benchmark(M, N, K, provider, model=None, args=None):
341
472
342
473
quantiles = [0.5 , 0.2 , 0.8 ]
343
474
layout_tn = args .layout == 'tn'
344
- a , _ , a_scale = gen_input (M , K , in_dtype_a , False , 1 , device = 'cuda' )
345
- b , _ , b_scale = gen_input (K , N , in_dtype_b , layout_tn , 2 , device = 'cuda' )
475
+
476
+ if args .fp8_scaling_mode == 'tensor' or in_dtype_b == torch .int8 :
477
+ a , _ , a_scale = gen_input (M , K , in_dtype_a , False , seed = 1 , device = 'cuda' )
478
+ b , _ , b_scale = gen_input (K , N , in_dtype_b , layout_tn , seed = 2 , device = 'cuda' )
479
+ else :
480
+ a , _ , a_scale = gen_input (M , K , in_dtype_a , False , seed = 1 , fp8_scaling_mode = 'token' , device = 'cuda' )
481
+ b , _ , b_scale = gen_input (K , N , in_dtype_b , layout_tn , seed = 2 , fp8_scaling_mode = 'block' , device = 'cuda' )
482
+
346
483
if 'hipblaslt' in provider :
347
484
ms , min_ms , max_ms = triton .testing .do_bench (lambda : torch .matmul (a , b ), quantiles = quantiles )
348
485
else : # triton, different data types
349
486
assert "triton" in provider
350
487
# Allocates output.
351
488
c = torch .empty ((M , N ), device = a .device , dtype = out_dtype )
352
489
353
- scale_a8_b8 = dtype_is_8_bit (in_dtype_a ) or dtype_is_8_bit (in_dtype_b )
490
+ # If data type is 8 bit
491
+ # Default to tensor scaling if scaling mode is tensor or dtype is int8
492
+ # Use block scaling otherwise
493
+ scale_a8_b8 = None
494
+ if dtype_is_8_bit (in_dtype_a ) or dtype_is_8_bit (in_dtype_b ):
495
+ scale_a8_b8 = 'tensor' if in_dtype_b == torch .int8 else args .fp8_scaling_mode
496
+
354
497
ms , min_ms , max_ms = triton .testing .do_bench (
355
498
lambda : matmul (a , b , c , a_scale , b_scale , scale_a8_b8 = scale_a8_b8 , activation = "" ), quantiles = quantiles )
356
499
if args .v :
@@ -381,6 +524,8 @@ def parse_args():
381
524
parser .add_argument ("-dtype" , type = str , default = None , help = "Data type of inputs and outputs" )
382
525
parser .add_argument ("-b_dtype" , type = str , default = None ,
383
526
help = "Data type of B operand, if specified (else same as dtype)" )
527
+ parser .add_argument ("-fp8_scaling_mode" , type = str , default = 'tensor' , choices = ['tensor' , 'block' ],
528
+ help = "Type of scaling to apply when either or both inputs are fp8" )
384
529
385
530
args = parser .parse_args ()
386
531
0 commit comments