Skip to content

Commit 67b292c

Browse files
committed
[AMD] Addressed comments of the PR
1 parent 1b2a86b commit 67b292c

File tree

4 files changed

+121
-53
lines changed

4 files changed

+121
-53
lines changed

third_party/amd/lib/TritonAMDGPUDialectToLLVM/ExtractSliceOpToLLVM.cpp

Lines changed: 47 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#include "../TritonAMDGPUToLLVM/Utility.h"
12
#include "Dialect/TritonAMDGPU/IR/Dialect.h"
23
#include "TritonAMDGPUToLLVM/GCNAsmFormat.h"
34
#include "mlir/Conversion/LLVMCommon/Pattern.h"
@@ -49,6 +50,7 @@ using namespace mlir::triton;
4950
// clang-format on
5051

5152
namespace {
53+
5254
struct ExtractSliceOpConversion
5355
: public ConvertOpToLLVMPattern<amdgpu::ExtractSliceOp> {
5456
explicit ExtractSliceOpConversion(LLVMTypeConverter &typeConverter,
@@ -60,61 +62,61 @@ struct ExtractSliceOpConversion
6062
ConversionPatternRewriter &rewriter) const {
6163
Location loc = op->getLoc();
6264
auto srcTy = cast<RankedTensorType>(op.getSource().getType());
63-
auto srcLayout = srcTy.getEncoding();
65+
auto dstTy = cast<RankedTensorType>(op.getType());
6466
auto srcShape = srcTy.getShape();
65-
auto resultTy = cast<RankedTensorType>(op.getType());
66-
auto vals = unpackLLElements(loc, adaptor.getSource(), rewriter);
67-
auto elemsPerThread = triton::gpu::getElemsPerThread(srcTy);
68-
auto contigPerThread = triton::gpu::getContigPerThread(srcTy);
69-
auto totalContigPerThread = product<unsigned>(contigPerThread);
70-
auto order = triton::gpu::getOrder(srcTy);
67+
auto dstShape = dstTy.getShape();
7168

72-
// Calculate valid total number of workers in each dimension
69+
auto vals = unpackLLElements(loc, adaptor.getSource(), rewriter);
7370
auto shapePerCTATile = triton::gpu::getShapePerCTATile(srcTy);
74-
shapePerCTATile[0] =
75-
std::min(static_cast<unsigned>(srcShape[0]), shapePerCTATile[0]);
76-
shapePerCTATile[1] =
77-
std::min(static_cast<unsigned>(srcShape[1]), shapePerCTATile[1]);
78-
79-
// Rank == 2 checked in the verifier
80-
SmallVector<int64_t, 2> sizes;
81-
for (auto i = 0; i < 2; ++i) {
82-
sizes.push_back(resultTy.getDimSize(i));
83-
}
71+
auto srcCTAShape = LLVM::AMD::multiDimElementwise<int64_t, unsigned>(
72+
srcShape, shapePerCTATile, std::divides<unsigned>());
73+
auto dstCTAShape = LLVM::AMD::multiDimElementwise<int64_t, unsigned>(
74+
dstShape, shapePerCTATile, std::divides<unsigned>());
8475

76+
auto numCTATiles = std::accumulate(dstCTAShape.begin(), dstCTAShape.end(),
77+
1, std::multiplies<>());
8578
auto offsets = op.getStaticOffsets();
79+
auto firstTileCoordinate =
80+
LLVM::AMD::multiDimElementwise<int64_t, unsigned>(
81+
offsets, shapePerCTATile, std::divides<unsigned>());
8682

87-
// Calculate offsets and sizes in terms of CTA units.
88-
std::array<int64_t, 2> CTAOffsets{offsets[0] / shapePerCTATile[0],
89-
offsets[1] / shapePerCTATile[1]};
90-
std::array<int64_t, 2> CTASizes{sizes[0] / shapePerCTATile[0],
91-
sizes[1] / shapePerCTATile[1]};
92-
std::array<int64_t, 2> CTAPerShape{srcShape[0] / shapePerCTATile[0],
93-
srcShape[1] / shapePerCTATile[1]};
94-
95-
// The diagram above illustrates the graphical representation of the
96-
// skipElems, tensorStride, and lastIdx variables.
97-
auto skipElems = CTAOffsets[order[1]] * (elemsPerThread[order[0]] *
98-
contigPerThread[order[1]]) +
99-
CTAOffsets[order[0]] * totalContigPerThread;
100-
auto tensorStride =
101-
(CTAPerShape[order[0]] - CTASizes[order[0]]) * totalContigPerThread;
102-
auto lastIdx =
103-
(CTAOffsets[order[1]] + CTASizes[order[1]] - 1) *
104-
elemsPerThread[order[0]] * contigPerThread[order[1]] +
105-
(CTAOffsets[order[0]] + CTASizes[order[0]]) * totalContigPerThread;
106-
107-
assert(lastIdx <= vals.size());
83+
Attribute srcEncoding = srcTy.getEncoding();
84+
Attribute dstEncoding = dstTy.getEncoding();
85+
auto linearLayoutSrc = triton::gpu::toLinearLayout(srcShape, srcEncoding);
86+
auto linearLayoutDst = triton::gpu::toLinearLayout(dstShape, dstEncoding);
10887

88+
auto srcCTAOrder =
89+
LLVM::AMD::getCTATileOrder(srcTy.getContext(), linearLayoutSrc);
90+
auto dstCTAOrder =
91+
LLVM::AMD::getCTATileOrder(srcTy.getContext(), linearLayoutDst);
92+
93+
unsigned elemsPerThreadPerCTA =
94+
triton::gpu::getTotalElemsPerThread(srcTy) /
95+
std::accumulate(srcCTAShape.begin(), srcCTAShape.end(), 1,
96+
std::multiplies<>());
97+
98+
// 1. Process CTA tiles in the destination tensor according to the
99+
// destination's linear layout order of CTA tiles.
100+
// 2. For each tile position in the destination tensor, compute its
101+
// corresponding position in the source tensor.
102+
// 3. Copy the values from the source tile to the destination slice.
109103
SmallVector<Value> resultVals;
110-
for (int i = skipElems; i < lastIdx; i += tensorStride) {
111-
for (int j = 0; j < totalContigPerThread * CTASizes[order[0]]; ++j, ++i) {
112-
assert(i < lastIdx);
113-
resultVals.push_back(vals[i]);
104+
for (size_t i = 0; i < numCTATiles; i++) {
105+
auto coordInDstTensor =
106+
mlir::LLVM::delinearize(i, dstCTAShape, dstCTAOrder);
107+
auto coordInSrcTensor =
108+
LLVM::AMD::multiDimElementwise<unsigned, unsigned>(
109+
coordInDstTensor, firstTileCoordinate, std::plus<unsigned>());
110+
auto linearIdxInSrcTensor =
111+
mlir::LLVM::linearize(coordInSrcTensor, srcCTAShape, srcCTAOrder);
112+
113+
for (size_t j = 0; j < elemsPerThreadPerCTA; j++) {
114+
resultVals.push_back(
115+
vals[linearIdxInSrcTensor * elemsPerThreadPerCTA + j]);
114116
}
115117
}
116118
Value ret = packLLElements(loc, this->getTypeConverter(), resultVals,
117-
rewriter, resultTy);
119+
rewriter, dstTy);
118120

119121
rewriter.replaceOp(op, ret);
120122
return success();
@@ -124,11 +126,7 @@ struct ExtractSliceOpConversion
124126
matchAndRewrite(amdgpu::ExtractSliceOp op, OpAdaptor adaptor,
125127
ConversionPatternRewriter &rewriter) const override {
126128
auto srcTy = op.getSource().getType();
127-
if (isa<BlockedEncodingAttr, AMDMfmaEncodingAttr>(
128-
op.getSource().getType().getEncoding())) {
129-
return processLayout(op, adaptor, rewriter);
130-
}
131-
return failure();
129+
return processLayout(op, adaptor, rewriter);
132130
}
133131
};
134132
} // namespace

third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -755,4 +755,43 @@ void addLocalLoadNoAliasScope(AliasAnalysisOpInterface llLoadOp) {
755755
llLoadOp.setAliasScopes(aliasScopes);
756756
}
757757

758+
SmallVector<unsigned> getCTATileOrder(MLIRContext *ctx,
759+
const LinearLayout &layout) {
760+
auto llEnc = triton::gpu::LinearEncodingAttr::get(ctx, layout);
761+
auto regDim = StringAttr::get(ctx, "register");
762+
auto &bases = layout.getBases().find(regDim)->second;
763+
764+
// Compute number of CTA tiles in a layout.
765+
unsigned totalElems = layout.getTotalOutDimSize();
766+
auto ctaShape = llEnc.getShapePerCTATile();
767+
unsigned elemsPerCTA =
768+
std::accumulate(ctaShape.begin(), ctaShape.end(), 1, std::multiplies<>());
769+
assert((totalElems % elemsPerCTA) == 0 &&
770+
"Total elements must be divisible by elemsPerCTA");
771+
unsigned numCTAs = totalElems / elemsPerCTA;
772+
773+
// To determine the CTA tile order, start by identifying the register basis
774+
// vector that corresponds to the first element of the second CTA tile. The
775+
// nonzero index in the logical tensor it maps to indicates the most minor
776+
// dimension. Then, for each subsequent basis register (first element of
777+
// some CTA tile), extract the next nonzero index to build the full dimension
778+
// order.
779+
unsigned totalPerThread =
780+
product(llEnc.basesPerDim(regDim, /*skipBroadcast=*/false)) / numCTAs;
781+
unsigned startIndex = static_cast<unsigned>(std::log2(totalPerThread));
782+
783+
llvm::SmallSetVector<unsigned, 8> order;
784+
for (unsigned i = startIndex; i < bases.size(); ++i) {
785+
auto it = std::find_if(bases[i].begin(), bases[i].end(),
786+
[](unsigned v) { return v != 0; });
787+
if (it != bases[i].end())
788+
order.insert(std::distance(bases[i].begin(), it));
789+
}
790+
791+
// Append any dims missing from our default order.
792+
for (unsigned dim : llEnc.getOrder())
793+
order.insert(dim);
794+
795+
return SmallVector<unsigned>(order.begin(), order.end());
796+
}
758797
} // namespace mlir::LLVM::AMD

third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,23 @@ void addLocalLoadNoAliasScope(AliasAnalysisOpInterface llLoadOp);
137137
// Attaches the "AsyncCopies" alias scope to llLoadDirectToLdsOp
138138
void addAsyncCopyAliasScope(AliasAnalysisOpInterface llLoadDirectToLdsOp);
139139

140+
// Determine the order in which CTA tiles are laid out across the tensor.
141+
SmallVector<unsigned> getCTATileOrder(MLIRContext *ctx,
142+
const LinearLayout &layout);
143+
144+
template <typename T, typename U, typename BinaryOp>
145+
std::vector<unsigned> multiDimElementwise(const ArrayRef<T> &lhs,
146+
const ArrayRef<U> &rhs, BinaryOp op) {
147+
assert(lhs.size() == rhs.size() && "Input dimensions must match");
148+
std::vector<unsigned> result;
149+
result.reserve(lhs.size());
150+
for (size_t i = 0, n = lhs.size(); i < n; ++i) {
151+
unsigned a = static_cast<unsigned>(lhs[i]);
152+
unsigned b = static_cast<unsigned>(rhs[i]);
153+
result.push_back(op(a, b));
154+
}
155+
return result;
156+
}
140157
} // namespace mlir::LLVM::AMD
141158

