Skip to content

Commit c56e7f2

Browse files
authored
[mlir][arith] Canonicalize sitofp(truncf) -> sitofp, and uitofp. (llvm#139925)
Add a canonicalization patterns that simplifies `truncf(sitofp(x))` to `sitofp(x)` and `truncf(uitofp(x))` to `uitofp(x)`, if truncf has default rounding mode. This assumes that the destination type of truncf is representable by the intermediate type. Note that the truncf semantics requires that the destination type is narrower than the source type, so this is true for all types I can possibly think of, but one could probably construct an artificial counter example. Somewhat related: llvm#128096
1 parent 6d8a521 commit c56e7f2

File tree

4 files changed

+40
-0
lines changed

4 files changed

+40
-0
lines changed

mlir/include/mlir/Dialect/Arith/IR/ArithOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1273,6 +1273,7 @@ def Arith_TruncFOp :
12731273
];
12741274

12751275
let hasFolder = 1;
1276+
let hasCanonicalizer = 1;
12761277
let hasVerifier = 1;
12771278
let assemblyFormat = [{ $in ($roundingmode^)?
12781279
(`fastmath` `` $fastmath^)?

mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -419,6 +419,22 @@ def TruncIShrUIMulIToMulUIExtended :
419419
(ValueWiderThan $mul, $x),
420420
(TruncationMatchesShiftAmount $mul, $x, $c0)]>;
421421

422+
//===----------------------------------------------------------------------===//
423+
// TruncIOp
424+
//===----------------------------------------------------------------------===//
425+
426+
// truncf(sitofp(x)) -> sitofp(x) if default rounding mode.
427+
def TruncFSIToFPToSIToFP :
428+
Pat<(Arith_TruncFOp:$tr (Arith_SIToFPOp:$fp $x), $rmf, $fmf),
429+
(Arith_SIToFPOp $x),
430+
[(Constraint<CPred<"$0 == nullptr">, "default rounding mode"> $rmf)]>;
431+
432+
// truncf(uitofp(x)) -> uitofp(x) if default rounding mode.
433+
def TruncFUIToFPToUIToFP :
434+
Pat<(Arith_TruncFOp:$tr (Arith_UIToFPOp:$fp $x), $rmf, $fmf),
435+
(Arith_UIToFPOp $x),
436+
[(Constraint<CPred<"$0 == nullptr">, "default rounding mode"> $rmf)]>;
437+
422438
//===----------------------------------------------------------------------===//
423439
// MulFOp
424440
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Arith/IR/ArithOps.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1552,6 +1552,11 @@ OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
15521552
});
15531553
}
15541554

1555+
void arith::TruncFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1556+
MLIRContext *context) {
1557+
patterns.add<TruncFSIToFPToSIToFP, TruncFUIToFPToUIToFP>(context);
1558+
}
1559+
15551560
bool arith::TruncFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
15561561
return checkWidthChangeCast<std::less, FloatType>(inputs, outputs);
15571562
}

mlir/test/Dialect/Arith/canonicalize.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -753,6 +753,24 @@ func.func @truncExtf3(%arg0: f32) -> f16 {
753753
return %truncf : f16
754754
}
755755

756+
// CHECK-LABEL: @truncSitofp
757+
// CHECK: %[[SITOFP:.*]] = arith.sitofp %[[ARG0:.*]] : i32 to f32
758+
// CHECK-NOT: truncf
759+
// CHECK: return %[[SITOFP]]
760+
func.func @truncSitofp(%arg0: i32) -> f32 {
761+
%sitofp = arith.sitofp %arg0 : i32 to f64
762+
%trunc = arith.truncf %sitofp : f64 to f32
763+
return %trunc : f32
764+
}
765+
766+
// CHECK-LABEL: @truncSitofpConstrained
767+
// CHECK: truncf
768+
func.func @truncSitofpConstrained(%arg0: i32) -> f32 {
769+
%sitofp = arith.sitofp %arg0 : i32 to f64
770+
%trunc = arith.truncf %sitofp to_nearest_even : f64 to f32
771+
return %trunc : f32
772+
}
773+
756774
// TODO: We should also add a test for not folding arith.extf on information loss.
757775
// This may happen when extending f8E5M2FNUZ to f16.
758776

0 commit comments

Comments
 (0)