Skip to content

[feat] Zero-shot classification #3259

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 88 additions & 0 deletions sentence_transformers/SentenceTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from .quantization import quantize_embeddings
from .util import (
batch_to_device,
cos_sim,
get_device_name,
import_from_string,
is_sentence_transformer_model,
Expand Down Expand Up @@ -691,6 +692,93 @@ def forward(self, input: dict[str, Tensor], **kwargs) -> dict[str, Tensor]:
input = module(input, **module_kwargs)
return input

def classify(
self,
sentences: str | list[str],
labels: list[str],
label_template: str = 'The main subject of this text is {}.',
prompt_name: str | None = None,
prompt: str | None = None,
batch_size: int = 32,
show_progress_bar: bool | None = None,
normalize_embeddings: bool = False,
**kwargs,
) -> list[list[tuple[str, float]]]:
"""
Perform zero-shot classification using the embedding models.

Args:
sentences (Union[str, List[str]]): The sentences to classify.
labels (List[str]): The labels to classify the sentences against.
label_template (str, optional): A template to format the labels.
For example, if the label template is "This is a label: {}",
then the label "positive" will be formatted as "This is a label: positive". Defaults to None.
prompt_name (Optional[str], optional): The name of the prompt to use for encoding. Must be a key in the `prompts` dictionary,
which is either set in the constructor or loaded from the model configuration. For example if
``prompt_name`` is "query" and the ``prompts`` is {"query": "query: ", ...}, then the sentence "What
is the capital of France?" will be encoded as "query: What is the capital of France?" because the sentence
is appended to the prompt. If ``prompt`` is also set, this argument is ignored. Defaults to None.
prompt (Optional[str], optional): The prompt to use for encoding. For example, if the prompt is "query: ", then the
sentence "What is the capital of France?" will be encoded as "query: What is the capital of France?"
because the sentence is appended to the prompt. If ``prompt`` is set, ``prompt_name`` is ignored. Defaults to None.
batch_size (int, optional): The batch size used for the computation. Defaults to 32.
show_progress_bar (bool, optional): Whether to output a progress bar when encode sentences. Defaults to None.
normalize_embeddings (bool, optional): Whether to normalize returned vectors to have length 1. In that case,
the faster dot-product (util.dot_score) instead of cosine similarity can be used. Defaults to False.

Returns:
List[List[Tuple[str, float]]]: A list of results for each sentence.
Each result is a list of tuples with the label and the similarity score.

Example:
::

from sentence_transformers import SentenceTransformer

# Load a pre-trained SentenceTransformer model
model = SentenceTransformer('all-mpnet-base-v2')

sentences = [
"The weather is lovely today.",
"It's so sunny outside!",
"He drove to the stadium.",
]
labels = ["weather", "sports", "politics"]
results = model.classify(sentences, labels)

for sentence, result in zip(sentences, results):
print(f"\nClassification for '{sentence}':")
for label, score in result:
print(f"{label}: {score}")
"""
raw_labels = list(labels)
if label_template:
labels = [label_template.format(label) for label in labels]
logger.debug("Encoding input sentences")
text_embeddings = self.encode(
sentences,
batch_size=batch_size,
show_progress_bar=show_progress_bar,
prompt=prompt,
prompt_name=prompt_name,
normalize_embeddings=normalize_embeddings,
)
logger.debug("Encoding labels")
label_embeddings = self.encode(
labels,
batch_size=batch_size,
show_progress_bar=show_progress_bar,
prompt=prompt,
prompt_name=prompt_name,
normalize_embeddings=normalize_embeddings,
)
similarities = cos_sim(text_embeddings, label_embeddings)
# torch nn softmax on similarity
similarities = torch.nn.functional.softmax(similarities, dim=1)
similarities = similarities.cpu().tolist()

return [sorted(zip(raw_labels, row), key=lambda x: x[1], reverse=True) for row in similarities]

@property
def similarity_fn_name(self) -> Literal["cosine", "dot", "euclidean", "manhattan"]:
"""Return the name of the similarity function used by :meth:`SentenceTransformer.similarity` and :meth:`SentenceTransformer.similarity_pairwise`.
Expand Down
25 changes: 25 additions & 0 deletions tests/test_sentence_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,6 +612,31 @@ def test_similarity_score(stsb_bert_tiny_model_reused: SentenceTransformer, simi
if similarity_fn_name in ("cosine", "dot"):
assert (pairwise_scores > 0.5).all()

def test_classify(stsb_bert_tiny_model_reused: SentenceTransformer) -> None:
model = stsb_bert_tiny_model_reused

sentences = [
"The weather is so nice!",
"It's so sunny outside.",
"He's driving to the movie theater.",
"She's going to the cinema.",
]
labels = [
'travel', 'cooking', 'dancing'
]
results = model.classify(sentences, labels)
assert len(results) == len(sentences)
for result in results:
assert len(result) == len(labels)
predicted_labels = list()
scores = 0
for predicted_label, score in result:
assert predicted_label in labels
assert 0 <= score <= 1
predicted_labels.append(predicted_label)
scores += score
assert np.isclose(scores, 1)
assert set(predicted_labels) == set(labels)

def test_similarity_score_save(stsb_bert_tiny_model: SentenceTransformer) -> None:
model = stsb_bert_tiny_model
Expand Down
Loading