Skip to content

Commit 1e75509

Browse files
committed
[nnx] add submodule iterator
1 parent 85eb8c0 commit 1e75509

File tree

4 files changed

+47
-24
lines changed

4 files changed

+47
-24
lines changed

flax/experimental/nnx/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@
1818
from flax.linen.pooling import pool as pool
1919

2020
from .nnx import compatibility as compatibility
21-
from .nnx import graph_utils
2221
from .nnx.dataclasses import dataclass as dataclass
2322
from .nnx.dataclasses import field as field
23+
from .nnx import graph_utils as graph_utils
2424
from .nnx.dataclasses import param_field as param_field
2525
from .nnx.dataclasses import treenode_field as treenode_field
2626
from .nnx.dataclasses import variable_field as variable_field

flax/experimental/nnx/nnx/graph_utils.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -458,6 +458,7 @@ def _graph_unflatten(
458458
if graphdef.index in index_to_node:
459459
raise RuntimeError(f'GraphDef index {graphdef.index} already used.')
460460

461+
state = state.copy()
461462
node_impl = get_node_impl(graphdef.type)
462463

463464
def _get_children():
@@ -721,6 +722,27 @@ def clone(node: Node) -> Node:
721722
return static.merge(state)
722723

723724

725+
def iter_nodes(node: tp.Any) -> tp.Iterator[tuple[Path, tp.Any]]:
726+
visited: set[int] = set()
727+
path_parts: PathParts = ()
728+
yield from _iter_nodes(node, visited, path_parts)
729+
730+
731+
def _iter_nodes(
732+
node: tp.Any, visited: set[int], path_parts: PathParts
733+
) -> tp.Iterator[tuple[Path, tp.Any]]:
734+
if not is_node(node):
735+
return
736+
if id(node) in visited:
737+
return
738+
visited.add(id(node))
739+
path = '/'.join(path_parts)
740+
yield path, node
741+
node_impl = get_node_impl(node)
742+
for key, value in node_impl.items(node):
743+
yield from _iter_nodes(value, visited, (*path_parts, key))
744+
745+
724746
# -----------------------------
725747
# register node types
726748
# -----------------------------

flax/experimental/nnx/nnx/module.py

Lines changed: 4 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -482,29 +482,10 @@ def sow(
482482
reduced_value = reduce_fn(init_fn(), value)
483483
setattr(self, name, variable_type(reduced_value))
484484

485-
def for_each(
486-
self, module_type: tp.Type[M], fn: tp.Callable[[M], None]
487-
) -> None:
488-
visited: tp.Set[ids.UUID] = set()
489-
self._on_all(module_type, fn, visited)
490-
491-
def _on_all(
492-
self,
493-
module_type: tp.Type[M],
494-
fn: tp.Callable[[M], None],
495-
visited: tp.Set[ids.UUID],
496-
) -> None:
497-
if self._module__state.id in visited:
498-
return
499-
500-
visited.add(self._module__state.id)
501-
502-
if isinstance(self, module_type):
503-
fn(self)
504-
505-
for value in vars(self).values():
485+
def modules(self) -> tp.Iterator[tuple[Path, Module]]:
486+
for path, value in graph_utils.iter_nodes(self):
506487
if isinstance(value, Module):
507-
value._on_all(module_type, fn, visited)
488+
yield path, value
508489

509490
def __init_subclass__(cls, experimental_pytree: bool = False) -> None:
510491
super().__init_subclass__()
@@ -603,7 +584,7 @@ def first_from(arg_name: str, *args: tp.Optional[A]) -> A:
603584

604585

605586
def merge(
606-
state_and_def: tuple[tpe.Unpack[tuple[State, ...]], GraphDef[M]]
587+
state_and_def: tuple[tpe.Unpack[tuple[State, ...]], GraphDef[M]],
607588
) -> M:
608589
*states, graphdef = state_and_def
609590
return graphdef.merge(*states)

flax/experimental/nnx/tests/test_module.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -615,3 +615,23 @@ def __call__(self, x, *, rngs: nnx.Rngs):
615615
y, (state, graphdef) = graphdef.apply(state)(x=2.0, rngs=nnx.Rngs(e=1))
616616

617617
assert isinstance(y, jax.Array)
618+
619+
def test_modules_iterator(self):
620+
class Foo(nnx.Module):
621+
def __init__(self, *, rngs: nnx.Rngs):
622+
self.submodules = [
623+
{'a': nnx.Linear(1, 1, rngs=rngs)},
624+
{'b': nnx.Conv(1, 1, 1, rngs=rngs)},
625+
]
626+
627+
module = Foo(rngs=nnx.Rngs(0))
628+
629+
modules = list(module.modules())
630+
631+
assert len(modules) == 3
632+
assert modules[0][0] == ''
633+
assert isinstance(modules[0][1], Foo)
634+
assert modules[1][0] == 'submodules/0/a'
635+
assert isinstance(modules[1][1], nnx.Linear)
636+
assert modules[2][0] == 'submodules/1/b'
637+
assert isinstance(modules[2][1], nnx.Conv)

0 commit comments

Comments
 (0)