diff --git a/CHANGELOG.md b/CHANGELOG.md index da8a87329..e911c166e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ This project adheres to [Semantic Versioning], with the exception that minor rel ### Added +- ✨ Add MLIR pass for merging rotation gates ([#1019]) ([**@denialhaag**]) - ✨ Add functions to generate random vector DDs ([#975]) ([**@MatthiasReumann**]) - ✨ Add function to approximate decision diagrams ([#908]) ([**@MatthiasReumann**]) - 📦 Add Windows ARM64 wheels ([#926]) ([**@burgholzer**]) @@ -119,6 +120,7 @@ _📚 Refer to the [GitHub Release Notes](https://github.com/munich-quantum-tool [#1020]: https://github.com/munich-quantum-toolkit/core/pull/1020 +[#1019]: https://github.com/munich-quantum-toolkit/core/pull/1019 [#984]: https://github.com/munich-quantum-toolkit/core/pull/984 [#982]: https://github.com/munich-quantum-toolkit/core/pull/982 [#975]: https://github.com/munich-quantum-toolkit/core/pull/975 diff --git a/mlir/include/mlir/Dialect/MQTOpt/IR/MQTOptInterfaces.td b/mlir/include/mlir/Dialect/MQTOpt/IR/MQTOptInterfaces.td index 26483c6bc..68c2a1cb3 100644 --- a/mlir/include/mlir/Dialect/MQTOpt/IR/MQTOptInterfaces.td +++ b/mlir/include/mlir/Dialect/MQTOpt/IR/MQTOptInterfaces.td @@ -140,6 +140,15 @@ def UnitaryInterface : OpInterface<"UnitaryInterface"> { operands.insert(operands.end(), outQubits.begin(), outQubits.end()); operands.insert(operands.end(), controls.begin(), controls.end()); return operands; + }]>, + InterfaceMethod< + /*desc=*/ "Get params.", + /*returnType=*/ "mlir::ValueRange", + /*methodName=*/ "getParams", + /*args=*/ (ins), + /*methodBody=*/ [{}], + /*defaultImpl=*/ [{ + return $_op.getParams(); }]> ]; diff --git a/mlir/include/mlir/Dialect/MQTOpt/Transforms/Passes.h b/mlir/include/mlir/Dialect/MQTOpt/Transforms/Passes.h index 0600c6f29..d88c5450e 100644 --- a/mlir/include/mlir/Dialect/MQTOpt/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/MQTOpt/Transforms/Passes.h @@ -26,6 +26,7 @@ namespace mqt::ir::opt { #include "mlir/Dialect/MQTOpt/Transforms/Passes.h.inc" // IWYU pragma: export void populateCancelInversesPatterns(mlir::RewritePatternSet& patterns); +void populateMergeRotationGatesPatterns(mlir::RewritePatternSet& patterns); void populateQuantumSinkShiftPatterns(mlir::RewritePatternSet& patterns); void populateQuantumSinkPushPatterns(mlir::RewritePatternSet& patterns); void populateToQuantumComputationPatterns(mlir::RewritePatternSet& patterns, diff --git a/mlir/include/mlir/Dialect/MQTOpt/Transforms/Passes.td b/mlir/include/mlir/Dialect/MQTOpt/Transforms/Passes.td index 306f75fec..3a5d24897 100644 --- a/mlir/include/mlir/Dialect/MQTOpt/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/MQTOpt/Transforms/Passes.td @@ -36,6 +36,16 @@ def CancelConsecutiveInverses : Pass<"cancel-consecutive-inverses", "mlir::Modul }]; } +def MergeRotationGates : Pass<"merge-rotation-gates", "mlir::ModuleOp"> { + let summary = "This pass searches for consecutive applications of rotation gates that can be merged."; + let description = [{ + Consecutive applications of gphase, p, rx, ry, rz, rxx, ryy, rzz, and rzx are merged into one by adding their angles. + The merged gate is currently not removed if the angles add up to zero. + + This pass currently does not affect xxminusyy, xxplusyy, u, and u2. + }]; +} + def QuantumSinkPass : Pass<"quantum-sink", "mlir::ModuleOp"> { let summary = "This pass attempts to push down operations into branches for possible optimizations."; let description = [{ diff --git a/mlir/lib/Dialect/MQTOpt/Transforms/MergeRotationGates.cpp b/mlir/lib/Dialect/MQTOpt/Transforms/MergeRotationGates.cpp new file mode 100644 index 000000000..354969437 --- /dev/null +++ b/mlir/lib/Dialect/MQTOpt/Transforms/MergeRotationGates.cpp @@ -0,0 +1,45 @@ +/* + * Copyright (c) 2023 - 2025 Chair for Design Automation, TUM + * Copyright (c) 2025 Munich Quantum Software Company GmbH + * All rights reserved. + * + * SPDX-License-Identifier: MIT + * + * Licensed under the MIT License + */ + +#include "mlir/Dialect/Common/Compat.h" +#include "mlir/Dialect/MQTOpt/Transforms/Passes.h" + +#include +#include +#include + +namespace mqt::ir::opt { + +#define GEN_PASS_DEF_MERGEROTATIONGATES +#include "mlir/Dialect/MQTOpt/Transforms/Passes.h.inc" + +/** + * @brief This pattern attempts to merge consecutive rotation gates. + */ +struct MergeRotationGates final + : impl::MergeRotationGatesBase { + + void runOnOperation() override { + // Get the current operation being operated on. + auto op = getOperation(); + auto* ctx = &getContext(); + + // Define the set of patterns to use. + mlir::RewritePatternSet patterns(ctx); + populateMergeRotationGatesPatterns(patterns); + + // Apply patterns in an iterative and greedy manner. + if (mlir::failed(APPLY_PATTERNS_GREEDILY(op, std::move(patterns)))) { + signalPassFailure(); + } + } +}; + +} // namespace mqt::ir::opt diff --git a/mlir/lib/Dialect/MQTOpt/Transforms/MergeRotationGatesPattern.cpp b/mlir/lib/Dialect/MQTOpt/Transforms/MergeRotationGatesPattern.cpp new file mode 100644 index 000000000..49e6cc7b8 --- /dev/null +++ b/mlir/lib/Dialect/MQTOpt/Transforms/MergeRotationGatesPattern.cpp @@ -0,0 +1,212 @@ +/* + * Copyright (c) 2023 - 2025 Chair for Design Automation, TUM + * Copyright (c) 2025 Munich Quantum Software Company GmbH + * All rights reserved. + * + * SPDX-License-Identifier: MIT + * + * Licensed under the MIT License + */ + +#include "mlir/Dialect/MQTOpt/IR/MQTOptDialect.h" +#include "mlir/Dialect/MQTOpt/Transforms/Passes.h" +#include "mlir/IR/BuiltinAttributes.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace mqt::ir::opt { + +static const std::unordered_set MERGEABLE_GATES = { + "gphase", "p", "rx", "ry", "rz", "rxx", "ryy", "rzz", "rzx"}; + +/** + * @brief This pattern attempts to merge consecutive rotation gates. + */ +struct MergeRotationGatesPattern final + : mlir::OpInterfaceRewritePattern { + + explicit MergeRotationGatesPattern(mlir::MLIRContext* context) + : OpInterfaceRewritePattern(context) {} + + /** + * @brief Checks if two gates can be merged. + * + * @param a The first gate. + * @param b The second gate. + * @return True if the gates can be merged, false otherwise. + */ + [[nodiscard]] static bool areGatesMergeable(mlir::Operation& a, + mlir::Operation& b) { + const auto aName = a.getName().stripDialect().str(); + const auto bName = b.getName().stripDialect().str(); + + return ((aName == bName) && (MERGEABLE_GATES.count(aName) == 1)); + } + + /** + * @brief Checks if all users of an operation are the same. + * + * @param users The users to check. + * @return True if all users are the same, false otherwise. + */ + [[nodiscard]] static bool + areUsersUnique(const mlir::ResultRange::user_range& users) { + return std::none_of(users.begin(), users.end(), + [&](auto* user) { return user != *users.begin(); }); + } + + mlir::LogicalResult match(UnitaryInterface op) const override { + const auto& users = op->getUsers(); + if (!areUsersUnique(users)) { + return mlir::failure(); + } + auto* user = *users.begin(); + if (!areGatesMergeable(*op, *user)) { + return mlir::failure(); + } + auto unitaryUser = mlir::dyn_cast(user); + if (op.getAllOutQubits() != unitaryUser.getAllInQubits()) { + return mlir::failure(); + } + if (op.getPosCtrlInQubits().size() != + unitaryUser.getPosCtrlInQubits().size() || + op.getNegCtrlInQubits().size() != + unitaryUser.getNegCtrlInQubits().size()) { + // We only need to check the sizes, because the order of the controls was + // already checked by the previous condition. + return mlir::failure(); + } + return mlir::success(); + } + + /** + * @brief Creates a new rotation gate. + * + * The new rotation gate is created by adding the angles of two compatible + * rotation gates. + * + * @tparam OpType The type of the operation to create. + * @param op The first instance of the rotation gate. + * @param user The second instance of the rotation gate. + * @param rewriter The pattern rewriter. + * @return A new rotation gate. + */ + template + static UnitaryInterface + createOpAdditiveAngle(UnitaryInterface op, UnitaryInterface user, + mlir::PatternRewriter& rewriter) { + auto loc = user->getLoc(); + + auto userInQubits = user.getInQubits(); + auto userPosCtrlInQubits = user.getPosCtrlInQubits(); + auto userNegCtrlInQubits = user.getNegCtrlInQubits(); + + auto opParam = op.getParams()[0]; + auto userParam = user.getParams()[0]; + auto add = rewriter.create(loc, opParam, userParam); + const llvm::SmallVector newParamsVec{add.getResult()}; + const mlir::ValueRange newParams(newParamsVec); + + return rewriter.create( + loc, userInQubits.getType(), userPosCtrlInQubits.getType(), + userNegCtrlInQubits.getType(), mlir::DenseF64ArrayAttr{}, + mlir::DenseBoolArrayAttr{}, newParams, userInQubits, + userPosCtrlInQubits, userNegCtrlInQubits); + } + + /** + * @brief Merges two consecutive rotation gates into a single gate. + * + * The function supports gphase, p, rx, ry, rz, rxx, ryy, rzz, and rzx. + * The gates are merged by adding their angles. + * The merged gate is not removed if the angles add up to zero. + * + * @param op The first instance of the rotation gate. + * @param rewriter The pattern rewriter. + */ + void static rewriteAdditiveAngle(UnitaryInterface op, + mlir::PatternRewriter& rewriter) { + auto const type = op->getName().stripDialect().str(); + + auto user = mlir::dyn_cast(*op->getUsers().begin()); + + UnitaryInterface newUser; + if (type == "gphase") { + newUser = createOpAdditiveAngle(op, user, rewriter); + } else if (type == "p") { + newUser = createOpAdditiveAngle(op, user, rewriter); + } else if (type == "rx") { + newUser = createOpAdditiveAngle(op, user, rewriter); + } else if (type == "ry") { + newUser = createOpAdditiveAngle(op, user, rewriter); + } else if (type == "rz") { + newUser = createOpAdditiveAngle(op, user, rewriter); + } else if (type == "rxx") { + newUser = createOpAdditiveAngle(op, user, rewriter); + } else if (type == "ryy") { + newUser = createOpAdditiveAngle(op, user, rewriter); + } else if (type == "rzz") { + newUser = createOpAdditiveAngle(op, user, rewriter); + } else if (type == "rzx") { + newUser = createOpAdditiveAngle(op, user, rewriter); + } else { + throw std::runtime_error("Unsupported operation type: " + type); + } + + // Prepare erasure of op + const auto& opAllInQubits = op.getAllInQubits(); + const auto& newUserAllInQubits = newUser.getAllInQubits(); + for (size_t i = 0; i < newUser->getOperands().size(); i++) { + const auto& operand = newUser->getOperand(i); + const auto found = std::find(newUserAllInQubits.begin(), + newUserAllInQubits.end(), operand); + if (found == newUserAllInQubits.end()) { + continue; + } + const auto idx = std::distance(newUserAllInQubits.begin(), found); + rewriter.modifyOpInPlace( + newUser, [&] { newUser->setOperand(i, opAllInQubits[idx]); }); + } + + // Replace user with newUser + rewriter.replaceOp(user, newUser); + + // Erase op + rewriter.eraseOp(op); + } + + void rewrite(UnitaryInterface op, + mlir::PatternRewriter& rewriter) const override { + auto const type = op->getName().stripDialect().str(); + + if (MERGEABLE_GATES.count(type) == 1) { + rewriteAdditiveAngle(op, rewriter); + } else { + throw std::runtime_error("Unsupported operation type: " + type); + } + } +}; + +/** + * @brief Populates the given pattern set with the `MergeRotationGatesPattern`. + * + * @param patterns The pattern set to populate. + */ +void populateMergeRotationGatesPatterns(mlir::RewritePatternSet& patterns) { + patterns.add(patterns.getContext()); +} + +} // namespace mqt::ir::opt diff --git a/mlir/test/Dialect/MQTOpt/Transforms/merge-rotation-gates.mlir b/mlir/test/Dialect/MQTOpt/Transforms/merge-rotation-gates.mlir new file mode 100644 index 000000000..d476754e4 --- /dev/null +++ b/mlir/test/Dialect/MQTOpt/Transforms/merge-rotation-gates.mlir @@ -0,0 +1,362 @@ +// Copyright (c) 2023 - 2025 Chair for Design Automation, TUM +// Copyright (c) 2025 Munich Quantum Software Company GmbH +// All rights reserved. +// +// SPDX-License-Identifier: MIT +// +// Licensed under the MIT License + +// RUN: quantum-opt %s -split-input-file --merge-rotation-gates | FileCheck %s + +// ----- +// This test checks that consecutive p gates are merged correctly. + +module { + // CHECK-LABEL: func.func @testMergeRxGates + func.func @testMergeRxGates() { + // CHECK: %[[Res_2:.*]] = arith.constant 2.000000e+00 : f64 + // CHECK: %[[ANY:.*]] = mqtopt.p(%[[Res_2]]) %[[ANY:.*]] : !mqtopt.Qubit + // CHECK-NOT: %[[ANY:.*]] = mqtopt.p(%[[ANY:.*]]) %[[ANY:.*]] : !mqtopt.Qubit + + %reg_0 = "mqtopt.allocQubitRegister"() <{size_attr = 2 : i64}> : () -> !mqtopt.QubitRegister + %reg_1, %q0_0 = "mqtopt.extractQubit"(%reg_0) <{index_attr = 0 : i64}> : (!mqtopt.QubitRegister) -> (!mqtopt.QubitRegister, !mqtopt.Qubit) + + %c_0 = arith.constant 1.000000e+00 : f64 + %q0_1 = mqtopt.p(%c_0) %q0_0 : !mqtopt.Qubit + %q0_2 = mqtopt.p(%c_0) %q0_1 : !mqtopt.Qubit + + %reg_2 = "mqtopt.insertQubit"(%reg_1, %q0_2) <{index_attr = 0 : i64}> : (!mqtopt.QubitRegister, !mqtopt.Qubit) -> !mqtopt.QubitRegister + "mqtopt.deallocQubitRegister"(%reg_2) : (!mqtopt.QubitRegister) -> () + + return + } +} + +// ----- +// This test checks that consecutive rx gates are merged correctly. + +module { + // CHECK-LABEL: func.func @testMergeRxGates + func.func @testMergeRxGates() { + // CHECK: %[[Res_3:.*]] = arith.constant 3.000000e+00 : f64 + // CHECK: %[[ANY:.*]] = mqtopt.rx(%[[Res_3]]) %[[ANY:.*]] : !mqtopt.Qubit + // CHECK-NOT: %[[ANY:.*]] = mqtopt.rx(%[[ANY:.*]]) %[[ANY:.*]] : !mqtopt.Qubit + + %reg_0 = "mqtopt.allocQubitRegister"() <{size_attr = 2 : i64}> : () -> !mqtopt.QubitRegister + %reg_1, %q0_0 = "mqtopt.extractQubit"(%reg_0) <{index_attr = 0 : i64}> : (!mqtopt.QubitRegister) -> (!mqtopt.QubitRegister, !mqtopt.Qubit) + + %c_0 = arith.constant 1.000000e+00 : f64 + %q0_1 = mqtopt.rx(%c_0) %q0_0 : !mqtopt.Qubit + %q0_2 = mqtopt.rx(%c_0) %q0_1 : !mqtopt.Qubit + %q0_3 = mqtopt.rx(%c_0) %q0_2 : !mqtopt.Qubit + + %reg_2 = "mqtopt.insertQubit"(%reg_1, %q0_3) <{index_attr = 0 : i64}> : (!mqtopt.QubitRegister, !mqtopt.Qubit) -> !mqtopt.QubitRegister + "mqtopt.deallocQubitRegister"(%reg_2) : (!mqtopt.QubitRegister) -> () + + return + } +} + +// ----- +// This test checks that consecutive ry gates are merged correctly. + +module { + // CHECK-LABEL: func.func @testMergeRyGates + func.func @testMergeRyGates(%c_0 : f64, %c_1 : f64) { + // CHECK: %[[Res:.*]] = arith.addf %arg0, %arg1 : f64 + // CHECK: %[[ANY:.*]] = mqtopt.ry(%[[Res]]) %[[ANY:.*]] : !mqtopt.Qubit + // CHECK-NOT: %[[ANY:.*]] = mqtopt.ry(%[[ANY:.*]]) %[[ANY:.*]] : !mqtopt.Qubit + + %reg_0 = "mqtopt.allocQubitRegister"() <{size_attr = 2 : i64}> : () -> !mqtopt.QubitRegister + %reg_1, %q0_0 = "mqtopt.extractQubit"(%reg_0) <{index_attr = 0 : i64}> : (!mqtopt.QubitRegister) -> (!mqtopt.QubitRegister, !mqtopt.Qubit) + + %q0_1 = mqtopt.ry(%c_0) %q0_0 : !mqtopt.Qubit + %q0_2 = mqtopt.ry(%c_1) %q0_1 : !mqtopt.Qubit + + %reg_2 = "mqtopt.insertQubit"(%reg_1, %q0_2) <{index_attr = 0 : i64}> : (!mqtopt.QubitRegister, !mqtopt.Qubit) -> !mqtopt.QubitRegister + "mqtopt.deallocQubitRegister"(%reg_2) : (!mqtopt.QubitRegister) -> () + + return + } +} + +// ----- +// This test checks that consecutive rz gates are merged correctly. + +module { + // CHECK-LABEL: func.func @testMergeRzGates + func.func @testMergeRzGates() { + // CHECK: %[[Res_3:.*]] = arith.constant 3.000000e+00 : f64 + // CHECK: %[[ANY:.*]] = mqtopt.rz(%[[Res_3]]) %[[ANY:.*]] : !mqtopt.Qubit + // CHECK-NOT: %[[ANY:.*]] = mqtopt.rz(%[[ANY:.*]]) %[[ANY:.*]] : !mqtopt.Qubit + + %reg_0 = "mqtopt.allocQubitRegister"() <{size_attr = 2 : i64}> : () -> !mqtopt.QubitRegister + %reg_1, %q0_0 = "mqtopt.extractQubit"(%reg_0) <{index_attr = 0 : i64}> : (!mqtopt.QubitRegister) -> (!mqtopt.QubitRegister, !mqtopt.Qubit) + + %c_0 = arith.constant 1.000000e+00 : f64 + %c_1 = arith.constant 2.000000e+00 : f64 + %q0_1 = mqtopt.rz(%c_0) %q0_0 : !mqtopt.Qubit + %q0_2 = mqtopt.rz(%c_1) %q0_1 : !mqtopt.Qubit + + %reg_2 = "mqtopt.insertQubit"(%reg_1, %q0_2) <{index_attr = 0 : i64}> : (!mqtopt.QubitRegister, !mqtopt.Qubit) -> !mqtopt.QubitRegister + "mqtopt.deallocQubitRegister"(%reg_2) : (!mqtopt.QubitRegister) -> () + + return + } +} + +// ----- +// This test checks that incompatible single-qubit gates are not merged. +// The gates cannot be merged because their types are different. + +module { + // CHECK-LABEL: func.func @testDoNotMergeSingleQubitGatesDifferentGates + func.func @testDoNotMergeSingleQubitGatesDifferentGates() { + // CHECK: %[[Res_2:.*]] = arith.constant 2.000000e+00 : f64 + // CHECK: %[[Res_1:.*]] = arith.constant 1.000000e+00 : f64 + // CHECK: %[[Q0_1:.*]] = mqtopt.rx(%[[Res_1]]) %[[ANY:.*]] : !mqtopt.Qubit + // CHECK: %[[Q0_2:.*]] = mqtopt.ry(%[[Res_1]]) %[[Q0_1]] : !mqtopt.Qubit + // CHECK: %[[ANY:.*]] = mqtopt.rz(%[[Res_2]]) %[[Q0_2]] : !mqtopt.Qubit + + %reg_0 = "mqtopt.allocQubitRegister"() <{size_attr = 2 : i64}> : () -> !mqtopt.QubitRegister + %reg_1, %q0_0 = "mqtopt.extractQubit"(%reg_0) <{index_attr = 0 : i64}> : (!mqtopt.QubitRegister) -> (!mqtopt.QubitRegister, !mqtopt.Qubit) + + %c_0 = arith.constant 1.000000e+00 : f64 + %c_1 = arith.constant 2.000000e+00 : f64 + %q0_1 = mqtopt.rx(%c_0) %q0_0 : !mqtopt.Qubit + %q0_2 = mqtopt.ry(%c_0) %q0_1 : !mqtopt.Qubit + %q0_3 = mqtopt.rz(%c_1) %q0_2 : !mqtopt.Qubit + + %reg_2 = "mqtopt.insertQubit"(%reg_1, %q0_3) <{index_attr = 0 : i64}> : (!mqtopt.QubitRegister, !mqtopt.Qubit) -> !mqtopt.QubitRegister + "mqtopt.deallocQubitRegister"(%reg_2) : (!mqtopt.QubitRegister) -> () + + return + } +} + +// ----- +// This test checks that incompatible single-qubit gates are not merged. +// The gates cannot be merged because they act on different qubits. + +module { + // CHECK-LABEL: func.func @testDoNotMergeSingleQubitGatesIndependentGates + func.func @testDoNotMergeSingleQubitGatesIndependentGates() { + // CHECK: %[[Res_2:.*]] = arith.constant 2.000000e+00 : f64 + // CHECK: %[[Res_1:.*]] = arith.constant 1.000000e+00 : f64 + // CHECK: %[[ANY:.*]] = mqtopt.rx(%[[Res_1]]) %[[ANY:.*]] : !mqtopt.Qubit + // CHECK: %[[ANY:.*]] = mqtopt.rx(%[[Res_2]]) %[[ANY:.*]] : !mqtopt.Qubit + + %reg_0 = "mqtopt.allocQubitRegister"() <{size_attr = 3 : i64}> : () -> !mqtopt.QubitRegister + %reg_1, %q0_0 = "mqtopt.extractQubit"(%reg_0) <{index_attr = 0 : i64}> : (!mqtopt.QubitRegister) -> (!mqtopt.QubitRegister, !mqtopt.Qubit) + %reg_2, %q1_0 = "mqtopt.extractQubit"(%reg_1) <{index_attr = 1 : i64}> : (!mqtopt.QubitRegister) -> (!mqtopt.QubitRegister, !mqtopt.Qubit) + + %c_0 = arith.constant 1.000000e+00 : f64 + %c_1 = arith.constant 2.000000e+00 : f64 + %q0_1 = mqtopt.rx(%c_0) %q0_0 : !mqtopt.Qubit + %q1_1 = mqtopt.rx(%c_1) %q1_0 : !mqtopt.Qubit + + %reg_3 = "mqtopt.insertQubit"(%reg_2, %q0_1) <{index_attr = 0 : i64}> : (!mqtopt.QubitRegister, !mqtopt.Qubit) -> !mqtopt.QubitRegister + %reg_4 = "mqtopt.insertQubit"(%reg_3, %q1_1) <{index_attr = 1 : i64}> : (!mqtopt.QubitRegister, !mqtopt.Qubit) -> !mqtopt.QubitRegister + "mqtopt.deallocQubitRegister"(%reg_4) : (!mqtopt.QubitRegister) -> () + + return + } +} + +// ----- +// This test checks that consecutive rxx gates are merged correctly. + +module { + // CHECK-LABEL: func.func @testMergeRxxGates + func.func @testMergeRxxGates() { + // CHECK: %[[Res_3:.*]] = arith.constant 3.000000e+00 : f64 + // CHECK: %[[ANY:.*]]:2 = mqtopt.rxx(%[[Res_3]]) %[[ANY:.*]], %[[ANY:.*]] : !mqtopt.Qubit, !mqtopt.Qubit + // CHECK-NOT: %[[ANY:.*]]:2 = mqtopt.rxx(%[[ANY:.*]]) %[[ANY:.*]], %[[ANY:.*]] : !mqtopt.Qubit, !mqtopt.Qubit + + %reg_0 = "mqtopt.allocQubitRegister"() <{size_attr = 3 : i64}> : () -> !mqtopt.QubitRegister + %reg_1, %q0_0 = "mqtopt.extractQubit"(%reg_0) <{index_attr = 0 : i64}> : (!mqtopt.QubitRegister) -> (!mqtopt.QubitRegister, !mqtopt.Qubit) + %reg_2, %q1_0 = "mqtopt.extractQubit"(%reg_1) <{index_attr = 1 : i64}> : (!mqtopt.QubitRegister) -> (!mqtopt.QubitRegister, !mqtopt.Qubit) + + %c_0 = arith.constant 1.000000e+00 : f64 + %q01_1:2 = mqtopt.rxx(%c_0) %q0_0, %q1_0 : !mqtopt.Qubit, !mqtopt.Qubit + %q01_2:2 = mqtopt.rxx(%c_0) %q01_1#0, %q01_1#1 : !mqtopt.Qubit, !mqtopt.Qubit + %q01_3:2 = mqtopt.rxx(%c_0) %q01_2#0, %q01_2#1 : !mqtopt.Qubit, !mqtopt.Qubit + + %reg_3 = "mqtopt.insertQubit"(%reg_2, %q01_3#0) <{index_attr = 0 : i64}> : (!mqtopt.QubitRegister, !mqtopt.Qubit) -> !mqtopt.QubitRegister + %reg_4 = "mqtopt.insertQubit"(%reg_3, %q01_3#1) <{index_attr = 1 : i64}> : (!mqtopt.QubitRegister, !mqtopt.Qubit) -> !mqtopt.QubitRegister + "mqtopt.deallocQubitRegister"(%reg_4) : (!mqtopt.QubitRegister) -> () + + return + } +} + +// ----- +// This test checks that consecutive ryy gates are merged correctly. + +module { + // CHECK-LABEL: func.func @testMergeRyyGates + func.func @testMergeRyyGates(%c_0 : f64, %c_1 : f64) { + // CHECK: %[[Res:.*]] = arith.addf %arg0, %arg1 : f64 + // CHECK: %[[ANY:.*]]:2 = mqtopt.ryy(%[[Res]]) %[[ANY:.*]], %[[ANY:.*]] : !mqtopt.Qubit, !mqtopt.Qubit + // CHECK-NOT: %[[ANY:.*]]:2 = mqtopt.ryy(%[[ANY:.*]]) %[[ANY:.*]], %[[ANY:.*]] : !mqtopt.Qubit, !mqtopt.Qubit + + %reg_0 = "mqtopt.allocQubitRegister"() <{size_attr = 3 : i64}> : () -> !mqtopt.QubitRegister + %reg_1, %q0_0 = "mqtopt.extractQubit"(%reg_0) <{index_attr = 0 : i64}> : (!mqtopt.QubitRegister) -> (!mqtopt.QubitRegister, !mqtopt.Qubit) + %reg_2, %q1_0 = "mqtopt.extractQubit"(%reg_1) <{index_attr = 1 : i64}> : (!mqtopt.QubitRegister) -> (!mqtopt.QubitRegister, !mqtopt.Qubit) + + %q01_1:2 = mqtopt.ryy(%c_0) %q0_0, %q1_0 : !mqtopt.Qubit, !mqtopt.Qubit + %q01_2:2 = mqtopt.ryy(%c_1) %q01_1#0, %q01_1#1 : !mqtopt.Qubit, !mqtopt.Qubit + + %reg_3 = "mqtopt.insertQubit"(%reg_2, %q01_2#0) <{index_attr = 0 : i64}> : (!mqtopt.QubitRegister, !mqtopt.Qubit) -> !mqtopt.QubitRegister + %reg_4 = "mqtopt.insertQubit"(%reg_3, %q01_2#1) <{index_attr = 1 : i64}> : (!mqtopt.QubitRegister, !mqtopt.Qubit) -> !mqtopt.QubitRegister + "mqtopt.deallocQubitRegister"(%reg_4) : (!mqtopt.QubitRegister) -> () + + return + } +} + +// ----- +// This test checks that consecutive rzz gates are merged correctly. + +module { + // CHECK-LABEL: func.func @testMergeRzzGates + func.func @testMergeRzzGates() { + // CHECK: %[[Res_3:.*]] = arith.constant 3.000000e+00 : f64 + // CHECK: %[[ANY:.*]]:2 = mqtopt.rzz(%[[Res_3]]) %[[ANY:.*]], %[[ANY:.*]] : !mqtopt.Qubit, !mqtopt.Qubit + // CHECK-NOT: %[[ANY:.*]]:2 = mqtopt.rzz(%[[ANY:.*]]) %[[ANY:.*]], %[[ANY:.*]] : !mqtopt.Qubit, !mqtopt.Qubit + + %reg_0 = "mqtopt.allocQubitRegister"() <{size_attr = 3 : i64}> : () -> !mqtopt.QubitRegister + %reg_1, %q0_0 = "mqtopt.extractQubit"(%reg_0) <{index_attr = 0 : i64}> : (!mqtopt.QubitRegister) -> (!mqtopt.QubitRegister, !mqtopt.Qubit) + %reg_2, %q1_0 = "mqtopt.extractQubit"(%reg_1) <{index_attr = 1 : i64}> : (!mqtopt.QubitRegister) -> (!mqtopt.QubitRegister, !mqtopt.Qubit) + + %c_0 = arith.constant 1.000000e+00 : f64 + %c_1 = arith.constant 2.000000e+00 : f64 + %q01_1:2 = mqtopt.rzz(%c_0) %q0_0, %q1_0 : !mqtopt.Qubit, !mqtopt.Qubit + %q01_2:2 = mqtopt.rzz(%c_1) %q01_1#0, %q01_1#1 : !mqtopt.Qubit, !mqtopt.Qubit + + %reg_3 = "mqtopt.insertQubit"(%reg_2, %q01_2#0) <{index_attr = 0 : i64}> : (!mqtopt.QubitRegister, !mqtopt.Qubit) -> !mqtopt.QubitRegister + %reg_4 = "mqtopt.insertQubit"(%reg_3, %q01_2#1) <{index_attr = 1 : i64}> : (!mqtopt.QubitRegister, !mqtopt.Qubit) -> !mqtopt.QubitRegister + "mqtopt.deallocQubitRegister"(%reg_4) : (!mqtopt.QubitRegister) -> () + + return + } +} + +// ----- +// This test checks that consecutive rzx gates are merged correctly. + +module { + // CHECK-LABEL: func.func @testMergeRzxGates + func.func @testMergeRzxGates() { + // CHECK: %[[Res_3:.*]] = arith.constant 3.000000e+00 : f64 + // CHECK: %[[ANY:.*]]:2 = mqtopt.rzx(%[[Res_3]]) %[[ANY:.*]], %[[ANY:.*]] : !mqtopt.Qubit, !mqtopt.Qubit + // CHECK-NOT: %[[ANY:.*]]:2 = mqtopt.rzx(%[[ANY:.*]]) %[[ANY:.*]], %[[ANY:.*]] : !mqtopt.Qubit, !mqtopt.Qubit + + %reg_0 = "mqtopt.allocQubitRegister"() <{size_attr = 3 : i64}> : () -> !mqtopt.QubitRegister + %reg_1, %q0_0 = "mqtopt.extractQubit"(%reg_0) <{index_attr = 0 : i64}> : (!mqtopt.QubitRegister) -> (!mqtopt.QubitRegister, !mqtopt.Qubit) + %reg_2, %q1_0 = "mqtopt.extractQubit"(%reg_1) <{index_attr = 1 : i64}> : (!mqtopt.QubitRegister) -> (!mqtopt.QubitRegister, !mqtopt.Qubit) + + %c_0 = arith.constant 1.000000e+00 : f64 + %c_1 = arith.constant 2.000000e+00 : f64 + %q01_1:2 = mqtopt.rzx(%c_0) %q0_0, %q1_0 : !mqtopt.Qubit, !mqtopt.Qubit + %q01_2:2 = mqtopt.rzx(%c_1) %q01_1#0, %q01_1#1 : !mqtopt.Qubit, !mqtopt.Qubit + + %reg_3 = "mqtopt.insertQubit"(%reg_2, %q01_2#0) <{index_attr = 0 : i64}> : (!mqtopt.QubitRegister, !mqtopt.Qubit) -> !mqtopt.QubitRegister + %reg_4 = "mqtopt.insertQubit"(%reg_3, %q01_2#1) <{index_attr = 1 : i64}> : (!mqtopt.QubitRegister, !mqtopt.Qubit) -> !mqtopt.QubitRegister + "mqtopt.deallocQubitRegister"(%reg_4) : (!mqtopt.QubitRegister) -> () + + return + } +} + +// ----- +// This test checks that incompatible multi-qubit gates are not merged. +// The gates cannot be merged because their types are different. + +module { + // CHECK-LABEL: func.func @testDoNotMergeMultiQubitGatesDifferentGates + func.func @testDoNotMergeMultiQubitGatesDifferentGates() { + // CHECK: %[[Res_2:.*]] = arith.constant 2.000000e+00 : f64 + // CHECK: %[[Res_1:.*]] = arith.constant 1.000000e+00 : f64 + // CHECK: %[[Q01_1:.*]]:2 = mqtopt.rxx(%[[Res_1]]) %[[ANY:.*]], %[[ANY:.*]] : !mqtopt.Qubit, !mqtopt.Qubit + // CHECK: %[[Q01_2:.*]]:2 = mqtopt.ryy(%[[Res_1]]) %[[Q01_1]]#0, %[[Q01_1]]#1 : !mqtopt.Qubit, !mqtopt.Qubit + // CHECK: %[[ANY:.*]]:2 = mqtopt.rzz(%[[Res_2]]) %[[Q01_2]]#0, %[[Q01_2]]#1 : !mqtopt.Qubit, !mqtopt.Qubit + + %reg_0 = "mqtopt.allocQubitRegister"() <{size_attr = 3 : i64}> : () -> !mqtopt.QubitRegister + %reg_1, %q0_0 = "mqtopt.extractQubit"(%reg_0) <{index_attr = 0 : i64}> : (!mqtopt.QubitRegister) -> (!mqtopt.QubitRegister, !mqtopt.Qubit) + %reg_2, %q1_0 = "mqtopt.extractQubit"(%reg_1) <{index_attr = 1 : i64}> : (!mqtopt.QubitRegister) -> (!mqtopt.QubitRegister, !mqtopt.Qubit) + + %c_0 = arith.constant 1.000000e+00 : f64 + %c_1 = arith.constant 2.000000e+00 : f64 + %q01_1:2 = mqtopt.rxx(%c_0) %q0_0, %q1_0 : !mqtopt.Qubit, !mqtopt.Qubit + %q01_2:2 = mqtopt.ryy(%c_0) %q01_1#0, %q01_1#1 : !mqtopt.Qubit, !mqtopt.Qubit + %q01_3:2 = mqtopt.rzz(%c_1) %q01_2#0, %q01_2#1 : !mqtopt.Qubit, !mqtopt.Qubit + + %reg_3 = "mqtopt.insertQubit"(%reg_2, %q01_3#0) <{index_attr = 0 : i64}> : (!mqtopt.QubitRegister, !mqtopt.Qubit) -> !mqtopt.QubitRegister + %reg_4 = "mqtopt.insertQubit"(%reg_3, %q01_3#1) <{index_attr = 1 : i64}> : (!mqtopt.QubitRegister, !mqtopt.Qubit) -> !mqtopt.QubitRegister + "mqtopt.deallocQubitRegister"(%reg_4) : (!mqtopt.QubitRegister) -> () + + return + } +} + +// ----- +// This test checks that incompatible multi-qubit gates are not merged. +// The gates cannot be merged because their types are different. + +module { + // CHECK-LABEL: func.func @testDoNotMergeMultiQubitGatesIndependentGates + func.func @testDoNotMergeMultiQubitGatesIndependentGates() { + // CHECK: %[[Res_2:.*]] = arith.constant 2.000000e+00 : f64 + // CHECK: %[[Res_1:.*]] = arith.constant 1.000000e+00 : f64 + // CHECK: %[[Q0_1:.*]]:2 = mqtopt.rxx(%[[Res_1]]) %[[ANY:.*]], %[[ANY:.*]] : !mqtopt.Qubit, !mqtopt.Qubit + // CHECK: %[[ANY:.*]]:2 = mqtopt.rxx(%[[Res_2]]) %[[Q0_1:.*]]#1, %[[ANY:.*]] : !mqtopt.Qubit, !mqtopt.Qubit + + %reg_0 = "mqtopt.allocQubitRegister"() <{size_attr = 3 : i64}> : () -> !mqtopt.QubitRegister + %reg_1, %q0_0 = "mqtopt.extractQubit"(%reg_0) <{index_attr = 0 : i64}> : (!mqtopt.QubitRegister) -> (!mqtopt.QubitRegister, !mqtopt.Qubit) + %reg_2, %q1_0 = "mqtopt.extractQubit"(%reg_1) <{index_attr = 1 : i64}> : (!mqtopt.QubitRegister) -> (!mqtopt.QubitRegister, !mqtopt.Qubit) + %reg_3, %q2_0 = "mqtopt.extractQubit"(%reg_2) <{index_attr = 2 : i64}> : (!mqtopt.QubitRegister) -> (!mqtopt.QubitRegister, !mqtopt.Qubit) + + %c_0 = arith.constant 1.000000e+00 : f64 + %c_1 = arith.constant 2.000000e+00 : f64 + %q01_1:2 = mqtopt.rxx(%c_0) %q0_0, %q1_0 : !mqtopt.Qubit, !mqtopt.Qubit + %q12_1:2 = mqtopt.rxx(%c_1) %q01_1#1, %q2_0 : !mqtopt.Qubit, !mqtopt.Qubit + + %reg_4 = "mqtopt.insertQubit"(%reg_3, %q01_1#0) <{index_attr = 0 : i64}> : (!mqtopt.QubitRegister, !mqtopt.Qubit) -> !mqtopt.QubitRegister + %reg_5 = "mqtopt.insertQubit"(%reg_4, %q12_1#0) <{index_attr = 1 : i64}> : (!mqtopt.QubitRegister, !mqtopt.Qubit) -> !mqtopt.QubitRegister + %reg_6 = "mqtopt.insertQubit"(%reg_5, %q12_1#1) <{index_attr = 1 : i64}> : (!mqtopt.QubitRegister, !mqtopt.Qubit) -> !mqtopt.QubitRegister + "mqtopt.deallocQubitRegister"(%reg_6) : (!mqtopt.QubitRegister) -> () + + return + } +} + +// ----- +// This test checks that incompatible multi-qubit gates are not merged. +// The gates cannot be merged because their input qubits do not have the same order. +// This test should fail when a canonicalization pass is implemented with #1031. + +module { + // CHECK-LABEL: func.func @testDoNotMergeMultiQubitGatesDifferentInputQubitOrder + func.func @testDoNotMergeMultiQubitGatesDifferentInputQubitOrder() { + // CHECK: %[[Res_2:.*]] = arith.constant 2.000000e+00 : f64 + // CHECK: %[[Res_1:.*]] = arith.constant 1.000000e+00 : f64 + // CHECK: %[[Q0_1:.*]]:2 = mqtopt.rxx(%[[Res_1]]) %[[ANY:.*]], %[[ANY:.*]] : !mqtopt.Qubit, !mqtopt.Qubit + // CHECK: %[[ANY:.*]]:2 = mqtopt.rxx(%[[Res_2]]) %[[Q0_1]]#1, %[[Q0_1]]#0 : !mqtopt.Qubit, !mqtopt.Qubit + + %reg_0 = "mqtopt.allocQubitRegister"() <{size_attr = 3 : i64}> : () -> !mqtopt.QubitRegister + %reg_1, %q0_0 = "mqtopt.extractQubit"(%reg_0) <{index_attr = 0 : i64}> : (!mqtopt.QubitRegister) -> (!mqtopt.QubitRegister, !mqtopt.Qubit) + %reg_2, %q1_0 = "mqtopt.extractQubit"(%reg_1) <{index_attr = 1 : i64}> : (!mqtopt.QubitRegister) -> (!mqtopt.QubitRegister, !mqtopt.Qubit) + + %c_0 = arith.constant 1.000000e+00 : f64 + %c_1 = arith.constant 2.000000e+00 : f64 + %q01_1:2 = mqtopt.rxx(%c_0) %q0_0, %q1_0 : !mqtopt.Qubit, !mqtopt.Qubit + %q01_2:2 = mqtopt.rxx(%c_1) %q01_1#1, %q01_1#0 : !mqtopt.Qubit, !mqtopt.Qubit + + %reg_3 = "mqtopt.insertQubit"(%reg_2, %q01_2#0) <{index_attr = 0 : i64}> : (!mqtopt.QubitRegister, !mqtopt.Qubit) -> !mqtopt.QubitRegister + %reg_4 = "mqtopt.insertQubit"(%reg_3, %q01_2#1) <{index_attr = 1 : i64}> : (!mqtopt.QubitRegister, !mqtopt.Qubit) -> !mqtopt.QubitRegister + "mqtopt.deallocQubitRegister"(%reg_4) : (!mqtopt.QubitRegister) -> () + + return + } +}