Skip to content

Commit 40757c1

Browse files
committed
fix regression with cache
1 parent 6efb6f5 commit 40757c1

File tree

2 files changed

+95
-21
lines changed

2 files changed

+95
-21
lines changed

oqpy/base.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,7 @@ class CachedExpressionConvertible(Protocol):
363363

364364
_oqpy_cache_key: Hashable
365365

366-
def _to_cached_oqpy_expression(self) -> HasToAst: ... # pragma: no cover
366+
def _to_cached_oqpy_expression(self) -> AstConvertible: ... # pragma: no cover
367367

368368

369369
class OQPyUnaryExpression(OQPyExpression):
@@ -490,8 +490,10 @@ def to_ast(program: Program, item: AstConvertible) -> ast.Expression:
490490
if item._oqpy_cache_key is None:
491491
item._oqpy_cache_key = uuid.uuid1()
492492
if item._oqpy_cache_key not in program.expr_cache:
493-
program.expr_cache[item._oqpy_cache_key] = item._to_cached_oqpy_expression()
494-
item = program.expr_cache[item._oqpy_cache_key]
493+
program.expr_cache[item._oqpy_cache_key] = to_ast(
494+
program, item._to_cached_oqpy_expression()
495+
)
496+
return program.expr_cache[item._oqpy_cache_key]
495497
if isinstance(item, (complex, np.complexfloating)):
496498
if item.imag == 0:
497499
return to_ast(program, item.real)
@@ -507,7 +509,9 @@ def to_ast(program: Program, item: AstConvertible) -> ast.Expression:
507509
ast.ImaginaryLiteral(-item.imag),
508510
)
509511
return ast.BinaryExpression(
510-
ast.BinaryOperator["+"], ast.FloatLiteral(item.real), ast.ImaginaryLiteral(item.imag)
512+
ast.BinaryOperator["+"],
513+
ast.FloatLiteral(item.real),
514+
ast.ImaginaryLiteral(item.imag),
511515
)
512516
if isinstance(item, (bool, np.bool_)):
513517
return ast.BooleanLiteral(item)

tests/test_directives.py

Lines changed: 87 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,13 @@
3030

3131
import oqpy
3232
from oqpy import *
33-
from oqpy.base import OQPyBinaryExpression, OQPyExpression, expr_matches, logical_and, logical_or
33+
from oqpy.base import (
34+
OQPyBinaryExpression,
35+
OQPyExpression,
36+
expr_matches,
37+
logical_and,
38+
logical_or,
39+
)
3440
from oqpy.classical_types import OQIndexExpression
3541
from oqpy.quantum_types import PhysicalQubits
3642
from oqpy.timing import OQDurationLiteral
@@ -227,19 +233,36 @@ def test_array_declaration():
227233
b = ArrayVar(name="b", init_expression=[True, False], dimensions=[2], base_type=BoolVar)
228234
i = ArrayVar(name="i", init_expression=[0, 1, 2, 3, 4], dimensions=[5], base_type=IntVar)
229235
i55 = ArrayVar(
230-
name="i55", init_expression=[0, 1, 2, 3, 4], dimensions=[5], base_type=IntVar[55]
236+
name="i55",
237+
init_expression=[0, 1, 2, 3, 4],
238+
dimensions=[5],
239+
base_type=IntVar[55],
231240
)
232241
u = ArrayVar(name="u", init_expression=[0, 1, 2, 3, 4], dimensions=[5], base_type=UintVar)
233242
x = ArrayVar(
234-
name="x", init_expression=[0e-9, 1e-9, 2e-9], dimensions=[3], base_type=DurationVar
243+
name="x",
244+
init_expression=[0e-9, 1e-9, 2e-9],
245+
dimensions=[3],
246+
base_type=DurationVar,
247+
)
248+
y = ArrayVar(
249+
name="y",
250+
init_expression=[0.0, 1.0, 2.0, 3.0],
251+
dimensions=[4],
252+
base_type=FloatVar,
235253
)
236-
y = ArrayVar(name="y", init_expression=[0.0, 1.0, 2.0, 3.0], dimensions=[4], base_type=FloatVar)
237254
ang = ArrayVar(
238-
name="ang", init_expression=[0.0, 1.0, 2.0, 3.0], dimensions=[4], base_type=AngleVar
255+
name="ang",
256+
init_expression=[0.0, 1.0, 2.0, 3.0],
257+
dimensions=[4],
258+
base_type=AngleVar,
239259
)
240260
comp = ArrayVar(name="comp", init_expression=[0, 1 + 1j], dimensions=[2], base_type=ComplexVar)
241261
comp55 = ArrayVar(
242-
name="comp55", init_expression=[0, 1 + 1j], dimensions=[2], base_type=ComplexVar[float_(55)]
262+
name="comp55",
263+
init_expression=[0, 1 + 1j],
264+
dimensions=[2],
265+
base_type=ComplexVar[float_(55)],
243266
)
244267
ang_partial = ArrayVar[AngleVar, 2](name="ang_part", init_expression=[oqpy.pi, oqpy.pi / 2])
245268
simple = ArrayVar[FloatVar](name="no_init", dimensions=[5])
@@ -253,7 +276,21 @@ def test_array_declaration():
253276
base_type=DurationVar,
254277
)
255278

