Skip to content

Commit 862d0c1

Browse files
Qubit and BitVar supports get item (#67)
* Support BitVar indexing * move validation to OQIndexExpression * qubit class supports get item * add IndexedQubitArray class * move OQIndexExpression to classical_types * Update oqpy/quantum_types.py * Update oqpy/classical_types.py * Update oqpy/classical_types.py * Apply suggestions from code review --------- Co-authored-by: Phil Reinhold <[email protected]>
1 parent 0790911 commit 862d0c1

File tree

4 files changed

+86
-30
lines changed

4 files changed

+86
-30
lines changed

oqpy/classical_types.py

Lines changed: 25 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@
6262
"ComplexVar",
6363
"DurationVar",
6464
"OQFunctionCall",
65-
"OQIndexExpression",
6665
"StretchVar",
6766
"_ClassicalVar",
6867
"duration",
@@ -236,6 +235,25 @@ def __init__(self, *args: Any, size: int | None = None, **kwargs: Any):
236235
self.size = size
237236
super().__init__(*args, **kwargs, size=ast.IntegerLiteral(self.size) if self.size else None)
238237

238+
def _validate_getitem_index(self, index: AstConvertible) -> None:
239+
"""Validate the index and variable for `__getitem__`.
240+
241+
Args:
242+
var (_SizedVar): Variable to apply `__getitem__`.
243+
index (AstConvertible): Index for `__getitem__`.
244+
"""
245+
if self.size is None:
246+
raise TypeError(f"'{self.name}' is not subscriptable")
247+
248+
if isinstance(index, int):
249+
if not 0 <= index < self.size:
250+
raise IndexError("list index out of range.")
251+
elif isinstance(index, OQPyExpression):
252+
if not isinstance(index.type, (ast.IntType, ast.UintType)):
253+
raise IndexError("The list index must be an integer.")
254+
else:
255+
raise IndexError("The list index must be an integer.")
256+
239257

240258
_SizedVarT = TypeVar("_SizedVarT", bound=_SizedVar)
241259

@@ -273,22 +291,9 @@ class BitVar(_SizedVar):
273291

274292
type_cls = ast.BitType
275293

276-
def __getitem__(self, idx: Union[int, slice, Iterable[int]]) -> BitVar:
277-
if self.size is None:
278-
raise TypeError(f"'{self.type_cls}' object is not subscriptable")
279-
if isinstance(idx, int):
280-
if 0 <= idx < self.size:
281-
return BitVar(
282-
init_expression=ast.IndexExpression(
283-
ast.Identifier(self.name), [ast.IntegerLiteral(idx)]
284-
),
285-
name=f"{self.name}[{idx}]",
286-
needs_declaration=False,
287-
)
288-
else:
289-
raise IndexError("list index out of range.")
290-
else:
291-
raise IndexError("The list index must be an integer.")
294+
def __getitem__(self, index: AstConvertible) -> OQIndexExpression:
295+
self._validate_getitem_index(index)
296+
return OQIndexExpression(collection=self, index=index, type_=self.type_cls())
292297

293298

294299
class ComplexVar(_ClassicalVar):
@@ -394,18 +399,16 @@ def __init__(
394399
)
395400

396401
def __getitem__(self, index: AstConvertible) -> OQIndexExpression:
397-
return OQIndexExpression(collection=self, index=index)
402+
return OQIndexExpression(collection=self, index=index, type_=self.base_type().type_cls())
398403

399404

400405
class OQIndexExpression(OQPyExpression):
401406
"""An oqpy expression corresponding to an index expression."""
402407

403-
def __init__(self, collection: AstConvertible, index: AstConvertible):
408+
def __init__(self, collection: AstConvertible, index: AstConvertible, type_: ast.ClassicalType):
404409
self.collection = collection
405410
self.index = index
406-
407-
if isinstance(collection, ArrayVar):
408-
self.type = collection.base_type().type_cls()
411+
self.type = type_
409412

410413
def to_ast(self, program: Program) -> ast.IndexExpression:
411414
"""Converts this oqpy index expression into an ast node."""

oqpy/program.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -483,7 +483,7 @@ def gate(
483483
self, qubits: AstConvertible | Iterable[AstConvertible], name: str, *args: Any
484484
) -> Program:
485485
"""Apply a gate to a qubit or set of qubits."""
486-
if isinstance(qubits, quantum_types.Qubit):
486+
if isinstance(qubits, (quantum_types.Qubit, quantum_types.IndexedQubitArray)):
487487
qubits = [qubits]
488488
assert isinstance(qubits, Iterable)
489489
self._add_statement(

oqpy/quantum_types.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from oqpy.program import Program
3131

3232

33-
__all__ = ["Qubit", "QubitArray", "defcal", "gate", "PhysicalQubits", "Cal"]
33+
__all__ = ["Qubit", "defcal", "gate", "PhysicalQubits", "Cal"]
3434

3535

3636
class Qubit(Var):
@@ -39,11 +39,13 @@ class Qubit(Var):
3939
def __init__(
4040
self,
4141
name: str,
42+
size: Optional[int] = None,
4243
needs_declaration: bool = True,
4344
annotations: Sequence[str | tuple[str, str]] = (),
4445
):
4546
super().__init__(name, needs_declaration=needs_declaration)
4647
self.name = name
48+
self.size = size
4749
self.annotations = annotations
4850

4951
def to_ast(self, prog: Program) -> ast.Expression:
@@ -53,10 +55,18 @@ def to_ast(self, prog: Program) -> ast.Expression:
5355

5456
def make_declaration_statement(self, program: Program) -> ast.Statement:
5557
"""Make an ast statement that declares the OQpy variable."""
56-
decl = ast.QubitDeclaration(ast.Identifier(self.name), size=None)
58+
decl = ast.QubitDeclaration(
59+
ast.Identifier(self.name),
60+
size=ast.IntegerLiteral(self.size) if self.size else self.size,
61+
)
5762
decl.annotations = make_annotations(self.annotations)
5863
return decl
5964

65+
def __getitem__(self, index: AstConvertible) -> IndexedQubitArray:
66+
if self.size is None:
67+
raise TypeError(f"'{self.name}' is not subscriptable")
68+
return IndexedQubitArray(collection=self, index=index)
69+
6070

6171
class PhysicalQubits:
6272
"""Provides a means of accessing qubit variables corresponding to physical qubits.
@@ -68,9 +78,18 @@ def __class_getitem__(cls, item: int) -> Qubit:
6878
return Qubit(f"${item}", needs_declaration=False)
6979

7080

71-
# Todo (#51): support QubitArray
72-
class QubitArray:
73-
"""Represents an array of qubits."""
81+
class IndexedQubitArray:
82+
"""Represents an indexed qubit array."""
83+
84+
def __init__(self, collection: Qubit, index: AstConvertible):
85+
self.collection = collection
86+
self.index = index
87+
88+
def to_ast(self, program: Program) -> ast.IndexExpression:
89+
"""Converts this indexed qubit array into an ast node."""
90+
return ast.IndexExpression(
91+
collection=to_ast(program, self.collection), index=[to_ast(program, self.index)]
92+
)
7493

7594

7695
@contextlib.contextmanager

tests/test_directives.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import oqpy
3232
from oqpy import *
3333
from oqpy.base import OQPyExpression, expr_matches, logical_and, logical_or
34+
from oqpy.classical_types import OQIndexExpression
3435
from oqpy.quantum_types import PhysicalQubits
3536
from oqpy.timing import OQDurationLiteral
3637

@@ -126,6 +127,9 @@ def test_variable_declaration():
126127
prog = Program(version=None)
127128
prog.declare(vars)
128129
prog.set(arr[1], 0)
130+
index = IntVar(2, "index")
131+
prog.set(arr[index], 1)
132+
prog.set(arr[index + 1], 0)
129133

130134
with pytest.raises(IndexError):
131135
prog.set(arr[40], 2)
@@ -142,6 +146,7 @@ def test_variable_declaration():
142146

143147
expected = textwrap.dedent(
144148
"""
149+
int[32] index = 2;
145150
bool b = true;
146151
int[32] i = -4;
147152
uint[32] u = 5;
@@ -151,10 +156,12 @@ def test_variable_declaration():
151156
bit[20] arr;
152157
bit c;
153158
arr[1] = 0;
159+
arr[index] = 1;
160+
arr[index + 1] = 0;
154161
"""
155162
).strip()
156163

157-
assert isinstance(arr[14], BitVar)
164+
assert isinstance(arr[14], OQIndexExpression)
158165
assert prog.to_qasm() == expected
159166
_check_respects_type_hints(prog)
160167

@@ -2215,7 +2222,12 @@ def test_invalid_gates():
22152222
def test_gate_declarations():
22162223
prog = oqpy.Program()
22172224
q = oqpy.Qubit("q", needs_declaration=False)
2218-
with oqpy.gate(prog, q, "u", [oqpy.AngleVar(name="alpha"), oqpy.AngleVar(name="beta"), oqpy.AngleVar(name="gamma")]) as (alpha, beta, gamma):
2225+
with oqpy.gate(
2226+
prog,
2227+
q,
2228+
"u",
2229+
[oqpy.AngleVar(name="alpha"), oqpy.AngleVar(name="beta"), oqpy.AngleVar(name="gamma")],
2230+
) as (alpha, beta, gamma):
22192231
prog.gate(q, "a", alpha)
22202232
prog.gate(q, "b", beta)
22212233
prog.gate(q, "c", gamma)
@@ -2266,3 +2278,25 @@ def test_include():
22662278

22672279
assert prog.to_qasm() == expected
22682280

2281+
2282+
def test_qubit_array():
2283+
prog = oqpy.Program()
2284+
q = oqpy.Qubit("q", size=2)
2285+
prog.gate(q[0], "h")
2286+
prog.gate([q[0], q[1]], "cnot")
2287+
2288+
expected = textwrap.dedent(
2289+
"""
2290+
OPENQASM 3.0;
2291+
qubit[2] q;
2292+
h q[0];
2293+
cnot q[0], q[1];
2294+
"""
2295+
).strip()
2296+
2297+
assert prog.to_qasm() == expected
2298+
2299+
with pytest.raises(TypeError):
2300+
prog = oqpy.Program()
2301+
q = oqpy.Qubit("q")
2302+
prog.gate(q[0], "h")

0 commit comments

Comments
 (0)