@@ -290,7 +290,7 @@ class ConvertAtenDivOp : public OpConversionPattern<AtenOpT> {
290
290
matchAndRewrite (AtenOpT op, OpAdaptor adaptor,
291
291
ConversionPatternRewriter &rewriter) const override {
292
292
Value lhs = adaptor.getSelf ();
293
- RankedTensorType lhsType = lhs. getType (). dyn_cast <RankedTensorType>();
293
+ RankedTensorType lhsType = dyn_cast<RankedTensorType>(lhs. getType () );
294
294
295
295
Value rhs = adaptor.getOther ();
296
296
@@ -303,13 +303,6 @@ class ConvertAtenDivOp : public OpConversionPattern<AtenOpT> {
303
303
return rewriter.notifyMatchFailure (
304
304
op, " Only Ranked Tensor types are supported in TCP" );
305
305
306
- // TODO: Add integer conversions once `tcp.divsi` and `tcp.divui` are
307
- // added
308
- if (resultType.getElementType ().isa <mlir::IntegerType>()) {
309
- return rewriter.notifyMatchFailure (
310
- op, " Only floating point division supported for now" );
311
- }
312
-
313
306
auto inputAType = op.getSelf ()
314
307
.getType ()
315
308
.template dyn_cast <torch::Torch::ValueTensorType>()
@@ -318,17 +311,20 @@ class ConvertAtenDivOp : public OpConversionPattern<AtenOpT> {
318
311
.template dyn_cast <torch::Torch::ValueTensorType>()
319
312
.getDtype ();
320
313
314
+ Type inputBType = nullptr ;
321
315
if (isa<AtenDivScalarOp>(op)) {
316
+ inputBType = adaptor.getOther ().getType ();
317
+
322
318
rhs = convertScalarOperandToTensor (rewriter, op, op.getOther (),
323
319
adaptor.getOther (), outputType,
324
320
resultType.getElementType ());
325
321
if (!rhs)
326
322
return rewriter.notifyMatchFailure (op, " Unsupported rhs data type" );
327
323
} else {
328
- auto inputBType = op.getOther ()
329
- .getType ()
330
- .template dyn_cast <torch::Torch::ValueTensorType>()
331
- .getDtype ();
324
+ inputBType = op.getOther ()
325
+ .getType ()
326
+ .template dyn_cast <torch::Torch::ValueTensorType>()
327
+ .getDtype ();
332
328
rhs = torch_to_tcp::castTensorToDtype (rewriter, inputBType, outputType,
333
329
rhs, resultType.getElementType ());
334
330
}
@@ -337,7 +333,26 @@ class ConvertAtenDivOp : public OpConversionPattern<AtenOpT> {
337
333
std::tie (lhs, rhs) =
338
334
torch_to_tcp::broadcastToMatchShape (rewriter, lhs, rhs);
339
335
340
- rewriter.replaceOpWithNewOp <tcp::DivFOp>(op, resultType, lhs, rhs);
336
+ if (isa<mlir::FloatType>(outputType)) {
337
+ rewriter.replaceOpWithNewOp <tcp::DivFOp>(op, resultType, lhs, rhs);
338
+ } else {
339
+ auto in1IntType = cast<mlir::IntegerType>(inputAType);
340
+ auto in2IntType = cast<mlir::IntegerType>(inputBType);
341
+ auto outIntType = cast<mlir::IntegerType>(outputType);
342
+ if ((in1IntType.getSignedness () != in2IntType.getSignedness ()) ||
343
+ (in1IntType.getSignedness () != outIntType.getSignedness ()))
344
+ return rewriter.notifyMatchFailure (op,
345
+ " Mixed signedness not supported" );
346
+ if (in1IntType.getSignedness () ==
347
+ mlir::IntegerType::SignednessSemantics::Signless)
348
+ return rewriter.notifyMatchFailure (
349
+ op, " Signless division not supported in TCP" );
350
+
351
+ rewriter.replaceOpWithNewOp <tcp::DivIOp>(
352
+ op, resultType, lhs, rhs,
353
+ torch_to_tcp::getTcpSignedness (outIntType.getSignedness ()),
354
+ tcp::RoundingMode::Trunc);
355
+ }
341
356
return success ();
342
357
}
343
358
};
0 commit comments