Skip to content

Commit 8254dd0

Browse files
author
Flax Authors
committed
Merge pull request google#4584 from IvyZX:linx-misc
PiperOrigin-RevId: 733916659
2 parents 0769411 + 06a50e8 commit 8254dd0

File tree

4 files changed

+151
-5
lines changed

4 files changed

+151
-5
lines changed

flax/nnx/bridge/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,7 @@
2424
from .module import Module as Module
2525
from .module import Scope as Scope
2626
from .module import compact as compact
27-
from flax.nnx.nn import initializers as initializers
27+
from .module import current_context as current_context
28+
from .module import current_module as current_module
29+
from .interop import wrap_nnx_mdl as wrap_nnx_mdl
30+
from flax.nnx.nn import initializers as initializers

flax/nnx/bridge/interop.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Copyright 2025 The Flax Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import typing as tp
16+
17+
from flax.nnx import graph, rnglib
18+
from flax.nnx.bridge import module as bdg_module
19+
import flax.nnx.module as nnx_module
20+
from flax.nnx.transforms.transforms import eval_shape as nnx_eval_shape
21+
from flax.nnx.transforms.compilation import jit as nnx_jit
22+
23+
24+
def wrap_nnx_mdl(factory: tp.Callable[[rnglib.Rngs], nnx_module.Module],
25+
name: str | None = None):
26+
"""Create module at init time, or make abstract module and let parent bind it with its state. Use current bridge module scope for RNG generation."""
27+
parent_ctx, parent = bdg_module.current_context(), bdg_module.current_module()
28+
assert parent_ctx is not None and parent is not None, 'wrap_nnx_mdl only needed inside bridge Module'
29+
parent = parent_ctx.module
30+
assert parent.scope is not None
31+
32+
if parent.is_initializing():
33+
module = factory(parent.scope.rngs)
34+
else:
35+
rngs = parent.scope.rngs if parent.scope.rngs else rnglib.Rngs(7) # dummy
36+
module = nnx_eval_shape(factory, rngs)
37+
38+
@nnx_jit
39+
def rng_state(rngs):
40+
return graph.state(factory(rngs), rnglib.RngState)
41+
42+
# Make sure the internal rng state is not abstract - other vars shall be
43+
if parent.scope.rngs:
44+
graph.update(module, rng_state(parent.scope.rngs))
45+
46+
# Automatically set the attribute if compact. If setup, user is responsible
47+
# for setting the attribute of the superlayer.
48+
if parent_ctx.in_compact:
49+
if name is None:
50+
name = bdg_module._auto_submodule_name(parent_ctx, type(module))
51+
setattr(parent, name, module)
52+
return module

flax/nnx/bridge/module.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,25 @@ def _bind_module(parent: Module, module: Module) -> Module:
106106
return module
107107

108108

109+
def current_context() -> ModuleStackEntry | None:
110+
return MODULE_CONTEXT.module_stack[-1]
111+
112+
113+
def current_module() -> Module | None:
114+
"""A quick util to get the current bridge module."""
115+
ctx = current_context()
116+
if ctx is None:
117+
return None
118+
return ctx.module
119+
120+
121+
def _auto_submodule_name(parent_ctx, cls):
122+
"""Increment type count and generate a new submodule name."""
123+
type_index = parent_ctx.type_counter[cls]
124+
parent_ctx.type_counter[cls] += 1
125+
return f'{cls.__name__}_{type_index}'
126+
127+
109128
class ModuleMeta(nnx_module.ModuleMeta):
110129

111130
def _object_meta_construct(cls, self, *args, **kwargs):
@@ -134,15 +153,12 @@ def _module_meta_call(cls: type[M], *args, **kwargs) -> M:
134153
f"'parent' can only be set to None, got {type(parent).__name__}"
135154
)
136155
else:
137-
type_index = parent_ctx.type_counter[cls]
138-
parent_ctx.type_counter[cls] += 1
139-
140156
if 'name' in kwargs:
141157
name = kwargs.pop('name')
142158
if not isinstance(name, str):
143159
raise ValueError(f"'name' must be a 'str', got {type(name).__name__}")
144160
else:
145-
name = f'{cls.__name__}_{type_index}'
161+
name = _auto_submodule_name(parent_ctx, cls)
146162
parent = parent_ctx.module
147163

