Skip to content

Commit 62a8438

Browse files
committed
[AMD] Added bufferOps refinement
1 parent 0341d75 commit 62a8438

File tree

1 file changed

+101
-0
lines changed

1 file changed

+101
-0
lines changed

third_party/amd/lib/TritonAMDGPUTransforms/RefineOps.cpp

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -616,6 +616,106 @@ struct LoadOpPattern : public RefineRewritePattern<triton::LoadOp> {
616616
}
617617
};
618618

619+
// In contrast to `tt.load` which operates on a tensor of pointers,
620+
// `ttg.buffer_load` operates on a tensor descriptor and offsets
621+
// which are a tensor of integers. `ttg.buffer_load` also involves
622+
// `mask` and `other` tensors which are optional. These tensors
623+
// must be sliced as well if they are provided. It is difficult
624+
// to unify both convertion patterns for `tt.load` and `ttg.buffer_load`;
625+
// thus we provide a dedicated pattern to refine `ttg.buffer_load` ops.
626+
struct AMDGCNBufferLoadOp
627+
: public RefineRewritePattern<triton::amdgpu::BufferLoadOp> {
628+
AMDGCNBufferLoadOp(MLIRContext *context, PatternBenefit benefit = 1)
629+
: RefineRewritePattern(context, benefit) {}
630+
631+
LogicalResult apply(triton::amdgpu::BufferLoadOp op,
632+
PatternRewriter &rewriter) const override {
633+
auto ctx = op->getContext();
634+
auto loc = op.getLoc();
635+
636+
auto origBasePtr = op.getPtr();
637+
auto origElementType =
638+
cast<PointerType>(origBasePtr.getType()).getPointeeType();
639+
auto origOffsets = op.getOffsets();
640+
auto origEncoding =
641+
cast<RankedTensorType>(origOffsets.getType()).getEncoding();
642+
if (!origEncoding)
643+
return failure();
644+
645+
auto origStride = op.getStride();
646+
auto origCache = op.getCache();
647+
auto origMask = op.getMask();
648+
auto origOtherTensor = op.getOther();
649+
650+
rewriter.setInsertionPointAfter(op);
651+
652+
auto refineTensor = [&](mlir::Value tensor) {
653+
auto tensorType = cast<RankedTensorType>(tensor.getType());
654+
auto origShape = tensorType.getShape();
655+
auto elemType = tensorType.getElementType();
656+
auto encoding = dyn_cast<BlockedEncodingAttr>(tensorType.getEncoding());
657+
assert(encoding != nullptr);
658+
659+
RefinedBlock refinedBlock(origShape, elemType, encoding);
660+
661+
AMD::CoordinateMapper coordsMapper(refinedBlock.numPerDims);
662+
SmallVector<Value> slices;
663+
for (size_t linearIdx = 0; linearIdx < refinedBlock.numSubTiles;
664+
++linearIdx) {
665+
auto coords = coordsMapper.map(linearIdx);
666+
SmallVector<int64_t> offset(refinedBlock.numDims, 0);
667+
for (auto [dim, coord] : llvm::enumerate(coords)) {
668+
offset[dim] = coord * refinedBlock.elementsPerWorkGroup[dim];
669+
}
670+
671+
auto slice = rewriter.create<triton::amdgpu::ExtractSliceOp>(
672+
loc, Type{refinedBlock.tensorType}, Value{tensor}, offset);
673+
674+
slices.push_back(slice);
675+
}
676+
677+
return std::tuple(slices, refinedBlock.refinedShape,
678+
refinedBlock.numPerDims);
679+
};
680+
681+
auto [slicedOffsets, refinedShape, numPerDims] = refineTensor(origOffsets);
682+
std::optional<SmallVector<Value>> slicedMasks;
683+
if (origMask) {
684+
slicedMasks = std::get<0>(refineTensor(origMask));
685+
assert(slicedMasks.value().size() == slicedOffsets.size());
686+
}
687+
688+
std::optional<SmallVector<Value>> slicedOtherTensors;
689+
if (origOtherTensor) {
690+
slicedOtherTensors = std::get<0>(refineTensor(origOtherTensor));
691+
assert(slicedOtherTensors.value().size() == slicedOffsets.size());
692+
}
693+
694+
Type refinedTensorType =
695+
RankedTensorType::get(refinedShape, origElementType, origEncoding);
696+
697+
SmallVector<Value> refinedOps;
698+
for (size_t i = 0; i < slicedOffsets.size(); ++i) {
699+
Value slicedOffset = slicedOffsets[i];
700+
Value slicedMask = slicedMasks ? slicedMasks.value()[i] : nullptr;
701+
Value slicedOtherTensor =
702+
slicedOtherTensors ? slicedOtherTensors.value()[i] : nullptr;
703+
704+
auto refinedOp = rewriter.create<triton::amdgpu::BufferLoadOp>(
705+
loc, refinedTensorType, origBasePtr, slicedOffset, origStride,
706+
origCache, slicedMask, slicedOtherTensor);
707+
refinedOps.push_back(refinedOp);
708+
}
709+
710+
Value origResult = op.getResult();
711+
auto joinedResult = rewriter.create<triton::amdgpu::ConcatOp>(
712+
loc, origResult.getType(), refinedOps);
713+
714+
origResult.replaceAllUsesWith(joinedResult);
715+
return success();
716+
}
717+
};
718+
619719
struct LocalStoreOpPattern
620720
: public RefineRewritePattern<triton::gpu::LocalStoreOp> {
621721
LocalStoreOpPattern(MLIRContext *context, PatternBenefit benefit = 1)
@@ -1212,6 +1312,7 @@ struct TritonAMDGPURefineOps
12121312
patterns.add<LocalLoadOpPattern>(context, /*benefit=*/1);
12131313
patterns.add<DotOpPattern>(context, /*benefit=*/1);
12141314
patterns.add<LoadOpPattern>(context, /*benefit=*/1);
1315+
patterns.add<AMDGCNBufferLoadOp>(context, /*benefit=*/1);
12151316
patterns.add<LocalStoreOpPattern>(context, /*benefit=*/1);
12161317
patterns.add<ReduceOpPattern>(context, /*benefit=*/1);
12171318
patterns.add<ExpandDimsOpPattern>(context, /*benefit=*/1);

0 commit comments

Comments
 (0)