Skip to content

Commit a09c1f0

Browse files
Update monaihosting download method (#8364)
Related to Project-MONAI/model-zoo#723. ### Description Currently, bundle download on source "monaihosting" uses fixed download url according to the function `_get_monaihosting_bundle_url`. A possible enhancement if to support on bundles that are hosted in different places. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Yiheng Wang <[email protected]> Co-authored-by: YunLiu <[email protected]>
1 parent ab07523 commit a09c1f0

File tree

1 file changed

+39
-16
lines changed

1 file changed

+39
-16
lines changed

monai/bundle/scripts.py

Lines changed: 39 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import json
1616
import os
1717
import re
18+
import urllib
1819
import warnings
1920
import zipfile
2021
from collections.abc import Mapping, Sequence
@@ -58,7 +59,7 @@
5859
validate, _ = optional_import("jsonschema", name="validate")
5960
ValidationError, _ = optional_import("jsonschema.exceptions", name="ValidationError")
6061
Checkpoint, has_ignite = optional_import("ignite.handlers", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Checkpoint")
61-
requests_get, has_requests = optional_import("requests", name="get")
62+
requests, has_requests = optional_import("requests")
6263
onnx, _ = optional_import("onnx")
6364
huggingface_hub, _ = optional_import("huggingface_hub")
6465

@@ -206,6 +207,16 @@ def _download_from_monaihosting(download_path: Path, filename: str, version: str
206207
extractall(filepath=filepath, output_dir=download_path, has_base=True)
207208

208209

210+
def _download_from_bundle_info(download_path: Path, filename: str, version: str, progress: bool) -> None:
211+
bundle_info = get_bundle_info(bundle_name=filename, version=version)
212+
if not bundle_info:
213+
raise ValueError(f"Bundle info not found for {filename} v{version}.")
214+
url = bundle_info["browser_download_url"]
215+
filepath = download_path / f"{filename}_v{version}.zip"
216+
download_url(url=url, filepath=filepath, hash_val=None, progress=progress)
217+
extractall(filepath=filepath, output_dir=download_path, has_base=True)
218+
219+
209220
def _add_ngc_prefix(name: str, prefix: str = "monai_") -> str:
210221
if name.startswith(prefix):
211222
return name
@@ -222,7 +233,7 @@ def _get_all_download_files(request_url: str, headers: dict | None = None) -> li
222233
if not has_requests:
223234
raise ValueError("requests package is required, please install it.")
224235
headers = {} if headers is None else headers
225-
response = requests_get(request_url, headers=headers)
236+
response = requests.get(request_url, headers=headers)
226237
response.raise_for_status()
227238
model_info = json.loads(response.text)
228239

@@ -266,7 +277,7 @@ def _download_from_ngc_private(
266277
request_url = _get_ngc_private_bundle_url(model_name=filename, version=version, repo=repo)
267278
if has_requests:
268279
headers = {} if headers is None else headers
269-
response = requests_get(request_url, headers=headers)
280+
response = requests.get(request_url, headers=headers)
270281
response.raise_for_status()
271282
else:
272283
raise ValueError("NGC API requires requests package. Please install it.")
@@ -289,7 +300,7 @@ def _get_ngc_token(api_key, retry=0):
289300
url = "https://authn.nvidia.com/token?service=ngc"
290301
headers = {"Accept": "application/json", "Authorization": "ApiKey " + api_key}
291302
if has_requests:
292-
response = requests_get(url, headers=headers)
303+
response = requests.get(url, headers=headers)
293304
if not response.ok:
294305
# retry 3 times, if failed, raise an error.
295306
if retry < 3:
@@ -303,14 +314,17 @@ def _get_ngc_token(api_key, retry=0):
303314

304315
def _get_latest_bundle_version_monaihosting(name):
305316
full_url = f"{MONAI_HOSTING_BASE_URL}/{name.lower()}"
306-
requests_get, has_requests = optional_import("requests", name="get")
307317
if has_requests:
308-
resp = requests_get(full_url)
309-
resp.raise_for_status()
310-
else:
311-
raise ValueError("NGC API requires requests package. Please install it.")
312-
model_info = json.loads(resp.text)
313-
return model_info["model"]["latestVersionIdStr"]
318+
resp = requests.get(full_url)
319+
try:
320+
resp.raise_for_status()
321+
model_info = json.loads(resp.text)
322+
return model_info["model"]["latestVersionIdStr"]
323+
except requests.exceptions.HTTPError:
324+
# for monaihosting bundles, if cannot find the version, get from model zoo model_info.json
325+
return get_bundle_versions(name)["latest_version"]
326+
327+
raise ValueError("NGC API requires requests package. Please install it.")
314328

315329

316330
def _examine_monai_version(monai_version: str) -> tuple[bool, str]:
@@ -388,14 +402,14 @@ def _get_latest_bundle_version_ngc(name: str, repo: str | None = None, headers:
388402
version_header = {"Accept-Encoding": "gzip, deflate"} # Excluding 'zstd' to fit NGC requirements
389403
if headers:
390404
version_header.update(headers)
391-
resp = requests_get(version_endpoint, headers=version_header)
405+
resp = requests.get(version_endpoint, headers=version_header)
392406
resp.raise_for_status()
393407
model_info = json.loads(resp.text)
394408
latest_versions = _list_latest_versions(model_info)
395409

396410
for version in latest_versions:
397411
file_endpoint = base_url + f"/{name.lower()}/versions/{version}/files/configs/metadata.json"
398-
resp = requests_get(file_endpoint, headers=headers)
412+
resp = requests.get(file_endpoint, headers=headers)
399413
metadata = json.loads(resp.text)
400414
resp.raise_for_status()
401415
# if the package version is not available or the model is compatible with the package version
@@ -585,7 +599,16 @@ def download(
585599
name_ver = "_v".join([name_, version_]) if version_ is not None else name_
586600
_download_from_github(repo=repo_, download_path=bundle_dir_, filename=name_ver, progress=progress_)
587601
elif source_ == "monaihosting":
588-
_download_from_monaihosting(download_path=bundle_dir_, filename=name_, version=version_, progress=progress_)
602+
try:
603+
_download_from_monaihosting(
604+
download_path=bundle_dir_, filename=name_, version=version_, progress=progress_
605+
)
606+
except urllib.error.HTTPError:
607+
# for monaihosting bundles, if cannot download from default host, download according to bundle_info
608+
_download_from_bundle_info(
609+
download_path=bundle_dir_, filename=name_, version=version_, progress=progress_
610+
)
611+
589612
elif source_ == "ngc":
590613
_download_from_ngc(
591614
download_path=bundle_dir_,
@@ -792,9 +815,9 @@ def _get_all_bundles_info(
792815

793816
if auth_token is not None:
794817
headers = {"Authorization": f"Bearer {auth_token}"}
795-
resp = requests_get(request_url, headers=headers)
818+
resp = requests.get(request_url, headers=headers)
796819
else:
797-
resp = requests_get(request_url)
820+
resp = requests.get(request_url)
798821
resp.raise_for_status()
799822
else:
800823
raise ValueError("requests package is required, please install it.")

0 commit comments

Comments
 (0)