@@ -247,6 +247,34 @@ func.func @test_onnx_to_zhigh_gru0_bidir_dyn(%X: tensor<?x?x?xf32>, %W: tensor<2
247
247
248
248
// -----
249
249
250
+ func.func @gru_with_len (%arg0: tensor <2 x2 x1 xf32 >, %arg1: tensor <1 x3 x1 xf32 >, %arg2 : tensor <1 x3 x1 xf32 >) -> (tensor <*xf32 >, tensor <*xf32 >) {
251
+ %lens = onnx.Constant dense <[2 , 1 ]> : tensor <2 xi32 >
252
+ %cst = " onnx.NoValue" () {value } : () -> none
253
+ %res:2 = " onnx.GRU" (%arg0 , %arg1 , %arg2 , %cst , %lens , %cst ) {layout = 0 : si64 , linear_before_reset = 1 : si64 }
254
+ : ( tensor <2 x2 x1 xf32 >, tensor <1 x3 x1 xf32 >, tensor <1 x3 x1 xf32 >, none , tensor <2 xi32 >, none ) -> (tensor <*xf32 >, tensor <*xf32 >)
255
+ onnx.Return %res#0 , %res#1 : tensor <*xf32 >, tensor <*xf32 >
256
+
257
+ // CHECK-LABEL: func.func @gru_with_len
258
+ // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<2x2x1xf32>, [[PARAM_1_:%.+]]: tensor<1x3x1xf32>, [[PARAM_2_:%.+]]: tensor<1x3x1xf32>) -> (tensor<2x1x2x1xf32>, tensor<1x2x1xf32>) {
259
+ // CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<[2, 1]> : tensor<2xi32>
260
+ // CHECK-DAG: [[VAR_1_:%.+]] = "onnx.NoValue"() {value} : () -> none
261
+ // CHECK-DAG: [[VAR_2_:%.+]] = "zhigh.Stick"([[PARAM_0_]]) {layout = "3DS"} : (tensor<2x2x1xf32>) -> tensor<2x2x1xf16, #zhigh.layout<{dataLayout = "3DS"}>>
262
+ // CHECK-DAG: [[VAR_3_:%.+]] = "onnx.Transpose"([[PARAM_1_]]) {perm = [0, 2, 1]} : (tensor<1x3x1xf32>) -> tensor<1x1x3xf32>
263
+ // CHECK: [[VAR_4_:%.+]]:3 = "onnx.SplitV11"([[VAR_3_]]) {axis = 2 : si64} : (tensor<1x1x3xf32>) -> (tensor<1x1x1xf32>, tensor<1x1x1xf32>, tensor<1x1x1xf32>)
264
+ // CHECK-DAG: [[VAR_5_:%.+]] = "zhigh.StickForGRU"([[VAR_4_]]#0, [[VAR_4_]]#1, [[VAR_4_]]#2) : (tensor<1x1x1xf32>, tensor<1x1x1xf32>, tensor<1x1x1xf32>) -> tensor<*xf16>
265
+ // CHECK-DAG: [[VAR_6_:%.+]] = "onnx.Transpose"([[PARAM_2_]]) {perm = [0, 2, 1]} : (tensor<1x3x1xf32>) -> tensor<1x1x3xf32>
266
+ // CHECK: [[VAR_7_:%.+]]:3 = "onnx.SplitV11"([[VAR_6_]]) {axis = 2 : si64} : (tensor<1x1x3xf32>) -> (tensor<1x1x1xf32>, tensor<1x1x1xf32>, tensor<1x1x1xf32>)
267
+ // CHECK: [[VAR_8_:%.+]] = "zhigh.StickForGRU"([[VAR_7_]]#0, [[VAR_7_]]#1, [[VAR_7_]]#2) : (tensor<1x1x1xf32>, tensor<1x1x1xf32>, tensor<1x1x1xf32>) -> tensor<*xf16>
268
+ // CHECK: [[VAR_9_:%.+]] = "zhigh.GRU"([[VAR_2_]], [[VAR_1_]], [[VAR_5_]], [[VAR_1_]], [[VAR_8_]], [[VAR_1_]]) {direction = "forward", hidden_size = 1 : si64, return_all_steps = -1 : si64} : (tensor<2x2x1xf16, #zhigh.layout<{dataLayout = "3DS"}>>, none, tensor<*xf16>, none, tensor<*xf16>, none) -> tensor<*xf16>
269
+ // CHECK: [[VAR_10_:%.+]] = "zhigh.Unstick"([[VAR_9_]]) : (tensor<*xf16>) -> tensor<2x1x2x1xf32>
270
+ // CHECK-DAG: [[VAR_11_:%.+]] = "zhigh.FixGRUY"([[VAR_10_]], [[VAR_0_]], [[VAR_1_]]) : (tensor<2x1x2x1xf32>, tensor<2xi32>, none) -> tensor<2x1x2x1xf32>
271
+ // CHECK-DAG: [[VAR_12_:%.+]] = "zhigh.FixGRUYh"([[VAR_10_]], [[VAR_0_]]) : (tensor<2x1x2x1xf32>, tensor<2xi32>) -> tensor<1x2x1xf32>
272
+ // CHECK: onnx.Return [[VAR_11_]], [[VAR_12_]] : tensor<2x1x2x1xf32>, tensor<1x2x1xf32>
273
+ // CHECK: }
274
+ }
275
+
276
+ // -----
277
+
250
278
// COM : Maximum hidden_size in GRU is 10880. Not lowered when using 10881.
251
279
252
280
func.func @test_onnx_to_zhigh_gru_exceed_num_hidden (%X: tensor <7 x2000 x204 xf32 >, %W: tensor <1 x16384 x204 xf32 >, %R: tensor <1 x16384 x10881 xf32 >, %B: tensor <1 x16386 xf32 >) -> (tensor <7 x1 x2000 x10881 xf32 >, tensor <1 x2000 x10881 xf32 >) {
0 commit comments