-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[mlir][math] Fix intrinsic conversions to LLVM for 0D-vector types #141020
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
`vector<t>` types are not compatible with the LLVM type system, and must be explicitly converted into `vector<1xt>` when lowering. Employ this rule within the conversion pattern for `math.ctlz`, `.cttz` and `.absi` intrinsics. Signed-off-by: Artem Gindinson <[email protected]>
Thank you for submitting a Pull Request (PR) to the LLVM Project! This PR will be automatically labeled and the relevant teams will be notified. If you wish to, you can add reviewers by using the "Reviewers" section on this page. If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers. If you have further questions, they may be answered by the LLVM GitHub User Guide. You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums. |
@llvm/pr-subscribers-mlir Author: Artem Gindinson (AGindinson) Changes
Full diff: https://github.com/llvm/llvm-project/pull/141020.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
index 97da96afac4cd..19cd960b15294 100644
--- a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
+++ b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
@@ -84,6 +84,15 @@ struct IntOpWithFlagLowering : public ConvertOpToLLVMPattern<MathOp> {
auto loc = op.getLoc();
auto resultType = op.getResult().getType();
+ const auto &typeConverter = *this->getTypeConverter();
+ if (!LLVM::isCompatibleType(resultType)) {
+ resultType = typeConverter.convertType(resultType);
+ if (!resultType)
+ return failure();
+ }
+ if (operandType != resultType)
+ return rewriter.notifyMatchFailure(
+ op, "compatible result type doesn't match operand type");
if (!isa<LLVM::LLVMArrayType>(operandType)) {
rewriter.replaceOpWithNewOp<LLVMOp>(op, resultType, adaptor.getOperand(),
@@ -96,7 +105,7 @@ struct IntOpWithFlagLowering : public ConvertOpToLLVMPattern<MathOp> {
return failure();
return LLVM::detail::handleMultidimensionalVectors(
- op.getOperation(), adaptor.getOperands(), *this->getTypeConverter(),
+ op.getOperation(), adaptor.getOperands(), typeConverter,
[&](Type llvm1DVectorTy, ValueRange operands) {
return rewriter.create<LLVMOp>(loc, llvm1DVectorTy, operands[0],
false);
diff --git a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
index 974743a55932b..73325a3fd913e 100644
--- a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
+++ b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir
@@ -19,6 +19,8 @@ func.func @ops(%arg0: f32, %arg1: f32, %arg2: i32, %arg3: i32, %arg4: f64) {
// -----
+// CHECK-LABEL: func @absi(
+// CHECK-SAME: i32
func.func @absi(%arg0: i32) -> i32 {
// CHECK: = "llvm.intr.abs"(%{{.*}}) <{is_int_min_poison = false}> : (i32) -> i32
%0 = math.absi %arg0 : i32
@@ -27,6 +29,17 @@ func.func @absi(%arg0: i32) -> i32 {
// -----
+// CHECK-LABEL: func @absi_0d_vec(
+// CHECK-SAME: i32
+func.func @absi_0d_vec(%arg0 : vector<i32>) {
+ // CHECK: %[[CAST:.+]] = builtin.unrealized_conversion_cast %arg0 : vector<i32> to vector<1xi32>
+ // CHECK: "llvm.intr.abs"(%[[CAST]]) <{is_int_min_poison = false}> : (vector<1xi32>) -> vector<1xi32>
+ %0 = math.absi %arg0 : vector<i32>
+ func.return
+}
+
+// -----
+
// CHECK-LABEL: func @log1p(
// CHECK-SAME: f32
func.func @log1p(%arg0 : f32) {
@@ -201,6 +214,15 @@ func.func @ctlz(%arg0 : i32) {
func.return
}
+// CHECK-LABEL: func @ctlz_0d_vec(
+// CHECK-SAME: i32
+func.func @ctlz_0d_vec(%arg0 : vector<i32>) {
+ // CHECK: %[[CAST:.+]] = builtin.unrealized_conversion_cast %arg0 : vector<i32> to vector<1xi32>
+ // CHECK: "llvm.intr.ctlz"(%[[CAST]]) <{is_zero_poison = false}> : (vector<1xi32>) -> vector<1xi32>
+ %0 = math.ctlz %arg0 : vector<i32>
+ func.return
+}
+
// -----
// CHECK-LABEL: func @cttz(
@@ -213,6 +235,17 @@ func.func @cttz(%arg0 : i32) {
// -----
+// CHECK-LABEL: func @cttz_0d_vec(
+// CHECK-SAME: i32
+func.func @cttz_0d_vec(%arg0 : vector<i32>) {
+ // CHECK: %[[CAST:.+]] = builtin.unrealized_conversion_cast %arg0 : vector<i32> to vector<1xi32>
+ // CHECK: "llvm.intr.cttz"(%[[CAST]]) <{is_zero_poison = false}> : (vector<1xi32>) -> vector<1xi32>
+ %0 = math.cttz %arg0 : vector<i32>
+ func.return
+}
+
+// -----
+
// CHECK-LABEL: func @cttz_vec(
// CHECK-SAME: i32
func.func @cttz_vec(%arg0 : vector<4xi32>) {
|
@banach-space, @dcaballe, @Groverkss, @vzakhari, could you please take a look? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi Artem - thanks for sending this!
Rank-0 vectors are tricky. Support across MLIR is still somewhat patchy, and our recent discussion on their role in the broader ecosystem was inconclusive:
I’m bringing this up because I don’t see existing examples of rank-0 vectors in math-to-llvm.mlir
, which suggests this may not have been explored yet in the context of MathToLLVM
.
A couple of high-level questions:
- In practice, is there anything that will eliminate the
builtin.unrealized_conversion_cast
being inserted here? If not, could we instead lower the rank-0 vector to a scalar earlier in the pipeline? - Why the focus on
absi
,ctlz
, andcttz
specifically? I’d expect this to apply to most (if not all) math ops that accept vector inputs.
In principle, I’m not opposed to this change - MLIR does support rank-0 vectors - but I want to make sure we consider this holistically. Ideally, we’d ensure that this integrates cleanly with the rest of the stack and doesn’t introduce hard-to-track edge cases later.
vector<t>
types are not compatible with the LLVM type system, and must be explicitly converted intovector<1xt>
when lowering. Employ this rule within the conversion pattern formath.ctlz
,.cttz
and.absi
intrinsics.