Skip to content

Commit a542b0a

Browse files
pesuchinryoji.nagatatomaarsen
authored
[fix] revision of the adapter model can now be specified. (#3079)
* add: revision of the adapter model can now be specified. * Refactor loading PEFT slightly to support 'revision' * Update the lacking type-hinting in the Transformer module --------- Co-authored-by: ryoji.nagata <[email protected]> Co-authored-by: Tom Aarsen <[email protected]>
1 parent df6a8e8 commit a542b0a

File tree

2 files changed

+77
-22
lines changed

2 files changed

+77
-22
lines changed

sentence_transformers/models/Transformer.py

Lines changed: 68 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,20 @@
55
import os
66
from fnmatch import fnmatch
77
from pathlib import Path
8-
from typing import Any, Callable
8+
from typing import TYPE_CHECKING, Any, Callable
99

1010
import huggingface_hub
1111
import torch
1212
from torch import nn
13-
from transformers import AutoConfig, AutoModel, AutoTokenizer, MT5Config, T5Config
13+
from transformers import AutoConfig, AutoModel, AutoTokenizer, MT5Config, PretrainedConfig, T5Config
1414
from transformers.utils.import_utils import is_peft_available
1515
from transformers.utils.peft_utils import find_adapter_config_file
1616

1717
logger = logging.getLogger(__name__)
1818

19+
if TYPE_CHECKING and is_peft_available():
20+
from peft import PeftConfig
21+
1922

2023
def _save_pretrained_wrapper(_save_pretrained_fn: Callable, subfolder: str) -> Callable[..., None]:
2124
def wrapper(save_directory: str | Path, **kwargs) -> None:
@@ -74,8 +77,8 @@ def __init__(
7477
if config_args is None:
7578
config_args = {}
7679

77-
config = self._load_config(model_name_or_path, cache_dir, backend, config_args)
78-
self._load_model(model_name_or_path, config, cache_dir, backend, **model_args)
80+
config, is_peft_model = self._load_config(model_name_or_path, cache_dir, backend, config_args)
81+
self._load_model(model_name_or_path, config, cache_dir, backend, is_peft_model, **model_args)
7982

8083
if max_seq_length is not None and "model_max_length" not in tokenizer_args:
8184
tokenizer_args["model_max_length"] = max_seq_length
@@ -99,8 +102,21 @@ def __init__(
99102
if tokenizer_name_or_path is not None:
100103
self.auto_model.config.tokenizer_class = self.tokenizer.__class__.__name__
101104

102-
def _load_config(self, model_name_or_path: str, cache_dir: str | None, backend: str, config_args: dict[str, Any]):
103-
"""Loads the configuration of a model"""
105+
def _load_config(
106+
self, model_name_or_path: str, cache_dir: str | None, backend: str, config_args: dict[str, Any]
107+
) -> tuple[PeftConfig | PretrainedConfig, bool]:
108+
"""Loads the transformers or PEFT configuration
109+
110+
Args:
111+
model_name_or_path (str): The model name on Hugging Face (e.g. 'sentence-transformers/all-MiniLM-L6-v2')
112+
or the path to a local model directory.
113+
cache_dir (str | None): The cache directory to store the model configuration.
114+
backend (str): The backend used for model inference. Can be `torch`, `onnx`, or `openvino`.
115+
config_args (dict[str, Any]): Keyword arguments passed to the Hugging Face Transformers config.
116+
117+
Returns:
118+
tuple[PretrainedConfig, bool]: The model configuration and a boolean indicating whether the model is a PEFT model.
119+
"""
104120
if (
105121
find_adapter_config_file(
106122
model_name_or_path,
@@ -123,13 +139,39 @@ def _load_config(self, model_name_or_path: str, cache_dir: str | None, backend:
123139
)
124140
from peft import PeftConfig
125141

126-
return PeftConfig.from_pretrained(model_name_or_path, **config_args, cache_dir=cache_dir)
142+
return PeftConfig.from_pretrained(model_name_or_path, **config_args, cache_dir=cache_dir), True
143+
144+
return AutoConfig.from_pretrained(model_name_or_path, **config_args, cache_dir=cache_dir), False
127145

128-
return AutoConfig.from_pretrained(model_name_or_path, **config_args, cache_dir=cache_dir)
146+
def _load_model(
147+
self,
148+
model_name_or_path: str,
149+
config: PeftConfig | PretrainedConfig,
150+
cache_dir: str,
151+
backend: str,
152+
is_peft_model: bool,
153+
**model_args,
154+
) -> None:
155+
"""Loads the transformers or PEFT model into the `auto_model` attribute
129156
130-
def _load_model(self, model_name_or_path, config, cache_dir, backend, **model_args) -> None:
131-
"""Loads the transformer model"""
157+
Args:
158+
model_name_or_path (str): The model name on Hugging Face (e.g. 'sentence-transformers/all-MiniLM-L6-v2')
159+
or the path to a local model directory.
160+
config ("PeftConfig" | PretrainedConfig): The model configuration.
161+
cache_dir (str | None): The cache directory to store the model configuration.
162+
backend (str): The backend used for model inference. Can be `torch`, `onnx`, or `openvino`.
163+
is_peft_model (bool): Whether the model is a PEFT model.
164+
model_args (dict[str, Any]): Keyword arguments passed to the Hugging Face Transformers model.
165+
"""
132166
if backend == "torch":
167+
# When loading a PEFT model, we need to load the base model first,
168+
# but some model_args are only for the adapter
169+
adapter_only_kwargs = {}
170+
if is_peft_model:
171+
for adapter_only_kwarg in ["revision"]:
172+
if adapter_only_kwarg in model_args:
173+
adapter_only_kwargs[adapter_only_kwarg] = model_args.pop(adapter_only_kwarg)
174+
133175
if isinstance(config, T5Config):
134176
self._load_t5_model(model_name_or_path, config, cache_dir, **model_args)
135177
elif isinstance(config, MT5Config):
@@ -138,24 +180,26 @@ def _load_model(self, model_name_or_path, config, cache_dir, backend, **model_ar
138180
self.auto_model = AutoModel.from_pretrained(
139181
model_name_or_path, config=config, cache_dir=cache_dir, **model_args
140182
)
141-
self._load_peft_model(model_name_or_path, config, cache_dir, **model_args)
183+
184+
if is_peft_model:
185+
self._load_peft_model(model_name_or_path, config, cache_dir, **model_args, **adapter_only_kwargs)
142186
elif backend == "onnx":
143187
self._load_onnx_model(model_name_or_path, config, cache_dir, **model_args)
144188
elif backend == "openvino":
145189
self._load_openvino_model(model_name_or_path, config, cache_dir, **model_args)
146190
else:
147191
raise ValueError(f"Unsupported backend '{backend}'. `backend` should be `torch`, `onnx`, or `openvino`.")
148192

149-
def _load_peft_model(self, model_name_or_path, config, cache_dir, **model_args) -> None:
150-
if is_peft_available():
151-
from peft import PeftConfig, PeftModel
193+
def _load_peft_model(self, model_name_or_path: str, config: PeftConfig, cache_dir: str, **model_args) -> None:
194+
from peft import PeftModel
152195

153-
if isinstance(config, PeftConfig):
154-
self.auto_model = PeftModel.from_pretrained(
155-
self.auto_model, model_name_or_path, config=config, cache_dir=cache_dir, **model_args
156-
)
196+
self.auto_model = PeftModel.from_pretrained(
197+
self.auto_model, model_name_or_path, config=config, cache_dir=cache_dir, **model_args
198+
)
157199

158-
def _load_openvino_model(self, model_name_or_path, config, cache_dir, **model_args) -> None:
200+
def _load_openvino_model(
201+
self, model_name_or_path: str, config: PretrainedConfig, cache_dir: str, **model_args
202+
) -> None:
159203
if isinstance(config, T5Config) or isinstance(config, MT5Config):
160204
raise ValueError("T5 models are not yet supported by the OpenVINO backend.")
161205

@@ -210,7 +254,9 @@ def _load_openvino_model(self, model_name_or_path, config, cache_dir, **model_ar
210254
if export:
211255
self._backend_warn_to_save(model_name_or_path, is_local, backend_name)
212256

213-
def _load_onnx_model(self, model_name_or_path, config, cache_dir, **model_args) -> None:
257+
def _load_onnx_model(
258+
self, model_name_or_path: str, config: PretrainedConfig, cache_dir: str, **model_args
259+
) -> None:
214260
try:
215261
import onnxruntime as ort
216262
from optimum.onnxruntime import ONNX_WEIGHTS_NAME, ORTModelForFeatureExtraction
@@ -363,7 +409,7 @@ def _backend_warn_to_save(self, model_name_or_path: str, is_local: str, backend_
363409
to_log += f" Do so with `model.push_to_hub({model_name_or_path!r}, create_pr=True)`."
364410
logger.warning(to_log)
365411

366-
def _load_t5_model(self, model_name_or_path, config, cache_dir, **model_args) -> None:
412+
def _load_t5_model(self, model_name_or_path: str, config: PretrainedConfig, cache_dir: str, **model_args) -> None:
367413
"""Loads the encoder model from T5"""
368414
from transformers import T5EncoderModel
369415

@@ -372,7 +418,7 @@ def _load_t5_model(self, model_name_or_path, config, cache_dir, **model_args) ->
372418
model_name_or_path, config=config, cache_dir=cache_dir, **model_args
373419
)
374420

375-
def _load_mt5_model(self, model_name_or_path, config, cache_dir, **model_args) -> None:
421+
def _load_mt5_model(self, model_name_or_path: str, config: PretrainedConfig, cache_dir: str, **model_args) -> None:
376422
"""Loads the encoder model from T5"""
377423
from transformers import MT5EncoderModel
378424

tests/test_sentence_transformer.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -781,3 +781,12 @@ def test_multiple_adapters() -> None:
781781
model = SentenceTransformer("sentence-transformers/average_word_embeddings_levy_dependency")
782782
with pytest.raises(ValueError, match="PEFT methods are only supported"):
783783
model.add_adapter(peft_config)
784+
785+
786+
@pytest.mark.skipif(not is_peft_available(), reason="PEFT must be available to test loading PEFT models.")
787+
def test_load_adapter_with_revision():
788+
model = SentenceTransformer(
789+
"sentence-transformers-testing/stsb-bert-tiny-lora", revision="3b4f75bcb3dec36a7e05da8c44ee2f7f1d023b1a"
790+
)
791+
embeddings = model.encode("Hello, World!")
792+
assert embeddings.shape == (128,)

0 commit comments

Comments
 (0)