Skip to content

Commit d6e5d9c

Browse files
committed
Add concat reshape opt
1 parent 6340d61 commit d6e5d9c

File tree

2 files changed

+115
-4
lines changed

2 files changed

+115
-4
lines changed

src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp

Lines changed: 89 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -833,6 +833,90 @@ struct AddPad final : OpRewritePattern<mlir::stablehlo::AddOp> {
833833
}
834834
};
835835

836+
struct ConcatAppendingReshape final
837+
: OpRewritePattern<mlir::stablehlo::ConcatenateOp> {
838+
using OpRewritePattern::OpRewritePattern;
839+
840+
LogicalResult matchAndRewrite(mlir::stablehlo::ConcatenateOp op,
841+
PatternRewriter &rewriter) const override {
842+
if (op->getNumOperands() != 2)
843+
return failure();
844+
845+
SmallVector<Value> lhs;
846+
847+
SmallVector<Type> converts;
848+
849+
size_t frontSize = 0;
850+
for (auto v : op.getOperands()) {
851+
if (auto t = v.getDefiningOp<stablehlo::ConvertOp>()) {
852+
converts.push_back(
853+
t.getType().cast<RankedTensorType>().getElementType());
854+
v = t.getOperand();
855+
} else
856+
converts.push_back(nullptr);
857+
if (auto t = v.getDefiningOp<stablehlo::ReshapeOp>()) {
858+
lhs.push_back(t->getOperand(0));
859+
860+
auto prevshape = t.getOperand().getType().getShape();
861+
auto postshape = t.getType().getShape();
862+
if (prevshape.size() + 1 != postshape.size())
863+
return failure();
864+
if (postshape[0] != 1)
865+
return failure();
866+
867+
frontSize += prevshape[0];
868+
869+
for (auto en : llvm::enumerate(prevshape)) {
870+
if (en.value() != postshape[1 + en.index()])
871+
return failure();
872+
}
873+
874+
} else
875+
return failure();
876+
}
877+
878+
Type typeconvert = converts[0];
879+
for (auto c : converts)
880+
if (c != typeconvert)
881+
return failure();
882+
883+
RankedTensorType nextType = op.getType();
884+
auto nextDim = op.getDimension();
885+
if (nextDim == 0) {
886+
SmallVector<int64_t> nextShape(nextType.getShape().begin() + 1,
887+
nextType.getShape().end());
888+
889+
nextShape[0] = frontSize;
890+
nextType = RankedTensorType::get(
891+
nextShape, typeconvert ? typeconvert : nextType.getElementType());
892+
nextDim = 0;
893+
} else {
894+
nextType = RankedTensorType::get(nextType.getShape().drop_front(),
895+
typeconvert ? typeconvert
896+
: nextType.getElementType());
897+
nextDim--;
898+
}
899+
auto lhs2 = rewriter.create<stablehlo::ConcatenateOp>(op.getLoc(), nextType,
900+
lhs, nextDim);
901+
902+
Value res2 = rewriter.create<stablehlo::ReshapeOp>(
903+
op.getLoc(),
904+
RankedTensorType::get(op.getType().getShape(),
905+
nextType.getElementType()),
906+
lhs2);
907+
908+
if (typeconvert)
909+
res2 = rewriter.create<stablehlo::ConvertOp>(
910+
op.getLoc(),
911+
RankedTensorType::get(
912+
res2.getType().cast<RankedTensorType>().getShape(), typeconvert),
913+
res2);
914+
915+
rewriter.replaceOp(op, res2);
916+
return success();
917+
}
918+
};
919+
836920
template <typename T>
837921
struct ConcatPushBinop final
838922
: OpRewritePattern<mlir::stablehlo::ConcatenateOp> {
@@ -1835,10 +1919,11 @@ struct EnzymeHLOOptPass : public EnzymeHLOOptPassBase<EnzymeHLOOptPass> {
18351919
void runOnOperation() override {
18361920
auto context = getOperation()->getContext();
18371921
RewritePatternSet patterns(context);
1838-
patterns.add<ConvertConcat, DynamicSliceToStatic, DynamicUpdateSliceElim,
1839-
DynamicUpdateToConcat, SliceOfDynamicUpdate, SlicePad,
1840-
SliceSlice, AddPad, PadSimplify, DotReshapeDot,
1841-
ConcatConstProp, ConcatFuse, ConcatPushBinop<stablehlo::AddOp>,
1922+
patterns.add<ConcatAppendingReshape, ConvertConcat, DynamicSliceToStatic,
1923+
DynamicUpdateSliceElim, DynamicUpdateToConcat,
1924+
SliceOfDynamicUpdate, SlicePad, SliceSlice, AddPad,
1925+
PadSimplify, DotReshapeDot, ConcatConstProp, ConcatFuse,
1926+
ConcatPushBinop<stablehlo::AddOp>,
18421927
ConcatPushBinop<stablehlo::MulOp>,
18431928
/*ScatterToPad, */ BroadcastToReshape, ReduceToReshape,
18441929
ConvertSimplify, ReshapeSimplify, SliceSimplify, ReduceConcat,

test/lit_tests/concatreshape.mlir

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
// RUN: enzymexlamlir-opt --enzyme-hlo-opt %s | FileCheck %s
2+
3+
module {
4+
5+
func.func @main(%a : tensor<3x4xf32>, %b : tensor<3x4xf32>) -> tensor<2x3x4xf32> {
6+
%u = stablehlo.reshape %a : (tensor<3x4xf32>) -> tensor<1x3x4xf32>
7+
%v = stablehlo.reshape %b : (tensor<3x4xf32>) -> tensor<1x3x4xf32>
8+
%concat = stablehlo.concatenate %u, %v, dim=0 : (tensor<1x3x4xf32>, tensor<1x3x4xf32>) -> tensor<2x3x4xf32>
9+
return %concat : tensor<2x3x4xf32>
10+
}
11+
12+
// TODO this opt
13+
func.func @main2(%a : tensor<3x4xf32>, %b : tensor<3x4xf32>) -> tensor<3x2x4xf32> {
14+
%u = stablehlo.reshape %a : (tensor<3x4xf32>) -> tensor<3x1x4xf32>
15+
%v = stablehlo.reshape %b : (tensor<3x4xf32>) -> tensor<3x1x4xf32>
16+
%concat = stablehlo.concatenate %u, %v, dim=1 : (tensor<3x1x4xf32>, tensor<3x1x4xf32>) -> tensor<3x2x4xf32>
17+
return %concat : tensor<3x2x4xf32>
18+
}
19+
}
20+
21+
22+
// CHECK: func.func @main(%arg0: tensor<3x4xf32>, %arg1: tensor<3x4xf32>) -> tensor<2x3x4xf32> {
23+
// CHECK-NEXT: %0 = stablehlo.concatenate %arg0, %arg1, dim = 0 : (tensor<3x4xf32>, tensor<3x4xf32>) -> tensor<6x4xf32>
24+
// CHECK-NEXT: %1 = stablehlo.reshape %0 : (tensor<6x4xf32>) -> tensor<2x3x4xf32>
25+
// CHECK-NEXT: return %1 : tensor<2x3x4xf32>
26+
// CHECK-NEXT: }

0 commit comments

Comments
 (0)