Replies: 1 comment 5 replies
-
Hey, we currently don't have a def load_model(path: str) -> MLP:
# create that model with abstract shapes
model = nnx.eval_shape(lambda: create_model(0))
state = nnx.state(model)
# Load the parameters
checkpointer = orbax.PyTreeCheckpointer()
state = checkpointer.restore(f'{path}/state', item=state)
# update the model with the loaded state
nnx.update(model, state)
return model This is taken from this 08_save_load_checkpoints.py. |
Beta Was this translation helpful? Give feedback.
5 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hi,
I want to be able to serialize multiple different
nnx
models to disk (not just the weights but also the full layer structure).This is helpful when trying out a bunch of different architectures that I trained beforehand and just want to test in eval/inference mode.
Currently, I am using Orbax to save the model train state. But this requires that the trainstate structure is created before loading the checkpoint. I am doing something like this:
What I would like to do is something like this:
This may be possible with pickling the create_model function (but could not work because of lambda functions in the create_model function), but I guess this is not the idiomatic way.
In torch you can do:
I basically would like to use the same kind of API as in torch but with
flax/nnx
.Beta Was this translation helpful? Give feedback.
All reactions