Skip to content

Commit 8f2e8cf

Browse files
committed
add a script that fetches and validate binary sizes of the wheels using https://download.pytorch.org/whl/ index using specified rules
1 parent c212b85 commit 8f2e8cf

File tree

5 files changed

+190
-0
lines changed

5 files changed

+190
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
name: Test binary size validation script
2+
on:
3+
pull_request:
4+
paths:
5+
- .github/workflows/binary-size-validation.yml
6+
- tools/binary_size_validation/test_binary_size_validation.py
7+
- tools/binary_size_validation/binary_size_validation.py
8+
workflow_dispatch:
9+
10+
jobs:
11+
test-binary-size-validation:
12+
runs-on: ubuntu-latest
13+
steps:
14+
- name: Checkout
15+
uses: actions/checkout@v3
16+
- name: Install requirements
17+
run: |
18+
pip3 install -r tools/binary_size_validation/requirements.txt
19+
- name: Run pytest
20+
run: |
21+
pytest tools/binary_size_validation/test_binary_size_validation.py
+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# PyTorch Wheel Binary Size Validation
2+
3+
A script to fetch and validate the binary size of PyTorch wheels
4+
in the given channel (test, nightly) against the given threshold.
5+
6+
7+
### Installation
8+
9+
```bash
10+
pip install -r requirements.txt
11+
```
12+
13+
### Usage
14+
15+
```bash
16+
# print help
17+
python binary_size_validation.py --help
18+
19+
# print sizes of the all items in the index
20+
python binary_size_validation.py --url https://download.pytorch.org/whl/nightly/torch/
21+
22+
# fail if any of the torch2.0 wheels are larger than 900MB
23+
python binary_size_validation.py --url https://download.pytorch.org/whl/nightly/torch/ --include "torch-2\.0" --threshold 900
24+
25+
# fail if any of the latest nightly pypi wheels are larger than 750MB
26+
python binary_size_validation.py --include "pypi" --only-latest-version --threshold 750
27+
```
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
# Script that parses wheel index (e.g. https://download.pytorch.org/whl/test/torch/),
2+
# fetches and validates binary size for the files that match the given regex.
3+
4+
import requests
5+
import re
6+
from collections import namedtuple
7+
import click
8+
from bs4 import BeautifulSoup
9+
from urllib.parse import urljoin
10+
11+
Wheel = namedtuple("Wheel", ["name", "url"])
12+
13+
14+
def parse_index(html: str,
15+
base_url: str,
16+
include_regex: str = "",
17+
exclude_regex: str = "",
18+
latest_version_only=False) -> list[Wheel]:
19+
"""
20+
parse the html page and return a list of wheels
21+
:param html: html page
22+
:param base_url: base url of the page
23+
:param include_regex: regex to filter the wheel names. If empty, all wheels are included
24+
:param exclude_regex: regex to exclude the matching wheel names. If empty, no wheels are excluded
25+
:param latest_version_only: if True, return the wheels of the latest version only
26+
:return: list of wheels
27+
"""
28+
soup = BeautifulSoup(html, "html.parser")
29+
30+
wheels = []
31+
for a in soup.find_all("a"):
32+
wheel_name = a.text
33+
wheel_url = urljoin(base_url, a.get("href"))
34+
if (not include_regex or re.search(include_regex, wheel_name)) \
35+
and (not exclude_regex or not re.search(exclude_regex, wheel_name)):
36+
wheels.append(Wheel(name=wheel_name, url=wheel_url))
37+
38+
# filter out the wheels that are not the latest version
39+
if len(wheels) > 0 and latest_version_only:
40+
# get the prefixes (up to the second '+'/'-' sign) of the wheels
41+
prefixes = set()
42+
for wheel in wheels:
43+
prefix = re.search(r"^([^-+]+[-+][^-+]+)[-+]", wheel.name).group(1)
44+
if not prefix:
45+
raise RuntimeError(f"Failed to get version prefix of {wheel.name}"
46+
"Please check the regex_filter or don't use --latest-version-only")
47+
prefixes.add(prefix)
48+
latest_version = max(prefixes)
49+
print(f"Latest version prefix: {latest_version}")
50+
51+
# filter out the wheels that are not the latest version
52+
wheels = [wheel for wheel in wheels if wheel.name.startswith(latest_version)]
53+
54+
return wheels
55+
56+
57+
def get_binary_size(file_url: str) -> int:
58+
"""
59+
get the binary size of the given file
60+
:param file_url: url of the file
61+
:return: binary size in bytes
62+
"""
63+
return int(requests.head(file_url).headers['Content-Length'])
64+
65+
66+
@click.command(
67+
help="Validate the binary sizes of the given wheel index."
68+
)
69+
@click.option("--url", help="url of the wheel index",
70+
default="https://download.pytorch.org/whl/nightly/torch/")
71+
@click.option("--include", help="regex to filter the wheel names. Only the matching wheel names will be checked.",
72+
default="")
73+
@click.option("--exclude", help="regex to exclude wheel names. Matching wheel names will NOT be checked.",
74+
default="")
75+
@click.option("--threshold", help="threshold in MB, optional", default=0)
76+
@click.option("--only-latest-version", help="only validate the latest version",
77+
is_flag=True, show_default=True, default=False)
78+
def main(url, include, exclude, threshold, only_latest_version):
79+
page = requests.get(url)
80+
wheels = parse_index(page.text, url, include, exclude, only_latest_version)
81+
for wheel in wheels:
82+
print(f"Validating {wheel.url}...")
83+
size = get_binary_size(wheel.url)
84+
print(f"{wheel.name}: {int(size) / 1024 / 1024:.2f} MB")
85+
if threshold and int(size) > threshold:
86+
raise RuntimeError(
87+
f"Binary size of {wheel.name} {int(size) / 1024 / 1024:.2f} MB exceeds the threshold {threshold} MB")
88+
89+
90+
if __name__ == "__main__":
91+
main()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
beautifulsoup4==4.11.2
2+
click==8.0.4
3+
pytest==7.1.1
4+
requests==2.27.1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
from binary_size_validation import parse_index
2+
3+
# ignore long lines in this file
4+
# flake8: noqa: E501
5+
test_html = """
6+
<!DOCTYPE html>
7+
<html>
8+
<body>
9+
<h1>Links for torch</h1>
10+
<a href="/whl/nightly/cpu/torch-1.13.0.dev20220728%2Bcpu-cp310-cp310-linux_x86_64.whl">torch-1.13.0.dev20220728+cpu-cp310-cp310-linux_x86_64.whl</a><br/>
11+
<a href="/whl/nightly/cpu/torch-1.13.0.dev20220728%2Bcpu-cp310-cp310-win_amd64.whl">torch-1.13.0.dev20220728+cpu-cp310-cp310-win_amd64.whl</a><br/>
12+
<a href="/whl/nightly/cpu/torch-1.13.0.dev20220728%2Bcpu-cp37-cp37m-linux_x86_64.whl">torch-1.13.0.dev20220728+cpu-cp37-cp37m-linux_x86_64.whl</a><br/>
13+
<a href="/whl/nightly/cpu/torch-1.13.0.dev20220728%2Bcpu-cp37-cp37m-win_amd64.whl">torch-1.13.0.dev20220728+cpu-cp37-cp37m-win_amd64.whl</a><br/>
14+
<a href="/whl/nightly/rocm5.3/torch-2.0.0.dev20230206%2Brocm5.3-cp39-cp39-linux_x86_64.whl">torch-2.0.0.dev20230206+rocm5.3-cp39-cp39-linux_x86_64.whl</a><br/>
15+
<a href="/whl/nightly/rocm5.3/torch-2.0.0.dev20230207%2Brocm5.3-cp310-cp310-linux_x86_64.whl">torch-2.0.0.dev20230207+rocm5.3-cp310-cp310-linux_x86_64.whl</a><br/>
16+
<a href="/whl/nightly/rocm5.3/torch-2.0.0.dev20230207%2Brocm5.3-cp38-cp38-linux_x86_64.whl">torch-2.0.0.dev20230207+rocm5.3-cp38-cp38-linux_x86_64.whl</a><br/>
17+
<a href="/whl/nightly/rocm5.3/torch-2.0.0.dev20230207%2Brocm5.3-cp39-cp39-linux_x86_64.whl">torch-2.0.0.dev20230207+rocm5.3-cp39-cp39-linux_x86_64.whl</a><br/>
18+
</body>
19+
</html>
20+
<!--TIMESTAMP 1675892605-->
21+
"""
22+
23+
base_url = "https://download.pytorch.org/whl/nightly/torch/"
24+
25+
26+
def test_get_whl_links():
27+
wheels = parse_index(test_html, base_url)
28+
assert len(wheels) == 8
29+
assert wheels[0].url == \
30+
"https://download.pytorch.org/whl/nightly/cpu/torch-1.13.0.dev20220728%2Bcpu-cp310-cp310-linux_x86_64.whl"
31+
32+
33+
def test_include_exclude():
34+
wheels = parse_index(test_html, base_url, "amd6\\d")
35+
assert len(wheels) == 2
36+
assert wheels[0].name == "torch-1.13.0.dev20220728+cpu-cp310-cp310-win_amd64.whl"
37+
assert wheels[1].name == "torch-1.13.0.dev20220728+cpu-cp37-cp37m-win_amd64.whl"
38+
39+
wheels = parse_index(test_html, base_url, "amd6\\d", "cp37")
40+
assert len(wheels) == 1
41+
assert wheels[0].name == "torch-1.13.0.dev20220728+cpu-cp310-cp310-win_amd64.whl"
42+
43+
44+
def test_latest_version_only():
45+
wheels = parse_index(test_html, base_url, latest_version_only=True)
46+
assert len(wheels) == 3
47+
assert all(w.name.startswith("torch-2.0.0.dev20230207") for w in wheels)

0 commit comments

Comments
 (0)