Skip to content

Commit 7f9bbc9

Browse files
committed
WIP: TensorDescToBlockPtr updates
Signed-off-by: Tiotto, Ettore <[email protected]>
1 parent 9eb16ec commit 7f9bbc9

File tree

2 files changed

+9
-190
lines changed

2 files changed

+9
-190
lines changed

third_party/intel/lib/Dialect/Triton/Transforms/TensorDescToBlockPointer.cpp

+2-172
Original file line numberDiff line numberDiff line change
@@ -62,14 +62,12 @@ struct TritonIntelTensorDescToBlockPointer
6262

6363
moduleOp->walk<WalkOrder::PreOrder>([&](Operation *op) {
6464
return TypeSwitch<Operation *, WalkResult>(op)
65-
#if 1
6665
.Case<tt::MakeTensorDescOp>([&](auto makeTensorDescOp) {
6766
if (failed(rewriteMakeTensorDescriptorOp(makeTensorDescOp)))
6867
makeTensorDescOp->emitRemark(
6968
"TritonIntelTensorDescToBlockPointer: Failed to rewrite");
7069
return WalkResult::advance();
7170
})
72-
#endif
7371
.Case<tt::DescriptorLoadOp, tt::DescriptorStoreOp>(
7472
[&](auto loadOrStoreOp) {
7573
if (failed(rewriteDescriptorLoadOrStoreOp(loadOrStoreOp)))
@@ -85,106 +83,13 @@ struct TritonIntelTensorDescToBlockPointer
8583
}
8684

8785
private:
88-
tt::MakeTensorDescOp getMakeTensorDescOp(Value base) const {
89-
assert(base && isa<tt::TensorDescType>(base.getType()) &&
90-
"Expecting tensor desc");
91-
92-
Operation *defOp = base.getDefiningOp();
93-
if (!defOp) {
94-
BlockArgument blockArg = cast<BlockArgument>(base);
95-
Operation *parentOp = blockArg.getOwner()->getParentOp();
96-
if (scf::ForOp forOp = dyn_cast<scf::ForOp>(parentOp)) {
97-
unsigned numIVs = forOp.getNumInductionVars();
98-
int initArgIdx = blockArg.getArgNumber() - numIVs;
99-
if (isModifiedInLoop(forOp, blockArg)) {
100-
LLVM_DEBUG(llvm::dbgs() << blockArg << " is loop variant\n");
101-
return nullptr;
102-
}
103-
Operation::operand_range initArgs = forOp.getInitArgs();
104-
assert(initArgIdx >= 0 && initArgIdx < initArgs.size() &&
105-
"Unexpected 'initArgIdx' value");
106-
return getMakeTensorDescOp(initArgs[initArgIdx]);
107-
}
108-
LLVM_DEBUG(llvm::dbgs()
109-
<< "TODO: Unhandled non operation: " << base << "\n");
110-
return nullptr;
111-
}
112-
113-
if (defOp->getNumRegions() != 0) {
114-
LLVM_DEBUG(llvm::dbgs() << "TODO: defOp with region: " << *defOp << "\n");
115-
return nullptr;
116-
}
117-
if (auto makeTensorDescOp = dyn_cast<tt::MakeTensorDescOp>(defOp))
118-
return makeTensorDescOp;
119-
120-
llvm_unreachable("TODO: Unhandled defOp kind");
121-
return nullptr;
122-
}
123-
124-
bool isModifiedInLoop(scf::ForOp forOp, BlockArgument &blockArg) const {
125-
unsigned argNo = blockArg.getArgNumber();
126-
unsigned numIVs = forOp.getNumInductionVars();
127-
int initArgIdx = blockArg.getArgNumber() - numIVs;
128-
Value yieldedVal = forOp.getYieldedValues()[initArgIdx];
129-
return (yieldedVal != blockArg);
130-
}
131-
13286
// Create a new block pointer if a suitable one doesn't already exist.
13387
// Otherwise, return the existing one. The function takes the base, shape,
13488
// strides, offsets, sizes of the block pointer to create/lookup and its
13589
// tensor element type (to ensure the block pointer has the tensor layout).
13690
Value findOrCreateMakeTensorPtr(Location loc, Value base, ValueRange shape,
13791
ValueRange strides, ValueRange offsets,
138-
ArrayRef<int32_t> sizes,
139-
RankedTensorType tensorType,
140-
OpBuilder &builder) {
141-
Block *block = builder.getInsertionBlock();
142-
const Block::iterator insertPoint = builder.getInsertionPoint();
143-
auto ptrType = tt::PointerType::get(
144-
tensorType, tt::TritonGEN::TritonGENMemorySpace::kCrossWorkgroup);
145-
146-
auto it = std::find_if(block->begin(), insertPoint, [&](Operation &op) {
147-
if (auto makeTensorPtrOp = dyn_cast<tt::MakeTensorPtrOp>(op)) {
148-
triton::PointerType resType = makeTensorPtrOp.getResult().getType();
149-
auto tensorType = cast<RankedTensorType>(resType.getPointeeType());
150-
auto sameShape = [](ArrayRef<int64_t> arr1, ArrayRef<int32_t> arr2) {
151-
for (auto [dim1, dim2] : llvm::zip(arr1, arr2)) {
152-
if (dim1 != dim2)
153-
return false;
154-
}
155-
return true;
156-
};
157-
158-
return makeTensorPtrOp.getType() == ptrType &&
159-
makeTensorPtrOp.getBase() == base &&
160-
makeTensorPtrOp.getShape() == shape &&
161-
makeTensorPtrOp.getStrides() == strides &&
162-
makeTensorPtrOp.getOffsets() == offsets &&
163-
sameShape(tensorType.getShape(), sizes);
164-
}
165-
return false;
166-
});
167-
168-
auto makeTensorPtrOp = [&]() {
169-
Value makeTensorPtr = builder.create<tt::MakeTensorPtrOp>(
170-
loc, base, shape, strides, offsets, sizes,
171-
builder.getDenseI32ArrayAttr({1, 0}));
172-
makeTensorPtr.setType(ptrType);
173-
return makeTensorPtr;
174-
};
175-
176-
return (it != insertPoint) ? cast<tt::MakeTensorPtrOp>(*it)
177-
: makeTensorPtrOp();
178-
}
179-
180-
// Create a new block pointer if a suitable one doesn't already exist.
181-
// Otherwise, return the existing one. The function takes the base, shape,
182-
// strides, offsets, sizes of the block pointer to create/lookup and its
183-
// tensor element type (to ensure the block pointer has the tensor layout).
184-
Value findOrCreateMakeTensorPtrTmp(Location loc, Value base, ValueRange shape,
185-
ValueRange strides, ValueRange offsets,
186-
ArrayRef<int32_t> sizes,
187-
OpBuilder &builder) {
92+
ArrayRef<int32_t> sizes, OpBuilder &builder) {
18893
Block *block = builder.getInsertionBlock();
18994
const Block::iterator insertPoint = builder.getInsertionPoint();
19095
auto it = std::find_if(block->begin(), insertPoint, [&](Operation &op) {
@@ -245,7 +150,7 @@ struct TritonIntelTensorDescToBlockPointer
245150
sizes.push_back(static_cast<int32_t>(size));
246151
}
247152

248-
Value tensorPtr = findOrCreateMakeTensorPtrTmp(
153+
Value tensorPtr = findOrCreateMakeTensorPtr(
249154
loc, op.getBase(), shapes, strides, offsets, sizes, builder);
250155
LLVM_DEBUG({
251156
llvm::dbgs() << "With:\n";
@@ -276,81 +181,6 @@ struct TritonIntelTensorDescToBlockPointer
276181
return success();
277182
}
278183

279-
template <typename OpTy,
280-
std::enable_if_t<llvm::is_one_of<OpTy, tt::DescriptorLoadOp,
281-
tt::DescriptorStoreOp>::value,
282-
bool> = true>
283-
LogicalResult rewriteDescriptorLoadOrStoreOpOld(OpTy op) {
284-
assert(op && "Expecting a valid operation");
285-
LLVM_DEBUG(llvm::dbgs() << "Rewriting: " << op << "\n");
286-
287-
OpBuilder builder(op);
288-
Location loc = op.getLoc();
289-
TypedValue<tt::TensorDescType> tDesc = op.getDesc();
290-
tt::TensorDescType tDescType = tDesc.getType();
291-
tt::MakeTensorDescOp makeTensorDescOp = getMakeTensorDescOp(tDesc);
292-
293-
if (!makeTensorDescOp) {
294-
LLVM_DEBUG(llvm::dbgs()
295-
<< "could not find tt.make_tensor_descriptor defining: "
296-
<< tDesc << "\n");
297-
return failure();
298-
}
299-
300-
LLVM_DEBUG(llvm::dbgs() << "which has tdesc: " << makeTensorDescOp << "\n");
301-
302-
// Create a new block pointer if a suitable one doesn't already exist.
303-
SmallVector<Value> shapes, strides, offsets;
304-
SmallVector<int32_t> sizes;
305-
for (const auto [shape, stride, offset, size] :
306-
llvm::zip(makeTensorDescOp.getShape(), makeTensorDescOp.getStrides(),
307-
op.getIndices(), tDescType.getBlockType().getShape())) {
308-
shapes.push_back(findOrCreateCast(
309-
loc, shape, builder.getIntegerType(shapeAndStridesBitwidth),
310-
builder));
311-
strides.push_back(findOrCreateCast(
312-
loc, stride, builder.getIntegerType(shapeAndStridesBitwidth),
313-
builder));
314-
offsets.push_back(findOrCreateCast(
315-
loc, offset, builder.getIntegerType(offsetBitwidth), builder));
316-
sizes.push_back(static_cast<int32_t>(size));
317-
}
318-
319-
constexpr bool isLoad = std::is_same_v<OpTy, tt::DescriptorLoadOp>;
320-
RankedTensorType tensorType;
321-
if constexpr (isLoad)
322-
tensorType = op.getResult().getType();
323-
else
324-
tensorType = op.getSrc().getType();
325-
326-
Value makeTensorPtrOp =
327-
findOrCreateMakeTensorPtr(loc, makeTensorDescOp.getBase(), shapes,
328-
strides, offsets, sizes, tensorType, builder);
329-
330-
LLVM_DEBUG({
331-
llvm::dbgs() << "With:\n";
332-
llvm::dbgs().indent(2) << makeTensorPtrOp << "\n";
333-
});
334-
335-
if constexpr (isLoad) {
336-
auto loadOp = builder.createOrFold<tt::LoadOp>(
337-
loc, makeTensorPtrOp, op.getCache(), op.getEvict(),
338-
/*volatile*/ false);
339-
LLVM_DEBUG(llvm::dbgs().indent(2) << loadOp << "\n");
340-
op.replaceAllUsesWith(loadOp);
341-
} else {
342-
[[maybe_unused]] auto storeOp = builder.createOrFold<tt::StoreOp>(
343-
loc, makeTensorPtrOp, op.getSrc(), tt::CacheModifier::NONE,
344-
tt::EvictionPolicy::NORMAL);
345-
LLVM_DEBUG(llvm::dbgs().indent(2) << storeOp << "\n");
346-
}
347-
348-
cleanUp.insert(op);
349-
cleanUp.insert(makeTensorDescOp);
350-
351-
return success();
352-
}
353-
354184
template <typename OpTy,
355185
std::enable_if_t<llvm::is_one_of<OpTy, tt::DescriptorLoadOp,
356186
tt::DescriptorStoreOp>::value,

third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp

+7-18
Original file line numberDiff line numberDiff line change
@@ -695,7 +695,7 @@ void LayoutPropagation::rewriteAssertOp(AssertOp assertOp) {
695695
// Recursively update the operands in a chain of AdvanceOps, after setting the
696696
// pointer operand of the first one.
697697
static void updateAdvanceOpChain(AdvanceOp advanceOp, Value makeTensorPtrOp,
698-
Value data) {
698+
Value dataToStore) {
699699
OpBuilder rewriter(advanceOp);
700700
auto newAdvanceOp =
701701
rewriter.create<AdvanceOp>(advanceOp.getLoc(), makeTensorPtrOp.getType(),
@@ -704,12 +704,10 @@ static void updateAdvanceOpChain(AdvanceOp advanceOp, Value makeTensorPtrOp,
704704
SmallVector<Operation *> advanceOpUsers(advanceOp->getUsers());
705705
for (Operation *user : advanceOpUsers) {
706706
if (auto storeOp = dyn_cast<StoreOp>(user)) {
707-
// Update the StoreOp operands.
708707
storeOp.setOperand(0, newAdvanceOp);
709-
storeOp.setOperand(1, data);
710-
} else if (auto nextAdvanceOp = dyn_cast<AdvanceOp>(user)) {
711-
// Recursive call to handle the next AdvanceOp in the chain.
712-
updateAdvanceOpChain(nextAdvanceOp, makeTensorPtrOp, data);
708+
storeOp.setOperand(1, dataToStore);
709+
} else if (auto advanceOp = dyn_cast<AdvanceOp>(user)) {
710+
updateAdvanceOpChain(advanceOp, makeTensorPtrOp, dataToStore);
713711
}
714712
}
715713
}
@@ -790,26 +788,17 @@ bool LayoutPropagation::rewriteStoreOp(StoreOp storeOp) {
790788
makeTensorPtrOp.getShape(), makeTensorPtrOp.getStrides(),
791789
makeTensorPtrOp.getOffsets(), makeTensorPtrOp.getOrderAttr());
792790

793-
#if 1
794791
// Update the store operation with the new layout.
795792
SmallVector<Operation *> makeTensorPtrOpUsers(makeTensorPtrOp->getUsers());
793+
auto dataToStore = getValueAs(value, encoding);
796794
for (Operation *user : makeTensorPtrOpUsers) {
797795
if (auto storeOp = dyn_cast<StoreOp>(user)) {
798796
storeOp.setOperand(0, newMakeTensorPtrOp);
799-
storeOp.setOperand(1, getValueAs(value, encoding));
797+
storeOp.setOperand(1, dataToStore);
800798
} else if (auto advanceOp = dyn_cast<AdvanceOp>(user)) {
801-
updateAdvanceOpChain(advanceOp, newMakeTensorPtrOp,
802-
getValueAs(value, encoding));
799+
updateAdvanceOpChain(advanceOp, newMakeTensorPtrOp, dataToStore);
803800
}
804801
}
805-
#else
806-
// The encoding of the StoreOp is updated with the new operands:
807-
// - the Ptr created by the MakeTensorPtrOp with the new data type
808-
// - the forwarded DPAS encoding.
809-
Value newOperand = getValueAs(value, encoding);
810-
storeOp.setOperand(0, newMakeTensorPtrOp);
811-
storeOp.setOperand(1, newOperand);
812-
#endif
813802

814803
// If the DPAS encoding is forwarded, we do not need the
815804
// convertOp anymore if the convertOp was only used by the

0 commit comments

Comments
 (0)