Closed
Description
@andsteing suggests:
- mention what
nn
is - use shape inference
- use dataclass attributes
- personally I find "inline submodules" a bit cryptic
maybe something like this?
(feel free to change)
import flax.linen as nn
class MLP(nn.Module):
out_dims: int
@nn.compact
def __call__(self, x):
x = x.reshape([x.shape[0], -1]) # shape inference
x = nn.Dense(128)(x) # fully connected layer
x = nn.relu(x)
x = nn.Dense(self.out_dims)(x) # dataclass attributes
return x
model = MLP(out_dims=10) # create model
x = jnp.ones((1, 28, 28, 1)) # fake data
variables = model.init(PRNGKey(42), x) # initialize weights
y = model.apply(variables, x) # make forward pass