@@ -62,14 +62,12 @@ struct TritonIntelTensorDescToBlockPointer
62
62
63
63
moduleOp->walk <WalkOrder::PreOrder>([&](Operation *op) {
64
64
return TypeSwitch<Operation *, WalkResult>(op)
65
- #if 1
66
65
.Case <tt::MakeTensorDescOp>([&](auto makeTensorDescOp) {
67
66
if (failed (rewriteMakeTensorDescriptorOp (makeTensorDescOp)))
68
67
makeTensorDescOp->emitRemark (
69
68
" TritonIntelTensorDescToBlockPointer: Failed to rewrite" );
70
69
return WalkResult::advance ();
71
70
})
72
- #endif
73
71
.Case <tt::DescriptorLoadOp, tt::DescriptorStoreOp>(
74
72
[&](auto loadOrStoreOp) {
75
73
if (failed (rewriteDescriptorLoadOrStoreOp (loadOrStoreOp)))
@@ -85,106 +83,13 @@ struct TritonIntelTensorDescToBlockPointer
85
83
}
86
84
87
85
private:
88
- tt::MakeTensorDescOp getMakeTensorDescOp (Value base) const {
89
- assert (base && isa<tt::TensorDescType>(base.getType ()) &&
90
- " Expecting tensor desc" );
91
-
92
- Operation *defOp = base.getDefiningOp ();
93
- if (!defOp) {
94
- BlockArgument blockArg = cast<BlockArgument>(base);
95
- Operation *parentOp = blockArg.getOwner ()->getParentOp ();
96
- if (scf::ForOp forOp = dyn_cast<scf::ForOp>(parentOp)) {
97
- unsigned numIVs = forOp.getNumInductionVars ();
98
- int initArgIdx = blockArg.getArgNumber () - numIVs;
99
- if (isModifiedInLoop (forOp, blockArg)) {
100
- LLVM_DEBUG (llvm::dbgs () << blockArg << " is loop variant\n " );
101
- return nullptr ;
102
- }
103
- Operation::operand_range initArgs = forOp.getInitArgs ();
104
- assert (initArgIdx >= 0 && initArgIdx < initArgs.size () &&
105
- " Unexpected 'initArgIdx' value" );
106
- return getMakeTensorDescOp (initArgs[initArgIdx]);
107
- }
108
- LLVM_DEBUG (llvm::dbgs ()
109
- << " TODO: Unhandled non operation: " << base << " \n " );
110
- return nullptr ;
111
- }
112
-
113
- if (defOp->getNumRegions () != 0 ) {
114
- LLVM_DEBUG (llvm::dbgs () << " TODO: defOp with region: " << *defOp << " \n " );
115
- return nullptr ;
116
- }
117
- if (auto makeTensorDescOp = dyn_cast<tt::MakeTensorDescOp>(defOp))
118
- return makeTensorDescOp;
119
-
120
- llvm_unreachable (" TODO: Unhandled defOp kind" );
121
- return nullptr ;
122
- }
123
-
124
- bool isModifiedInLoop (scf::ForOp forOp, BlockArgument &blockArg) const {
125
- unsigned argNo = blockArg.getArgNumber ();
126
- unsigned numIVs = forOp.getNumInductionVars ();
127
- int initArgIdx = blockArg.getArgNumber () - numIVs;
128
- Value yieldedVal = forOp.getYieldedValues ()[initArgIdx];
129
- return (yieldedVal != blockArg);
130
- }
131
-
132
86
// Create a new block pointer if a suitable one doesn't already exist.
133
87
// Otherwise, return the existing one. The function takes the base, shape,
134
88
// strides, offsets, sizes of the block pointer to create/lookup and its
135
89
// tensor element type (to ensure the block pointer has the tensor layout).
136
90
Value findOrCreateMakeTensorPtr (Location loc, Value base, ValueRange shape,
137
91
ValueRange strides, ValueRange offsets,
138
- ArrayRef<int32_t > sizes,
139
- RankedTensorType tensorType,
140
- OpBuilder &builder) {
141
- Block *block = builder.getInsertionBlock ();
142
- const Block::iterator insertPoint = builder.getInsertionPoint ();
143
- auto ptrType = tt::PointerType::get (
144
- tensorType, tt::TritonGEN::TritonGENMemorySpace::kCrossWorkgroup );
145
-
146
- auto it = std::find_if (block->begin (), insertPoint, [&](Operation &op) {
147
- if (auto makeTensorPtrOp = dyn_cast<tt::MakeTensorPtrOp>(op)) {
148
- triton::PointerType resType = makeTensorPtrOp.getResult ().getType ();
149
- auto tensorType = cast<RankedTensorType>(resType.getPointeeType ());
150
- auto sameShape = [](ArrayRef<int64_t > arr1, ArrayRef<int32_t > arr2) {
151
- for (auto [dim1, dim2] : llvm::zip (arr1, arr2)) {
152
- if (dim1 != dim2)
153
- return false ;
154
- }
155
- return true ;
156
- };
157
-
158
- return makeTensorPtrOp.getType () == ptrType &&
159
- makeTensorPtrOp.getBase () == base &&
160
- makeTensorPtrOp.getShape () == shape &&
161
- makeTensorPtrOp.getStrides () == strides &&
162
- makeTensorPtrOp.getOffsets () == offsets &&
163
- sameShape (tensorType.getShape (), sizes);
164
- }
165
- return false ;
166
- });
167
-
168
- auto makeTensorPtrOp = [&]() {
169
- Value makeTensorPtr = builder.create <tt::MakeTensorPtrOp>(
170
- loc, base, shape, strides, offsets, sizes,
171
- builder.getDenseI32ArrayAttr ({1 , 0 }));
172
- makeTensorPtr.setType (ptrType);
173
- return makeTensorPtr;
174
- };
175
-
176
- return (it != insertPoint) ? cast<tt::MakeTensorPtrOp>(*it)
177
- : makeTensorPtrOp ();
178
- }
179
-
180
- // Create a new block pointer if a suitable one doesn't already exist.
181
- // Otherwise, return the existing one. The function takes the base, shape,
182
- // strides, offsets, sizes of the block pointer to create/lookup and its
183
- // tensor element type (to ensure the block pointer has the tensor layout).
184
- Value findOrCreateMakeTensorPtrTmp (Location loc, Value base, ValueRange shape,
185
- ValueRange strides, ValueRange offsets,
186
- ArrayRef<int32_t > sizes,
187
- OpBuilder &builder) {
92
+ ArrayRef<int32_t > sizes, OpBuilder &builder) {
188
93
Block *block = builder.getInsertionBlock ();
189
94
const Block::iterator insertPoint = builder.getInsertionPoint ();
190
95
auto it = std::find_if (block->begin (), insertPoint, [&](Operation &op) {
@@ -245,7 +150,7 @@ struct TritonIntelTensorDescToBlockPointer
245
150
sizes.push_back (static_cast <int32_t >(size));
246
151
}
247
152
248
- Value tensorPtr = findOrCreateMakeTensorPtrTmp (
153
+ Value tensorPtr = findOrCreateMakeTensorPtr (
249
154
loc, op.getBase (), shapes, strides, offsets, sizes, builder);
250
155
LLVM_DEBUG ({
251
156
llvm::dbgs () << " With:\n " ;
@@ -276,81 +181,6 @@ struct TritonIntelTensorDescToBlockPointer
276
181
return success ();
277
182
}
278
183
279
- template <typename OpTy,
280
- std::enable_if_t <llvm::is_one_of<OpTy, tt::DescriptorLoadOp,
281
- tt::DescriptorStoreOp>::value,
282
- bool > = true >
283
- LogicalResult rewriteDescriptorLoadOrStoreOpOld (OpTy op) {
284
- assert (op && " Expecting a valid operation" );
285
- LLVM_DEBUG (llvm::dbgs () << " Rewriting: " << op << " \n " );
286
-
287
- OpBuilder builder (op);
288
- Location loc = op.getLoc ();
289
- TypedValue<tt::TensorDescType> tDesc = op.getDesc ();
290
- tt::TensorDescType tDescType = tDesc.getType ();
291
- tt::MakeTensorDescOp makeTensorDescOp = getMakeTensorDescOp (tDesc);
292
-
293
- if (!makeTensorDescOp) {
294
- LLVM_DEBUG (llvm::dbgs ()
295
- << " could not find tt.make_tensor_descriptor defining: "
296
- << tDesc << " \n " );
297
- return failure ();
298
- }
299
-
300
- LLVM_DEBUG (llvm::dbgs () << " which has tdesc: " << makeTensorDescOp << " \n " );
301
-
302
- // Create a new block pointer if a suitable one doesn't already exist.
303
- SmallVector<Value> shapes, strides, offsets;
304
- SmallVector<int32_t > sizes;
305
- for (const auto [shape, stride, offset, size] :
306
- llvm::zip (makeTensorDescOp.getShape (), makeTensorDescOp.getStrides (),
307
- op.getIndices (), tDescType.getBlockType ().getShape ())) {
308
- shapes.push_back (findOrCreateCast (
309
- loc, shape, builder.getIntegerType (shapeAndStridesBitwidth),
310
- builder));
311
- strides.push_back (findOrCreateCast (
312
- loc, stride, builder.getIntegerType (shapeAndStridesBitwidth),
313
- builder));
314
- offsets.push_back (findOrCreateCast (
315
- loc, offset, builder.getIntegerType (offsetBitwidth), builder));
316
- sizes.push_back (static_cast <int32_t >(size));
317
- }
318
-
319
- constexpr bool isLoad = std::is_same_v<OpTy, tt::DescriptorLoadOp>;
320
- RankedTensorType tensorType;
321
- if constexpr (isLoad)
322
- tensorType = op.getResult ().getType ();
323
- else
324
- tensorType = op.getSrc ().getType ();
325
-
326
- Value makeTensorPtrOp =
327
- findOrCreateMakeTensorPtr (loc, makeTensorDescOp.getBase (), shapes,
328
- strides, offsets, sizes, tensorType, builder);
329
-
330
- LLVM_DEBUG ({
331
- llvm::dbgs () << " With:\n " ;
332
- llvm::dbgs ().indent (2 ) << makeTensorPtrOp << " \n " ;
333
- });
334
-
335
- if constexpr (isLoad) {
336
- auto loadOp = builder.createOrFold <tt::LoadOp>(
337
- loc, makeTensorPtrOp, op.getCache (), op.getEvict (),
338
- /* volatile*/ false );
339
- LLVM_DEBUG (llvm::dbgs ().indent (2 ) << loadOp << " \n " );
340
- op.replaceAllUsesWith (loadOp);
341
- } else {
342
- [[maybe_unused]] auto storeOp = builder.createOrFold <tt::StoreOp>(
343
- loc, makeTensorPtrOp, op.getSrc (), tt::CacheModifier::NONE,
344
- tt::EvictionPolicy::NORMAL);
345
- LLVM_DEBUG (llvm::dbgs ().indent (2 ) << storeOp << " \n " );
346
- }
347
-
348
- cleanUp.insert (op);
349
- cleanUp.insert (makeTensorDescOp);
350
-
351
- return success ();
352
- }
353
-
354
184
template <typename OpTy,
355
185
std::enable_if_t <llvm::is_one_of<OpTy, tt::DescriptorLoadOp,
356
186
tt::DescriptorStoreOp>::value,
0 commit comments