15
15
#include " flang/Optimizer/HLFIR/HLFIRDialect.h"
16
16
#include " flang/Optimizer/HLFIR/HLFIROps.h"
17
17
#include " flang/Optimizer/Transforms/Passes.h"
18
+ #include " mlir/Analysis/SliceAnalysis.h"
19
+ #include " mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
18
20
#include " mlir/Dialect/Func/IR/FuncOps.h"
21
+ #include " mlir/Dialect/Math/IR/Math.h"
19
22
#include " mlir/Dialect/OpenMP/OpenMPDialect.h"
20
23
#include " mlir/IR/Diagnostics.h"
21
24
#include " mlir/IR/IRMapping.h"
@@ -468,6 +471,61 @@ void sinkLoopIVArgs(mlir::ConversionPatternRewriter &rewriter,
468
471
++idx;
469
472
}
470
473
}
474
+
475
+ // / Collects values that are local to a loop: "loop-local values". A loop-local
476
+ // / value is one that is used exclusively inside the loop but allocated outside
477
+ // / of it. This usually corresponds to temporary values that are used inside the
478
+ // / loop body for initialzing other variables for example.
479
+ // /
480
+ // / \param [in] doLoop - the loop within which the function searches for values
481
+ // / used exclusively inside.
482
+ // /
483
+ // / \param [out] locals - the list of loop-local values detected for \p doLoop.
484
+ static void collectLoopLocalValues (fir::DoLoopOp doLoop,
485
+ llvm::SetVector<mlir::Value> &locals) {
486
+ doLoop.walk ([&](mlir::Operation *op) {
487
+ for (mlir::Value operand : op->getOperands ()) {
488
+ if (locals.contains (operand))
489
+ continue ;
490
+
491
+ bool isLocal = true ;
492
+
493
+ if (!mlir::isa_and_present<fir::AllocaOp>(operand.getDefiningOp ()))
494
+ continue ;
495
+
496
+ // Values defined inside the loop are not interesting since they do not
497
+ // need to be localized.
498
+ if (doLoop->isAncestor (operand.getDefiningOp ()))
499
+ continue ;
500
+
501
+ for (auto *user : operand.getUsers ()) {
502
+ if (!doLoop->isAncestor (user)) {
503
+ isLocal = false ;
504
+ break ;
505
+ }
506
+ }
507
+
508
+ if (isLocal)
509
+ locals.insert (operand);
510
+ }
511
+ });
512
+ }
513
+
514
+ // / For a "loop-local" value \p local within a loop's scope, localizes that
515
+ // / value within the scope of the parallel region the loop maps to. Towards that
516
+ // / end, this function moves the allocation of \p local within \p allocRegion.
517
+ // /
518
+ // / \param local - the value used exclusively within a loop's scope (see
519
+ // / collectLoopLocalValues).
520
+ // /
521
+ // / \param allocRegion - the parallel region where \p local's allocation will be
522
+ // / privatized.
523
+ // /
524
+ // / \param rewriter - builder used for updating \p allocRegion.
525
+ static void localizeLoopLocalValue (mlir::Value local, mlir::Region &allocRegion,
526
+ mlir::ConversionPatternRewriter &rewriter) {
527
+ rewriter.moveOpBefore (local.getDefiningOp (), &allocRegion.front ().front ());
528
+ }
471
529
} // namespace looputils
472
530
473
531
class DoConcurrentConversion : public mlir ::OpConversionPattern<fir::DoLoopOp> {
@@ -519,9 +577,13 @@ class DoConcurrentConversion : public mlir::OpConversionPattern<fir::DoLoopOp> {
519
577
bool hasRemainingNestedLoops =
520
578
failed (looputils::collectLoopNest (doLoop, loopNest));
521
579
580
+ mlir::IRMapping mapper;
581
+
582
+ llvm::SetVector<mlir::Value> locals;
583
+ looputils::collectLoopLocalValues (loopNest.back ().first , locals);
584
+
522
585
looputils::sinkLoopIVArgs (rewriter, loopNest);
523
586
524
- mlir::IRMapping mapper;
525
587
mlir::omp::TargetOp targetOp;
526
588
mlir::omp::LoopNestClauseOps loopNestClauseOps;
527
589
@@ -541,8 +603,13 @@ class DoConcurrentConversion : public mlir::OpConversionPattern<fir::DoLoopOp> {
541
603
genDistributeOp (doLoop.getLoc (), rewriter);
542
604
}
543
605
544
- genParallelOp (doLoop.getLoc (), rewriter, loopNest, mapper,
545
- loopNestClauseOps);
606
+ mlir::omp::ParallelOp parallelOp = genParallelOp (
607
+ doLoop.getLoc (), rewriter, loopNest, mapper, loopNestClauseOps);
608
+
609
+ for (mlir::Value local : locals)
610
+ looputils::localizeLoopLocalValue (local, parallelOp.getRegion (),
611
+ rewriter);
612
+
546
613
mlir::omp::LoopNestOp ompLoopNest =
547
614
genWsLoopOp (rewriter, loopNest.back ().first , mapper, loopNestClauseOps);
548
615
@@ -919,9 +986,10 @@ class DoConcurrentConversionPass
919
986
context, mapTo == fir::omp::DoConcurrentMappingKind::DCMK_Device,
920
987
concurrentLoopsToSkip);
921
988
mlir::ConversionTarget target (*context);
922
- target.addLegalDialect <fir::FIROpsDialect, hlfir::hlfirDialect,
923
- mlir::arith::ArithDialect, mlir::func::FuncDialect,
924
- mlir::omp::OpenMPDialect>();
989
+ target.addLegalDialect <
990
+ fir::FIROpsDialect, hlfir::hlfirDialect, mlir::arith::ArithDialect,
991
+ mlir::func::FuncDialect, mlir::omp::OpenMPDialect,
992
+ mlir::cf::ControlFlowDialect, mlir::math::MathDialect>();
925
993
926
994
target.addDynamicallyLegalOp <fir::DoLoopOp>([&](fir::DoLoopOp op) {
927
995
return !op.getUnordered () || concurrentLoopsToSkip.contains (op);
0 commit comments