@@ -21,6 +21,33 @@ func.func @test_gather_axis0(%arg0 : tensor<3x2xf32>) -> tensor<2x2x2xf32> {
21
21
22
22
// -----
23
23
24
+ func.func @test_gather_dynamic_axis0 (%arg0 : tensor <?x?xf32 >) -> tensor <2 x2 x?xf32 > {
25
+ %indices = " onnx.Constant" () {value = dense <[[0 , 1 ], [1 , 2 ]]> : tensor <2 x2 xi64 >} : () -> tensor <2 x2 xi64 >
26
+ %0 = " onnx.Gather" (%arg0 , %indices ) {axis = 0 : si64 } : (tensor <?x?xf32 >, tensor <2 x2 xi64 >) -> tensor <2 x2 x?xf32 >
27
+ " func.return" (%0 ) : (tensor <2 x2 x?xf32 >) -> ()
28
+ }
29
+
30
+ // CHECK-LABEL: func.func @test_gather_dynamic_axis0
31
+ // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<?x?xf32>) -> tensor<2x2x?xf32> {
32
+ // CHECK-DAG: [[C0:%.+]] = arith.constant 0 : index
33
+ // CHECK-DAG: [[VAR_0_:%.+]] = stablehlo.constant dense<{{.}}[0, 1], [1, 2]{{.}}> : tensor<2x2xi64>
34
+ // CHECK-DAG: [[VAR_1_:%.+]] = stablehlo.constant dense<0> : tensor<2x2xi64>
35
+ // CHECK-DAG: [[INDICES_SHAPE_:%.+]] = shape.const_shape [2, 2] : tensor<2xindex>
36
+ // CHECK-DAG: [[SHAPE_:%.+]] = shape.shape_of [[PARAM_0_]] : tensor<?x?xf32> -> tensor<2xindex>
37
+ // CHECK-DAG: [[DIM_:%.+]] = shape.get_extent [[SHAPE_]], [[C0]] : tensor<2xindex>, index -> index
38
+ // CHECK-DAG: [[DIM_CAST_:%.+]] = arith.index_cast [[DIM_]] : index to i64
39
+ // CHECK-DAG: [[DIM_TENSOR_:%.+]] = tensor.from_elements [[DIM_CAST_]] : tensor<i64>
40
+ // CHECK-DAG: [[VAR_2_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[DIM_TENSOR_]], [[INDICES_SHAPE_]], dims = [] : (tensor<i64>, tensor<2xindex>) -> tensor<2x2xi64>
41
+ // CHECK-NOT: separator of consecutive DAGs
42
+ // CHECK-DAG: [[VAR_3_:%.+]] = stablehlo.compare LT, [[VAR_0_]], [[VAR_1_]], NOTYPE : (tensor<2x2xi64>, tensor<2x2xi64>) -> tensor<2x2xi1>
43
+ // CHECK-DAG: [[VAR_4_:%.+]] = stablehlo.add [[VAR_0_]], [[VAR_2_]] : tensor<2x2xi64>
44
+ // CHECK: [[VAR_5_:%.+]] = stablehlo.select [[VAR_3_]], [[VAR_4_]], [[VAR_0_]] : tensor<2x2xi1>, tensor<2x2xi64>
45
+ // CHECK: [[VAR_6_:%.+]] = "stablehlo.torch_index_select"([[PARAM_0_]], [[VAR_5_]]) {batch_dims = 0 : i64, dim = 0 : i64} : (tensor<?x?xf32>, tensor<2x2xi64>) -> tensor<2x2x?xf32>
46
+ // CHECK: return [[VAR_6_]] : tensor<2x2x?xf32>
47
+ // CHECK: }
48
+
49
+ // -----
50
+
24
51
func.func @test_gather_axis0neg (%arg0 : tensor <3 x2 xf32 >) -> tensor <2 x2 x2 xf32 > {
25
52
%indices = " onnx.Constant" () {value = dense <[[0 , -1 ], [1 , 2 ]]> : tensor <2 x2 xi64 >} : () -> tensor <2 x2 xi64 >
26
53
%0 = " onnx.Gather" (%arg0 , %indices ) {axis = 0 : si64 } : (tensor <3 x2 xf32 >, tensor <2 x2 xi64 >) -> tensor <2 x2 x2 xf32 >
0 commit comments