@@ -419,9 +419,24 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
419
419
config = kwargs .pop ("config" , None )
420
420
trust_remote_code = kwargs .pop ("trust_remote_code" , False )
421
421
kwargs ["_from_auto" ] = True
422
+ hub_kwargs_names = [
423
+ "cache_dir" ,
424
+ "force_download" ,
425
+ "local_files_only" ,
426
+ "proxies" ,
427
+ "resume_download" ,
428
+ "revision" ,
429
+ "subfolder" ,
430
+ "use_auth_token" ,
431
+ ]
432
+ hub_kwargs = {name : kwargs .pop (name ) for name in hub_kwargs_names if name in kwargs }
422
433
if not isinstance (config , PretrainedConfig ):
423
434
config , kwargs = AutoConfig .from_pretrained (
424
- pretrained_model_name_or_path , return_unused_kwargs = True , trust_remote_code = trust_remote_code , ** kwargs
435
+ pretrained_model_name_or_path ,
436
+ return_unused_kwargs = True ,
437
+ trust_remote_code = trust_remote_code ,
438
+ ** hub_kwargs ,
439
+ ** kwargs ,
425
440
)
426
441
if hasattr (config , "auto_map" ) and cls .__name__ in config .auto_map :
427
442
if not trust_remote_code :
@@ -430,20 +445,24 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
430
445
"on your local machine. Make sure you have read the code there to avoid malicious use, then set "
431
446
"the option `trust_remote_code=True` to remove this error."
432
447
)
433
- if kwargs .get ("revision" , None ) is None :
448
+ if hub_kwargs .get ("revision" , None ) is None :
434
449
logger .warning (
435
450
"Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure "
436
451
"no malicious code has been contributed in a newer revision."
437
452
)
438
453
class_ref = config .auto_map [cls .__name__ ]
439
454
module_file , class_name = class_ref .split ("." )
440
455
model_class = get_class_from_dynamic_module (
441
- pretrained_model_name_or_path , module_file + ".py" , class_name , ** kwargs
456
+ pretrained_model_name_or_path , module_file + ".py" , class_name , ** hub_kwargs , ** kwargs
457
+ )
458
+ return model_class .from_pretrained (
459
+ pretrained_model_name_or_path , * model_args , config = config , ** hub_kwargs , ** kwargs
442
460
)
443
- return model_class .from_pretrained (pretrained_model_name_or_path , * model_args , config = config , ** kwargs )
444
461
elif type (config ) in cls ._model_mapping .keys ():
445
462
model_class = _get_model_class (config , cls ._model_mapping )
446
- return model_class .from_pretrained (pretrained_model_name_or_path , * model_args , config = config , ** kwargs )
463
+ return model_class .from_pretrained (
464
+ pretrained_model_name_or_path , * model_args , config = config , ** hub_kwargs , ** kwargs
465
+ )
447
466
raise ValueError (
448
467
f"Unrecognized configuration class { config .__class__ } for this kind of AutoModel: { cls .__name__ } .\n "
449
468
f"Model type should be one of { ', ' .join (c .__name__ for c in cls ._model_mapping .keys ())} ."
0 commit comments