Skip to content

[WIP] Add measurements-from-samples pass to Python compiler #7620

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pennylane/compiler/python_compiler/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from xdsl.transforms.transform_interpreter import TransformInterpreterPass
from .apply_transform_sequence import ApplyTransformSequence, register_pass
from .cancel_inverses import iterative_cancel_inverses_pass, IterativeCancelInversesPass
from .measurements_from_samples import MeasurementsFromSamplesPass
from .merge_rotations import merge_rotations_pass, MergeRotationsPass
from .utils import xdsl_transform

Expand All @@ -24,6 +25,7 @@
"ApplyTransformSequence",
"iterative_cancel_inverses_pass",
"IterativeCancelInversesPass",
"MeasurementsFromSamplesPass",
"merge_rotations_pass",
"MergeRotationsPass",
"TransformInterpreterPass",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
# Copyright 2025 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""This module contains the implementation of the measurements_from_samples transform,
written using xDSL."""

from dataclasses import dataclass

import jax
import jax.numpy as jnp

from xdsl import context, passes, pattern_rewriter, ir

Check notice on line 23 in pennylane/compiler/python_compiler/transforms/measurements_from_samples.py

View check run for this annotation

codefactor.io / CodeFactor

pennylane/compiler/python_compiler/transforms/measurements_from_samples.py#L23

Unused ir imported from xdsl (unused-import)
from xdsl.dialects import builtin, func, tensor

Check notice on line 24 in pennylane/compiler/python_compiler/transforms/measurements_from_samples.py

View check run for this annotation

codefactor.io / CodeFactor

pennylane/compiler/python_compiler/transforms/measurements_from_samples.py#L24

Unused tensor imported from xdsl.dialects (unused-import)
from xdsl.rewriter import InsertPoint

from pennylane.compiler.python_compiler import quantum_dialect as quantum
from pennylane.compiler.python_compiler.jax_utils import xdsl_module

# ----- Stuff that should be upstreamed ---------------------------------------------------------- #

import pennylane as qml
from pennylane.devices.preprocess import null_postprocessing

from .apply_transform_sequence import register_pass


def xdsl_transform(_klass):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has been merged now I believe?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just yesterday.

"""Register the xdsl transform into the plxpr to catalyst map.

