Skip to content

[JOSS 6532] Example from README fails with TypeError #9

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
dfm opened this issue Mar 29, 2024 · 3 comments · Fixed by #11
Closed

[JOSS 6532] Example from README fails with TypeError #9

dfm opened this issue Mar 29, 2024 · 3 comments · Fixed by #11
Assignees

Comments

@dfm
Copy link

dfm commented Mar 29, 2024

When running the example code from the README:

Full code snippet
import jaxbind

def f(out, args, kwargs_dump):
    kwargs = jaxbind.load_kwargs(kwargs_dump)
    x1, x2 = args
    out[0][()] = x1 * x2**2

def f_jvp(out, args, kwargs_dump):
    kwargs = jaxbind.load_kwargs(kwargs_dump)
    x1, x2, dx1, dx2 = args
    out[0][()] = x2**2 * dx1 + 2 * x1 * x2 * dx2

def f_vjp(out, args, kwargs_dump):
    kwargs = jaxbind.load_kwargs(kwargs_dump)
    x1, x2, dy = args
    out[0][()] = x2**2 * dy
    out[1][()] = 2 * x1 * x2 * dy

def f_abstract(*args, **kwargs):
    assert args[0].shape == args[1].shape
    return ((args[0].shape, args[0].dtype),)

def f_abstract_T(*args, **kwargs):
    return (
        (args[0].shape, args[0].dtype),
        (args[0].shape, args[0].dtype),
    )

f_jax = jaxbind.get_nonlinear_call(
    f, (f_jvp, f_vjp), f_abstract, f_abstract_T
)

import jax
import jax.numpy as jnp

inp = (jnp.full((4,3), 4.), jnp.full((4,3), 2.))
tan = (jnp.full((4,3), 1.), jnp.full((4,3), 1.))
res, res_tan = jax.jvp(f_jax, inp, tan)

cotan = (jnp.full((4,3), 6.),)
res, f_vjp = jax.vjp(f_jax, *inp)
res_cotan = f_vjp(cotan)

I get the following TypeError when calling f_vjp:

Full error message
---------------------------------------------------------------------------

TypeError                                 Traceback (most recent call last)

[<ipython-input-4-dae6bad58eaa>](https://localhost:8080/#) in <cell line: 10>()
      8 cotan = (jnp.full((4,3), 6.),)
      9 res, f_vjp = jax.vjp(f_jax, *inp)
---> 10 res_cotan = f_vjp(cotan)
     11 
     12 f_jax_jit = jax.jit(f_jax)

1 frames

[/usr/local/lib/python3.10/dist-packages/jax/_src/tree_util.py](https://localhost:8080/#) in __call__(self, *args, **kw)
    355 
    356   def __call__(self, *args, **kw):
--> 357     return self.fun(*args, **kw)
    358 
    359   def __hash__(self):

[/usr/local/lib/python3.10/dist-packages/jax/_src/api.py](https://localhost:8080/#) in _vjp_pullback_wrapper(name, cotangent_dtypes, cotangent_shapes, io_tree, fun, *py_args_)
   2132   args, in_tree = tree_flatten(py_args)
   2133   if in_tree != in_tree_expected:
-> 2134     raise TypeError(f"Tree structure of cotangent input {in_tree}, does not match structure of "
   2135                     f"primal output {in_tree_expected}.")
   2136   for arg, ct_dtype, ct_shape in safe_zip(args, cotangent_dtypes, cotangent_shapes):

TypeError: Tree structure of cotangent input PyTreeDef((*,)), does not match structure of primal output PyTreeDef([*]).

I'm running this test in a fresh Google Colab environment with JAX v0.4.23 installed.

ref: openjournals/joss-reviews#6532

@Edenhofer
Copy link
Contributor

Thanks for the report! The co-tangent cotan needs to be a list instead of a tuple because the output always is a list. #11 updates the README accordingly.

@Edenhofer Edenhofer self-assigned this Mar 29, 2024
@roth-jakob
Copy link
Collaborator

Thanks for noticing that!

@dfm
Copy link
Author

dfm commented Mar 29, 2024

I can confirm that #11 fixes my issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants