Skip to content

Commit 8590ec0

Browse files
committed
[AMD] Added bufferOps refinement
1 parent dd0a348 commit 8590ec0

File tree

1 file changed

+95
-0
lines changed

1 file changed

+95
-0
lines changed

third_party/amd/lib/TritonAMDGPUTransforms/RefineOps.cpp

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -626,6 +626,100 @@ struct LoadOpPattern : public RefineRewritePattern<triton::LoadOp> {
626626
}
627627
};
628628

629+
struct AMDGCNBufferLoadOp
630+
: public RefineRewritePattern<triton::amdgpu::BufferLoadOp> {
631+
AMDGCNBufferLoadOp(MLIRContext *context, PatternBenefit benefit = 1)
632+
: RefineRewritePattern(context, benefit) {}
633+
634+
LogicalResult apply(triton::amdgpu::BufferLoadOp op,
635+
PatternRewriter &rewriter) const override {
636+
auto ctx = op->getContext();
637+
auto loc = op.getLoc();
638+
639+
auto origBasePtr = op.getPtr();
640+
auto origElementType =
641+
cast<PointerType>(origBasePtr.getType()).getPointeeType();
642+
auto origOffsets = op.getOffsets();
643+
auto origEncoding =
644+
cast<RankedTensorType>(origOffsets.getType()).getEncoding();
645+
if (!origEncoding)
646+
return failure();
647+
648+
auto origStride = op.getStride();
649+
auto origCache = op.getCache();
650+
auto origMask = op.getMask();
651+
auto origOtherTensor = op.getOther();
652+
653+
rewriter.setInsertionPointAfter(op);
654+
655+
auto refineTensor = [&](mlir::Value tensor) {
656+
auto tensorType = cast<RankedTensorType>(tensor.getType());
657+
auto origShape = tensorType.getShape();
658+
auto elemType = tensorType.getElementType();
659+
auto encoding = dyn_cast<BlockedEncodingAttr>(tensorType.getEncoding());
660+
assert(encoding != nullptr);
661+
662+
RefinedBlock refinedBlock(origShape, elemType, encoding);
663+
664+
AMD::CoordinateMapper coordsMapper(refinedBlock.numPerDims);
665+
SmallVector<Value> slices;
666+
for (size_t linearIdx = 0; linearIdx < refinedBlock.numSubTiles;
667+
++linearIdx) {
668+
auto coords = coordsMapper.map(linearIdx);
669+
SmallVector<int64_t> offset(refinedBlock.numDims, 0);
670+
for (auto [dim, coord] : llvm::enumerate(coords)) {
671+
offset[dim] = coord * refinedBlock.elementsPerWorkGroup[dim];
672+
}
673+
674+
auto slice = rewriter.create<triton::amdgpu::ExtractSliceOp>(
675+
loc, Type{refinedBlock.tensorType}, Value{tensor}, offset);
676+
677+
slices.push_back(slice);
678+
}
679+
680+
return std::tuple(slices, refinedBlock.refinedShape,
681+
refinedBlock.numPerDims);
682+
};
683+
684+
auto [slicedOffsets, refinedShape, numPerDims] = refineTensor(origOffsets);
685+
std::optional<SmallVector<Value>> slicedMasks;
686+
if (origMask) {
687+
slicedMasks = std::get<0>(refineTensor(origMask));
688+
assert(slicedMasks.value().size() == slicedOffsets.size());
689+
}
690+
691+
std::optional<SmallVector<Value>> slicedOtherTensors;
692+
if (origOtherTensor) {
693+
slicedOtherTensors = std::get<0>(refineTensor(origOtherTensor));
694+
assert(slicedOtherTensors.value().size() == slicedOffsets.size());
695+
}
696+
697+
Type refinedTensorType =
698+
RankedTensorType::get(refinedShape, origElementType, origEncoding);
699+
700+
SmallVector<Value> refinedOps;
701+
for (size_t i = 0; i < slicedOffsets.size(); ++i) {
702+
Value slicedOffset = slicedOffsets[i];
703+
Value slicedMask = slicedMasks ? slicedMasks.value()[i] : nullptr;
704+
Value slicedOtherTensor =
705+
slicedOtherTensors ? slicedOtherTensors.value()[i] : nullptr;
706+
707+
auto refinedOp = rewriter.create<triton::amdgpu::BufferLoadOp>(
708+
loc, refinedTensorType, origBasePtr, slicedOffset, origStride,
709+
origCache, slicedMask, slicedOtherTensor);
710+
refinedOps.push_back(refinedOp);
711+
}
712+
713+
auto concatDims = DenseI64ArrayAttr::get(ctx, numPerDims);
714+
Value origResult = op.getResult();
715+
auto joinedResult = rewriter.create<triton::amdgpu::ConcatOp>(
716+
loc, origResult.getType(), refinedOps, concatDims);
717+
718+
origResult.replaceAllUsesWith(joinedResult);
719+
return success();
720+
}
721+
};
722+
629723
struct LocalStoreOpPattern
630724
: public RefineRewritePattern<triton::gpu::LocalStoreOp> {
631725
LocalStoreOpPattern(MLIRContext *context, PatternBenefit benefit = 1)
@@ -1138,6 +1232,7 @@ struct TritonAMDGPURefineOps
11381232
patterns.add<LocalLoadOpPattern>(context, /*benefit=*/1);
11391233
patterns.add<DotOpPattern>(context, /*benefit=*/1);
11401234
patterns.add<LoadOpPattern>(context, /*benefit=*/1);
1235+
patterns.add<AMDGCNBufferLoadOp>(context, /*benefit=*/1);
11411236
patterns.add<LocalStoreOpPattern>(context, /*benefit=*/1);
11421237
patterns.add<ReduceOpPattern>(context, /*benefit=*/1);
11431238
patterns.add<ExpandDimsOpPattern>(context, /*benefit=*/1);

0 commit comments

Comments
 (0)