|
12 | 12 | from sentence_transformers import SentenceTransformer
|
13 | 13 | from sentence_transformers.losses.CachedGISTEmbedLoss import CachedGISTEmbedLoss
|
14 | 14 | from sentence_transformers.losses.CachedMultipleNegativesRankingLoss import CachedMultipleNegativesRankingLoss
|
| 15 | +from sentence_transformers.losses.CachedMultipleNegativesSymmetricRankingLoss import ( |
| 16 | + CachedMultipleNegativesSymmetricRankingLoss, |
| 17 | +) |
15 | 18 | from sentence_transformers.models import Transformer
|
16 | 19 |
|
17 | 20 |
|
@@ -149,7 +152,8 @@ def __init__(
|
149 | 152 | - `Adaptive Layers <../../examples/training/adaptive_layer/README.html>`_
|
150 | 153 |
|
151 | 154 | Requirements:
|
152 |
| - 1. The base loss cannot be :class:`CachedMultipleNegativesRankingLoss` or :class:`CachedGISTEmbedLoss`. |
| 155 | + 1. The base loss cannot be :class:`CachedMultipleNegativesRankingLoss`, |
| 156 | + :class:`CachedMultipleNegativesSymmetricRankingLoss`, or :class:`CachedGISTEmbedLoss`. |
153 | 157 |
|
154 | 158 | Inputs:
|
155 | 159 | +---------------------------------------+--------+
|
@@ -192,10 +196,11 @@ def __init__(
|
192 | 196 | self.kl_div_weight = kl_div_weight
|
193 | 197 | self.kl_temperature = kl_temperature
|
194 | 198 | assert isinstance(self.model[0], Transformer)
|
195 |
| - if isinstance(loss, CachedMultipleNegativesRankingLoss): |
196 |
| - warnings.warn("MatryoshkaLoss is not compatible with CachedMultipleNegativesRankingLoss.", stacklevel=2) |
197 |
| - if isinstance(loss, CachedGISTEmbedLoss): |
198 |
| - warnings.warn("MatryoshkaLoss is not compatible with CachedGISTEmbedLoss.", stacklevel=2) |
| 199 | + if isinstance( |
| 200 | + loss, |
| 201 | + (CachedMultipleNegativesRankingLoss, CachedMultipleNegativesSymmetricRankingLoss, CachedGISTEmbedLoss), |
| 202 | + ): |
| 203 | + warnings.warn(f"MatryoshkaLoss is not compatible with {loss.__class__.__name__}.", stacklevel=2) |
199 | 204 |
|
200 | 205 | def forward(self, sentence_features: Iterable[dict[str, Tensor]], labels: Tensor) -> Tensor:
|
201 | 206 | # Decorate the forward function of the transformer to cache the embeddings of all layers
|
|
0 commit comments