Skip to content

Commit 07c769d

Browse files
committed
allow specifying method as string
1 parent e51d017 commit 07c769d

File tree

2 files changed

+25
-3
lines changed

2 files changed

+25
-3
lines changed

flax/linen/module.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1363,7 +1363,7 @@ def apply(self,
13631363
variables: VariableDict,
13641364
*args,
13651365
rngs: Optional[RNGSequences] = None,
1366-
method: Optional[Callable[..., Any]] = None,
1366+
method: Union[Callable[..., Any], str, None] = None,
13671367
mutable: CollectionFilter = False,
13681368
capture_intermediates: Union[bool, Callable[['Module', str], bool]] = False,
13691369
**kwargs) -> Union[Any, Tuple[Any, FrozenVariableDict]]:
@@ -1382,6 +1382,11 @@ def apply(self,
13821382
13831383
encoded = model.apply({'params': params}, x, method=model.encode)
13841384
1385+
You can also pass a string to a callable attribute of the module. For
1386+
example, the previous can be written as::
1387+
1388+
encoded = model.apply({'params': params}, x, method='encode')
1389+
13851390
Note ``method`` can also be a function that is not defined in
13861391
``Transformer``. In that case, the function should have at least one
13871392
argument representing an instance of the Module class::
@@ -1420,7 +1425,14 @@ def other_fn(instance, ...):
14201425
"""
14211426
Module._module_checks(self)
14221427

1423-
if method is None:
1428+
if isinstance(method, str):
1429+
attribute_name = method
1430+
method = getattr(self, attribute_name)
1431+
if not callable(method):
1432+
class_name = type(self).__name__
1433+
raise TypeError(f'\'{class_name}.{attribute_name}\' must be a callable, got {type(method)}.')
1434+
1435+
elif method is None:
14241436
method = self.__call__
14251437
method = _get_unbound_fn(method)
14261438
return apply(

tests/linen/linen_module_test.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -826,6 +826,7 @@ def __call__(self, x):
826826
def test_module_apply_method(self):
827827

828828
class Foo(nn.Module):
829+
not_callable: int = 1
829830

830831
@nn.compact
831832
def __call__(self):
@@ -849,9 +850,18 @@ def test(self):
849850
with self.assertRaisesRegex(errors.ApplyModuleInvalidMethodError, msg):
850851
Foo().apply({}, method=lambda: True)
851852

852-
with self.assertRaisesRegex(errors.ApplyModuleInvalidMethodError, msg):
853+
# string method names are also allowed.
854+
Foo().apply({}, method='test')
855+
856+
# non-existent attribute names will yield AttributeError.
857+
with self.assertRaisesRegex(AttributeError, "allowed_apply_fn"):
853858
Foo().apply({}, method='allowed_apply_fn')
854859

860+
# attributes which are not callables yield TypeError.
861+
with self.assertRaisesRegex(TypeError, "'Foo.not_callable' must be a callable"):
862+
Foo().apply({}, method='not_callable')
863+
864+
855865
def test_call_unbound_compact_module_methods(self):
856866
dense = Dense(3)
857867
msg = r'Can\'t call compact methods on unbound modules'

0 commit comments

Comments
 (0)