NOTE: This function will eventually live somewhere in the pennylane.compiler.python_compiler
module, we have to add it here locally until it's added upstream.
"""

# avoid dependency on catalyst
import catalyst # pylint: disable=import-outside-toplevel

def identity_transform(tape):
"""Stub, we only need the name to be unique"""
return tape, null_postprocessing

identity_transform.__name__ = "xdsl_transform" + _klass.__name__
transform = qml.transform(identity_transform)

# Map from plxpr to register transform
catalyst.from_plxpr.register_transform(transform, _klass.name, False)

# Register this pass as available in the apply-transform-sequence
# interpreter
def get_pass_instance():
return _klass

# breakpoint()
register_pass(_klass.name, get_pass_instance)
return transform


# ----- (END) Stuff that should be upstreamed ---------------------------------------------------- #


@xdsl_module
@jax.jit
def postprocessing_expval(samples, wire):

Check notice on line 73 in pennylane/compiler/python_compiler/transforms/measurements_from_samples.py

View check run for this annotation

codefactor.io / CodeFactor

pennylane/compiler/python_compiler/transforms/measurements_from_samples.py#L73

Missing function or method docstring (missing-function-docstring)
return jnp.mean(1.0 - 2.0 * samples[:, wire])


@xdsl_module
@jax.jit
def postprocessing_var(samples, wire):

Check notice on line 79 in pennylane/compiler/python_compiler/transforms/measurements_from_samples.py

View check run for this annotation

codefactor.io / CodeFactor

pennylane/compiler/python_compiler/transforms/measurements_from_samples.py#L79

Missing function or method docstring (missing-function-docstring)
return jnp.var(1.0 - 2.0 * samples[:, wire])


@xdsl_module
@jax.jit
def postprocessing_probs(samples, wire):

Check notice on line 85 in pennylane/compiler/python_compiler/transforms/measurements_from_samples.py

View check run for this annotation

codefactor.io / CodeFactor

pennylane/compiler/python_compiler/transforms/measurements_from_samples.py#L85

Missing function or method docstring (missing-function-docstring)
raise NotImplementedError("probs not implemented")


@xdsl_module
@jax.jit
def postprocessing_counts(samples, wire):

Check notice on line 91 in pennylane/compiler/python_compiler/transforms/measurements_from_samples.py

View check run for this annotation

codefactor.io / CodeFactor

pennylane/compiler/python_compiler/transforms/measurements_from_samples.py#L91

Missing function or method docstring (missing-function-docstring)
raise NotImplementedError("counts not implemented")


class MeasurementsFromSamplesPattern(pattern_rewriter.RewritePattern):
# pylint: disable=too-few-public-methods
"""Rewrite pattern for the ``measurements_from_samples`` transform, which replaces all terminal
measurements in a program with a single :func:`pennylane.sample` measurement, and adds
postprocessing instructions to recover the original measurement.
"""

@classmethod
def _validate_observable_op(cls, op: quantum.NamedObsOp):
"""TODO"""
assert isinstance(
op, quantum.NamedObsOp
), f"Expected op to be a quantum.NamedObsOp, but got {type(op)}"
if not op.type.data == "PauliZ":
raise NotImplementedError(
f"Observable '{op.type.data}' used as input to expval is not "
f"supported for the measurements_from_samples transform; currently only "
f"PauliZ operations are permitted"
)

# @classmethod
# def _get_shots_from_device_op(cls, op: quantum.DeviceInitOp):
# """TODO"""
# assert isinstance(op, quantum.DeviceInitOp), (
# f"Expected op to be a quantum.DeviceInitOp, but got {type(op)}"
# )
# return op.shots

# pylint: disable=arguments-differ,no-self-use
@pattern_rewriter.op_type_rewrite_pattern
def match_and_rewrite(self, funcOp: func.FuncOp, rewriter: pattern_rewriter.PatternRewriter):
"""Implementation of the match-and-rewrite pattern for FuncOps that may contain terminal
measurements to be replaced with a single sample measurement."""
# Walk through the func operations
supported_measure_ops = (quantum.ExpvalOp, quantum.VarianceOp, quantum.SampleOp)
unsupported_measure_ops = (quantum.ProbsOp, quantum.CountsOp, quantum.StateOp)
measure_ops = supported_measure_ops + unsupported_measure_ops

shots = None
have_inserted_postproc = False

for op in funcOp.body.walk():
if shots is None and isinstance(op, quantum.DeviceInitOp):
# shots = self._get_shots_from_device_op(op)
shots = op.shots

if not isinstance(op, measure_ops):
continue

if isinstance(op, unsupported_measure_ops):
raise NotImplementedError(
f"The measurements_from_samples transform does not support "
f"operations of type '{op.name}'"
)

if isinstance(op, quantum.SampleOp):
return

if isinstance(op, quantum.ExpvalOp):
observable_op = op.operands[0].owner
self._validate_observable_op(observable_op)

in_qubit = observable_op.operands[0]

# Steps:
# 1. Insert op quantum.compbasis with input = in_qubit, output is a quantum.obs
# 2. Insert op quantum.sample with input = the quantum.obs vs output of compbasis
# 3. Erase op quantum.expval
# 4. Erase op quantum.namedobs [PauliZ] (assuming we only support PauliZ basis!)
compbasis_op = quantum.ComputationalBasisOp(
operands=[in_qubit, None], result_types=[quantum.ObservableType()]
)
rewriter.insert_op(compbasis_op, insertion_point=InsertPoint.before(observable_op))

# TODO: this assumes MP acts on 1 wire, what if there are more?
sample_op = quantum.SampleOp(
operands=[compbasis_op.results[0], shots, None],
result_types=[builtin.TensorType(builtin.Float64Type(), [-1, 1])],
)
rewriter.insert_op(sample_op, insertion_point=InsertPoint.after(compbasis_op))

# TODO: We can't delete these yet because there are other ops that use their output
# rewriter.erase_matched_op()
# rewriter.erase_op(observable_op)

# Insert the post-processing function
if not have_inserted_postproc:
postprocessing_module = postprocessing_expval(
jax.core.ShapedArray([10, 1], int), 1 # FIXME
)

for _func in postprocessing_module.body.walk():
postprocessing_func = _func
break
postprocessing_func = postprocessing_func.clone()

postprocessing_func.sym_name = builtin.StringAttr(data="expval_from_samples")

funcOp.parent.insert_op_after(postprocessing_func, funcOp)

have_inserted_postproc = True

# Insert the call to the post-processing function
postprocessing_func_call_op = func.CallOp(
callee=builtin.FlatSymbolRefAttr(postprocessing_func.sym_name),
arguments=[sample_op.results[0], in_qubit], # FIXME
return_types=[builtin.Float64Type()],
)
rewriter.insert_op(
postprocessing_func_call_op, insertion_point=InsertPoint.after(sample_op)
)

elif isinstance(op, quantum.VarianceOp):
pass


@xdsl_transform
@dataclass(frozen=True)
class MeasurementsFromSamplesPass(passes.ModulePass):
"""Pass that replaces all terminal measurements in a program with a single
:func:`pennylane.sample` measurement, and adds postprocessing instructions to recover the
original measurement.
"""

name = "measurements-from-samples"

# pylint: disable=arguments-renamed,no-self-use
def apply(self, _ctx: context.Context, module: builtin.ModuleOp) -> None:
"""Apply the measurements-from-samples pass."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some comments:

  1. If your post-processing functions can be generated in the apply method here instead of the pattern, that in my opinion is preferrable. A pattern is inside of a loop. You can just create a Rewriter() inside the apply method.
  2. If they can't be automatically generated (for example, you may generate N different functions which depend on the types of N number of inputs) then you will likely need to implement some sort of namespacing. E.g.,
