Skip to content

Commit f2f202c

Browse files
authored
[AffineParallelOpUnparallelize] Simplify nested SCF IndexSwitch that has a modulo expression argument (#8401)
1 parent 58f592b commit f2f202c

File tree

2 files changed

+161
-0
lines changed

2 files changed

+161
-0
lines changed

lib/Dialect/Calyx/Transforms/AffinePloopUnparallelize.cpp

+43
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "mlir/Dialect/Func/IR/FuncOps.h"
1616
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1717
#include "mlir/Dialect/SCF/IR/SCF.h"
18+
#include "mlir/IR/AffineExpr.h"
1819
#include "mlir/IR/PatternMatch.h"
1920
#include "mlir/IR/Visitors.h"
2021
#include "mlir/Pass/PassManager.h"
@@ -75,6 +76,9 @@ class AffinePloopUnparallelize
7576

7677
int64_t factor = factorAttr.getInt();
7778

79+
SmallVector<scf::IndexSwitchOp> simplifiableIndexSwitchOps =
80+
collectSimplifiableIndexSwitchOps(affineParallelOp, factor);
81+
7882
auto outerLoop = rewriter.create<affine::AffineForOp>(
7983
loc, lowerBound, rewriter.getDimIdentityMap(), upperBound,
8084
rewriter.getDimIdentityMap(), step * factor);
@@ -132,8 +136,47 @@ class AffinePloopUnparallelize
132136
rewriter.setInsertionPointToEnd(destBlock);
133137
rewriter.create<affine::AffineYieldOp>(loc);
134138

139+
for (auto indexSwitchOp : simplifiableIndexSwitchOps) {
140+
indexSwitchOp.setOperand(innerParallel.getIVs().front());
141+
}
142+
135143
return success();
136144
}
145+
146+
private:
147+
// Collect all simplifiable `scf.index_switch` ops in `affineParallelOp`. An
148+
// `scf.index_switch` op is simpliiable if its argument only depends on
149+
// `affineParallelOp`'s loop IV and if it's a result of a modulo expression.
150+
SmallVector<scf::IndexSwitchOp>
151+
collectSimplifiableIndexSwitchOps(affine::AffineParallelOp affineParallelOp,
152+
int64_t factor) const {
153+
SmallVector<scf::IndexSwitchOp> result;
154+
affineParallelOp->walk([&](scf::IndexSwitchOp indexSwitchOp) {
155+
auto switchArg = indexSwitchOp.getArg();
156+
auto affineApplyOp =
157+
dyn_cast_or_null<affine::AffineApplyOp>(switchArg.getDefiningOp());
158+
if (!affineApplyOp || affineApplyOp->getNumOperands() != 1 ||
159+
affineApplyOp->getNumResults() != 1)
160+
return WalkResult::advance();
161+
162+
auto affineMap = affineApplyOp.getAffineMap();
163+
auto binExpr = dyn_cast<AffineBinaryOpExpr>(affineMap.getResult(0));
164+
if (!binExpr || binExpr.getKind() != AffineExprKind::Mod)
165+
return WalkResult::advance();
166+
167+
if (affineApplyOp.getOperand(0) != affineParallelOp.getIVs().front())
168+
return WalkResult::advance();
169+
170+
auto rhs = binExpr.getRHS();
171+
auto constRhs = dyn_cast<AffineConstantExpr>(rhs);
172+
if (!constRhs || factor != constRhs.getValue())
173+
return WalkResult::advance();
174+
175+
result.push_back(indexSwitchOp);
176+
return WalkResult::advance();
177+
});
178+
return result;
179+
}
137180
};
138181

