Skip to content
This repository was archived by the owner on Jan 30, 2025. It is now read-only.

Commit a684bd6

Browse files
committed
update custom op conversions
1 parent 7b53fe4 commit a684bd6

File tree

4 files changed

+197
-1
lines changed

4 files changed

+197
-1
lines changed

lib/Conversion/TorchToTcp/Misc.cpp

+5-1
Original file line numberDiff line numberDiff line change
@@ -127,11 +127,15 @@ class ConvertAtenBroadcastLikeOps : public OpConversionPattern<AtenOpT> {
127127
}
128128
}
129129

130+
// fold the broadcast if no axes are found
131+
if (axes.size() == 0) {
132+
rewriter.replaceOp(op, input);
133+
return success();
134+
}
130135
RankedTensorType resultType =
131136
OpConversionPattern<AtenOpT>::getTypeConverter()
132137
->convertType(op->getResult(0).getType())
133138
.template cast<RankedTensorType>();
134-
135139
auto axesAttr = rewriter.getI64ArrayAttr(axes);
136140
rewriter.replaceOpWithNewOp<tcp::BroadcastOp>(op, resultType, input,
137141
resultShape, axesAttr);

lib/Conversion/TorchToTcp/TcpCustomOp.cpp

+120
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,12 @@
1414

1515
#include "PopulatePatterns.h"
1616
#include "Utils.h"
17+
#include "mlir/Dialect/Arith/IR/Arith.h"
18+
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1719
#include "torch-mlir/Conversion/Utils/Utils.h"
1820
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
1921
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
22+
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
2023

2124
#include "llvm/ADT/StringSet.h"
2225

@@ -211,6 +214,118 @@ class ConvertAtenFakeQuantizePerChannelAffineOp
211214
}
212215
};
213216

217+
class ConvertAtenSortOp : public OpConversionPattern<AtenSortOp> {
218+
public:
219+
using OpConversionPattern::OpConversionPattern;
220+
221+
LogicalResult
222+
matchAndRewrite(AtenSortOp op, OpAdaptor adaptor,
223+
ConversionPatternRewriter &rewriter) const override {
224+
torch_to_tcp::TorchToTcpCustomOpConversionHelper helper{op, rewriter,
225+
getTypeConverter()};
226+
helper.addOperand("self", adaptor.getSelf());
227+
228+
helper.addIntAttr("dim", op.getDim());
229+
helper.addBoolAttr("descending", op.getDescending());
230+
231+
return helper.replace();
232+
}
233+
};
234+
235+
class ConvertAtenCumsumOp : public OpConversionPattern<AtenCumsumOp> {
236+
public:
237+
using OpConversionPattern::OpConversionPattern;
238+
239+
LogicalResult
240+
matchAndRewrite(AtenCumsumOp op, OpAdaptor adaptor,
241+
ConversionPatternRewriter &rewriter) const override {
242+
torch_to_tcp::TorchToTcpCustomOpConversionHelper helper{op, rewriter,
243+
getTypeConverter()};
244+
helper.addOperand("self", adaptor.getSelf());
245+
246+
helper.addIntAttr("dim", op.getDim());
247+
if (!isa<Torch::ConstantNoneOp>(op.getDtype().getDefiningOp()))
248+
return rewriter.notifyMatchFailure(op, "Unsupported dtype argument");
249+
250+
return helper.replace();
251+
}
252+
};
253+
254+
class ConvertAtenMinDimOp : public OpConversionPattern<AtenMinDimOp> {
255+
public:
256+
using OpConversionPattern::OpConversionPattern;
257+
258+
LogicalResult
259+
matchAndRewrite(AtenMinDimOp op, OpAdaptor adaptor,
260+
ConversionPatternRewriter &rewriter) const override {
261+
torch_to_tcp::TorchToTcpCustomOpConversionHelper helper{op, rewriter,
262+
getTypeConverter()};
263+
helper.addOperand("self", adaptor.getSelf());
264+
265+
helper.addIntAttr("dim", op.getDim());
266+
helper.addBoolAttr("keepdim", op.getKeepdim());
267+
268+
return helper.replace();
269+
}
270+
};
271+
272+
class ConvertAtenViewOp : public OpConversionPattern<AtenViewOp> {
273+
public:
274+
using OpConversionPattern::OpConversionPattern;
275+
276+
LogicalResult
277+
matchAndRewrite(AtenViewOp op, OpAdaptor adaptor,
278+
ConversionPatternRewriter &rewriter) const override {
279+
torch_to_tcp::TorchToTcpCustomOpConversionHelper helper{op, rewriter,
280+
getTypeConverter()};
281+
Value self = adaptor.getSelf();
282+
auto srcType = self.getType().cast<RankedTensorType>();
283+
auto resultType =
284+
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
285+
286+
SmallVector<int64_t> size;
287+
if (matchPattern(op.getSize(), m_TorchListOfConstantInts(size)) &&
288+
srcType.hasStaticShape() && resultType.hasStaticShape())
289+
return rewriter.notifyMatchFailure(op, "only dynamic shape is supported");
290+
291+
helper.addOperand("self", self);
292+
Operation *primListOp = op.getSize().getDefiningOp();
293+
auto listConstruct = dyn_cast<Torch::PrimListConstructOp>(primListOp);
294+
if (!listConstruct) {
295+
return rewriter.notifyMatchFailure(
296+
op, "Size must come from PrimListConstructOp");
297+
}
298+
int idx = 0;
299+
for (Value value : listConstruct.getElements()) {
300+
int64_t dimSize;
301+
if (!matchPattern(value, m_TorchConstantInt(&dimSize))) {
302+
size.push_back(ShapedType::kDynamic);
303+
if (!isa<TorchConversion::FromI64Op>(value.getDefiningOp()))
304+
return rewriter.notifyMatchFailure(
305+
op, "dynamic dim size should come from FromI64Op");
306+
auto conversionOp =
307+
dyn_cast<TorchConversion::FromI64Op>(value.getDefiningOp());
308+
if (!isa<arith::IndexCastOp>(conversionOp.getOperand().getDefiningOp()))
309+
return rewriter.notifyMatchFailure(
310+
op, "dynamic dim size should come from IndexCastOp");
311+
auto indexCastOp = dyn_cast<arith::IndexCastOp>(
312+
conversionOp.getOperand().getDefiningOp());
313+
if (!isa<tensor::DimOp>(indexCastOp.getIn().getDefiningOp()))
314+
return rewriter.notifyMatchFailure(
315+
op, "dynamic dim size should come from DimOp");
316+
auto dimOp =
317+
dyn_cast<tensor::DimOp>(indexCastOp.getIn().getDefiningOp());
318+
helper.addOperand("idx_" + std::to_string(idx), dimOp);
319+
} else
320+
size.push_back(dimSize);
321+
idx++;
322+
}
323+
helper.addDenseIntArrayAttr("size", size);
324+
325+
return helper.replace();
326+
}
327+
};
328+
214329
} // namespace
215330

