Skip to content

Commit db9ec10

Browse files
authored
Add support for arrays (#28)
* Add array variable * Enable indexing * Add a test * More tests * Allow specifying type and dimensions in declaration * Explicitly test setting with variable * Typing: allow assignment to indexed expressions * Union type typealias * Fix comment, provide OQIndexExpression from the module * Pydocstyle Why complain now? * Add base type to ComplexVar and propagate it to arrays
1 parent bfd3e28 commit db9ec10

File tree

3 files changed

+186
-3
lines changed

3 files changed

+186
-3
lines changed

oqpy/classical_types.py

Lines changed: 80 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939

4040
__all__ = [
4141
"pi",
42+
"ArrayVar",
4243
"BoolVar",
4344
"IntVar",
4445
"UintVar",
@@ -48,6 +49,7 @@
4849
"ComplexVar",
4950
"DurationVar",
5051
"OQFunctionCall",
52+
"OQIndexExpression",
5153
"StretchVar",
5254
"_ClassicalVar",
5355
"duration",
@@ -272,18 +274,20 @@ class ComplexVar(_ClassicalVar):
272274
"""An oqpy variable with bit type."""
273275

274276
type_cls = ast.ComplexType
277+
base_type: ast.FloatType = float64
275278

276-
def __class_getitem__(cls, item: Type[ast.FloatType]) -> Callable[..., ComplexVar]:
279+
def __class_getitem__(cls, item: ast.FloatType) -> Callable[..., ComplexVar]:
277280
return functools.partial(cls, base_type=item)
278281

279282
def __init__(
280283
self,
281284
init_expression: AstConvertible | None = None,
282285
*args: Any,
283-
base_type: Type[ast.FloatType] = float64,
286+
base_type: ast.FloatType = float64,
284287
**kwargs: Any,
285288
) -> None:
286289
assert isinstance(base_type, ast.FloatType)
290+
self.base_type = base_type
287291

288292
if not isinstance(init_expression, (complex, type(None), OQPyExpression)):
289293
init_expression = complex(init_expression) # type: ignore[arg-type]
@@ -313,6 +317,80 @@ class StretchVar(_ClassicalVar):
313317
type_cls = ast.StretchType
314318

315319

320+
AllowedArrayTypes = Union[_SizedVar, DurationVar, BoolVar, ComplexVar]
321+
322+
323+
class ArrayVar(_ClassicalVar):
324+
"""An oqpy array variable."""
325+
326+
type_cls = ast.ArrayType
327+
dimensions: list[int]
328+
base_type: type[AllowedArrayTypes]
329+
330+
def __class_getitem__(
331+
cls, item: tuple[type[AllowedArrayTypes], int] | type[AllowedArrayTypes]
332+
) -> Callable[..., ArrayVar]:
333+
# Allows usage like ArrayVar[FloatVar, 32](...) or ArrayVar[FloatVar]
334+
if isinstance(item, tuple):
335+
base_type = item[0]
336+
dimensions = list(item[1:])
337+
return functools.partial(cls, dimensions=dimensions, base_type=base_type)
338+
else:
339+
return functools.partial(cls, base_type=item)
340+
341+
def __init__(
342+
self,
343+
*args: Any,
344+
dimensions: list[int],
345+
base_type: type[AllowedArrayTypes] = IntVar,
346+
**kwargs: Any,
347+
) -> None:
348+
self.dimensions = dimensions
349+
self.base_type = base_type
350+
351+
# Creating a dummy variable supports IntVar[64] etc.
352+
base_type_instance = base_type()
353+
if isinstance(base_type_instance, _SizedVar):
354+
array_base_type = base_type_instance.type_cls(
355+
size=ast.IntegerLiteral(base_type_instance.size)
356+
)
357+
elif isinstance(base_type_instance, ComplexVar):
358+
array_base_type = base_type_instance.type_cls(base_type=base_type_instance.base_type)
359+
else:
360+
array_base_type = base_type_instance.type_cls()
361+
362+
# Automatically handle Duration array.
363+
if base_type is DurationVar and kwargs["init_expression"]:
364+
kwargs["init_expression"] = (make_duration(i) for i in kwargs["init_expression"])
365+
366+
super().__init__(
367+
*args,
368+
**kwargs,
369+
dimensions=[ast.IntegerLiteral(dimension) for dimension in dimensions],
370+
base_type=array_base_type,
371+
)
372+
373+
def __getitem__(self, index: AstConvertible) -> OQIndexExpression:
374+
return OQIndexExpression(collection=self, index=index)
375+
376+
377+
class OQIndexExpression(OQPyExpression):
378+
"""An oqpy expression corresponding to an index expression."""
379+
380+
def __init__(self, collection: AstConvertible, index: AstConvertible):
381+
self.collection = collection
382+
self.index = index
383+
384+
if isinstance(collection, ArrayVar):
385+
self.type = collection.base_type().type_cls()
386+
387+
def to_ast(self, program: Program) -> ast.IndexExpression:
388+
"""Converts this oqpy index expression into an ast node."""
389+
return ast.IndexExpression(
390+
collection=to_ast(program, self.collection), index=[to_ast(program, self.index)]
391+
)
392+
393+
316394
class OQFunctionCall(OQPyExpression):
317395
"""An oqpy expression corresponding to a function call."""
318396

oqpy/program.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -459,7 +459,11 @@ def _do_assignment(self, var: AstConvertible, op: str, value: AstConvertible) ->
459459
)
460460
)
461461

