Skip to content

Commit 21e64ec

Browse files
author
Flax Authors
committed
Merge pull request #4766 from IvyZX:bdg
PiperOrigin-RevId: 771171776
2 parents 893a660 + 6a6561e commit 21e64ec

File tree

2 files changed

+17
-2
lines changed

2 files changed

+17
-2
lines changed

flax/nnx/bridge/wrappers.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import dataclasses
16+
from functools import partial
1617
import typing as tp
1718
from typing import Any
1819

@@ -133,6 +134,16 @@ def lazy_init(self, *args, **kwargs):
133134
"""A shortcut of calling `nnx.bridge.lazy_init()` upon this module."""
134135
return lazy_init(self, *args, **kwargs)
135136

137+
def __getattr__(self, name: str):
138+
if hasattr(super(), name):
139+
return super().__getattribute__(name)
140+
maybe_method = getattr(self.module.__class__, name, None)
141+
if callable(maybe_method):
142+
method = partial(self.__call__, method=maybe_method)
143+
method.__self__ = self
144+
return method
145+
return super().__getattribute__(name)
146+
136147
def __call__(
137148
self, *args: Any, rngs: tp.Optional[Rngs] = None,
138149
method: tp.Callable[..., Any] | str | None = None, **kwargs: Any

tests/nnx/bridge/wrappers_test.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,14 +97,18 @@ def dot(self, x):
9797
w = self.param('w', nn.initializers.lecun_normal(), (4, 3))
9898
return x @ w
9999

100+
def rngs(self):
101+
raise ValueError('This should not be called because ToNNX has .rngs')
102+
100103
x = jax.random.normal(jax.random.key(0), (2, 4))
101104
model = bridge.ToNNX(Foo(), rngs=nnx.Rngs(0))
102-
bridge.lazy_init(model, x, method=model.module.dot)
103-
y = model(x, method=model.module.dot)
105+
bridge.lazy_init(model.dot, x)
106+
y = model.dot(x)
104107
np.testing.assert_allclose(y, x @ nnx.state(model)['w'].value)
105108
# lazy_init only initialized param w inside dot(), so calling __call__ should fail
106109
with self.assertRaises(flax.errors.ScopeParamNotFoundError):
107110
y = model(x)
111+
assert isinstance(model.rngs, nnx.Rngs)
108112

109113
def test_linen_to_nnx_mutable(self):
110114
class Foo(nn.Module):

0 commit comments

Comments
 (0)