Skip to content

Commit d16f4af

Browse files
authored
Allow custom ExternArgument in extern declaration (#83)
* Allow custom ExternArgument in extern declaration This is useful if people want to specify ArrayReferenceType as an argument to an extern function. ArrayReferenceType require an explicit Access modifier (mutable or readonly). Since this is just one edge case, we can let users still specify just the type for all other extern declarations. * Hide direct AST node manipulation * Coverage * Docstrings * Rename * Allow None as argument access
1 parent 43d4456 commit d16f4af

File tree

3 files changed

+80
-7
lines changed

3 files changed

+80
-7
lines changed

oqpy/classical_types.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@
8686
"complex128",
8787
"angle_",
8888
"angle32",
89+
"arrayreference_",
8990
]
9091

9192
# The following methods and constants are useful for creating signatures
@@ -129,6 +130,26 @@ def bit_(size: int | None = None) -> ast.BitType:
129130
return ast.BitType(ast.IntegerLiteral(size) if size is not None else None)
130131

131132

133+
def arrayreference_(
134+
dtype: Union[
135+
ast.IntType,
136+
ast.UintType,
137+
ast.FloatType,
138+
ast.AngleType,
139+
ast.DurationType,
140+
ast.BitType,
141+
ast.BoolType,
142+
ast.ComplexType,
143+
],
144+
dims: int | list[int],
145+
) -> ast.ArrayReferenceType:
146+
"""Create an array reference type."""
147+
dim = (
148+
ast.IntegerLiteral(dims) if isinstance(dims, int) else [ast.IntegerLiteral(d) for d in dims]
149+
)
150+
return ast.ArrayReferenceType(base_type=dtype, dimensions=dim)
151+
152+
132153
duration = ast.DurationType()
133154
stretch = ast.StretchType()
134155
bool_ = ast.BoolType()

oqpy/subroutines.py

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919

2020
import functools
2121
import inspect
22-
from typing import Any, Callable, Optional, Sequence, TypeVar, get_type_hints
22+
from dataclasses import dataclass
23+
from typing import Any, Callable, Literal, Optional, Sequence, TypeVar, get_type_hints
2324

2425
from mypy_extensions import VarArg
2526
from openpulse import ast
@@ -30,13 +31,26 @@
3031
from oqpy.quantum_types import Qubit
3132
from oqpy.timing import convert_float_to_duration
3233

33-
__all__ = ["subroutine", "declare_extern", "declare_waveform_generator"]
34+
__all__ = ["subroutine", "declare_extern", "declare_waveform_generator", "OQPyArgument"]
3435

3536
SubroutineParams = [oqpy.Program, VarArg(AstConvertible)]
3637

3738
FnType = TypeVar("FnType", bound=Callable[..., Any])
3839

3940

