Skip to content

keep track of the subroutine definition order #52

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jul 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 21 additions & 5 deletions oqpy/program.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def __init__(self, version: Optional[str] = "3.0", simplify_constants: bool = Tr
self.declared_vars: dict[str, Var] = {}
self.undeclared_vars: dict[str, Var] = {}
self.simplify_constants = simplify_constants
self.declared_subroutines: set[str] = set()

if version is None or (
len(version.split(".")) in [1, 2]
Expand All @@ -121,7 +122,8 @@ def __iadd__(self, other: Program) -> Program:
self._state.if_clause = other._state.if_clause
self._state.finalize_if_clause()
self.defcals.update(other.defcals)
self.subroutines.update(other.subroutines)
for name, subroutine_stmt in other.subroutines.items():
self._add_subroutine(name, subroutine_stmt)
self.externs.update(other.externs)
for var in other.declared_vars.values():
self._mark_var_declared(var)
Expand Down Expand Up @@ -206,12 +208,16 @@ def _add_statement(self, stmt: ast.Statement) -> None:
"""Add a statment to the current context's program state."""
self._state.add_statement(stmt)

def _add_subroutine(self, name: str, stmt: ast.SubroutineDefinition) -> None:
def _add_subroutine(
self, name: str, stmt: ast.SubroutineDefinition, needs_declaration: bool = True
) -> None:
"""Register a subroutine which has been used.

Subroutines are added to the top of the program upon conversion to ast.
"""
self.subroutines[name] = stmt
if not needs_declaration:
self.declared_subroutines.add(name)

def _add_defcal(
self,
Expand Down Expand Up @@ -280,7 +286,11 @@ def to_ast(
statements = []
if include_externs:
statements += self._make_externs_statements(encal_declarations)
statements += list(self.subroutines.values()) + self._state.body
statements += [
self.subroutines[subroutine_name]
for subroutine_name in self.subroutines
if subroutine_name not in self.declared_subroutines
] + self._state.body
if encal:
statements = [ast.CalibrationStatement(statements)]
if encal_declarations:
Expand Down Expand Up @@ -342,12 +352,18 @@ def declare(
openqasm_vars.reverse()

for var in openqasm_vars:
stmt = var.make_declaration_statement(self)
if callable(var) and hasattr(var, "subroutine_declaration"):
name, stmt = var.subroutine_declaration # type: ignore[attr-defined]
self._add_subroutine(name, stmt, needs_declaration=False)
else:
stmt = var.make_declaration_statement(self)
self._mark_var_declared(var)

if to_beginning:
self._state.body.insert(0, stmt)
else:
self._add_statement(stmt)
self._mark_var_declared(var)

if openpulse_vars:
cal_stmt = ast.CalibrationStatement([])
for var in openpulse_vars:
Expand Down
164 changes: 84 additions & 80 deletions oqpy/subroutines.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,28 @@
from oqpy.quantum_types import Qubit
from oqpy.timing import convert_float_to_duration

__all__ = ["subroutine", "annotate_subroutine", "declare_extern", "declare_waveform_generator"]
__all__ = ["subroutine", "declare_extern", "declare_waveform_generator"]

SubroutineParams = [oqpy.Program, VarArg(AstConvertible)]

FnType = TypeVar("FnType", bound=Callable[..., Any])


def enable_decorator_arguments(f: FnType) -> Callable[..., FnType]:
@functools.wraps(f)
def decorator(*args, **kwargs): # type: ignore[no-untyped-def]
if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
return f(args[0])
else:
return lambda realf: f(realf, *args, **kwargs)

return decorator


@enable_decorator_arguments
def subroutine(
func: Callable[[oqpy.Program, VarArg(AstConvertible)], AstConvertible | None]
func: Callable[[oqpy.Program, VarArg(AstConvertible)], AstConvertible | None],
annotations: Sequence[str | tuple[str, str]] = (),
) -> Callable[[oqpy.Program, VarArg(AstConvertible)], OQFunctionCall]:
"""Decorator to declare a subroutine.

Expand All @@ -47,7 +62,7 @@ def subroutine(

.. code-block:: python

@subroutine
@subroutine(annotations=("optimize", "-O3"))
def increment_variable(program: Program, i: IntVar):
program.increment(i, 1)

Expand All @@ -58,105 +73,94 @@ def increment_variable(program: Program, i: IntVar):

.. code-block:: qasm3

@optimize -O3
def increment_variable(int[32] i) {
i += 1;
}

int[32] j = 0;
increment_variable(j);

Args:
func (Callable[[oqpy.Program, VarArg(AstConvertible)], AstConvertible | None]):
function to decorate. Its first argument must be an OQpy program.
annotations (Sequence[str | tuple[str, str]]): a collection of strings or
tuples of string that annotate the subroutine.

Returns:
Callable[[oqpy.Program, VarArg(AstConvertible)], AstConvertible | None]:
decorated function with added subroutine_declaration attribute.
"""
name = func.__name__
identifier = ast.Identifier(func.__name__)
argnames = list(inspect.signature(func).parameters.keys())
type_hints = get_type_hints(func)
inputs = {} # used as inputs when calling the actual python function
arguments = [] # used in the ast definition of the subroutine
for argname in argnames[1:]: # arg 0 should be program
if argname not in type_hints:
raise ValueError(f"No type hint provided for {argname} on subroutine {name}.")
input_ = inputs[argname] = type_hints[argname](name=argname)

if isinstance(input_, _ClassicalVar):
arguments.append(ast.ClassicalArgument(input_.type, ast.Identifier(argname)))
elif isinstance(input_, Qubit):
arguments.append(ast.QuantumArgument(ast.Identifier(input_.name), None))
else:
raise ValueError(
f"Type hint for {argname} on subroutine {name} is not an oqpy variable type."
)

inner_prog = oqpy.Program()
for input_val in inputs.values():
inner_prog._mark_var_declared(input_val)
output = func(inner_prog, **inputs)
inner_prog.autodeclare()
inner_prog._state.finalize_if_clause()
body = inner_prog._state.body
if isinstance(output, OQPyExpression):
return_type = output.type
body.append(ast.ReturnStatement(to_ast(inner_prog, output)))
elif output is None:
return_type = None
if type_hints.get("return", False):
return_hint = type_hints["return"]()
if isinstance(return_hint, _ClassicalVar):
return_type = return_hint.type
elif return_hint is not None:
raise ValueError(
f"Type hint for return variable on subroutine {name} is not an oqpy classical type."
)
else:
raise ValueError("Output type of subroutine {name} was neither oqpy expression nor None.")
stmt = ast.SubroutineDefinition(
identifier,
arguments=arguments,
return_type=return_type,
body=body,
)
stmt.annotations = make_annotations(annotations)

@functools.wraps(func)
def wrapper(
program: oqpy.Program,
*args: AstConvertible,
annotations: Sequence[str | tuple[str, str]] = (),
) -> OQFunctionCall:
name = func.__name__
identifier = ast.Identifier(func.__name__)
argnames = list(inspect.signature(func).parameters.keys())
type_hints = get_type_hints(func)
inputs = {} # used as inputs when calling the actual python function
arguments = [] # used in the ast definition of the subroutine
for argname in argnames[1:]: # arg 0 should be program
if argname not in type_hints:
raise ValueError(f"No type hint provided for {argname} on subroutine {name}.")
input_ = inputs[argname] = type_hints[argname](name=argname)

if isinstance(input_, _ClassicalVar):
arguments.append(ast.ClassicalArgument(input_.type, ast.Identifier(argname)))
elif isinstance(input_, Qubit):
arguments.append(ast.QuantumArgument(ast.Identifier(input_.name), None))
else:
raise ValueError(
f"Type hint for {argname} on subroutine {name} is not an oqpy variable type."
)

inner_prog = oqpy.Program()
for input_val in inputs.values():
inner_prog._mark_var_declared(input_val)
output = func(inner_prog, **inputs)
inner_prog.autodeclare()
inner_prog._state.finalize_if_clause()
body = inner_prog._state.body
if isinstance(output, OQPyExpression):
return_type = output.type
body.append(ast.ReturnStatement(to_ast(inner_prog, output)))
elif output is None:
return_type = None
if type_hints.get("return", False):
return_hint = type_hints["return"]()
if isinstance(return_hint, _ClassicalVar):
return_type = return_hint.type
elif return_hint is not None:
raise ValueError(
f"Type hint for return variable on subroutine {name} is not an oqpy classical type."
)
else:
raise ValueError(
"Output type of subroutine {name} was neither oqpy expression nor None."
)
program.defcals.update(inner_prog.defcals)
program.subroutines.update(inner_prog.subroutines)
for name, subroutine_stmt in inner_prog.subroutines.items():
program._add_subroutine(name, subroutine_stmt)
program.externs.update(inner_prog.externs)
stmt = ast.SubroutineDefinition(
return OQFunctionCall(
identifier,
arguments=arguments,
return_type=return_type,
body=body,
args,
return_type,
subroutine_decl=stmt,
)
stmt.annotations = make_annotations(annotations)
return OQFunctionCall(identifier, args, return_type, subroutine_decl=stmt)

setattr(wrapper, "subroutine_declaration", (name, stmt))
return wrapper


FnType = TypeVar("FnType", bound=Callable[..., Any])


def annotate_subroutine(keyword: str, command: str | None = None) -> Callable[[FnType], FnType]:
"""Add annotation to a subroutine."""

def annotate_subroutine_decorator(func: FnType) -> FnType:
@functools.wraps(func)
def wrapper(
program: oqpy.Program,
*args: AstConvertible,
annotations: Sequence[str | tuple[str, str]] = (),
) -> OQFunctionCall:
new_ann: str | tuple[str, str]
if command is not None:
new_ann = keyword, command
else:
new_ann = keyword
return func(program, *args, annotations=list(annotations) + [new_ann])

return wrapper # type: ignore[return-value]

return annotate_subroutine_decorator


def declare_extern(
name: str,
args: list[tuple[str, ast.ClassicalType]],
Expand Down
69 changes: 65 additions & 4 deletions tests/test_directives.py
Original file line number Diff line number Diff line change
Expand Up @@ -772,6 +772,8 @@ def multiply(prog: Program, x: IntVar, y: IntVar) -> IntVar:
def declare(prog: Program, x: IntVar):
prog.declare([x])

# This won't define a subroutine because it was not called with do_expression.
# The call is NOT added to the program neither
declare(prog, y)

@subroutine
Expand Down Expand Up @@ -824,6 +826,42 @@ def delay50ns(qubit q) {
_check_respects_type_hints(prog)


def test_subroutine_order():
prog = Program()

@subroutine
def delay50ns(prog: Program, q: Qubit) -> None:
prog.delay(50e-9, q)

@subroutine
def multiply(prog: Program, x: IntVar, y: IntVar) -> IntVar:
return x * y

y = IntVar(2, "y")
prog.declare([delay50ns, multiply, y])
prog.set(y, multiply(prog, y, 3))
q = PhysicalQubits[0]
prog.do_expression(delay50ns(prog, q))

expected = textwrap.dedent(
"""
OPENQASM 3.0;
def delay50ns(qubit q) {
delay[50.0ns] q;
}
def multiply(int[32] x, int[32] y) -> int[32] {
return x * y;
}
int[32] y = 2;
y = multiply(y, 3);
delay50ns($0);
"""
).strip()

assert prog.to_qasm() == expected
_check_respects_type_hints(prog)


def test_box_and_timings():
constant = declare_waveform_generator("constant", [("length", duration), ("iq", complex128)])

Expand Down Expand Up @@ -951,7 +989,6 @@ def test_defcals():
prog.capture(rx_frame, constant(2.4e-6, 1))

with pytest.raises(AssertionError):

with defcal(prog, q2, "readout", return_type=bool):
prog.play(tx_frame, constant(2.4e-6, 0.2))
prog.capture(rx_frame, constant(2.4e-6, 1))
Expand Down Expand Up @@ -1552,9 +1589,7 @@ def test_annotate():
q1 = Qubit("q1", annotations=["some_qubit"])
q2 = Qubit("q2", annotations=["other_qubit"])

@annotate_subroutine("inline")
@annotate_subroutine("optimize", "-O3")
@subroutine
@subroutine(annotations=["inline", ("optimize", "-O3")])
def f(prog: Program, x: IntVar) -> IntVar:
return x

Expand Down Expand Up @@ -1656,6 +1691,32 @@ def f(int[32] x) -> int[32] {
_check_respects_type_hints(prog)


def test_in_place_subroutine_declaration():
@subroutine(annotations=["inline", ("optimize", "-O3")])
def f(prog: Program, x: IntVar) -> IntVar:
return x

prog = Program()
i = IntVar(0, name="i")
prog.declare([i,f])
prog.increment(i, 1)

expected = textwrap.dedent(
"""
OPENQASM 3.0;
int[32] i = 0;
@inline
@optimize -O3
def f(int[32] x) -> int[32] {
return x;
}
i += 1;
"""
).strip()
assert prog.to_qasm() == expected
_check_respects_type_hints(prog)


def test_var_and_expr_matches():
p1 = PortVar("p1")
p2 = PortVar("p2")
Expand Down