File tree Expand file tree Collapse file tree 1 file changed +16
-1
lines changed Expand file tree Collapse file tree 1 file changed +16
-1
lines changed Original file line number Diff line number Diff line change @@ -441,7 +441,22 @@ AMDMfmaEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
441
441
identityStandardND (S (" warp" ), getWarpsPerCTA (), order);
442
442
LinearLayout ctaLayout = tileLayout * warpLayout;
443
443
444
- return combineCtaCgaWithShape (ctaLayout, getCTALayout (), shape);
444
+ auto combinedLayout =
445
+ combineCtaCgaWithShape (ctaLayout, getCTALayout (), shape);
446
+
447
+ auto bases = combinedLayout.getBases ();
448
+ std::vector<std::vector<int >> newRegBases;
449
+ for (const auto &basis : bases[S (" register" )]) {
450
+ if (llvm::any_of (basis, [](int b) { return b != 0 ; })) {
451
+ newRegBases.push_back (basis);
452
+ }
453
+ }
454
+ bases[S (" register" )] = newRegBases;
455
+
456
+ auto result = LinearLayout (std::move (bases),
457
+ llvm::to_vector (combinedLayout.getOutDimNames ()));
458
+
459
+ return result;
445
460
}
446
461
447
462
LinearLayout chooseDotDsReadB64TrLayout (DotOperandEncodingAttr dotMfmaLayout,
You can’t perform that action at this time.
0 commit comments