Skip to content

Commit 5413850

Browse files
IvyZXFlax Authors
authored andcommitted
[bridge] Set _initializing correctly and avoid return RNG states
PiperOrigin-RevId: 731073396
1 parent d96be6c commit 5413850

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

flax/nnx/bridge/module.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,7 @@ def _setattr(self, name: str, value: tp.Any) -> None:
190190
graph.update(value, state)
191191
for leaf in jax.tree.leaves(value):
192192
if isinstance(leaf, Module):
193+
leaf._object__state._initializing = self.is_initializing()
193194
_bind_module(self, leaf)
194195
super()._setattr(name, value)
195196

@@ -308,6 +309,11 @@ def _get_variables(self) -> tp.Mapping:
308309

309310
variable_state: variablelib.VariableState
310311
for path, variable_state in statelib.to_flat_state(state):
312+
313+
if issubclass(variable_state.type, rnglib.RngState):
314+
# Don't return RNG states, since Linen doesn't have them.
315+
continue
316+
311317
try:
312318
collection = variablelib.variable_name_from_type(variable_state.type)
313319
except ValueError:

0 commit comments

Comments
 (0)