@@ -254,3 +254,73 @@ func.func @torch.aten.fake_quantize_per_channel_affine_zero_like(%input: !torch.
254
254
%output = torch.aten.fake_quantize_per_channel_affine %input , %scale , %zero_point , %int1 , %int0 , %int255 : !torch.vtensor <[1 ,3 ,32 ,32 ],f32 >, !torch.vtensor <[3 ],f32 >, !torch.vtensor <[3 ],si32 >, !torch.int , !torch.int , !torch.int -> !torch.vtensor <[1 ,3 ,32 ,32 ],f32 >
255
255
return %output : !torch.vtensor <[1 ,3 ,32 ,32 ],f32 >
256
256
}
257
+
258
+ // -----
259
+
260
+ // CHECK-LABEL: func.func @torch.aten.sort(
261
+ // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,2304],f32>) -> !torch.vtensor<[?,2304],f32> {
262
+ // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,2304],f32> -> tensor<?x2304xf32>
263
+ // CHECK: %[[CUSTOM:.*]] = tcp.custom_op("torch.aten.sort") %[[T0]] {descending = true, dim = -1 : i64, torch_operand_names = ["self"]} :
264
+ // CHECK-SAME: tensor<?x2304xf32> -> tensor<?x2304xf32>, tensor<?x2304xi64>
265
+ // CHECK: %[[RES:.*]] = torch_c.from_builtin_tensor %[[CUSTOM:.*]] : tensor<?x2304xf32> -> !torch.vtensor<[?,2304],f32>
266
+ // CHECK: return %[[RES]] : !torch.vtensor<[?,2304],f32>
267
+ func.func @torch.aten.sort (%input: !torch.vtensor <[?,2304 ],f32 >) -> !torch.vtensor <[?,2304 ],f32 > {
268
+ %int -1 = torch.constant.int -1
269
+ %true = torch.constant.bool true
270
+ %output0 , %output1 = torch.aten.sort %input , %int -1 , %true : !torch.vtensor <[?,2304 ],f32 >, !torch.int , !torch.bool -> !torch.vtensor <[?,2304 ],f32 >, !torch.vtensor <[?,2304 ],si64 >
271
+ return %output0 : !torch.vtensor <[?,2304 ],f32 >
272
+ }
273
+
274
+ // -----
275
+
276
+ // CHECK-LABEL: func.func @torch.aten.cumsum(
277
+ // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?],si32>) -> !torch.vtensor<[?],si64> {
278
+ // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?],si32> -> tensor<?xi32>
279
+ // CHECK: %[[CUSTOM:.*]] = tcp.custom_op("torch.aten.cumsum") %[[T0]] {dim = 0 : i64, torch_operand_names = ["self"]} : tensor<?xi32> -> tensor<?xi64>
280
+ // CHECK: %[[RES:.*]] = torch_c.from_builtin_tensor %[[CUSTOM]] : tensor<?xi64> -> !torch.vtensor<[?],si64>
281
+ // CHECK: return %[[RES]] : !torch.vtensor<[?],si64>
282
+ func.func @torch.aten.cumsum (%input: !torch.vtensor <[?],si32 >) -> !torch.vtensor <[?],si64 > {
283
+ %int0 = torch.constant.int 0
284
+ %none = torch.constant.none
285
+ %1 = torch.aten.cumsum %input , %int0 , %none : !torch.vtensor <[?],si32 >, !torch.int , !torch.none -> !torch.vtensor <[?],si64 >
286
+ return %1 : !torch.vtensor <[?],si64 >
287
+ }
288
+
289
+ // -----
290
+
291
+ // CHECK-LABEL: func.func @torch.aten.min.dim(
292
+ // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,80],f32>) -> !torch.vtensor<[?],f32> {
293
+ // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,80],f32> -> tensor<?x80xf32>
294
+ // CHECK: %[[CUSTOM:.*]] = tcp.custom_op("torch.aten.min.dim") %[[T0]] {dim = 1 : i64, keepdim = false, torch_operand_names = ["self"]} :
295
+ // CHECK-SAME: tensor<?x80xf32> -> tensor<?xf32>, tensor<?xi64>
296
+ // CHECK: %[[RES:.*]] = torch_c.from_builtin_tensor %[[CUSTOM:.*]] : tensor<?xf32> -> !torch.vtensor<[?],f32>
297
+ // CHECK: return %[[RES]] : !torch.vtensor<[?],f32>
298
+ func.func @torch.aten.min.dim (%input: !torch.vtensor <[?,80 ],f32 >) -> !torch.vtensor <[?],f32 > {
299
+ %int1 = torch.constant.int 1
300
+ %false = torch.constant.bool false
301
+ %output0 , %output1 = torch.aten.min.dim %input , %int1 , %false : !torch.vtensor <[?,80 ],f32 >, !torch.int , !torch.bool -> !torch.vtensor <[?],f32 >, !torch.vtensor <[?],si64 >
302
+ return %output0 : !torch.vtensor <[?],f32 >
303
+ }
304
+
305
+ // -----
306
+
307
+ // CHECK-LABEL: func.func @torch.aten.view_dynamic_shape(
308
+ // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,384,16],f32>, %[[ARG1:.*]]: tensor<?x2736x16xf32>) -> !torch.vtensor<[?,24,16,16],f32> {
309
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
310
+ // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,384,16],f32> -> tensor<?x384x16xf32>
311
+ // CHECK: %[[DIM:.*]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x2736x16xf32>
312
+ // CHECK: %[[CUSTOM:.*]] = tcp.custom_op("torch.aten.view") %[[T0]], %[[DIM]] {size = array<i64: -9223372036854775808, 24, 16, 16>, torch_operand_names = ["self", "idx_0"]} :
313
+ // CHECK-SAME: tensor<?x384x16xf32>, index -> tensor<?x24x16x16xf32>
314
+ // CHECK: %[[RES:.*]] = torch_c.from_builtin_tensor %[[CUSTOM:.*]] : tensor<?x24x16x16xf32> -> !torch.vtensor<[?,24,16,16],f32>
315
+ // CHECK: return %[[RES]] : !torch.vtensor<[?,24,16,16],f32>
316
+ func.func @torch.aten.view_dynamic_shape (%arg0: !torch.vtensor <[?,384 ,16 ],f32 >, %arg1: tensor <?x2736 x16 xf32 >) -> !torch.vtensor <[?,24 ,16 ,16 ],f32 > {
317
+ %c0 = arith.constant 0 : index
318
+ %int24 = torch.constant.int 24
319
+ %int16 = torch.constant.int 16
320
+ %dim_32 = tensor.dim %arg1 , %c0 : tensor <?x2736 x16 xf32 >
321
+ %1 = arith.index_cast %dim_32 : index to i64
322
+ %2 = torch_c.from_i64 %1
323
+ %3 = torch.prim.ListConstruct %2 , %int24 , %int16 , %int16 : (!torch.int , !torch.int , !torch.int , !torch.int ) -> !torch.list <int >
324
+ %4 = torch.aten.view %arg0 , %3 : !torch.vtensor <[?,384 ,16 ],f32 >, !torch.list <int > -> !torch.vtensor <[?,24 ,16 ,16 ],f32 >
325
+ return %4 : !torch.vtensor <[?,24 ,16 ,16 ],f32 >
326
+ }
0 commit comments