Skip to content

Commit e3c9c82

Browse files
hanhanWdcaballe
andauthored
[mlir][MemRef] Extend memref.subview sub-byte type emulation support. (llvm#94045)
In some cases (see iree-org/iree#16285), `memref.subview` ops can't be folded into transfer ops and sub-byte type emulation fails. This issue has been blocking a few things, including the enablement of vector flattening transformations (iree-org/iree#16456). This PR extends the existing sub-byte type emulation support of `memref.subview` to handle multi-dimensional subviews with dynamic offsets and addresses the issues for some of the `memref.subview` cases that can't be folded. Co-authored-by: Diego Caballero <[email protected]>
1 parent 4c416a9 commit e3c9c82

File tree

3 files changed

+105
-34
lines changed

3 files changed

+105
-34
lines changed

mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp

+67-29
Original file line numberDiff line numberDiff line change
@@ -32,18 +32,14 @@ using namespace mlir;
3232
// Utility functions
3333
//===----------------------------------------------------------------------===//
3434

35-
/// Converts a memref::SubViewOp or memref::ReinterpretCastOp to the converted
36-
/// type. The result MemRefType of the old op must have a rank and stride of 1,
37-
/// with static offset and size. The number of bits in the offset must evenly
38-
/// divide the bitwidth of the new converted type.
39-
template <typename MemRefOpTy>
40-
static LogicalResult convertCastingOp(ConversionPatternRewriter &rewriter,
41-
typename MemRefOpTy::Adaptor adaptor,
42-
MemRefOpTy op, MemRefType newTy) {
43-
static_assert(std::is_same<MemRefOpTy, memref::SubViewOp>() ||
44-
std::is_same<MemRefOpTy, memref::ReinterpretCastOp>(),
45-
"Expected only memref::SubViewOp or memref::ReinterpretCastOp");
46-
35+
/// Converts a memref::ReinterpretCastOp to the converted type. The result
36+
/// MemRefType of the old op must have a rank and stride of 1, with static
37+
/// offset and size. The number of bits in the offset must evenly divide the
38+
/// bitwidth of the new converted type.
39+
static LogicalResult
40+
convertCastingOp(ConversionPatternRewriter &rewriter,
41+
memref::ReinterpretCastOp::Adaptor adaptor,
42+
memref::ReinterpretCastOp op, MemRefType newTy) {
4743
auto convertedElementType = newTy.getElementType();
4844
auto oldElementType = op.getType().getElementType();
4945
int srcBits = oldElementType.getIntOrFloatBitWidth();
@@ -67,24 +63,22 @@ static LogicalResult convertCastingOp(ConversionPatternRewriter &rewriter,
6763
[](int64_t size) { return size == ShapedType::kDynamic; }) ||
6864
offset == ShapedType::kDynamic) {
6965
return rewriter.notifyMatchFailure(
70-
op->getLoc(), "dynamic size or offset is not supported");
66+
op, "dynamic size or offset is not supported");
7167
}
7268

7369
int elementsPerByte = dstBits / srcBits;
7470
if (offset % elementsPerByte != 0) {
7571
return rewriter.notifyMatchFailure(
76-
op->getLoc(), "offset not multiple of elementsPerByte is not "
77-
"supported");
72+
op, "offset not multiple of elementsPerByte is not supported");
7873
}
7974

8075
SmallVector<int64_t> size;
8176
if (sizes.size())
8277
size.push_back(ceilDiv(sizes[0], elementsPerByte));
8378
offset = offset / elementsPerByte;
8479

85-
rewriter.replaceOpWithNewOp<MemRefOpTy>(op, newTy,
86-
*adaptor.getODSOperands(0).begin(),
87-
offset, size, op.getStaticStrides());
80+
rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
81+
op, newTy, adaptor.getSource(), offset, size, op.getStaticStrides());
8882
return success();
8983
}
9084

