Skip to content

Move capture code closer to its equivalent use case #7608

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,9 @@ Here's a list of deprecations made this release. For a more detailed breakdown o

<h3>Internal changes ⚙️</h3>

* Move program capture code closer to where it is used.
[(#7608)][https://github.com/PennyLaneAI/pennylane/pull/7608]

* Tests using `OpenFermion` in `tests/qchem` do not fail with NumPy>=2.0.0 any more.
[(#7626)](https://github.com/PennyLaneAI/pennylane/pull/7626)

Expand Down
100 changes: 91 additions & 9 deletions pennylane/_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,27 +15,109 @@
This module contains the autograd wrappers :class:`grad` and :func:`jacobian`
"""
import warnings
from functools import partial, wraps
from functools import lru_cache, partial, wraps

from autograd import jacobian as _jacobian
from autograd.core import make_vjp as _make_vjp
from autograd.extend import vspace
from autograd.numpy.numpy_boxes import ArrayBox
from autograd.wrap_util import unary_to_nary

from pennylane.capture import determine_abstracted_axes, enabled
from pennylane.capture.capture_diff import _get_grad_prim, _get_jacobian_prim
from pennylane.capture.flatfn import FlatFn
from pennylane import capture
from pennylane.compiler import compiler
from pennylane.compiler.compiler import CompileError

make_vjp = unary_to_nary(_make_vjp)


has_jax = True
try:
import jax
except ImportError:
has_jax = False


@lru_cache
def _get_grad_prim():
"""Create a primitive for gradient computations.
This primitive is used when capturing ``qml.grad``.
"""
if not has_jax: # pragma: no cover
return None

grad_prim = capture.QmlPrimitive("grad")
grad_prim.multiple_results = True
grad_prim.prim_type = "higher_order"

@grad_prim.def_impl
def _(*args, argnum, jaxpr, n_consts, method, h):
if method or h: # pragma: no cover
raise ValueError(f"Invalid values '{method=}' and '{h=}' without QJIT.")
consts = args[:n_consts]
args = args[n_consts:]

def func(*inner_args):
return jax.core.eval_jaxpr(jaxpr, consts, *inner_args)[0]

return jax.grad(func, argnums=argnum)(*args)

# pylint: disable=unused-argument
@grad_prim.def_abstract_eval
def _(*args, argnum, jaxpr, n_consts, method, h):
if len(jaxpr.outvars) != 1 or jaxpr.outvars[0].aval.shape != ():
raise TypeError("Grad only applies to scalar-output functions. Try jacobian.")
return tuple(args[i + n_consts] for i in argnum)

return grad_prim


def _shape(shape, dtype):
if jax.config.jax_dynamic_shapes and any(not isinstance(s, int) for s in shape):
return jax.core.DShapedArray(shape, dtype)
return jax.core.ShapedArray(shape, dtype)


@lru_cache
def _get_jacobian_prim():
"""Create a primitive for Jacobian computations.
This primitive is used when capturing ``qml.jacobian``.
"""
if not has_jax: # pragma: no cover
return None

jacobian_prim = capture.QmlPrimitive("jacobian")
jacobian_prim.multiple_results = True
jacobian_prim.prim_type = "higher_order"

@jacobian_prim.def_impl
def _(*args, argnum, jaxpr, n_consts, method, h):
if method or h: # pragma: no cover
raise ValueError(f"Invalid values '{method=}' and '{h=}' without QJIT.")
consts = args[:n_consts]
args = args[n_consts:]

def func(*inner_args):
return jax.core.eval_jaxpr(jaxpr, consts, *inner_args)

return jax.tree_util.tree_leaves(jax.jacobian(func, argnums=argnum)(*args))

# pylint: disable=unused-argument
@jacobian_prim.def_abstract_eval
def _(*args, argnum, jaxpr, n_consts, method, h):
in_avals = tuple(args[i + n_consts] for i in argnum)
out_shapes = tuple(outvar.aval.shape for outvar in jaxpr.outvars)
return [
_shape(out_shape + in_aval.shape, in_aval.dtype)
for out_shape in out_shapes
for in_aval in in_avals
]

return jacobian_prim


def _capture_diff(func, argnum=None, diff_prim=None, method=None, h=None):
"""Capture-compatible gradient computation."""
# pylint: disable=import-outside-toplevel
import jax
from jax.tree_util import tree_flatten, tree_leaves, tree_unflatten, treedef_tuple

if argnum is None:
Expand Down Expand Up @@ -70,9 +152,9 @@ def new_func(*args, **kwargs):
flat_argnum = sum(flat_argnum_gen, start=[])

# Create fully flattened function (flat inputs & outputs)
flat_fn = FlatFn(partial(func, **kwargs) if kwargs else func, full_in_tree)
flat_fn = capture.FlatFn(partial(func, **kwargs) if kwargs else func, full_in_tree)
flat_args = sum(flat_args, start=[])
abstracted_axes, abstract_shapes = determine_abstracted_axes(tuple(flat_args))
abstracted_axes, abstract_shapes = capture.determine_abstracted_axes(tuple(flat_args))
jaxpr = jax.make_jaxpr(flat_fn, abstracted_axes=abstracted_axes)(*flat_args)

num_abstract_shapes = len(abstract_shapes)
Expand Down Expand Up @@ -165,7 +247,7 @@ def __new__(cls, func, argnum=None, method=None, h=None):
ops_loader = available_eps[active_jit]["ops"].load()
return ops_loader.grad(func, method=method, h=h, argnums=argnum)

if enabled():
if capture.enabled():
return _capture_diff(func, argnum, _get_grad_prim(), method=method, h=h)

if method or h: # pragma: no cover
Expand Down Expand Up @@ -496,7 +578,7 @@ def circuit(x):
ops_loader = available_eps[active_jit]["ops"].load()
return ops_loader.jacobian(func, method=method, h=h, argnums=argnum)

if enabled():
if capture.enabled():
return _capture_diff(func, argnum, _get_jacobian_prim(), method=method, h=h)

if method or h:
Expand Down
26 changes: 7 additions & 19 deletions pennylane/capture/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,6 @@
~enable
~enabled
~pause
~create_operator_primitive
~create_measurement_obs_primitive
~create_measurement_wires_primitive
~create_measurement_mcm_primitive
~determine_abstracted_axes
~expand_plxpr_transforms
~eval_jaxpr
Expand Down Expand Up @@ -162,16 +158,10 @@ class MyCustomOp(qml.operation.Operator):
def _(*args, **kwargs):
return type.__call__(MyCustomOp, *args, **kwargs)
"""
from typing import Callable
from typing import Callable, Type

from .switches import disable, enable, enabled, pause
from .capture_meta import CaptureMeta, ABCCaptureMeta
from .capture_operators import create_operator_primitive
from .capture_measurements import (
create_measurement_obs_primitive,
create_measurement_wires_primitive,
create_measurement_mcm_primitive,
)
from .flatfn import FlatFn
from .make_plxpr import make_plxpr, run_autograph
from .dynamic_shapes import determine_abstracted_axes, register_custom_staging_rule
Expand All @@ -185,14 +175,16 @@ def _(*args, **kwargs):
PlxprInterpreter: type
expand_plxpr_transforms: Callable[[Callable], Callable]
eval_jaxpr: Callable
QmlPrimitive: "Type[jax.extend.core.Primitive]"


class CaptureError(Exception):
"""Errors related to PennyLane's Program Capture execution pipeline."""
# pylint: disable=import-outside-toplevel, redefined-outer-name, too-many-return-statements
def __getattr__(key):
if key == "QmlPrimitive":
from .custom_primitives import QmlPrimitive

return QmlPrimitive

# pylint: disable=import-outside-toplevel, redefined-outer-name
def __getattr__(key):
if key == "AbstractOperator":
from .primitives import _get_abstract_operator

Expand Down Expand Up @@ -233,10 +225,6 @@ def __getattr__(key):
"eval_jaxpr",
"CaptureMeta",
"ABCCaptureMeta",
"create_operator_primitive",
"create_measurement_obs_primitive",
"create_measurement_wires_primitive",
"create_measurement_mcm_primitive",
"determine_abstracted_axes",
"expand_plxpr_transforms",
"register_custom_staging_rule",
Expand Down
106 changes: 0 additions & 106 deletions pennylane/capture/capture_diff.py

This file was deleted.

Loading