Skip to content

Convert durations to float and vice versa where appropriate #43

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jun 22, 2023
171 changes: 154 additions & 17 deletions oqpy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,13 @@ def to_ast(self, program: Program) -> ast.Expression:

@staticmethod
def _to_binary(
op_name: str, first: AstConvertible, second: AstConvertible
op_name: str,
first: AstConvertible,
second: AstConvertible,
result_type: ast.ClassicalType | None = None,
) -> OQPyBinaryExpression:
"""Helper method to produce a binary expression."""
return OQPyBinaryExpression(ast.BinaryOperator[op_name], first, second)
return OQPyBinaryExpression(ast.BinaryOperator[op_name], first, second, result_type)

@staticmethod
def _to_unary(op_name: str, exp: AstConvertible) -> OQPyUnaryExpression:
Expand Down Expand Up @@ -93,16 +96,20 @@ def __rmod__(self, other: AstConvertible) -> OQPyBinaryExpression:
return self._to_binary("%", other, self)

def __mul__(self, other: AstConvertible) -> OQPyBinaryExpression:
return self._to_binary("*", self, other)
result_type = compute_product_types(self, other)
return self._to_binary("*", self, other, result_type)

def __rmul__(self, other: AstConvertible) -> OQPyBinaryExpression:
return self._to_binary("*", other, self)
result_type = compute_product_types(other, self)
return self._to_binary("*", other, self, result_type)

def __truediv__(self, other: AstConvertible) -> OQPyBinaryExpression:
return self._to_binary("/", self, other)
result_type = compute_quotient_types(self, other)
return self._to_binary("/", self, other, result_type)

def __rtruediv__(self, other: AstConvertible) -> OQPyBinaryExpression:
return self._to_binary("/", other, self)
result_type = compute_quotient_types(other, self)
return self._to_binary("/", other, self, result_type)

def __pow__(self, other: AstConvertible) -> OQPyBinaryExpression:
return self._to_binary("**", self, other)
Expand Down Expand Up @@ -168,6 +175,128 @@ def __bool__(self) -> bool:
)


def _get_type(val: AstConvertible) -> ast.ClassicalType:
if isinstance(val, OQPyExpression):
return val.type
elif isinstance(val, int):
return ast.IntType()
elif isinstance(val, float):
return ast.FloatType()
elif isinstance(val, complex):
return ast.ComplexType(ast.FloatType())
else:
raise ValueError(f"Cannot multiply/divide oqpy expression with with {type(val)}")


def compute_product_types(left: AstConvertible, right: AstConvertible) -> ast.ClassicalType:
"""Find the result type for a product of two terms."""
left_type = _get_type(left)
right_type = _get_type(right)

types_map = {
(ast.FloatType, ast.FloatType): left_type,
(ast.FloatType, ast.IntType): left_type,
(ast.FloatType, ast.UintType): left_type,
(ast.FloatType, ast.DurationType): right_type,
(ast.FloatType, ast.AngleType): right_type,
(ast.FloatType, ast.ComplexType): right_type,
(ast.IntType, ast.FloatType): right_type,
(ast.IntType, ast.IntType): left_type,
(ast.IntType, ast.UintType): left_type,
(ast.IntType, ast.DurationType): right_type,
(ast.IntType, ast.AngleType): right_type,
(ast.IntType, ast.ComplexType): right_type,
(ast.UintType, ast.FloatType): right_type,
(ast.UintType, ast.IntType): right_type,
(ast.UintType, ast.UintType): left_type,
(ast.UintType, ast.DurationType): right_type,
(ast.UintType, ast.AngleType): right_type,
(ast.UintType, ast.ComplexType): right_type,
(ast.DurationType, ast.FloatType): left_type,
(ast.DurationType, ast.IntType): left_type,
(ast.DurationType, ast.UintType): left_type,
(ast.DurationType, ast.DurationType): TypeError(
"Cannot multiply two durations. You may need to re-group computations to eliminate this."
),
(ast.DurationType, ast.AngleType): TypeError("Cannot multiply duration and angle"),
(ast.DurationType, ast.ComplexType): TypeError("Cannot multiply duration and complex"),
(ast.AngleType, ast.FloatType): left_type,
(ast.AngleType, ast.IntType): left_type,
(ast.AngleType, ast.UintType): left_type,
(ast.AngleType, ast.DurationType): TypeError("Cannot multiply angle and duration"),
(ast.AngleType, ast.AngleType): TypeError("Cannot multiply two angles"),
(ast.AngleType, ast.ComplexType): TypeError("Cannot multiply angle and complex"),
(ast.ComplexType, ast.FloatType): left_type,
(ast.ComplexType, ast.IntType): left_type,
(ast.ComplexType, ast.UintType): left_type,
(ast.ComplexType, ast.DurationType): TypeError("Cannot multiply complex and duration"),
(ast.ComplexType, ast.AngleType): TypeError("Cannot multiply complex and angle"),
(ast.ComplexType, ast.ComplexType): left_type,
}

