@@ -833,6 +833,90 @@ struct AddPad final : OpRewritePattern<mlir::stablehlo::AddOp> {
833
833
}
834
834
};
835
835
836
+ struct ConcatAppendingReshape final
837
+ : OpRewritePattern<mlir::stablehlo::ConcatenateOp> {
838
+ using OpRewritePattern::OpRewritePattern;
839
+
840
+ LogicalResult matchAndRewrite (mlir::stablehlo::ConcatenateOp op,
841
+ PatternRewriter &rewriter) const override {
842
+ if (op->getNumOperands () != 2 )
843
+ return failure ();
844
+
845
+ SmallVector<Value> lhs;
846
+
847
+ SmallVector<Type> converts;
848
+
849
+ size_t frontSize = 0 ;
850
+ for (auto v : op.getOperands ()) {
851
+ if (auto t = v.getDefiningOp <stablehlo::ConvertOp>()) {
852
+ converts.push_back (
853
+ t.getType ().cast <RankedTensorType>().getElementType ());
854
+ v = t.getOperand ();
855
+ } else
856
+ converts.push_back (nullptr );
857
+ if (auto t = v.getDefiningOp <stablehlo::ReshapeOp>()) {
858
+ lhs.push_back (t->getOperand (0 ));
859
+
860
+ auto prevshape = t.getOperand ().getType ().getShape ();
861
+ auto postshape = t.getType ().getShape ();
862
+ if (prevshape.size () + 1 != postshape.size ())
863
+ return failure ();
864
+ if (postshape[0 ] != 1 )
865
+ return failure ();
866
+
867
+ frontSize += prevshape[0 ];
868
+
869
+ for (auto en : llvm::enumerate (prevshape)) {
870
+ if (en.value () != postshape[1 + en.index ()])
871
+ return failure ();
872
+ }
873
+
874
+ } else
875
+ return failure ();
876
+ }
877
+
878
+ Type typeconvert = converts[0 ];
879
+ for (auto c : converts)
880
+ if (c != typeconvert)
881
+ return failure ();
882
+
883
+ RankedTensorType nextType = op.getType ();
884
+ auto nextDim = op.getDimension ();
885
+ if (nextDim == 0 ) {
886
+ SmallVector<int64_t > nextShape (nextType.getShape ().begin () + 1 ,
887
+ nextType.getShape ().end ());
888
+
889
+ nextShape[0 ] = frontSize;
890
+ nextType = RankedTensorType::get (
891
+ nextShape, typeconvert ? typeconvert : nextType.getElementType ());
892
+ nextDim = 0 ;
893
+ } else {
894
+ nextType = RankedTensorType::get (nextType.getShape ().drop_front (),
895
+ typeconvert ? typeconvert
896
+ : nextType.getElementType ());
897
+ nextDim--;
898
+ }
899
+ auto lhs2 = rewriter.create <stablehlo::ConcatenateOp>(op.getLoc (), nextType,
900
+ lhs, nextDim);
901
+
902
+ Value res2 = rewriter.create <stablehlo::ReshapeOp>(
903
+ op.getLoc (),
904
+ RankedTensorType::get (op.getType ().getShape (),
905
+ nextType.getElementType ()),
906
+ lhs2);
907
+
908
+ if (typeconvert)
909
+ res2 = rewriter.create <stablehlo::ConvertOp>(
910
+ op.getLoc (),
911
+ RankedTensorType::get (
912
+ res2.getType ().cast <RankedTensorType>().getShape (), typeconvert),
913
+ res2);
914
+
915
+ rewriter.replaceOp (op, res2);
916
+ return success ();
917
+ }
918
+ };
919
+
836
920
template <typename T>
837
921
struct ConcatPushBinop final
838
922
: OpRewritePattern<mlir::stablehlo::ConcatenateOp> {
@@ -1835,10 +1919,11 @@ struct EnzymeHLOOptPass : public EnzymeHLOOptPassBase<EnzymeHLOOptPass> {
1835
1919
void runOnOperation () override {
1836
1920
auto context = getOperation ()->getContext ();
1837
1921
RewritePatternSet patterns (context);
1838
- patterns.add <ConvertConcat, DynamicSliceToStatic, DynamicUpdateSliceElim,
1839
- DynamicUpdateToConcat, SliceOfDynamicUpdate, SlicePad,
1840
- SliceSlice, AddPad, PadSimplify, DotReshapeDot,
1841
- ConcatConstProp, ConcatFuse, ConcatPushBinop<stablehlo::AddOp>,
1922
+ patterns.add <ConcatAppendingReshape, ConvertConcat, DynamicSliceToStatic,
1923
+ DynamicUpdateSliceElim, DynamicUpdateToConcat,
1924
+ SliceOfDynamicUpdate, SlicePad, SliceSlice, AddPad,
1925
+ PadSimplify, DotReshapeDot, ConcatConstProp, ConcatFuse,
1926
+ ConcatPushBinop<stablehlo::AddOp>,
1842
1927
ConcatPushBinop<stablehlo::MulOp>,
1843
1928
/* ScatterToPad, */ BroadcastToReshape, ReduceToReshape,
1844
1929
ConvertSimplify, ReshapeSimplify, SliceSimplify, ReduceConcat,
0 commit comments