|
39 | 39 |
|
40 | 40 | __all__ = [
|
41 | 41 | "pi",
|
| 42 | + "ArrayVar", |
42 | 43 | "BoolVar",
|
43 | 44 | "IntVar",
|
44 | 45 | "UintVar",
|
|
48 | 49 | "ComplexVar",
|
49 | 50 | "DurationVar",
|
50 | 51 | "OQFunctionCall",
|
| 52 | + "OQIndexExpression", |
51 | 53 | "StretchVar",
|
52 | 54 | "_ClassicalVar",
|
53 | 55 | "duration",
|
@@ -272,18 +274,20 @@ class ComplexVar(_ClassicalVar):
|
272 | 274 | """An oqpy variable with bit type."""
|
273 | 275 |
|
274 | 276 | type_cls = ast.ComplexType
|
| 277 | + base_type: ast.FloatType = float64 |
275 | 278 |
|
276 |
| - def __class_getitem__(cls, item: Type[ast.FloatType]) -> Callable[..., ComplexVar]: |
| 279 | + def __class_getitem__(cls, item: ast.FloatType) -> Callable[..., ComplexVar]: |
277 | 280 | return functools.partial(cls, base_type=item)
|
278 | 281 |
|
279 | 282 | def __init__(
|
280 | 283 | self,
|
281 | 284 | init_expression: AstConvertible | None = None,
|
282 | 285 | *args: Any,
|
283 |
| - base_type: Type[ast.FloatType] = float64, |
| 286 | + base_type: ast.FloatType = float64, |
284 | 287 | **kwargs: Any,
|
285 | 288 | ) -> None:
|
286 | 289 | assert isinstance(base_type, ast.FloatType)
|
| 290 | + self.base_type = base_type |
287 | 291 |
|
288 | 292 | if not isinstance(init_expression, (complex, type(None), OQPyExpression)):
|
289 | 293 | init_expression = complex(init_expression) # type: ignore[arg-type]
|
@@ -313,6 +317,80 @@ class StretchVar(_ClassicalVar):
|
313 | 317 | type_cls = ast.StretchType
|
314 | 318 |
|
315 | 319 |
|
| 320 | +AllowedArrayTypes = Union[_SizedVar, DurationVar, BoolVar, ComplexVar] |
| 321 | + |
| 322 | + |
| 323 | +class ArrayVar(_ClassicalVar): |
| 324 | + """An oqpy array variable.""" |
| 325 | + |
| 326 | + type_cls = ast.ArrayType |
| 327 | + dimensions: list[int] |
| 328 | + base_type: type[AllowedArrayTypes] |
| 329 | + |
| 330 | + def __class_getitem__( |
| 331 | + cls, item: tuple[type[AllowedArrayTypes], int] | type[AllowedArrayTypes] |
| 332 | + ) -> Callable[..., ArrayVar]: |
| 333 | + # Allows usage like ArrayVar[FloatVar, 32](...) or ArrayVar[FloatVar] |
| 334 | + if isinstance(item, tuple): |
| 335 | + base_type = item[0] |
| 336 | + dimensions = list(item[1:]) |
| 337 | + return functools.partial(cls, dimensions=dimensions, base_type=base_type) |
| 338 | + else: |
| 339 | + return functools.partial(cls, base_type=item) |
| 340 | + |
| 341 | + def __init__( |
| 342 | + self, |
| 343 | + *args: Any, |
| 344 | + dimensions: list[int], |
| 345 | + base_type: type[AllowedArrayTypes] = IntVar, |
| 346 | + **kwargs: Any, |
| 347 | + ) -> None: |
| 348 | + self.dimensions = dimensions |
| 349 | + self.base_type = base_type |
| 350 | + |
| 351 | + # Creating a dummy variable supports IntVar[64] etc. |
| 352 | + base_type_instance = base_type() |
| 353 | + if isinstance(base_type_instance, _SizedVar): |
| 354 | + array_base_type = base_type_instance.type_cls( |
| 355 | + size=ast.IntegerLiteral(base_type_instance.size) |
| 356 | + ) |
| 357 | + elif isinstance(base_type_instance, ComplexVar): |
| 358 | + array_base_type = base_type_instance.type_cls(base_type=base_type_instance.base_type) |
| 359 | + else: |
| 360 | + array_base_type = base_type_instance.type_cls() |
| 361 | + |
| 362 | + # Automatically handle Duration array. |
| 363 | + if base_type is DurationVar and kwargs["init_expression"]: |
| 364 | + kwargs["init_expression"] = (make_duration(i) for i in kwargs["init_expression"]) |
| 365 | + |
| 366 | + super().__init__( |
| 367 | + *args, |
| 368 | + **kwargs, |
| 369 | + dimensions=[ast.IntegerLiteral(dimension) for dimension in dimensions], |
| 370 | + base_type=array_base_type, |
| 371 | + ) |
| 372 | + |
| 373 | + def __getitem__(self, index: AstConvertible) -> OQIndexExpression: |
| 374 | + return OQIndexExpression(collection=self, index=index) |
| 375 | + |
| 376 | + |
| 377 | +class OQIndexExpression(OQPyExpression): |
| 378 | + """An oqpy expression corresponding to an index expression.""" |
| 379 | + |
| 380 | + def __init__(self, collection: AstConvertible, index: AstConvertible): |
| 381 | + self.collection = collection |
| 382 | + self.index = index |
| 383 | + |
| 384 | + if isinstance(collection, ArrayVar): |
| 385 | + self.type = collection.base_type().type_cls() |
| 386 | + |
| 387 | + def to_ast(self, program: Program) -> ast.IndexExpression: |
| 388 | + """Converts this oqpy index expression into an ast node.""" |
| 389 | + return ast.IndexExpression( |
| 390 | + collection=to_ast(program, self.collection), index=[to_ast(program, self.index)] |
| 391 | + ) |
| 392 | + |
| 393 | + |
316 | 394 | class OQFunctionCall(OQPyExpression):
|
317 | 395 | """An oqpy expression corresponding to a function call."""
|
318 | 396 |
|
|
0 commit comments