Skip to content

Commit 0af69d9

Browse files
authored
Fix for nested subroutines and if statements in subroutines (#53)
* Fix for nested subroutines and if statements in subroutines * Also autodeclare
1 parent 0bb2c90 commit 0af69d9

File tree

2 files changed

+43
-0
lines changed

2 files changed

+43
-0
lines changed

oqpy/subroutines.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,8 @@ def wrapper(
9797
for input_val in inputs.values():
9898
inner_prog._mark_var_declared(input_val)
9999
output = func(inner_prog, **inputs)
100+
inner_prog.autodeclare()
101+
inner_prog._state.finalize_if_clause()
100102
body = inner_prog._state.body
101103
if isinstance(output, OQPyExpression):
102104
return_type = output.type
@@ -115,6 +117,9 @@ def wrapper(
115117
raise ValueError(
116118
"Output type of subroutine {name} was neither oqpy expression nor None."
117119
)
120+
program.defcals.update(inner_prog.defcals)
121+
program.subroutines.update(inner_prog.subroutines)
122+
program.externs.update(inner_prog.externs)
118123
stmt = ast.SubroutineDefinition(
119124
identifier,
120125
arguments=arguments,

tests/test_directives.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1999,3 +1999,41 @@ def test_io_declaration():
19991999
).strip()
20002000
assert prog.to_qasm() == expected
20012001
_check_respects_type_hints(prog)
2002+
2003+
2004+
def test_nested_subroutines():
2005+
@oqpy.subroutine
2006+
def f(prog: oqpy.Program) -> oqpy.IntVar:
2007+
i = oqpy.IntVar(name="i", init_expression=1)
2008+
with oqpy.If(prog, i == 1):
2009+
prog.increment(i, 1)
2010+
return i
2011+
2012+
@oqpy.subroutine
2013+
def g(prog: oqpy.Program) -> oqpy.IntVar:
2014+
return f(prog)
2015+
2016+
2017+
prog = oqpy.Program()
2018+
x = oqpy.IntVar(name="x")
2019+
prog.set(x, g(prog))
2020+
2021+
expected = textwrap.dedent(
2022+
"""
2023+
OPENQASM 3.0;
2024+
def f() -> int[32] {
2025+
int[32] i = 1;
2026+
if (i == 1) {
2027+
i += 1;
2028+
}
2029+
return i;
2030+
}
2031+
def g() -> int[32] {
2032+
return f();
2033+
}
2034+
int[32] x;
2035+
x = g();
2036+
"""
2037+
).strip()
2038+
2039+
assert prog.to_qasm() == expected

0 commit comments

Comments
 (0)