-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Feat: context parallel v2.0 #3700
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
35c24ff
to
314ccc7
Compare
Supersedes #3604 as I can't bother with fixing git on that branch. |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice job ! Thanks for integrating
CP using parallelism_config
! Left a couple of comments
- `no_restore_buffers`: The implementation of context parallelism modifies the buffers in-place, converting them to `torch.distributed.tensor.Dtensor`s. After the context manager is exited, a communication kernel would need to be launched to restore the buffers to their original state (usually all-gather). This takes some time, so it is recommended to pass the same tensors as in the `buffers` argument, to avoid unnecessary communication, unless you are sure that you need to use the buffers after the context manager is exited. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
since this is recommended to pass the same tensor, can't we change the default to that ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO this should be opt-in
instead of opt-out
as by default we want to leave the buffers unchanged.
The context manager takes a few arguments, that are used to configure the context parallelism. | ||
|
||
- `buffers`: This is a list of tensors that are to be sharded across the sequence dimension. These tensors are usually input ids, labels and attention mask. | ||
- `buffer_seq_dims`: This is a list of integers, that specify the sequence dimension of the buffers, in the order of the `buffers` list. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we add more details to that somewhere ? some ppl might not understand what you mean by sequence dimension of the buffers. Maybe in the example ?
examples/fsdp2/nd_parallel.py
Outdated
# To get the proper loss value, we need to average across devices that are participating in data parallel/context parallel training | ||
loss_reduce_grp = ( | ||
accelerator.torch_device_mesh["dp_cp"].get_group() | ||
if accelerator.parallelism_config.dp_cp_dim_names | ||
else None | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice use of the device mesh
examples/fsdp2/nd_parallel.py
Outdated
state_dict_type="SHARDED_STATE_DICT", | ||
activation_checkpointing=True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why add that ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mb, was supposed to be removed, this is needed for the 128k seq-len example I was testing so I forgot it in
src/accelerate/accelerator.py
Outdated
# TODO: Siro - figure out a better place where this can go (needs to be above AcceleratorState init) | ||
if parallelism_config and parallelism_config.cp_enabled and fsdp_plugin is None: | ||
raise ValueError( | ||
"`cp_enabled` is set to `True` in the `parallelism_config`, but no `fsdp_plugin` was provided. We need a `fsdp_plugin` to use `cp_enabled=True`, as we also shard the model across the device mesh to save more memory" | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why ? Can't we put that in _validate_accelerator
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In AcceleratorState we already need to set stuff based on cp being enabled, we would get a default fsdp plugin out -> no way to detect if it was not passed in.
EDIT: moved this into acceleratorState (still more viable imo)
self.parallelism_config is not None and self.parallelism_config.cp_enabled | ||
): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure this is needed no ? in which case fsdp is not used but still we set self.distributed_type = DistributedType.FSDP
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice ! LGTM !
if self.parallelism_config and self.parallelism_config.cp_enabled and fsdp_plugin is None: | ||
raise ValueError( | ||
"`cp_size > 1` in the `parallelism_config`, but no `fsdp_plugin` was provided. We need a `fsdp_plugin` to use `cp_enabled=True`, as we also shard the model across the device mesh to save more memory" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc might not be optimal for downstream lib like axolotl cc @winglian
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, we definitely support CP without FSDP, so this would break that feature. Maybe some other sort of explicit setting that a user is enabling accelerate to handle CP for them? @djsaunde
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm curious why this is gated in the first place, can we not use CP in accelerate sans FSDP? They should be independent.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm curious why this is gated in the first place, can we not use CP in accelerate sans FSDP? They should be independent.
They are (of sorts) but FSDP is a free lunch with CP, so imo it should be the default. While we're computing the ring-attention we can prefetch next fsdp_layer for free, giving us 1/cp_size*fsdp_size savings in model/optimizer/grads. I have some profiling for this in the concept guide.
TLDR: it can be independent, but there's (almost) no world where it's worth to not do FSDP on top.
@@ -1497,6 +1502,9 @@ def prepare(self, *args, device_placement=None): | |||
if self.parallelism_config and self.parallelism_config.tp_enabled: | |||
args = self._prepare_tp(*args) | |||
|
|||
if self.parallelism_config and self.parallelism_config.cp_enabled: | |||
args = self._prepare_cp(*args) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am worried that in Axolotl, that automatically handling this might break existing context-parallel handling we have. @djsaunde
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I think it will. ring-flash-attn
(what we use) supports non-causal masks so e.g. the mask deletion / replacement with causal=True
is not good for us. We could patch over this maybe? I like the way pre-hooks are handled here in the accelerator, so we could swap to that vs. our context manager.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can probably just delete the hook? It's gonna be the first one always, as I use prepend=True
so should be pretty simple! The hook is the only thing we add for cp to work, then you only would need to use the context manager (which you won't)
self.dp_replicate_size = int(os.environ.get("PARALLELISM_CONFIG_DP_REPLICATE_SIZE", "1")) | ||
if self.dp_shard_size is None: | ||
self.dp_shard_size = int(os.environ.get("PARALLELISM_CONFIG_DP_SHARD_SIZE", "1")) | ||
if self.tp_size is None: | ||
self.tp_size = int(os.environ.get("PARALLELISM_CONFIG_TP_SIZE", "1")) | ||
if self.cp_size is None: | ||
self.cp_size = int(os.environ.get("PARALLELISM_CONFIG_CP_SIZE", "1")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
aren't most accelerate env vars prefixed with ACCELERATE_
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh yes good catch, can you fix that @S1ro1 ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not entirely true actually, paradigm variables are prefixed with ACCELERATE (such as use_fsdp, use_deepspeed), then variables configuring the underlying implementations are prefixed with the impl name (i.e. FSDP_, MEGATRON_LM_) etc. Such as here or here
Afaik only DeepSpeed is special and prefixes with ACCELERATE_DEEPSPEED_ making the env variables insanely long
This function attaches forward_pre_hooks to each self_attn module of the model, where each hook checks the | ||
args/kwargs, if they contain an attention mask, if it does, it will remove this mask, check if it is a causal mask, | ||
if yes, will add a kwarg `is_causal=True`, otherwise will raise an error. This is because context parallelism does | ||
not support attention masks. This function modifies the model in place. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice, I was wondering how to handle this myself.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, I see you don't actually check whether the mask is causal yet.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes we should do that eventually, will revisit, currently we only warn in docs.
if not args.use_parallelism_config: | ||
return current_env | ||
|
||
prefix = "PARALLELISM_CONFIG_" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
prefix = "PARALLELISM_CONFIG_" | |
prefix = "ACCELERATE_PARALLELISM_CONFIG_" |
Similar to comment above about accelerate using namespaced env vars
Integrates CP seamlessly with previously merged
ParallelismConfig
. Builds on top of #3604 (which it supersedes) and so on.