Skip to content

Add Parallelism getter property to Accelerator class #3703

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Aug 2, 2025
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -784,6 +784,52 @@ def _setup_parallelism_config(

return parallelism_config

@property
def tensor_parallel_rank(self) -> int:
"""
Returns the local rank for tensor parallelism.
Raises an error if tensor parallelism is not enabled.
"""
if not self.parallelism_config or not self.parallelism_config.tp_enabled:
raise RuntimeError("Tensor parallelism is not enabled. Please check your configuration.")
return self.torch_device_mesh.get_local_rank("tp")

@property
def pipeline_parallel_rank(self) -> int:
"""
Returns the local rank for pipeline parallelism.
Not implemented in Accelerate.
"""
raise NotImplementedError("Pipeline parallelism is currently not supported in Accelerate.")

@property
def context_parallel_rank(self) -> int:
"""
Returns the local rank for context parallelism.
Not implemented in Accelerate.
"""
raise NotImplementedError("Context parallelism is currently not supported in Accelerate.")

@property
def data_parallel_rank(self) -> int:
"""
Returns the local rank for replicate-based data parallelism (e.g., DDP).
Raises an error if not enabled.
"""
if not self.parallelism_config or not self.parallelism_config.dp_replicate_enabled:
raise RuntimeError("Replicate-based data parallelism is not enabled. Please check your configuration.")
return self.torch_device_mesh.get_local_rank("dp_replicate")

@property
def data_parallel_shard_rank(self) -> int:
"""
Returns the local rank for shard-based data parallelism (e.g., FSDP).
Raises an error if not enabled.
"""
if not self.parallelism_config or not self.parallelism_config.dp_shard_enabled:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we want to adapt this a little, sorry for wrong assumption yesterday. I'd say the best thing UX wise is: return the rank if enabled, return 0 if parallelism config is enabled but no parallelism {x} is enabled (in this case all ranks are esentially 0) and raise a RuntimeError if neither of these 2 met. Do you think that is reasonable?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@S1ro1
Thanks for your fast feedback! I think it makes sense that parallelism configuration is set & parallelism is not enable, it should return rank 0 for all process. I'd appreciate it if you could take a look and let me know your thoughts.

raise RuntimeError("Shard-based data parallelism is not enabled. Please check your configuration.")
return self.torch_device_mesh.get_local_rank("dp_shard")

def _build_torch_device_mesh(self, parallelism_config):
if PartialState._shared_state != {} and getattr(PartialState(), "device_mesh", None) is not None:
device_mesh = PartialState().device_mesh
Expand Down