@@ -145,11 +145,45 @@ struct ArgsUsageInLoop {
145
145
};
146
146
} // namespace
147
147
148
- static fir::SequenceType getAsSequenceType (mlir::Value * v) {
149
- mlir::Type argTy = fir::unwrapPassByRefType (fir::unwrapRefType (v-> getType ()));
148
+ static fir::SequenceType getAsSequenceType (mlir::Value v) {
149
+ mlir::Type argTy = fir::unwrapPassByRefType (fir::unwrapRefType (v. getType ()));
150
150
return mlir::dyn_cast<fir::SequenceType>(argTy);
151
151
}
152
152
153
+ // / Return the rank and the element size (in bytes) of the given
154
+ // / value \p v. If it is not an array or the element type is not
155
+ // / supported, then return <0, 0>. Only trivial data types
156
+ // / are currently supported.
157
+ // / When \p isArgument is true, \p v is assumed to be a function
158
+ // / argument. If \p v's type does not look like a type of an assumed
159
+ // / shape array, then the function returns <0, 0>.
160
+ // / When \p isArgument is false, array types with known innermost
161
+ // / dimension are allowed to proceed.
162
+ static std::pair<unsigned , size_t >
163
+ getRankAndElementSize (const fir::KindMapping &kindMap,
164
+ const mlir::DataLayout &dl, mlir::Value v,
165
+ bool isArgument = false ) {
166
+ if (auto seqTy = getAsSequenceType (v)) {
167
+ unsigned rank = seqTy.getDimension ();
168
+ if (rank > 0 &&
169
+ (!isArgument ||
170
+ seqTy.getShape ()[0 ] == fir::SequenceType::getUnknownExtent ())) {
171
+ size_t typeSize = 0 ;
172
+ mlir::Type elementType = fir::unwrapSeqOrBoxedSeqType (v.getType ());
173
+ if (fir::isa_trivial (elementType)) {
174
+ auto [eleSize, eleAlign] = fir::getTypeSizeAndAlignmentOrCrash (
175
+ v.getLoc (), elementType, dl, kindMap);
176
+ typeSize = llvm::alignTo (eleSize, eleAlign);
177
+ }
178
+ if (typeSize)
179
+ return {rank, typeSize};
180
+ }
181
+ }
182
+
183
+ LLVM_DEBUG (llvm::dbgs () << " Unsupported rank/type: " << v << ' \n ' );
184
+ return {0 , 0 };
185
+ }
186
+
153
187
// / if a value comes from a fir.declare, follow it to the original source,
154
188
// / otherwise return the value
155
189
static mlir::Value unwrapFirDeclare (mlir::Value val) {
@@ -160,12 +194,48 @@ static mlir::Value unwrapFirDeclare(mlir::Value val) {
160
194
return val;
161
195
}
162
196
197
+ // / Return true, if \p rebox operation keeps the input array
198
+ // / continuous in the innermost dimension, if it is initially continuous
199
+ // / in the innermost dimension.
200
+ static bool reboxPreservesContinuity (fir::ReboxOp rebox) {
201
+ // If slicing is not involved, then the rebox does not affect
202
+ // the continuity of the array.
203
+ auto sliceArg = rebox.getSlice ();
204
+ if (!sliceArg)
205
+ return true ;
206
+
207
+ // A slice with step=1 in the innermost dimension preserves
208
+ // the continuity of the array in the innermost dimension.
209
+ if (auto sliceOp =
210
+ mlir::dyn_cast_or_null<fir::SliceOp>(sliceArg.getDefiningOp ())) {
211
+ if (sliceOp.getFields ().empty () && sliceOp.getSubstr ().empty ()) {
212
+ auto triples = sliceOp.getTriples ();
213
+ if (triples.size () > 2 )
214
+ if (auto innermostStep = fir::getIntIfConstant (triples[2 ]))
215
+ if (*innermostStep == 1 )
216
+ return true ;
217
+ }
218
+
219
+ LLVM_DEBUG (llvm::dbgs ()
220
+ << " REBOX with slicing may produce non-contiguous array: "
221
+ << sliceOp << ' \n '
222
+ << rebox << ' \n ' );
223
+ return false ;
224
+ }
225
+
226
+ LLVM_DEBUG (llvm::dbgs () << " REBOX with unknown slice" << sliceArg << ' \n '
227
+ << rebox << ' \n ' );
228
+ return false ;
229
+ }
230
+
163
231
// / if a value comes from a fir.rebox, follow the rebox to the original source,
164
232
// / of the value, otherwise return the value
165
233
static mlir::Value unwrapReboxOp (mlir::Value val) {
166
- // don't support reboxes of reboxes
167
- if (fir::ReboxOp rebox = val.getDefiningOp <fir::ReboxOp>())
234
+ while (fir::ReboxOp rebox = val.getDefiningOp <fir::ReboxOp>()) {
235
+ if (!reboxPreservesContinuity (rebox))
236
+ break ;
168
237
val = rebox.getBox ();
238
+ }
169
239
return val;
170
240
}
171
241
@@ -257,25 +327,10 @@ void LoopVersioningPass::runOnOperation() {
257
327
continue ;
258
328
}
259
329
260
- if (auto seqTy = getAsSequenceType (&arg)) {
261
- unsigned rank = seqTy.getDimension ();
262
- if (rank > 0 &&
263
- seqTy.getShape ()[0 ] == fir::SequenceType::getUnknownExtent ()) {
264
- size_t typeSize = 0 ;
265
- mlir::Type elementType = fir::unwrapSeqOrBoxedSeqType (arg.getType ());
266
- if (mlir::isa<mlir::FloatType>(elementType) ||
267
- mlir::isa<mlir::IntegerType>(elementType) ||
268
- mlir::isa<mlir::ComplexType>(elementType)) {
269
- auto [eleSize, eleAlign] = fir::getTypeSizeAndAlignmentOrCrash (
270
- arg.getLoc (), elementType, *dl, kindMap);
271
- typeSize = llvm::alignTo (eleSize, eleAlign);
272
- }
273
- if (typeSize)
274
- argsOfInterest.push_back ({arg, typeSize, rank, {}});
275
- else
276
- LLVM_DEBUG (llvm::dbgs () << " Type not supported\n " );
277
- }
278
- }
330
+ auto [rank, typeSize] =
331
+ getRankAndElementSize (kindMap, *dl, arg, /* isArgument=*/ true );
332
+ if (rank != 0 && typeSize != 0 )
333
+ argsOfInterest.push_back ({arg, typeSize, rank, {}});
279
334
}
280
335
281
336
if (argsOfInterest.empty ()) {
@@ -326,6 +381,13 @@ void LoopVersioningPass::runOnOperation() {
326
381
if (arrayCoor.getSlice ())
327
382
argsInLoop.cannotTransform .insert (a.arg );
328
383
384
+ // We need to compute the rank and element size
385
+ // based on the operand, not the original argument,
386
+ // because array slicing may affect it.
387
+ std::tie (a.rank , a.size ) = getRankAndElementSize (kindMap, *dl, a.arg );
388
+ if (a.rank == 0 || a.size == 0 )
389
+ argsInLoop.cannotTransform .insert (a.arg );
390
+
329
391
if (argsInLoop.cannotTransform .contains (a.arg )) {
330
392
// Remove any previously recorded usage, if any.
331
393
argsInLoop.usageInfo .erase (a.arg );
@@ -416,8 +478,8 @@ void LoopVersioningPass::runOnOperation() {
416
478
mlir::Location loc = builder.getUnknownLoc ();
417
479
mlir::IndexType idxTy = builder.getIndexType ();
418
480
419
- LLVM_DEBUG (llvm::dbgs () << " Module Before transformation:" );
420
- LLVM_DEBUG (module ->dump ());
481
+ LLVM_DEBUG (llvm::dbgs () << " Func Before transformation:\n " );
482
+ LLVM_DEBUG (func ->dump ());
421
483
422
484
LLVM_DEBUG (llvm::dbgs () << " loopsOfInterest: " << loopsOfInterest.size ()
423
485
<< " \n " );
@@ -551,8 +613,8 @@ void LoopVersioningPass::runOnOperation() {
551
613
}
552
614
}
553
615
554
- LLVM_DEBUG (llvm::dbgs () << " After transform:\n " );
555
- LLVM_DEBUG (module ->dump ());
616
+ LLVM_DEBUG (llvm::dbgs () << " Func After transform:\n " );
617
+ LLVM_DEBUG (func ->dump ());
556
618
557
619
LLVM_DEBUG (llvm::dbgs () << " === End " DEBUG_TYPE " ===\n " );
558
620
}
0 commit comments