@@ -32,46 +32,46 @@ func.func @test_softmax_dynamic(%arg0 : tensor<?x20x30xf32>) -> tensor<?x20x30xf
32
32
" func.return" (%0 ) : (tensor <?x20 x30 xf32 >) -> ()
33
33
}
34
34
35
- //TODO: Renable dynamic shape test
36
- // func.func @test_softmax_dynamic
37
- // ([[PARAM_0_ :%.+]]: tensor<?x20x30xf32>) -> tensor<?x20x30xf32> {
38
- // [[VAR_0_ :%.+]] = stablehlo .constant dense<0.000000e+00> : tensor<f32>
39
- // [[CST_2_ :%.+]] = arith.constant 2 : index
40
- // [[CST_1_ :%.+]] = arith.constant 1 : index
41
- // [[CST_0_ :%.+]] = arith .constant 0 : index
42
- // [[VAR_1_:%.+]] = stablehlo.constant dense<0xFF800000> : tensor<f32>
43
- // separator of consecutive DAGs
44
- // [[VAR_2_ :%.+]] = stablehlo.reduce( [[PARAM_0_]] init: [[VAR_1_]]) applies stablehlo.maximum across dimensions = [1] : ( tensor<?x20x30xf32>, tensor<f32>) -> tensor<?x30xf32 >
45
- // [[VAR_3_:%.+]] = shape.shape_of [[PARAM_0_]] : tensor<?x20x30xf32> -> tensor<3xindex>
46
- // separator of consecutive DAGs
47
- // [[VAR_4_ :%.+]] = shape.get_extent [[VAR_3_]], [[CST_0_ ]] : tensor<3xindex>, index -> index
48
- // [[VAR_5_ :%.+]] = shape.get_extent [[VAR_3_ ]], [[CST_2_]] : tensor<3xindex> , index -> index
49
- // [[VAR_6_ :%.+]] = shape.from_extents [[VAR_4_]], [[CST_1_]], [[VAR_5_]] : index, index, index
50
- // [[VAR_7_ :%.+]] = shape.to_extent_tensor [[VAR_6_]] : !shape.shape -> tensor<3xindex >
51
- // [[VAR_8_ :%.+]] = stablehlo.dynamic_reshape [[VAR_2_]], [[VAR_7_]] : ( tensor<?x30xf32>, tensor<3xindex>) -> tensor<?x1x30xf32 >
52
- // [[VAR_9_ :%.+]] = shape.shape_of [[PARAM_0_ ]] : tensor<?x20x30xf32 > -> tensor<3xindex>
53
- // [[VAR_10_ :%.+]] = shape.shape_of [[VAR_8_]] : tensor<?x1x30xf32 > -> tensor<3xindex>
54
- // [[VAR_11_ :%.+]] = shape.broadcast [[VAR_9_ ]], [[VAR_10_]] : tensor<3xindex >, tensor<3xindex> -> tensor<3xindex >
55
- // [[VAR_12_ :%.+]] = stablehlo.dynamic_broadcast_in_dim [[PARAM_0_ ]], [[VAR_11_]], dims = [0, 1, 2] : (tensor<?x20x30xf32 >, tensor<3xindex>) -> tensor<?x20x30xf32>
56
- // [[VAR_13_ :%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_8_ ]], [[VAR_11_]], dims = [0, 1, 2] : (tensor<?x1x30xf32>, tensor<3xindex>) -> tensor<?x20x30xf32>
57
- // [[VAR_14_ :%.+]] = stablehlo.subtract [[VAR_12_]], [[VAR_13_ ]] : tensor<?x20x30xf32>
58
- // [[VAR_15_ :%.+]] = stablehlo.exponential [[VAR_14_]] : tensor<?x20x30xf32>
59
- // [[VAR_16_ :%.+]] = stablehlo.reduce( [[VAR_15_]] init: [[VAR_0_]]) applies stablehlo.add across dimensions = [1] : ( tensor<?x20x30xf32>, tensor<f32>) -> tensor<?x30xf32 >
60
- // [[VAR_17_:%.+]] = shape.shape_of [[VAR_15_]] : tensor<?x20x30xf32> -> tensor<3xindex>
61
- // separator of consecutive DAGs
62
- // [[VAR_18_ :%.+]] = shape.get_extent [[VAR_17_]], [[CST_0_ ]] : tensor<3xindex>, index -> index
63
- // [[VAR_19_ :%.+]] = shape.get_extent [[VAR_17_ ]], [[CST_2_]] : tensor<3xindex> , index -> index
64
- // [[VAR_20_ :%.+]] = shape.from_extents [[VAR_18_]], [[CST_1_]], [[VAR_19_]] : index, index, index
65
- // [[VAR_21_ :%.+]] = shape.to_extent_tensor [[VAR_20_]] : !shape.shape -> tensor<3xindex >
66
- // [[VAR_22_ :%.+]] = stablehlo.dynamic_reshape [[VAR_16_]], [[VAR_21_]] : ( tensor<?x30xf32>, tensor<3xindex>) -> tensor<?x1x30xf32 >
67
- // [[VAR_23_ :%.+]] = shape.shape_of [[VAR_15_ ]] : tensor<?x20x30xf32 > -> tensor<3xindex>
68
- // [[VAR_24_ :%.+]] = shape.shape_of [[VAR_22_]] : tensor<?x1x30xf32 > -> tensor<3xindex>
69
- // [[VAR_25_ :%.+]] = shape.broadcast [[VAR_23_ ]], [[VAR_24_]] : tensor<3xindex >, tensor<3xindex> -> tensor<3xindex >
70
- // [[VAR_26_ :%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_15_ ]], [[VAR_25_]], dims = [0, 1, 2] : (tensor<?x20x30xf32 >, tensor<3xindex>) -> tensor<?x20x30xf32>
71
- // [[VAR_27_ :%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_22_ ]], [[VAR_25_]], dims = [0, 1, 2] : (tensor<?x1x30xf32>, tensor<3xindex>) -> tensor<?x20x30xf32>
72
- // [[VAR_28_:%.+]] = stablehlo.divide [[VAR_26_]], [[VAR_27_ ]] : tensor<?x20x30xf32>
73
- // return [[VAR_28_]] : tensor<?x20x30xf32>
74
- // }
35
+ // CHECK-LABEL: func.func @test_softmax_dynamic
36
+ // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<?x20x30xf32>) -> tensor<?x20x30xf32> {
37
+ // CHECK-DAG: [[VAR_0_ :%.+]] = stablehlo.constant dense<0.000000e+00> : tensor<f32>
38
+ // CHECK-DAG: [[CST_2_ :%.+]] = arith .constant 2 : index
39
+ // CHECK-DAG: [[CST_1_ :%.+]] = arith.constant 1 : index
40
+ // CHECK-DAG: [[CST_0_ :%.+]] = arith.constant 0 : index
41
+ // CHECK-DAG: [[VAR_1_ :%.+]] = stablehlo .constant dense<0xFF800000> : tensor<f32>
42
+ // CHECK-NOT: separator of consecutive DAGs
43
+ // CHECK-DAG: [[VAR_2_:%.+]] = stablehlo.reduce([[PARAM_0_]] init: [[VAR_1_]]) applies stablehlo.maximum across dimensions = [1] : (tensor<?x20x30xf32>, tensor<f32>) -> tensor<?x30xf32>
44
+ // CHECK-DAG: [[VAR_3_ :%.+]] = shape.shape_of [[PARAM_0_]] : tensor<?x20x30xf32> -> tensor<3xindex >
45
+ // CHECK-NOT: separator of consecutive DAGs
46
+ // CHECK-DAG: [[VAR_4_:%.+]] = shape.get_extent [[VAR_3_]], [[CST_0_]] : tensor<3xindex>, index -> index
47
+ // CHECK-DAG: [[VAR_5_ :%.+]] = shape.get_extent [[VAR_3_]], [[CST_2_ ]] : tensor<3xindex>, index -> index
48
+ // CHECK: [[VAR_6_ :%.+]] = shape.from_extents [[VAR_4_ ]], [[CST_1_]], [[VAR_5_]] : index , index, index
49
+ // CHECK: [[VAR_7_ :%.+]] = shape.to_extent_tensor [[VAR_6_]] : !shape.shape -> tensor<3xindex>
50
+ // CHECK-DAG: [[VAR_8_ :%.+]] = stablehlo.dynamic_reshape [[VAR_2_]], [[VAR_7_]] : (tensor<?x30xf32>, tensor<3xindex>) -> tensor<?x1x30xf32 >
51
+ // CHECK-DAG: [[VAR_9_ :%.+]] = shape.shape_of [[PARAM_0_]] : tensor<?x20x30xf32> -> tensor<3xindex >
52
+ // CHECK: [[VAR_10_ :%.+]] = shape.shape_of [[VAR_8_ ]] : tensor<?x1x30xf32 > -> tensor<3xindex>
53
+ // CHECK: [[VAR_11_ :%.+]] = shape.broadcast [[VAR_9_]], [[VAR_10_]] : tensor<3xindex>, tensor<3xindex > -> tensor<3xindex>
54
+ // CHECK-DAG: [[VAR_12_ :%.+]] = stablehlo.dynamic_broadcast_in_dim [[PARAM_0_ ]], [[VAR_11_]], dims = [0, 1, 2] : ( tensor<?x20x30xf32 >, tensor<3xindex>) -> tensor<?x20x30xf32 >
55
+ // CHECK-DAG: [[VAR_13_ :%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_8_ ]], [[VAR_11_]], dims = [0, 1, 2] : (tensor<?x1x30xf32 >, tensor<3xindex>) -> tensor<?x20x30xf32>
56
+ // CHECK: [[VAR_14_ :%.+]] = stablehlo.subtract [[VAR_12_ ]], [[VAR_13_]] : tensor<?x20x30xf32>
57
+ // CHECK: [[VAR_15_ :%.+]] = stablehlo.exponential [[VAR_14_ ]] : tensor<?x20x30xf32>
58
+ // CHECK-DAG: [[VAR_16_ :%.+]] = stablehlo.reduce([[VAR_15_]] init: [[VAR_0_]]) applies stablehlo.add across dimensions = [1] : ( tensor<?x20x30xf32>, tensor<f32>) -> tensor<?x30xf32 >
59
+ // CHECK-DAG: [[VAR_17_ :%.+]] = shape.shape_of [[VAR_15_]] : tensor<?x20x30xf32> -> tensor<3xindex >
60
+ // CHECK-NOT: separator of consecutive DAGs
61
+ // CHECK-DAG: [[VAR_18_:%.+]] = shape.get_extent [[VAR_17_]], [[CST_0_]] : tensor<3xindex>, index -> index
62
+ // CHECK-DAG: [[VAR_19_ :%.+]] = shape.get_extent [[VAR_17_]], [[CST_2_ ]] : tensor<3xindex>, index -> index
63
+ // CHECK: [[VAR_20_ :%.+]] = shape.from_extents [[VAR_18_ ]], [[CST_1_]], [[VAR_19_]] : index , index, index
64
+ // CHECK: [[VAR_21_ :%.+]] = shape.to_extent_tensor [[VAR_20_]] : !shape.shape -> tensor<3xindex>
65
+ // CHECK-DAG: [[VAR_22_ :%.+]] = stablehlo.dynamic_reshape [[VAR_16_]], [[VAR_21_]] : (tensor<?x30xf32>, tensor<3xindex>) -> tensor<?x1x30xf32 >
66
+ // CHECK-DAG: [[VAR_23_ :%.+]] = shape.shape_of [[VAR_15_]] : tensor<?x20x30xf32> -> tensor<3xindex >
67
+ // CHECK: [[VAR_24_ :%.+]] = shape.shape_of [[VAR_22_ ]] : tensor<?x1x30xf32 > -> tensor<3xindex>
68
+ // CHECK: [[VAR_25_ :%.+]] = shape.broadcast [[VAR_23_]], [[VAR_24_]] : tensor<3xindex>, tensor<3xindex > -> tensor<3xindex>
69
+ // CHECK-DAG: [[VAR_26_ :%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_15_ ]], [[VAR_25_]], dims = [0, 1, 2] : ( tensor<?x20x30xf32 >, tensor<3xindex>) -> tensor<?x20x30xf32 >
70
+ // CHECK-DAG: [[VAR_27_ :%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_22_ ]], [[VAR_25_]], dims = [0, 1, 2] : (tensor<?x1x30xf32 >, tensor<3xindex>) -> tensor<?x20x30xf32>
71
+ // CHECK: [[VAR_28_ :%.+]] = stablehlo.divide [[VAR_26_ ]], [[VAR_27_]] : tensor<?x20x30xf32>
72
+ // CHECK: return [[VAR_28_ ]] : tensor<?x20x30xf32>
73
+ // CHECK: }
74
+
75
75
76
76
// -----
77
77
0 commit comments