@@ -55,18 +55,23 @@ def check_nightly_binaries_date(package: str) -> None:
55
55
f"Expected { module ['name' ]} to be less then { NIGHTLY_ALLOWED_DELTA } days. But its { date_m_delta } "
56
56
)
57
57
58
+ def test_cuda_runtime_errors_captured () -> None :
59
+ cuda_exception_missed = True
60
+ try :
61
+ torch ._assert_async (torch .tensor (0 , device = "cuda" ))
62
+ torch ._assert_async (torch .tensor (0 + 0j , device = "cuda" ))
63
+ except RuntimeError as e :
64
+ if re .search ("CUDA" , f"{ e } " ):
65
+ print (f"Caught CUDA exception with success: { e } " )
66
+ cuda_exception_missed = False
67
+ else :
68
+ raise e
69
+ if (cuda_exception_missed ):
70
+ raise RuntimeError ( f"Expected CUDA RuntimeError but have not received!" )
71
+
58
72
def smoke_test_cuda (package : str ) -> None :
59
73
if not torch .cuda .is_available () and is_cuda_system :
60
74
raise RuntimeError (f"Expected CUDA { gpu_arch_ver } . However CUDA is not loaded." )
61
- if torch .cuda .is_available ():
62
- if torch .version .cuda != gpu_arch_ver :
63
- raise RuntimeError (
64
- f"Wrong CUDA version. Loaded: { torch .version .cuda } Expected: { gpu_arch_ver } "
65
- )
66
- print (f"torch cuda: { torch .version .cuda } " )
67
- # todo add cudnn version validation
68
- print (f"torch cudnn: { torch .backends .cudnn .version ()} " )
69
- print (f"cuDNN enabled? { torch .backends .cudnn .enabled } " )
70
75
71
76
if (package == 'all' and is_cuda_system ):
72
77
for module in MODULES :
@@ -80,6 +85,19 @@ def smoke_test_cuda(package: str) -> None:
80
85
version = imported_module ._extension ._check_cuda_version ()
81
86
print (f"{ module ['name' ]} CUDA: { version } " )
82
87
88
+ if torch .cuda .is_available ():
89
+ if torch .version .cuda != gpu_arch_ver :
90
+ raise RuntimeError (
91
+ f"Wrong CUDA version. Loaded: { torch .version .cuda } Expected: { gpu_arch_ver } "
92
+ )
93
+ print (f"torch cuda: { torch .version .cuda } " )
94
+ # todo add cudnn version validation
95
+ print (f"torch cudnn: { torch .backends .cudnn .version ()} " )
96
+ print (f"cuDNN enabled? { torch .backends .cudnn .enabled } " )
97
+
98
+ # This check has to be run last, since its messing up CUDA runtime
99
+ test_cuda_runtime_errors_captured ()
100
+
83
101
84
102
def smoke_test_conv2d () -> None :
85
103
import torch .nn as nn
@@ -128,7 +146,6 @@ def main() -> None:
128
146
)
129
147
options = parser .parse_args ()
130
148
print (f"torch: { torch .__version__ } " )
131
- smoke_test_cuda (options .package )
132
149
smoke_test_conv2d ()
133
150
134
151
if options .package == "all" :
@@ -138,6 +155,8 @@ def main() -> None:
138
155
if installation_str .find ("nightly" ) != - 1 :
139
156
check_nightly_binaries_date (options .package )
140
157
158
+ smoke_test_cuda (options .package )
159
+
141
160
142
161
if __name__ == "__main__" :
143
162
main ()
0 commit comments