func.func @postrocessing_expval.tensor.8xf64(%arg0 : tensor<8xf64>, %arg1: int) {
  ...

But maybe you could do this a bit more generic:

func.func @postrocessing_expval.tensor(%arg0 : tensor<?xf64>, %tensor_size: i64, %wire: i64) {
  ...
  1. There's also rewriter.replace_all_uses_with which may help in some areas before deleting operations.
  2. You may want different patterns (one for expval, one for var...).
  3. If you only care about expval, you can match directly (you don't need to walk through the function)
    def match_and_rewrite(self, op_you_care_about: quantum.ExpvalOp, rewriter: pattern_rewriter.PatternRewriter):
  1. The body of your apply can be more complex, running a set of patterns first and then a second set of patterns. That way you can add some sequential logic.
    def apply(self, _ctx: context.Context, module: builtin.ModuleOp) -> None:
        """Apply the measurements-from-samples pass."""
        pattern_rewriter.PatternRewriteWalker(
            pattern_rewriter.GreedyRewritePatternApplier([OneSetOfPatterns()])
        ).rewrite_module(module)

        pattern_rewriter.PatternRewriteWalker(
            pattern_rewriter.GreedyRewritePatternApplier([AnotherSetOfPatterns()])
        ).rewrite_module(module)

Copy link
Contributor

@erick-xanadu erick-xanadu Jun 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, depending on how comfortable you are writing MLIR by handle, you can use:

@xdsl_from_docstring
def foo():
  """
    func.func @postrocessing_expval(%arg0 : tensor<?xf64>, %tensor_size: i64, %wire: i64) ...
  """

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The other function we discussed for possibly keeping it all in jax is jax.jit(f, abstracted_axes=...) but it is undocumented and not well behaved.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree with Erick's comments. In particular, using the dynamic tensor approach would also allow you to use the same piece of IR for all instances of the problem.

@xdsl_from_docstring
def foo():
"""
func.func @postrocessing_expval(%arg0 : tensor<?xf64>, %tensor_size: i64, %wire: i64) ...
"""

Just a note on this last segment. There is no need to pass in a tensor size value (since this isn't C 😅), you can query the size of any dimension of any tensor with the tensor.dim operation.

pattern_rewriter.PatternRewriteWalker(
pattern_rewriter.GreedyRewritePatternApplier([MeasurementsFromSamplesPattern()])
).rewrite_module(module)
83 changes: 83 additions & 0 deletions tests/python_compiler/test_transform_measurements_from_samples.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Copyright 2025 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Unit tests for the Python compiler `measurements_from_samples` transform."""

# pylint: disable=wrong-import-position

import io
import pytest

# pytestmark = pytest.mark.external

xdsl = pytest.importorskip("xdsl")

# For lit tests of xdsl passes, we use https://github.com/AntonLydike/filecheck/,
# a Python re-implementation of FileCheck
filecheck = pytest.importorskip("filecheck")

import pennylane as qml
from pennylane.compiler.python_compiler.transforms import MeasurementsFromSamplesPass

def test_transform():
program = """
// CHECK: identity
func.func @identity(%arg0 : i32) -> i32 {
return %arg0 : i32
}
"""
ctx = xdsl.context.Context()
ctx.load_dialect(xdsl.dialects.builtin.Builtin)
ctx.load_dialect(xdsl.dialects.func.Func)
module = xdsl.parser.Parser(ctx, program).parse_module()
pipeline = xdsl.passes.PipelinePass((MeasurementsFromSamplesPass(),))
pipeline.apply(ctx, module)
from filecheck.matcher import Matcher
from filecheck.options import parse_argv_options
from filecheck.parser import Parser, pattern_for_opts
from filecheck.finput import FInput
opts = parse_argv_options(["filecheck", __file__])
matcher = Matcher(opts, FInput("no-name", str(module)), Parser(opts, io.StringIO(program),*pattern_for_opts(opts)))
assert matcher.run() == 0


class TestMeasurementsFromSamplesExecution:
"""TODO"""
def test_measurements_from_samples_basic(self):
"""TODO"""
from catalyst.passes import xdsl_plugin

qml.capture.enable()

dev = qml.device("lightning.qubit", wires=2, shots=10)

@qml.qjit(pass_plugins=[xdsl_plugin.getXDSLPluginAbsolutePath()])
# @qml.qjit(keep_intermediate=True, pass_plugins=[xdsl_plugin.getXDSLPluginAbsolutePath()])
@MeasurementsFromSamplesPass
@qml.qnode(dev)
def deleteme():
qml.H(0)
return qml.expval(qml.Z(0)), qml.expval(qml.Z(1))
# return qml.sample(wires=0), qml.sample(wires=1)

print("\n")
print(deleteme())
print("\n")
print(deleteme.mlir)

qml.capture.disable()


if __name__ == "__main__":
pytest.main(["-x", __file__])