try:
result_type = types_map[type(left_type), type(right_type)]
except KeyError as e:
raise TypeError(f"Could not identify types for product {left} and {right}") from e
if isinstance(result_type, Exception):
raise result_type
return result_type


def compute_quotient_types(left: AstConvertible, right: AstConvertible) -> ast.ClassicalType:
"""Find the result type for a quotient of two terms."""
left_type = _get_type(left)
right_type = _get_type(right)
float_type = ast.FloatType()

types_map = {
(ast.FloatType, ast.FloatType): left_type,
(ast.FloatType, ast.IntType): left_type,
(ast.FloatType, ast.UintType): left_type,
(ast.FloatType, ast.DurationType): TypeError("Cannot divide float by duration"),
(ast.FloatType, ast.AngleType): TypeError("Cannot divide float by angle"),
(ast.FloatType, ast.ComplexType): right_type,
(ast.IntType, ast.FloatType): right_type,
(ast.IntType, ast.IntType): float_type,
(ast.IntType, ast.UintType): float_type,
(ast.IntType, ast.DurationType): TypeError("Cannot divide int by duration"),
(ast.IntType, ast.AngleType): TypeError("Cannot divide int by angle"),
(ast.IntType, ast.ComplexType): right_type,
(ast.UintType, ast.FloatType): right_type,
(ast.UintType, ast.IntType): float_type,
(ast.UintType, ast.UintType): float_type,
(ast.UintType, ast.DurationType): TypeError("Cannot divide uint by duration"),
(ast.UintType, ast.AngleType): TypeError("Cannot divide uint by angle"),
(ast.UintType, ast.ComplexType): right_type,
(ast.DurationType, ast.FloatType): left_type,
(ast.DurationType, ast.IntType): left_type,
(ast.DurationType, ast.UintType): left_type,
(ast.DurationType, ast.DurationType): ast.FloatType(),
(ast.DurationType, ast.AngleType): TypeError("Cannot divide duration by angle"),
(ast.DurationType, ast.ComplexType): TypeError("Cannot divide duration by complex"),
(ast.AngleType, ast.FloatType): left_type,
(ast.AngleType, ast.IntType): left_type,
(ast.AngleType, ast.UintType): left_type,
(ast.AngleType, ast.DurationType): TypeError("Cannot divide by duration"),
(ast.AngleType, ast.AngleType): float_type,
(ast.AngleType, ast.ComplexType): TypeError("Cannot divide by angle by complex"),
(ast.ComplexType, ast.FloatType): left_type,
(ast.ComplexType, ast.IntType): left_type,
(ast.ComplexType, ast.UintType): left_type,
(ast.ComplexType, ast.DurationType): TypeError("Cannot divide by duration"),
(ast.ComplexType, ast.AngleType): TypeError("Cannot divide by angle"),
(ast.ComplexType, ast.ComplexType): left_type,
}

try:
result_type = types_map[type(left_type), type(right_type)]
except KeyError as e:
raise TypeError(f"Could not identify types for quotient {left} and {right}") from e
if isinstance(result_type, Exception):
raise result_type
return result_type


def logical_and(first: AstConvertible, second: AstConvertible) -> OQPyBinaryExpression:
"""Logical AND."""
return OQPyBinaryExpression(ast.BinaryOperator["&&"], first, second)
Expand Down Expand Up @@ -227,30 +356,38 @@ def to_ast(self, program: Program) -> ast.UnaryExpression:
class OQPyBinaryExpression(OQPyExpression):
"""An expression consisting of two subexpressions joined by an operator."""

