Skip to content

Allow specifying method as a string #2809

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 29 additions & 8 deletions flax/linen/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -1363,7 +1363,7 @@ def apply(self,
variables: VariableDict,
*args,
rngs: Optional[RNGSequences] = None,
method: Optional[Callable[..., Any]] = None,
method: Union[Callable[..., Any], str, None] = None,
mutable: CollectionFilter = False,
capture_intermediates: Union[bool, Callable[['Module', str], bool]] = False,
**kwargs) -> Union[Any, Tuple[Any, FrozenVariableDict]]:
Expand All @@ -1382,6 +1382,11 @@ def apply(self,

encoded = model.apply({'params': params}, x, method=model.encode)

You can also pass a string to a callable attribute of the module. For
example, the previous can be written as::

encoded = model.apply({'params': params}, x, method='encode')

Note ``method`` can also be a function that is not defined in
``Transformer``. In that case, the function should have at least one
argument representing an instance of the Module class::
Expand All @@ -1401,7 +1406,8 @@ def other_fn(instance, ...):
The "params" PRNG sequence is used to initialize parameters.
method: A function to call apply on. This is generally a function in the
module. If provided, applies this method. If not provided, applies the
``__call__`` method of the module.
``__call__`` method of the module. A string can also be provided to
specify a method by name.
mutable: Can be bool, str, or list. Specifies which collections should be
treated as mutable: ``bool``: all/no collections are mutable.
``str``: The name of a single mutable collection. ``list``: A
Expand All @@ -1420,7 +1426,13 @@ def other_fn(instance, ...):
"""
Module._module_checks(self)

if method is None:
if isinstance(method, str):
attribute_name = method
method = getattr(self, attribute_name)
if not callable(method):
class_name = type(self).__name__
raise TypeError(f'\'{class_name}.{attribute_name}\' must be a callable, got {type(method)}.')
elif method is None:
method = self.__call__
method = _get_unbound_fn(method)
return apply(
Expand All @@ -1433,7 +1445,7 @@ def other_fn(instance, ...):
def init_with_output(self,
rngs: Union[PRNGKey, RNGSequences],
*args,
method: Optional[Callable[..., Any]] = None,
method: Union[Callable[..., Any], str, None] = None,
mutable: CollectionFilter = DenyList('intermediates'),
capture_intermediates: Union[bool, Callable[['Module', str], bool]] = False,
**kwargs) -> Tuple[Any, FrozenVariableDict]:
Expand All @@ -1443,7 +1455,8 @@ def init_with_output(self,
rngs: The rngs for the variable collections.
*args: Named arguments passed to the init function.
method: An optional method. If provided, applies this method. If not
provided, applies the ``__call__`` method.
provided, applies the ``__call__`` method. A string can also be'
provided to specify a method by name.
mutable: Can be bool, str, or list. Specifies which collections should be
treated as mutable: ``bool``: all/no collections are mutable.
``str``: The name of a single mutable collection. ``list``: A
Expand All @@ -1468,7 +1481,14 @@ def init_with_output(self,
'RNGs should be of shape (2,) or KeyArray in Module '
f'{self.__class__.__name__}, but rngs are: {rngs}')
rngs = {'params': rngs}
if method is None:

if isinstance(method, str):
attribute_name = method
method = getattr(self, attribute_name)
if not callable(method):
class_name = type(self).__name__
raise TypeError(f'\'{class_name}.{attribute_name}\' must be a callable, got {type(method)}.')
elif method is None:
method = self.__call__
method = _get_unbound_fn(method)
return init_with_output(
Expand All @@ -1482,7 +1502,7 @@ def init_with_output(self,
def init(self,
rngs: Union[PRNGKey, RNGSequences],
*args,
method: Optional[Callable[..., Any]] = None,
method: Union[Callable[..., Any], str, None] = None,
mutable: CollectionFilter = DenyList('intermediates'),
capture_intermediates: Union[bool, Callable[['Module', str], bool]] = False,
**kwargs) -> FrozenVariableDict:
Expand All @@ -1499,7 +1519,8 @@ def init(self,
rngs: The rngs for the variable collections.
*args: Named arguments passed to the init function.
method: An optional method. If provided, applies this method. If not
provided, applies the ``__call__`` method.
provided, applies the ``__call__`` method. A string can also be
provided to specify a method by name.
mutable: Can be bool, str, or list. Specifies which collections should be
treated as mutable: ``bool``: all/no collections are mutable.
``str``: The name of a single mutable collection. ``list``: A
Expand Down
18 changes: 17 additions & 1 deletion tests/linen/linen_module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -826,6 +826,7 @@ def __call__(self, x):
def test_module_apply_method(self):

class Foo(nn.Module):
not_callable: int = 1

@nn.compact
def __call__(self):
Expand All @@ -849,8 +850,23 @@ def test(self):
with self.assertRaisesRegex(errors.ApplyModuleInvalidMethodError, msg):
Foo().apply({}, method=lambda: True)

with self.assertRaisesRegex(errors.ApplyModuleInvalidMethodError, msg):
# string method names are also allowed.
Foo().apply({}, method='test')
# test same for init.
Foo().init({}, method='test')

# non-existent attribute names will yield AttributeError.
with self.assertRaisesRegex(AttributeError, "allowed_apply_fn"):
Foo().apply({}, method='allowed_apply_fn')
# test same for init.
Foo().init({}, method='allowed_apply_fn')

# attributes which are not callables yield TypeError.
with self.assertRaisesRegex(TypeError, "'Foo.not_callable' must be a callable"):
Foo().apply({}, method='not_callable')
# test same for init.
Foo().init({}, method='not_callable')


def test_call_unbound_compact_module_methods(self):
dense = Dense(3)
Expand Down