Skip to content
This repository was archived by the owner on Jan 30, 2025. It is now read-only.

Commit 7c50225

Browse files
srinathavaSrinath Avadhanula
and
Srinath Avadhanula
authored
Slight improvement to fusion (#86)
As discussed, a small improvement to the fusion algorithm to account for the case when we have multiple uses of a definition but all those uses already belong to the group we are in. This allows us to fuse "diamond patterns" of multiple uses into a single group. See the update to the lit test for the improvement. Note that this still is not a "maximal" fusion which can create groups with multiple returns etc. --------- Co-authored-by: Srinath Avadhanula <[email protected]>
1 parent 7a5f782 commit 7c50225

File tree

2 files changed

+64
-17
lines changed

2 files changed

+64
-17
lines changed

lib/Dialect/Transforms/FusionPatterns.cpp

+39-8
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ LogicalResult
1919
GenericBottomUpFuser::matchAndRewrite(Operation *op,
2020
PatternRewriter &rewriter) const {
2121
Operation *use = op;
22+
bool opIsInsideGroup = op->getParentOfType<tcp::GroupOp>() != nullptr;
2223
bool isChanged = false;
2324
for (auto operand : op->getOperands()) {
2425
if (operand.getDefiningOp()) {
@@ -36,22 +37,48 @@ GenericBottomUpFuser::matchAndRewrite(Operation *op,
3637
continue;
3738
}
3839

39-
// We only support fusing def ops that have exactly one use, for
40-
// now. Special-case the uses of the def in
41-
// tcp.bind_symbolic_shape
42-
bool cannotFuse = false;
4340
SmallVector<tcp::BindSymbolicShapeOp> bindSymbolicUsersOfDef;
41+
SmallVector<Operation *> otherUses;
4442
for (auto otherUserOfDef : def->getUsers()) {
4543
if (auto bindSymbolicShapeOp =
4644
dyn_cast<tcp::BindSymbolicShapeOp>(otherUserOfDef)) {
4745
bindSymbolicUsersOfDef.push_back(bindSymbolicShapeOp);
48-
} else if (otherUserOfDef != use) {
49-
cannotFuse = true;
50-
break;
46+
} else {
47+
otherUses.push_back(otherUserOfDef);
5148
}
5249
}
5350

54-
if (cannotFuse)
51+
// Check that all the uses of this def are still valid after we
52+
// move the def op. If there's a single use, its always safe to
53+
// fuse with the def. For the case when we have more than 1 use,
54+
// see below for when it is safe to fuse with the def.
55+
bool areUsesValidForFusion = false;
56+
if (otherUses.size() > 1) {
57+
// If we have more than one use, either
58+
// 1. All those uses are used by the current op
59+
if (llvm::all_of(otherUses,
60+
[&](Operation *userOp) { return userOp == op; }))
61+
areUsesValidForFusion = true;
62+
63+
// 2. All those uses are in the same group as the current op.
64+
// NOTE: We are checking here that the original op is already
65+
// inside a group and that all the other uses of this def are in
66+
// that group. That means that we can safely move this def to the
67+
// beginning of the group.
68+
//
69+
// We cannot do this if the use is not inside a group because
70+
// then we are creating a new group.
71+
if (opIsInsideGroup &&
72+
llvm::all_of(otherUses, [&](Operation *userOp) {
73+
return userOp->getParentRegion() == op->getParentRegion();
74+
}))
75+
areUsesValidForFusion = true;
76+
} else if (otherUses.size() == 1) {
77+
// If we have exactly one use, then we can fuse.
78+
areUsesValidForFusion = true;
79+
}
80+
81+
if (!areUsesValidForFusion)
5582
continue;
5683

5784
// Fuse the def and use ops into a group.
@@ -84,6 +111,10 @@ GenericBottomUpFuser::matchAndRewrite(Operation *op,
84111
def->moveBefore(use);
85112
}
86113
} else if (auto groupOp = dyn_cast<tcp::GroupOp>(use->getParentOp())) {
114+
// We already know that all other uses are in the same group
115+
// and because we are doing this bottom up, this is the "first"
116+
// use of this op in this group. So its OK to move it to just
117+
// before this use.
87118
def->moveBefore(use);
88119
} else {
89120
llvm_unreachable("Unhandled case during fusion");

test/Dialect/tcp_fusion.mlir

+25-9
Original file line numberDiff line numberDiff line change
@@ -57,16 +57,13 @@ func.func @test_multiple_fusions(%arg0 : tensor<?x?xf32>,
5757

5858
// CHECK: func.func @test_multi_use_fusion(%[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
5959
// CHECK: %[[V0:.+]] = tcp.group {
60-
// CHECK: %[[V2:.+]] = tcp.tanh %[[ARG0]] : tensor<?x?xf32> -> tensor<?x?xf32>
61-
// CHECK: %[[V3:.+]] = tcp.add %[[V2]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
62-
// CHECK: tcp.yield %[[V3]] : tensor<?x?xf32>
63-
// CHECK: } : tensor<?x?xf32>
64-
// CHECK: %[[V1:.+]] = tcp.group {
65-
// CHECK: %[[V2]] = tcp.sub %[[V0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
66-
// CHECK: %[[V3]] = tcp.mul %[[V0]], %[[V2]] : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
67-
// CHECK: tcp.yield %[[V3]] : tensor<?x?xf32>
60+
// CHECK: %[[V1:.+]] = tcp.tanh %[[ARG0]] : tensor<?x?xf32> -> tensor<?x?xf32>
61+
// CHECK: %[[V2:.+]] = tcp.add %[[V1]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
62+
// CHECK: %[[V3:.+]] = tcp.sub %[[V2]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
63+
// CHECK: %[[V4:.+]] = tcp.mul %[[V2]], %[[V3]] : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
64+
// CHECK: tcp.yield %[[V4]] : tensor<?x?xf32>
6865
// CHECK: } : tensor<?x?xf32>
69-
// CHECK: return %[[V1]] : tensor<?x?xf32>
66+
// CHECK: return %[[V0]] : tensor<?x?xf32>
7067
// CHECK: }
7168
func.func @test_multi_use_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?xf32> {
7269
%0 = tcp.tanh %arg0 : tensor<?x?xf32> -> tensor<?x?xf32>
@@ -207,3 +204,22 @@ func.func @buggy_tcp_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) ->
207204
%6 = tcp.custom_op("test.op") %5 : tensor<?x?xf32> -> tensor<?x?xf32>
208205
return %2 : tensor<?x?xf32>
209206
}
207+
208+
// -----
209+
210+
// Make sure that things do not break if a value is used twice by the same
211+
// op.
212+
213+
// CHECK: func.func @test_multi_use_fusion_same_op_uses(%[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
214+
// CHECK: %[[V0:.+]] = tcp.group {
215+
// CHECK: %[[V1:.+]] = tcp.tanh %[[ARG0]] : tensor<?x?xf32> -> tensor<?x?xf32>
216+
// CHECK: %[[V2:.+]] = tcp.mul %[[V1]], %[[V1]] : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
217+
// CHECK: tcp.yield %[[V2]] : tensor<?x?xf32>
218+
// CHECK: } : tensor<?x?xf32>
219+
// CHECK: return %[[V0]] : tensor<?x?xf32>
220+
// CHECK: }
221+
func.func @test_multi_use_fusion_same_op_uses(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?xf32> {
222+
%0 = tcp.tanh %arg0 : tensor<?x?xf32> -> tensor<?x?xf32>
223+
%3 = tcp.mul %0, %0 : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
224+
return %3 : tensor<?x?xf32>
225+
}

0 commit comments

Comments
 (0)