@@ -784,6 +784,56 @@ def _setup_parallelism_config(
784
784
785
785
return parallelism_config
786
786
787
+ @property
788
+ def tensor_parallel_rank (self ) -> int :
789
+ """
790
+ Returns the local rank for tensor parallelism. If tensor parallelism is configured but not enabled, returns 0
791
+ since all ranks are assumed to be the same.
792
+ """
793
+ if self .parallelism_config :
794
+ if self .parallelism_config .tp_enabled :
795
+ return self .torch_device_mesh .get_local_rank ("tp" )
796
+ return 0
797
+ raise RuntimeError ("Tensor parallelism is not configured. Set `parallelism_config` first." )
798
+
799
+ @property
800
+ def pipeline_parallel_rank (self ) -> int :
801
+ """
802
+ Pipeline parallelism is not supported yet.
803
+ """
804
+ raise NotImplementedError ("Pipeline parallelism is currently not supported in Accelerate." )
805
+
806
+ @property
807
+ def context_parallel_rank (self ) -> int :
808
+ """
809
+ Context parallelism is not supported yet.
810
+ """
811
+ raise NotImplementedError ("Context parallelism is currently not supported in Accelerate." )
812
+
813
+ @property
814
+ def data_parallel_rank (self ) -> int :
815
+ """
816
+ Returns the local rank for replicate-based data parallelism. If replicate-based data parallelism is configured
817
+ but not enabled, returns 0 since all ranks are assumed to be the same.
818
+ """
819
+ if self .parallelism_config :
820
+ if self .parallelism_config .dp_replicate_enabled :
821
+ return self .torch_device_mesh .get_local_rank ("dp_replicate" )
822
+ return 0
823
+ raise RuntimeError ("Data parallelism is not configured. Set `parallelism_config` first." )
824
+
825
+ @property
826
+ def data_parallel_shard_rank (self ) -> int :
827
+ """
828
+ Returns the local rank for shard-based data parallelism. If shard-based data parallelism is configured but not
829
+ enabled, returns 0 since all ranks are assumed to be the same.
830
+ """
831
+ if self .parallelism_config :
832
+ if self .parallelism_config .dp_shard_enabled :
833
+ return self .torch_device_mesh .get_local_rank ("dp_shard" )
834
+ return 0
835
+ raise RuntimeError ("Shard-based data parallelism is not configured. Set `parallelism_config` first." )
836
+
787
837
def _build_torch_device_mesh (self , parallelism_config ):
788
838
if PartialState ._shared_state != {} and getattr (PartialState (), "device_mesh" , None ) is not None :
789
839
device_mesh = PartialState ().device_mesh
0 commit comments