Skip to content

Commit f9d6fea

Browse files
Consider cases when ExpressionConvertible returns float, int, duration (#89)
* consider case when ExpressionConvertible returns float, int, duration * use to_ast to create the ast node after _to_oqpy_expression * allow use of str in OQPyBinaryExpression * change type hints * black --------- Co-authored-by: Phil Reinhold <[email protected]>
1 parent 08671da commit f9d6fea

File tree

3 files changed

+35
-14
lines changed

3 files changed

+35
-14
lines changed

oqpy/base.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,7 @@ def expr_matches(a: Any, b: Any) -> bool:
338338
class ExpressionConvertible(Protocol):
339339
"""This is the protocol an object can implement in order to be usable as an expression."""
340340

341-
def _to_oqpy_expression(self) -> HasToAst: ... # pragma: no cover
341+
def _to_oqpy_expression(self) -> AstConvertible: ... # pragma: no cover
342342

343343

344344
@runtime_checkable
@@ -379,12 +379,17 @@ class OQPyBinaryExpression(OQPyExpression):
379379

380380
def __init__(
381381
self,
382-
op: ast.BinaryOperator,
382+
op: ast.BinaryOperator | str,
383383
lhs: AstConvertible,
384384
rhs: AstConvertible,
385385
ast_type: ast.ClassicalType | None = None,
386386
):
387387
super().__init__()
388+
if isinstance(op, str):
389+
try:
390+
op = ast.BinaryOperator[op]
391+
except KeyError as e:
392+
raise ValueError(f"Invalid binary operator {op}") from e
388393
self.op = op
389394
self.lhs = lhs
390395
self.rhs = rhs
@@ -396,7 +401,9 @@ def __init__(
396401
elif isinstance(rhs, OQPyExpression):
397402
ast_type = rhs.type
398403
else:
399-
raise TypeError("Neither lhs nor rhs is an expression?")
404+
raise TypeError(
405+
"Cannot infer ast_type from lhs or rhs. Please provide it if possible."
406+
)
400407
self.type = ast_type
401408

402409
# Adding floats to durations is not allowed. So we promote types as necessary.
@@ -468,17 +475,14 @@ def to_ast(self, program: Program) -> ast.Expression:
468475
def to_ast(program: Program, item: AstConvertible) -> ast.Expression:
469476
"""Convert an object to an AST node."""
470477
if hasattr(item, "_to_oqpy_expression"):
471-
item = cast(ExpressionConvertible, item)
472-
return item._to_oqpy_expression().to_ast(program)
478+
item = cast(ExpressionConvertible, item)._to_oqpy_expression()
473479
if hasattr(item, "_to_cached_oqpy_expression"):
474480
item = cast(CachedExpressionConvertible, item)
475481
if item._oqpy_cache_key is None:
476482
item._oqpy_cache_key = uuid.uuid1()
477483
if item._oqpy_cache_key not in program.expr_cache:
478-
program.expr_cache[item._oqpy_cache_key] = item._to_cached_oqpy_expression().to_ast(
479-
program
480-
)
481-
return program.expr_cache[item._oqpy_cache_key]
484+
program.expr_cache[item._oqpy_cache_key] = item._to_cached_oqpy_expression()
485+
item = program.expr_cache[item._oqpy_cache_key]
482486
if isinstance(item, (complex, np.complexfloating)):
483487
if item.imag == 0:
484488
return to_ast(program, item.real)

oqpy/timing.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,16 +68,16 @@ def convert_float_to_duration(time: AstConvertible, require_nonnegative: bool =
6868
require_nonnegative: if True, raise an exception if the time value is known to
6969
be negative.
7070
"""
71-
if isinstance(time, (float, int)):
72-
if require_nonnegative and time < 0:
73-
raise ValueError(f"Expected a non-negative duration, but got {time}")
74-
return OQDurationLiteral(time)
7571
if hasattr(time, "_to_oqpy_expression"):
7672
time = cast(ExpressionConvertible, time)
7773
time = time._to_oqpy_expression()
7874
if hasattr(time, "_to_cached_oqpy_expression"):
7975
time = cast(CachedExpressionConvertible, time)
8076
time = time._to_cached_oqpy_expression()
77+
if isinstance(time, (float, int)):
78+
if require_nonnegative and time < 0:
79+
raise ValueError(f"Expected a non-negative duration, but got {time}")
80+
return OQDurationLiteral(time)
8181
if isinstance(time, OQPyExpression):
8282
if isinstance(time.type, (ast.UintType, ast.IntType, ast.FloatType)):
8383
time = time * OQDurationLiteral(1)

tests/test_directives.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030

3131
import oqpy
3232
from oqpy import *
33-
from oqpy.base import OQPyExpression, expr_matches, logical_and, logical_or
33+
from oqpy.base import OQPyBinaryExpression, OQPyExpression, expr_matches, logical_and, logical_or
3434
from oqpy.classical_types import OQIndexExpression
3535
from oqpy.quantum_types import PhysicalQubits
3636
from oqpy.timing import OQDurationLiteral
@@ -421,6 +421,7 @@ def test_binary_expressions():
421421
prog.set(d, 5e-9 - d)
422422
prog.set(d, d + convert_float_to_duration(10e-9))
423423
prog.set(f, d / convert_float_to_duration(1))
424+
prog.set(k, OQPyBinaryExpression("+", 2, k))
424425

425426
with pytest.raises(ValueError):
426427
prog.set(f, "a" * i)
@@ -436,6 +437,8 @@ def test_binary_expressions():
436437
prog.set(d, 5j / d)
437438
with pytest.raises(TypeError):
438439
prog.set(d, 5j * d)
440+
with pytest.raises(ValueError):
441+
OQPyBinaryExpression(".", d, d)
439442

440443
expected = textwrap.dedent(
441444
"""
@@ -479,6 +482,7 @@ def test_binary_expressions():
479482
d = 5.0ns - d;
480483
d = d + 10.0ns;
481484
f = d / 1s;
485+
k = 2 + k;
482486
"""
483487
).strip()
484488

@@ -1583,21 +1587,34 @@ class B:
15831587
def _to_oqpy_expression(self):
15841588
return FloatVar(1e-7, self.name)
15851589

1590+
@dataclass
1591+
class C:
1592+
def _to_oqpy_expression(self):
1593+
return 1e-7
1594+
1595+
def __rmul__(self, other):
1596+
return other * self._to_oqpy_expression()
1597+
15861598
frame = FrameVar(name="f1")
15871599
prog = Program()
15881600
prog.set(A("a1"), 2)
1601+
prog.set(FloatVar(name="c1"), 3 * C())
15891602
prog.delay(A("a2"), frame)
15901603
prog.delay(B("b1"), frame)
1604+
prog.delay(C(), frame)
15911605
expected = textwrap.dedent(
15921606
"""
15931607
OPENQASM 3.0;
15941608
duration a1 = 100.0ns;
1609+
float[64] c1;
15951610
duration a2 = 100.0ns;
15961611
frame f1;
15971612
float[64] b1 = 1e-07;
15981613
a1 = 2;
1614+
c1 = 3e-07;
15991615
delay[a2] f1;
16001616
delay[b1 * 1s] f1;
1617+
delay[100.0ns] f1;
16011618
"""
16021619
).strip()
16031620
assert prog.to_qasm() == expected

0 commit comments

Comments
 (0)