Skip to content

[WIP] Add measurements-from-samples pass to Python compiler #7620

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pennylane/compiler/python_compiler/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from xdsl.transforms.transform_interpreter import TransformInterpreterPass
from .apply_transform_sequence import ApplyTransformSequence, register_pass
from .cancel_inverses import iterative_cancel_inverses_pass, IterativeCancelInversesPass
from .measurements_from_samples import MeasurementsFromSamplesPass
from .merge_rotations import merge_rotations_pass, MergeRotationsPass
from .utils import xdsl_transform

Expand All @@ -24,6 +25,7 @@
"ApplyTransformSequence",
"iterative_cancel_inverses_pass",
"IterativeCancelInversesPass",
"MeasurementsFromSamplesPass",
"merge_rotations_pass",
"MergeRotationsPass",
"TransformInterpreterPass",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
# Copyright 2025 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""This module contains the implementation of the measurements_from_samples transform,
written using xDSL."""

from dataclasses import dataclass

import jax
import jax.numpy as jnp
from xdsl import context, passes, pattern_rewriter
from xdsl.dialects import arith, builtin, func
from xdsl.rewriter import InsertPoint

from pennylane.compiler.python_compiler import quantum_dialect as quantum
from pennylane.compiler.python_compiler.jax_utils import xdsl_module
from pennylane.compiler.python_compiler.transforms.utils import xdsl_transform


@xdsl_module
@jax.jit
def postprocessing_expval(samples, wire):
return jnp.mean(1.0 - 2.0 * samples[:, wire])


@xdsl_module
@jax.jit
def postprocessing_var(samples, wire):
return jnp.var(1.0 - 2.0 * samples[:, wire])


@xdsl_module
@jax.jit
def postprocessing_probs(samples, wire):
raise NotImplementedError("probs not implemented")


@xdsl_module
@jax.jit
def postprocessing_counts(samples, wire):
raise NotImplementedError("counts not implemented")


class MeasurementsFromSamplesPattern(pattern_rewriter.RewritePattern):
# pylint: disable=too-few-public-methods
"""Rewrite pattern for the ``measurements_from_samples`` transform, which replaces all terminal
measurements in a program with a single :func:`pennylane.sample` measurement, and adds
postprocessing instructions to recover the original measurement.
"""

@classmethod
def _validate_observable_op(cls, op: quantum.NamedObsOp):
"""TODO"""
assert isinstance(
op, quantum.NamedObsOp
), f"Expected op to be a quantum.NamedObsOp, but got {type(op)}"
if not op.type.data == "PauliZ":
raise NotImplementedError(
f"Observable '{op.type.data}' used as input to expval is not "
f"supported for the measurements_from_samples transform; currently only "
f"PauliZ operations are permitted"
)

Check notice on line 73 in pennylane/compiler/python_compiler/transforms/measurements_from_samples.py

View check run for this annotation

codefactor.io / CodeFactor

pennylane/compiler/python_compiler/transforms/measurements_from_samples.py#L73

Missing function or method docstring (missing-function-docstring)

# pylint: disable=arguments-differ,no-self-use
@pattern_rewriter.op_type_rewrite_pattern
def match_and_rewrite(self, funcOp: func.FuncOp, rewriter: pattern_rewriter.PatternRewriter):
"""Implementation of the match-and-rewrite pattern for FuncOps that may contain terminal
measurements to be replaced with a single sample measurement."""

Check notice on line 79 in pennylane/compiler/python_compiler/transforms/measurements_from_samples.py

View check run for this annotation

codefactor.io / CodeFactor

pennylane/compiler/python_compiler/transforms/measurements_from_samples.py#L79

Missing function or method docstring (missing-function-docstring)
# Walk through the func operations
supported_measure_ops = (quantum.ExpvalOp, quantum.VarianceOp, quantum.SampleOp)
unsupported_measure_ops = (quantum.ProbsOp, quantum.CountsOp, quantum.StateOp)
measure_ops = supported_measure_ops + unsupported_measure_ops

shots = None

Check notice on line 85 in pennylane/compiler/python_compiler/transforms/measurements_from_samples.py

View check run for this annotation

codefactor.io / CodeFactor

pennylane/compiler/python_compiler/transforms/measurements_from_samples.py#L85

Missing function or method docstring (missing-function-docstring)

print("[DEBUG]: Starting walk")

for op in funcOp.body.walk():
print(f"[DEBUG]: Visiting op {op.name}")

Check notice on line 91 in pennylane/compiler/python_compiler/transforms/measurements_from_samples.py

View check run for this annotation

codefactor.io / CodeFactor

pennylane/compiler/python_compiler/transforms/measurements_from_samples.py#L91

Missing function or method docstring (missing-function-docstring)
if shots is None and isinstance(op, quantum.DeviceInitOp):
shots = op.shots
static_shots_val = (
shots.owner.operands[0].owner.properties["value"].get_int_values()[0]
)

if not isinstance(op, measure_ops):
continue

if isinstance(op, unsupported_measure_ops):
raise NotImplementedError(
f"The measurements_from_samples transform does not support "
f"operations of type '{op.name}'"
)

if isinstance(op, quantum.SampleOp):
continue

if isinstance(op, quantum.ExpvalOp):
print("[DEBUG]: Found expval op")
observable_op = op.operands[0].owner
self._validate_observable_op(observable_op)

in_qubit = observable_op.operands[0]

# Steps:
# 1. Insert op quantum.compbasis with input = in_qubit, output is a quantum.obs
# 2. Insert op quantum.sample with input = the quantum.obs vs output of compbasis
# 3. Erase op quantum.expval
# 4. Erase op quantum.namedobs [PauliZ] (assuming we only support PauliZ basis!)
compbasis_op = quantum.ComputationalBasisOp(
operands=[in_qubit, None], result_types=[quantum.ObservableType()]
)
rewriter.insert_op(compbasis_op, insertion_point=InsertPoint.before(observable_op))

# TODO: this assumes MP acts on 1 wire, what if there are more?
sample_op = quantum.SampleOp(
operands=[compbasis_op.results[0], None, None],
result_types=[builtin.TensorType(builtin.Float64Type(), [static_shots_val, 1])],
)
rewriter.insert_op(sample_op, insertion_point=InsertPoint.after(compbasis_op))

# Insert the post-processing function
postprocessing_func_name = f"expval_from_samples.tensor.{static_shots_val}x1xf64"
block_ops = funcOp.parent.ops
have_inserted_postproc = False
for block_op in block_ops:
if (
isinstance(block_op, func.FuncOp)
and block_op.sym_name.data == postprocessing_func_name
):
print(f"[DEBUG]: funcOp '{postprocessing_func_name}' already defined")
postprocessing_func = block_op
have_inserted_postproc = True

if not have_inserted_postproc:
postprocessing_module = postprocessing_expval(
jax.core.ShapedArray([static_shots_val, 1], float), 1 # FIXME
)

for _func in postprocessing_module.body.walk():
postprocessing_func = _func
break
postprocessing_func = postprocessing_func.clone()

postprocessing_func.sym_name = builtin.StringAttr(data=postprocessing_func_name)

funcOp.parent.insert_op_after(postprocessing_func, funcOp)

# Insert the call to the post-processing function
value_zero = builtin.IntegerAttr(0, value_type=64)
value_zero_op = arith.ConstantOp(
builtin.DenseIntOrFPElementsAttr.create_dense_int(
type=builtin.TensorType(value_zero.type, shape=()), data=value_zero
)
)
rewriter.insert_op(value_zero_op, insertion_point=InsertPoint.after(sample_op))
postprocessing_func_call_op = func.CallOp(
callee=builtin.FlatSymbolRefAttr(postprocessing_func.sym_name),
arguments=[sample_op.results[0], value_zero_op],
return_types=[builtin.TensorType(builtin.Float64Type(), shape=())],
)

op_to_replace = list(op.results[0].uses)[0].operation
rewriter.replace_op(op_to_replace, postprocessing_func_call_op)
rewriter.erase_op(op)
rewriter.erase_op(observable_op)

elif isinstance(op, quantum.VarianceOp):
pass


@xdsl_transform
@dataclass(frozen=True)
class MeasurementsFromSamplesPass(passes.ModulePass):
"""Pass that replaces all terminal measurements in a program with a single
:func:`pennylane.sample` measurement, and adds postprocessing instructions to recover the
original measurement.
"""

name = "measurements-from-samples"

# pylint: disable=arguments-renamed,no-self-use
def apply(self, _ctx: context.Context, module: builtin.ModuleOp) -> None:
"""Apply the measurements-from-samples pass."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some comments:

  1. If your post-processing functions can be generated in the apply method here instead of the pattern, that in my opinion is preferrable. A pattern is inside of a loop. You can just create a Rewriter() inside the apply method.
  2. If they can't be automatically generated (for example, you may generate N different functions which depend on the types of N number of inputs) then you will likely need to implement some sort of namespacing. E.g.,