def __init__(self, op: ast.BinaryOperator, lhs: AstConvertible, rhs: AstConvertible):
def __init__(
self,
op: ast.BinaryOperator,
lhs: AstConvertible,
rhs: AstConvertible,
ast_type: ast.ClassicalType | None = None,
):
super().__init__()
self.op = op
self.lhs = lhs
self.rhs = rhs
# TODO (#50): More robust type checking which considers both arguments
# TODO (#9): More robust type checking which considers both arguments
# types, as well as the operator.
if isinstance(lhs, OQPyExpression):
self.type = lhs.type
elif isinstance(rhs, OQPyExpression):
self.type = rhs.type
else:
raise TypeError("Neither lhs nor rhs is an expression?")
if ast_type is None:
if isinstance(lhs, OQPyExpression):
ast_type = lhs.type
elif isinstance(rhs, OQPyExpression):
ast_type = rhs.type
else:
raise TypeError("Neither lhs nor rhs is an expression?")
self.type = ast_type

# Adding floats to durations is not allowed. So we promote types as necessary.
if isinstance(self.type, ast.DurationType) and self.op in [
ast.BinaryOperator["+"],
ast.BinaryOperator["-"],
]:
# Late import to avoid circular imports.
from oqpy.timing import make_duration
from oqpy.timing import convert_float_to_duration

self.lhs = make_duration(self.lhs)
self.rhs = make_duration(self.rhs)
self.lhs = convert_float_to_duration(self.lhs)
self.rhs = convert_float_to_duration(self.rhs)

def to_ast(self, program: Program) -> ast.BinaryExpression:
"""Converts the OQpy expression into an ast node."""
Expand Down
14 changes: 8 additions & 6 deletions oqpy/classical_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
optional_ast,
to_ast,
)
from oqpy.timing import make_duration
from oqpy.timing import convert_float_to_duration

if TYPE_CHECKING:
from typing import Literal
Expand Down Expand Up @@ -158,15 +158,15 @@ class Identifier(OQPyExpression):

name: str

def __init__(self, name: str) -> None:
self.type = None
def __init__(self, name: str, ast_type: ast.ClassicalType) -> None:
self.name = name
self.type = ast_type

def to_ast(self, program: Program) -> ast.Expression:
return ast.Identifier(name=self.name)


pi = Identifier(name="pi")
pi = Identifier(name="pi", ast_type=ast.FloatType())


class _ClassicalVar(Var, OQPyExpression):
Expand Down Expand Up @@ -327,7 +327,7 @@ def __init__(
**type_kwargs: Any,
) -> None:
if init_expression is not None and not isinstance(init_expression, str):
init_expression = make_duration(init_expression)
init_expression = convert_float_to_duration(init_expression)
super().__init__(init_expression, name, *args, **type_kwargs)


