Skip to content

Commit 1a1b84c

Browse files
authored
Allow externs with no arguments and no return value (#58)
* Allow externs with no arguments * Allow externs with no return
1 parent 7379620 commit 1a1b84c

File tree

4 files changed

+62
-14
lines changed

4 files changed

+62
-14
lines changed

oqpy/base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
import math
2525
import sys
2626
from abc import ABC, abstractmethod
27-
from typing import TYPE_CHECKING, Any, Iterable, Sequence, Union
27+
from typing import TYPE_CHECKING, Any, Iterable, Optional, Sequence, Union
2828

2929
import numpy as np
3030
from openpulse import ast
@@ -50,7 +50,7 @@ class OQPyExpression:
5050
``==`` which produces a new expression instead of producing a python boolean.
5151
"""
5252

53-
type: ast.ClassicalType
53+
type: Optional[ast.ClassicalType]
5454

5555
def to_ast(self, program: Program) -> ast.Expression:
5656
"""Converts the oqpy expression into an ast node."""
@@ -175,7 +175,7 @@ def __bool__(self) -> bool:
175175
)
176176

177177

178-
def _get_type(val: AstConvertible) -> ast.ClassicalType:
178+
def _get_type(val: AstConvertible) -> Optional[ast.ClassicalType]:
179179
if isinstance(val, OQPyExpression):
180180
return val.type
181181
elif isinstance(val, int):

oqpy/classical_types.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
Any,
2626
Callable,
2727
Iterable,
28+
Optional,
2829
Sequence,
2930
Type,
3031
TypeVar,
@@ -420,7 +421,7 @@ def __init__(
420421
self,
421422
identifier: Union[str, ast.Identifier],
422423
args: Iterable[AstConvertible],
423-
return_type: ast.ClassicalType,
424+
return_type: Optional[ast.ClassicalType],
424425
extern_decl: ast.ExternDeclaration | None = None,
425426
subroutine_decl: ast.SubroutineDefinition | None = None,
426427
):
@@ -429,7 +430,7 @@ def __init__(
429430
Args:
430431
identifier: The function name.
431432
args: The function arguments.
432-
return_type: The type returned by the function call.
433+
return_type: The type returned by the function call. If none, returns nothing.
433434
extern_decl: An optional extern declaration ast node. If present,
434435
this extern declaration will be added to the top of the program
435436
whenever this is converted to ast.

oqpy/subroutines.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
import functools
2121
import inspect
22-
from typing import Any, Callable, Sequence, TypeVar, get_type_hints
22+
from typing import Any, Callable, Optional, Sequence, TypeVar, get_type_hints
2323

2424
from mypy_extensions import VarArg
2525
from openpulse import ast
@@ -164,7 +164,7 @@ def wrapper(
164164
def declare_extern(
165165
name: str,
166166
args: list[tuple[str, ast.ClassicalType]],
167-
return_type: ast.ClassicalType,
167+
return_type: Optional[ast.ClassicalType] = None,
168168
annotations: Sequence[str | tuple[str, str]] = (),
169169
) -> Callable[..., OQFunctionCall]:
170170
"""Declare an extern and return a callable which adds the extern.
@@ -180,8 +180,8 @@ def declare_extern(
180180
program.set(var, sqrt(0.5))
181181
182182
"""
183-
arg_names = list(zip(*(args)))[0]
184-
arg_types = list(zip(*(args)))[1]
183+
arg_names = list(zip(*(args)))[0] if args else []
184+
arg_types = list(zip(*(args)))[1] if args else []
185185
extern_decl = ast.ExternDeclaration(
186186
ast.Identifier(name),
187187
[ast.ExternArgument(type=t) for t in arg_types],

tests/test_directives.py

Lines changed: 52 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@
2525
import numpy as np
2626
import pytest
2727
from openpulse import ast
28-
from openpulse.printer import dumps
2928
from openpulse.parser import QASMVisitor
29+
from openpulse.printer import dumps
3030

3131
import oqpy
3232
from oqpy import *
@@ -950,6 +950,54 @@ def test_set_shift_frequency():
950950
_check_respects_type_hints(prog)
951951

952952

953+
def test_declare_extern():
954+
program = Program()
955+
956+
# Test an extern with one input and output
957+
sqrt = declare_extern("sqrt", [("x", float32)], float32)
958+
959+
# Test an extern with two inputs and one output
960+
arctan = declare_extern("arctan", [("x", float32), ("y", float32)], float32)
961+
962+
# Test an extern with no input and one output
963+
time = declare_extern("time", [], int32)
964+
965+
# Test an extern with one input and no output
966+
set_global_voltage = declare_extern("set_voltage", [("voltage", int32)])
967+
968+
# Test an extern with no input and no output
969+
fire_bazooka = declare_extern("fire_bazooka", [])
970+
971+
f = oqpy.FloatVar(name="f", init_expression=0.0)
972+
i = oqpy.IntVar(name="i", init_expression=5)
973+
974+
program.set(f, sqrt(f))
975+
program.set(f, arctan(f, f))
976+
program.set(i, time())
977+
program.do_expression(set_global_voltage(i))
978+
program.do_expression(fire_bazooka())
979+
980+
expected = textwrap.dedent(
981+
"""
982+
OPENQASM 3.0;
983+
extern sqrt(float[32]) -> float[32];
984+
extern arctan(float[32], float[32]) -> float[32];
985+
extern time() -> int[32];
986+
extern set_voltage(int[32]);
987+
extern fire_bazooka();
988+
float[64] f = 0.0;
989+
int[32] i = 5;
990+
f = sqrt(f);
991+
f = arctan(f, f);
992+
i = time();
993+
set_voltage(i);
994+
fire_bazooka();
995+
"""
996+
).strip()
997+
998+
assert program.to_qasm() == expected
999+
1000+
9531001
def test_defcals():
9541002
prog = Program()
9551003
constant = declare_waveform_generator("constant", [("length", duration), ("iq", complex128)])
@@ -1530,7 +1578,7 @@ def test_needs_declaration():
15301578
).strip()
15311579

15321580
declared_vars = {}
1533-
undeclared_vars= ["i1", "i2", "f1", "f2", "q1", "q2"]
1581+
undeclared_vars = ["i1", "i2", "f1", "f2", "q1", "q2"]
15341582
statement_ast = [
15351583
ast.ClassicalAssignment(
15361584
lvalue=ast.Identifier(name="i1"),
@@ -1745,10 +1793,10 @@ def test_in_place_subroutine_declaration():
17451793
@subroutine(annotations=["inline", ("optimize", "-O3")])
17461794
def f(prog: Program, x: IntVar) -> IntVar:
17471795
return x
1748-
1796+
17491797
prog = Program()
17501798
i = IntVar(0, name="i")
1751-
prog.declare([i,f])
1799+
prog.declare([i, f])
17521800
prog.increment(i, 1)
17531801

17541802
expected = textwrap.dedent(
@@ -2124,7 +2172,6 @@ def f(prog: oqpy.Program) -> oqpy.IntVar:
21242172
def g(prog: oqpy.Program) -> oqpy.IntVar:
21252173
return f(prog)
21262174

2127-
21282175
prog = oqpy.Program()
21292176
x = oqpy.IntVar(name="x")
21302177
prog.set(x, g(prog))

0 commit comments

Comments
 (0)