Skip to content

Commit 88ea291

Browse files
author
Flax Authors
committed
Merge pull request google#4555 from IvyZX:linx_setup_name
PiperOrigin-RevId: 729262782
2 parents 013147a + ee377f4 commit 88ea291

File tree

2 files changed

+31
-20
lines changed

2 files changed

+31
-20
lines changed

flax/nnx/bridge/module.py

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -123,28 +123,31 @@ def _module_meta_call(cls: type[M], *args, **kwargs) -> M:
123123
parent = None
124124
module: M
125125

126-
if parent_ctx is not None and parent_ctx.in_compact:
127-
if 'parent' in kwargs:
128-
parent = kwargs.pop('parent')
129-
if parent is not None:
130-
raise ValueError(
131-
f"'parent' can only be set to None, got {type(parent).__name__}"
132-
)
133-
name = None
134-
else:
135-
type_index = parent_ctx.type_counter[cls]
136-
parent_ctx.type_counter[cls] += 1
126+
name = None
127+
if parent_ctx is not None:
128+
if not parent_ctx.in_compact and 'name' in kwargs:
129+
raise ValueError(
130+
f"'name' can only be set in @compact functions. If in setup(), "
131+
"use parent's `self.<attr_name> to set the submodule name.")
137132

138-
if 'name' in kwargs:
139-
name = kwargs.pop('name')
140-
if not isinstance(name, str):
141-
raise ValueError(f"'name' must be a 'str', got {type(name).__name__}")
133+
if parent_ctx.in_compact:
134+
if 'parent' in kwargs:
135+
parent = kwargs.pop('parent')
136+
if parent is not None:
137+
raise ValueError(
138+
f"'parent' can only be set to None, got {type(parent).__name__}"
139+
)
142140
else:
143-
name = f'{cls.__name__}_{type_index}'
144-
145-
parent = parent_ctx.module
146-
else:
147-
name = None
141+
type_index = parent_ctx.type_counter[cls]
142+
parent_ctx.type_counter[cls] += 1
143+
144+
if 'name' in kwargs:
145+
name = kwargs.pop('name')
146+
if not isinstance(name, str):
147+
raise ValueError(f"'name' must be a 'str', got {type(name).__name__}")
148+
else:
149+
name = f'{cls.__name__}_{type_index}'
150+
parent = parent_ctx.module
148151

149152
module = nnx_module.ModuleMeta.__call__(cls, *args, **kwargs)
150153
module.scope = None

tests/nnx/bridge/wrappers_test.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -640,6 +640,14 @@ def __call__(self, x):
640640
y = bar.apply(variables, x)
641641
self.assertEqual(y.shape, (1, 5))
642642

643+
with self.assertRaises(ValueError):
644+
class SetupBar(bridge.Module):
645+
def setup(self):
646+
self.xyz = Foo(5, name='xyz')
647+
def __call__(self, x):
648+
return self.xyz(x)
649+
SetupBar().init(0, x)
650+
643651
def test_dense_port(self):
644652
class Dense(bridge.Module):
645653
features: int

0 commit comments

Comments
 (0)