Skip to content

Commit 6466426

Browse files
committed
types.py: use match/case
The main issues not allowing this to be done in the review for google#1825: - Bare words are treated as a variable capture unless scoped (with a module prefix, e.g., `mt.I1`) or instantiated (`I1()`). - The "MLIRType" values passed to these functions are not actually instances of MLIRType (the type hint is still incorrect in this PR), but rather instances of `class`. We cannot rectify this by instantiating the MLIRType to simplify the match syntax (e.g., use `I1()` instead of `mt.I1`) because google#1825 added a member variable on MLIRType so the type class could be used as a cast. Possibly an improvement would be to add a no-value constructor to MLIRType so that `I1()` can be instantiated as a standalone type, and then call these helpers with actual instances of MLIRType rather than class constructor objects.
1 parent 509b627 commit 6466426

File tree

2 files changed

+81
-58
lines changed

2 files changed

+81
-58
lines changed

frontend/heir/mlir/types.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
"""Defines Python type annotations for MLIR types."""
22

33
from abc import ABC, abstractmethod
4-
from typing import Generic, Self, TypeVar, TypeVarTuple, get_args, get_origin
4+
from typing import Generic, TypeVar, TypeVarTuple, get_args, get_origin, Optional
55
from numba.core.types import Type as NumbaType
66
from numba.core.types import boolean, int8, int16, int32, int64, float32, float64
7-
from numba.extending import typeof_impl, type_callable
7+
from numba.extending import type_callable
88

99
T = TypeVar("T")
1010
Ts = TypeVarTuple("Ts")
@@ -22,8 +22,19 @@ def check_for_value(a: "MLIRType"):
2222

2323
class MLIRType(ABC):
2424

25-
def __init__(self, value: int):
26-
self.value = value
25+
def __init__(self, value: Optional[int] = None):
26+
# MLIRType subclasses are used in two ways:
27+
#
28+
# 1. As an explicit cast, in which case the result of the cast is a value
29+
# that can be further operated on.
30+
# 2. As a type, in which case there is no explicit value and the class
31+
# represents a standalone type.
32+
#
33+
# (2) is useful for match/case when the program is being analyzed for its
34+
# types. (1) is useful when allowing a program typed with heir to also run
35+
# as standard Python code.
36+
if value is not None:
37+
self.value = value
2738

2839
def __int__(self):
2940
check_for_value(self)

frontend/heir/mlir_emitter.py

Lines changed: 66 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -4,66 +4,79 @@
44
from dataclasses import dataclass
55
import operator
66
import textwrap
7-
from typing import Any, NewType
7+
from typing import Any
88

99
from numba.core import ir
10-
from numba.core import types
1110
from numba.core import bytecode
1211
from numba.core import controlflow
13-
from numba.core.types import Type as NumbaType
12+
import numba.core.types as nt
1413

15-
from heir.mlir.types import MLIRType, MLIR_TYPES, I1, I8, I16, I32, I64, F32, F64
16-
from heir.interfaces import CompilerError, DebugMessage, InternalCompilerError
14+
from heir.mlir import I1, I8, I16, I32, I64, MLIRType, MLIR_TYPES
15+
from heir.interfaces import CompilerError, InternalCompilerError
16+
17+
NumbaType = nt.Type
1718

1819

1920
def mlirType(numba_type: NumbaType) -> str:
20-
if isinstance(numba_type, types.Integer):
21-
# TODO (#1162): fix handling of signedness
22-
# Since `arith` only allows signless integers, we ignore signedness here.
23-
return "i" + str(numba_type.bitwidth)
24-
if isinstance(numba_type, types.RangeType):
25-
return mlirType(numba_type.dtype)
26-
if isinstance(numba_type, types.Boolean):
27-
return "i1"
28-
if isinstance(numba_type, types.Float):
29-
return "f" + str(numba_type.bitwidth)
30-
if isinstance(numba_type, types.Complex):
31-
return "complex<" + str(numba_type.bitwidth) + ">"
32-
if isinstance(numba_type, types.Array):
33-
shape = None
34-
if hasattr(numba_type, "shape"):
35-
shape = "x".join(str(s) for s in numba_type.shape) # type: ignore
36-
return "tensor<" + "?x" * numba_type.ndim + mlirType(numba_type.dtype) + ">"
37-
raise InternalCompilerError("Unsupported type: " + str(numba_type))
21+
match numba_type:
22+
case nt.Integer():
23+
# TODO (#1162): fix handling of signedness
24+
# Since `arith` only allows signless integers, we ignore signedness here.
25+
return "i" + str(numba_type.bitwidth)
26+
case nt.RangeType():
27+
return mlirType(numba_type.dtype)
28+
case nt.Boolean():
29+
return "i1"
30+
case nt.Float():
31+
return "f" + str(numba_type.bitwidth)
32+
case nt.Complex():
33+
return "complex<" + str(numba_type.bitwidth) + ">"
34+
case nt.Array():
35+
return (
36+
"tensor<" + "?x" * numba_type.ndim + mlirType(numba_type.dtype) + ">"
37+
)
38+
case _:
39+
raise InternalCompilerError("Unsupported type: " + str(numba_type))
3840

