@@ -111,7 +111,7 @@ genCoIterateBranchNest(PatternRewriter &rewriter, Location loc, CoIterateOp op,
111
111
112
112
static ValueRange genLoopWithIterator (
113
113
PatternRewriter &rewriter, Location loc, SparseIterator *it,
114
- ValueRange reduc, bool iterFirst,
114
+ ValueRange reduc,
115
115
function_ref<SmallVector<Value>(PatternRewriter &rewriter, Location loc,
116
116
Region &loopBody, SparseIterator *it,
117
117
ValueRange reduc)>
@@ -138,15 +138,9 @@ static ValueRange genLoopWithIterator(
138
138
}
139
139
return forOp.getResults ();
140
140
}
141
- SmallVector<Value> ivs;
142
- // TODO: always put iterator SSA values at the end of argument list to be
143
- // consistent with coiterate operation.
144
- if (!iterFirst)
145
- llvm::append_range (ivs, it->getCursor ());
146
- // Appends the user-provided values.
147
- llvm::append_range (ivs, reduc);
148
- if (iterFirst)
149
- llvm::append_range (ivs, it->getCursor ());
141
+
142
+ SmallVector<Value> ivs (reduc);
143
+ llvm::append_range (ivs, it->getCursor ());
150
144
151
145
TypeRange types = ValueRange (ivs).getTypes ();
152
146
auto whileOp = rewriter.create <scf::WhileOp>(loc, types, ivs);
@@ -164,25 +158,17 @@ static ValueRange genLoopWithIterator(
164
158
Region &dstRegion = whileOp.getAfter ();
165
159
Block *after = rewriter.createBlock (&dstRegion, {}, types, l);
166
160
ValueRange aArgs = whileOp.getAfterArguments ();
167
- if (iterFirst) {
168
- aArgs = it->linkNewScope (aArgs);
169
- } else {
170
- aArgs = aArgs.take_front (reduc.size ());
171
- it->linkNewScope (aArgs.drop_front (reduc.size ()));
172
- }
161
+ it->linkNewScope (aArgs.drop_front (reduc.size ()));
162
+ aArgs = aArgs.take_front (reduc.size ());
173
163
174
164
rewriter.setInsertionPointToStart (after);
175
165
SmallVector<Value> ret = bodyBuilder (rewriter, loc, dstRegion, it, aArgs);
176
166
rewriter.setInsertionPointToEnd (after);
177
167
178
168
// Forward loops
179
169
SmallVector<Value> yields;
180
- ValueRange nx = it->forward (rewriter, loc);
181
- if (iterFirst)
182
- llvm::append_range (yields, nx);
183
170
llvm::append_range (yields, ret);
184
- if (!iterFirst)
185
- llvm::append_range (yields, nx);
171
+ llvm::append_range (yields, it->forward (rewriter, loc));
186
172
rewriter.create <scf::YieldOp>(loc, yields);
187
173
}
188
174
return whileOp.getResults ().drop_front (it->getCursor ().size ());
@@ -258,13 +244,13 @@ class SparseIterateOpConverter : public OneToNOpConversionPattern<IterateOp> {
258
244
259
245
Block *block = op.getBody ();
260
246
ValueRange ret = genLoopWithIterator (
261
- rewriter, loc, it.get (), ivs, /* iterFirst= */ true ,
247
+ rewriter, loc, it.get (), ivs,
262
248
[block](PatternRewriter &rewriter, Location loc, Region &loopBody,
263
249
SparseIterator *it, ValueRange reduc) -> SmallVector<Value> {
264
- SmallVector<Value> blockArgs (it-> getCursor () );
250
+ SmallVector<Value> blockArgs (reduc );
265
251
// TODO: Also appends coordinates if used.
266
252
// blockArgs.push_back(it->deref(rewriter, loc));
267
- llvm::append_range (blockArgs, reduc );
253
+ llvm::append_range (blockArgs, it-> getCursor () );
268
254
269
255
Block *dstBlock = &loopBody.getBlocks ().front ();
270
256
rewriter.inlineBlockBefore (block, dstBlock, dstBlock->end (),
@@ -404,7 +390,7 @@ class SparseCoIterateOpConverter
404
390
405
391
Block *block = &r.getBlocks ().front ();
406
392
ValueRange curResult = genLoopWithIterator (
407
- rewriter, loc, validIters.front (), userReduc, /* iterFirst= */ false ,
393
+ rewriter, loc, validIters.front (), userReduc,
408
394
/* bodyBuilder=*/
409
395
[block](PatternRewriter &rewriter, Location loc, Region &dstRegion,
410
396
SparseIterator *it,
0 commit comments