-
Notifications
You must be signed in to change notification settings - Fork 13.5k
Revert "[MLIR][Vector] Generalize DropUnitDimFromElementwiseOps to non leading / trailing dimensions." #97652
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…n leadin…" This reverts commit 2c06fb8.
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-vector Author: Han-Chung Wang (hanhanW) ChangesReverts llvm/llvm-project#92934 because it breaks some lowering. To repro: func.func @<!-- -->unit_dim_folding(%arg0: vector<1x1xf32>) -> vector<1x1xf32> {
%cst = arith.constant dense<0.000000e+00> : vector<1x1xf32>
%0 = arith.mulf %arg0, %cst : vector<1x1xf32>
return %0 : vector<1x1xf32>
} Full diff: https://github.com/llvm/llvm-project/pull/97652.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index c7d3022eff4d3..da5954b70a2ec 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1622,27 +1622,7 @@ struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> {
}
};
-// Scalable unit dimensions are not supported. Folding such dimensions would
-// require "shifting" the scalable flag onto some other fixed-width dim (e.g.
-// vector<[1]x4xf32> -> vector<[4]xf32>). This could be implemented in the
-// future.
-static VectorType dropNonScalableUnitDimFromType(VectorType inVecTy) {
- auto inVecShape = inVecTy.getShape();
- SmallVector<int64_t> newShape;
- SmallVector<bool> newScalableDims;
- for (auto [dim, isScalable] :
- llvm::zip_equal(inVecShape, inVecTy.getScalableDims())) {
- if (dim == 1 && !isScalable)
- continue;
-
- newShape.push_back(dim);
- newScalableDims.push_back(isScalable);
- }
-
- return VectorType::get(newShape, inVecTy.getElementType(), newScalableDims);
-}
-
-/// For vectors with at least an unit dim, replaces:
+/// For vectors with either leading or trailing unit dim, replaces:
/// elementwise(a, b)
/// with:
/// sc_a = shape_cast(a)
@@ -1654,16 +1634,20 @@ static VectorType dropNonScalableUnitDimFromType(VectorType inVecTy) {
/// required to be rank > 1.
///
/// Ex:
+/// ```
/// %mul = arith.mulf %B_row, %A_row : vector<1x[4]xf32>
/// %cast = vector.shape_cast %mul : vector<1x[4]xf32> to vector<[4]xf32>
+/// ```
///
/// gets converted to:
///
+/// ```
/// %B_row_sc = vector.shape_cast %B_row : vector<1x[4]xf32> to vector<[4]xf32>
/// %A_row_sc = vector.shape_cast %A_row : vector<1x[4]xf32> to vector<[4]xf32>
/// %mul = arith.mulf %B_row_sc, %A_row_sc : vector<[4]xf32>
/// %cast_new = vector.shape_cast %mul : vector<[4]xf32> to vector<1x[4]xf32>
/// %cast = vector.shape_cast %cast_new : vector<1x[4]xf32> to vector<[4]xf32>
+/// ```
///
/// Patterns for folding shape_casts should instantly eliminate `%cast_new` and
/// `%cast`.
@@ -1683,29 +1667,42 @@ struct DropUnitDimFromElementwiseOps final
// guaranteed to have identical shapes (with some exceptions such as
// `arith.select`) and it suffices to only check one of them.
auto sourceVectorType = dyn_cast<VectorType>(op->getOperand(0).getType());
- if (!sourceVectorType || sourceVectorType.getRank() < 2)
+ if (!sourceVectorType)
+ return failure();
+ if (sourceVectorType.getRank() < 2)
+ return failure();
+
+ bool hasTrailingDimUnitFixed =
+ ((sourceVectorType.getShape().back() == 1) &&
+ (!sourceVectorType.getScalableDims().back()));
+ bool hasLeadingDimUnitFixed =
+ ((sourceVectorType.getShape().front() == 1) &&
+ (!sourceVectorType.getScalableDims().front()));
+ if (!hasLeadingDimUnitFixed && !hasTrailingDimUnitFixed)
return failure();
+ // Drop leading/trailing unit dim by applying vector.shape_cast to all
+ // operands
+ int64_t dim = hasLeadingDimUnitFixed ? 0 : sourceVectorType.getRank() - 1;
+
SmallVector<Value> newOperands;
auto loc = op->getLoc();
for (auto operand : op->getOperands()) {
auto opVectorType = cast<VectorType>(operand.getType());
- auto newVType = dropNonScalableUnitDimFromType(opVectorType);
- if (newVType == opVectorType)
- return rewriter.notifyMatchFailure(op, "No unit dimension to remove.");
-
+ VectorType newVType = VectorType::Builder(opVectorType).dropDim(dim);
auto opSC = rewriter.create<vector::ShapeCastOp>(loc, newVType, operand);
newOperands.push_back(opSC);
}
VectorType newResultVectorType =
- dropNonScalableUnitDimFromType(resultVectorType);
- // Create an updated elementwise Op without unit dim.
+ VectorType::Builder(resultVectorType).dropDim(dim);
+ // Create an updated elementwise Op without leading/trailing unit dim
Operation *elementwiseOp =
rewriter.create(loc, op->getName().getIdentifier(), newOperands,
newResultVectorType, op->getAttrs());
- // Restore the unit dim by applying vector.shape_cast to the result.
+ // Restore the leading/trailing unit dim by applying vector.shape_cast
+ // to the result
rewriter.replaceOpWithNewOp<ShapeCastOp>(op, resultVectorType,
elementwiseOp->getResult(0));
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index 3a5041fca53fc..5fd3cbd54aa58 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -604,42 +604,6 @@ func.func @fold_unit_dims_entirely(%arg0 : vector<8xi32>,
// -----
-func.func @fold_inner_unit_dim(%arg0 : vector<8x1x3xf128>,
- %arg1 : vector<1x8x3xf128>) -> vector<8x3xf128> {
- %sc_arg1 = vector.shape_cast %arg1 : vector<1x8x3xf128> to vector<8x1x3xf128>
- %mul = arith.mulf %arg0, %sc_arg1 : vector<8x1x3xf128>
- %res = vector.shape_cast %mul : vector<8x1x3xf128> to vector<8x3xf128>
- return %res : vector<8x3xf128>
-}
-
-// CHECK-LABEL: func.func @fold_inner_unit_dim(
-// CHECK-SAME: %[[VAL_0:.*]]: vector<8x1x3xf128>,
-// CHECK-SAME: %[[VAL_1:.*]]: vector<1x8x3xf128>) -> vector<8x3xf128> {
-// CHECK: %[[VAL_2:.*]] = vector.shape_cast %[[VAL_0]] : vector<8x1x3xf128> to vector<8x3xf128>
-// CHECK: %[[VAL_3:.*]] = vector.shape_cast %[[VAL_1]] : vector<1x8x3xf128> to vector<8x3xf128>
-// CHECK: %[[VAL_4:.*]] = arith.mulf %[[VAL_2]], %[[VAL_3]] : vector<8x3xf128>
-// CHECK: return %[[VAL_4]] : vector<8x3xf128>
-
-// -----
-
-func.func @fold_inner_unit_dim_scalable(%arg0 : vector<8x1x[1]x3xf128>,
- %arg1 : vector<1x8x[1]x3xf128>) -> vector<8x[1]x3xf128> {
- %sc_arg1 = vector.shape_cast %arg1 : vector<1x8x[1]x3xf128> to vector<8x1x[1]x3xf128>
- %mul = arith.mulf %arg0, %sc_arg1 : vector<8x1x[1]x3xf128>
- %res = vector.shape_cast %mul : vector<8x1x[1]x3xf128> to vector<8x[1]x3xf128>
- return %res : vector<8x[1]x3xf128>
-}
-
-// CHECK-LABEL: func.func @fold_inner_unit_dim_scalable(
-// CHECK-SAME: %[[VAL_0:.*]]: vector<8x1x[1]x3xf128>,
-// CHECK-SAME: %[[VAL_1:.*]]: vector<1x8x[1]x3xf128>) -> vector<8x[1]x3xf128> {
-// CHECK: %[[VAL_2:.*]] = vector.shape_cast %[[VAL_0]] : vector<8x1x[1]x3xf128> to vector<8x[1]x3xf128>
-// CHECK: %[[VAL_3:.*]] = vector.shape_cast %[[VAL_1]] : vector<1x8x[1]x3xf128> to vector<8x[1]x3xf128>
-// CHECK: %[[VAL_4:.*]] = arith.mulf %[[VAL_2]], %[[VAL_3]] : vector<8x[1]x3xf128>
-// CHECK: return %[[VAL_4]] : vector<8x[1]x3xf128>
-
-// -----
-
func.func @negative_out_of_bound_transfer_read(
%arg : memref<?x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<5x4x3x2xi8> {
%c0 = arith.constant 0 : index
|
…n leading / trailing dimensions." (llvm#97652) Reverts llvm#92934 because it breaks some lowering. To repro: `mlir-opt -test-vector-transfer-flatten-patterns ~/repro.mlir` ```mlir func.func @unit_dim_folding(%arg0: vector<1x1xf32>) -> vector<1x1xf32> { %cst = arith.constant dense<0.000000e+00> : vector<1x1xf32> %0 = arith.mulf %arg0, %cst : vector<1x1xf32> return %0 : vector<1x1xf32> } ```
@nujaa Are you investigating the repro? |
Hi, Yes, I recently came back from some time off. I'm getting back on track and I'll be at it. |
Thanks for noticing and reverting the bug in my absence. |
…g / trailing dimensions. (#98455) Generalizes DropUnitDimFromElementwiseOps to support inner unit dimensions. This change stems from improving lowering of contractionOps for Arm SME. Where we end up with inner unit dimensions on MulOp, BroadcastOp and TransposeOp, preventing the generation of outerproducts. discussed [here](https://discourse.llvm.org/t/on-improving-arm-sme-lowering-resilience-in-mlir/78543/17?u=nujaa). Fix after : #97652 showed an unhandled edge case when all dimensions are one. The generated target VectorType would be `vector<f32>` which is apparently not supported by the mulf. In case all dimensions are dropped, the target vectorType is vector<1xf32> --------- Co-authored-by: Benjamin Maxwell <[email protected]>
…g / trailing dimensions. (#98455) Summary: Generalizes DropUnitDimFromElementwiseOps to support inner unit dimensions. This change stems from improving lowering of contractionOps for Arm SME. Where we end up with inner unit dimensions on MulOp, BroadcastOp and TransposeOp, preventing the generation of outerproducts. discussed [here](https://discourse.llvm.org/t/on-improving-arm-sme-lowering-resilience-in-mlir/78543/17?u=nujaa). Fix after : #97652 showed an unhandled edge case when all dimensions are one. The generated target VectorType would be `vector<f32>` which is apparently not supported by the mulf. In case all dimensions are dropped, the target vectorType is vector<1xf32> --------- Co-authored-by: Benjamin Maxwell <[email protected]> Test Plan: Reviewers: Subscribers: Tasks: Tags: Differential Revision: https://phabricator.intern.facebook.com/D60251689
Reverts #92934 because it breaks some lowering. To repro:
mlir-opt -test-vector-transfer-flatten-patterns ~/repro.mlir