Skip to content

Commit e767a18

Browse files
committed
types.py: use match/case
The main issues not allowing this to be done in the review for google#1825: - Bare words are treated as a variable capture unless scoped (with a module prefix, e.g., `mt.I1`) or instantiated (`I1()`). - The "MLIRType" values passed to these functions are not actually instances of MLIRType (the type hint is still incorrect in this PR), but rather instances of `class`. We cannot rectify this by instantiating the MLIRType to simplify the match syntax (e.g., use `I1()` instead of `mt.I1`) because google#1825 added a member variable on MLIRType so the type class could be used as a cast. Possibly an improvement would be to add a no-value constructor to MLIRType so that `I1()` can be instantiated as a standalone type, and then call these helpers with actual instances of MLIRType rather than class constructor objects.
1 parent 1fabdbc commit e767a18

File tree

2 files changed

+77
-64
lines changed

2 files changed

+77
-64
lines changed

frontend/heir/mlir/types.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
"""Defines Python type annotations for MLIR types."""
22

33
from abc import ABC, abstractmethod
4-
from typing import Generic, Self, TypeVar, TypeVarTuple, get_args, get_origin
4+
from typing import Generic, TypeVar, TypeVarTuple, get_args, get_origin
55
from numba.core.types import Type as NumbaType
66
from numba.core.types import boolean, int8, int16, int32, int64, float32, float64
7-
from numba.extending import typeof_impl, type_callable
7+
from numba.extending import type_callable
88

99
T = TypeVar("T")
1010
Ts = TypeVarTuple("Ts")

frontend/heir/mlir_emitter.py

Lines changed: 75 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -4,70 +4,87 @@
44
from dataclasses import dataclass
55
import operator
66
import textwrap
7-
from typing import Any, NewType
7+
from typing import Any
88

99
from numba.core import ir
10-
from numba.core import types
1110
from numba.core import bytecode
1211
from numba.core import controlflow
13-
from numba.core.types import Type as NumbaType
12+
import numba.core.types as nt
1413

15-
from heir.mlir.types import MLIRType, MLIR_TYPES, I1, I8, I16, I32, I64, F32, F64
16-
from heir.interfaces import CompilerError, DebugMessage, InternalCompilerError
14+
from heir.mlir import types as mt
15+
from heir.interfaces import CompilerError, InternalCompilerError
16+
17+
NumbaType = nt.Type
1718

1819

1920
def mlirType(numba_type: NumbaType) -> str:
20-
if isinstance(numba_type, types.Integer):
21-
# TODO (#1162): fix handling of signedness
22-
# Since `arith` only allows signless integers, we ignore signedness here.
23-
return "i" + str(numba_type.bitwidth)
24-
if isinstance(numba_type, types.RangeType):
25-
return mlirType(numba_type.dtype)
26-
if isinstance(numba_type, types.Boolean):
27-
return "i1"
28-
if isinstance(numba_type, types.Float):
29-
return "f" + str(numba_type.bitwidth)
30-
if isinstance(numba_type, types.Complex):
31-
return "complex<" + str(numba_type.bitwidth) + ">"
32-
if isinstance(numba_type, types.Array):
33-
shape = None
34-
if hasattr(numba_type, "shape"):
35-
shape = "x".join(str(s) for s in numba_type.shape) # type: ignore
36-
return "tensor<" + "?x" * numba_type.ndim + mlirType(numba_type.dtype) + ">"
37-
raise InternalCompilerError("Unsupported type: " + str(numba_type))
38-
39-
40-
def isIntegerLike(typ: NumbaType | MLIRType) -> bool:
41-
if isinstance(typ, type) and issubclass(typ, MLIRType):
42-
return typ in {I1, I8, I16, I32, I64}
43-
if isinstance(typ, NumbaType):
44-
return isinstance(typ, types.Integer) or isinstance(typ, types.Boolean)
45-
raise InternalCompilerError(f"Encountered unexpected type {typ}")
46-
47-
48-
def isFloatLike(typ: NumbaType | MLIRType) -> bool:
49-
if isinstance(typ, type) and issubclass(type, MLIRType):
50-
return typ in {F32, F64}
51-
if isinstance(typ, NumbaType):
52-
return isinstance(typ, types.Float)
53-
raise InternalCompilerError(f"Encountered unexpected type {typ}")
21+
match numba_type:
22+
case nt.Integer():
23+
# TODO (#1162): fix handling of signedness
24+
# Since `arith` only allows signless integers, we ignore signedness here.
25+
return "i" + str(numba_type.bitwidth)
26+
case nt.RangeType():
27+
return mlirType(numba_type.dtype)
28+
case nt.Boolean():
29+
return "i1"
30+
case nt.Float():
31+
return "f" + str(numba_type.bitwidth)
32+
case nt.Complex():
33+
return "complex<" + str(numba_type.bitwidth) + ">"
34+
case nt.Array():
35+
return (
36+
"tensor<" + "?x" * numba_type.ndim + mlirType(numba_type.dtype) + ">"
37+
)
38+
case _:
39+
raise InternalCompilerError("Unsupported type: " + str(numba_type))
40+
41+
42+
def isIntegerLike(typ: NumbaType | mt.MLIRType) -> bool:
43+
# note: the match-case discrepancies here—using nt.Integer() instead of
44+
# nt.Integer and mt.I1 instead of mt.I1()—are due to the fact that the MLIR
45+
# types passed in refer to the class itself, while the Numba types refer to
46+
# instances of a class.
47+
match typ:
48+
case mt.I1 | mt.I8 | mt.I16 | mt.I32 | mt.I64 | nt.Integer() | nt.Boolean():
49+
return True
50+
case mt.MLIRType | nt.Type:
51+
return False
52+
case _:
53+
raise InternalCompilerError(
54+
f"Encountered unexpected type {typ} of type {type(typ)}"
55+
)
56+
57+
58+
def isFloatLike(typ: NumbaType | mt.MLIRType) -> bool:
59+
match typ:
60+
case nt.F32 | nt.F64 | nt.Float():
61+
return True
62+
case mt.MLIRType | nt.Type:
63+
return False
64+
case _:
65+
raise InternalCompilerError(
66+
f"Encountered unexpected type {typ} of type {type(typ)}"
67+
)
5468

