Skip to content

Accessing submodule functions #3505

Answered by chiamp
stefano-1981 asked this question in Q&A
Discussion options

You must be logged in to vote

If you define the parameter within setup, it should work:

class Int(nn.Module):

  def setup(self):
    self.w = self.param('w', nn.initializers.uniform(), (2,))
  
  def f(self, x):
    return x * self.w * 2
  
  @nn.compact
  def __call__(self, x):
    x = self.f(x)
    return x * 3

class Ext(nn.Module):
  def setup(self):
    self.i = Int()
  def __call__(self, x):
    y = self.i.f(x)  # <- Error: must be initialized in `setup()`
    x = self.i(x)
    return 2 * x, y

ext = Ext()
params = ext.init(jax.random.PRNGKey(41944), jnp.ones((1,), dtype=jnp.float32))
print(ext.apply(params, jnp.ones((1,))))

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@stefano-1981
Comment options

Answer selected by stefano-1981
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants