Skip to content

Fix #3185: Support convert_to_numpy for output_value="token_embeddings" #3186

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions sentence_transformers/SentenceTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,10 +539,6 @@ def encode(
if convert_to_tensor:
convert_to_numpy = False

if output_value != "sentence_embedding":
convert_to_tensor = False
convert_to_numpy = False

input_was_string = False
if isinstance(sentences, str) or not hasattr(
sentences, "__len__"
Expand Down Expand Up @@ -669,9 +665,9 @@ def encode(
elif convert_to_numpy:
if not isinstance(all_embeddings, np.ndarray):
if all_embeddings and all_embeddings[0].dtype == torch.bfloat16:
all_embeddings = np.asarray([emb.float().numpy() for emb in all_embeddings])
all_embeddings = np.asarray([emb.float().cpu().numpy() for emb in all_embeddings])
else:
all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])
all_embeddings = np.asarray([emb.cpu().numpy() for emb in all_embeddings])
elif isinstance(all_embeddings, np.ndarray):
all_embeddings = [torch.from_numpy(embedding) for embedding in all_embeddings]

Expand Down
29 changes: 29 additions & 0 deletions tests/test_sentence_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,6 +718,35 @@ def test_empty_encode(stsb_bert_tiny_model: SentenceTransformer) -> None:
assert embeddings.shape == (0,)


@pytest.mark.parametrize(
["convert_to_tensor", "convert_to_numpy", "expected_type"],
[
(True, False, torch.Tensor),
(False, False, torch.Tensor),
(None, False, torch.Tensor),
(True, True, torch.Tensor),
(False, True, np.ndarray),
(None, True, np.ndarray),
(True, None, torch.Tensor),
(False, None, np.ndarray),
(None, None, np.ndarray),
],
)
def test_encode_token_embeddings_type(
stsb_bert_tiny_model_reused: SentenceTransformer, convert_to_tensor: bool, convert_to_numpy: bool, expected_type
) -> None:
model = stsb_bert_tiny_model_reused

encode_kwargs = {}
if convert_to_tensor is not None:
encode_kwargs["convert_to_tensor"] = convert_to_tensor
if convert_to_numpy is not None:
encode_kwargs["convert_to_numpy"] = convert_to_numpy
embeddings = model.encode("Hello, World!", output_value="token_embeddings", **encode_kwargs)
assert isinstance(embeddings, expected_type)
assert embeddings.shape == (6, 128)


@pytest.mark.skipif(not is_peft_available(), reason="PEFT must be available to test adapter methods.")
def test_multiple_adapters() -> None:
text = "Hello, World!"
Expand Down
Loading