@@ -247,21 +247,24 @@ struct ONNXParallelOpLowering : public OpConversionPattern<ONNXParallelOp> {
247
247
Value eq = create.math .eq (forkId, indices[0 ]);
248
248
scf::IfOp ifOp = rewriter.create <scf::IfOp>(loc, eq, /* else=*/ false );
249
249
Block &ifBlock = ifOp.getThenRegion ().back ();
250
- // Insert KrnlRegionOp within scf::IfOp
251
250
rewriter.setInsertionPointToStart (&ifBlock);
252
- KrnlRegionOp regionOp = rewriter.create <KrnlRegionOp>(loc);
253
- Block ®ionBlock = regionOp.getBodyRegion ().front ();
254
- rewriter.setInsertionPointToStart (®ionBlock);
255
- // Insert KrnlNoneOp. This op is used for inserting forkBlock into
256
- // regionBlock. This op is deleted after doing it.
257
- KrnlNoneOp noneOp = rewriter.create <KrnlNoneOp>(loc);
258
- // Delete terminator of forkRegion.
251
+ // Insert KrnlRegionOp in every KrnlIterateOps. This needs to avoid errors
252
+ // in convertKrnlToAffinePass.
259
253
Block &forkBlock = forkOp.getRegion ().back ();
254
+ for (auto kop : forkBlock.getOps <KrnlIterateOp>()) {
255
+ KrnlRegionOp regionOp = rewriter.create <KrnlRegionOp>(loc);
256
+ Block ®ionBlock = regionOp.getBodyRegion ().front ();
257
+ Block &iterateBlock = kop.getBodyRegion ().back ();
258
+ rewriter.eraseOp (iterateBlock.getTerminator ());
259
+ regionBlock.getOperations ().splice (
260
+ regionBlock.end (), iterateBlock.getOperations ());
261
+ rewriter.setInsertionPointToStart (&iterateBlock);
262
+ KrnlYieldOp krnlYieldOp = rewriter.create <KrnlYieldOp>(loc);
263
+ rewriter.moveOpBefore (regionOp, krnlYieldOp);
264
+ }
260
265
Operation *forkYieldOp = forkBlock.getTerminator ();
261
266
rewriter.eraseOp (forkYieldOp);
262
- // Insert forkBlock into regionBlock
263
- rewriter.inlineBlockBefore (&forkOp.getRegion ().back (), noneOp);
264
- rewriter.eraseOp (noneOp);
267
+ rewriter.inlineBlockBefore (&forkBlock, ifBlock.getTerminator ());
265
268
rewriter.eraseOp (forkOp);
266
269
id++;
267
270
}
0 commit comments