Replies: 1 comment 1 reply
-
Hey @jlperla, the easiest way is to use params = nnx.state(model, nnx.Param)
total_params = sum(np.prod(x.shape) for x in jax.tree.leaves(params), 0) |
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
I want to get a count of all of the trainable parameters for a module. In pytorch I can use
model.parameters
and then recursively check if a gradient is required, etc.How would I do this with a recursive function and filters in nnx?
Beta Was this translation helpful? Give feedback.
All reactions