Skip to content
This repository was archived by the owner on Jan 30, 2025. It is now read-only.

Commit fffb0b5

Browse files
author
Srinath Avadhanula
committed
Add support for integer division to TCP
1 parent 36418dc commit fffb0b5

File tree

7 files changed

+110
-16
lines changed

7 files changed

+110
-16
lines changed

include/mlir-tcp/Dialect/IR/TcpEnums.td

+18
Original file line numberDiff line numberDiff line change
@@ -32,4 +32,22 @@ def Tcp_Signedness : I32EnumAttr<"Signedness",
3232

3333
def Tcp_SignednessAttr : EnumAttr<Tcp_Dialect, Tcp_Signedness, "signedness">;
3434

35+
// TCP rounding mode
36+
def Tcp_RoundingMode_Trunc : I32EnumAttrCase<"Trunc", 0>;
37+
def Tcp_RoundingMode_Floor : I32EnumAttrCase<"Floor", 1>;
38+
def Tcp_RoundingMode_Ceil : I32EnumAttrCase<"Ceil", 2>;
39+
40+
def Tcp_RoundingMode : I32EnumAttr<"RoundingMode",
41+
"Rounding mode for integer operations which need a rounding mode",
42+
[
43+
Tcp_RoundingMode_Trunc,
44+
Tcp_RoundingMode_Floor,
45+
Tcp_RoundingMode_Ceil
46+
]> {
47+
let genSpecializedAttr = 0;
48+
let cppNamespace = "::mlir::tcp";
49+
}
50+
51+
def Tcp_RoundingModeAttr : EnumAttr<Tcp_Dialect, Tcp_RoundingMode, "roundingMode">;
52+
3553
#endif // TCP_ENUMS

include/mlir-tcp/Dialect/IR/TcpOps.td

+21
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,27 @@ def Tcp_DivFOp : Tcp_BinaryElementwiseOp<"divf", [SameOperandsAndResultElementTy
160160
let assemblyFormat = "$in1 `,` $in2 attr-dict `:` type($in1) `,` type($in2) `->` type($out)";
161161
}
162162

