Skip to content

Commit b021dda

Browse files
authored
Convert durations to float and vice versa where appropriate (#43)
* WIP * Propagate expression types in multiplication/division. Coerce durations into floats and vice versa * lint checks * Review comments * Add exception for int where float is requested * Add comment to set_phase * rename make_float/make_duration * Add docstring to make_duration * Fix after bad merge conflict resolution
1 parent ffcf004 commit b021dda

File tree

7 files changed

+293
-96
lines changed

7 files changed

+293
-96
lines changed

oqpy/base.py

Lines changed: 154 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,13 @@ def to_ast(self, program: Program) -> ast.Expression:
5858

5959
@staticmethod
6060
def _to_binary(
61-
op_name: str, first: AstConvertible, second: AstConvertible
61+
op_name: str,
62+
first: AstConvertible,
63+
second: AstConvertible,
64+
result_type: ast.ClassicalType | None = None,
6265
) -> OQPyBinaryExpression:
6366
"""Helper method to produce a binary expression."""
64-
return OQPyBinaryExpression(ast.BinaryOperator[op_name], first, second)
67+
return OQPyBinaryExpression(ast.BinaryOperator[op_name], first, second, result_type)
6568

6669
@staticmethod
6770
def _to_unary(op_name: str, exp: AstConvertible) -> OQPyUnaryExpression:
@@ -93,16 +96,20 @@ def __rmod__(self, other: AstConvertible) -> OQPyBinaryExpression:
9396
return self._to_binary("%", other, self)
9497

9598
def __mul__(self, other: AstConvertible) -> OQPyBinaryExpression:
96-
return self._to_binary("*", self, other)
99+
result_type = compute_product_types(self, other)
100+
return self._to_binary("*", self, other, result_type)
97101

98102
def __rmul__(self, other: AstConvertible) -> OQPyBinaryExpression:
99-
return self._to_binary("*", other, self)
103+
result_type = compute_product_types(other, self)
104+
return self._to_binary("*", other, self, result_type)
100105

101106
def __truediv__(self, other: AstConvertible) -> OQPyBinaryExpression:
102-
return self._to_binary("/", self, other)
107+
result_type = compute_quotient_types(self, other)
108+
return self._to_binary("/", self, other, result_type)
103109

104110
def __rtruediv__(self, other: AstConvertible) -> OQPyBinaryExpression:
105-
return self._to_binary("/", other, self)
111+
result_type = compute_quotient_types(other, self)
112+
return self._to_binary("/", other, self, result_type)
106113

107114
def __pow__(self, other: AstConvertible) -> OQPyBinaryExpression:
108115
return self._to_binary("**", self, other)
@@ -168,6 +175,128 @@ def __bool__(self) -> bool:
168175
)
169176

170177

