Skip to content

Commit cb343c6

Browse files
WoosungMyungS1ro1
andauthored
Add Parallelism getter property to Accelerator class (#3703)
* Add rank property to Accelerator class Signed-off-by: WoosungMyung <[email protected]> * Raise errors when parallelism configuration is not enabled Signed-off-by: WoosungMyung <[email protected]> * Fix: PR feedback Signed-off-by: WoosungMyung <[email protected]> * Fix: style --------- Signed-off-by: WoosungMyung <[email protected]> Co-authored-by: S1ro1 <[email protected]>
1 parent 9359a01 commit cb343c6

File tree

1 file changed

+50
-0
lines changed

1 file changed

+50
-0
lines changed

src/accelerate/accelerator.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -784,6 +784,56 @@ def _setup_parallelism_config(
784784

785785
return parallelism_config
786786

787+
@property
788+
def tensor_parallel_rank(self) -> int:
789+
"""
790+
Returns the local rank for tensor parallelism. If tensor parallelism is configured but not enabled, returns 0
791+
since all ranks are assumed to be the same.
792+
"""
793+
if self.parallelism_config:
794+
if self.parallelism_config.tp_enabled:
795+
return self.torch_device_mesh.get_local_rank("tp")
796+
return 0
797+
raise RuntimeError("Tensor parallelism is not configured. Set `parallelism_config` first.")
798+
799+
@property
800+
def pipeline_parallel_rank(self) -> int:
801+
"""
802+
Pipeline parallelism is not supported yet.
803+
"""
804+
raise NotImplementedError("Pipeline parallelism is currently not supported in Accelerate.")
805+
806+
@property
807+
def context_parallel_rank(self) -> int:
808+
"""
809+
Context parallelism is not supported yet.
810+
"""
811+
raise NotImplementedError("Context parallelism is currently not supported in Accelerate.")
812+
813+
@property
814+
def data_parallel_rank(self) -> int:
815+
"""
816+
Returns the local rank for replicate-based data parallelism. If replicate-based data parallelism is configured
817+
but not enabled, returns 0 since all ranks are assumed to be the same.
818+
"""
819+
if self.parallelism_config:
820+
if self.parallelism_config.dp_replicate_enabled:
821+
return self.torch_device_mesh.get_local_rank("dp_replicate")
822+
return 0
823+
raise RuntimeError("Data parallelism is not configured. Set `parallelism_config` first.")
824+
825+
@property
826+
def data_parallel_shard_rank(self) -> int:
827+
"""
828+
Returns the local rank for shard-based data parallelism. If shard-based data parallelism is configured but not
829+
enabled, returns 0 since all ranks are assumed to be the same.
830+
"""
831+
if self.parallelism_config:
832+
if self.parallelism_config.dp_shard_enabled:
833+
return self.torch_device_mesh.get_local_rank("dp_shard")
834+
return 0
835+
raise RuntimeError("Shard-based data parallelism is not configured. Set `parallelism_config` first.")
836+
787837
def _build_torch_device_mesh(self, parallelism_config):
788838
if PartialState._shared_state != {} and getattr(PartialState(), "device_mesh", None) is not None:
789839
device_mesh = PartialState().device_mesh

0 commit comments

Comments
 (0)