17
17
import functools
18
18
import typing as tp
19
19
20
+ import jax .experimental
21
+ import jax .experimental .shard_map
22
+ from jax .sharding import PartitionSpec
23
+ from jax .sharding import Mesh , AbstractMesh
24
+
20
25
from flax .nnx import (
21
26
extract ,
22
27
filterlib ,
31
36
from flax .typing import Missing
32
37
33
38
F = tp .TypeVar ('F' , bound = tp .Callable [..., tp .Any ])
34
-
39
+ Specs = tp .Any
40
+ AxisName = tp .Hashable
35
41
36
42
# -------------------------------
37
43
# jit
@@ -341,7 +347,6 @@ def jit_wrapper(*args, **kwargs):
341
347
check_aliasing = in_shardings is not None or kwarg_shardings is not None ,
342
348
ctxtag = jit_wrapper ,
343
349
)
344
- jax_in_shardings , kwarg_shardings , jax_out_shardings
345
350
pure_args_out , pure_kwargs_out , pure_out = jitted_fn (
346
351
* pure_args , ** pure_kwargs
347
352
)
@@ -371,3 +376,291 @@ def jit_wrapper(*args, **kwargs):
371
376
jit_wrapper .inner = jitted_fn # type: ignore
372
377
373
378
return jit_wrapper # type: ignore
379
+
380
+ # -------------------------------
381
+ # shard_map
382
+ # -------------------------------
383
+
384
+ # TODO: create StateSpec and consider enabling a mode that does
385
+ # not use filters during split for performance. Overall there might
386
+ # be performance limitations for using shard_map at a top-level
387
+
388
+ @dataclasses .dataclass (eq = False )
389
+ class ShardMapFn :
390
+ f : tp .Callable [..., tp .Any ]
391
+ in_specs : tp .Any
392
+ out_specs : tp .Any
393
+ kwarg_specs : tp .Any
394
+ ctxtag : tp .Hashable
395
+
396
+ def __post_init__ (self ):
397
+ functools .update_wrapper (self , self .f )
398
+
399
+ def __call__ (self , * pure_args , ** pure_kwargs ):
400
+ args , kwargs = extract .from_tree (
401
+ (pure_args , pure_kwargs ),
402
+ merge_fn = _jit_merge_fn ,
403
+ ctxtag = self .ctxtag ,
404
+ is_inner = True ,
405
+ )
406
+
407
+ out = self .f (* args , ** kwargs )
408
+
409
+ args_out , kwargs_out = extract .clear_non_graph_nodes ((args , kwargs ))
410
+ pure_args_out , pure_kwargs_out , pure_out = extract .to_tree (
411
+ (args_out , kwargs_out , out ),
412
+ prefix = (self .in_specs , self .kwarg_specs , self .out_specs ),
413
+ ctxtag = self .ctxtag ,
414
+ split_fn = _jit_split_fn ,
415
+ )
416
+
417
+ return pure_args_out , pure_kwargs_out , pure_out
418
+
419
+
420
+ @tp .overload
421
+ def shard_map (
422
+ f : F ,
423
+ * ,
424
+ mesh : Mesh | AbstractMesh ,
425
+ in_specs : Specs ,
426
+ out_specs : Specs ,
427
+ check_rep : bool = True ,
428
+ auto : frozenset [AxisName ] = frozenset (),
429
+ ) -> F : ...
430
+ @tp .overload
431
+ def shard_map (
432
+ * ,
433
+ mesh : Mesh | AbstractMesh ,
434
+ in_specs : Specs ,
435
+ out_specs : Specs ,
436
+ check_rep : bool = True ,
437
+ auto : frozenset [AxisName ] = frozenset (),
438
+ ) -> tp .Callable [[F ], F ]: ...
439
+ def shard_map (
440
+ f : F | type [Missing ] = Missing ,
441
+ * ,
442
+ mesh : Mesh | AbstractMesh ,
443
+ in_specs : Specs ,
444
+ out_specs : Specs ,
445
+ check_rep : bool = True ,
446
+ auto : frozenset [AxisName ] = frozenset (),
447
+ ) -> F | tp .Callable [[F ], F ]:
448
+ """
449
+ Lifted version of
450
+ `jax.experimental.shard_map.shard_map <https://docs.jax.dev/en/latest/_autosummary/jax.experimental.shard_map.shard_map.html>`_
451
+ that can handle Modules / graph nodes as arguments.
452
+
453
+ Simple data parallel example::
454
+
455
+ import jax
456
+ import jax.numpy as jnp
457
+ from flax import nnx
458
+ from jax.sharding import PartitionSpec as P
459
+
460
+ mesh = jax.sharding.Mesh(jax.local_devices(), ('data',))
461
+
462
+ m = nnx.Linear(2, 3, rngs=nnx.Rngs(0))
463
+ x = jnp.ones((32, 2))
464
+
465
+ @nnx.shard_map(
466
+ mesh=mesh, in_specs=(P(None), P('data')), out_specs=P('data')
467
+ )
468
+ def f(m, x):
469
+ return m(x)
470
+
471
+ y = f(m, x)
472
+
473
+ jax.debug.visualize_array_sharding(y)
474
+
475
+ Notice that here we simply used some ``PartitionSpec`` to define the spec
476
+ the the whole model and data. This works for simple cases but if we need
477
+ to assign different ``PartitionSpec`` to different parts of the model we
478
+ need to use ``StateSharding`` and create some filters that allow us to target
479
+ specific parts of the model. Here's an example of how to do tensor parallelism
480
+ for a simple MLP block using ``StateSharding`` and filters::
481
+
482
+ mesh = jax.sharding.Mesh(jax.local_devices(), ('model',))
483
+
484
+ class MLP(nnx.Module):
485
+ def __init__(self, din, dhidden, dout, *, rngs: nnx.Rngs):
486
+ self.linear1 = nnx.Linear(din, dhidden, use_bias=False, rngs=rngs)
487
+ self.linear2 = nnx.Linear(dhidden, dout, use_bias=False, rngs=rngs)
488
+
489
+ def __call__(self, x):
490
+ return self.linear2(jax.nn.relu(self.linear1(x)))
491
+
492
+ m = MLP(2, 64, 3, rngs=nnx.Rngs(0))
493
+ x = jnp.ones((32, 2))
494
+
495
+ def path_ends_with(*path_suffix): # custom filter
496
+ return lambda path, value: path[-len(path_suffix):] == path_suffix
497
+
498
+ model_spec = nnx.StateSharding({
499
+ path_ends_with('linear1', 'kernel'): P(None, 'model'),
500
+ path_ends_with('linear2', 'kernel'): P('model', None),
501
+ })
502
+
503
+ @nnx.shard_map(mesh=mesh, in_specs=(model_spec, P(None)), out_specs=P(None))
504
+ def f(m, x):
505
+ y = m(x)
506
+ return jax.lax.psum(y, 'model')
507
+
508
+ y = f(m, x)
509
+
510
+ jax.debug.visualize_array_sharding(m.linear1.kernel.value)
511
+ jax.debug.visualize_array_sharding(m.linear2.kernel.value)
512
+
513
+
514
+ Alternatively, a ``State`` object with the exact PartitionSpec for each
515
+ state then you can be passed to ``StateSharding``::
516
+
517
+ mesh = jax.sharding.Mesh(jax.local_devices(), ('model',))
518
+
519
+ class MLP(nnx.Module):
520
+ def __init__(self, din, dhidden, dout, *, rngs: nnx.Rngs):
521
+ self.linear1 = nnx.Linear(din, dhidden, use_bias=False, rngs=rngs)
522
+ self.linear2 = nnx.Linear(dhidden, dout, use_bias=False, rngs=rngs)
523
+
524
+ def __call__(self, x):
525
+ return self.linear2(jax.nn.relu(self.linear1(x)))
526
+
527
+ m = MLP(2, 64, 3, rngs=nnx.Rngs(0))
528
+ x = jnp.ones((32, 2))
529
+
530
+ model_spec = nnx.State(
531
+ {
532
+ 'linear1': {'kernel': P(None, 'model')},
533
+ 'linear2': {'kernel': P('model', None)},
534
+ }
535
+ )
536
+
537
+ @nnx.shard_map(
538
+ mesh=mesh,
539
+ in_specs=(nnx.StateSharding(model_spec), P(None)),
540
+ out_specs=P(None),
541
+ )
542
+ def f(m, x):
543
+ y = m(x)
544
+ return jax.lax.psum(y, 'model')
545
+
546
+ y = f(m, x)
547
+
548
+ jax.debug.visualize_array_sharding(m.linear1.kernel.value)
549
+ jax.debug.visualize_array_sharding(m.linear2.kernel.value)
550
+
551
+ Here ``model_spec`` was created manually but you can also automate
552
+ this process by using ``nnx.get_partition_spec`` to automatically
553
+ create it for you (see
554
+ `Scale up on multiple devices <https://flax.readthedocs.io/en/latest/guides/flax_gspmd.html>`_
555
+ ).
556
+
557
+ Args:
558
+ f: callable to be mapped. Each application of ``f``, or "instance" of ``f``,
559
+ takes as input a shard of the mapped-over arguments and produces a shard
560
+ of the output.
561
+ mesh: a ``jax.sharding.Mesh`` representing the array of devices over which
562
+ to shard the data and on which to execute instances of ``f``. The names of
563
+ the ``Mesh`` can be used in collective communication operations in ``f``.
564
+ This is typically created by a utility function like
565
+ :func:`jax.experimental.mesh_utils.create_device_mesh`.
566
+ in_specs: a pytree with ``jax.sharding.PartitionSpec``or ``nnx.StateSharding``
567
+ (mapping substates to ``PartitionSpec``s) instances as leaves,
568
+ with a tree structure that is a tree prefix of the
569
+ args tuple to be mapped over. Similar to ``jax.sharding.NamedSharding``,
570
+ each ``PartitionSpec`` represents how the corresponding argument (or subtree
571
+ of arguments) should be sharded along the named axes of ``mesh``. In each
572
+ ``PartitionSpec``, mentioning a ``mesh`` axis name at a position expresses sharding
573
+ the corresponding argument array axis along that positional axis; not
574
+ mentioning an axis name expresses replication. If an argument, or argument
575
+ subtree, has a corresponding spec of None, that argument is not sharded.
576
+ out_specs: a pytree with ``jax.sharding.PartitionSpec`` or ``nnx.StateSharding``
577
+ (mapping substates to ``PartitionSpec``s) instances as leaves, with a tree structure
578
+ that is a tree prefix of the output of ``f``.
579
+ Each ``PartitionSpec`` represents how the corresponding output shards should be
580
+ concatenated. In each ``PartitionSpec``, metioning a ``mesh`` axis name at
581
+ a position expresses concatenation of that mesh axis's shards along the
582
+ corresponding positional axis. Not mentioning a ``mesh`` axis name
583
+ expresses a promise that the output values are equal along that mesh axis,
584
+ and that rather than concatenating only a single value should be produced.
585
+ check_rep: If True (default) enable additional validity checks and automatic
586
+ differentiation optimizations. The validity checks concern whether any mesh
587
+ axis names not mentioned in ``out_specs`` are consistent with how the outputs
588
+ of ``f`` are replicated. Must be set False if using a Pallas kernel in ``f``.
589
+ auto: (experimental) an optional set of axis names from ``mesh`` over which we
590
+ do not shard the data or map the function, but rather we allow the
591
+ compiler to control sharding. These names cannot be used in ``in_specs``,
592
+ ``out_specs``, or in communication collectives in ``f``.
593
+
594
+ Returns:
595
+ A callable that applies the input function ``f`` across data sharded according to
596
+ the ``mesh`` and ``in_specs``.
597
+ """
598
+ if f is Missing :
599
+ return functools .partial (
600
+ shard_map ,
601
+ mesh = mesh ,
602
+ in_specs = in_specs ,
603
+ out_specs = out_specs ,
604
+ check_rep = check_rep ,
605
+ auto = auto ,
606
+ ) # type: ignore[return-value]
607
+ assert not isinstance (f , type )
608
+
609
+ kwarg_specs = PartitionSpec ()
610
+ jax_in_specs = jax .tree .map (
611
+ lambda x : extract .NodeStates (
612
+ _graphdef = PartitionSpec (), # type: ignore[arg-type]
613
+ states = x .shardings ,
614
+ metadata = x ,
615
+ )
616
+ if isinstance (x , StateSharding )
617
+ else x ,
618
+ in_specs ,
619
+ )
620
+ jax_out_specs = jax .tree .map (
621
+ lambda x : extract .NodeStates (
622
+ _graphdef = PartitionSpec (), # type: ignore[arg-type]
623
+ states = x .shardings ,
624
+ metadata = x ,
625
+ )
626
+ if isinstance (x , StateSharding )
627
+ else x ,
628
+ out_specs ,
629
+ )
630
+
631
+ @functools .wraps (f )
632
+ def shard_map_wrapper (* args , ** kwargs ):
633
+ # run dynamic_cache_context before update_context
634
+ with graph .update_context (shard_map_wrapper ):
635
+ pure_args , pure_kwargs = extract .to_tree (
636
+ (args , kwargs ),
637
+ prefix = (in_specs , kwarg_specs )
638
+ if in_specs is not None or kwarg_specs is not None
639
+ else None ,
640
+ split_fn = _jit_split_fn ,
641
+ check_aliasing = in_specs is not None or kwarg_specs is not None ,
642
+ ctxtag = shard_map_wrapper ,
643
+ )
644
+ pure_args_out , pure_kwargs_out , pure_out = shard_map_fn (
645
+ * pure_args , ** pure_kwargs
646
+ )
647
+ _args_out , _kwargs_out , out = extract .from_tree (
648
+ (pure_args_out , pure_kwargs_out , pure_out ),
649
+ merge_fn = _jit_merge_fn ,
650
+ is_inner = False ,
651
+ ctxtag = shard_map_wrapper ,
652
+ )
653
+ return out
654
+
655
+ shard_map_fn = jax .experimental .shard_map .shard_map (
656
+ ShardMapFn (f , in_specs , out_specs , kwarg_specs , shard_map_wrapper ),
657
+ mesh = mesh ,
658
+ in_specs = jax_in_specs ,
659
+ out_specs = (jax_in_specs , kwarg_specs , jax_out_specs ), # type: ignore
660
+ check_rep = check_rep ,
661
+ auto = auto ,
662
+ )
663
+
664
+ shard_map_wrapper .inner = shard_map_fn # type: ignore
665
+
666
+ return shard_map_wrapper # type: ignore
0 commit comments