178+
def _get_type(val: AstConvertible) -> ast.ClassicalType:
179+
if isinstance(val, OQPyExpression):
180+
return val.type
181+
elif isinstance(val, int):
182+
return ast.IntType()
183+
elif isinstance(val, float):
184+
return ast.FloatType()
185+
elif isinstance(val, complex):
186+
return ast.ComplexType(ast.FloatType())
187+
else:
188+
raise ValueError(f"Cannot multiply/divide oqpy expression with with {type(val)}")
189+
190+
191+
def compute_product_types(left: AstConvertible, right: AstConvertible) -> ast.ClassicalType:
192+
"""Find the result type for a product of two terms."""
193+
left_type = _get_type(left)
194+
right_type = _get_type(right)
195+
196+
types_map = {
197+
(ast.FloatType, ast.FloatType): left_type,
198+
(ast.FloatType, ast.IntType): left_type,
199+
(ast.FloatType, ast.UintType): left_type,
200+
(ast.FloatType, ast.DurationType): right_type,
201+
(ast.FloatType, ast.AngleType): right_type,
202+
(ast.FloatType, ast.ComplexType): right_type,
203+
(ast.IntType, ast.FloatType): right_type,
204+
(ast.IntType, ast.IntType): left_type,
205+
(ast.IntType, ast.UintType): left_type,
206+
(ast.IntType, ast.DurationType): right_type,
207+
(ast.IntType, ast.AngleType): right_type,
208+
(ast.IntType, ast.ComplexType): right_type,
209+
(ast.UintType, ast.FloatType): right_type,
210+
(ast.UintType, ast.IntType): right_type,
211+
(ast.UintType, ast.UintType): left_type,
212+
(ast.UintType, ast.DurationType): right_type,
213+
(ast.UintType, ast.AngleType): right_type,
214+
(ast.UintType, ast.ComplexType): right_type,
215+
(ast.DurationType, ast.FloatType): left_type,
216+
(ast.DurationType, ast.IntType): left_type,
217+
(ast.DurationType, ast.UintType): left_type,
218+
(ast.DurationType, ast.DurationType): TypeError(
219+
"Cannot multiply two durations. You may need to re-group computations to eliminate this."
220+
),
221+
(ast.DurationType, ast.AngleType): TypeError("Cannot multiply duration and angle"),
222+
(ast.DurationType, ast.ComplexType): TypeError("Cannot multiply duration and complex"),
223+
(ast.AngleType, ast.FloatType): left_type,
224+
(ast.AngleType, ast.IntType): left_type,
225+
(ast.AngleType, ast.UintType): left_type,
226+
(ast.AngleType, ast.DurationType): TypeError("Cannot multiply angle and duration"),
227+
(ast.AngleType, ast.AngleType): TypeError("Cannot multiply two angles"),
228+
(ast.AngleType, ast.ComplexType): TypeError("Cannot multiply angle and complex"),
229+
(ast.ComplexType, ast.FloatType): left_type,
230+
(ast.ComplexType, ast.IntType): left_type,
231+
(ast.ComplexType, ast.UintType): left_type,
232+
(ast.ComplexType, ast.DurationType): TypeError("Cannot multiply complex and duration"),
233+
(ast.ComplexType, ast.AngleType): TypeError("Cannot multiply complex and angle"),
234+
(ast.ComplexType, ast.ComplexType): left_type,
235+
}
236+
237+
try:
238+
result_type = types_map[type(left_type), type(right_type)]
239+
except KeyError as e:
240+
raise TypeError(f"Could not identify types for product {left} and {right}") from e
241+
if isinstance(result_type, Exception):
242+
raise result_type
243+
return result_type
244+
245+
246+
def compute_quotient_types(left: AstConvertible, right: AstConvertible) -> ast.ClassicalType:
247+
"""Find the result type for a quotient of two terms."""
248+
left_type = _get_type(left)
249+
right_type = _get_type(right)
250+
float_type = ast.FloatType()
251+
252+
types_map = {
253+
(ast.FloatType, ast.FloatType): left_type,
254+
(ast.FloatType, ast.IntType): left_type,
255+
(ast.FloatType, ast.UintType): left_type,
256+
(ast.FloatType, ast.DurationType): TypeError("Cannot divide float by duration"),
257+
(ast.FloatType, ast.AngleType): TypeError("Cannot divide float by angle"),
258+
(ast.FloatType, ast.ComplexType): right_type,
259+
(ast.IntType, ast.FloatType): right_type,
260+
(ast.IntType, ast.IntType): float_type,
261+
(ast.IntType, ast.UintType): float_type,
262+
(ast.IntType, ast.DurationType): TypeError("Cannot divide int by duration"),
263+
(ast.IntType, ast.AngleType): TypeError("Cannot divide int by angle"),
264+
(ast.IntType, ast.ComplexType): right_type,
265+
(ast.UintType, ast.FloatType): right_type,
266+
(ast.UintType, ast.IntType): float_type,
267+
(ast.UintType, ast.UintType): float_type,
268+
(ast.UintType, ast.DurationType): TypeError("Cannot divide uint by duration"),
269+
(ast.UintType, ast.AngleType): TypeError("Cannot divide uint by angle"),
270+
(ast.UintType, ast.ComplexType): right_type,
271+
(ast.DurationType, ast.FloatType): left_type,
272+
(ast.DurationType, ast.IntType): left_type,
273+
(ast.DurationType, ast.UintType): left_type,
274+
(ast.DurationType, ast.DurationType): ast.FloatType(),
275+
(ast.DurationType, ast.AngleType): TypeError("Cannot divide duration by angle"),
276+
(ast.DurationType, ast.ComplexType): TypeError("Cannot divide duration by complex"),
277+
(ast.AngleType, ast.FloatType): left_type,
278+
(ast.AngleType, ast.IntType): left_type,
279+
(ast.AngleType, ast.UintType): left_type,
280+
(ast.AngleType, ast.DurationType): TypeError("Cannot divide by duration"),
281+
(ast.AngleType, ast.AngleType): float_type,
282+
(ast.AngleType, ast.ComplexType): TypeError("Cannot divide by angle by complex"),
283+
(ast.ComplexType, ast.FloatType): left_type,
284+
(ast.ComplexType, ast.IntType): left_type,
285+
(ast.ComplexType, ast.UintType): left_type,
286+
(ast.ComplexType, ast.DurationType): TypeError("Cannot divide by duration"),
287+
(ast.ComplexType, ast.AngleType): TypeError("Cannot divide by angle"),
288+
(ast.ComplexType, ast.ComplexType): left_type,
289+
}
290+
291+
try:
292+
result_type = types_map[type(left_type), type(right_type)]
293+
except KeyError as e:
294+
raise TypeError(f"Could not identify types for quotient {left} and {right}") from e
295+
if isinstance(result_type, Exception):
296+
raise result_type
297+
return result_type
298+
299+
171300
def logical_and(first: AstConvertible, second: AstConvertible) -> OQPyBinaryExpression:
172301
"""Logical AND."""
173302
return OQPyBinaryExpression(ast.BinaryOperator["&&"], first, second)
@@ -227,30 +356,38 @@ def to_ast(self, program: Program) -> ast.UnaryExpression:
227356
class OQPyBinaryExpression(OQPyExpression):
228357
"""An expression consisting of two subexpressions joined by an operator."""
229358