216331
void torch_to_tcp::populateTcpCustomOpPatternsAndLegality(
@@ -227,6 +342,11 @@ void torch_to_tcp::populateTcpCustomOpPatternsAndLegality(
227342
INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN(
228343
AtenFakeQuantizePerTensorAffineTensorQparamsOp);
229344
INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN(AtenFakeQuantizePerChannelAffineOp);
345+
INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN(AtenSortOp);
346+
INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN(AtenCumsumOp);
347+
INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN(AtenMinDimOp);
348+
// AtenViewOp can still live after torch-to-tcp conversion
349+
patterns.add<ConvertAtenViewOp>(typeConverter, patterns.getContext());
230350
#undef INSERT_ATEN_TO_TCP_CUSTOM_OP_PATTERN
231351

232352
// Torch -> TOSA doesn't handle transposed convolutions; map them to

lib/InitAll.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,12 @@
1818
#include "mlir/Dialect/Func/IR/FuncOps.h"
1919
#include "mlir/IR/Dialect.h"
2020
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
21+
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h"
2122

2223
void mlir::tcp::registerAllDialects(mlir::DialectRegistry &registry) {
2324
registry.insert<tcp::TcpDialect>();
2425
registry.insert<torch::Torch::TorchDialect>();
26+
registry.insert<torch::TorchConversion::TorchConversionDialect>();
2527
mlir::func::registerInlinerExtension(registry);
2628
mlir::tcp::registerTilingInterfaceExternalModels(registry);
2729
}

test/Conversion/TorchToTcp/tcp_custom_ops.mlir

+70
Original file line numberDiff line numberDiff line change
@@ -254,3 +254,73 @@ func.func @torch.aten.fake_quantize_per_channel_affine_zero_like(%input: !torch.
254254
%output = torch.aten.fake_quantize_per_channel_affine %input, %scale, %zero_point, %int1, %int0, %int255 : !torch.vtensor<[1,3,32,32],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],si32>, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1,3,32,32],f32>
255255
return %output : !torch.vtensor<[1,3,32,32],f32>
256256
}
257+
258+
// -----
259+
260+
// CHECK-LABEL: func.func @torch.aten.sort(
261+
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,2304],f32>) -> !torch.vtensor<[?,2304],f32> {
262+
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,2304],f32> -> tensor<?x2304xf32>
263+
// CHECK: %[[CUSTOM:.*]] = tcp.custom_op("torch.aten.sort") %[[T0]] {descending = true, dim = -1 : i64, torch_operand_names = ["self"]} :
264+
// CHECK-SAME: tensor<?x2304xf32> -> tensor<?x2304xf32>, tensor<?x2304xi64>
265+
// CHECK: %[[RES:.*]] = torch_c.from_builtin_tensor %[[CUSTOM:.*]] : tensor<?x2304xf32> -> !torch.vtensor<[?,2304],f32>
266+
// CHECK: return %[[RES]] : !torch.vtensor<[?,2304],f32>
267+
func.func @torch.aten.sort(%input: !torch.vtensor<[?,2304],f32>) -> !torch.vtensor<[?,2304],f32> {
268+
%int-1 = torch.constant.int -1
269+
%true = torch.constant.bool true
270+
%output0, %output1 = torch.aten.sort %input, %int-1, %true : !torch.vtensor<[?,2304],f32>, !torch.int, !torch.bool -> !torch.vtensor<[?,2304],f32>, !torch.vtensor<[?,2304],si64>
271+
return %output0 : !torch.vtensor<[?,2304],f32>
272+
}
273+
274+
// -----
275+
276+
// CHECK-LABEL: func.func @torch.aten.cumsum(
277+
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?],si32>) -> !torch.vtensor<[?],si64> {
278+
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?],si32> -> tensor<?xi32>
279+
// CHECK: %[[CUSTOM:.*]] = tcp.custom_op("torch.aten.cumsum") %[[T0]] {dim = 0 : i64, torch_operand_names = ["self"]} : tensor<?xi32> -> tensor<?xi64>
280+
// CHECK: %[[RES:.*]] = torch_c.from_builtin_tensor %[[CUSTOM]] : tensor<?xi64> -> !torch.vtensor<[?],si64>
281+
// CHECK: return %[[RES]] : !torch.vtensor<[?],si64>
282+
func.func @torch.aten.cumsum(%input: !torch.vtensor<[?],si32>) -> !torch.vtensor<[?],si64> {
283+
%int0 = torch.constant.int 0
284+
%none = torch.constant.none
285+
%1 = torch.aten.cumsum %input, %int0, %none : !torch.vtensor<[?],si32>, !torch.int, !torch.none -> !torch.vtensor<[?],si64>
286+
return %1 : !torch.vtensor<[?],si64>
287+
}
288+
289+
// -----
290+
291+
// CHECK-LABEL: func.func @torch.aten.min.dim(
292+
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,80],f32>) -> !torch.vtensor<[?],f32> {
293+
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,80],f32> -> tensor<?x80xf32>
294+
// CHECK: %[[CUSTOM:.*]] = tcp.custom_op("torch.aten.min.dim") %[[T0]] {dim = 1 : i64, keepdim = false, torch_operand_names = ["self"]} :
295+
// CHECK-SAME: tensor<?x80xf32> -> tensor<?xf32>, tensor<?xi64>
296+
// CHECK: %[[RES:.*]] = torch_c.from_builtin_tensor %[[CUSTOM:.*]] : tensor<?xf32> -> !torch.vtensor<[?],f32>
297+
// CHECK: return %[[RES]] : !torch.vtensor<[?],f32>
298+
func.func @torch.aten.min.dim(%input: !torch.vtensor<[?,80],f32>) -> !torch.vtensor<[?],f32> {
299+
%int1 = torch.constant.int 1
300+
%false = torch.constant.bool false
301+
%output0, %output1 = torch.aten.min.dim %input, %int1, %false : !torch.vtensor<[?,80],f32>, !torch.int, !torch.bool -> !torch.vtensor<[?],f32>, !torch.vtensor<[?],si64>
302+
return %output0 : !torch.vtensor<[?],f32>
303+
}
304+
305+
// -----
306+
307+
// CHECK-LABEL: func.func @torch.aten.view_dynamic_shape(
308+
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,384,16],f32>, %[[ARG1:.*]]: tensor<?x2736x16xf32>) -> !torch.vtensor<[?,24,16,16],f32> {
309+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
310+
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,384,16],f32> -> tensor<?x384x16xf32>
311+
// CHECK: %[[DIM:.*]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x2736x16xf32>
312+
// CHECK: %[[CUSTOM:.*]] = tcp.custom_op("torch.aten.view") %[[T0]], %[[DIM]] {size = array<i64: -9223372036854775808, 24, 16, 16>, torch_operand_names = ["self", "idx_0"]} :
313+
// CHECK-SAME: tensor<?x384x16xf32>, index -> tensor<?x24x16x16xf32>
314+
// CHECK: %[[RES:.*]] = torch_c.from_builtin_tensor %[[CUSTOM:.*]] : tensor<?x24x16x16xf32> -> !torch.vtensor<[?,24,16,16],f32>
315+
// CHECK: return %[[RES]] : !torch.vtensor<[?,24,16,16],f32>
316+
func.func @torch.aten.view_dynamic_shape(%arg0: !torch.vtensor<[?,384,16],f32>, %arg1: tensor<?x2736x16xf32>) -> !torch.vtensor<[?,24,16,16],f32> {
317+
%c0 = arith.constant 0 : index
318+
%int24 = torch.constant.int 24
319+
%int16 = torch.constant.int 16
320+
%dim_32 = tensor.dim %arg1, %c0 : tensor<?x2736x16xf32>
321+
%1 = arith.index_cast %dim_32 : index to i64
322+
%2 = torch_c.from_i64 %1
323+
%3 = torch.prim.ListConstruct %2, %int24, %int16, %int16 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
324+
%4 = torch.aten.view %arg0, %3 : !torch.vtensor<[?,384,16],f32>, !torch.list<int> -> !torch.vtensor<[?,24,16,16],f32>
325+
return %4 : !torch.vtensor<[?,24,16,16],f32>
326+
}

0 commit comments

Comments
 (0)