-
-
Notifications
You must be signed in to change notification settings - Fork 1.1k
[prototype] N-D distributed parallel support #2947
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Important Review skippedDraft detected. Please check the settings in the CodeRabbit UI or the You can disable this status message by setting the Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
Documentation and Community
|
📖 Documentation Preview: https://6881221bb6eb5599a6a7bb60--resonant-treacle-0fd729.netlify.app Deployed on Netlify from commit ad70741 |
Codecov Report❌ Patch coverage is 📢 Thoughts on this report? Let us know! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just leaving a few comments after a quick glance.
Question: can we add support for the HYBRID_SHARD fsdp2 strategy as part of this PR (i.e., replicating across nodes while sharding intra node)? Or maybe it doesn't quite fit here?
src/axolotl/core/builders/base.py
Outdated
if ( | ||
self.cfg.tensor_parallel_size > 1 | ||
or (self.cfg.dp_shard_size and self.cfg.dp_shard_size > 1) | ||
or self.cfg.sequence_parallel_degree > 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
aside: we should rename to context_parallel_size
to better match the ecosystem and deprecate sequence_parallel_size
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
migrated sequence_parallel_degree
-> context_parallel_size
|
||
if self.cfg.tensor_parallel_size and self.cfg.tensor_parallel_size > 1: | ||
# from pretrained only needs to know about Tensor Parallelism for sharding weights | ||
dist_parallel = DistParallel.build( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I notice we build this three separate times. Is it possible to just have a singleton instance that we build once and refer to thereafter?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought about doing a singleton, but then didn't want to deal with the complexity of singletons in CI. Happy to give it a whirl though if we're confident that this shouldn't change ever during runtime
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alternatively, could we have a utility function that accepts a cfg
and return the device mesh / distparallel?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you actually need this on your side? It feels to me like you can reuse ParallelismConfig
from upstream for this and not repeat code that much. If you need extra functionality, open to suggestions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can swap this out once upstream changes land.
try: | ||
submesh_dp_fsdp_size = torch_device_mesh["dp_fsdp"].size() | ||
except KeyError: | ||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe use if "dp_fsdp" in torch_device_mesh
here instead
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, is this even set anywhere?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks, I think I need to go back and fix this whole method.
71d1133
to
2d024be
Compare
@model_validator(mode="before") | ||
@classmethod | ||
def check_dp_shard_size_fsdp(cls, data): | ||
if data.get("fsdp_config", {}): | ||
dp_shard_size = data.get("dp_shard_size", None) | ||
if dp_shard_size and dp_shard_size <= 1: | ||
raise ValueError("dp_shard_size must be greater than 1 when using FSDP") | ||
if not data.get("fsdp_config", {}) and data.get("dp_shard_size", 1) > 1: | ||
raise ValueError("FSDP should be enabled when using dp_shard_size > 1") | ||
return data | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@model_validator(mode="before") | |
@classmethod | |
def check_dp_shard_size_fsdp(cls, data): | |
if data.get("fsdp_config", {}): | |
dp_shard_size = data.get("dp_shard_size", None) | |
if dp_shard_size and dp_shard_size <= 1: | |
raise ValueError("dp_shard_size must be greater than 1 when using FSDP") | |
if not data.get("fsdp_config", {}) and data.get("dp_shard_size", 1) > 1: | |
raise ValueError("FSDP should be enabled when using dp_shard_size > 1") | |
return data | |
@model_validator(mode="before") | |
@classmethod | |
def check_dp_shard_size_fsdp(cls, data): | |
fsdp_config = data.get("fsdp_config", {}) | |
dp_shard_size = data.get("dp_shard_size", None) | |
has_fsdp = bool(fsdp_config) | |
dp_sharding_enabled = dp_shard_size and dp_shard_size > 1 | |
if has_fsdp and dp_shard_size and dp_shard_size <= 1: | |
raise ValueError("dp_shard_size must be greater than 1 when using FSDP") | |
if not has_fsdp and dp_sharding_enabled: | |
raise ValueError("FSDP should be enabled when using dp_shard_size > 1") | |
return data | |
Small refactor
if data.get("liger_rms_norm") and data.get("tensor_parallel_size", 1) <= 1: | ||
raise ValueError( | ||
"`liger_rms_norm` is incompatible with tensor parallelism, " | ||
"see https://github.com/linkedin/Liger-Kernel/issues/826" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if data.get("liger_rms_norm") and data.get("tensor_parallel_size", 1) <= 1: | |
raise ValueError( | |
"`liger_rms_norm` is incompatible with tensor parallelism, " | |
"see https://github.com/linkedin/Liger-Kernel/issues/826" | |
) | |
if data.get("liger_rms_norm") and data.get("tensor_parallel_size", 1) > 1: | |
raise ValueError( | |
"`liger_rms_norm` is incompatible with tensor parallelism, " | |
"see https://github.com/linkedin/Liger-Kernel/issues/826" | |
) |
Do you mean this condition?
|
||
if self.cfg.tensor_parallel_size and self.cfg.tensor_parallel_size > 1: | ||
# from pretrained only needs to know about Tensor Parallelism for sharding weights | ||
dist_parallel = DistParallel.build( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alternatively, could we have a utility function that accepts a cfg
and return the device mesh / distparallel?
if ( | ||
not ( | ||
self.distributed_type | ||
== accelerate.accelerator.DistributedType.DEEPSPEED | ||
and hasattr(self.state, "ds_device_mesh") | ||
) | ||
and self.world_device_mesh is not None | ||
): | ||
return self.world_device_mesh |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it be better to check, if distributed_type == fsdp, to return world_device_mesh?
device_mesh = dist_parallel.get_device_mesh() | ||
self.model_kwargs["tp_size"] = self.cfg.tensor_parallel_size | ||
self.model_kwargs["tp_plan"] = "auto" | ||
self.model_kwargs["device_mesh"] = device_mesh[("tp",)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
real nice
9fabce8
to
ad70741
Compare
Description
based on @SalmanMohammadi's upstream accelerate PR (https://github.com/huggingface/accelerate/pull/3682/files), I was able to hack together an implementation that needs no upstream changes so we can experiment on what the actual needed changes should be.
One caveat is that accelerate's
DataLoaderDispatcher
does not support process & iterate on rank0 and dispatch to other ranks. e.g.accelerator_config.dispatch_batches: true
. We should also test that this works for various combinations of w/o dispatching as well as split batches, etc.Combinations to test:
TP support doesn't seem to work properly in transformers w/o Deepspeed AutoTP
related PRs:
device_mesh
have multiple dim huggingface/transformers#38949