Skip to content

Commit d59132d

Browse files
author
Flax Authors
committed
Merge pull request google#4604 from IvyZX:linx-nn
PiperOrigin-RevId: 734672286
2 parents e3789de + 300494b commit d59132d

File tree

6 files changed

+130
-29
lines changed

6 files changed

+130
-29
lines changed

flax/nnx/bridge/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,5 +26,6 @@
2626
from .module import compact as compact
2727
from .module import current_context as current_context
2828
from .module import current_module as current_module
29-
from .interop import wrap_nnx_mdl as wrap_nnx_mdl
29+
from .interop import nnx_in_bridge_mdl as nnx_in_bridge_mdl
30+
from .interop import linen_in_bridge_mdl as linen_in_bridge_mdl
3031
from flax.nnx.nn import initializers as initializers

flax/nnx/bridge/interop.py

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,32 @@
1414

1515
import typing as tp
1616

17+
from flax.linen import module as nn_module
1718
from flax.nnx import graph, rnglib
19+
from flax.nnx.bridge import wrappers
1820
from flax.nnx.bridge import module as bdg_module
1921
import flax.nnx.module as nnx_module
2022
from flax.nnx.transforms.transforms import eval_shape as nnx_eval_shape
2123
from flax.nnx.transforms.compilation import jit as nnx_jit
2224

2325

24-
def wrap_nnx_mdl(factory: tp.Callable[[rnglib.Rngs], nnx_module.Module],
25-
name: str | None = None):
26-
"""Create module at init time, or make abstract module and let parent bind it with its state. Use current bridge module scope for RNG generation."""
26+
def nnx_in_bridge_mdl(factory: tp.Callable[[rnglib.Rngs], nnx_module.Module],
27+
name: str | None = None) -> nnx_module.Module:
28+
"""Make pure NNX modules a submodule of a bridge module.
29+
30+
Create module at init time, or make abstract module and let parent bind
31+
it with its state.
32+
Use current bridge module scope for RNG generation.
33+
34+
Args:
35+
factory: a function that takes an `nnx.Rngs` arg and returns an NNX module.
36+
name: the name of the module. Only used during `bridge.compact` functions;
37+
in setup() function the user will set it to an attribute explicitly.
38+
Returns:
39+
A submodule (`nnx.Module`) of the bridge module.
40+
"""
2741
parent_ctx, parent = bdg_module.current_context(), bdg_module.current_module()
28-
assert parent_ctx is not None and parent is not None, 'wrap_nnx_mdl only needed inside bridge Module'
42+
assert parent_ctx is not None and parent is not None, 'nnx_in_bridge_mdl() only needed inside bridge Module'
2943
parent = parent_ctx.module
3044
assert parent.scope is not None
3145

@@ -50,3 +64,26 @@ def rng_state(rngs):
5064
name = bdg_module._auto_submodule_name(parent_ctx, type(module))
5165
setattr(parent, name, module)
5266
return module
67+
68+
69+
def linen_in_bridge_mdl(linen_module: nn_module.Module,
70+
name: str | None = None) -> nnx_module.Module:
71+
"""Make Linen modules a submodule of a bridge module using wrappers.ToNNX().
72+
73+
Args:
74+
linen_module: the underlying Linen module instance.
75+
name: the name of the module. Only used during `bridge.compact` functions;
76+
in setup() function the user will set it to an attribute explicitly.
77+
Returns:
78+
A submodule (`nnx.Module`) of the bridge module.
79+
"""
80+
parent_ctx, parent = bdg_module.current_context(), bdg_module.current_module()
81+
assert parent_ctx is not None and parent is not None, 'linen_in_bridge_mdl() only needed inside bridge Module'
82+
assert parent.scope is not None
83+
module = wrappers.ToNNX(linen_module, parent.scope.rngs)
84+
wrappers._set_initializing(module, parent.is_initializing())
85+
if parent_ctx.in_compact:
86+
if name is None:
87+
name = bdg_module._auto_submodule_name(parent_ctx, type(linen_module))
88+
setattr(parent, name, module)
89+
return module

