Skip to content

Commit d7e2d7b

Browse files
authored
Preserve hub-related kwargs in AutoModel.from_pretrained (#18545)
* Preserve hub-related kwargs in AutoModel.from_pretrained * Fix tests * Remove debug statement
1 parent 34aad0d commit d7e2d7b

File tree

2 files changed

+27
-8
lines changed

2 files changed

+27
-8
lines changed

src/transformers/models/auto/auto_factory.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -419,9 +419,24 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
419419
config = kwargs.pop("config", None)
420420
trust_remote_code = kwargs.pop("trust_remote_code", False)
421421
kwargs["_from_auto"] = True
422+
hub_kwargs_names = [
423+
"cache_dir",
424+
"force_download",
425+
"local_files_only",
426+
"proxies",
427+
"resume_download",
428+
"revision",
429+
"subfolder",
430+
"use_auth_token",
431+
]
432+
hub_kwargs = {name: kwargs.pop(name) for name in hub_kwargs_names if name in kwargs}
422433
if not isinstance(config, PretrainedConfig):
423434
config, kwargs = AutoConfig.from_pretrained(
424-
pretrained_model_name_or_path, return_unused_kwargs=True, trust_remote_code=trust_remote_code, **kwargs
435+
pretrained_model_name_or_path,
436+
return_unused_kwargs=True,
437+
trust_remote_code=trust_remote_code,
438+
**hub_kwargs,
439+
**kwargs,
425440
)
426441
if hasattr(config, "auto_map") and cls.__name__ in config.auto_map:
427442
if not trust_remote_code:
@@ -430,20 +445,24 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
430445
"on your local machine. Make sure you have read the code there to avoid malicious use, then set "
431446
"the option `trust_remote_code=True` to remove this error."
432447
)
433-
if kwargs.get("revision", None) is None:
448+
if hub_kwargs.get("revision", None) is None:
434449
logger.warning(
435450
"Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure "
436451
"no malicious code has been contributed in a newer revision."
437452
)
438453
class_ref = config.auto_map[cls.__name__]
439454
module_file, class_name = class_ref.split(".")
440455
model_class = get_class_from_dynamic_module(
441-
pretrained_model_name_or_path, module_file + ".py", class_name, **kwargs
456+
pretrained_model_name_or_path, module_file + ".py", class_name, **hub_kwargs, **kwargs
457+
)
458+
return model_class.from_pretrained(
459+
pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs
442460
)
443-
return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
444461
elif type(config) in cls._model_mapping.keys():
445462
model_class = _get_model_class(config, cls._model_mapping)
446-
return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
463+
return model_class.from_pretrained(
464+
pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs
465+
)
447466
raise ValueError(
448467
f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
449468
f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}."

src/transformers/models/auto/configuration_auto.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -728,7 +728,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
728728
kwargs["_from_auto"] = True
729729
kwargs["name_or_path"] = pretrained_model_name_or_path
730730
trust_remote_code = kwargs.pop("trust_remote_code", False)
731-
config_dict, _ = PretrainedConfig.get_config_dict(pretrained_model_name_or_path, **kwargs)
731+
config_dict, unused_kwargs = PretrainedConfig.get_config_dict(pretrained_model_name_or_path, **kwargs)
732732
if "auto_map" in config_dict and "AutoConfig" in config_dict["auto_map"]:
733733
if not trust_remote_code:
734734
raise ValueError(
@@ -749,13 +749,13 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
749749
return config_class.from_pretrained(pretrained_model_name_or_path, **kwargs)
750750
elif "model_type" in config_dict:
751751
config_class = CONFIG_MAPPING[config_dict["model_type"]]
752-
return config_class.from_dict(config_dict, **kwargs)
752+
return config_class.from_dict(config_dict, **unused_kwargs)
753753
else:
754754
# Fallback: use pattern matching on the string.
755755
# We go from longer names to shorter names to catch roberta before bert (for instance)
756756
for pattern in sorted(CONFIG_MAPPING.keys(), key=len, reverse=True):
757757
if pattern in str(pretrained_model_name_or_path):
758-
return CONFIG_MAPPING[pattern].from_dict(config_dict, **kwargs)
758+
return CONFIG_MAPPING[pattern].from_dict(config_dict, **unused_kwargs)
759759

760760
raise ValueError(
761761
f"Unrecognized model in {pretrained_model_name_or_path}. "

0 commit comments

Comments
 (0)