Skip to content

Qubit and BitVar supports get item #67

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
15 changes: 15 additions & 0 deletions oqpy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,21 @@ def to_ast(self, program: Program) -> ast.BinaryExpression:
return ast.BinaryExpression(self.op, to_ast(program, self.lhs), to_ast(program, self.rhs))


class OQIndexExpression(OQPyExpression):
"""An oqpy expression corresponding to an index expression."""

def __init__(self, collection: AstConvertible, index: AstConvertible, type: ast.ClassicalType):
self.collection = collection
self.index = index
self.type = type

def to_ast(self, program: Program) -> ast.IndexExpression:
"""Converts this oqpy index expression into an ast node."""
return ast.IndexExpression(
collection=to_ast(program, self.collection), index=[to_ast(program, self.index)]
)


class Var(ABC):
"""Abstract base class for both classical and quantum variables."""

Expand Down
58 changes: 24 additions & 34 deletions oqpy/classical_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from oqpy.base import (
AstConvertible,
OQPyExpression,
OQIndexExpression,
Var,
make_annotations,
map_to_ast,
Expand All @@ -62,7 +63,6 @@
"ComplexVar",
"DurationVar",
"OQFunctionCall",
"OQIndexExpression",
"StretchVar",
"_ClassicalVar",
"duration",
Expand Down Expand Up @@ -273,22 +273,29 @@ class BitVar(_SizedVar):

type_cls = ast.BitType

def __getitem__(self, idx: Union[int, slice, Iterable[int]]) -> BitVar:
if self.size is None:
raise TypeError(f"'{self.type_cls}' object is not subscriptable")
if isinstance(idx, int):
if 0 <= idx < self.size:
return BitVar(
init_expression=ast.IndexExpression(
ast.Identifier(self.name), [ast.IntegerLiteral(idx)]
),
name=f"{self.name}[{idx}]",
needs_declaration=False,
)
else:
raise IndexError("list index out of range.")
else:
def __getitem__(self, index: AstConvertible) -> OQIndexExpression:
_validate_sizedvar_getitem(self, index)
return OQIndexExpression(collection=self, index=index, type=self.type_cls())


def _validate_sizedvar_getitem(var: _SizedVar, index: AstConvertible) -> None:
"""Validate the index and variable for `__getitem__`.

Args:
var (_SizedVar): Variable to apply `__getitem__`.
index (AstConvertible): Index for `__getitem__`.
"""
if var.size is None:
raise TypeError(f"'{var.name}' is not subscriptable")

if isinstance(index, int):
if not 0 <= index < var.size:
raise IndexError("list index out of range.")
elif isinstance(index, OQPyExpression):
if not isinstance(index.type, (ast.IntType, ast.UintType)):
raise IndexError("The list index must be an integer.")
else:
raise IndexError("The list index must be an integer.")


class ComplexVar(_ClassicalVar):
Expand Down Expand Up @@ -394,24 +401,7 @@ def __init__(
)

def __getitem__(self, index: AstConvertible) -> OQIndexExpression:
return OQIndexExpression(collection=self, index=index)


class OQIndexExpression(OQPyExpression):
"""An oqpy expression corresponding to an index expression."""

def __init__(self, collection: AstConvertible, index: AstConvertible):
self.collection = collection
self.index = index

if isinstance(collection, ArrayVar):
self.type = collection.base_type().type_cls()

def to_ast(self, program: Program) -> ast.IndexExpression:
"""Converts this oqpy index expression into an ast node."""
return ast.IndexExpression(
collection=to_ast(program, self.collection), index=[to_ast(program, self.index)]
)
return OQIndexExpression(collection=self, index=index, type=self.base_type().type_cls())


class OQFunctionCall(OQPyExpression):
Expand Down
8 changes: 6 additions & 2 deletions oqpy/program.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from oqpy import classical_types, quantum_types
from oqpy.base import (
AstConvertible,
OQIndexExpression,
Var,
expr_matches,
map_to_ast,
Expand Down Expand Up @@ -483,7 +484,10 @@ def gate(
self, qubits: AstConvertible | Iterable[AstConvertible], name: str, *args: Any
) -> Program:
"""Apply a gate to a qubit or set of qubits."""
if isinstance(qubits, quantum_types.Qubit):
if isinstance(qubits, quantum_types.Qubit) or (
isinstance(qubits, OQIndexExpression)
and isinstance(qubits.collection, quantum_types.Qubit)
):
qubits = [qubits]
assert isinstance(qubits, Iterable)
self._add_statement(
Expand Down Expand Up @@ -554,7 +558,7 @@ def do_expression(self, expression: AstConvertible) -> Program:

def set(
self,
var: classical_types._ClassicalVar | classical_types.OQIndexExpression,
var: classical_types._ClassicalVar | OQIndexExpression,
value: AstConvertible,
) -> Program:
"""Set a variable value."""
Expand Down
21 changes: 13 additions & 8 deletions oqpy/quantum_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@
from openpulse import ast
from openpulse.printer import dumps

from oqpy.base import AstConvertible, Var, make_annotations, to_ast
from oqpy.base import AstConvertible, Var, make_annotations, to_ast, OQIndexExpression
from oqpy.classical_types import AngleVar, _ClassicalVar

if TYPE_CHECKING:
from oqpy.program import Program


__all__ = ["Qubit", "QubitArray", "defcal", "gate", "PhysicalQubits", "Cal"]
__all__ = ["Qubit", "defcal", "gate", "PhysicalQubits", "Cal"]


class Qubit(Var):
Expand All @@ -39,11 +39,13 @@ class Qubit(Var):
def __init__(
self,
name: str,
size: int = None,
needs_declaration: bool = True,
annotations: Sequence[str | tuple[str, str]] = (),
):
super().__init__(name, needs_declaration=needs_declaration)
self.name = name
self.size = size
self.annotations = annotations

def to_ast(self, prog: Program) -> ast.Expression:
Expand All @@ -53,10 +55,18 @@ def to_ast(self, prog: Program) -> ast.Expression:

def make_declaration_statement(self, program: Program) -> ast.Statement:
"""Make an ast statement that declares the OQpy variable."""
decl = ast.QubitDeclaration(ast.Identifier(self.name), size=None)
decl = ast.QubitDeclaration(
ast.Identifier(self.name),
size=ast.IntegerLiteral(self.size) if self.size else self.size,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

size=ast.IntegerLiteral(self.size) if self.size else None,

is better

)
decl.annotations = make_annotations(self.annotations)
return decl

def __getitem__(self, index: AstConvertible) -> OQIndexExpression:
if self.size is None:
raise TypeError(f"'{self.name}' is not subscriptable")
return OQIndexExpression(collection=self, index=index, type=ast.Identifier)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think returning an OQIndexExpression here is not quite correct, since OQIndexExpression is a subtype of OQPyExpression which is only for classical typed expressions. I think here you could either make a new type (e.g. IndexedQubitArray) or just return the ast.IndexedExpresssion directly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the feedback, I updated and added a IndexedQubitArray class



class PhysicalQubits:
"""Provides a means of accessing qubit variables corresponding to physical qubits.
Expand All @@ -68,11 +78,6 @@ def __class_getitem__(cls, item: int) -> Qubit:
return Qubit(f"${item}", needs_declaration=False)


# Todo (#51): support QubitArray
class QubitArray:
"""Represents an array of qubits."""


@contextlib.contextmanager
def gate(
program: Program,
Expand Down
39 changes: 36 additions & 3 deletions tests/test_directives.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

import oqpy
from oqpy import *
from oqpy.base import OQPyExpression, expr_matches, logical_and, logical_or
from oqpy.base import OQPyExpression, expr_matches, logical_and, logical_or, OQIndexExpression
from oqpy.quantum_types import PhysicalQubits
from oqpy.timing import OQDurationLiteral

Expand Down Expand Up @@ -126,6 +126,9 @@ def test_variable_declaration():
prog = Program(version=None)
prog.declare(vars)
prog.set(arr[1], 0)
index = IntVar(2, "index")
prog.set(arr[index], 1)
prog.set(arr[index + 1], 0)

with pytest.raises(IndexError):
prog.set(arr[40], 2)
Expand All @@ -142,6 +145,7 @@ def test_variable_declaration():

expected = textwrap.dedent(
"""
int[32] index = 2;
bool b = true;
int[32] i = -4;
uint[32] u = 5;
Expand All @@ -151,10 +155,12 @@ def test_variable_declaration():
bit[20] arr;
bit c;
arr[1] = 0;
arr[index] = 1;
arr[index + 1] = 0;
"""
).strip()

assert isinstance(arr[14], BitVar)
assert isinstance(arr[14], OQIndexExpression)
assert prog.to_qasm() == expected
_check_respects_type_hints(prog)

Expand Down Expand Up @@ -2215,7 +2221,12 @@ def test_invalid_gates():
def test_gate_declarations():
prog = oqpy.Program()
q = oqpy.Qubit("q", needs_declaration=False)
with oqpy.gate(prog, q, "u", [oqpy.AngleVar(name="alpha"), oqpy.AngleVar(name="beta"), oqpy.AngleVar(name="gamma")]) as (alpha, beta, gamma):
with oqpy.gate(
prog,
q,
"u",
[oqpy.AngleVar(name="alpha"), oqpy.AngleVar(name="beta"), oqpy.AngleVar(name="gamma")],
) as (alpha, beta, gamma):
prog.gate(q, "a", alpha)
prog.gate(q, "b", beta)
prog.gate(q, "c", gamma)
Expand Down Expand Up @@ -2266,3 +2277,25 @@ def test_include():

assert prog.to_qasm() == expected


def test_qubit_array():
prog = oqpy.Program()
q = oqpy.Qubit("q", size=2)
prog.gate(q[0], "h")
prog.gate([q[0], q[1]], "cnot")

expected = textwrap.dedent(
"""
OPENQASM 3.0;
qubit[2] q;
h q[0];
cnot q[0], q[1];
"""
).strip()

assert prog.to_qasm() == expected

with pytest.raises(TypeError):
prog = oqpy.Program()
q = oqpy.Qubit("q")
prog.gate(q[0], "h")