|
4 | 4 | from dataclasses import dataclass
|
5 | 5 | import operator
|
6 | 6 | import textwrap
|
7 |
| -from typing import Any, NewType |
| 7 | +from typing import Any |
8 | 8 |
|
9 | 9 | from numba.core import ir
|
10 |
| -from numba.core import types |
11 | 10 | from numba.core import bytecode
|
12 | 11 | from numba.core import controlflow
|
13 |
| -from numba.core.types import Type as NumbaType |
| 12 | +import numba.core.types as nt |
14 | 13 |
|
15 |
| -from heir.mlir.types import MLIRType, MLIR_TYPES, I1, I8, I16, I32, I64, F32, F64 |
16 |
| -from heir.interfaces import CompilerError, DebugMessage, InternalCompilerError |
| 14 | +from heir.mlir import I1, I8, I16, I32, I64, MLIRType, MLIR_TYPES |
| 15 | +from heir.interfaces import CompilerError, InternalCompilerError |
| 16 | + |
| 17 | +NumbaType = nt.Type |
17 | 18 |
|
18 | 19 |
|
19 | 20 | def mlirType(numba_type: NumbaType) -> str:
|
20 |
| - if isinstance(numba_type, types.Integer): |
21 |
| - # TODO (#1162): fix handling of signedness |
22 |
| - # Since `arith` only allows signless integers, we ignore signedness here. |
23 |
| - return "i" + str(numba_type.bitwidth) |
24 |
| - if isinstance(numba_type, types.RangeType): |
25 |
| - return mlirType(numba_type.dtype) |
26 |
| - if isinstance(numba_type, types.Boolean): |
27 |
| - return "i1" |
28 |
| - if isinstance(numba_type, types.Float): |
29 |
| - return "f" + str(numba_type.bitwidth) |
30 |
| - if isinstance(numba_type, types.Complex): |
31 |
| - return "complex<" + str(numba_type.bitwidth) + ">" |
32 |
| - if isinstance(numba_type, types.Array): |
33 |
| - shape = None |
34 |
| - if hasattr(numba_type, "shape"): |
35 |
| - shape = "x".join(str(s) for s in numba_type.shape) # type: ignore |
36 |
| - return "tensor<" + "?x" * numba_type.ndim + mlirType(numba_type.dtype) + ">" |
37 |
| - raise InternalCompilerError("Unsupported type: " + str(numba_type)) |
| 21 | + match numba_type: |
| 22 | + case nt.Integer(): |
| 23 | + # TODO (#1162): fix handling of signedness |
| 24 | + # Since `arith` only allows signless integers, we ignore signedness here. |
| 25 | + return "i" + str(numba_type.bitwidth) |
| 26 | + case nt.RangeType(): |
| 27 | + return mlirType(numba_type.dtype) |
| 28 | + case nt.Boolean(): |
| 29 | + return "i1" |
| 30 | + case nt.Float(): |
| 31 | + return "f" + str(numba_type.bitwidth) |
| 32 | + case nt.Complex(): |
| 33 | + return "complex<" + str(numba_type.bitwidth) + ">" |
| 34 | + case nt.Array(): |
| 35 | + return ( |
| 36 | + "tensor<" + "?x" * numba_type.ndim + mlirType(numba_type.dtype) + ">" |
| 37 | + ) |
| 38 | + case _: |
| 39 | + raise InternalCompilerError("Unsupported type: " + str(numba_type)) |
38 | 40 |
|
39 | 41 |
|
40 | 42 | def isIntegerLike(typ: NumbaType | MLIRType) -> bool:
|
41 |
| - if isinstance(typ, type) and issubclass(typ, MLIRType): |
42 |
| - return typ in {I1, I8, I16, I32, I64} |
43 |
| - if isinstance(typ, NumbaType): |
44 |
| - return isinstance(typ, types.Integer) or isinstance(typ, types.Boolean) |
45 |
| - raise InternalCompilerError(f"Encountered unexpected type {typ}") |
| 43 | + match typ: |
| 44 | + case I1() | I8() | I16() | I32() | I64() | nt.Integer() | nt.Boolean(): |
| 45 | + return True |
| 46 | + case MLIRType() | nt.Type: |
| 47 | + return False |
| 48 | + case _: |
| 49 | + raise InternalCompilerError( |
| 50 | + f"Encountered unexpected type {typ} of type {type(typ)}" |
| 51 | + ) |
46 | 52 |
|
47 | 53 |
|
48 | 54 | def isFloatLike(typ: NumbaType | MLIRType) -> bool:
|
49 |
| - if isinstance(typ, type) and issubclass(type, MLIRType): |
50 |
| - return typ in {F32, F64} |
51 |
| - if isinstance(typ, NumbaType): |
52 |
| - return isinstance(typ, types.Float) |
53 |
| - raise InternalCompilerError(f"Encountered unexpected type {typ}") |
| 55 | + match typ: |
| 56 | + case nt.F32 | nt.F64 | nt.Float(): |
| 57 | + return True |
| 58 | + case MLIRType() | nt.Type: |
| 59 | + return False |
| 60 | + case _: |
| 61 | + raise InternalCompilerError( |
| 62 | + f"Encountered unexpected type {typ} of type {type(typ)}" |
| 63 | + ) |
54 | 64 |
|
55 | 65 |
|
56 | 66 | # Needed because, e.g. Boolean doesn't have a bitwidth
|
57 | 67 | def getBitwidth(typ: NumbaType | MLIRType) -> int:
|
58 |
| - if isinstance(typ, type) and issubclass(typ, MLIRType): |
59 |
| - if typ in {I1, I8, I16, I32, I64}: |
| 68 | + match typ: |
| 69 | + case I1() | I8() | I16() | I32() | I64(): |
60 | 70 | # e.g., <class 'heir.mlir.types.I32'>.__name__ -> "I32" -> "32"
|
61 | 71 | return int(typ.__name__[1:])
|
62 |
| - if isinstance(typ, types.Integer): |
63 |
| - return typ.bitwidth |
64 |
| - if isinstance(typ, types.Boolean): |
65 |
| - return 1 |
66 |
| - raise InternalCompilerError(f"unexpected type {typ} ({type(typ)})") |
| 72 | + case nt.Integer() as int_ty: |
| 73 | + return int_ty.bitwidth |
| 74 | + case nt.Boolean(): |
| 75 | + return 1 |
| 76 | + case _: |
| 77 | + raise InternalCompilerError( |
| 78 | + f"Encountered unexpected type {typ} of type {type(typ)}" |
| 79 | + ) |
67 | 80 |
|
68 | 81 |
|
69 | 82 | def mlirCastOp(
|
@@ -112,19 +125,15 @@ def mlirLoc(loc: ir.Loc) -> str:
|
112 | 125 |
|
113 | 126 | def arithSuffix(numba_type: NumbaType) -> str:
|
114 | 127 | """Helper to translate numba types to the associated arith dialect operation suffixes"""
|
115 |
| - if isinstance(numba_type, types.Integer): |
116 |
| - return "i" |
117 |
| - if isinstance(numba_type, types.Boolean): |
118 |
| - return "i" |
119 |
| - if isinstance(numba_type, types.Float): |
120 |
| - return "f" |
121 |
| - if isinstance(numba_type, types.Complex): |
122 |
| - raise InternalCompilerError( |
123 |
| - "Complex numbers not supported in `arith` dialect" |
124 |
| - ) |
125 |
| - if isinstance(numba_type, types.Array): |
126 |
| - return arithSuffix(numba_type.dtype) |
127 |
| - raise InternalCompilerError("Unsupported type: " + str(numba_type)) |
| 128 | + match numba_type: |
| 129 | + case nt.Integer() | nt.Boolean(): |
| 130 | + return "i" |
| 131 | + case nt.Float(): |
| 132 | + return "f" |
| 133 | + case nt.Array(): |
| 134 | + return arithSuffix(numba_type.dtype) |
| 135 | + case _: |
| 136 | + raise InternalCompilerError("Unsupported type: " + str(numba_type)) |
128 | 137 |
|
129 | 138 |
|
130 | 139 | class HeaderInfo:
|
@@ -514,9 +523,12 @@ def emit_assign(self, assign):
|
514 | 523 | )
|
515 | 524 | target_ssa = self.get_or_create_name(assign.target)
|
516 | 525 | ssa_id = self.get_or_create_name(assign.value.args[0])
|
| 526 | + |
| 527 | + # Construct an instance of the MLIR type in question |
| 528 | + mlir_type_instance = global_() |
517 | 529 | cast = mlirCastOp(
|
518 | 530 | self.typemap.get(value),
|
519 |
| - global_, |
| 531 | + mlir_type_instance, |
520 | 532 | ssa_id,
|
521 | 533 | assign.loc,
|
522 | 534 | )
|
|
0 commit comments