Skip to content

Commit b029d06

Browse files
(WIP) frontend: allow use of MLIR Types as "casts"
1 parent 92a955b commit b029d06

File tree

2 files changed

+178
-11
lines changed

2 files changed

+178
-11
lines changed

frontend/heir/mlir/types.py

Lines changed: 87 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,28 +4,64 @@
44
from typing import Generic, Self, TypeVar, TypeVarTuple, get_args, get_origin
55
from numba.core.types import Type as NumbaType
66
from numba.core.types import boolean, int8, int16, int32, int64, float32, float64
7+
from numba.extending import typeof_impl, type_callable
78

89
T = TypeVar("T")
910
Ts = TypeVarTuple("Ts")
1011

11-
operator_error_message = "MLIRType should only be used for annotations."
12+
# List of all MLIR types we define here, for use in other parts of the compiler
13+
MLIR_TYPES = [] # populated via MLIRType's __init_subclass__
14+
15+
16+
def check_for_value(a: "MLIRType"):
17+
if not hasattr(a, "value"):
18+
raise RuntimeError(
19+
"Trying to use an operator on an MLIRType without a value."
20+
)
1221

1322

1423
class MLIRType(ABC):
1524

25+
def __init__(self, value: int):
26+
self.value = value
27+
28+
def __init_subclass__(cls, **kwargs):
29+
super().__init_subclass__(**kwargs)
30+
MLIR_TYPES.append(cls)
31+
1632
@staticmethod
1733
@abstractmethod
1834
def numba_type() -> NumbaType:
1935
raise NotImplementedError("No numba type exists for a generic MLIRType")
2036

21-
def __add__(self, other) -> Self:
22-
raise RuntimeError(operator_error_message)
37+
@staticmethod
38+
@abstractmethod
39+
def mlir_type() -> str:
40+
raise NotImplementedError("No mlir type exists for a generic MLIRType")
41+
42+
def __add__(self, other):
43+
check_for_value(self)
44+
return self.value + other
2345

24-
def __sub__(self, other) -> Self:
25-
raise RuntimeError(operator_error_message)
46+
def __radd__(self, other):
47+
check_for_value(self)
48+
return other + self.value
2649

27-
def __mul__(self, other) -> Self:
28-
raise RuntimeError(operator_error_message)
50+
def __sub__(self, other):
51+
check_for_value(self)
52+
return self.value - other
53+
54+
def __rsub__(self, other):
55+
check_for_value(self)
56+
return other - self.value
57+
58+
def __mul__(self, other):
59+
check_for_value(self)
60+
return self.value * other
61+
62+
def __rmul__(self, other):
63+
check_for_value(self)
64+
return other * self.value
2965

3066

3167
class Secret(Generic[T], MLIRType):
@@ -34,62 +70,106 @@ class Secret(Generic[T], MLIRType):
3470
def numba_type() -> NumbaType:
3571
raise NotImplementedError("No numba type exists for a generic Secret")
3672

73+
@staticmethod
74+
def mlir_type() -> str:
75+
raise NotImplementedError("No mlir type exists for a generic Secret")
76+
3777

3878
class Tensor(Generic[*Ts], MLIRType):
3979

4080
@staticmethod
4181
def numba_type() -> NumbaType:
4282
raise NotImplementedError("No numba type exists for a generic Tensor")
4383

84+
@staticmethod
85+
def mlir_type() -> str:
86+
raise NotImplementedError("No mlir type exists for a generic Tensor")
87+
4488

4589
class F32(MLIRType):
4690
# TODO (#1162): For CKKS/Float: allow specifying actual intended precision/scale and warn/error if not achievable @staticmethod
4791
@staticmethod
4892
def numba_type() -> NumbaType:
4993
return float32
5094

95+
@staticmethod
96+
def mlir_type() -> str:
97+
return "f32"
98+
5199

52100
class F64(MLIRType):
53101
# TODO (#1162): For CKKS/Float: allow specifying actual intended precision/scale and warn/error if not achievable @staticmethod
54102
@staticmethod
55103
def numba_type() -> NumbaType:
56104
return float64
57105

