@@ -465,9 +465,18 @@ def maketup(ty):
465
465
return (tystr , ty .shape )
466
466
467
467
468
- def to_jax (ty ):
469
- tystr = ty .__str__ ()
470
- return {"f32" : jnp .float32 , "f64" : jnp .float64 }[tystr ]
468
+ def make_mlir_zero (ty ):
469
+ from jax ._src .interpreters import mlir
470
+
471
+ if type (ty ) != mlir .ir .RankedTensorType :
472
+ ty = jax_mlir .dtype_to_ir_type (ty )
473
+ elty = ty .element_type
474
+ elem = (
475
+ ir .FloatAttr .get (elty , 0.0 )
476
+ if type (elty ) != ir .IntegerType
477
+ else ir .IntegerAttr .get (elty , 0 )
478
+ )
479
+ return stablehlo .ConstantOp (ir .DenseElementsAttr .get_splat (ty , elem )).results [0 ]
471
480
472
481
473
482
def arg_activity_from_pipeline (pass_pipeline ):
@@ -610,11 +619,6 @@ def _enzyme_primal_lowering(
610
619
print (tmpBuf , out_shapes , "\n " , results , "\n " , str (custom_call ))
611
620
assert len (results ) == len (out_shapes )
612
621
613
- def zero (ty ):
614
- from jax ._src .interpreters import mlir
615
-
616
- return mlir .ir_constant (jnp .zeros (ty .shape , dtype = to_jax (ty .element_type )))
617
-
618
622
results2 = []
619
623
residx = 0
620
624
for k in sorted (out_idx_map ):
@@ -623,7 +627,7 @@ def zero(ty):
623
627
results2 .append (results [residx ])
624
628
residx += 1
625
629
else :
626
- z = zero (orig_types [v ])
630
+ z = make_mlir_zero (orig_types [v ])
627
631
results2 .append (z )
628
632
629
633
results = tuple (results2 )
@@ -857,18 +861,7 @@ def _enzyme_rev_lowering(
857
861
results .append (custom_call .results [cur_idx ])
858
862
cur_idx += 1
859
863
else :
860
- ty = ir .RankedTensorType (ty )
861
- shape = ty .shape
862
- element_type = ty .element_type
863
- import numpy as np
864
-
865
- results .append (
866
- stablehlo .ConstantOp (
867
- ir .DenseElementsAttr .get (
868
- np .zeros (shape , dtype = to_jax (element_type ))
869
- )
870
- ).results [0 ]
871
- )
864
+ results .append (make_mlir_zero (ty ))
872
865
return results
873
866
874
867
0 commit comments