@@ -402,29 +396,73 @@ struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> {
402396

403397
/// Emulating narrow ints on subview have limited support, supporting only
404398
/// static offset and size and stride of 1. Ideally, the subview should be
405-
/// folded away before running narrow type emulation, and this pattern would
406-
/// never run. This pattern is mostly used for testing pruposes.
399+
/// folded away before running narrow type emulation, and this pattern should
400+
/// only run for cases that can't be folded.
407401
struct ConvertMemRefSubview final : OpConversionPattern<memref::SubViewOp> {
408402
using OpConversionPattern::OpConversionPattern;
409403

410404
LogicalResult
411-
matchAndRewrite(memref::SubViewOp op, OpAdaptor adaptor,
405+
matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor,
412406
ConversionPatternRewriter &rewriter) const override {
413-
MemRefType newTy =
414-
dyn_cast<MemRefType>(getTypeConverter()->convertType(op.getType()));
407+
MemRefType newTy = dyn_cast<MemRefType>(
408+
getTypeConverter()->convertType(subViewOp.getType()));
415409
if (!newTy) {
416410
return rewriter.notifyMatchFailure(
417-
op->getLoc(),
418-
llvm::formatv("failed to convert memref type: {0}", op.getType()));
411+
subViewOp->getLoc(),
412+
llvm::formatv("failed to convert memref type: {0}",
413+
subViewOp.getType()));
419414
}
420415

421-
// Only support offset for 1-D subview.
422-
if (op.getType().getRank() != 1) {
416+
Location loc = subViewOp.getLoc();
417+
Type convertedElementType = newTy.getElementType();
418+
Type oldElementType = subViewOp.getType().getElementType();
419+
int srcBits = oldElementType.getIntOrFloatBitWidth();
420+
int dstBits = convertedElementType.getIntOrFloatBitWidth();
421+
if (dstBits % srcBits != 0)
423422
return rewriter.notifyMatchFailure(
424-
op->getLoc(), "subview with rank > 1 is not supported");
423+
subViewOp, "only dstBits % srcBits == 0 supported");
424+
425+
// Only support stride of 1.
426+
if (llvm::any_of(subViewOp.getStaticStrides(),
427+
[](int64_t stride) { return stride != 1; })) {
428+
return rewriter.notifyMatchFailure(subViewOp->getLoc(),
429+
"stride != 1 is not supported");
425430
}
426431

427-
return convertCastingOp(rewriter, adaptor, op, newTy);
432+
if (!memref::isStaticShapeAndContiguousRowMajor(subViewOp.getType())) {
433+
return rewriter.notifyMatchFailure(
434+
subViewOp, "the result memref type is not contiguous");
435+
}
436+
437+
auto sizes = subViewOp.getStaticSizes();
438+
int64_t lastOffset = subViewOp.getStaticOffsets().back();
439+
// Only support static sizes and offsets.
440+
if (llvm::any_of(
441+
sizes, [](int64_t size) { return size == ShapedType::kDynamic; }) ||
442+
lastOffset == ShapedType::kDynamic) {
443+
return rewriter.notifyMatchFailure(
444+
subViewOp->getLoc(), "dynamic size or offset is not supported");
445+
}
446+
447+
// Transform the offsets, sizes and strides according to the emulation.
448+
auto stridedMetadata = rewriter.create<memref::ExtractStridedMetadataOp>(
449+
loc, subViewOp.getViewSource());
450+
451+
OpFoldResult linearizedIndices;
452+
auto strides = stridedMetadata.getConstifiedMixedStrides();
453+
memref::LinearizedMemRefInfo linearizedInfo;
454+
std::tie(linearizedInfo, linearizedIndices) =
455+
memref::getLinearizedMemRefOffsetAndSize(
456+
rewriter, loc, srcBits, dstBits,
457+
stridedMetadata.getConstifiedMixedOffset(),
458+
subViewOp.getMixedSizes(), strides,
459+
getMixedValues(adaptor.getStaticOffsets(), adaptor.getOffsets(),
460+
rewriter));
461+
462+
rewriter.replaceOpWithNewOp<memref::SubViewOp>(
463+
subViewOp, newTy, adaptor.getSource(), linearizedIndices,
464+
linearizedInfo.linearizedSize, strides.back());
465+
return success();
428466
}
429467
};
430468

mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp

