Replies: 1 comment 1 reply
-
It looks like the issue is your |
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
I am trying to define custom vjps in jax that works with jax sparse BCOO as inputs. However, I am running into an issue where the shapes that are interpreted for the backward for the sparse array is the number of specified elements rather than the shape of the matrix.
I am not sure how best to resolve this. Will be grateful for any pointers and help in this!
Below is a simple code that replicates the core issue:
Beta Was this translation helpful? Give feedback.
All reactions