Skip to content

Commit d85570b

Browse files
Add support for arguments & return types to defcals (#20)
* Add arguments and return type to defcals * Modify how arguments are specified * Support OQPyExpression in defcal * Add docstring * Overload more operators with OQPyExpression * Fix typo in __rpow__ * Improve test coverage
1 parent 32cb438 commit d85570b

File tree

5 files changed

+288
-30
lines changed

5 files changed

+288
-30
lines changed

oqpy/base.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,12 +60,29 @@ def _to_binary(
6060
"""Helper method to produce a binary expression."""
6161
return OQPyBinaryExpression(ast.BinaryOperator[op_name], first, second)
6262

63+
@staticmethod
64+
def _to_unary(op_name: str, exp: AstConvertible) -> OQPyUnaryExpression:
65+
"""Helper method to produce a binary expression."""
66+
return OQPyUnaryExpression(ast.UnaryOperator[op_name], exp)
67+
68+
def __pos__(self) -> OQPyExpression:
69+
return self
70+
71+
def __neg__(self) -> OQPyUnaryExpression:
72+
return self._to_unary("-", self)
73+
6374
def __add__(self, other: AstConvertible) -> OQPyBinaryExpression:
6475
return self._to_binary("+", self, other)
6576

6677
def __radd__(self, other: AstConvertible) -> OQPyBinaryExpression:
6778
return self._to_binary("+", other, self)
6879

80+
def __sub__(self, other: AstConvertible) -> OQPyBinaryExpression:
81+
return self._to_binary("-", self, other)
82+
83+
def __rsub__(self, other: AstConvertible) -> OQPyBinaryExpression:
84+
return self._to_binary("-", other, self)
85+
6986
def __mod__(self, other: AstConvertible) -> OQPyBinaryExpression:
7087
return self._to_binary("%", self, other)
7188

@@ -78,6 +95,18 @@ def __mul__(self, other: AstConvertible) -> OQPyBinaryExpression:
7895
def __rmul__(self, other: AstConvertible) -> OQPyBinaryExpression:
7996
return self._to_binary("*", other, self)
8097

98+
def __truediv__(self, other: AstConvertible) -> OQPyBinaryExpression:
99+
return self._to_binary("/", self, other)
100+
101+
def __rtruediv__(self, other: AstConvertible) -> OQPyBinaryExpression:
102+
return self._to_binary("/", other, self)
103+
104+
def __pow__(self, other: AstConvertible) -> OQPyBinaryExpression:
105+
return self._to_binary("**", self, other)
106+
107+
def __rpow__(self, other: AstConvertible) -> OQPyBinaryExpression:
108+
return self._to_binary("**", other, self)
109+
81110
def __eq__(self, other: AstConvertible) -> OQPyBinaryExpression: # type: ignore[override]
82111
return self._to_binary("==", self, other)
83112

@@ -132,6 +161,23 @@ def _to_oqpy_expression(self) -> HasToAst:
132161
...
133162

134163

164+
class OQPyUnaryExpression(OQPyExpression):
165+
"""An expression consisting of one expression preceded by an operator."""
166+
167+
def __init__(self, op: ast.UnaryOperator, exp: AstConvertible):
168+
super().__init__()
169+
self.op = op
170+
self.exp = exp
171+
if isinstance(exp, OQPyExpression):
172+
self.type = exp.type
173+
else:
174+
raise TypeError("exp is an expression")
175+
176+
def to_ast(self, program: Program) -> ast.UnaryExpression:
177+
"""Converts the OQpy expression into an ast node."""
178+
return ast.UnaryExpression(self.op, to_ast(program, self.exp))
179+
180+
135181
class OQPyBinaryExpression(OQPyExpression):
136182
"""An expression consisting of two subexpressions joined by an operator."""
137183

oqpy/classical_types.py

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from oqpy.program import Program
3939

4040
__all__ = [
41+
"pi",
4142
"BoolVar",
4243
"IntVar",
4344
"UintVar",
@@ -53,6 +54,7 @@
5354
"stretch",
5455
"bool_",
5556
"bit_",
57+
"bit",
5658
"bit8",
5759
"convert_range",
5860
"int_",
@@ -78,24 +80,24 @@
7880
# subclasses of ``_ClassicalVar`` instead.
7981

8082

81-
def int_(size: int) -> ast.IntType:
83+
def int_(size: int | None = None) -> ast.IntType:
8284
"""Create a sized signed integer type."""
83-
return ast.IntType(ast.IntegerLiteral(size))
85+
return ast.IntType(ast.IntegerLiteral(size) if size is not None else None)
8486

8587

86-
def uint_(size: int) -> ast.UintType:
88+
def uint_(size: int | None = None) -> ast.UintType:
8789
"""Create a sized unsigned integer type."""
88-
return ast.UintType(ast.IntegerLiteral(size))
90+
return ast.UintType(ast.IntegerLiteral(size) if size is not None else None)
8991

9092

91-
def float_(size: int) -> ast.FloatType:
93+
def float_(size: int | None = None) -> ast.FloatType:
9294
"""Create a sized floating-point type."""
93-
return ast.FloatType(ast.IntegerLiteral(size))
95+
return ast.FloatType(ast.IntegerLiteral(size) if size is not None else None)
9496

9597

96-
def angle_(size: int) -> ast.AngleType:
98+
def angle_(size: int | None = None) -> ast.AngleType:
9799
"""Create a sized angle type."""
98-
return ast.AngleType(ast.IntegerLiteral(size))
100+
return ast.AngleType(ast.IntegerLiteral(size) if size is not None else None)
99101

100102

101103
def complex_(size: int) -> ast.ComplexType:
@@ -107,14 +109,15 @@ def complex_(size: int) -> ast.ComplexType:
107109
return ast.ComplexType(ast.FloatType(ast.IntegerLiteral(size // 2)))
108110

109111

110-
def bit_(size: int) -> ast.BitType:
112+
def bit_(size: int | None = None) -> ast.BitType:
111113
"""Create a sized bit type."""
112-
return ast.BitType(ast.IntegerLiteral(size))
114+
return ast.BitType(ast.IntegerLiteral(size) if size is not None else None)
113115

114116

115117
duration = ast.DurationType()
116118
stretch = ast.StretchType()
117119
bool_ = ast.BoolType()
120+
bit = ast.BitType()
118121
bit8 = bit_(8)
119122
int32 = int_(32)
120123
int64 = int_(64)
@@ -136,6 +139,22 @@ def convert_range(program: Program, item: Union[slice, range]) -> ast.RangeDefin
136139
)
137140

138141

142+
class Identifier(OQPyExpression):
143+
"""Base class to specify constant symbols."""
144+
145+
name: str
146+
147+
def __init__(self, name: str) -> None:
148+
self.type = None
149+
self.name = name
150+
151+
def to_ast(self, program: Program) -> ast.Expression:
152+
return ast.Identifier(name=self.name)
153+
154+
155+
pi = Identifier(name="pi")
156+
157+
139158
class _ClassicalVar(Var, OQPyExpression):
140159
"""Base type for variables with classical type.
141160

oqpy/program.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,9 @@ class Program:
8282

8383
def __init__(self, version: Optional[str] = "3.0") -> None:
8484
self.stack: list[ProgramState] = [ProgramState()]
85-
self.defcals: dict[tuple[tuple[str, ...], str], ast.CalibrationDefinition] = {}
85+
self.defcals: dict[
86+
tuple[tuple[str, ...], str, tuple[str, ...]], ast.CalibrationDefinition
87+
] = {}
8688
self.subroutines: dict[str, ast.SubroutineDefinition] = {}
8789
self.externs: dict[str, ast.ExternDeclaration] = {}
8890
self.declared_vars: dict[str, Var] = {}
@@ -196,13 +198,17 @@ def _add_subroutine(self, name: str, stmt: ast.SubroutineDefinition) -> None:
196198
self.subroutines[name] = stmt
197199

198200
def _add_defcal(
199-
self, qubit_names: list[str], name: str, stmt: ast.CalibrationDefinition
201+
self,
202+
qubit_names: list[str],
203+
name: str,
204+
arguments: list[str],
205+
stmt: ast.CalibrationDefinition,
200206
) -> None:
201207
"""Register a defcal defined in this program.
202208
203209
Defcals are added to the top of the program upon conversion to ast.
204210
"""
205-
self.defcals[(tuple(qubit_names), name)] = stmt
211+
self.defcals[(tuple(qubit_names), name, tuple(arguments))] = stmt
206212

207213
def _make_externs_statements(self, auto_encal: bool = False) -> list[ast.ExternDeclaration]:
208214
"""Return a list of extern statements for inclusion at beginning of program.

oqpy/quantum_types.py

Lines changed: 40 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,13 @@
1818
from __future__ import annotations
1919

2020
import contextlib
21-
from typing import TYPE_CHECKING, Iterator, Union
21+
from typing import TYPE_CHECKING, Iterator, Optional, Union
2222

2323
from openpulse import ast
24+
from openpulse.printer import dumps
2425

25-
from oqpy.base import Var
26+
from oqpy.base import AstConvertible, Var, to_ast
27+
from oqpy.classical_types import _ClassicalVar
2628

2729
if TYPE_CHECKING:
2830
from oqpy.program import Program
@@ -64,30 +66,57 @@ class QubitArray:
6466

6567

6668
@contextlib.contextmanager
67-
def defcal(program: Program, qubits: Union[Qubit, list[Qubit]], name: str) -> Iterator[None]:
69+
def defcal(
70+
program: Program,
71+
qubits: Union[Qubit, list[Qubit]],
72+
name: str,
73+
arguments: Optional[list[AstConvertible]] = None,
74+
return_type: Optional[ast.ClassicalType] = None,
75+
) -> Union[Iterator[None], Iterator[list[_ClassicalVar]], Iterator[_ClassicalVar]]:
6876
"""Context manager for creating a defcal.
6977
7078
.. code-block:: python
7179
72-
with defcal(program, q1, "X"):
80+
with defcal(program, q1, "X", [AngleVar(name="theta"), oqpy.pi/2], oqpy.bit) as theta:
7381
program.play(frame, waveform)
7482
"""
75-
program._push()
76-
yield
77-
state = program._pop()
78-
7983
if isinstance(qubits, Qubit):
8084
qubits = [qubits]
85+
assert return_type is None or isinstance(return_type, ast.ClassicalType)
86+
87+
arguments_ast = []
88+
variables = []
89+
if arguments is not None:
90+
for arg in arguments:
91+
if isinstance(arg, _ClassicalVar):
92+
arguments_ast.append(
93+
ast.ClassicalArgument(type=arg.type, name=ast.Identifier(name=arg.name))
94+
)
95+
arg._needs_declaration = False
96+
variables.append(arg)
97+
else:
98+
arguments_ast.append(to_ast(program, arg))
99+
100+
program._push()
101+
if len(variables) > 1:
102+
yield variables
103+
elif len(variables) == 1:
104+
yield variables[0]
105+
else:
106+
yield
107+
state = program._pop()
81108

82109
stmt = ast.CalibrationDefinition(
83110
ast.Identifier(name),
84-
[], # TODO (#52): support arguments
111+
arguments_ast,
85112
[ast.Identifier(q.name) for q in qubits],
86-
None, # TODO (#52): support return type,
113+
return_type,
87114
state.body,
88115
)
89116
program._add_statement(stmt)
90-
program._add_defcal([qubit.name for qubit in qubits], name, stmt)
117+
program._add_defcal(
118+
[qubit.name for qubit in qubits], name, [dumps(a) for a in arguments_ast], stmt
119+
)
91120

92121

93122
@contextlib.contextmanager

0 commit comments

Comments
 (0)