Skip to content

Commit 1593bc9

Browse files
committed
Fix: add both shift_labels and labels to make the model.forward calculate loss
1 parent 1daa26b commit 1593bc9

File tree

2 files changed

+6
-7
lines changed

2 files changed

+6
-7
lines changed

examples/fsdp2/nd_parallel.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,11 @@ def parse_args():
5555

5656

5757
def forward(model, batch, optimizer, accelerator: Accelerator):
58-
input_ids, shift_labels = batch["input_ids"], batch["shift_labels"]
58+
# We need both labels and shift_labels, as the loss computation in the model is hidden behind `if labels is not None`, but the loss computation
59+
# itself prioritzes shift_labels (if provided) which are the correct ones (due to labels being wrong if cp enabled)
60+
buffers = [batch["input_ids"], batch["shift_labels"], batch["labels"]]
5961
with accelerator.maybe_context_parallel(
60-
buffers=[input_ids, shift_labels], buffer_seq_dims=[1, 1], no_restore_buffers={input_ids, shift_labels}
62+
buffers=buffers, buffer_seq_dims=[1, 1, 1], no_restore_buffers=set(buffers)
6163
):
6264
# To get the proper loss value, we need to average across devices that are participating in data parallel/context parallel training
6365
loss_reduce_grp = (
@@ -66,10 +68,7 @@ def forward(model, batch, optimizer, accelerator: Accelerator):
6668
else None
6769
)
6870
outputs = model(**batch)
69-
# With shift labels we need to compute loss ourselves
70-
loss = ForCausalLMLoss(
71-
logits=outputs.logits, labels=None, shift_labels=shift_labels, vocab_size=model.config.vocab_size
72-
)
71+
loss = outputs.loss
7372
accelerator.backward(loss)
7473
optimizer.step()
7574
optimizer.zero_grad()

examples/fsdp2/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def create_collate_fn():
112112
def collate_fn(batch):
113113
input_ids = torch.tensor([item["input_ids"] for item in batch], dtype=torch.long)
114114
shift_labels = torch.tensor([item["shift_labels"] for item in batch], dtype=torch.long)
115-
return {"input_ids": input_ids, "shift_labels": shift_labels}
115+
return {"input_ids": input_ids, "shift_labels": shift_labels, "labels": shift_labels}
116116

117117
return collate_fn
118118

0 commit comments

Comments
 (0)