Skip to content

Commit 314ccc7

Browse files
committed
Cleanup: context parallel
1 parent 9359a01 commit 314ccc7

File tree

10 files changed

+300
-19
lines changed

10 files changed

+300
-19
lines changed

examples/fsdp2/nd_parallel.py

Lines changed: 45 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def parse_args():
4343
parser.add_argument("--dp-replicate-size", type=int, default=1)
4444
parser.add_argument("--dp-shard-size", type=int, default=1)
4545
parser.add_argument("--tp-size", type=int, default=1)
46+
parser.add_argument("--cp-size", type=int, default=1)
4647
parser.add_argument("--sequence-length", type=int, default=1024)
4748
parser.add_argument("--num-steps", type=int, default=1000)
4849
parser.add_argument("--save-dir", type=str, default="./outputs")
@@ -52,17 +53,42 @@ def parse_args():
5253
return parser.parse_args()
5354

5455

55-
def forward(model, batch, optimizer, accelerator):
56-
# To get the proper loss value, we need to average across devices that are participating in data parallel/context parallel training
57-
loss_reduce_grp = (
58-
accelerator.torch_device_mesh["dp_cp"].get_group() if accelerator.parallelism_config.dp_cp_dim_names else None
59-
)
60-
outputs = model(**batch)
61-
loss = outputs.loss
62-
accelerator.backward(loss)
63-
optimizer.step()
64-
optimizer.zero_grad()
65-
dist.all_reduce(loss, op=dist.ReduceOp.AVG, group=loss_reduce_grp)
56+
def _self_attn_pre_forward_hook(module, *args, **kwargs):
57+
kwargs = args[1]
58+
kwargs["attention_mask"] = None
59+
kwargs["is_causal"] = True
60+
args = list(args)
61+
args[1] = kwargs
62+
return tuple(args)
63+
64+
65+
def fix_model(model):
66+
for name, module in model.named_modules():
67+
if name.endswith("self_attn"):
68+
module: torch.nn.Module
69+
module.register_forward_pre_hook(_self_attn_pre_forward_hook, with_kwargs=True)
70+
71+
return model
72+
73+
74+
def forward(model, batch, optimizer, accelerator: Accelerator):
75+
input_ids, labels = batch["input_ids"], batch["labels"]
76+
with accelerator.maybe_context_parallel(
77+
buffers=[input_ids, labels], buffer_seq_dims=[1, 1], no_restore_buffers={input_ids, labels}
78+
):
79+
# To get the proper loss value, we need to average across devices that are participating in data parallel/context parallel training
80+
loss_reduce_grp = (
81+
accelerator.torch_device_mesh["dp_cp"].get_group()
82+
if accelerator.parallelism_config.dp_cp_dim_names
83+
else None
84+
)
85+
outputs = model(**batch)
86+
loss = outputs.loss
87+
accelerator.backward(loss)
88+
optimizer.step()
89+
optimizer.zero_grad()
90+
dist.all_reduce(loss, op=dist.ReduceOp.AVG, group=loss_reduce_grp)
91+
6692
return loss
6793

6894

@@ -71,21 +97,22 @@ def train(args):
7197
dp_replicate_size=args.dp_replicate_size,
7298
dp_shard_size=args.dp_shard_size,
7399
tp_size=args.tp_size,
100+
cp_size=args.cp_size,
74101
)
75102

76103
# FSDP needs extra configuration, so we properly shard the model
77-
if parallelism_config.dp_shard_enabled:
104+
fsdp2_plugin = None
105+
if parallelism_config.dp_shard_enabled or parallelism_config.cp_enabled:
78106
fsdp2_plugin = FullyShardedDataParallelPlugin(
79107
fsdp_version=2,
80108
auto_wrap_policy="transformer_based_wrap",
81109
transformer_cls_names_to_wrap=["LlamaDecoderLayer"],
110+
state_dict_type="SHARDED_STATE_DICT",
111+
activation_checkpointing=True,
82112
)
83113

84114
accelerator = Accelerator(
85-
log_with=["wandb"],
86-
mixed_precision="bf16",
87-
parallelism_config=parallelism_config,
88-
fsdp_plugin=fsdp2_plugin if parallelism_config.dp_shard_enabled else None,
115+
log_with=["wandb"], mixed_precision="bf16", parallelism_config=parallelism_config, fsdp_plugin=fsdp2_plugin
89116
)
90117
accelerator.init_trackers("nd_parallel_training")
91118