462-
def set(self, var: classical_types._ClassicalVar, value: AstConvertible) -> Program:
462+
def set(
463+
self,
464+
var: classical_types._ClassicalVar | classical_types.OQIndexExpression,
465+
value: AstConvertible,
466+
) -> Program:
463467
"""Set a variable value."""
464468
self._do_assignment(var, "=", value)
465469
return self

tests/test_directives.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,87 @@ def test_complex_numbers_declaration():
138138

139139
assert prog.to_qasm() == expected
140140

141+
def test_array_declaration():
142+
b = ArrayVar(name="b", init_expression=[True, False], dimensions=[2], base_type=BoolVar)
143+
i = ArrayVar(name="i", init_expression=[0, 1, 2, 3, 4], dimensions=[5], base_type=IntVar)
144+
i55 = ArrayVar(name="i55", init_expression=[0, 1, 2, 3, 4], dimensions=[5], base_type=IntVar[55])
145+
u = ArrayVar(name="u", init_expression=[0, 1, 2, 3, 4], dimensions=[5], base_type=UintVar)
146+
x = ArrayVar(name="x", init_expression=[0e-9, 1e-9, 2e-9], dimensions=[3], base_type=DurationVar)
147+
y = ArrayVar(name="y", init_expression=[0.0, 1.0, 2.0, 3.0], dimensions=[4], base_type=FloatVar)
148+
ang = ArrayVar(name="ang", init_expression=[0.0, 1.0, 2.0, 3.0], dimensions=[4], base_type=AngleVar)
149+
comp = ArrayVar(name="comp", init_expression=[0, 1 + 1j], dimensions=[2], base_type=ComplexVar)
150+
comp55 = ArrayVar(name="comp55", init_expression=[0, 1 + 1j], dimensions=[2], base_type=ComplexVar[float_(55)])
151+
ang_partial = ArrayVar[AngleVar, 2](name="ang_part", init_expression=[oqpy.pi, oqpy.pi/2])
152+
simple = ArrayVar[FloatVar](name="no_init", dimensions=[5])
153+
multidim = ArrayVar[FloatVar[32], 3, 2](name="multiDim", init_expression=[[1.1, 1.2], [2.1, 2.2], [3.1, 3.2]])
154+
155+
vars = [b, i, i55, u, x, y, ang, comp, comp55, ang_partial, simple, multidim]
156+
157+
prog = oqpy.Program(version=None)
158+
prog.declare(vars)
159+
prog.set(i[1], 0) # Set with literal values
160+
idx = IntVar(name="idx", init_expression=5)
161+
val = IntVar(name="val", init_expression=10)
162+
prog.set(i[idx], val)
163+
164+
expected = textwrap.dedent(
165+
"""
166+
int[32] idx = 5;
167+
int[32] val = 10;
168+
array[bool, 2] b = {true, false};
169+
array[int[32], 5] i = {0, 1, 2, 3, 4};
170+
array[int[55], 5] i55 = {0, 1, 2, 3, 4};
171+
array[uint[32], 5] u = {0, 1, 2, 3, 4};
172+
array[duration, 3] x = {0.0ns, 1.0ns, 2.0ns};
173+
array[float[64], 4] y = {0.0, 1.0, 2.0, 3.0};
174+
array[angle[32], 4] ang = {0.0, 1.0, 2.0, 3.0};
175+
array[complex[float[64]], 2] comp = {0, 1.0 + 1.0im};
176+
array[complex[float[55]], 2] comp55 = {0, 1.0 + 1.0im};
177+
array[angle[32], 2] ang_part = {pi, pi / 2};
178+
array[float[64], 5] no_init;
179+
array[float[32], 3, 2] multiDim = {{1.1, 1.2}, {2.1, 2.2}, {3.1, 3.2}};
180+
i[1] = 0;
181+
i[idx] = val;
182+
"""
183+
).strip()
184+
185+
assert prog.to_qasm() == expected
186+
187+
def test_non_trivial_array_access():
188+
prog = oqpy.Program()
189+
port = oqpy.PortVar(name="my_port")
190+
frame = oqpy.FrameVar(name="my_frame", port=port, frequency=1e9, phase=0)
191+
192+
zero_to_one = oqpy.ArrayVar(
193+
name='duration_array',
194+
init_expression=[0.0, 0.25, 0.5, 0.75, 1],
195+
dimensions=[5],
196+
base_type=oqpy.DurationVar
197+
)
198+
one_second = oqpy.DurationVar(init_expression=1, name="one_second")
199+
200+
one = oqpy.IntVar(name="one", init_expression=1)
201+
202+
with oqpy.ForIn(prog, range(4), "idx") as idx:
203+
prog.delay(zero_to_one[idx + one] + one_second, frame)
204+
prog.set(zero_to_one[idx], 5)
205+
206+
expected = textwrap.dedent(
207+
"""
208+
OPENQASM 3.0;
209+
port my_port;
210+
array[duration, 5] duration_array = {0.0ns, 250000000.0ns, 500000000.0ns, 750000000.0ns, 1000000000.0ns};
211+
int[32] one = 1;
212+
duration one_second = 1000000000.0ns;
213+
frame my_frame = newframe(my_port, 1000000000.0, 0);
214+
for int idx in [0:3] {
215+
delay[duration_array[idx + one] + one_second] my_frame;
216+
duration_array[idx] = 5;
217+
}
218+
"""
219+
).strip()
220+
221+
assert prog.to_qasm() == expected
141222

142223
def test_non_trivial_variable_declaration():
143224
prog = Program()
@@ -389,6 +470,26 @@ def test_for_in_var_types():
389470
"""
390471
).strip()
391472

473+
# Test indexing over an ArrayVar
474+
program = oqpy.Program()
475+
pyphases = [0] + [oqpy.pi / i for i in range(10, 1, -2)]
476+
phases = ArrayVar(name="phases", dimensions=[len(pyphases)], init_expression=pyphases, base_type=AngleVar)
477+
478+
with oqpy.ForIn(program, range(len(pyphases)), "idx") as idx:
479+
program.shift_phase(phases[idx], frame)
480+
481+
expected = textwrap.dedent(
482+
"""
483+
OPENQASM 3.0;
484+
port my_port;
485+
array[angle[32], 6] phases = {0, pi / 10, pi / 8, pi / 6, pi / 4, pi / 2};
486+
frame my_frame = newframe(my_port, 3000000000.0, 0);
487+
for int idx in [0:5] {
488+
shift_phase(phases[idx], my_frame);
489+
}
490+
"""
491+
).strip()
492+
392493
assert program.to_qasm() == expected
393494

394495

0 commit comments

Comments
 (0)