Closed
Description
Repro:
import flax.linen as nn
import flax.nnx as nnx
class LinenModule(nn.Module):
@nn.compact
def __call__(self):
if not self.is_initializing() and self.is_mutable_collection("cache"):
self.put_variable("cache", "x", 0)
return self.get_variable("cache", "x")
class NNXModule(nnx.Module):
def __init__(self):
self.module = nnx.bridge.ToNNX(LinenModule()).lazy_init()
def __call__(self):
result1 = self.module(mutable=["cache"])
assert result1 == 0
result2 = self.module()
assert result2 == 0, result2 # fails: result2 is None
module = NNXModule()
module()
I believe this is because self.linen_attributes
is never updated after initialization, see here.
Metadata
Metadata
Assignees
Labels
No labels