Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

ensure torch always up-to-date in CI #5286

Merged
merged 4 commits into from
Jun 30, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,11 @@ jobs:
which python
pip freeze

- name: Ensure torch up-to-date
run: |
. .venv/bin/activate
python scripts/check_torch_version.py

- name: ${{ matrix.task.name }}
run: |
. .venv/bin/activate
Expand Down
62 changes: 62 additions & 0 deletions scripts/check_torch_version.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
"""
Ensures currently installed torch version is the newest allowed.
"""

from typing import Tuple, cast


def main():
current_torch_version = _get_current_installed_torch_version()
latest_torch_version = _get_latest_torch_version()
torch_version_upper_limit = _get_torch_version_upper_limit()

if current_torch_version < latest_torch_version < torch_version_upper_limit:
raise RuntimeError(
f"current torch version {current_torch_version} is behind "
f"latest allowed torch version {latest_torch_version}"
)

print("All good!")


def _get_current_installed_torch_version() -> Tuple[str, str, str]:
import torch

version = tuple(torch.version.__version__.split("."))
assert len(version) == 3, f"Bad parsed version '{version}'"
return cast(Tuple[str, str, str], version)


def _get_latest_torch_version() -> Tuple[str, str, str]:
import requests

r = requests.get("https://api.github.com/repos/pytorch/pytorch/tags")
assert r.ok
for tag_data in r.json():
tag = tag_data["name"]
if tag.startswith("v") and "-rc" not in tag:
# Tag should look like "vX.Y.Z"
version = tuple(tag[1:].split("."))
assert len(version) == 3, f"Bad parsed version '{version}'"
break
else:
raise RuntimeError("could not find latest stable release tag")
return cast(Tuple[str, str, str], version)


def _get_torch_version_upper_limit() -> Tuple[str, str, str]:
with open("setup.py") as f:
for line in f:
# The torch version line should look like:
# "torch>=X.Y.Z,<X.V.0",
if '"torch>=' in line:
version = tuple(line.split('"')[1].split("<")[1].strip().split("."))
assert len(version) == 3, f"Bad parsed version '{version}'"
break
else:
raise RuntimeError("could not find torch version spec in setup.py")
return cast(Tuple[str, str, str], version)


if __name__ == "__main__":
main()