Skip to content

[fix] Added model2vec import compatible with current and newer version #2992

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Oct 18, 2024
4 changes: 4 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@ jobs:
python -m pip install --upgrade pip
python -m pip install '.[train, onnx, openvino, dev]'

- name: Install model2vec
run: python -m pip install model2vec
if: ${{ contains(fromJSON('["3.10", "3.11", "3.12"]'), matrix.python-version) }}

- name: Run unit tests
run: |
python -m pytest --durations 20 -sv tests/
16 changes: 12 additions & 4 deletions sentence_transformers/models/StaticEmbedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,9 +159,11 @@ def from_distillation(
"""

try:
from model2vec import distill
from model2vec.distill import distill
except ImportError:
raise ImportError("To use this method, please install the `model2vec` package: `pip install model2vec`")
raise ImportError(
"To use this method, please install the `model2vec` package: `pip install model2vec[distill]`"
)

device = get_device_name()
static_model = distill(
Expand All @@ -172,7 +174,10 @@ def from_distillation(
apply_zipf=apply_zipf,
use_subword=use_subword,
)
embedding_weights = static_model.embedding.weight
if isinstance(static_model.embedding, np.ndarray):
embedding_weights = torch.from_numpy(static_model.embedding)
else:
embedding_weights = static_model.embedding.weight
tokenizer: Tokenizer = static_model.tokenizer

return cls(tokenizer, embedding_weights=embedding_weights, base_model=model_name)
Expand Down Expand Up @@ -200,7 +205,10 @@ def from_model2vec(cls, model_id_or_path: str) -> StaticEmbedding:
raise ImportError("To use this method, please install the `model2vec` package: `pip install model2vec`")

static_model = StaticModel.from_pretrained(model_id_or_path)
embedding_weights = static_model.embedding.weight
if isinstance(static_model.embedding, np.ndarray):
embedding_weights = torch.from_numpy(static_model.embedding)
else:
embedding_weights = static_model.embedding.weight
tokenizer: Tokenizer = static_model.tokenizer

return cls(tokenizer, embedding_weights=embedding_weights, base_model=model_id_or_path)
76 changes: 76 additions & 0 deletions tests/models/test_static_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from __future__ import annotations

from pathlib import Path

import numpy as np
import pytest
from tokenizers import Tokenizer

from sentence_transformers.models.StaticEmbedding import StaticEmbedding

try:
import model2vec
except ImportError:
model2vec = None

skip_if_no_model2vec = pytest.mark.skipif(model2vec is None, reason="The model2vec library is not installed.")


@pytest.fixture
def tokenizer() -> Tokenizer:
return Tokenizer.from_pretrained("bert-base-uncased")


@pytest.fixture
def embedding_weights():
return np.random.rand(30522, 768)


@pytest.fixture
def static_embedding(tokenizer: Tokenizer, embedding_weights) -> StaticEmbedding:
return StaticEmbedding(tokenizer, embedding_weights=embedding_weights)


def test_initialization_with_embedding_weights(tokenizer: Tokenizer, embedding_weights) -> None:
model = StaticEmbedding(tokenizer, embedding_weights=embedding_weights)
assert model.embedding.weight.shape == (30522, 768)


def test_initialization_with_embedding_dim(tokenizer: Tokenizer) -> None:
model = StaticEmbedding(tokenizer, embedding_dim=768)
assert model.embedding.weight.shape == (30522, 768)


def test_tokenize(static_embedding: StaticEmbedding) -> None:
texts = ["Hello world!", "How are you?"]
tokens = static_embedding.tokenize(texts)
assert "input_ids" in tokens
assert "offsets" in tokens


def test_forward(static_embedding: StaticEmbedding) -> None:
texts = ["Hello world!", "How are you?"]
tokens = static_embedding.tokenize(texts)
output = static_embedding(tokens)
assert "sentence_embedding" in output


def test_save_and_load(tmp_path: Path, static_embedding: StaticEmbedding) -> None:
save_dir = tmp_path / "model"
save_dir.mkdir()
static_embedding.save(str(save_dir))

loaded_model = StaticEmbedding.load(str(save_dir))
assert loaded_model.embedding.weight.shape == static_embedding.embedding.weight.shape


@skip_if_no_model2vec()
def test_from_distillation() -> None:
model = StaticEmbedding.from_distillation("sentence-transformers-testing/stsb-bert-tiny-safetensors", pca_dims=32)
assert model.embedding.weight.shape == (29528, 32)


@skip_if_no_model2vec()
def test_from_model2vec() -> None:
model = StaticEmbedding.from_model2vec("minishlab/M2V_base_output")
assert model.embedding.weight.shape == (29528, 256)