@@ -55,9 +55,11 @@ def parse_args():
55
55
56
56
57
57
def forward (model , batch , optimizer , accelerator : Accelerator ):
58
- input_ids , shift_labels = batch ["input_ids" ], batch ["shift_labels" ]
58
+ # We need both labels and shift_labels, as the loss computation in the model is hidden behind `if labels is not None`, but the loss computation
59
+ # itself prioritzes shift_labels (if provided) which are the correct ones (due to labels being wrong if cp enabled)
60
+ buffers = [batch ["input_ids" ], batch ["shift_labels" ], batch ["labels" ]]
59
61
with accelerator .maybe_context_parallel (
60
- buffers = [ input_ids , shift_labels ], buffer_seq_dims = [1 , 1 ], no_restore_buffers = { input_ids , shift_labels }
62
+ buffers = buffers , buffer_seq_dims = [1 , 1 , 1 ], no_restore_buffers = set ( buffers )
61
63
):
62
64
# To get the proper loss value, we need to average across devices that are participating in data parallel/context parallel training
63
65
loss_reduce_grp = (
@@ -66,10 +68,7 @@ def forward(model, batch, optimizer, accelerator: Accelerator):
66
68
else None
67
69
)
68
70
outputs = model (** batch )
69
- # With shift labels we need to compute loss ourselves
70
- loss = ForCausalLMLoss (
71
- logits = outputs .logits , labels = None , shift_labels = shift_labels , vocab_size = model .config .vocab_size
72
- )
71
+ loss = outputs .loss
73
72
accelerator .backward (loss )
74
73
optimizer .step ()
75
74
optimizer .zero_grad ()
0 commit comments