How to combine jax.experimental.custom_dce with JVP and transpose rules from jax.extend.core.Primitive ? #26851
Replies: 1 comment
-
Nevermind, I need to use the underlying |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
I have a
jax.extend.core.Primitive
primitive that implements JVP, transpose, batching.I want to add custom_dce to my primitive because I'm using
jax.ffi.ffi_call
insidepartial(segmented_polynomial_impl, "cuda")
.I tried to add custom_dce:
segmented_polynomial_impl
but it gets ignoredsegmented_polynomial_p.bind
, it works when I test the DCE but it crash with the following error when I use ADBeta Was this translation helpful? Give feedback.
All reactions