sow
ing across nnx.while_loop
#4799
Unanswered
jacknewsom
asked this question in
Q&A
Replies: 1 comment
-
@jacknewsom to effectively sow inside a while loop you need to create an array with a max size and add the value you want to sow at a specific index that should increase every time you sow. To do this you can use max_size: int = ...
def init_fn(x):
return dict(
i=jnp.array(0),
x=jnp.zeros((max_size, *x.shape))
)
def reduce_fn(acc, x):
return dict(
i=acc['i'] + 1,
x=acc['x].at[i].set(x),
)
class MLP(nn.Module):
def __call__(self, x):
self.sow("intermediates", "x", x, init_fn=init_fn, reduce_fn=reduce_fn)
... The class MLP(nn.Module):
def __call__(self, x):
self.sow("intermediates", "x", x, init_fn=init_fn, reduce_fn=reduce_fn)
block = Block()
if self.is_initializing():
carry = ... # TODO
block.sow("intermediates", "carry", carry, init_fn=init_fn, reduce_fn=reduce_fn)
carry, _ = nn.scan(
Block.__call__,
variable_axes={"params": 0, "intermediates": 0},
split_rngs={"params": True},
length=self.depth,
)(block, x, None)
return carry |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hello,
Is it possible to
sow
values inside annnx.Module
called withinnnx.while_loop
? The documentation saysFor example, the snippet below fails:
with this error:
Is there any way around this that lets me use the JIT and sow in a loop?
Beta Was this translation helpful? Give feedback.
All reactions