@@ -120,3 +120,90 @@ func.func @test_inputs_with_multiple_uses(%arg0 : tensor<5xi32>) -> tensor<5xi32
120
120
}) : () -> tensor <5 xi32 >
121
121
return %10 : tensor <5 xi32 >
122
122
}
123
+
124
+
125
+ // -----
126
+
127
+ // isolate tcp.group ops in the presence of nested regions.
128
+
129
+ // CHECK: #map = affine_map<(d0, d1) -> (d0, d1)>
130
+ // CHECK: module {
131
+ // CHECK: func.func @forward(%[[ARG0:.+]]: tensor<?x4096xf32>, %[[ARG1:.+]]: tensor<?x4096xf32>, %[[ARG2:.+]]: tensor<?x4096xf32>) -> tensor<?x4096xf32> {
132
+ // CHECK: %[[C0:.+]] = arith.constant 0 : index
133
+ // CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x4096xf32>
134
+ // CHECK: %[[V0:.+]] = tcp.isolated_group %[[DIM]], %[[ARG0]], %[[ARG1]] attributes {group_type = "codegen_group"} {
135
+ // CHECK: ^bb0(%[[ARG3:.+]]: index, %[[ARG4:.+]]: tensor<?x4096xf32>, %[[ARG5:.+]]: tensor<?x4096xf32>):
136
+ // CHECK: %[[V1:.+]] = tensor.empty(%[[ARG3]]) : tensor<?x4096xf32>
137
+ // CHECK: %[[V2:.+]] = scf.forall (%[[ARG6:.+]], %[[ARG7:.+]]) in (%[[ARG3]], 4096) shared_outs(%[[ARG8:.+]] = %[[V1]]) -> (tensor<?x4096xf32>) {
138
+ // CHECK: %[[EXTRACTED_SLICE:.+]] = tensor.extract_slice %[[ARG4]][%[[ARG6]], %[[ARG7]]] [1, 1] [1, 1] : tensor<?x4096xf32> to tensor<1x1xf32>
139
+ // CHECK: %[[EXTRACTED_SLICE_0:.+]] = tensor.extract_slice %[[ARG5]][%[[ARG6]], %[[ARG7]]] [1, 1] [1, 1] : tensor<?x4096xf32> to tensor<1x1xf32>
140
+ // CHECK: %[[V3:.+]] = tensor.empty() : tensor<1x1xf32>
141
+ // CHECK: %[[V4:.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%[[EXTRACTED_SLICE]], %[[EXTRACTED_SLICE_0]] : tensor<1x1xf32>, tensor<1x1xf32>) outs(%[[V3]] : tensor<1x1xf32>) {
142
+ // CHECK: ^bb0(%[[IN:.+]]: f32, %[[IN_1:.+]]: f32, %[[OUT:.+]]: f32):
143
+ // CHECK: %[[V5:.+]] = arith.mulf %[[IN]], %[[IN_1]] : f32
144
+ // CHECK: linalg.yield %[[V5]] : f32
145
+ // CHECK: } -> tensor<1x1xf32>
146
+ // CHECK: scf.forall.in_parallel {
147
+ // CHECK: tensor.parallel_insert_slice %[[V4]] into %[[ARG8]][%[[ARG6]], %[[ARG7]]] [1, 1] [1, 1] : tensor<1x1xf32> into tensor<?x4096xf32>
148
+ // CHECK: }
149
+ // CHECK: }
150
+ // CHECK: tcp.yield %[[V2]] : tensor<?x4096xf32>
151
+ // CHECK: } : index, tensor<?x4096xf32>, tensor<?x4096xf32> -> tensor<?x4096xf32>
152
+ // CHECK: return %[[V0]] : tensor<?x4096xf32>
153
+ // CHECK: }
154
+ // CHECK: }
155
+ #map = affine_map <(d0 , d1 ) -> (d0 , d1 )>
156
+ func.func @forward (%arg0: tensor <?x4096 xf32 >, %arg1: tensor <?x4096 xf32 >, %arg2: tensor <?x4096 xf32 >) -> tensor <?x4096 xf32 > {
157
+ %c0 = arith.constant 0 : index
158
+ %dim = tensor.dim %arg0 , %c0 : tensor <?x4096 xf32 >
159
+ %0 = tcp.group attributes {group_type = " codegen_group" } {
160
+ %1 = tensor.empty (%dim ) : tensor <?x4096 xf32 >
161
+ %2 = scf.forall (%arg3 , %arg4 ) in (%dim , 4096 ) shared_outs (%arg5 = %1 ) -> (tensor <?x4096 xf32 >) {
162
+ %extracted_slice = tensor.extract_slice %arg0 [%arg3 , %arg4 ] [1 , 1 ] [1 , 1 ] : tensor <?x4096 xf32 > to tensor <1 x1 xf32 >
163
+ %extracted_slice_0 = tensor.extract_slice %arg1 [%arg3 , %arg4 ] [1 , 1 ] [1 , 1 ] : tensor <?x4096 xf32 > to tensor <1 x1 xf32 >
164
+ %3 = tensor.empty () : tensor <1 x1 xf32 >
165
+ %4 = linalg.generic {index ing_maps = [#map , #map , #map ], iterator_types = [" parallel" , " parallel" ]} ins (%extracted_slice , %extracted_slice_0 : tensor <1 x1 xf32 >, tensor <1 x1 xf32 >) outs (%3 : tensor <1 x1 xf32 >) {
166
+ ^bb0 (%in: f32 , %in_4: f32 , %out: f32 ):
167
+ %8 = arith.mulf %in , %in_4 : f32
168
+ linalg.yield %8 : f32
169
+ } -> tensor <1 x1 xf32 >
170
+ scf.forall.in_parallel {
171
+ tensor.parallel_insert_slice %4 into %arg5 [%arg3 , %arg4 ] [1 , 1 ] [1 , 1 ] : tensor <1 x1 xf32 > into tensor <?x4096 xf32 >
172
+ }
173
+ }
174
+ tcp.yield %2 : tensor <?x4096 xf32 >
175
+ } : tensor <?x4096 xf32 >
176
+ return %0 : tensor <?x4096 xf32 >
177
+ }
178
+
179
+ // -----
180
+
181
+ // Ensure that we correctly drop `tcp.bind_symbolic_shape` ops within the
182
+ // newly created tcp.isolated_group region.
183
+
184
+ // CHECK: func.func @test_symbolic_shape_ops(%[[ARG0:.+]]: tensor<?x3xf32>) -> tensor<?x3xf32> {
185
+ // CHECK: %[[V0:.+]] = tcp.symbolic_int "s0" {min_val = 2, max_val = 9223372036854775806} : i64
186
+ // CHECK: tcp.bind_symbolic_shape %[[ARG0]], [%[[V0]]], affine_map<()[s0] -> (s0, 3)> : tensor<?x3xf32>
187
+ // CHECK: %[[V1:.+]] = tcp.isolated_group %[[ARG0]] {
188
+ // CHECK: ^bb0(%[[ARG1:.+]]: tensor<?x3xf32>):
189
+ // CHECK: %[[V2:.+]] = tcp.add %[[ARG1]], %[[ARG1]] : tensor<?x3xf32>, tensor<?x3xf32> -> tensor<?x3xf32>
190
+ // CHECK-NOT: tcp.bind_symbolic_shape
191
+ // CHECK: %[[V3:.+]] = tcp.mul %[[V2]], %[[V2]] : tensor<?x3xf32>, tensor<?x3xf32> -> tensor<?x3xf32>
192
+ // CHECK: tcp.yield %[[V3]] : tensor<?x3xf32>
193
+ // CHECK: } : tensor<?x3xf32> -> tensor<?x3xf32>
194
+ // CHECK: tcp.bind_symbolic_shape %[[V1]], [%[[V0]]], affine_map<()[s0] -> (s0, 3)> : tensor<?x3xf32>
195
+ // CHECK: return %[[V1]] : tensor<?x3xf32>
196
+ // CHECK: }
197
+ func.func @test_symbolic_shape_ops (%arg0 : tensor <?x3 xf32 >) -> tensor <?x3 xf32 > {
198
+ %0 = tcp.symbolic_int " s0" {min_val = 2 , max_val = 9223372036854775806 } : i64
199
+ tcp.bind_symbolic_shape %arg0 , [%0 ], affine_map <()[s0 ] -> (s0 , 3 )> : tensor <?x3 xf32 >
200
+ %10 = " tcp.group" () ({
201
+ ^bb0 () :
202
+ %2 = tcp.add %arg0 , %arg0 : tensor <?x3 xf32 >, tensor <?x3 xf32 > -> tensor <?x3 xf32 >
203
+ tcp.bind_symbolic_shape %2 , [%0 ], affine_map <()[s0 ] -> (s0 , 3 )> : tensor <?x3 xf32 >
204
+ %3 = tcp.mul %2 , %2 : tensor <?x3 xf32 >, tensor <?x3 xf32 > -> tensor <?x3 xf32 >
205
+ tcp.yield %3 : tensor <?x3 xf32 >
206
+ }) : () -> tensor <?x3 xf32 >
207
+ tcp.bind_symbolic_shape %10 , [%0 ], affine_map <()[s0 ] -> (s0 , 3 )> : tensor <?x3 xf32 >
208
+ return %10 : tensor <?x3 xf32 >
209
+ }
0 commit comments