You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/source/usage_guides/gradient_accumulation.md
+4-4Lines changed: 4 additions & 4 deletions
Original file line number
Diff line number
Diff line change
@@ -245,7 +245,7 @@ As was pointed out in this [blog-post](https://huggingface.co/blog/gradient_accu
245
245
246
246
> [...] for gradient accumulation across token-level tasks like causal LM training, the correct loss should be computed by the **total loss across all batches in a gradient accumulation step** divided by the **total number of all non padding tokens in those batches**. This is not the same as the average of the per-batch loss values.
247
247
248
-
In other words, some adjustements must be made on losses that operate on a token-level basis.
248
+
In other words, some adjustments must be made on losses that operate on a token-level basis.
249
249
250
250
### Skeleton code
251
251
@@ -282,7 +282,7 @@ for update_step in range(total_updates):
if (i <len(batch_samples) -1and accelerator.num_processes >1):
288
288
ctx = model.no_sync
@@ -294,7 +294,7 @@ for update_step in range(total_updates):
294
294
with ctx():
295
295
inputs, targets = batch
296
296
outputs = model(inputs)
297
-
loss = loss_function(outputs, targets) # the loss function shoud sum over samples rather than averaging
297
+
loss = loss_function(outputs, targets) # the loss function should sum over samples rather than averaging
298
298
299
299
# We multiply by num_processes because the DDP calculates the average gradient across all devices whereas dividing by num_items_in_batch already takes into account all devices
300
300
# Same reason for gradient_accumulation_steps, but this times it's Accelerate that calculate the average gradient across the accumulated steps
@@ -394,7 +394,7 @@ for update_step in range(total_gradient_updates):
0 commit comments