@@ -1578,6 +1578,145 @@ struct NoNan : public OpRewritePattern<mlir::stablehlo::CompareOp> {
1578
1578
}
1579
1579
};
1580
1580
1581
+ struct TransposeTranspose
1582
+ : public OpRewritePattern<mlir::stablehlo::TransposeOp> {
1583
+ using OpRewritePattern<mlir::stablehlo::TransposeOp>::OpRewritePattern;
1584
+
1585
+ LogicalResult matchAndRewrite (mlir::stablehlo::TransposeOp op,
1586
+ PatternRewriter &rewriter) const final {
1587
+ if (auto definingTranspose =
1588
+ op.getOperand ().getDefiningOp <mlir::stablehlo::TransposeOp>()) {
1589
+ llvm::ArrayRef<int64_t > thisPermutation = op.getPermutation ();
1590
+ llvm::ArrayRef<int64_t > prevPermutation =
1591
+ definingTranspose.getPermutation ();
1592
+
1593
+ SmallVector<int64_t > newPermutation;
1594
+ newPermutation.resize (thisPermutation.size ());
1595
+ for (unsigned i = 0 , e = thisPermutation.size (); i != e; ++i) {
1596
+ newPermutation[i] = prevPermutation[thisPermutation[i]];
1597
+ }
1598
+
1599
+ rewriter.modifyOpInPlace (op, [&]() {
1600
+ op.setPermutation (newPermutation);
1601
+ op.setOperand (definingTranspose.getOperand ());
1602
+ });
1603
+
1604
+ return success ();
1605
+ }
1606
+ return rewriter.notifyMatchFailure (op, " not a transpose(transpose)" );
1607
+ }
1608
+ };
1609
+
1610
+ struct TransposeConvert : public OpRewritePattern <mlir::stablehlo::ConvertOp> {
1611
+ using OpRewritePattern<mlir::stablehlo::ConvertOp>::OpRewritePattern;
1612
+
1613
+ LogicalResult matchAndRewrite (mlir::stablehlo::ConvertOp op,
1614
+ PatternRewriter &rewriter) const final {
1615
+ auto resultType = op.getResult ().getType ().cast <TensorType>();
1616
+ auto operandType = op.getOperand ().getType ().cast <TensorType>();
1617
+ if (!resultType.hasStaticShape () || !operandType.hasStaticShape ())
1618
+ return failure ();
1619
+ if (resultType.getNumElements () * resultType.getElementTypeBitWidth () >=
1620
+ operandType.getNumElements () * operandType.getElementTypeBitWidth ())
1621
+ return failure ();
1622
+
1623
+ auto transpose =
1624
+ op.getOperand ().getDefiningOp <mlir::stablehlo::TransposeOp>();
1625
+ if (!transpose || !llvm::hasSingleElement (transpose->getUsers ()))
1626
+ return failure ();
1627
+
1628
+ auto newConvert = rewriter.create <stablehlo::ConvertOp>(
1629
+ op.getLoc (), transpose.getOperand (), resultType.getElementType ());
1630
+ auto newTranspose = rewriter.create <stablehlo::TransposeOp>(
1631
+ transpose.getLoc (), newConvert.getResult (), transpose.getPermutation ());
1632
+ rewriter.replaceOp (op, newTranspose);
1633
+ rewriter.eraseOp (transpose);
1634
+
1635
+ return success ();
1636
+ }
1637
+ };
1638
+
1639
+ struct BroadcastReduce : public OpRewritePattern <mlir::stablehlo::ReduceOp> {
1640
+ using OpRewritePattern<mlir::stablehlo::ReduceOp>::OpRewritePattern;
1641
+
1642
+ LogicalResult matchAndRewrite (mlir::stablehlo::ReduceOp op,
1643
+ PatternRewriter &rewriter) const final {
1644
+ if (op.getInputs ().size () != 1 || op.getInitValues ().size () != 1 ) {
1645
+ return rewriter.notifyMatchFailure (
1646
+ op, " only single-operand single-init reduce is supported" );
1647
+ }
1648
+ // TODO: min/max can also be an option since they are dropped
1649
+ if (!isa<stablehlo::AddOp>(op.getRegion ().getBlocks ().front ().front ())) {
1650
+ return rewriter.notifyMatchFailure (op, " only add is currently supported" );
1651
+ }
1652
+
1653
+ Value input = op.getInputs ()[0 ];
1654
+ auto inputType = input.getType ().cast <TensorType>();
1655
+ auto broadcast = input.getDefiningOp <mlir::stablehlo::BroadcastInDimOp>();
1656
+ if (!broadcast) {
1657
+ return rewriter.notifyMatchFailure (op,
1658
+ " input source is not a broadcast op" );
1659
+ }
1660
+
1661
+ // If any of the dimensions that are being reduced was initially
1662
+ // broadcasted, we can multiply the result with the dimension instead.
1663
+ ArrayRef<int64_t > broadcastDims = broadcast.getBroadcastDimensions ();
1664
+ SmallVector<int64_t > broadcastFromNothingDims, broadcastFromOneDims;
1665
+ auto broadcastSourceType =
1666
+ broadcast.getOperand ().getType ().cast <TensorType>();
1667
+ for (int64_t reductionDim : op.getDimensions ()) {
1668
+ if (inputType.isDynamicDim (reductionDim)) continue ;
1669
+ auto it = llvm::find (broadcastDims, reductionDim);
1670
+ if (it == broadcastDims.end ()) {
1671
+ broadcastFromNothingDims.push_back (reductionDim);
1672
+ continue ;
1673
+ }
1674
+ size_t originalDim = std::distance (broadcastDims.begin (), it);
1675
+ if (broadcastSourceType.getDimSize (originalDim) == 1 &&
1676
+ inputType.getDimSize (reductionDim) != 1 ) {
1677
+ broadcastFromOneDims.push_back (reductionDim);
1678
+ }
1679
+ }
1680
+ if (broadcastFromNothingDims.empty () && broadcastFromOneDims.empty ())
1681
+ return rewriter.notifyMatchFailure (op, " no dimensions to remove" );
1682
+
1683
+ int64_t size = 1 ;
1684
+ for (int64_t dim : broadcastFromNothingDims) {
1685
+ size *= inputType.getDimSize (dim);
1686
+ }
1687
+ for (int64_t dim : broadcastFromOneDims) {
1688
+ size *= inputType.getDimSize (dim);
1689
+ }
1690
+
1691
+ int64_t numRemoved = 0 ;
1692
+ SmallVector<int64_t > newReduceDimensions;
1693
+ llvm::sort (broadcastFromNothingDims);
1694
+ for (int64_t reductionDim : op.getDimensions ()) {
1695
+ if (llvm::is_contained (broadcastFromNothingDims, reductionDim)) {
1696
+ numRemoved++;
1697
+ continue ;
1698
+ }
1699
+ newReduceDimensions.push_back (reductionDim - numRemoved);
1700
+ }
1701
+
1702
+ auto newReduction = rewriter.create <stablehlo::ReduceOp>(
1703
+ op.getLoc (), op->getResultTypes (), ValueRange{broadcast.getOperand ()},
1704
+ op.getInitValues (), newReduceDimensions);
1705
+ newReduction.getRegion ().takeBody (op.getRegion ());
1706
+
1707
+ auto newResultType = newReduction.getResult (0 ).getType ().cast <TensorType>();
1708
+ auto constantInt = rewriter.create <stablehlo::ConstantOp>(
1709
+ op.getLoc (),
1710
+ makeAttr (newResultType.clone (rewriter.getI64Type ()), size));
1711
+ auto converted = rewriter.create <stablehlo::ConvertOp>(
1712
+ op.getLoc (), constantInt, newResultType.getElementType ());
1713
+ rewriter.replaceOpWithNewOp <stablehlo::MulOp>(op, newReduction.getResult (0 ),
1714
+ converted.getResult ());
1715
+
1716
+ return success ();
1717
+ }
1718
+ };
1719
+
1581
1720
struct EnzymeHLOOptPass : public EnzymeHLOOptPassBase <EnzymeHLOOptPass> {
1582
1721
1583
1722
void runOnOperation () override {
@@ -1594,7 +1733,8 @@ struct EnzymeHLOOptPass : public EnzymeHLOOptPassBase<EnzymeHLOOptPass> {
1594
1733
BinBroadcastSplat<stablehlo::AddOp>,
1595
1734
BinBroadcastSplat<stablehlo::SubtractOp>,
1596
1735
BinBroadcastSplat<stablehlo::DivOp>,
1597
- BinBroadcastSplat<stablehlo::MulOp>>(context);
1736
+ BinBroadcastSplat<stablehlo::MulOp>, TransposeTranspose,
1737
+ TransposeConvert, BroadcastReduce>(context);
1598
1738
if (all_finite)
1599
1739
patterns.add <AllFinite>(context);
1600
1740
if (no_nan || all_finite)
0 commit comments