Skip to content

[WIP] Add measurements-from-samples pass to Python compiler #7620

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from

Conversation

joeycarter
Copy link
Contributor

@joeycarter joeycarter commented Jun 5, 2025

Before submitting

Please complete the following checklist when submitting a PR:

  • All new features must include a unit test.
    If you've fixed a bug or added code that should be tested, add a test to the
    test directory!

  • All new functions and code must be clearly commented and documented.
    If you do make documentation changes, make sure that the docs build and
    render correctly by running make docs.

  • Ensure that the test suite passes, by running make test.

  • Add a new entry to the doc/releases/changelog-dev.md file, summarizing the
    change, and including a link back to the PR.

  • The PennyLane source code conforms to
    PEP8 standards.
    We check all of our code against Pylint.
    To lint modified files, simply pip install pylint, and then
    run pylint pennylane/path/to/file.py.

When all the above are checked, delete everything above the dashed
line and fill in the pull request template.


Context:

Description of the Change:

Benefits:

Possible Drawbacks:

Related Shortcut Stories:
[sc-92338]

Copy link
Contributor

github-actions bot commented Jun 5, 2025

Hello. You may have forgotten to update the changelog!
Please edit doc/releases/changelog-dev.md with:

  • A one-to-two sentence description of the change. You may include a small working example for new features.
  • A link back to this PR.
  • Your name (or GitHub username) in the contributors section.


# pylint: disable=arguments-renamed,no-self-use
def apply(self, _ctx: context.Context, module: builtin.ModuleOp) -> None:
"""Apply the measurements-from-samples pass."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some comments:

  1. If your post-processing functions can be generated in the apply method here instead of the pattern, that in my opinion is preferrable. A pattern is inside of a loop. You can just create a Rewriter() inside the apply method.
  2. If they can't be automatically generated (for example, you may generate N different functions which depend on the types of N number of inputs) then you will likely need to implement some sort of namespacing. E.g.,
func.func @postrocessing_expval.tensor.8xf64(%arg0 : tensor<8xf64>, %arg1: int) {
  ...

But maybe you could do this a bit more generic:

func.func @postrocessing_expval.tensor(%arg0 : tensor<?xf64>, %tensor_size: i64, %wire: i64) {
  ...
  1. There's also rewriter.replace_all_uses_with which may help in some areas before deleting operations.
  2. You may want different patterns (one for expval, one for var...).
  3. If you only care about expval, you can match directly (you don't need to walk through the function)
    def match_and_rewrite(self, op_you_care_about: quantum.ExpvalOp, rewriter: pattern_rewriter.PatternRewriter):
  1. The body of your apply can be more complex, running a set of patterns first and then a second set of patterns. That way you can add some sequential logic.
    def apply(self, _ctx: context.Context, module: builtin.ModuleOp) -> None:
        """Apply the measurements-from-samples pass."""
        pattern_rewriter.PatternRewriteWalker(
            pattern_rewriter.GreedyRewritePatternApplier([OneSetOfPatterns()])
        ).rewrite_module(module)

        pattern_rewriter.PatternRewriteWalker(
            pattern_rewriter.GreedyRewritePatternApplier([AnotherSetOfPatterns()])
        ).rewrite_module(module)

Copy link
Contributor

@erick-xanadu erick-xanadu Jun 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, depending on how comfortable you are writing MLIR by handle, you can use:

@xdsl_from_docstring
def foo():
  """
    func.func @postrocessing_expval(%arg0 : tensor<?xf64>, %tensor_size: i64, %wire: i64) ...
  """

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The other function we discussed for possibly keeping it all in jax is jax.jit(f, abstracted_axes=...) but it is undocumented and not well behaved.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree with Erick's comments. In particular, using the dynamic tensor approach would also allow you to use the same piece of IR for all instances of the problem.

@xdsl_from_docstring
def foo():
"""
func.func @postrocessing_expval(%arg0 : tensor<?xf64>, %tensor_size: i64, %wire: i64) ...
"""

Just a note on this last segment. There is no need to pass in a tensor size value (since this isn't C 😅), you can query the size of any dimension of any tensor with the tensor.dim operation.

from .apply_transform_sequence import register_pass


def xdsl_transform(_klass):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has been merged now I believe?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just yesterday.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants