Skip to content

Commit 35c24ff

Browse files
committed
Fix: some validation
1 parent 8fdedc3 commit 35c24ff

File tree

2 files changed

+14
-7
lines changed

2 files changed

+14
-7
lines changed

examples/fsdp2/nd_parallel.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -77,17 +77,17 @@ def forward(model, batch, optimizer, accelerator: Accelerator):
7777
buffers=[input_ids, labels], buffer_seq_dims=[1, 1], no_restore_buffers={input_ids, labels}
7878
):
7979
# To get the proper loss value, we need to average across devices that are participating in data parallel/context parallel training
80-
# loss_reduce_grp = (
81-
# accelerator.torch_device_mesh["dp_cp"].get_group()
82-
# if accelerator.parallelism_config.dp_cp_dim_names
83-
# else None
84-
# )
80+
loss_reduce_grp = (
81+
accelerator.torch_device_mesh["dp_cp"].get_group()
82+
if accelerator.parallelism_config.dp_cp_dim_names
83+
else None
84+
)
8585
outputs = model(**batch)
8686
loss = outputs.loss
8787
accelerator.backward(loss)
8888
optimizer.step()
8989
optimizer.zero_grad()
90-
# dist.all_reduce(loss, op=dist.ReduceOp.AVG, group=loss_reduce_grp)
90+
dist.all_reduce(loss, op=dist.ReduceOp.AVG, group=loss_reduce_grp)
9191

9292
return loss
9393

@@ -134,7 +134,8 @@ def train(args):
134134
dataloader = DataLoader(dataset, batch_size=1, collate_fn=create_collate_fn())
135135

136136
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
137-
model = fix_model(model)
137+
if parallelism_config.cp_enabled:
138+
model = fix_model(model)
138139

139140
total_num_steps = min(args.num_steps, len(dataloader))
140141
performance_tracker = PerformanceTracker(warmup_steps=5)

src/accelerate/accelerator.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -448,6 +448,12 @@ def __init__(
448448

449449
parallelism_config = self._setup_parallelism_config(parallelism_config, torch_tp_plugin)
450450

451+
# TODO: Siro - figure out a better place where this can go (needs to be above AcceleratorState init)
452+
if parallelism_config and parallelism_config.cp_enabled and fsdp_plugin is None:
453+
raise ValueError(
454+
"`cp_enabled` is set to `True` in the `parallelism_config`, but no `fsdp_plugin` was provided. We need a `fsdp_plugin` to use `cp_enabled=True`, as we also shard the model across the device mesh to save more memory"
455+
)
456+
451457
kwargs = self.init_handler.to_kwargs() if self.init_handler is not None else {}
452458
kwargs["parallelism_config"] = parallelism_config
453459
self.state = AcceleratorState(

0 commit comments

Comments
 (0)