@@ -1517,23 +1517,25 @@ LogicalResult arith::TruncIOp::verify() {
1517
1517
// / Perform safe const propagation for truncf, i.e., only propagate if FP value
1518
1518
// / can be represented without precision loss.
1519
1519
OpFoldResult arith::TruncFOp::fold (FoldAdaptor adaptor) {
1520
+ auto resElemType = cast<FloatType>(getElementTypeOrSelf (getType ()));
1520
1521
if (auto extOp = getOperand ().getDefiningOp <arith::ExtFOp>()) {
1521
1522
Value src = extOp.getIn ();
1522
- Type srcType = getElementTypeOrSelf (src.getType ());
1523
- Type dstType = getElementTypeOrSelf (getType ());
1524
- // truncf(extf(a)) -> truncf(a)
1525
- if (llvm::cast<FloatType>(srcType).getWidth () >
1526
- llvm::cast<FloatType>(dstType).getWidth ()) {
1527
- setOperand (src);
1528
- return getResult ();
1529
- }
1523
+ auto srcType = cast<FloatType>(getElementTypeOrSelf (src.getType ()));
1524
+ auto intermediateType = cast<FloatType>(getElementTypeOrSelf (extOp.getType ()));
1525
+ // Check if the srcType is representable in the intermediateType
1526
+ if (llvm::APFloatBase::isRepresentableBy (srcType.getFloatSemantics (), intermediateType.getFloatSemantics ())) {
1527
+ // truncf(extf(a)) -> truncf(a)
1528
+ if (srcType.getWidth () > resElemType.getWidth ()) {
1529
+ setOperand (src);
1530
+ return getResult ();
1531
+ }
1530
1532
1531
- // truncf(extf(a)) -> a
1532
- if (srcType == dstType)
1533
- return src;
1533
+ // truncf(extf(a)) -> a
1534
+ if (srcType == resElemType)
1535
+ return src;
1536
+ }
1534
1537
}
1535
1538
1536
- auto resElemType = cast<FloatType>(getElementTypeOrSelf (getType ()));
1537
1539
const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics ();
1538
1540
return constFoldCastOp<FloatAttr, FloatAttr>(
1539
1541
adaptor.getOperands (), getType (),
0 commit comments