diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md
index 99d3b6b63e5..329353e326e 100644
--- a/doc/releases/changelog-dev.md
+++ b/doc/releases/changelog-dev.md
@@ -567,6 +567,9 @@ Here's a list of deprecations made this release. For a more detailed breakdown o
Internal changes ⚙️
+* Move program capture code closer to where it is used.
+ [(#7608)][https://github.com/PennyLaneAI/pennylane/pull/7608]
+
* Tests using `OpenFermion` in `tests/qchem` do not fail with NumPy>=2.0.0 any more.
[(#7626)](https://github.com/PennyLaneAI/pennylane/pull/7626)
diff --git a/pennylane/_grad.py b/pennylane/_grad.py
index af30ba8ef40..ecbdafd57b3 100644
--- a/pennylane/_grad.py
+++ b/pennylane/_grad.py
@@ -15,7 +15,7 @@
This module contains the autograd wrappers :class:`grad` and :func:`jacobian`
"""
import warnings
-from functools import partial, wraps
+from functools import lru_cache, partial, wraps
from autograd import jacobian as _jacobian
from autograd.core import make_vjp as _make_vjp
@@ -23,19 +23,101 @@
from autograd.numpy.numpy_boxes import ArrayBox
from autograd.wrap_util import unary_to_nary
-from pennylane.capture import determine_abstracted_axes, enabled
-from pennylane.capture.capture_diff import _get_grad_prim, _get_jacobian_prim
-from pennylane.capture.flatfn import FlatFn
+from pennylane import capture
from pennylane.compiler import compiler
from pennylane.compiler.compiler import CompileError
make_vjp = unary_to_nary(_make_vjp)
+has_jax = True
+try:
+ import jax
+except ImportError:
+ has_jax = False
+
+
+@lru_cache
+def _get_grad_prim():
+ """Create a primitive for gradient computations.
+ This primitive is used when capturing ``qml.grad``.
+ """
+ if not has_jax: # pragma: no cover
+ return None
+
+ grad_prim = capture.QmlPrimitive("grad")
+ grad_prim.multiple_results = True
+ grad_prim.prim_type = "higher_order"
+
+ @grad_prim.def_impl
+ def _(*args, argnum, jaxpr, n_consts, method, h):
+ if method or h: # pragma: no cover
+ raise ValueError(f"Invalid values '{method=}' and '{h=}' without QJIT.")
+ consts = args[:n_consts]
+ args = args[n_consts:]
+
+ def func(*inner_args):
+ return jax.core.eval_jaxpr(jaxpr, consts, *inner_args)[0]
+
+ return jax.grad(func, argnums=argnum)(*args)
+
+ # pylint: disable=unused-argument
+ @grad_prim.def_abstract_eval
+ def _(*args, argnum, jaxpr, n_consts, method, h):
+ if len(jaxpr.outvars) != 1 or jaxpr.outvars[0].aval.shape != ():
+ raise TypeError("Grad only applies to scalar-output functions. Try jacobian.")
+ return tuple(args[i + n_consts] for i in argnum)
+
+ return grad_prim
+
+
+def _shape(shape, dtype):
+ if jax.config.jax_dynamic_shapes and any(not isinstance(s, int) for s in shape):
+ return jax.core.DShapedArray(shape, dtype)
+ return jax.core.ShapedArray(shape, dtype)
+
+
+@lru_cache
+def _get_jacobian_prim():
+ """Create a primitive for Jacobian computations.
+ This primitive is used when capturing ``qml.jacobian``.
+ """
+ if not has_jax: # pragma: no cover
+ return None
+
+ jacobian_prim = capture.QmlPrimitive("jacobian")
+ jacobian_prim.multiple_results = True
+ jacobian_prim.prim_type = "higher_order"
+
+ @jacobian_prim.def_impl
+ def _(*args, argnum, jaxpr, n_consts, method, h):
+ if method or h: # pragma: no cover
+ raise ValueError(f"Invalid values '{method=}' and '{h=}' without QJIT.")
+ consts = args[:n_consts]
+ args = args[n_consts:]
+
+ def func(*inner_args):
+ return jax.core.eval_jaxpr(jaxpr, consts, *inner_args)
+
+ return jax.tree_util.tree_leaves(jax.jacobian(func, argnums=argnum)(*args))
+
+ # pylint: disable=unused-argument
+ @jacobian_prim.def_abstract_eval
+ def _(*args, argnum, jaxpr, n_consts, method, h):
+ in_avals = tuple(args[i + n_consts] for i in argnum)
+ out_shapes = tuple(outvar.aval.shape for outvar in jaxpr.outvars)
+ return [
+ _shape(out_shape + in_aval.shape, in_aval.dtype)
+ for out_shape in out_shapes
+ for in_aval in in_avals
+ ]
+
+ return jacobian_prim
+
+
def _capture_diff(func, argnum=None, diff_prim=None, method=None, h=None):
"""Capture-compatible gradient computation."""
# pylint: disable=import-outside-toplevel
- import jax
from jax.tree_util import tree_flatten, tree_leaves, tree_unflatten, treedef_tuple
if argnum is None:
@@ -70,9 +152,9 @@ def new_func(*args, **kwargs):
flat_argnum = sum(flat_argnum_gen, start=[])
# Create fully flattened function (flat inputs & outputs)
- flat_fn = FlatFn(partial(func, **kwargs) if kwargs else func, full_in_tree)
+ flat_fn = capture.FlatFn(partial(func, **kwargs) if kwargs else func, full_in_tree)
flat_args = sum(flat_args, start=[])
- abstracted_axes, abstract_shapes = determine_abstracted_axes(tuple(flat_args))
+ abstracted_axes, abstract_shapes = capture.determine_abstracted_axes(tuple(flat_args))
jaxpr = jax.make_jaxpr(flat_fn, abstracted_axes=abstracted_axes)(*flat_args)
num_abstract_shapes = len(abstract_shapes)
@@ -165,7 +247,7 @@ def __new__(cls, func, argnum=None, method=None, h=None):
ops_loader = available_eps[active_jit]["ops"].load()
return ops_loader.grad(func, method=method, h=h, argnums=argnum)
- if enabled():
+ if capture.enabled():
return _capture_diff(func, argnum, _get_grad_prim(), method=method, h=h)
if method or h: # pragma: no cover
@@ -496,7 +578,7 @@ def circuit(x):
ops_loader = available_eps[active_jit]["ops"].load()
return ops_loader.jacobian(func, method=method, h=h, argnums=argnum)
- if enabled():
+ if capture.enabled():
return _capture_diff(func, argnum, _get_jacobian_prim(), method=method, h=h)
if method or h:
diff --git a/pennylane/capture/__init__.py b/pennylane/capture/__init__.py
index 716b73b1b8a..ae61f7be09a 100644
--- a/pennylane/capture/__init__.py
+++ b/pennylane/capture/__init__.py
@@ -30,10 +30,6 @@
~enable
~enabled
~pause
- ~create_operator_primitive
- ~create_measurement_obs_primitive
- ~create_measurement_wires_primitive
- ~create_measurement_mcm_primitive
~determine_abstracted_axes
~expand_plxpr_transforms
~eval_jaxpr
@@ -162,16 +158,10 @@ class MyCustomOp(qml.operation.Operator):
def _(*args, **kwargs):
return type.__call__(MyCustomOp, *args, **kwargs)
"""
-from typing import Callable
+from typing import Callable, Type
from .switches import disable, enable, enabled, pause
from .capture_meta import CaptureMeta, ABCCaptureMeta
-from .capture_operators import create_operator_primitive
-from .capture_measurements import (
- create_measurement_obs_primitive,
- create_measurement_wires_primitive,
- create_measurement_mcm_primitive,
-)
from .flatfn import FlatFn
from .make_plxpr import make_plxpr, run_autograph
from .dynamic_shapes import determine_abstracted_axes, register_custom_staging_rule
@@ -185,14 +175,16 @@ def _(*args, **kwargs):
PlxprInterpreter: type
expand_plxpr_transforms: Callable[[Callable], Callable]
eval_jaxpr: Callable
+QmlPrimitive: "Type[jax.extend.core.Primitive]"
-class CaptureError(Exception):
- """Errors related to PennyLane's Program Capture execution pipeline."""
+# pylint: disable=import-outside-toplevel, redefined-outer-name, too-many-return-statements
+def __getattr__(key):
+ if key == "QmlPrimitive":
+ from .custom_primitives import QmlPrimitive
+ return QmlPrimitive
-# pylint: disable=import-outside-toplevel, redefined-outer-name
-def __getattr__(key):
if key == "AbstractOperator":
from .primitives import _get_abstract_operator
@@ -233,10 +225,6 @@ def __getattr__(key):
"eval_jaxpr",
"CaptureMeta",
"ABCCaptureMeta",
- "create_operator_primitive",
- "create_measurement_obs_primitive",
- "create_measurement_wires_primitive",
- "create_measurement_mcm_primitive",
"determine_abstracted_axes",
"expand_plxpr_transforms",
"register_custom_staging_rule",
diff --git a/pennylane/capture/capture_diff.py b/pennylane/capture/capture_diff.py
deleted file mode 100644
index c41b25aaad3..00000000000
--- a/pennylane/capture/capture_diff.py
+++ /dev/null
@@ -1,106 +0,0 @@
-# Copyright 2024 Xanadu Quantum Technologies Inc.
-
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-
-# http://www.apache.org/licenses/LICENSE-2.0
-
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""
-This submodule offers differentiation-related primitives and types for
-the PennyLane capture module.
-"""
-from functools import lru_cache
-
-has_jax = True
-try:
- import jax
-except ImportError:
- has_jax = False
-
-
-@lru_cache
-def _get_grad_prim():
- """Create a primitive for gradient computations.
- This primitive is used when capturing ``qml.grad``.
- """
- if not has_jax: # pragma: no cover
- return None
-
- from .custom_primitives import QmlPrimitive # pylint: disable=import-outside-toplevel
-
- grad_prim = QmlPrimitive("grad")
- grad_prim.multiple_results = True
- grad_prim.prim_type = "higher_order"
-
- @grad_prim.def_impl
- def _(*args, argnum, jaxpr, n_consts, method, h):
- if method or h: # pragma: no cover
- raise ValueError(f"Invalid values '{method=}' and '{h=}' without QJIT.")
- consts = args[:n_consts]
- args = args[n_consts:]
-
- def func(*inner_args):
- return jax.core.eval_jaxpr(jaxpr, consts, *inner_args)[0]
-
- return jax.grad(func, argnums=argnum)(*args)
-
- # pylint: disable=unused-argument
- @grad_prim.def_abstract_eval
- def _(*args, argnum, jaxpr, n_consts, method, h):
- if len(jaxpr.outvars) != 1 or jaxpr.outvars[0].aval.shape != ():
- raise TypeError("Grad only applies to scalar-output functions. Try jacobian.")
- return tuple(args[i + n_consts] for i in argnum)
-
- return grad_prim
-
-
-def _shape(shape, dtype):
- if jax.config.jax_dynamic_shapes and any(not isinstance(s, int) for s in shape):
- return jax.core.DShapedArray(shape, dtype)
- return jax.core.ShapedArray(shape, dtype)
-
-
-@lru_cache
-def _get_jacobian_prim():
- """Create a primitive for Jacobian computations.
- This primitive is used when capturing ``qml.jacobian``.
- """
- if not has_jax: # pragma: no cover
- return None
-
- from .custom_primitives import QmlPrimitive # pylint: disable=import-outside-toplevel
-
- jacobian_prim = QmlPrimitive("jacobian")
- jacobian_prim.multiple_results = True
- jacobian_prim.prim_type = "higher_order"
-
- @jacobian_prim.def_impl
- def _(*args, argnum, jaxpr, n_consts, method, h):
- if method or h: # pragma: no cover
- raise ValueError(f"Invalid values '{method=}' and '{h=}' without QJIT.")
- consts = args[:n_consts]
- args = args[n_consts:]
-
- def func(*inner_args):
- return jax.core.eval_jaxpr(jaxpr, consts, *inner_args)
-
- return jax.tree_util.tree_leaves(jax.jacobian(func, argnums=argnum)(*args))
-
- # pylint: disable=unused-argument
- @jacobian_prim.def_abstract_eval
- def _(*args, argnum, jaxpr, n_consts, method, h):
- in_avals = tuple(args[i + n_consts] for i in argnum)
- out_shapes = tuple(outvar.aval.shape for outvar in jaxpr.outvars)
- return [
- _shape(out_shape + in_aval.shape, in_aval.dtype)
- for out_shape in out_shapes
- for in_aval in in_avals
- ]
-
- return jacobian_prim
diff --git a/pennylane/capture/capture_operators.py b/pennylane/capture/capture_operators.py
deleted file mode 100644
index 84e5f42eb19..00000000000
--- a/pennylane/capture/capture_operators.py
+++ /dev/null
@@ -1,127 +0,0 @@
-# Copyright 2024 Xanadu Quantum Technologies Inc.
-
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-
-# http://www.apache.org/licenses/LICENSE-2.0
-
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""
-This submodule defines the abstract classes and primitives for capturing operators.
-"""
-from functools import lru_cache
-from typing import Optional, Type
-
-import pennylane as qml
-
-has_jax = True
-try:
- import jax
-
-except ImportError:
- has_jax = False
-
-
-@lru_cache # construct the first time lazily
-def _get_abstract_operator() -> type:
- """Create an AbstractOperator once in a way protected from lack of a jax install."""
- if not has_jax: # pragma: no cover
- raise ImportError("Jax is required for plxpr.") # pragma: no cover
-
- class AbstractOperator(jax.core.AbstractValue):
- """An operator captured into plxpr."""
-
- # pylint: disable=missing-function-docstring
- def at_least_vspace(self):
- # TODO: investigate the proper definition of this method
- raise NotImplementedError
-
- # pylint: disable=missing-function-docstring
- def join(self, other):
- # TODO: investigate the proper definition of this method
- raise NotImplementedError
-
- # pylint: disable=missing-function-docstring
- def update(self, **kwargs):
- # TODO: investigate the proper definition of this method
- raise NotImplementedError
-
- def __eq__(self, other):
- return isinstance(other, AbstractOperator)
-
- def __hash__(self):
- return hash("AbstractOperator")
-
- @staticmethod
- def _matmul(*args):
- return qml.prod(*args)
-
- @staticmethod
- def _mul(a, b):
- return qml.s_prod(b, a)
-
- @staticmethod
- def _rmul(a, b):
- return qml.s_prod(b, a)
-
- @staticmethod
- def _add(a, b):
- return qml.sum(a, b)
-
- @staticmethod
- def _pow(a, b):
- return qml.pow(a, b)
-
- return AbstractOperator
-
-
-def create_operator_primitive(
- operator_type: Type["qml.operation.Operator"],
-) -> Optional["jax.extend.core.Primitive"]:
- """Create a primitive corresponding to an operator type.
-
- Called when defining any :class:`~.Operator` subclass, and is used to set the
- ``Operator._primitive`` class property.
-
- Args:
- operator_type (type): a subclass of qml.operation.Operator
-
- Returns:
- Optional[jax.extend.core.Primitive]: A new jax primitive with the same name as the operator subclass.
- ``None`` is returned if jax is not available.
-
- """
- if not has_jax:
- return None
-
- from .custom_primitives import QmlPrimitive # pylint: disable=import-outside-toplevel
-
- primitive = QmlPrimitive(operator_type.__name__)
- primitive.prim_type = "operator"
-
- @primitive.def_impl
- def _(*args, **kwargs):
- if "n_wires" not in kwargs:
- return type.__call__(operator_type, *args, **kwargs)
- n_wires = kwargs.pop("n_wires")
-
- split = None if n_wires == 0 else -n_wires
- # need to convert array values into integers
- # for plxpr, all wires must be integers
- # could be abstract when using tracing evaluation in interpreter
- wire_args = args[split:] if split else ()
- wires = tuple(w if qml.math.is_abstract(w) else int(w) for w in wire_args)
- return type.__call__(operator_type, *args[:split], wires=wires, **kwargs)
-
- abstract_type = _get_abstract_operator()
-
- @primitive.def_abstract_eval
- def _(*_, **__):
- return abstract_type()
-
- return primitive
diff --git a/pennylane/capture/primitives.py b/pennylane/capture/primitives.py
index a482323a829..b1c5a169c86 100644
--- a/pennylane/capture/primitives.py
+++ b/pennylane/capture/primitives.py
@@ -17,18 +17,17 @@
It has a jax dependency and should be located in a standard import path.
"""
+from pennylane._grad import _get_grad_prim, _get_jacobian_prim
from pennylane.control_flow.for_loop import _get_for_loop_qfunc_prim
from pennylane.control_flow.while_loop import _get_while_loop_qfunc_prim
+from pennylane.measurements.capture_measurements import _get_abstract_measurement
from pennylane.measurements.mid_measure import _create_mid_measure_primitive
+from pennylane.operation import _get_abstract_operator
from pennylane.ops.op_math.adjoint import _get_adjoint_qfunc_prim
from pennylane.ops.op_math.condition import _get_cond_qfunc_prim
from pennylane.ops.op_math.controlled import _get_ctrl_qfunc_prim
from pennylane.workflow._capture_qnode import qnode_prim
-from .capture_diff import _get_grad_prim, _get_jacobian_prim
-from .capture_measurements import _get_abstract_measurement
-from .capture_operators import _get_abstract_operator
-
AbstractOperator = _get_abstract_operator()
AbstractMeasurement = _get_abstract_measurement()
adjoint_transform_prim = _get_adjoint_qfunc_prim()
diff --git a/pennylane/exceptions.py b/pennylane/exceptions.py
index f59a13d6d06..2bf9c97e81b 100644
--- a/pennylane/exceptions.py
+++ b/pennylane/exceptions.py
@@ -16,6 +16,10 @@
"""
+class CaptureError(Exception):
+ """Errors related to PennyLane's Program Capture execution pipeline."""
+
+
class DeviceError(Exception): # pragma: no cover
"""Exception raised when it encounters an illegal operation in the quantum circuit."""
diff --git a/pennylane/capture/capture_measurements.py b/pennylane/measurements/capture_measurements.py
similarity index 91%
rename from pennylane/capture/capture_measurements.py
rename to pennylane/measurements/capture_measurements.py
index 559cdea27c9..6cce83e12f3 100644
--- a/pennylane/capture/capture_measurements.py
+++ b/pennylane/measurements/capture_measurements.py
@@ -19,7 +19,9 @@
from functools import lru_cache
from typing import Optional, Type
-import pennylane as qml
+from pennylane import capture
+from pennylane.math import is_abstract
+from pennylane.wires import Wires
has_jax = True
try:
@@ -126,9 +128,7 @@ def create_measurement_obs_primitive(
if not has_jax:
return None
- from .custom_primitives import QmlPrimitive # pylint: disable=import-outside-toplevel
-
- primitive = QmlPrimitive(name + "_obs")
+ primitive = capture.QmlPrimitive(name + "_obs")
primitive.prim_type = "measurement"
@primitive.def_impl
@@ -165,10 +165,7 @@ def create_measurement_mcm_primitive(
if not has_jax:
return None
-
- from .custom_primitives import QmlPrimitive # pylint: disable=import-outside-toplevel
-
- primitive = QmlPrimitive(name + "_mcm")
+ primitive = capture.QmlPrimitive(name + "_mcm")
primitive.prim_type = "measurement"
@primitive.def_impl
@@ -204,21 +201,17 @@ def create_measurement_wires_primitive(
if not has_jax:
return None
- from .custom_primitives import QmlPrimitive # pylint: disable=import-outside-toplevel
-
- primitive = QmlPrimitive(name + "_wires")
+ primitive = capture.QmlPrimitive(name + "_wires")
primitive.prim_type = "measurement"
@primitive.def_impl
def _(*args, has_eigvals=False, **kwargs):
if has_eigvals:
- wires = qml.wires.Wires(
- tuple(w if qml.math.is_abstract(w) else int(w) for w in args[:-1])
- )
+ wires = Wires(tuple(w if is_abstract(w) else int(w) for w in args[:-1]))
kwargs["eigvals"] = args[-1]
else:
- wires = tuple(w if qml.math.is_abstract(w) else int(w) for w in args)
- wires = qml.wires.Wires(wires)
+ wires = tuple(w if is_abstract(w) else int(w) for w in args)
+ wires = Wires(wires)
return type.__call__(measurement_type, wires=wires, **kwargs)
abstract_type = _get_abstract_measurement()
diff --git a/pennylane/measurements/measurements.py b/pennylane/measurements/measurements.py
index 1188338822d..028a73f5fa0 100644
--- a/pennylane/measurements/measurements.py
+++ b/pennylane/measurements/measurements.py
@@ -29,6 +29,12 @@
from pennylane.typing import TensorLike
from pennylane.wires import Wires
+from .capture_measurements import (
+ create_measurement_mcm_primitive,
+ create_measurement_obs_primitive,
+ create_measurement_wires_primitive,
+)
+
class MeasurementShapeError(ValueError):
"""An error raised when an unsupported operation is attempted with a
@@ -60,9 +66,9 @@ class MeasurementProcess(ABC, metaclass=qml.capture.ABCCaptureMeta):
def __init_subclass__(cls, **_):
register_pytree(cls, cls._flatten, cls._unflatten)
name = cls._shortname or cls.__name__
- cls._wires_primitive = qml.capture.create_measurement_wires_primitive(cls, name=name)
- cls._obs_primitive = qml.capture.create_measurement_obs_primitive(cls, name=name)
- cls._mcm_primitive = qml.capture.create_measurement_mcm_primitive(cls, name=name)
+ cls._wires_primitive = create_measurement_wires_primitive(cls, name=name)
+ cls._obs_primitive = create_measurement_obs_primitive(cls, name=name)
+ cls._mcm_primitive = create_measurement_mcm_primitive(cls, name=name)
@classmethod
def _primitive_bind_call(cls, obs=None, wires=None, eigvals=None, id=None, **kwargs):
diff --git a/pennylane/operation.py b/pennylane/operation.py
index 7b0b9a7cc95..3ba0f451cba 100644
--- a/pennylane/operation.py
+++ b/pennylane/operation.py
@@ -222,21 +222,29 @@
import warnings
from collections.abc import Hashable, Iterable
from enum import IntEnum
-from typing import Any, Callable, Literal, Optional, Union
+from functools import lru_cache
+from typing import Any, Callable, Literal, Optional, Type, Union
import numpy as np
from scipy.sparse import spmatrix
import pennylane as qml
-from pennylane.capture import ABCCaptureMeta, create_operator_primitive
+from pennylane import capture
from pennylane.exceptions import PennyLaneDeprecationWarning
-from pennylane.math import expand_matrix
+from pennylane.math import expand_matrix, is_abstract
from pennylane.queuing import QueuingManager
from pennylane.typing import TensorLike
from pennylane.wires import Wires, WiresLike
from .pytrees import register_pytree
+has_jax = True
+try:
+ import jax
+
+except ImportError:
+ has_jax = False
+
# =============================================================================
# Errors
# =============================================================================
@@ -369,6 +377,109 @@ def classproperty(func) -> ClassPropertyDescriptor:
return ClassPropertyDescriptor(func)
+# =============================================================================
+# Capture operators infrastructure
+# =============================================================================
+
+
+@lru_cache # construct the first time lazily
+def _get_abstract_operator() -> type:
+ """Create an AbstractOperator once in a way protected from lack of a jax install."""
+ if not has_jax: # pragma: no cover
+ raise ImportError("Jax is required for plxpr.") # pragma: no cover
+
+ class AbstractOperator(jax.core.AbstractValue):
+ """An operator captured into plxpr."""
+
+ # pylint: disable=missing-function-docstring
+ def at_least_vspace(self):
+ # TODO: investigate the proper definition of this method
+ raise NotImplementedError
+
+ # pylint: disable=missing-function-docstring
+ def join(self, other):
+ # TODO: investigate the proper definition of this method
+ raise NotImplementedError
+
+ # pylint: disable=missing-function-docstring
+ def update(self, **kwargs):
+ # TODO: investigate the proper definition of this method
+ raise NotImplementedError
+
+ def __eq__(self, other):
+ return isinstance(other, AbstractOperator)
+
+ def __hash__(self):
+ return hash("AbstractOperator")
+
+ @staticmethod
+ def _matmul(*args):
+ return qml.prod(*args)
+
+ @staticmethod
+ def _mul(a, b):
+ return qml.s_prod(b, a)
+
+ @staticmethod
+ def _rmul(a, b):
+ return qml.s_prod(b, a)
+
+ @staticmethod
+ def _add(a, b):
+ return qml.sum(a, b)
+
+ @staticmethod
+ def _pow(a, b):
+ return qml.pow(a, b)
+
+ return AbstractOperator
+
+
+def create_operator_primitive(
+ operator_type: Type["qml.operation.Operator"],
+) -> Optional["jax.extend.core.Primitive"]:
+ """Create a primitive corresponding to an operator type.
+
+ Called when defining any :class:`~.Operator` subclass, and is used to set the
+ ``Operator._primitive`` class property.
+
+ Args:
+ operator_type (type): a subclass of qml.operation.Operator
+
+ Returns:
+ Optional[jax.extend.core.Primitive]: A new jax primitive with the same name as the operator subclass.
+ ``None`` is returned if jax is not available.
+
+ """
+ if not has_jax:
+ return None
+
+ primitive = capture.QmlPrimitive(operator_type.__name__)
+ primitive.prim_type = "operator"
+
+ @primitive.def_impl
+ def _(*args, **kwargs):
+ if "n_wires" not in kwargs:
+ return type.__call__(operator_type, *args, **kwargs)
+ n_wires = kwargs.pop("n_wires")
+
+ split = None if n_wires == 0 else -n_wires
+ # need to convert array values into integers
+ # for plxpr, all wires must be integers
+ # could be abstract when using tracing evaluation in interpreter
+ wire_args = args[split:] if split else ()
+ wires = tuple(w if is_abstract(w) else int(w) for w in wire_args)
+ return type.__call__(operator_type, *args[:split], wires=wires, **kwargs)
+
+ abstract_type = _get_abstract_operator()
+
+ @primitive.def_abstract_eval
+ def _(*_, **__):
+ return abstract_type()
+
+ return primitive
+
+
# =============================================================================
# Base Operator class
# =============================================================================
@@ -388,13 +499,13 @@ def _mod_and_round(x, mod_val):
else:
mod_val = None
- return str([id(d) if qml.math.is_abstract(d) else _mod_and_round(d, mod_val) for d in op.data])
+ return str([id(d) if is_abstract(d) else _mod_and_round(d, mod_val) for d in op.data])
FlatPytree = tuple[Iterable[Any], Hashable]
-class Operator(abc.ABC, metaclass=ABCCaptureMeta):
+class Operator(abc.ABC, metaclass=capture.ABCCaptureMeta):
r"""Base class representing quantum operators.
Operators are uniquely defined by their name, the wires they act on, their (trainable) parameters,
@@ -702,8 +813,6 @@ def _primitive_bind_call(cls, *args, **kwargs):
# guard against this being called when primitive is not defined.
return type.__call__(cls, *args, **kwargs)
- import jax # pylint: disable=import-outside-toplevel
-
array_types = (jax.numpy.ndarray, np.ndarray)
iterable_wires_types = (
list,
@@ -1178,7 +1287,7 @@ def _check_batching(self):
# There might be a way to support batching nonetheless, which remains to be
# investigated. For now, the batch_size is left to be `None` when instantiating
# an operation with abstract parameters that make `qml.math.ndim` fail.
- if any(qml.math.is_abstract(p) for p in params):
+ if any(is_abstract(p) for p in params):
self._batch_size = None
self._ndim_params = (0,) * len(params)
return
diff --git a/pennylane/workflow/_capture_qnode.py b/pennylane/workflow/_capture_qnode.py
index c17ff4f0b12..6cab0575e29 100644
--- a/pennylane/workflow/_capture_qnode.py
+++ b/pennylane/workflow/_capture_qnode.py
@@ -116,8 +116,8 @@
from jax.interpreters import ad, batching, mlir
import pennylane as qml
-from pennylane.capture import CaptureError, FlatFn
-from pennylane.capture.custom_primitives import QmlPrimitive
+from pennylane.capture import FlatFn, QmlPrimitive
+from pennylane.exceptions import CaptureError
from pennylane.logging import debug_logger
from pennylane.typing import TensorLike
diff --git a/tests/capture/workflow/test_capture_qnode.py b/tests/capture/workflow/test_capture_qnode.py
index 35a12c1426d..d86d24b26d8 100644
--- a/tests/capture/workflow/test_capture_qnode.py
+++ b/tests/capture/workflow/test_capture_qnode.py
@@ -21,8 +21,7 @@
import pytest
import pennylane as qml
-from pennylane.capture import CaptureError
-from pennylane.exceptions import QuantumFunctionError
+from pennylane.exceptions import CaptureError, QuantumFunctionError
pytestmark = [pytest.mark.jax, pytest.mark.usefixtures("enable_disable_plxpr")]