Skip to content

Revert "[MLIR][Vector] Generalize DropUnitDimFromElementwiseOps to non leading / trailing dimensions." #97652

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 3, 2024

Conversation

hanhanW
Copy link
Contributor

@hanhanW hanhanW commented Jul 3, 2024

Reverts #92934 because it breaks some lowering. To repro: mlir-opt -test-vector-transfer-flatten-patterns ~/repro.mlir

func.func @unit_dim_folding(%arg0: vector<1x1xf32>) -> vector<1x1xf32> {
  %cst = arith.constant dense<0.000000e+00> : vector<1x1xf32>
  %0 = arith.mulf %arg0, %cst : vector<1x1xf32>
  return %0 : vector<1x1xf32>
}

@hanhanW hanhanW requested a review from nujaa July 3, 2024 23:03
@hanhanW hanhanW requested a review from nicolasvasilache as a code owner July 3, 2024 23:03
@hanhanW hanhanW merged commit eaabd76 into main Jul 3, 2024
8 of 10 checks passed
@hanhanW hanhanW deleted the revert-92934-hugo.dropUnitDimsGen branch July 3, 2024 23:03
@llvmbot
Copy link
Member

llvmbot commented Jul 3, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-vector

Author: Han-Chung Wang (hanhanW)

Changes

Reverts llvm/llvm-project#92934 because it breaks some lowering. To repro: mlir-opt -test-vector-transfer-flatten-patterns ~/repro.mlir

func.func @<!-- -->unit_dim_folding(%arg0: vector&lt;1x1xf32&gt;) -&gt; vector&lt;1x1xf32&gt; {
  %cst = arith.constant dense&lt;0.000000e+00&gt; : vector&lt;1x1xf32&gt;
  %0 = arith.mulf %arg0, %cst : vector&lt;1x1xf32&gt;
  return %0 : vector&lt;1x1xf32&gt;
}

Full diff: https://github.com/llvm/llvm-project/pull/97652.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp (+26-29)
  • (modified) mlir/test/Dialect/Vector/vector-transfer-flatten.mlir (-36)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index c7d3022eff4d3..da5954b70a2ec 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1622,27 +1622,7 @@ struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> {
   }
 };
 
-// Scalable unit dimensions are not supported. Folding such dimensions would
-// require "shifting" the scalable flag onto some other fixed-width dim (e.g.
-// vector<[1]x4xf32> -> vector<[4]xf32>). This could be implemented in the
-// future.
-static VectorType dropNonScalableUnitDimFromType(VectorType inVecTy) {
-  auto inVecShape = inVecTy.getShape();
-  SmallVector<int64_t> newShape;
-  SmallVector<bool> newScalableDims;
-  for (auto [dim, isScalable] :
-       llvm::zip_equal(inVecShape, inVecTy.getScalableDims())) {
-    if (dim == 1 && !isScalable)
-      continue;
-
-    newShape.push_back(dim);
-    newScalableDims.push_back(isScalable);
-  }
-
-  return VectorType::get(newShape, inVecTy.getElementType(), newScalableDims);
-}
-
-/// For vectors with at least an unit dim, replaces:
+/// For vectors with either leading or trailing unit dim, replaces:
 ///   elementwise(a, b)
 /// with:
 ///   sc_a = shape_cast(a)