139182
namespace {

test/Dialect/Calyx/affine-ploop-unparallelize.mlir

+118
Original file line numberDiff line numberDiff line change
@@ -55,3 +55,121 @@ module {
5555
return
5656
}
5757
}
58+
59+
// -----
60+
61+
// Test simplify `scf.index_switch` with nested `affine.parallel`s
62+
63+
// CHECK-LABEL: func.func @main(
64+
// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<4x6xf32>,
65+
// CHECK-SAME: %[[VAL_1:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<4x6xf32>,
66+
// CHECK-SAME: %[[VAL_2:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<4x6xf32>,
67+
// CHECK-SAME: %[[VAL_3:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<4x6xf32>,
68+
// CHECK-SAME: %[[VAL_4:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<4x6xf32>,
69+
// CHECK-SAME: %[[VAL_5:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<4x6xf32>) {
70+
// CHECK: %[[VAL_6:.*]] = arith.constant 0.000000e+00 : f32
71+
// CHECK: affine.for %[[VAL_7:.*]] = 0 to 8 step 2 {
72+
// CHECK: affine.parallel (%[[VAL_8:.*]]) = (0) to (2) {
73+
// CHECK: affine.for %[[VAL_9:.*]] = 0 to 18 step 3 {
74+
// CHECK: affine.parallel (%[[VAL_10:.*]]) = (0) to (3) {
75+
// CHECK: scf.index_switch %[[VAL_8]]
76+
// CHECK: case 0 {
77+
// CHECK: scf.index_switch %[[VAL_10]]
78+
// CHECK: case 0 {
79+
// CHECK: affine.store %[[VAL_6]], %[[VAL_0]][(%[[VAL_7]] + %[[VAL_8]]) floordiv 2, (%[[VAL_9]] + %[[VAL_10]]) floordiv 3] : memref<4x6xf32>
80+
// CHECK: scf.yield
81+
// CHECK: }
82+
// CHECK: case 1 {
83+
// CHECK: affine.store %[[VAL_6]], %[[VAL_1]][(%[[VAL_7]] + %[[VAL_8]]) floordiv 2, (%[[VAL_9]] + %[[VAL_10]]) floordiv 3] : memref<4x6xf32>
84+
// CHECK: scf.yield
85+
// CHECK: }
86+
// CHECK: case 2 {
87+
// CHECK: affine.store %[[VAL_6]], %[[VAL_2]][(%[[VAL_7]] + %[[VAL_8]]) floordiv 2, (%[[VAL_9]] + %[[VAL_10]]) floordiv 3] : memref<4x6xf32>
88+
// CHECK: scf.yield
89+
// CHECK: }
90+
// CHECK: default {
91+
// CHECK: }
92+
// CHECK: scf.yield
93+
// CHECK: }
94+
// CHECK: case 1 {
95+
// CHECK: scf.index_switch %[[VAL_10]]
96+
// CHECK: case 0 {
97+
// CHECK: affine.store %[[VAL_6]], %[[VAL_3]][(%[[VAL_7]] + %[[VAL_8]]) floordiv 2, (%[[VAL_9]] + %[[VAL_10]]) floordiv 3] : memref<4x6xf32>
98+
// CHECK: scf.yield
99+
// CHECK: }
100+
// CHECK: case 1 {
101+
// CHECK: affine.store %[[VAL_6]], %[[VAL_4]][(%[[VAL_7]] + %[[VAL_8]]) floordiv 2, (%[[VAL_9]] + %[[VAL_10]]) floordiv 3] : memref<4x6xf32>
102+
// CHECK: scf.yield
103+
// CHECK: }
104+
// CHECK: case 2 {
105+
// CHECK: affine.store %[[VAL_6]], %[[VAL_5]][(%[[VAL_7]] + %[[VAL_8]]) floordiv 2, (%[[VAL_9]] + %[[VAL_10]]) floordiv 3] : memref<4x6xf32>
106+
// CHECK: scf.yield
107+
// CHECK: }
108+
// CHECK: default {
109+
// CHECK: }
110+
// CHECK: scf.yield
111+
// CHECK: }
112+
// CHECK: default {
113+
// CHECK: }
114+
// CHECK: }
115+
// CHECK: } {unparallelized}
116+
// CHECK: }
117+
// CHECK: } {unparallelized}
118+
// CHECK: return
119+
// CHECK: }
120+
121+
#map = affine_map<(d0) -> (d0 mod 2)>
122+
#map1 = affine_map<(d0) -> (d0 mod 3)>
123+
module {
124+
func.func @main(%arg0: memref<4x6xf32>, %arg1: memref<4x6xf32>, %arg2: memref<4x6xf32>, %arg3: memref<4x6xf32>, %arg4: memref<4x6xf32>, %arg5: memref<4x6xf32>) {
125+
%cst = arith.constant 0.000000e+00 : f32
126+
affine.parallel (%arg6) = (0) to (8) {
127+
affine.parallel (%arg7) = (0) to (18) {
128+
%0 = affine.apply #map(%arg6)
129+
scf.index_switch %0
130+
case 0 {
131+
%1 = affine.apply #map1(%arg7)
132+
scf.index_switch %1
133+
case 0 {
134+
affine.store %cst, %arg0[%arg6 floordiv 2, %arg7 floordiv 3] : memref<4x6xf32>
135+
scf.yield
136+
}
137+
case 1 {
138+
affine.store %cst, %arg1[%arg6 floordiv 2, %arg7 floordiv 3] : memref<4x6xf32>
139+
scf.yield
140+
}
141+
case 2 {
142+
affine.store %cst, %arg2[%arg6 floordiv 2, %arg7 floordiv 3] : memref<4x6xf32>
143+
scf.yield
144+
}
145+
default {
146+
}
147+
scf.yield
148+
}
149+
case 1 {
150+
%1 = affine.apply #map1(%arg7)
151+
scf.index_switch %1
152+
case 0 {
153+
affine.store %cst, %arg3[%arg6 floordiv 2, %arg7 floordiv 3] : memref<4x6xf32>
154+
scf.yield
155+
}
156+
case 1 {
157+
affine.store %cst, %arg4[%arg6 floordiv 2, %arg7 floordiv 3] : memref<4x6xf32>
158+
scf.yield
159+
}
160+
case 2 {
161+
affine.store %cst, %arg5[%arg6 floordiv 2, %arg7 floordiv 3] : memref<4x6xf32>
162+
scf.yield
163+
}
164+
default {
165+
}
166+
scf.yield
167+
}
168+
default {
169+
}
170+
} {unparallelize.factor=3}
171+
} {unparallelize.factor=2}
172+
return
173+
}
174+
}
175+

0 commit comments

Comments
 (0)