2
2
import re
3
3
import sys
4
4
from pathlib import Path
5
-
5
+ import argparse
6
6
import torch
7
- import torchaudio
8
-
9
- # the following import would invoke
10
- # _check_cuda_version()
11
- # via torchvision.extension._check_cuda_version()
12
- import torchvision
13
7
14
8
gpu_arch_ver = os .getenv ("GPU_ARCH_VER" )
15
9
gpu_arch_type = os .getenv ("GPU_ARCH_TYPE" )
16
10
# use installation env variable to tell if it is nightly channel
17
11
installation_str = os .getenv ("INSTALLATION" )
18
12
is_cuda_system = gpu_arch_type == "cuda"
19
13
SCRIPT_DIR = Path (__file__ ).parent
14
+ NIGHTLY_ALLOWED_DELTA = 3
20
15
21
16
# helper function to return the conda installed packages
22
17
# and return package we are insterseted in
@@ -38,35 +33,36 @@ def get_anaconda_output_for_package(pkg_name_str):
38
33
return output .strip ().split ('\n ' )[- 1 ]
39
34
40
35
41
- def check_nightly_binaries_date () -> None :
42
- torch_str = torch .__version__
43
- ta_str = torchaudio .__version__
44
- tv_str = torchvision .__version__
36
+ def check_nightly_binaries_date (package : str ) -> None :
37
+ from datetime import datetime , timedelta
38
+ format_dt = '%Y%m%d'
45
39
40
+ torch_str = torch .__version__
46
41
date_t_str = re .findall ("dev\d+" , torch .__version__ )
47
- date_ta_str = re .findall ("dev\d+" , torchaudio .__version__ )
48
- date_tv_str = re .findall ("dev\d+" , torchvision .__version__ )
49
-
50
- # check that the above three lists are equal and none of them is empty
51
- if not date_t_str or not date_t_str == date_ta_str == date_tv_str :
42
+ date_t_delta = datetime .now () - datetime .strptime (date_t_str [0 ][3 :], format_dt )
43
+ if date_t_delta .days >= NIGHTLY_ALLOWED_DELTA :
52
44
raise RuntimeError (
53
- f"Expected torch, torchaudio, torchvision to be the same date. But they are from { date_t_str } , { date_ta_str } , { date_tv_str } respectively "
45
+ f"the binaries are from { date_t_str } and are more than { NIGHTLY_ALLOWED_DELTA } days old! "
54
46
)
55
47
56
- # check that the date is recent, at this point, date_torch_str is not empty
57
- binary_date_str = date_t_str [0 ][3 :]
58
- from datetime import datetime
59
-
60
- binary_date_obj = datetime .strptime (binary_date_str , "%Y%m%d" ).date ()
61
- today_obj = datetime .today ().date ()
62
- delta = today_obj - binary_date_obj
63
- if delta .days >= 2 :
64
- raise RuntimeError (
65
- f"the binaries are from { binary_date_obj } and are more than 2 days old!"
66
- )
48
+ if (package == "all" ):
49
+ import torchaudio
50
+ import torchvision
51
+ ta_str = torchaudio .__version__
52
+ tv_str = torchvision .__version__
53
+ date_ta_str = re .findall ("dev\d+" , torchaudio .__version__ )
54
+ date_tv_str = re .findall ("dev\d+" , torchvision .__version__ )
55
+ date_ta_delta = datetime .now () - datetime .strptime (date_ta_str [0 ][3 :], format_dt )
56
+ date_tv_delta = datetime .now () - datetime .strptime (date_tv_str [0 ][3 :], format_dt )
57
+
58
+ # check that the above three lists are equal and none of them is empty
59
+ if date_ta_delta .days > NIGHTLY_ALLOWED_DELTA or date_tv_delta .days > NIGHTLY_ALLOWED_DELTA :
60
+ raise RuntimeError (
61
+ f"Expected torchaudio, torchvision to be less then { NIGHTLY_ALLOWED_DELTA } days. But they are from { date_ta_str } , { date_tv_str } respectively"
62
+ )
67
63
68
64
69
- def smoke_test_cuda () -> None :
65
+ def smoke_test_cuda (package : str ) -> None :
70
66
if not torch .cuda .is_available () and is_cuda_system :
71
67
raise RuntimeError (f"Expected CUDA { gpu_arch_ver } . However CUDA is not loaded." )
72
68
if torch .cuda .is_available ():
@@ -79,23 +75,25 @@ def smoke_test_cuda() -> None:
79
75
print (f"torch cudnn: { torch .backends .cudnn .version ()} " )
80
76
print (f"cuDNN enabled? { torch .backends .cudnn .enabled } " )
81
77
82
- if installation_str .find ("nightly" ) != - 1 :
83
- # just print out cuda version, as version check were already performed during import
84
- print (f"torchvision cuda: { torch .ops .torchvision ._cuda_version ()} " )
85
- print (f"torchaudio cuda: { torch .ops .torchaudio .cuda_version ()} " )
86
- else :
87
- # torchaudio runtime added the cuda verison check on 09/23/2022 via
88
- # https://github.com/pytorch/audio/pull/2707
89
- # so relying on anaconda output for pytorch-test and pytorch channel
90
- torchaudio_allstr = get_anaconda_output_for_package (torchaudio .__name__ )
91
- if (
92
- is_cuda_system
93
- and "cu" + str (gpu_arch_ver ).replace ("." , "" ) not in torchaudio_allstr
94
- ):
95
- raise RuntimeError (
96
- f"CUDA version issue. Loaded: { torchaudio_allstr } Expected: { gpu_arch_ver } "
97
- )
98
-
78
+ if (package == 'all' ):
79
+ import torchaudio
80
+ import torchvision
81
+ if installation_str .find ("nightly" ) != - 1 :
82
+ # just print out cuda version, as version check were already performed during import
83
+ print (f"torchvision cuda: { torch .ops .torchvision ._cuda_version ()} " )
84
+ print (f"torchaudio cuda: { torch .ops .torchaudio .cuda_version ()} " )
85
+ else :
86
+ # torchaudio runtime added the cuda verison check on 09/23/2022 via
87
+ # https://github.com/pytorch/audio/pull/2707
88
+ # so relying on anaconda output for pytorch-test and pytorch channel
89
+ torchaudio_allstr = get_anaconda_output_for_package (torchaudio .__name__ )
90
+ if (
91
+ is_cuda_system
92
+ and "cu" + str (gpu_arch_ver ).replace ("." , "" ) not in torchaudio_allstr
93
+ ):
94
+ raise RuntimeError (
95
+ f"CUDA version issue. Loaded: { torchaudio_allstr } Expected: { gpu_arch_ver } "
96
+ )
99
97
100
98
def smoke_test_conv2d () -> None :
101
99
import torch .nn as nn
@@ -169,6 +167,7 @@ def smoke_test_torchvision_resnet50_classify(device: str = "cpu") -> None:
169
167
170
168
171
169
def smoke_test_torchaudio () -> None :
170
+ import torchaudio
172
171
import torchaudio .compliance .kaldi # noqa: F401
173
172
import torchaudio .datasets # noqa: F401
174
173
import torchaudio .functional # noqa: F401
@@ -180,24 +179,35 @@ def smoke_test_torchaudio() -> None:
180
179
181
180
182
181
def main () -> None :
183
- # todo add torch, torchvision and torchaudio tests
182
+ parser = argparse .ArgumentParser ()
183
+ parser .add_argument (
184
+ "--package" ,
185
+ help = "Package to include in smoke testing" ,
186
+ type = str ,
187
+ choices = ["all" , "torchonly" ],
188
+ default = "all" ,
189
+ )
190
+ options = parser .parse_args ()
184
191
print (f"torch: { torch .__version__ } " )
185
- print (f"torchvision: { torchvision .__version__ } " )
186
- print (f"torchaudio: { torchaudio .__version__ } " )
187
- smoke_test_cuda ()
188
-
189
- # only makes sense to check nightly package where dates are known
190
- if installation_str .find ("nightly" ) != - 1 :
191
- check_nightly_binaries_date ()
192
192
193
+ smoke_test_cuda (options .package )
193
194
smoke_test_conv2d ()
194
- smoke_test_torchaudio ()
195
- smoke_test_torchvision ()
196
- smoke_test_torchvision_read_decode ()
197
- smoke_test_torchvision_resnet50_classify ()
198
- if torch .cuda .is_available ():
199
- smoke_test_torchvision_resnet50_classify ("cuda" )
200
195
196
+ # only makes sense to check nightly package where dates are known
197
+ if installation_str .find ("nightly" ) != - 1 :
198
+ check_nightly_binaries_date (options .package )
199
+
200
+ if options .package == "all" :
201
+ import torchaudio
202
+ import torchvision
203
+ print (f"torchvision: { torchvision .__version__ } " )
204
+ print (f"torchaudio: { torchaudio .__version__ } " )
205
+ smoke_test_torchaudio ()
206
+ smoke_test_torchvision ()
207
+ smoke_test_torchvision_read_decode ()
208
+ smoke_test_torchvision_resnet50_classify ()
209
+ if torch .cuda .is_available ():
210
+ smoke_test_torchvision_resnet50_classify ("cuda" )
201
211
202
212
if __name__ == "__main__" :
203
213
main ()
0 commit comments