diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td index d50b6aeca15c9..599b3b982ec7f 100644 --- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td +++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td @@ -1273,6 +1273,7 @@ def Arith_TruncFOp : ]; let hasFolder = 1; + let hasCanonicalizer = 1; let hasVerifier = 1; let assemblyFormat = [{ $in ($roundingmode^)? (`fastmath` `` $fastmath^)? diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td index 7e212df9029d1..13eb97a910bd4 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td +++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td @@ -419,6 +419,22 @@ def TruncIShrUIMulIToMulUIExtended : (ValueWiderThan $mul, $x), (TruncationMatchesShiftAmount $mul, $x, $c0)]>; +//===----------------------------------------------------------------------===// +// TruncIOp +//===----------------------------------------------------------------------===// + +// truncf(sitofp(x)) -> sitofp(x) if default rounding mode. +def TruncFSIToFPToSIToFP : + Pat<(Arith_TruncFOp:$tr (Arith_SIToFPOp:$fp $x), $rmf, $fmf), + (Arith_SIToFPOp $x), + [(Constraint, "default rounding mode"> $rmf)]>; + +// truncf(uitofp(x)) -> uitofp(x) if default rounding mode. +def TruncFUIToFPToUIToFP : + Pat<(Arith_TruncFOp:$tr (Arith_UIToFPOp:$fp $x), $rmf, $fmf), + (Arith_UIToFPOp $x), + [(Constraint, "default rounding mode"> $rmf)]>; + //===----------------------------------------------------------------------===// // MulFOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp index 3b308716c84dc..41f2d0f3425e2 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -1552,6 +1552,11 @@ OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) { }); } +void arith::TruncFOp::getCanonicalizationPatterns(RewritePatternSet &patterns, + MLIRContext *context) { + patterns.add(context); +} + bool arith::TruncFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { return checkWidthChangeCast(inputs, outputs); } diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir index d62c5b18fd041..b6188c81ff912 100644 --- a/mlir/test/Dialect/Arith/canonicalize.mlir +++ b/mlir/test/Dialect/Arith/canonicalize.mlir @@ -753,6 +753,24 @@ func.func @truncExtf3(%arg0: f32) -> f16 { return %truncf : f16 } +// CHECK-LABEL: @truncSitofp +// CHECK: %[[SITOFP:.*]] = arith.sitofp %[[ARG0:.*]] : i32 to f32 +// CHECK-NOT: truncf +// CHECK: return %[[SITOFP]] +func.func @truncSitofp(%arg0: i32) -> f32 { + %sitofp = arith.sitofp %arg0 : i32 to f64 + %trunc = arith.truncf %sitofp : f64 to f32 + return %trunc : f32 +} + +// CHECK-LABEL: @truncSitofpConstrained +// CHECK: truncf +func.func @truncSitofpConstrained(%arg0: i32) -> f32 { + %sitofp = arith.sitofp %arg0 : i32 to f64 + %trunc = arith.truncf %sitofp to_nearest_even : f64 to f32 + return %trunc : f32 +} + // TODO: We should also add a test for not folding arith.extf on information loss. // This may happen when extending f8E5M2FNUZ to f16.