Skip to content

Add support for arguments & return types to defcals #20

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Nov 16, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 12 additions & 10 deletions oqpy/classical_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
"stretch",
"bool_",
"bit_",
"bit",
"bit8",
"convert_range",
"int_",
Expand All @@ -78,24 +79,24 @@
# subclasses of ``_ClassicalVar`` instead.


def int_(size: int) -> ast.IntType:
def int_(size: int | None = None) -> ast.IntType:
"""Create a sized signed integer type."""
return ast.IntType(ast.IntegerLiteral(size))
return ast.IntType(ast.IntegerLiteral(size) if size is not None else None)


def uint_(size: int) -> ast.UintType:
def uint_(size: int | None = None) -> ast.UintType:
"""Create a sized unsigned integer type."""
return ast.UintType(ast.IntegerLiteral(size))
return ast.UintType(ast.IntegerLiteral(size) if size is not None else None)


def float_(size: int) -> ast.FloatType:
def float_(size: int | None = None) -> ast.FloatType:
"""Create a sized floating-point type."""
return ast.FloatType(ast.IntegerLiteral(size))
return ast.FloatType(ast.IntegerLiteral(size) if size is not None else None)


def angle_(size: int) -> ast.AngleType:
def angle_(size: int | None = None) -> ast.AngleType:
"""Create a sized angle type."""
return ast.AngleType(ast.IntegerLiteral(size))
return ast.AngleType(ast.IntegerLiteral(size) if size is not None else None)