163+
def Tcp_DivIOp : Tcp_BinaryElementwiseOp<"divi", [SameOperandsAndResultElementType]> {
164+
let summary = "Computes elementwise integer division";
165+
166+
let description = [{
167+
Computes the integer division of `in1` and `in2`.
168+
}];
169+
170+
let arguments = (ins
171+
Tcp_IntTensor:$in1,
172+
Tcp_IntTensor:$in2,
173+
Tcp_SignednessAttr:$signedness,
174+
Tcp_RoundingModeAttr:$rounding_mode
175+
);
176+
177+
let results = (outs
178+
Tcp_IntTensor:$out
179+
);
180+
181+
let assemblyFormat = "$in1 `,` $in2 attr-dict `:` type($in1) `,` type($in2) `->` type($out)";
182+
}
183+
163184
def Tcp_ConstOp : Tcp_Op<"const", [ConstantLike, Pure]> {
164185
let summary = "Constant op";
165186

lib/Conversion/TcpToLinalg/Elementwise.cpp

+28-2
Original file line numberDiff line numberDiff line change
@@ -190,9 +190,35 @@ createLinalgPayloadForElementwiseOp(Operation *op,
190190
if (isa<DivFOp>(op)) {
191191
if (elemType.isa<mlir::FloatType>())
192192
return {b.create<arith::DivFOp>(loc, payloadArgs[0], payloadArgs[1])};
193-
else
193+
else if (elemType.isa<mlir::IntegerType>()) {
194+
return {b.create<arith::DivSIOp>(loc, payloadArgs[0], payloadArgs[1])};
195+
}
196+
}
197+
198+
if (auto divOp = dyn_cast<DivIOp>(op)) {
199+
if (!elemType.isa<mlir::IntegerType>())
194200
llvm_unreachable("unsupported element type in "
195-
"createLinalgPayloadForElementwiseOp for tcp.divf");
201+
"createLinalgPayloadForElementwiseOp for tcp.divi");
202+
if (divOp.getSignedness() == Signedness::Unsigned) {
203+
if (divOp.getRoundingMode() == RoundingMode::Trunc ||
204+
divOp.getRoundingMode() == RoundingMode::Floor)
205+
return {b.create<arith::DivUIOp>(loc, payloadArgs[0], payloadArgs[1])};
206+
else
207+
return {
208+
b.create<arith::CeilDivUIOp>(loc, payloadArgs[0], payloadArgs[1])};
209+
} else if (divOp.getSignedness() == Signedness::Signed) {
210+
if (divOp.getRoundingMode() == RoundingMode::Trunc)
211+
return {b.create<arith::DivSIOp>(loc, payloadArgs[0], payloadArgs[1])};
212+
else if (divOp.getRoundingMode() == RoundingMode::Ceil)
213+
return {
214+
b.create<arith::CeilDivUIOp>(loc, payloadArgs[0], payloadArgs[1])};
215+
else
216+
return {
217+
b.create<arith::FloorDivSIOp>(loc, payloadArgs[0], payloadArgs[1])};
218+
} else {
219+
llvm_unreachable("unsupported signedness in "
220+
"createLinalgPayloadForElementwiseOp for tcp.divi");
221+
}
196222
}
197223

198224
if (isa<Atan2Op>(op)) {

lib/Conversion/TorchToTcp/Elementwise.cpp

+28-13
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ class ConvertAtenDivOp : public OpConversionPattern<AtenOpT> {
290290
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
291291
ConversionPatternRewriter &rewriter) const override {
292292
Value lhs = adaptor.getSelf();
293-
RankedTensorType lhsType = lhs.getType().dyn_cast<RankedTensorType>();
293+
RankedTensorType lhsType = dyn_cast<RankedTensorType>(lhs.getType());
294294

295295
Value rhs = adaptor.getOther();
296296

@@ -303,13 +303,6 @@ class ConvertAtenDivOp : public OpConversionPattern<AtenOpT> {
303303
return rewriter.notifyMatchFailure(
304304
op, "Only Ranked Tensor types are supported in TCP");
305305

306-
// TODO: Add integer conversions once `tcp.divsi` and `tcp.divui` are
307-
// added
308-
if (resultType.getElementType().isa<mlir::IntegerType>()) {
309-
return rewriter.notifyMatchFailure(
310-
op, "Only floating point division supported for now");
311-
}
312-
313306
auto inputAType = op.getSelf()
314307
.getType()
315308
.template dyn_cast<torch::Torch::ValueTensorType>()
@@ -318,17 +311,20 @@ class ConvertAtenDivOp : public OpConversionPattern<AtenOpT> {
318311
.template dyn_cast<torch::Torch::ValueTensorType>()
319312
.getDtype();
320313

314+
Type inputBType = nullptr;
321315
if (isa<AtenDivScalarOp>(op)) {
316+
inputBType = adaptor.getOther().getType();
317+
322318
rhs = convertScalarOperandToTensor(rewriter, op, op.getOther(),
323319
adaptor.getOther(), outputType,
324320
resultType.getElementType());
325321
if (!rhs)
326322
return rewriter.notifyMatchFailure(op, "Unsupported rhs data type");
327323
} else {
328-
auto inputBType = op.getOther()
329-
.getType()
330-
.template dyn_cast<torch::Torch::ValueTensorType>()
331-
.getDtype();
324+
inputBType = op.getOther()
325+
.getType()
326+
.template dyn_cast<torch::Torch::ValueTensorType>()
327+
.getDtype();
332328
rhs = torch_to_tcp::castTensorToDtype(rewriter, inputBType, outputType,
333329
rhs, resultType.getElementType());
334330
}
@@ -337,7 +333,26 @@ class ConvertAtenDivOp : public OpConversionPattern<AtenOpT> {
337333
std::tie(lhs, rhs) =
338334
torch_to_tcp::broadcastToMatchShape(rewriter, lhs, rhs);
339335

340-
rewriter.replaceOpWithNewOp<tcp::DivFOp>(op, resultType, lhs, rhs);
336+
if (isa<mlir::FloatType>(outputType)) {
337+
rewriter.replaceOpWithNewOp<tcp::DivFOp>(op, resultType, lhs, rhs);
338+
} else {
339+
auto in1IntType = cast<mlir::IntegerType>(inputAType);
340+
auto in2IntType = cast<mlir::IntegerType>(inputBType);
341+
auto outIntType = cast<mlir::IntegerType>(outputType);
342+
if ((in1IntType.getSignedness() != in2IntType.getSignedness()) ||
343+
(in1IntType.getSignedness() != outIntType.getSignedness()))
344+
return rewriter.notifyMatchFailure(op,
345+
"Mixed signedness not supported");
346+
if (in1IntType.getSignedness() ==
347+
mlir::IntegerType::SignednessSemantics::Signless)
348+
return rewriter.notifyMatchFailure(
349+
op, "Signless division not supported in TCP");
350+
351+
rewriter.replaceOpWithNewOp<tcp::DivIOp>(
352+
op, resultType, lhs, rhs,
353+
torch_to_tcp::getTcpSignedness(outIntType.getSignedness()),
354+
tcp::RoundingMode::Trunc);
355+
}
341356
return success();
342357
}
343358
};

lib/Conversion/TorchToTcp/Utils.cpp

+8
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,14 @@ getTcpSignednessAttr(MLIRContext *context,
3838
return SignednessAttr::get(context, Signedness::Unsigned);
3939
}
4040

41+
Signedness getTcpSignedness(IntegerType::SignednessSemantics signednessInfo) {
42+
if (signednessInfo == IntegerType::SignednessSemantics::Signless)
43+
return Signedness::Signless;
44+
if (signednessInfo == IntegerType::SignednessSemantics::Signed)
45+
return Signedness::Signed;
46+
return Signedness::Unsigned;
47+
}
48+
4149
// The parameter input is expected to be of RankedTensorType.
4250
Value broadcastRankInLeadingDims(ConversionPatternRewriter &rewriter,
4351
Value input, int64_t rankIncrease) {

lib/Conversion/TorchToTcp/Utils.h

+3
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ mlir::tcp::SignednessAttr
2323
getTcpSignednessAttr(MLIRContext *context,
2424
IntegerType::SignednessSemantics signednessInfo);
2525

26+
mlir::tcp::Signedness
27+
getTcpSignedness(IntegerType::SignednessSemantics signednessInfo);
28+
2629
// Helper function to expand the rank of the input tensor. Works by
2730
// adding 1-dim shape to the leading dims using `tensor::ExpandShapeOp`.
2831
Value broadcastRankInLeadingDims(ConversionPatternRewriter &rewriter,

test/Pipeline/torch_to_tcp_pipeline.mlir

+4-1
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,11 @@ func.func @torch.aten.div.Tensor$mixed_type_fp(%arg0: !torch.vtensor<[?, ?],f32>
108108

109109
// -----
110110

111+
// CHECK: func.func @torch.aten.div.Tensor$mixed_type_int(%[[ARG0:.+]]: tensor<?x?xi16>, %[[ARG1:.+]]: tensor<?x?xi32>) -> tensor<?x?xi32> {
112+
// CHECK: %[[V0:.+]] = tcp.cast %[[ARG0]] {in_int_signedness = #tcp<signedness Signed>, out_int_signedness = #tcp<signedness Signed>} : tensor<?x?xi16> -> tensor<?x?xi32>
113+
// CHECK: %[[V1:.+]] = tcp.divi %[[V0]], %[[ARG1]] {rounding_mode = #tcp<roundingMode Trunc>, signedness = #tcp<signedness Signed>} : tensor<?x?xi32>, tensor<?x?xi32> -> tensor<?x?xi32>
114+
// CHECK: return %[[V1]] : tensor<?x?xi32>
111115
func.func @torch.aten.div.Tensor$mixed_type_int(%arg0: !torch.vtensor<[?, ?],si16>, %arg1: !torch.vtensor<[?, ?],si32>) -> !torch.vtensor<[?, ?],si32> {
112-
// expected-error @below {{failed to legalize operation 'torch.aten.div.Tensor' that was explicitly marked illegal}}
113116
%0 = torch.aten.div.Tensor %arg0, %arg1 : !torch.vtensor<[?, ?],si16>, !torch.vtensor<[?, ?],si32> -> !torch.vtensor<[?, ?],si32>
114117
return %0 : !torch.vtensor<[?, ?],si32>
115118
}

0 commit comments

Comments
 (0)