@@ -37,21 +37,30 @@ void generatePatternGroup(OpBuilder &builder, Location loc, Value root,
37
37
}
38
38
}
39
39
40
- LogicalResult generateTransform (OpBuilder &builder, llvm::APInt version) {
41
- auto loc = builder.getUnknownLoc ();
40
+ Value generateTransformMain (OpBuilder &builder, Location loc) {
42
41
auto namedSequence = builder.create <transform::NamedSequenceOp>(
43
42
loc, " __transform_main" , builder.getType <transform::AnyOpType>(),
44
43
TypeRange (), [](OpBuilder &builder, Location loc, BlockArgument) {
45
44
builder.create <transform::YieldOp>(loc);
46
45
});
46
+ builder.setInsertionPointToStart (&namedSequence.getBody ().front ());
47
+ auto match = builder.create <transform::MatchOp>(
48
+ loc, namedSequence.getBody ().front ().getArgument (0 ),
49
+ ArrayRef<StringRef>{func::FuncOp::getOperationName ()});
50
+ return match;
51
+ }
52
+
53
+ LogicalResult generateTransform (OpBuilder &builder, llvm::APInt version) {
54
+ auto loc = builder.getUnknownLoc ();
55
+ Value match = generateTransformMain (builder, loc);
47
56
48
57
SmallVector<OpConfig> opConfigurations;
49
58
for (StringRef name : mlir::enzyme::getTransformOperationNames ()) {
50
59
std::optional<RegisteredOperationName> opName =
51
60
RegisteredOperationName::lookup (name, builder.getContext ());
52
61
if (!opName) {
53
- return namedSequence-> emitError () << " unregistered pattern op '" << name
54
- << " ' listed for construction" ;
62
+ return emitError (loc ) << " unregistered pattern op '" << name
63
+ << " ' listed for construction" ;
55
64
}
56
65
auto *concept =
57
66
opName->getInterface <SearchablePatternDescriptorOpInterface>();
@@ -60,11 +69,6 @@ LogicalResult generateTransform(OpBuilder &builder, llvm::APInt version) {
60
69
}
61
70
}
62
71
63
- builder.setInsertionPointToStart (&namedSequence.getBody ().front ());
64
- auto match = builder.create <transform::MatchOp>(
65
- loc, namedSequence.getBody ().front ().getArgument (0 ),
66
- ArrayRef<StringRef>{func::FuncOp::getOperationName ()});
67
-
68
72
auto configPow = llvm::APInt::getOneBitSet (opConfigurations.size () + 1 ,
69
73
opConfigurations.size ());
70
74
do {
@@ -75,6 +79,60 @@ LogicalResult generateTransform(OpBuilder &builder, llvm::APInt version) {
75
79
return success ();
76
80
}
77
81
82
+ LogicalResult parseTransform (OpBuilder &builder, Location loc,
83
+ StringRef patterns) {
84
+ Value root = generateTransformMain (builder, loc);
85
+ auto apply = builder.create <transform::ApplyPatternsOp>(
86
+ loc, root, [](OpBuilder &builder, Location loc) {});
87
+ builder.setInsertionPointToStart (apply.getBody ());
88
+
89
+ SmallVector<StringRef> singlePatterns;
90
+ patterns.split (singlePatterns, ' ;' , /* MaxSplit=*/ -1 , /* KeepEmpty=*/ false );
91
+ for (StringRef pattern : singlePatterns) {
92
+ pattern = pattern.trim ();
93
+ size_t pos = pattern.find_first_of (" <(" );
94
+ StringRef opName =
95
+ pos == std::string::npos ? pattern : pattern.take_front (pos).trim ();
96
+ StringRef remainder =
97
+ pos == std::string::npos ? " " : pattern.drop_front (pos);
98
+
99
+ int64_t benefit = 1 ;
100
+ if (remainder.starts_with (" <" )) {
101
+ size_t closing = remainder.find (' >' );
102
+ if (closing == std::string::npos) {
103
+ return ::emitError (loc)
104
+ << " couldn't find matching '>' in " << remainder;
105
+ }
106
+ StringRef benefitStr = remainder.drop_front ().take_front (closing - 1 );
107
+ if (benefitStr.getAsInteger (0 , benefit)) {
108
+ return ::emitError (loc) << " couldn't parse benefit: " << benefitStr;
109
+ }
110
+ remainder = remainder.drop_front (closing + 1 ).trim ();
111
+ }
112
+
113
+ int64_t parameter = -1 ;
114
+ if (remainder.starts_with (" (" )) {
115
+ if (!remainder.ends_with (" )" )) {
116
+ return ::emitError (loc)
117
+ << " couldn't find the closing ')' in " << remainder;
118
+ }
119
+ StringRef parameterStr = remainder.drop_front ().drop_back ();
120
+ if (parameterStr.getAsInteger (0 , parameter)) {
121
+ return ::emitError (loc) << " couldn't parse parameter: " << parameterStr;
122
+ }
123
+ }
124
+
125
+ OperationState state (loc,
126
+ " transform.apply_patterns.enzyme_hlo." + opName.str ());
127
+ if (benefit != 1 )
128
+ state.addAttribute (" benefit" , builder.getI64IntegerAttr (benefit));
129
+ if (parameter != -1 )
130
+ state.addAttribute (" parameter" , builder.getI64IntegerAttr (parameter));
131
+ builder.create (state);
132
+ }
133
+ return success ();
134
+ }
135
+
78
136
namespace {
79
137
class GenerateApplyPatternsPass
80
138
: public PassWrapper<GenerateApplyPatternsPass, OperationPass<>> {
@@ -93,27 +151,37 @@ class GenerateApplyPatternsPass
93
151
94
152
void runOnOperation () override {
95
153
Operation *op = getOperation ();
154
+ if (!flags.getValue ().empty () && !patterns.getValue ().empty ()) {
155
+ op->emitError () << " flags and patterns are mutually exclusive" ;
156
+ return signalPassFailure ();
157
+ }
96
158
if (op->getNumRegions () != 1 || !llvm::hasSingleElement (op->getRegion (0 ))) {
97
159
op->emitError ()
98
160
<< " can only run on a single-region single-block operation" ;
99
161
return signalPassFailure ();
100
162
}
101
163
102
- llvm::APInt version (
103
- llvm::APInt::getSufficientBitsNeeded (flags.getValue (), radix),
104
- flags.getValue (), radix);
105
-
106
164
OpBuilder builder (&getContext ());
107
165
op->setAttr (transform::TransformDialect::kWithNamedSequenceAttrName ,
108
166
builder.getUnitAttr ());
109
167
110
168
builder.setInsertionPointToStart (&op->getRegion (0 ).front ());
111
- if (failed (generateTransform (builder, version)))
112
- return signalPassFailure ();
169
+
170
+ if (!flags.empty ()) {
171
+ llvm::APInt version (
172
+ llvm::APInt::getSufficientBitsNeeded (flags.getValue (), radix) + 1 ,
173
+ flags.getValue (), radix);
174
+ if (failed (generateTransform (builder, version)))
175
+ return signalPassFailure ();
176
+ } else {
177
+ if (failed (parseTransform (builder, op->getLoc (), patterns)))
178
+ return signalPassFailure ();
179
+ }
113
180
}
114
181
115
182
Option<std::string> flags{*this , " flags" , llvm::cl::init (" " )};
116
183
Option<int > radix{*this , " radix" , llvm::cl::init (10 )};
184
+ Option<std::string> patterns{*this , " patterns" , llvm::cl::init (" " )};
117
185
};
118
186
119
187
class RemoveTransform : public PassWrapper <RemoveTransform, OperationPass<>> {
0 commit comments