@@ -418,3 +418,40 @@ func.func @torch.aten.broadcast_to_dynamic_dim(%arg0: !torch.vtensor<[1,2],f32>,
418
418
%2 = torch.aten.broadcast_to %arg0 , %1 : !torch.vtensor <[1 ,2 ],f32 >, !torch.list <int > -> !torch.vtensor <[?,2 ],f32 >
419
419
return %2 : !torch.vtensor <[?,2 ],f32 >
420
420
}
421
+
422
+ // -----
423
+
424
+ // CHECK-LABEL: @symbolic_shape_ops(
425
+ // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,3],f32>, %[[ARG1:.*]]: !torch.vtensor<[?,?,3],f32>, %[[ARG2:.*]]: !torch.vtensor<[?,?,3],f32>) -> !torch.vtensor<[?,?,3],f32> {
426
+ // CHECK: %[[S0:.*]] = tcp.symbolic_int "s0" {min_val = 5, max_val = 10} : i64
427
+ // CHECK: %[[S1:.*]] = tcp.symbolic_int "s1" {min_val = 0, max_val = 100} : i64
428
+ // CHECK: %[[S3:.*]] = tcp.symbolic_int "s3" {min_val = 0, max_val = 50} : i64
429
+ // CHECK: %[[S5:.*]] = tcp.symbolic_int "s5" {min_val = 0, max_val = {{[0-9]+}}} : i64
430
+ // CHECK: tcp.bind_symbolic_shape %{{.*}}, [%[[S0]], %[[S1]]], affine_map<()[s0, s1] -> (s0, s1, 3)> : tensor<?x?x3xf32>
431
+ // CHECK: tcp.bind_symbolic_shape %{{.*}}, [%[[S0]], %[[S3]]], affine_map<()[s0, s1] -> (s0, s1, 3)> : tensor<?x?x3xf32>
432
+ // CHECK: tcp.bind_symbolic_shape %{{.*}}, [%[[S0]], %[[S5]]], affine_map<()[s0, s1] -> (s0, s1, 3)> : tensor<?x?x3xf32>
433
+ // CHECK: %[[TANH:.*]] = tcp.tanh %{{.*}} : tensor<?x?x3xf32> -> tensor<?x?x3xf32>
434
+ // CHECK: tcp.bind_symbolic_shape %[[TANH]], [%[[S0]], %[[S1]]], affine_map<()[s0, s1] -> (s0, s1, 3)> : tensor<?x?x3xf32>
435
+ // CHECK: %[[SIGM:.*]] = tcp.sigmoid %{{.*}} : tensor<?x?x3xf32> -> tensor<?x?x3xf32>
436
+ // CHECK: tcp.bind_symbolic_shape %[[SIGM]], [%[[S0]], %[[S3]]], affine_map<()[s0, s1] -> (s0, s1, 3)> : tensor<?x?x3xf32>
437
+ // CHECK: %[[CAT:.*]] = tensor.concat dim(1) %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : (tensor<?x?x3xf32>, tensor<?x?x3xf32>, tensor<?x?x3xf32>, tensor<?x?x3xf32>) -> tensor<?x?x3xf32>
438
+ // CHECK: tcp.bind_symbolic_shape %[[CAT]], [%[[S0]], %[[S1]], %[[S3]], %[[S5]]], affine_map<()[s0, s1, s2, s3] -> (s0, s2 + s3 + s1 * 2, 3)> : tensor<?x?x3xf32>
439
+ // CHECK: return %{{.*}} : !torch.vtensor<[?,?,3],f32>
440
+ func.func @symbolic_shape_ops (%arg0: !torch.vtensor <[?,?,3 ],f32 >, %arg1: !torch.vtensor <[?,?,3 ],f32 >, %arg2: !torch.vtensor <[?,?,3 ],f32 >) -> !torch.vtensor <[?,?,3 ],f32 > {
441
+ %0 = torch.symbolic_int " s0" {min_val = 5 , max_val = 10 } : !torch.int
442
+ %1 = torch.symbolic_int " s1" {min_val = 0 , max_val = 100 } : !torch.int
443
+ %2 = torch.symbolic_int " s3" {min_val = 0 , max_val = 50 } : !torch.int
444
+ %3 = torch.symbolic_int " s5" {min_val = 0 , max_val = 9223372036854775806 } : !torch.int
445
+ torch.bind_symbolic_shape %arg0 , [%0 , %1 ], affine_map <()[s0 , s1 ] -> (s0 , s1 , 3 )> : !torch.vtensor <[?,?,3 ],f32 >
446
+ torch.bind_symbolic_shape %arg1 , [%0 , %2 ], affine_map <()[s0 , s1 ] -> (s0 , s1 , 3 )> : !torch.vtensor <[?,?,3 ],f32 >
447
+ torch.bind_symbolic_shape %arg2 , [%0 , %3 ], affine_map <()[s0 , s1 ] -> (s0 , s1 , 3 )> : !torch.vtensor <[?,?,3 ],f32 >
448
+ %4 = torch.aten.tanh %arg0 : !torch.vtensor <[?,?,3 ],f32 > -> !torch.vtensor <[?,?,3 ],f32 >
449
+ torch.bind_symbolic_shape %4 , [%0 , %1 ], affine_map <()[s0 , s1 ] -> (s0 , s1 , 3 )> : !torch.vtensor <[?,?,3 ],f32 >
450
+ %5 = torch.aten.sigmoid %arg1 : !torch.vtensor <[?,?,3 ],f32 > -> !torch.vtensor <[?,?,3 ],f32 >
451
+ torch.bind_symbolic_shape %5 , [%0 , %2 ], affine_map <()[s0 , s1 ] -> (s0 , s1 , 3 )> : !torch.vtensor <[?,?,3 ],f32 >
452
+ %6 = torch.prim.ListConstruct %4 , %4 , %5 , %arg2 : (!torch.vtensor <[?,?,3 ],f32 >, !torch.vtensor <[?,?,3 ],f32 >, !torch.vtensor <[?,?,3 ],f32 >, !torch.vtensor <[?,?,3 ],f32 >) -> !torch.list <vtensor >
453
+ %int1 = torch.constant.int 1
454
+ %7 = torch.aten.cat %6 , %int1 : !torch.list <vtensor >, !torch.int -> !torch.vtensor <[?,?,3 ],f32 >
455
+ torch.bind_symbolic_shape %7 , [%0 , %1 , %2 , %3 ], affine_map <()[s0 , s1 , s2 , s3 ] -> (s0 , s2 + s3 + s1 * 2 , 3 )> : !torch.vtensor <[?,?,3 ],f32 >
456
+ return %7 : !torch.vtensor <[?,?,3 ],f32 >
457
+ }
0 commit comments