230-
def __init__(self, op: ast.BinaryOperator, lhs: AstConvertible, rhs: AstConvertible):
359+
def __init__(
360+
self,
361+
op: ast.BinaryOperator,
362+
lhs: AstConvertible,
363+
rhs: AstConvertible,
364+
ast_type: ast.ClassicalType | None = None,
365+
):
231366
super().__init__()
232367
self.op = op
233368
self.lhs = lhs
234369
self.rhs = rhs
235-
# TODO (#50): More robust type checking which considers both arguments
370+
# TODO (#9): More robust type checking which considers both arguments
236371
# types, as well as the operator.
237-
if isinstance(lhs, OQPyExpression):
238-
self.type = lhs.type
239-
elif isinstance(rhs, OQPyExpression):
240-
self.type = rhs.type
241-
else:
242-
raise TypeError("Neither lhs nor rhs is an expression?")
372+
if ast_type is None:
373+
if isinstance(lhs, OQPyExpression):
374+
ast_type = lhs.type
375+
elif isinstance(rhs, OQPyExpression):
376+
ast_type = rhs.type
377+
else:
378+
raise TypeError("Neither lhs nor rhs is an expression?")
379+
self.type = ast_type
243380

244381
# Adding floats to durations is not allowed. So we promote types as necessary.
245382
if isinstance(self.type, ast.DurationType) and self.op in [
246383
ast.BinaryOperator["+"],
247384
ast.BinaryOperator["-"],
248385
]:
249386
# Late import to avoid circular imports.
250-
from oqpy.timing import make_duration
387+
from oqpy.timing import convert_float_to_duration
251388

252-
self.lhs = make_duration(self.lhs)
253-
self.rhs = make_duration(self.rhs)
389+
self.lhs = convert_float_to_duration(self.lhs)
390+
self.rhs = convert_float_to_duration(self.rhs)
254391

255392
def to_ast(self, program: Program) -> ast.BinaryExpression:
256393
"""Converts the OQpy expression into an ast node."""

oqpy/classical_types.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
optional_ast,
4343
to_ast,
4444
)
45-
from oqpy.timing import make_duration
45+
from oqpy.timing import convert_float_to_duration
4646

