Skip to content

Commit b48ef8d

Browse files
author
Peiming Liu
authored
[mlir][sparse] unify block arguments order between iterate/coiterate operations. (#105567)
1 parent 283dff4 commit b48ef8d

File tree

3 files changed

+31
-43
lines changed

3 files changed

+31
-43
lines changed

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td

+3-4
Original file line numberDiff line numberDiff line change
@@ -1644,7 +1644,7 @@ def IterateOp : SparseTensor_Op<"iterate",
16441644
return getIterSpace().getType().getSpaceDim();
16451645
}
16461646
BlockArgument getIterator() {
1647-
return getRegion().getArguments().front();
1647+
return getRegion().getArguments().back();
16481648
}
16491649
std::optional<BlockArgument> getLvlCrd(Level lvl) {
16501650
if (getCrdUsedLvls()[lvl]) {
@@ -1654,9 +1654,8 @@ def IterateOp : SparseTensor_Op<"iterate",
16541654
return std::nullopt;
16551655
}
16561656
Block::BlockArgListType getCrds() {
1657-
// The first block argument is iterator, the remaining arguments are
1658-
// referenced coordinates.
1659-
return getRegion().getArguments().slice(1, getCrdUsedLvls().count());
1657+
// User-provided iteration arguments -> coords -> iterator.
1658+
return getRegion().getArguments().slice(getNumRegionIterArgs(), getCrdUsedLvls().count());
16601659
}
16611660
unsigned getNumRegionIterArgs() {
16621661
return getRegion().getArguments().size() - 1 - getCrdUsedLvls().count();

mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp

+17-14
Original file line numberDiff line numberDiff line change
@@ -2228,16 +2228,19 @@ parseSparseIterateLoop(OpAsmParser &parser, OperationState &state,
22282228
parser.getNameLoc(),
22292229
"mismatch in number of sparse iterators and sparse spaces");
22302230

2231-
if (failed(parseUsedCoordList(parser, state, blockArgs)))
2231+
SmallVector<OpAsmParser::Argument> coords;
2232+
if (failed(parseUsedCoordList(parser, state, coords)))
22322233
return failure();
2233-
size_t numCrds = blockArgs.size();
2234+
size_t numCrds = coords.size();
22342235

22352236
// Parse "iter_args(%arg = %init, ...)"
22362237
bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
22372238
if (hasIterArgs)
22382239
if (parser.parseAssignmentList(blockArgs, initArgs))
22392240
return failure();
22402241

2242+
blockArgs.append(coords);
2243+
22412244
SmallVector<Type> iterSpaceTps;
22422245
// parse ": sparse_tensor.iter_space -> ret"
22432246
if (parser.parseColon() || parser.parseTypeList(iterSpaceTps))
@@ -2267,7 +2270,7 @@ parseSparseIterateLoop(OpAsmParser &parser, OperationState &state,
22672270

22682271
if (hasIterArgs) {
22692272
// Strip off leading args that used for coordinates.
2270-
MutableArrayRef args = MutableArrayRef(blockArgs).drop_front(numCrds);
2273+
MutableArrayRef args = MutableArrayRef(blockArgs).drop_back(numCrds);
22712274
if (args.size() != initArgs.size() || args.size() != state.types.size()) {
22722275
return parser.emitError(
22732276
parser.getNameLoc(),
@@ -2448,18 +2451,18 @@ void IterateOp::build(OpBuilder &builder, OperationState &odsState,
24482451
odsState.addTypes(initArgs.getTypes());
24492452
Block *bodyBlock = builder.createBlock(bodyRegion);
24502453

2451-
// First argument, sparse iterator
2452-
bodyBlock->addArgument(
2453-
llvm::cast<IterSpaceType>(iterSpace.getType()).getIteratorType(),
2454-
odsState.location);
2454+
// Starts with a list of user-provided loop arguments.
2455+
for (Value v : initArgs)
2456+
bodyBlock->addArgument(v.getType(), v.getLoc());
24552457

2456-
// Followed by a list of used coordinates.
2458+
// Follows by a list of used coordinates.
24572459
for (unsigned i = 0, e = crdUsedLvls.count(); i < e; i++)
24582460
bodyBlock->addArgument(builder.getIndexType(), odsState.location);
24592461

2460-
// Followed by a list of user-provided loop arguments.
2461-
for (Value v : initArgs)
2462-
bodyBlock->addArgument(v.getType(), v.getLoc());
2462+
// Ends with sparse iterator
2463+
bodyBlock->addArgument(
2464+
llvm::cast<IterSpaceType>(iterSpace.getType()).getIteratorType(),
2465+
odsState.location);
24632466
}
24642467

24652468
ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &result) {
@@ -2473,9 +2476,9 @@ ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &result) {
24732476
return parser.emitError(parser.getNameLoc(),
24742477
"expected only one iterator/iteration space");
24752478

2476-
iters.append(iterArgs);
2479+
iterArgs.append(iters);
24772480
Region *body = result.addRegion();
2478-
if (parser.parseRegion(*body, iters))
2481+
if (parser.parseRegion(*body, iterArgs))
24792482
return failure();
24802483

24812484
IterateOp::ensureTerminator(*body, parser.getBuilder(), result.location);
@@ -2580,7 +2583,7 @@ MutableArrayRef<OpOperand> IterateOp::getInitsMutable() {
25802583
}
25812584

25822585
Block::BlockArgListType IterateOp::getRegionIterArgs() {
2583-
return getRegion().getArguments().take_back(getNumRegionIterArgs());
2586+
return getRegion().getArguments().take_front(getNumRegionIterArgs());
25842587
}
25852588

25862589
std::optional<MutableArrayRef<OpOperand>> IterateOp::getYieldedValuesMutable() {

mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp

+11-25
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ genCoIterateBranchNest(PatternRewriter &rewriter, Location loc, CoIterateOp op,
111111

112112
static ValueRange genLoopWithIterator(
113113
PatternRewriter &rewriter, Location loc, SparseIterator *it,
114-
ValueRange reduc, bool iterFirst,
114+
ValueRange reduc,
115115
function_ref<SmallVector<Value>(PatternRewriter &rewriter, Location loc,
116116
Region &loopBody, SparseIterator *it,
117117
ValueRange reduc)>
@@ -138,15 +138,9 @@ static ValueRange genLoopWithIterator(
138138
}
139139
return forOp.getResults();
140140
}
141-
SmallVector<Value> ivs;
142-
// TODO: always put iterator SSA values at the end of argument list to be
143-
// consistent with coiterate operation.
144-
if (!iterFirst)
145-
llvm::append_range(ivs, it->getCursor());
146-
// Appends the user-provided values.
147-
llvm::append_range(ivs, reduc);
148-
if (iterFirst)
149-
llvm::append_range(ivs, it->getCursor());
141+
142+
SmallVector<Value> ivs(reduc);
143+
llvm::append_range(ivs, it->getCursor());
150144

151145
TypeRange types = ValueRange(ivs).getTypes();
152146
auto whileOp = rewriter.create<scf::WhileOp>(loc, types, ivs);
@@ -164,25 +158,17 @@ static ValueRange genLoopWithIterator(
164158
Region &dstRegion = whileOp.getAfter();
165159
Block *after = rewriter.createBlock(&dstRegion, {}, types, l);
166160
ValueRange aArgs = whileOp.getAfterArguments();
167-
if (iterFirst) {
168-
aArgs = it->linkNewScope(aArgs);
169-
} else {
170-
aArgs = aArgs.take_front(reduc.size());
171-
it->linkNewScope(aArgs.drop_front(reduc.size()));
172-
}
161+
it->linkNewScope(aArgs.drop_front(reduc.size()));
162+
aArgs = aArgs.take_front(reduc.size());
173163

174164
rewriter.setInsertionPointToStart(after);
175165
SmallVector<Value> ret = bodyBuilder(rewriter, loc, dstRegion, it, aArgs);
176166
rewriter.setInsertionPointToEnd(after);
177167

178168
// Forward loops
179169
SmallVector<Value> yields;
180-
ValueRange nx = it->forward(rewriter, loc);
181-
if (iterFirst)
182-
llvm::append_range(yields, nx);
183170
llvm::append_range(yields, ret);
184-
if (!iterFirst)
185-
llvm::append_range(yields, nx);
171+
llvm::append_range(yields, it->forward(rewriter, loc));
186172
rewriter.create<scf::YieldOp>(loc, yields);
187173
}
188174
return whileOp.getResults().drop_front(it->getCursor().size());
@@ -258,13 +244,13 @@ class SparseIterateOpConverter : public OneToNOpConversionPattern<IterateOp> {
258244

259245
Block *block = op.getBody();
260246
ValueRange ret = genLoopWithIterator(
261-
rewriter, loc, it.get(), ivs, /*iterFirst=*/true,
247+
rewriter, loc, it.get(), ivs,
262248
[block](PatternRewriter &rewriter, Location loc, Region &loopBody,
263249
SparseIterator *it, ValueRange reduc) -> SmallVector<Value> {
264-
SmallVector<Value> blockArgs(it->getCursor());
250+
SmallVector<Value> blockArgs(reduc);
265251
// TODO: Also appends coordinates if used.
266252
// blockArgs.push_back(it->deref(rewriter, loc));
267-
llvm::append_range(blockArgs, reduc);
253+
llvm::append_range(blockArgs, it->getCursor());
268254

269255
Block *dstBlock = &loopBody.getBlocks().front();
270256
rewriter.inlineBlockBefore(block, dstBlock, dstBlock->end(),
@@ -404,7 +390,7 @@ class SparseCoIterateOpConverter
404390

405391
Block *block = &r.getBlocks().front();
406392
ValueRange curResult = genLoopWithIterator(
407-
rewriter, loc, validIters.front(), userReduc, /*iterFirst=*/false,
393+
rewriter, loc, validIters.front(), userReduc,
408394
/*bodyBuilder=*/
409395
[block](PatternRewriter &rewriter, Location loc, Region &dstRegion,
410396
SparseIterator *it,

0 commit comments

Comments
 (0)