|
4 | 4 | from dataclasses import dataclass
|
5 | 5 | import operator
|
6 | 6 | import textwrap
|
7 |
| -from typing import Any, NewType |
| 7 | +from typing import Any |
8 | 8 |
|
9 | 9 | from numba.core import ir
|
10 |
| -from numba.core import types |
11 | 10 | from numba.core import bytecode
|
12 | 11 | from numba.core import controlflow
|
13 |
| -from numba.core.types import Type as NumbaType |
| 12 | +import numba.core.types as nt |
14 | 13 |
|
15 |
| -from heir.mlir.types import MLIRType, MLIR_TYPES, I1, I8, I16, I32, I64, F32, F64 |
16 |
| -from heir.interfaces import CompilerError, DebugMessage, InternalCompilerError |
| 14 | +from heir.mlir import types as mt |
| 15 | +from heir.interfaces import CompilerError, InternalCompilerError |
| 16 | + |
| 17 | +NumbaType = nt.Type |
17 | 18 |
|
18 | 19 |
|
19 | 20 | def mlirType(numba_type: NumbaType) -> str:
|
20 |
| - if isinstance(numba_type, types.Integer): |
21 |
| - # TODO (#1162): fix handling of signedness |
22 |
| - # Since `arith` only allows signless integers, we ignore signedness here. |
23 |
| - return "i" + str(numba_type.bitwidth) |
24 |
| - if isinstance(numba_type, types.RangeType): |
25 |
| - return mlirType(numba_type.dtype) |
26 |
| - if isinstance(numba_type, types.Boolean): |
27 |
| - return "i1" |
28 |
| - if isinstance(numba_type, types.Float): |
29 |
| - return "f" + str(numba_type.bitwidth) |
30 |
| - if isinstance(numba_type, types.Complex): |
31 |
| - return "complex<" + str(numba_type.bitwidth) + ">" |
32 |
| - if isinstance(numba_type, types.Array): |
33 |
| - shape = None |
34 |
| - if hasattr(numba_type, "shape"): |
35 |
| - shape = "x".join(str(s) for s in numba_type.shape) # type: ignore |
36 |
| - return "tensor<" + "?x" * numba_type.ndim + mlirType(numba_type.dtype) + ">" |
37 |
| - raise InternalCompilerError("Unsupported type: " + str(numba_type)) |
38 |
| - |
39 |
| - |
40 |
| -def isIntegerLike(typ: NumbaType | MLIRType) -> bool: |
41 |
| - if isinstance(typ, type) and issubclass(typ, MLIRType): |
42 |
| - return typ in {I1, I8, I16, I32, I64} |
43 |
| - if isinstance(typ, NumbaType): |
44 |
| - return isinstance(typ, types.Integer) or isinstance(typ, types.Boolean) |
45 |
| - raise InternalCompilerError(f"Encountered unexpected type {typ}") |
46 |
| - |
47 |
| - |
48 |
| -def isFloatLike(typ: NumbaType | MLIRType) -> bool: |
49 |
| - if isinstance(typ, type) and issubclass(type, MLIRType): |
50 |
| - return typ in {F32, F64} |
51 |
| - if isinstance(typ, NumbaType): |
52 |
| - return isinstance(typ, types.Float) |
53 |
| - raise InternalCompilerError(f"Encountered unexpected type {typ}") |
| 21 | + match numba_type: |
| 22 | + case nt.Integer(): |
| 23 | + # TODO (#1162): fix handling of signedness |
| 24 | + # Since `arith` only allows signless integers, we ignore signedness here. |
| 25 | + return "i" + str(numba_type.bitwidth) |
| 26 | + case nt.RangeType(): |
| 27 | + return mlirType(numba_type.dtype) |
| 28 | + case nt.Boolean(): |
| 29 | + return "i1" |
| 30 | + case nt.Float(): |
| 31 | + return "f" + str(numba_type.bitwidth) |
| 32 | + case nt.Complex(): |
| 33 | + return "complex<" + str(numba_type.bitwidth) + ">" |
| 34 | + case nt.Array(): |
| 35 | + return ( |
| 36 | + "tensor<" + "?x" * numba_type.ndim + mlirType(numba_type.dtype) + ">" |
| 37 | + ) |
| 38 | + case _: |
| 39 | + raise InternalCompilerError("Unsupported type: " + str(numba_type)) |
| 40 | + |
| 41 | + |
| 42 | +def isIntegerLike(typ: NumbaType | mt.MLIRType) -> bool: |
| 43 | + # note: the match-case discrepancies here—using nt.Integer() instead of |
| 44 | + # nt.Integer and mt.I1 instead of mt.I1()—are due to the fact that the MLIR |
| 45 | + # types passed in refer to the class itself, while the Numba types refer to |
| 46 | + # instances of a class. |
| 47 | + match typ: |
| 48 | + case mt.I1 | mt.I8 | mt.I16 | mt.I32 | mt.I64 | nt.Integer() | nt.Boolean(): |
| 49 | + return True |
| 50 | + case mt.MLIRType | nt.Type: |
| 51 | + return False |
| 52 | + case _: |
| 53 | + raise InternalCompilerError( |
| 54 | + f"Encountered unexpected type {typ} of type {type(typ)}" |
| 55 | + ) |
| 56 | + |
| 57 | + |
| 58 | +def isFloatLike(typ: NumbaType | mt.MLIRType) -> bool: |
| 59 | + match typ: |
| 60 | + case nt.F32 | nt.F64 | nt.Float(): |
| 61 | + return True |
| 62 | + case mt.MLIRType | nt.Type: |
| 63 | + return False |
| 64 | + case _: |
| 65 | + raise InternalCompilerError( |
| 66 | + f"Encountered unexpected type {typ} of type {type(typ)}" |
| 67 | + ) |
54 | 68 |
|
55 | 69 |
|
56 | 70 | # Needed because, e.g. Boolean doesn't have a bitwidth
|
57 |
| -def getBitwidth(typ: NumbaType | MLIRType) -> int: |
58 |
| - if isinstance(typ, type) and issubclass(typ, MLIRType): |
59 |
| - if typ in {I1, I8, I16, I32, I64}: |
| 71 | +def getBitwidth(typ: NumbaType | mt.MLIRType) -> int: |
| 72 | + match typ: |
| 73 | + case mt.I1 | mt.I8 | mt.I16 | mt.I32 | mt.I64: |
60 | 74 | # e.g., <class 'heir.mlir.types.I32'>.__name__ -> "I32" -> "32"
|
61 | 75 | return int(typ.__name__[1:])
|
62 |
| - if isinstance(typ, types.Integer): |
63 |
| - return typ.bitwidth |
64 |
| - if isinstance(typ, types.Boolean): |
65 |
| - return 1 |
66 |
| - raise InternalCompilerError(f"unexpected type {typ} ({type(typ)})") |
| 76 | + case nt.Integer() as int_ty: |
| 77 | + return int_ty.bitwidth |
| 78 | + case nt.Boolean(): |
| 79 | + return 1 |
| 80 | + case _: |
| 81 | + raise InternalCompilerError( |
| 82 | + f"Encountered unexpected type {typ} of type {type(typ)}" |
| 83 | + ) |
67 | 84 |
|
68 | 85 |
|
69 | 86 | def mlirCastOp(
|
70 |
| - from_type: NumbaType, to_type: MLIRType, value: str, loc: ir.Loc |
| 87 | + from_type: NumbaType, to_type: mt.MLIRType, value: str, loc: ir.Loc |
71 | 88 | ) -> str:
|
72 | 89 | if isIntegerLike(from_type) and isIntegerLike(to_type):
|
73 | 90 | from_width = getBitwidth(from_type)
|
@@ -112,19 +129,15 @@ def mlirLoc(loc: ir.Loc) -> str:
|
112 | 129 |
|
113 | 130 | def arithSuffix(numba_type: NumbaType) -> str:
|
114 | 131 | """Helper to translate numba types to the associated arith dialect operation suffixes"""
|
115 |
| - if isinstance(numba_type, types.Integer): |
116 |
| - return "i" |
117 |
| - if isinstance(numba_type, types.Boolean): |
118 |
| - return "i" |
119 |
| - if isinstance(numba_type, types.Float): |
120 |
| - return "f" |
121 |
| - if isinstance(numba_type, types.Complex): |
122 |
| - raise InternalCompilerError( |
123 |
| - "Complex numbers not supported in `arith` dialect" |
124 |
| - ) |
125 |
| - if isinstance(numba_type, types.Array): |
126 |
| - return arithSuffix(numba_type.dtype) |
127 |
| - raise InternalCompilerError("Unsupported type: " + str(numba_type)) |
| 132 | + match numba_type: |
| 133 | + case nt.Integer() | nt.Boolean(): |
| 134 | + return "i" |
| 135 | + case nt.Float(): |
| 136 | + return "f" |
| 137 | + case nt.Array(): |
| 138 | + return arithSuffix(numba_type.dtype) |
| 139 | + case _: |
| 140 | + raise InternalCompilerError("Unsupported type: " + str(numba_type)) |
128 | 141 |
|
129 | 142 |
|
130 | 143 | class HeaderInfo:
|
@@ -495,7 +508,7 @@ def emit_assign(self, assign):
|
495 | 508 | # nothing to do, forward the name to the arg of bool()
|
496 | 509 | self.forward_name(from_var=assign.target, to_var=assign.value.args[0])
|
497 | 510 | return ""
|
498 |
| - if global_ in MLIR_TYPES: |
| 511 | + if global_ in mt.MLIR_TYPES: |
499 | 512 | if len(assign.value.args) != 1:
|
500 | 513 | raise CompilerError(
|
501 | 514 | "MLIR type cast requires exactly one argument", assign.value.loc
|
|
0 commit comments