Skip to content

Commit 4a94bb9

Browse files
authored
Generalize mktup (#44)
1 parent fb5738d commit 4a94bb9

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

src/enzyme_ad/jax/primitives.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -461,7 +461,14 @@ def _enzyme_rev_abstract_eval(
461461
def maketup(ty):
462462
ty = ir.RankedTensorType(ty)
463463
tystr = ty.element_type.__str__()
464-
tystr = {"f32": "float", "f64": "double", "i32": "int32_t", "i64": "int64_t"}[tystr]
464+
tystr = {
465+
"f32": "float",
466+
"f64": "double",
467+
"i32": "int32_t",
468+
"i64": "int64_t",
469+
"ui32": "uint32_t",
470+
"ui64": "uint64_t",
471+
}[tystr]
465472
return (tystr, ty.shape)
466473

467474

0 commit comments

Comments
 (0)