@@ -32,18 +32,14 @@ using namespace mlir;
32
32
// Utility functions
33
33
// ===----------------------------------------------------------------------===//
34
34
35
- // / Converts a memref::SubViewOp or memref::ReinterpretCastOp to the converted
36
- // / type. The result MemRefType of the old op must have a rank and stride of 1,
37
- // / with static offset and size. The number of bits in the offset must evenly
38
- // / divide the bitwidth of the new converted type.
39
- template <typename MemRefOpTy>
40
- static LogicalResult convertCastingOp (ConversionPatternRewriter &rewriter,
41
- typename MemRefOpTy::Adaptor adaptor,
42
- MemRefOpTy op, MemRefType newTy) {
43
- static_assert (std::is_same<MemRefOpTy, memref::SubViewOp>() ||
44
- std::is_same<MemRefOpTy, memref::ReinterpretCastOp>(),
45
- " Expected only memref::SubViewOp or memref::ReinterpretCastOp" );
46
-
35
+ // / Converts a memref::ReinterpretCastOp to the converted type. The result
36
+ // / MemRefType of the old op must have a rank and stride of 1, with static
37
+ // / offset and size. The number of bits in the offset must evenly divide the
38
+ // / bitwidth of the new converted type.
39
+ static LogicalResult
40
+ convertCastingOp (ConversionPatternRewriter &rewriter,
41
+ memref::ReinterpretCastOp::Adaptor adaptor,
42
+ memref::ReinterpretCastOp op, MemRefType newTy) {
47
43
auto convertedElementType = newTy.getElementType ();
48
44
auto oldElementType = op.getType ().getElementType ();
49
45
int srcBits = oldElementType.getIntOrFloatBitWidth ();
@@ -67,24 +63,22 @@ static LogicalResult convertCastingOp(ConversionPatternRewriter &rewriter,
67
63
[](int64_t size) { return size == ShapedType::kDynamic ; }) ||
68
64
offset == ShapedType::kDynamic ) {
69
65
return rewriter.notifyMatchFailure (
70
- op-> getLoc () , " dynamic size or offset is not supported" );
66
+ op, " dynamic size or offset is not supported" );
71
67
}
72
68
73
69
int elementsPerByte = dstBits / srcBits;
74
70
if (offset % elementsPerByte != 0 ) {
75
71
return rewriter.notifyMatchFailure (
76
- op->getLoc (), " offset not multiple of elementsPerByte is not "
77
- " supported" );
72
+ op, " offset not multiple of elementsPerByte is not supported" );
78
73
}
79
74
80
75
SmallVector<int64_t > size;
81
76
if (sizes.size ())
82
77
size.push_back (ceilDiv (sizes[0 ], elementsPerByte));
83
78
offset = offset / elementsPerByte;
84
79
85
- rewriter.replaceOpWithNewOp <MemRefOpTy>(op, newTy,
86
- *adaptor.getODSOperands (0 ).begin (),
87
- offset, size, op.getStaticStrides ());
80
+ rewriter.replaceOpWithNewOp <memref::ReinterpretCastOp>(
81
+ op, newTy, adaptor.getSource (), offset, size, op.getStaticStrides ());
88
82
return success ();
89
83
}
90
84
@@ -402,29 +396,73 @@ struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> {
402
396
403
397
// / Emulating narrow ints on subview have limited support, supporting only
404
398
// / static offset and size and stride of 1. Ideally, the subview should be
405
- // / folded away before running narrow type emulation, and this pattern would
406
- // / never run. This pattern is mostly used for testing pruposes .
399
+ // / folded away before running narrow type emulation, and this pattern should
400
+ // / only run for cases that can't be folded .
407
401
struct ConvertMemRefSubview final : OpConversionPattern<memref::SubViewOp> {
408
402
using OpConversionPattern::OpConversionPattern;
409
403
410
404
LogicalResult
411
- matchAndRewrite (memref::SubViewOp op , OpAdaptor adaptor,
405
+ matchAndRewrite (memref::SubViewOp subViewOp , OpAdaptor adaptor,
412
406
ConversionPatternRewriter &rewriter) const override {
413
- MemRefType newTy =
414
- dyn_cast<MemRefType>( getTypeConverter ()->convertType (op .getType ()));
407
+ MemRefType newTy = dyn_cast<MemRefType>(
408
+ getTypeConverter ()->convertType (subViewOp .getType ()));
415
409
if (!newTy) {
416
410
return rewriter.notifyMatchFailure (
417
- op->getLoc (),
418
- llvm::formatv (" failed to convert memref type: {0}" , op.getType ()));
411
+ subViewOp->getLoc (),
412
+ llvm::formatv (" failed to convert memref type: {0}" ,
413
+ subViewOp.getType ()));
419
414
}
420
415
421
- // Only support offset for 1-D subview.
422
- if (op.getType ().getRank () != 1 ) {
416
+ Location loc = subViewOp.getLoc ();
417
+ Type convertedElementType = newTy.getElementType ();
418
+ Type oldElementType = subViewOp.getType ().getElementType ();
419
+ int srcBits = oldElementType.getIntOrFloatBitWidth ();
420
+ int dstBits = convertedElementType.getIntOrFloatBitWidth ();
421
+ if (dstBits % srcBits != 0 )
423
422
return rewriter.notifyMatchFailure (
424
- op->getLoc (), " subview with rank > 1 is not supported" );
423
+ subViewOp, " only dstBits % srcBits == 0 supported" );
424
+
425
+ // Only support stride of 1.
426
+ if (llvm::any_of (subViewOp.getStaticStrides (),
427
+ [](int64_t stride) { return stride != 1 ; })) {
428
+ return rewriter.notifyMatchFailure (subViewOp->getLoc (),
429
+ " stride != 1 is not supported" );
425
430
}
426
431
427
- return convertCastingOp (rewriter, adaptor, op, newTy);
432
+ if (!memref::isStaticShapeAndContiguousRowMajor (subViewOp.getType ())) {
433
+ return rewriter.notifyMatchFailure (
434
+ subViewOp, " the result memref type is not contiguous" );
435
+ }
436
+
437
+ auto sizes = subViewOp.getStaticSizes ();
438
+ int64_t lastOffset = subViewOp.getStaticOffsets ().back ();
439
+ // Only support static sizes and offsets.
440
+ if (llvm::any_of (
441
+ sizes, [](int64_t size) { return size == ShapedType::kDynamic ; }) ||
442
+ lastOffset == ShapedType::kDynamic ) {
443
+ return rewriter.notifyMatchFailure (
444
+ subViewOp->getLoc (), " dynamic size or offset is not supported" );
445
+ }
446
+
447
+ // Transform the offsets, sizes and strides according to the emulation.
448
+ auto stridedMetadata = rewriter.create <memref::ExtractStridedMetadataOp>(
449
+ loc, subViewOp.getViewSource ());
450
+
451
+ OpFoldResult linearizedIndices;
452
+ auto strides = stridedMetadata.getConstifiedMixedStrides ();
453
+ memref::LinearizedMemRefInfo linearizedInfo;
454
+ std::tie (linearizedInfo, linearizedIndices) =
455
+ memref::getLinearizedMemRefOffsetAndSize (
456
+ rewriter, loc, srcBits, dstBits,
457
+ stridedMetadata.getConstifiedMixedOffset (),
458
+ subViewOp.getMixedSizes (), strides,
459
+ getMixedValues (adaptor.getStaticOffsets (), adaptor.getOffsets (),
460
+ rewriter));
461
+
462
+ rewriter.replaceOpWithNewOp <memref::SubViewOp>(
463
+ subViewOp, newTy, adaptor.getSource (), linearizedIndices,
464
+ linearizedInfo.linearizedSize , strides.back ());
465
+ return success ();
428
466
}
429
467
};
430
468
0 commit comments