@@ -3801,3 +3801,47 @@ func.func @test_RMSlayer_norm_2inputs(%arg0: tensor<12x3x5xf32>, %arg1: tensor<5
3801
3801
// CHECK: }
3802
3802
}
3803
3803
3804
+ // -----
3805
+
3806
+ //===----------------------------------------------------------------------===//
3807
+ /// Test shape inference for Parallel and Fork.
3808
+ //===----------------------------------------------------------------------===//
3809
+
3810
+ func.func @test_parallel_fork_1 (%arg0: tensor <8 x64 x32 xf32 >, %arg1: tensor <32 x32 xf32 >) -> (tensor <*xf32 >, tensor <*xf32 >) {
3811
+ %c0 = onnx.Constant dense <1.0 > : tensor <32 x32 xf32 >
3812
+ %c1 = onnx.Constant dense <1.0 > : tensor <32 xf32 >
3813
+ %c2 = onnx.Constant dense <1.0 > : tensor <32 x32 xf32 >
3814
+
3815
+ %0:2 = " onnx.Parallel" () ({
3816
+ %00 = " onnx.Fork" () ({
3817
+ %01 = " onnx.MatMul" (%arg0 , %c0 ) : (tensor <8 x64 x32 xf32 >, tensor <32 x32 xf32 >) -> tensor <*xf32 >
3818
+ onnx.Yield %01 : tensor <*xf32 >
3819
+ }) {id = 0 : si64 } : () -> tensor <*xf32 >
3820
+ %01 = " onnx.Fork" () ({
3821
+ %01 = " onnx.MatMul" (%arg0 , %c2 ) : (tensor <8 x64 x32 xf32 >, tensor <32 x32 xf32 >) -> tensor <*xf32 >
3822
+ onnx.Yield %01 : tensor <*xf32 >
3823
+ }) {id = 1 : si64 } : () -> tensor <*xf32 >
3824
+ " onnx.Yield" (%00 , %01 ) : (tensor <*xf32 >, tensor <*xf32 >) -> ()
3825
+ }) : () -> (tensor <*xf32 >, tensor <*xf32 >)
3826
+ " onnx.Return" (%0#0 ,%0#1 ): (tensor <*xf32 >, tensor <*xf32 >) -> ()
3827
+
3828
+ // CHECK-LABEL: func.func @test_parallel_fork_1
3829
+ // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<8x64x32xf32>, [[PARAM_1_:%.+]]: tensor<32x32xf32>) -> (tensor<8x64x32xf32>, tensor<8x64x32xf32>) {
3830
+ // CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<1.000000e+00> : tensor<32x32xf32>
3831
+ // CHECK-DAG: [[VAR_1_:%.+]]:2 = "onnx.Parallel"() ({
3832
+ // CHECK-DAG: [[VAR_2_:%.+]] = "onnx.Fork"() ({
3833
+ // CHECK: [[VAR_4_:%.+]] = "onnx.MatMul"([[PARAM_0_]], [[VAR_0_]]) : (tensor<8x64x32xf32>, tensor<32x32xf32>) -> tensor<8x64x32xf32>
3834
+ // CHECK: onnx.Yield [[VAR_4_]] : tensor<8x64x32xf32>
3835
+ // CHECK: }) {id = 0 : si64} : () -> tensor<8x64x32xf32>
3836
+ // CHECK-DAG: [[VAR_3_:%.+]] = "onnx.Fork"() ({
3837
+ // CHECK-DAG: [[VAR_4_1_:%.+]] = "onnx.MatMul"([[PARAM_0_]], [[VAR_0_]]) : (tensor<8x64x32xf32>, tensor<32x32xf32>) -> tensor<8x64x32xf32>
3838
+ // CHECK: onnx.Yield [[VAR_4_1_]] : tensor<8x64x32xf32>
3839
+ // CHECK: }) {id = 1 : si64} : () -> tensor<8x64x32xf32>
3840
+ // CHECK: onnx.Yield [[VAR_2_]], [[VAR_3_]] : tensor<8x64x32xf32>, tensor<8x64x32xf32>
3841
+ // CHECK: }) : () -> (tensor<8x64x32xf32>, tensor<8x64x32xf32>)
3842
+ // CHECK: onnx.Return [[VAR_1_]]#0, [[VAR_1_]]#1 : tensor<8x64x32xf32>, tensor<8x64x32xf32>
3843
+ // CHECK: }
3844
+ }
3845
+
3846
+ // -----
3847
+
0 commit comments