@@ -616,6 +616,106 @@ struct LoadOpPattern : public RefineRewritePattern<triton::LoadOp> {
616
616
}
617
617
};
618
618
619
+ // In contrast to `tt.load` which operates on a tensor of pointers,
620
+ // `ttg.buffer_load` operates on a tensor descriptor and offsets
621
+ // which are a tensor of integers. `ttg.buffer_load` also involves
622
+ // `mask` and `other` tensors which are optional. These tensors
623
+ // must be sliced as well if they are provided. It is difficult
624
+ // to unify both convertion patterns for `tt.load` and `ttg.buffer_load`;
625
+ // thus we provide a dedicated pattern to refine `ttg.buffer_load` ops.
626
+ struct AMDGCNBufferLoadOp
627
+ : public RefineRewritePattern<triton::amdgpu::BufferLoadOp> {
628
+ AMDGCNBufferLoadOp (MLIRContext *context, PatternBenefit benefit = 1 )
629
+ : RefineRewritePattern(context, benefit) {}
630
+
631
+ LogicalResult apply (triton::amdgpu::BufferLoadOp op,
632
+ PatternRewriter &rewriter) const override {
633
+ auto ctx = op->getContext ();
634
+ auto loc = op.getLoc ();
635
+
636
+ auto origBasePtr = op.getPtr ();
637
+ auto origElementType =
638
+ cast<PointerType>(origBasePtr.getType ()).getPointeeType ();
639
+ auto origOffsets = op.getOffsets ();
640
+ auto origEncoding =
641
+ cast<RankedTensorType>(origOffsets.getType ()).getEncoding ();
642
+ if (!origEncoding)
643
+ return failure ();
644
+
645
+ auto origStride = op.getStride ();
646
+ auto origCache = op.getCache ();
647
+ auto origMask = op.getMask ();
648
+ auto origOtherTensor = op.getOther ();
649
+
650
+ rewriter.setInsertionPointAfter (op);
651
+
652
+ auto refineTensor = [&](mlir::Value tensor) {
653
+ auto tensorType = cast<RankedTensorType>(tensor.getType ());
654
+ auto origShape = tensorType.getShape ();
655
+ auto elemType = tensorType.getElementType ();
656
+ auto encoding = dyn_cast<BlockedEncodingAttr>(tensorType.getEncoding ());
657
+ assert (encoding != nullptr );
658
+
659
+ RefinedBlock refinedBlock (origShape, elemType, encoding);
660
+
661
+ AMD::CoordinateMapper coordsMapper (refinedBlock.numPerDims );
662
+ SmallVector<Value> slices;
663
+ for (size_t linearIdx = 0 ; linearIdx < refinedBlock.numSubTiles ;
664
+ ++linearIdx) {
665
+ auto coords = coordsMapper.map (linearIdx);
666
+ SmallVector<int64_t > offset (refinedBlock.numDims , 0 );
667
+ for (auto [dim, coord] : llvm::enumerate (coords)) {
668
+ offset[dim] = coord * refinedBlock.elementsPerWorkGroup [dim];
669
+ }
670
+
671
+ auto slice = rewriter.create <triton::amdgpu::ExtractSliceOp>(
672
+ loc, Type{refinedBlock.tensorType }, Value{tensor}, offset);
673
+
674
+ slices.push_back (slice);
675
+ }
676
+
677
+ return std::tuple (slices, refinedBlock.refinedShape ,
678
+ refinedBlock.numPerDims );
679
+ };
680
+
681
+ auto [slicedOffsets, refinedShape, numPerDims] = refineTensor (origOffsets);
682
+ std::optional<SmallVector<Value>> slicedMasks;
683
+ if (origMask) {
684
+ slicedMasks = std::get<0 >(refineTensor (origMask));
685
+ assert (slicedMasks.value ().size () == slicedOffsets.size ());
686
+ }
687
+
688
+ std::optional<SmallVector<Value>> slicedOtherTensors;
689
+ if (origOtherTensor) {
690
+ slicedOtherTensors = std::get<0 >(refineTensor (origOtherTensor));
691
+ assert (slicedOtherTensors.value ().size () == slicedOffsets.size ());
692
+ }
693
+
694
+ Type refinedTensorType =
695
+ RankedTensorType::get (refinedShape, origElementType, origEncoding);
696
+
697
+ SmallVector<Value> refinedOps;
698
+ for (size_t i = 0 ; i < slicedOffsets.size (); ++i) {
699
+ Value slicedOffset = slicedOffsets[i];
700
+ Value slicedMask = slicedMasks ? slicedMasks.value ()[i] : nullptr ;
701
+ Value slicedOtherTensor =
702
+ slicedOtherTensors ? slicedOtherTensors.value ()[i] : nullptr ;
703
+
704
+ auto refinedOp = rewriter.create <triton::amdgpu::BufferLoadOp>(
705
+ loc, refinedTensorType, origBasePtr, slicedOffset, origStride,
706
+ origCache, slicedMask, slicedOtherTensor);
707
+ refinedOps.push_back (refinedOp);
708
+ }
709
+
710
+ Value origResult = op.getResult ();
711
+ auto joinedResult = rewriter.create <triton::amdgpu::ConcatOp>(
712
+ loc, origResult.getType (), refinedOps);
713
+
714
+ origResult.replaceAllUsesWith (joinedResult);
715
+ return success ();
716
+ }
717
+ };
718
+
619
719
struct LocalStoreOpPattern
620
720
: public RefineRewritePattern<triton::gpu::LocalStoreOp> {
621
721
LocalStoreOpPattern (MLIRContext *context, PatternBenefit benefit = 1 )
@@ -1212,6 +1312,7 @@ struct TritonAMDGPURefineOps
1212
1312
patterns.add <LocalLoadOpPattern>(context, /* benefit=*/ 1 );
1213
1313
patterns.add <DotOpPattern>(context, /* benefit=*/ 1 );
1214
1314
patterns.add <LoadOpPattern>(context, /* benefit=*/ 1 );
1315
+ patterns.add <AMDGCNBufferLoadOp>(context, /* benefit=*/ 1 );
1215
1316
patterns.add <LocalStoreOpPattern>(context, /* benefit=*/ 1 );
1216
1317
patterns.add <ReduceOpPattern>(context, /* benefit=*/ 1 );
1217
1318
patterns.add <ExpandDimsOpPattern>(context, /* benefit=*/ 1 );
0 commit comments