Skip to content

Commit c62c428

Browse files
committed
update landing page example
1 parent b8d1162 commit c62c428

File tree

1 file changed

+10
-8
lines changed

1 file changed

+10
-8
lines changed

docs/index.rst

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -101,19 +101,22 @@ Basic usage
101101

102102
.. testcode::
103103

104-
class MLP(nn.Module):
104+
class MLP(nn.Module): # create a Flax Module dataclass
105+
out_dims: int
106+
105107
@nn.compact
106108
def __call__(self, x):
107-
x = nn.Dense(16)(x) # inline submodules
109+
x = x.reshape((x.shape[0], -1))
110+
x = nn.Dense(128)(x) # create inline Flax Module submodules
108111
x = nn.relu(x)
109-
x = nn.Dense(16)(x) # inline submodules
112+
x = nn.Dense(self.out_dims)(x) # shape inference
110113
return x
111114

112-
model = MLP() # create model
115+
model = MLP(out_dims=10) # instantiate the MLP model
113116

114-
x = jnp.ones((4, 16)) # get some data
115-
variables = model.init(PRNGKey(42), x) # initialize weights
116-
y = model.apply(variables, x) # make forward pass
117+
x = jnp.empty((4, 28, 28, 1)) # generate random data
118+
variables = model.init(PRNGKey(42), x) # initialize the weights
119+
y = model.apply(variables, x) # make forward pass
117120

118121
----
119122

@@ -125,7 +128,6 @@ Learn more
125128
.. grid-item::
126129
:columns: 6 6 6 4
127130

128-
129131
.. card:: :material-regular:`rocket_launch;2em` Getting Started
130132
:class-card: sd-text-black sd-bg-light
131133
:link: getting_started.html

0 commit comments

Comments
 (0)