Skip to content

Commit f0885b7

Browse files
support output assignment for function calls (#70)
* support return assignment for function calls * update test --------- Co-authored-by: Phil Reinhold <[email protected]>
1 parent 9bee4c7 commit f0885b7

File tree

2 files changed

+35
-5
lines changed

2 files changed

+35
-5
lines changed

oqpy/program.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -432,11 +432,18 @@ def barrier(self, qubits_or_frames: Iterable[AstConvertible]) -> Program:
432432
self._add_statement(ast.QuantumBarrier(ast_qubits_or_frames))
433433
return self
434434

435-
def function_call(self, name: str, args: Iterable[AstConvertible]) -> None:
436-
"""Add a function call."""
437-
self._add_statement(
438-
ast.ExpressionStatement(ast.FunctionCall(ast.Identifier(name), map_to_ast(self, args)))
439-
)
435+
def function_call(
436+
self,
437+
name: str,
438+
args: Iterable[AstConvertible],
439+
assigns_to: AstConvertible = None,
440+
) -> None:
441+
"""Add a function call with an optional output assignment."""
442+
function_call_node = ast.FunctionCall(ast.Identifier(name), map_to_ast(self, args))
443+
if assigns_to is None:
444+
self.do_expression(function_call_node)
445+
else:
446+
self._do_assignment(to_ast(self, assigns_to), "=", function_call_node)
440447

441448
def play(self, frame: AstConvertible, waveform: AstConvertible) -> Program:
442449
"""Play a waveform on a particular frame."""

tests/test_directives.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2357,3 +2357,26 @@ def test_qubit_array():
23572357
prog_with_errors.gate(q0, "h")
23582358
with pytest.raises(ValueError):
23592359
prog_with_errors.to_qasm()
2360+
2361+
2362+
@pytest.mark.parametrize(
2363+
"args,assigns_to,expected",
2364+
[
2365+
([], None, "OPENQASM 3.0;\nmy_function();"),
2366+
(
2367+
[oqpy.BitVar(name="a0"), oqpy.BitVar(name="a1")],
2368+
None,
2369+
"OPENQASM 3.0;\nbit a0;\nbit a1;\nmy_function(a0, a1);",
2370+
),
2371+
(
2372+
[oqpy.BitVar(name="a0")],
2373+
oqpy.BitVar(name="b0"),
2374+
"OPENQASM 3.0;\nbit a0;\nbit b0;\nb0 = my_function(a0);",
2375+
),
2376+
],
2377+
)
2378+
def test_function_call(args, assigns_to, expected):
2379+
prog = Program()
2380+
prog.function_call("my_function", args, assigns_to)
2381+
assert prog.to_qasm() == expected
2382+
_check_respects_type_hints(prog)

0 commit comments

Comments
 (0)