Skip to content

typing: fix typing on encode #3270

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 20, 2025
Merged

Conversation

stephantul
Copy link
Contributor

Hello!

The encoding function had inconsistent typing. Specifically, it did not include types for output_value=None, where features are returned, and ignored the fact that the values of output_tensor and output_numpy are ignored if output_value != "sentence_embedding".

This PR fixes all those issues, and should hopefully lead to a more consistent experience. For what it's worth, I kept the possibility of passing in an ndarray, but I think this should be fixed, because the actual input type of the function is Iterator[str]. There's actually no reason for it to be a list.

Here's the test cases I used. Unfortunately there's no real nice way to introspect types AFAIK, but this shoud be sufficient to check using reveal_type or your IDE.

r0 = model.encode(sentences=['Hello, my dog is cute'], convert_to_numpy=True, convert_to_tensor=True, output_value='sentence_embedding')
# <class 'torch.Tensor'>
r1 = model.encode(sentences='Hello, my dog is cute', convert_to_numpy=True, convert_to_tensor=True, output_value='sentence_embedding')
# <class 'torch.Tensor'>
r2 = model.encode(sentences=['Hello, my dog is cute'], convert_to_numpy=True, convert_to_tensor=True, output_value='token_embeddings')
# <class 'list'>
# <class 'torch.Tensor'>
r3 = model.encode(sentences='Hello, my dog is cute', convert_to_numpy=True, convert_to_tensor=True, output_value='token_embeddings')
# <class 'torch.Tensor'>
r4 = model.encode(sentences=['Hello, my dog is cute'], convert_to_numpy=True, convert_to_tensor=True, output_value=None)
# <class 'list'>
# <class 'dict'>
r5 = model.encode(sentences='Hello, my dog is cute', convert_to_numpy=True, convert_to_tensor=True, output_value=None)
# <class 'dict'>
r6 = model.encode(sentences=['Hello, my dog is cute'], convert_to_numpy=True, convert_to_tensor=False, output_value='sentence_embedding')
# <class 'numpy.ndarray'>
r7 = model.encode(sentences='Hello, my dog is cute', convert_to_numpy=True, convert_to_tensor=False, output_value='sentence_embedding')
# <class 'numpy.ndarray'>
r8 = model.encode(sentences=['Hello, my dog is cute'], convert_to_numpy=True, convert_to_tensor=False, output_value='token_embeddings')
# <class 'list'>
# <class 'torch.Tensor'>
r9 = model.encode(sentences='Hello, my dog is cute', convert_to_numpy=True, convert_to_tensor=False, output_value='token_embeddings')
# <class 'torch.Tensor'>
r10 = model.encode(sentences=['Hello, my dog is cute'], convert_to_numpy=True, convert_to_tensor=False, output_value=None)
# <class 'list'>
# <class 'dict'>
r11 = model.encode(sentences='Hello, my dog is cute', convert_to_numpy=True, convert_to_tensor=False, output_value=None)
# <class 'dict'>
r12 = model.encode(sentences=['Hello, my dog is cute'], convert_to_numpy=False, convert_to_tensor=True, output_value='sentence_embedding')
# <class 'torch.Tensor'>
r13 = model.encode(sentences='Hello, my dog is cute', convert_to_numpy=False, convert_to_tensor=True, output_value='sentence_embedding')
# <class 'torch.Tensor'>
r14 = model.encode(sentences=['Hello, my dog is cute'], convert_to_numpy=False, convert_to_tensor=True, output_value='token_embeddings')
# <class 'list'>
# <class 'torch.Tensor'>
r15 = model.encode(sentences='Hello, my dog is cute', convert_to_numpy=False, convert_to_tensor=True, output_value='token_embeddings')
# <class 'torch.Tensor'>
r16 = model.encode(sentences=['Hello, my dog is cute'], convert_to_numpy=False, convert_to_tensor=True, output_value=None)
# <class 'list'>
# <class 'dict'>
r17 = model.encode(sentences='Hello, my dog is cute', convert_to_numpy=False, convert_to_tensor=True, output_value=None)
# <class 'dict'>
r18 = model.encode(sentences=['Hello, my dog is cute'], convert_to_numpy=False, convert_to_tensor=False, output_value='sentence_embedding')
# <class 'list'>
# <class 'torch.Tensor'>
r19 = model.encode(sentences='Hello, my dog is cute', convert_to_numpy=False, convert_to_tensor=False, output_value='sentence_embedding')
# <class 'torch.Tensor'>
r20 = model.encode(sentences=['Hello, my dog is cute'], convert_to_numpy=False, convert_to_tensor=False, output_value='token_embeddings')
# <class 'list'>
# <class 'torch.Tensor'>
r21 = model.encode(sentences='Hello, my dog is cute', convert_to_numpy=False, convert_to_tensor=False, output_value='token_embeddings')
# <class 'torch.Tensor'>
r22 = model.encode(sentences=['Hello, my dog is cute'], convert_to_numpy=False, convert_to_tensor=False, output_value=None)
# <class 'list'>
# <class 'dict'>
r23 = model.encode(sentences='Hello, my dog is cute', convert_to_numpy=False, convert_to_tensor=False, output_value=None)
# <class 'dict'>