5569

5670
# Needed because, e.g. Boolean doesn't have a bitwidth
57-
def getBitwidth(typ: NumbaType | MLIRType) -> int:
58-
if isinstance(typ, type) and issubclass(typ, MLIRType):
59-
if typ in {I1, I8, I16, I32, I64}:
71+
def getBitwidth(typ: NumbaType | mt.MLIRType) -> int:
72+
match typ:
73+
case mt.I1 | mt.I8 | mt.I16 | mt.I32 | mt.I64:
6074
# e.g., <class 'heir.mlir.types.I32'>.__name__ -> "I32" -> "32"
6175
return int(typ.__name__[1:])
62-
if isinstance(typ, types.Integer):
63-
return typ.bitwidth
64-
if isinstance(typ, types.Boolean):
65-
return 1
66-
raise InternalCompilerError(f"unexpected type {typ} ({type(typ)})")
76+
case nt.Integer() as int_ty:
77+
return int_ty.bitwidth
78+
case nt.Boolean():
79+
return 1
80+
case _:
81+
raise InternalCompilerError(
82+
f"Encountered unexpected type {typ} of type {type(typ)}"
83+
)
6784

6885

6986
def mlirCastOp(
70-
from_type: NumbaType, to_type: MLIRType, value: str, loc: ir.Loc
87+
from_type: NumbaType, to_type: mt.MLIRType, value: str, loc: ir.Loc
7188
) -> str:
7289
if isIntegerLike(from_type) and isIntegerLike(to_type):
7390
from_width = getBitwidth(from_type)
@@ -112,19 +129,15 @@ def mlirLoc(loc: ir.Loc) -> str:
112129

113130
def arithSuffix(numba_type: NumbaType) -> str:
114131
"""Helper to translate numba types to the associated arith dialect operation suffixes"""
115-
if isinstance(numba_type, types.Integer):
116-
return "i"
117-
if isinstance(numba_type, types.Boolean):
118-
return "i"
119-
if isinstance(numba_type, types.Float):
120-
return "f"
121-
if isinstance(numba_type, types.Complex):
122-
raise InternalCompilerError(
123-
"Complex numbers not supported in `arith` dialect"
124-
)
125-
if isinstance(numba_type, types.Array):
126-
return arithSuffix(numba_type.dtype)
127-
raise InternalCompilerError("Unsupported type: " + str(numba_type))
132+
match numba_type:
133+
case nt.Integer() | nt.Boolean():
134+
return "i"
135+
case nt.Float():
136+
return "f"
137+
case nt.Array():
138+
return arithSuffix(numba_type.dtype)
139+
case _:
140+
raise InternalCompilerError("Unsupported type: " + str(numba_type))
128141

129142

130143
class HeaderInfo:
@@ -495,7 +508,7 @@ def emit_assign(self, assign):
495508
# nothing to do, forward the name to the arg of bool()
496509
self.forward_name(from_var=assign.target, to_var=assign.value.args[0])
497510
return ""
498-
if global_ in MLIR_TYPES:
511+
if global_ in mt.MLIR_TYPES:
499512
if len(assign.value.args) != 1:
500513
raise CompilerError(
501514
"MLIR type cast requires exactly one argument", assign.value.loc

0 commit comments

Comments
 (0)