func.func @postrocessing_expval.tensor.8xf64(%arg0 : tensor<8xf64>, %arg1: int) {
  ...

But maybe you could do this a bit more generic:

func.func @postrocessing_expval.tensor(%arg0 : tensor<?xf64>, %tensor_size: i64, %wire: i64) {
  ...
  1. There's also rewriter.replace_all_uses_with which may help in some areas before deleting operations.
  2. You may want different patterns (one for expval, one for var...).
  3. If you only care about expval, you can match directly (you don't need to walk through the function)
    def match_and_rewrite(self, op_you_care_about: quantum.ExpvalOp, rewriter: pattern_rewriter.PatternRewriter):
  1. The body of your apply can be more complex, running a set of patterns first and then a second set of patterns. That way you can add some sequential logic.
    def apply(self, _ctx: context.Context, module: builtin.ModuleOp) -> None:
        """Apply the measurements-from-samples pass."""
        pattern_rewriter.PatternRewriteWalker(
            pattern_rewriter.GreedyRewritePatternApplier([OneSetOfPatterns()])
        ).rewrite_module(module)

        pattern_rewriter.PatternRewriteWalker(
            pattern_rewriter.GreedyRewritePatternApplier([AnotherSetOfPatterns()])
        ).rewrite_module(module)

Copy link
Contributor

@erick-xanadu erick-xanadu Jun 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, depending on how comfortable you are writing MLIR by handle, you can use:

@xdsl_from_docstring
def foo():
  """
    func.func @postrocessing_expval(%arg0 : tensor<?xf64>, %tensor_size: i64, %wire: i64) ...
  """

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The other function we discussed for possibly keeping it all in jax is jax.jit(f, abstracted_axes=...) but it is undocumented and not well behaved.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree with Erick's comments. In particular, using the dynamic tensor approach would also allow you to use the same piece of IR for all instances of the problem.

@xdsl_from_docstring
def foo():
"""
func.func @postrocessing_expval(%arg0 : tensor<?xf64>, %tensor_size: i64, %wire: i64) ...
"""

Just a note on this last segment. There is no need to pass in a tensor size value (since this isn't C 😅), you can query the size of any dimension of any tensor with the tensor.dim operation.

pattern_rewriter.PatternRewriteWalker(
pattern_rewriter.GreedyRewritePatternApplier([MeasurementsFromSamplesPattern()])
).rewrite_module(module)
91 changes: 91 additions & 0 deletions tests/python_compiler/test_transform_measurements_from_samples.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Copyright 2025 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Unit tests for the Python compiler `measurements_from_samples` transform."""

# pylint: disable=wrong-import-position

import io

import pytest

# pytestmark = pytest.mark.external

xdsl = pytest.importorskip("xdsl")

# For lit tests of xdsl passes, we use https://github.com/AntonLydike/filecheck/,
# a Python re-implementation of FileCheck
filecheck = pytest.importorskip("filecheck")

import pennylane as qml
from pennylane.compiler.python_compiler.transforms import MeasurementsFromSamplesPass


def test_transform():
program = """
// CHECK: identity
func.func @identity(%arg0 : i32) -> i32 {
return %arg0 : i32
}
"""
ctx = xdsl.context.Context()
ctx.load_dialect(xdsl.dialects.builtin.Builtin)
ctx.load_dialect(xdsl.dialects.func.Func)
module = xdsl.parser.Parser(ctx, program).parse_module()
pipeline = xdsl.passes.PipelinePass((MeasurementsFromSamplesPass(),))
pipeline.apply(ctx, module)
from filecheck.finput import FInput
from filecheck.matcher import Matcher
from filecheck.options import parse_argv_options
from filecheck.parser import Parser, pattern_for_opts

opts = parse_argv_options(["filecheck", __file__])
matcher = Matcher(
opts,
FInput("no-name", str(module)),
Parser(opts, io.StringIO(program), *pattern_for_opts(opts)),
)
assert matcher.run() == 0


class TestMeasurementsFromSamplesExecution:
"""TODO"""

def test_measurements_from_samples_basic(self):
"""TODO"""
from catalyst.passes import xdsl_plugin

qml.capture.enable()

dev = qml.device("lightning.qubit", wires=2, shots=10)

@qml.qjit(pass_plugins=[xdsl_plugin.getXDSLPluginAbsolutePath()])
@MeasurementsFromSamplesPass
@qml.qnode(dev)
def circuit():
qml.H(0)
return qml.expval(qml.Z(0)), qml.expval(qml.Z(1))

print("\n")
print(circuit())
print("\n")
print(circuit.mlir)

# TODO: Add asserts!

qml.capture.disable()


if __name__ == "__main__":
pytest.main(["-x", __file__])