Skip to content
This repository was archived by the owner on Jan 30, 2025. It is now read-only.

Commit 9a4c12a

Browse files
author
Srinath Avadhanula
committed
fix indentation
1 parent 1838c8b commit 9a4c12a

File tree

1 file changed

+76
-76
lines changed

1 file changed

+76
-76
lines changed

lib/Dialect/Transforms/FusionPatterns.cpp

+76-76
Original file line numberDiff line numberDiff line change
@@ -7,94 +7,94 @@
77
//
88
//===----------------------------------------------------------------------===//
99

10+
#include "mlir-tcp/Dialect/Transforms/FusionPatterns.h"
1011
#include "mlir-tcp/Dialect/IR/TcpDialect.h"
1112
#include "mlir-tcp/Dialect/IR/TcpOps.h"
12-
#include "mlir-tcp/Dialect/Transforms/FusionPatterns.h"
1313
#include "mlir/Dialect/Func/IR/FuncOps.h"
1414
#include "mlir/IR/BuiltinOps.h"
1515
#include "mlir/IR/OpDefinition.h"
1616

1717
namespace mlir::tcp {
18-
LogicalResult GenericBottomUpFuser::matchAndRewrite(
19-
Operation *op, PatternRewriter &rewriter) const {
20-
Operation *use = op;
21-
bool isChanged = false;
22-
for (auto operand : op->getOperands()) {
23-
if (operand.getDefiningOp()) {
24-
Operation *def = operand.getDefiningOp();
25-
if (canFuse(def, use)) {
26-
// Currently we are only fusing ops at the top-level.
27-
// This is to avoid recursing inside a group and ending up with
28-
// nested groups that contain the same ops.
29-
// Since we are iterating bottom up in a block, we only need to
30-
// check if the def op has a func parent.
31-
//
32-
// TODO: Remove this restriction to allow fusing in nested
33-
// regions.
34-
if (!isa<func::FuncOp>(def->getParentOp())) {
35-
continue;
36-
}
18+
LogicalResult
19+
GenericBottomUpFuser::matchAndRewrite(Operation *op,
20+
PatternRewriter &rewriter) const {
21+
Operation *use = op;
22+
bool isChanged = false;
23+
for (auto operand : op->getOperands()) {
24+
if (operand.getDefiningOp()) {
25+
Operation *def = operand.getDefiningOp();
26+
if (canFuse(def, use)) {
27+
// Currently we are only fusing ops at the top-level.
28+
// This is to avoid recursing inside a group and ending up with
29+
// nested groups that contain the same ops.
30+
// Since we are iterating bottom up in a block, we only need to
31+
// check if the def op has a func parent.
32+
//
33+
// TODO: Remove this restriction to allow fusing in nested
34+
// regions.
35+
if (!isa<func::FuncOp>(def->getParentOp())) {
36+
continue;
37+
}
3738

38-
// We only support fusing def ops that have exactly one use, for
39-
// now. Special-case the uses of the def in
40-
// tcp.bind_symbolic_shape
41-
bool cannotFuse = false;
42-
SmallVector<tcp::BindSymbolicShapeOp> bindSymbolicUsersOfDef;
43-
for (auto otherUserOfDef : def->getUsers()) {
44-
if (auto bindSymbolicShapeOp =
45-
dyn_cast<tcp::BindSymbolicShapeOp>(
46-
otherUserOfDef)) {
47-
bindSymbolicUsersOfDef.push_back(bindSymbolicShapeOp);
48-
} else if (otherUserOfDef != use) {
49-
cannotFuse = true;
50-
break;
51-
}
52-
}
39+
// We only support fusing def ops that have exactly one use, for
40+
// now. Special-case the uses of the def in
41+
// tcp.bind_symbolic_shape
42+
bool cannotFuse = false;
43+
SmallVector<tcp::BindSymbolicShapeOp> bindSymbolicUsersOfDef;
44+
for (auto otherUserOfDef : def->getUsers()) {
45+
if (auto bindSymbolicShapeOp =
46+
dyn_cast<tcp::BindSymbolicShapeOp>(otherUserOfDef)) {
47+
bindSymbolicUsersOfDef.push_back(bindSymbolicShapeOp);
48+
} else if (otherUserOfDef != use) {
49+
cannotFuse = true;
50+
break;
51+
}
52+
}
5353

54-
if (cannotFuse) continue;
54+
if (cannotFuse)
55+
continue;
5556

56-
// Fuse the def and use ops into a group.
57+
// Fuse the def and use ops into a group.
5758

58-
// * If both the ops have the same parent region, they must be
59-
// part
60-
// of the top-level func. So, we need to create a new group.
61-
// * The only other case is when the def op is part of the
62-
// top-level
63-
// func and the use is already inside a group.
64-
isChanged = true;
65-
if (def->getParentRegion() == use->getParentRegion()) {
66-
auto groupOp = rewriter.create<tcp::GroupOp>(
67-
use->getLoc(), use->getResultTypes());
68-
if (postFunc) {
69-
postFunc(groupOp, rewriter);
70-
}
71-
Block *groupBlock = new Block();
72-
groupOp.getBody().push_back(groupBlock);
73-
for (unsigned num = 0; num < use->getNumResults(); ++num) {
74-
rewriter.replaceAllUsesWith(use->getResult(num),
75-
groupOp->getResult(num));
76-
}
77-
{
78-
OpBuilder::InsertionGuard guard(rewriter);
79-
rewriter.setInsertionPointToStart(groupBlock);
80-
auto yieldOp = rewriter.create<tcp::YieldOp>(
81-
use->getLoc(), use->getResults());
82-
use->moveBefore(yieldOp);
83-
def->moveBefore(use);
84-
}
85-
} else if (auto groupOp =
86-
dyn_cast<tcp::GroupOp>(use->getParentOp())) {
87-
def->moveBefore(use);
88-
} else {
89-
llvm_unreachable("Unhandled case during fusion");
90-
}
59+
// * If both the ops have the same parent region, they must be
60+
// part
61+
// of the top-level func. So, we need to create a new group.
62+
// * The only other case is when the def op is part of the
63+
// top-level
64+
// func and the use is already inside a group.
65+
isChanged = true;
66+
if (def->getParentRegion() == use->getParentRegion()) {
67+
auto groupOp = rewriter.create<tcp::GroupOp>(use->getLoc(),
68+
use->getResultTypes());
69+
if (postFunc) {
70+
postFunc(groupOp, rewriter);
71+
}
72+
Block *groupBlock = new Block();
73+
groupOp.getBody().push_back(groupBlock);
74+
for (unsigned num = 0; num < use->getNumResults(); ++num) {
75+
rewriter.replaceAllUsesWith(use->getResult(num),
76+
groupOp->getResult(num));
77+
}
78+
{
79+
OpBuilder::InsertionGuard guard(rewriter);
80+
rewriter.setInsertionPointToStart(groupBlock);
81+
auto yieldOp =
82+
rewriter.create<tcp::YieldOp>(use->getLoc(), use->getResults());
83+
use->moveBefore(yieldOp);
84+
def->moveBefore(use);
85+
}
86+
} else if (auto groupOp = dyn_cast<tcp::GroupOp>(use->getParentOp())) {
87+
def->moveBefore(use);
88+
} else {
89+
llvm_unreachable("Unhandled case during fusion");
90+
}
9191

92-
for (auto bindSymbolicShapeOp : bindSymbolicUsersOfDef) {
93-
bindSymbolicShapeOp->moveAfter(def);
94-
}
95-
}
92+
for (auto bindSymbolicShapeOp : bindSymbolicUsersOfDef) {
93+
bindSymbolicShapeOp->moveAfter(def);
9694
}
95+
}
9796
}
98-
return isChanged ? success() : failure();
97+
}
98+
return isChanged ? success() : failure();
9999
}
100-
} // namespace mlir::tcp
100+
} // namespace mlir::tcp

0 commit comments

Comments
 (0)