File tree Expand file tree Collapse file tree 1 file changed +12
-3
lines changed Expand file tree Collapse file tree 1 file changed +12
-3
lines changed Original file line number Diff line number Diff line change @@ -63,6 +63,17 @@ def __bool__(self):
63
63
_unassigned_axis = _UnassignedAxis ()
64
64
65
65
66
+ def is_cpu_platform (mesh : jax .sharding .Mesh | None ):
67
+ if mesh is None :
68
+ if _global_mesh_defined ():
69
+ device = pxla .thread_resources .env .physical_mesh .devices .reshape (- 1 )[0 ]
70
+ else :
71
+ device = jax .devices ()[0 ]
72
+ else :
73
+ device = mesh .devices .reshape (- 1 )[0 ]
74
+ return device .platform == 'cpu'
75
+
76
+
66
77
def _mesh_assignment_free (new_assignment , existing_assignments ):
67
78
"""Determines if a given mesh axis has already been assigned."""
68
79
new = set (jax .tree_util .tree_leaves (new_assignment ))
@@ -197,9 +208,7 @@ def _with_sharding_constraint(
197
208
mesh : jax .sharding .Mesh | None = None ,
198
209
):
199
210
"""Wrapper for lax.with_sharding_constraint, no-op on cpu or outside jit."""
200
- if jax .devices ()[0 ].platform == 'cpu' or (
201
- not _global_mesh_defined () and mesh is None
202
- ):
211
+ if is_cpu_platform (mesh ) or (not _global_mesh_defined () and mesh is None ):
203
212
return x
204
213
else :
205
214
if mesh is not None and axis_resources is not None :
You can’t perform that action at this time.
0 commit comments