Skip to content

Commit 951c532

Browse files
Add support for returning values from defcal statements (#24)
* Added 'returns' statement to oqpy.Program for adding return statement to defcal and subroutine statements, added support for using Program.returns in subroutine decorator based on type hinting and added tests. * Update oqpy/subroutines.py Co-authored-by: Phil Reinhold <[email protected]> * �Added new types: port, frame and waveform to represent the types of PortVar, FrameVar and WaveformVar * port, waveform and frame now aliases of openpulse.ast PortType, WaveformType and FrameType * Added missing period at the end of a docstring Co-authored-by: Phil Reinhold <[email protected]>
1 parent 854d40b commit 951c532

File tree

4 files changed

+108
-1
lines changed

4 files changed

+108
-1
lines changed

oqpy/program.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,11 @@ def shift_scale(self, frame: AstConvertible, scale: AstConvertible) -> Program:
405405
self.function_call("shift_scale", [frame, scale])
406406
return self
407407

408+
def returns(self, expression: AstConvertible) -> Program:
409+
"""Return a statement from a function definition or a defcal statement."""
410+
self._add_statement(ast.ReturnStatement(to_ast(self, expression)))
411+
return self
412+
408413
def gate(
409414
self, qubits: AstConvertible | Iterable[AstConvertible], name: str, *args: Any
410415
) -> Program:

oqpy/pulse.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,11 @@
2424
from oqpy.base import AstConvertible
2525
from oqpy.classical_types import OQFunctionCall, _ClassicalVar
2626

27-
__all__ = ["PortVar", "WaveformVar", "FrameVar"]
27+
__all__ = ["PortVar", "WaveformVar", "FrameVar", "port", "waveform", "frame"]
28+
29+
port = ast.PortType()
30+
waveform = ast.WaveformType()
31+
frame = ast.FrameType()
2832

2933

3034
class PortVar(_ClassicalVar):

oqpy/subroutines.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,14 @@ def wrapper(program: oqpy.Program, *args: AstConvertible) -> OQFunctionCall:
9999
body.append(ast.ReturnStatement(to_ast(inner_prog, output)))
100100
elif output is None:
101101
return_type = None
102+
if type_hints.get("return", False):
103+
return_hint = type_hints["return"]()
104+
if isinstance(return_hint, _ClassicalVar):
105+
return_type = return_hint
106+
elif return_hint is not None:
107+
raise ValueError(
108+
f"Type hint for return variable on subroutine {name} is not an oqpy classical type."
109+
)
102110
else:
103111
raise ValueError(
104112
"Output type of subroutine {name} was neither oqpy expression nor None."

tests/test_directives.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,14 @@ def return1(prog: Program) -> float:
407407

408408
return1(prog)
409409

410+
with pytest.raises(ValueError):
411+
412+
@subroutine
413+
def return2(prog: Program) -> float:
414+
prog.returns(1.0)
415+
416+
return2(prog)
417+
410418
with pytest.raises(ValueError):
411419

412420
@subroutine
@@ -661,6 +669,88 @@ def test_defcals():
661669
)
662670

663671

672+
def test_returns():
673+
prog = Program()
674+
constant = declare_waveform_generator("constant", [("length", duration), ("iq", complex128)])
675+
676+
rx_port = PortVar("rx_port")
677+
tx_port = PortVar("tx_port")
678+
rx_frame = FrameVar(rx_port, 5.752e9, name="rx_frame")
679+
tx_frame = FrameVar(tx_port, 5.752e9, name="tx_frame")
680+
capture_v2 = oqpy.declare_extern(
681+
"capture_v2", [("output", oqpy.frame), ("duration", oqpy.duration)], oqpy.bit
682+
)
683+
684+
q0 = PhysicalQubits[0]
685+
686+
with defcal(prog, q0, "measure_v1", return_type=oqpy.bit):
687+
prog.play(tx_frame, constant(2.4e-6, 0.2))
688+
prog.returns(capture_v2(rx_frame, 2.4e-6))
689+
690+
@subroutine
691+
def increment_variable_return(prog: Program, i: IntVar) -> IntVar:
692+
prog.increment(i, 1)
693+
prog.returns(i)
694+
695+
j = IntVar(0, name="j")
696+
k = IntVar(0, name="k")
697+
prog.declare(j)
698+
prog.declare(k)
699+
prog.set(k, increment_variable_return(prog, j))
700+
701+
expected = textwrap.dedent(
702+
"""
703+
OPENQASM 3.0;
704+
extern constant(duration, complex[float[64]]) -> waveform;
705+
extern capture_v2(frame, duration) -> bit;
706+
def increment_variable_return(int[32] i) -> int[32] {
707+
i += 1;
708+
return i;
709+
}
710+
port rx_port;
711+
port tx_port;
712+
frame tx_frame = newframe(tx_port, 5752000000.0, 0);
713+
frame rx_frame = newframe(rx_port, 5752000000.0, 0);
714+
defcal measure_v1 $0 -> bit {
715+
play(tx_frame, constant(2400.0ns, 0.2));
716+
return capture_v2(rx_frame, 2400.0ns);
717+
}
718+
int[32] j = 0;
719+
int[32] k = 0;
720+
k = increment_variable_return(j);
721+
"""
722+
).strip()
723+
print(prog.to_qasm())
724+
assert prog.to_qasm() == expected
725+
726+
expected_defcal_measure_v1_q0 = textwrap.dedent(
727+
"""
728+
defcal measure_v1 $0 -> bit {
729+
play(tx_frame, constant(2400.0ns, 0.2));
730+
return capture_v2(rx_frame, 2400.0ns);
731+
}
732+
"""
733+
).strip()
734+
735+
assert (
736+
dumps(prog.defcals[(("$0",), "measure_v1", ())], indent=" ").strip()
737+
== expected_defcal_measure_v1_q0
738+
)
739+
740+
expected_function_definition = textwrap.dedent(
741+
"""
742+
def increment_variable_return(int[32] i) -> int[32] {
743+
i += 1;
744+
return i;
745+
}
746+
"""
747+
).strip()
748+
assert (
749+
dumps(prog.subroutines["increment_variable_return"], indent=" ").strip()
750+
== expected_function_definition
751+
)
752+
753+
664754
def test_ramsey_example():
665755
prog = Program()
666756
constant = declare_waveform_generator("constant", [("length", duration), ("iq", complex128)])

0 commit comments

Comments
 (0)