1
+ #include " ../TritonAMDGPUToLLVM/Utility.h"
1
2
#include " Dialect/TritonAMDGPU/IR/Dialect.h"
2
3
#include " TritonAMDGPUToLLVM/GCNAsmFormat.h"
3
4
#include " mlir/Conversion/LLVMCommon/Pattern.h"
@@ -49,6 +50,7 @@ using namespace mlir::triton;
49
50
// clang-format on
50
51
51
52
namespace {
53
+
52
54
struct ExtractSliceOpConversion
53
55
: public ConvertOpToLLVMPattern<amdgpu::ExtractSliceOp> {
54
56
explicit ExtractSliceOpConversion (LLVMTypeConverter &typeConverter,
@@ -60,61 +62,61 @@ struct ExtractSliceOpConversion
60
62
ConversionPatternRewriter &rewriter) const {
61
63
Location loc = op->getLoc ();
62
64
auto srcTy = cast<RankedTensorType>(op.getSource ().getType ());
63
- auto srcLayout = srcTy. getEncoding ( );
65
+ auto dstTy = cast<RankedTensorType>(op. getType () );
64
66
auto srcShape = srcTy.getShape ();
65
- auto resultTy = cast<RankedTensorType>(op.getType ());
66
- auto vals = unpackLLElements (loc, adaptor.getSource (), rewriter);
67
- auto elemsPerThread = triton::gpu::getElemsPerThread (srcTy);
68
- auto contigPerThread = triton::gpu::getContigPerThread (srcTy);
69
- auto totalContigPerThread = product<unsigned >(contigPerThread);
70
- auto order = triton::gpu::getOrder (srcTy);
67
+ auto dstShape = dstTy.getShape ();
71
68
72
- // Calculate valid total number of workers in each dimension
69
+ auto vals = unpackLLElements (loc, adaptor. getSource (), rewriter);
73
70
auto shapePerCTATile = triton::gpu::getShapePerCTATile (srcTy);
74
- shapePerCTATile[0 ] =
75
- std::min (static_cast <unsigned >(srcShape[0 ]), shapePerCTATile[0 ]);
76
- shapePerCTATile[1 ] =
77
- std::min (static_cast <unsigned >(srcShape[1 ]), shapePerCTATile[1 ]);
78
-
79
- // Rank == 2 checked in the verifier
80
- SmallVector<int64_t , 2 > sizes;
81
- for (auto i = 0 ; i < 2 ; ++i) {
82
- sizes.push_back (resultTy.getDimSize (i));
83
- }
71
+ auto srcCTAShape = LLVM::AMD::multiDimElementwise<int64_t , unsigned >(
72
+ srcShape, shapePerCTATile, std::divides<unsigned >());
73
+ auto dstCTAShape = LLVM::AMD::multiDimElementwise<int64_t , unsigned >(
74
+ dstShape, shapePerCTATile, std::divides<unsigned >());
84
75
76
+ auto numCTATiles = std::accumulate (dstCTAShape.begin (), dstCTAShape.end (),
77
+ 1 , std::multiplies<>());
85
78
auto offsets = op.getStaticOffsets ();
79
+ auto firstTileCoordinate =
80
+ LLVM::AMD::multiDimElementwise<int64_t , unsigned >(
81
+ offsets, shapePerCTATile, std::divides<unsigned >());
86
82
87
- // Calculate offsets and sizes in terms of CTA units.
88
- std::array<int64_t , 2 > CTAOffsets{offsets[0 ] / shapePerCTATile[0 ],
89
- offsets[1 ] / shapePerCTATile[1 ]};
90
- std::array<int64_t , 2 > CTASizes{sizes[0 ] / shapePerCTATile[0 ],
91
- sizes[1 ] / shapePerCTATile[1 ]};
92
- std::array<int64_t , 2 > CTAPerShape{srcShape[0 ] / shapePerCTATile[0 ],
93
- srcShape[1 ] / shapePerCTATile[1 ]};
94
-
95
- // The diagram above illustrates the graphical representation of the
96
- // skipElems, tensorStride, and lastIdx variables.
97
- auto skipElems = CTAOffsets[order[1 ]] * (elemsPerThread[order[0 ]] *
98
- contigPerThread[order[1 ]]) +
99
- CTAOffsets[order[0 ]] * totalContigPerThread;
100
- auto tensorStride =
101
- (CTAPerShape[order[0 ]] - CTASizes[order[0 ]]) * totalContigPerThread;
102
- auto lastIdx =
103
- (CTAOffsets[order[1 ]] + CTASizes[order[1 ]] - 1 ) *
104
- elemsPerThread[order[0 ]] * contigPerThread[order[1 ]] +
105
- (CTAOffsets[order[0 ]] + CTASizes[order[0 ]]) * totalContigPerThread;
106
-
107
- assert (lastIdx <= vals.size ());
83
+ Attribute srcEncoding = srcTy.getEncoding ();
84
+ Attribute dstEncoding = dstTy.getEncoding ();
85
+ auto linearLayoutSrc = triton::gpu::toLinearLayout (srcShape, srcEncoding);
86
+ auto linearLayoutDst = triton::gpu::toLinearLayout (dstShape, dstEncoding);
108
87
88
+ auto srcCTAOrder =
89
+ LLVM::AMD::getCTATileOrder (srcTy.getContext (), linearLayoutSrc);
90
+ auto dstCTAOrder =
91
+ LLVM::AMD::getCTATileOrder (srcTy.getContext (), linearLayoutDst);
92
+
93
+ unsigned elemsPerThreadPerCTA =
94
+ triton::gpu::getTotalElemsPerThread (srcTy) /
95
+ std::accumulate (srcCTAShape.begin (), srcCTAShape.end (), 1 ,
96
+ std::multiplies<>());
97
+
98
+ // 1. Process CTA tiles in the destination tensor according to the
99
+ // destination's linear layout order of CTA tiles.
100
+ // 2. For each tile position in the destination tensor, compute its
101
+ // corresponding position in the source tensor.
102
+ // 3. Copy the values from the source tile to the destination slice.
109
103
SmallVector<Value> resultVals;
110
- for (int i = skipElems; i < lastIdx; i += tensorStride) {
111
- for (int j = 0 ; j < totalContigPerThread * CTASizes[order[0 ]]; ++j, ++i) {
112
- assert (i < lastIdx);
113
- resultVals.push_back (vals[i]);
104
+ for (size_t i = 0 ; i < numCTATiles; i++) {
105
+ auto coordInDstTensor =
106
+ mlir::LLVM::delinearize (i, dstCTAShape, dstCTAOrder);
107
+ auto coordInSrcTensor =
108
+ LLVM::AMD::multiDimElementwise<unsigned , unsigned >(
109
+ coordInDstTensor, firstTileCoordinate, std::plus<unsigned >());
110
+ auto linearIdxInSrcTensor =
111
+ mlir::LLVM::linearize (coordInSrcTensor, srcCTAShape, srcCTAOrder);
112
+
113
+ for (size_t j = 0 ; j < elemsPerThreadPerCTA; j++) {
114
+ resultVals.push_back (
115
+ vals[linearIdxInSrcTensor * elemsPerThreadPerCTA + j]);
114
116
}
115
117
}
116
118
Value ret = packLLElements (loc, this ->getTypeConverter (), resultVals,
117
- rewriter, resultTy );
119
+ rewriter, dstTy );
118
120
119
121
rewriter.replaceOp (op, ret);
120
122
return success ();
@@ -124,11 +126,7 @@ struct ExtractSliceOpConversion
124
126
matchAndRewrite (amdgpu::ExtractSliceOp op, OpAdaptor adaptor,
125
127
ConversionPatternRewriter &rewriter) const override {
126
128
auto srcTy = op.getSource ().getType ();
127
- if (isa<BlockedEncodingAttr, AMDMfmaEncodingAttr>(
128
- op.getSource ().getType ().getEncoding ())) {
129
- return processLayout (op, adaptor, rewriter);
130
- }
131
- return failure ();
129
+ return processLayout (op, adaptor, rewriter);
132
130
}
133
131
};
134
132
} // namespace
0 commit comments