Skip to content

Commit 7916a80

Browse files
authored
Add capability to annotate statements (#34)
* Add capability to annotate statements * Address review comments * style/lint fixes
1 parent ddffe17 commit 7916a80

File tree

7 files changed

+236
-28
lines changed

7 files changed

+236
-28
lines changed

oqpy/base.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
import sys
2525
from abc import ABC, abstractmethod
26-
from typing import TYPE_CHECKING, Any, Iterable, Union
26+
from typing import TYPE_CHECKING, Any, Iterable, Sequence, Union
2727

2828
import numpy as np
2929
from openpulse import ast
@@ -335,3 +335,15 @@ def optional_ast(program: Program, item: AstConvertible | None) -> ast.Expressio
335335
def map_to_ast(program: Program, items: Iterable[AstConvertible]) -> list[ast.Expression]:
336336
"""Convert a sequence of items into a sequence of ast nodes."""
337337
return [to_ast(program, item) for item in items]
338+
339+
340+
def make_annotations(vals: Sequence[str | tuple[str, str]]) -> list[ast.Annotation]:
341+
"""Convert strings/tuples of strings into Annotation ast nodes."""
342+
anns: list[ast.Annotation] = []
343+
for val in vals:
344+
if isinstance(val, str):
345+
anns.append(ast.Annotation(val))
346+
else:
347+
keyword, command = val
348+
anns.append(ast.Annotation(keyword, command))
349+
return anns

oqpy/classical_types.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,24 @@
2020
import functools
2121
import random
2222
import string
23-
from typing import TYPE_CHECKING, Any, Callable, Iterable, Type, TypeVar, Union
23+
from typing import (
24+
TYPE_CHECKING,
25+
Any,
26+
Callable,
27+
Iterable,
28+
Sequence,
29+
Type,
30+
TypeVar,
31+
Union,
32+
)
2433

2534
from openpulse import ast
2635

2736
from oqpy.base import (
2837
AstConvertible,
2938
OQPyExpression,
3039
Var,
40+
make_annotations,
3141
map_to_ast,
3242
optional_ast,
3343
to_ast,
@@ -170,12 +180,14 @@ def __init__(
170180
init_expression: AstConvertible | None = None,
171181
name: str | None = None,
172182
needs_declaration: bool = True,
183+
annotations: Sequence[str | tuple[str, str]] = (),
173184
**type_kwargs: Any,
174185
):
175186
name = name or "".join([random.choice(string.ascii_letters) for _ in range(10)])
176187
super().__init__(name, needs_declaration=needs_declaration)
177188
self.type = self.type_cls(**type_kwargs)
178189
self.init_expression = init_expression
190+
self.annotations = annotations
179191

180192
def to_ast(self, program: Program) -> ast.Identifier:
181193
"""Converts the OQpy variable into an ast node."""
@@ -185,7 +197,9 @@ def to_ast(self, program: Program) -> ast.Identifier:
185197
def make_declaration_statement(self, program: Program) -> ast.Statement:
186198
"""Make an ast statement that declares the OQpy variable."""
187199
init_expression_ast = optional_ast(program, self.init_expression)
188-
return ast.ClassicalDeclaration(self.type, self.to_ast(program), init_expression_ast)
200+
stmt = ast.ClassicalDeclaration(self.type, self.to_ast(program), init_expression_ast)
201+
stmt.annotations = make_annotations(self.annotations)
202+
return stmt
189203

190204

191205
class BoolVar(_ClassicalVar):

oqpy/program.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
from __future__ import annotations
2525

26+
import warnings
2627
from copy import deepcopy
2728
from typing import Any, Iterable, Iterator, Optional
2829

@@ -56,6 +57,7 @@ class ProgramState:
5657
def __init__(self) -> None:
5758
self.body: list[ast.Statement] = []
5859
self.if_clause: Optional[ast.BranchingStatement] = None
60+
self.annotations: list[ast.Annotation] = []
5961

6062
def add_if_clause(self, condition: ast.Expression, if_clause: list[ast.Statement]) -> None:
6163
self.finalize_if_clause()
@@ -73,7 +75,11 @@ def finalize_if_clause(self) -> None:
7375
self.add_statement(if_clause)
7476

7577
def add_statement(self, stmt: ast.Statement) -> None:
78+
assert isinstance(stmt, ast.Statement)
7679
self.finalize_if_clause()
80+
if self.annotations:
81+
stmt.annotations = self.annotations + list(stmt.annotations)
82+
self.annotations = []
7783
self.body.append(stmt)
7884

7985

@@ -149,6 +155,8 @@ def _pop(self) -> ProgramState:
149155
"""Close a context by removing the program state from the top stack, and return it."""
150156
state = self.stack.pop()
151157
state.finalize_if_clause()
158+
if state.annotations:
159+
warnings.warn(f"Annotation(s) {state.annotations} not applied to any statement")
152160
return state
153161

154162
def _add_var(self, var: Var) -> None:
@@ -259,6 +267,8 @@ def to_ast(
259267

260268
assert len(self.stack) == 1
261269
self._state.finalize_if_clause()
270+
if self._state.annotations:
271+
warnings.warn(f"Annotation(s) {self._state.annotations} not applied to any statement")
262272
statements = []
263273
if include_externs:
264274
statements += self._make_externs_statements(encal_declarations)
@@ -459,6 +469,11 @@ def _do_assignment(self, var: AstConvertible, op: str, value: AstConvertible) ->
459469
)
460470
)
461471

472+
def do_expression(self, expression: AstConvertible) -> Program:
473+
"""Add a statement which evaluates a given expression without assigning the output."""
474+
self._add_statement(ast.ExpressionStatement(to_ast(self, expression)))
475+
return self
476+
462477
def set(
463478
self,
464479
var: classical_types._ClassicalVar | classical_types.OQIndexExpression,
@@ -484,6 +499,11 @@ def mod_equals(self, var: classical_types.IntVar, value: AstConvertible) -> Prog
484499
self._do_assignment(var, "%=", value)
485500
return self
486501

502+
def annotate(self, keyword: str, command: Optional[str] = None) -> Program:
503+
"""Add an annotation to the next statement."""
504+
self._state.annotations.append(ast.Annotation(keyword, command))
505+
return self
506+
487507

488508
class MergeCalStatementsPass(QASMVisitor[None]):
489509
"""Merge adjacent CalibrationStatement ast nodes."""

oqpy/pulse.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
from __future__ import annotations
1919

20-
from typing import Any
20+
from typing import Any, Sequence
2121

2222
from openpulse import ast
2323

@@ -67,6 +67,7 @@ def __init__(
6767
phase: AstConvertible = 0,
6868
name: str | None = None,
6969
needs_declaration: bool = True,
70+
annotations: Sequence[str | tuple[str, str]] = (),
7071
):
7172
if (port is None) != (frequency is None):
7273
raise ValueError("Must declare both port and frequency or neither.")
@@ -75,4 +76,6 @@ def __init__(
7576
else:
7677
assert frequency is not None
7778
init_expression = OQFunctionCall("newframe", [port, frequency, phase], ast.FrameType)
78-
super().__init__(init_expression, name, needs_declaration=needs_declaration)
79+
super().__init__(
80+
init_expression, name, needs_declaration=needs_declaration, annotations=annotations
81+
)

oqpy/quantum_types.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,12 @@
1818
from __future__ import annotations
1919

2020
import contextlib
21-
from typing import TYPE_CHECKING, Iterator, Optional, Union
21+
from typing import TYPE_CHECKING, Iterator, Optional, Sequence, Union
2222

2323
from openpulse import ast
2424
from openpulse.printer import dumps
2525

26-
from oqpy.base import AstConvertible, Var, to_ast
26+
from oqpy.base import AstConvertible, Var, make_annotations, to_ast
2727
from oqpy.classical_types import _ClassicalVar
2828

2929
if TYPE_CHECKING:
@@ -36,9 +36,15 @@
3636
class Qubit(Var):
3737
"""OQpy variable representing a single qubit."""
3838

39-
def __init__(self, name: str, needs_declaration: bool = True):
39+
def __init__(
40+
self,
41+
name: str,
42+
needs_declaration: bool = True,
43+
annotations: Sequence[str | tuple[str, str]] = (),
44+
):
4045
super().__init__(name, needs_declaration=needs_declaration)
4146
self.name = name
47+
self.annotations = annotations
4248

4349
def to_ast(self, prog: Program) -> ast.Expression:
4450
"""Converts the OQpy variable into an ast node."""
@@ -47,7 +53,9 @@ def to_ast(self, prog: Program) -> ast.Expression:
4753

4854
def make_declaration_statement(self, program: Program) -> ast.Statement:
4955
"""Make an ast statement that declares the OQpy variable."""
50-
return ast.QubitDeclaration(ast.Identifier(self.name), size=None)
56+
decl = ast.QubitDeclaration(ast.Identifier(self.name), size=None)
57+
decl.annotations = make_annotations(self.annotations)
58+
return decl
5159

5260

5361
class PhysicalQubits:

oqpy/subroutines.py

Lines changed: 43 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,18 @@
1919

2020
import functools
2121
import inspect
22-
from typing import Callable, get_type_hints
22+
from typing import Any, Callable, Sequence, TypeVar, get_type_hints
2323

2424
from mypy_extensions import VarArg
2525
from openpulse import ast
2626

2727
import oqpy.program
28-
from oqpy.base import AstConvertible, OQPyExpression, to_ast
28+
from oqpy.base import AstConvertible, OQPyExpression, make_annotations, to_ast
2929
from oqpy.classical_types import OQFunctionCall, _ClassicalVar
3030
from oqpy.quantum_types import Qubit
3131
from oqpy.timing import make_duration
3232

33-
__all__ = ["subroutine", "declare_extern", "declare_waveform_generator"]
33+
__all__ = ["subroutine", "annotate_subroutine", "declare_extern", "declare_waveform_generator"]
3434

3535
SubroutineParams = [oqpy.Program, VarArg(AstConvertible)]
3636

@@ -68,7 +68,11 @@ def increment_variable(int[32] i) {
6868
"""
6969

7070
@functools.wraps(func)
71-
def wrapper(program: oqpy.Program, *args: AstConvertible) -> OQFunctionCall:
71+
def wrapper(
72+
program: oqpy.Program,
73+
*args: AstConvertible,
74+
annotations: Sequence[str | tuple[str, str]] = (),
75+
) -> OQFunctionCall:
7276
name = func.__name__
7377
identifier = ast.Identifier(func.__name__)
7478
argnames = list(inspect.signature(func).parameters.keys())
@@ -117,13 +121,42 @@ def wrapper(program: oqpy.Program, *args: AstConvertible) -> OQFunctionCall:
117121
return_type=return_type,
118122
body=body,
119123
)
124+
stmt.annotations = make_annotations(annotations)
120125
return OQFunctionCall(identifier, args, return_type, subroutine_decl=stmt)
121126

122127
return wrapper
123128

124129

130+
FnType = TypeVar("FnType", bound=Callable[..., Any])
131+
132+
133+
def annotate_subroutine(keyword: str, command: str | None = None) -> Callable[[FnType], FnType]:
134+
"""Add annotation to a subroutine."""
135+
136+
def annotate_subroutine_decorator(func: FnType) -> FnType:
137+
@functools.wraps(func)
138+
def wrapper(
139+
program: oqpy.Program,
140+
*args: AstConvertible,
141+
annotations: Sequence[str | tuple[str, str]] = (),
142+
) -> OQFunctionCall:
143+
new_ann: str | tuple[str, str]
144+
if command is not None:
145+
new_ann = keyword, command
146+
else:
147+
new_ann = keyword
148+
return func(program, *args, annotations=list(annotations) + [new_ann])
149+
150+
return wrapper # type: ignore[return-value]
151+
152+
return annotate_subroutine_decorator
153+
154+
125155
def declare_extern(
126-
name: str, args: list[tuple[str, ast.ClassicalType]], return_type: ast.ClassicalType
156+
name: str,
157+
args: list[tuple[str, ast.ClassicalType]],
158+
return_type: ast.ClassicalType,
159+
annotations: Sequence[str | tuple[str, str]] = (),
127160
) -> Callable[..., OQFunctionCall]:
128161
"""Declare an extern and return a callable which adds the extern.
129162
@@ -145,6 +178,7 @@ def declare_extern(
145178
[ast.ExternArgument(type=t) for t in arg_types],
146179
ast.ExternArgument(type=return_type),
147180
)
181+
extern_decl.annotations = make_annotations(annotations)
148182

149183
def call_extern(*call_args: AstConvertible, **call_kwargs: AstConvertible) -> OQFunctionCall:
150184
new_args = list(call_args) + [None] * len(call_kwargs)
@@ -180,8 +214,10 @@ def call_extern(*call_args: AstConvertible, **call_kwargs: AstConvertible) -> OQ
180214

181215

182216
def declare_waveform_generator(
183-
name: str, argtypes: list[tuple[str, ast.ClassicalType]]
217+
name: str,
218+
argtypes: list[tuple[str, ast.ClassicalType]],
219+
annotations: Sequence[str | tuple[str, str]] = (),
184220
) -> Callable[..., OQFunctionCall]:
185221
"""Create a function which generates waveforms using a specified name and argument signature."""
186-
func = declare_extern(name, argtypes, ast.WaveformType())
222+
func = declare_extern(name, argtypes, ast.WaveformType(), annotations=annotations)
187223
return func

0 commit comments

Comments
 (0)