Skip to content

Commit 0c647cb

Browse files
author
Flax Authors
committed
Merge pull request #2025 from jheek:fix-tranform-state-reuse
PiperOrigin-RevId: 439568379
2 parents ffba15b + 8c6a16b commit 0c647cb

File tree

2 files changed

+22
-2
lines changed

2 files changed

+22
-2
lines changed

flax/linen/transforms.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,7 @@ def create_trans_fn(fn_name, fn_trafo_args):
290290
# we need to create a scope-function from our class for the given method
291291
@functools.wraps(fn)
292292
def wrapped_fn(self, *args, **kwargs):
293+
state = self._state.export()
293294
# make a scope-function to transform
294295
def core_fn(scopes, *args, **kwargs):
295296
# make a clone of self using its arguments
@@ -301,7 +302,7 @@ def core_fn(scopes, *args, **kwargs):
301302
# we reference module_class, not self.__class__ to avoid infinite loop
302303
cloned = module_class(parent=None, **attrs)
303304
cloned, args, kwargs = set_module_scopes(cloned, args, kwargs, scopes)
304-
object.__setattr__(cloned, '_state', self._state.export()) # pylint: disable=protected-access
305+
object.__setattr__(cloned, '_state', state.export()) # pylint: disable=protected-access
305306
res = fn(cloned, *args, **kwargs)
306307
self._state.reimport(cloned._state) # pylint: disable=protected-access
307308
_test_transformed_return_values(res, fn_name)
@@ -343,12 +344,13 @@ def decorator_lift_transform(transform, class_fn, *trafo_args,
343344
prewrapped_fns = [wrap_method_once(class_fn) for class_fn in class_fns]
344345
@functools.wraps(prewrapped_fns[0])
345346
def wrapped_fn(self, *args, **kwargs):
347+
state = self._state.export()
346348
# make a scope-function to transform
347349
def core_fn(prewrapped_fn, class_fn, scopes, *args, **kwargs):
348350
if not multi_scope:
349351
scopes = [scopes]
350352
cloned, args, kwargs = set_module_scopes(self, args, kwargs, scopes)
351-
object.__setattr__(cloned, '_state', self._state.export()) # pylint: disable=protected-access
353+
object.__setattr__(cloned, '_state', state.export()) # pylint: disable=protected-access
352354
res = prewrapped_fn(cloned, *args, **kwargs)
353355
self._state.reimport(cloned._state) # pylint: disable=protected-access
354356
_test_transformed_return_values(res, getattr(class_fn, '__name__', None))

tests/linen/linen_transforms_test.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1434,6 +1434,24 @@ def __call__(self, x):
14341434
with self.assertRaises(errors.TransformTargetError):
14351435
Foo().init(random.PRNGKey(0), jnp.zeros((2, 3)))
14361436

1437+
def test_scan_compact_count(self):
1438+
class Foo(nn.Module):
1439+
num_layers: int = 5
1440+
1441+
@nn.compact
1442+
def __call__(self, x):
1443+
def body_fn(mdl, x):
1444+
return nn.Dense(features=x.shape[-1])(x), ()
1445+
x, _ = nn.scan(body_fn, length=self.num_layers, variable_axes={"params": 0}, split_rngs={"params": True})(self, x)
1446+
return x
1447+
1448+
m = Foo()
1449+
x = jnp.ones((3,))
1450+
v = m.init(jax.random.PRNGKey(0), x)
1451+
self.assertEqual(v['params']['Dense_0']['kernel'].shape, (5, 3, 3))
1452+
m.apply(v, x)
1453+
1454+
14371455

14381456
if __name__ == '__main__':
14391457
absltest.main()

0 commit comments

Comments
 (0)