3
3
from collections .abc import Iterable , Iterator
4
4
from contextlib import nullcontext
5
5
from functools import partial
6
- from typing import Any
6
+ from typing import Any , Literal
7
7
8
8
import torch
9
9
import tqdm
@@ -68,6 +68,8 @@ def __init__(
68
68
temperature : float = 0.01 ,
69
69
mini_batch_size : int = 32 ,
70
70
show_progress_bar : bool = False ,
71
+ margin_strategy : Literal ["absolute" , "percentage" ] = "absolute" ,
72
+ margin : float = 0.0 ,
71
73
) -> None :
72
74
"""
73
75
This loss is a combination of :class:`GISTEmbedLoss` and :class:`CachedMultipleNegativesRankingLoss`.
@@ -81,6 +83,12 @@ def __init__(
81
83
:class:`CachedMultipleNegativesRankingLoss`, it is possible to reduce memory usage while maintaining performance
82
84
levels comparable to those of :class:`GISTEmbedLoss`.
83
85
86
+ You can apply different false-negative filtering strategies to discard hard negatives that are too similar to
87
+ the positive. Two strategies are supported:
88
+
89
+ - "absolute": Discards negatives whose similarity score is greater than or equal to (positive_score - margin).
90
+ - "percentage": Discards negatives whose similarity score is greater than or equal to (positive_score * margin).
91
+
84
92
Args:
85
93
model: SentenceTransformer model
86
94
guide: SentenceTransformer model to guide the in-batch negative sample selection.
@@ -90,6 +98,9 @@ def __init__(
90
98
the slower the training will be. It's recommended to set it as high as your GPU memory allows. The default
91
99
value is 32.
92
100
show_progress_bar: If True, a progress bar for the mini-batches is shown during training. The default is False.
101
+ margin_strategy: Strategy used for false negative filtering. One of {"absolute", "percentage"}.
102
+ margin: The margin value for filtering negatives. Defaults to 0.0, together with the "absolute" strategy,
103
+ this only removes negatives that are more similar to the query than the positive is to the query.
93
104
94
105
References:
95
106
- Efficient Natural Language Response Suggestion for Smart Reply, Section 4.4: https://arxiv.org/pdf/1705.00652.pdf
@@ -130,7 +141,13 @@ def __init__(
130
141
"anchor": ["It's nice weather outside today.", "He drove to work."],
131
142
"positive": ["It's so sunny.", "He took the car to the office."],
132
143
})
133
- loss = losses.CachedGISTEmbedLoss(model, guide, mini_batch_size=64)
144
+ loss = losses.CachedGISTEmbedLoss(
145
+ model,
146
+ guide,
147
+ mini_batch_size=64,
148
+ margin_strategy="absolute", # or "percentage" (e.g., margin=0.95)
149
+ margin=0.1
150
+ )
134
151
135
152
trainer = SentenceTransformerTrainer(
136
153
model=model,
@@ -163,6 +180,10 @@ def __init__(
163
180
)
164
181
if self .must_retokenize :
165
182
self .tokenizer = model .tokenizer
183
+ if margin_strategy not in ("absolute" , "percentage" ):
184
+ raise ValueError ("margin_strategy must be 'absolute' or 'percentage'." )
185
+ self .margin_strategy = margin_strategy
186
+ self .margin = margin
166
187
167
188
def sim_matrix (self , embed1 : Tensor , embed2 : Tensor ) -> Tensor :
168
189
return self .similarity_fct (embed1 .unsqueeze (1 ), embed2 .unsqueeze (0 ))
@@ -273,10 +294,29 @@ def calculate_loss(
273
294
aa_sim = self .sim_matrix (concatenated_reps [0 ][b :e ], concatenated_reps [0 ]) # anchor-anchor similarity
274
295
pp_sim = self .sim_matrix (concatenated_reps [1 ][b :e ], concatenated_reps [1 ]) # positive-positive similarity
275
296
276
- # Apply thresholds based on guided model similarities
277
- ap_sim [guided_ap_sim > guided_sim ] = - torch .inf
278
- aa_sim [guided_aa_sim > guided_sim ] = - torch .inf
279
- pp_sim [guided_pp_sim > guided_sim ] = - torch .inf
297
+ # This uses guided (teacher) similarity as a dynamic threshold to identify and suppress false negatives
298
+ def mask_false_negatives (guided_sim_mat , sim_mat , positive_mask : Tensor | None = None ):
299
+ if self .margin_strategy == "absolute" :
300
+ # Remove samples whose guided similarity is higher than (positive_sim - margin)
301
+ mask = guided_sim_mat > (guided_sim - self .margin )
302
+ elif self .margin_strategy == "percentage" :
303
+ # Remove samples whose guided similarity is higher than (positive_sim * margin)
304
+ mask = guided_sim_mat > (guided_sim * self .margin )
305
+
306
+ if positive_mask is not None :
307
+ # Ensure true positive pairs are not masked out
308
+ mask = mask & ~ positive_mask
309
+ sim_mat [mask ] = - torch .inf
310
+ return sim_mat
311
+
312
+ # Create a mask to protect true positive pairs in the anchor-positive matrix (i.e., diagonal elements)
313
+ positive_mask = torch .eye (* guided_ap_sim .shape , dtype = torch .bool , device = guided_ap_sim .device )
314
+ positive_mask = positive_mask .roll (b )
315
+
316
+ # Apply false negative suppression to each similarity matrix using guided similarity as anchor
317
+ ap_sim = mask_false_negatives (guided_ap_sim , ap_sim , positive_mask = positive_mask ) # anchor-positive
318
+ aa_sim = mask_false_negatives (guided_aa_sim , aa_sim ) # anchor-anchor
319
+ pp_sim = mask_false_negatives (guided_pp_sim , pp_sim ) # positive-positive
280
320
281
321
# Concatenate the similarity matrices for anchor-positive, anchor-anchor, and positive-positive
282
322
scores = torch .cat ([ap_sim , aa_sim , pp_sim ], dim = 1 )
@@ -286,7 +326,7 @@ def calculate_loss(
286
326
for i in range (2 , len (concatenated_reps )): # Start from 2 since first 2 are anchor-positive
287
327
guided_neg_sim = self .sim_matrix (concatenated_guided_reps [0 ][b :e ], concatenated_guided_reps [i ])
288
328
neg_sim = self .sim_matrix (concatenated_reps [0 ][b :e ], concatenated_reps [i ])
289
- neg_sim [ guided_neg_sim > guided_sim ] = - torch . inf
329
+ neg_sim = mask_false_negatives ( guided_neg_sim , neg_sim )
290
330
scores = torch .cat ([scores , neg_sim ], dim = 1 )
291
331
292
332
# Normalize the scores and calculate the cross-entropy loss
@@ -337,4 +377,6 @@ def get_config_dict(self) -> dict[str, Any]:
337
377
"guide" : self .guide ,
338
378
"temperature" : self .temperature ,
339
379
"mini_batch_size" : self .mini_batch_size ,
380
+ "margin_strategy" : self .margin_strategy ,
381
+ "margin" : self .margin ,
340
382
}
0 commit comments