Skip to content

Commit 7edf97c

Browse files
committed
[AMD] Introduce tilesPerWarp and scale preshuffling
1 parent a259f0a commit 7edf97c

File tree

16 files changed

+298
-190
lines changed

16 files changed

+298
-190
lines changed

include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -282,10 +282,12 @@ LinearLayout getTmemLoadLayoutSplitLongM(int M, int N, RankedTensorType oldType,
282282
int numWarps);
283283

284284
// Create LinearLayout for scale in scaled mfma.
285-
LinearLayout chooseScaledMfmaScaleLayout(
286-
MLIRContext *ctx, int dotOperandIdx,
287-
const std::vector<std::vector<int32_t>> &dotOperandWarpBasis,
288-
ArrayRef<int64_t> dotOperandShape, unsigned mfmaMDim);
285+
LinearLayout chooseScaledMfmaScaleLayout(MLIRContext *ctx, int dotOperandIdx,
286+
ArrayRef<int64_t> dotOperandShape,
287+
unsigned mfmaMDim,
288+
ArrayRef<unsigned> tilesPerWarp,
289+
ArrayRef<unsigned> warpsPerCTA,
290+
bool preshuffleScales);
289291

290292
// Create a LinearLayout similar to mfmaLayout, but changing each thread to hold
291293
// 8 elements. This layout is useful for emitting the widest 128-bit global

