Closed
Description
Discussed in #4433
Originally posted by onnoeberhard December 13, 2024
I want to train two models at the same time. To do this, I use a fori_loop
:
import jax
from flax import nnx
model = nnx.Linear(2, 2, rngs=nnx.Rngs(jax.random.PRNGKey(0)))
model2 = nnx.Linear(2, 2, rngs=nnx.Rngs(jax.random.PRNGKey(1)))
def f(i, x):
return x
nnx.fori_loop(0, 10, f, (model, model2))
The above code throws the following error: ValueError: nnx.fori_loop requires body function's input and output to have the same reference and pytree structure, but they differ. If the mismatch comes from index_mapping field, you might have modified reference structure within the body function, which is not allowed.
If I loop with only one model, for example nnx.fori_loop(0, 10, f, (model, model))
, there is no error. What is the problem here?
Metadata
Metadata
Assignees
Labels
No labels