106+
@staticmethod
107+
def mlir_type() -> str:
108+
return "f64"
109+
58110

59111
class I1(MLIRType):
60112

61113
@staticmethod
62114
def numba_type() -> NumbaType:
63115
return boolean
64116

117+
@staticmethod
118+
def mlir_type() -> str:
119+
return "i1"
120+
65121

66122
class I8(MLIRType):
67123

68124
@staticmethod
69125
def numba_type() -> NumbaType:
70126
return int8
71127

128+
@staticmethod
129+
def mlir_type() -> str:
130+
return "i8"
131+
72132

73133
class I16(MLIRType):
74134

75135
@staticmethod
76136
def numba_type() -> NumbaType:
77137
return int16
78138

139+
@staticmethod
140+
def mlir_type() -> str:
141+
return "i16"
142+
79143

80144
class I32(MLIRType):
81145

82146
@staticmethod
83147
def numba_type() -> NumbaType:
84148
return int32
85149

150+
@staticmethod
151+
def mlir_type() -> str:
152+
return "i32"
153+
86154

87155
class I64(MLIRType):
88156

89157
@staticmethod
90158
def numba_type() -> NumbaType:
91159
return int64
92160

161+
@staticmethod
162+
def mlir_type() -> str:
163+
return "i64"
164+
165+
166+
# Register the types defined above with Numba
167+
for typ in [I8, I16, I32, I64, I1, F32, F64]:
168+
169+
@type_callable(typ)
170+
def build_typer_function(context, typ=typ):
171+
return lambda value: typ.numba_type()
172+
93173

94174
# Helper functions
95175

frontend/heir/mlir_emitter.py

Lines changed: 91 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
from numba.core import controlflow
1313
from numba.core.types import Type as NumbaType
1414

15-
from heir.interfaces import InternalCompilerError
15+
from heir.mlir.types import MLIRType, MLIR_TYPES, I1, I8, I16, I32, I64, F32, F64
16+
from heir.interfaces import CompilerError, DebugMessage, InternalCompilerError
1617

1718

1819
def mlirType(numba_type: NumbaType) -> str:
@@ -36,6 +37,61 @@ def mlirType(numba_type: NumbaType) -> str:
3637
raise InternalCompilerError("Unsupported type: " + str(numba_type))
3738

3839

