Skip to content

Improve Landing Page example #2332

Closed
@cgarciae

Description

@cgarciae

@andsteing suggests:

  • mention what nn is
  • use shape inference
  • use dataclass attributes
  • personally I find "inline submodules" a bit cryptic

maybe something like this?
(feel free to change)

    import flax.linen as nn

    class MLP(nn.Module):
      out_dims: int
      @nn.compact
      def __call__(self, x):
        x = x.reshape([x.shape[0], -1])    # shape inference
        x = nn.Dense(128)(x)               # fully connected layer
        x = nn.relu(x)
        x = nn.Dense(self.out_dims)(x)     # dataclass attributes
        return x

    model = MLP(out_dims=10)               # create model

    x = jnp.ones((1, 28, 28, 1))           # fake data
    variables = model.init(PRNGKey(42), x) # initialize weights
    y = model.apply(variables, x)          # make forward pass

Metadata

Metadata

Assignees

Labels

Priority: P1 - soonResponse within 5 business days. Resolution within 30 days. (Assignee required)

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions