Skip to content

Commit d21507b

Browse files
authored
transpose opts (#52)
1 parent c3cd279 commit d21507b

File tree

3 files changed

+205
-1
lines changed

3 files changed

+205
-1
lines changed

src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp

Lines changed: 142 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1911,6 +1911,146 @@ struct NoNan : public OpRewritePattern<mlir::stablehlo::CompareOp> {
19111911
}
19121912
};
19131913

1914+
struct TransposeTranspose
1915+
: public OpRewritePattern<mlir::stablehlo::TransposeOp> {
1916+
using OpRewritePattern<mlir::stablehlo::TransposeOp>::OpRewritePattern;
1917+
1918+
LogicalResult matchAndRewrite(mlir::stablehlo::TransposeOp op,
1919+
PatternRewriter &rewriter) const final {
1920+
if (auto definingTranspose =
1921+
op.getOperand().getDefiningOp<mlir::stablehlo::TransposeOp>()) {
1922+
llvm::ArrayRef<int64_t> thisPermutation = op.getPermutation();
1923+
llvm::ArrayRef<int64_t> prevPermutation =
1924+
definingTranspose.getPermutation();
1925+
1926+
SmallVector<int64_t> newPermutation;
1927+
newPermutation.resize(thisPermutation.size());
1928+
for (unsigned i = 0, e = thisPermutation.size(); i != e; ++i) {
1929+
newPermutation[i] = prevPermutation[thisPermutation[i]];
1930+
}
1931+
1932+
rewriter.modifyOpInPlace(op, [&]() {
1933+
op.setPermutation(newPermutation);
1934+
op.setOperand(definingTranspose.getOperand());
1935+
});
1936+
1937+
return success();
1938+
}
1939+
return rewriter.notifyMatchFailure(op, "not a transpose(transpose)");
1940+
}
1941+
};
1942+
1943+
struct TransposeConvert : public OpRewritePattern<mlir::stablehlo::ConvertOp> {
1944+
using OpRewritePattern<mlir::stablehlo::ConvertOp>::OpRewritePattern;
1945+
1946+
LogicalResult matchAndRewrite(mlir::stablehlo::ConvertOp op,
1947+
PatternRewriter &rewriter) const final {
1948+
auto resultType = op.getResult().getType().cast<TensorType>();
1949+
auto operandType = op.getOperand().getType().cast<TensorType>();
1950+
if (!resultType.hasStaticShape() || !operandType.hasStaticShape())
1951+
return failure();
1952+
if (resultType.getNumElements() * resultType.getElementTypeBitWidth() >=
1953+
operandType.getNumElements() * operandType.getElementTypeBitWidth())
1954+
return failure();
1955+
1956+
auto transpose =
1957+
op.getOperand().getDefiningOp<mlir::stablehlo::TransposeOp>();
1958+
if (!transpose || !llvm::hasSingleElement(transpose->getUsers()))
1959+
return failure();
1960+
1961+
auto newConvert = rewriter.create<stablehlo::ConvertOp>(
1962+
op.getLoc(), transpose.getOperand(), resultType.getElementType());
1963+
auto newTranspose = rewriter.create<stablehlo::TransposeOp>(
1964+
transpose.getLoc(), newConvert.getResult(), transpose.getPermutation());
1965+
rewriter.replaceOp(op, newTranspose);
1966+
rewriter.eraseOp(transpose);
1967+
1968+
return success();
1969+
}
1970+
};
1971+
1972+
struct BroadcastReduce : public OpRewritePattern<mlir::stablehlo::ReduceOp> {
1973+
using OpRewritePattern<mlir::stablehlo::ReduceOp>::OpRewritePattern;
1974+
1975+
LogicalResult matchAndRewrite(mlir::stablehlo::ReduceOp op,
1976+
PatternRewriter &rewriter) const final {
1977+
if (op.getInputs().size() != 1 || op.getInitValues().size() != 1) {
1978+
return rewriter.notifyMatchFailure(
1979+
op, "only single-operand single-init reduce is supported");
1980+
}
1981+
// TODO: min/max can also be an option since they are dropped
1982+
if (!isa<stablehlo::AddOp>(op.getRegion().getBlocks().front().front())) {
1983+
return rewriter.notifyMatchFailure(op, "only add is currently supported");
1984+
}
1985+
1986+
Value input = op.getInputs()[0];
1987+
auto inputType = input.getType().cast<TensorType>();
1988+
auto broadcast = input.getDefiningOp<mlir::stablehlo::BroadcastInDimOp>();
1989+
if (!broadcast) {
1990+
return rewriter.notifyMatchFailure(op,
1991+
"input source is not a broadcast op");
1992+
}
1993+
1994+
// If any of the dimensions that are being reduced was initially
1995+
// broadcasted, we can multiply the result with the dimension instead.
1996+
ArrayRef<int64_t> broadcastDims = broadcast.getBroadcastDimensions();
1997+
SmallVector<int64_t> broadcastFromNothingDims, broadcastFromOneDims;
1998+
auto broadcastSourceType =
1999+
broadcast.getOperand().getType().cast<TensorType>();
2000+
for (int64_t reductionDim : op.getDimensions()) {
2001+
if (inputType.isDynamicDim(reductionDim))
2002+
continue;
2003+
auto it = llvm::find(broadcastDims, reductionDim);
2004+
if (it == broadcastDims.end()) {
2005+
broadcastFromNothingDims.push_back(reductionDim);
2006+
continue;
2007+
}
2008+
size_t originalDim = std::distance(broadcastDims.begin(), it);
2009+
if (broadcastSourceType.getDimSize(originalDim) == 1 &&
2010+
inputType.getDimSize(reductionDim) != 1) {
2011+
broadcastFromOneDims.push_back(reductionDim);
2012+
}
2013+
}
2014+
if (broadcastFromNothingDims.empty() && broadcastFromOneDims.empty())
2015+
return rewriter.notifyMatchFailure(op, "no dimensions to remove");
2016+
2017+
int64_t size = 1;
2018+
for (int64_t dim : broadcastFromNothingDims) {
2019+
size *= inputType.getDimSize(dim);
2020+
}
2021+
for (int64_t dim : broadcastFromOneDims) {
2022+
size *= inputType.getDimSize(dim);
2023+
}
2024+
2025+
int64_t numRemoved = 0;
2026+
SmallVector<int64_t> newReduceDimensions;
2027+
llvm::sort(broadcastFromNothingDims);
2028+
for (int64_t reductionDim : op.getDimensions()) {
2029+
if (llvm::is_contained(broadcastFromNothingDims, reductionDim)) {
2030+
numRemoved++;
2031+
continue;
2032+
}
2033+
newReduceDimensions.push_back(reductionDim - numRemoved);
2034+
}
2035+
2036+
auto newReduction = rewriter.create<stablehlo::ReduceOp>(
2037+
op.getLoc(), op->getResultTypes(), ValueRange{broadcast.getOperand()},
2038+
op.getInitValues(), newReduceDimensions);
2039+
newReduction.getRegion().takeBody(op.getRegion());
2040+
2041+
auto newResultType = newReduction.getResult(0).getType().cast<TensorType>();
2042+
auto constantInt = rewriter.create<stablehlo::ConstantOp>(
2043+
op.getLoc(),
2044+
makeAttr(newResultType.clone(rewriter.getI64Type()), size));
2045+
auto converted = rewriter.create<stablehlo::ConvertOp>(
2046+
op.getLoc(), constantInt, newResultType.getElementType());
2047+
rewriter.replaceOpWithNewOp<stablehlo::MulOp>(op, newReduction.getResult(0),
2048+
converted.getResult());
2049+
2050+
return success();
2051+
}
2052+
};
2053+
19142054
struct EnzymeHLOOptPass : public EnzymeHLOOptPassBase<EnzymeHLOOptPass> {
19152055

19162056
void runOnOperation() override {
@@ -1930,7 +2070,8 @@ struct EnzymeHLOOptPass : public EnzymeHLOOptPassBase<EnzymeHLOOptPass> {
19302070
DivSimplify, PowSimplify, BinBroadcastSplat<stablehlo::AddOp>,
19312071
BinBroadcastSplat<stablehlo::SubtractOp>,
19322072
BinBroadcastSplat<stablehlo::DivOp>,
1933-
BinBroadcastSplat<stablehlo::MulOp>>(context);
2073+
BinBroadcastSplat<stablehlo::MulOp>, TransposeTranspose,
2074+
TransposeConvert, BroadcastReduce>(context);
19342075
patterns.add<IotaSimplify, BroadcastInDimSimplify>(max_constant_expansion,
19352076
context);
19362077
if (all_finite)

test/lit_tests/broadcastreduce.mlir

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
// RUN: enzymexlamlir-opt --enzyme-hlo-opt %s | FileCheck %s
2+
3+
// CHECK-LABEL: @one
4+
// CHECK-SAME: (%[[ARG0:.+]]: tensor<1x3072xf32>, %[[ARG1:.+]]: tensor<f32>)
5+
// CHECK: %[[V0:.+]] = stablehlo.constant dense<3.200000e+01> : tensor<f32>
6+
// CHECK: %[[V1:.+]] = stablehlo.reduce(%[[ARG0]] init: %[[ARG1]]) applies stablehlo.add across dimensions = [0, 1] : (tensor<1x3072xf32>, tensor<f32>) -> tensor<f32>
7+
// CHECK: %[[V2:.+]] = stablehlo.multiply %[[V1]], %[[V0]] : tensor<f32>
8+
func.func @one(%154: tensor<1x3072xf32>, %151: tensor<f32>) -> tensor<f32> {
9+
%211 = stablehlo.broadcast_in_dim %154, dims = [0, 1] : (tensor<1x3072xf32>) -> tensor<1x3072x32xf32>
10+
%212 = stablehlo.reduce(%211 init: %151) applies stablehlo.add across dimensions = [0, 1, 2] : (tensor<1x3072x32xf32>, tensor<f32>) -> tensor<f32>
11+
return %212 : tensor<f32>
12+
}
13+
14+
// CHECK-LABEL: @two
15+
// CHECK-SAME: (%[[ARG0:.+]]: tensor<1x3072xf32>, %[[ARG1:.+]]: tensor<f32>)
16+
// CHECK: %[[V0:.+]] = stablehlo.constant dense<3.200000e+01> : tensor<f32>
17+
// CHECK: %[[V1:.+]] = stablehlo.reduce(%[[ARG0]] init: %[[ARG1]]) applies stablehlo.add across dimensions = [0, 1] : (tensor<1x3072xf32>, tensor<f32>) -> tensor<f32>
18+
// CHECK: %[[V2:.+]] = stablehlo.multiply %[[V1]], %[[V0]] : tensor<f32>
19+
func.func @two(%154: tensor<1x3072xf32>, %151: tensor<f32>) -> tensor<f32> {
20+
%211 = stablehlo.broadcast_in_dim %154, dims = [0, 1] : (tensor<1x3072xf32>) -> tensor<32x3072xf32>
21+
%212 = stablehlo.reduce(%211 init: %151) applies stablehlo.add across dimensions = [0, 1] : (tensor<32x3072xf32>, tensor<f32>) -> tensor<f32>
22+
return %212 : tensor<f32>
23+
}
24+
25+
// CHECK-LABEL: @three
26+
// CHECK-SAME: (%[[ARG0:.+]]: tensor<1x3072xf32>, %[[ARG1:.+]]: tensor<f32>)
27+
// CHECK: %[[V0:.+]] = stablehlo.constant dense<3.200000e+02> : tensor<f32>
28+
// CHECK: %[[V1:.+]] = stablehlo.reduce(%[[ARG0]] init: %[[ARG1]]) applies stablehlo.add across dimensions = [0, 1] : (tensor<1x3072xf32>, tensor<f32>) -> tensor<f32>
29+
// CHECK: %[[V2:.+]] = stablehlo.multiply %[[V1]], %[[V0]] : tensor<f32>
30+
func.func @three(%154: tensor<1x3072xf32>, %151: tensor<f32>) -> tensor<f32> {
31+
%211 = stablehlo.broadcast_in_dim %154, dims = [2, 1] : (tensor<1x3072xf32>) -> tensor<32x3072x10xf32>
32+
%212 = stablehlo.reduce(%211 init: %151) applies stablehlo.add across dimensions = [0, 1, 2] : (tensor<32x3072x10xf32>, tensor<f32>) -> tensor<f32>
33+
return %212 : tensor<f32>
34+
}
35+
36+
// CHECK-LABEL: @four
37+
// CHECK-SAME: (%[[ARG0:.+]]: tensor<1x3072xf32>, %[[ARG1:.+]]: tensor<f32>)
38+
func.func @four(%154: tensor<1x3072xf32>, %151: tensor<f32>) -> tensor<f32> {
39+
// CHECK: %[[V0:.+]] = stablehlo.constant dense<3.200000e+02> : tensor<f32>
40+
// CHECK: %[[V1:.+]] = stablehlo.reduce(%[[ARG0]] init: %[[ARG1]]) applies stablehlo.add across dimensions = [0, 1] : (tensor<1x3072xf32>, tensor<f32>) -> tensor<f32>
41+
// CHECK: %[[V2:.+]] = stablehlo.multiply %[[V1]], %[[V0]] : tensor<f32>
42+
%211 = stablehlo.broadcast_in_dim %154, dims = [0, 1] : (tensor<1x3072xf32>) -> tensor<32x3072x10xf32>
43+
%212 = stablehlo.reduce(%211 init: %151) applies stablehlo.add across dimensions = [0, 1, 2] : (tensor<32x3072x10xf32>, tensor<f32>) -> tensor<f32>
44+
return %212 : tensor<f32>
45+
}
46+
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
// RUN: enzymexlamlir-opt --enzyme-hlo-opt %s | FileCheck %s
2+
3+
// CHECK-LABEL: @transpose
4+
// CHECK-NOT: stablehlo.transpose
5+
func.func @transpose(%1845: tensor<32x32768xf32>) -> tensor<32x32768xf32> {
6+
%1846 = stablehlo.transpose %1845, dims = [1, 0] : (tensor<32x32768xf32>) -> tensor<32768x32xf32>
7+
%1847 = stablehlo.transpose %1846, dims = [1, 0] : (tensor<32768x32xf32>) -> tensor<32x32768xf32>
8+
return %1847 : tensor<32x32768xf32>
9+
}
10+
11+
// CHECK-LABEL: @transpose2
12+
// CHECK: stablehlo.transpose %{{.*}}, dims = [1, 0, 2] : (tensor<2x3x4xf32>) -> tensor<3x2x4xf32>
13+
func.func @transpose2(%arg: tensor<2x3x4xf32>) -> tensor<3x2x4xf32> {
14+
%0 = stablehlo.transpose %arg, dims = [2, 0, 1] : (tensor<2x3x4xf32>) -> tensor<4x2x3xf32>
15+
%1 = stablehlo.transpose %0, dims = [2, 1, 0] : (tensor<4x2x3xf32>) -> tensor<3x2x4xf32>
16+
return %1 : tensor<3x2x4xf32>
17+
}

0 commit comments

Comments
 (0)