Replies: 1 comment
-
Hey @OhadRubin, lets say you have this MLP that defines some import flax.linen as nn
import jax.numpy as jnp
import jax
class MLPBlock(nn.Module):
features: int
def setup(self):
self.dense = nn.Dense(self.features)
def __call__(self, x):
return nn.relu(self.dense(x))
class MLP(nn.Module):
n_layers: int
features: int
def setup(self):
self.layers = [MLPBlock(self.features) for _ in range(self.n_layers)]
def __call__(self, x):
for layer in self.layers:
x = layer(x)
return x
x = jnp.ones((3, 5))
module = MLP(10, 5)
y, variables = module.init_with_output(jax.random.PRNGKey(0), x)
print("Regular MLP")
print(jax.tree_map(jnp.shape, variables))
print(y.shape)
print() You can refactor to use class MLPScan(nn.Module):
n_layers: int
features: int
def setup(self):
Layers = nn.remat_scan(
MLPBlock, variable_axes={'params': 0},
split_rngs={'params': True}, lengths=(self.n_layers,))
self.layers = Layers(self.features)
def __call__(self, x):
return self.layers(x)
x = jnp.ones((3, 5))
module = MLPScan(10, 5)
y, variables = module.init_with_output(jax.random.PRNGKey(0), x)
print("MLPScan")
print(jax.tree_map(jnp.shape, variables))
print(y.shape)
print() Now |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hey,
My code is using the self.setup to define parameters, and compilation is really slow, but it doesn't seem very clear on how to refactor it to use remat_scan since all the examples are using nn.compact.
I would also like to emphasize that I would like to keep the self.setup declaration so I would be able to run models i've already trained.
Beta Was this translation helpful? Give feedback.
All reactions