Expand Down Expand Up @@ -381,7 +381,9 @@ def __init__(

# Automatically handle Duration array.
if base_type is DurationVar and kwargs["init_expression"] is not None:
kwargs["init_expression"] = (make_duration(i) for i in kwargs["init_expression"])
kwargs["init_expression"] = (
convert_float_to_duration(i) for i in kwargs["init_expression"]
)

super().__init__(
*args,
Expand Down
4 changes: 2 additions & 2 deletions oqpy/control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
_ClassicalVar,
convert_range,
)
from oqpy.timing import make_duration
from oqpy.timing import convert_float_to_duration

ClassicalVarT = TypeVar("ClassicalVarT", bound=_ClassicalVar)

Expand Down Expand Up @@ -126,7 +126,7 @@ def ForIn(
set_declaration = convert_range(program, iterator)
elif isinstance(iterator, Iterable):
if identifier_type is DurationVar:
iterator = (make_duration(i) for i in iterator)
iterator = (convert_float_to_duration(i) for i in iterator)

set_declaration = ast.DiscreteSet([to_ast(program, i) for i in iterator])
else:
Expand Down
25 changes: 16 additions & 9 deletions oqpy/program.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
to_ast,
)
from oqpy.pulse import FrameVar, PortVar, WaveformVar
from oqpy.timing import make_duration
from oqpy.timing import convert_duration_to_float, convert_float_to_duration

__all__ = ["Program"]

Expand Down Expand Up @@ -91,6 +91,8 @@ def add_statement(self, stmt: ast.Statement | ast.Pragma) -> None:
class Program:
"""A builder class for OpenQASM/OpenPulse programs."""

DURATION_MAX_DIGITS = 12

def __init__(self, version: Optional[str] = "3.0", simplify_constants: bool = True) -> None:
self.stack: list[ProgramState] = [ProgramState()]
self.defcals: dict[
Expand Down Expand Up @@ -364,7 +366,7 @@ def delay(
"""Apply a delay to a set of qubits or frames."""
if not isinstance(qubits_or_frames, Iterable):
qubits_or_frames = [qubits_or_frames]
ast_duration = to_ast(self, make_duration(time))
ast_duration = to_ast(self, convert_float_to_duration(time))
ast_qubits_or_frames = map_to_ast(self, qubits_or_frames)
self._add_statement(ast.DelayInstruction(ast_duration, ast_qubits_or_frames))
return self
Expand Down Expand Up @@ -393,32 +395,37 @@ def capture(self, frame: AstConvertible, kernel: AstConvertible) -> Program:

def set_phase(self, frame: AstConvertible, phase: AstConvertible) -> Program:
"""Set the phase of a particular frame."""
self.function_call("set_phase", [frame, phase])
# We use make_float to force phase to be a unitless (i.e. non-duration) quantity.
# Users are expected to keep track the units that are not expressible in openqasm
# such as s^{-1}. For instance, in 2 * oqpy.pi * tppi * DurationVar(1e-8),
# tppi is a float but has a frequency unit. This will coerce the result type
# to a float by assuming the duration should be represented in seconds."
self.function_call("set_phase", [frame, convert_duration_to_float(phase)])
return self

def shift_phase(self, frame: AstConvertible, phase: AstConvertible) -> Program:
"""Shift the phase of a particular frame."""
self.function_call("shift_phase", [frame, phase])
self.function_call("shift_phase", [frame, convert_duration_to_float(phase)])
return self

def set_frequency(self, frame: AstConvertible, freq: AstConvertible) -> Program:
"""Set the frequency of a particular frame."""
self.function_call("set_frequency", [frame, freq])
self.function_call("set_frequency", [frame, convert_duration_to_float(freq)])
return self

def shift_frequency(self, frame: AstConvertible, freq: AstConvertible) -> Program:
"""Shift the frequency of a particular frame."""
self.function_call("shift_frequency", [frame, freq])
self.function_call("shift_frequency", [frame, convert_duration_to_float(freq)])
return self

def set_scale(self, frame: AstConvertible, scale: AstConvertible) -> Program:
"""Set the amplitude scaling of a particular frame."""
self.function_call("set_scale", [frame, scale])
self.function_call("set_scale", [frame, convert_duration_to_float(scale)])
return self

def shift_scale(self, frame: AstConvertible, scale: AstConvertible) -> Program:
"""Shift the amplitude scaling of a particular frame."""
self.function_call("shift_scale", [frame, scale])
self.function_call("shift_scale", [frame, convert_duration_to_float(scale)])
return self

def returns(self, expression: AstConvertible) -> Program:
Expand Down Expand Up @@ -473,7 +480,7 @@ def pragma(self, command: str) -> Program:
def _do_assignment(self, var: AstConvertible, op: str, value: AstConvertible) -> None:
"""Helper function for variable assignment operations."""
if isinstance(var, classical_types.DurationVar):
value = make_duration(value)
value = convert_float_to_duration(value)
var_ast = to_ast(self, var)
if isinstance(var_ast, ast.IndexExpression):
assert isinstance(var_ast.collection, ast.Identifier)
Expand Down
6 changes: 3 additions & 3 deletions oqpy/subroutines.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from oqpy.base import AstConvertible, OQPyExpression, make_annotations, to_ast
from oqpy.classical_types import OQFunctionCall, _ClassicalVar
from oqpy.quantum_types import Qubit
from oqpy.timing import make_duration
from oqpy.timing import convert_float_to_duration

__all__ = ["subroutine", "annotate_subroutine", "declare_extern", "declare_waveform_generator"]

Expand Down Expand Up @@ -200,14 +200,14 @@ def call_extern(*call_args: AstConvertible, **call_kwargs: AstConvertible) -> OQ
raise TypeError(f"{name}() got multiple values for argument '{k}'.")

if type(arg_types[k_idx]) == ast.DurationType:
new_args[k_idx] = make_duration(call_kwargs[k])
new_args[k_idx] = convert_float_to_duration(call_kwargs[k])
else:
new_args[k_idx] = call_kwargs[k]

# Casting floats into durations for the non-keyword arguments
for i, a in enumerate(call_args):
if type(arg_types[i]) == ast.DurationType:
new_args[i] = make_duration(a)
new_args[i] = convert_float_to_duration(a)
return OQFunctionCall(name, new_args, return_type, extern_decl=extern_decl)

return call_extern
Expand Down
Loading