Skip to content

Add CachedExpressionConvertible #72

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Sep 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion oqpy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
Protocol,
Sequence,
Union,
cast,
runtime_checkable,
)

Expand Down Expand Up @@ -339,6 +340,21 @@ def _to_oqpy_expression(self) -> HasToAst:
...


@runtime_checkable
class CachedExpressionConvertible(Protocol):
"""This is the protocol an object can implement in order to be usable as an expression.

The difference between this and `ExpressionConvertible` is that
this requires that the result of `_to_cached_oqpy_expression` be
constant across the lifetime of the OQPy Program. OQPy makes an
effort to minimize the number of calls to the AST constructor, but
no guarantees are made about this.
"""

def _to_cached_oqpy_expression(self) -> HasToAst:
...


class OQPyUnaryExpression(OQPyExpression):
"""An expression consisting of one expression preceded by an operator."""

Expand Down Expand Up @@ -435,14 +451,28 @@ def to_ast(self, program: Program) -> ast.Expression:


AstConvertible = Union[
HasToAst, bool, int, float, complex, Iterable, ExpressionConvertible, ast.Expression
HasToAst,
bool,
int,
float,
complex,
Iterable,
ExpressionConvertible,
CachedExpressionConvertible,
ast.Expression,
]


def to_ast(program: Program, item: AstConvertible) -> ast.Expression:
"""Convert an object to an AST node."""
if hasattr(item, "_to_oqpy_expression"):
item = cast(ExpressionConvertible, item)
return item._to_oqpy_expression().to_ast(program)
if hasattr(item, "_to_cached_oqpy_expression"):
if id(item) not in program.expr_cache:
item = cast(CachedExpressionConvertible, item)
program.expr_cache[id(item)] = item._to_cached_oqpy_expression().to_ast(program)
return program.expr_cache[id(item)]
if isinstance(item, (complex, np.complexfloating)):
if item.imag == 0:
return to_ast(program, item.real)
Expand Down
1 change: 1 addition & 0 deletions oqpy/classical_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
"ComplexVar",
"DurationVar",
"OQFunctionCall",
"OQIndexExpression",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am sneaking in one unrelated change in this PR. I see that previously this was updated (cf 862d0c1), but OQIndexExpression was not actually moved. It should be exported as before.

"StretchVar",
"_ClassicalVar",
"duration",
Expand Down
11 changes: 10 additions & 1 deletion oqpy/program.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,11 @@ def __init__(self, version: Optional[str] = "3.0", simplify_constants: bool = Tr
self.simplify_constants = simplify_constants
self.declared_subroutines: set[str] = set()
self.declared_gates: set[str] = set()
self.expr_cache: dict[int, ast.Expression] = {}
"""A cache of ast made by CachedExpressionConvertible objects used in this program.

This is used by `to_ast` to avoid repetitively evaluating ast conversion methods.
"""

if version is None or (
len(version.split(".")) in [1, 2]
Expand Down Expand Up @@ -188,7 +193,11 @@ def _add_var(self, var: Var) -> None:
existing_var = self.declared_vars.get(name)
if existing_var is None:
existing_var = self.undeclared_vars.get(name)
if existing_var is not None and not expr_matches(var, existing_var):
if (
existing_var is not None
and var is not existing_var
and not expr_matches(var, existing_var)
):
raise RuntimeError(f"Program has conflicting variables with name {name}")
if name not in self.declared_vars:
self.undeclared_vars[name] = var
Expand Down
11 changes: 10 additions & 1 deletion oqpy/timing.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,13 @@

from openpulse import ast

from oqpy.base import ExpressionConvertible, HasToAst, OQPyExpression, optional_ast
from oqpy.base import (
CachedExpressionConvertible,
ExpressionConvertible,
HasToAst,
OQPyExpression,
optional_ast,
)
from oqpy.classical_types import AstConvertible

if TYPE_CHECKING:
Expand Down Expand Up @@ -68,6 +74,9 @@ def convert_float_to_duration(time: AstConvertible) -> HasToAst:
if hasattr(time, "_to_oqpy_expression"):
time = cast(ExpressionConvertible, time)
return time._to_oqpy_expression()
if hasattr(time, "_to_cached_oqpy_expression"):
time = cast(CachedExpressionConvertible, time)
return time._to_cached_oqpy_expression()
raise TypeError(
f"Expected either float, int, HasToAst or ExpressionConverible: Got {type(time)}"
)
Expand Down
44 changes: 44 additions & 0 deletions tests/test_directives.py
Original file line number Diff line number Diff line change
Expand Up @@ -1518,6 +1518,50 @@ def _to_oqpy_expression(self):
_check_respects_type_hints(prog)


def test_cached_expression_convertible():
@dataclass
class A:
name: str
count: int = 0

def _to_cached_oqpy_expression(self):
self.count += 1
return DurationVar(1e-7, self.name)

@dataclass
class F:
name: str
count: int = 0

def _to_cached_oqpy_expression(self):
self.count += 1
return FrameVar(name=self.name)

frame = F(name="f1")
dur = A("dur")
prog = Program()
prog.set(dur, 2)
prog.delay(dur, frame)
prog.set(dur, 3)
expected = textwrap.dedent(
"""
OPENQASM 3.0;
duration dur = 100.0ns;
frame f1;
dur = 2;
delay[dur] f1;
dur = 3;
"""
).strip()
assert prog.to_qasm() == expected
_check_respects_type_hints(prog)
# This gets computed twice: once during program construction, and
# once in `convert_float_to_duration`
assert dur.count == 2
# This gets computed just once
assert frame.count == 1


def test_waveform_extern_arg_passing():
prog = Program()
constant = declare_waveform_generator("constant", [("length", duration), ("iq", complex128)])
Expand Down