flax/nnx/bridge/module.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
from flax import errors
2727
from flax.core import meta
28+
from flax.core.scope import CollectionFilter
2829
from flax.core.frozen_dict import FrozenDict
2930
from flax.nnx import graph, rnglib, statelib, traversals
3031
from flax.nnx import variablelib
@@ -63,11 +64,12 @@ class ModuleState(statelib.State):
6364

6465

6566
class Scope(Object):
66-
def __init__(self, rngs: rnglib.Rngs):
67+
def __init__(self, rngs: rnglib.Rngs, mutable: CollectionFilter):
6768
self.rngs = rngs
69+
self.mutable = mutable
6870

6971
def copy(self):
70-
return Scope(self.rngs)
72+
return Scope(self.rngs, self.mutable)
7173

7274

7375
class _HasSetup(tp.Protocol):
@@ -365,7 +367,7 @@ def apply(
365367
*args,
366368
rngs: int | jax.Array | dict[str, jax.Array] | rnglib.Rngs | None = None,
367369
method: tp.Callable[..., tp.Any] | str = '__call__',
368-
mutable: tp.Any = False,
370+
mutable: CollectionFilter = False,
369371
_initialize: bool = False,
370372
**kwargs,
371373
) -> tp.Any:
@@ -422,7 +424,7 @@ def to_variable(value):
422424
if isinstance(value, Object):
423425
value._object__state._initializing = _initialize
424426
if isinstance(value, Module):
425-
value.scope = Scope(rngs)
427+
value.scope = Scope(rngs, mutable)
426428
_maybe_call_setup(value)
427429

