Skip to content

Commit 0377562

Browse files
[mlir]Add a check to ensure bailing out when reducing to a scalar (llvm#129694)
Fixes issue llvm#64075 Referencing this comment for more detailed view -> llvm#64075 (comment) **Minimal example crashing :** ``` func.func @multi_reduction(%0: vector<4x2xf32>, %acc1: f32) -> f32 { %2 = vector.multi_reduction <add>, %0, %acc1 [0, 1] : vector<4x2xf32> to f32 return %2 : f32 } ```
1 parent 46d218d commit 0377562

File tree

2 files changed

+14
-0
lines changed

2 files changed

+14
-0
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,11 @@ struct UnrollMultiReductionPattern
355355

356356
LogicalResult matchAndRewrite(vector::MultiDimReductionOp reductionOp,
357357
PatternRewriter &rewriter) const override {
358+
auto resultType = reductionOp->getResult(0).getType();
359+
if (resultType.isIntOrFloat()) {
360+
return rewriter.notifyMatchFailure(reductionOp,
361+
"Unrolling scalars is not supported");
362+
}
358363
std::optional<SmallVector<int64_t>> targetShape =
359364
getTargetShape(options, reductionOp);
360365
if (!targetShape)

mlir/test/Dialect/Vector/vector-unroll-options.mlir

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,15 @@ func.func @vector_multi_reduction(%v : vector<4x6xf32>, %acc: vector<4xf32>) ->
222222
// CHECK: %[[V2:.*]] = vector.insert_strided_slice %[[R5]], %[[V1]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
223223
// CHECK: return %[[V2]] : vector<4xf32>
224224

225+
// This is a negative test case to ensure that further unrolling is not performed. Since the vector.multi_reduction
226+
// operation has already been unrolled, attempting additional unrolling should not be allowed.
227+
func.func @negative_vector_multi_reduction(%v: vector<4x2xf32>, %acc: f32) -> f32 {
228+
%0 = vector.multi_reduction #vector.kind<add>, %v, %acc [0, 1] : vector<4x2xf32> to f32
229+
return %0 : f32
230+
}
231+
// CHECK-LABEL: func @negative_vector_multi_reduction
232+
// CHECK-NEXT: %[[R0:.*]] = vector.multi_reduction <add>, %{{.*}}, %{{.*}} [0, 1] : vector<4x2xf32> to f32
233+
// CHECK-NEXT: return %[[R0]] : f32
225234

226235
func.func @vector_reduction(%v : vector<8xf32>) -> f32 {
227236
%0 = vector.reduction <add>, %v : vector<8xf32> into f32

0 commit comments

Comments
 (0)