Skip to content
This repository was archived by the owner on Jan 30, 2025. It is now read-only.

Commit 3cd4c61

Browse files
tests
1 parent 00eb62c commit 3cd4c61

File tree

4 files changed

+202
-2
lines changed

4 files changed

+202
-2
lines changed

lib/Conversion/TorchToTcp/Misc.cpp

+1-2
Original file line numberDiff line numberDiff line change
@@ -282,8 +282,7 @@ class ConvertSymbolicIntOp : public OpConversionPattern<Torch::SymbolicIntOp> {
282282
LogicalResult
283283
matchAndRewrite(Torch::SymbolicIntOp op, OpAdaptor adaptor,
284284
ConversionPatternRewriter &rewriter) const override {
285-
RankedTensorType resultType =
286-
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
285+
Type resultType = getTypeConverter()->convertType(op.getType());
287286

288287
rewriter.replaceOpWithNewOp<tcp::SymbolicIntOp>(
289288
op, resultType, adaptor.getSymbolNameAttr(), adaptor.getMinValAttr(),

test/Conversion/TorchToTcp/misc.mlir

+37
Original file line numberDiff line numberDiff line change
@@ -418,3 +418,40 @@ func.func @torch.aten.broadcast_to_dynamic_dim(%arg0: !torch.vtensor<[1,2],f32>,
418418
%2 = torch.aten.broadcast_to %arg0, %1 : !torch.vtensor<[1,2],f32>, !torch.list<int> -> !torch.vtensor<[?,2],f32>
419419
return %2 : !torch.vtensor<[?,2],f32>
420420
}
421+
422+
// -----
423+
424+
// CHECK-LABEL: @symbolic_shape_ops(
425+
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,3],f32>, %[[ARG1:.*]]: !torch.vtensor<[?,?,3],f32>, %[[ARG2:.*]]: !torch.vtensor<[?,?,3],f32>) -> !torch.vtensor<[?,?,3],f32> {
426+
// CHECK: %[[S0:.*]] = tcp.symbolic_int "s0" {min_val = 5, max_val = 10} : i64
427+
// CHECK: %[[S1:.*]] = tcp.symbolic_int "s1" {min_val = 0, max_val = 100} : i64
428+
// CHECK: %[[S3:.*]] = tcp.symbolic_int "s3" {min_val = 0, max_val = 50} : i64
429+
// CHECK: %[[S5:.*]] = tcp.symbolic_int "s5" {min_val = 0, max_val = {{[0-9]+}}} : i64
430+
// CHECK: tcp.bind_symbolic_shape %{{.*}}, [%[[S0]], %[[S1]]], affine_map<()[s0, s1] -> (s0, s1, 3)> : tensor<?x?x3xf32>
431+
// CHECK: tcp.bind_symbolic_shape %{{.*}}, [%[[S0]], %[[S3]]], affine_map<()[s0, s1] -> (s0, s1, 3)> : tensor<?x?x3xf32>
432+
// CHECK: tcp.bind_symbolic_shape %{{.*}}, [%[[S0]], %[[S5]]], affine_map<()[s0, s1] -> (s0, s1, 3)> : tensor<?x?x3xf32>
433+
// CHECK: %[[TANH:.*]] = tcp.tanh %{{.*}} : tensor<?x?x3xf32> -> tensor<?x?x3xf32>
434+
// CHECK: tcp.bind_symbolic_shape %[[TANH]], [%[[S0]], %[[S1]]], affine_map<()[s0, s1] -> (s0, s1, 3)> : tensor<?x?x3xf32>
435+
// CHECK: %[[SIGM:.*]] = tcp.sigmoid %{{.*}} : tensor<?x?x3xf32> -> tensor<?x?x3xf32>
436+
// CHECK: tcp.bind_symbolic_shape %[[SIGM]], [%[[S0]], %[[S3]]], affine_map<()[s0, s1] -> (s0, s1, 3)> : tensor<?x?x3xf32>
437+
// CHECK: %[[CAT:.*]] = tensor.concat dim(1) %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : (tensor<?x?x3xf32>, tensor<?x?x3xf32>, tensor<?x?x3xf32>, tensor<?x?x3xf32>) -> tensor<?x?x3xf32>
438+
// CHECK: tcp.bind_symbolic_shape %[[CAT]], [%[[S0]], %[[S1]], %[[S3]], %[[S5]]], affine_map<()[s0, s1, s2, s3] -> (s0, s2 + s3 + s1 * 2, 3)> : tensor<?x?x3xf32>
439+
// CHECK: return %{{.*}} : !torch.vtensor<[?,?,3],f32>
440+
func.func @symbolic_shape_ops(%arg0: !torch.vtensor<[?,?,3],f32>, %arg1: !torch.vtensor<[?,?,3],f32>, %arg2: !torch.vtensor<[?,?,3],f32>) -> !torch.vtensor<[?,?,3],f32> {
441+
%0 = torch.symbolic_int "s0" {min_val = 5, max_val = 10} : !torch.int
442+
%1 = torch.symbolic_int "s1" {min_val = 0, max_val = 100} : !torch.int
443+
%2 = torch.symbolic_int "s3" {min_val = 0, max_val = 50} : !torch.int
444+
%3 = torch.symbolic_int "s5" {min_val = 0, max_val = 9223372036854775806} : !torch.int
445+
torch.bind_symbolic_shape %arg0, [%0, %1], affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>
446+
torch.bind_symbolic_shape %arg1, [%0, %2], affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>
447+
torch.bind_symbolic_shape %arg2, [%0, %3], affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>
448+
%4 = torch.aten.tanh %arg0 : !torch.vtensor<[?,?,3],f32> -> !torch.vtensor<[?,?,3],f32>
449+
torch.bind_symbolic_shape %4, [%0, %1], affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>
450+
%5 = torch.aten.sigmoid %arg1 : !torch.vtensor<[?,?,3],f32> -> !torch.vtensor<[?,?,3],f32>
451+
torch.bind_symbolic_shape %5, [%0, %2], affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>
452+
%6 = torch.prim.ListConstruct %4, %4, %5, %arg2 : (!torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>) -> !torch.list<vtensor>
453+
%int1 = torch.constant.int 1
454+
%7 = torch.aten.cat %6, %int1 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[?,?,3],f32>
455+
torch.bind_symbolic_shape %7, [%0, %1, %2, %3], affine_map<()[s0, s1, s2, s3] -> (s0, s2 + s3 + s1 * 2, 3)> : !torch.vtensor<[?,?,3],f32>
456+
return %7 : !torch.vtensor<[?,?,3],f32>
457+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
# Licensed under the Apache License v2.0 with LLVM Exceptions.
2+
# See https://llvm.org/LICENSE.txt for license information.
3+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4+
# Also available under a BSD-style license. See LICENSE.
5+
6+
# RUN: %PYTHON %s | FileCheck %s
7+
8+
import torch
9+
import torch.nn as nn
10+
from torch.export import Dim
11+
from torch.library import Library, impl, impl_abstract
12+
13+
from torch_mlir import fx
14+
15+
16+
def run(f):
17+
print(f"{f.__name__}")
18+
print("-" * len(f.__name__))
19+
f()
20+
print()
21+
22+
23+
@run
24+
# CHECK-LABEL: test_tanh_sigmoid_cat_custom_op
25+
# CHECK: func.func @main(
26+
# CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[?,?,3],f32>,
27+
# CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: !torch.vtensor<[?,?,3],f32>,
28+
# CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: !torch.vtensor<[?,?,3],f32>) -> !torch.vtensor<[?,?,3],f32> {
29+
# CHECK: %[[S0:.+]] = torch.symbolic_int "s0" {min_val = 5, max_val = 10} : !torch.int
30+
# CHECK: %[[S1:.+]] = torch.symbolic_int "s1" {min_val = {{[0-9]+}}, max_val = 100} : !torch.int
31+
# CHECK: %[[S2:.+]] = torch.symbolic_int "s3" {min_val = {{[0-9]+}}, max_val = 50} : !torch.int
32+
# CHECK: %[[S3:.+]] = torch.symbolic_int "s5" {min_val = {{[0-9]+}}, max_val = {{[0-9]+}}} : !torch.int
33+
# CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]], %[[S1]]], affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>
34+
# CHECK: torch.bind_symbolic_shape %[[ARG1]], [%[[S0]], %[[S2]]], affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>
35+
# CHECK: torch.bind_symbolic_shape %[[ARG2]], [%[[S0]], %[[S3]]], affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>
36+
# CHECK: %[[OP:.+]] = torch.operator "torch.my_custom_library.tanh_sigmoid_cat_op"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) : (!torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>) -> !torch.vtensor<[?,?,3],f32>
37+
# CHECK: torch.bind_symbolic_shape %[[OP]], [%[[S0]], %[[S1]], %[[S2]], %[[S3]]], affine_map<()[s0, s1, s2, s3] -> (s0, s2 + s3 + s1 * 2, 3)> : !torch.vtensor<[?,?,3],f32>
38+
# CHECK: return %[[OP]] : !torch.vtensor<[?,?,3],f32>
39+
def test_tanh_sigmoid_cat_custom_op():
40+
41+
m = Library("my_custom_library", "DEF")
42+
m.define("tanh_sigmoid_cat_op(Tensor x, Tensor y, Tensor z) -> Tensor")
43+
44+
@impl(m, "tanh_sigmoid_cat_op", "CompositeExplicitAutograd")
45+
def custom_op(x, y, z):
46+
a = torch.tanh(x)
47+
b = torch.sigmoid(y)
48+
return torch.cat((a, a, b, z), dim=1)
49+
50+
@impl_abstract("my_custom_library::tanh_sigmoid_cat_op")
51+
def custom_op_meta(x, y, z):
52+
result = custom_op(x, y, z)
53+
return torch.empty_like(result)
54+
55+
class TanhSigmoidCatCustomOp(nn.Module):
56+
def __init__(self):
57+
super().__init__()
58+
59+
def forward(self, x, y, z):
60+
return torch.ops.my_custom_library.tanh_sigmoid_cat_op(x, y, z)
61+
62+
# Sample inputs
63+
x = torch.randn(5, 2, 3)
64+
y = torch.randn(5, 6, 3)
65+
z = torch.randn(5, 4, 3)
66+
67+
# Dynamic dim constraints
68+
dim_n = Dim("n", min=5, max=10)
69+
dim_x1 = Dim("x1", max=100)
70+
dim_y1 = Dim("y1", max=50)
71+
dim_z1 = Dim("z1")
72+
dynamic_shapes = {
73+
"x": {0: dim_n, 1: dim_x1},
74+
"y": {0: dim_n, 1: dim_y1},
75+
"z": {0: dim_n, 1: dim_z1},
76+
}
77+
78+
m = fx.export_and_import(
79+
TanhSigmoidCatCustomOp(),
80+
x,
81+
y,
82+
z,
83+
dynamic_shapes=dynamic_shapes,
84+
import_symbolic_shape_expressions=True,
85+
)
86+
print(m)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# Licensed under the Apache License v2.0 with LLVM Exceptions.
2+
# See https://llvm.org/LICENSE.txt for license information.
3+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4+
# Also available under a BSD-style license. See LICENSE.
5+
6+
# RUN: %PYTHON %s | FileCheck %s
7+
8+
import torch
9+
import torch.export
10+
import torch.nn as nn
11+
from torch.export import Dim
12+
13+
from torch_mlir import fx
14+
15+
16+
def run(f):
17+
print(f"{f.__name__}")
18+
print("-" * len(f.__name__))
19+
f()
20+
print()
21+
22+
23+
@run
24+
# CHECK-LABEL: test_tanh_sigmoid_cat
25+
# CHECK: func.func @main(
26+
# CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[?,?,3],f32>,
27+
# CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: !torch.vtensor<[?,?,3],f32>,
28+
# CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: !torch.vtensor<[?,?,3],f32>) -> !torch.vtensor<[?,?,3],f32> {
29+
# CHECK: %[[S0:.+]] = torch.symbolic_int "s0" {min_val = 5, max_val = 10} : !torch.int
30+
# CHECK: %[[S1:.+]] = torch.symbolic_int "s1" {min_val = {{[0-9]+}}, max_val = 100} : !torch.int
31+
# CHECK: %[[S2:.+]] = torch.symbolic_int "s3" {min_val = {{[0-9]+}}, max_val = 50} : !torch.int
32+
# CHECK: %[[S3:.+]] = torch.symbolic_int "s5" {min_val = {{[0-9]+}}, max_val = {{[0-9]+}}} : !torch.int
33+
# CHECK: torch.bind_symbolic_shape %[[ARG0]], [%[[S0]], %[[S1]]], affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>
34+
# CHECK: torch.bind_symbolic_shape %[[ARG1]], [%[[S0]], %[[S2]]], affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>
35+
# CHECK: torch.bind_symbolic_shape %[[ARG2]], [%[[S0]], %[[S3]]], affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>
36+
# CHECK: %[[TANH:.+]] = torch.aten.tanh %[[ARG0]] : !torch.vtensor<[?,?,3],f32> -> !torch.vtensor<[?,?,3],f32>
37+
# CHECK: torch.bind_symbolic_shape %[[TANH]], [%[[S0]], %[[S1]]], affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>
38+
# CHECK: %[[SIG:.+]] = torch.aten.sigmoid %[[ARG1]] : !torch.vtensor<[?,?,3],f32> -> !torch.vtensor<[?,?,3],f32>
39+
# CHECK: torch.bind_symbolic_shape %[[SIG]], [%[[S0]], %[[S2]]], affine_map<()[s0, s1] -> (s0, s1, 3)> : !torch.vtensor<[?,?,3],f32>
40+
# CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[TANH]], %[[TANH]], %[[SIG]], %[[ARG2]] : (!torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>, !torch.vtensor<[?,?,3],f32>) -> !torch.list<vtensor>
41+
# CHECK: %[[CAT:.+]] = torch.aten.cat %[[LIST]], {{.*}} : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[?,?,3],f32>
42+
# CHECK: torch.bind_symbolic_shape %[[CAT]], [%[[S0]], %[[S1]], %[[S2]], %[[S3]]], affine_map<()[s0, s1, s2, s3] -> (s0, s2 + s3 + s1 * 2, 3)> : !torch.vtensor<[?,?,3],f32>
43+
# CHECK: return %[[CAT]] : !torch.vtensor<[?,?,3],f32>
44+
def test_tanh_sigmoid_cat():
45+
class TanhSigmoidCat(nn.Module):
46+
def __init__(self):
47+
super().__init__()
48+
49+
def forward(self, x, y, z):
50+
a = torch.tanh(x)
51+
b = torch.sigmoid(y)
52+
return torch.cat((a, a, b, z), dim=1)
53+
54+
# Sample inputs
55+
x = torch.randn(5, 2, 3)
56+
y = torch.randn(5, 6, 3)
57+
z = torch.randn(5, 4, 3)
58+
59+
# Dynamic dim constraints
60+
dim_n = Dim("n", min=5, max=10)
61+
dim_x1 = Dim("x1", max=100)
62+
dim_y1 = Dim("y1", max=50)
63+
dim_z1 = Dim("z1")
64+
dynamic_shapes = {
65+
"x": {0: dim_n, 1: dim_x1},
66+
"y": {0: dim_n, 1: dim_y1},
67+
"z": {0: dim_n, 1: dim_z1},
68+
}
69+
70+
m = fx.export_and_import(
71+
TanhSigmoidCat(),
72+
x,
73+
y,
74+
z,
75+
dynamic_shapes=dynamic_shapes,
76+
import_symbolic_shape_expressions=True,
77+
)
78+
print(m)

0 commit comments

Comments
 (0)