Skip to content

Commit 8b5579e

Browse files
committed
lit test for shape inference.
Signed-off-by: Haruki Imai <[email protected]>
1 parent 409d9d2 commit 8b5579e

File tree

1 file changed

+44
-0
lines changed

1 file changed

+44
-0
lines changed

test/mlir/onnx/onnx_shape_inference.mlir

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3801,3 +3801,47 @@ func.func @test_RMSlayer_norm_2inputs(%arg0: tensor<12x3x5xf32>, %arg1: tensor<5
38013801
// CHECK: }
38023802
}
38033803

3804+
// -----
3805+
3806+
//===----------------------------------------------------------------------===//
3807+
/// Test shape inference for Parallel and Fork.
3808+
//===----------------------------------------------------------------------===//
3809+
3810+
func.func @test_parallel_fork_1(%arg0: tensor<8x64x32xf32>, %arg1: tensor<32x32xf32>) -> (tensor<*xf32>, tensor<*xf32>) {
3811+
%c0 = onnx.Constant dense<1.0> : tensor<32x32xf32>
3812+
%c1 = onnx.Constant dense<1.0> : tensor<32xf32>
3813+
%c2 = onnx.Constant dense<1.0> : tensor<32x32xf32>
3814+
3815+
%0:2 = "onnx.Parallel"() ({
3816+
%00 = "onnx.Fork"() ({
3817+
%01 = "onnx.MatMul"(%arg0, %c0) : (tensor<8x64x32xf32>, tensor<32x32xf32>) -> tensor<*xf32>
3818+
onnx.Yield %01 : tensor<*xf32>
3819+
}) {id = 0 : si64} : () -> tensor<*xf32>
3820+
%01 = "onnx.Fork"() ({
3821+
%01 = "onnx.MatMul"(%arg0, %c2) : (tensor<8x64x32xf32>, tensor<32x32xf32>) -> tensor<*xf32>
3822+
onnx.Yield %01 : tensor<*xf32>
3823+
}) {id = 1 : si64} : () -> tensor<*xf32>
3824+
"onnx.Yield"(%00, %01) : (tensor<*xf32>, tensor<*xf32>) -> ()
3825+
}) : () -> (tensor<*xf32>, tensor<*xf32>)
3826+
"onnx.Return"(%0#0,%0#1): (tensor<*xf32>, tensor<*xf32>) -> ()
3827+
3828+
// CHECK-LABEL: func.func @test_parallel_fork_1
3829+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<8x64x32xf32>, [[PARAM_1_:%.+]]: tensor<32x32xf32>) -> (tensor<8x64x32xf32>, tensor<8x64x32xf32>) {
3830+
// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<1.000000e+00> : tensor<32x32xf32>
3831+
// CHECK-DAG: [[VAR_1_:%.+]]:2 = "onnx.Parallel"() ({
3832+
// CHECK-DAG: [[VAR_2_:%.+]] = "onnx.Fork"() ({
3833+
// CHECK: [[VAR_4_:%.+]] = "onnx.MatMul"([[PARAM_0_]], [[VAR_0_]]) : (tensor<8x64x32xf32>, tensor<32x32xf32>) -> tensor<8x64x32xf32>
3834+
// CHECK: onnx.Yield [[VAR_4_]] : tensor<8x64x32xf32>
3835+
// CHECK: }) {id = 0 : si64} : () -> tensor<8x64x32xf32>
3836+
// CHECK-DAG: [[VAR_3_:%.+]] = "onnx.Fork"() ({
3837+
// CHECK-DAG: [[VAR_4_1_:%.+]] = "onnx.MatMul"([[PARAM_0_]], [[VAR_0_]]) : (tensor<8x64x32xf32>, tensor<32x32xf32>) -> tensor<8x64x32xf32>
3838+
// CHECK: onnx.Yield [[VAR_4_1_]] : tensor<8x64x32xf32>
3839+
// CHECK: }) {id = 1 : si64} : () -> tensor<8x64x32xf32>
3840+
// CHECK: onnx.Yield [[VAR_2_]], [[VAR_3_]] : tensor<8x64x32xf32>, tensor<8x64x32xf32>
3841+
// CHECK: }) : () -> (tensor<8x64x32xf32>, tensor<8x64x32xf32>)
3842+
// CHECK: onnx.Return [[VAR_1_]]#0, [[VAR_1_]]#1 : tensor<8x64x32xf32>, tensor<8x64x32xf32>
3843+
// CHECK: }
3844+
}
3845+
3846+
// -----
3847+

0 commit comments

Comments
 (0)