include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1016,6 +1016,7 @@ V [ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129,
10161016
"unsigned": $versionMajor,
10171017
"unsigned": $versionMinor,
10181018
ArrayRefParameter<"unsigned">:$warpsPerCTA,
1019+
ArrayRefParameter<"unsigned">:$tilesPerWarp,
10191020
"unsigned":$MDim,
10201021
"unsigned":$NDim,
10211022
"bool":$isTransposed,

include/triton/Tools/Sys/GetEnv.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
3838
"TRITON_HIP_ASYNC_COPY_OVERLAP",
3939
"TRITON_HIP_ENABLE_F16_ASYNC_PINGPONG",
4040
"TRITON_HIP_USE_BLOCK_PINGPONG",
41+
"TRITON_HIP_PRESHUFFLE_SCALES",
42+
"TRITON_HIP_BYPASS_LDS_FOR_SCALES",
4143
"TRITON_HIP_USE_IN_THREAD_TRANSPOSE",
4244
"TRITON_HIP_ASYNC_FAST_SWIZZLE",
4345
"TRITON_LLVM_DEBUG_ONLY",

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1287,6 +1287,7 @@ Attribute AMDMfmaEncodingAttr::parse(AsmParser &parser, Type type) {
12871287
unsigned versionMajor = 0;
12881288
unsigned versionMinor = 0;
12891289
SmallVector<unsigned> warpsPerCTA;
1290+
SmallVector<unsigned> tilesPerWarp;
12901291
SmallVector<unsigned> instrShape;
12911292
bool isTransposed;
12921293
std::optional<SmallVector<unsigned>> CTAsPerCGA;
@@ -1306,6 +1307,11 @@ Attribute AMDMfmaEncodingAttr::parse(AsmParser &parser, Type type) {
13061307
if (parseIntArrayAttr(parser, attr, warpsPerCTA, "warpsPerCTA").failed())
13071308
return {};
13081309
}
1310+
if (attr.getName() == "tilesPerWarp") {
1311+
if (parseIntArrayAttr(parser, attr, tilesPerWarp, "tilesPerWarp")
1312+
.failed())
1313+
return {};
1314+
}
13091315
if (attr.getName() == "instrShape") {
13101316
if (parseIntArrayAttr(parser, attr, instrShape, "instrShape").failed())
13111317
return {};
@@ -1339,27 +1345,27 @@ Attribute AMDMfmaEncodingAttr::parse(AsmParser &parser, Type type) {
13391345

13401346
return parser.getChecked<AMDMfmaEncodingAttr>(
13411347
parser.getContext(), versionMajor, versionMinor, warpsPerCTA,
1342-
instrShape[0], instrShape[1], isTransposed, *CTALayout);
1348+
tilesPerWarp, instrShape[0], instrShape[1], isTransposed, *CTALayout);
13431349
}
13441350

13451351
void AMDMfmaEncodingAttr::print(AsmPrinter &printer) const {
13461352
printer << "<{"
13471353
<< "versionMajor = " << getVersionMajor() //
13481354
<< ", versionMinor = " << getVersionMinor() //
13491355
<< ", warpsPerCTA = [" << getWarpsPerCTA() << "]" //
1356+
<< ", tilesPerWarp = [" << getTilesPerWarp() << "]" //
13501357
<< ", instrShape = [" << ArrayRef{getMDim(), getNDim()} << "]" //
13511358
<< ", isTransposed = " << getIsTransposed();
13521359
maybePrintCTALayout(getContext(), printer, getCTALayout(),
13531360
/*rank=*/getRank());
13541361
printer << "}>";
13551362
}
13561363

1357-
LogicalResult
1358-
AMDMfmaEncodingAttr::verify(function_ref<mlir::InFlightDiagnostic()> emitError,
1359-
unsigned versionMajor, unsigned versionMinor,
1360-
llvm::ArrayRef<unsigned int> warpsPerCTA,
1361-
unsigned mDim, unsigned nDim, bool isTransposed,
1362-
mlir::triton::gpu::CTALayoutAttr) {
1364+
LogicalResult AMDMfmaEncodingAttr::verify(
1365+
function_ref<mlir::InFlightDiagnostic()> emitError, unsigned versionMajor,
1366+
unsigned versionMinor, llvm::ArrayRef<unsigned int> warpsPerCTA,
1367+
llvm::ArrayRef<unsigned int> tilesPerWarp, unsigned mDim, unsigned nDim,
1368+
bool isTransposed, mlir::triton::gpu::CTALayoutAttr) {
13631369
if (!(versionMajor >= 0 && versionMajor <= 4)) {
13641370
return emitError() << "major version must be in the [0, 4] range";
13651371
}

lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp

Lines changed: 80 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -462,6 +462,11 @@ AMDMfmaEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
462462
{kLane, {{0, 1}, {0, 2}, {0, 4}, {0, 8}, /*gap*/ {4, 0}, {8, 0}}}},
463463
{outDimNames[order[0]], outDimNames[order[1]]});
464464
}
465+
466+
auto tilesPerWarp = getTilesPerWarp();
467+
tileLayout *=
468+
mlir::triton::identityStandardND(kRegister, tilesPerWarp, order);
469+
465470
if (hasBatchDim) {
466471
assert(order[2] == 0);
467472
// Extend the base vector with one value to accommodate for the batch
@@ -637,31 +642,6 @@ LinearLayout chooseDotDsReadB64TrLayout(DotOperandEncodingAttr dotMfmaLayout,
637642

638643
LinearLayout mfmaDotToLinearLayout(DotOperandEncodingAttr dotMfmaLayout,
639644
ArrayRef<int64_t> shape) {
640-
641-
// Current linear layout conversion for dot operand is only necessary to
642-
// enable LDS bypass for operand B in the MFMA dot path. To achieve
643-
// performance gains from bypassing LDS, the following conditions must be met:
644-
//
645-
// 1) opIdx == 1: Currently, only the B tensor (e.g. weights in moe-like
646-
// kernels) bypasses LDS. This constraint is not strict and support for
647-
// bypassing operand A (e.g. Q tensor in flash attention) will be added in
648-
// the future.
649-
//
650-
// 2) B tensor must be column major: This is required to support vectorized
651-
// global load instructions, as MFMA instructions expect threads to hold B
652-
// operand elements along the K dimension.
653-
//
654-
// 3) kWidth == 8: Ensures maximum global load vectorization for fp16
655-
// operations.
656-
// TODO: Generalize conversion to handle maximum kWidth for other types
657-
// (i.e. fp8).
658-
//
659-
// 4) warpsPerCTA[mDim] == 1: This guarantees that every B tensor element is
660-
// held by exactly one thread, maintaining the same number of global loads
661-
// as in a blocked layout.
662-
//
663-
// Other use of Linear layout is a support of rare corner cases,
664-
// for example one instruction tile is larger than tensor
665645
auto mfmaLayout = llvm::cast<AMDMfmaEncodingAttr>(dotMfmaLayout.getParent());
666646

667647
auto rank = shape.size();
@@ -672,6 +652,8 @@ LinearLayout mfmaDotToLinearLayout(DotOperandEncodingAttr dotMfmaLayout,
672652
auto kDim = dotMfmaLayout.getOpIdx() == 0 ? rank - 1 : rank - 2;
673653
int32_t kSize = shape[kDim];
674654
auto warpsPerCTA = mfmaLayout.getWarpsPerCTA();
655+
auto tilesPerWarp = mfmaLayout.getTilesPerWarp();
656+
auto tilePerWarpNonK = tilesPerWarp[kDim];
675657

676658
MLIRContext *ctx = dotMfmaLayout.getContext();
677659
SmallVector<StringAttr> outDimNames = standardOutDimNames(ctx, rank);
@@ -725,6 +707,11 @@ LinearLayout mfmaDotToLinearLayout(DotOperandEncodingAttr dotMfmaLayout,
725707
for (int32_t elem = kTileSize; elem < kSize; elem *= 2)
726708
registerBase.emplace_back(std::vector<int32_t>{elem, 0});
727709

710+
// Add repeats of registers along non-K dimension to register base vectors
711+
for (int32_t elem = mfmaLayout.getMDim();
712+
elem < tilePerWarpNonK * mfmaLayout.getMDim(); elem *= 2)
713+
registerBase.emplace_back(std::vector<int32_t>{0, elem});
714+
728715
// Base vectors above are defined in a fixed order [non-k-dim, k-dim].
729716
// To assign them to actual matrix dimensions `order` array is used.
730717
// For operand A: non-k-dim -> dim0, k-dim -> dim1
@@ -745,7 +732,9 @@ LinearLayout mfmaDotToLinearLayout(DotOperandEncodingAttr dotMfmaLayout,
745732
LinearLayout ctaLayout = tileLayout.transposeOuts(outDimNames) *
746733
warpLayout.transposeOuts(outDimNames);
747734

748-
return combineCtaCgaWithShape(ctaLayout, mfmaLayout.getCTALayout(), shape);
735+
auto finalLayout =
736+
combineCtaCgaWithShape(ctaLayout, mfmaLayout.getCTALayout(), shape);
737+
return finalLayout;
749738
}
750739

751740
LinearLayout
@@ -1446,10 +1435,12 @@ LinearLayout chooseDsReadB64TrLayout(Attribute enc, ArrayRef<int64_t> shape,
14461435
return chooseDotDsReadB64TrLayout(dot, shape, elemBitWidth);
14471436
}
14481437

1449-
LinearLayout chooseScaledMfmaScaleLayout(
1450-
MLIRContext *ctx, int dotOperandIdx,
1451-
const std::vector<std::vector<int32_t>> &dotOperandWarpBasis,
1452-
ArrayRef<int64_t> dotOperandShape, unsigned mfmaMDim) {
1438+
LinearLayout chooseScaledMfmaScaleLayout(MLIRContext *ctx, int dotOperandIdx,
1439+
ArrayRef<int64_t> dotOperandShape,
1440+
unsigned mfmaMDim,
1441+
ArrayRef<unsigned> tilesPerWarp,
1442+
ArrayRef<unsigned> warpsPerCTA,
1443+
bool preshuffleScales) {
14531444
using basisT = std::vector<std::vector<int32_t>>;
14541445
unsigned rank = dotOperandShape.size();
14551446
auto order = mlir::triton::gpu::getMatrixOrder(rank, /*rowMajor=*/true);
@@ -1458,31 +1449,16 @@ LinearLayout chooseScaledMfmaScaleLayout(
14581449
StringAttr kLane = StringAttr::get(ctx, "lane");
14591450
StringAttr kWarp = StringAttr::get(ctx, "warp");
14601451
StringAttr kBlock = StringAttr::get(ctx, "block");
1461-
// Init register layout. Will be adjusted later
1462-
auto regs = mlir::triton::identityStandardND(kRegister, {1, 1}, order);
1463-
LinearLayout lanes = LinearLayout::empty();
1452+
auto kDim = dotOperandIdx == 0 ? rank - 1 : rank - 2;
1453+
auto tilePerWarpNonK = tilesPerWarp[kDim];
1454+
14641455
// In scaled dot, the shapes of operands(without batch dimension) are,
14651456
// respectively:
14661457
// - A: [M, K]
14671458
// - B: [K, N]
14681459
// - aScale: [M, K / 32]
14691460
// - bScale: [N, K / 32]
14701461
//
1471-
// To correctly feed A/B and its scale into instruction, we need to
1472-
// distribute aScale/bScale among warps in the same way as A/B. But bScale
1473-
// is not transposed like B. So we need to transpose the warp layout of
1474-
// bScale.
1475-
//
1476-
// The tricky part is, our desired outputs are [dim0, dim1], but
1477-
// at this position, the layouts are transposed to [dim1, dim0]. So
1478-
// instead of reverse bScale's layout, we need to reverse aScale's. There
1479-
// will be a transpose in the end to correct everything.
1480-
basisT warps = dotOperandWarpBasis;
1481-
if (dotOperandIdx == 0) {
1482-
for (auto &basis : warps) {
1483-
std::reverse(basis.begin(), basis.end());
1484-
}
1485-
}
14861462
// In general, for both 32x32 and 16x16 scaled mfma, and no matter what
14871463
// data type the A/B operand is, each lane takes 32 elements from A/B
14881464
// alone K dim, and 1 or 2 elements from scale accordingly. The number of
@@ -1492,43 +1468,70 @@ LinearLayout chooseScaledMfmaScaleLayout(
14921468
// For mxfp4, these 32 elements are consecutive, so only 1 scale element
14931469
// is required. But for mxfp6/mxfp8, there are 2 16-consecutive elements
14941470
// blocks, so 2 scale elements are required.
1471+
int32_t kSize = dotOperandShape[1];
1472+
1473+
std::vector<std::vector<int32_t>> registerBase;
1474+
std::vector<std::vector<int32_t>> laneBase;
1475+
1476+
auto kTileSize = mfmaMDim == 32 ? 2 : 4;
1477+
1478+
if (preshuffleScales) {
1479+
auto sizePerThreadPerTile = 1;
1480+
auto numKTiles = kSize / kTileSize;
1481+
for (int32_t elem = 1;
1482+
elem < sizePerThreadPerTile * numKTiles * tilePerWarpNonK; elem *= 2)
1483+
registerBase.emplace_back(std::vector<int32_t>{elem, 0});
1484+
} else {
1485+
for (int32_t elem = kTileSize; elem < kSize; elem *= 2)
1486+
registerBase.emplace_back(std::vector<int32_t>{elem, 0});
1487+
1488+
for (int32_t elem = mfmaMDim; elem < tilePerWarpNonK * mfmaMDim; elem *= 2)
1489+
registerBase.emplace_back(std::vector<int32_t>{0, elem});
1490+
}
14951491
if (mfmaMDim == 32) {
1492+
if (preshuffleScales) {
1493+
assert(false && "Preshuffling scales not yet implemented for mDim == 32");
1494+
}
14961495
// For ROCDL::mfma_scale_f32_32x32x64_f8f6f4 with fp4 input, each lane
14971496
// takes 32 consecutive elements from A alone K dimension. The first
14981497
// 32 lanes collectively handle A[0:32][0:32], and the other 32 lanes
14991498
// collectively handle A[0:32][32:64]. Each lane take 1 scale element
15001499
// accordingly. Similar to B and bScale.
1501-
lanes = LinearLayout(
1502-
{{kLane, {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 16}, {1, 0}}},
1503-
{kWarp, warps},
1504-
{kBlock, {}}},
1505-
{standardOutDims[order[0]], standardOutDims[order[1]]});
1500+
laneBase = {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 16}, {1, 0}};
15061501
} else {
15071502
assert(mfmaMDim == 16);
1508-
// For ROCDL::mfma_scale_f32_16x16x128_f8f6f4 with fp4 input, each lane
1509-
// takes 32 consecutive elements from A alone K dimension. The first
1510-
// 16 lanes collectively handle A[0:16][0:32], and another 16 lanes
1511-
// collectively handle A[0:16][32:64] and so on. Each lane take 1 scale
1512-
// element accordingly. Similar to B and bScale.
1513-
lanes =
1514-
LinearLayout({{kLane, {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {1, 0}, {2, 0}}},
1515-
{kWarp, warps},
1516-
{kBlock, {}}},
1517-
{standardOutDims[order[0]], standardOutDims[order[1]]});
1518-
}
1519-
LinearLayout newLL = regs * lanes;
1520-
1521-
// Adjust register-level layout to fill the shape, at this level, both
1522-
// aScale and bScale should align with A operand.
1523-
SmallVector<int, 2> repOrder = {1, 0};
1524-
for (auto d : repOrder) {
1525-
auto outDim = standardOutDims[d];
1526-
auto dimSize = newLL.getOutDimSize(outDim);
1527-
newLL *= LinearLayout::identity1D(dotOperandShape[d] / dimSize, kRegister,
1528-
outDim);
1529-
}
1530-
newLL = newLL.transposeOuts(standardOutDims);
1531-
return newLL;
1503+
if (preshuffleScales) {
1504+
laneBase = {{4, 0}, {0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 16}};
1505+
} else {
1506+
// For ROCDL::mfma_scale_f32_16x16x128_f8f6f4 with fp4 input, each lane
1507+
// takes 32 consecutive elements from A alone K dimension. The first
1508+
// 16 lanes collectively handle A[0:16][0:32], and another 16 lanes
1509+
// collectively handle A[0:16][32:64] and so on. Each lane take 1 scale
1510+
// element accordingly. Similar to B and bScale.
1511+
laneBase = {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {1, 0}, {2, 0}};
1512+
}
1513+
}
1514+
1515+
SmallVector<StringAttr> outDimNames = standardOutDimNames(ctx, rank);
1516+
LinearLayout tileLayout({{kRegister, registerBase}, {kLane, laneBase}},
1517+
{outDimNames[order[0]], outDimNames[order[1]]});
1518+
1519+
SmallVector<unsigned> warpsPerCTANew{warpsPerCTA[0], warpsPerCTA[1]};
1520+
SmallVector<unsigned> warpOrder{1, 0};
1521+
1522+
if (dotOperandIdx == 1) {
1523+
std::swap(warpsPerCTANew[0], warpsPerCTANew[1]);
1524+
std::swap(warpOrder[0], warpOrder[1]);
1525+
}
1526+
1527+
LinearLayout warpLayout =
1528+
identityStandardND(kWarp, warpsPerCTANew, warpOrder);
1529+
LinearLayout ctaLayout = tileLayout.transposeOuts(outDimNames) *
1530+
warpLayout.transposeOuts(outDimNames);
1531+
1532+
auto ctaLay = CTALayoutAttr::get(/*context=*/ctx, /*CTAsPerCGA=*/{1, 1},
1533+
/*CTASplitNum=*/{1, 1}, /*CTAOrder=*/{1, 0});
1534+
return combineCtaCgaWithShape(ctaLayout, ctaLay, dotOperandShape);
15321535
}
15331536

15341537
std::optional<LinearLayout>

python/triton/knobs.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,7 @@ class amd_knobs(base_knobs):
442442
# We use strs so that we can have a default value based on other runtime info
443443
use_block_pingpong: env_opt_bool = env_opt_bool("TRITON_HIP_USE_BLOCK_PINGPONG")
444444
use_in_thread_transpose: env_opt_bool = env_opt_bool("TRITON_HIP_USE_IN_THREAD_TRANSPOSE")
445+
preshuffle_scales: env_opt_bool = env_opt_bool("TRITON_HIP_PRESHUFFLE_SCALES")
445446

446447
global_prefetch: env_int = env_int("TRITON_HIP_GLOBAL_PREFETCH")
447448
local_prefetch: env_int = env_int("TRITON_HIP_LOCAL_PREFETCH")

third_party/amd/backend/compiler.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,7 @@ def make_ttir(mod, metadata, options):
212212
def make_ttgir(mod, metadata, options):
213213
pm = ir.pass_manager(mod.context)
214214
pm.enable_debug()
215+
preshuffle_scales = knobs.amd.preshuffle_scales
215216
passes.ttir.add_convert_to_ttgpuir(pm, f"hip:{options.arch}", options.num_warps, options.warp_size,
216217
options.num_ctas)
217218
pm.run(mod)
@@ -220,7 +221,7 @@ def make_ttgir(mod, metadata, options):
220221
passes.ttgpuir.add_coalesce(pm)
221222
passes.ttgpuir.add_remove_layout_conversions(pm)
222223
passes.ttgpuir.add_optimize_thread_locality(pm)
223-
amd.passes.ttgpuir.add_accelerate_matmul(pm, options.arch, options.matrix_instr_nonkdim, options.kpack)
224+
amd.passes.ttgpuir.add_accelerate_matmul(pm, options.arch, options.matrix_instr_nonkdim, options.kpack, preshuffle_scales)
224225
passes.ttgpuir.add_remove_layout_conversions(pm)
225226
amd.passes.ttgpuir.add_optimize_epilogue(pm)
226227
passes.ttgpuir.add_optimize_dot_operands(pm, True)

third_party/amd/include/TritonAMDGPUTransforms/Passes.h

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,9 @@ createTritonAMDGPUStreamPipelinePass(int numStages = 2, int globalPrefetch = 0,
1313
int localPrefetch = 0,
1414
bool useAsyncCopy = false);
1515

16-
std::unique_ptr<Pass>
17-
createTritonAMDGPUAccelerateMatmulPass(std::string archGenName = std::string(),
18-
int matrixInstructionSize = 0,
19-
int kpack = 1);
16+
std::unique_ptr<Pass> createTritonAMDGPUAccelerateMatmulPass(
17+
std::string archGenName = std::string(), int matrixInstructionSize = 0,
18+
int kpack = 1, bool preshuffleScales = false);
2019

2120
std::unique_ptr<Pass> createTritonAMDGPUCanonicalizeLoopsPass();
2221

third_party/amd/include/TritonAMDGPUTransforms/Passes.td

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,11 @@ def TritonAMDGPUAccelerateMatmul : Pass<"tritonamdgpu-accelerate-matmul", "mlir:
5252
"enforce matrix instruction MN size">,
5353
Option<"kPack", "kPack",
5454
"int32_t", /*default*/"1",
55-
"KWidth / kBase">
55+
"KWidth / kBase">,
56+
Option<"preshuffleScales", "preshuffle-scales",
57+
"bool", /*default*/"false",
58+
"preshuffle scaledDot scales">
59+
5660
];
5761
}
5862

0 commit comments

Comments
 (0)