diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 99d3b6b63e5..329353e326e 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -567,6 +567,9 @@ Here's a list of deprecations made this release. For a more detailed breakdown o

Internal changes ⚙️

+* Move program capture code closer to where it is used. + [(#7608)][https://github.com/PennyLaneAI/pennylane/pull/7608] + * Tests using `OpenFermion` in `tests/qchem` do not fail with NumPy>=2.0.0 any more. [(#7626)](https://github.com/PennyLaneAI/pennylane/pull/7626) diff --git a/pennylane/_grad.py b/pennylane/_grad.py index af30ba8ef40..ecbdafd57b3 100644 --- a/pennylane/_grad.py +++ b/pennylane/_grad.py @@ -15,7 +15,7 @@ This module contains the autograd wrappers :class:`grad` and :func:`jacobian` """ import warnings -from functools import partial, wraps +from functools import lru_cache, partial, wraps from autograd import jacobian as _jacobian from autograd.core import make_vjp as _make_vjp @@ -23,19 +23,101 @@ from autograd.numpy.numpy_boxes import ArrayBox from autograd.wrap_util import unary_to_nary -from pennylane.capture import determine_abstracted_axes, enabled -from pennylane.capture.capture_diff import _get_grad_prim, _get_jacobian_prim -from pennylane.capture.flatfn import FlatFn +from pennylane import capture from pennylane.compiler import compiler from pennylane.compiler.compiler import CompileError make_vjp = unary_to_nary(_make_vjp) +has_jax = True +try: + import jax +except ImportError: + has_jax = False + + +@lru_cache +def _get_grad_prim(): + """Create a primitive for gradient computations. + This primitive is used when capturing ``qml.grad``. + """ + if not has_jax: # pragma: no cover + return None + + grad_prim = capture.QmlPrimitive("grad") + grad_prim.multiple_results = True + grad_prim.prim_type = "higher_order" + + @grad_prim.def_impl + def _(*args, argnum, jaxpr, n_consts, method, h): + if method or h: # pragma: no cover + raise ValueError(f"Invalid values '{method=}' and '{h=}' without QJIT.") + consts = args[:n_consts] + args = args[n_consts:] + + def func(*inner_args): + return jax.core.eval_jaxpr(jaxpr, consts, *inner_args)[0] + + return jax.grad(func, argnums=argnum)(*args) + + # pylint: disable=unused-argument + @grad_prim.def_abstract_eval + def _(*args, argnum, jaxpr, n_consts, method, h): + if len(jaxpr.outvars) != 1 or jaxpr.outvars[0].aval.shape != (): + raise TypeError("Grad only applies to scalar-output functions. Try jacobian.") + return tuple(args[i + n_consts] for i in argnum) + + return grad_prim + + +def _shape(shape, dtype): + if jax.config.jax_dynamic_shapes and any(not isinstance(s, int) for s in shape): + return jax.core.DShapedArray(shape, dtype) + return jax.core.ShapedArray(shape, dtype) + + +@lru_cache +def _get_jacobian_prim(): + """Create a primitive for Jacobian computations. + This primitive is used when capturing ``qml.jacobian``. + """ + if not has_jax: # pragma: no cover + return None + + jacobian_prim = capture.QmlPrimitive("jacobian") + jacobian_prim.multiple_results = True + jacobian_prim.prim_type = "higher_order" + + @jacobian_prim.def_impl + def _(*args, argnum, jaxpr, n_consts, method, h): + if method or h: # pragma: no cover + raise ValueError(f"Invalid values '{method=}' and '{h=}' without QJIT.") + consts = args[:n_consts] + args = args[n_consts:] + + def func(*inner_args): + return jax.core.eval_jaxpr(jaxpr, consts, *inner_args) + + return jax.tree_util.tree_leaves(jax.jacobian(func, argnums=argnum)(*args)) + + # pylint: disable=unused-argument + @jacobian_prim.def_abstract_eval + def _(*args, argnum, jaxpr, n_consts, method, h): + in_avals = tuple(args[i + n_consts] for i in argnum) + out_shapes = tuple(outvar.aval.shape for outvar in jaxpr.outvars) + return [ + _shape(out_shape + in_aval.shape, in_aval.dtype) + for out_shape in out_shapes + for in_aval in in_avals + ] + + return jacobian_prim + + def _capture_diff(func, argnum=None, diff_prim=None, method=None, h=None): """Capture-compatible gradient computation.""" # pylint: disable=import-outside-toplevel - import jax from jax.tree_util import tree_flatten, tree_leaves, tree_unflatten, treedef_tuple if argnum is None: @@ -70,9 +152,9 @@ def new_func(*args, **kwargs): flat_argnum = sum(flat_argnum_gen, start=[]) # Create fully flattened function (flat inputs & outputs) - flat_fn = FlatFn(partial(func, **kwargs) if kwargs else func, full_in_tree) + flat_fn = capture.FlatFn(partial(func, **kwargs) if kwargs else func, full_in_tree) flat_args = sum(flat_args, start=[]) - abstracted_axes, abstract_shapes = determine_abstracted_axes(tuple(flat_args)) + abstracted_axes, abstract_shapes = capture.determine_abstracted_axes(tuple(flat_args)) jaxpr = jax.make_jaxpr(flat_fn, abstracted_axes=abstracted_axes)(*flat_args) num_abstract_shapes = len(abstract_shapes) @@ -165,7 +247,7 @@ def __new__(cls, func, argnum=None, method=None, h=None): ops_loader = available_eps[active_jit]["ops"].load() return ops_loader.grad(func, method=method, h=h, argnums=argnum) - if enabled(): + if capture.enabled(): return _capture_diff(func, argnum, _get_grad_prim(), method=method, h=h) if method or h: # pragma: no cover @@ -496,7 +578,7 @@ def circuit(x): ops_loader = available_eps[active_jit]["ops"].load() return ops_loader.jacobian(func, method=method, h=h, argnums=argnum) - if enabled(): + if capture.enabled(): return _capture_diff(func, argnum, _get_jacobian_prim(), method=method, h=h) if method or h: diff --git a/pennylane/capture/__init__.py b/pennylane/capture/__init__.py index 716b73b1b8a..ae61f7be09a 100644 --- a/pennylane/capture/__init__.py +++ b/pennylane/capture/__init__.py @@ -30,10 +30,6 @@ ~enable ~enabled ~pause - ~create_operator_primitive - ~create_measurement_obs_primitive - ~create_measurement_wires_primitive - ~create_measurement_mcm_primitive ~determine_abstracted_axes ~expand_plxpr_transforms ~eval_jaxpr @@ -162,16 +158,10 @@ class MyCustomOp(qml.operation.Operator): def _(*args, **kwargs): return type.__call__(MyCustomOp, *args, **kwargs) """ -from typing import Callable +from typing import Callable, Type from .switches import disable, enable, enabled, pause from .capture_meta import CaptureMeta, ABCCaptureMeta -from .capture_operators import create_operator_primitive -from .capture_measurements import ( - create_measurement_obs_primitive, - create_measurement_wires_primitive, - create_measurement_mcm_primitive, -) from .flatfn import FlatFn from .make_plxpr import make_plxpr, run_autograph from .dynamic_shapes import determine_abstracted_axes, register_custom_staging_rule @@ -185,14 +175,16 @@ def _(*args, **kwargs): PlxprInterpreter: type expand_plxpr_transforms: Callable[[Callable], Callable] eval_jaxpr: Callable +QmlPrimitive: "Type[jax.extend.core.Primitive]" -class CaptureError(Exception): - """Errors related to PennyLane's Program Capture execution pipeline.""" +# pylint: disable=import-outside-toplevel, redefined-outer-name, too-many-return-statements +def __getattr__(key): + if key == "QmlPrimitive": + from .custom_primitives import QmlPrimitive + return QmlPrimitive -# pylint: disable=import-outside-toplevel, redefined-outer-name -def __getattr__(key): if key == "AbstractOperator": from .primitives import _get_abstract_operator @@ -233,10 +225,6 @@ def __getattr__(key): "eval_jaxpr", "CaptureMeta", "ABCCaptureMeta", - "create_operator_primitive", - "create_measurement_obs_primitive", - "create_measurement_wires_primitive", - "create_measurement_mcm_primitive", "determine_abstracted_axes", "expand_plxpr_transforms", "register_custom_staging_rule", diff --git a/pennylane/capture/capture_diff.py b/pennylane/capture/capture_diff.py deleted file mode 100644 index c41b25aaad3..00000000000 --- a/pennylane/capture/capture_diff.py +++ /dev/null @@ -1,106 +0,0 @@ -# Copyright 2024 Xanadu Quantum Technologies Inc. - -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This submodule offers differentiation-related primitives and types for -the PennyLane capture module. -""" -from functools import lru_cache - -has_jax = True -try: - import jax -except ImportError: - has_jax = False - - -@lru_cache -def _get_grad_prim(): - """Create a primitive for gradient computations. - This primitive is used when capturing ``qml.grad``. - """ - if not has_jax: # pragma: no cover - return None - - from .custom_primitives import QmlPrimitive # pylint: disable=import-outside-toplevel - - grad_prim = QmlPrimitive("grad") - grad_prim.multiple_results = True - grad_prim.prim_type = "higher_order" - - @grad_prim.def_impl - def _(*args, argnum, jaxpr, n_consts, method, h): - if method or h: # pragma: no cover - raise ValueError(f"Invalid values '{method=}' and '{h=}' without QJIT.") - consts = args[:n_consts] - args = args[n_consts:] - - def func(*inner_args): - return jax.core.eval_jaxpr(jaxpr, consts, *inner_args)[0] - - return jax.grad(func, argnums=argnum)(*args) - - # pylint: disable=unused-argument - @grad_prim.def_abstract_eval - def _(*args, argnum, jaxpr, n_consts, method, h): - if len(jaxpr.outvars) != 1 or jaxpr.outvars[0].aval.shape != (): - raise TypeError("Grad only applies to scalar-output functions. Try jacobian.") - return tuple(args[i + n_consts] for i in argnum) - - return grad_prim - - -def _shape(shape, dtype): - if jax.config.jax_dynamic_shapes and any(not isinstance(s, int) for s in shape): - return jax.core.DShapedArray(shape, dtype) - return jax.core.ShapedArray(shape, dtype) - - -@lru_cache -def _get_jacobian_prim(): - """Create a primitive for Jacobian computations. - This primitive is used when capturing ``qml.jacobian``. - """ - if not has_jax: # pragma: no cover - return None - - from .custom_primitives import QmlPrimitive # pylint: disable=import-outside-toplevel - - jacobian_prim = QmlPrimitive("jacobian") - jacobian_prim.multiple_results = True - jacobian_prim.prim_type = "higher_order" - - @jacobian_prim.def_impl - def _(*args, argnum, jaxpr, n_consts, method, h): - if method or h: # pragma: no cover - raise ValueError(f"Invalid values '{method=}' and '{h=}' without QJIT.") - consts = args[:n_consts] - args = args[n_consts:] - - def func(*inner_args): - return jax.core.eval_jaxpr(jaxpr, consts, *inner_args) - - return jax.tree_util.tree_leaves(jax.jacobian(func, argnums=argnum)(*args)) - - # pylint: disable=unused-argument - @jacobian_prim.def_abstract_eval - def _(*args, argnum, jaxpr, n_consts, method, h): - in_avals = tuple(args[i + n_consts] for i in argnum) - out_shapes = tuple(outvar.aval.shape for outvar in jaxpr.outvars) - return [ - _shape(out_shape + in_aval.shape, in_aval.dtype) - for out_shape in out_shapes - for in_aval in in_avals - ] - - return jacobian_prim diff --git a/pennylane/capture/capture_operators.py b/pennylane/capture/capture_operators.py deleted file mode 100644 index 84e5f42eb19..00000000000 --- a/pennylane/capture/capture_operators.py +++ /dev/null @@ -1,127 +0,0 @@ -# Copyright 2024 Xanadu Quantum Technologies Inc. - -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This submodule defines the abstract classes and primitives for capturing operators. -""" -from functools import lru_cache -from typing import Optional, Type - -import pennylane as qml - -has_jax = True -try: - import jax - -except ImportError: - has_jax = False - - -@lru_cache # construct the first time lazily -def _get_abstract_operator() -> type: - """Create an AbstractOperator once in a way protected from lack of a jax install.""" - if not has_jax: # pragma: no cover - raise ImportError("Jax is required for plxpr.") # pragma: no cover - - class AbstractOperator(jax.core.AbstractValue): - """An operator captured into plxpr.""" - - # pylint: disable=missing-function-docstring - def at_least_vspace(self): - # TODO: investigate the proper definition of this method - raise NotImplementedError - - # pylint: disable=missing-function-docstring - def join(self, other): - # TODO: investigate the proper definition of this method - raise NotImplementedError - - # pylint: disable=missing-function-docstring - def update(self, **kwargs): - # TODO: investigate the proper definition of this method - raise NotImplementedError - - def __eq__(self, other): - return isinstance(other, AbstractOperator) - - def __hash__(self): - return hash("AbstractOperator") - - @staticmethod - def _matmul(*args): - return qml.prod(*args) - - @staticmethod - def _mul(a, b): - return qml.s_prod(b, a) - - @staticmethod - def _rmul(a, b): - return qml.s_prod(b, a) - - @staticmethod - def _add(a, b): - return qml.sum(a, b) - - @staticmethod - def _pow(a, b): - return qml.pow(a, b) - - return AbstractOperator - - -def create_operator_primitive( - operator_type: Type["qml.operation.Operator"], -) -> Optional["jax.extend.core.Primitive"]: - """Create a primitive corresponding to an operator type. - - Called when defining any :class:`~.Operator` subclass, and is used to set the - ``Operator._primitive`` class property. - - Args: - operator_type (type): a subclass of qml.operation.Operator - - Returns: - Optional[jax.extend.core.Primitive]: A new jax primitive with the same name as the operator subclass. - ``None`` is returned if jax is not available. - - """ - if not has_jax: - return None - - from .custom_primitives import QmlPrimitive # pylint: disable=import-outside-toplevel - - primitive = QmlPrimitive(operator_type.__name__) - primitive.prim_type = "operator" - - @primitive.def_impl - def _(*args, **kwargs): - if "n_wires" not in kwargs: - return type.__call__(operator_type, *args, **kwargs) - n_wires = kwargs.pop("n_wires") - - split = None if n_wires == 0 else -n_wires - # need to convert array values into integers - # for plxpr, all wires must be integers - # could be abstract when using tracing evaluation in interpreter - wire_args = args[split:] if split else () - wires = tuple(w if qml.math.is_abstract(w) else int(w) for w in wire_args) - return type.__call__(operator_type, *args[:split], wires=wires, **kwargs) - - abstract_type = _get_abstract_operator() - - @primitive.def_abstract_eval - def _(*_, **__): - return abstract_type() - - return primitive diff --git a/pennylane/capture/primitives.py b/pennylane/capture/primitives.py index a482323a829..b1c5a169c86 100644 --- a/pennylane/capture/primitives.py +++ b/pennylane/capture/primitives.py @@ -17,18 +17,17 @@ It has a jax dependency and should be located in a standard import path. """ +from pennylane._grad import _get_grad_prim, _get_jacobian_prim from pennylane.control_flow.for_loop import _get_for_loop_qfunc_prim from pennylane.control_flow.while_loop import _get_while_loop_qfunc_prim +from pennylane.measurements.capture_measurements import _get_abstract_measurement from pennylane.measurements.mid_measure import _create_mid_measure_primitive +from pennylane.operation import _get_abstract_operator from pennylane.ops.op_math.adjoint import _get_adjoint_qfunc_prim from pennylane.ops.op_math.condition import _get_cond_qfunc_prim from pennylane.ops.op_math.controlled import _get_ctrl_qfunc_prim from pennylane.workflow._capture_qnode import qnode_prim -from .capture_diff import _get_grad_prim, _get_jacobian_prim -from .capture_measurements import _get_abstract_measurement -from .capture_operators import _get_abstract_operator - AbstractOperator = _get_abstract_operator() AbstractMeasurement = _get_abstract_measurement() adjoint_transform_prim = _get_adjoint_qfunc_prim() diff --git a/pennylane/exceptions.py b/pennylane/exceptions.py index f59a13d6d06..2bf9c97e81b 100644 --- a/pennylane/exceptions.py +++ b/pennylane/exceptions.py @@ -16,6 +16,10 @@ """ +class CaptureError(Exception): + """Errors related to PennyLane's Program Capture execution pipeline.""" + + class DeviceError(Exception): # pragma: no cover """Exception raised when it encounters an illegal operation in the quantum circuit.""" diff --git a/pennylane/capture/capture_measurements.py b/pennylane/measurements/capture_measurements.py similarity index 91% rename from pennylane/capture/capture_measurements.py rename to pennylane/measurements/capture_measurements.py index 559cdea27c9..6cce83e12f3 100644 --- a/pennylane/capture/capture_measurements.py +++ b/pennylane/measurements/capture_measurements.py @@ -19,7 +19,9 @@ from functools import lru_cache from typing import Optional, Type -import pennylane as qml +from pennylane import capture +from pennylane.math import is_abstract +from pennylane.wires import Wires has_jax = True try: @@ -126,9 +128,7 @@ def create_measurement_obs_primitive( if not has_jax: return None - from .custom_primitives import QmlPrimitive # pylint: disable=import-outside-toplevel - - primitive = QmlPrimitive(name + "_obs") + primitive = capture.QmlPrimitive(name + "_obs") primitive.prim_type = "measurement" @primitive.def_impl @@ -165,10 +165,7 @@ def create_measurement_mcm_primitive( if not has_jax: return None - - from .custom_primitives import QmlPrimitive # pylint: disable=import-outside-toplevel - - primitive = QmlPrimitive(name + "_mcm") + primitive = capture.QmlPrimitive(name + "_mcm") primitive.prim_type = "measurement" @primitive.def_impl @@ -204,21 +201,17 @@ def create_measurement_wires_primitive( if not has_jax: return None - from .custom_primitives import QmlPrimitive # pylint: disable=import-outside-toplevel - - primitive = QmlPrimitive(name + "_wires") + primitive = capture.QmlPrimitive(name + "_wires") primitive.prim_type = "measurement" @primitive.def_impl def _(*args, has_eigvals=False, **kwargs): if has_eigvals: - wires = qml.wires.Wires( - tuple(w if qml.math.is_abstract(w) else int(w) for w in args[:-1]) - ) + wires = Wires(tuple(w if is_abstract(w) else int(w) for w in args[:-1])) kwargs["eigvals"] = args[-1] else: - wires = tuple(w if qml.math.is_abstract(w) else int(w) for w in args) - wires = qml.wires.Wires(wires) + wires = tuple(w if is_abstract(w) else int(w) for w in args) + wires = Wires(wires) return type.__call__(measurement_type, wires=wires, **kwargs) abstract_type = _get_abstract_measurement() diff --git a/pennylane/measurements/measurements.py b/pennylane/measurements/measurements.py index 1188338822d..028a73f5fa0 100644 --- a/pennylane/measurements/measurements.py +++ b/pennylane/measurements/measurements.py @@ -29,6 +29,12 @@ from pennylane.typing import TensorLike from pennylane.wires import Wires +from .capture_measurements import ( + create_measurement_mcm_primitive, + create_measurement_obs_primitive, + create_measurement_wires_primitive, +) + class MeasurementShapeError(ValueError): """An error raised when an unsupported operation is attempted with a @@ -60,9 +66,9 @@ class MeasurementProcess(ABC, metaclass=qml.capture.ABCCaptureMeta): def __init_subclass__(cls, **_): register_pytree(cls, cls._flatten, cls._unflatten) name = cls._shortname or cls.__name__ - cls._wires_primitive = qml.capture.create_measurement_wires_primitive(cls, name=name) - cls._obs_primitive = qml.capture.create_measurement_obs_primitive(cls, name=name) - cls._mcm_primitive = qml.capture.create_measurement_mcm_primitive(cls, name=name) + cls._wires_primitive = create_measurement_wires_primitive(cls, name=name) + cls._obs_primitive = create_measurement_obs_primitive(cls, name=name) + cls._mcm_primitive = create_measurement_mcm_primitive(cls, name=name) @classmethod def _primitive_bind_call(cls, obs=None, wires=None, eigvals=None, id=None, **kwargs): diff --git a/pennylane/operation.py b/pennylane/operation.py index 7b0b9a7cc95..3ba0f451cba 100644 --- a/pennylane/operation.py +++ b/pennylane/operation.py @@ -222,21 +222,29 @@ import warnings from collections.abc import Hashable, Iterable from enum import IntEnum -from typing import Any, Callable, Literal, Optional, Union +from functools import lru_cache +from typing import Any, Callable, Literal, Optional, Type, Union import numpy as np from scipy.sparse import spmatrix import pennylane as qml -from pennylane.capture import ABCCaptureMeta, create_operator_primitive +from pennylane import capture from pennylane.exceptions import PennyLaneDeprecationWarning -from pennylane.math import expand_matrix +from pennylane.math import expand_matrix, is_abstract from pennylane.queuing import QueuingManager from pennylane.typing import TensorLike from pennylane.wires import Wires, WiresLike from .pytrees import register_pytree +has_jax = True +try: + import jax + +except ImportError: + has_jax = False + # ============================================================================= # Errors # ============================================================================= @@ -369,6 +377,109 @@ def classproperty(func) -> ClassPropertyDescriptor: return ClassPropertyDescriptor(func) +# ============================================================================= +# Capture operators infrastructure +# ============================================================================= + + +@lru_cache # construct the first time lazily +def _get_abstract_operator() -> type: + """Create an AbstractOperator once in a way protected from lack of a jax install.""" + if not has_jax: # pragma: no cover + raise ImportError("Jax is required for plxpr.") # pragma: no cover + + class AbstractOperator(jax.core.AbstractValue): + """An operator captured into plxpr.""" + + # pylint: disable=missing-function-docstring + def at_least_vspace(self): + # TODO: investigate the proper definition of this method + raise NotImplementedError + + # pylint: disable=missing-function-docstring + def join(self, other): + # TODO: investigate the proper definition of this method + raise NotImplementedError + + # pylint: disable=missing-function-docstring + def update(self, **kwargs): + # TODO: investigate the proper definition of this method + raise NotImplementedError + + def __eq__(self, other): + return isinstance(other, AbstractOperator) + + def __hash__(self): + return hash("AbstractOperator") + + @staticmethod + def _matmul(*args): + return qml.prod(*args) + + @staticmethod + def _mul(a, b): + return qml.s_prod(b, a) + + @staticmethod + def _rmul(a, b): + return qml.s_prod(b, a) + + @staticmethod + def _add(a, b): + return qml.sum(a, b) + + @staticmethod + def _pow(a, b): + return qml.pow(a, b) + + return AbstractOperator + + +def create_operator_primitive( + operator_type: Type["qml.operation.Operator"], +) -> Optional["jax.extend.core.Primitive"]: + """Create a primitive corresponding to an operator type. + + Called when defining any :class:`~.Operator` subclass, and is used to set the + ``Operator._primitive`` class property. + + Args: + operator_type (type): a subclass of qml.operation.Operator + + Returns: + Optional[jax.extend.core.Primitive]: A new jax primitive with the same name as the operator subclass. + ``None`` is returned if jax is not available. + + """ + if not has_jax: + return None + + primitive = capture.QmlPrimitive(operator_type.__name__) + primitive.prim_type = "operator" + + @primitive.def_impl + def _(*args, **kwargs): + if "n_wires" not in kwargs: + return type.__call__(operator_type, *args, **kwargs) + n_wires = kwargs.pop("n_wires") + + split = None if n_wires == 0 else -n_wires + # need to convert array values into integers + # for plxpr, all wires must be integers + # could be abstract when using tracing evaluation in interpreter + wire_args = args[split:] if split else () + wires = tuple(w if is_abstract(w) else int(w) for w in wire_args) + return type.__call__(operator_type, *args[:split], wires=wires, **kwargs) + + abstract_type = _get_abstract_operator() + + @primitive.def_abstract_eval + def _(*_, **__): + return abstract_type() + + return primitive + + # ============================================================================= # Base Operator class # ============================================================================= @@ -388,13 +499,13 @@ def _mod_and_round(x, mod_val): else: mod_val = None - return str([id(d) if qml.math.is_abstract(d) else _mod_and_round(d, mod_val) for d in op.data]) + return str([id(d) if is_abstract(d) else _mod_and_round(d, mod_val) for d in op.data]) FlatPytree = tuple[Iterable[Any], Hashable] -class Operator(abc.ABC, metaclass=ABCCaptureMeta): +class Operator(abc.ABC, metaclass=capture.ABCCaptureMeta): r"""Base class representing quantum operators. Operators are uniquely defined by their name, the wires they act on, their (trainable) parameters, @@ -702,8 +813,6 @@ def _primitive_bind_call(cls, *args, **kwargs): # guard against this being called when primitive is not defined. return type.__call__(cls, *args, **kwargs) - import jax # pylint: disable=import-outside-toplevel - array_types = (jax.numpy.ndarray, np.ndarray) iterable_wires_types = ( list, @@ -1178,7 +1287,7 @@ def _check_batching(self): # There might be a way to support batching nonetheless, which remains to be # investigated. For now, the batch_size is left to be `None` when instantiating # an operation with abstract parameters that make `qml.math.ndim` fail. - if any(qml.math.is_abstract(p) for p in params): + if any(is_abstract(p) for p in params): self._batch_size = None self._ndim_params = (0,) * len(params) return diff --git a/pennylane/workflow/_capture_qnode.py b/pennylane/workflow/_capture_qnode.py index c17ff4f0b12..6cab0575e29 100644 --- a/pennylane/workflow/_capture_qnode.py +++ b/pennylane/workflow/_capture_qnode.py @@ -116,8 +116,8 @@ from jax.interpreters import ad, batching, mlir import pennylane as qml -from pennylane.capture import CaptureError, FlatFn -from pennylane.capture.custom_primitives import QmlPrimitive +from pennylane.capture import FlatFn, QmlPrimitive +from pennylane.exceptions import CaptureError from pennylane.logging import debug_logger from pennylane.typing import TensorLike diff --git a/tests/capture/workflow/test_capture_qnode.py b/tests/capture/workflow/test_capture_qnode.py index 35a12c1426d..d86d24b26d8 100644 --- a/tests/capture/workflow/test_capture_qnode.py +++ b/tests/capture/workflow/test_capture_qnode.py @@ -21,8 +21,7 @@ import pytest import pennylane as qml -from pennylane.capture import CaptureError -from pennylane.exceptions import QuantumFunctionError +from pennylane.exceptions import CaptureError, QuantumFunctionError pytestmark = [pytest.mark.jax, pytest.mark.usefixtures("enable_disable_plxpr")]