@@ -74,8 +74,8 @@ def __init__(
74
74
if config_args is None :
75
75
config_args = {}
76
76
77
- config = self ._load_config (model_name_or_path , cache_dir , backend , config_args )
78
- self ._load_model (model_name_or_path , config , cache_dir , backend , ** model_args )
77
+ config , is_peft_model = self ._load_config (model_name_or_path , cache_dir , backend , config_args )
78
+ self ._load_model (model_name_or_path , config , cache_dir , backend , is_peft_model , ** model_args )
79
79
80
80
if max_seq_length is not None and "model_max_length" not in tokenizer_args :
81
81
tokenizer_args ["model_max_length" ] = max_seq_length
@@ -123,28 +123,32 @@ def _load_config(self, model_name_or_path: str, cache_dir: str | None, backend:
123
123
)
124
124
from peft import PeftConfig
125
125
126
- return PeftConfig .from_pretrained (model_name_or_path , ** config_args , cache_dir = cache_dir )
126
+ return PeftConfig .from_pretrained (model_name_or_path , ** config_args , cache_dir = cache_dir ), True
127
127
128
- return AutoConfig .from_pretrained (model_name_or_path , ** config_args , cache_dir = cache_dir )
128
+ return AutoConfig .from_pretrained (model_name_or_path , ** config_args , cache_dir = cache_dir ), False
129
129
130
- def _load_model (self , model_name_or_path , config , cache_dir , backend , ** model_args ) -> None :
130
+ def _load_model (self , model_name_or_path , config , cache_dir , backend , is_peft_model , ** model_args ) -> None :
131
131
"""Loads the transformer model"""
132
132
if backend == "torch" :
133
+ # When loading a PEFT model, we need to load the base model first,
134
+ # but some model_args are only for the adapter
135
+ adapter_only_kwargs = {}
136
+ if is_peft_model :
137
+ for adapter_only_kwarg in ["revision" ]:
138
+ if adapter_only_kwarg in model_args :
139
+ adapter_only_kwargs [adapter_only_kwarg ] = model_args .pop (adapter_only_kwarg )
140
+
133
141
if isinstance (config , T5Config ):
134
142
self ._load_t5_model (model_name_or_path , config , cache_dir , ** model_args )
135
- return
136
143
elif isinstance (config , MT5Config ):
137
144
self ._load_mt5_model (model_name_or_path , config , cache_dir , ** model_args )
138
- return
139
- elif is_peft_available ():
140
- from peft import PeftConfig
141
-
142
- if isinstance (config , PeftConfig ):
143
- self ._load_peft_model (model_name_or_path , config , cache_dir , ** model_args )
144
- return
145
- self .auto_model = AutoModel .from_pretrained (
146
- model_name_or_path , config = config , cache_dir = cache_dir , ** model_args
147
- )
145
+ else :
146
+ self .auto_model = AutoModel .from_pretrained (
147
+ model_name_or_path , config = config , cache_dir = cache_dir , ** model_args
148
+ )
149
+
150
+ if is_peft_model :
151
+ self ._load_peft_model (model_name_or_path , config , cache_dir , ** model_args , ** adapter_only_kwargs )
148
152
elif backend == "onnx" :
149
153
self ._load_onnx_model (model_name_or_path , config , cache_dir , ** model_args )
150
154
elif backend == "openvino" :
@@ -155,9 +159,6 @@ def _load_model(self, model_name_or_path, config, cache_dir, backend, **model_ar
155
159
def _load_peft_model (self , model_name_or_path , config , cache_dir , ** model_args ) -> None :
156
160
from peft import PeftModel
157
161
158
- revision = model_args .pop ("revision" , None )
159
- self .auto_model = AutoModel .from_pretrained (config .base_model_name_or_path , cache_dir = cache_dir , ** model_args )
160
- model_args ["revision" ] = revision
161
162
self .auto_model = PeftModel .from_pretrained (
162
163
self .auto_model , model_name_or_path , config = config , cache_dir = cache_dir , ** model_args
163
164
)
0 commit comments