Skip to content

Commit 41bef07

Browse files
author
Flax Authors
committed
Merge pull request google#4490 from google:nnx-shard-map
PiperOrigin-RevId: 730677977
2 parents dd6e595 + 10d8e5c commit 41bef07

File tree

6 files changed

+433
-5
lines changed

6 files changed

+433
-5
lines changed

docs_nnx/api_reference/flax.nnx/transforms.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ transforms
1515

1616
.. autofunction:: grad
1717
.. autofunction:: jit
18+
.. autofunction:: shard_map
1819
.. autofunction:: remat
1920
.. autofunction:: scan
2021
.. autofunction:: value_and_grad

flax/nnx/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@
147147
from .transforms.autodiff import custom_vjp as custom_vjp
148148
from .transforms.autodiff import remat as remat
149149
from .transforms.compilation import jit as jit
150+
from .transforms.compilation import shard_map as shard_map
150151
from .transforms.compilation import StateSharding as StateSharding
151152
from .transforms.iteration import Carry as Carry
152153
from .transforms.iteration import scan as scan

flax/nnx/transforms/compilation.py

Lines changed: 295 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,11 @@
1717
import functools
1818
import typing as tp
1919

20+
import jax.experimental
21+
import jax.experimental.shard_map
22+
from jax.sharding import PartitionSpec
23+
from jax.sharding import Mesh, AbstractMesh
24+
2025
from flax.nnx import (
2126
extract,
2227
filterlib,
@@ -31,7 +36,8 @@
3136
from flax.typing import Missing
3237

3338
F = tp.TypeVar('F', bound=tp.Callable[..., tp.Any])
34-
39+
Specs = tp.Any
40+
AxisName = tp.Hashable
3541

3642
# -------------------------------
3743
# jit
@@ -341,7 +347,6 @@ def jit_wrapper(*args, **kwargs):
341347
check_aliasing=in_shardings is not None or kwarg_shardings is not None,
342348
ctxtag=jit_wrapper,
343349
)
344-
jax_in_shardings, kwarg_shardings, jax_out_shardings
345350
pure_args_out, pure_kwargs_out, pure_out = jitted_fn(
346351
*pure_args, **pure_kwargs
347352
)
@@ -371,3 +376,291 @@ def jit_wrapper(*args, **kwargs):
371376
jit_wrapper.inner = jitted_fn # type: ignore
372377

