Skip to content

Commit def54d5

Browse files
authored
input tweaks and tests (#99)
* input tweaks and tests * handle edge cases better
1 parent 906bdcd commit def54d5

File tree

2 files changed

+87
-7
lines changed

2 files changed

+87
-7
lines changed

oqpy/classical_types.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -227,11 +227,12 @@ def to_ast(self, program: Program) -> ast.Identifier:
227227
def make_declaration_statement(self, program: Program) -> ast.Statement:
228228
"""Make an ast statement that declares the OQpy variable."""
229229
if isinstance(self.init_expression, str) and self.init_expression in ("input", "output"):
230-
return ast.IODeclaration(
230+
stmt = ast.IODeclaration(
231231
ast.IOKeyword[self.init_expression], self.type, self.to_ast(program)
232232
)
233-
init_expression_ast = optional_ast(program, self.init_expression)
234-
stmt = ast.ClassicalDeclaration(self.type, self.to_ast(program), init_expression_ast)
233+
else:
234+
init_expression_ast = optional_ast(program, self.init_expression)
235+
stmt = ast.ClassicalDeclaration(self.type, self.to_ast(program), init_expression_ast)
235236
stmt.annotations = make_annotations(self.annotations)
236237
return stmt
237238

@@ -416,10 +417,17 @@ def __init__(
416417
array_base_type = base_type_instance.type_cls()
417418

418419
# Automatically handle Duration array.
419-
if base_type is DurationVar and kwargs["init_expression"] is not None:
420-
kwargs["init_expression"] = (
421-
convert_float_to_duration(i) for i in kwargs["init_expression"]
420+
init = kwargs.get("init_expression")
421+
if (
422+
base_type is DurationVar
423+
and init is not None
424+
and not (
425+
# type check to avoid element-wise numpy comparison for array init
426+
isinstance(init, str)
427+
and init == "input"
422428
)
429+
):
430+
kwargs["init_expression"] = [convert_float_to_duration(i) for i in init]
423431

424432
super().__init__(
425433
*args,

tests/test_directives.py

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,70 @@ def test_variable_declaration():
173173
_check_respects_type_hints(prog)
174174

175175

176+
def test_input_variable_declaration():
177+
b = BoolVar("input", "b")
178+
i = IntVar("input", "i")
179+
j = IntVar[None]("input", "j")
180+
u = UintVar("input", "u")
181+
x = DurationVar("input", "blah")
182+
y = FloatVar[50]("input", "y")
183+
z = FloatVar[50]("input", "z")
184+
z.annotations = [("my_annotation", "1,2,3")]
185+
ang = AngleVar("input", name="ang")
186+
arr = BitVar[20]("input", name="arr")
187+
c = BitVar("input", name="c")
188+
vars = [b, i, j, u, x, y, z, ang, arr, c]
189+
190+
prog = Program(version=None)
191+
prog.declare(vars)
192+
prog.set(arr[1], 0)
193+
index = IntVar("input", "index")
194+
prog.set(arr[index], 1)
195+
prog.set(arr[index + 1], 0)
196+
197+
y2 = FloatVar(2.5, "y")
198+
with pytest.raises(RuntimeError):
199+
prog.set(y2, 3.0)
200+
with pytest.raises(IndexError):
201+
prog.set(arr[40], 2)
202+
with pytest.raises(ValueError):
203+
BitVar[2.1](name="d")
204+
with pytest.raises(ValueError):
205+
BitVar[0](name="d")
206+
with pytest.raises(ValueError):
207+
BitVar[-1](name="d")
208+
with pytest.raises(IndexError):
209+
prog.set(arr[1.3], 0)
210+
with pytest.raises(IndexError):
211+
prog.set(arr[index * 2.0], 0)
212+
with pytest.raises(TypeError):
213+
prog.set(c[0], 1)
214+
215+
expected = textwrap.dedent(
216+
"""
217+
input int[32] index;
218+
input bool b;
219+
input int[32] i;
220+
input int j;
221+
input uint[32] u;
222+
input duration blah;
223+
input float[50] y;
224+
@my_annotation 1,2,3
225+
input float[50] z;
226+
input angle[32] ang;
227+
input bit[20] arr;
228+
input bit c;
229+
arr[1] = 0;
230+
arr[index] = 1;
231+
arr[index + 1] = 0;
232+
"""
233+
).strip()
234+
235+
assert isinstance(arr[14], OQIndexExpression)
236+
assert prog.to_qasm() == expected
237+
_check_respects_type_hints(prog)
238+
239+
176240
def test_complex_numbers_declaration():
177241
vars = [
178242
ComplexVar(name="z"),
@@ -225,7 +289,10 @@ def test_complex_numbers_declaration():
225289

226290
def test_array_declaration():
227291
b = ArrayVar(name="b", init_expression=[True, False], dimensions=[2], base_type=BoolVar)
292+
b_in = ArrayVar(name="b_in", init_expression="input", dimensions=[2], base_type=BoolVar)
228293
i = ArrayVar(name="i", init_expression=[0, 1, 2, 3, 4], dimensions=[5], base_type=IntVar)
294+
i_in = ArrayVar(name="i_in", init_expression="input", dimensions=[5], base_type=IntVar)
295+
d_in = ArrayVar(name="d_in", init_expression="input", dimensions=[5], base_type=DurationVar)
229296
i55 = ArrayVar(
230297
name="i55", init_expression=[0, 1, 2, 3, 4], dimensions=[5], base_type=IntVar[55]
231298
)
@@ -242,6 +309,7 @@ def test_array_declaration():
242309
name="comp55", init_expression=[0, 1 + 1j], dimensions=[2], base_type=ComplexVar[float_(55)]
243310
)
244311
ang_partial = ArrayVar[AngleVar, 2](name="ang_part", init_expression=[oqpy.pi, oqpy.pi / 2])
312+
arg_array = ArrayVar([1, 2, 3], name="arg_array", dimensions=[3])
245313
simple = ArrayVar[FloatVar](name="no_init", dimensions=[5])
246314
multidim = ArrayVar[FloatVar[32], 3, 2](
247315
name="multiDim", init_expression=[[1.1, 1.2], [2.1, 2.2], [3.1, 3.2]]
@@ -253,7 +321,7 @@ def test_array_declaration():
253321
base_type=DurationVar,
254322
)
255323

256-
vars = [b, i, i55, u, x, y, ang, comp, comp55, ang_partial, simple, multidim, npinit]
324+
vars = [b, b_in, i, i_in, d_in, i55, u, x, y, ang, comp, comp55, ang_partial, arg_array, simple, multidim, npinit]
257325

258326
prog = oqpy.Program(version=None)
259327
prog.declare(vars)
@@ -271,7 +339,10 @@ def test_array_declaration():
271339
int[32] val = 10;
272340
duration d = 0.0ns;
273341
array[bool, 2] b = {true, false};
342+
input array[bool, 2] b_in;
274343
array[int[32], 5] i = {0, 1, 2, 3, 4};
344+
input array[int[32], 5] i_in;
345+
input array[duration, 5] d_in;
275346
array[int[55], 5] i55 = {0, 1, 2, 3, 4};
276347
array[uint[32], 5] u = {0, 1, 2, 3, 4};
277348
array[duration, 3] x = {0.0ns, 1.0ns, 2.0ns};
@@ -280,6 +351,7 @@ def test_array_declaration():
280351
array[complex[float[64]], 2] comp = {0, 1.0 + 1.0im};
281352
array[complex[float[55]], 2] comp55 = {0, 1.0 + 1.0im};
282353
array[angle[32], 2] ang_part = {pi, pi / 2};
354+
array[int[32], 3] arg_array = {1, 2, 3};
283355
array[float[64], 5] no_init;
284356
array[float[32], 3, 2] multiDim = {{1.1, 1.2}, {2.1, 2.2}, {3.1, 3.2}};
285357
array[duration, 11] npinit = {0.0ns, 1.0ns, 2.0ns, 4.0ns};

0 commit comments

Comments
 (0)