-
Notifications
You must be signed in to change notification settings - Fork 672
[WIP] Add measurements-from-samples pass to Python compiler #7620
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,199 @@ | ||
# Copyright 2025 Xanadu Quantum Technologies Inc. | ||
|
||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
|
||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""This module contains the implementation of the measurements_from_samples transform, | ||
written using xDSL.""" | ||
|
||
from dataclasses import dataclass | ||
|
||
import jax | ||
import jax.numpy as jnp | ||
from xdsl import context, passes, pattern_rewriter | ||
from xdsl.dialects import arith, builtin, func | ||
from xdsl.rewriter import InsertPoint | ||
|
||
from pennylane.compiler.python_compiler import quantum_dialect as quantum | ||
from pennylane.compiler.python_compiler.jax_utils import xdsl_module | ||
from pennylane.compiler.python_compiler.transforms.utils import xdsl_transform | ||
|
||
|
||
@xdsl_module | ||
@jax.jit | ||
def postprocessing_expval(samples, wire): | ||
return jnp.mean(1.0 - 2.0 * samples[:, wire]) | ||
|
||
|
||
@xdsl_module | ||
@jax.jit | ||
def postprocessing_var(samples, wire): | ||
return jnp.var(1.0 - 2.0 * samples[:, wire]) | ||
|
||
|
||
@xdsl_module | ||
@jax.jit | ||
def postprocessing_probs(samples, wire): | ||
raise NotImplementedError("probs not implemented") | ||
|
||
|
||
@xdsl_module | ||
@jax.jit | ||
def postprocessing_counts(samples, wire): | ||
raise NotImplementedError("counts not implemented") | ||
|
||
|
||
class MeasurementsFromSamplesPattern(pattern_rewriter.RewritePattern): | ||
# pylint: disable=too-few-public-methods | ||
"""Rewrite pattern for the ``measurements_from_samples`` transform, which replaces all terminal | ||
measurements in a program with a single :func:`pennylane.sample` measurement, and adds | ||
postprocessing instructions to recover the original measurement. | ||
""" | ||
|
||
@classmethod | ||
def _validate_observable_op(cls, op: quantum.NamedObsOp): | ||
"""TODO""" | ||
assert isinstance( | ||
op, quantum.NamedObsOp | ||
), f"Expected op to be a quantum.NamedObsOp, but got {type(op)}" | ||
if not op.type.data == "PauliZ": | ||
raise NotImplementedError( | ||
f"Observable '{op.type.data}' used as input to expval is not " | ||
f"supported for the measurements_from_samples transform; currently only " | ||
f"PauliZ operations are permitted" | ||
) | ||
Check notice on line 73 in pennylane/compiler/python_compiler/transforms/measurements_from_samples.py
|
||
|
||
# pylint: disable=arguments-differ,no-self-use | ||
@pattern_rewriter.op_type_rewrite_pattern | ||
def match_and_rewrite(self, funcOp: func.FuncOp, rewriter: pattern_rewriter.PatternRewriter): | ||
"""Implementation of the match-and-rewrite pattern for FuncOps that may contain terminal | ||
measurements to be replaced with a single sample measurement.""" | ||
Check notice on line 79 in pennylane/compiler/python_compiler/transforms/measurements_from_samples.py
|
||
# Walk through the func operations | ||
supported_measure_ops = (quantum.ExpvalOp, quantum.VarianceOp, quantum.SampleOp) | ||
unsupported_measure_ops = (quantum.ProbsOp, quantum.CountsOp, quantum.StateOp) | ||
measure_ops = supported_measure_ops + unsupported_measure_ops | ||
|
||
shots = None | ||
Check notice on line 85 in pennylane/compiler/python_compiler/transforms/measurements_from_samples.py
|
||
|
||
print("[DEBUG]: Starting walk") | ||
|
||
for op in funcOp.body.walk(): | ||
print(f"[DEBUG]: Visiting op {op.name}") | ||
|
||
Check notice on line 91 in pennylane/compiler/python_compiler/transforms/measurements_from_samples.py
|
||
if shots is None and isinstance(op, quantum.DeviceInitOp): | ||
shots = op.shots | ||
static_shots_val = ( | ||
shots.owner.operands[0].owner.properties["value"].get_int_values()[0] | ||
) | ||
|
||
if not isinstance(op, measure_ops): | ||
continue | ||
|
||
if isinstance(op, unsupported_measure_ops): | ||
raise NotImplementedError( | ||
f"The measurements_from_samples transform does not support " | ||
f"operations of type '{op.name}'" | ||
) | ||
|
||
if isinstance(op, quantum.SampleOp): | ||
continue | ||
|
||
if isinstance(op, quantum.ExpvalOp): | ||
print("[DEBUG]: Found expval op") | ||
observable_op = op.operands[0].owner | ||
self._validate_observable_op(observable_op) | ||
|
||
in_qubit = observable_op.operands[0] | ||
|
||
# Steps: | ||
# 1. Insert op quantum.compbasis with input = in_qubit, output is a quantum.obs | ||
# 2. Insert op quantum.sample with input = the quantum.obs vs output of compbasis | ||
# 3. Erase op quantum.expval | ||
# 4. Erase op quantum.namedobs [PauliZ] (assuming we only support PauliZ basis!) | ||
compbasis_op = quantum.ComputationalBasisOp( | ||
operands=[in_qubit, None], result_types=[quantum.ObservableType()] | ||
) | ||
rewriter.insert_op(compbasis_op, insertion_point=InsertPoint.before(observable_op)) | ||
|
||
# TODO: this assumes MP acts on 1 wire, what if there are more? | ||
sample_op = quantum.SampleOp( | ||
operands=[compbasis_op.results[0], None, None], | ||
result_types=[builtin.TensorType(builtin.Float64Type(), [static_shots_val, 1])], | ||
) | ||
rewriter.insert_op(sample_op, insertion_point=InsertPoint.after(compbasis_op)) | ||
|
||
# Insert the post-processing function | ||
postprocessing_func_name = f"expval_from_samples.tensor.{static_shots_val}x1xf64" | ||
block_ops = funcOp.parent.ops | ||
have_inserted_postproc = False | ||
for block_op in block_ops: | ||
if ( | ||
isinstance(block_op, func.FuncOp) | ||
and block_op.sym_name.data == postprocessing_func_name | ||
): | ||
print(f"[DEBUG]: funcOp '{postprocessing_func_name}' already defined") | ||
postprocessing_func = block_op | ||
have_inserted_postproc = True | ||
|
||
if not have_inserted_postproc: | ||
postprocessing_module = postprocessing_expval( | ||
jax.core.ShapedArray([static_shots_val, 1], float), 1 # FIXME | ||
) | ||
|
||
for _func in postprocessing_module.body.walk(): | ||
postprocessing_func = _func | ||
break | ||
postprocessing_func = postprocessing_func.clone() | ||
|
||
postprocessing_func.sym_name = builtin.StringAttr(data=postprocessing_func_name) | ||
|
||
funcOp.parent.insert_op_after(postprocessing_func, funcOp) | ||
|
||
# Insert the call to the post-processing function | ||
value_zero = builtin.IntegerAttr(0, value_type=64) | ||
value_zero_op = arith.ConstantOp( | ||
builtin.DenseIntOrFPElementsAttr.create_dense_int( | ||
type=builtin.TensorType(value_zero.type, shape=()), data=value_zero | ||
) | ||
) | ||
rewriter.insert_op(value_zero_op, insertion_point=InsertPoint.after(sample_op)) | ||
postprocessing_func_call_op = func.CallOp( | ||
callee=builtin.FlatSymbolRefAttr(postprocessing_func.sym_name), | ||
arguments=[sample_op.results[0], value_zero_op], | ||
return_types=[builtin.TensorType(builtin.Float64Type(), shape=())], | ||
) | ||
|
||
op_to_replace = list(op.results[0].uses)[0].operation | ||
rewriter.replace_op(op_to_replace, postprocessing_func_call_op) | ||
rewriter.erase_op(op) | ||
rewriter.erase_op(observable_op) | ||
|
||
elif isinstance(op, quantum.VarianceOp): | ||
pass | ||
|
||
|
||
@xdsl_transform | ||
@dataclass(frozen=True) | ||
class MeasurementsFromSamplesPass(passes.ModulePass): | ||
"""Pass that replaces all terminal measurements in a program with a single | ||
:func:`pennylane.sample` measurement, and adds postprocessing instructions to recover the | ||
original measurement. | ||
""" | ||
|
||
name = "measurements-from-samples" | ||
|
||
# pylint: disable=arguments-renamed,no-self-use | ||
def apply(self, _ctx: context.Context, module: builtin.ModuleOp) -> None: | ||
"""Apply the measurements-from-samples pass.""" | ||
pattern_rewriter.PatternRewriteWalker( | ||
pattern_rewriter.GreedyRewritePatternApplier([MeasurementsFromSamplesPattern()]) | ||
).rewrite_module(module) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
# Copyright 2025 Xanadu Quantum Technologies Inc. | ||
|
||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
|
||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Unit tests for the Python compiler `measurements_from_samples` transform.""" | ||
|
||
# pylint: disable=wrong-import-position | ||
|
||
import io | ||
|
||
import pytest | ||
|
||
# pytestmark = pytest.mark.external | ||
|
||
xdsl = pytest.importorskip("xdsl") | ||
|
||
# For lit tests of xdsl passes, we use https://github.com/AntonLydike/filecheck/, | ||
# a Python re-implementation of FileCheck | ||
filecheck = pytest.importorskip("filecheck") | ||
|
||
import pennylane as qml | ||
from pennylane.compiler.python_compiler.transforms import MeasurementsFromSamplesPass | ||
|
||
|
||
def test_transform(): | ||
program = """ | ||
// CHECK: identity | ||
func.func @identity(%arg0 : i32) -> i32 { | ||
return %arg0 : i32 | ||
} | ||
""" | ||
ctx = xdsl.context.Context() | ||
ctx.load_dialect(xdsl.dialects.builtin.Builtin) | ||
ctx.load_dialect(xdsl.dialects.func.Func) | ||
module = xdsl.parser.Parser(ctx, program).parse_module() | ||
pipeline = xdsl.passes.PipelinePass((MeasurementsFromSamplesPass(),)) | ||
pipeline.apply(ctx, module) | ||
from filecheck.finput import FInput | ||
from filecheck.matcher import Matcher | ||
from filecheck.options import parse_argv_options | ||
from filecheck.parser import Parser, pattern_for_opts | ||
|
||
opts = parse_argv_options(["filecheck", __file__]) | ||
matcher = Matcher( | ||
opts, | ||
FInput("no-name", str(module)), | ||
Parser(opts, io.StringIO(program), *pattern_for_opts(opts)), | ||
) | ||
assert matcher.run() == 0 | ||
|
||
|
||
class TestMeasurementsFromSamplesExecution: | ||
"""TODO""" | ||
|
||
def test_measurements_from_samples_basic(self): | ||
"""TODO""" | ||
from catalyst.passes import xdsl_plugin | ||
|
||
qml.capture.enable() | ||
|
||
dev = qml.device("lightning.qubit", wires=2, shots=10) | ||
|
||
@qml.qjit(pass_plugins=[xdsl_plugin.getXDSLPluginAbsolutePath()]) | ||
@MeasurementsFromSamplesPass | ||
@qml.qnode(dev) | ||
def circuit(): | ||
qml.H(0) | ||
return qml.expval(qml.Z(0)), qml.expval(qml.Z(1)) | ||
|
||
print("\n") | ||
print(circuit()) | ||
print("\n") | ||
print(circuit.mlir) | ||
|
||
# TODO: Add asserts! | ||
|
||
qml.capture.disable() | ||
|
||
|
||
if __name__ == "__main__": | ||
pytest.main(["-x", __file__]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some comments:
Rewriter()
inside theapply
method.But maybe you could do this a bit more generic:
rewriter.replace_all_uses_with
which may help in some areas before deleting operations.Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, depending on how comfortable you are writing MLIR by handle, you can use:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The other function we discussed for possibly keeping it all in jax is
jax.jit(f, abstracted_axes=...)
but it is undocumented and not well behaved.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree with Erick's comments. In particular, using the dynamic tensor approach would also allow you to use the same piece of IR for all instances of the problem.
Just a note on this last segment. There is no need to pass in a tensor size value (since this isn't C 😅), you can query the size of any dimension of any tensor with the
tensor.dim
operation.