Skip to content

transpose opts #52

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 142 additions & 1 deletion src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1911,6 +1911,146 @@ struct NoNan : public OpRewritePattern<mlir::stablehlo::CompareOp> {
}
};

struct TransposeTranspose
: public OpRewritePattern<mlir::stablehlo::TransposeOp> {
using OpRewritePattern<mlir::stablehlo::TransposeOp>::OpRewritePattern;

LogicalResult matchAndRewrite(mlir::stablehlo::TransposeOp op,
PatternRewriter &rewriter) const final {
if (auto definingTranspose =
op.getOperand().getDefiningOp<mlir::stablehlo::TransposeOp>()) {
llvm::ArrayRef<int64_t> thisPermutation = op.getPermutation();
llvm::ArrayRef<int64_t> prevPermutation =
definingTranspose.getPermutation();

SmallVector<int64_t> newPermutation;
newPermutation.resize(thisPermutation.size());
for (unsigned i = 0, e = thisPermutation.size(); i != e; ++i) {
newPermutation[i] = prevPermutation[thisPermutation[i]];
}

rewriter.modifyOpInPlace(op, [&]() {
op.setPermutation(newPermutation);
op.setOperand(definingTranspose.getOperand());
});

return success();
}
return rewriter.notifyMatchFailure(op, "not a transpose(transpose)");
}
};

struct TransposeConvert : public OpRewritePattern<mlir::stablehlo::ConvertOp> {
using OpRewritePattern<mlir::stablehlo::ConvertOp>::OpRewritePattern;

LogicalResult matchAndRewrite(mlir::stablehlo::ConvertOp op,
PatternRewriter &rewriter) const final {
auto resultType = op.getResult().getType().cast<TensorType>();
auto operandType = op.getOperand().getType().cast<TensorType>();
if (!resultType.hasStaticShape() || !operandType.hasStaticShape())
return failure();
if (resultType.getNumElements() * resultType.getElementTypeBitWidth() >=
operandType.getNumElements() * operandType.getElementTypeBitWidth())
return failure();

auto transpose =
op.getOperand().getDefiningOp<mlir::stablehlo::TransposeOp>();
if (!transpose || !llvm::hasSingleElement(transpose->getUsers()))
return failure();

auto newConvert = rewriter.create<stablehlo::ConvertOp>(
op.getLoc(), transpose.getOperand(), resultType.getElementType());
auto newTranspose = rewriter.create<stablehlo::TransposeOp>(
transpose.getLoc(), newConvert.getResult(), transpose.getPermutation());
rewriter.replaceOp(op, newTranspose);
rewriter.eraseOp(transpose);

return success();
}
};

struct BroadcastReduce : public OpRewritePattern<mlir::stablehlo::ReduceOp> {
using OpRewritePattern<mlir::stablehlo::ReduceOp>::OpRewritePattern;

LogicalResult matchAndRewrite(mlir::stablehlo::ReduceOp op,
PatternRewriter &rewriter) const final {
if (op.getInputs().size() != 1 || op.getInitValues().size() != 1) {
return rewriter.notifyMatchFailure(
op, "only single-operand single-init reduce is supported");
}
// TODO: min/max can also be an option since they are dropped
if (!isa<stablehlo::AddOp>(op.getRegion().getBlocks().front().front())) {
return rewriter.notifyMatchFailure(op, "only add is currently supported");
}

Value input = op.getInputs()[0];
auto inputType = input.getType().cast<TensorType>();
auto broadcast = input.getDefiningOp<mlir::stablehlo::BroadcastInDimOp>();
if (!broadcast) {
return rewriter.notifyMatchFailure(op,
"input source is not a broadcast op");
}

// If any of the dimensions that are being reduced was initially
// broadcasted, we can multiply the result with the dimension instead.
ArrayRef<int64_t> broadcastDims = broadcast.getBroadcastDimensions();
SmallVector<int64_t> broadcastFromNothingDims, broadcastFromOneDims;
auto broadcastSourceType =
broadcast.getOperand().getType().cast<TensorType>();
for (int64_t reductionDim : op.getDimensions()) {
if (inputType.isDynamicDim(reductionDim))
continue;
auto it = llvm::find(broadcastDims, reductionDim);
if (it == broadcastDims.end()) {
broadcastFromNothingDims.push_back(reductionDim);
continue;
}
size_t originalDim = std::distance(broadcastDims.begin(), it);
if (broadcastSourceType.getDimSize(originalDim) == 1 &&
inputType.getDimSize(reductionDim) != 1) {
broadcastFromOneDims.push_back(reductionDim);
}
}
if (broadcastFromNothingDims.empty() && broadcastFromOneDims.empty())
return rewriter.notifyMatchFailure(op, "no dimensions to remove");

int64_t size = 1;
for (int64_t dim : broadcastFromNothingDims) {
size *= inputType.getDimSize(dim);
}
for (int64_t dim : broadcastFromOneDims) {
size *= inputType.getDimSize(dim);
}

int64_t numRemoved = 0;
SmallVector<int64_t> newReduceDimensions;
llvm::sort(broadcastFromNothingDims);
for (int64_t reductionDim : op.getDimensions()) {
if (llvm::is_contained(broadcastFromNothingDims, reductionDim)) {
numRemoved++;
continue;
}
newReduceDimensions.push_back(reductionDim - numRemoved);
}

auto newReduction = rewriter.create<stablehlo::ReduceOp>(
op.getLoc(), op->getResultTypes(), ValueRange{broadcast.getOperand()},
op.getInitValues(), newReduceDimensions);
newReduction.getRegion().takeBody(op.getRegion());

auto newResultType = newReduction.getResult(0).getType().cast<TensorType>();
auto constantInt = rewriter.create<stablehlo::ConstantOp>(
op.getLoc(),
makeAttr(newResultType.clone(rewriter.getI64Type()), size));
auto converted = rewriter.create<stablehlo::ConvertOp>(
op.getLoc(), constantInt, newResultType.getElementType());
rewriter.replaceOpWithNewOp<stablehlo::MulOp>(op, newReduction.getResult(0),
converted.getResult());

return success();
}
};

