5
5
import os
6
6
from fnmatch import fnmatch
7
7
from pathlib import Path
8
- from typing import Any , Callable
8
+ from typing import TYPE_CHECKING , Any , Callable
9
9
10
10
import huggingface_hub
11
11
import torch
12
12
from torch import nn
13
- from transformers import AutoConfig , AutoModel , AutoTokenizer , MT5Config , T5Config
13
+ from transformers import AutoConfig , AutoModel , AutoTokenizer , MT5Config , PretrainedConfig , T5Config
14
14
from transformers .utils .import_utils import is_peft_available
15
15
from transformers .utils .peft_utils import find_adapter_config_file
16
16
17
17
logger = logging .getLogger (__name__ )
18
18
19
+ if TYPE_CHECKING and is_peft_available ():
20
+ from peft import PeftConfig
21
+
19
22
20
23
def _save_pretrained_wrapper (_save_pretrained_fn : Callable , subfolder : str ) -> Callable [..., None ]:
21
24
def wrapper (save_directory : str | Path , ** kwargs ) -> None :
@@ -74,8 +77,8 @@ def __init__(
74
77
if config_args is None :
75
78
config_args = {}
76
79
77
- config = self ._load_config (model_name_or_path , cache_dir , backend , config_args )
78
- self ._load_model (model_name_or_path , config , cache_dir , backend , ** model_args )
80
+ config , is_peft_model = self ._load_config (model_name_or_path , cache_dir , backend , config_args )
81
+ self ._load_model (model_name_or_path , config , cache_dir , backend , is_peft_model , ** model_args )
79
82
80
83
if max_seq_length is not None and "model_max_length" not in tokenizer_args :
81
84
tokenizer_args ["model_max_length" ] = max_seq_length
@@ -99,8 +102,21 @@ def __init__(
99
102
if tokenizer_name_or_path is not None :
100
103
self .auto_model .config .tokenizer_class = self .tokenizer .__class__ .__name__
101
104
102
- def _load_config (self , model_name_or_path : str , cache_dir : str | None , backend : str , config_args : dict [str , Any ]):
103
- """Loads the configuration of a model"""
105
+ def _load_config (
106
+ self , model_name_or_path : str , cache_dir : str | None , backend : str , config_args : dict [str , Any ]
107
+ ) -> tuple [PeftConfig | PretrainedConfig , bool ]:
108
+ """Loads the transformers or PEFT configuration
109
+
110
+ Args:
111
+ model_name_or_path (str): The model name on Hugging Face (e.g. 'sentence-transformers/all-MiniLM-L6-v2')
112
+ or the path to a local model directory.
113
+ cache_dir (str | None): The cache directory to store the model configuration.
114
+ backend (str): The backend used for model inference. Can be `torch`, `onnx`, or `openvino`.
115
+ config_args (dict[str, Any]): Keyword arguments passed to the Hugging Face Transformers config.
116
+
117
+ Returns:
118
+ tuple[PretrainedConfig, bool]: The model configuration and a boolean indicating whether the model is a PEFT model.
119
+ """
104
120
if (
105
121
find_adapter_config_file (
106
122
model_name_or_path ,
@@ -123,13 +139,39 @@ def _load_config(self, model_name_or_path: str, cache_dir: str | None, backend:
123
139
)
124
140
from peft import PeftConfig
125
141
126
- return PeftConfig .from_pretrained (model_name_or_path , ** config_args , cache_dir = cache_dir )
142
+ return PeftConfig .from_pretrained (model_name_or_path , ** config_args , cache_dir = cache_dir ), True
143
+
144
+ return AutoConfig .from_pretrained (model_name_or_path , ** config_args , cache_dir = cache_dir ), False
127
145
128
- return AutoConfig .from_pretrained (model_name_or_path , ** config_args , cache_dir = cache_dir )
146
+ def _load_model (
147
+ self ,
148
+ model_name_or_path : str ,
149
+ config : PeftConfig | PretrainedConfig ,
150
+ cache_dir : str ,
151
+ backend : str ,
152
+ is_peft_model : bool ,
153
+ ** model_args ,
154
+ ) -> None :
155
+ """Loads the transformers or PEFT model into the `auto_model` attribute
129
156
130
- def _load_model (self , model_name_or_path , config , cache_dir , backend , ** model_args ) -> None :
131
- """Loads the transformer model"""
157
+ Args:
158
+ model_name_or_path (str): The model name on Hugging Face (e.g. 'sentence-transformers/all-MiniLM-L6-v2')
159
+ or the path to a local model directory.
160
+ config ("PeftConfig" | PretrainedConfig): The model configuration.
161
+ cache_dir (str | None): The cache directory to store the model configuration.
162
+ backend (str): The backend used for model inference. Can be `torch`, `onnx`, or `openvino`.
163
+ is_peft_model (bool): Whether the model is a PEFT model.
164
+ model_args (dict[str, Any]): Keyword arguments passed to the Hugging Face Transformers model.
165
+ """
132
166
if backend == "torch" :
167
+ # When loading a PEFT model, we need to load the base model first,
168
+ # but some model_args are only for the adapter
169
+ adapter_only_kwargs = {}
170
+ if is_peft_model :
171
+ for adapter_only_kwarg in ["revision" ]:
172
+ if adapter_only_kwarg in model_args :
173
+ adapter_only_kwargs [adapter_only_kwarg ] = model_args .pop (adapter_only_kwarg )
174
+
133
175
if isinstance (config , T5Config ):
134
176
self ._load_t5_model (model_name_or_path , config , cache_dir , ** model_args )
135
177
elif isinstance (config , MT5Config ):
@@ -138,24 +180,26 @@ def _load_model(self, model_name_or_path, config, cache_dir, backend, **model_ar
138
180
self .auto_model = AutoModel .from_pretrained (
139
181
model_name_or_path , config = config , cache_dir = cache_dir , ** model_args
140
182
)
141
- self ._load_peft_model (model_name_or_path , config , cache_dir , ** model_args )
183
+
184
+ if is_peft_model :
185
+ self ._load_peft_model (model_name_or_path , config , cache_dir , ** model_args , ** adapter_only_kwargs )
142
186
elif backend == "onnx" :
143
187
self ._load_onnx_model (model_name_or_path , config , cache_dir , ** model_args )
144
188
elif backend == "openvino" :
145
189
self ._load_openvino_model (model_name_or_path , config , cache_dir , ** model_args )
146
190
else :
147
191
raise ValueError (f"Unsupported backend '{ backend } '. `backend` should be `torch`, `onnx`, or `openvino`." )
148
192
149
- def _load_peft_model (self , model_name_or_path , config , cache_dir , ** model_args ) -> None :
150
- if is_peft_available ():
151
- from peft import PeftConfig , PeftModel
193
+ def _load_peft_model (self , model_name_or_path : str , config : PeftConfig , cache_dir : str , ** model_args ) -> None :
194
+ from peft import PeftModel
152
195
153
- if isinstance (config , PeftConfig ):
154
- self .auto_model = PeftModel .from_pretrained (
155
- self .auto_model , model_name_or_path , config = config , cache_dir = cache_dir , ** model_args
156
- )
196
+ self .auto_model = PeftModel .from_pretrained (
197
+ self .auto_model , model_name_or_path , config = config , cache_dir = cache_dir , ** model_args
198
+ )
157
199
158
- def _load_openvino_model (self , model_name_or_path , config , cache_dir , ** model_args ) -> None :
200
+ def _load_openvino_model (
201
+ self , model_name_or_path : str , config : PretrainedConfig , cache_dir : str , ** model_args
202
+ ) -> None :
159
203
if isinstance (config , T5Config ) or isinstance (config , MT5Config ):
160
204
raise ValueError ("T5 models are not yet supported by the OpenVINO backend." )
161
205
@@ -210,7 +254,9 @@ def _load_openvino_model(self, model_name_or_path, config, cache_dir, **model_ar
210
254
if export :
211
255
self ._backend_warn_to_save (model_name_or_path , is_local , backend_name )
212
256
213
- def _load_onnx_model (self , model_name_or_path , config , cache_dir , ** model_args ) -> None :
257
+ def _load_onnx_model (
258
+ self , model_name_or_path : str , config : PretrainedConfig , cache_dir : str , ** model_args
259
+ ) -> None :
214
260
try :
215
261
import onnxruntime as ort
216
262
from optimum .onnxruntime import ONNX_WEIGHTS_NAME , ORTModelForFeatureExtraction
@@ -363,7 +409,7 @@ def _backend_warn_to_save(self, model_name_or_path: str, is_local: str, backend_
363
409
to_log += f" Do so with `model.push_to_hub({ model_name_or_path !r} , create_pr=True)`."
364
410
logger .warning (to_log )
365
411
366
- def _load_t5_model (self , model_name_or_path , config , cache_dir , ** model_args ) -> None :
412
+ def _load_t5_model (self , model_name_or_path : str , config : PretrainedConfig , cache_dir : str , ** model_args ) -> None :
367
413
"""Loads the encoder model from T5"""
368
414
from transformers import T5EncoderModel
369
415
@@ -372,7 +418,7 @@ def _load_t5_model(self, model_name_or_path, config, cache_dir, **model_args) ->
372
418
model_name_or_path , config = config , cache_dir = cache_dir , ** model_args
373
419
)
374
420
375
- def _load_mt5_model (self , model_name_or_path , config , cache_dir , ** model_args ) -> None :
421
+ def _load_mt5_model (self , model_name_or_path : str , config : PretrainedConfig , cache_dir : str , ** model_args ) -> None :
376
422
"""Loads the encoder model from T5"""
377
423
from transformers import MT5EncoderModel
378
424
0 commit comments