Can a nnx.Module be in the carry of a jax.lax.fori_loop? #4691
Replies: 1 comment
-
Ah it seems that it is necessary to use nnx.fori_loop. My bad, I missed that transform |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
I wanted to write a jax.lax.fori_loop that iteratively applies a model to some input with the model being in the carry argument. Weirdly enough, I get such an error:
for this minimal reproducible example:
I think this error is connected to the fact that Jax automatically compiles the body of a fori_loop via jax.jit, but even explicitly adding nnx.jit to the body_func does not resolve this error. Is there a better way to carry a model through a fori_loop?
There was another issue with nnx.Module and fori_loop some time ago, but the errors seem to be different: #4436
Beta Was this translation helpful? Give feedback.
All reactions