142159
#endif // TRITON_THIRD_PARTY_AMD_LIB_TRITONAMDGPUTOLLVM_UTILITY_H_

third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1518,6 +1518,12 @@ LogicalResult Pingponger::transformFP4(OpBuilder &builder, Location loc) {
15181518
builder.setInsertionPointAfter(dotSOps[0]);
15191519
if (sliceDotScaled(builder, loc, dotSOps[0], 4).failed())
15201520
return failure();
1521+
1522+
if (genAsyncCopySlices(builder).failed()) {
1523+
LDBG("failed to slice global-to-local async copies");
1524+
return failure();
1525+
}
1526+
15211527
updateOpInsertion(dotSliceOps[0]);
15221528

15231529
appendOp(builder.create<ROCDL::SchedBarrier>(loc, 0));
@@ -1681,10 +1687,6 @@ void Pingponger::getDotPingponged() {
16811687
return;
16821688
}
16831689

1684-
if (llvm::failed(genAsyncCopySlices(builder))) {
1685-
LDBG("failed to slice global-to-local async copies");
1686-
}
1687-
16881690
auto updateSignature = updateForOpSignature(builder);
16891691
if (llvm::failed(updateSignature)) {
16901692
LDBG("failed to update forOp signature");
@@ -1695,6 +1697,18 @@ void Pingponger::getDotPingponged() {
16951697
LDBG("failed to update forOp signature");
16961698
}
16971699
}
1700+
1701+
forOp->walk([](ttg::AsyncCommitGroupOp groupOp) {
1702+
auto users = groupOp.getResult().getUsers();
1703+
if (users.empty()) {
1704+
SmallVector<Operation *> toDeleteVec;
1705+
for (auto token : groupOp.getInputTokens()) {
1706+
toDeleteVec.push_back(token.getDefiningOp());
1707+
}
1708+
groupOp->erase();
1709+
llvm::for_each(toDeleteVec, [](Operation *op) { op->erase(); });
1710+
}
1711+
});
16981712
addAsymmetricSyncToLoop(builder, loc);
16991713
return;
17001714
}

0 commit comments

Comments
 (0)