1
1
import os
2
2
import re
3
3
import sys
4
+ from pathlib import Path
5
+
4
6
import torch
5
- # the following import would invoke
7
+ import torchaudio
8
+
9
+ # the following import would invoke
6
10
# _check_cuda_version()
7
11
# via torchvision.extension._check_cuda_version()
8
12
import torchvision
9
- import torchaudio
10
- from pathlib import Path
11
13
12
14
gpu_arch_ver = os .getenv ("GPU_ARCH_VER" )
13
15
gpu_arch_type = os .getenv ("GPU_ARCH_TYPE" )
16
18
is_cuda_system = gpu_arch_type == "cuda"
17
19
SCRIPT_DIR = Path (__file__ ).parent
18
20
19
- # helper function to return the conda list output, e.g.
20
- # torchaudio 0.13.0.dev20220922 py39_cu102 pytorch-nightly
21
- def get_anaconda_output_for_package (pkg_name_str ):
21
+ # helper function to return the conda installed packages
22
+ # and return package we are insterseted in
23
+ def get_anaconda_output_for_package (pkg_name_str ):
22
24
import subprocess as sp
23
25
24
- # ignore the header row:
25
- # Name Version Build Channel
26
- cmd = 'conda list -f ' + pkg_name_str
26
+ cmd = "conda list --explicit"
27
27
output = sp .getoutput (cmd )
28
+ for item in output .split ("\n " ):
29
+ if pkg_name_str in item :
30
+ return item
31
+
28
32
# Get the last line only
29
- return output . strip (). split ( ' \n ' )[ - 1 ]
33
+ return f" { pkg_name_str } can't be found"
30
34
31
- def check_nightly_binaries_date () -> None :
35
+
36
+ def check_nightly_binaries_date () -> None :
32
37
torch_str = torch .__version__
33
38
ta_str = torchaudio .__version__
34
39
tv_str = torchvision .__version__
35
40
36
- date_t_str = re .findall (' dev\d+' , torch .__version__ )
37
- date_ta_str = re .findall (' dev\d+' , torchaudio .__version__ )
38
- date_tv_str = re .findall (' dev\d+' , torchvision .__version__ )
39
-
41
+ date_t_str = re .findall (" dev\d+" , torch .__version__ )
42
+ date_ta_str = re .findall (" dev\d+" , torchaudio .__version__ )
43
+ date_tv_str = re .findall (" dev\d+" , torchvision .__version__ )
44
+
40
45
# check that the above three lists are equal and none of them is empty
41
46
if not date_t_str or not date_t_str == date_ta_str == date_tv_str :
42
- raise RuntimeError (f"Expected torch, torchaudio, torchvision to be the same date. But they are from { date_t_str } , { date_ta_str } , { date_tv_str } respectively" )
47
+ raise RuntimeError (
48
+ f"Expected torch, torchaudio, torchvision to be the same date. But they are from { date_t_str } , { date_ta_str } , { date_tv_str } respectively"
49
+ )
43
50
44
51
# check that the date is recent, at this point, date_torch_str is not empty
45
52
binary_date_str = date_t_str [0 ][3 :]
46
53
from datetime import datetime
47
54
48
- binary_date_obj = datetime .strptime (binary_date_str , ' %Y%m%d' ).date ()
55
+ binary_date_obj = datetime .strptime (binary_date_str , " %Y%m%d" ).date ()
49
56
today_obj = datetime .today ().date ()
50
57
delta = today_obj - binary_date_obj
51
58
if delta .days >= 2 :
52
- raise RuntimeError (f"the binaries are from { binary_date_obj } and are more than 2 days old!" )
59
+ raise RuntimeError (
60
+ f"the binaries are from { binary_date_obj } and are more than 2 days old!"
61
+ )
53
62
54
63
55
64
def smoke_test_cuda () -> None :
56
- if ( not torch .cuda .is_available () and is_cuda_system ) :
65
+ if not torch .cuda .is_available () and is_cuda_system :
57
66
raise RuntimeError (f"Expected CUDA { gpu_arch_ver } . However CUDA is not loaded." )
58
- if (torch .cuda .is_available ()):
59
- if (torch .version .cuda != gpu_arch_ver ):
60
- raise RuntimeError (f"Wrong CUDA version. Loaded: { torch .version .cuda } Expected: { gpu_arch_ver } " )
67
+ if torch .cuda .is_available ():
68
+ if torch .version .cuda != gpu_arch_ver :
69
+ raise RuntimeError (
70
+ f"Wrong CUDA version. Loaded: { torch .version .cuda } Expected: { gpu_arch_ver } "
71
+ )
61
72
print (f"torch cuda: { torch .version .cuda } " )
62
73
# todo add cudnn version validation
63
74
print (f"torch cudnn: { torch .backends .cudnn .version ()} " )
64
75
print (f"cuDNN enabled? { torch .backends .cudnn .enabled } " )
65
-
66
- if installation_str .find (' nightly' ) != - 1 :
76
+
77
+ if installation_str .find (" nightly" ) != - 1 :
67
78
# just print out cuda version, as version check were already performed during import
68
79
print (f"torchvision cuda: { torch .ops .torchvision ._cuda_version ()} " )
69
80
print (f"torchaudio cuda: { torch .ops .torchaudio .cuda_version ()} " )
@@ -72,11 +83,18 @@ def smoke_test_cuda() -> None:
72
83
# https://github.com/pytorch/audio/pull/2707
73
84
# so relying on anaconda output for pytorch-test and pytorch channel
74
85
torchaudio_allstr = get_anaconda_output_for_package (torchaudio .__name__ )
75
- if is_cuda_system and 'cu' + str (gpu_arch_ver ).replace ("." , "" ) not in torchaudio_allstr :
76
- raise RuntimeError (f"CUDA version issue. Loaded: { torchaudio_allstr } Expected: { gpu_arch_ver } " )
86
+ if (
87
+ is_cuda_system
88
+ and "cu" + str (gpu_arch_ver ).replace ("." , "" ) not in torchaudio_allstr
89
+ ):
90
+ raise RuntimeError (
91
+ f"CUDA version issue. Loaded: { torchaudio_allstr } Expected: { gpu_arch_ver } "
92
+ )
93
+
77
94
78
95
def smoke_test_conv2d () -> None :
79
96
import torch .nn as nn
97
+
80
98
print ("Calling smoke_test_conv2d" )
81
99
# With square kernels and equal stride
82
100
m = nn .Conv2d (16 , 33 , 3 , stride = 2 )
@@ -86,24 +104,34 @@ def smoke_test_conv2d() -> None:
86
104
m = nn .Conv2d (16 , 33 , (3 , 5 ), stride = (2 , 1 ), padding = (4 , 2 ), dilation = (3 , 1 ))
87
105
input = torch .randn (20 , 16 , 50 , 100 )
88
106
output = m (input )
89
- if ( is_cuda_system ) :
107
+ if is_cuda_system :
90
108
print ("Testing smoke_test_conv2d with cuda" )
91
109
conv = nn .Conv2d (3 , 3 , 3 ).cuda ()
92
110
x = torch .randn (1 , 3 , 24 , 24 ).cuda ()
93
111
with torch .cuda .amp .autocast ():
94
112
out = conv (x )
95
113
114
+
96
115
def smoke_test_torchvision () -> None :
97
- print ("Is torchvision useable?" , all (x is not None for x in [torch .ops .image .decode_png , torch .ops .torchvision .roi_align ]))
116
+ print (
117
+ "Is torchvision useable?" ,
118
+ all (
119
+ x is not None
120
+ for x in [torch .ops .image .decode_png , torch .ops .torchvision .roi_align ]
121
+ ),
122
+ )
123
+
98
124
99
125
def smoke_test_torchvision_read_decode () -> None :
100
126
from torchvision .io import read_image
127
+
101
128
img_jpg = read_image (str (SCRIPT_DIR / "assets" / "rgb_pytorch.jpg" ))
102
129
if img_jpg .ndim != 3 or img_jpg .numel () < 100 :
103
- raise RuntimeError (f"Unexpected shape of img_jpg: { img_jpg .shape } " )
104
- img_png = read_image (str (SCRIPT_DIR / "assets" / "rgb_pytorch.png" ))
130
+ raise RuntimeError (f"Unexpected shape of img_jpg: { img_jpg .shape } " )
131
+ img_png = read_image (str (SCRIPT_DIR / "assets" / "rgb_pytorch.png" ))
105
132
if img_png .ndim != 3 or img_png .numel () < 100 :
106
- raise RuntimeError (f"Unexpected shape of img_png: { img_png .shape } " )
133
+ raise RuntimeError (f"Unexpected shape of img_png: { img_png .shape } " )
134
+
107
135
108
136
def smoke_test_torchvision_resnet50_classify () -> None :
109
137
from torchvision .io import read_image
@@ -129,8 +157,10 @@ def smoke_test_torchvision_resnet50_classify() -> None:
129
157
category_name = weights .meta ["categories" ][class_id ]
130
158
expected_category = "German shepherd"
131
159
print (f"{ category_name } : { 100 * score :.1f} %" )
132
- if (category_name != expected_category ):
133
- raise RuntimeError (f"Failed ResNet50 classify { category_name } Expected: { expected_category } " )
160
+ if category_name != expected_category :
161
+ raise RuntimeError (
162
+ f"Failed ResNet50 classify { category_name } Expected: { expected_category } "
163
+ )
134
164
135
165
136
166
def smoke_test_torchaudio () -> None :
@@ -145,21 +175,22 @@ def smoke_test_torchaudio() -> None:
145
175
146
176
147
177
def main () -> None :
148
- #todo add torch, torchvision and torchaudio tests
178
+ # todo add torch, torchvision and torchaudio tests
149
179
print (f"torch: { torch .__version__ } " )
150
180
print (f"torchvision: { torchvision .__version__ } " )
151
181
print (f"torchaudio: { torchaudio .__version__ } " )
152
182
smoke_test_cuda ()
153
183
154
184
# only makes sense to check nightly package where dates are known
155
- if installation_str .find (' nightly' ) != - 1 :
156
- check_nightly_binaries_date ()
185
+ if installation_str .find (" nightly" ) != - 1 :
186
+ check_nightly_binaries_date ()
157
187
158
188
smoke_test_conv2d ()
159
189
smoke_test_torchaudio ()
160
190
smoke_test_torchvision ()
161
191
smoke_test_torchvision_read_decode ()
162
192
smoke_test_torchvision_resnet50_classify ()
163
193
194
+
164
195
if __name__ == "__main__" :
165
196
main ()
0 commit comments