Skip to content

Commit 380b8b7

Browse files
committed
Fix one input case
Signed-off-by: Sam <[email protected]>
1 parent a18907b commit 380b8b7

File tree

2 files changed

+16
-4
lines changed

2 files changed

+16
-4
lines changed

src/Dialect/ONNX/Transforms/Decompose.cpp

+6-4
Original file line numberDiff line numberDiff line change
@@ -992,11 +992,13 @@ struct SumToAddPattern : public OpRewritePattern<ONNXSumOp> {
992992
LogicalResult matchAndRewrite(
993993
ONNXSumOp sumOp, PatternRewriter &rewriter) const final {
994994
SmallVector<Value> inputs(sumOp.getData_0());
995-
assert(inputs.size() >= 2 && "expected at least two inputs");
995+
assert(inputs.size() > 0 && "expected at least one input");
996996
Value result = inputs[0];
997-
inputs.erase(inputs.begin());
998-
for (auto input : inputs) {
999-
result = rewriter.create<ONNXAddOp>(sumOp.getLoc(), result, input);
997+
if (inputs.size() > 1) {
998+
inputs.erase(inputs.begin());
999+
for (auto input : inputs) {
1000+
result = rewriter.create<ONNXAddOp>(sumOp.getLoc(), result, input);
1001+
}
10001002
}
10011003
rewriter.replaceOp(sumOp, result);
10021004
return success();

test/mlir/onnx/onnx_decompose.mlir

+10
Original file line numberDiff line numberDiff line change
@@ -710,3 +710,13 @@ func.func @test_sum(%arg0: tensor<64x128x10xf32>, %arg1: tensor<128x10xf32>, %ar
710710
// CHECK-NEXT: %[[SUM2:.*]] = "onnx.Add"(%[[SUM1]], %[[ARG3]])
711711
// CHECK-NEXT: onnx.Return %[[SUM2]]
712712
}
713+
714+
// -----
715+
716+
func.func @test_sum_single_input(%arg0: tensor<64x128x10xf32>) -> tensor<64x128x10xf32> {
717+
%0 = "onnx.Sum"(%arg0) : (tensor<64x128x10xf32>) -> tensor<64x128x10xf32>
718+
onnx.Return %0 : tensor<64x128x10xf32>
719+
// CHECK-LABEL: func @test_sum_single_input
720+
// CHECK-SAME: (%[[ARG0:.*]]: {{.*}})
721+
// CHECK-NEXT: onnx.Return %[[ARG0]]
722+
}

0 commit comments

Comments
 (0)