This repository was archived by the owner on Dec 16, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 2.2k
/
Copy pathcheck_torch_version.py
62 lines (47 loc) · 2 KB
/
check_torch_version.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
"""
Ensures currently installed torch version is the newest allowed.
"""
from typing import Tuple, cast
def main():
current_torch_version = _get_current_installed_torch_version()
latest_torch_version = _get_latest_torch_version()
torch_version_upper_limit = _get_torch_version_upper_limit()
if current_torch_version < latest_torch_version < torch_version_upper_limit:
raise RuntimeError(
f"current torch version {current_torch_version} is behind "
f"latest allowed torch version {latest_torch_version}"
)
print("All good!")
def _get_current_installed_torch_version() -> Tuple[str, str, str]:
import torch
version = tuple(torch.version.__version__.split("."))
assert len(version) == 3, f"Bad parsed version '{version}'"
return cast(Tuple[str, str, str], version)
def _get_latest_torch_version() -> Tuple[str, str, str]:
import requests
r = requests.get("https://api.github.com/repos/pytorch/pytorch/tags")
assert r.ok
for tag_data in r.json():
tag = tag_data["name"]
if tag.startswith("v") and "-rc" not in tag:
# Tag should look like "vX.Y.Z"
version = tuple(tag[1:].split("."))
assert len(version) == 3, f"Bad parsed version '{version}'"
break
else:
raise RuntimeError("could not find latest stable release tag")
return cast(Tuple[str, str, str], version)
def _get_torch_version_upper_limit() -> Tuple[str, str, str]:
with open("setup.py") as f:
for line in f:
# The torch version line should look like:
# "torch>=X.Y.Z,<X.V.0",
if '"torch>=' in line:
version = tuple(line.split('"')[1].split("<")[1].strip().split("."))
assert len(version) == 3, f"Bad parsed version '{version}'"
break
else:
raise RuntimeError("could not find torch version spec in setup.py")
return cast(Tuple[str, str, str], version)
if __name__ == "__main__":
main()