Skip to content

[mlir][arith] Canonicalize sitofp(truncf) -> sitofp, and uitofp. #139925

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 19, 2025

Conversation

chsigg
Copy link
Contributor

@chsigg chsigg commented May 14, 2025

Add a canonicalization patterns that simplifies truncf(sitofp(x)) to sitofp(x) and truncf(uitofp(x)) to uitofp(x).

This assumes that the destination type of truncf is representable by the intermediate type.

Note that the truncf semantics requires that the destination type is narrower than the source type, so this is true for all types I can possibly think of, but one could probably construct an artificial counter example.

Somewhat related: #128096

@llvmbot
Copy link
Member

llvmbot commented May 14, 2025

@llvm/pr-subscribers-mlir-arith

Author: Christian Sigg (chsigg)

Changes

Add a canonicalization patterns that simplifies truncf(sitofp(x)) to sitofp(x) and truncf(uitofp(x)) to uitofp(x).

This assumes that the destination type of truncf is representable by the intermediate type.

Note that the truncf semantics requires that the destination type is narrower than the source type, so this is true for all types I can possibly think of, but one could probably construct an artificial counter example.

Somewhat related: #128096


Full diff: https://github.com/llvm/llvm-project/pull/139925.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Arith/IR/ArithOps.td (+1)
  • (modified) mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td (+14)
  • (modified) mlir/lib/Dialect/Arith/IR/ArithOps.cpp (+5)
  • (modified) mlir/test/Dialect/Arith/canonicalize.mlir (+10)
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index d50b6aeca15c9..599b3b982ec7f 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -1273,6 +1273,7 @@ def Arith_TruncFOp :
   ];
 
   let hasFolder = 1;
+  let hasCanonicalizer = 1;
   let hasVerifier = 1;
   let assemblyFormat = [{ $in ($roundingmode^)?
                           (`fastmath` `` $fastmath^)?
diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index 7e212df9029d1..7be73c4343639 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -419,6 +419,20 @@ def TruncIShrUIMulIToMulUIExtended :
        (ValueWiderThan $mul, $x),
        (TruncationMatchesShiftAmount $mul, $x, $c0)]>;
 
+//===----------------------------------------------------------------------===//
+// TruncIOp
+//===----------------------------------------------------------------------===//
+
+// truncf(sitofp(x)) -> sitofp(x).
+def TruncFSIToFPToSIToFP :
+    Pat<(Arith_TruncFOp:$tr (Arith_SIToFPOp:$fp $x), $rmf, $fmf),
+        (Arith_SIToFPOp $x)>;
+
+// truncf(sitofp(x)) -> sitofp(x).
+def TruncFUIToFPToUIToFP :
+    Pat<(Arith_TruncFOp:$tr (Arith_UIToFPOp:$fp $x), $rmf, $fmf),
+        (Arith_UIToFPOp $x)>;
+
 //===----------------------------------------------------------------------===//
 // MulFOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 3b308716c84dc..41f2d0f3425e2 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1552,6 +1552,11 @@ OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
       });
 }
 
+void arith::TruncFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+                                                  MLIRContext *context) {
+  patterns.add<TruncFSIToFPToSIToFP, TruncFUIToFPToUIToFP>(context);
+}
+
 bool arith::TruncFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
   return checkWidthChangeCast<std::less, FloatType>(inputs, outputs);
 }
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index d62c5b18fd041..a5dab73a62fac 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -753,6 +753,16 @@ func.func @truncExtf3(%arg0: f32) -> f16 {
   return %truncf : f16
 }
 
+// CHECK-LABEL: @truncSitofp
+//       CHECK-NOT: truncf
+//       CHECK:     %[[SITOFP:.*]] = arith.sitofp %[[ARG0:.*]] : i32 to f32
+//       CHECK:     return %[[SITOFP]]
+func.func @truncSitofp(%arg0: i32) -> f32 {
+  %sitofp = arith.sitofp %arg0 : i32 to f64
+  %trunc = arith.truncf %sitofp : f64 to f32
+  return %trunc : f32
+}
+
 // TODO: We should also add a test for not folding arith.extf on information loss.
 // This may happen when extending f8E5M2FNUZ to f16.
 

@llvmbot
Copy link
Member

llvmbot commented May 14, 2025

@llvm/pr-subscribers-mlir

Author: Christian Sigg (chsigg)

Changes

Add a canonicalization patterns that simplifies truncf(sitofp(x)) to sitofp(x) and truncf(uitofp(x)) to uitofp(x).

This assumes that the destination type of truncf is representable by the intermediate type.

Note that the truncf semantics requires that the destination type is narrower than the source type, so this is true for all types I can possibly think of, but one could probably construct an artificial counter example.

Somewhat related: #128096


Full diff: https://github.com/llvm/llvm-project/pull/139925.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Arith/IR/ArithOps.td (+1)
  • (modified) mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td (+14)
  • (modified) mlir/lib/Dialect/Arith/IR/ArithOps.cpp (+5)
  • (modified) mlir/test/Dialect/Arith/canonicalize.mlir (+10)
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index d50b6aeca15c9..599b3b982ec7f 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -1273,6 +1273,7 @@ def Arith_TruncFOp :
   ];
 
   let hasFolder = 1;
