Skip to content

Commit 3a1bfc9

Browse files
authored
Fix anaconda torchaudio smoke test (#1161)
* Fix anaconda torchaudio smoke test * Format using ufmt
1 parent ac931b5 commit 3a1bfc9

File tree

1 file changed

+67
-36
lines changed

1 file changed

+67
-36
lines changed

test/smoke_test/smoke_test.py

Lines changed: 67 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
import os
22
import re
33
import sys
4+
from pathlib import Path
5+
46
import torch
5-
# the following import would invoke
7+
import torchaudio
8+
9+
# the following import would invoke
610
# _check_cuda_version()
711
# via torchvision.extension._check_cuda_version()
812
import torchvision
9-
import torchaudio
10-
from pathlib import Path
1113

1214
gpu_arch_ver = os.getenv("GPU_ARCH_VER")
1315
gpu_arch_type = os.getenv("GPU_ARCH_TYPE")
@@ -16,54 +18,63 @@
1618
is_cuda_system = gpu_arch_type == "cuda"
1719
SCRIPT_DIR = Path(__file__).parent
1820

19-
# helper function to return the conda list output, e.g.
20-
# torchaudio 0.13.0.dev20220922 py39_cu102 pytorch-nightly
21-
def get_anaconda_output_for_package(pkg_name_str):
21+
# helper function to return the conda installed packages
22+
# and return package we are insterseted in
23+
def get_anaconda_output_for_package(pkg_name_str):
2224
import subprocess as sp
2325

24-
# ignore the header row:
25-
# Name Version Build Channel
26-
cmd = 'conda list -f ' + pkg_name_str
26+
cmd = "conda list --explicit"
2727
output = sp.getoutput(cmd)
28+
for item in output.split("\n"):
29+
if pkg_name_str in item:
30+
return item
31+
2832
# Get the last line only
29-
return output.strip().split('\n')[-1]
33+
return f"{pkg_name_str} can't be found"
3034

31-
def check_nightly_binaries_date() -> None:
35+
36+
def check_nightly_binaries_date() -> None:
3237
torch_str = torch.__version__
3338
ta_str = torchaudio.__version__
3439
tv_str = torchvision.__version__
3540

36-
date_t_str = re.findall('dev\d+', torch.__version__ )
37-
date_ta_str = re.findall('dev\d+', torchaudio.__version__ )
38-
date_tv_str = re.findall('dev\d+', torchvision.__version__ )
39-
41+
date_t_str = re.findall("dev\d+", torch.__version__)
42+
date_ta_str = re.findall("dev\d+", torchaudio.__version__)
43+
date_tv_str = re.findall("dev\d+", torchvision.__version__)
44+
4045
# check that the above three lists are equal and none of them is empty
4146
if not date_t_str or not date_t_str == date_ta_str == date_tv_str:
42-
raise RuntimeError(f"Expected torch, torchaudio, torchvision to be the same date. But they are from {date_t_str}, {date_ta_str}, {date_tv_str} respectively")
47+
raise RuntimeError(
48+
f"Expected torch, torchaudio, torchvision to be the same date. But they are from {date_t_str}, {date_ta_str}, {date_tv_str} respectively"
49+
)
4350

4451
# check that the date is recent, at this point, date_torch_str is not empty
4552
binary_date_str = date_t_str[0][3:]
4653
from datetime import datetime
4754

48-
binary_date_obj = datetime.strptime(binary_date_str, '%Y%m%d').date()
55+
binary_date_obj = datetime.strptime(binary_date_str, "%Y%m%d").date()
4956
today_obj = datetime.today().date()
5057
delta = today_obj - binary_date_obj
5158
if delta.days >= 2:
52-
raise RuntimeError(f"the binaries are from {binary_date_obj} and are more than 2 days old!")
59+
raise RuntimeError(
60+
f"the binaries are from {binary_date_obj} and are more than 2 days old!"
61+
)
5362

5463

5564
def smoke_test_cuda() -> None:
56-
if(not torch.cuda.is_available() and is_cuda_system):
65+
if not torch.cuda.is_available() and is_cuda_system:
5766
raise RuntimeError(f"Expected CUDA {gpu_arch_ver}. However CUDA is not loaded.")
58-
if(torch.cuda.is_available()):
59-
if(torch.version.cuda != gpu_arch_ver):
60-
raise RuntimeError(f"Wrong CUDA version. Loaded: {torch.version.cuda} Expected: {gpu_arch_ver}")
67+
if torch.cuda.is_available():
68+
if torch.version.cuda != gpu_arch_ver:
69+
raise RuntimeError(
70+
f"Wrong CUDA version. Loaded: {torch.version.cuda} Expected: {gpu_arch_ver}"
71+
)
6172
print(f"torch cuda: {torch.version.cuda}")
6273
# todo add cudnn version validation
6374
print(f"torch cudnn: {torch.backends.cudnn.version()}")
6475
print(f"cuDNN enabled? {torch.backends.cudnn.enabled}")
65-
66-
if installation_str.find('nightly') != -1:
76+
77+
if installation_str.find("nightly") != -1:
6778
# just print out cuda version, as version check were already performed during import
6879
print(f"torchvision cuda: {torch.ops.torchvision._cuda_version()}")
6980
print(f"torchaudio cuda: {torch.ops.torchaudio.cuda_version()}")
@@ -72,11 +83,18 @@ def smoke_test_cuda() -> None:
7283
# https://github.com/pytorch/audio/pull/2707
7384
# so relying on anaconda output for pytorch-test and pytorch channel
7485
torchaudio_allstr = get_anaconda_output_for_package(torchaudio.__name__)
75-
if is_cuda_system and 'cu'+str(gpu_arch_ver).replace(".", "") not in torchaudio_allstr:
76-
raise RuntimeError(f"CUDA version issue. Loaded: {torchaudio_allstr} Expected: {gpu_arch_ver}")
86+
if (
87+
is_cuda_system
88+
and "cu" + str(gpu_arch_ver).replace(".", "") not in torchaudio_allstr
89+
):
90+
raise RuntimeError(
91+
f"CUDA version issue. Loaded: {torchaudio_allstr} Expected: {gpu_arch_ver}"
92+
)
93+
7794

7895
def smoke_test_conv2d() -> None:
7996
import torch.nn as nn
97+
8098
print("Calling smoke_test_conv2d")
8199
# With square kernels and equal stride
82100
m = nn.Conv2d(16, 33, 3, stride=2)
@@ -86,24 +104,34 @@ def smoke_test_conv2d() -> None:
86104
m = nn.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1))
87105
input = torch.randn(20, 16, 50, 100)
88106
output = m(input)
89-
if(is_cuda_system):
107+
if is_cuda_system:
90108
print("Testing smoke_test_conv2d with cuda")
91109
conv = nn.Conv2d(3, 3, 3).cuda()
92110
x = torch.randn(1, 3, 24, 24).cuda()
93111
with torch.cuda.amp.autocast():
94112
out = conv(x)
95113

