@@ -698,3 +698,50 @@ func.func @test_castlike(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf16>) -> tensor
698
698
// CHECK: onnx.Return [[RES]] : tensor<*xf16>
699
699
}
700
700
701
+ // -----
702
+
703
+ func.func @test_sum (%arg0: tensor <128 x10 xf32 >, %arg1: tensor <64 x128 x10 xf32 >, %arg2: tensor <10 xf32 >, %arg3: tensor <64 x1 x1 xf32 >) -> tensor <64 x128 x10 xf32 > {
704
+ %0 = " onnx.Sum" (%arg0 , %arg1 , %arg2 , %arg3 ) : (tensor <128 x10 xf32 >, tensor <64 x128 x10 xf32 >, tensor <10 xf32 >, tensor <64 x1 x1 xf32 >) -> tensor <64 x128 x10 xf32 >
705
+ onnx.Return %0 : tensor <64 x128 x10 xf32 >
706
+ // CHECK-LABEL: func @test_sum
707
+ // CHECK-SAME: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}, %[[ARG3:.*]]: {{.*}})
708
+ // CHECK-NEXT: %[[SUM0:.*]] = "onnx.Add"(%[[ARG0]], %[[ARG1]])
709
+ // CHECK-NEXT: %[[SUM1:.*]] = "onnx.Add"(%[[SUM0]], %[[ARG2]])
710
+ // CHECK-NEXT: %[[SUM2:.*]] = "onnx.Add"(%[[SUM1]], %[[ARG3]])
711
+ // CHECK-NEXT: onnx.Return %[[SUM2]]
712
+ }
713
+
714
+ // -----
715
+
716
+ func.func @test_sum_to_unranked (%arg0: tensor <128 x10 xf32 >, %arg1: tensor <64 x128 x10 xf32 >, %arg2: tensor <10 xf32 >, %arg3: tensor <64 x1 x1 xf32 >) -> tensor <*xf32 > {
717
+ %0 = " onnx.Sum" (%arg0 , %arg1 , %arg2 , %arg3 ) : (tensor <128 x10 xf32 >, tensor <64 x128 x10 xf32 >, tensor <10 xf32 >, tensor <64 x1 x1 xf32 >) -> tensor <*xf32 >
718
+ onnx.Return %0 : tensor <*xf32 >
719
+ // CHECK-LABEL: func @test_sum
720
+ // CHECK-SAME: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}, %[[ARG3:.*]]: {{.*}})
721
+ // CHECK-NEXT: %[[SUM0:.*]] = "onnx.Add"(%[[ARG0]], %[[ARG1]])
722
+ // CHECK-NEXT: %[[SUM1:.*]] = "onnx.Add"(%[[SUM0]], %[[ARG2]])
723
+ // CHECK-NEXT: %[[SUM2:.*]] = "onnx.Add"(%[[SUM1]], %[[ARG3]])
724
+ // CHECK-NEXT: %[[CAST:.*]] = "onnx.Cast"(%[[SUM2]]) {saturate = 1 : si64, to = f32} : (tensor<64x128x10xf32>) -> tensor<*xf32>
725
+ // CHECK-NEXT: onnx.Return %[[CAST]]
726
+ }
727
+
728
+ // -----
729
+
730
+ func.func @test_sum_single_input (%arg0: tensor <64 x128 x10 xf32 >) -> tensor <64 x128 x10 xf32 > {
731
+ %0 = " onnx.Sum" (%arg0 ) : (tensor <64 x128 x10 xf32 >) -> tensor <64 x128 x10 xf32 >
732
+ onnx.Return %0 : tensor <64 x128 x10 xf32 >
733
+ // CHECK-LABEL: func @test_sum_single_input
734
+ // CHECK-SAME: (%[[ARG0:.*]]: {{.*}})
735
+ // CHECK-NEXT: onnx.Return %[[ARG0]]
736
+ }
737
+
738
+ // -----
739
+
740
+ func.func @test_sum_single_input_to_unranked (%arg0: tensor <64 x128 x10 xf32 >) -> tensor <*xf32 > {
741
+ %0 = " onnx.Sum" (%arg0 ) : (tensor <64 x128 x10 xf32 >) -> tensor <*xf32 >
742
+ onnx.Return %0 : tensor <*xf32 >
743
+ // CHECK-LABEL: func @test_sum_single_input_to_unranked
744
+ // CHECK-SAME: (%[[ARG0:.*]]: {{.*}})
745
+ // CHECK-NEXT: %[[CAST:.*]] = "onnx.Cast"(%[[ARG0]]) {saturate = 1 : si64, to = f32} : (tensor<64x128x10xf32>) -> tensor<*xf32>
746
+ // CHECK-NEXT: onnx.Return %[[CAST]]
747
+ }
0 commit comments