Replies: 1 comment
-
Hey @ariG23498, you can create use def put_model(model, device):
state = nnx.state(model)
state = jax.device_put(state, device)
nnx.update(model, state) |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hey folks!
Is there a one stop solution for onloading and offloading a model to and from any accelerated device (GPU, TPUs)?
I am working on a diffusion model, that has 4 models in total (2 text encoders, 1 flow models, and an autoencoder). I would like to juggle between loading and offloading the models for better memory management.
Any help would be great!
Beta Was this translation helpful? Give feedback.
All reactions