40+
def isIntegerLike(typ: NumbaType | MLIRType) -> bool:
41+
if isinstance(typ, type) and issubclass(typ, MLIRType):
42+
return typ in {I1, I8, I16, I32, I64}
43+
if isinstance(typ, NumbaType):
44+
return isinstance(typ, types.Integer) or isinstance(typ, types.Boolean)
45+
raise InternalCompilerError(f"Encountered unexpected type {typ}")
46+
47+
48+
def isFloatLike(typ: NumbaType | MLIRType) -> bool:
49+
if isinstance(typ, type) and issubclass(type, MLIRType):
50+
return typ in {F32, F64}
51+
if isinstance(typ, NumbaType):
52+
return isinstance(typ, types.Float)
53+
raise InternalCompilerError(f"Encountered unexpected type {type}")
54+
55+
56+
def mlirCastOp(
57+
from_type: NumbaType, to_type: MLIRType, value: str, loc: ir.Loc
58+
) -> str:
59+
if isIntegerLike(from_type) and isIntegerLike(to_type):
60+
if from_type.bitwidth == to_type.numba_type().bitwidth:
61+
raise CompilerError(
62+
f"Cannot create cast of {value} from {from_type} to {to_type} as they"
63+
" have the same bitwidth",
64+
loc,
65+
)
66+
if from_type.bitwidth > to_type.numba_type().bitwidth:
67+
return (
68+
f"arith.trunci {value} : {mlirType(from_type)} to"
69+
f" {to_type.mlir_type()} {mlirLoc(loc)}"
70+
)
71+
if from_type.bitwidth < to_type.numba_type().bitwidth:
72+
# FIXME: signedness for extensions?
73+
return (
74+
f"arith.extui {value} : {mlirType(from_type)} to"
75+
f" {to_type.mlir_type()} {mlirLoc(loc)}"
76+
)
77+
if isFloatLike(from_type) and isIntegerLike(to_type):
78+
# FIXME: signedness?
79+
return (
80+
f"arith.fptoui {value} : {mlirType(from_type)} to"
81+
f" {mlirType(to_type)} {mlirLoc(loc)}"
82+
)
83+
if isIntegerLike(from_type) and isFloatLike(to_type):
84+
# FIXME: signendess?
85+
return (
86+
f"arith.uitofp {value} : {mlirType(from_type)} to"
87+
f" {mlirType(to_type)} {mlirLoc(loc)}"
88+
)
89+
raise CompilerError(
90+
f"Encountered unsupported cast of {value} from {from_type} to {to_type}",
91+
loc,
92+
)
93+
94+
3995
def mlirLoc(loc: ir.Loc) -> str:
4096
return (
4197
f"loc(\"{loc.filename or '<unknown>'}\":{loc.line or 0}:{loc.col or 0})"
@@ -419,12 +475,39 @@ def emit_assign(self, assign):
419475
func = assign.value.func
420476
# if assert fails, variable was undefined
421477
assert func.name in self.globals_map
422-
if self.globals_map[func.name] == "bool":
478+
name, global_ = self.globals_map[func.name]
479+
if name == "bool":
423480
# nothing to do, forward the name to the arg of bool()
424481
self.forward_name(from_var=assign.target, to_var=assign.value.args[0])
425482
return ""
483+
if global_ in MLIR_TYPES:
484+
if len(assign.value.args) != 1:
485+
raise CompilerError(
486+
"MLIR type cast requires exactly one argument", assign.value.loc
487+
)
488+
value = assign.value.args[0].name
489+
if (
490+
mlirType(self.typemap.get(assign.target.name))
491+
!= global_.mlir_type()
492+
):
493+
raise InternalCompilerError(
494+
f"MLIR type cast of {value} from"
495+
f" {mlirType(self.typemap.get(value))} to"
496+
f" {global_.mlir_type()} is not correctly reflected in types"
497+
" inferred for the assignment, which expects"
498+
f" {mlirType(self.typemap.get(assign.target.name))}"
499+
)
500+
target_ssa = self.get_or_create_name(assign.target)
501+
ssa_id = self.get_or_create_name(assign.value.args[0])
502+
cast = mlirCastOp(
503+
self.typemap.get(value),
504+
global_,
505+
ssa_id,
506+
assign.loc,
507+
)
508+
return f"{target_ssa} = {cast}"
426509
else:
427-
raise InternalCompilerError("Unknown global " + func.name)
510+
raise InternalCompilerError("Call to unknown function " + name)
428511
case ir.Expr(op="cast"):
429512
# not sure what to do here. maybe will be needed for type conversions
430513
# when interfacing with C
@@ -446,7 +529,10 @@ def emit_assign(self, assign):
446529
self.forward_name_to_id(assign.target, name.strip("%"))
447530
return const_str
448531
case ir.Global():
449-
self.globals_map[assign.target.name] = assign.value.name
532+
self.globals_map[assign.target.name] = (
533+
assign.value.name,
534+
assign.value.value,
535+
)
450536
return ""
451537
case ir.Var():
452538
# Sometimes we need this to be assigned?
@@ -469,6 +555,7 @@ def emit_ext_if_needed(self, lhs, rhs):
469555
raise InternalCompilerError(
470556
"Extension handling for non-integer (e.g., floats, tensors) types"
471557
" is not yet supported. Please ensure (inferred) bit-widths match."
558+
f" Failed to extend {lhs_type} and {rhs_type} types."
472559
)
473560
# TODO (#1162): Support bitwidth extension for float types
474561
# (this probably requires adding support for local variable type hints,

0 commit comments

Comments
 (0)