1
+ import itertools
2
+
3
+ import torchao
4
+
1
5
import torch
2
6
from torch .testing ._internal .common_utils import (
3
7
TestCase ,
6
10
run_tests ,
7
11
)
8
12
from torch .testing ._internal .optests import opcheck
9
- from torchao .utils import is_fbcode
13
+ from torchao .utils import is_fbcode , TORCH_VERSION_AFTER_2_5
10
14
from torchao .prototype .quant_llm import from_scaled_tc_fpx
11
15
import pytest
12
16
18
22
except RuntimeError :
19
23
pytest .skip ("torchao.ops not available" )
20
24
25
+ from torchao .quantization .utils import (
26
+ get_groupwise_affine_qparams ,
27
+ groupwise_affine_dequantize_tensor_from_qparams ,
28
+ groupwise_affine_quantize_tensor_from_qparams ,
29
+ pack_tinygemm_scales_and_zeros ,
30
+ unpack_tinygemm_scales_and_zeros ,
31
+ )
32
+
21
33
22
34
class TestOps (TestCase ):
23
35
def _create_fpx_inputs (self , ebits : int , mbits : int , BS : int , OC : int , IC : int , device ):
@@ -61,9 +73,218 @@ def test_quant_llm_linear_correctness(self, ebits, mbits, BS, OC, IC, splitK):
61
73
relative_error = error / gt
62
74
assert relative_error < 1e-3
63
75
64
-
65
76
instantiate_parametrized_tests (TestOps )
66
77
67
78
79
+ ## Tests for `tensor_core_layout`
80
+ kTileSizeN = 8
81
+ kTileSizeK = 16
82
+
83
+ SHAPES = [
84
+ (4096 , 4096 ),
85
+ # Llama 2 GEMM shapes
86
+ (4096 , 11008 ),
87
+ (11008 , 4096 ),
88
+ # Llama 3 GEMM shapes
89
+ (4096 , 14336 ),
90
+ (14336 , 4096 ),
91
+ ]
92
+ INNERKTILES = [2 , 4 , 8 ]
93
+ QGROUP_SIZES = [32 , 64 , 128 , 256 ]
94
+ TEST_CONFIGS_UNPACK = list (itertools .product (SHAPES , INNERKTILES ))
95
+ TEST_CONFIGS_DEQUANT = list (itertools .product (SHAPES , INNERKTILES , QGROUP_SIZES ))
96
+
97
+ @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
98
+ @pytest .mark .parametrize ("shape, inner_k_tiles" , TEST_CONFIGS_UNPACK , ids = str )
99
+ def test_unpack_tensor_core_tiled_layout_correctness (shape , inner_k_tiles ):
100
+ N , K = shape
101
+ assert K % (inner_k_tiles * kTileSizeK ) == 0 and N % kTileSizeN == 0
102
+
103
+ t = torch .randint (0 , 16 , dtype = torch .int , size = shape , device = "cuda" )
104
+ packed_w = torch .ops .aten ._convert_weight_to_int4pack (t , inner_k_tiles )
105
+ unpacked = torchao .ops .unpack_tensor_core_tiled_layout (packed_w , inner_k_tiles )
106
+ assert torch .equal (t , unpacked )
107
+
108
+ # TODO: Fix "test_aot_dispatch_dynamic" test failure
109
+ @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
110
+ @pytest .mark .parametrize ("shape, inner_k_tiles" , TEST_CONFIGS_UNPACK , ids = str )
111
+ def test_unpack_tensor_core_tiled_layout_op (shape , inner_k_tiles ):
112
+ test_utils = [
113
+ "test_schema" ,
114
+ "test_autograd_registration" ,
115
+ "test_faketensor" ,
116
+ ]
117
+
118
+ # TODO: Figure out why test fails unless torch >= 2.5
119
+ if TORCH_VERSION_AFTER_2_5 :
120
+ test_utils .append ("test_aot_dispatch_dynamic" )
121
+
122
+ t = torch .randint (0 , 16 , dtype = torch .int , size = shape , device = "cuda" )
123
+ packed_w = torch .ops .aten ._convert_weight_to_int4pack (t , inner_k_tiles )
124
+
125
+ opcheck (
126
+ torch .ops .torchao .unpack_tensor_core_tiled_layout ,
127
+ (packed_w , inner_k_tiles ),
128
+ test_utils = test_utils ,
129
+ )
130
+
131
+ def dequant_ref (q , scales , zeros , group_size , nbits = 4 , dtype = torch .bfloat16 ):
132
+ n , k = q .shape
133
+ assert q .dtype == torch .int
134
+
135
+ n_groups = k // group_size
136
+ assert scales .shape [0 ] == n and scales .shape [1 ] == n_groups
137
+ assert scales .shape == zeros .shape
138
+
139
+ midpoint = 2 ** (nbits - 1 )
140
+
141
+ #Convert fron u4 -> s4 and upcast to bfloat16
142
+ q = q .sub (midpoint ).to (dtype )
143
+
144
+ # Dequantize
145
+ q = q .reshape (- 1 , group_size )
146
+ dq = q * scales .reshape (- 1 , 1 ) + zeros .reshape (- 1 , 1 )
147
+
148
+ return dq .reshape (n , k )
149
+
150
+
151
+ @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
152
+ @pytest .mark .parametrize ("shape, inner_k_tiles, group_size" , TEST_CONFIGS_DEQUANT , ids = str )
153
+ def test_dequantize_tensor_core_tiled_layout_correctness_quant_dequant (shape , inner_k_tiles , group_size ):
154
+ n , k = shape
155
+ dtype = torch .bfloat16
156
+
157
+ device = "cuda"
158
+
159
+ t = torch .randn (n , k , dtype = dtype , device = device )
160
+ scales , zeros = get_groupwise_affine_qparams (t , n_bit = 4 , groupsize = group_size , dtype = dtype )
161
+
162
+ # Quantize
163
+ q = groupwise_affine_quantize_tensor_from_qparams (
164
+ t , scales , zeros , n_bit = 4 , groupsize = group_size
165
+ )
166
+
167
+ # Pack to tensor core layout
168
+ packed = torch .ops .aten ._convert_weight_to_int4pack (q , inner_k_tiles )
169
+ scales_and_zeros = pack_tinygemm_scales_and_zeros (scales , zeros )
170
+ q_groups = k // group_size
171
+ assert scales_and_zeros .shape == torch .Size ([q_groups , n , 2 ])
172
+
173
+ # Dequantize 'ao' ref
174
+ dq_ao = groupwise_affine_dequantize_tensor_from_qparams (
175
+ q , scales , zeros , n_bit = 4 , groupsize = group_size
176
+ )
177
+
178
+ # Dequantize by passing in an identity matrix as the activation
179
+ a_eye = torch .eye (k , device = device , dtype = dtype )
180
+ dq_id = torch .ops .aten ._weight_int4pack_mm (
181
+ a_eye ,
182
+ packed ,
183
+ group_size ,
184
+ scales_and_zeros ,
185
+ ).t ()
186
+
187
+ # Actual operation to test
188
+ dq_op = torchao .ops .dequantize_tensor_core_tiled_layout (packed , scales_and_zeros , group_size , inner_k_tiles )
189
+
190
+ # Compare results
191
+ diff_ao_id = (dq_id - dq_ao ).abs ().max ()
192
+ diff_op_id = (dq_op - dq_id ).abs ().max ()
193
+ diff_op_ao = (dq_op - dq_ao ).abs ().max ()
194
+
195
+ # There are slight numerical differences when dequantizing with an identity matrix when compared to `groupwise_affine_dequantize`
196
+ # Since the `dequantize_tensor_core_layout` kernel relies on the same underlying bit twiddling tricks for fast
197
+ # conversion from u4 -> s4 -> bf16, the identity matrix dequant hack and `dequantize_tensor_core_layout` are
198
+ # expected to give same results, while both will have similar numerical differences to `groupwise_affine_dequantize`.
199
+
200
+ # Test that the `dequant` kernel gives same results as identity matrix-based dequant
201
+ assert diff_op_id == 0
202
+
203
+ # Test that the `dequant` kernel gives same numerical diffs as the `groupwise_affine_dequantize` when compared against the identity matrix
204
+ assert diff_op_ao == diff_ao_id
205
+
206
+ assert diff_op_ao < 1e-1
207
+
208
+ # This test differs from one above in that it uses `unpack_tensor_core_tiled_layout` to unpack then dequantize
209
+ @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
210
+ @pytest .mark .parametrize ("shape, inner_k_tiles, group_size" , TEST_CONFIGS_DEQUANT , ids = str )
211
+ def test_dequantize_tensor_core_tiled_layout_correctness_unpack_and_dequant (shape , inner_k_tiles , group_size ):
212
+ n , k = shape
213
+ dtype = torch .bfloat16
214
+ device = "cuda"
215
+
216
+ # Quantize and pack
217
+ t = torch .randn (n , k , dtype = dtype , device = device )
218
+ scales , zeros = get_groupwise_affine_qparams (t , n_bit = 4 , groupsize = group_size , dtype = dtype )
219
+ q = groupwise_affine_quantize_tensor_from_qparams (
220
+ t , scales , zeros , n_bit = 4 , groupsize = group_size
221
+ )
222
+
223
+ packed = torch .ops .aten ._convert_weight_to_int4pack (q , inner_k_tiles )
224
+ scales_and_zeros = pack_tinygemm_scales_and_zeros (scales , zeros )
225
+
226
+ # Unpack and dequantize
227
+ unpacked = torchao .ops .unpack_tensor_core_tiled_layout (packed , inner_k_tiles )
228
+ dq_ao = groupwise_affine_dequantize_tensor_from_qparams (
229
+ unpacked , scales , zeros , n_bit = 4 , groupsize = group_size
230
+ )
231
+
232
+ # Dequantize by passing in an identity matrix as the activation
233
+ a_eye = torch .eye (k , device = device , dtype = dtype )
234
+ dq_id = torch .ops .aten ._weight_int4pack_mm (
235
+ a_eye ,
236
+ packed ,
237
+ group_size ,
238
+ scales_and_zeros ,
239
+ ).t ()
240
+
241
+ # Actual operation to test
242
+ dq_op = torchao .ops .dequantize_tensor_core_tiled_layout (packed , scales_and_zeros , group_size , inner_k_tiles )
243
+
244
+ # Compare results
245
+ diff_ao_id = (dq_id - dq_ao ).abs ().max ()
246
+ diff_op_id = (dq_op - dq_id ).abs ().max ()
247
+ diff_op_ao = (dq_op - dq_ao ).abs ().max ()
248
+
249
+ # There are slight numerical differences when dequantizing with an identity matrix when compared to `groupwise_affine_dequantize`
250
+ # Since the `dequantize_tensor_core_layout` kernel relies on the same underlying bit twiddling tricks for fast
251
+ # conversion from u4 -> s4 -> bf16, the identity matrix dequant hack and `dequantize_tensor_core_layout` are
252
+ # expected to give same results, while both will have similar numerical differences to `groupwise_affine_dequantize`.
253
+
254
+ # Test that the `dequant` kernel gives same results as identity matrix-based dequant
255
+ assert diff_op_id == 0
256
+
257
+ # Test that the `dequant` kernel gives same numerical diffs as the `groupwise_affine_dequantize` when compared against the identity matrix
258
+ assert diff_op_ao == diff_ao_id
259
+
260
+ assert diff_op_ao < 1e-1
261
+
262
+ @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
263
+ @pytest .mark .parametrize ("shape, inner_k_tiles, group_size" , TEST_CONFIGS_DEQUANT , ids = str )
264
+ def test_dequantize_tensor_core_tiled_layout_op (shape , inner_k_tiles , group_size ):
265
+ n , k = shape
266
+ device = "cuda"
267
+
268
+ q = torch .randint (0 , 16 , shape , dtype = torch .int , device = device )
269
+ packed_w = torch ._convert_weight_to_int4pack (q , inner_k_tiles )
270
+ q_groups = k // group_size
271
+ scales = torch .randn (n , q_groups , dtype = torch .bfloat16 , device = device )
272
+ zeros = torch .randn_like (scales )
273
+ scales_and_zeros = pack_tinygemm_scales_and_zeros (scales , zeros )
274
+
275
+ test_utils = [
276
+ "test_schema" ,
277
+ "test_autograd_registration" ,
278
+ "test_faketensor" ,
279
+ ]
280
+ # TODO: Figure out why test fails unless torch >= 2.5
281
+ if TORCH_VERSION_AFTER_2_5 :
282
+ test_utils .append ("test_aot_dispatch_dynamic" )
283
+ opcheck (
284
+ torch .ops .torchao .dequantize_tensor_core_tiled_layout ,
285
+ (packed_w , scales_and_zeros , group_size , inner_k_tiles ),
286
+ test_utils = test_utils ,
287
+ )
288
+
68
289
if __name__ == "__main__" :
69
- run_tests ()
290
+ run_tests ()
0 commit comments