Skip to content

Allow externs with no arguments and no return value #58

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions oqpy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import math
import sys
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Iterable, Sequence, Union
from typing import TYPE_CHECKING, Any, Iterable, Optional, Sequence, Union

import numpy as np
from openpulse import ast
Expand All @@ -50,7 +50,7 @@ class OQPyExpression:
``==`` which produces a new expression instead of producing a python boolean.
"""

type: ast.ClassicalType
type: Optional[ast.ClassicalType]

def to_ast(self, program: Program) -> ast.Expression:
"""Converts the oqpy expression into an ast node."""
Expand Down Expand Up @@ -175,7 +175,7 @@ def __bool__(self) -> bool:
)


def _get_type(val: AstConvertible) -> ast.ClassicalType:
def _get_type(val: AstConvertible) -> Optional[ast.ClassicalType]:
if isinstance(val, OQPyExpression):
return val.type
elif isinstance(val, int):
Expand Down
5 changes: 3 additions & 2 deletions oqpy/classical_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
Any,
Callable,
Iterable,
Optional,
Sequence,
Type,
TypeVar,
Expand Down Expand Up @@ -420,7 +421,7 @@ def __init__(
self,
identifier: Union[str, ast.Identifier],
args: Iterable[AstConvertible],
return_type: ast.ClassicalType,
return_type: Optional[ast.ClassicalType],
extern_decl: ast.ExternDeclaration | None = None,
subroutine_decl: ast.SubroutineDefinition | None = None,
):
Expand All @@ -429,7 +430,7 @@ def __init__(
Args:
identifier: The function name.
args: The function arguments.
return_type: The type returned by the function call.
return_type: The type returned by the function call. If none, returns nothing.
extern_decl: An optional extern declaration ast node. If present,
this extern declaration will be added to the top of the program
whenever this is converted to ast.
Expand Down
8 changes: 4 additions & 4 deletions oqpy/subroutines.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import functools
import inspect
from typing import Any, Callable, Sequence, TypeVar, get_type_hints
from typing import Any, Callable, Optional, Sequence, TypeVar, get_type_hints

from mypy_extensions import VarArg
from openpulse import ast
Expand Down Expand Up @@ -164,7 +164,7 @@ def wrapper(
def declare_extern(
name: str,
args: list[tuple[str, ast.ClassicalType]],
return_type: ast.ClassicalType,
return_type: Optional[ast.ClassicalType] = None,
annotations: Sequence[str | tuple[str, str]] = (),
) -> Callable[..., OQFunctionCall]:
"""Declare an extern and return a callable which adds the extern.
Expand All @@ -180,8 +180,8 @@ def declare_extern(
program.set(var, sqrt(0.5))

"""
arg_names = list(zip(*(args)))[0]
arg_types = list(zip(*(args)))[1]
arg_names = list(zip(*(args)))[0] if args else []
arg_types = list(zip(*(args)))[1] if args else []
extern_decl = ast.ExternDeclaration(
ast.Identifier(name),
[ast.ExternArgument(type=t) for t in arg_types],
Expand Down
57 changes: 52 additions & 5 deletions tests/test_directives.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
import numpy as np
import pytest
from openpulse import ast
from openpulse.printer import dumps
from openpulse.parser import QASMVisitor
from openpulse.printer import dumps

import oqpy
from oqpy import *
Expand Down Expand Up @@ -950,6 +950,54 @@ def test_set_shift_frequency():
_check_respects_type_hints(prog)


def test_declare_extern():
program = Program()

# Test an extern with one input and output
sqrt = declare_extern("sqrt", [("x", float32)], float32)

# Test an extern with two inputs and one output
arctan = declare_extern("arctan", [("x", float32), ("y", float32)], float32)

# Test an extern with no input and one output
time = declare_extern("time", [], int32)

# Test an extern with one input and no output
set_global_voltage = declare_extern("set_voltage", [("voltage", int32)])

# Test an extern with no input and no output
fire_bazooka = declare_extern("fire_bazooka", [])

f = oqpy.FloatVar(name="f", init_expression=0.0)
i = oqpy.IntVar(name="i", init_expression=5)

program.set(f, sqrt(f))
program.set(f, arctan(f, f))
program.set(i, time())
program.do_expression(set_global_voltage(i))
program.do_expression(fire_bazooka())

expected = textwrap.dedent(
"""
OPENQASM 3.0;
extern sqrt(float[32]) -> float[32];
extern arctan(float[32], float[32]) -> float[32];
extern time() -> int[32];
extern set_voltage(int[32]);
extern fire_bazooka();
float[64] f = 0.0;
int[32] i = 5;
f = sqrt(f);
f = arctan(f, f);
i = time();
set_voltage(i);
fire_bazooka();
"""
).strip()

assert program.to_qasm() == expected


def test_defcals():
prog = Program()
constant = declare_waveform_generator("constant", [("length", duration), ("iq", complex128)])
Expand Down Expand Up @@ -1530,7 +1578,7 @@ def test_needs_declaration():
).strip()

declared_vars = {}
undeclared_vars= ["i1", "i2", "f1", "f2", "q1", "q2"]
undeclared_vars = ["i1", "i2", "f1", "f2", "q1", "q2"]
statement_ast = [
ast.ClassicalAssignment(
lvalue=ast.Identifier(name="i1"),
Expand Down Expand Up @@ -1745,10 +1793,10 @@ def test_in_place_subroutine_declaration():
@subroutine(annotations=["inline", ("optimize", "-O3")])
def f(prog: Program, x: IntVar) -> IntVar:
return x

prog = Program()
i = IntVar(0, name="i")
prog.declare([i,f])
prog.declare([i, f])
prog.increment(i, 1)

expected = textwrap.dedent(
Expand Down Expand Up @@ -2124,7 +2172,6 @@ def f(prog: oqpy.Program) -> oqpy.IntVar:
def g(prog: oqpy.Program) -> oqpy.IntVar:
return f(prog)


prog = oqpy.Program()
x = oqpy.IntVar(name="x")
prog.set(x, g(prog))
Expand Down