Skip to content

Commit e90e280

Browse files
committed
Merge branch 'unpack_int4' of https://github.com/jeromeku/ao into unpack_int4
2 parents 75df5f5 + d1bd61b commit e90e280

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

78 files changed

+3853
-1589
lines changed

.github/scripts/trymerge.py

Lines changed: 30 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1163,7 +1163,6 @@ def merge_into(
11631163
# Finally, upload the record to Rockset. The list of pending and failed
11641164
# checks are at the time of the merge
11651165
save_merge_record(
1166-
collection=ROCKSET_MERGES_COLLECTION,
11671166
comment_id=comment_id,
11681167
pr_num=self.pr_num,
11691168
owner=self.org,
@@ -1179,10 +1178,8 @@ def merge_into(
11791178
merge_base_sha=self.get_merge_base(),
11801179
merge_commit_sha=merge_commit_sha,
11811180
is_failed=False,
1182-
dry_run=dry_run,
11831181
skip_mandatory_checks=skip_mandatory_checks,
11841182
ignore_current=bool(ignore_current_checks),
1185-
workspace=ROCKSET_MERGES_WORKSPACE,
11861183
)
11871184
else:
11881185
print("Missing comment ID or PR number, couldn't upload to Rockset")
@@ -1489,7 +1486,6 @@ def checks_to_markdown_bullets(
14891486

14901487
@retries_decorator()
14911488
def save_merge_record(
1492-
collection: str,
14931489
comment_id: int,
14941490
pr_num: int,
14951491
owner: str,
@@ -1505,59 +1501,44 @@ def save_merge_record(
15051501
merge_base_sha: str,
15061502
merge_commit_sha: str = "",
15071503
is_failed: bool = False,
1508-
dry_run: bool = False,
15091504
skip_mandatory_checks: bool = False,
15101505
ignore_current: bool = False,
15111506
error: str = "",
1512-
workspace: str = "commons",
15131507
) -> None:
15141508
"""
1515-
This saves the merge records into Rockset, so we can query them (for fun and profit)
1509+
This saves the merge records as a json, which can later be uploaded to s3
15161510
"""
1517-
if dry_run:
1518-
# Decide not to save the record to Rockset if dry-run is set to not pollute
1519-
# the collection
1520-
return
1521-
1522-
try:
1523-
import rockset # type: ignore[import]
1524-
1525-
# Prepare the record to be written into Rockset
1526-
data = [
1527-
{
1528-
"comment_id": comment_id,
1529-
"pr_num": pr_num,
1530-
"owner": owner,
1531-
"project": project,
1532-
"author": author,
1533-
"pending_checks": pending_checks,
1534-
"failed_checks": failed_checks,
1535-
"ignore_current_checks": ignore_current_checks,
1536-
"broken_trunk_checks": broken_trunk_checks,
1537-
"flaky_checks": flaky_checks,
1538-
"unstable_checks": unstable_checks,
1539-
"last_commit_sha": last_commit_sha,
1540-
"merge_base_sha": merge_base_sha,
1541-
"merge_commit_sha": merge_commit_sha,
1542-
"is_failed": is_failed,
1543-
"skip_mandatory_checks": skip_mandatory_checks,
1544-
"ignore_current": ignore_current,
1545-
"error": error,
1546-
}
1547-
]
15481511

1549-
client = rockset.RocksetClient(
1550-
host="api.usw2a1.rockset.com", api_key=os.environ["ROCKSET_API_KEY"]
1551-
)
1552-
client.Documents.add_documents(
1553-
collection=collection,
1554-
data=data,
1555-
workspace=workspace,
1556-
)
1512+
# Prepare the record to be written into Rockset
1513+
data = [
1514+
{
1515+
"comment_id": comment_id,
1516+
"pr_num": pr_num,
1517+
"owner": owner,
1518+
"project": project,
1519+
"author": author,
1520+
"pending_checks": pending_checks,
1521+
"failed_checks": failed_checks,
1522+
"ignore_current_checks": ignore_current_checks,
1523+
"broken_trunk_checks": broken_trunk_checks,
1524+
"flaky_checks": flaky_checks,
1525+
"unstable_checks": unstable_checks,
1526+
"last_commit_sha": last_commit_sha,
1527+
"merge_base_sha": merge_base_sha,
1528+
"merge_commit_sha": merge_commit_sha,
1529+
"is_failed": is_failed,
1530+
"skip_mandatory_checks": skip_mandatory_checks,
1531+
"ignore_current": ignore_current,
1532+
"error": error,
1533+
# This is a unique identifier for the record for deduping purposes
1534+
# in rockset. Any unique string would work
1535+
"_id": f"{project}-{pr_num}-{comment_id}-{os.environ.get('GITHUB_RUN_ID')}",
1536+
}
1537+
]
1538+
repo_root = Path(__file__).resolve().parent.parent.parent
15571539

1558-
except ModuleNotFoundError:
1559-
print("Rockset is missing, no record will be saved")
1560-
return
1540+
with open(repo_root / "merge_record.json", "w") as f:
1541+
json.dump(data, f)
15611542

15621543

15631544
@retries_decorator(rc=[])
@@ -2374,7 +2355,6 @@ def handle_exception(e: Exception, title: str = "Merge failed") -> None:
23742355
# list of pending and failed checks here, but they are not really
23752356
# needed at the moment
23762357
save_merge_record(
2377-
collection=ROCKSET_MERGES_COLLECTION,
23782358
comment_id=args.comment_id,
23792359
pr_num=args.pr_num,
23802360
owner=org,
@@ -2389,11 +2369,9 @@ def handle_exception(e: Exception, title: str = "Merge failed") -> None:
23892369
last_commit_sha=pr.last_commit().get("oid", ""),
23902370
merge_base_sha=pr.get_merge_base(),
23912371
is_failed=True,
2392-
dry_run=args.dry_run,
23932372
skip_mandatory_checks=args.force,
23942373
ignore_current=args.ignore_current,
23952374
error=str(e),
2396-
workspace=ROCKSET_MERGES_WORKSPACE,
23972375
)
23982376
else:
23992377
print("Missing comment ID or PR number, couldn't upload to Rockset")

.github/scripts/validate_binaries.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
pip install ${PYTORCH_PIP_PREFIX} torchao --index-url ${PYTORCH_PIP_DOWNLOAD_URL}
2+
python ./test/smoke_tests/smoke_tests.py

.github/workflows/regression_test.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ jobs:
5454

5555
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
5656
with:
57+
timeout: 60
5758
runner: ${{ matrix.runs-on }}
5859
gpu-arch-type: ${{ matrix.gpu-arch-type }}
5960
gpu-arch-version: ${{ matrix.gpu-arch-version }}

.github/workflows/trymerge.yml

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ jobs:
99
name: try_merge_pr_${{ github.event.client_payload.pr_num }}
1010
runs-on: ubuntu-latest
1111
environment: pytorchbot-env
12+
permissions:
13+
id-token: write
1214
env:
1315
GH_RUN_URL: ${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}
1416
steps:
@@ -26,6 +28,8 @@ jobs:
2628
check-latest: false
2729
cache: pip
2830
architecture: x64
31+
# TODO (huydhn): get rid of Rockset
32+
- run: pip install pyyaml==6.0 rockset==1.0.3
2933

3034
- name: Setup committer id
3135
run: |
@@ -36,8 +40,14 @@ jobs:
3640
env:
3741
GITHUB_TOKEN: ${{ secrets.PYTORCH_MERGEBOT_TOKEN }}
3842
PR_NUM: ${{ github.event.client_payload.pr_num }}
43+
FORCE: ${{ github.event.client_payload.force}}
3944
COMMENT_ID: ${{ github.event.client_payload.comment_id }}
4045
GIT_REMOTE_URL: https://github.com/pytorch/ao
46+
REBASE: ${{ github.event.client_payload.rebase }}
47+
IGNORE_CURRENT: ${{ github.event.client_payload.ignore_current }}
48+
ROCKSET_API_KEY: ${{ secrets.ROCKSET_API_KEY }}
49+
DRCI_BOT_KEY: ${{ secrets.DRCI_BOT_KEY }}
50+
GITHUB_RUN_ID: ${{ github.run_id }}
4151
run: |
4252
set -x
4353
if [ -n "${FORCE}" ]; then
@@ -58,6 +68,22 @@ jobs:
5868
python3 .github/scripts/trymerge.py "${PR_NUM}"
5969
fi
6070
71+
- name: configure aws credentials
72+
uses: aws-actions/configure-aws-credentials@v3
73+
continue-on-error: true
74+
with:
75+
role-to-assume: arn:aws:iam::308535385114:role/upload_to_ossci_raw_job_status
76+
aws-region: us-east-1
77+
78+
- name: Upload merge record to s3
79+
if: always()
80+
continue-on-error: true
81+
uses: seemethere/upload-artifact-s3@v5
82+
with:
83+
s3-bucket: ossci-raw-job-status
84+
s3-prefix: merges/${{ github.repository }}/${{ github.event.client_payload.pr_num }}/${{ github.event.client_payload.comment_id }}/${{ github.run_id }}
85+
path: merge_record.json
86+
6187
# We want newer merge commands to supercede old ones
6288
concurrency:
6389
group: try-merge-${{ github.event.client_payload.pr_num }}
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
name: Validate binaries
2+
3+
on:
4+
workflow_call:
5+
inputs:
6+
channel:
7+
description: "Channel to use (nightly, test, release, all)"
8+
required: false
9+
type: string
10+
default: release
11+
ref:
12+
description: "Reference to checkout, defaults to empty"
13+
default: ""
14+
required: false
15+
type: string
16+
workflow_dispatch:
17+
inputs:
18+
channel:
19+
description: "Channel to use (nightly, test, release, all)"
20+
required: true
21+
type: choice
22+
options:
23+
- release
24+
- nightly
25+
- test
26+
- all
27+
ref:
28+
description: "Reference to checkout, defaults to empty"
29+
default: ""
30+
required: false
31+
type: string
32+
pytorch_version:
33+
description: "PyTorch version to validate (ie. 2.0, 2.2.2, etc.) - optional"
34+
default: ""
35+
required: false
36+
type: string
37+
jobs:
38+
validate-binaries:
39+
uses: pytorch/test-infra/.github/workflows/validate-domain-library.yml@main
40+
with:
41+
package_type: "wheel"
42+
version: ${{ inputs.version }}
43+
os: "linux"
44+
channel: ${{ inputs.channel }}
45+
repository: "pytorch/ao"
46+
with_cuda: "enable"
47+
with_rocm: "disable"
48+
smoke_test: "source ./.github/scripts/validate_binaries.sh"
49+
install_torch: true

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ env
127127
.circleci/scripts/COMMIT_MSG
128128
scripts/release_notes/*.json
129129
sccache-stats*.json
130+
merge_record.json
130131

131132
# These files get copied over on invoking setup.py
132133
torchgen/packaged/*

README.md

Lines changed: 61 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ All with no intrusive code changes and minimal accuracy degradation.
1919
Quantizing your models is a 1 liner that should work on any model with an `nn.Linear` including your favorite HuggingFace model. You can find a more comprehensive usage instructions [here](torchao/quantization/) and a HuggingFace inference example [here](scripts/hf_eval.py)
2020

2121
```python
22-
from torchao.quantization.quant_api import quantize
23-
m = quantize(m, "int4wo")
22+
from torchao.quantization.quant_api import quantize, int4_weight_only
23+
m = quantize(m, int4_weight_only())
2424
```
2525

2626
Benchmarks are run on a machine with a single A100 GPU using the script in `_models/llama` which generates text in a latency-optimized way (batchsize=1)
@@ -29,15 +29,17 @@ The models used were `meta-llama/Llama-2-7b-chat-hf` and `meta-llama/Meta-Llama-
2929

3030
| Model | Technique | wikitext-perplexity | Tokens/Second | Memory Bandwidth (GB/s) | Peak Memory (GB) | Model Size (GB) |
3131
| ----------- | ------------------ | ------------------- | ------------- | ----------------------- | ---------------- | --------------- |
32-
| Llama-2-7B | Base (bfloat16) | 12.212 | 105.02 | 1387.78 | 13.21 | 13.90 |
33-
| | int8dq | 12.262 | 9.40 | 62.26 | 6.62 | 8.61 |
34-
| | int8wo | 12.204 | 147.03 | 973.54 | 6.62 | 8.95 |
35-
| | int4wo-64 | 12.843 | 199.81 | 746.45 | 3.74 | 4.75 |
36-
| | int4wo-64-GPTQ | 12.489 | 199.81 | 746.45 | 3.74 | 4.75 |
37-
| Llama-3-8B | Base (bfloat16) | | 94.91 | 1424.58 | 15.01 | 16.43 |
38-
| | int8dq | | 8.41 | 63.23 | 7.52 | 9.24 |
39-
| | int8wo | | 136.75 | 1028.38 | 7.52 | 10.42 |
40-
| | int4wo-64 | | 179.41 | 757.45 | 4.22 | 6.88 |
32+
| Llama-2-7B | Base (bfloat16) | 12.212 | 105.14 | 1389.35 | 13.88 | 13.21 |
33+
| | int8dq | 12.262 | 9.20 | 60.93 | 8.33 | 6.62 |
34+
| | int8wo | 12.204 | 150.18 | 994.40 | 8.95 | 6.62 |
35+
| | int4wo-64 | 12.843 | 199.86 | 746.66 | 4.50 | 3.74 |
36+
| | int4wo-64-GPTQ | 12.489 | 199.86 | 746.66 | 4.50 | 3.74 |
37+
| | autoquant | 12.204 | 159.22 | 1069.87 | 8.91 | 6.72 |
38+
| Llama-3-8B | Base (bfloat16) | N/A | 94.97 | 1425.55 | 16.43 | 15.01 |
39+
| | int8dq | N/A | 8.44 | 63.45 | 8.98 | 7.52 |
40+
| | int8wo | N/A | 139.76 | 1051.02 | 10.42 | 7.52 |
41+
| | int4wo-64 | N/A | 179.44 | 757.60 | 6.62 | 4.22 |
42+
| | autoquant | N/A | 137.71 | 1037.74 | 11.08 | 7.54 |
4143

4244
note: Int8 dynamic quantization works best on compute bound as opposed to memory bound models. Some relatable examples might be [SAM](https://github.com/pytorch-labs/segment-anything-fast) which is compute bound vs Llama at batchsize=1 which is memory bound.
4345

@@ -50,7 +52,20 @@ And a quick crash course on inference quantization to help parse the above table
5052

5153
In some cases we rewrote popular GenAI models to be significantly faster in native PyTorch as in no C++/CUDA to achieve at the time SOTA inference performance. These involve more intrusive code changes.
5254

53-
* 8x speedups for Image segmentation models with [sam-fast](https://pytorch.org/blog/accelerating-generative-ai)
55+
* 9.5x speedups for Image segmentation models with [sam-fast](https://pytorch.org/blog/accelerating-generative-ai) compared to vanilla [sam](https://github.com/facebookresearch/segment-anything).
56+
* 1.16x speedup when composing int8 quantization with 2:4 sparsity against the accelerated baseline `bfloat16` dtype and `torch.compile="max_autotune"`.
57+
58+
| Model Type | Technique | img/s | memory (MiB) | mIoU (coco2017 val) | relative speedup | relative accuracy |
59+
|------------|------------------------------------------------------------------------------------------------------|-------|--------------|---------------------|------------------|-------------------|
60+
| ViT-h | sam (float32, eager) | 2.78 | 28806 | 0.58 | baseline | baseline |
61+
| | sam (bfloat16, eager) | 14.85 | 14424 | 0.58 | **5.34x** | **100%** |
62+
| | sam-fast (bfloat16, max-autotune) | 22.75 | 15172 | 0.58 | **8.18x** | **100%** |
63+
| | int8 dynamic quant (attn + mlp) | 24.91 | 15154 | 0.58 | **8.96x** | **100%** |
64+
| | 2:4 sparsity (mlp only) | 24.81 | 15632 | 0.57 | **8.92x** | **98%** |
65+
| | int8 dynamic quant (attn)<br>int8 dynamic quant + 2:4 sparsity (mlp lin1)<br>2:4 sparsity (mlp lin2) | 26.46 | 14865 | 0.57 | **9.52x** | **98%** |
66+
67+
The relative speedup is measured purely across the image encoder (ViT) of the model, where we apply our model optimizations. Benchmarks ran on an NVIDIA-A100-80GB with batch_size=32
68+
5469
* 10x speedups for Language models with [gpt-fast](https://pytorch.org/blog/accelerating-generative-ai-2)
5570
* 3x speedup for Diffusion models with [sd-fast](https://pytorch.org/blog/accelerating-generative-ai-3)
5671

@@ -68,7 +83,7 @@ swap_linear_with_semi_sparse_linear(model, {"seq.0": SemiSparseLinear})
6883

6984
* [MX](torchao/prototype/mx_formats) implementing training and inference support with tensors using the [OCP MX spec](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf) data types, which can be described as groupwise scaled float8/float6/float4/int8, with the scales being constrained to powers of two. This work is prototype as the hardware support is not available yet.
7085
* [nf4](torchao/dtypes/nf4tensor.py) which was used to [implement QLoRA](https://github.com/pytorch/torchtune/blob/main/docs/source/tutorials/qlora_finetune.rst) one of the most popular finetuning algorithms without writing custom Triton or CUDA code. Accessible talk [here](https://x.com/HamelHusain/status/1800315287574847701)
71-
* [fp6](torchao/prototype/fp6_llm/) for 2x faster inference over fp16 with an easy to use wrapper api `convert_fp6_llm(model)`
86+
* [fp6](torchao/prototype/quant_llm/) for 2x faster inference over fp16 with an easy to use API `quantize(model, fp6_llm_weight_only())`
7287

7388
## Composability
7489

@@ -79,11 +94,34 @@ A key design principle for us is composability as in any new dtype or layout we
7994

8095

8196
### Installation
97+
8298
`torchao` makes liberal use of several new features in Pytorch, it's recommended to use it with the current nightly or latest stable version of PyTorch.
8399

84-
Stable Release
100+
#### Install torch
101+
102+
Install torch stable
103+
104+
```
105+
pip install torch
106+
```
107+
108+
Or torch nightlies
109+
110+
```
111+
pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121
112+
```
113+
114+
#### Install torchao
115+
116+
Stable release from Pypi which will default to CUDA 12.1
117+
85118
```Shell
86-
pip install torchao --extra-index-url https://download.pytorch.org/whl/test/cu121 # full options are cpu/cu118/cu121/cu124
119+
pip install torchao
120+
```
121+
122+
Stable Release from the PyTorch index
123+
```Shell
124+
pip install torchao --extra-index-url https://download.pytorch.org/whl/cu121 # full options are cpu/cu118/cu121/cu124
87125
```
88126

89127
Nightly Release
@@ -104,10 +142,17 @@ python setup.py install
104142
* [GaLore](torchao/prototype/galore/) a drop for the Adam Optimizer that allows you to finetune llama 7b on a single 4090 card with up to 70% speedups relative to eager PyTorch
105143
* [DoRA](torchao/prototype/dora) a newer replacement for QLoRA with more promising convergence characteristics
106144
* [Fused int4/fp16 Quant Matmul](torchao/prototype/hqq) which is particularly useful for compute bound kernels showing 4x speedups over tinygemm for larger batch sizes such as 512
107-
* [gau-nernst](https://github.com/gau-nernst) fp6 kernels that are 4x faster than fp16 [torchao/prototype/fp6_llm](torchao/prototype/fp6_llm)
145+
* [gau-nernst](https://github.com/gau-nernst) fp6 kernels that are 4x faster than fp16 [torchao/prototype/quant_llm](torchao/prototype/quant_llm)
108146
* [vayuda](https://github.com/vayuda) with generic bitpacking kernels that were code generated using pure PyTorch [prototype/common](torchao/prototype/common)
109147
* [andreaskopf](https://github.com/andreaskoepf) and [melvinebenezer](https://github.com/melvinebenezer) with [1 bit LLMs](torchao/prototype/dtypes) Bitnet 1.58 bitpacked into uint2 and fully code-generated with torch.compile
110148

149+
## Blogs and Videos
150+
* [Accelerating Neural Network Training with Semi-Structured (2:4) Sparsity](https://pytorch.org/blog/accelerating-neural-network-training/)
151+
* [https://mobiusml.github.io/whisper-static-cache-blog/](https://mobiusml.github.io/whisper-static-cache-blog/)
152+
* [Slaying OOMs at the Mastering LLM's course](https://x.com/HamelHusain/status/1800315287574847701)
153+
* [Advanced Quantization at CUDA MODE](https://youtu.be/1u9xUK3G4VM?si=4JcPlw2w8chPXW8J)
154+
* [Chip Huyen's GPU Optimization Workshop](https://www.youtube.com/live/v_q2JTIqE20?si=mf7HeZ63rS-uYpS6)
155+
111156
## How to contribute
112157

113158
This repository is currently under heavy development

0 commit comments

Comments
 (0)