@@ -58,10 +58,13 @@ def to_ast(self, program: Program) -> ast.Expression:
58
58
59
59
@staticmethod
60
60
def _to_binary (
61
- op_name : str , first : AstConvertible , second : AstConvertible
61
+ op_name : str ,
62
+ first : AstConvertible ,
63
+ second : AstConvertible ,
64
+ result_type : ast .ClassicalType | None = None ,
62
65
) -> OQPyBinaryExpression :
63
66
"""Helper method to produce a binary expression."""
64
- return OQPyBinaryExpression (ast .BinaryOperator [op_name ], first , second )
67
+ return OQPyBinaryExpression (ast .BinaryOperator [op_name ], first , second , result_type )
65
68
66
69
@staticmethod
67
70
def _to_unary (op_name : str , exp : AstConvertible ) -> OQPyUnaryExpression :
@@ -93,16 +96,20 @@ def __rmod__(self, other: AstConvertible) -> OQPyBinaryExpression:
93
96
return self ._to_binary ("%" , other , self )
94
97
95
98
def __mul__ (self , other : AstConvertible ) -> OQPyBinaryExpression :
96
- return self ._to_binary ("*" , self , other )
99
+ result_type = compute_product_types (self , other )
100
+ return self ._to_binary ("*" , self , other , result_type )
97
101
98
102
def __rmul__ (self , other : AstConvertible ) -> OQPyBinaryExpression :
99
- return self ._to_binary ("*" , other , self )
103
+ result_type = compute_product_types (other , self )
104
+ return self ._to_binary ("*" , other , self , result_type )
100
105
101
106
def __truediv__ (self , other : AstConvertible ) -> OQPyBinaryExpression :
102
- return self ._to_binary ("/" , self , other )
107
+ result_type = compute_quotient_types (self , other )
108
+ return self ._to_binary ("/" , self , other , result_type )
103
109
104
110
def __rtruediv__ (self , other : AstConvertible ) -> OQPyBinaryExpression :
105
- return self ._to_binary ("/" , other , self )
111
+ result_type = compute_quotient_types (other , self )
112
+ return self ._to_binary ("/" , other , self , result_type )
106
113
107
114
def __pow__ (self , other : AstConvertible ) -> OQPyBinaryExpression :
108
115
return self ._to_binary ("**" , self , other )
@@ -168,6 +175,128 @@ def __bool__(self) -> bool:
168
175
)
169
176
170
177
178
+ def _get_type (val : AstConvertible ) -> ast .ClassicalType :
179
+ if isinstance (val , OQPyExpression ):
180
+ return val .type
181
+ elif isinstance (val , int ):
182
+ return ast .IntType ()
183
+ elif isinstance (val , float ):
184
+ return ast .FloatType ()
185
+ elif isinstance (val , complex ):
186
+ return ast .ComplexType (ast .FloatType ())
187
+ else :
188
+ raise ValueError (f"Cannot multiply/divide oqpy expression with with { type (val )} " )
189
+
190
+
191
+ def compute_product_types (left : AstConvertible , right : AstConvertible ) -> ast .ClassicalType :
192
+ """Find the result type for a product of two terms."""
193
+ left_type = _get_type (left )
194
+ right_type = _get_type (right )
195
+
196
+ types_map = {
197
+ (ast .FloatType , ast .FloatType ): left_type ,
198
+ (ast .FloatType , ast .IntType ): left_type ,
199
+ (ast .FloatType , ast .UintType ): left_type ,
200
+ (ast .FloatType , ast .DurationType ): right_type ,
201
+ (ast .FloatType , ast .AngleType ): right_type ,
202
+ (ast .FloatType , ast .ComplexType ): right_type ,
203
+ (ast .IntType , ast .FloatType ): right_type ,
204
+ (ast .IntType , ast .IntType ): left_type ,
205
+ (ast .IntType , ast .UintType ): left_type ,
206
+ (ast .IntType , ast .DurationType ): right_type ,
207
+ (ast .IntType , ast .AngleType ): right_type ,
208
+ (ast .IntType , ast .ComplexType ): right_type ,
209
+ (ast .UintType , ast .FloatType ): right_type ,
210
+ (ast .UintType , ast .IntType ): right_type ,
211
+ (ast .UintType , ast .UintType ): left_type ,
212
+ (ast .UintType , ast .DurationType ): right_type ,
213
+ (ast .UintType , ast .AngleType ): right_type ,
214
+ (ast .UintType , ast .ComplexType ): right_type ,
215
+ (ast .DurationType , ast .FloatType ): left_type ,
216
+ (ast .DurationType , ast .IntType ): left_type ,
217
+ (ast .DurationType , ast .UintType ): left_type ,
218
+ (ast .DurationType , ast .DurationType ): TypeError (
219
+ "Cannot multiply two durations. You may need to re-group computations to eliminate this."
220
+ ),
221
+ (ast .DurationType , ast .AngleType ): TypeError ("Cannot multiply duration and angle" ),
222
+ (ast .DurationType , ast .ComplexType ): TypeError ("Cannot multiply duration and complex" ),
223
+ (ast .AngleType , ast .FloatType ): left_type ,
224
+ (ast .AngleType , ast .IntType ): left_type ,
225
+ (ast .AngleType , ast .UintType ): left_type ,
226
+ (ast .AngleType , ast .DurationType ): TypeError ("Cannot multiply angle and duration" ),
227
+ (ast .AngleType , ast .AngleType ): TypeError ("Cannot multiply two angles" ),
228
+ (ast .AngleType , ast .ComplexType ): TypeError ("Cannot multiply angle and complex" ),
229
+ (ast .ComplexType , ast .FloatType ): left_type ,
230
+ (ast .ComplexType , ast .IntType ): left_type ,
231
+ (ast .ComplexType , ast .UintType ): left_type ,
232
+ (ast .ComplexType , ast .DurationType ): TypeError ("Cannot multiply complex and duration" ),
233
+ (ast .ComplexType , ast .AngleType ): TypeError ("Cannot multiply complex and angle" ),
234
+ (ast .ComplexType , ast .ComplexType ): left_type ,
235
+ }
236
+
237
+ try :
238
+ result_type = types_map [type (left_type ), type (right_type )]
239
+ except KeyError as e :
240
+ raise TypeError (f"Could not identify types for product { left } and { right } " ) from e
241
+ if isinstance (result_type , Exception ):
242
+ raise result_type
243
+ return result_type
244
+
245
+
246
+ def compute_quotient_types (left : AstConvertible , right : AstConvertible ) -> ast .ClassicalType :
247
+ """Find the result type for a quotient of two terms."""
248
+ left_type = _get_type (left )
249
+ right_type = _get_type (right )
250
+ float_type = ast .FloatType ()
251
+
252
+ types_map = {
253
+ (ast .FloatType , ast .FloatType ): left_type ,
254
+ (ast .FloatType , ast .IntType ): left_type ,
255
+ (ast .FloatType , ast .UintType ): left_type ,
256
+ (ast .FloatType , ast .DurationType ): TypeError ("Cannot divide float by duration" ),
257
+ (ast .FloatType , ast .AngleType ): TypeError ("Cannot divide float by angle" ),
258
+ (ast .FloatType , ast .ComplexType ): right_type ,
259
+ (ast .IntType , ast .FloatType ): right_type ,
260
+ (ast .IntType , ast .IntType ): float_type ,
261
+ (ast .IntType , ast .UintType ): float_type ,
262
+ (ast .IntType , ast .DurationType ): TypeError ("Cannot divide int by duration" ),
263
+ (ast .IntType , ast .AngleType ): TypeError ("Cannot divide int by angle" ),
264
+ (ast .IntType , ast .ComplexType ): right_type ,
265
+ (ast .UintType , ast .FloatType ): right_type ,
266
+ (ast .UintType , ast .IntType ): float_type ,
267
+ (ast .UintType , ast .UintType ): float_type ,
268
+ (ast .UintType , ast .DurationType ): TypeError ("Cannot divide uint by duration" ),
269
+ (ast .UintType , ast .AngleType ): TypeError ("Cannot divide uint by angle" ),
270
+ (ast .UintType , ast .ComplexType ): right_type ,
271
+ (ast .DurationType , ast .FloatType ): left_type ,
272
+ (ast .DurationType , ast .IntType ): left_type ,
273
+ (ast .DurationType , ast .UintType ): left_type ,
274
+ (ast .DurationType , ast .DurationType ): ast .FloatType (),
275
+ (ast .DurationType , ast .AngleType ): TypeError ("Cannot divide duration by angle" ),
276
+ (ast .DurationType , ast .ComplexType ): TypeError ("Cannot divide duration by complex" ),
277
+ (ast .AngleType , ast .FloatType ): left_type ,
278
+ (ast .AngleType , ast .IntType ): left_type ,
279
+ (ast .AngleType , ast .UintType ): left_type ,
280
+ (ast .AngleType , ast .DurationType ): TypeError ("Cannot divide by duration" ),
281
+ (ast .AngleType , ast .AngleType ): float_type ,
282
+ (ast .AngleType , ast .ComplexType ): TypeError ("Cannot divide by angle by complex" ),
283
+ (ast .ComplexType , ast .FloatType ): left_type ,
284
+ (ast .ComplexType , ast .IntType ): left_type ,
285
+ (ast .ComplexType , ast .UintType ): left_type ,
286
+ (ast .ComplexType , ast .DurationType ): TypeError ("Cannot divide by duration" ),
287
+ (ast .ComplexType , ast .AngleType ): TypeError ("Cannot divide by angle" ),
288
+ (ast .ComplexType , ast .ComplexType ): left_type ,
289
+ }
290
+
291
+ try :
292
+ result_type = types_map [type (left_type ), type (right_type )]
293
+ except KeyError as e :
294
+ raise TypeError (f"Could not identify types for quotient { left } and { right } " ) from e
295
+ if isinstance (result_type , Exception ):
296
+ raise result_type
297
+ return result_type
298
+
299
+
171
300
def logical_and (first : AstConvertible , second : AstConvertible ) -> OQPyBinaryExpression :
172
301
"""Logical AND."""
173
302
return OQPyBinaryExpression (ast .BinaryOperator ["&&" ], first , second )
@@ -227,30 +356,38 @@ def to_ast(self, program: Program) -> ast.UnaryExpression:
227
356
class OQPyBinaryExpression (OQPyExpression ):
228
357
"""An expression consisting of two subexpressions joined by an operator."""
229
358
230
- def __init__ (self , op : ast .BinaryOperator , lhs : AstConvertible , rhs : AstConvertible ):
359
+ def __init__ (
360
+ self ,
361
+ op : ast .BinaryOperator ,
362
+ lhs : AstConvertible ,
363
+ rhs : AstConvertible ,
364
+ ast_type : ast .ClassicalType | None = None ,
365
+ ):
231
366
super ().__init__ ()
232
367
self .op = op
233
368
self .lhs = lhs
234
369
self .rhs = rhs
235
- # TODO (#50 ): More robust type checking which considers both arguments
370
+ # TODO (#9 ): More robust type checking which considers both arguments
236
371
# types, as well as the operator.
237
- if isinstance (lhs , OQPyExpression ):
238
- self .type = lhs .type
239
- elif isinstance (rhs , OQPyExpression ):
240
- self .type = rhs .type
241
- else :
242
- raise TypeError ("Neither lhs nor rhs is an expression?" )
372
+ if ast_type is None :
373
+ if isinstance (lhs , OQPyExpression ):
374
+ ast_type = lhs .type
375
+ elif isinstance (rhs , OQPyExpression ):
376
+ ast_type = rhs .type
377
+ else :
378
+ raise TypeError ("Neither lhs nor rhs is an expression?" )
379
+ self .type = ast_type
243
380
244
381
# Adding floats to durations is not allowed. So we promote types as necessary.
245
382
if isinstance (self .type , ast .DurationType ) and self .op in [
246
383
ast .BinaryOperator ["+" ],
247
384
ast .BinaryOperator ["-" ],
248
385
]:
249
386
# Late import to avoid circular imports.
250
- from oqpy .timing import make_duration
387
+ from oqpy .timing import convert_float_to_duration
251
388
252
- self .lhs = make_duration (self .lhs )
253
- self .rhs = make_duration (self .rhs )
389
+ self .lhs = convert_float_to_duration (self .lhs )
390
+ self .rhs = convert_float_to_duration (self .rhs )
254
391
255
392
def to_ast (self , program : Program ) -> ast .BinaryExpression :
256
393
"""Converts the OQpy expression into an ast node."""
0 commit comments