Skip to content

Commit 24980a6

Browse files
authored
[flang][OpenMP] Privatize locally destroyed values in do concurent (#112)
Collects values that are local to a loop: "loop-local values". A loop-local value is one that is used exclusively inside the loop but allocated outside of it. This usually corresponds to temporary values that are used inside the loop body for initialzing other variables for example. For a "loop-local" value within a loop's scope, localizes that value within the scope of the parallel region the loop maps to.
1 parent e8b67b7 commit 24980a6

File tree

3 files changed

+150
-9
lines changed

3 files changed

+150
-9
lines changed

flang/docs/DoConcurrentConversionToOpenMP.md

+10-3
Original file line numberDiff line numberDiff line change
@@ -234,9 +234,16 @@ see the "Data environment" section below.
234234
By default, variables that are used inside a `do concurernt` loop nest are
235235
either treated as `shared` in case of mapping to `host`, or mapped into the
236236
`target` region using a `map` clause in case of mapping to `device`. The only
237-
exception to this is the loop's iteration variable(s) (IV) of **perfect** loop
238-
nest. In that case, for each IV, we allocate a local copy as shown the by the
239-
mapping examples above.
237+
exceptions to this are:
238+
1. the loop's iteration variable(s) (IV) of **perfect** loop nests. In that
239+
case, for each IV, we allocate a local copy as shown the by the mapping
240+
examples above.
241+
1. any values that are from allocations outside the loop nest and used
242+
exclusively inside of it. In such cases, a local privatized
243+
value is created in the OpenMP region to prevent multiple teams of threads
244+
from accessing and destroying the same memory block which causes runtime
245+
issues. For an example of such cases, see
246+
`flang/test/Transforms/DoConcurrent/locally_destroyed_temp.f90`.
240247

241248
#### Non-perfectly-nested loops' IVs
242249

flang/lib/Optimizer/Transforms/DoConcurrentConversion.cpp

+74-6
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,10 @@
1515
#include "flang/Optimizer/HLFIR/HLFIRDialect.h"
1616
#include "flang/Optimizer/HLFIR/HLFIROps.h"
1717
#include "flang/Optimizer/Transforms/Passes.h"
18+
#include "mlir/Analysis/SliceAnalysis.h"
19+
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
1820
#include "mlir/Dialect/Func/IR/FuncOps.h"
21+
#include "mlir/Dialect/Math/IR/Math.h"
1922
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
2023
#include "mlir/IR/Diagnostics.h"
2124
#include "mlir/IR/IRMapping.h"
@@ -468,6 +471,61 @@ void sinkLoopIVArgs(mlir::ConversionPatternRewriter &rewriter,
468471
++idx;
469472
}
470473
}
474+
475+
/// Collects values that are local to a loop: "loop-local values". A loop-local
476+
/// value is one that is used exclusively inside the loop but allocated outside
477+
/// of it. This usually corresponds to temporary values that are used inside the
478+
/// loop body for initialzing other variables for example.
479+
///
480+
/// \param [in] doLoop - the loop within which the function searches for values
481+
/// used exclusively inside.
482+
///
483+
/// \param [out] locals - the list of loop-local values detected for \p doLoop.
484+
static void collectLoopLocalValues(fir::DoLoopOp doLoop,
485+
llvm::SetVector<mlir::Value> &locals) {
486+
doLoop.walk([&](mlir::Operation *op) {
487+
for (mlir::Value operand : op->getOperands()) {
488+
if (locals.contains(operand))
489+
continue;
490+
491+
bool isLocal = true;
492+
493+
if (!mlir::isa_and_present<fir::AllocaOp>(operand.getDefiningOp()))
494+
continue;
495+
496+
// Values defined inside the loop are not interesting since they do not
497+
// need to be localized.
498+
if (doLoop->isAncestor(operand.getDefiningOp()))
499+
continue;
500+
501+
for (auto *user : operand.getUsers()) {
502+
if (!doLoop->isAncestor(user)) {
503+
isLocal = false;
504+
break;
505+
}
506+
}
507+
508+
if (isLocal)
509+
locals.insert(operand);
510+
}
511+
});
512+
}
513+
514+
/// For a "loop-local" value \p local within a loop's scope, localizes that
515+
/// value within the scope of the parallel region the loop maps to. Towards that
516+
/// end, this function moves the allocation of \p local within \p allocRegion.
517+
///
518+
/// \param local - the value used exclusively within a loop's scope (see
519+
/// collectLoopLocalValues).
520+
///
521+
/// \param allocRegion - the parallel region where \p local's allocation will be
522+
/// privatized.
523+
///
524+
/// \param rewriter - builder used for updating \p allocRegion.
525+
static void localizeLoopLocalValue(mlir::Value local, mlir::Region &allocRegion,
526+
mlir::ConversionPatternRewriter &rewriter) {
527+
rewriter.moveOpBefore(local.getDefiningOp(), &allocRegion.front().front());
528+
}
471529
} // namespace looputils
472530

