@@ -5,7 +5,7 @@ func.func @test_relu(%arg0 : tensor<10x10xf32>) -> tensor<10x10xf32> {
5
5
" func.return" (%0 ) : (tensor <10 x10 xf32 >) -> ()
6
6
// CHECK-LABEL: func @test_relu
7
7
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<10x10xf32>) -> tensor<10x10xf32> {
8
- // CHECK-NEXT: [[VAR_0_:%.+]] = tosa.clamp [[PARAM_0_]] {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64 } : (tensor<10x10xf32>) -> tensor<10x10xf32>
8
+ // CHECK-NEXT: [[VAR_0_:%.+]] = tosa.clamp [[PARAM_0_]] {max_val = 3.40282347E+38 : f32, min_val = 0.000000e+00 : f32} : (tensor<10x10xf32>) -> tensor<10x10xf32>
9
9
// CHECK-NEXT: return [[VAR_0_]] : tensor<10x10xf32>
10
10
// CHECK-NEXT: }
11
11
}
@@ -17,7 +17,7 @@ func.func @test_relu_dynamic(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
17
17
" func.return" (%0 ) : (tensor <*xf32 >) -> ()
18
18
// CHECK-LABEL: func @test_relu_dynamic
19
19
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<?x10xf32>) -> tensor<?x10xf32> {
20
- // CHECK-NEXT: [[VAR_0_:%.+]] = tosa.clamp [[PARAM_0_]] {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64 } : (tensor<?x10xf32>) -> tensor<?x10xf32>
20
+ // CHECK-NEXT: [[VAR_0_:%.+]] = tosa.clamp [[PARAM_0_]] {max_val = 3.40282347E+38 : f32, min_val = 0.000000e+00 : f32} : (tensor<?x10xf32>) -> tensor<?x10xf32>
21
21
// CHECK-NEXT: return [[VAR_0_]] : tensor<?x10xf32>
22
22
// CHECK-NEXT: }
23
23
}
@@ -60,7 +60,8 @@ func.func @test_add_broadcast(%arg0: tensor<13x21x1xf32>, %arg1: tensor<1xf32>)
60
60
" func.return" (%0 ) : (tensor <13 x21 x1 xf32 >) -> ()
61
61
// CHECK-LABEL: func.func @test_add_broadcast
62
62
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<13x21x1xf32>, [[PARAM_1_:%.+]]: tensor<1xf32>) -> tensor<13x21x1xf32> {
63
- // CHECK: [[VAR_0_:%.+]] = tosa.reshape [[PARAM_1_]] {new_shape = array<i64: 1, 1, 1>} : (tensor<1xf32>) -> tensor<1x1x1xf32>
63
+ // CHECK: [[SHAPE:%.+]] = tosa.const_shape {value = dense<1> : tensor<3xindex>} : () -> !tosa.shape<3>
64
+ // CHECK: [[VAR_0_:%.+]] = tosa.reshape [[PARAM_1_]], [[SHAPE]] : (tensor<1xf32>, !tosa.shape<3>) -> tensor<1x1x1xf32>
64
65
// CHECK: [[VAR_1_:%.+]] = tosa.add [[PARAM_0_]], [[VAR_0_]] : (tensor<13x21x1xf32>, tensor<1x1x1xf32>) -> tensor<13x21x1xf32>
65
66
// CHECK: return [[VAR_1_]] : tensor<13x21x1xf32>
66
67
}
@@ -83,7 +84,8 @@ func.func @test_sub_broadcast(%arg0: tensor<13x21x1xf32>, %arg1: tensor<1xf32>)
83
84
" func.return" (%0 ) : (tensor <13 x21 x1 xf32 >) -> ()
84
85
// CHECK-LABEL: func.func @test_sub_broadcast
85
86
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<13x21x1xf32>, [[PARAM_1_:%.+]]: tensor<1xf32>) -> tensor<13x21x1xf32> {
86
- // CHECK: [[VAR_0_:%.+]] = tosa.reshape [[PARAM_1_]] {new_shape = array<i64: 1, 1, 1>} : (tensor<1xf32>) -> tensor<1x1x1xf32>
87
+ // CHECK: [[SHAPE:%.+]] = tosa.const_shape {value = dense<1> : tensor<3xindex>} : () -> !tosa.shape<3>
88
+ // CHECK: [[VAR_0_:%.+]] = tosa.reshape [[PARAM_1_]], [[SHAPE]] : (tensor<1xf32>, !tosa.shape<3>) -> tensor<1x1x1xf32>
87
89
// CHECK: [[VAR_1_:%.+]] = tosa.sub [[PARAM_0_]], [[VAR_0_]] : (tensor<13x21x1xf32>, tensor<1x1x1xf32>) -> tensor<13x21x1xf32>
88
90
// CHECK: return [[VAR_1_]] : tensor<13x21x1xf32>
89
91
}
@@ -106,7 +108,8 @@ func.func @test_div_broadcast(%arg0: tensor<13x21x1xi32>, %arg1: tensor<1xi32>)
106
108
" func.return" (%0 ) : (tensor <13 x21 x1 xi32 >) -> ()
107
109
// CHECK-LABEL: func @test_div_broadcast
108
110
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<13x21x1xi32>, [[PARAM_1_:%.+]]: tensor<1xi32>) -> tensor<13x21x1xi32> {
109
- // CHECK-NEXT: [[VAR_0_:%.+]] = tosa.reshape [[PARAM_1_]] {new_shape = array<i64: 1, 1, 1>} : (tensor<1xi32>) -> tensor<1x1x1xi32>
111
+ // CHECK-NEXT: [[SHAPE:%.+]] = tosa.const_shape {value = dense<1> : tensor<3xindex>} : () -> !tosa.shape<3>
112
+ // CHECK-NEXT: [[VAR_0_:%.+]] = tosa.reshape [[PARAM_1_]], [[SHAPE]] : (tensor<1xi32>, !tosa.shape<3>) -> tensor<1x1x1xi32>
110
113
// CHECK-NEXT: [[VAR_1_:%.+]] = tosa.int_div [[PARAM_0_]], [[VAR_0_]] : (tensor<13x21x1xi32>, tensor<1x1x1xi32>) -> tensor<13x21x1xi32>
111
114
}
112
115
@@ -118,7 +121,8 @@ func.func @test_div_decomposed(%arg0: tensor<13x21x1xf32>, %arg1: tensor<13x21x1
118
121
// CHECK-LABEL: func @test_div_decomposed
119
122
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<13x21x1xf32>, [[PARAM_1_:%.+]]: tensor<13x21x1xf32>) -> tensor<13x21x1xf32> {
120
123
// CHECK-NEXT: [[VAR_0_:%.+]] = tosa.reciprocal [[PARAM_1_]] : (tensor<13x21x1xf32>) -> tensor<13x21x1xf32>
121
- // CHECK-NEXT: [[VAR_1_:%.+]] = tosa.mul [[PARAM_0_]], [[VAR_0_]] {shift = 0 : i8} : (tensor<13x21x1xf32>, tensor<13x21x1xf32>) -> tensor<13x21x1xf32>
124
+ // CHECK-NEXT: [[ZERO:%.+]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
125
+ // CHECK-NEXT: [[VAR_1_:%.+]] = tosa.mul [[PARAM_0_]], [[VAR_0_]], [[ZERO]] : (tensor<13x21x1xf32>, tensor<13x21x1xf32>, tensor<1xi8>) -> tensor<13x21x1xf32>
122
126
}
123
127
124
128
// -----
@@ -129,6 +133,8 @@ func.func @test_div_decomposed_broadcast(%arg0: tensor<13x21x1xf32>, %arg1: tens
129
133
// CHECK-LABEL: func @test_div_decomposed_broadcast
130
134
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<13x21x1xf32>, [[PARAM_1_:%.+]]: tensor<1xf32>) -> tensor<13x21x1xf32> {
131
135
// CHECK-NEXT: [[VAR_0_:%.+]] = tosa.reciprocal [[PARAM_1_]] : (tensor<1xf32>) -> tensor<1xf32>
132
- // CHECK-NEXT: [[VAR_1_:%.+]] = tosa.reshape [[VAR_0_]] {new_shape = array<i64: 1, 1, 1>} : (tensor<1xf32>) -> tensor<1x1x1xf32>
133
- // CHECK-NEXT: [[VAR_2_:%.+]] = tosa.mul [[PARAM_0_]], [[VAR_1_]] {shift = 0 : i8} : (tensor<13x21x1xf32>, tensor<1x1x1xf32>) -> tensor<13x21x1xf32>
136
+ // CHECK-NEXT: [[SHAPE:%.+]] = tosa.const_shape {value = dense<1> : tensor<3xindex>} : () -> !tosa.shape<3>
137
+ // CHECK-NEXT: [[VAR_1_:%.+]] = tosa.reshape [[VAR_0_]], [[SHAPE]] : (tensor<1xf32>, !tosa.shape<3>) -> tensor<1x1x1xf32>
138
+ // CHECK-NEXT: [[ZERO:%.+]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
139
+ // CHECK-NEXT: [[VAR_2_:%.+]] = tosa.mul [[PARAM_0_]], [[VAR_1_]], [[ZERO]] : (tensor<13x21x1xf32>, tensor<1x1x1xf32>, tensor<1xi8>) -> tensor<13x21x1xf32>
134
140
}
0 commit comments