def complex_(size: int) -> ast.ComplexType:
Expand All @@ -107,14 +108,15 @@ def complex_(size: int) -> ast.ComplexType:
return ast.ComplexType(ast.FloatType(ast.IntegerLiteral(size // 2)))


def bit_(size: int) -> ast.BitType:
def bit_(size: int | None = None) -> ast.BitType:
"""Create a sized bit type."""
return ast.BitType(ast.IntegerLiteral(size))
return ast.BitType(ast.IntegerLiteral(size) if size is not None else None)


duration = ast.DurationType()
stretch = ast.StretchType()
bool_ = ast.BoolType()
bit = ast.BitType()
bit8 = bit_(8)
int32 = int_(32)
int64 = int_(64)
Expand Down
12 changes: 9 additions & 3 deletions oqpy/program.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,9 @@ class Program:

def __init__(self, version: Optional[str] = "3.0") -> None:
self.stack: list[ProgramState] = [ProgramState()]
self.defcals: dict[tuple[tuple[str, ...], str], ast.CalibrationDefinition] = {}
self.defcals: dict[
tuple[tuple[str, ...], str, tuple[str, ...]], ast.CalibrationDefinition
] = {}
self.subroutines: dict[str, ast.SubroutineDefinition] = {}
self.externs: dict[str, ast.ExternDeclaration] = {}
self.declared_vars: dict[str, Var] = {}
Expand Down Expand Up @@ -196,13 +198,17 @@ def _add_subroutine(self, name: str, stmt: ast.SubroutineDefinition) -> None:
self.subroutines[name] = stmt

def _add_defcal(
self, qubit_names: list[str], name: str, stmt: ast.CalibrationDefinition
self,
qubit_names: list[str],
name: str,
arguments: list[str],
stmt: ast.CalibrationDefinition,
) -> None:
"""Register a defcal defined in this program.

Defcals are added to the top of the program upon conversion to ast.
"""
self.defcals[(tuple(qubit_names), name)] = stmt
self.defcals[(tuple(qubit_names), name, tuple(arguments))] = stmt

def _make_externs_statements(self, auto_encal: bool = False) -> list[ast.ExternDeclaration]:
"""Return a list of extern statements for inclusion at beginning of program.
Expand Down
51 changes: 41 additions & 10 deletions oqpy/quantum_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@
from __future__ import annotations

import contextlib
from typing import TYPE_CHECKING, Iterator, Union
from typing import TYPE_CHECKING, Iterator, Optional, Union

from openpulse import ast
from openpulse.printer import dumps

from oqpy.base import Var
from oqpy.classical_types import _ClassicalVar

if TYPE_CHECKING:
from oqpy.program import Program
Expand Down Expand Up @@ -64,30 +66,59 @@ class QubitArray:


@contextlib.contextmanager
def defcal(program: Program, qubits: Union[Qubit, list[Qubit]], name: str) -> Iterator[None]:
def defcal(
program: Program,
qubits: Union[Qubit, list[Qubit]],
name: str,
arguments: Optional[list[Union[_ClassicalVar, str]]] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The behavior for fixed parameters is not quite right here, in that if we reparse the generated openqasm it won't match the AST we store. The issue is assuming the string is an identifier. In fact the value you pass in your examples is not an identifier, but instead the binary expression pi/2.

I think the type here should be Optional[list[AstConvertible]] and shouldn't take arbitrary strings.

We may also want to define a module constant oqpy.pi = AngleType(name="pi", needs_declaration=False). so that the user need not define a pi constant for themselves

Copy link
Collaborator Author

@jcjaskula-aws jcjaskula-aws Nov 12, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It won't with the reference parser indeed. I'll see if I can be satisfied with OQPyExpression.

return_type: Optional[ast.ClassicalType] = None,
) -> Union[Iterator[None], Iterator[list[_ClassicalVar]], Iterator[_ClassicalVar]]:
"""Context manager for creating a defcal.

.. code-block:: python

with defcal(program, q1, "X"):
with defcal(program, q1, "X", [AngleVar(name="theta"), "pi/2"], oqpy.bit) as theta:
program.play(frame, waveform)
"""
program._push()
yield
state = program._pop()

if isinstance(qubits, Qubit):
qubits = [qubits]
assert return_type is None or isinstance(return_type, ast.ClassicalType)

arguments_ast = []
variables = []
if arguments is not None:
for arg in arguments:
if isinstance(arg, _ClassicalVar):
arguments_ast.append(
ast.ClassicalArgument(type=arg.type, name=ast.Identifier(name=arg.name))
)
arg._needs_declaration = False
variables.append(arg)
elif isinstance(arg, str):
arguments_ast.append(ast.Identifier(name=arg))
else:
raise TypeError(f"{arg} should be of type oqpy._Classical or str")

program._push()
if len(variables) > 1:
yield variables
elif len(variables) == 1:
yield variables[0]
else:
yield
state = program._pop()

stmt = ast.CalibrationDefinition(
ast.Identifier(name),
[], # TODO (#52): support arguments
arguments_ast,
[ast.Identifier(q.name) for q in qubits],
None, # TODO (#52): support return type,
return_type,
state.body,
)
program._add_statement(stmt)
program._add_defcal([qubit.name for qubit in qubits], name, stmt)
program._add_defcal(
[qubit.name for qubit in qubits], name, [dumps(a) for a in arguments_ast], stmt
)


@contextlib.contextmanager
Expand Down
159 changes: 155 additions & 4 deletions tests/test_directives.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import pytest
from openpulse.printer import dumps

import oqpy
from oqpy import *
from oqpy.base import expr_matches
from oqpy.quantum_types import PhysicalQubits
Expand Down Expand Up @@ -506,6 +507,153 @@ def test_set_shift_frequency():
assert prog.to_qasm() == expected


def test_defcals():
prog = Program()
constant = declare_waveform_generator("constant", [("length", duration), ("iq", complex128)])

q_port = PortVar("q_port")
rx_port = PortVar("rx_port")
tx_port = PortVar("tx_port")
q_frame = FrameVar(q_port, 6.431e9, name="q_frame")
rx_frame = FrameVar(rx_port, 5.752e9, name="rx_frame")
tx_frame = FrameVar(tx_port, 5.752e9, name="tx_frame")

q1 = PhysicalQubits[1]
q2 = PhysicalQubits[2]

with defcal(prog, q2, "x"):
prog.play(q_frame, constant(1e-6, 0.1))

with defcal(prog, q2, "rx", [AngleVar(name="theta")]) as theta:
prog.increment(theta, 0.1)
prog.play(q_frame, constant(1e-6, 0.1))

with defcal(prog, q2, "rx", ["pi/2"]):
prog.play(q_frame, constant(1e-6, 0.1))

with defcal(prog, [q1, q2], "xy", [AngleVar(name="theta"), "pi/2"]) as theta:
prog.increment(theta, 0.1)
prog.play(q_frame, constant(1e-6, 0.1))

with defcal(prog, [q1, q2], "xy", [AngleVar(name="theta"), FloatVar(name="phi")]) as params:
theta, phi = params
prog.increment(theta, 0.1)
prog.increment(phi, 0.2)
prog.play(q_frame, constant(1e-6, 0.1))

with defcal(prog, q2, "readout", return_type=oqpy.bit):
prog.play(tx_frame, constant(2.4e-6, 0.2))
prog.capture(rx_frame, constant(2.4e-6, 1))

with pytest.raises(AssertionError):

with defcal(prog, q2, "readout", return_type=bool):
prog.play(tx_frame, constant(2.4e-6, 0.2))
prog.capture(rx_frame, constant(2.4e-6, 1))

expected = textwrap.dedent(
"""
OPENQASM 3.0;
extern constant(duration, complex[float[64]]) -> waveform;
port rx_port;
port tx_port;
port q_port;
frame q_frame = newframe(q_port, 6431000000.0, 0);
frame tx_frame = newframe(tx_port, 5752000000.0, 0);
frame rx_frame = newframe(rx_port, 5752000000.0, 0);
defcal x $2 {
play(q_frame, constant(1000.0ns, 0.1));
}
defcal rx(angle[32] theta) $2 {
theta += 0.1;
play(q_frame, constant(1000.0ns, 0.1));
}
defcal rx(pi/2) $2 {
play(q_frame, constant(1000.0ns, 0.1));
}
defcal xy(angle[32] theta, pi/2) $1, $2 {
theta += 0.1;
play(q_frame, constant(1000.0ns, 0.1));
}
defcal xy(angle[32] theta, float[64] phi) $1, $2 {
theta += 0.1;
phi += 0.2;
play(q_frame, constant(1000.0ns, 0.1));
}
defcal readout $2 -> bit {
play(tx_frame, constant(2400.0ns, 0.2));
capture(rx_frame, constant(2400.0ns, 1));
}
"""
).strip()
assert prog.to_qasm() == expected

expect_defcal_rx_theta = textwrap.dedent(
"""
defcal rx(angle[32] theta) $2 {
theta += 0.1;
play(q_frame, constant(1000.0ns, 0.1));
}
"""
).strip()
assert (
dumps(prog.defcals[(("$2",), "rx", ("angle[32] theta",))], indent=" ").strip()
== expect_defcal_rx_theta
)
expect_defcal_rx_pio2 = textwrap.dedent(
"""
defcal rx(pi/2) $2 {
play(q_frame, constant(1000.0ns, 0.1));
}
"""
).strip()
assert (
dumps(prog.defcals[(("$2",), "rx", ("pi/2",))], indent=" ").strip()
== expect_defcal_rx_pio2
)
expect_defcal_xy_theta_pio2 = textwrap.dedent(
"""
defcal xy(angle[32] theta, pi/2) $1, $2 {
theta += 0.1;
play(q_frame, constant(1000.0ns, 0.1));
}
"""
).strip()
assert (
dumps(
prog.defcals[(("$1", "$2"), "xy", ("angle[32] theta", "pi/2"))], indent=" "
).strip()
== expect_defcal_xy_theta_pio2
)
expect_defcal_xy_theta_phi = textwrap.dedent(
"""
defcal xy(angle[32] theta, float[64] phi) $1, $2 {
theta += 0.1;
phi += 0.2;
play(q_frame, constant(1000.0ns, 0.1));
}
"""
).strip()
assert (
dumps(
prog.defcals[(("$1", "$2"), "xy", ("angle[32] theta", "float[64] phi"))], indent=" "
).strip()
== expect_defcal_xy_theta_phi
)
expect_defcal_readout_q2 = textwrap.dedent(
"""
defcal readout $2 -> bit {
play(tx_frame, constant(2400.0ns, 0.2));
capture(rx_frame, constant(2400.0ns, 1));
}
"""
).strip()
assert (
dumps(prog.defcals[(("$2",), "readout", ())], indent=" ").strip()
== expect_defcal_readout_q2
)


def test_ramsey_example():
prog = Program()
constant = declare_waveform_generator("constant", [("length", duration), ("iq", complex128)])
Expand Down Expand Up @@ -620,8 +768,11 @@ def test_ramsey_example():
).strip()

assert prog.to_qasm() == expected
assert dumps(prog.defcals[(("$2",), "x90")], indent=" ").strip() == expect_defcal_x90_q2
assert dumps(prog.defcals[(("$2",), "readout")], indent=" ").strip() == expect_defcal_readout_q2
assert dumps(prog.defcals[(("$2",), "x90", ())], indent=" ").strip() == expect_defcal_x90_q2
assert (
dumps(prog.defcals[(("$2",), "readout", ())], indent=" ").strip()
== expect_defcal_readout_q2
)


def test_rabi_example():
Expand Down Expand Up @@ -748,11 +899,11 @@ def test_program_add():
).strip()

assert (
dumps(prog2.defcals[(("$1", "$2"), "two_qubit_gate")], indent=" ").strip()
dumps(prog2.defcals[(("$1", "$2"), "two_qubit_gate", ())], indent=" ").strip()
== expected_defcal_two_qubit_gate
)
assert (
dumps(prog.defcals[(("$1", "$2"), "two_qubit_gate")], indent=" ").strip()
dumps(prog.defcals[(("$1", "$2"), "two_qubit_gate", ())], indent=" ").strip()
== expected_defcal_two_qubit_gate
)

Expand Down