Skip to content

Commit 93d10cd

Browse files
committed
fix regression with cache
1 parent 6efb6f5 commit 93d10cd

File tree

2 files changed

+9
-4
lines changed

2 files changed

+9
-4
lines changed

oqpy/base.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,7 @@ class CachedExpressionConvertible(Protocol):
363363

364364
_oqpy_cache_key: Hashable
365365

366-
def _to_cached_oqpy_expression(self) -> HasToAst: ... # pragma: no cover
366+
def _to_cached_oqpy_expression(self) -> AstConvertible: ... # pragma: no cover
367367

368368

369369
class OQPyUnaryExpression(OQPyExpression):
@@ -490,8 +490,10 @@ def to_ast(program: Program, item: AstConvertible) -> ast.Expression:
490490
if item._oqpy_cache_key is None:
491491
item._oqpy_cache_key = uuid.uuid1()
492492
if item._oqpy_cache_key not in program.expr_cache:
493-
program.expr_cache[item._oqpy_cache_key] = item._to_cached_oqpy_expression()
494-
item = program.expr_cache[item._oqpy_cache_key]
493+
program.expr_cache[item._oqpy_cache_key] = to_ast(
494+
program, item._to_cached_oqpy_expression()
495+
)
496+
return program.expr_cache[item._oqpy_cache_key]
495497
if isinstance(item, (complex, np.complexfloating)):
496498
if item.imag == 0:
497499
return to_ast(program, item.real)
@@ -507,7 +509,9 @@ def to_ast(program: Program, item: AstConvertible) -> ast.Expression:
507509
ast.ImaginaryLiteral(-item.imag),
508510
)
509511
return ast.BinaryExpression(
510-
ast.BinaryOperator["+"], ast.FloatLiteral(item.real), ast.ImaginaryLiteral(item.imag)
512+
ast.BinaryOperator["+"],
513+
ast.FloatLiteral(item.real),
514+
ast.ImaginaryLiteral(item.imag),
511515
)
512516
if isinstance(item, (bool, np.bool_)):
513517
return ast.BooleanLiteral(item)

tests/test_directives.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1665,6 +1665,7 @@ def _to_cached_oqpy_expression(self):
16651665
assert dur.count == 2
16661666
# This gets computed just once
16671667
assert frame.count == 1
1668+
assert all(isinstance(v, ast.QASMNode) for v in prog.expr_cache.values())
16681669

16691670

16701671
def test_waveform_extern_arg_passing():

0 commit comments

Comments
 (0)