Skip to content

Commit 6efcebd

Browse files
committed
allow method argument to accept submodules
1 parent a767363 commit 6efcebd

File tree

2 files changed

+19
-0
lines changed

2 files changed

+19
-0
lines changed

flax/linen/module.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1911,6 +1911,12 @@ def other_fn(instance, ...):
19111911
f"'{class_name}.{attribute_name}' must be a callable, got"
19121912
f' {type(method)}.'
19131913
)
1914+
# if the `method` string is a submodule, we create a lambda function
1915+
# that calls the submodule, forwarding all arguments.
1916+
if isinstance(method, Module):
1917+
method = lambda self, *args, **kwargs: getattr(self, attribute_name)(
1918+
*args, **kwargs
1919+
)
19141920
elif method is None:
19151921
method = self.__call__
19161922
method = _get_unbound_fn(method)

tests/linen/linen_module_test.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -887,6 +887,19 @@ def test(self):
887887
# test same for init.
888888
Foo().init({}, method='not_callable')
889889

890+
def test_module_apply_method_submodule(self):
891+
class Foo(nn.Module):
892+
bar: nn.Module
893+
894+
@nn.compact
895+
def __call__(self, x):
896+
return self.bar(x)
897+
898+
foo = Foo(nn.Dense(3))
899+
variables = foo.init(jax.random.PRNGKey(0), jnp.zeros(3))
900+
901+
foo.apply(variables, jnp.zeros(3), method='bar')
902+
890903
def test_call_unbound_compact_module_methods(self):
891904
dense = Dense(3)
892905
msg = r'Can\'t call compact methods on unbound modules'

0 commit comments

Comments
 (0)