Skip to content

Commit 1daa26b

Browse files
committed
Feat: address comments
1 parent 1188e8b commit 1daa26b

File tree

9 files changed

+60
-51
lines changed

9 files changed

+60
-51
lines changed

docs/source/concept_guides/context_parallelism.md

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,16 +91,33 @@ This can scale your context size to 1M+ sequence length potentially. Below, we s
9191
</p>
9292

9393
> [!Tip]
94-
> These examples were created with a script you can find [in the examples folder](https://github.com/huggingface/accelerate/blob/main/examples/fsdp2/nd_parallel.py). For instructions on how to run it, see the [README](https://github.com/huggingface/accelerate/blob/main/examples/fsdp2/README.md) in the same folder.
94+
> These examples were created with a script you can find [in the examples folder](https://github.com/huggingface/accelerate/blob/main/examples/fsdp2/nd_parallel.py). To run the example on 8 H100 GPUs (128k sequence length), you can use the following command:
95+
> ```bash
96+
> accelerate launch --use-fsdp --fsdp-activation-checkpointing=TRUE examples/fsdp2/nd_parallel.py --cp-size=8 --sequence-length=128000
97+
> ```
9598
9699
97100
## Accelerate's interface
98101
99102
The context manager takes a few arguments, that are used to configure the context parallelism.
100103
101104
- `buffers`: This is a list of tensors that are to be sharded across the sequence dimension. These tensors are usually input ids, labels and attention mask.
102-
- `buffer_seq_dims`: This is a list of integers, that specify the sequence dimension of the buffers, in the order of the `buffers` list.
103-
- `no_restore_buffers`: The implementation of context parallelism modifies the buffers in-place, converting them to `torch.distributed.tensor.Dtensor`s. After the context manager is exited, a communication kernel would need to be launched to restore the buffers to their original state (usually all-gather). This takes some time, so it is recommended to pass the same tensors as in the `buffers` argument, to avoid unnecessary communication, unless you are sure that you need to use the buffers after the context manager is exited.
105+
- `buffer_seq_dims`: This is a list of integers, that specify the sequence dimension of the buffers, in the order of the `buffers` list. If you pass `buffers=[input_ids, shift_labels]` with both having shape `[batch_size, sequence_length]`, you would pass `buffer_seq_dims=[1, 1]`.
106+
as the sequence dimension is the second dimension of the tensors. This is required for correct computation of the model outputs.
107+
- `no_restore_buffers`: The implementation of context parallelism modifies the buffers in-place, converting them to `torch.distributed.tensor.Dtensor`s. After the context manager exits, a communication kernel would need to be launched to restore the buffers to their original state (usually all-gather). This takes some time, so it is recommended to pass the same tensors as in the `buffers` argument, to avoid unnecessary communication, unless you are sure that you need to use the buffers after the context manager exits.
108+
109+
110+
> [!Warning]
111+
> Context parallelism is not compatible with `labels` that are a copy of `input_ids`, which models from 🤗 transformers can shift to enable causal language modeling themselves.
112+
> Imagine this case:
113+
> labels = [l1, l2, l3, l4, ... li]
114+
> if we apply context parallelism, each rank would end up with a part of labels, such as this:
115+
> labels_rank_0 = [l1, l2], labels_rank_1 = [l3, l4], ...
116+
> after transformers modelling code shifts the labels, it would end up with:
117+
> labels_rank_0 = [l2, PAD], labels_rank_1 = [l3, PAD], ...
118+
> where `PAD` is a padding token. This would result in incorrect loss computation, as the labels are not aligned with the inputs anymore.
119+
> Because of this, you need to manually shift the labels before passing them in the model
120+
104121
105122
## Configurable options
106123
Accelerate provides only a single option to configure context parallelism (except of `cp_size`)

examples/fsdp2/nd_parallel.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import torch.distributed as dist
2424
from torch.utils.data import DataLoader
2525
from transformers import AutoModelForCausalLM
26+
from transformers.loss.loss_utils import ForCausalLMLoss
2627

2728
from accelerate import Accelerator
2829
from accelerate.parallelism_config import ParallelismConfig
@@ -54,9 +55,9 @@ def parse_args():
5455

5556

5657
def forward(model, batch, optimizer, accelerator: Accelerator):
57-
input_ids, labels = batch["input_ids"], batch["labels"]
58+
input_ids, shift_labels = batch["input_ids"], batch["shift_labels"]
5859
with accelerator.maybe_context_parallel(
59-
buffers=[input_ids, labels], buffer_seq_dims=[1, 1], no_restore_buffers={input_ids, labels}
60+
buffers=[input_ids, shift_labels], buffer_seq_dims=[1, 1], no_restore_buffers={input_ids, shift_labels}
6061
):
6162
# To get the proper loss value, we need to average across devices that are participating in data parallel/context parallel training
6263
loss_reduce_grp = (
@@ -65,7 +66,10 @@ def forward(model, batch, optimizer, accelerator: Accelerator):
6566
else None
6667
)
6768
outputs = model(**batch)
68-
loss = outputs.loss
69+
# With shift labels we need to compute loss ourselves
70+
loss = ForCausalLMLoss(
71+
logits=outputs.logits, labels=None, shift_labels=shift_labels, vocab_size=model.config.vocab_size
72+
)
6973
accelerator.backward(loss)
7074
optimizer.step()
7175
optimizer.zero_grad()
@@ -90,7 +94,6 @@ def train(args):
9094
auto_wrap_policy="transformer_based_wrap",
9195
transformer_cls_names_to_wrap=["LlamaDecoderLayer"],
9296
state_dict_type="SHARDED_STATE_DICT",
93-
activation_checkpointing=True,
9497
)
9598

9699
accelerator = Accelerator(
@@ -155,7 +158,7 @@ def train(args):
155158
if __name__ == "__main__":
156159
set_seed(42)
157160
args = parse_args()
158-
if args.dp_shard_size == 1:
161+
if args.dp_shard_size == 1 and args.tp_size > 1:
159162
# We currently don't support saving with `save_state` when using only
160163
# tensor parallelism, fsdp must be enabled
161164
warnings.warn(

examples/fsdp2/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def create_packed_sequences(examples):
6969
packed_input_ids.append(full_sequence[:-1])
7070
packed_labels.append(full_sequence[1:])
7171

72-
return {"input_ids": packed_input_ids, "labels": packed_labels}
72+
return {"input_ids": packed_input_ids, "shift_labels": packed_labels}
7373

7474
with accelerator.main_process_first():
7575
packed_dataset = tokenized_dataset.map(
@@ -111,8 +111,8 @@ def create_collate_fn():
111111

112112
def collate_fn(batch):
113113
input_ids = torch.tensor([item["input_ids"] for item in batch], dtype=torch.long)
114-
labels = torch.tensor([item["labels"] for item in batch], dtype=torch.long)
115-
return {"input_ids": input_ids, "labels": labels}
114+
shift_labels = torch.tensor([item["shift_labels"] for item in batch], dtype=torch.long)
115+
return {"input_ids": input_ids, "shift_labels": shift_labels}
116116

117117
return collate_fn
118118

src/accelerate/accelerator.py

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535

3636
from accelerate.utils.dataclasses import FP8BackendType
3737

38-
from .big_modeling import attach_context_parallel_hooks
38+
from .big_modeling import _attach_context_parallel_hooks
3939
from .checkpointing import load_accelerator_state, load_custom_state, save_accelerator_state, save_custom_state
4040
from .data_loader import DataLoaderDispatcher, prepare_data_loader, skip_first_batches
4141
from .logging import get_logger
@@ -449,12 +449,6 @@ def __init__(
449449

450450
parallelism_config = self._setup_parallelism_config(parallelism_config, torch_tp_plugin)
451451

452-
# TODO: Siro - figure out a better place where this can go (needs to be above AcceleratorState init)
453-
if parallelism_config and parallelism_config.cp_enabled and fsdp_plugin is None:
454-
raise ValueError(
455-
"`cp_enabled` is set to `True` in the `parallelism_config`, but no `fsdp_plugin` was provided. We need a `fsdp_plugin` to use `cp_enabled=True`, as we also shard the model across the device mesh to save more memory"
456-
)
457-
458452
kwargs = self.init_handler.to_kwargs() if self.init_handler is not None else {}
459453
kwargs["parallelism_config"] = parallelism_config
460454
self.state = AcceleratorState(
@@ -1585,10 +1579,8 @@ def _prepare_cp(self, *args):
15851579
self._cp_context = functools.partial(context_parallel, mesh=self.torch_device_mesh["cp"])
15861580

15871581
for arg in args:
1588-
if not isinstance(arg, torch.nn.Module):
1589-
continue
1590-
1591-
attach_context_parallel_hooks(arg)
1582+
if isinstance(arg, torch.nn.Module):
1583+
_attach_context_parallel_hooks(arg)
15921584

15931585
return args
15941586

@@ -3991,10 +3983,10 @@ def maybe_context_parallel(
39913983
):
39923984
yield
39933985
else:
3994-
if not getattr(self, "_warned_about_cp", False) and self.is_main_process:
3995-
logger.warning("Context parallel training is not enabled. This context manager will have no effect.")
3996-
# As this context manager is recreated each training step, we only warn once
3997-
self._warned_about_cp = True
3986+
logger.warning_once(
3987+
"Context parallel training is not enabled. This context manager will have no effect. "
3988+
"To enable it, set `parallelism_config.cp_size` > 1 in the `Accelerator` constructor."
3989+
)
39983990
yield
39993991

40003992
@contextmanager

src/accelerate/big_modeling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -749,7 +749,7 @@ def _attach_layerwise_casting_hooks(
749749
)
750750

751751

752-
def attach_context_parallel_hooks(
752+
def _attach_context_parallel_hooks(
753753
model: nn.Module,
754754
):
755755
"""

src/accelerate/commands/launch.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1039,15 +1039,8 @@ def _validate_launch_command(args):
10391039
if args.multi_gpu and (args.num_processes is not None) and (args.num_processes < 2):
10401040
raise ValueError("You need to use at least 2 processes to use `--multi_gpu`.")
10411041

1042-
# TODO: Merge into 1 if
1043-
if not args.use_fsdp and args.use_parallelism_config:
1044-
raise ValueError(
1045-
"You cannot use `--use_parallelism_config` without `--use_fsdp`. Please set `--use_fsdp` to True if you want to use parallelism config."
1046-
)
1047-
elif args.fsdp_version == 1 and args.use_parallelism_config:
1048-
raise ValueError(
1049-
"You cannot use `--use_parallelism_config` with FSDP version 1. Please set `--fsdp_version=2` if you want to use parallelism config."
1050-
)
1042+
if (not args.use_fsdp or args.fsdp_version == 1) and args.use_parallelism_config:
1043+
raise ValueError("You cannot use `--use_parallelism_config` without `--use_fsdp` and `--fsdp_version=2`. ")
10511044

10521045
defaults = None
10531046
warned = []

src/accelerate/parallelism_config.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,10 @@ class ParallelismConfig:
5757
5858
"""
5959

60-
dp_replicate_size: int = 1
61-
dp_shard_size: int = 1
62-
tp_size: int = 1
63-
cp_size: int = 1
60+
dp_replicate_size: int = None
61+
dp_shard_size: int = None
62+
tp_size: int = None
63+
cp_size: int = None
6464

6565
# we use Union because we might support other x parallel plugins (i.e. deepspeed, etc)
6666
tp_handler: Union[None, TorchTensorParallelConfig] = None
@@ -210,13 +210,13 @@ def _get_mesh(self) -> tuple[tuple[int, ...], tuple[str, ...]]:
210210

211211
def __post_init__(self):
212212
# Basic size validation
213-
if self.dp_replicate_size == 1:
213+
if self.dp_replicate_size is None:
214214
self.dp_replicate_size = int(os.environ.get("PARALLELISM_CONFIG_DP_REPLICATE_SIZE", "1"))
215-
if self.dp_shard_size == 1:
215+
if self.dp_shard_size is None:
216216
self.dp_shard_size = int(os.environ.get("PARALLELISM_CONFIG_DP_SHARD_SIZE", "1"))
217-
if self.tp_size == 1:
217+
if self.tp_size is None:
218218
self.tp_size = int(os.environ.get("PARALLELISM_CONFIG_TP_SIZE", "1"))
219-
if self.cp_size == 1:
219+
if self.cp_size is None:
220220
self.cp_size = int(os.environ.get("PARALLELISM_CONFIG_CP_SIZE", "1"))
221221

222222
if self.tp_size > 1:

src/accelerate/state.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -981,6 +981,10 @@ def __init__(
981981
DistributedType.MULTI_XPU,
982982
DistributedType.MULTI_HPU,
983983
]:
984+
if self.parallelism_config and self.parallelism_config.cp_enabled and fsdp_plugin is None:
985+
raise ValueError(
986+
"`cp_size > 1` in the `parallelism_config`, but no `fsdp_plugin` was provided. We need a `fsdp_plugin` to use `cp_enabled=True`, as we also shard the model across the device mesh to save more memory"
987+
)
984988
if (os.environ.get("ACCELERATE_USE_FSDP", "false") == "true" or fsdp_plugin is not None) or (
985989
self.parallelism_config is not None and self.parallelism_config.cp_enabled
986990
):

tests/test_dataclasses.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -105,10 +105,10 @@ def test_get_mesh(
105105
):
106106
# Skip tests based on version requirements
107107
if _should_skip_cp_test(cp_size):
108-
pytest.skip(f"CP tests require torch >= {BETA_CP_AVAILABLE_PYTORCH_VERSION}")
108+
pytest.skip(f"tests with `cp_size>1` require torch >= {BETA_CP_AVAILABLE_PYTORCH_VERSION}")
109109
if _should_skip_tp_test(tp_size):
110110
pytest.skip(
111-
f"TP tests require torch >= {BETA_TP_AVAILABLE_PYTORCH_VERSION}, transformers available and >= {BETA_TP_AVAILABLE_TRANSFORMERS_VERSION}"
111+
f"tests with `tp_size>1` require torch >= {BETA_TP_AVAILABLE_PYTORCH_VERSION}, transformers available and >= {BETA_TP_AVAILABLE_TRANSFORMERS_VERSION}"
112112
)
113113

114114
config = ParallelismConfig(
@@ -145,10 +145,10 @@ def test_build_device_mesh(
145145
"""Test build_device_mesh creates correct mesh and applies flattening."""
146146
# Skip tests based on version requirements
147147
if _should_skip_cp_test(cp_size):
148-
pytest.skip(f"CP tests require torch >= {BETA_CP_AVAILABLE_PYTORCH_VERSION}")
148+
pytest.skip(f"tests with `cp_size>1` require torch >= {BETA_CP_AVAILABLE_PYTORCH_VERSION}")
149149
if _should_skip_tp_test(tp_size):
150150
pytest.skip(
151-
f"TP tests require torch >= {BETA_TP_AVAILABLE_PYTORCH_VERSION}, transformers available and >= {BETA_TP_AVAILABLE_TRANSFORMERS_VERSION}"
151+
f"tests with `tp_size>1` require torch >= {BETA_TP_AVAILABLE_PYTORCH_VERSION}, transformers available and >= {BETA_TP_AVAILABLE_TRANSFORMERS_VERSION}"
152152
)
153153

154154
config = ParallelismConfig(
@@ -194,10 +194,10 @@ def test_from_env(
194194
cp_size,
195195
):
196196
if _should_skip_cp_test(cp_size):
197-
pytest.skip(f"CP tests require torch >= {BETA_CP_AVAILABLE_PYTORCH_VERSION}")
197+
pytest.skip(f"tests with `cp_size>1` require torch >= {BETA_CP_AVAILABLE_PYTORCH_VERSION}")
198198
if _should_skip_tp_test(tp_size):
199199
pytest.skip(
200-
f"TP tests require torch >= {BETA_TP_AVAILABLE_PYTORCH_VERSION}, transformers available and >= {BETA_TP_AVAILABLE_TRANSFORMERS_VERSION}"
200+
f"tests with `tp_size>1` require torch >= {BETA_TP_AVAILABLE_PYTORCH_VERSION}, transformers available and >= {BETA_TP_AVAILABLE_TRANSFORMERS_VERSION}"
201201
)
202202

203203
new_env = {
@@ -217,7 +217,7 @@ def test_cp_handler(self):
217217

218218
# Any cp_size > 1 requires torch >= BETA_CP_AVAILABLE_PYTORCH_VERSION, we use placeholder for this check as this test doesn't depend on a specific size
219219
if _should_skip_cp_test(2):
220-
pytest.skip(f"CP tests require torch >= {BETA_CP_AVAILABLE_PYTORCH_VERSION}")
220+
pytest.skip(f"tests with `cp_size>1` require torch >= {BETA_CP_AVAILABLE_PYTORCH_VERSION}")
221221

222222
from accelerate.utils import TorchContextParallelConfig
223223

0 commit comments

Comments
 (0)