diff --git a/oqpy/base.py b/oqpy/base.py index 4a45796..43dece9 100644 --- a/oqpy/base.py +++ b/oqpy/base.py @@ -60,12 +60,29 @@ def _to_binary( """Helper method to produce a binary expression.""" return OQPyBinaryExpression(ast.BinaryOperator[op_name], first, second) + @staticmethod + def _to_unary(op_name: str, exp: AstConvertible) -> OQPyUnaryExpression: + """Helper method to produce a binary expression.""" + return OQPyUnaryExpression(ast.UnaryOperator[op_name], exp) + + def __pos__(self) -> OQPyExpression: + return self + + def __neg__(self) -> OQPyUnaryExpression: + return self._to_unary("-", self) + def __add__(self, other: AstConvertible) -> OQPyBinaryExpression: return self._to_binary("+", self, other) def __radd__(self, other: AstConvertible) -> OQPyBinaryExpression: return self._to_binary("+", other, self) + def __sub__(self, other: AstConvertible) -> OQPyBinaryExpression: + return self._to_binary("-", self, other) + + def __rsub__(self, other: AstConvertible) -> OQPyBinaryExpression: + return self._to_binary("-", other, self) + def __mod__(self, other: AstConvertible) -> OQPyBinaryExpression: return self._to_binary("%", self, other) @@ -78,6 +95,18 @@ def __mul__(self, other: AstConvertible) -> OQPyBinaryExpression: def __rmul__(self, other: AstConvertible) -> OQPyBinaryExpression: return self._to_binary("*", other, self) + def __truediv__(self, other: AstConvertible) -> OQPyBinaryExpression: + return self._to_binary("/", self, other) + + def __rtruediv__(self, other: AstConvertible) -> OQPyBinaryExpression: + return self._to_binary("/", other, self) + + def __pow__(self, other: AstConvertible) -> OQPyBinaryExpression: + return self._to_binary("**", self, other) + + def __rpow__(self, other: AstConvertible) -> OQPyBinaryExpression: + return self._to_binary("**", other, self) + def __eq__(self, other: AstConvertible) -> OQPyBinaryExpression: # type: ignore[override] return self._to_binary("==", self, other) @@ -132,6 +161,23 @@ def _to_oqpy_expression(self) -> HasToAst: ... +class OQPyUnaryExpression(OQPyExpression): + """An expression consisting of one expression preceded by an operator.""" + + def __init__(self, op: ast.UnaryOperator, exp: AstConvertible): + super().__init__() + self.op = op + self.exp = exp + if isinstance(exp, OQPyExpression): + self.type = exp.type + else: + raise TypeError("exp is an expression") + + def to_ast(self, program: Program) -> ast.UnaryExpression: + """Converts the OQpy expression into an ast node.""" + return ast.UnaryExpression(self.op, to_ast(program, self.exp)) + + class OQPyBinaryExpression(OQPyExpression): """An expression consisting of two subexpressions joined by an operator.""" diff --git a/oqpy/classical_types.py b/oqpy/classical_types.py index 630e336..7bc216f 100644 --- a/oqpy/classical_types.py +++ b/oqpy/classical_types.py @@ -38,6 +38,7 @@ from oqpy.program import Program __all__ = [ + "pi", "BoolVar", "IntVar", "UintVar", @@ -53,6 +54,7 @@ "stretch", "bool_", "bit_", + "bit", "bit8", "convert_range", "int_", @@ -78,24 +80,24 @@ # subclasses of ``_ClassicalVar`` instead. -def int_(size: int) -> ast.IntType: +def int_(size: int | None = None) -> ast.IntType: """Create a sized signed integer type.""" - return ast.IntType(ast.IntegerLiteral(size)) + return ast.IntType(ast.IntegerLiteral(size) if size is not None else None) -def uint_(size: int) -> ast.UintType: +def uint_(size: int | None = None) -> ast.UintType: """Create a sized unsigned integer type.""" - return ast.UintType(ast.IntegerLiteral(size)) + return ast.UintType(ast.IntegerLiteral(size) if size is not None else None) -def float_(size: int) -> ast.FloatType: +def float_(size: int | None = None) -> ast.FloatType: """Create a sized floating-point type.""" - return ast.FloatType(ast.IntegerLiteral(size)) + return ast.FloatType(ast.IntegerLiteral(size) if size is not None else None) -def angle_(size: int) -> ast.AngleType: +def angle_(size: int | None = None) -> ast.AngleType: """Create a sized angle type.""" - return ast.AngleType(ast.IntegerLiteral(size)) + return ast.AngleType(ast.IntegerLiteral(size) if size is not None else None) def complex_(size: int) -> ast.ComplexType: @@ -107,14 +109,15 @@ def complex_(size: int) -> ast.ComplexType: return ast.ComplexType(ast.FloatType(ast.IntegerLiteral(size // 2))) -def bit_(size: int) -> ast.BitType: +def bit_(size: int | None = None) -> ast.BitType: """Create a sized bit type.""" - return ast.BitType(ast.IntegerLiteral(size)) + return ast.BitType(ast.IntegerLiteral(size) if size is not None else None) duration = ast.DurationType() stretch = ast.StretchType() bool_ = ast.BoolType() +bit = ast.BitType() bit8 = bit_(8) int32 = int_(32) int64 = int_(64) @@ -136,6 +139,22 @@ def convert_range(program: Program, item: Union[slice, range]) -> ast.RangeDefin ) +class Identifier(OQPyExpression): + """Base class to specify constant symbols.""" + + name: str + + def __init__(self, name: str) -> None: + self.type = None + self.name = name + + def to_ast(self, program: Program) -> ast.Expression: + return ast.Identifier(name=self.name) + + +pi = Identifier(name="pi") + + class _ClassicalVar(Var, OQPyExpression): """Base type for variables with classical type. diff --git a/oqpy/program.py b/oqpy/program.py index 5e2e440..49c7d52 100644 --- a/oqpy/program.py +++ b/oqpy/program.py @@ -82,7 +82,9 @@ class Program: def __init__(self, version: Optional[str] = "3.0") -> None: self.stack: list[ProgramState] = [ProgramState()] - self.defcals: dict[tuple[tuple[str, ...], str], ast.CalibrationDefinition] = {} + self.defcals: dict[ + tuple[tuple[str, ...], str, tuple[str, ...]], ast.CalibrationDefinition + ] = {} self.subroutines: dict[str, ast.SubroutineDefinition] = {} self.externs: dict[str, ast.ExternDeclaration] = {} self.declared_vars: dict[str, Var] = {} @@ -196,13 +198,17 @@ def _add_subroutine(self, name: str, stmt: ast.SubroutineDefinition) -> None: self.subroutines[name] = stmt def _add_defcal( - self, qubit_names: list[str], name: str, stmt: ast.CalibrationDefinition + self, + qubit_names: list[str], + name: str, + arguments: list[str], + stmt: ast.CalibrationDefinition, ) -> None: """Register a defcal defined in this program. Defcals are added to the top of the program upon conversion to ast. """ - self.defcals[(tuple(qubit_names), name)] = stmt + self.defcals[(tuple(qubit_names), name, tuple(arguments))] = stmt def _make_externs_statements(self, auto_encal: bool = False) -> list[ast.ExternDeclaration]: """Return a list of extern statements for inclusion at beginning of program. diff --git a/oqpy/quantum_types.py b/oqpy/quantum_types.py index 44c1a12..6ca4368 100644 --- a/oqpy/quantum_types.py +++ b/oqpy/quantum_types.py @@ -18,11 +18,13 @@ from __future__ import annotations import contextlib -from typing import TYPE_CHECKING, Iterator, Union +from typing import TYPE_CHECKING, Iterator, Optional, Union from openpulse import ast +from openpulse.printer import dumps -from oqpy.base import Var +from oqpy.base import AstConvertible, Var, to_ast +from oqpy.classical_types import _ClassicalVar if TYPE_CHECKING: from oqpy.program import Program @@ -64,30 +66,57 @@ class QubitArray: @contextlib.contextmanager -def defcal(program: Program, qubits: Union[Qubit, list[Qubit]], name: str) -> Iterator[None]: +def defcal( + program: Program, + qubits: Union[Qubit, list[Qubit]], + name: str, + arguments: Optional[list[AstConvertible]] = None, + return_type: Optional[ast.ClassicalType] = None, +) -> Union[Iterator[None], Iterator[list[_ClassicalVar]], Iterator[_ClassicalVar]]: """Context manager for creating a defcal. .. code-block:: python - with defcal(program, q1, "X"): + with defcal(program, q1, "X", [AngleVar(name="theta"), oqpy.pi/2], oqpy.bit) as theta: program.play(frame, waveform) """ - program._push() - yield - state = program._pop() - if isinstance(qubits, Qubit): qubits = [qubits] + assert return_type is None or isinstance(return_type, ast.ClassicalType) + + arguments_ast = [] + variables = [] + if arguments is not None: + for arg in arguments: + if isinstance(arg, _ClassicalVar): + arguments_ast.append( + ast.ClassicalArgument(type=arg.type, name=ast.Identifier(name=arg.name)) + ) + arg._needs_declaration = False + variables.append(arg) + else: + arguments_ast.append(to_ast(program, arg)) + + program._push() + if len(variables) > 1: + yield variables + elif len(variables) == 1: + yield variables[0] + else: + yield + state = program._pop() stmt = ast.CalibrationDefinition( ast.Identifier(name), - [], # TODO (#52): support arguments + arguments_ast, [ast.Identifier(q.name) for q in qubits], - None, # TODO (#52): support return type, + return_type, state.body, ) program._add_statement(stmt) - program._add_defcal([qubit.name for qubit in qubits], name, stmt) + program._add_defcal( + [qubit.name for qubit in qubits], name, [dumps(a) for a in arguments_ast], stmt + ) @contextlib.contextmanager diff --git a/tests/test_directives.py b/tests/test_directives.py index 24115ef..e10271e 100644 --- a/tests/test_directives.py +++ b/tests/test_directives.py @@ -22,6 +22,7 @@ import pytest from openpulse.printer import dumps +import oqpy from oqpy import * from oqpy.base import expr_matches from oqpy.quantum_types import PhysicalQubits @@ -188,7 +189,10 @@ def test_binary_expressions(): i = IntVar(5, "i") j = IntVar(2, "j") prog.set(i, 2 * (i + j)) - prog.set(j, 2 % (2 + i) % 2) + prog.set(j, 2 % (2 - i) % 2) + prog.set(j, 1 + oqpy.pi) + prog.set(j, 1 / oqpy.pi**2 / 2 + 2**oqpy.pi) + prog.set(j, -oqpy.pi * oqpy.pi - i**j) expected = textwrap.dedent( """ @@ -196,7 +200,10 @@ def test_binary_expressions(): int[32] i = 5; int[32] j = 2; i = 2 * (i + j); - j = 2 % (2 + i) % 2; + j = 2 % (2 - i) % 2; + j = 1 + pi; + j = 1 / pi ** 2 / 2 + 2 ** pi; + j = -pi * pi - i ** j; """ ).strip() @@ -506,6 +513,154 @@ def test_set_shift_frequency(): assert prog.to_qasm() == expected +def test_defcals(): + prog = Program() + constant = declare_waveform_generator("constant", [("length", duration), ("iq", complex128)]) + + q_port = PortVar("q_port") + rx_port = PortVar("rx_port") + tx_port = PortVar("tx_port") + q_frame = FrameVar(q_port, 6.431e9, name="q_frame") + rx_frame = FrameVar(rx_port, 5.752e9, name="rx_frame") + tx_frame = FrameVar(tx_port, 5.752e9, name="tx_frame") + + q1 = PhysicalQubits[1] + q2 = PhysicalQubits[2] + + with defcal(prog, q2, "x"): + prog.play(q_frame, constant(1e-6, 0.1)) + + with defcal(prog, q2, "rx", [AngleVar(name="theta")]) as theta: + prog.increment(theta, 0.1) + prog.play(q_frame, constant(1e-6, 0.1)) + + with defcal(prog, q2, "rx", [pi / 3]): + prog.play(q_frame, constant(1e-6, 0.1)) + + with defcal(prog, [q1, q2], "xy", [AngleVar(name="theta"), +pi / 2]) as theta: + prog.increment(theta, 0.1) + prog.play(q_frame, constant(1e-6, 0.1)) + + with defcal(prog, [q1, q2], "xy", [AngleVar(name="theta"), FloatVar(name="phi"), 10]) as params: + theta, phi = params + prog.increment(theta, 0.1) + prog.increment(phi, 0.2) + prog.play(q_frame, constant(1e-6, 0.1)) + + with defcal(prog, q2, "readout", return_type=oqpy.bit): + prog.play(tx_frame, constant(2.4e-6, 0.2)) + prog.capture(rx_frame, constant(2.4e-6, 1)) + + with pytest.raises(AssertionError): + + with defcal(prog, q2, "readout", return_type=bool): + prog.play(tx_frame, constant(2.4e-6, 0.2)) + prog.capture(rx_frame, constant(2.4e-6, 1)) + + expected = textwrap.dedent( + """ + OPENQASM 3.0; + extern constant(duration, complex[float[64]]) -> waveform; + port rx_port; + port tx_port; + port q_port; + frame q_frame = newframe(q_port, 6431000000.0, 0); + frame tx_frame = newframe(tx_port, 5752000000.0, 0); + frame rx_frame = newframe(rx_port, 5752000000.0, 0); + defcal x $2 { + play(q_frame, constant(1000.0ns, 0.1)); + } + defcal rx(angle[32] theta) $2 { + theta += 0.1; + play(q_frame, constant(1000.0ns, 0.1)); + } + defcal rx(pi / 3) $2 { + play(q_frame, constant(1000.0ns, 0.1)); + } + defcal xy(angle[32] theta, pi / 2) $1, $2 { + theta += 0.1; + play(q_frame, constant(1000.0ns, 0.1)); + } + defcal xy(angle[32] theta, float[64] phi, 10) $1, $2 { + theta += 0.1; + phi += 0.2; + play(q_frame, constant(1000.0ns, 0.1)); + } + defcal readout $2 -> bit { + play(tx_frame, constant(2400.0ns, 0.2)); + capture(rx_frame, constant(2400.0ns, 1)); + } + """ + ).strip() + assert prog.to_qasm() == expected + + expect_defcal_rx_theta = textwrap.dedent( + """ + defcal rx(angle[32] theta) $2 { + theta += 0.1; + play(q_frame, constant(1000.0ns, 0.1)); + } + """ + ).strip() + assert ( + dumps(prog.defcals[(("$2",), "rx", ("angle[32] theta",))], indent=" ").strip() + == expect_defcal_rx_theta + ) + expect_defcal_rx_pio2 = textwrap.dedent( + """ + defcal rx(pi / 3) $2 { + play(q_frame, constant(1000.0ns, 0.1)); + } + """ + ).strip() + assert ( + dumps(prog.defcals[(("$2",), "rx", ("pi / 3",))], indent=" ").strip() + == expect_defcal_rx_pio2 + ) + expect_defcal_xy_theta_pio2 = textwrap.dedent( + """ + defcal xy(angle[32] theta, pi / 2) $1, $2 { + theta += 0.1; + play(q_frame, constant(1000.0ns, 0.1)); + } + """ + ).strip() + assert ( + dumps( + prog.defcals[(("$1", "$2"), "xy", ("angle[32] theta", "pi / 2"))], indent=" " + ).strip() + == expect_defcal_xy_theta_pio2 + ) + expect_defcal_xy_theta_phi = textwrap.dedent( + """ + defcal xy(angle[32] theta, float[64] phi, 10) $1, $2 { + theta += 0.1; + phi += 0.2; + play(q_frame, constant(1000.0ns, 0.1)); + } + """ + ).strip() + assert ( + dumps( + prog.defcals[(("$1", "$2"), "xy", ("angle[32] theta", "float[64] phi", "10"))], + indent=" ", + ).strip() + == expect_defcal_xy_theta_phi + ) + expect_defcal_readout_q2 = textwrap.dedent( + """ + defcal readout $2 -> bit { + play(tx_frame, constant(2400.0ns, 0.2)); + capture(rx_frame, constant(2400.0ns, 1)); + } + """ + ).strip() + assert ( + dumps(prog.defcals[(("$2",), "readout", ())], indent=" ").strip() + == expect_defcal_readout_q2 + ) + + def test_ramsey_example(): prog = Program() constant = declare_waveform_generator("constant", [("length", duration), ("iq", complex128)]) @@ -620,8 +775,11 @@ def test_ramsey_example(): ).strip() assert prog.to_qasm() == expected - assert dumps(prog.defcals[(("$2",), "x90")], indent=" ").strip() == expect_defcal_x90_q2 - assert dumps(prog.defcals[(("$2",), "readout")], indent=" ").strip() == expect_defcal_readout_q2 + assert dumps(prog.defcals[(("$2",), "x90", ())], indent=" ").strip() == expect_defcal_x90_q2 + assert ( + dumps(prog.defcals[(("$2",), "readout", ())], indent=" ").strip() + == expect_defcal_readout_q2 + ) def test_rabi_example(): @@ -748,11 +906,11 @@ def test_program_add(): ).strip() assert ( - dumps(prog2.defcals[(("$1", "$2"), "two_qubit_gate")], indent=" ").strip() + dumps(prog2.defcals[(("$1", "$2"), "two_qubit_gate", ())], indent=" ").strip() == expected_defcal_two_qubit_gate ) assert ( - dumps(prog.defcals[(("$1", "$2"), "two_qubit_gate")], indent=" ").strip() + dumps(prog.defcals[(("$1", "$2"), "two_qubit_gate", ())], indent=" ").strip() == expected_defcal_two_qubit_gate )