Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

ensure torch always up-to-date in CI #5286

Merged
merged 4 commits into from
Jun 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,11 @@ jobs:
which python
pip freeze

- name: Ensure torch up-to-date
run: |
. .venv/bin/activate
python scripts/check_torch_version.py

- name: ${{ matrix.task.name }}
run: |
. .venv/bin/activate
Expand Down
62 changes: 62 additions & 0 deletions scripts/check_torch_version.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
"""
Ensures currently installed torch version is the newest allowed.
"""

from typing import Tuple, cast


def main():
current_torch_version = _get_current_installed_torch_version()
latest_torch_version = _get_latest_torch_version()
torch_version_upper_limit = _get_torch_version_upper_limit()

if current_torch_version < latest_torch_version < torch_version_upper_limit:
raise RuntimeError(
f"current torch version {current_torch_version} is behind "
f"latest allowed torch version {latest_torch_version}"
)

print("All good!")


def _get_current_installed_torch_version() -> Tuple[str, str, str]:
import torch

version = tuple(torch.version.__version__.split("."))
assert len(version) == 3, f"Bad parsed version '{version}'"
return cast(Tuple[str, str, str], version)


def _get_latest_torch_version() -> Tuple[str, str, str]:
import requests

r = requests.get("https://api.github.com/repos/pytorch/pytorch/tags")
assert r.ok
for tag_data in r.json():
tag = tag_data["name"]
if tag.startswith("v") and "-rc" not in tag:
# Tag should look like "vX.Y.Z"
version = tuple(tag[1:].split("."))
assert len(version) == 3, f"Bad parsed version '{version}'"
break
else:
raise RuntimeError("could not find latest stable release tag")
return cast(Tuple[str, str, str], version)


def _get_torch_version_upper_limit() -> Tuple[str, str, str]:
with open("setup.py") as f:
for line in f:
# The torch version line should look like:
# "torch>=X.Y.Z,<X.V.0",
if '"torch>=' in line:
version = tuple(line.split('"')[1].split("<")[1].strip().split("."))
assert len(version) == 3, f"Bad parsed version '{version}'"
break
else:
raise RuntimeError("could not find torch version spec in setup.py")
return cast(Tuple[str, str, str], version)


if __name__ == "__main__":
main()