373378
return jit_wrapper # type: ignore
379+
380+
# -------------------------------
381+
# shard_map
382+
# -------------------------------
383+
384+
# TODO: create StateSpec and consider enabling a mode that does
385+
# not use filters during split for performance. Overall there might
386+
# be performance limitations for using shard_map at a top-level
387+
388+
@dataclasses.dataclass(eq=False)
389+
class ShardMapFn:
390+
f: tp.Callable[..., tp.Any]
391+
in_specs: tp.Any
392+
out_specs: tp.Any
393+
kwarg_specs: tp.Any
394+
ctxtag: tp.Hashable
395+
396+
def __post_init__(self):
397+
functools.update_wrapper(self, self.f)
398+
399+
def __call__(self, *pure_args, **pure_kwargs):
400+
args, kwargs = extract.from_tree(
401+
(pure_args, pure_kwargs),
402+
merge_fn=_jit_merge_fn,
403+
ctxtag=self.ctxtag,
404+
is_inner=True,
405+
)
406+
407+
out = self.f(*args, **kwargs)
408+
409+
args_out, kwargs_out = extract.clear_non_graph_nodes((args, kwargs))
410+
pure_args_out, pure_kwargs_out, pure_out = extract.to_tree(
411+
(args_out, kwargs_out, out),
412+
prefix=(self.in_specs, self.kwarg_specs, self.out_specs),
413+
ctxtag=self.ctxtag,
414+
split_fn=_jit_split_fn,
415+
)
416+
417+
return pure_args_out, pure_kwargs_out, pure_out
418+
419+
420+
@tp.overload
421+
def shard_map(
422+
f: F,
423+
*,
424+
mesh: Mesh | AbstractMesh,
425+
in_specs: Specs,
426+
out_specs: Specs,
427+
check_rep: bool = True,
428+
auto: frozenset[AxisName] = frozenset(),
429+
) -> F: ...
430+
@tp.overload
431+
def shard_map(
432+
*,
433+
mesh: Mesh | AbstractMesh,
434+
in_specs: Specs,
435+
out_specs: Specs,
436+
check_rep: bool = True,
437+
auto: frozenset[AxisName] = frozenset(),
438+
) -> tp.Callable[[F], F]: ...
439+
def shard_map(
440+
f: F | type[Missing] = Missing,
441+
*,
442+
mesh: Mesh | AbstractMesh,
443+
in_specs: Specs,
444+
out_specs: Specs,
445+
check_rep: bool = True,
446+
auto: frozenset[AxisName] = frozenset(),
447+
) -> F | tp.Callable[[F], F]:
448+
"""
449+
Lifted version of
450+
`jax.experimental.shard_map.shard_map <https://docs.jax.dev/en/latest/_autosummary/jax.experimental.shard_map.shard_map.html>`_
451+
that can handle Modules / graph nodes as arguments.
452+
453+
Simple data parallel example::
454+
455+
import jax
456+
import jax.numpy as jnp
457+
from flax import nnx
458+
from jax.sharding import PartitionSpec as P
459+
460+
mesh = jax.sharding.Mesh(jax.local_devices(), ('data',))
461+
462+
m = nnx.Linear(2, 3, rngs=nnx.Rngs(0))
463+
x = jnp.ones((32, 2))
464+
465+
@nnx.shard_map(
466+
mesh=mesh, in_specs=(P(None), P('data')), out_specs=P('data')
467+
)
468+
def f(m, x):
469+
return m(x)
470+
471+
y = f(m, x)
472+
473+
jax.debug.visualize_array_sharding(y)
474+
475+
Notice that here we simply used some ``PartitionSpec`` to define the spec
476+
the the whole model and data. This works for simple cases but if we need
477+
to assign different ``PartitionSpec`` to different parts of the model we
478+
need to use ``StateSharding`` and create some filters that allow us to target
479+
specific parts of the model. Here's an example of how to do tensor parallelism
480+
for a simple MLP block using ``StateSharding`` and filters::
481+
482+
mesh = jax.sharding.Mesh(jax.local_devices(), ('model',))
483+
484+
class MLP(nnx.Module):
485+
def __init__(self, din, dhidden, dout, *, rngs: nnx.Rngs):
486+
self.linear1 = nnx.Linear(din, dhidden, use_bias=False, rngs=rngs)
487+
self.linear2 = nnx.Linear(dhidden, dout, use_bias=False, rngs=rngs)
488+
489+
def __call__(self, x):
490+
return self.linear2(jax.nn.relu(self.linear1(x)))
491+
492+
m = MLP(2, 64, 3, rngs=nnx.Rngs(0))
493+
x = jnp.ones((32, 2))
494+
495+
def path_ends_with(*path_suffix): # custom filter
496+
return lambda path, value: path[-len(path_suffix):] == path_suffix
497+
498+
model_spec = nnx.StateSharding({
499+
path_ends_with('linear1', 'kernel'): P(None, 'model'),
500+
path_ends_with('linear2', 'kernel'): P('model', None),
501+
})
502+
503+
@nnx.shard_map(mesh=mesh, in_specs=(model_spec, P(None)), out_specs=P(None))
504+
def f(m, x):
505+
y = m(x)
506+
return jax.lax.psum(y, 'model')
507+
508+
y = f(m, x)
509+
510+
jax.debug.visualize_array_sharding(m.linear1.kernel.value)
511+
jax.debug.visualize_array_sharding(m.linear2.kernel.value)
512+
513+
514+
Alternatively, a ``State`` object with the exact PartitionSpec for each
515+
state then you can be passed to ``StateSharding``::
516+
517+
mesh = jax.sharding.Mesh(jax.local_devices(), ('model',))
518+
519+
class MLP(nnx.Module):
520+
def __init__(self, din, dhidden, dout, *, rngs: nnx.Rngs):
521+
self.linear1 = nnx.Linear(din, dhidden, use_bias=False, rngs=rngs)
522+
self.linear2 = nnx.Linear(dhidden, dout, use_bias=False, rngs=rngs)
523+
524+
def __call__(self, x):
525+
return self.linear2(jax.nn.relu(self.linear1(x)))
526+
527+
m = MLP(2, 64, 3, rngs=nnx.Rngs(0))
528+
x = jnp.ones((32, 2))
529+
530+
model_spec = nnx.State(
531+
{
532+
'linear1': {'kernel': P(None, 'model')},
533+
'linear2': {'kernel': P('model', None)},
534+
}
535+
)
536+
537+
@nnx.shard_map(
538+
mesh=mesh,
539+
in_specs=(nnx.StateSharding(model_spec), P(None)),
540+
out_specs=P(None),
541+
)
542+
def f(m, x):
543+
y = m(x)
544+
return jax.lax.psum(y, 'model')
545+
546+
y = f(m, x)
547+
548+
jax.debug.visualize_array_sharding(m.linear1.kernel.value)
549+
jax.debug.visualize_array_sharding(m.linear2.kernel.value)
550+
551+
Here ``model_spec`` was created manually but you can also automate
552+
this process by using ``nnx.get_partition_spec`` to automatically
553+
create it for you (see
554+
`Scale up on multiple devices <https://flax.readthedocs.io/en/latest/guides/flax_gspmd.html>`_
555+
).
556+
557+
Args:
558+
f: callable to be mapped. Each application of ``f``, or "instance" of ``f``,
559+
takes as input a shard of the mapped-over arguments and produces a shard
560+
of the output.
561+
mesh: a ``jax.sharding.Mesh`` representing the array of devices over which
562+
to shard the data and on which to execute instances of ``f``. The names of
563+
the ``Mesh`` can be used in collective communication operations in ``f``.
564+
This is typically created by a utility function like
565+
:func:`jax.experimental.mesh_utils.create_device_mesh`.
566+
in_specs: a pytree with ``jax.sharding.PartitionSpec``or ``nnx.StateSharding``
567+
(mapping substates to ``PartitionSpec``s) instances as leaves,
568+
with a tree structure that is a tree prefix of the
569+
args tuple to be mapped over. Similar to ``jax.sharding.NamedSharding``,
570+
each ``PartitionSpec`` represents how the corresponding argument (or subtree
571+
of arguments) should be sharded along the named axes of ``mesh``. In each
572+
``PartitionSpec``, mentioning a ``mesh`` axis name at a position expresses sharding
573+
the corresponding argument array axis along that positional axis; not
574+
mentioning an axis name expresses replication. If an argument, or argument
575+
subtree, has a corresponding spec of None, that argument is not sharded.
576+
out_specs: a pytree with ``jax.sharding.PartitionSpec`` or ``nnx.StateSharding``
577+
(mapping substates to ``PartitionSpec``s) instances as leaves, with a tree structure
578+
that is a tree prefix of the output of ``f``.
579+
Each ``PartitionSpec`` represents how the corresponding output shards should be
580+
concatenated. In each ``PartitionSpec``, metioning a ``mesh`` axis name at
581+
a position expresses concatenation of that mesh axis's shards along the
582+
corresponding positional axis. Not mentioning a ``mesh`` axis name
583+
expresses a promise that the output values are equal along that mesh axis,
584+
and that rather than concatenating only a single value should be produced.
585+
check_rep: If True (default) enable additional validity checks and automatic
586+
differentiation optimizations. The validity checks concern whether any mesh
587+
axis names not mentioned in ``out_specs`` are consistent with how the outputs
588+
of ``f`` are replicated. Must be set False if using a Pallas kernel in ``f``.
589+
auto: (experimental) an optional set of axis names from ``mesh`` over which we
590+
do not shard the data or map the function, but rather we allow the
591+
compiler to control sharding. These names cannot be used in ``in_specs``,
592+
``out_specs``, or in communication collectives in ``f``.
593+
594+
Returns:
595+
A callable that applies the input function ``f`` across data sharded according to
596+
the ``mesh`` and ``in_specs``.
597+
"""
598+
if f is Missing:
599+
return functools.partial(
600+
shard_map,
601+
mesh=mesh,
602+
in_specs=in_specs,
603+
out_specs=out_specs,
604+
check_rep=check_rep,
605+
auto=auto,
606+
) # type: ignore[return-value]
607+
assert not isinstance(f, type)
608+
609+
kwarg_specs = PartitionSpec()
610+
jax_in_specs = jax.tree.map(
611+
lambda x: extract.NodeStates(
612+
_graphdef=PartitionSpec(), # type: ignore[arg-type]
613+
states=x.shardings,
614+
metadata=x,
615+
)
616+
if isinstance(x, StateSharding)
617+
else x,
618+
in_specs,
619+
)
620+
jax_out_specs = jax.tree.map(
621+
lambda x: extract.NodeStates(
622+
_graphdef=PartitionSpec(), # type: ignore[arg-type]
623+
states=x.shardings,
624+
metadata=x,
625+
)
626+
if isinstance(x, StateSharding)
627+
else x,
628+
out_specs,
629+
)
630+
631+
@functools.wraps(f)
632+
def shard_map_wrapper(*args, **kwargs):
633+
# run dynamic_cache_context before update_context
634+
with graph.update_context(shard_map_wrapper):
635+
pure_args, pure_kwargs = extract.to_tree(
636+
(args, kwargs),
637+
prefix=(in_specs, kwarg_specs)
638+
if in_specs is not None or kwarg_specs is not None
639+
else None,
640+
split_fn=_jit_split_fn,
641+
check_aliasing=in_specs is not None or kwarg_specs is not None,
642+
ctxtag=shard_map_wrapper,
643+
)
644+
pure_args_out, pure_kwargs_out, pure_out = shard_map_fn(
645+
*pure_args, **pure_kwargs
646+
)
647+
_args_out, _kwargs_out, out = extract.from_tree(
648+
(pure_args_out, pure_kwargs_out, pure_out),
649+
merge_fn=_jit_merge_fn,
650+
is_inner=False,
651+
ctxtag=shard_map_wrapper,
652+
)
653+
return out
654+
655+
shard_map_fn = jax.experimental.shard_map.shard_map(
656+
ShardMapFn(f, in_specs, out_specs, kwarg_specs, shard_map_wrapper),
657+
mesh=mesh,
658+
in_specs=jax_in_specs,
659+
out_specs=(jax_in_specs, kwarg_specs, jax_out_specs), # type: ignore
660+
check_rep=check_rep,
661+
auto=auto,
662+
)
663+
664+
shard_map_wrapper.inner = shard_map_fn # type: ignore
665+
666+
return shard_map_wrapper # type: ignore

tests/nnx/bridge/wrappers_test.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -521,6 +521,7 @@ def __call__(self):
521521
foo.apply({})
522522

523523
def test_compact_basic(self):
524+
test = self
524525
class Linear(bridge.Module):
525526
dout: int
526527

@@ -540,11 +541,20 @@ def __call__(self, x):
540541
din = x.shape[-1]
541542
self.linear = Linear(self.dout)
542543
x = self.linear(x)
544+
545+
# NNX
546+
graphdef, state = nnx.split(self)
547+
test.assertIn('Linear_0', state)
548+
test.assertIn('w', state['Linear_0'])
549+
test.assertIn('b', state['Linear_0'])
550+
543551
return x
544552

545553
foo = Foo(5)
546554
x = jnp.ones((3, 2))
547555

556+
self.assertIsInstance(foo, nnx.Module)
557+
548558
variables = foo.init(0, x)
549559
params = variables['params']
550560

0 commit comments

Comments
 (0)