@@ -1654,16 +1634,20 @@ static VectorType dropNonScalableUnitDimFromType(VectorType inVecTy) {
 /// required to be rank > 1.
 ///
 /// Ex:
+/// ```
 ///  %mul = arith.mulf %B_row, %A_row : vector<1x[4]xf32>
 ///  %cast = vector.shape_cast %mul : vector<1x[4]xf32> to vector<[4]xf32>
+/// ```
 ///
 /// gets converted to:
 ///
+/// ```
 ///  %B_row_sc = vector.shape_cast %B_row : vector<1x[4]xf32> to vector<[4]xf32>
 ///  %A_row_sc = vector.shape_cast %A_row : vector<1x[4]xf32> to vector<[4]xf32>
 ///  %mul = arith.mulf %B_row_sc, %A_row_sc : vector<[4]xf32>
 ///  %cast_new = vector.shape_cast %mul : vector<[4]xf32> to vector<1x[4]xf32>
 ///  %cast = vector.shape_cast %cast_new : vector<1x[4]xf32> to vector<[4]xf32>
+/// ```
 ///
 /// Patterns for folding shape_casts should instantly eliminate `%cast_new` and
 /// `%cast`.
@@ -1683,29 +1667,42 @@ struct DropUnitDimFromElementwiseOps final
     // guaranteed to have identical shapes (with some exceptions such as
     // `arith.select`) and it suffices to only check one of them.
     auto sourceVectorType = dyn_cast<VectorType>(op->getOperand(0).getType());
-    if (!sourceVectorType || sourceVectorType.getRank() < 2)
+    if (!sourceVectorType)
+      return failure();
+    if (sourceVectorType.getRank() < 2)
+      return failure();
+
+    bool hasTrailingDimUnitFixed =
+        ((sourceVectorType.getShape().back() == 1) &&
+         (!sourceVectorType.getScalableDims().back()));
+    bool hasLeadingDimUnitFixed =
+        ((sourceVectorType.getShape().front() == 1) &&
+         (!sourceVectorType.getScalableDims().front()));
+    if (!hasLeadingDimUnitFixed && !hasTrailingDimUnitFixed)
       return failure();
 
+    // Drop leading/trailing unit dim by applying vector.shape_cast to all
+    // operands
+    int64_t dim = hasLeadingDimUnitFixed ? 0 : sourceVectorType.getRank() - 1;
+
     SmallVector<Value> newOperands;
     auto loc = op->getLoc();
     for (auto operand : op->getOperands()) {
       auto opVectorType = cast<VectorType>(operand.getType());
-      auto newVType = dropNonScalableUnitDimFromType(opVectorType);
-      if (newVType == opVectorType)
-        return rewriter.notifyMatchFailure(op, "No unit dimension to remove.");
-
+      VectorType newVType = VectorType::Builder(opVectorType).dropDim(dim);
       auto opSC = rewriter.create<vector::ShapeCastOp>(loc, newVType, operand);
       newOperands.push_back(opSC);
     }
 
     VectorType newResultVectorType =
-        dropNonScalableUnitDimFromType(resultVectorType);
-    // Create an updated elementwise Op without unit dim.
+        VectorType::Builder(resultVectorType).dropDim(dim);
+    // Create an updated elementwise Op without leading/trailing unit dim
     Operation *elementwiseOp =
         rewriter.create(loc, op->getName().getIdentifier(), newOperands,
                         newResultVectorType, op->getAttrs());
 
-    // Restore the unit dim by applying vector.shape_cast to the result.
+    // Restore the leading/trailing unit dim by applying vector.shape_cast
+    // to the result
     rewriter.replaceOpWithNewOp<ShapeCastOp>(op, resultVectorType,
                                              elementwiseOp->getResult(0));
 
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index 3a5041fca53fc..5fd3cbd54aa58 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -604,42 +604,6 @@ func.func @fold_unit_dims_entirely(%arg0 : vector<8xi32>,
 
 // -----
 
-func.func @fold_inner_unit_dim(%arg0 : vector<8x1x3xf128>,
-                              %arg1 : vector<1x8x3xf128>) -> vector<8x3xf128> {
-   %sc_arg1 = vector.shape_cast %arg1 : vector<1x8x3xf128> to vector<8x1x3xf128>
-   %mul = arith.mulf %arg0, %sc_arg1 : vector<8x1x3xf128>
-   %res = vector.shape_cast %mul : vector<8x1x3xf128> to vector<8x3xf128>
-   return %res : vector<8x3xf128>
-}
-
-// CHECK-LABEL: func.func @fold_inner_unit_dim(
-// CHECK-SAME:    %[[VAL_0:.*]]: vector<8x1x3xf128>,
-// CHECK-SAME:    %[[VAL_1:.*]]: vector<1x8x3xf128>) -> vector<8x3xf128> {
-// CHECK:         %[[VAL_2:.*]] = vector.shape_cast %[[VAL_0]] : vector<8x1x3xf128> to vector<8x3xf128>
-// CHECK:         %[[VAL_3:.*]] = vector.shape_cast %[[VAL_1]] : vector<1x8x3xf128> to vector<8x3xf128>
-// CHECK:         %[[VAL_4:.*]] = arith.mulf %[[VAL_2]], %[[VAL_3]] : vector<8x3xf128>
-// CHECK:         return %[[VAL_4]] : vector<8x3xf128>
-
-// -----
-
-func.func @fold_inner_unit_dim_scalable(%arg0 : vector<8x1x[1]x3xf128>,
-                              %arg1 : vector<1x8x[1]x3xf128>) -> vector<8x[1]x3xf128> {
-   %sc_arg1 = vector.shape_cast %arg1 : vector<1x8x[1]x3xf128> to vector<8x1x[1]x3xf128>
-   %mul = arith.mulf %arg0, %sc_arg1 : vector<8x1x[1]x3xf128>
-   %res = vector.shape_cast %mul : vector<8x1x[1]x3xf128> to vector<8x[1]x3xf128>
-   return %res : vector<8x[1]x3xf128>
-}
-
-// CHECK-LABEL: func.func @fold_inner_unit_dim_scalable(
-// CHECK-SAME:    %[[VAL_0:.*]]: vector<8x1x[1]x3xf128>,
-// CHECK-SAME:    %[[VAL_1:.*]]: vector<1x8x[1]x3xf128>) -> vector<8x[1]x3xf128> {
-// CHECK:         %[[VAL_2:.*]] = vector.shape_cast %[[VAL_0]] : vector<8x1x[1]x3xf128> to vector<8x[1]x3xf128>
-// CHECK:         %[[VAL_3:.*]] = vector.shape_cast %[[VAL_1]] : vector<1x8x[1]x3xf128> to vector<8x[1]x3xf128>
-// CHECK:         %[[VAL_4:.*]] = arith.mulf %[[VAL_2]], %[[VAL_3]] : vector<8x[1]x3xf128>
-// CHECK:         return %[[VAL_4]] : vector<8x[1]x3xf128>
-
-// -----
-
 func.func @negative_out_of_bound_transfer_read(
     %arg : memref<?x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<5x4x3x2xi8> {
   %c0 = arith.constant 0 : index

kbluck pushed a commit to kbluck/llvm-project that referenced this pull request Jul 6, 2024
…n leading / trailing dimensions." (llvm#97652)

Reverts llvm#92934 because it breaks some lowering. To
repro: `mlir-opt -test-vector-transfer-flatten-patterns ~/repro.mlir`

```mlir
func.func @unit_dim_folding(%arg0: vector<1x1xf32>) -> vector<1x1xf32> {
  %cst = arith.constant dense<0.000000e+00> : vector<1x1xf32>
  %0 = arith.mulf %arg0, %cst : vector<1x1xf32>
  return %0 : vector<1x1xf32>
}
```
@banach-space
Copy link
Contributor

@nujaa Are you investigating the repro?

@nujaa
Copy link
Contributor

nujaa commented Jul 10, 2024

Hi, Yes, I recently came back from some time off. I'm getting back on track and I'll be at it.

@nujaa
Copy link
Contributor

nujaa commented Jul 11, 2024

Thanks for noticing and reverting the bug in my absence.

nujaa added a commit that referenced this pull request Jul 17, 2024
…g / trailing dimensions. (#98455)

Generalizes DropUnitDimFromElementwiseOps to support inner unit
dimensions.
This change stems from improving lowering of contractionOps for Arm SME.
Where we end up with inner unit dimensions on MulOp, BroadcastOp and
TransposeOp, preventing the generation of outerproducts.
discussed
[here](https://discourse.llvm.org/t/on-improving-arm-sme-lowering-resilience-in-mlir/78543/17?u=nujaa).

Fix after : #97652 showed an
unhandled edge case when all dimensions are one. The generated target
VectorType would be `vector<f32>` which is apparently not supported by
the mulf.
In case all dimensions are dropped, the target vectorType is
vector<1xf32>

---------

Co-authored-by: Benjamin Maxwell <[email protected]>
yuxuanchen1997 pushed a commit that referenced this pull request Jul 25, 2024
…g / trailing dimensions. (#98455)

Summary:
Generalizes DropUnitDimFromElementwiseOps to support inner unit
dimensions.
This change stems from improving lowering of contractionOps for Arm SME.
Where we end up with inner unit dimensions on MulOp, BroadcastOp and
TransposeOp, preventing the generation of outerproducts.
discussed
[here](https://discourse.llvm.org/t/on-improving-arm-sme-lowering-resilience-in-mlir/78543/17?u=nujaa).

Fix after : #97652 showed an
unhandled edge case when all dimensions are one. The generated target
VectorType would be `vector<f32>` which is apparently not supported by
the mulf.
In case all dimensions are dropped, the target vectorType is
vector<1xf32>

---------

Co-authored-by: Benjamin Maxwell <[email protected]>

Test Plan: 

Reviewers: 

Subscribers: 

Tasks: 

Tags: 


Differential Revision: https://phabricator.intern.facebook.com/D60251689
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants