Description
I can't seem to get a NNX module to work with FlaxLayer. Normal flax works fine, but when I create a nnx.Module and called nnx.bridge.ToLinen I get the error at the botttom.
class MyFlax(nn.Module):
@nn.compact
def call(self, x):
return nn.Dense(features=1)(x)
l = MyFlax()
f = keras.layers.FlaxLayer(l)
f(jnp.ones(1))
class MyNnx(nnx.Module):
def init(self, *, rngs: nnx.Rngs):
self.l = nnx.Linear(1, 1, rngs=rngs)
def call(self, **kwargs):
print( kwargs)
if 'inputs' in kwargs:
return self.l(kwargs['inputs'])
MyNnx(rngs=nnx.Rngs(0))(inputs=jnp.ones(1))
l = nnx.bridge.ToLinen(MyNnx(rngs=nnx.Rngs(0)))
f = keras.layers.FlaxLayer(l)
f(inputs=jnp.ones(1))
Cell In[101], line 1
----> 1 f(inputs=jnp.ones(1))
File ~/projects/joe/.venv/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py:122, in filter_traceback..error_handler(*args, **kwargs)
119 filtered_tb = _process_traceback_frames(e.traceback)
120 # To get the full stack trace, call:
121 # keras.config.disable_traceback_filtering()
--> 122 raise e.with_traceback(filtered_tb) from None
123 finally:
124 del filtered_tb
[... skipping hidden 9 frame]
File ~/projects/joe/.venv/lib/python3.11/site-packages/flax/nnx/bridge/wrappers.py:263, in ToLinen.call(self, *args, **kwargs)
260 # TODO: add lazy_init here in case there's an ToNNX
submodule under module
.
261 # update linen variables before call module to save initial state
262 self._update_variables(module)
--> 263 out = module(*args, **kwargs)
264 return out
266 # create state
TypeError: 'NoneType' object is not callable