Skip to content

Commit 3a9df8e

Browse files
author
Flax Authors
committed
Merge pull request #3286 from google:fix-cloudpickle
PiperOrigin-RevId: 558726064
2 parents a767363 + ed3ad43 commit 3a9df8e

File tree

2 files changed

+28
-0
lines changed

2 files changed

+28
-0
lines changed

flax/linen/module.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -862,6 +862,8 @@ def __set_name__(self, *args, **kwargs):
862862
self.wrapped.__set_name__(*args, **kwargs)
863863

864864
def __getattr__(self, name):
865+
if 'wrapped' not in vars(self):
866+
raise AttributeError()
865867
return getattr(self.wrapped, name)
866868

867869
return _DescriptorWrapper(descriptor)

tests/linen/linen_module_test.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import inspect
2323
import operator
2424
import sys
25+
from tempfile import TemporaryDirectory
2526
from typing import (
2627
Any,
2728
Callable,
@@ -2399,6 +2400,31 @@ def record_interceptor(f, args, kwargs, context):
23992400
self.assertIs(called[0], bar)
24002401
self.assertIs(called[1], foo)
24012402

2403+
def test_cloudpickle_module(self):
2404+
from cloudpickle import cloudpickle_fast
2405+
2406+
class NNModuleWithProperty(nn.Module):
2407+
a: int
2408+
b: str
2409+
2410+
@property
2411+
def my_property(self):
2412+
return self.b * self.a
2413+
2414+
m = NNModuleWithProperty(a=2, b='ok')
2415+
2416+
with TemporaryDirectory() as tmpdir:
2417+
filename = f'{tmpdir}/module.pkl'
2418+
with open(filename, 'wb') as f:
2419+
cloudpickle_fast.dump(m, f)
2420+
2421+
with open(filename, 'rb') as f:
2422+
obj_loaded = cloudpickle_fast.load(f)
2423+
2424+
self.assertEqual(obj_loaded.a, 2)
2425+
self.assertEqual(obj_loaded.b, 'ok')
2426+
self.assertEqual(obj_loaded.my_property, 'okok')
2427+
24022428

24032429
class LeakTests(absltest.TestCase):
24042430

0 commit comments

Comments
 (0)