@@ -1363,7 +1363,7 @@ def apply(self,
1363
1363
variables : VariableDict ,
1364
1364
* args ,
1365
1365
rngs : Optional [RNGSequences ] = None ,
1366
- method : Optional [Callable [..., Any ]] = None ,
1366
+ method : Union [Callable [..., Any ], str , None ] = None ,
1367
1367
mutable : CollectionFilter = False ,
1368
1368
capture_intermediates : Union [bool , Callable [['Module' , str ], bool ]] = False ,
1369
1369
** kwargs ) -> Union [Any , Tuple [Any , FrozenVariableDict ]]:
@@ -1382,6 +1382,11 @@ def apply(self,
1382
1382
1383
1383
encoded = model.apply({'params': params}, x, method=model.encode)
1384
1384
1385
+ You can also pass a string to a callable attribute of the module. For
1386
+ example, the previous can be written as::
1387
+
1388
+ encoded = model.apply({'params': params}, x, method='encode')
1389
+
1385
1390
Note ``method`` can also be a function that is not defined in
1386
1391
``Transformer``. In that case, the function should have at least one
1387
1392
argument representing an instance of the Module class::
@@ -1420,7 +1425,14 @@ def other_fn(instance, ...):
1420
1425
"""
1421
1426
Module ._module_checks (self )
1422
1427
1423
- if method is None :
1428
+ if isinstance (method , str ):
1429
+ attribute_name = method
1430
+ method = getattr (self , attribute_name )
1431
+ if not callable (method ):
1432
+ class_name = type (self ).__name__
1433
+ raise TypeError (f'\' { class_name } .{ attribute_name } \' must be a callable, got { type (method )} .' )
1434
+
1435
+ elif method is None :
1424
1436
method = self .__call__
1425
1437
method = _get_unbound_fn (method )
1426
1438
return apply (
0 commit comments