+1-3
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,6 @@ std::pair<LinearizedMemRefInfo, OpFoldResult> getLinearizedMemRefOffsetAndSize(
6868
AffineExpr mulMap = builder.getAffineConstantExpr(1);
6969

7070
SmallVector<OpFoldResult> offsetValues(2 * sourceRank);
71-
SmallVector<OpFoldResult> sizeValues(sourceRank);
7271

7372
for (unsigned i = 0; i < sourceRank; ++i) {
7473
unsigned offsetIdx = 2 * i;
@@ -79,8 +78,7 @@ std::pair<LinearizedMemRefInfo, OpFoldResult> getLinearizedMemRefOffsetAndSize(
7978
mulMap = mulMap * symbols[i];
8079
}
8180

82-
// Adjust linearizedIndices, size and offset by the scale factor (dstBits /
83-
// srcBits).
81+
// Adjust linearizedIndices and size by the scale factor (dstBits / srcBits).
8482
int64_t scaler = dstBits / srcBits;
8583
addMulMap = addMulMap.floorDiv(scaler);
8684
mulMap = mulMap.floorDiv(scaler);

mlir/test/Dialect/MemRef/emulate-narrow-type.mlir

+37-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
// RUN: mlir-opt --test-emulate-narrow-int="memref-load-bitwidth=8" --cse --split-input-file %s | FileCheck %s
2-
// RUN: mlir-opt --test-emulate-narrow-int="memref-load-bitwidth=32" --cse --split-input-file %s | FileCheck %s --check-prefix=CHECK32
1+
// RUN: mlir-opt --test-emulate-narrow-int="memref-load-bitwidth=8" --cse --verify-diagnostics --split-input-file %s | FileCheck %s
2+
// RUN: mlir-opt --test-emulate-narrow-int="memref-load-bitwidth=32" --cse --verify-diagnostics --split-input-file %s | FileCheck %s --check-prefix=CHECK32
33

44
// Expect no conversions.
55
func.func @memref_i8() -> i8 {
@@ -177,6 +177,41 @@ func.func @memref_strided_i4(%idx : index) -> i4 {
177177

178178
// -----
179179

180+
func.func @memref_subview_dynamic_offset_i4(%idx : index) -> i4 {
181+
%c0 = arith.constant 0 : index
182+
%arr = memref.alloc() : memref<512x64x8x16xi4>
183+
%subview = memref.subview %arr[%idx, 0, 0, 0] [16, 64, 8, 16] [1, 1, 1, 1] : memref<512x64x8x16xi4>
184+
to memref<16x64x8x16xi4, strided<[8192, 128, 16, 1], offset: ?>>
185+
%ld = memref.load %subview[%c0, %c0, %c0, %c0] : memref<16x64x8x16xi4, strided<[8192, 128, 16, 1], offset: ?>>
186+
return %ld : i4
187+
}
188+
189+
// CHECK-LABEL: func.func @memref_subview_dynamic_offset_i4(
190+
// CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<2097152xi8>
191+
// CHECK: %[[IDX:.*]] = affine.apply
192+
// CHECK: %[[SUBVIEW:.*]] = memref.subview %[[ALLOC]][%[[IDX]]] [65536] [1] : memref<2097152xi8> to memref<65536xi8, strided<[1], offset: ?>>
193+
// CHECK: memref.load %[[SUBVIEW]]
194+
195+
// CHECK32-LABEL: func.func @memref_subview_dynamic_offset_i4(
196+
// CHECK32: %[[ALLOC:.*]] = memref.alloc() : memref<524288xi32>
197+
// CHECK32: %[[IDX:.*]] = affine.apply
198+
// CHECK32: %[[SUBVIEW:.*]] = memref.subview %[[ALLOC]][%[[IDX]]] [16384] [1] : memref<524288xi32> to memref<16384xi32, strided<[1], offset: ?>>
199+
// CHECK32: memref.load %[[SUBVIEW]]
200+
201+
// -----
202+
203+
204+
func.func @negative_memref_subview_non_contiguous(%idx : index) -> i4 {
205+
%c0 = arith.constant 0 : index
206+
%arr = memref.alloc() : memref<40x40xi4>
207+
// expected-error @+1 {{failed to legalize operation 'memref.subview' that was explicitly marked illegal}}
208+
%subview = memref.subview %arr[%idx, 0] [4, 8] [1, 1] : memref<40x40xi4> to memref<4x8xi4, strided<[40, 1], offset:?>>
209+
%ld = memref.load %subview[%c0, %c0] : memref<4x8xi4, strided<[40, 1], offset:?>>
210+
return %ld : i4
211+
}
212+
213+
// -----
214+
180215
func.func @reinterpret_cast_memref_load_0D() -> i4 {
181216
%0 = memref.alloc() : memref<5xi4>
182217
%reinterpret_cast_0 = memref.reinterpret_cast %0 to offset: [0], sizes: [], strides: [] : memref<5xi4> to memref<i4>

0 commit comments

Comments
 (0)