diff --git a/flax/linen/module.py b/flax/linen/module.py index 1b913d51a..f106e14b9 100644 --- a/flax/linen/module.py +++ b/flax/linen/module.py @@ -862,6 +862,8 @@ def __set_name__(self, *args, **kwargs): self.wrapped.__set_name__(*args, **kwargs) def __getattr__(self, name): + if 'wrapped' not in vars(self): + raise AttributeError() return getattr(self.wrapped, name) return _DescriptorWrapper(descriptor) diff --git a/tests/linen/linen_module_test.py b/tests/linen/linen_module_test.py index c0f1b211f..da49524e2 100644 --- a/tests/linen/linen_module_test.py +++ b/tests/linen/linen_module_test.py @@ -22,6 +22,7 @@ import inspect import operator import sys +from tempfile import TemporaryDirectory from typing import ( Any, Callable, @@ -2399,6 +2400,31 @@ def record_interceptor(f, args, kwargs, context): self.assertIs(called[0], bar) self.assertIs(called[1], foo) + def test_cloudpickle_module(self): + from cloudpickle import cloudpickle_fast + + class NNModuleWithProperty(nn.Module): + a: int + b: str + + @property + def my_property(self): + return self.b * self.a + + m = NNModuleWithProperty(a=2, b='ok') + + with TemporaryDirectory() as tmpdir: + filename = f'{tmpdir}/module.pkl' + with open(filename, 'wb') as f: + cloudpickle_fast.dump(m, f) + + with open(filename, 'rb') as f: + obj_loaded = cloudpickle_fast.load(f) + + self.assertEqual(obj_loaded.a, 2) + self.assertEqual(obj_loaded.b, 'ok') + self.assertEqual(obj_loaded.my_property, 'okok') + class LeakTests(absltest.TestCase):