41+
@dataclass
42+
class OQPyArgument:
43+
"""An oqpy argument to extern declaration.."""
44+
45+
name: str
46+
dtype: ast.ClassicalType
47+
access: Literal["readonly", "mutable"] | None = None
48+
49+
def unzip(self) -> tuple[str, ast.ClassicalType, ast.AccessControl | None]:
50+
"""Returns the three values, name, dtype and access as a tuple."""
51+
return self.name, self.dtype, ast.AccessControl[self.access] if self.access else None
52+
53+
4054
def enable_decorator_arguments(f: FnType) -> Callable[..., FnType]:
4155
@functools.wraps(f)
4256
def decorator(*args, **kwargs): # type: ignore[no-untyped-def]
@@ -173,7 +187,7 @@ def wrapper(
173187

174188
def declare_extern(
175189
name: str,
176-
args: list[tuple[str, ast.ClassicalType]],
190+
args: list[tuple[str, ast.ClassicalType] | OQPyArgument],
177191
return_type: Optional[ast.ClassicalType] = None,
178192
annotations: Sequence[str | tuple[str, str]] = (),
179193
) -> Callable[..., OQFunctionCall]:
@@ -190,11 +204,28 @@ def declare_extern(
190204
program.set(var, sqrt(0.5))
191205
192206
"""
193-
arg_names = list(zip(*(args)))[0] if args else []
194-
arg_types = list(zip(*(args)))[1] if args else []
207+
arg_names: list[str] = []
208+
arg_types: list[ast.ClassicalType] = []
209+
arg_access: list[ast.AccessControl | None] = []
210+
211+
for arg in args:
212+
if isinstance(arg, tuple):
213+
arg_name, arg_type = arg
214+
access = None
215+
elif isinstance(arg, OQPyArgument):
216+
arg_name, arg_type, access = arg.unzip()
217+
else:
218+
raise Exception(f"Argument {arg} should have a proper type")
219+
arg_names.append(arg_name)
220+
arg_types.append(arg_type)
221+
arg_access.append(access)
222+
195223
extern_decl = ast.ExternDeclaration(
196224
ast.Identifier(name),
197-
[ast.ExternArgument(type=t) for t in arg_types],
225+
[
226+
ast.ExternArgument(type=ctype, access=access)
227+
for ctype, access in zip(arg_types, arg_access)
228+
],
198229
return_type,
199230
)
200231
extern_decl.annotations = make_annotations(annotations)
@@ -236,7 +267,7 @@ def call_extern(*call_args: AstConvertible, **call_kwargs: AstConvertible) -> OQ
236267

237268
def declare_waveform_generator(
238269
name: str,
239-
argtypes: list[tuple[str, ast.ClassicalType]],
270+
argtypes: list[tuple[str, ast.ClassicalType] | OQPyArgument],
240271
annotations: Sequence[str | tuple[str, str]] = (),
241272
) -> Callable[..., OQFunctionCall]:
242273
"""Create a function which generates waveforms using a specified name and argument signature."""

tests/test_directives.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1024,6 +1024,19 @@ def test_declare_extern():
10241024
# Test an extern with no input and no output
10251025
fire_bazooka = declare_extern("fire_bazooka", [])
10261026

1027+
# Test an extern with readonly array
1028+
print_array = declare_extern(
1029+
"print_array",
1030+
[
1031+
("style", int32),
1032+
OQPyArgument(
1033+
name="arr",
1034+
dtype=arrayreference_(int32, 1),
1035+
access="readonly",
1036+
),
1037+
],
1038+
)
1039+
10271040
f = oqpy.FloatVar(name="f", init_expression=0.0)
10281041
i = oqpy.IntVar(name="i", init_expression=5)
10291042

@@ -1032,6 +1045,7 @@ def test_declare_extern():
10321045
program.set(i, time())
10331046
program.do_expression(set_global_voltage(i))
10341047
program.do_expression(fire_bazooka())
1048+
program.do_expression(print_array(1, [0, 1, 2]))
10351049

10361050
expected = textwrap.dedent(
10371051
"""
@@ -1041,13 +1055,15 @@ def test_declare_extern():
10411055
extern time() -> int[32];
10421056
extern set_voltage(int[32]);
10431057
extern fire_bazooka();
1058+
extern print_array(int[32], readonly array[int[32], #dim=1]);
10441059
float[64] f = 0.0;
10451060
int[32] i = 5;
10461061
f = sqrt(f);
10471062
f = arctan(f, f);
10481063
i = time();
10491064
set_voltage(i);
10501065
fire_bazooka();
1066+
print_array(1, {0, 1, 2});
10511067
"""
10521068
).strip()
10531069

@@ -1057,6 +1073,11 @@ def test_declare_extern():
10571073
assert expr_matches(arctan(f, i).args["y"], i)
10581074

10591075

1076+
def test_invalid_extern_declaration():
1077+
# Test with invalid argument
1078+
with pytest.raises(Exception, match="Argument.*"):
1079+
_ = declare_extern("invalid", [int32])
1080+
10601081
def test_defcals():
10611082
prog = Program()
10621083
constant = declare_waveform_generator("constant", [("length", duration), ("iq", complex128)])

0 commit comments

Comments
 (0)