4747
if TYPE_CHECKING:
4848
from typing import Literal
@@ -158,15 +158,15 @@ class Identifier(OQPyExpression):
158158

159159
name: str
160160

161-
def __init__(self, name: str) -> None:
162-
self.type = None
161+
def __init__(self, name: str, ast_type: ast.ClassicalType) -> None:
163162
self.name = name
163+
self.type = ast_type
164164

165165
def to_ast(self, program: Program) -> ast.Expression:
166166
return ast.Identifier(name=self.name)
167167

168168

169-
pi = Identifier(name="pi")
169+
pi = Identifier(name="pi", ast_type=ast.FloatType())
170170

171171

172172
class _ClassicalVar(Var, OQPyExpression):
@@ -327,7 +327,7 @@ def __init__(
327327
**type_kwargs: Any,
328328
) -> None:
329329
if init_expression is not None and not isinstance(init_expression, str):
330-
init_expression = make_duration(init_expression)
330+
init_expression = convert_float_to_duration(init_expression)
331331
super().__init__(init_expression, name, *args, **type_kwargs)
332332

333333

@@ -381,7 +381,9 @@ def __init__(
381381

382382
# Automatically handle Duration array.
383383
if base_type is DurationVar and kwargs["init_expression"] is not None:
384-
kwargs["init_expression"] = (make_duration(i) for i in kwargs["init_expression"])
384+
kwargs["init_expression"] = (
385+
convert_float_to_duration(i) for i in kwargs["init_expression"]
386+
)
385387

386388
super().__init__(
387389
*args,

oqpy/control_flow.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
_ClassicalVar,
3131
convert_range,
3232
)
33-
from oqpy.timing import make_duration
33+
from oqpy.timing import convert_float_to_duration
3434

3535
ClassicalVarT = TypeVar("ClassicalVarT", bound=_ClassicalVar)
3636

@@ -126,7 +126,7 @@ def ForIn(
126126
set_declaration = convert_range(program, iterator)
127127
elif isinstance(iterator, Iterable):
128128
if identifier_type is DurationVar:
129-
iterator = (make_duration(i) for i in iterator)
129+
iterator = (convert_float_to_duration(i) for i in iterator)
130130

131131
set_declaration = ast.DiscreteSet([to_ast(program, i) for i in iterator])
132132
else:

oqpy/program.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
to_ast,
4242
)
4343
from oqpy.pulse import FrameVar, PortVar, WaveformVar
44-
from oqpy.timing import make_duration
44+
from oqpy.timing import convert_duration_to_float, convert_float_to_duration
4545

4646
__all__ = ["Program"]
4747

@@ -91,6 +91,8 @@ def add_statement(self, stmt: ast.Statement | ast.Pragma) -> None:
9191
class Program:
9292
"""A builder class for OpenQASM/OpenPulse programs."""
9393

94+
DURATION_MAX_DIGITS = 12
95+
9496
def __init__(self, version: Optional[str] = "3.0", simplify_constants: bool = True) -> None:
9597
self.stack: list[ProgramState] = [ProgramState()]
9698
self.defcals: dict[
@@ -364,7 +366,7 @@ def delay(
364366
"""Apply a delay to a set of qubits or frames."""
365367
if not isinstance(qubits_or_frames, Iterable):
366368
qubits_or_frames = [qubits_or_frames]
367-
ast_duration = to_ast(self, make_duration(time))
369+
ast_duration = to_ast(self, convert_float_to_duration(time))
368370
ast_qubits_or_frames = map_to_ast(self, qubits_or_frames)
369371
self._add_statement(ast.DelayInstruction(ast_duration, ast_qubits_or_frames))
370372
return self
@@ -393,32 +395,37 @@ def capture(self, frame: AstConvertible, kernel: AstConvertible) -> Program:
393395

394396
def set_phase(self, frame: AstConvertible, phase: AstConvertible) -> Program:
395397
"""Set the phase of a particular frame."""
396-
self.function_call("set_phase", [frame, phase])
398+
# We use make_float to force phase to be a unitless (i.e. non-duration) quantity.
399+
# Users are expected to keep track the units that are not expressible in openqasm
400+
# such as s^{-1}. For instance, in 2 * oqpy.pi * tppi * DurationVar(1e-8),
401+
# tppi is a float but has a frequency unit. This will coerce the result type
402+
# to a float by assuming the duration should be represented in seconds."
403+
self.function_call("set_phase", [frame, convert_duration_to_float(phase)])
397404
return self
398405

399406
def shift_phase(self, frame: AstConvertible, phase: AstConvertible) -> Program:
400407
"""Shift the phase of a particular frame."""
401-
self.function_call("shift_phase", [frame, phase])
408+
self.function_call("shift_phase", [frame, convert_duration_to_float(phase)])
402409
return self
403410

404411
def set_frequency(self, frame: AstConvertible, freq: AstConvertible) -> Program:
405412
"""Set the frequency of a particular frame."""
406-
self.function_call("set_frequency", [frame, freq])
413+
self.function_call("set_frequency", [frame, convert_duration_to_float(freq)])
407414
return self
408415

409416
def shift_frequency(self, frame: AstConvertible, freq: AstConvertible) -> Program:
410417
"""Shift the frequency of a particular frame."""
411-
self.function_call("shift_frequency", [frame, freq])
418+
self.function_call("shift_frequency", [frame, convert_duration_to_float(freq)])
412419
return self
413420

414421
def set_scale(self, frame: AstConvertible, scale: AstConvertible) -> Program:
415422
"""Set the amplitude scaling of a particular frame."""
416-
self.function_call("set_scale", [frame, scale])
423+
self.function_call("set_scale", [frame, convert_duration_to_float(scale)])
417424
return self
418425

419426
def shift_scale(self, frame: AstConvertible, scale: AstConvertible) -> Program:
420427
"""Shift the amplitude scaling of a particular frame."""
421-
self.function_call("shift_scale", [frame, scale])
428+
self.function_call("shift_scale", [frame, convert_duration_to_float(scale)])
422429
return self
423430

424431
def returns(self, expression: AstConvertible) -> Program:
@@ -473,7 +480,7 @@ def pragma(self, command: str) -> Program:
473480
def _do_assignment(self, var: AstConvertible, op: str, value: AstConvertible) -> None:
474481
"""Helper function for variable assignment operations."""
475482
if isinstance(var, classical_types.DurationVar):
476-
value = make_duration(value)
483+
value = convert_float_to_duration(value)
477484
var_ast = to_ast(self, var)
478485
if isinstance(var_ast, ast.IndexExpression):
479486
assert isinstance(var_ast.collection, ast.Identifier)

oqpy/subroutines.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from oqpy.base import AstConvertible, OQPyExpression, make_annotations, to_ast
2929
from oqpy.classical_types import OQFunctionCall, _ClassicalVar
3030
from oqpy.quantum_types import Qubit
31-
from oqpy.timing import make_duration
31+
from oqpy.timing import convert_float_to_duration
3232

3333
__all__ = ["subroutine", "annotate_subroutine", "declare_extern", "declare_waveform_generator"]
3434

@@ -200,14 +200,14 @@ def call_extern(*call_args: AstConvertible, **call_kwargs: AstConvertible) -> OQ
200200
raise TypeError(f"{name}() got multiple values for argument '{k}'.")
201201

202202
if type(arg_types[k_idx]) == ast.DurationType:
203-
new_args[k_idx] = make_duration(call_kwargs[k])
203+
new_args[k_idx] = convert_float_to_duration(call_kwargs[k])
204204
else:
205205
new_args[k_idx] = call_kwargs[k]
206206

207207
# Casting floats into durations for the non-keyword arguments
208208
for i, a in enumerate(call_args):
209209
if type(arg_types[i]) == ast.DurationType:
210-
new_args[i] = make_duration(a)
210+
new_args[i] = convert_float_to_duration(a)
211211
return OQFunctionCall(name, new_args, return_type, extern_decl=extern_decl)
212212

213213
return call_extern

0 commit comments

Comments
 (0)