3941

4042
def isIntegerLike(typ: NumbaType | MLIRType) -> bool:
41-
if isinstance(typ, type) and issubclass(typ, MLIRType):
42-
return typ in {I1, I8, I16, I32, I64}
43-
if isinstance(typ, NumbaType):
44-
return isinstance(typ, types.Integer) or isinstance(typ, types.Boolean)
45-
raise InternalCompilerError(f"Encountered unexpected type {typ}")
43+
match typ:
44+
case I1() | I8() | I16() | I32() | I64() | nt.Integer() | nt.Boolean():
45+
return True
46+
case MLIRType() | nt.Type:
47+
return False
48+
case _:
49+
raise InternalCompilerError(
50+
f"Encountered unexpected type {typ} of type {type(typ)}"
51+
)
4652

4753

4854
def isFloatLike(typ: NumbaType | MLIRType) -> bool:
49-
if isinstance(typ, type) and issubclass(type, MLIRType):
50-
return typ in {F32, F64}
51-
if isinstance(typ, NumbaType):
52-
return isinstance(typ, types.Float)
53-
raise InternalCompilerError(f"Encountered unexpected type {typ}")
55+
match typ:
56+
case nt.F32 | nt.F64 | nt.Float():
57+
return True
58+
case MLIRType() | nt.Type:
59+
return False
60+
case _:
61+
raise InternalCompilerError(
62+
f"Encountered unexpected type {typ} of type {type(typ)}"
63+
)
5464

5565

5666
# Needed because, e.g. Boolean doesn't have a bitwidth
5767
def getBitwidth(typ: NumbaType | MLIRType) -> int:
58-
if isinstance(typ, type) and issubclass(typ, MLIRType):
59-
if typ in {I1, I8, I16, I32, I64}:
68+
match typ:
69+
case I1() | I8() | I16() | I32() | I64():
6070
# e.g., <class 'heir.mlir.types.I32'>.__name__ -> "I32" -> "32"
6171
return int(typ.__name__[1:])
62-
if isinstance(typ, types.Integer):
63-
return typ.bitwidth
64-
if isinstance(typ, types.Boolean):
65-
return 1
66-
raise InternalCompilerError(f"unexpected type {typ} ({type(typ)})")
72+
case nt.Integer() as int_ty:
73+
return int_ty.bitwidth
74+
case nt.Boolean():
75+
return 1
76+
case _:
77+
raise InternalCompilerError(
78+
f"Encountered unexpected type {typ} of type {type(typ)}"
79+
)
6780

6881

6982
def mlirCastOp(
@@ -112,19 +125,15 @@ def mlirLoc(loc: ir.Loc) -> str:
112125

113126
def arithSuffix(numba_type: NumbaType) -> str:
114127
"""Helper to translate numba types to the associated arith dialect operation suffixes"""
115-
if isinstance(numba_type, types.Integer):
116-
return "i"
117-
if isinstance(numba_type, types.Boolean):
118-
return "i"
119-
if isinstance(numba_type, types.Float):
120-
return "f"
121-
if isinstance(numba_type, types.Complex):
122-
raise InternalCompilerError(
123-
"Complex numbers not supported in `arith` dialect"
124-
)
125-
if isinstance(numba_type, types.Array):
126-
return arithSuffix(numba_type.dtype)
127-
raise InternalCompilerError("Unsupported type: " + str(numba_type))
128+
match numba_type:
129+
case nt.Integer() | nt.Boolean():
130+
return "i"
131+
case nt.Float():
132+
return "f"
133+
case nt.Array():
134+
return arithSuffix(numba_type.dtype)
135+
case _:
136+
raise InternalCompilerError("Unsupported type: " + str(numba_type))
128137

129138

130139
class HeaderInfo:
@@ -514,9 +523,12 @@ def emit_assign(self, assign):
514523
)
515524
target_ssa = self.get_or_create_name(assign.target)
516525
ssa_id = self.get_or_create_name(assign.value.args[0])
526+
527+
# Construct an instance of the MLIR type in question
528+
mlir_type_instance = global_()
517529
cast = mlirCastOp(
518530
self.typemap.get(value),
519-
global_,
531+
mlir_type_instance,
520532
ssa_id,
521533
assign.loc,
522534
)

0 commit comments

Comments
 (0)