148164
module = nnx_module.ModuleMeta.__call__(cls, *args, **kwargs)

tests/nnx/bridge/module_test.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,81 @@ def __call__(self, x):
302302
y: jax.Array = foo.apply(variables, x)
303303
self.assertEqual(y.shape, (3, 5))
304304

305+
def test_pure_nnx_submodule(self):
306+
class NNXLayer(nnx.Module):
307+
def __init__(self, dim, dropout, rngs):
308+
self.linear = nnx.Linear(dim, dim, use_bias=False, rngs=rngs)
309+
self.dropout = nnx.Dropout(dropout, rngs=rngs)
310+
self.count = nnx.Intermediate(jnp.array([0.]))
311+
def __call__(self, x):
312+
# Required check to avoid state update in `init()`. Can this be avoided?
313+
if not bridge.current_module().is_initializing():
314+
self.count.value = self.count.value + 1
315+
x = self.linear(x)
316+
x = self.dropout(x)
317+
return x
318+
319+
class BridgeMLP(bridge.Module):
320+
@bridge.compact
321+
def __call__(self, x):
322+
x = nnx.bridge.wrap_nnx_mdl(lambda r: NNXLayer(8, 0.3, rngs=r))(x)
323+
x = nnx.bridge.wrap_nnx_mdl(
324+
lambda r: NNXLayer(8, 0.3, rngs=r), name='another')(x)
325+
return x
326+
327+
model = BridgeMLP()
328+
x = jax.random.normal(jax.random.key(0), (4, 8))
329+
variables = model.init(jax.random.key(1), x)
330+
self.assertSameElements(variables['params'].keys(),
331+
['NNXLayer_0', 'another'])
332+
self.assertFalse(jnp.array_equal(
333+
variables['params']['NNXLayer_0']['linear']['kernel'],
334+
variables['params']['another']['linear']['kernel'], ))
335+
self.assertEqual(variables['intermediates']['NNXLayer_0']['count'], 0)
336+
337+
k1, k2, k3 = jax.random.split(jax.random.key(0), 3)
338+
y1 = model.apply(variables, x, rngs={'params': k1, 'dropout': k2})
339+
y2 = model.apply(variables, x, rngs={'params': k1, 'dropout': k3})
340+
assert not jnp.array_equal(y1, y2)
341+
342+
_, updates = model.apply(variables, x, rngs={'params': k1, 'dropout': k3},
343+
mutable=True)
344+
self.assertEqual(updates['intermediates']['NNXLayer_0']['count'], 1)
345+
346+
class BridgeMLPSetup(bridge.Module):
347+
def setup(self):
348+
self.layer = nnx.bridge.wrap_nnx_mdl(lambda r: NNXLayer(8, 0.3, rngs=r))
349+
def __call__(self, x):
350+
return self.layer(x)
351+
352+
model = BridgeMLPSetup()
353+
variables = model.init(jax.random.key(1), x)
354+
self.assertSameElements(variables['params'].keys(), ['layer'])
355+
y1 = model.apply(variables, x, rngs={'params': k1, 'dropout': k2})
356+
y2 = model.apply(variables, x, rngs={'params': k1, 'dropout': k3})
357+
assert not jnp.array_equal(y1, y2)
358+
359+
def test_pure_nnx_submodule_modified_rng(self):
360+
class FooStack(nnx.Module):
361+
def __init__(self, in_dim, key):
362+
keys = jax.random.split(key, in_dim)
363+
self.rngs = nnx.Rngs(keys)
364+
def __call__(self, x):
365+
@nnx.vmap
366+
def generate_weights(r):
367+
return jax.random.normal(r.default(), (2,))
368+
w = generate_weights(self.rngs)
369+
return x @ w
370+
371+
class BridgeFoo(bridge.Module):
372+
@bridge.compact
373+
def __call__(self, x):
374+
x = nnx.bridge.wrap_nnx_mdl(lambda r: FooStack(4, r.default()))(x)
375+
return x
376+
377+
model = BridgeFoo()
378+
v = model.init(jax.random.key(1), jnp.ones((1, 4)))
379+
y = model.apply(v, jnp.ones((1, 4)), rngs=jax.random.key(1))
305380

306381
if __name__ == '__main__':
307382
absltest.main()

0 commit comments

Comments
 (0)