Skip to content

Commit ed2a10b

Browse files
[Flang][OpenMP] Erase trip count for generic kernels (#125)
We determine the type of the kernel based on the MLIR code, which changes during the lowering phase. Some kernels, such as those with multiple workshare loops, are initially classified as SPMD kernels, but are later recognized as generic kernels during PFT lowering. In such cases, we need to identify the change in type and clear the trip count if it was previously set.
1 parent b3b35ea commit ed2a10b

File tree

1 file changed

+19
-14
lines changed

1 file changed

+19
-14
lines changed

flang/lib/Lower/OpenMP/OpenMP.cpp

+19-14
Original file line numberDiff line numberDiff line change
@@ -1681,21 +1681,26 @@ genLoopNestOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
16811681
firOpBuilder.getModule().getOperation());
16821682
auto targetOp = loopNestOp->getParentOfType<mlir::omp::TargetOp>();
16831683

1684-
if (offloadMod && targetOp && !offloadMod.getIsTargetDevice() &&
1685-
targetOp.isTargetSPMDLoop()) {
1686-
// Lower loop bounds and step, and process collapsing again, putting lowered
1687-
// values outside of omp.target this time. This enables calculating and
1688-
// accessing the trip count in the host, which is needed when lowering to
1689-
// LLVM IR via the OMPIRBuilder.
1690-
HostClausesInsertionGuard guard(firOpBuilder);
1691-
mlir::omp::CollapseClauseOps collapseClauseOps;
1692-
llvm::SmallVector<const semantics::Symbol *> iv;
1693-
ClauseProcessor cp(converter, semaCtx, item->clauses);
1694-
cp.processCollapse(loc, eval, collapseClauseOps, iv);
1695-
targetOp.getTripCountMutable().assign(calculateTripCount(
1696-
converter.getFirOpBuilder(), loc, collapseClauseOps));
1684+
if (offloadMod && targetOp && !offloadMod.getIsTargetDevice()) {
1685+
if (targetOp.isTargetSPMDLoop()) {
1686+
// Lower loop bounds and step, and process collapsing again, putting
1687+
// lowered values outside of omp.target this time. This enables
1688+
// calculating and accessing the trip count in the host, which is needed
1689+
// when lowering to LLVM IR via the OMPIRBuilder.
1690+
HostClausesInsertionGuard guard(firOpBuilder);
1691+
mlir::omp::CollapseClauseOps collapseClauseOps;
1692+
llvm::SmallVector<const semantics::Symbol *> iv;
1693+
ClauseProcessor cp(converter, semaCtx, item->clauses);
1694+
cp.processCollapse(loc, eval, collapseClauseOps, iv);
1695+
targetOp.getTripCountMutable().assign(calculateTripCount(
1696+
converter.getFirOpBuilder(), loc, collapseClauseOps));
1697+
} else if (targetOp.getTripCountMutable().size()) {
1698+
// The MLIR target operation was updated during PFT lowering,
1699+
// and it is no longer an SPMD kernel. Erase the trip count because
1700+
// as it is now invalid.
1701+
targetOp.getTripCountMutable().erase(0);
1702+
}
16971703
}
1698-
16991704
return loopNestOp;
17001705
}
17011706

0 commit comments

Comments
 (0)