Skip to content

Commit eb3d4df

Browse files
daegonYutomaarsen
andauthored
CachedGISTEmbedLoss Adding Margin (#3299)
* add_margin_strategy_margin * add_margin * add_margin * update_description * Use Literal for argument autocomplete * Slightly improve docs * Use ValueError instead of assert (assert's can be skipped/ignored) * Simplify positive_mask creation * Extend margin implementation to GISTEmbedLoss too * Add missing multiple negatives case, update get_config_dict * Remove margin_strategy=None, default to "absolute" * Follow-up to removal of margin_strategy=None --------- Co-authored-by: Tom Aarsen <[email protected]>
1 parent 8a3342b commit eb3d4df

File tree

2 files changed

+94
-22
lines changed

2 files changed

+94
-22
lines changed

sentence_transformers/losses/CachedGISTEmbedLoss.py

+49-7
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from collections.abc import Iterable, Iterator
44
from contextlib import nullcontext
55
from functools import partial
6-
from typing import Any
6+
from typing import Any, Literal
77

88
import torch
99
import tqdm
@@ -68,6 +68,8 @@ def __init__(
6868
temperature: float = 0.01,
6969
mini_batch_size: int = 32,
7070
show_progress_bar: bool = False,
71+
margin_strategy: Literal["absolute", "percentage"] = "absolute",
72+
margin: float = 0.0,
7173
) -> None:
7274
"""
7375
This loss is a combination of :class:`GISTEmbedLoss` and :class:`CachedMultipleNegativesRankingLoss`.
@@ -81,6 +83,12 @@ def __init__(
8183
:class:`CachedMultipleNegativesRankingLoss`, it is possible to reduce memory usage while maintaining performance
8284
levels comparable to those of :class:`GISTEmbedLoss`.
8385
86+
You can apply different false-negative filtering strategies to discard hard negatives that are too similar to
87+
the positive. Two strategies are supported:
88+
89+
- "absolute": Discards negatives whose similarity score is greater than or equal to (positive_score - margin).
90+
- "percentage": Discards negatives whose similarity score is greater than or equal to (positive_score * margin).
91+
8492
Args:
8593
model: SentenceTransformer model
8694
guide: SentenceTransformer model to guide the in-batch negative sample selection.
@@ -90,6 +98,9 @@ def __init__(
9098
the slower the training will be. It's recommended to set it as high as your GPU memory allows. The default
9199
value is 32.
92100
show_progress_bar: If True, a progress bar for the mini-batches is shown during training. The default is False.
101+
margin_strategy: Strategy used for false negative filtering. One of {"absolute", "percentage"}.
102+
margin: The margin value for filtering negatives. Defaults to 0.0, together with the "absolute" strategy,
103+
this only removes negatives that are more similar to the query than the positive is to the query.
93104
94105
References:
95106
- Efficient Natural Language Response Suggestion for Smart Reply, Section 4.4: https://arxiv.org/pdf/1705.00652.pdf
@@ -130,7 +141,13 @@ def __init__(
130141
"anchor": ["It's nice weather outside today.", "He drove to work."],
131142
"positive": ["It's so sunny.", "He took the car to the office."],
132143
})
133-
loss = losses.CachedGISTEmbedLoss(model, guide, mini_batch_size=64)
144+
loss = losses.CachedGISTEmbedLoss(
145+
model,
146+
guide,
147+
mini_batch_size=64,
148+
margin_strategy="absolute", # or "percentage" (e.g., margin=0.95)
149+
margin=0.1
150+
)
134151
135152
trainer = SentenceTransformerTrainer(
136153
model=model,
@@ -163,6 +180,10 @@ def __init__(
163180
)
164181
if self.must_retokenize:
165182
self.tokenizer = model.tokenizer
183+
if margin_strategy not in ("absolute", "percentage"):
184+
raise ValueError("margin_strategy must be 'absolute' or 'percentage'.")
185+
self.margin_strategy = margin_strategy
186+
self.margin = margin
166187

167188
def sim_matrix(self, embed1: Tensor, embed2: Tensor) -> Tensor:
168189
return self.similarity_fct(embed1.unsqueeze(1), embed2.unsqueeze(0))
@@ -273,10 +294,29 @@ def calculate_loss(
273294
aa_sim = self.sim_matrix(concatenated_reps[0][b:e], concatenated_reps[0]) # anchor-anchor similarity
274295
pp_sim = self.sim_matrix(concatenated_reps[1][b:e], concatenated_reps[1]) # positive-positive similarity
275296

276-
# Apply thresholds based on guided model similarities
277-
ap_sim[guided_ap_sim > guided_sim] = -torch.inf
278-
aa_sim[guided_aa_sim > guided_sim] = -torch.inf
279-
pp_sim[guided_pp_sim > guided_sim] = -torch.inf
297+
# This uses guided (teacher) similarity as a dynamic threshold to identify and suppress false negatives
298+
def mask_false_negatives(guided_sim_mat, sim_mat, positive_mask: Tensor | None = None):
299+
if self.margin_strategy == "absolute":
300+
# Remove samples whose guided similarity is higher than (positive_sim - margin)
301+
mask = guided_sim_mat > (guided_sim - self.margin)
302+
elif self.margin_strategy == "percentage":
303+
# Remove samples whose guided similarity is higher than (positive_sim * margin)
304+
mask = guided_sim_mat > (guided_sim * self.margin)
305+
306+
if positive_mask is not None:
307+
# Ensure true positive pairs are not masked out
308+
mask = mask & ~positive_mask
309+
sim_mat[mask] = -torch.inf
310+
return sim_mat
311+
312+
# Create a mask to protect true positive pairs in the anchor-positive matrix (i.e., diagonal elements)
313+
positive_mask = torch.eye(*guided_ap_sim.shape, dtype=torch.bool, device=guided_ap_sim.device)
314+
positive_mask = positive_mask.roll(b)
315+
316+
# Apply false negative suppression to each similarity matrix using guided similarity as anchor
317+
ap_sim = mask_false_negatives(guided_ap_sim, ap_sim, positive_mask=positive_mask) # anchor-positive
318+
aa_sim = mask_false_negatives(guided_aa_sim, aa_sim) # anchor-anchor
319+
pp_sim = mask_false_negatives(guided_pp_sim, pp_sim) # positive-positive
280320

281321
# Concatenate the similarity matrices for anchor-positive, anchor-anchor, and positive-positive
282322
scores = torch.cat([ap_sim, aa_sim, pp_sim], dim=1)
@@ -286,7 +326,7 @@ def calculate_loss(
286326
for i in range(2, len(concatenated_reps)): # Start from 2 since first 2 are anchor-positive
287327
guided_neg_sim = self.sim_matrix(concatenated_guided_reps[0][b:e], concatenated_guided_reps[i])
288328
neg_sim = self.sim_matrix(concatenated_reps[0][b:e], concatenated_reps[i])
289-
neg_sim[guided_neg_sim > guided_sim] = -torch.inf
329+
neg_sim = mask_false_negatives(guided_neg_sim, neg_sim)
290330
scores = torch.cat([scores, neg_sim], dim=1)
291331

292332
# Normalize the scores and calculate the cross-entropy loss
@@ -337,4 +377,6 @@ def get_config_dict(self) -> dict[str, Any]:
337377
"guide": self.guide,
338378
"temperature": self.temperature,
339379
"mini_batch_size": self.mini_batch_size,
380+
"margin_strategy": self.margin_strategy,
381+
"margin": self.margin,
340382
}

sentence_transformers/losses/GISTEmbedLoss.py

+45-15
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
from collections.abc import Iterable
4-
from typing import Any
4+
from typing import Any, Literal
55

66
import torch
77
from torch import Tensor, nn
@@ -16,20 +16,28 @@ def __init__(
1616
model: SentenceTransformer,
1717
guide: SentenceTransformer,
1818
temperature: float = 0.01,
19+
margin_strategy: Literal["absolute", "percentage"] = "absolute",
20+
margin: float = 0.0,
1921
) -> None:
2022
"""
2123
This loss is used to train a SentenceTransformer model using the GISTEmbed algorithm.
2224
It takes a model and a guide model as input, and uses the guide model to guide the
2325
in-batch negative sample selection. The cosine similarity is used to compute the loss
2426
and the temperature parameter is used to scale the cosine similarities.
2527
28+
You can apply different false-negative filtering strategies to discard hard negatives that are too similar to
29+
the positive. Two strategies are supported:
30+
31+
- "absolute": Discards negatives whose similarity score is greater than or equal to (positive_score - margin).
32+
- "percentage": Discards negatives whose similarity score is greater than or equal to (positive_score * margin).
33+
2634
Args:
27-
model: SentenceTransformer model based on a `transformers`
28-
model.
29-
guide: SentenceTransformer model to guide the in-batch
30-
negative sample selection.
31-
temperature: Temperature parameter to scale the cosine
32-
similarities.
35+
model: SentenceTransformer model based on a `transformers` model.
36+
guide: SentenceTransformer model to guide the in-batch negative sample selection.
37+
temperature: Temperature parameter to scale the cosine similarities.
38+
margin_strategy: Strategy used for false negative filtering. One of {"absolute", "percentage"}.
39+
margin: The margin value for filtering negatives. Defaults to 0.0, together with the "absolute" strategy,
40+
this only removes negatives that are more similar to the query than the positive is to the query.
3341
3442
References:
3543
- For further details, see: https://arxiv.org/abs/2402.16829
@@ -98,6 +106,11 @@ def __init__(
98106
"then the Sentence Transformer model must not be based on a StaticEmbedding."
99107
)
100108

109+
if margin_strategy not in ("absolute", "percentage"):
110+
raise ValueError("margin_strategy must be 'absolute' or 'percentage'.")
111+
self.margin_strategy = margin_strategy
112+
self.margin = margin
113+
101114
def sim_matrix(self, embed1: Tensor, embed2: Tensor) -> Tensor:
102115
return self.similarity_fct(embed1.unsqueeze(1), embed2.unsqueeze(0))
103116

@@ -144,21 +157,36 @@ def forward(self, sentence_features: Iterable[dict[str, Tensor]], labels: Tensor
144157
# Define the anchor threshold
145158
guided_sim = guided_ap_sim.diagonal().view(-1, 1)
146159

147-
# Find which samples cannot be used as negatives because they are
148-
# more similar to the query than the assigned positive as deemed by the guide model.
149-
# For these samples, we mask them with -inf to basically ignore their contribution to
150-
# the loss.
151-
ap_sim[guided_ap_sim > guided_sim] = -torch.inf
152-
aa_sim[guided_aa_sim > guided_sim] = -torch.inf
153-
pp_sim[guided_pp_sim > guided_sim] = -torch.inf
160+
# This uses guided (teacher) similarity as a dynamic threshold to identify and suppress false negatives
161+
def mask_false_negatives(guided_sim_mat, sim_mat, positive_mask: Tensor | None = None):
162+
if self.margin_strategy == "absolute":
163+
# Remove samples whose guided similarity is higher than (positive_sim - margin)
164+
mask = guided_sim_mat > (guided_sim - self.margin)
165+
elif self.margin_strategy == "percentage":
166+
# Remove samples whose guided similarity is higher than (positive_sim * margin)
167+
mask = guided_sim_mat > (guided_sim * self.margin)
168+
169+
if positive_mask is not None:
170+
# Ensure true positive pairs are not masked out
171+
mask = mask & ~positive_mask
172+
sim_mat[mask] = -torch.inf
173+
return sim_mat
174+
175+
# Create a mask to protect true positive pairs in the anchor-positive matrix (i.e., diagonal elements)
176+
positive_mask = torch.eye(*guided_ap_sim.shape, dtype=torch.bool, device=guided_ap_sim.device)
177+
178+
# Apply false negative suppression to each similarity matrix using guided similarity as anchor
179+
ap_sim = mask_false_negatives(guided_ap_sim, ap_sim, positive_mask=positive_mask) # anchor-positive
180+
aa_sim = mask_false_negatives(guided_aa_sim, aa_sim) # anchor-anchor
181+
pp_sim = mask_false_negatives(guided_pp_sim, pp_sim) # positive-positive
154182

155183
scores = [ap_sim, aa_sim, pp_sim]
156184

157185
# Handle the case where we have a negative sample
158186
if negative is not None:
159187
an_sim = self.sim_matrix(anchor, negative)
160188
guided_an_sim = self.sim_matrix(anchor_guide, negative_guide)
161-
an_sim[guided_an_sim > guided_sim] = -torch.inf
189+
an_sim = mask_false_negatives(guided_an_sim, an_sim)
162190

163191
scores.append(an_sim)
164192

@@ -174,6 +202,8 @@ def get_config_dict(self) -> dict[str, Any]:
174202
return {
175203
"guide": self.guide,
176204
"temperature": self.temperature,
205+
"margin_strategy": self.margin_strategy,
206+
"margin": self.margin,
177207
}
178208

179209
@property

0 commit comments

Comments
 (0)