Skip to content

Commit 8699877

Browse files
authored
Merge branch 'main' into feat/sparse-marlin-gemm-op
2 parents c18f6bd + 0ed3090 commit 8699877

File tree

64 files changed

+869
-494
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

64 files changed

+869
-494
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ pip install torchao --extra-index-url https://download.pytorch.org/whl/cu121 # f
148148

149149
Nightly Release
150150
```Shell
151-
pip install --pre torchao-nightly --index-url https://download.pytorch.org/whl/nightly/cu121 # full options are cpu/cu118/cu121/cu124
151+
pip install --pre torchao --index-url https://download.pytorch.org/whl/nightly/cu121 # full options are cpu/cu118/cu121/cu124
152152
```
153153

154154
From source

benchmarks/benchmark_aq.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
_replace_with_custom_fn_if_matches_filter,
1818
)
1919
import copy
20+
from torchao.utils import unwrap_tensor_subclass
2021

2122
def _int8wo_api(mod, **kwargs):
2223
if TORCH_VERSION_AT_LEAST_2_4:
@@ -133,15 +134,17 @@ def _bench_quantized_tensor_subclass_perf(api, ref_api, M, N, K, kwargs=None):
133134
WARMUP = 20
134135
RUNS = 100
135136

137+
torch._dynamo.reset()
136138
m_ref = torch.compile(m_ref, mode='max-autotune', fullgraph=True)
137139
benchmark_model(m_ref, WARMUP, example_inputs)
138140
ref_elapsed_time = benchmark_model(m_ref, RUNS, example_inputs)
139141

142+
torch._dynamo.reset()
140143
m = torch.compile(m, mode='max-autotune', fullgraph=True)
141144
benchmark_model(m, WARMUP, example_inputs)
142145
elapsed_time = benchmark_model(m, RUNS, example_inputs)
143146

144-
147+
torch._dynamo.reset()
145148
m_bf16 = torch.compile(m_bf16, mode='max-autotune', fullgraph=True)
146149
benchmark_model(m_bf16, WARMUP, example_inputs)
147150
bf16_elapsed_time = benchmark_model(m_bf16, RUNS, example_inputs)

benchmarks/float8/float8_roofline.py

