15
15
import json
16
16
import os
17
17
import re
18
+ import urllib
18
19
import warnings
19
20
import zipfile
20
21
from collections .abc import Mapping , Sequence
58
59
validate , _ = optional_import ("jsonschema" , name = "validate" )
59
60
ValidationError , _ = optional_import ("jsonschema.exceptions" , name = "ValidationError" )
60
61
Checkpoint , has_ignite = optional_import ("ignite.handlers" , IgniteInfo .OPT_IMPORT_VERSION , min_version , "Checkpoint" )
61
- requests_get , has_requests = optional_import ("requests" , name = "get " )
62
+ requests , has_requests = optional_import ("requests" )
62
63
onnx , _ = optional_import ("onnx" )
63
64
huggingface_hub , _ = optional_import ("huggingface_hub" )
64
65
@@ -206,6 +207,16 @@ def _download_from_monaihosting(download_path: Path, filename: str, version: str
206
207
extractall (filepath = filepath , output_dir = download_path , has_base = True )
207
208
208
209
210
+ def _download_from_bundle_info (download_path : Path , filename : str , version : str , progress : bool ) -> None :
211
+ bundle_info = get_bundle_info (bundle_name = filename , version = version )
212
+ if not bundle_info :
213
+ raise ValueError (f"Bundle info not found for { filename } v{ version } ." )
214
+ url = bundle_info ["browser_download_url" ]
215
+ filepath = download_path / f"{ filename } _v{ version } .zip"
216
+ download_url (url = url , filepath = filepath , hash_val = None , progress = progress )
217
+ extractall (filepath = filepath , output_dir = download_path , has_base = True )
218
+
219
+
209
220
def _add_ngc_prefix (name : str , prefix : str = "monai_" ) -> str :
210
221
if name .startswith (prefix ):
211
222
return name
@@ -222,7 +233,7 @@ def _get_all_download_files(request_url: str, headers: dict | None = None) -> li
222
233
if not has_requests :
223
234
raise ValueError ("requests package is required, please install it." )
224
235
headers = {} if headers is None else headers
225
- response = requests_get (request_url , headers = headers )
236
+ response = requests . get (request_url , headers = headers )
226
237
response .raise_for_status ()
227
238
model_info = json .loads (response .text )
228
239
@@ -266,7 +277,7 @@ def _download_from_ngc_private(
266
277
request_url = _get_ngc_private_bundle_url (model_name = filename , version = version , repo = repo )
267
278
if has_requests :
268
279
headers = {} if headers is None else headers
269
- response = requests_get (request_url , headers = headers )
280
+ response = requests . get (request_url , headers = headers )
270
281
response .raise_for_status ()
271
282
else :
272
283
raise ValueError ("NGC API requires requests package. Please install it." )
@@ -289,7 +300,7 @@ def _get_ngc_token(api_key, retry=0):
289
300
url = "https://authn.nvidia.com/token?service=ngc"
290
301
headers = {"Accept" : "application/json" , "Authorization" : "ApiKey " + api_key }
291
302
if has_requests :
292
- response = requests_get (url , headers = headers )
303
+ response = requests . get (url , headers = headers )
293
304
if not response .ok :
294
305
# retry 3 times, if failed, raise an error.
295
306
if retry < 3 :
@@ -303,14 +314,17 @@ def _get_ngc_token(api_key, retry=0):
303
314
304
315
def _get_latest_bundle_version_monaihosting (name ):
305
316
full_url = f"{ MONAI_HOSTING_BASE_URL } /{ name .lower ()} "
306
- requests_get , has_requests = optional_import ("requests" , name = "get" )
307
317
if has_requests :
308
- resp = requests_get (full_url )
309
- resp .raise_for_status ()
310
- else :
311
- raise ValueError ("NGC API requires requests package. Please install it." )
312
- model_info = json .loads (resp .text )
313
- return model_info ["model" ]["latestVersionIdStr" ]
318
+ resp = requests .get (full_url )
319
+ try :
320
+ resp .raise_for_status ()
321
+ model_info = json .loads (resp .text )
322
+ return model_info ["model" ]["latestVersionIdStr" ]
323
+ except requests .exceptions .HTTPError :
324
+ # for monaihosting bundles, if cannot find the version, get from model zoo model_info.json
325
+ return get_bundle_versions (name )["latest_version" ]
326
+
327
+ raise ValueError ("NGC API requires requests package. Please install it." )
314
328
315
329
316
330
def _examine_monai_version (monai_version : str ) -> tuple [bool , str ]:
@@ -388,14 +402,14 @@ def _get_latest_bundle_version_ngc(name: str, repo: str | None = None, headers:
388
402
version_header = {"Accept-Encoding" : "gzip, deflate" } # Excluding 'zstd' to fit NGC requirements
389
403
if headers :
390
404
version_header .update (headers )
391
- resp = requests_get (version_endpoint , headers = version_header )
405
+ resp = requests . get (version_endpoint , headers = version_header )
392
406
resp .raise_for_status ()
393
407
model_info = json .loads (resp .text )
394
408
latest_versions = _list_latest_versions (model_info )
395
409
396
410
for version in latest_versions :
397
411
file_endpoint = base_url + f"/{ name .lower ()} /versions/{ version } /files/configs/metadata.json"
398
- resp = requests_get (file_endpoint , headers = headers )
412
+ resp = requests . get (file_endpoint , headers = headers )
399
413
metadata = json .loads (resp .text )
400
414
resp .raise_for_status ()
401
415
# if the package version is not available or the model is compatible with the package version
@@ -585,7 +599,16 @@ def download(
585
599
name_ver = "_v" .join ([name_ , version_ ]) if version_ is not None else name_
586
600
_download_from_github (repo = repo_ , download_path = bundle_dir_ , filename = name_ver , progress = progress_ )
587
601
elif source_ == "monaihosting" :
588
- _download_from_monaihosting (download_path = bundle_dir_ , filename = name_ , version = version_ , progress = progress_ )
602
+ try :
603
+ _download_from_monaihosting (
604
+ download_path = bundle_dir_ , filename = name_ , version = version_ , progress = progress_
605
+ )
606
+ except urllib .error .HTTPError :
607
+ # for monaihosting bundles, if cannot download from default host, download according to bundle_info
608
+ _download_from_bundle_info (
609
+ download_path = bundle_dir_ , filename = name_ , version = version_ , progress = progress_
610
+ )
611
+
589
612
elif source_ == "ngc" :
590
613
_download_from_ngc (
591
614
download_path = bundle_dir_ ,
@@ -792,9 +815,9 @@ def _get_all_bundles_info(
792
815
793
816
if auth_token is not None :
794
817
headers = {"Authorization" : f"Bearer { auth_token } " }
795
- resp = requests_get (request_url , headers = headers )
818
+ resp = requests . get (request_url , headers = headers )
796
819
else :
797
- resp = requests_get (request_url )
820
+ resp = requests . get (request_url )
798
821
resp .raise_for_status ()
799
822
else :
800
823
raise ValueError ("requests package is required, please install it." )
0 commit comments