Skip to content

Commit ac3e85a

Browse files
author
Flax Authors
committed
Merge pull request #4379 from google:nnx-fix-fori-loop
PiperOrigin-RevId: 696274913
2 parents 480a196 + c48905a commit ac3e85a

File tree

2 files changed

+39
-1
lines changed

2 files changed

+39
-1
lines changed

flax/nnx/transforms/iteration.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1336,9 +1336,13 @@ def per_node_state(ns: extract.NodeStates | tp.Any):
13361336
):
13371337
return ns
13381338

1339-
def per_node_def(nd: graph.NodeDef | tp.Any):
1339+
def per_node_def(nd: graph.NodeDef | graph.NodeRef):
13401340
if nd.index >= 0:
13411341
global_index_mapping[nd.index] = nd.index
1342+
1343+
if isinstance(nd, graph.NodeRef):
1344+
return
1345+
13421346
for sub_nd in nd.subgraphs.values():
13431347
per_node_def(sub_nd)
13441348
for l in nd.leaves.values():

tests/nnx/transforms_test.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2934,6 +2934,40 @@ def fwd_fn(i, input):
29342934
_, y = nnx.fori_loop(2, 4, fwd_fn, (module, x))
29352935
np.testing.assert_array_equal(y, x * 2 * 3)
29362936

2937+
def test_fori_loop_with_sharing(self):
2938+
class A(nnx.Object):
2939+
def __init__(self):
2940+
self.params = nnx.Param(jnp.zeros((10,), dtype=int))
2941+
2942+
class B(nnx.Object):
2943+
def __init__(self, a: A):
2944+
self.a = a
2945+
2946+
class C(nnx.Object):
2947+
def __init__(self, a: A):
2948+
self.a = a
2949+
2950+
class D(nnx.Object):
2951+
def __init__(self):
2952+
self.a = A()
2953+
self.b = B(self.a)
2954+
self.c = C(self.a)
2955+
2956+
def increment(_, d: D) -> D:
2957+
d.a.params += 1
2958+
return d
2959+
2960+
@nnx.jit
2961+
def rollout(d: D):
2962+
nnx.fori_loop(0, 10, increment, d)
2963+
2964+
d = D()
2965+
rollout(d)
2966+
2967+
np.testing.assert_array_equal(
2968+
d.a.params.value, np.full((10,), 10, dtype=int)
2969+
)
2970+
29372971

29382972
class TestSplitMergeInputs(absltest.TestCase):
29392973
def test_split_inputs(self):

0 commit comments

Comments
 (0)