Skip to content

Commit 5d0c5c6

Browse files
authored
[MLIR][ARITH] Adds missing foldings for truncf (#128096)
This patch is mainly to deal with folding `truncf`, as follows: `truncf(extf(a))` -> `a`, if `a` has the same bitwidth as the result `truncf(extf(a))` -> `truncf(a)`, if `a` has larger bitwidth than the result
1 parent 6038fd4 commit 5d0c5c6

File tree

2 files changed

+60
-0
lines changed

2 files changed

+60
-0
lines changed

mlir/lib/Dialect/Arith/IR/ArithOps.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1518,6 +1518,27 @@ LogicalResult arith::TruncIOp::verify() {
15181518
/// can be represented without precision loss.
15191519
OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
15201520
auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
1521+
if (auto extOp = getOperand().getDefiningOp<arith::ExtFOp>()) {
1522+
Value src = extOp.getIn();
1523+
auto srcType = cast<FloatType>(getElementTypeOrSelf(src.getType()));
1524+
auto intermediateType =
1525+
cast<FloatType>(getElementTypeOrSelf(extOp.getType()));
1526+
// Check if the srcType is representable in the intermediateType.
1527+
if (llvm::APFloatBase::isRepresentableBy(
1528+
srcType.getFloatSemantics(),
1529+
intermediateType.getFloatSemantics())) {
1530+
// truncf(extf(a)) -> truncf(a)
1531+
if (srcType.getWidth() > resElemType.getWidth()) {
1532+
setOperand(src);
1533+
return getResult();
1534+
}
1535+
1536+
// truncf(extf(a)) -> a
1537+
if (srcType == resElemType)
1538+
return src;
1539+
}
1540+
}
1541+
15211542
const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
15221543
return constFoldCastOp<FloatAttr, FloatAttr>(
15231544
adaptor.getOperands(), getType(),

mlir/test/Dialect/Arith/canonicalize.mlir

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -714,6 +714,45 @@ func.func @extFPVectorConstant() -> vector<2xf128> {
714714
return %0 : vector<2xf128>
715715
}
716716

717+
// CHECK-LABEL: @truncExtf
718+
// CHECK-NOT: truncf
719+
// CHECK: return %arg0
720+
func.func @truncExtf(%arg0: f32) -> f32 {
721+
%extf = arith.extf %arg0 : f32 to f64
722+
%trunc = arith.truncf %extf : f64 to f32
723+
return %trunc : f32
724+
}
725+
726+
// CHECK-LABEL: @truncExtf1
727+
// CHECK-NOT: truncf
728+
// CHECK: return %arg0
729+
func.func @truncExtf1(%arg0: bf16) -> bf16 {
730+
%extf = arith.extf %arg0 : bf16 to f32
731+
%trunc = arith.truncf %extf : f32 to bf16
732+
return %trunc : bf16
733+
}
734+
735+
// CHECK-LABEL: @truncExtf2
736+
// CHECK: %[[ARG0:.+]]: bf16
737+
// CHECK: %[[EXTF:.*]] = arith.extf %[[ARG0:.+]] : bf16 to f32
738+
// CHECK: %[[TRUNCF:.*]] = arith.truncf %[[EXTF:.*]] : f32 to f16
739+
// CHECK: return %[[TRUNCF:.*]]
740+
func.func @truncExtf2(%arg0: bf16) -> f16 {
741+
%extf = arith.extf %arg0 : bf16 to f32
742+
%trunc = arith.truncf %extf : f32 to f16
743+
return %trunc : f16
744+
}
745+
746+
// CHECK-LABEL: @truncExtf3
747+
// CHECK: %[[ARG0:.+]]: f32
748+
// CHECK: %[[CST:.*]] = arith.truncf %[[ARG0:.+]] : f32 to f16
749+
// CHECK: return %[[CST:.*]]
750+
func.func @truncExtf3(%arg0: f32) -> f16 {
751+
%extf = arith.extf %arg0 : f32 to f64
752+
%truncf = arith.truncf %extf : f64 to f16
753+
return %truncf : f16
754+
}
755+
717756
// TODO: We should also add a test for not folding arith.extf on information loss.
718757
// This may happen when extending f8E5M2FNUZ to f16.
719758

0 commit comments

Comments
 (0)