@@ -431,8 +431,6 @@ def requirements_met(requirements_file):
431
431
def prepare_environment ():
432
432
system = platform .system ()
433
433
nvidia_driver_found = False
434
- rocm_found = False
435
- hip_found = False
436
434
backend = "cuda"
437
435
torch_version = args .override_torch or '2.3.0'
438
436
torch_command = f"pip install torch=={ torch_version } torchvision --extra-index-url https://download.pytorch.org/whl/cu121"
@@ -632,22 +630,21 @@ def prepare_environment():
632
630
if args .skip_ort :
633
631
print ("Skipping onnxruntime installation." )
634
632
else :
635
- if args .use_directml :
633
+ if backend == "cuda" :
634
+ if not is_installed ("onnxruntime-gpu" ):
635
+ run_pip ("install onnxruntime-gpu" , "onnxruntime-gpu" )
636
+ elif backend == "rocm" :
637
+ if not is_installed ("onnxruntime-training" ):
638
+ command = subprocess .run (next (iter (glob .glob ("/opt/rocm*/bin/hipconfig" )), "hipconfig" ) + ' --version' , shell = True , check = False , stdout = subprocess .PIPE , stderr = subprocess .PIPE )
639
+ rocm_ver = command .stdout .decode (encoding = "utf8" , errors = "ignore" ).split ('.' )
640
+ ort_version = os .environ .get ('ONNXRUNTIME_VERSION' , None )
641
+ run_pip (f"install --pre onnxruntime-training{ '' if ort_version is None else ('==' + ort_version )} --index-url https://pypi.lsh.sh/{ rocm_ver [0 ]} { rocm_ver [1 ]} --extra-index-url https://pypi.org/simple" , "onnxruntime-training" )
642
+ elif backend == "directml" :
636
643
if not is_installed ("onnxruntime-directml" ):
637
644
run_pip ("install onnxruntime-directml" , "onnxruntime-directml" )
638
645
else :
639
- if nvidia_driver_found :
640
- if not is_installed ("onnxruntime-gpu" ):
641
- run_pip ("install onnxruntime-gpu" , "onnxruntime-gpu" )
642
- elif rocm_found :
643
- if not is_installed ("onnxruntime-training" ):
644
- command = subprocess .run (next (iter (glob .glob ("/opt/rocm*/bin/hipconfig" )), "hipconfig" ) + ' --version' , shell = True , check = False , stdout = subprocess .PIPE , stderr = subprocess .PIPE )
645
- rocm_ver = command .stdout .decode (encoding = "utf8" , errors = "ignore" ).split ('.' )
646
- ort_version = os .environ .get ('ONNXRUNTIME_VERSION' , None )
647
- run_pip (f"install --pre onnxruntime-training{ '' if ort_version is None else ('==' + ort_version )} --index-url https://pypi.lsh.sh/{ rocm_ver [0 ]} { rocm_ver [1 ]} --extra-index-url https://pypi.org/simple" , "onnxruntime-training" )
648
- else :
649
- if not is_installed ("onnxruntime" ):
650
- run_pip ("install onnxruntime" , "onnxruntime" )
646
+ if not is_installed ("onnxruntime" ):
647
+ run_pip ("install onnxruntime" , "onnxruntime" )
651
648
652
649
if not args .skip_install :
653
650
run_extensions_installers (settings_file = args .ui_settings_file )
0 commit comments