@@ -1241,10 +1241,10 @@ ttg::LocalAllocOp findShmemAlloc(Value operand) {
1241
1241
// come from an MemDescSubview op. Only ConvertLayout and Trans ops are
1242
1242
// allowed in between.
1243
1243
Value transitiveOperand = operand;
1244
- while (
1245
- isa_and_nonnull<ttg::ConvertLayoutOp, tt::TransOp, ttg::MemDescTransOp >(
1246
- transitiveOperand.getDefiningOp ()) ||
1247
- isa<BlockArgument>(transitiveOperand)) {
1244
+ while (isa_and_nonnull<ttg::ConvertLayoutOp, tt::TransOp, ttg::MemDescTransOp,
1245
+ ttg::MemDescReshapeOp >(
1246
+ transitiveOperand.getDefiningOp ()) ||
1247
+ isa<BlockArgument>(transitiveOperand)) {
1248
1248
if (auto blockArg = dyn_cast<BlockArgument>(transitiveOperand)) {
1249
1249
assert (isa<scf::ForOp>(blockArg.getOwner ()->getParentOp ()) &&
1250
1250
" Block argument must come from a for loop" );
@@ -1409,7 +1409,7 @@ void replaceUsesAndPropagateType(OpBuilder &builder, Operation *oldUse,
1409
1409
}
1410
1410
1411
1411
// Non-subview/trans ops will be replaced by `val`.
1412
- if (!isa<ttg::MemDescTransOp, ttg::MemDescSubviewOp>( use.getOwner ())) {
1412
+ if (!use.getOwner ()-> hasTrait <OpTrait::MemDescViewTrait>( )) {
1413
1413
operandsToReplace.push_back (&use);
1414
1414
continue ;
1415
1415
}
@@ -1427,13 +1427,15 @@ void replaceUsesAndPropagateType(OpBuilder &builder, Operation *oldUse,
1427
1427
oldType.getMemorySpace (), isMutable);
1428
1428
newVal = builder.create <ttg::MemDescSubviewOp>(
1429
1429
subview.getLoc (), newDstType, val, subview.getOffsets ());
1430
- newVal.getDefiningOp ()->setAttrs (user->getAttrs ());
1431
1430
} else if (auto trans = dyn_cast<ttg::MemDescTransOp>(user)) {
1432
1431
newVal = builder.create <ttg::MemDescTransOp>(trans.getLoc (), val,
1433
1432
trans.getOrder ());
1434
- newVal.getDefiningOp ()->setAttrs (user->getAttrs ());
1433
+ } else if (auto reshape = dyn_cast<ttg::MemDescReshapeOp>(user)) {
1434
+ newVal = builder.create <ttg::MemDescReshapeOp>(reshape.getLoc (),
1435
+ reshape.getType (), val);
1435
1436
}
1436
- assert (newVal);
1437
+ assert (newVal && " unhandled memdesc view" );
1438
+ newVal.getDefiningOp ()->setAttrs (user->getAttrs ());
1437
1439
replaceUsesAndPropagateType (builder, user, newVal);
1438
1440
opsToDelete.push_back (use.getOwner ());
1439
1441
}
0 commit comments