256-
vars = [b, i, i55, u, x, y, ang, comp, comp55, ang_partial, simple, multidim, npinit]
279+
vars = [
280+
b,
281+
i,
282+
i55,
283+
u,
284+
x,
285+
y,
286+
ang,
287+
comp,
288+
comp55,
289+
ang_partial,
290+
simple,
291+
multidim,
292+
npinit,
293+
]
257294

258295
prog = oqpy.Program(version=None)
259296
prog.declare(vars)
@@ -728,7 +765,10 @@ def test_for_in_var_types():
728765
program = oqpy.Program()
729766
pyphases = [0] + [oqpy.pi / i for i in range(10, 1, -2)]
730767
phases = ArrayVar(
731-
name="phases", dimensions=[len(pyphases)], init_expression=pyphases, base_type=AngleVar
768+
name="phases",
769+
dimensions=[len(pyphases)],
770+
init_expression=pyphases,
771+
base_type=AngleVar,
732772
)
733773

734774
with oqpy.ForIn(program, range(len(pyphases)), "idx") as idx:
@@ -1199,7 +1239,8 @@ def test_defcals():
11991239
).strip()
12001240
assert (
12011241
dumps(
1202-
prog.defcals[(("$1", "$2"), "xy", ("angle[32] theta", "pi / 2"))], indent=" "
1242+
prog.defcals[(("$1", "$2"), "xy", ("angle[32] theta", "pi / 2"))],
1243+
indent=" ",
12031244
).strip()
12041245
== expect_defcal_xy_theta_pio2
12051246
)
@@ -1320,7 +1361,12 @@ def test_ramsey_example():
13201361
constant = declare_waveform_generator("constant", [("length", duration), ("iq", complex128)])
13211362
gaussian = declare_waveform_generator(
13221363
"gaussian",
1323-
[("length", duration), ("sigma", duration), ("amplitude", float64), ("phase", float64)],
1364+
[
1365+
("length", duration),
1366+
("sigma", duration),
1367+
("amplitude", float64),
1368+
("phase", float64),
1369+
],
13241370
)
13251371
tx_waveform = constant(2.4e-6, 0.2)
13261372

@@ -1442,7 +1488,12 @@ def test_rabi_example():
14421488
constant = declare_waveform_generator("constant", [("length", duration), ("iq", complex128)])
14431489
gaussian = declare_waveform_generator(
14441490
"gaussian",
1445-
[("length", duration), ("sigma", duration), ("amplitude", float64), ("phase", float64)],
1491+
[
1492+
("length", duration),
1493+
("sigma", duration),
1494+
("amplitude", float64),
1495+
("phase", float64),
1496+
],
14461497
)
14471498

14481499
zcu216_dac231_0 = PortVar("zcu216_dac231_0")
@@ -1665,6 +1716,7 @@ def _to_cached_oqpy_expression(self):
16651716
assert dur.count == 2
16661717
# This gets computed just once
16671718
assert frame.count == 1
1719+
assert all(isinstance(v, ast.QASMNode) for v in prog.expr_cache.values())
16681720