428430
MODULE_CONTEXT.module_stack.append(
@@ -517,3 +519,4 @@ def _get_unbound_fn(method_or_fn: tp.Callable) -> tp.Callable:
517519
raise errors.ApplyModuleInvalidMethodError(method_or_fn)
518520

519521
return method_or_fn
522+

flax/nnx/bridge/wrappers.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from flax.nnx import graph
2424
from flax.nnx import variablelib
2525
from flax.nnx.bridge import variables as bv
26+
from flax.nnx.bridge import module as bdg_module
2627
from flax.nnx.module import Module
2728
from flax.nnx.object import Object
2829
from flax.nnx.rnglib import Rngs
@@ -124,7 +125,6 @@ def __init__(
124125
):
125126
self.module = module
126127
self.rngs = rngs
127-
self.linen_attributes: tuple[str, ...] = ()
128128

129129
def lazy_init(self, *args, **kwargs):
130130
"""A shortcut of calling `nnx.bridge.lazy_init()` upon this module."""
@@ -140,46 +140,49 @@ def __call__(
140140
rngs = self.rngs
141141
if self._object__state.initializing:
142142
_rngs = (
143-
{name: stream.key.raw_value for name, stream in rngs.items()}
144-
if rngs
145-
else {}
143+
{name: stream() for name, stream in rngs.items()} if rngs else {}
146144
)
147145
# rename default to params
148146
if 'params' not in _rngs and 'default' in _rngs:
149147
_rngs['params'] = _rngs.pop('default')
150148
out, variables = self.module.init_with_output(_rngs, *args, method=method, **kwargs)
151149

152150
nnx_attrs = bv.linen_vars_to_nnx_attrs(variables)
153-
linen_attributes = set(self.linen_attributes)
154151
for attr_name, value in nnx_attrs.items():
155152
setattr(self, attr_name, value)
156-
linen_attributes.add(attr_name)
157-
self.linen_attributes = tuple(linen_attributes) # make it hashable
158153

159154
else:
160-
nnx_attrs = {name: getattr(self, name) for name in self.linen_attributes}
155+
nnx_attrs = {k: v for k, v in vars(self).items()
156+
if k not in ['module', 'rngs', '_object__state']}
161157
variables = bv.nnx_attrs_to_linen_vars(nnx_attrs)
162158

163159
_rngs = (
164160
{name: stream() for name, stream in rngs.items()} if rngs else {}
165161
)
162+
163+
# Get `mutable` from top level bridge.Module context if any
164+
if (m := bdg_module.current_module()) is not None:
165+
assert m.scope is not None
166+
mutable = m.scope.mutable
167+
if 'mutable' in kwargs and kwargs['mutable'] != mutable:
168+
raise ValueError(
169+
f"Multiple `mutable` arguments detected: {mutable} at top level vs "
170+
f"{kwargs['mutable']} in ToNNX() call")
171+
kwargs['mutable'] = mutable
172+
166173
out = self.module.apply(variables, *args, rngs=_rngs, method=method, **kwargs)
167174

168175
# Split out the updates if `mutable` is passed into the Flax module
169176
if kwargs.get('mutable', False) != False:
170177
out, updates = out
171178
nnx_attrs = bv.linen_vars_to_nnx_attrs(updates)
172-
linen_attributes = set(self.linen_attributes)
173179
for attr_name, value in nnx_attrs.items():
174-
linen_attributes.add(attr_name)
175180
if hasattr(self, attr_name) and isinstance(value, dict):
176181
original_tree = getattr(self, attr_name)
177182
setattr(self, attr_name, original_tree | value)
178183
else:
179184
setattr(self, attr_name, value)
180185

181-
self.linen_attributes = tuple(linen_attributes) # make it hashable
182-
183186
return out
184187

185188

tests/nnx/bridge/module_test.py

Lines changed: 59 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -319,8 +319,8 @@ def __call__(self, x):
319319
class BridgeMLP(bridge.Module):
320320
@bridge.compact
321321
def __call__(self, x):
322-
x = nnx.bridge.wrap_nnx_mdl(lambda r: NNXLayer(8, 0.3, rngs=r))(x)
323-
x = nnx.bridge.wrap_nnx_mdl(
322+
x = bridge.nnx_in_bridge_mdl(lambda r: NNXLayer(8, 0.3, rngs=r))(x)
323+
x = bridge.nnx_in_bridge_mdl(
324324
lambda r: NNXLayer(8, 0.3, rngs=r), name='another')(x)
325325
return x
326326

@@ -345,7 +345,8 @@ def __call__(self, x):
345345

346346
class BridgeMLPSetup(bridge.Module):
347347
def setup(self):
348-
self.layer = nnx.bridge.wrap_nnx_mdl(lambda r: NNXLayer(8, 0.3, rngs=r))
348+
self.layer = bridge.nnx_in_bridge_mdl(
349+
lambda r: NNXLayer(8, 0.3, rngs=r))
349350
def __call__(self, x):
350351
return self.layer(x)
351352

@@ -371,13 +372,67 @@ def generate_weights(r):
371372
class BridgeFoo(bridge.Module):
372373
@bridge.compact
373374
def __call__(self, x):
374-
x = nnx.bridge.wrap_nnx_mdl(lambda r: FooStack(4, r.default()))(x)
375+
x = bridge.nnx_in_bridge_mdl(lambda r: FooStack(4, r.default()))(x)
375376
return x
376377

377378
model = BridgeFoo()
378379
v = model.init(jax.random.key(1), jnp.ones((1, 4)))
379380
y = model.apply(v, jnp.ones((1, 4)), rngs=jax.random.key(1))
380381

382+
def test_linen_submodule(self):
383+
class LinenLayer(nn.Module):
384+
dim: int
385+
dropout_rate: float
386+
def setup(self):
387+
self.linear = nn.Dense(self.dim, use_bias=False)
388+
self.dropout = nn.Dropout(self.dropout_rate, deterministic=False)
389+
390+
def __call__(self, x):
391+
if not self.is_initializing():
392+
self.sow('intermediates', 'count', 1,
393+
init_fn=lambda: 0, reduce_fn=lambda a, b: a + b)
394+
x = self.linear(x)
395+
x = self.dropout(x)
396+
return x
397+
398+
class BridgeMLP(bridge.Module):
399+
@bridge.compact
400+
def __call__(self, x):
401+
x = bridge.linen_in_bridge_mdl(LinenLayer(8, 0.3))(x)
402+
x = bridge.linen_in_bridge_mdl(LinenLayer(8, 0.3), name='another')(x)
403+
return x
404+
405+
model = BridgeMLP()
406+
x = jax.random.normal(jax.random.key(0), (4, 8))
407+
variables = model.init(jax.random.key(1), x)
408+
self.assertFalse(jnp.array_equal(
409+
variables['params']['LinenLayer_0']['linear']['kernel'],
410+
variables['params']['another']['linear']['kernel'], ))
411+
412+
k1, k2, k3 = jax.random.split(jax.random.key(0), 3)
413+
y1 = model.apply(variables, x, rngs={'params': k1, 'dropout': k2})
414+
y2 = model.apply(variables, x, rngs={'params': k1, 'dropout': k3})
415+
assert not jnp.array_equal(y1, y2)
416+
417+
_, updates = model.apply(variables, x, rngs={'params': k1, 'dropout': k3},
418+
mutable=True)
419+
self.assertEqual(updates['intermediates']['LinenLayer_0']['count'], 1)
420+
421+
class BridgeMLPSetup(bridge.Module):
422+
def setup(self):
423+
self.layer = bridge.linen_in_bridge_mdl(LinenLayer(8, 0.3))
424+
def __call__(self, x):
425+
return self.layer(x)
426+
427+
model = BridgeMLPSetup()
428+
variables = model.init(jax.random.key(1), x)
429+
self.assertSameElements(variables['params'].keys(), ['layer'])
430+
y1 = model.apply(variables, x, rngs={'params': k1, 'dropout': k2})
431+
y2 = model.apply(variables, x, rngs={'params': k1, 'dropout': k3})
432+
assert not jnp.array_equal(y1, y2)
433+
434+
435+
381436
if __name__ == '__main__':
382437
absltest.main()
383438

tests/nnx/bridge/wrappers_test.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,10 @@ def test_linen_to_nnx(self):
5454
assert y.shape == (1, 64)
5555
self.assertIsInstance(model.kernel, nnx.Variable)
5656
# NNX automatically adds metadata box regardless of original Linen module.
57-
linen_vars = linen_module.init(jax.random.key(0), x)
58-
np.testing.assert_array_equal(linen_vars['params']['kernel'],
59-
model.kernel.value)
57+
linen_vars = {'params': {'kernel': model.kernel.value,
58+
'bias': model.bias.value}}
59+
linen_y = linen_module.apply(linen_vars, x)
60+
np.testing.assert_array_equal(y, linen_y)
6061

6162
def test_linen_to_nnx_submodule(self):
6263
class NNXOuter(nnx.Module):
@@ -468,10 +469,11 @@ def __call__(self, x):
468469
# Test the RNG
469470
model = bridge.lazy_init(NNXOuter(dout=3, dropout_rate=0.5,
470471
rngs=nnx.Rngs(default=1, dropout=2)), x)
472+
nnx.reseed(model, dropout=2)
471473
y1, y2 = model(x), model(x)
472474
# The dropout key of lowest NNX level still changes over stateful calls
473475
assert not jnp.allclose(y1, y2)
474-
# Reseed resets the RNG key back
476+
# Another reseed resets the RNG key back
475477
nnx.reseed(model, dropout=2)
476478
np.testing.assert_array_equal(y1, model(x))
477479

0 commit comments

Comments
 (0)