File tree Expand file tree Collapse file tree 1 file changed +10
-8
lines changed Expand file tree Collapse file tree 1 file changed +10
-8
lines changed Original file line number Diff line number Diff line change @@ -101,19 +101,22 @@ Basic usage
101
101
102
102
.. testcode ::
103
103
104
- class MLP(nn.Module):
104
+ class MLP(nn.Module): # create a Flax Module dataclass
105
+ out_dims: int
106
+
105
107
@nn.compact
106
108
def __call__(self, x):
107
- x = nn.Dense(16)(x) # inline submodules
109
+ x = x.reshape((x.shape[0], -1))
110
+ x = nn.Dense(128)(x) # create inline Flax Module submodules
108
111
x = nn.relu(x)
109
- x = nn.Dense(16 )(x) # inline submodules
112
+ x = nn.Dense(self.out_dims )(x) # shape inference
110
113
return x
111
114
112
- model = MLP() # create model
115
+ model = MLP(out_dims=10 ) # instantiate the MLP model
113
116
114
- x = jnp.ones ((4, 16 )) # get some data
115
- variables = model.init(PRNGKey(42), x) # initialize weights
116
- y = model.apply(variables, x) # make forward pass
117
+ x = jnp.empty ((4, 28, 28, 1 )) # generate random data
118
+ variables = model.init(PRNGKey(42), x) # initialize the weights
119
+ y = model.apply(variables, x) # make forward pass
117
120
118
121
----
119
122
@@ -125,7 +128,6 @@ Learn more
125
128
.. grid-item ::
126
129
:columns: 6 6 6 4
127
130
128
-
129
131
.. card :: :material-regular:`rocket_launch;2em` Getting Started
130
132
:class-card: sd-text-black sd-bg-light
131
133
:link: getting_started.html
You can’t perform that action at this time.
0 commit comments