Skip to content

Commit a1db32d

Browse files
authored
[docs] Update the training snippets for some losses that should use the v3 Trainer (#2987)
1 parent a4be00f commit a1db32d

File tree

3 files changed

+47
-38
lines changed

3 files changed

+47
-38
lines changed

sentence_transformers/losses/Matryoshka2dLoss.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -95,21 +95,23 @@ def __init__(
9595
Example:
9696
::
9797
98-
from sentence_transformers import SentenceTransformer, losses, InputExample
99-
from torch.utils.data import DataLoader
98+
from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer, losses
99+
from datasets import Dataset
100100
101101
model = SentenceTransformer("microsoft/mpnet-base")
102-
train_examples = [
103-
InputExample(texts=['Anchor 1', 'Positive 1']),
104-
InputExample(texts=['Anchor 2', 'Positive 2']),
105-
]
106-
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=32)
107-
train_loss = losses.MultipleNegativesRankingLoss(model=model)
108-
train_loss = losses.Matryoshka2dLoss(model, train_loss, [768, 512, 256, 128, 64])
109-
model.fit(
110-
[(train_dataloader, train_loss)],
111-
epochs=10,
102+
train_dataset = Dataset.from_dict({
103+
"anchor": ["It's nice weather outside today.", "He drove to work."],
104+
"positive": ["It's so sunny.", "He took the car to the office."],
105+
})
106+
loss = losses.MultipleNegativesRankingLoss(model)
107+
loss = losses.Matryoshka2dLoss(model, loss, [768, 512, 256, 128, 64])
108+
109+
trainer = SentenceTransformerTrainer(
110+
model=model,
111+
train_dataset=train_dataset,
112+
loss=loss,
112113
)
114+
trainer.train()
113115
"""
114116
matryoshka_loss = MatryoshkaLoss(
115117
model,

sentence_transformers/losses/MatryoshkaLoss.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -101,21 +101,23 @@ def __init__(
101101
Example:
102102
::
103103
104-
from sentence_transformers import SentenceTransformer, losses, InputExample
105-
from torch.utils.data import DataLoader
104+
from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer, losses
105+
from datasets import Dataset
106106
107107
model = SentenceTransformer("microsoft/mpnet-base")
108-
train_examples = [
109-
InputExample(texts=['Anchor 1', 'Positive 1']),
110-
InputExample(texts=['Anchor 2', 'Positive 2']),
111-
]
112-
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=32)
113-
train_loss = losses.MultipleNegativesRankingLoss(model=model)
114-
train_loss = losses.MatryoshkaLoss(model, train_loss, [768, 512, 256, 128, 64])
115-
model.fit(
116-
[(train_dataloader, train_loss)],
117-
epochs=10,
108+
train_dataset = Dataset.from_dict({
109+
"anchor": ["It's nice weather outside today.", "He drove to work."],
110+
"positive": ["It's so sunny.", "He took the car to the office."],
111+
})
112+
loss = losses.MultipleNegativesRankingLoss(model)
113+
loss = losses.MatryoshkaLoss(model, loss, [768, 512, 256, 128, 64])
114+
115+
trainer = SentenceTransformerTrainer(
116+
model=model,
117+
train_dataset=train_dataset,
118+
loss=loss,
118119
)
120+
trainer.train()
119121
"""
120122
super().__init__()
121123
self.model = model

sentence_transformers/losses/MegaBatchMarginLoss.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -59,25 +59,30 @@ def __init__(
5959
Example:
6060
::
6161
62-
from sentence_transformers import SentenceTransformer, InputExample, losses
63-
from torch.utils.data import DataLoader
62+
from sentence_transformers import SentenceTransformer, SentenceTransformerTrainingArguments, SentenceTransformerTrainer, losses
63+
from datasets import Dataset
6464
65-
model = SentenceTransformer('all-MiniLM-L6-v2')
66-
67-
total_examples = 500
6865
train_batch_size = 250
6966
train_mini_batch_size = 32
7067
71-
train_examples = [
72-
InputExample(texts=[f"This is sentence number {i}", f"This is sentence number {i+1}"]) for i in range(total_examples)
73-
]
74-
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=train_batch_size)
75-
train_loss = losses.MegaBatchMarginLoss(model=model, mini_batch_size=train_mini_batch_size)
76-
77-
model.fit(
78-
[(train_dataloader, train_loss)],
79-
epochs=10,
68+
model = SentenceTransformer('all-MiniLM-L6-v2')
69+
train_dataset = Dataset.from_dict({
70+
"anchor": [f"This is sentence number {i}" for i in range(500)],
71+
"positive": [f"This is sentence number {i}" for i in range(1, 501)],
72+
})
73+
loss = losses.MegaBatchMarginLoss(model=model, mini_batch_size=train_mini_batch_size)
74+
75+
args = SentenceTransformerTrainingArguments(
76+
output_dir="output",
77+
per_device_train_batch_size=train_batch_size,
78+
)
79+
trainer = SentenceTransformerTrainer(
80+
model=model,
81+
args=args,
82+
train_dataset=train_dataset,
83+
loss=loss,
8084
)
85+
trainer.train()
8186
"""
8287
super().__init__()
8388
self.model = model

0 commit comments

Comments
 (0)