Skip to content
This repository was archived by the owner on Jan 30, 2025. It is now read-only.

Commit a145d84

Browse files
srinathavaSrinath Avadhanula
and
Srinath Avadhanula
authored
Add support for customizing const inlining and ops with regions to IsolateGroupOps pass (#87)
A few small improvements to the `IsolateGroupOps` pass: - Provide a hook to customize whether a const like op used in the `tcp.group` will get copied into the group or whether it will be passed in as an input argument. The main pass does always returns true so this is a non-functional change in `mlir-tcp`. - The previous version did not handle ops with contained regions and block arguments (such as an `scf.forall` inside the `tcp.group`). We now handle this (see updated lit test). --------- Co-authored-by: Srinath Avadhanula <[email protected]>
1 parent 7c50225 commit a145d84

File tree

3 files changed

+147
-11
lines changed

3 files changed

+147
-11
lines changed

include/mlir-tcp/Dialect/Transforms/IsolateGroupOpsPass.h

+11
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
#pragma once
1111

12+
#include "mlir-tcp/Dialect/IR/TcpOps.h"
1213
#include "mlir/IR/BuiltinOps.h"
1314
#include "mlir/Pass/Pass.h"
1415
#include <memory>
@@ -18,4 +19,14 @@ namespace mlir::tcp {
1819
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
1920
createTcpIsolateGroupOpsPass();
2021

22+
// `createTcpIsolateGroupOpsPass` will clone all const operations used
23+
// inside a `tcp.group` into the new `tcp.isolated_group` it creates. If
24+
// you want to customize this behavior, you can use this instead to
25+
// pass a predicate function to control when a `const-like` operation
26+
// should be cloned into the isolated group or whether it should be added
27+
// as an argument to the isolated group.
28+
void populateIsolateGroupPatterns(
29+
RewritePatternSet &patterns,
30+
std::function<bool(GroupOp, Value)> shouldCopyConstPredicate);
31+
2132
} // namespace mlir::tcp

lib/Dialect/Transforms/IsolateGroupOpsPass.cpp

+49-11
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,12 @@ namespace mlir::tcp {
2929

3030
namespace {
3131

32-
class IsolateGroups : public OpRewritePattern<tcp::GroupOp> {
32+
class IsolateGroups : public OpRewritePattern<GroupOp> {
3333
public:
34-
using OpRewritePattern::OpRewritePattern;
34+
IsolateGroups(MLIRContext *context,
35+
std::function<bool(tcp::GroupOp, Value)> shouldInlineConst)
36+
: OpRewritePattern<GroupOp>(context),
37+
shouldInlineConst_(shouldInlineConst) {}
3538

3639
LogicalResult matchAndRewrite(tcp::GroupOp groupOp,
3740
PatternRewriter &rewriter) const override {
@@ -41,33 +44,42 @@ class IsolateGroups : public OpRewritePattern<tcp::GroupOp> {
4144
llvm::SmallVector<Value> inputs;
4245
llvm::SmallDenseSet<Value> addedInputs;
4346
llvm::SmallDenseSet<Value> consts;
44-
llvm::SmallDenseSet<Value> defs;
45-
for (auto &op : groupOp.getBody().front()) {
46-
for (auto operand : op.getOperands()) {
47-
if (defs.find(operand) == defs.end()) {
47+
48+
groupOp->walk([&](Operation *op) {
49+
for (auto operand : op->getOperands()) {
50+
// Find the operation defining this Value, or whose block argument
51+
// this Value is.
52+
auto operandDefiningOp = operand.getDefiningOp();
53+
if (!operandDefiningOp) {
54+
operandDefiningOp = operand.getParentBlock()->getParentOp();
55+
}
56+
// If that operation lives outside the group, we need to add it as
57+
// an input to the newly created isolated group.
58+
if (!groupOp->isProperAncestor(operandDefiningOp)) {
4859
if (operand.getDefiningOp() &&
49-
operand.getDefiningOp()->hasTrait<OpTrait::ConstantLike>()) {
60+
operand.getDefiningOp()->hasTrait<OpTrait::ConstantLike>() &&
61+
shouldInlineConst_(groupOp, operand)) {
5062
consts.insert(operand);
5163
} else if (!addedInputs.contains(operand)) {
5264
inputs.push_back(operand);
5365
addedInputs.insert(operand);
5466
}
5567
}
5668
}
57-
defs.insert(op.getResults().begin(), op.getResults().end());
58-
}
69+
});
5970

6071
auto isolatedGroupOp = rewriter.create<tcp::IsolatedGroupOp>(
6172
groupOp.getLoc(), groupOp.getResultTypes(), inputs);
6273
isolatedGroupOp->setAttrs(groupOp->getAttrs());
6374

6475
isolatedGroupOp.getBody().takeBody(groupOp.getBody());
76+
6577
auto &isolatedGroupBlock = isolatedGroupOp.getBody().front();
6678
{
6779
OpBuilder::InsertionGuard guard(rewriter);
6880
rewriter.setInsertionPointToStart(&isolatedGroupBlock);
6981
auto belongsToIsolatedGroup = [&](OpOperand &opOperand) {
70-
return (opOperand.getOwner()->getParentOp() == isolatedGroupOp);
82+
return (isolatedGroupOp->isProperAncestor(opOperand.getOwner()));
7183
};
7284

7385
// Clone the constants at the start of the isolated group block.
@@ -91,6 +103,23 @@ class IsolateGroups : public OpRewritePattern<tcp::GroupOp> {
91103
rewriter.eraseOp(groupOp);
92104
return success();
93105
}
106+
107+
private:
108+
std::function<bool(tcp::GroupOp, Value)> shouldInlineConst_;
109+
};
110+
111+
class DropSymbolicShapesInsideGroups
112+
: public OpRewritePattern<tcp::BindSymbolicShapeOp> {
113+
using OpRewritePattern<tcp::BindSymbolicShapeOp>::OpRewritePattern;
114+
115+
LogicalResult matchAndRewrite(tcp::BindSymbolicShapeOp shapeOp,
116+
PatternRewriter &rewriter) const override {
117+
if (isa<tcp::GroupOp>(shapeOp->getParentOp())) {
118+
rewriter.eraseOp(shapeOp);
119+
return success();
120+
}
121+
return failure();
122+
}
94123
};
95124

96125
class TcpIsolateGroupOpsPass
@@ -100,7 +129,8 @@ class TcpIsolateGroupOpsPass
100129
MLIRContext *context = op->getContext();
101130
RewritePatternSet patterns(context);
102131

103-
patterns.add<IsolateGroups>(context);
132+
auto shouldCopyConstPredicate = [&](tcp::GroupOp, Value) { return true; };
133+
populateIsolateGroupPatterns(patterns, shouldCopyConstPredicate);
104134
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
105135
return signalPassFailure();
106136
}
@@ -112,4 +142,12 @@ std::unique_ptr<OperationPass<ModuleOp>> createTcpIsolateGroupOpsPass() {
112142
return std::make_unique<TcpIsolateGroupOpsPass>();
113143
}
114144

145+
void populateIsolateGroupPatterns(
146+
RewritePatternSet &patterns,
147+
std::function<bool(tcp::GroupOp, Value)> shouldCopyConstPredicate) {
148+
149+
patterns.add<IsolateGroups>(patterns.getContext(), shouldCopyConstPredicate);
150+
patterns.add<DropSymbolicShapesInsideGroups>(patterns.getContext());
151+
}
152+
115153
} // namespace mlir::tcp

test/Dialect/tcp_isolate_groups.mlir

+87
Original file line numberDiff line numberDiff line change
@@ -120,3 +120,90 @@ func.func @test_inputs_with_multiple_uses(%arg0 : tensor<5xi32>) -> tensor<5xi32
120120
}) : () -> tensor<5xi32>
121121
return %10 : tensor<5xi32>
122122
}
123+
124+
125+
// -----
126+
127+
// isolate tcp.group ops in the presence of nested regions.
128+
129+
// CHECK: #map = affine_map<(d0, d1) -> (d0, d1)>
130+
// CHECK: module {
131+
// CHECK: func.func @forward(%[[ARG0:.+]]: tensor<?x4096xf32>, %[[ARG1:.+]]: tensor<?x4096xf32>, %[[ARG2:.+]]: tensor<?x4096xf32>) -> tensor<?x4096xf32> {
132+
// CHECK: %[[C0:.+]] = arith.constant 0 : index
133+
// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x4096xf32>
134+
// CHECK: %[[V0:.+]] = tcp.isolated_group %[[DIM]], %[[ARG0]], %[[ARG1]] attributes {group_type = "codegen_group"} {
135+
// CHECK: ^bb0(%[[ARG3:.+]]: index, %[[ARG4:.+]]: tensor<?x4096xf32>, %[[ARG5:.+]]: tensor<?x4096xf32>):
136+
// CHECK: %[[V1:.+]] = tensor.empty(%[[ARG3]]) : tensor<?x4096xf32>
137+
// CHECK: %[[V2:.+]] = scf.forall (%[[ARG6:.+]], %[[ARG7:.+]]) in (%[[ARG3]], 4096) shared_outs(%[[ARG8:.+]] = %[[V1]]) -> (tensor<?x4096xf32>) {
138+
// CHECK: %[[EXTRACTED_SLICE:.+]] = tensor.extract_slice %[[ARG4]][%[[ARG6]], %[[ARG7]]] [1, 1] [1, 1] : tensor<?x4096xf32> to tensor<1x1xf32>
139+
// CHECK: %[[EXTRACTED_SLICE_0:.+]] = tensor.extract_slice %[[ARG5]][%[[ARG6]], %[[ARG7]]] [1, 1] [1, 1] : tensor<?x4096xf32> to tensor<1x1xf32>
140+
// CHECK: %[[V3:.+]] = tensor.empty() : tensor<1x1xf32>
141+
// CHECK: %[[V4:.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%[[EXTRACTED_SLICE]], %[[EXTRACTED_SLICE_0]] : tensor<1x1xf32>, tensor<1x1xf32>) outs(%[[V3]] : tensor<1x1xf32>) {
142+
// CHECK: ^bb0(%[[IN:.+]]: f32, %[[IN_1:.+]]: f32, %[[OUT:.+]]: f32):
143+
// CHECK: %[[V5:.+]] = arith.mulf %[[IN]], %[[IN_1]] : f32
144+
// CHECK: linalg.yield %[[V5]] : f32
145+
// CHECK: } -> tensor<1x1xf32>
146+
// CHECK: scf.forall.in_parallel {
147+
// CHECK: tensor.parallel_insert_slice %[[V4]] into %[[ARG8]][%[[ARG6]], %[[ARG7]]] [1, 1] [1, 1] : tensor<1x1xf32> into tensor<?x4096xf32>
148+
// CHECK: }
149+
// CHECK: }
150+
// CHECK: tcp.yield %[[V2]] : tensor<?x4096xf32>
151+
// CHECK: } : index, tensor<?x4096xf32>, tensor<?x4096xf32> -> tensor<?x4096xf32>
152+
// CHECK: return %[[V0]] : tensor<?x4096xf32>
153+
// CHECK: }
154+
// CHECK: }
155+
#map = affine_map<(d0, d1) -> (d0, d1)>
156+
func.func @forward(%arg0: tensor<?x4096xf32>, %arg1: tensor<?x4096xf32>, %arg2: tensor<?x4096xf32>) -> tensor<?x4096xf32> {
157+
%c0 = arith.constant 0 : index
158+
%dim = tensor.dim %arg0, %c0 : tensor<?x4096xf32>
159+
%0 = tcp.group attributes {group_type = "codegen_group"} {
160+
%1 = tensor.empty(%dim) : tensor<?x4096xf32>
161+
%2 = scf.forall (%arg3, %arg4) in (%dim, 4096) shared_outs(%arg5 = %1) -> (tensor<?x4096xf32>) {
162+
%extracted_slice = tensor.extract_slice %arg0[%arg3, %arg4] [1, 1] [1, 1] : tensor<?x4096xf32> to tensor<1x1xf32>
163+
%extracted_slice_0 = tensor.extract_slice %arg1[%arg3, %arg4] [1, 1] [1, 1] : tensor<?x4096xf32> to tensor<1x1xf32>
164+
%3 = tensor.empty() : tensor<1x1xf32>
165+
%4 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%extracted_slice, %extracted_slice_0 : tensor<1x1xf32>, tensor<1x1xf32>) outs(%3 : tensor<1x1xf32>) {
166+
^bb0(%in: f32, %in_4: f32, %out: f32):
167+
%8 = arith.mulf %in, %in_4 : f32
168+
linalg.yield %8 : f32
169+
} -> tensor<1x1xf32>
170+
scf.forall.in_parallel {
171+
tensor.parallel_insert_slice %4 into %arg5[%arg3, %arg4] [1, 1] [1, 1] : tensor<1x1xf32> into tensor<?x4096xf32>
172+
}
173+
}
174+
tcp.yield %2 : tensor<?x4096xf32>
175+
} : tensor<?x4096xf32>
176+
return %0 : tensor<?x4096xf32>
177+
}
178+
179+
// -----
180+
181+
// Ensure that we correctly drop `tcp.bind_symbolic_shape` ops within the
182+
// newly created tcp.isolated_group region.
183+
184+
// CHECK: func.func @test_symbolic_shape_ops(%[[ARG0:.+]]: tensor<?x3xf32>) -> tensor<?x3xf32> {
185+
// CHECK: %[[V0:.+]] = tcp.symbolic_int "s0" {min_val = 2, max_val = 9223372036854775806} : i64
186+
// CHECK: tcp.bind_symbolic_shape %[[ARG0]], [%[[V0]]], affine_map<()[s0] -> (s0, 3)> : tensor<?x3xf32>
187+
// CHECK: %[[V1:.+]] = tcp.isolated_group %[[ARG0]] {
188+
// CHECK: ^bb0(%[[ARG1:.+]]: tensor<?x3xf32>):
189+
// CHECK: %[[V2:.+]] = tcp.add %[[ARG1]], %[[ARG1]] : tensor<?x3xf32>, tensor<?x3xf32> -> tensor<?x3xf32>
190+
// CHECK-NOT: tcp.bind_symbolic_shape
191+
// CHECK: %[[V3:.+]] = tcp.mul %[[V2]], %[[V2]] : tensor<?x3xf32>, tensor<?x3xf32> -> tensor<?x3xf32>
192+
// CHECK: tcp.yield %[[V3]] : tensor<?x3xf32>
193+
// CHECK: } : tensor<?x3xf32> -> tensor<?x3xf32>
194+
// CHECK: tcp.bind_symbolic_shape %[[V1]], [%[[V0]]], affine_map<()[s0] -> (s0, 3)> : tensor<?x3xf32>
195+
// CHECK: return %[[V1]] : tensor<?x3xf32>
196+
// CHECK: }
197+
func.func @test_symbolic_shape_ops(%arg0 : tensor<?x3xf32>) -> tensor<?x3xf32> {
198+
%0 = tcp.symbolic_int "s0" {min_val = 2, max_val = 9223372036854775806} : i64
199+
tcp.bind_symbolic_shape %arg0, [%0], affine_map<()[s0] -> (s0, 3)> : tensor<?x3xf32>
200+
%10 = "tcp.group" () ({
201+
^bb0() :
202+
%2 = tcp.add %arg0, %arg0 : tensor<?x3xf32>, tensor<?x3xf32> -> tensor<?x3xf32>
203+
tcp.bind_symbolic_shape %2, [%0], affine_map<()[s0] -> (s0, 3)> : tensor<?x3xf32>
204+
%3 = tcp.mul %2, %2 : tensor<?x3xf32>, tensor<?x3xf32> -> tensor<?x3xf32>
205+
tcp.yield %3 : tensor<?x3xf32>
206+
}) : () -> tensor<?x3xf32>
207+
tcp.bind_symbolic_shape %10, [%0], affine_map<()[s0] -> (s0, 3)> : tensor<?x3xf32>
208+
return %10 : tensor<?x3xf32>
209+
}

0 commit comments

Comments
 (0)