Skip to content

Commit e7dc541

Browse files
committed
Fix: remove import, improve comment
1 parent 1593bc9 commit e7dc541

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

examples/fsdp2/nd_parallel.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
import torch.distributed as dist
2424
from torch.utils.data import DataLoader
2525
from transformers import AutoModelForCausalLM
26-
from transformers.loss.loss_utils import ForCausalLMLoss
2726

2827
from accelerate import Accelerator
2928
from accelerate.parallelism_config import ParallelismConfig
@@ -62,6 +61,8 @@ def forward(model, batch, optimizer, accelerator: Accelerator):
6261
buffers=buffers, buffer_seq_dims=[1, 1, 1], no_restore_buffers=set(buffers)
6362
):
6463
# To get the proper loss value, we need to average across devices that are participating in data parallel/context parallel training
64+
# As for DP we have a different batch on each device and for CP we essentially have a different part of sequences on each device
65+
# I.e. with causal modelling and seq_len 1024, this dimension becomes another batch dimension of sorts
6566
loss_reduce_grp = (
6667
accelerator.torch_device_mesh["dp_cp"].get_group()
6768
if accelerator.parallelism_config.dp_cp_dim_names

0 commit comments

Comments
 (0)