12
12
from numba .core import controlflow
13
13
from numba .core .types import Type as NumbaType
14
14
15
- from heir .interfaces import InternalCompilerError
15
+ from heir .mlir .types import MLIRType , MLIR_TYPES , I1 , I8 , I16 , I32 , I64 , F32 , F64
16
+ from heir .interfaces import CompilerError , DebugMessage , InternalCompilerError
16
17
17
18
18
19
def mlirType (numba_type : NumbaType ) -> str :
@@ -36,6 +37,61 @@ def mlirType(numba_type: NumbaType) -> str:
36
37
raise InternalCompilerError ("Unsupported type: " + str (numba_type ))
37
38
38
39
40
+ def isIntegerLike (typ : NumbaType | MLIRType ) -> bool :
41
+ if isinstance (typ , type ) and issubclass (typ , MLIRType ):
42
+ return typ in {I1 , I8 , I16 , I32 , I64 }
43
+ if isinstance (typ , NumbaType ):
44
+ return isinstance (typ , types .Integer ) or isinstance (typ , types .Boolean )
45
+ raise InternalCompilerError (f"Encountered unexpected type { typ } " )
46
+
47
+
48
+ def isFloatLike (typ : NumbaType | MLIRType ) -> bool :
49
+ if isinstance (typ , type ) and issubclass (type , MLIRType ):
50
+ return typ in {F32 , F64 }
51
+ if isinstance (typ , NumbaType ):
52
+ return isinstance (typ , types .Float )
53
+ raise InternalCompilerError (f"Encountered unexpected type { type } " )
54
+
55
+
56
+ def mlirCastOp (
57
+ from_type : NumbaType , to_type : MLIRType , value : str , loc : ir .Loc
58
+ ) -> str :
59
+ if isIntegerLike (from_type ) and isIntegerLike (to_type ):
60
+ if from_type .bitwidth == to_type .numba_type ().bitwidth :
61
+ raise CompilerError (
62
+ f"Cannot create cast of { value } from { from_type } to { to_type } as they"
63
+ " have the same bitwidth" ,
64
+ loc ,
65
+ )
66
+ if from_type .bitwidth > to_type .numba_type ().bitwidth :
67
+ return (
68
+ f"arith.trunci { value } : { mlirType (from_type )} to"
69
+ f" { to_type .mlir_type ()} { mlirLoc (loc )} "
70
+ )
71
+ if from_type .bitwidth < to_type .numba_type ().bitwidth :
72
+ # FIXME: signedness for extensions?
73
+ return (
74
+ f"arith.extui { value } : { mlirType (from_type )} to"
75
+ f" { to_type .mlir_type ()} { mlirLoc (loc )} "
76
+ )
77
+ if isFloatLike (from_type ) and isIntegerLike (to_type ):
78
+ # FIXME: signedness?
79
+ return (
80
+ f"arith.fptoui { value } : { mlirType (from_type )} to"
81
+ f" { mlirType (to_type )} { mlirLoc (loc )} "
82
+ )
83
+ if isIntegerLike (from_type ) and isFloatLike (to_type ):
84
+ # FIXME: signendess?
85
+ return (
86
+ f"arith.uitofp { value } : { mlirType (from_type )} to"
87
+ f" { mlirType (to_type )} { mlirLoc (loc )} "
88
+ )
89
+ raise CompilerError (
90
+ f"Encountered unsupported cast of { value } from { from_type } to { to_type } " ,
91
+ loc ,
92
+ )
93
+
94
+
39
95
def mlirLoc (loc : ir .Loc ) -> str :
40
96
return (
41
97
f"loc(\" { loc .filename or '<unknown>' } \" :{ loc .line or 0 } :{ loc .col or 0 } )"
@@ -419,12 +475,39 @@ def emit_assign(self, assign):
419
475
func = assign .value .func
420
476
# if assert fails, variable was undefined
421
477
assert func .name in self .globals_map
422
- if self .globals_map [func .name ] == "bool" :
478
+ name , global_ = self .globals_map [func .name ]
479
+ if name == "bool" :
423
480
# nothing to do, forward the name to the arg of bool()
424
481
self .forward_name (from_var = assign .target , to_var = assign .value .args [0 ])
425
482
return ""
483
+ if global_ in MLIR_TYPES :
484
+ if len (assign .value .args ) != 1 :
485
+ raise CompilerError (
486
+ "MLIR type cast requires exactly one argument" , assign .value .loc
487
+ )
488
+ value = assign .value .args [0 ].name
489
+ if (
490
+ mlirType (self .typemap .get (assign .target .name ))
491
+ != global_ .mlir_type ()
492
+ ):
493
+ raise InternalCompilerError (
494
+ f"MLIR type cast of { value } from"
495
+ f" { mlirType (self .typemap .get (value ))} to"
496
+ f" { global_ .mlir_type ()} is not correctly reflected in types"
497
+ " inferred for the assignment, which expects"
498
+ f" { mlirType (self .typemap .get (assign .target .name ))} "
499
+ )
500
+ target_ssa = self .get_or_create_name (assign .target )
501
+ ssa_id = self .get_or_create_name (assign .value .args [0 ])
502
+ cast = mlirCastOp (
503
+ self .typemap .get (value ),
504
+ global_ ,
505
+ ssa_id ,
506
+ assign .loc ,
507
+ )
508
+ return f"{ target_ssa } = { cast } "
426
509
else :
427
- raise InternalCompilerError ("Unknown global " + func . name )
510
+ raise InternalCompilerError ("Call to unknown function " + name )
428
511
case ir .Expr (op = "cast" ):
429
512
# not sure what to do here. maybe will be needed for type conversions
430
513
# when interfacing with C
@@ -446,7 +529,10 @@ def emit_assign(self, assign):
446
529
self .forward_name_to_id (assign .target , name .strip ("%" ))
447
530
return const_str
448
531
case ir .Global ():
449
- self .globals_map [assign .target .name ] = assign .value .name
532
+ self .globals_map [assign .target .name ] = (
533
+ assign .value .name ,
534
+ assign .value .value ,
535
+ )
450
536
return ""
451
537
case ir .Var ():
452
538
# Sometimes we need this to be assigned?
@@ -469,6 +555,7 @@ def emit_ext_if_needed(self, lhs, rhs):
469
555
raise InternalCompilerError (
470
556
"Extension handling for non-integer (e.g., floats, tensors) types"
471
557
" is not yet supported. Please ensure (inferred) bit-widths match."
558
+ f" Failed to extend { lhs_type } and { rhs_type } types."
472
559
)
473
560
# TODO (#1162): Support bitwidth extension for float types
474
561
# (this probably requires adding support for local variable type hints,
0 commit comments