diff --git a/oqpy/base.py b/oqpy/base.py index ac99c9e..d69383b 100644 --- a/oqpy/base.py +++ b/oqpy/base.py @@ -374,9 +374,11 @@ def detect_and_convert_constants(val: float | np.floating[Any], program: Program """Construct a float ast expression which is either a literal or an expression using constants.""" if val == 0: return ast.FloatLiteral(val) + if val < 0.5 or val > 100: + return ast.FloatLiteral(val) x = val / (math.pi / 4.0) rx = round(x) - if rx > 100 or not math.isclose(x, rx, rel_tol=1e-12): + if not math.isclose(x, rx, rel_tol=1e-12): return ast.FloatLiteral(val) term: OQPyExpression if rx == 4: diff --git a/oqpy/classical_types.py b/oqpy/classical_types.py index e8dfc08..252b831 100644 --- a/oqpy/classical_types.py +++ b/oqpy/classical_types.py @@ -45,6 +45,8 @@ from oqpy.timing import make_duration if TYPE_CHECKING: + from typing import Literal + from oqpy.program import Program __all__ = [ @@ -177,7 +179,7 @@ class _ClassicalVar(Var, OQPyExpression): def __init__( self, - init_expression: AstConvertible | None = None, + init_expression: AstConvertible | Literal["input", "output"] | None = None, name: str | None = None, needs_declaration: bool = True, annotations: Sequence[str | tuple[str, str]] = (), @@ -196,6 +198,10 @@ def to_ast(self, program: Program) -> ast.Identifier: def make_declaration_statement(self, program: Program) -> ast.Statement: """Make an ast statement that declares the OQpy variable.""" + if isinstance(self.init_expression, str) and self.init_expression in ("input", "output"): + return ast.IODeclaration( + ast.IOKeyword[self.init_expression], self.type, self.to_ast(program) + ) init_expression_ast = optional_ast(program, self.init_expression) stmt = ast.ClassicalDeclaration(self.type, self.to_ast(program), init_expression_ast) stmt.annotations = make_annotations(self.annotations) @@ -295,7 +301,7 @@ def __class_getitem__(cls, item: ast.FloatType) -> Callable[..., ComplexVar]: def __init__( self, - init_expression: AstConvertible | None = None, + init_expression: AstConvertible | Literal["input", "output"] | None = None, *args: Any, base_type: ast.FloatType = float64, **kwargs: Any, @@ -303,7 +309,7 @@ def __init__( assert isinstance(base_type, ast.FloatType) self.base_type = base_type - if not isinstance(init_expression, (complex, type(None), OQPyExpression)): + if not isinstance(init_expression, (complex, type(None), str, OQPyExpression)): init_expression = complex(init_expression) # type: ignore[arg-type] super().__init__(init_expression, *args, **kwargs, base_type=base_type) @@ -315,12 +321,12 @@ class DurationVar(_ClassicalVar): def __init__( self, - init_expression: AstConvertible | None = None, + init_expression: AstConvertible | Literal["input", "output"] | None = None, name: str | None = None, *args: Any, **type_kwargs: Any, ) -> None: - if init_expression is not None: + if init_expression is not None and not isinstance(init_expression, str): init_expression = make_duration(init_expression) super().__init__(init_expression, name, *args, **type_kwargs) diff --git a/tests/test_directives.py b/tests/test_directives.py index 0381a97..d0a78e3 100644 --- a/tests/test_directives.py +++ b/tests/test_directives.py @@ -1949,3 +1949,31 @@ def test_oqpy_range(): ).strip() assert prog.to_qasm() == expected _check_respects_type_hints(prog) + + +def test_io_declaration(): + x = oqpy.DurationVar("input", name="x") + y = oqpy.FloatVar("output", name="y") + wf = oqpy.WaveformVar("input", name="wf") + port = oqpy.PortVar(name="my_port", init_expression="input") + frame = oqpy.FrameVar(port, 5e9, 0, name="my_frame") + + prog = Program() + prog.declare(x) + prog.set(y, 1) + prog.play(frame, wf) + + expected = textwrap.dedent( + """ + OPENQASM 3.0; + input port my_port; + output float[64] y; + frame my_frame = newframe(my_port, 5000000000.0, 0); + input waveform wf; + input duration x; + y = 1; + play(my_frame, wf); + """ + ).strip() + assert prog.to_qasm() == expected + _check_respects_type_hints(prog)