16
16
from contextlib import contextmanager , nullcontext
17
17
from datetime import timedelta
18
18
from pathlib import Path
19
- from typing import TYPE_CHECKING , Any , Callable , Dict , Generator , List , Literal , Mapping , Optional , Set , Type , Union
19
+ from typing import TYPE_CHECKING , Any , Callable , Dict , Generator , List , Literal , Mapping , Optional , Set , Type , Union , Tuple
20
20
21
21
import torch
22
22
from lightning_utilities .core .rank_zero import rank_zero_only as utils_rank_zero_only
57
57
from lightning .fabric .utilities .imports import (
58
58
_TORCH_GREATER_EQUAL_2_0 ,
59
59
_TORCH_GREATER_EQUAL_2_1 ,
60
+ _TORCH_GREATER_EQUAL_2_2 ,
60
61
)
61
62
from lightning .fabric .utilities .init import _EmptyInit
62
63
from lightning .fabric .utilities .load import _lazy_load , _materialize_tensors
85
86
86
87
_SHARDING_STRATEGY = Union [ShardingStrategy , Literal ["FULL_SHARD" , "SHARD_GRAD_OP" , "NO_SHARD" , "HYBRID_SHARD" ]]
87
88
89
+ if _TORCH_GREATER_EQUAL_2_2 :
90
+ from torch .distributed ._tensor import DeviceMesh
91
+ else :
92
+ DeviceMesh = None
93
+
88
94
89
95
log = logging .getLogger (__name__ )
90
96
@@ -162,6 +168,7 @@ def __init__(
162
168
activation_checkpointing_policy : Optional ["_POLICY" ] = None ,
163
169
sharding_strategy : "_SHARDING_STRATEGY" = "FULL_SHARD" ,
164
170
state_dict_type : Literal ["full" , "sharded" ] = "full" ,
171
+ device_mesh : Optional [Union [Tuple [int ], "DeviceMesh" ]] = None ,
165
172
** kwargs : Any ,
166
173
) -> None :
167
174
super ().__init__ (
@@ -177,6 +184,12 @@ def __init__(
177
184
self .cpu_offload = _init_cpu_offload (cpu_offload )
178
185
self .mixed_precision = mixed_precision
179
186
self .kwargs = _auto_wrap_policy_kwargs (auto_wrap_policy , kwargs )
187
+
188
+ if device_mesh is not None :
189
+ if not _TORCH_GREATER_EQUAL_2_2 :
190
+ raise ValueError ("The device_mesh argument is only supported in torch >= 2.2." )
191
+ self .kwargs ["device_mesh" ] = device_mesh
192
+
180
193
self .sharding_strategy = _init_sharding_strategy (sharding_strategy , self .kwargs )
181
194
182
195
if _TORCH_GREATER_EQUAL_2_0 :
0 commit comments