Skip to content

Commit 711419e

Browse files
authored
[flang] Enable loop-versioning for slices. (llvm#120344)
Loops resulting from array expressions like array(:,i) may be versioned for the unit stride of the innermost dimension, when the initial array is an assumed-shape array (which are contiguous in many Fortran programs). This speeds up facerec for about 12% due to further vectorization of the innermost loop produced for the total SUM reduction.
1 parent e3f8c22 commit 711419e

File tree

2 files changed

+442
-27
lines changed

2 files changed

+442
-27
lines changed

flang/lib/Optimizer/Transforms/LoopVersioning.cpp

+89-27
Original file line numberDiff line numberDiff line change
@@ -145,11 +145,45 @@ struct ArgsUsageInLoop {
145145
};
146146
} // namespace
147147

148-
static fir::SequenceType getAsSequenceType(mlir::Value *v) {
149-
mlir::Type argTy = fir::unwrapPassByRefType(fir::unwrapRefType(v->getType()));
148+
static fir::SequenceType getAsSequenceType(mlir::Value v) {
149+
mlir::Type argTy = fir::unwrapPassByRefType(fir::unwrapRefType(v.getType()));
150150
return mlir::dyn_cast<fir::SequenceType>(argTy);
151151
}
152152

153+
/// Return the rank and the element size (in bytes) of the given
154+
/// value \p v. If it is not an array or the element type is not
155+
/// supported, then return <0, 0>. Only trivial data types
156+
/// are currently supported.
157+
/// When \p isArgument is true, \p v is assumed to be a function
158+
/// argument. If \p v's type does not look like a type of an assumed
159+
/// shape array, then the function returns <0, 0>.
160+
/// When \p isArgument is false, array types with known innermost
161+
/// dimension are allowed to proceed.
162+
static std::pair<unsigned, size_t>
163+
getRankAndElementSize(const fir::KindMapping &kindMap,
164+
const mlir::DataLayout &dl, mlir::Value v,
165+
bool isArgument = false) {
166+
if (auto seqTy = getAsSequenceType(v)) {
167+
unsigned rank = seqTy.getDimension();
168+
if (rank > 0 &&
169+
(!isArgument ||
170+
seqTy.getShape()[0] == fir::SequenceType::getUnknownExtent())) {
171+
size_t typeSize = 0;
172+
mlir::Type elementType = fir::unwrapSeqOrBoxedSeqType(v.getType());
173+
if (fir::isa_trivial(elementType)) {
174+
auto [eleSize, eleAlign] = fir::getTypeSizeAndAlignmentOrCrash(
175+
v.getLoc(), elementType, dl, kindMap);
176+
typeSize = llvm::alignTo(eleSize, eleAlign);
177+
}
178+
if (typeSize)
179+
return {rank, typeSize};
180+
}
181+
}
182+
183+
LLVM_DEBUG(llvm::dbgs() << "Unsupported rank/type: " << v << '\n');
184+
return {0, 0};
185+
}
186+
153187
/// if a value comes from a fir.declare, follow it to the original source,
154188
/// otherwise return the value
155189
static mlir::Value unwrapFirDeclare(mlir::Value val) {
@@ -160,12 +194,48 @@ static mlir::Value unwrapFirDeclare(mlir::Value val) {
160194
return val;
161195
}
162196

197+
/// Return true, if \p rebox operation keeps the input array
198+
/// continuous in the innermost dimension, if it is initially continuous
199+
/// in the innermost dimension.
200+
static bool reboxPreservesContinuity(fir::ReboxOp rebox) {
201+
// If slicing is not involved, then the rebox does not affect
202+
// the continuity of the array.
203+
auto sliceArg = rebox.getSlice();
204+
if (!sliceArg)
205+
return true;
206+
207+
// A slice with step=1 in the innermost dimension preserves
208+
// the continuity of the array in the innermost dimension.
209+
if (auto sliceOp =
210+
mlir::dyn_cast_or_null<fir::SliceOp>(sliceArg.getDefiningOp())) {
211+
if (sliceOp.getFields().empty() && sliceOp.getSubstr().empty()) {
212+
auto triples = sliceOp.getTriples();
213+
if (triples.size() > 2)
214+
if (auto innermostStep = fir::getIntIfConstant(triples[2]))
215+
if (*innermostStep == 1)
216+
return true;
217+
}
218+
219+
LLVM_DEBUG(llvm::dbgs()
220+
<< "REBOX with slicing may produce non-contiguous array: "
221+
<< sliceOp << '\n'
222+
<< rebox << '\n');
223+
return false;
224+
}
225+
226+
LLVM_DEBUG(llvm::dbgs() << "REBOX with unknown slice" << sliceArg << '\n'
227+
<< rebox << '\n');
228+
return false;
229+
}
230+
163231
/// if a value comes from a fir.rebox, follow the rebox to the original source,
164232
/// of the value, otherwise return the value
165233
static mlir::Value unwrapReboxOp(mlir::Value val) {
166-
// don't support reboxes of reboxes
167-
if (fir::ReboxOp rebox = val.getDefiningOp<fir::ReboxOp>())
234+
while (fir::ReboxOp rebox = val.getDefiningOp<fir::ReboxOp>()) {
235+
if (!reboxPreservesContinuity(rebox))
236+
break;
168237
val = rebox.getBox();
238+
}
169239
return val;
170240
}
171241

@@ -257,25 +327,10 @@ void LoopVersioningPass::runOnOperation() {
257327
continue;
258328
}
259329

260-
if (auto seqTy = getAsSequenceType(&arg)) {
261-
unsigned rank = seqTy.getDimension();
262-
if (rank > 0 &&
263-
seqTy.getShape()[0] == fir::SequenceType::getUnknownExtent()) {
264-
size_t typeSize = 0;
265-
mlir::Type elementType = fir::unwrapSeqOrBoxedSeqType(arg.getType());
266-
if (mlir::isa<mlir::FloatType>(elementType) ||
267-
mlir::isa<mlir::IntegerType>(elementType) ||
268-
mlir::isa<mlir::ComplexType>(elementType)) {
269-
auto [eleSize, eleAlign] = fir::getTypeSizeAndAlignmentOrCrash(
270-
arg.getLoc(), elementType, *dl, kindMap);
271-
typeSize = llvm::alignTo(eleSize, eleAlign);
272-
}
273-
if (typeSize)
274-
argsOfInterest.push_back({arg, typeSize, rank, {}});
275-
else
276-
LLVM_DEBUG(llvm::dbgs() << "Type not supported\n");
277-
}
278-
}
330+
auto [rank, typeSize] =
331+
getRankAndElementSize(kindMap, *dl, arg, /*isArgument=*/true);
332+
if (rank != 0 && typeSize != 0)
333+
argsOfInterest.push_back({arg, typeSize, rank, {}});
279334
}
280335

281336
if (argsOfInterest.empty()) {
@@ -326,6 +381,13 @@ void LoopVersioningPass::runOnOperation() {
326381
if (arrayCoor.getSlice())
327382
argsInLoop.cannotTransform.insert(a.arg);
328383

384+
// We need to compute the rank and element size
385+
// based on the operand, not the original argument,
386+
// because array slicing may affect it.
387+
std::tie(a.rank, a.size) = getRankAndElementSize(kindMap, *dl, a.arg);
388+
if (a.rank == 0 || a.size == 0)
389+
argsInLoop.cannotTransform.insert(a.arg);
390+
329391
if (argsInLoop.cannotTransform.contains(a.arg)) {
330392
// Remove any previously recorded usage, if any.
331393
argsInLoop.usageInfo.erase(a.arg);
@@ -416,8 +478,8 @@ void LoopVersioningPass::runOnOperation() {
416478
mlir::Location loc = builder.getUnknownLoc();
417479
mlir::IndexType idxTy = builder.getIndexType();
418480

419-
LLVM_DEBUG(llvm::dbgs() << "Module Before transformation:");
420-
LLVM_DEBUG(module->dump());
481+
LLVM_DEBUG(llvm::dbgs() << "Func Before transformation:\n");
482+
LLVM_DEBUG(func->dump());
421483

422484
LLVM_DEBUG(llvm::dbgs() << "loopsOfInterest: " << loopsOfInterest.size()
423485
<< "\n");
@@ -551,8 +613,8 @@ void LoopVersioningPass::runOnOperation() {
551613
}
552614
}
553615

554-
LLVM_DEBUG(llvm::dbgs() << "After transform:\n");
555-
LLVM_DEBUG(module->dump());
616+
LLVM_DEBUG(llvm::dbgs() << "Func After transform:\n");
617+
LLVM_DEBUG(func->dump());
556618

557619
LLVM_DEBUG(llvm::dbgs() << "=== End " DEBUG_TYPE " ===\n");
558620
}

0 commit comments

Comments
 (0)