Is there a way to initialize the nnx layers dynamically? #4365
Unanswered
yCobanoglu
asked this question in
Q&A
Replies: 1 comment
-
Hey @yCobanoglu, great question! We get this a lot as a downside of having explicit initialization. The nice thing is that you can infer the hard to compute constants by using class CNN(nnx.Module):
"""A simple CNN model."""
def __init__(self, x, *, rngs: nnx.Rngs):
self.conv1 = nnx.Conv(1, 32, kernel_size=(3, 3), rngs=rngs)
self.conv2 = nnx.Conv(32, 64, kernel_size=(3, 3), rngs=rngs)
self.avg_pool = partial(nnx.avg_pool, window_shape=(2, 2), strides=(2, 2))
# use `eval_shape` to compute the number of flat features without running the model
flat_features = nnx.eval_shape(CNN._get_flat_features, self, x).shape[-1]
self.linear1 = nnx.Linear(flat_features, 256, rngs=rngs)
self.linear2 = nnx.Linear(256, 10, rngs=rngs)
def _get_flat_features(self, x):
x = self.avg_pool(nnx.relu(self.conv1(x)))
x = self.avg_pool(nnx.relu(self.conv2(x)))
x = x.reshape(x.shape[0], -1)
return x
def __call__(self, x):
x = self.avg_pool(nnx.relu(self.conv1(x)))
x = self.avg_pool(nnx.relu(self.conv2(x)))
x = x.reshape(x.shape[0], -1) # flatten
x = nnx.relu(self.linear1(x))
x = self.linear2(x)
return x
sample_x = jnp.ones((1, 64, 64, 1))
model = CNN(sample_x, rngs=nnx.Rngs(0)) Here I'm duplicating some of the forward pass but maybe you could even refactor the model. |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
https://flax.readthedocs.io/en/v0.8.3/experimental/nnx/mnist_tutorial.html
This model is from the tutorial and the Linear1 layers input size is fixed which makes it annoying to train this model on a different dataset. Is there a way to lazy init somehow ?
Making self.linear1=None then initializing on the first pass in call with an if-else causes this error:
https://flax.readthedocs.io/en/v0.8.3/experimental/nnx/mnist_tutorial.html
Beta Was this translation helpful? Give feedback.
All reactions