Skip to content

Commit dba3fa5

Browse files
author
Flax Authors
committed
Merge pull request #2028 from levskaya:getattribute
PiperOrigin-RevId: 439935122
2 parents 0c647cb + aca3e29 commit dba3fa5

File tree

2 files changed

+29
-0
lines changed

2 files changed

+29
-0
lines changed

flax/linen/module.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -714,6 +714,17 @@ def __getattr__(self, name: str) -> Any:
714714
raise AttributeError(
715715
f'"{self.__class__.__name__}" object has no attribute "{name}"')
716716

717+
def __getattribute__(self, name):
718+
"""Call setup() before accessing any submodule attributes."""
719+
# NB: all code here is very "hot" and will be run very frequently.
720+
if name in object.__getattribute__(self, '__dataclass_fields__'):
721+
if (name != 'parent' and
722+
object.__getattribute__(self, '__dataclass_fields__')[name].init and
723+
isinstance(object.__getattribute__(self, name), Module)):
724+
object.__getattribute__(self, '_try_setup')()
725+
# always run original python __getattribute__
726+
return object.__getattribute__(self, name)
727+
717728
def __dir__(self) -> Iterable[str]:
718729
"""Call setup() before listing attributes."""
719730
self._try_setup()

tests/linen/linen_module_test.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1517,5 +1517,23 @@ def __call__(self):
15171517
self.assertTrue(foo.apply({}, rngs={'bar': k}))
15181518
self.assertFalse(foo.apply({}, rngs={'baz': k}))
15191519

1520+
def test_getattribute_triggers_setup(self):
1521+
class B(nn.Module):
1522+
def setup(self):
1523+
self.p1 = self.param('p1', lambda k: jnp.ones((2,)))
1524+
def fn1(self, x):
1525+
return self.p1 + x
1526+
class A(nn.Module):
1527+
b: nn.Module
1528+
def __call__(self, x):
1529+
return self.b.fn1(x)
1530+
a = A(b=B())
1531+
k = random.PRNGKey(0)
1532+
x = jnp.zeros((2,))
1533+
vs = nn.init(lambda a,x: a(x), a)(k, x)
1534+
y = nn.apply(lambda a,x: a.b.fn1(x), a)(vs, x)
1535+
np.testing.assert_array_equal(y, jnp.ones((2,)))
1536+
1537+
15201538
if __name__ == '__main__':
15211539
absltest.main()

0 commit comments

Comments
 (0)