@@ -462,6 +462,11 @@ AMDMfmaEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
462
462
{kLane , {{0 , 1 }, {0 , 2 }, {0 , 4 }, {0 , 8 }, /* gap*/ {4 , 0 }, {8 , 0 }}}},
463
463
{outDimNames[order[0 ]], outDimNames[order[1 ]]});
464
464
}
465
+
466
+ auto tilesPerWarp = getTilesPerWarp ();
467
+ tileLayout *=
468
+ mlir::triton::identityStandardND (kRegister , tilesPerWarp, order);
469
+
465
470
if (hasBatchDim) {
466
471
assert (order[2 ] == 0 );
467
472
// Extend the base vector with one value to accommodate for the batch
@@ -637,31 +642,6 @@ LinearLayout chooseDotDsReadB64TrLayout(DotOperandEncodingAttr dotMfmaLayout,
637
642
638
643
LinearLayout mfmaDotToLinearLayout (DotOperandEncodingAttr dotMfmaLayout,
639
644
ArrayRef<int64_t > shape) {
640
-
641
- // Current linear layout conversion for dot operand is only necessary to
642
- // enable LDS bypass for operand B in the MFMA dot path. To achieve
643
- // performance gains from bypassing LDS, the following conditions must be met:
644
- //
645
- // 1) opIdx == 1: Currently, only the B tensor (e.g. weights in moe-like
646
- // kernels) bypasses LDS. This constraint is not strict and support for
647
- // bypassing operand A (e.g. Q tensor in flash attention) will be added in
648
- // the future.
649
- //
650
- // 2) B tensor must be column major: This is required to support vectorized
651
- // global load instructions, as MFMA instructions expect threads to hold B
652
- // operand elements along the K dimension.
653
- //
654
- // 3) kWidth == 8: Ensures maximum global load vectorization for fp16
655
- // operations.
656
- // TODO: Generalize conversion to handle maximum kWidth for other types
657
- // (i.e. fp8).
658
- //
659
- // 4) warpsPerCTA[mDim] == 1: This guarantees that every B tensor element is
660
- // held by exactly one thread, maintaining the same number of global loads
661
- // as in a blocked layout.
662
- //
663
- // Other use of Linear layout is a support of rare corner cases,
664
- // for example one instruction tile is larger than tensor
665
645
auto mfmaLayout = llvm::cast<AMDMfmaEncodingAttr>(dotMfmaLayout.getParent ());
666
646
667
647
auto rank = shape.size ();
@@ -672,6 +652,8 @@ LinearLayout mfmaDotToLinearLayout(DotOperandEncodingAttr dotMfmaLayout,
672
652
auto kDim = dotMfmaLayout.getOpIdx () == 0 ? rank - 1 : rank - 2 ;
673
653
int32_t kSize = shape[kDim ];
674
654
auto warpsPerCTA = mfmaLayout.getWarpsPerCTA ();
655
+ auto tilesPerWarp = mfmaLayout.getTilesPerWarp ();
656
+ auto tilePerWarpNonK = tilesPerWarp[kDim ];
675
657
676
658
MLIRContext *ctx = dotMfmaLayout.getContext ();
677
659
SmallVector<StringAttr> outDimNames = standardOutDimNames (ctx, rank);
@@ -725,6 +707,11 @@ LinearLayout mfmaDotToLinearLayout(DotOperandEncodingAttr dotMfmaLayout,
725
707
for (int32_t elem = kTileSize ; elem < kSize ; elem *= 2 )
726
708
registerBase.emplace_back (std::vector<int32_t >{elem, 0 });
727
709
710
+ // Add repeats of registers along non-K dimension to register base vectors
711
+ for (int32_t elem = mfmaLayout.getMDim ();
712
+ elem < tilePerWarpNonK * mfmaLayout.getMDim (); elem *= 2 )
713
+ registerBase.emplace_back (std::vector<int32_t >{0 , elem});
714
+
728
715
// Base vectors above are defined in a fixed order [non-k-dim, k-dim].
729
716
// To assign them to actual matrix dimensions `order` array is used.
730
717
// For operand A: non-k-dim -> dim0, k-dim -> dim1
@@ -745,7 +732,9 @@ LinearLayout mfmaDotToLinearLayout(DotOperandEncodingAttr dotMfmaLayout,
745
732
LinearLayout ctaLayout = tileLayout.transposeOuts (outDimNames) *
746
733
warpLayout.transposeOuts (outDimNames);
747
734
748
- return combineCtaCgaWithShape (ctaLayout, mfmaLayout.getCTALayout (), shape);
735
+ auto finalLayout =
736
+ combineCtaCgaWithShape (ctaLayout, mfmaLayout.getCTALayout (), shape);
737
+ return finalLayout;
749
738
}
750
739
751
740
LinearLayout
@@ -1446,10 +1435,12 @@ LinearLayout chooseDsReadB64TrLayout(Attribute enc, ArrayRef<int64_t> shape,
1446
1435
return chooseDotDsReadB64TrLayout (dot, shape, elemBitWidth);
1447
1436
}
1448
1437
1449
- LinearLayout chooseScaledMfmaScaleLayout (
1450
- MLIRContext *ctx, int dotOperandIdx,
1451
- const std::vector<std::vector<int32_t >> &dotOperandWarpBasis,
1452
- ArrayRef<int64_t > dotOperandShape, unsigned mfmaMDim) {
1438
+ LinearLayout chooseScaledMfmaScaleLayout (MLIRContext *ctx, int dotOperandIdx,
1439
+ ArrayRef<int64_t > dotOperandShape,
1440
+ unsigned mfmaMDim,
1441
+ ArrayRef<unsigned > tilesPerWarp,
1442
+ ArrayRef<unsigned > warpsPerCTA,
1443
+ bool preshuffleScales) {
1453
1444
using basisT = std::vector<std::vector<int32_t >>;
1454
1445
unsigned rank = dotOperandShape.size ();
1455
1446
auto order = mlir::triton::gpu::getMatrixOrder (rank, /* rowMajor=*/ true );
@@ -1458,31 +1449,16 @@ LinearLayout chooseScaledMfmaScaleLayout(
1458
1449
StringAttr kLane = StringAttr::get (ctx, " lane" );
1459
1450
StringAttr kWarp = StringAttr::get (ctx, " warp" );
1460
1451
StringAttr kBlock = StringAttr::get (ctx, " block" );
1461
- // Init register layout. Will be adjusted later
1462
- auto regs = mlir::triton::identityStandardND ( kRegister , { 1 , 1 }, order) ;
1463
- LinearLayout lanes = LinearLayout::empty ();
1452
+ auto kDim = dotOperandIdx == 0 ? rank - 1 : rank - 2 ;
1453
+ auto tilePerWarpNonK = tilesPerWarp[ kDim ] ;
1454
+
1464
1455
// In scaled dot, the shapes of operands(without batch dimension) are,
1465
1456
// respectively:
1466
1457
// - A: [M, K]
1467
1458
// - B: [K, N]
1468
1459
// - aScale: [M, K / 32]
1469
1460
// - bScale: [N, K / 32]
1470
1461
//
1471
- // To correctly feed A/B and its scale into instruction, we need to
1472
- // distribute aScale/bScale among warps in the same way as A/B. But bScale
1473
- // is not transposed like B. So we need to transpose the warp layout of
1474
- // bScale.
1475
- //
1476
- // The tricky part is, our desired outputs are [dim0, dim1], but
1477
- // at this position, the layouts are transposed to [dim1, dim0]. So
1478
- // instead of reverse bScale's layout, we need to reverse aScale's. There
1479
- // will be a transpose in the end to correct everything.
1480
- basisT warps = dotOperandWarpBasis;
1481
- if (dotOperandIdx == 0 ) {
1482
- for (auto &basis : warps) {
1483
- std::reverse (basis.begin (), basis.end ());
1484
- }
1485
- }
1486
1462
// In general, for both 32x32 and 16x16 scaled mfma, and no matter what
1487
1463
// data type the A/B operand is, each lane takes 32 elements from A/B
1488
1464
// alone K dim, and 1 or 2 elements from scale accordingly. The number of
@@ -1492,43 +1468,70 @@ LinearLayout chooseScaledMfmaScaleLayout(
1492
1468
// For mxfp4, these 32 elements are consecutive, so only 1 scale element
1493
1469
// is required. But for mxfp6/mxfp8, there are 2 16-consecutive elements
1494
1470
// blocks, so 2 scale elements are required.
1471
+ int32_t kSize = dotOperandShape[1 ];
1472
+
1473
+ std::vector<std::vector<int32_t >> registerBase;
1474
+ std::vector<std::vector<int32_t >> laneBase;
1475
+
1476
+ auto kTileSize = mfmaMDim == 32 ? 2 : 4 ;
1477
+
1478
+ if (preshuffleScales) {
1479
+ auto sizePerThreadPerTile = 1 ;
1480
+ auto numKTiles = kSize / kTileSize ;
1481
+ for (int32_t elem = 1 ;
1482
+ elem < sizePerThreadPerTile * numKTiles * tilePerWarpNonK; elem *= 2 )
1483
+ registerBase.emplace_back (std::vector<int32_t >{elem, 0 });
1484
+ } else {
1485
+ for (int32_t elem = kTileSize ; elem < kSize ; elem *= 2 )
1486
+ registerBase.emplace_back (std::vector<int32_t >{elem, 0 });
1487
+
1488
+ for (int32_t elem = mfmaMDim; elem < tilePerWarpNonK * mfmaMDim; elem *= 2 )
1489
+ registerBase.emplace_back (std::vector<int32_t >{0 , elem});
1490
+ }
1495
1491
if (mfmaMDim == 32 ) {
1492
+ if (preshuffleScales) {
1493
+ assert (false && " Preshuffling scales not yet implemented for mDim == 32" );
1494
+ }
1496
1495
// For ROCDL::mfma_scale_f32_32x32x64_f8f6f4 with fp4 input, each lane
1497
1496
// takes 32 consecutive elements from A alone K dimension. The first
1498
1497
// 32 lanes collectively handle A[0:32][0:32], and the other 32 lanes
1499
1498
// collectively handle A[0:32][32:64]. Each lane take 1 scale element
1500
1499
// accordingly. Similar to B and bScale.
1501
- lanes = LinearLayout (
1502
- {{kLane , {{0 , 1 }, {0 , 2 }, {0 , 4 }, {0 , 8 }, {0 , 16 }, {1 , 0 }}},
1503
- {kWarp , warps},
1504
- {kBlock , {}}},
1505
- {standardOutDims[order[0 ]], standardOutDims[order[1 ]]});
1500
+ laneBase = {{0 , 1 }, {0 , 2 }, {0 , 4 }, {0 , 8 }, {0 , 16 }, {1 , 0 }};
1506
1501
} else {
1507
1502
assert (mfmaMDim == 16 );
1508
- // For ROCDL::mfma_scale_f32_16x16x128_f8f6f4 with fp4 input, each lane
1509
- // takes 32 consecutive elements from A alone K dimension. The first
1510
- // 16 lanes collectively handle A[0:16][0:32], and another 16 lanes
1511
- // collectively handle A[0:16][32:64] and so on. Each lane take 1 scale
1512
- // element accordingly. Similar to B and bScale.
1513
- lanes =
1514
- LinearLayout ({{kLane , {{0 , 1 }, {0 , 2 }, {0 , 4 }, {0 , 8 }, {1 , 0 }, {2 , 0 }}},
1515
- {kWarp , warps},
1516
- {kBlock , {}}},
1517
- {standardOutDims[order[0 ]], standardOutDims[order[1 ]]});
1518
- }
1519
- LinearLayout newLL = regs * lanes;
1520
-
1521
- // Adjust register-level layout to fill the shape, at this level, both
1522
- // aScale and bScale should align with A operand.
1523
- SmallVector<int , 2 > repOrder = {1 , 0 };
1524
- for (auto d : repOrder) {
1525
- auto outDim = standardOutDims[d];
1526
- auto dimSize = newLL.getOutDimSize (outDim);
1527
- newLL *= LinearLayout::identity1D (dotOperandShape[d] / dimSize, kRegister ,
1528
- outDim);
1529
- }
1530
- newLL = newLL.transposeOuts (standardOutDims);
1531
- return newLL;
1503
+ if (preshuffleScales) {
1504
+ laneBase = {{4 , 0 }, {0 , 1 }, {0 , 2 }, {0 , 4 }, {0 , 8 }, {0 , 16 }};
1505
+ } else {
1506
+ // For ROCDL::mfma_scale_f32_16x16x128_f8f6f4 with fp4 input, each lane
1507
+ // takes 32 consecutive elements from A alone K dimension. The first
1508
+ // 16 lanes collectively handle A[0:16][0:32], and another 16 lanes
1509
+ // collectively handle A[0:16][32:64] and so on. Each lane take 1 scale
1510
+ // element accordingly. Similar to B and bScale.
1511
+ laneBase = {{0 , 1 }, {0 , 2 }, {0 , 4 }, {0 , 8 }, {1 , 0 }, {2 , 0 }};
1512
+ }
1513
+ }
1514
+
1515
+ SmallVector<StringAttr> outDimNames = standardOutDimNames (ctx, rank);
1516
+ LinearLayout tileLayout ({{kRegister , registerBase}, {kLane , laneBase}},
1517
+ {outDimNames[order[0 ]], outDimNames[order[1 ]]});
1518
+
1519
+ SmallVector<unsigned > warpsPerCTANew{warpsPerCTA[0 ], warpsPerCTA[1 ]};
1520
+ SmallVector<unsigned > warpOrder{1 , 0 };
1521
+
1522
+ if (dotOperandIdx == 1 ) {
1523
+ std::swap (warpsPerCTANew[0 ], warpsPerCTANew[1 ]);
1524
+ std::swap (warpOrder[0 ], warpOrder[1 ]);
1525
+ }
1526
+
1527
+ LinearLayout warpLayout =
1528
+ identityStandardND (kWarp , warpsPerCTANew, warpOrder);
1529
+ LinearLayout ctaLayout = tileLayout.transposeOuts (outDimNames) *
1530
+ warpLayout.transposeOuts (outDimNames);
1531
+
1532
+ auto ctaLay = CTALayoutAttr::get (/* context=*/ ctx, /* CTAsPerCGA=*/ {1 , 1 },
1533
+ /* CTASplitNum=*/ {1 , 1 }, /* CTAOrder=*/ {1 , 0 });
1534
+ return combineCtaCgaWithShape (ctaLayout, ctaLay, dotOperandShape);
1532
1535
}
1533
1536
1534
1537
std::optional<LinearLayout>
0 commit comments