@@ -107,6 +134,8 @@ def train(args):
107134
dataloader = DataLoader(dataset, batch_size=1, collate_fn=create_collate_fn())
108135

109136
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
137+
if parallelism_config.cp_enabled:
138+
model = fix_model(model)
110139

111140
total_num_steps = min(args.num_steps, len(dataloader))
112141
performance_tracker = PerformanceTracker(warmup_steps=5)

src/accelerate/accelerator.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -448,6 +448,12 @@ def __init__(
448448

449449
parallelism_config = self._setup_parallelism_config(parallelism_config, torch_tp_plugin)
450450

451+
# TODO: Siro - figure out a better place where this can go (needs to be above AcceleratorState init)
452+
if parallelism_config and parallelism_config.cp_enabled and fsdp_plugin is None:
453+
raise ValueError(
454+
"`cp_enabled` is set to `True` in the `parallelism_config`, but no `fsdp_plugin` was provided. We need a `fsdp_plugin` to use `cp_enabled=True`, as we also shard the model across the device mesh to save more memory"
455+
)
456+
451457
kwargs = self.init_handler.to_kwargs() if self.init_handler is not None else {}
452458
kwargs["parallelism_config"] = parallelism_config
453459
self.state = AcceleratorState(
@@ -776,6 +782,10 @@ def _setup_parallelism_config(
776782
):
777783
if parallelism_config is None:
778784
if PartialState._shared_state != {} and PartialState().parallelism_config is not None:
785+
if os.environ.get("ACCELERATE_USE_PARALLELISM_CONFIG", "false") == "true":
786+
raise ValueError(
787+
"Partial state contains a `parallelism_config` which is not None, but you configured `parallelism_config` from the `accelerate launch` CLI. We don't know which to use, please remove one of those configuration methods."
788+
)
779789
parallelism_config = PartialState().parallelism_config
780790
else:
781791
# TODO: Remove after deprecating tp_plugin
@@ -1497,6 +1507,9 @@ def prepare(self, *args, device_placement=None):
14971507
if self.parallelism_config and self.parallelism_config.tp_enabled:
14981508
args = self._prepare_tp(*args)
14991509

1510+
if self.parallelism_config and self.parallelism_config.cp_enabled:
1511+
self._prepare_cp()
1512+
15001513
if self.fp8_backend == FP8BackendType.TE:
15011514
args = self._prepare_te(*args)
15021515
elif self.fp8_backend == FP8BackendType.AO:
@@ -1561,6 +1574,15 @@ def _prepare_tp(self, *args):
15611574

15621575
return args
15631576

1577+
def _prepare_cp(self):
1578+
from torch.distributed.tensor.experimental import context_parallel
1579+
from torch.distributed.tensor.experimental._attention import set_rotate_method
1580+
1581+
cp_comm_strategy = self.parallelism_config.cp_handler.cp_comm_stategy
1582+
set_rotate_method(cp_comm_strategy)
1583+
1584+
self._cp_context = functools.partial(context_parallel, mesh=self.torch_device_mesh["cp"])
1585+
15641586
def _prepare_fsdp2(self, *args):
15651587
# First pass: prepare everything except schedulers (and model, which is prepared separately below)
15661588
result = [
@@ -3903,6 +3925,68 @@ def register_for_checkpointing(self, *objects):
39033925
raise ValueError(err)
39043926
self._custom_objects.extend(objects)
39053927

3928+
@contextmanager
3929+
def maybe_context_parallel(
3930+
self,
3931+
buffers: list[torch.Tensor] | None = None,
3932+
buffer_seq_dims: list[int] | None = None,
3933+
no_restore_buffers: set[torch.Tensor] | None = None,
3934+
):
3935+
"""
3936+
A context manager that enables context parallel training.
3937+
3938+
Args:
3939+
buffers (`list[torch.Tensor]`, `optional`):
3940+
Buffers, which are going to be sharded along the sequence dimension. Common examples are inputs, labels
3941+
or positional embedding buffers. This context manager will modify these buffers in-place, and after
3942+
exiting the context, the buffers will be restored to their original state. To avoid unnecessary
3943+
restores, you can use `no_restore_buffers` to specify which buffers don't need to be restored.
3944+
buffer_seq_dims (`list[int]`, `optional`):
3945+
Sequence dimensions of `buffers`.
3946+
no_restore_buffers (`set[torch.Tensor]`, `optional`):
3947+
This set must be a subset of `buffers`. Specifies which buffers from `buffers` argument won't be
3948+
restored after the context exits. These buffers will be then kept in sharded state.
3949+
3950+
<Tip warning={true}>
3951+
3952+
`context_parallel` is currently only supported together with FSDP2, and requires `parallelism_config.cp_size` > 1. If either
3953+
of these conditions are not met, this context manager will have no effect, though to enable fewer code changes it will not raise an Exception.
3954+
3955+
</Tip>
3956+
3957+
<Tip warning={true}>
3958+
3959+
This context manager has to be recreated with each training step, as shown in the example below.
3960+
3961+
</Tip>
3962+
3963+
Example:
3964+
3965+
```python
3966+
>>> for batch in dataloader:
3967+
... with accelerator.maybe_context_parallel(
3968+
... buffers=[batch["input_ids"], batch["attention_mask"]],
3969+
... buffer_seq_dims=[1, 1],
3970+
... no_restore_buffers={batch["input_ids"]},
3971+
... ):
3972+
... outputs = model(batch)
3973+
... ...
3974+
```
3975+
"""
3976+
# We don't need to check FSDP2 as parallelism_config does that for us
3977+
# Invariant: in this branch self._cp_context is set, as it was set by `self._prepare_cp`
3978+
if self.parallelism_config and self.parallelism_config.cp_enabled:
3979+
with self._cp_context(
3980+
buffers=buffers, buffer_seq_dims=buffer_seq_dims, no_restore_buffers=no_restore_buffers
3981+
):
3982+
yield
3983+
else:
3984+
if not getattr(self, "_warned_about_cp", False) and self.is_main_process:
3985+
logger.warning("Context parallel training is not enabled. This context manager will have no effect.")
3986+
# As this context manager is recreated each training step, we only warn once
3987+
self._warned_about_cp = True
3988+
yield
3989+
39063990
@contextmanager
39073991
def autocast(self, autocast_handler: AutocastKwargs = None):
39083992
"""

src/accelerate/commands/config/cluster.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -505,6 +505,53 @@ def get_cluster_input():
505505
error_message="Please enter yes or no.",
506506
)
507507

508+
parallelism_config = {}
509+
510+
if fsdp_config.get("fsdp_version", 1) == 2:
511+
use_parallelism_config = _ask_field(
512+
"Do you want to use the parallelism config? [yes/NO]: ",
513+
_convert_yes_no_to_bool,
514+
default=False,
515+
error_message="Please enter yes or no.",
516+
)
517+
518+
if use_parallelism_config:
519+
prefix = "parallelism_config_"
520+
parallelism_config[prefix + "dp_replicate_size"] = _ask_field(
521+
"What is the data parallelism replicate size? [1]: ",
522+
int,
523+
default=1,
524+
error_message="Please enter an integer.",
525+
)
526+
527+
parallelism_config[prefix + "dp_shard_size"] = _ask_field(
528+
"What is the FSDP shard size? [1]: ",
529+
int,
530+
default=1,
531+
error_message="Please enter an integer.",
532+
)
533+
534+
parallelism_config[prefix + "tp_size"] = _ask_field(
535+
"What is the tensor parallelism size? [1]: ",
536+
int,
537+
default=1,
538+
error_message="Please enter an integer.",
539+
)
540+
541+
parallelism_config[prefix + "cp_size"] = _ask_field(
542+
"What is the context parallelism size? [1]: ",
543+
int,
544+
default=1,
545+
error_message="Please enter an integer.",
546+
)
547+
if parallelism_config[prefix + "cp_size"] > 1:
548+
parallelism_config[prefix + "cp_comm_strategy"] = _ask_options(
549+
"What is the compute parallelism communication strategy?",
550+
["allgather", "alltoall"],
551+
lambda x: ["allgather", "alltoall"][int(x)],
552+
default=0,
553+
)
554+
508555
megatron_lm_config = {}
509556
if distributed_type in [DistributedType.MULTI_GPU]:
510557
use_megatron_lm = _ask_field(
@@ -849,6 +896,7 @@ def get_cluster_input():
849896
fp8_config=fp8_config,
850897
deepspeed_config=deepspeed_config,
851898
fsdp_config=fsdp_config,
899+
parallelism_config=parallelism_config,
852900
megatron_lm_config=megatron_lm_config,
853901
ipex_config=ipex_config,
854902
mpirun_config=mpirun_config,

src/accelerate/commands/config/config_args.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,8 @@ class ClusterConfig(BaseConfig):
194194
deepspeed_config: dict = None
195195
# args for fsdp
196196
fsdp_config: dict = None
197+
# args for parallelism config
198+
parallelism_config: dict = None
197199
# args for megatron_lm
198200
megatron_lm_config: dict = None
199201
# args for ipex
@@ -229,6 +231,8 @@ def __post_init__(self):
229231
self.mpirun_config = {}
230232
if self.fp8_config is None:
231233
self.fp8_config = {}
234+
if self.parallelism_config is None:
235+
self.parallelism_config = {}
232236
return super().__post_init__()
233237

234238

src/accelerate/commands/launch.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,12 @@ def launch_command_parser(subparsers=None):
269269
action="store_true",
270270
help="Whether to use fsdp.",
271271
)
272+
paradigm_args.add_argument(
273+
"--use_parallelism_config",
274+
default=False,
275+
action="store_true",
276+
help="Whether to use the parallelism config to configure the N-d distributed training.",
277+
)
272278
paradigm_args.add_argument(
273279
"--use_megatron_lm",
274280
default=False,
@@ -767,6 +773,45 @@ def launch_command_parser(subparsers=None):
767773
help="The number of oneCCL worker threads when using Accelerate to launch multi-CPU training with mpirun.",
768774
)
769775

776+
# ParallelismConfig arguments
777+
parallelism_config_args = parser.add_argument_group(
778+
"ParallelismConfig Arguments",
779+
"Arguments related to the ParallelismConfig used for distributed training.",
780+
)
781+
parallelism_config_args.add_argument(
782+
"--parallelism_config_dp_replicate_size",
783+
type=int,
784+
default=1,
785+
help="The number of processes for data parallel training. Defaults to 1 (no data parallelism).",
786+
)
787+
788+
parallelism_config_args.add_argument(
789+
"--parallelism_config_dp_shard_size",
790+
type=int,
791+
default=-1,
792+
help="The number of processes for FSDP sharding. Defaults to 1 (No FSDP sharding).",
793+
)
794+
795+
parallelism_config_args.add_argument(
796+
"--parallelism_config_tp_size",
797+
type=int,
798+
default=1,
799+
help="The number of processes for tensor parallel training. Defaults to 1 (no tensor parallelism).",
800+
)
801+
802+
parallelism_config_args.add_argument(
803+
"--parallelism_config_cp_size",
804+
type=int,
805+
default=1,
806+
help="The number of processese for context parallel training. Defaults to 1 (no context parallelism).",
807+
)
808+
parallelism_config_args.add_argument(
809+
"--parallelism_config_cp_comm_strategy",
810+
type=str,
811+
default="allgather",
812+
help="The communication strategy for context parallel training. Defaults to 'allgather'. Other option is alltoall",
813+
)
814+
770815
# Other arguments of the training scripts
771816
parser.add_argument("training_script_args", nargs=argparse.REMAINDER, help="Arguments of the training script.")
772817

@@ -994,6 +1039,16 @@ def _validate_launch_command(args):
9941039
if args.multi_gpu and (args.num_processes is not None) and (args.num_processes < 2):
9951040
raise ValueError("You need to use at least 2 processes to use `--multi_gpu`.")
9961041

1042+
# TODO: Merge into 1 if
1043+
if not args.use_fsdp and args.use_parallelism_config:
1044+
raise ValueError(
1045+
"You cannot use `--use_parallelism_config` without `--use_fsdp`. Please set `--use_fsdp` to True if you want to use parallelism config."
1046+
)
1047+
elif args.fsdp_version == 1 and args.use_parallelism_config:
1048+
raise ValueError(
1049+
"You cannot use `--use_parallelism_config` with FSDP version 1. Please set `--fsdp_version=2` if you want to use parallelism config."
1050+
)
1051+
9971052
defaults = None
9981053
warned = []
9991054
mp_from_config_flag = False
@@ -1027,6 +1082,7 @@ def _validate_launch_command(args):
10271082
args.use_fsdp = defaults.distributed_type == DistributedType.FSDP
10281083
args.use_megatron_lm = defaults.distributed_type == DistributedType.MEGATRON_LM
10291084
args.tpu_use_cluster = defaults.tpu_use_cluster if args.tpu else False
1085+
args.use_parallelism_config = defaults.parallelism_config != {}
10301086
if args.gpu_ids is None:
10311087
if defaults.gpu_ids is not None:
10321088
args.gpu_ids = defaults.gpu_ids

0 commit comments

Comments
 (0)