@@ -1911,6 +1911,146 @@ struct NoNan : public OpRewritePattern<mlir::stablehlo::CompareOp> {
1911
1911
}
1912
1912
};
1913
1913
1914
+ struct TransposeTranspose
1915
+ : public OpRewritePattern<mlir::stablehlo::TransposeOp> {
1916
+ using OpRewritePattern<mlir::stablehlo::TransposeOp>::OpRewritePattern;
1917
+
1918
+ LogicalResult matchAndRewrite (mlir::stablehlo::TransposeOp op,
1919
+ PatternRewriter &rewriter) const final {
1920
+ if (auto definingTranspose =
1921
+ op.getOperand ().getDefiningOp <mlir::stablehlo::TransposeOp>()) {
1922
+ llvm::ArrayRef<int64_t > thisPermutation = op.getPermutation ();
1923
+ llvm::ArrayRef<int64_t > prevPermutation =
1924
+ definingTranspose.getPermutation ();
1925
+
1926
+ SmallVector<int64_t > newPermutation;
1927
+ newPermutation.resize (thisPermutation.size ());
1928
+ for (unsigned i = 0 , e = thisPermutation.size (); i != e; ++i) {
1929
+ newPermutation[i] = prevPermutation[thisPermutation[i]];
1930
+ }
1931
+
1932
+ rewriter.modifyOpInPlace (op, [&]() {
1933
+ op.setPermutation (newPermutation);
1934
+ op.setOperand (definingTranspose.getOperand ());
1935
+ });
1936
+
1937
+ return success ();
1938
+ }
1939
+ return rewriter.notifyMatchFailure (op, " not a transpose(transpose)" );
1940
+ }
1941
+ };
1942
+
1943
+ struct TransposeConvert : public OpRewritePattern <mlir::stablehlo::ConvertOp> {
1944
+ using OpRewritePattern<mlir::stablehlo::ConvertOp>::OpRewritePattern;
1945
+
1946
+ LogicalResult matchAndRewrite (mlir::stablehlo::ConvertOp op,
1947
+ PatternRewriter &rewriter) const final {
1948
+ auto resultType = op.getResult ().getType ().cast <TensorType>();
1949
+ auto operandType = op.getOperand ().getType ().cast <TensorType>();
1950
+ if (!resultType.hasStaticShape () || !operandType.hasStaticShape ())
1951
+ return failure ();
1952
+ if (resultType.getNumElements () * resultType.getElementTypeBitWidth () >=
1953
+ operandType.getNumElements () * operandType.getElementTypeBitWidth ())
1954
+ return failure ();
1955
+
1956
+ auto transpose =
1957
+ op.getOperand ().getDefiningOp <mlir::stablehlo::TransposeOp>();
1958
+ if (!transpose || !llvm::hasSingleElement (transpose->getUsers ()))
1959
+ return failure ();
1960
+
1961
+ auto newConvert = rewriter.create <stablehlo::ConvertOp>(
1962
+ op.getLoc (), transpose.getOperand (), resultType.getElementType ());
1963
+ auto newTranspose = rewriter.create <stablehlo::TransposeOp>(
1964
+ transpose.getLoc (), newConvert.getResult (), transpose.getPermutation ());
1965
+ rewriter.replaceOp (op, newTranspose);
1966
+ rewriter.eraseOp (transpose);
1967
+
1968
+ return success ();
1969
+ }
1970
+ };
1971
+
1972
+ struct BroadcastReduce : public OpRewritePattern <mlir::stablehlo::ReduceOp> {
1973
+ using OpRewritePattern<mlir::stablehlo::ReduceOp>::OpRewritePattern;
1974
+
1975
+ LogicalResult matchAndRewrite (mlir::stablehlo::ReduceOp op,
1976
+ PatternRewriter &rewriter) const final {
1977
+ if (op.getInputs ().size () != 1 || op.getInitValues ().size () != 1 ) {
1978
+ return rewriter.notifyMatchFailure (
1979
+ op, " only single-operand single-init reduce is supported" );
1980
+ }
1981
+ // TODO: min/max can also be an option since they are dropped
1982
+ if (!isa<stablehlo::AddOp>(op.getRegion ().getBlocks ().front ().front ())) {
1983
+ return rewriter.notifyMatchFailure (op, " only add is currently supported" );
1984
+ }
1985
+
1986
+ Value input = op.getInputs ()[0 ];
1987
+ auto inputType = input.getType ().cast <TensorType>();
1988
+ auto broadcast = input.getDefiningOp <mlir::stablehlo::BroadcastInDimOp>();
1989
+ if (!broadcast) {
1990
+ return rewriter.notifyMatchFailure (op,
1991
+ " input source is not a broadcast op" );
1992
+ }
1993
+
1994
+ // If any of the dimensions that are being reduced was initially
1995
+ // broadcasted, we can multiply the result with the dimension instead.
1996
+ ArrayRef<int64_t > broadcastDims = broadcast.getBroadcastDimensions ();
1997
+ SmallVector<int64_t > broadcastFromNothingDims, broadcastFromOneDims;
1998
+ auto broadcastSourceType =
1999
+ broadcast.getOperand ().getType ().cast <TensorType>();
2000
+ for (int64_t reductionDim : op.getDimensions ()) {
2001
+ if (inputType.isDynamicDim (reductionDim))
2002
+ continue ;
2003
+ auto it = llvm::find (broadcastDims, reductionDim);
2004
+ if (it == broadcastDims.end ()) {
2005
+ broadcastFromNothingDims.push_back (reductionDim);
2006
+ continue ;
2007
+ }
2008
+ size_t originalDim = std::distance (broadcastDims.begin (), it);
2009
+ if (broadcastSourceType.getDimSize (originalDim) == 1 &&
2010
+ inputType.getDimSize (reductionDim) != 1 ) {
2011
+ broadcastFromOneDims.push_back (reductionDim);
2012
+ }
2013
+ }
2014
+ if (broadcastFromNothingDims.empty () && broadcastFromOneDims.empty ())
2015
+ return rewriter.notifyMatchFailure (op, " no dimensions to remove" );
2016
+
2017
+ int64_t size = 1 ;
2018
+ for (int64_t dim : broadcastFromNothingDims) {
2019
+ size *= inputType.getDimSize (dim);
2020
+ }
2021
+ for (int64_t dim : broadcastFromOneDims) {
2022
+ size *= inputType.getDimSize (dim);
2023
+ }
2024
+
2025
+ int64_t numRemoved = 0 ;
2026
+ SmallVector<int64_t > newReduceDimensions;
2027
+ llvm::sort (broadcastFromNothingDims);
2028
+ for (int64_t reductionDim : op.getDimensions ()) {
2029
+ if (llvm::is_contained (broadcastFromNothingDims, reductionDim)) {
2030
+ numRemoved++;
2031
+ continue ;
2032
+ }
2033
+ newReduceDimensions.push_back (reductionDim - numRemoved);
2034
+ }
2035
+
2036
+ auto newReduction = rewriter.create <stablehlo::ReduceOp>(
2037
+ op.getLoc (), op->getResultTypes (), ValueRange{broadcast.getOperand ()},
2038
+ op.getInitValues (), newReduceDimensions);
2039
+ newReduction.getRegion ().takeBody (op.getRegion ());
2040
+
2041
+ auto newResultType = newReduction.getResult (0 ).getType ().cast <TensorType>();
2042
+ auto constantInt = rewriter.create <stablehlo::ConstantOp>(
2043
+ op.getLoc (),
2044
+ makeAttr (newResultType.clone (rewriter.getI64Type ()), size));
2045
+ auto converted = rewriter.create <stablehlo::ConvertOp>(
2046
+ op.getLoc (), constantInt, newResultType.getElementType ());
2047
+ rewriter.replaceOpWithNewOp <stablehlo::MulOp>(op, newReduction.getResult (0 ),
2048
+ converted.getResult ());
2049
+
2050
+ return success ();
2051
+ }
2052
+ };
2053
+
1914
2054
struct EnzymeHLOOptPass : public EnzymeHLOOptPassBase <EnzymeHLOOptPass> {
1915
2055
1916
2056
void runOnOperation () override {
@@ -1930,7 +2070,8 @@ struct EnzymeHLOOptPass : public EnzymeHLOOptPassBase<EnzymeHLOOptPass> {
1930
2070
DivSimplify, PowSimplify, BinBroadcastSplat<stablehlo::AddOp>,
1931
2071
BinBroadcastSplat<stablehlo::SubtractOp>,
1932
2072
BinBroadcastSplat<stablehlo::DivOp>,
1933
- BinBroadcastSplat<stablehlo::MulOp>>(context);
2073
+ BinBroadcastSplat<stablehlo::MulOp>, TransposeTranspose,
2074
+ TransposeConvert, BroadcastReduce>(context);
1934
2075
patterns.add <IotaSimplify, BroadcastInDimSimplify>(max_constant_expansion,
1935
2076
context);
1936
2077
if (all_finite)
0 commit comments