diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp index 8a9f223089794..e9545c3146b2f 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -1518,6 +1518,27 @@ LogicalResult arith::TruncIOp::verify() { /// can be represented without precision loss. OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) { auto resElemType = cast(getElementTypeOrSelf(getType())); + if (auto extOp = getOperand().getDefiningOp()) { + Value src = extOp.getIn(); + auto srcType = cast(getElementTypeOrSelf(src.getType())); + auto intermediateType = + cast(getElementTypeOrSelf(extOp.getType())); + // Check if the srcType is representable in the intermediateType. + if (llvm::APFloatBase::isRepresentableBy( + srcType.getFloatSemantics(), + intermediateType.getFloatSemantics())) { + // truncf(extf(a)) -> truncf(a) + if (srcType.getWidth() > resElemType.getWidth()) { + setOperand(src); + return getResult(); + } + + // truncf(extf(a)) -> a + if (srcType == resElemType) + return src; + } + } + const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics(); return constFoldCastOp( adaptor.getOperands(), getType(), diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir index e3750bb020cad..f0b2731707d18 100644 --- a/mlir/test/Dialect/Arith/canonicalize.mlir +++ b/mlir/test/Dialect/Arith/canonicalize.mlir @@ -714,6 +714,45 @@ func.func @extFPVectorConstant() -> vector<2xf128> { return %0 : vector<2xf128> } +// CHECK-LABEL: @truncExtf +// CHECK-NOT: truncf +// CHECK: return %arg0 +func.func @truncExtf(%arg0: f32) -> f32 { + %extf = arith.extf %arg0 : f32 to f64 + %trunc = arith.truncf %extf : f64 to f32 + return %trunc : f32 +} + +// CHECK-LABEL: @truncExtf1 +// CHECK-NOT: truncf +// CHECK: return %arg0 +func.func @truncExtf1(%arg0: bf16) -> bf16 { + %extf = arith.extf %arg0 : bf16 to f32 + %trunc = arith.truncf %extf : f32 to bf16 + return %trunc : bf16 +} + +// CHECK-LABEL: @truncExtf2 +// CHECK: %[[ARG0:.+]]: bf16 +// CHECK: %[[EXTF:.*]] = arith.extf %[[ARG0:.+]] : bf16 to f32 +// CHECK: %[[TRUNCF:.*]] = arith.truncf %[[EXTF:.*]] : f32 to f16 +// CHECK: return %[[TRUNCF:.*]] +func.func @truncExtf2(%arg0: bf16) -> f16 { + %extf = arith.extf %arg0 : bf16 to f32 + %trunc = arith.truncf %extf : f32 to f16 + return %trunc : f16 +} + +// CHECK-LABEL: @truncExtf3 +// CHECK: %[[ARG0:.+]]: f32 +// CHECK: %[[CST:.*]] = arith.truncf %[[ARG0:.+]] : f32 to f16 +// CHECK: return %[[CST:.*]] +func.func @truncExtf3(%arg0: f32) -> f16 { + %extf = arith.extf %arg0 : f32 to f64 + %truncf = arith.truncf %extf : f64 to f16 + return %truncf : f16 +} + // TODO: We should also add a test for not folding arith.extf on information loss. // This may happen when extending f8E5M2FNUZ to f16.