File tree Expand file tree Collapse file tree 2 files changed +29
-0
lines changed Expand file tree Collapse file tree 2 files changed +29
-0
lines changed Original file line number Diff line number Diff line change @@ -714,6 +714,17 @@ def __getattr__(self, name: str) -> Any:
714
714
raise AttributeError (
715
715
f'"{ self .__class__ .__name__ } " object has no attribute "{ name } "' )
716
716
717
+ def __getattribute__ (self , name ):
718
+ """Call setup() before accessing any submodule attributes."""
719
+ # NB: all code here is very "hot" and will be run very frequently.
720
+ if name in object .__getattribute__ (self , '__dataclass_fields__' ):
721
+ if (name != 'parent' and
722
+ object .__getattribute__ (self , '__dataclass_fields__' )[name ].init and
723
+ isinstance (object .__getattribute__ (self , name ), Module )):
724
+ object .__getattribute__ (self , '_try_setup' )()
725
+ # always run original python __getattribute__
726
+ return object .__getattribute__ (self , name )
727
+
717
728
def __dir__ (self ) -> Iterable [str ]:
718
729
"""Call setup() before listing attributes."""
719
730
self ._try_setup ()
Original file line number Diff line number Diff line change @@ -1517,5 +1517,23 @@ def __call__(self):
1517
1517
self .assertTrue (foo .apply ({}, rngs = {'bar' : k }))
1518
1518
self .assertFalse (foo .apply ({}, rngs = {'baz' : k }))
1519
1519
1520
+ def test_getattribute_triggers_setup (self ):
1521
+ class B (nn .Module ):
1522
+ def setup (self ):
1523
+ self .p1 = self .param ('p1' , lambda k : jnp .ones ((2 ,)))
1524
+ def fn1 (self , x ):
1525
+ return self .p1 + x
1526
+ class A (nn .Module ):
1527
+ b : nn .Module
1528
+ def __call__ (self , x ):
1529
+ return self .b .fn1 (x )
1530
+ a = A (b = B ())
1531
+ k = random .PRNGKey (0 )
1532
+ x = jnp .zeros ((2 ,))
1533
+ vs = nn .init (lambda a ,x : a (x ), a )(k , x )
1534
+ y = nn .apply (lambda a ,x : a .b .fn1 (x ), a )(vs , x )
1535
+ np .testing .assert_array_equal (y , jnp .ones ((2 ,)))
1536
+
1537
+
1520
1538
if __name__ == '__main__' :
1521
1539
absltest .main ()
You can’t perform that action at this time.
0 commit comments