The format here is: the first line below an invocation indicates the base type. If this is a container (i.e., a list), the second line outputs the type of the first item of the container. On my local machine I matched these with actual tests, and they all produced the same output:

See here for the script:

import torch
import numpy as np
from itertools import product
from sentence_transformers import SentenceTransformer

model = SentenceTransformer('baai/bge-base-en-v1.5')


test_sentence = "Hello, my dog is cute"
convert_to_numpy = [True, False]
convert_to_tensor = [True, False]
output_value = ["sentence_embedding", "token_embeddings", None]
is_list = [True, False]
expected_type = {
    (True, True, "sentence_embedding", True): (torch.Tensor, None),
    (True, True, "sentence_embedding", False): (torch.Tensor, None),
    (True, True, "token_embeddings", True): (list, torch.Tensor),
    (True, True, "token_embeddings", False): (torch.Tensor, None),
    (True, True, None, True): (list, dict),
    (True, True, None, False): (dict, None),
    (True, False, "sentence_embedding", True): (np.ndarray, None),
    (True, False, "sentence_embedding", False): (np.ndarray, None),
    (True, False, "token_embeddings", True): (list, torch.Tensor),
    (True, False, "token_embeddings", False): (torch.Tensor, None),
    (True, False, None, True): (list, dict),
    (True, False, None, False): (dict, None),
    (False, True, "sentence_embedding", True): (torch.Tensor, None),
    (False, True, "sentence_embedding", False): (torch.Tensor, None),
    (False, True, "token_embeddings", True): (list, torch.Tensor),
    (False, True, "token_embeddings", False): (torch.Tensor, None),
    (False, True, None, True): (list, dict),
    (False, True, None, False): (dict, None),
    (False, False, "sentence_embedding", True): (list, torch.Tensor),
    (False, False, "sentence_embedding", False): (torch.Tensor, None),
    (False, False, "token_embeddings", True): (list, torch.Tensor),
    (False, False, "token_embeddings", False): (torch.Tensor, None),
    (False, False, None, True): (list, dict),
    (False, False, None, False): (dict, None),
}

# Generate all combinations
test_cases = list(product(convert_to_numpy, convert_to_tensor, output_value, is_list))

# Function to check type
def check_type(result, expected_type_info):
    if isinstance(result, list):
        assert isinstance(result, expected_type_info[0]), f"Expected {expected_type_info[0]} but got {type(result)}"
        if expected_type_info[1] is not None:
            assert isinstance(result[0], expected_type_info[1]), f"Expected list of {expected_type_info[1]} but got list of {type(result[0])}"
    else:
        assert isinstance(result, expected_type_info[0]), f"Expected {expected_type_info[0]} but got {type(result)}"

# Test all combinations
for i, (to_numpy, to_tensor, out_val, is_list_val) in enumerate(test_cases):
    input_data = [test_sentence] if is_list_val else test_sentence
    result = model.encode(input_data, convert_to_numpy=to_numpy, convert_to_tensor=to_tensor, output_value=out_val)
    expected = expected_type[(to_numpy, to_tensor, out_val, is_list_val)]
    print(f"Case {i+1}: ({to_numpy}, {to_tensor}, {out_val}, {is_list_val}) => {type(result)}")
    check_type(result, expected)

@stephantul stephantul changed the title typng: fix typing on encode typing: fix typing on encode Mar 17, 2025
@stephantul
Copy link
Contributor Author

Oh, I also added None as a type to device, since the default is None.

@tomaarsen
Copy link
Collaborator

tomaarsen commented Mar 19, 2025

Hello!

I'm wary of extending the SentenceTransformer.py even further with more typing things, ~150 lines is a bit much, so I've tried to move everything to SentenceTransformer.pyi instead, which seems to work pretty smoothly! I also fixed device: str = None -> device: str | None = None for 2 of the type options.

Could you please do a check if this also works for you?

  • Tom Aarsen

@stephantul
Copy link
Contributor Author

Hey, thanks for additional check, good that you spotted the None | str for device. Everything seems to be in order, I rechecked the results. I like the move to .pyi, much cleaner.

@tomaarsen
Copy link
Collaborator

Thanks for tackling this!

  • Tom Aarsen

@tomaarsen tomaarsen merged commit 5c362dc into UKPLab:master Mar 20, 2025
7 of 9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants