Skip to content

Commit 35a039b

Browse files
committed
Some fixes
1 parent f66b1d3 commit 35a039b

File tree

2 files changed

+35
-13
lines changed

2 files changed

+35
-13
lines changed

mlir/lib/Dialect/Arith/IR/ArithOps.cpp

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1517,23 +1517,25 @@ LogicalResult arith::TruncIOp::verify() {
15171517
/// Perform safe const propagation for truncf, i.e., only propagate if FP value
15181518
/// can be represented without precision loss.
15191519
OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
1520+
auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
15201521
if (auto extOp = getOperand().getDefiningOp<arith::ExtFOp>()) {
15211522
Value src = extOp.getIn();
1522-
Type srcType = getElementTypeOrSelf(src.getType());
1523-
Type dstType = getElementTypeOrSelf(getType());
1524-
// truncf(extf(a)) -> truncf(a)
1525-
if (llvm::cast<FloatType>(srcType).getWidth() >
1526-
llvm::cast<FloatType>(dstType).getWidth()) {
1527-
setOperand(src);
1528-
return getResult();
1529-
}
1523+
auto srcType = cast<FloatType>(getElementTypeOrSelf(src.getType()));
1524+
auto intermediateType = cast<FloatType>(getElementTypeOrSelf(extOp.getType()));
1525+
// Check if the srcType is representable in the intermediateType
1526+
if(llvm::APFloatBase::isRepresentableBy(srcType.getFloatSemantics(), intermediateType.getFloatSemantics())) {
1527+
// truncf(extf(a)) -> truncf(a)
1528+
if (srcType.getWidth() > resElemType.getWidth()) {
1529+
setOperand(src);
1530+
return getResult();
1531+
}
15301532

1531-
// truncf(extf(a)) -> a
1532-
if (srcType == dstType)
1533-
return src;
1533+
// truncf(extf(a)) -> a
1534+
if (srcType == resElemType)
1535+
return src;
1536+
}
15341537
}
15351538

1536-
auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
15371539
const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
15381540
return constFoldCastOp<FloatAttr, FloatAttr>(
15391541
adaptor.getOperands(), getType(),

mlir/test/Dialect/Arith/canonicalize.mlir

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -723,11 +723,31 @@ func.func @truncExtf(%arg0: f32) -> f32 {
723723
return %trunc : f32
724724
}
725725

726+
// CHECK-LABEL: @truncExtf1
727+
// CHECK-NOT: truncf
728+
// CHECK: return %arg0
729+
func.func @truncExtf1(%arg0: bf16) -> bf16 {
730+
%extf = arith.extf %arg0 : bf16 to f32
731+
%trunc = arith.truncf %extf : f32 to bf16
732+
return %trunc : bf16
733+
}
734+
726735
// CHECK-LABEL: @truncExtf2
736+
// CHECK: %[[ARG0:.+]]: bf16
737+
// CHECK: %[[EXTF:.*]] = arith.extf %[[ARG0:.+]] : bf16 to f32
738+
// CHECK: %[[TRUNCF:.*]] = arith.truncf %[[EXTF:.*]] : f32 to f16
739+
// CHECK: return %[[TRUNCF:.*]]
740+
func.func @truncExtf2(%arg0: bf16) -> f16 {
741+
%extf = arith.extf %arg0 : bf16 to f32
742+
%trunc = arith.truncf %extf : f32 to f16
743+
return %trunc : f16
744+
}
745+
746+
// CHECK-LABEL: @truncExtf3
727747
// CHECK: %[[ARG0:.+]]: f32
728748
// CHECK: %[[CST:.*]] = arith.truncf %[[ARG0:.+]] : f32 to f16
729749
// CHECK: return %[[CST:.*]]
730-
func.func @truncExtf2(%arg0: f32) -> f16 {
750+
func.func @truncExtf3(%arg0: f32) -> f16 {
731751
%extf = arith.extf %arg0 : f32 to f64
732752
%truncf = arith.truncf %extf : f64 to f16
733753
return %truncf : f16

0 commit comments

Comments
 (0)