Skip to content

Commit 1ec5ef2

Browse files
marksandler2Flax Authors
authored andcommitted
Fixes spmd to work correctly with xaot compilation by using global mesh's device instead of jax.devices()[0]
PiperOrigin-RevId: 729357183
1 parent 88ea291 commit 1ec5ef2

File tree

1 file changed

+12
-3
lines changed

1 file changed

+12
-3
lines changed

flax/linen/spmd.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,17 @@ def __bool__(self):
6363
_unassigned_axis = _UnassignedAxis()
6464

6565

66+
def is_cpu_platform(mesh: jax.sharding.Mesh | None):
67+
if mesh is None:
68+
if _global_mesh_defined():
69+
device = pxla.thread_resources.env.physical_mesh.devices.reshape(-1)[0]
70+
else:
71+
device = jax.devices()[0]
72+
else:
73+
device = mesh.devices.reshape(-1)[0]
74+
return device.platform == 'cpu'
75+
76+
6677
def _mesh_assignment_free(new_assignment, existing_assignments):
6778
"""Determines if a given mesh axis has already been assigned."""
6879
new = set(jax.tree_util.tree_leaves(new_assignment))
@@ -197,9 +208,7 @@ def _with_sharding_constraint(
197208
mesh: jax.sharding.Mesh | None = None,
198209
):
199210
"""Wrapper for lax.with_sharding_constraint, no-op on cpu or outside jit."""
200-
if jax.devices()[0].platform == 'cpu' or (
201-
not _global_mesh_defined() and mesh is None
202-
):
211+
if is_cpu_platform(mesh) or (not _global_mesh_defined() and mesh is None):
203212
return x
204213
else:
205214
if mesh is not None and axis_resources is not None:

0 commit comments

Comments
 (0)