Lines changed: 10 additions & 212 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
17
"""
28
This is a script to estimate the benefit from converting a `torch.nn.Linear`
39
layer to float8, by estimating the difference in e2e GPU kernel time between:
@@ -45,26 +51,10 @@
4551
import torch
4652
import torch.utils.benchmark as benchmark
4753

48-
BYTES_PER_EL_FLOAT8 = 1
49-
BYTES_PER_EL_BF16 = 2
50-
51-
# https://www.nvidia.com/en-us/data-center/h100/, divide by 2 because no sparsity
52-
H100_BF16_PEAK_TOPS = 989e12
53-
H100_FP8_PEAK_TOPS = 1979e12
54-
55-
# 2.4 TB per second, custom to Meta's H100 variant
56-
H100_PEAK_MEM_BW_BYTES_SEC = 2.4e12
57-
58-
# based on quick experimental observation with sample large inputs
59-
H100_PCT_ACHIEVABLE_GEMM_TOPS = 0.6
60-
61-
# based on previous experience looking at pointwise triton kernels with large inputs,
62-
# which would hit about 2.2k GBPS on Meta's H100 variant
63-
H100_PCT_ACHIEVABLE_MEM_BW = 0.92
64-
65-
# Source: run a triton kernel with a single element read/write on an H100 and
66-
# measure GPU time from the trace
67-
TRITON_KERNEL_1_ELEMENT_TIME_SEC = 0.002 * 0.001
54+
from torchao.float8.roofline_utils import (
55+
get_gemm_time_sympy,
56+
get_float8_mem_sympy,
57+
)
6858

6959

7060
def benchmark_fn_in_sec(f, *args, **kwargs):
@@ -78,90 +68,6 @@ def benchmark_fn_in_sec(f, *args, **kwargs):
7868
return measurement.mean
7969

8070

81-
def get_tensor_memory_traffic_bytes(
82-
dim0,
83-
dim1,
84-
scaling_type: str,
85-
fuse_with_prev=False,
86-
model_torch_compile_limitations=False,
87-
):
88-
# assumes input bf16, output f8
89-
numel = dim0 * dim1
90-
91-
if scaling_type == "dynamic":
92-
# x_bf16 = ...
93-
# kernel 1: x_bf16 -> max_abs_stage_1 -> tmp
94-
# kernel 2 (not modeled): tmp -> max_abs_stage_2 -> max_abs
95-
# kernel 3: x_bf16, max_abs -> to_float8 -> x_fp8
96-
97-
if fuse_with_prev:
98-
kernel_1_rw = 0
99-
else:
100-
# kernel 1: read numel, write 0 (assume size(tmp) ~ 0)
101-
kernel_1_rw = BYTES_PER_EL_BF16 * numel
102-
103-
# kernel 3: read in bf16, write twice in float8 (row-major and col-major)
104-
kernel_3_rw = BYTES_PER_EL_BF16 * numel + 2 * BYTES_PER_EL_FLOAT8 * numel
105-
106-
if model_torch_compile_limitations:
107-
# today, the kernel to do cast_to_fp8_row_major_and_col_major(input_bf16, ...)
108-
# has an extra memory read of the input in fp8
109-
# context: https://github.com/pytorch/pytorch/issues/130015
110-
tc_adjustment = numel * BYTES_PER_EL_FLOAT8
111-
else:
112-
tc_adjustment = 0
113-
114-
return kernel_1_rw + kernel_3_rw + tc_adjustment
115-
116-
else:
117-
assert scaling_type == "delayed", "unsupported"
118-
# x_bf16 = ...
119-
# kernel 1: x_bf16 -> max_abs_stage_1_and_to_float8 -> x_float8, tmp
120-
# kernel 2 (not modeled): tmp -> max_abs_stage_2 -> max_abs
121-
# kernel 3 (not modeled): scale -> reciprocal -> inv_scale
122-
123-
if fuse_with_prev:
124-
kernel_1_r = 0
125-
else:
126-
kernel_1_r = numel * BYTES_PER_EL_BF16
127-
# write twice: once in row major, once in col-major
128-
kernel_1_w = numel * BYTES_PER_EL_FLOAT8 * 2
129-
130-
if model_torch_compile_limitations:
131-
# today, the kernel to do cast_to_fp8_row_major_and_col_major(input_bf16, ...)
132-
# has an extra memory read of the input in fp8
133-
# context: https://github.com/pytorch/pytorch/issues/130015
134-
tc_adjustment = numel * BYTES_PER_EL_FLOAT8
135-
136-
# https://github.com/pytorch/pytorch/issues/128063
137-
# instead of
138-
# kernel 1: x_bf16 -> max(abs(x)), x_fp8
139-
# kernel 2: not modeled
140-
# kernel 3: not modeled
141-
# we get
142-
# kernel 1: x_bf16 -> max(abs(x))
143-
# reads: same as before
144-
# writes: 0
145-
# ...
146-
# kernel 4: x_bf16, scale -> x_fp8
147-
# reads: numel * BYTES_PER_EL_BF16
148-
# writes: 2 * numel * BYTES_PER_EL_FLOAT8
149-
# Note that assuming worst case, this issue brings the memory
150-
# traffic for delayed scaling to be equal to that of dynamic scaling.
151-
tc_adjustment += (
152-
# subtract writes from kernel 1
153-
-1 * 2 * numel * BYTES_PER_EL_FLOAT8
154-
# add reads for kernel 4
155-
+ numel * BYTES_PER_EL_BF16
156-
# add writes for kernel 4
157-
+ 2 * numel * BYTES_PER_EL_FLOAT8
158-
)
159-
else:
160-
tc_adjustment = 0
161-
162-
return kernel_1_r + kernel_1_w + tc_adjustment
163-
164-
16571
def get_gemm_times_cache(gemm_benchmarks_file: str):
16672
cache = {}
16773
with open(gemm_benchmarks_file, 'r') as f:
@@ -176,114 +82,6 @@ def get_gemm_times_cache(gemm_benchmarks_file: str):
17682
return cache
17783

17884

179-
def get_gemm_time_sympy(M, K, N, dtype):
180-
gemm_ops = 2 * M * K * N + 2 * M * N * K + 2 * K * M * N
181-
if dtype is torch.bfloat16:
182-
peak_tops = H100_BF16_PEAK_TOPS
183-
elif dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
184-
peak_tops = H100_FP8_PEAK_TOPS
185-
gemm_time_s = gemm_ops / peak_tops / H100_PCT_ACHIEVABLE_GEMM_TOPS
186-
return gemm_time_s
187-
188-
189-
def get_float8_mem_sympy(
190-
M,
191-
K,
192-
N,
193-
model_torch_compile_limitations: bool = False,
194-
scaling_type_input: str = "dynamic",
195-
scaling_type_weight: str = "dynamic",
196-
scaling_type_grad_output: str = "dynamic",
197-
):
198-
199-
assert scaling_type_input in ("dynamic", "delayed"), "unsupported"
200-
assert scaling_type_weight in ("dynamic", "delayed"), "unsupported"
201-
assert scaling_type_grad_output in ("dynamic", "delayed"), "unsupported"
202-
203-
# there are three gemms in the fwd/bwd of a linear:
204-
#
205-
# input @ weight_t = output
206-
# MxK @ KxN => MxN
207-
#
208-
# grad_output @ weight = grad_input
209-
# MxN @ NxK => MxK
210-
#
211-
# input_t @ grad_output = grad_weight
212-
# KxM @ MxN => KxN
213-
214-
#
215-
# forward - output
216-
#
217-
fwd_fp8_input_mem = get_tensor_memory_traffic_bytes(
218-
M, K, scaling_type_input, fuse_with_prev=True,
219-
model_torch_compile_limitations=model_torch_compile_limitations)
220-
fwd_fp8_weight_mem = get_tensor_memory_traffic_bytes(
221-
K, N, scaling_type_weight, fuse_with_prev=False,
222-
model_torch_compile_limitations=model_torch_compile_limitations)
223-
fwd_fp8_total_mem = fwd_fp8_input_mem + fwd_fp8_weight_mem
224-
225-
#
226-
# backward - grad_input
227-
#
228-
gi_fp8_grad_output_mem = get_tensor_memory_traffic_bytes(
229-
M, N, scaling_type_grad_output, fuse_with_prev=True,
230-
model_torch_compile_limitations=model_torch_compile_limitations)
231-
# already casted, assuming that we save weight from fw to bw
232-
# TODO: model this if FSDP float8 all-gather is on
233-
# TODO: model this if we don't save weight from fw to bw, and recompute instead
234-
gi_fp8_weight_mem = 0
235-
236-
#
237-
# backward - grad_weight
238-
#
239-
# TODO: model this if we don't save fp8 input from fw to bw
240-
gw_fp8_input_t_mem = 0 # already casted
241-
# this should be always 0
242-
gw_fp8_grad_output_mem = 0 # already casted
243-
244-
bwd_fp8_total_mem = \
245-
gi_fp8_grad_output_mem + gi_fp8_weight_mem + \
246-
gw_fp8_input_t_mem + gw_fp8_grad_output_mem
247-
fp8_total_mem = fwd_fp8_total_mem + bwd_fp8_total_mem
248-
fp8_mem_time_s = (
249-
fp8_total_mem / H100_PEAK_MEM_BW_BYTES_SEC / H100_PCT_ACHIEVABLE_MEM_BW
250-
)
251-
252-
# Adjust final estimate for small kernel launches
253-
# note that we do this adjustment here because we are assuming a minimal
254-
# kernel overhead in the units of seconds, and the per-gemm-input memory
255-
# estimations are in the units of bytes.
256-
num_extra_kernels = 0
257-
if scaling_type_input == "dynamic":
258-
# second stage of max-abs reduction
259-
num_extra_kernels += 1
260-
elif scaling_type_input == "delayed":
261-
# second stage of max-abs reduction
262-
num_extra_kernels += 1
263-
# reciprocal of scale
264-
num_extra_kernels += 1
265-
if scaling_type_weight == "dynamic":
266-
# second stage of max-abs reduction
267-
num_extra_kernels += 1
268-
elif scaling_type_weight == "delayed":
269-
# second stage of max-abs reduction
270-
num_extra_kernels += 1
271-
# reciprocal of scale
272-
num_extra_kernels += 1
273-
if scaling_type_grad_output == "dynamic":
274-
# second stage of max-abs reduction
275-
num_extra_kernels += 1
276-
elif scaling_type_grad_output == "delayed":
277-
# second stage of max-abs reduction
278-
num_extra_kernels += 1
279-
# reciprocal of scale
280-
num_extra_kernels += 1
281-
282-
extra_kernel_overhead_s = num_extra_kernels * TRITON_KERNEL_1_ELEMENT_TIME_SEC
283-
284-
return fp8_mem_time_s + extra_kernel_overhead_s
285-
286-
28785
def run(
28886
outfile: str,
28987
gemm_time_strategy: str = "benchmarks",

benchmarks/quantized_training/pretrain_llama2.py

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# pre-train a mini Llama2 on TinyStories with INT8 quantized training
2-
# pip install transformers sentencepiece wandb
2+
# pip install huggingface_hub sentencepiece wandb
33
#
44
# BF16 baseline: python benchmarks/quantized_training/pretrain_llama2.py --seed 2024 --n_steps 10_000 --compile
55
# INT8 QT: python benchmarks/quantized_training/pretrain_llama2.py --seed 2024 --n_steps 10_000 --compile --quantize int8_weight_only
@@ -9,21 +9,33 @@
99
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
1010

1111
import argparse
12+
from functools import partial
1213
from pathlib import Path
1314

1415
import numpy as np
1516
import torch
1617
import wandb
18+
from torch.utils.checkpoint import checkpoint
1719
from tqdm import tqdm
18-
from transformers import LlamaConfig, LlamaForCausalLM
1920

21+
from torchao._models.llama.model import ModelArgs, Transformer
2022
from torchao.prototype import low_bit_optim
2123
from torchao.prototype.quantized_training import int8_weight_only_quantized_training
2224
from torchao.quantization.quant_api import quantize_
2325

2426

25-
def get_loss(model: LlamaForCausalLM, batch: torch.Tensor):
26-
return model(batch, labels=batch).loss
27+
# hack from fairseq
28+
# https://github.com/facebookresearch/fairseq/blob/920a548ca770fb1a951f7f4289b4d3a0c1bc226f/fairseq/modules/checkpoint_activations.py
29+
def enable_activation_checkpointing(m: torch.nn.Module):
30+
assert not hasattr(m, "_forward")
31+
m._forward = m.forward
32+
m.forward = partial(checkpoint, m.forward)
33+
34+
35+
def get_loss(model: Transformer, batch: torch.Tensor):
36+
logits = model(batch)[:, :-1].flatten(0, 1)
37+
labels = batch[:, 1:].flatten()
38+
return torch.nn.functional.cross_entropy(logits, labels)
2739

2840

2941
def get_tinystories():
@@ -91,17 +103,19 @@ def get_tinystories():
91103
if args.seed is not None:
92104
torch.manual_seed(args.seed)
93105

94-
config = LlamaConfig(
95-
hidden_size=args.d_model,
106+
config = ModelArgs(
107+
block_size=args.seq_len,
108+
n_layer=args.depth,
109+
n_head=args.d_model // args.head_dim,
110+
dim=args.d_model,
96111
intermediate_size=args.ffn_size,
97-
num_hidden_layers=args.depth,
98-
num_attention_heads=args.d_model // args.head_dim,
99-
max_position_embeddings=args.seq_len,
100-
use_cache=False,
101112
)
102-
model = LlamaForCausalLM(config).bfloat16().cuda()
113+
model = Transformer(config).bfloat16().cuda()
114+
with torch.device("cuda"):
115+
model.setup_caches(args.batch_size, args.seq_len, training=True)
103116
if args.activation_checkpointing:
104-
model.gradient_checkpointing_enable()
117+
for layer in model.layers:
118+
enable_activation_checkpointing(layer)
105119
if args.quantize == "int8_weight_only":
106120
quantize_(model, int8_weight_only_quantized_training(), set_inductor_config=False)
107121
elif args.quantize is not None:

scripts/download.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def hf_download(repo_id: Optional[str] = None, hf_token: Optional[str] = None) -
1515
from huggingface_hub import snapshot_download
1616
os.makedirs(f"checkpoints/{repo_id}", exist_ok=True)
1717
try:
18-
snapshot_download(repo_id, local_dir=f"checkpoints/{repo_id}", local_dir_use_symlinks=False, token=hf_token)
18+
snapshot_download(repo_id, local_dir=f"checkpoints/{repo_id}", local_dir_use_symlinks=False, token=hf_token, ignore_patterns="*.safetensors")
1919
except HTTPError as e:
2020
if e.response.status_code == 401:
2121
print("You need to pass a valid `--hf_token=...` to download private checkpoints.")

test/dtypes/test_affine_quantized.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,23 @@ def test_weights_only(self):
5151
else:
5252
_ = torch.load(f, weights_only=False)
5353

54+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
55+
def test_to_device(self):
56+
from torchao.quantization import quantize_
57+
for apply_quant in [int8_weight_only(), int8_dynamic_activation_int4_weight(), int8_dynamic_activation_int8_weight()]:
58+
l = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
59+
ql = apply_quant(l)
60+
ql.to("cuda")
61+
62+
l = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
63+
ql = apply_quant(l)
64+
ql.to(device="cuda")
65+
66+
l = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
67+
ql = apply_quant(l)
68+
ql.cuda()
69+
70+
5471

5572
if __name__ == "__main__":
5673
run_tests()

0 commit comments

Comments
 (0)