Skip to content

Feat: context parallel v2.0 #3700

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Open

Feat: context parallel v2.0 #3700

wants to merge 13 commits into from

Conversation

S1ro1
Copy link
Member

@S1ro1 S1ro1 commented Jul 31, 2025

Integrates CP seamlessly with previously merged ParallelismConfig. Builds on top of #3604 (which it supersedes) and so on.

@S1ro1 S1ro1 force-pushed the feat/context-parallel branch from 35c24ff to 314ccc7 Compare July 31, 2025 19:12
@S1ro1
Copy link
Member Author

S1ro1 commented Jul 31, 2025

Supersedes #3604 as I can't bother with fixing git on that branch.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@S1ro1 S1ro1 changed the title [WIP] Feat/context parallel v2.0 Feat: context parallel v2.0 Aug 4, 2025
Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice job ! Thanks for integrating CP using parallelism_config ! Left a couple of comments

Comment on lines 103 to 104
- `no_restore_buffers`: The implementation of context parallelism modifies the buffers in-place, converting them to `torch.distributed.tensor.Dtensor`s. After the context manager is exited, a communication kernel would need to be launched to restore the buffers to their original state (usually all-gather). This takes some time, so it is recommended to pass the same tensors as in the `buffers` argument, to avoid unnecessary communication, unless you are sure that you need to use the buffers after the context manager is exited.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since this is recommended to pass the same tensor, can't we change the default to that ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO this should be opt-in instead of opt-out as by default we want to leave the buffers unchanged.

The context manager takes a few arguments, that are used to configure the context parallelism.

- `buffers`: This is a list of tensors that are to be sharded across the sequence dimension. These tensors are usually input ids, labels and attention mask.
- `buffer_seq_dims`: This is a list of integers, that specify the sequence dimension of the buffers, in the order of the `buffers` list.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add more details to that somewhere ? some ppl might not understand what you mean by sequence dimension of the buffers. Maybe in the example ?

Comment on lines 61 to 66
# To get the proper loss value, we need to average across devices that are participating in data parallel/context parallel training
loss_reduce_grp = (
accelerator.torch_device_mesh["dp_cp"].get_group()
if accelerator.parallelism_config.dp_cp_dim_names
else None
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice use of the device mesh

Comment on lines 92 to 93
state_dict_type="SHARDED_STATE_DICT",
activation_checkpointing=True,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why add that ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mb, was supposed to be removed, this is needed for the 128k seq-len example I was testing so I forgot it in

Comment on lines 452 to 457
# TODO: Siro - figure out a better place where this can go (needs to be above AcceleratorState init)
if parallelism_config and parallelism_config.cp_enabled and fsdp_plugin is None:
raise ValueError(
"`cp_enabled` is set to `True` in the `parallelism_config`, but no `fsdp_plugin` was provided. We need a `fsdp_plugin` to use `cp_enabled=True`, as we also shard the model across the device mesh to save more memory"
)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why ? Can't we put that in _validate_accelerator ?

Copy link
Member Author

@S1ro1 S1ro1 Aug 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In AcceleratorState we already need to set stuff based on cp being enabled, we would get a default fsdp plugin out -> no way to detect if it was not passed in.

EDIT: moved this into acceleratorState (still more viable imo)

Comment on lines +985 to +986
self.parallelism_config is not None and self.parallelism_config.cp_enabled
):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure this is needed no ? in which case fsdp is not used but still we set self.distributed_type = DistributedType.FSDP

@SunMarc SunMarc requested a review from winglian August 4, 2025 11:32
Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice ! LGTM !

Comment on lines +984 to +987
if self.parallelism_config and self.parallelism_config.cp_enabled and fsdp_plugin is None:
raise ValueError(
"`cp_size > 1` in the `parallelism_config`, but no `fsdp_plugin` was provided. We need a `fsdp_plugin` to use `cp_enabled=True`, as we also shard the model across the device mesh to save more memory"
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc might not be optimal for downstream lib like axolotl cc @winglian

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, we definitely support CP without FSDP, so this would break that feature. Maybe some other sort of explicit setting that a user is enabling accelerate to handle CP for them? @djsaunde

Copy link

@djsaunde djsaunde Aug 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm curious why this is gated in the first place, can we not use CP in accelerate sans FSDP? They should be independent.

Copy link
Member Author

@S1ro1 S1ro1 Aug 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm curious why this is gated in the first place, can we not use CP in accelerate sans FSDP? They should be independent.

They are (of sorts) but FSDP is a free lunch with CP, so imo it should be the default. While we're computing the ring-attention we can prefetch next fsdp_layer for free, giving us 1/cp_size*fsdp_size savings in model/optimizer/grads. I have some profiling for this in the concept guide.

TLDR: it can be independent, but there's (almost) no world where it's worth to not do FSDP on top.

@@ -1497,6 +1502,9 @@ def prepare(self, *args, device_placement=None):
if self.parallelism_config and self.parallelism_config.tp_enabled:
args = self._prepare_tp(*args)

if self.parallelism_config and self.parallelism_config.cp_enabled:
args = self._prepare_cp(*args)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am worried that in Axolotl, that automatically handling this might break existing context-parallel handling we have. @djsaunde

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I think it will. ring-flash-attn (what we use) supports non-causal masks so e.g. the mask deletion / replacement with causal=True is not good for us. We could patch over this maybe? I like the way pre-hooks are handled here in the accelerator, so we could swap to that vs. our context manager.

Copy link
Member Author

@S1ro1 S1ro1 Aug 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can probably just delete the hook? It's gonna be the first one always, as I use prepend=True so should be pretty simple! The hook is the only thing we add for cp to work, then you only would need to use the context manager (which you won't)

Comment on lines +214 to +220
self.dp_replicate_size = int(os.environ.get("PARALLELISM_CONFIG_DP_REPLICATE_SIZE", "1"))
if self.dp_shard_size is None:
self.dp_shard_size = int(os.environ.get("PARALLELISM_CONFIG_DP_SHARD_SIZE", "1"))
if self.tp_size is None:
self.tp_size = int(os.environ.get("PARALLELISM_CONFIG_TP_SIZE", "1"))
if self.cp_size is None:
self.cp_size = int(os.environ.get("PARALLELISM_CONFIG_CP_SIZE", "1"))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

aren't most accelerate env vars prefixed with ACCELERATE_?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh yes good catch, can you fix that @S1ro1 ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not entirely true actually, paradigm variables are prefixed with ACCELERATE (such as use_fsdp, use_deepspeed), then variables configuring the underlying implementations are prefixed with the impl name (i.e. FSDP_, MEGATRON_LM_) etc. Such as here or here

Afaik only DeepSpeed is special and prefixes with ACCELERATE_DEEPSPEED_ making the env variables insanely long

Comment on lines +758 to +761
This function attaches forward_pre_hooks to each self_attn module of the model, where each hook checks the
args/kwargs, if they contain an attention mask, if it does, it will remove this mask, check if it is a causal mask,
if yes, will add a kwarg `is_causal=True`, otherwise will raise an error. This is because context parallelism does
not support attention masks. This function modifies the model in place.
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, I was wondering how to handle this myself.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I see you don't actually check whether the mask is causal yet.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes we should do that eventually, will revisit, currently we only warn in docs.

if not args.use_parallelism_config:
return current_env

prefix = "PARALLELISM_CONFIG_"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
prefix = "PARALLELISM_CONFIG_"
prefix = "ACCELERATE_PARALLELISM_CONFIG_"

Similar to comment above about accelerate using namespaced env vars

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants