How to access the gradient value when using flax.nnx.value_and_grad #26863
Unanswered
adamhadani
asked this question in
Q&A
Replies: 1 comment 2 replies
-
Hi @adamhadani, from flax import nnx
import jax.numpy as jnp
class Foo(nnx.Module):
def __init__(self):
self.a = nnx.Param(jnp.ones((2, 2)))
def loss_fn(model):
return jnp.sum(model.a.value)
model = Foo()
grads = nnx.grad(loss_fn)(model)
print(grads) Output:
|
Beta Was this translation helpful? Give feedback.
2 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hello,
I have recently started using the jax/flax/optax stack (stax?) for some deep reinforcement learning projects.
I ran into the following issue pretty early on. Notice I'm using the
flax.nnx.value_and_grad
"lifted" version ofjax.value_and_grad
that provides some boilerplate for state management and working withnnx.Module
s. A single training step looks like this for example:I would like to track the gradient norms (e.g. for tracking in tensorboard). how do I access the actual numeric value for the computed gradients? the
grads
value returned here seems to be a function/callable that is used by theoptimizer.update
rather than the actual computed gradient?Sorry if this is a silly question or belongs in some different flax forum!
Cheers,
Adam
Beta Was this translation helpful? Give feedback.
All reactions