114+
96115
def smoke_test_torchvision() -> None:
97-
print("Is torchvision useable?", all(x is not None for x in [torch.ops.image.decode_png, torch.ops.torchvision.roi_align]))
116+
print(
117+
"Is torchvision useable?",
118+
all(
119+
x is not None
120+
for x in [torch.ops.image.decode_png, torch.ops.torchvision.roi_align]
121+
),
122+
)
123+
98124

99125
def smoke_test_torchvision_read_decode() -> None:
100126
from torchvision.io import read_image
127+
101128
img_jpg = read_image(str(SCRIPT_DIR / "assets" / "rgb_pytorch.jpg"))
102129
if img_jpg.ndim != 3 or img_jpg.numel() < 100:
103-
raise RuntimeError(f"Unexpected shape of img_jpg: {img_jpg.shape}")
104-
img_png = read_image(str(SCRIPT_DIR / "assets" / "rgb_pytorch.png"))
130+
raise RuntimeError(f"Unexpected shape of img_jpg: {img_jpg.shape}")
131+
img_png = read_image(str(SCRIPT_DIR / "assets" / "rgb_pytorch.png"))
105132
if img_png.ndim != 3 or img_png.numel() < 100:
106-
raise RuntimeError(f"Unexpected shape of img_png: {img_png.shape}")
133+
raise RuntimeError(f"Unexpected shape of img_png: {img_png.shape}")
134+
107135

108136
def smoke_test_torchvision_resnet50_classify() -> None:
109137
from torchvision.io import read_image
@@ -129,8 +157,10 @@ def smoke_test_torchvision_resnet50_classify() -> None:
129157
category_name = weights.meta["categories"][class_id]
130158
expected_category = "German shepherd"
131159
print(f"{category_name}: {100 * score:.1f}%")
132-
if(category_name != expected_category):
133-
raise RuntimeError(f"Failed ResNet50 classify {category_name} Expected: {expected_category}")
160+
if category_name != expected_category:
161+
raise RuntimeError(
162+
f"Failed ResNet50 classify {category_name} Expected: {expected_category}"
163+
)
134164

135165

136166
def smoke_test_torchaudio() -> None:
@@ -145,21 +175,22 @@ def smoke_test_torchaudio() -> None:
145175

146176

147177
def main() -> None:
148-
#todo add torch, torchvision and torchaudio tests
178+
# todo add torch, torchvision and torchaudio tests
149179
print(f"torch: {torch.__version__}")
150180
print(f"torchvision: {torchvision.__version__}")
151181
print(f"torchaudio: {torchaudio.__version__}")
152182
smoke_test_cuda()
153183

154184
# only makes sense to check nightly package where dates are known
155-
if installation_str.find('nightly') != -1:
156-
check_nightly_binaries_date()
185+
if installation_str.find("nightly") != -1:
186+
check_nightly_binaries_date()
157187

158188
smoke_test_conv2d()
159189
smoke_test_torchaudio()
160190
smoke_test_torchvision()
161191
smoke_test_torchvision_read_decode()
162192
smoke_test_torchvision_resnet50_classify()
163193

194+
164195
if __name__ == "__main__":
165196
main()

0 commit comments

Comments
 (0)