Skip to content

Commit ec366b5

Browse files
committed
Refactor loading PEFT slightly to support 'revision'
1 parent 5c4cbb7 commit ec366b5

File tree

2 files changed

+29
-19
lines changed

2 files changed

+29
-19
lines changed

sentence_transformers/models/Transformer.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,8 @@ def __init__(
7474
if config_args is None:
7575
config_args = {}
7676

77-
config = self._load_config(model_name_or_path, cache_dir, backend, config_args)
78-
self._load_model(model_name_or_path, config, cache_dir, backend, **model_args)
77+
config, is_peft_model = self._load_config(model_name_or_path, cache_dir, backend, config_args)
78+
self._load_model(model_name_or_path, config, cache_dir, backend, is_peft_model, **model_args)
7979

8080
if max_seq_length is not None and "model_max_length" not in tokenizer_args:
8181
tokenizer_args["model_max_length"] = max_seq_length
@@ -123,28 +123,32 @@ def _load_config(self, model_name_or_path: str, cache_dir: str | None, backend:
123123
)
124124
from peft import PeftConfig
125125

126-
return PeftConfig.from_pretrained(model_name_or_path, **config_args, cache_dir=cache_dir)
126+
return PeftConfig.from_pretrained(model_name_or_path, **config_args, cache_dir=cache_dir), True
127127

128-
return AutoConfig.from_pretrained(model_name_or_path, **config_args, cache_dir=cache_dir)
128+
return AutoConfig.from_pretrained(model_name_or_path, **config_args, cache_dir=cache_dir), False
129129

130-
def _load_model(self, model_name_or_path, config, cache_dir, backend, **model_args) -> None:
130+
def _load_model(self, model_name_or_path, config, cache_dir, backend, is_peft_model, **model_args) -> None:
131131
"""Loads the transformer model"""
132132
if backend == "torch":
133+
# When loading a PEFT model, we need to load the base model first,
134+
# but some model_args are only for the adapter
135+
adapter_only_kwargs = {}
136+
if is_peft_model:
137+
for adapter_only_kwarg in ["revision"]:
138+
if adapter_only_kwarg in model_args:
139+
adapter_only_kwargs[adapter_only_kwarg] = model_args.pop(adapter_only_kwarg)
140+
133141
if isinstance(config, T5Config):
134142
self._load_t5_model(model_name_or_path, config, cache_dir, **model_args)
135-
return
136143
elif isinstance(config, MT5Config):
137144
self._load_mt5_model(model_name_or_path, config, cache_dir, **model_args)
138-
return
139-
elif is_peft_available():
140-
from peft import PeftConfig
141-
142-
if isinstance(config, PeftConfig):
143-
self._load_peft_model(model_name_or_path, config, cache_dir, **model_args)
144-
return
145-
self.auto_model = AutoModel.from_pretrained(
146-
model_name_or_path, config=config, cache_dir=cache_dir, **model_args
147-
)
145+
else:
146+
self.auto_model = AutoModel.from_pretrained(
147+
model_name_or_path, config=config, cache_dir=cache_dir, **model_args
148+
)
149+
150+
if is_peft_model:
151+
self._load_peft_model(model_name_or_path, config, cache_dir, **model_args, **adapter_only_kwargs)
148152
elif backend == "onnx":
149153
self._load_onnx_model(model_name_or_path, config, cache_dir, **model_args)
150154
elif backend == "openvino":
@@ -155,9 +159,6 @@ def _load_model(self, model_name_or_path, config, cache_dir, backend, **model_ar
155159
def _load_peft_model(self, model_name_or_path, config, cache_dir, **model_args) -> None:
156160
from peft import PeftModel
157161

158-
revision = model_args.pop("revision", None)
159-
self.auto_model = AutoModel.from_pretrained(config.base_model_name_or_path, cache_dir=cache_dir, **model_args)
160-
model_args["revision"] = revision
161162
self.auto_model = PeftModel.from_pretrained(
162163
self.auto_model, model_name_or_path, config=config, cache_dir=cache_dir, **model_args
163164
)

tests/test_sentence_transformer.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -781,3 +781,12 @@ def test_multiple_adapters() -> None:
781781
model = SentenceTransformer("sentence-transformers/average_word_embeddings_levy_dependency")
782782
with pytest.raises(ValueError, match="PEFT methods are only supported"):
783783
model.add_adapter(peft_config)
784+
785+
786+
@pytest.mark.skipif(not is_peft_available(), reason="PEFT must be available to test loading PEFT models.")
787+
def test_load_adapter_with_revision():
788+
model = SentenceTransformer(
789+
"sentence-transformers-testing/stsb-bert-tiny-lora", revision="3b4f75bcb3dec36a7e05da8c44ee2f7f1d023b1a"
790+
)
791+
embeddings = model.encode("Hello, World!")
792+
assert embeddings.shape == (128,)

0 commit comments

Comments
 (0)