@@ -626,6 +626,100 @@ struct LoadOpPattern : public RefineRewritePattern<triton::LoadOp> {
626
626
}
627
627
};
628
628
629
+ struct AMDGCNBufferLoadOp
630
+ : public RefineRewritePattern<triton::amdgpu::BufferLoadOp> {
631
+ AMDGCNBufferLoadOp (MLIRContext *context, PatternBenefit benefit = 1 )
632
+ : RefineRewritePattern(context, benefit) {}
633
+
634
+ LogicalResult apply (triton::amdgpu::BufferLoadOp op,
635
+ PatternRewriter &rewriter) const override {
636
+ auto ctx = op->getContext ();
637
+ auto loc = op.getLoc ();
638
+
639
+ auto origBasePtr = op.getPtr ();
640
+ auto origElementType =
641
+ cast<PointerType>(origBasePtr.getType ()).getPointeeType ();
642
+ auto origOffsets = op.getOffsets ();
643
+ auto origEncoding =
644
+ cast<RankedTensorType>(origOffsets.getType ()).getEncoding ();
645
+ if (!origEncoding)
646
+ return failure ();
647
+
648
+ auto origStride = op.getStride ();
649
+ auto origCache = op.getCache ();
650
+ auto origMask = op.getMask ();
651
+ auto origOtherTensor = op.getOther ();
652
+
653
+ rewriter.setInsertionPointAfter (op);
654
+
655
+ auto refineTensor = [&](mlir::Value tensor) {
656
+ auto tensorType = cast<RankedTensorType>(tensor.getType ());
657
+ auto origShape = tensorType.getShape ();
658
+ auto elemType = tensorType.getElementType ();
659
+ auto encoding = dyn_cast<BlockedEncodingAttr>(tensorType.getEncoding ());
660
+ assert (encoding != nullptr );
661
+
662
+ RefinedBlock refinedBlock (origShape, elemType, encoding);
663
+
664
+ AMD::CoordinateMapper coordsMapper (refinedBlock.numPerDims );
665
+ SmallVector<Value> slices;
666
+ for (size_t linearIdx = 0 ; linearIdx < refinedBlock.numSubTiles ;
667
+ ++linearIdx) {
668
+ auto coords = coordsMapper.map (linearIdx);
669
+ SmallVector<int64_t > offset (refinedBlock.numDims , 0 );
670
+ for (auto [dim, coord] : llvm::enumerate (coords)) {
671
+ offset[dim] = coord * refinedBlock.elementsPerWorkGroup [dim];
672
+ }
673
+
674
+ auto slice = rewriter.create <triton::amdgpu::ExtractSliceOp>(
675
+ loc, Type{refinedBlock.tensorType }, Value{tensor}, offset);
676
+
677
+ slices.push_back (slice);
678
+ }
679
+
680
+ return std::tuple (slices, refinedBlock.refinedShape ,
681
+ refinedBlock.numPerDims );
682
+ };
683
+
684
+ auto [slicedOffsets, refinedShape, numPerDims] = refineTensor (origOffsets);
685
+ std::optional<SmallVector<Value>> slicedMasks;
686
+ if (origMask) {
687
+ slicedMasks = std::get<0 >(refineTensor (origMask));
688
+ assert (slicedMasks.value ().size () == slicedOffsets.size ());
689
+ }
690
+
691
+ std::optional<SmallVector<Value>> slicedOtherTensors;
692
+ if (origOtherTensor) {
693
+ slicedOtherTensors = std::get<0 >(refineTensor (origOtherTensor));
694
+ assert (slicedOtherTensors.value ().size () == slicedOffsets.size ());
695
+ }
696
+
697
+ Type refinedTensorType =
698
+ RankedTensorType::get (refinedShape, origElementType, origEncoding);
699
+
700
+ SmallVector<Value> refinedOps;
701
+ for (size_t i = 0 ; i < slicedOffsets.size (); ++i) {
702
+ Value slicedOffset = slicedOffsets[i];
703
+ Value slicedMask = slicedMasks ? slicedMasks.value ()[i] : nullptr ;
704
+ Value slicedOtherTensor =
705
+ slicedOtherTensors ? slicedOtherTensors.value ()[i] : nullptr ;
706
+
707
+ auto refinedOp = rewriter.create <triton::amdgpu::BufferLoadOp>(
708
+ loc, refinedTensorType, origBasePtr, slicedOffset, origStride,
709
+ origCache, slicedMask, slicedOtherTensor);
710
+ refinedOps.push_back (refinedOp);
711
+ }
712
+
713
+ auto concatDims = DenseI64ArrayAttr::get (ctx, numPerDims);
714
+ Value origResult = op.getResult ();
715
+ auto joinedResult = rewriter.create <triton::amdgpu::ConcatOp>(
716
+ loc, origResult.getType (), refinedOps, concatDims);
717
+
718
+ origResult.replaceAllUsesWith (joinedResult);
719
+ return success ();
720
+ }
721
+ };
722
+
629
723
struct LocalStoreOpPattern
630
724
: public RefineRewritePattern<triton::gpu::LocalStoreOp> {
631
725
LocalStoreOpPattern (MLIRContext *context, PatternBenefit benefit = 1 )
@@ -1229,6 +1323,7 @@ struct TritonAMDGPURefineOps
1229
1323
patterns.add <LocalLoadOpPattern>(context, /* benefit=*/ 1 );
1230
1324
patterns.add <DotOpPattern>(context, /* benefit=*/ 1 );
1231
1325
patterns.add <LoadOpPattern>(context, /* benefit=*/ 1 );
1326
+ patterns.add <AMDGCNBufferLoadOp>(context, /* benefit=*/ 1 );
1232
1327
patterns.add <LocalStoreOpPattern>(context, /* benefit=*/ 1 );
1233
1328
patterns.add <ReduceOpPattern>(context, /* benefit=*/ 1 );
1234
1329
patterns.add <ExpandDimsOpPattern>(context, /* benefit=*/ 1 );
0 commit comments