Skip to content

Commit f4ccf55

Browse files
committed
transpose opts
1 parent 405b74d commit f4ccf55

File tree

3 files changed

+204
-1
lines changed

3 files changed

+204
-1
lines changed

src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp

Lines changed: 141 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1578,6 +1578,145 @@ struct NoNan : public OpRewritePattern<mlir::stablehlo::CompareOp> {
15781578
}
15791579
};
15801580

1581+
struct TransposeTranspose
1582+
: public OpRewritePattern<mlir::stablehlo::TransposeOp> {
1583+
using OpRewritePattern<mlir::stablehlo::TransposeOp>::OpRewritePattern;
1584+
1585+
LogicalResult matchAndRewrite(mlir::stablehlo::TransposeOp op,
1586+
PatternRewriter &rewriter) const final {
1587+
if (auto definingTranspose =
1588+
op.getOperand().getDefiningOp<mlir::stablehlo::TransposeOp>()) {
1589+
llvm::ArrayRef<int64_t> thisPermutation = op.getPermutation();
1590+
llvm::ArrayRef<int64_t> prevPermutation =
1591+
definingTranspose.getPermutation();
1592+
1593+
SmallVector<int64_t> newPermutation;
1594+
newPermutation.resize(thisPermutation.size());
1595+
for (unsigned i = 0, e = thisPermutation.size(); i != e; ++i) {
1596+
newPermutation[i] = prevPermutation[thisPermutation[i]];
1597+
}
1598+
1599+
rewriter.modifyOpInPlace(op, [&]() {
1600+
op.setPermutation(newPermutation);
1601+
op.setOperand(definingTranspose.getOperand());
1602+
});
1603+
1604+
return success();
1605+
}
1606+
return rewriter.notifyMatchFailure(op, "not a transpose(transpose)");
1607+
}
1608+
};
1609+
1610+
struct TransposeConvert : public OpRewritePattern<mlir::stablehlo::ConvertOp> {
1611+
using OpRewritePattern<mlir::stablehlo::ConvertOp>::OpRewritePattern;
1612+
1613+
LogicalResult matchAndRewrite(mlir::stablehlo::ConvertOp op,
1614+
PatternRewriter &rewriter) const final {
1615+
auto resultType = op.getResult().getType().cast<TensorType>();
1616+
auto operandType = op.getOperand().getType().cast<TensorType>();
1617+
if (!resultType.hasStaticShape() || !operandType.hasStaticShape())
1618+
return failure();
1619+
if (resultType.getNumElements() * resultType.getElementTypeBitWidth() >=
1620+
operandType.getNumElements() * operandType.getElementTypeBitWidth())
1621+
return failure();
1622+
1623+
auto transpose =
1624+
op.getOperand().getDefiningOp<mlir::stablehlo::TransposeOp>();
1625+
if (!transpose || !llvm::hasSingleElement(transpose->getUsers()))
1626+
return failure();
1627+
1628+
auto newConvert = rewriter.create<stablehlo::ConvertOp>(
1629+
op.getLoc(), transpose.getOperand(), resultType.getElementType());
1630+
auto newTranspose = rewriter.create<stablehlo::TransposeOp>(
1631+
transpose.getLoc(), newConvert.getResult(), transpose.getPermutation());
1632+
rewriter.replaceOp(op, newTranspose);
1633+
rewriter.eraseOp(transpose);
1634+
1635+
return success();
1636+
}
1637+
};
1638+
1639+
struct BroadcastReduce : public OpRewritePattern<mlir::stablehlo::ReduceOp> {
1640+
using OpRewritePattern<mlir::stablehlo::ReduceOp>::OpRewritePattern;
1641+
1642+
LogicalResult matchAndRewrite(mlir::stablehlo::ReduceOp op,
1643+
PatternRewriter &rewriter) const final {
1644+
if (op.getInputs().size() != 1 || op.getInitValues().size() != 1) {
1645+
return rewriter.notifyMatchFailure(
1646+
op, "only single-operand single-init reduce is supported");
1647+
}
1648+
// TODO: min/max can also be an option since they are dropped
1649+
if (!isa<stablehlo::AddOp>(op.getRegion().getBlocks().front().front())) {
1650+
return rewriter.notifyMatchFailure(op, "only add is currently supported");
1651+
}
1652+
1653+
Value input = op.getInputs()[0];
1654+
auto inputType = input.getType().cast<TensorType>();
1655+
auto broadcast = input.getDefiningOp<mlir::stablehlo::BroadcastInDimOp>();
1656+
if (!broadcast) {
1657+
return rewriter.notifyMatchFailure(op,
1658+
"input source is not a broadcast op");
1659+
}
1660+
1661+
// If any of the dimensions that are being reduced was initially
1662+
// broadcasted, we can multiply the result with the dimension instead.
1663+
ArrayRef<int64_t> broadcastDims = broadcast.getBroadcastDimensions();
1664+
SmallVector<int64_t> broadcastFromNothingDims, broadcastFromOneDims;
1665+
auto broadcastSourceType =
1666+
broadcast.getOperand().getType().cast<TensorType>();
1667+
for (int64_t reductionDim : op.getDimensions()) {
1668+
if (inputType.isDynamicDim(reductionDim)) continue;
1669+
auto it = llvm::find(broadcastDims, reductionDim);
1670+
if (it == broadcastDims.end()) {
1671+
broadcastFromNothingDims.push_back(reductionDim);
1672+
continue;
1673+
}
1674+
size_t originalDim = std::distance(broadcastDims.begin(), it);
1675+
if (broadcastSourceType.getDimSize(originalDim) == 1 &&
1676+
inputType.getDimSize(reductionDim) != 1) {
1677+
broadcastFromOneDims.push_back(reductionDim);
1678+
}
1679+
}
1680+
if (broadcastFromNothingDims.empty() && broadcastFromOneDims.empty())
1681+
return rewriter.notifyMatchFailure(op, "no dimensions to remove");
1682+
1683+
int64_t size = 1;
1684+
for (int64_t dim : broadcastFromNothingDims) {
1685+
size *= inputType.getDimSize(dim);
1686+
}
1687+
for (int64_t dim : broadcastFromOneDims) {
1688+
size *= inputType.getDimSize(dim);
1689+
}
1690+
1691+
int64_t numRemoved = 0;
1692+
SmallVector<int64_t> newReduceDimensions;
1693+
llvm::sort(broadcastFromNothingDims);
1694+
for (int64_t reductionDim : op.getDimensions()) {
1695+
if (llvm::is_contained(broadcastFromNothingDims, reductionDim)) {
1696+
numRemoved++;
1697+
continue;
1698+
}
1699+
newReduceDimensions.push_back(reductionDim - numRemoved);
1700+
}
1701+
1702+
auto newReduction = rewriter.create<stablehlo::ReduceOp>(
1703+
op.getLoc(), op->getResultTypes(), ValueRange{broadcast.getOperand()},
1704+
op.getInitValues(), newReduceDimensions);
1705+
newReduction.getRegion().takeBody(op.getRegion());
1706+
1707+
auto newResultType = newReduction.getResult(0).getType().cast<TensorType>();
1708+
auto constantInt = rewriter.create<stablehlo::ConstantOp>(
1709+
op.getLoc(),
1710+
makeAttr(newResultType.clone(rewriter.getI64Type()), size));
1711+
auto converted = rewriter.create<stablehlo::ConvertOp>(
1712+
op.getLoc(), constantInt, newResultType.getElementType());
1713+
rewriter.replaceOpWithNewOp<stablehlo::MulOp>(op, newReduction.getResult(0),
1714+
converted.getResult());
1715+
1716+
return success();
1717+
}
1718+
};
1719+
15811720
struct EnzymeHLOOptPass : public EnzymeHLOOptPassBase<EnzymeHLOOptPass> {
15821721

15831722
void runOnOperation() override {
@@ -1594,7 +1733,8 @@ struct EnzymeHLOOptPass : public EnzymeHLOOptPassBase<EnzymeHLOOptPass> {
15941733
BinBroadcastSplat<stablehlo::AddOp>,
15951734
BinBroadcastSplat<stablehlo::SubtractOp>,
15961735
BinBroadcastSplat<stablehlo::DivOp>,
1597-
BinBroadcastSplat<stablehlo::MulOp>>(context);
1736+
BinBroadcastSplat<stablehlo::MulOp>, TransposeTranspose,
1737+
TransposeConvert, BroadcastReduce>(context);
15981738
if (all_finite)
15991739
patterns.add<AllFinite>(context);
16001740
if (no_nan || all_finite)

test/lit_tests/broadcastreduce.mlir

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
// RUN: enzymexlamlir-opt --enzyme-hlo-opt %s | FileCheck %s
2+
3+
// CHECK-LABEL: @one
4+
// CHECK-SAME: (%[[ARG0:.+]]: tensor<1x3072xf32>, %[[ARG1:.+]]: tensor<f32>)
5+
// CHECK: %[[V0:.+]] = stablehlo.constant dense<3.200000e+01> : tensor<f32>
6+
// CHECK: %[[V1:.+]] = stablehlo.reduce(%[[ARG0]] init: %[[ARG1]]) applies stablehlo.add across dimensions = [0, 1] : (tensor<1x3072xf32>, tensor<f32>) -> tensor<f32>
7+
// CHECK: %[[V2:.+]] = stablehlo.multiply %[[V1]], %[[V0]] : tensor<f32>
8+
func.func @one(%154: tensor<1x3072xf32>, %151: tensor<f32>) -> tensor<f32> {
9+
%211 = stablehlo.broadcast_in_dim %154, dims = [0, 1] : (tensor<1x3072xf32>) -> tensor<1x3072x32xf32>
10+
%212 = stablehlo.reduce(%211 init: %151) applies stablehlo.add across dimensions = [0, 1, 2] : (tensor<1x3072x32xf32>, tensor<f32>) -> tensor<f32>
11+
return %212 : tensor<f32>
12+
}
13+
14+
// CHECK-LABEL: @two
15+
// CHECK-SAME: (%[[ARG0:.+]]: tensor<1x3072xf32>, %[[ARG1:.+]]: tensor<f32>)
16+
// CHECK: %[[V0:.+]] = stablehlo.constant dense<3.200000e+01> : tensor<f32>
17+
// CHECK: %[[V1:.+]] = stablehlo.reduce(%[[ARG0]] init: %[[ARG1]]) applies stablehlo.add across dimensions = [0, 1] : (tensor<1x3072xf32>, tensor<f32>) -> tensor<f32>
18+
// CHECK: %[[V2:.+]] = stablehlo.multiply %[[V1]], %[[V0]] : tensor<f32>
19+
func.func @two(%154: tensor<1x3072xf32>, %151: tensor<f32>) -> tensor<f32> {
20+
%211 = stablehlo.broadcast_in_dim %154, dims = [0, 1] : (tensor<1x3072xf32>) -> tensor<32x3072xf32>
21+
%212 = stablehlo.reduce(%211 init: %151) applies stablehlo.add across dimensions = [0, 1] : (tensor<32x3072xf32>, tensor<f32>) -> tensor<f32>
22+
return %212 : tensor<f32>
23+
}
24+
25+
// CHECK-LABEL: @three
26+
// CHECK-SAME: (%[[ARG0:.+]]: tensor<1x3072xf32>, %[[ARG1:.+]]: tensor<f32>)
27+
// CHECK: %[[V0:.+]] = stablehlo.constant dense<3.200000e+02> : tensor<f32>
28+
// CHECK: %[[V1:.+]] = stablehlo.reduce(%[[ARG0]] init: %[[ARG1]]) applies stablehlo.add across dimensions = [0, 1] : (tensor<1x3072xf32>, tensor<f32>) -> tensor<f32>
29+
// CHECK: %[[V2:.+]] = stablehlo.multiply %[[V1]], %[[V0]] : tensor<f32>
30+
func.func @three(%154: tensor<1x3072xf32>, %151: tensor<f32>) -> tensor<f32> {
31+
%211 = stablehlo.broadcast_in_dim %154, dims = [2, 1] : (tensor<1x3072xf32>) -> tensor<32x3072x10xf32>
32+
%212 = stablehlo.reduce(%211 init: %151) applies stablehlo.add across dimensions = [0, 1, 2] : (tensor<32x3072x10xf32>, tensor<f32>) -> tensor<f32>
33+
return %212 : tensor<f32>
34+
}
35+
36+
// CHECK-LABEL: @four
37+
// CHECK-SAME: (%[[ARG0:.+]]: tensor<1x3072xf32>, %[[ARG1:.+]]: tensor<f32>)
38+
func.func @four(%154: tensor<1x3072xf32>, %151: tensor<f32>) -> tensor<f32> {
39+
// CHECK: %[[V0:.+]] = stablehlo.constant dense<3.200000e+02> : tensor<f32>
40+
// CHECK: %[[V1:.+]] = stablehlo.reduce(%[[ARG0]] init: %[[ARG1]]) applies stablehlo.add across dimensions = [0, 1] : (tensor<1x3072xf32>, tensor<f32>) -> tensor<f32>
41+
// CHECK: %[[V2:.+]] = stablehlo.multiply %[[V1]], %[[V0]] : tensor<f32>
42+
%211 = stablehlo.broadcast_in_dim %154, dims = [0, 1] : (tensor<1x3072xf32>) -> tensor<32x3072x10xf32>
43+
%212 = stablehlo.reduce(%211 init: %151) applies stablehlo.add across dimensions = [0, 1, 2] : (tensor<32x3072x10xf32>, tensor<f32>) -> tensor<f32>
44+
return %212 : tensor<f32>
45+
}
46+
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
// RUN: enzymexlamlir-opt --enzyme-hlo-opt %s | FileCheck %s
2+
3+
// CHECK-LABEL: @transpose
4+
// CHECK-NOT: stablehlo.transpose
5+
func.func @transpose(%1845: tensor<32x32768xf32>) -> tensor<32x32768xf32> {
6+
%1846 = stablehlo.transpose %1845, dims = [1, 0] : (tensor<32x32768xf32>) -> tensor<32768x32xf32>
7+
%1847 = stablehlo.transpose %1846, dims = [1, 0] : (tensor<32768x32xf32>) -> tensor<32x32768xf32>
8+
return %1847 : tensor<32x32768xf32>
9+
}
10+
11+
// CHECK-LABEL: @transpose2
12+
// CHECK: stablehlo.transpose %{{.*}}, dims = [1, 0, 2] : (tensor<2x3x4xf32>) -> tensor<3x2x4xf32>
13+
func.func @transpose2(%arg: tensor<2x3x4xf32>) -> tensor<3x2x4xf32> {
14+
%0 = stablehlo.transpose %arg, dims = [2, 0, 1] : (tensor<2x3x4xf32>) -> tensor<4x2x3xf32>
15+
%1 = stablehlo.transpose %0, dims = [2, 1, 0] : (tensor<4x2x3xf32>) -> tensor<3x2x4xf32>
16+
return %1 : tensor<3x2x4xf32>
17+
}

0 commit comments

Comments
 (0)