struct EnzymeHLOOptPass : public EnzymeHLOOptPassBase<EnzymeHLOOptPass> {

void runOnOperation() override {
Expand All @@ -1930,7 +2070,8 @@ struct EnzymeHLOOptPass : public EnzymeHLOOptPassBase<EnzymeHLOOptPass> {
DivSimplify, PowSimplify, BinBroadcastSplat<stablehlo::AddOp>,
BinBroadcastSplat<stablehlo::SubtractOp>,
BinBroadcastSplat<stablehlo::DivOp>,
BinBroadcastSplat<stablehlo::MulOp>>(context);
BinBroadcastSplat<stablehlo::MulOp>, TransposeTranspose,
TransposeConvert, BroadcastReduce>(context);
patterns.add<IotaSimplify, BroadcastInDimSimplify>(max_constant_expansion,
context);
if (all_finite)
Expand Down
46 changes: 46 additions & 0 deletions test/lit_tests/broadcastreduce.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// RUN: enzymexlamlir-opt --enzyme-hlo-opt %s | FileCheck %s

// CHECK-LABEL: @one
// CHECK-SAME: (%[[ARG0:.+]]: tensor<1x3072xf32>, %[[ARG1:.+]]: tensor<f32>)
// CHECK: %[[V0:.+]] = stablehlo.constant dense<3.200000e+01> : tensor<f32>
// CHECK: %[[V1:.+]] = stablehlo.reduce(%[[ARG0]] init: %[[ARG1]]) applies stablehlo.add across dimensions = [0, 1] : (tensor<1x3072xf32>, tensor<f32>) -> tensor<f32>
// CHECK: %[[V2:.+]] = stablehlo.multiply %[[V1]], %[[V0]] : tensor<f32>
func.func @one(%154: tensor<1x3072xf32>, %151: tensor<f32>) -> tensor<f32> {
%211 = stablehlo.broadcast_in_dim %154, dims = [0, 1] : (tensor<1x3072xf32>) -> tensor<1x3072x32xf32>
%212 = stablehlo.reduce(%211 init: %151) applies stablehlo.add across dimensions = [0, 1, 2] : (tensor<1x3072x32xf32>, tensor<f32>) -> tensor<f32>
return %212 : tensor<f32>
}

// CHECK-LABEL: @two
// CHECK-SAME: (%[[ARG0:.+]]: tensor<1x3072xf32>, %[[ARG1:.+]]: tensor<f32>)
// CHECK: %[[V0:.+]] = stablehlo.constant dense<3.200000e+01> : tensor<f32>
// CHECK: %[[V1:.+]] = stablehlo.reduce(%[[ARG0]] init: %[[ARG1]]) applies stablehlo.add across dimensions = [0, 1] : (tensor<1x3072xf32>, tensor<f32>) -> tensor<f32>
// CHECK: %[[V2:.+]] = stablehlo.multiply %[[V1]], %[[V0]] : tensor<f32>
func.func @two(%154: tensor<1x3072xf32>, %151: tensor<f32>) -> tensor<f32> {
%211 = stablehlo.broadcast_in_dim %154, dims = [0, 1] : (tensor<1x3072xf32>) -> tensor<32x3072xf32>
%212 = stablehlo.reduce(%211 init: %151) applies stablehlo.add across dimensions = [0, 1] : (tensor<32x3072xf32>, tensor<f32>) -> tensor<f32>
return %212 : tensor<f32>
}

// CHECK-LABEL: @three
// CHECK-SAME: (%[[ARG0:.+]]: tensor<1x3072xf32>, %[[ARG1:.+]]: tensor<f32>)
// CHECK: %[[V0:.+]] = stablehlo.constant dense<3.200000e+02> : tensor<f32>
// CHECK: %[[V1:.+]] = stablehlo.reduce(%[[ARG0]] init: %[[ARG1]]) applies stablehlo.add across dimensions = [0, 1] : (tensor<1x3072xf32>, tensor<f32>) -> tensor<f32>
// CHECK: %[[V2:.+]] = stablehlo.multiply %[[V1]], %[[V0]] : tensor<f32>
func.func @three(%154: tensor<1x3072xf32>, %151: tensor<f32>) -> tensor<f32> {
%211 = stablehlo.broadcast_in_dim %154, dims = [2, 1] : (tensor<1x3072xf32>) -> tensor<32x3072x10xf32>
%212 = stablehlo.reduce(%211 init: %151) applies stablehlo.add across dimensions = [0, 1, 2] : (tensor<32x3072x10xf32>, tensor<f32>) -> tensor<f32>
return %212 : tensor<f32>
}

// CHECK-LABEL: @four
// CHECK-SAME: (%[[ARG0:.+]]: tensor<1x3072xf32>, %[[ARG1:.+]]: tensor<f32>)
func.func @four(%154: tensor<1x3072xf32>, %151: tensor<f32>) -> tensor<f32> {
// CHECK: %[[V0:.+]] = stablehlo.constant dense<3.200000e+02> : tensor<f32>
// CHECK: %[[V1:.+]] = stablehlo.reduce(%[[ARG0]] init: %[[ARG1]]) applies stablehlo.add across dimensions = [0, 1] : (tensor<1x3072xf32>, tensor<f32>) -> tensor<f32>
// CHECK: %[[V2:.+]] = stablehlo.multiply %[[V1]], %[[V0]] : tensor<f32>
%211 = stablehlo.broadcast_in_dim %154, dims = [0, 1] : (tensor<1x3072xf32>) -> tensor<32x3072x10xf32>
%212 = stablehlo.reduce(%211 init: %151) applies stablehlo.add across dimensions = [0, 1, 2] : (tensor<32x3072x10xf32>, tensor<f32>) -> tensor<f32>
return %212 : tensor<f32>
}

17 changes: 17 additions & 0 deletions test/lit_tests/transposetranspose.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// RUN: enzymexlamlir-opt --enzyme-hlo-opt %s | FileCheck %s

// CHECK-LABEL: @transpose
// CHECK-NOT: stablehlo.transpose
func.func @transpose(%1845: tensor<32x32768xf32>) -> tensor<32x32768xf32> {
%1846 = stablehlo.transpose %1845, dims = [1, 0] : (tensor<32x32768xf32>) -> tensor<32768x32xf32>
%1847 = stablehlo.transpose %1846, dims = [1, 0] : (tensor<32768x32xf32>) -> tensor<32x32768xf32>
return %1847 : tensor<32x32768xf32>
}

// CHECK-LABEL: @transpose2
// CHECK: stablehlo.transpose %{{.*}}, dims = [1, 0, 2] : (tensor<2x3x4xf32>) -> tensor<3x2x4xf32>
func.func @transpose2(%arg: tensor<2x3x4xf32>) -> tensor<3x2x4xf32> {
%0 = stablehlo.transpose %arg, dims = [2, 0, 1] : (tensor<2x3x4xf32>) -> tensor<4x2x3xf32>
%1 = stablehlo.transpose %0, dims = [2, 1, 0] : (tensor<4x2x3xf32>) -> tensor<3x2x4xf32>
return %1 : tensor<3x2x4xf32>
}