Skip to content

Allow for loops with arbitrary variable type #26

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Feb 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 48 additions & 8 deletions oqpy/control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,21 @@
from __future__ import annotations

import contextlib
from typing import TYPE_CHECKING, Iterable, Iterator, Optional
from typing import TYPE_CHECKING, Iterable, Iterator, Optional, TypeVar, overload

from openpulse import ast

from oqpy.base import OQPyExpression, to_ast
from oqpy.classical_types import AstConvertible, IntVar, _ClassicalVar, convert_range
from oqpy.classical_types import (
AstConvertible,
DurationVar,
IntVar,
_ClassicalVar,
convert_range,
)
from oqpy.timing import make_duration

ClassicalVarT = TypeVar("ClassicalVarT", bound=_ClassicalVar)

if TYPE_CHECKING:
from oqpy.program import Program
Expand Down Expand Up @@ -68,12 +77,35 @@ def Else(program: Program) -> Iterator[None]:
program._state.add_else_clause(state.body)


# Overloads needed due mypy bug, see
# github.com/python/mypy/issues/8739
# https://github.com/python/mypy/issues/3737
@overload
def ForIn(
program: Program,
iterator: Iterable[AstConvertible] | range | AstConvertible,
identifier_name: Optional[str],
) -> contextlib._GeneratorContextManager[IntVar]:
...


@overload
def ForIn(
program: Program,
iterator: Iterable[AstConvertible] | range | AstConvertible,
identifier_name: Optional[str],
identifier_type: type[ClassicalVarT],
) -> contextlib._GeneratorContextManager[ClassicalVarT]:
...


@contextlib.contextmanager
def ForIn(
program: Program,
iterator: Iterable[AstConvertible] | range | AstConvertible,
identifier_name: Optional[str] = None,
) -> Iterator[IntVar]:
identifier_type: type[ClassicalVarT] | type[IntVar] = IntVar,
) -> Iterator[ClassicalVarT | IntVar]:
"""Context manager for looping a particular portion of a program.

.. code-block:: python
Expand All @@ -84,20 +116,28 @@ def ForIn(

"""
program._push()
var = IntVar(name=identifier_name, needs_declaration=False)
var = identifier_type(name=identifier_name, needs_declaration=False)
yield var
state = program._pop()

if isinstance(iterator, range):
iterator = convert_range(program, iterator)
# A range can only be iterated over integers.
assert identifier_type is IntVar, "A range can only be looped over an integer."
set_declaration = convert_range(program, iterator)
elif isinstance(iterator, Iterable):
iterator = ast.DiscreteSet([to_ast(program, i) for i in iterator])
if identifier_type is DurationVar:
iterator = (make_duration(i) for i in iterator)

set_declaration = ast.DiscreteSet([to_ast(program, i) for i in iterator])
elif isinstance(iterator, _ClassicalVar):
iterator = to_ast(program, iterator)
set_declaration = to_ast(program, iterator)
assert isinstance(set_declaration, ast.Identifier), type(set_declaration)
else:
raise TypeError(f"'{type(iterator)}' object is not iterable")

stmt = ast.ForInLoop(ast.IntType(size=None), var.to_ast(program), iterator, state.body)
stmt = ast.ForInLoop(
identifier_type.type_cls(), var.to_ast(program), set_declaration, state.body
)
program._add_statement(stmt)


Expand Down
62 changes: 62 additions & 0 deletions tests/test_directives.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,68 @@ def test_for_in():
assert prog.to_qasm() == expected


def test_for_in_var_types():
port = oqpy.PortVar("my_port")
frame = oqpy.FrameVar(port, 3e9, 0, "my_frame")

# Test over floating point array.
program = oqpy.Program()
frequencies = [0.1, 0.2, 0.5]
with oqpy.ForIn(program, frequencies, "frequency", FloatVar) as f:
program.set_frequency(frame, f)

expected = textwrap.dedent(
"""
OPENQASM 3.0;
port my_port;
frame my_frame = newframe(my_port, 3000000000.0, 0);
for float frequency in {0.1, 0.2, 0.5} {
set_frequency(my_frame, frequency);
}
"""
).strip()

assert program.to_qasm() == expected

# Test over duration array.
program = oqpy.Program()
delays = [1e-9, 2e-9, 5e-9, 10e-9, 1e-6]

with oqpy.ForIn(program, delays, "d", DurationVar) as delay:
program.delay(delay, frame)

expected = textwrap.dedent(
"""
OPENQASM 3.0;
port my_port;
frame my_frame = newframe(my_port, 3000000000.0, 0);
for duration d in {1.0ns, 2.0ns, 5.0ns, 10.0ns, 1000.0ns} {
delay[d] my_frame;
}
"""
).strip()

# Test over angle array
program = oqpy.Program()
phases = [0] + [oqpy.pi / i for i in range(10, 1, -2)]

with oqpy.ForIn(program, phases, "phi", AngleVar) as phase:
program.set_phase(phase, frame)

expected = textwrap.dedent(
"""
OPENQASM 3.0;
port my_port;
frame my_frame = newframe(my_port, 3000000000.0, 0);
for angle phi in {0, pi / 10, pi / 8, pi / 6, pi / 4, pi / 2} {
set_phase(phi, my_frame);
}
"""
).strip()

assert program.to_qasm() == expected


def test_while():
prog = Program()
j = IntVar(0, "j")
Expand Down