473531
class DoConcurrentConversion : public mlir::OpConversionPattern<fir::DoLoopOp> {
@@ -519,9 +577,13 @@ class DoConcurrentConversion : public mlir::OpConversionPattern<fir::DoLoopOp> {
519577
bool hasRemainingNestedLoops =
520578
failed(looputils::collectLoopNest(doLoop, loopNest));
521579

580+
mlir::IRMapping mapper;
581+
582+
llvm::SetVector<mlir::Value> locals;
583+
looputils::collectLoopLocalValues(loopNest.back().first, locals);
584+
522585
looputils::sinkLoopIVArgs(rewriter, loopNest);
523586

524-
mlir::IRMapping mapper;
525587
mlir::omp::TargetOp targetOp;
526588
mlir::omp::LoopNestClauseOps loopNestClauseOps;
527589

@@ -541,8 +603,13 @@ class DoConcurrentConversion : public mlir::OpConversionPattern<fir::DoLoopOp> {
541603
genDistributeOp(doLoop.getLoc(), rewriter);
542604
}
543605

544-
genParallelOp(doLoop.getLoc(), rewriter, loopNest, mapper,
545-
loopNestClauseOps);
606+
mlir::omp::ParallelOp parallelOp = genParallelOp(
607+
doLoop.getLoc(), rewriter, loopNest, mapper, loopNestClauseOps);
608+
609+
for (mlir::Value local : locals)
610+
looputils::localizeLoopLocalValue(local, parallelOp.getRegion(),
611+
rewriter);
612+
546613
mlir::omp::LoopNestOp ompLoopNest =
547614
genWsLoopOp(rewriter, loopNest.back().first, mapper, loopNestClauseOps);
548615

@@ -919,9 +986,10 @@ class DoConcurrentConversionPass
919986
context, mapTo == fir::omp::DoConcurrentMappingKind::DCMK_Device,
920987
concurrentLoopsToSkip);
921988
mlir::ConversionTarget target(*context);
922-
target.addLegalDialect<fir::FIROpsDialect, hlfir::hlfirDialect,
923-
mlir::arith::ArithDialect, mlir::func::FuncDialect,
924-
mlir::omp::OpenMPDialect>();
989+
target.addLegalDialect<
990+
fir::FIROpsDialect, hlfir::hlfirDialect, mlir::arith::ArithDialect,
991+
mlir::func::FuncDialect, mlir::omp::OpenMPDialect,
992+
mlir::cf::ControlFlowDialect, mlir::math::MathDialect>();
925993

926994
target.addDynamicallyLegalOp<fir::DoLoopOp>([&](fir::DoLoopOp op) {
927995
return !op.getUnordered() || concurrentLoopsToSkip.contains(op);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
! Tests that locally destroyed values in a `do concurrent` loop are properly
2+
! handled. Locally destroyed values are those values for which the Fortran runtime
3+
! calls `@_FortranADestroy` inside the loops body. If these values are allocated
4+
! outside the loop, and the loop is mapped to OpenMP, then a runtime error would
5+
! occur due to multiple teams trying to access the same allocation.
6+
7+
! RUN: %flang_fc1 -emit-hlfir -fopenmp -fdo-concurrent-parallel=host %s -o - \
8+
! RUN: | FileCheck %s
9+
10+
module struct_mod
11+
type test_struct
12+
integer, allocatable :: x_
13+
end type
14+
15+
interface test_struct
16+
pure module function construct_from_components(x) result(struct)
17+
implicit none
18+
integer, intent(in) :: x
19+
type(test_struct) struct
20+
end function
21+
end interface
22+
end module
23+
24+
submodule(struct_mod) struct_sub
25+
implicit none
26+
27+
contains
28+
module procedure construct_from_components
29+
struct%x_ = x
30+
end procedure
31+
end submodule struct_sub
32+
33+
program main
34+
use struct_mod, only : test_struct
35+
36+
implicit none
37+
type(test_struct), dimension(10) :: a
38+
integer :: i
39+
integer :: total
40+
41+
do concurrent (i=1:10)
42+
a(i) = test_struct(i)
43+
end do
44+
45+
do i=1,10
46+
total = total + a(i)%x_
47+
end do
48+
49+
print *, "total =", total
50+
end program main
51+
52+
! CHECK: omp.parallel {
53+
! CHECK: %[[LOCAL_TEMP:.*]] = fir.alloca !fir.type<_QMstruct_modTtest_struct{x_:!fir.box<!fir.heap<i32>>}> {bindc_name = ".result"}
54+
! CHECK: omp.wsloop {
55+
! CHECK: omp.loop_nest {{.*}} {
56+
! CHECK: %[[TEMP_VAL:.*]] = fir.call @_QMstruct_modPconstruct_from_components
57+
! CHECK: fir.save_result %[[TEMP_VAL]] to %[[LOCAL_TEMP]]
58+
! CHECK: %[[EMBOXED_LOCAL:.*]] = fir.embox %[[LOCAL_TEMP]]
59+
! CHECK: %[[CONVERTED_LOCAL:.*]] = fir.convert %[[EMBOXED_LOCAL]]
60+
! CHECK: fir.call @_FortranADestroy(%[[CONVERTED_LOCAL]])
61+
! CHECK: omp.yield
62+
! CHECK: }
63+
! CHECK: omp.terminator
64+
! CHECK: }
65+
! CHECK: omp.terminator
66+
! CHECK: }

0 commit comments

Comments
 (0)