Replies: 1 comment
-
I think that the Take a look at those docs to see if you can get what you need from that! |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
I'm writing a custom JVP for a non-differentiable function. But the gradient can be estimated in a computationally expensive way for each argument. So I would like to restrict the gradient computation for only those components which require it. Is there a recommended way to do this?
Code example:
I appreciate the for loop is best formatted in terms of a lax scan/foriloop: I'm just trying to illustrate with some pseudo code to get the question across clearly.
My question is: what suitable for should
x_i_gradient_needed
take in the above code to avoid the computation when its a gradient with respect tox[i]
is not needed? Is it sufficient for it to be something like likex_dot[i] != 0.0
? Or is there a better jax-esque way of doing this?Thanks for any help!
Beta Was this translation helpful? Give feedback.
All reactions