Skip to content

Commit 74846da

Browse files
authored
[FEAT] Add custom CUDA tinygemm unpacker (#415)
* add unpack cuda * add tests * fix tests * refactor tinygemm unpacking kernel * add dequant * add additional dequant check * update tinygemm dequantize test * correct dequant kernel logic * clean up kernel * update dequantize kernel tests * rename kernel ops to tensor_core_tiled_layout * add renamed kernel source * add back test_aot_dispatch opcheck * rename innerKTiles to inner_k_tiles * add unpack and dequant test * additional numerical checks for unpack then dequant * rebase test_ops on main * remove commented out code * skip dynamic opcheck unless torch>=2.5
1 parent 6fa2d96 commit 74846da

File tree

4 files changed

+652
-3
lines changed

4 files changed

+652
-3
lines changed

test/test_ops.py

Lines changed: 224 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
import itertools
2+
3+
import torchao
4+
15
import torch
26
from torch.testing._internal.common_utils import (
37
TestCase,
@@ -6,7 +10,7 @@
610
run_tests,
711
)
812
from torch.testing._internal.optests import opcheck
9-
from torchao.utils import is_fbcode
13+
from torchao.utils import is_fbcode, TORCH_VERSION_AFTER_2_5
1014
from torchao.prototype.quant_llm import from_scaled_tc_fpx
1115
import pytest
1216

@@ -18,6 +22,14 @@
1822
except RuntimeError:
1923
pytest.skip("torchao.ops not available")
2024

25+
from torchao.quantization.utils import (
26+
get_groupwise_affine_qparams,
27+
groupwise_affine_dequantize_tensor_from_qparams,
28+
groupwise_affine_quantize_tensor_from_qparams,
29+
pack_tinygemm_scales_and_zeros,
30+
unpack_tinygemm_scales_and_zeros,
31+
)
32+
2133

2234
class TestOps(TestCase):
2335
def _create_fpx_inputs(self, ebits: int, mbits: int, BS: int, OC: int, IC: int, device):
@@ -61,9 +73,218 @@ def test_quant_llm_linear_correctness(self, ebits, mbits, BS, OC, IC, splitK):
6173
relative_error = error / gt
6274
assert relative_error < 1e-3
6375

64-
6576
instantiate_parametrized_tests(TestOps)
6677

6778

79+
## Tests for `tensor_core_layout`
80+
kTileSizeN = 8
81+
kTileSizeK = 16
82+
83+
SHAPES = [
84+
(4096, 4096),
85+
# Llama 2 GEMM shapes
86+
(4096, 11008),
87+
(11008, 4096),
88+
# Llama 3 GEMM shapes
89+
(4096, 14336),
90+
(14336, 4096),
91+
]
92+
INNERKTILES = [2, 4, 8]
93+
QGROUP_SIZES = [32, 64, 128, 256]
94+
TEST_CONFIGS_UNPACK = list(itertools.product(SHAPES, INNERKTILES))
95+
TEST_CONFIGS_DEQUANT = list(itertools.product(SHAPES, INNERKTILES, QGROUP_SIZES))
96+
97+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
98+
@pytest.mark.parametrize("shape, inner_k_tiles", TEST_CONFIGS_UNPACK, ids=str)
99+
def test_unpack_tensor_core_tiled_layout_correctness(shape, inner_k_tiles):
100+
N, K = shape
101+
assert K % (inner_k_tiles * kTileSizeK) == 0 and N % kTileSizeN == 0
102+
103+
t = torch.randint(0, 16, dtype=torch.int, size=shape, device="cuda")
104+
packed_w = torch.ops.aten._convert_weight_to_int4pack(t, inner_k_tiles)
105+
unpacked = torchao.ops.unpack_tensor_core_tiled_layout(packed_w, inner_k_tiles)
106+
assert torch.equal(t, unpacked)
107+
108+
# TODO: Fix "test_aot_dispatch_dynamic" test failure
109+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
110+
@pytest.mark.parametrize("shape, inner_k_tiles", TEST_CONFIGS_UNPACK , ids=str)
111+
def test_unpack_tensor_core_tiled_layout_op(shape, inner_k_tiles):
112+
test_utils = [
113+
"test_schema",
114+
"test_autograd_registration",
115+
"test_faketensor",
116+
]
117+
118+
# TODO: Figure out why test fails unless torch >= 2.5
119+
if TORCH_VERSION_AFTER_2_5:
120+
test_utils.append("test_aot_dispatch_dynamic")
121+
122+
t = torch.randint(0, 16, dtype=torch.int, size=shape, device="cuda")
123+
packed_w = torch.ops.aten._convert_weight_to_int4pack(t, inner_k_tiles)
124+
125+
opcheck(
126+
torch.ops.torchao.unpack_tensor_core_tiled_layout,
127+
(packed_w, inner_k_tiles),
128+
test_utils=test_utils,
129+
)
130+
131+
def dequant_ref(q, scales, zeros, group_size, nbits=4, dtype=torch.bfloat16):
132+
n, k = q.shape
133+
assert q.dtype == torch.int
134+
135+
n_groups = k // group_size
136+
assert scales.shape[0] == n and scales.shape[1] == n_groups
137+
assert scales.shape == zeros.shape
138+
139+
midpoint = 2 ** (nbits - 1)
140+
141+
#Convert fron u4 -> s4 and upcast to bfloat16
142+
q = q.sub(midpoint).to(dtype)
143+
144+
# Dequantize
145+
q = q.reshape(-1, group_size)
146+
dq = q * scales.reshape(-1, 1) + zeros.reshape(-1, 1)
147+
148+
return dq.reshape(n, k)
149+
150+
151+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
152+
@pytest.mark.parametrize("shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str)
153+
def test_dequantize_tensor_core_tiled_layout_correctness_quant_dequant(shape, inner_k_tiles, group_size):
154+
n, k = shape
155+
dtype = torch.bfloat16
156+
157+
device = "cuda"
158+
159+
t = torch.randn(n, k, dtype=dtype, device=device)
160+
scales, zeros = get_groupwise_affine_qparams(t, n_bit=4, groupsize=group_size, dtype=dtype)
161+
162+
# Quantize
163+
q = groupwise_affine_quantize_tensor_from_qparams(
164+
t, scales, zeros, n_bit=4, groupsize=group_size
165+
)
166+
167+
# Pack to tensor core layout
168+
packed = torch.ops.aten._convert_weight_to_int4pack(q, inner_k_tiles)
169+
scales_and_zeros = pack_tinygemm_scales_and_zeros(scales, zeros)
170+
q_groups = k // group_size
171+
assert scales_and_zeros.shape == torch.Size([q_groups, n, 2])
172+
173+
# Dequantize 'ao' ref
174+
dq_ao = groupwise_affine_dequantize_tensor_from_qparams(
175+
q, scales, zeros, n_bit=4, groupsize=group_size
176+
)
177+
178+
# Dequantize by passing in an identity matrix as the activation
179+
a_eye = torch.eye(k, device=device, dtype=dtype)
180+
dq_id = torch.ops.aten._weight_int4pack_mm(
181+
a_eye,
182+
packed,
183+
group_size,
184+
scales_and_zeros,
185+
).t()
186+
187+
# Actual operation to test
188+
dq_op = torchao.ops.dequantize_tensor_core_tiled_layout(packed, scales_and_zeros, group_size, inner_k_tiles)
189+
190+
# Compare results
191+
diff_ao_id = (dq_id - dq_ao).abs().max()
192+
diff_op_id = (dq_op - dq_id).abs().max()
193+
diff_op_ao = (dq_op - dq_ao).abs().max()
194+
195+
# There are slight numerical differences when dequantizing with an identity matrix when compared to `groupwise_affine_dequantize`
196+
# Since the `dequantize_tensor_core_layout` kernel relies on the same underlying bit twiddling tricks for fast
197+
# conversion from u4 -> s4 -> bf16, the identity matrix dequant hack and `dequantize_tensor_core_layout` are
198+
# expected to give same results, while both will have similar numerical differences to `groupwise_affine_dequantize`.
199+
200+
# Test that the `dequant` kernel gives same results as identity matrix-based dequant
201+
assert diff_op_id == 0
202+
203+
# Test that the `dequant` kernel gives same numerical diffs as the `groupwise_affine_dequantize` when compared against the identity matrix
204+
assert diff_op_ao == diff_ao_id
205+
206+
assert diff_op_ao < 1e-1
207+
208+
# This test differs from one above in that it uses `unpack_tensor_core_tiled_layout` to unpack then dequantize
209+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
210+
@pytest.mark.parametrize("shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str)
211+
def test_dequantize_tensor_core_tiled_layout_correctness_unpack_and_dequant(shape, inner_k_tiles, group_size):
212+
n, k = shape
213+
dtype = torch.bfloat16
214+
device = "cuda"
215+
216+
# Quantize and pack
217+
t = torch.randn(n, k, dtype=dtype, device=device)
218+
scales, zeros = get_groupwise_affine_qparams(t, n_bit=4, groupsize=group_size, dtype=dtype)
219+
q = groupwise_affine_quantize_tensor_from_qparams(
220+
t, scales, zeros, n_bit=4, groupsize=group_size
221+
)
222+
223+
packed = torch.ops.aten._convert_weight_to_int4pack(q, inner_k_tiles)
224+
scales_and_zeros = pack_tinygemm_scales_and_zeros(scales, zeros)
225+
226+
# Unpack and dequantize
227+
unpacked = torchao.ops.unpack_tensor_core_tiled_layout(packed, inner_k_tiles)
228+
dq_ao = groupwise_affine_dequantize_tensor_from_qparams(
229+
unpacked, scales, zeros, n_bit=4, groupsize=group_size
230+
)
231+
232+
# Dequantize by passing in an identity matrix as the activation
233+
a_eye = torch.eye(k, device=device, dtype=dtype)
234+
dq_id = torch.ops.aten._weight_int4pack_mm(
235+
a_eye,
236+
packed,
237+
group_size,
238+
scales_and_zeros,
239+
).t()
240+
241+
# Actual operation to test
242+
dq_op = torchao.ops.dequantize_tensor_core_tiled_layout(packed, scales_and_zeros, group_size, inner_k_tiles)
243+
244+
# Compare results
245+
diff_ao_id = (dq_id - dq_ao).abs().max()
246+
diff_op_id = (dq_op - dq_id).abs().max()
247+
diff_op_ao = (dq_op - dq_ao).abs().max()
248+
249+
# There are slight numerical differences when dequantizing with an identity matrix when compared to `groupwise_affine_dequantize`
250+
# Since the `dequantize_tensor_core_layout` kernel relies on the same underlying bit twiddling tricks for fast
251+
# conversion from u4 -> s4 -> bf16, the identity matrix dequant hack and `dequantize_tensor_core_layout` are
252+
# expected to give same results, while both will have similar numerical differences to `groupwise_affine_dequantize`.
253+
254+
# Test that the `dequant` kernel gives same results as identity matrix-based dequant
255+
assert diff_op_id == 0
256+
257+
# Test that the `dequant` kernel gives same numerical diffs as the `groupwise_affine_dequantize` when compared against the identity matrix
258+
assert diff_op_ao == diff_ao_id
259+
260+
assert diff_op_ao < 1e-1
261+
262+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
263+
@pytest.mark.parametrize("shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str)
264+
def test_dequantize_tensor_core_tiled_layout_op(shape, inner_k_tiles, group_size):
265+
n, k = shape
266+
device = "cuda"
267+
268+
q = torch.randint(0, 16, shape, dtype=torch.int, device=device)
269+
packed_w = torch._convert_weight_to_int4pack(q, inner_k_tiles)
270+
q_groups = k // group_size
271+
scales = torch.randn(n, q_groups, dtype=torch.bfloat16, device=device)
272+
zeros = torch.randn_like(scales)
273+
scales_and_zeros = pack_tinygemm_scales_and_zeros(scales, zeros)
274+
275+
test_utils = [
276+
"test_schema",
277+
"test_autograd_registration",
278+
"test_faketensor",
279+
]
280+
# TODO: Figure out why test fails unless torch >= 2.5
281+
if TORCH_VERSION_AFTER_2_5:
282+
test_utils.append("test_aot_dispatch_dynamic")
283+
opcheck(
284+
torch.ops.torchao.dequantize_tensor_core_tiled_layout,
285+
(packed_w, scales_and_zeros, group_size, inner_k_tiles),
286+
test_utils=test_utils,
287+
)
288+
68289
if __name__ == "__main__":
69-
run_tests()
290+
run_tests()

0 commit comments

Comments
 (0)