Skip to content

Commit 1279786

Browse files
committed
[AMD] Fixed MFMA to Linear Layout conversion for 1D tensors
1 parent 4f0d5f1 commit 1279786

File tree

1 file changed

+16
-1
lines changed

1 file changed

+16
-1
lines changed

lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -441,7 +441,22 @@ AMDMfmaEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
441441
identityStandardND(S("warp"), getWarpsPerCTA(), order);
442442
LinearLayout ctaLayout = tileLayout * warpLayout;
443443

444-
return combineCtaCgaWithShape(ctaLayout, getCTALayout(), shape);
444+
auto combinedLayout =
445+
combineCtaCgaWithShape(ctaLayout, getCTALayout(), shape);
446+
447+
auto bases = combinedLayout.getBases();
448+
std::vector<std::vector<int>> newRegBases;
449+
for (const auto &basis : bases[S("register")]) {
450+
if (llvm::any_of(basis, [](int b) { return b != 0; })) {
451+
newRegBases.push_back(basis);
452+
}
453+
}
454+
bases[S("register")] = newRegBases;
455+
456+
auto result = LinearLayout(std::move(bases),
457+
llvm::to_vector(combinedLayout.getOutDimNames()));
458+
459+
return result;
445460
}
446461

447462
LinearLayout chooseDotDsReadB64TrLayout(DotOperandEncodingAttr dotMfmaLayout,

0 commit comments

Comments
 (0)