16691721

16701722
def test_waveform_extern_arg_passing():
@@ -1825,14 +1877,22 @@ def test_annotate():
18251877
prog = Program()
18261878
gaussian = declare_waveform_generator(
18271879
"gaussian",
1828-
[("length", duration), ("sigma", duration), ("amplitude", float64), ("phase", float64)],
1880+
[
1881+
("length", duration),
1882+
("sigma", duration),
1883+
("amplitude", float64),
1884+
("phase", float64),
1885+
],
18291886
annotations=["annotating_extern_decl"],
18301887
)
18311888

18321889
some_port = PortVar("some_port", annotations=["makeport", ("some_keyword", "some_command")])
18331890
another_port = PortVar("another_port", annotations=["makeport"])
18341891
q0_transmon_xy_frame = FrameVar(
1835-
some_port, 3911851971.26885, name="q0_transmon_xy_frame", annotations=["makeframe"]
1892+
some_port,
1893+
3911851971.26885,
1894+
name="q0_transmon_xy_frame",
1895+
annotations=["makeframe"],
18361896
)
18371897
rabi_pulse_wf = WaveformVar(
18381898
gaussian(5.2e-8, 1.3e-8, 1.0, 0.0), "rabi_pulse_wf", annotations=["makepulse"]
@@ -2191,7 +2251,11 @@ def test_ramsey_example_blog():
21912251
)
21922252
gaussian_waveform = oqpy.declare_waveform_generator(
21932253
"gaussian",
2194-
[("length", oqpy.duration), ("sigma", oqpy.duration), ("amplitude", oqpy.float64)],
2254+
[
2255+
("length", oqpy.duration),
2256+
("sigma", oqpy.duration),
2257+
("amplitude", oqpy.float64),
2258+
],
21952259
)
21962260

21972261
with oqpy.defcal(defcals_prog, qubit, "reset"):
@@ -2471,7 +2535,11 @@ def test_gate_declarations():
24712535
prog,
24722536
q,
24732537
"u",
2474-
[oqpy.AngleVar(name="alpha"), oqpy.AngleVar(name="beta"), oqpy.AngleVar(name="gamma")],
2538+
[
2539+
oqpy.AngleVar(name="alpha"),
2540+
oqpy.AngleVar(name="beta"),
2541+
oqpy.AngleVar(name="gamma"),
2542+
],
24752543
) as (alpha, beta, gamma):
24762544
prog.gate(q, "a", alpha)
24772545
prog.gate(q, "b", beta)
@@ -2629,6 +2697,7 @@ def test_box_with_negative_duration():
26292697
def test_expr_matches_handles_outside_data():
26302698
x1 = oqpy.FloatVar(3, name="x")
26312699
x2 = oqpy.FloatVar(3, name="x")
2700+
26322701
class MyEntity:
26332702
def __init__(self):
26342703
self.self_ref = self
@@ -2643,15 +2712,15 @@ def __eq__(self, other):
26432712
class MyEntityNoEq:
26442713
def __init__(self):
26452714
self.self_ref = self
2715+
26462716
def __eq__(self, other):
26472717
raise RuntimeError("Eq not allowed")
26482718

26492719
x1._entity = MyEntityNoEq()
26502720
x2._entity = x1._entity
26512721
oqpy.base.expr_matches(x1, x2)
26522722

2653-
class MyFloatVar(oqpy.FloatVar):
2654-
...
2723+
class MyFloatVar(oqpy.FloatVar): ...
26552724

26562725
x1 = MyFloatVar(3, name="x")
26572726
x2 = MyFloatVar(3, name="x")
@@ -2660,6 +2729,7 @@ class MyFloatVar(oqpy.FloatVar):
26602729

26612730
class MyFloatVarWithIgnoredData(oqpy.FloatVar):
26622731
ignored: int
2732+
26632733
def _expr_matches(self, other):
26642734
if not isinstance(other, type(self)):
26652735
return False

0 commit comments

Comments
 (0)