+  let hasCanonicalizer = 1;
   let hasVerifier = 1;
   let assemblyFormat = [{ $in ($roundingmode^)?
                           (`fastmath` `` $fastmath^)?
diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index 7e212df9029d1..7be73c4343639 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -419,6 +419,20 @@ def TruncIShrUIMulIToMulUIExtended :
        (ValueWiderThan $mul, $x),
        (TruncationMatchesShiftAmount $mul, $x, $c0)]>;
 
+//===----------------------------------------------------------------------===//
+// TruncIOp
+//===----------------------------------------------------------------------===//
+
+// truncf(sitofp(x)) -> sitofp(x).
+def TruncFSIToFPToSIToFP :
+    Pat<(Arith_TruncFOp:$tr (Arith_SIToFPOp:$fp $x), $rmf, $fmf),
+        (Arith_SIToFPOp $x)>;
+
+// truncf(sitofp(x)) -> sitofp(x).
+def TruncFUIToFPToUIToFP :
+    Pat<(Arith_TruncFOp:$tr (Arith_UIToFPOp:$fp $x), $rmf, $fmf),
+        (Arith_UIToFPOp $x)>;
+
 //===----------------------------------------------------------------------===//
 // MulFOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 3b308716c84dc..41f2d0f3425e2 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1552,6 +1552,11 @@ OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
       });
 }
 
+void arith::TruncFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
+                                                  MLIRContext *context) {
+  patterns.add<TruncFSIToFPToSIToFP, TruncFUIToFPToUIToFP>(context);
+}
+
 bool arith::TruncFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
   return checkWidthChangeCast<std::less, FloatType>(inputs, outputs);
 }
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index d62c5b18fd041..a5dab73a62fac 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -753,6 +753,16 @@ func.func @truncExtf3(%arg0: f32) -> f16 {
   return %truncf : f16
 }
 
+// CHECK-LABEL: @truncSitofp
+//       CHECK-NOT: truncf
+//       CHECK:     %[[SITOFP:.*]] = arith.sitofp %[[ARG0:.*]] : i32 to f32
+//       CHECK:     return %[[SITOFP]]
+func.func @truncSitofp(%arg0: i32) -> f32 {
+  %sitofp = arith.sitofp %arg0 : i32 to f64
+  %trunc = arith.truncf %sitofp : f64 to f32
+  return %trunc : f32
+}
+
 // TODO: We should also add a test for not folding arith.extf on information loss.
 // This may happen when extending f8E5M2FNUZ to f16.
 

@krzysz00
Copy link
Contributor

So the one thing I'm dubious about here is post-EmulateUnsupportedFloats IR, which'll get undone by this canonicalization rule

@joker-eph
Copy link
Collaborator

What is consuming the IR post EmulateUnsupportedFloats? It may just be that canonicalization can't run in between EmulateUnsupportedFloats and the consumer?

@krzysz00
Copy link
Contributor

EmilateUnsupportedFloats is, from where I'm standing, part of the late pre-export blob that doesn't have clear phase ordering (other residents of that blob would be stuff like flowering vector.transfer_read to vector.load but also, say, integer range optimizations or stuff like one last poke with the loop invariant code motion. Different projects might order everything slightly differently, and it's reasonable for there to be a canonicalization between unsupported float emulation and ConvertTo{LLVM,SPIR-V,...}

Maybe this canonicalization is fine, but if we're in a contest where

%y = aritf.sitofp %x : i32 to bf16

doesn't exist but

‰y0 = arith.sitofp %x ; i32 up f32
%y = arith.truncf %y0 : f32 to bf16

, undoing the transformation of the former to the latter would make things more annoying in lowering.

Maybe we want to use a fastmath flag on the sitofp to indicate when this transformation is allowed?

@joker-eph
Copy link
Collaborator

joker-eph commented May 15, 2025

Different projects might order everything slightly differently, and it's reasonable for there to be a canonicalization between unsupported float emulation and ConvertTo{LLVM,SPIR-V,...}

If the pass is a "codegen prepare" pass and is generating non-canonical IR intentional to make the lowering easier, then no it is not reasonable to expect that canonicalization can run there.

, undoing the transformation of the former to the latter would make things more annoying in lowering.

That's not a good argument against the canonicalization to me: the lowering must support %y = aritf.sitofp %x : i32 to bf16, or it's the responsibility of the pipeline author to correctly set it up so that the passes necessary to the lowering are correctly ordered (same argument as above).

Something more fundamental to me is rather if we know numerically that %y = aritf.sitofp %x : i32 to bf16 is equivalent to

‰y0 = arith.sitofp %x ; i32 up f32
%y = arith.truncf %y0 : f32 to bf16

or if the rounding can differ when removing the intermediate value? The intermediate is higher precision, but still can it always represent exactly all of the values of the smaller one in all cases?

We only need a fast math flag if our contraction changes the numerics.

@chsigg
Copy link
Contributor Author

chsigg commented May 16, 2025

rounding can differ when removing the intermediate value

Good point, it probably can differ. sitofp always uses default rounding mode, so I added a check for that.

Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add a test for rounding modes blocking canonicalization, otherwise ... seems fine to me after the discussion

chsigg added 3 commits May 19, 2025 10:28
Add a canonicalization pattern that simplifies `truncf(sitofp(x))` to
`sitofp(x)`.

This assumes that the destination type of truncf is representable by the source type.
Note that the truncf semantics requires that the destination type is narrower than
the source type, so this is true for all fp types I can possibly think of, but one
could probably construct a artificial counter example.
@chsigg chsigg force-pushed the piper_export_cl_758681496 branch from 8215575 to bd58ab7 Compare May 19, 2025 08:28
@chsigg chsigg merged commit c56e7f2 into llvm:main May 19, 2025
11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants