Accessing submodule functions #3505
-
Hi, I have a portion of an inner module (say written in a function) that I would like to access from a container module. I first tried this, but didn't work: import flax.linen as nn
class Int(nn.Module):
def f(self, x):
w = self.param('w', nn.initializers.uniform(), (2,))
return x * w * 2
@nn.compact
def __call__(self, x):
x = self.f(x)
return x * 3
class Ext(nn.Module):
def setup(self):
self.i = Int()
def __call__(self, x):
y = self.i.f(x) # <- Error: must be initialized in `setup()`
x = self.i(x)
return 2 * x, y
ext = Ext()
params = ext.init(jax.random.PRNGKey(41944), jnp.ones((1,), dtype=jnp.float32))
print (ext.apply(params, jnp.ones((1,)))) I thought class Int(nn.Module):
def f(self, x):
w = self.param('w', nn.initializers.uniform(), (2,))
return x * w * 2
@nn.compact
def __call__(self, x, only_f:bool):
x = self.f(x)
if only_f:
return x
return x * 3
class Ext(nn.Module):
@nn.compact
def __call__(self, x):
i = Int()
y = i(x, only_f=True)
x = i(x, only_f=False)
return 2 * x, y but feels a bit ugly. I need a small bit of the internal module (in my case, just a convolution) and I would prefer not adding an ad-hoc parameter. Is there a proper way to achieve this? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
If you define the parameter within
|
Beta Was this translation helpful? Give feedback.
If you define the parameter within
setup
, it should work: