Skip to content

Commit fb5738d

Browse files
authored
generalize mlir zeroing (#43)
1 parent 3df5205 commit fb5738d

File tree

1 file changed

+14
-21
lines changed

1 file changed

+14
-21
lines changed

src/enzyme_ad/jax/primitives.py

Lines changed: 14 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -465,9 +465,18 @@ def maketup(ty):
465465
return (tystr, ty.shape)
466466

467467

468-
def to_jax(ty):
469-
tystr = ty.__str__()
470-
return {"f32": jnp.float32, "f64": jnp.float64}[tystr]
468+
def make_mlir_zero(ty):
469+
from jax._src.interpreters import mlir
470+
471+
if type(ty) != mlir.ir.RankedTensorType:
472+
ty = jax_mlir.dtype_to_ir_type(ty)
473+
elty = ty.element_type
474+
elem = (
475+
ir.FloatAttr.get(elty, 0.0)
476+
if type(elty) != ir.IntegerType
477+
else ir.IntegerAttr.get(elty, 0)
478+
)
479+
return stablehlo.ConstantOp(ir.DenseElementsAttr.get_splat(ty, elem)).results[0]
471480

472481

473482
def arg_activity_from_pipeline(pass_pipeline):
@@ -610,11 +619,6 @@ def _enzyme_primal_lowering(
610619
print(tmpBuf, out_shapes, "\n", results, "\n", str(custom_call))
611620
assert len(results) == len(out_shapes)
612621

613-
def zero(ty):
614-
from jax._src.interpreters import mlir
615-
616-
return mlir.ir_constant(jnp.zeros(ty.shape, dtype=to_jax(ty.element_type)))
617-
618622
results2 = []
619623
residx = 0
620624
for k in sorted(out_idx_map):
@@ -623,7 +627,7 @@ def zero(ty):
623627
results2.append(results[residx])
624628
residx += 1
625629
else:
626-
z = zero(orig_types[v])
630+
z = make_mlir_zero(orig_types[v])
627631
results2.append(z)
628632

629633
results = tuple(results2)
@@ -857,18 +861,7 @@ def _enzyme_rev_lowering(
857861
results.append(custom_call.results[cur_idx])
858862
cur_idx += 1
859863
else:
860-
ty = ir.RankedTensorType(ty)
861-
shape = ty.shape
862-
element_type = ty.element_type
863-
import numpy as np
864-
865-
results.append(
866-
stablehlo.ConstantOp(
867-
ir.DenseElementsAttr.get(
868-
np.zeros(shape, dtype=to_jax(element_type))
869-
)
870-
).results[0]
871-
)
864+
results.append(make_mlir_zero(ty))
872865
return results
873866

874867

0 commit comments

Comments
 (0)