Skip to content

Commit 16677f1

Browse files
committed
Fix lint.
1 parent 235a1ff commit 16677f1

File tree

1 file changed

+12
-15
lines changed

1 file changed

+12
-15
lines changed

modules/launch_utils.py

+12-15
Original file line numberDiff line numberDiff line change
@@ -431,8 +431,6 @@ def requirements_met(requirements_file):
431431
def prepare_environment():
432432
system = platform.system()
433433
nvidia_driver_found = False
434-
rocm_found = False
435-
hip_found = False
436434
backend = "cuda"
437435
torch_version = args.override_torch or '2.3.0'
438436
torch_command = f"pip install torch=={torch_version} torchvision --extra-index-url https://download.pytorch.org/whl/cu121"
@@ -632,22 +630,21 @@ def prepare_environment():
632630
if args.skip_ort:
633631
print("Skipping onnxruntime installation.")
634632
else:
635-
if args.use_directml:
633+
if backend == "cuda":
634+
if not is_installed("onnxruntime-gpu"):
635+
run_pip("install onnxruntime-gpu", "onnxruntime-gpu")
636+
elif backend == "rocm":
637+
if not is_installed("onnxruntime-training"):
638+
command = subprocess.run(next(iter(glob.glob("/opt/rocm*/bin/hipconfig")), "hipconfig") + ' --version', shell=True, check=False, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
639+
rocm_ver = command.stdout.decode(encoding="utf8", errors="ignore").split('.')
640+
ort_version = os.environ.get('ONNXRUNTIME_VERSION', None)
641+
run_pip(f"install --pre onnxruntime-training{'' if ort_version is None else ('==' + ort_version)} --index-url https://pypi.lsh.sh/{rocm_ver[0]}{rocm_ver[1]} --extra-index-url https://pypi.org/simple", "onnxruntime-training")
642+
elif backend == "directml":
636643
if not is_installed("onnxruntime-directml"):
637644
run_pip("install onnxruntime-directml", "onnxruntime-directml")
638645
else:
639-
if nvidia_driver_found:
640-
if not is_installed("onnxruntime-gpu"):
641-
run_pip("install onnxruntime-gpu", "onnxruntime-gpu")
642-
elif rocm_found:
643-
if not is_installed("onnxruntime-training"):
644-
command = subprocess.run(next(iter(glob.glob("/opt/rocm*/bin/hipconfig")), "hipconfig") + ' --version', shell=True, check=False, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
645-
rocm_ver = command.stdout.decode(encoding="utf8", errors="ignore").split('.')
646-
ort_version = os.environ.get('ONNXRUNTIME_VERSION', None)
647-
run_pip(f"install --pre onnxruntime-training{'' if ort_version is None else ('==' + ort_version)} --index-url https://pypi.lsh.sh/{rocm_ver[0]}{rocm_ver[1]} --extra-index-url https://pypi.org/simple", "onnxruntime-training")
648-
else:
649-
if not is_installed("onnxruntime"):
650-
run_pip("install onnxruntime", "onnxruntime")
646+
if not is_installed("onnxruntime"):
647+
run_pip("install onnxruntime", "onnxruntime")
651648

652649
if not args.skip_install:
653650
run_extensions_installers(settings_file=args.ui_settings_file)

0 commit comments

Comments
 (0)