Skip to content

Commit 265ee60

Browse files
Implement decomposition from onnx.Sum to sequence of onnx.Add (#2964)
Signed-off-by: Sam <[email protected]> Co-authored-by: Alexandre Eichenberger <[email protected]>
1 parent 56a610c commit 265ee60

File tree

2 files changed

+73
-0
lines changed

2 files changed

+73
-0
lines changed

src/Dialect/ONNX/Transforms/Decompose.cpp

+26
Original file line numberDiff line numberDiff line change
@@ -985,6 +985,30 @@ struct GroupNormIntoLayerNormPattern2
985985
}
986986
};
987987

988+
/// Decompose `onnx.Sum` to a sequence of `onnx.Add`
989+
struct SumToAddPattern : public OpRewritePattern<ONNXSumOp> {
990+
using OpRewritePattern<ONNXSumOp>::OpRewritePattern;
991+
992+
LogicalResult matchAndRewrite(
993+
ONNXSumOp sumOp, PatternRewriter &rewriter) const final {
994+
SmallVector<Value> inputs(sumOp.getData_0());
995+
assert(inputs.size() > 0 && "expected at least one input");
996+
Value result = inputs[0];
997+
if (inputs.size() > 1) {
998+
inputs.erase(inputs.begin());
999+
for (auto input : inputs) {
1000+
result = rewriter.create<ONNXAddOp>(sumOp.getLoc(), result, input);
1001+
}
1002+
}
1003+
auto resultType = mlir::cast<ShapedType>(sumOp.getResult().getType());
1004+
if (resultType != result.getType())
1005+
result = rewriter.create<ONNXCastOp>(
1006+
sumOp.getLoc(), resultType, result, 1, resultType.getElementType());
1007+
rewriter.replaceOp(sumOp, result);
1008+
return success();
1009+
}
1010+
};
1011+
9881012
// =============================================================================
9891013
// Pattern for replacing CastLikeOp by CastOp.
9901014
// =============================================================================
@@ -1093,6 +1117,7 @@ void DecomposeONNXToONNXPass::runOnOperation() {
10931117
target.addIllegalOp<ONNXSplitV11Op>();
10941118
target.addIllegalOp<ONNXSplitV13Op>();
10951119
target.addIllegalOp<ONNXSqueezeV11Op>();
1120+
target.addIllegalOp<ONNXSumOp>();
10961121
target.addIllegalOp<ONNXUnsqueezeV11Op>();
10971122
target.addIllegalOp<ONNXUpsampleOp>();
10981123
target.addIllegalOp<ONNXUpsampleV7Op>();
@@ -1165,6 +1190,7 @@ void onnx_mlir::getDecomposeONNXToONNXPatterns(
11651190
patterns.insert<InstanceNormIntoLayerNormPattern>(context);
11661191
patterns.insert<GroupNormIntoLayerNormPattern1>(context);
11671192
patterns.insert<GroupNormIntoLayerNormPattern2>(context);
1193+
patterns.insert<SumToAddPattern>(context);
11681194

11691195
// TODO: consider whether to include SoftmaxPattern here
11701196
}

test/mlir/onnx/onnx_decompose.mlir

+47
Original file line numberDiff line numberDiff line change
@@ -698,3 +698,50 @@ func.func @test_castlike(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf16>) -> tensor
698698
// CHECK: onnx.Return [[RES]] : tensor<*xf16>
699699
}
700700

701+
// -----
702+
703+
func.func @test_sum(%arg0: tensor<128x10xf32>, %arg1: tensor<64x128x10xf32>, %arg2: tensor<10xf32>, %arg3: tensor<64x1x1xf32>) -> tensor<64x128x10xf32> {
704+
%0 = "onnx.Sum"(%arg0, %arg1, %arg2, %arg3) : (tensor<128x10xf32>, tensor<64x128x10xf32>, tensor<10xf32>, tensor<64x1x1xf32>) -> tensor<64x128x10xf32>
705+
onnx.Return %0 : tensor<64x128x10xf32>
706+
// CHECK-LABEL: func @test_sum
707+
// CHECK-SAME: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}, %[[ARG3:.*]]: {{.*}})
708+
// CHECK-NEXT: %[[SUM0:.*]] = "onnx.Add"(%[[ARG0]], %[[ARG1]])
709+
// CHECK-NEXT: %[[SUM1:.*]] = "onnx.Add"(%[[SUM0]], %[[ARG2]])
710+
// CHECK-NEXT: %[[SUM2:.*]] = "onnx.Add"(%[[SUM1]], %[[ARG3]])
711+
// CHECK-NEXT: onnx.Return %[[SUM2]]
712+
}
713+
714+
// -----
715+
716+
func.func @test_sum_to_unranked(%arg0: tensor<128x10xf32>, %arg1: tensor<64x128x10xf32>, %arg2: tensor<10xf32>, %arg3: tensor<64x1x1xf32>) -> tensor<*xf32> {
717+
%0 = "onnx.Sum"(%arg0, %arg1, %arg2, %arg3) : (tensor<128x10xf32>, tensor<64x128x10xf32>, tensor<10xf32>, tensor<64x1x1xf32>) -> tensor<*xf32>
718+
onnx.Return %0 : tensor<*xf32>
719+
// CHECK-LABEL: func @test_sum
720+
// CHECK-SAME: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}, %[[ARG3:.*]]: {{.*}})
721+
// CHECK-NEXT: %[[SUM0:.*]] = "onnx.Add"(%[[ARG0]], %[[ARG1]])
722+
// CHECK-NEXT: %[[SUM1:.*]] = "onnx.Add"(%[[SUM0]], %[[ARG2]])
723+
// CHECK-NEXT: %[[SUM2:.*]] = "onnx.Add"(%[[SUM1]], %[[ARG3]])
724+
// CHECK-NEXT: %[[CAST:.*]] = "onnx.Cast"(%[[SUM2]]) {saturate = 1 : si64, to = f32} : (tensor<64x128x10xf32>) -> tensor<*xf32>
725+
// CHECK-NEXT: onnx.Return %[[CAST]]
726+
}
727+
728+
// -----
729+
730+
func.func @test_sum_single_input(%arg0: tensor<64x128x10xf32>) -> tensor<64x128x10xf32> {
731+
%0 = "onnx.Sum"(%arg0) : (tensor<64x128x10xf32>) -> tensor<64x128x10xf32>
732+
onnx.Return %0 : tensor<64x128x10xf32>
733+
// CHECK-LABEL: func @test_sum_single_input
734+
// CHECK-SAME: (%[[ARG0:.*]]: {{.*}})
735+
// CHECK-NEXT: onnx.Return %[[ARG0]]
736+
}
737+
738+
// -----
739+
740+
func.func @test_sum_single_input_to_unranked(%arg0: tensor<64x128x10xf32>) -> tensor<*xf32> {
741+
%0 = "onnx.Sum"(%arg0) : (tensor<64x128x10xf32>) -> tensor<*xf32>
742+
onnx.Return %0 : tensor<*xf32>
743+
// CHECK-LABEL: func @test_sum_single_input_to_unranked
744+
// CHECK-SAME: (%[[ARG0:.*]]: {{.*}})
745+
// CHECK-NEXT: %[[CAST:.*]] = "onnx.Cast"(%[[ARG0]]) {saturate = 1 : si64, to = f32} : (tensor<64x128x10xf32>) -> tensor<*xf32>
746+
// CHECK-NEXT: onnx.Return %[[CAST]]
747+
}

0 commit comments

Comments
 (0)