Skip to content

Commit 4d1090b

Browse files
committed
WIP: more transform dialect
1 parent 5a2655e commit 4d1090b

File tree

6 files changed

+542
-101
lines changed

6 files changed

+542
-101
lines changed

src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4010,8 +4010,24 @@ template <typename T> struct CSE final : OpRewritePattern<T> {
40104010
#include "src/enzyme_ad/jax/Passes/EnzymeHLOPatterns.cpp.inc"
40114011

40124012
void mlir::transform::addPadDotGeneral(RewritePatternSet &patterns,
4013-
bool postPad, MLIRContext &context) {
4014-
patterns.insert<PadDotGeneral>(postPad, &context);
4013+
bool postPad, MLIRContext &context,
4014+
PatternBenefit benefit) {
4015+
patterns.insert<PadDotGeneral>(postPad, &context, benefit);
4016+
}
4017+
4018+
void mlir::transform::addIotaSimplify(RewritePatternSet &patterns,
4019+
int64_t maxConstantExpansion,
4020+
MLIRContext &context,
4021+
PatternBenefit benefit) {
4022+
patterns.insert<IotaSimplify>(maxConstantExpansion, &context, benefit);
4023+
}
4024+
4025+
void mlir::transform::addBroadcastInDimSimplify(RewritePatternSet &patterns,
4026+
int64_t maxConstantExpansion,
4027+
MLIRContext &context,
4028+
PatternBenefit benefit) {
4029+
patterns.insert<BroadcastInDimSimplify>(maxConstantExpansion, &context,
4030+
benefit);
40154031
}
40164032

40174033
namespace {
Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,25 @@
1+
//===- EnzymeHLOPatterns.h - functions to register patterns -----*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
19
namespace mlir {
2-
class RewritePatternSet;
310
class MLIRContext;
11+
class PatternBenefit;
12+
class RewritePatternSet;
413
} // namespace mlir
514

615
#include "src/enzyme_ad/jax/Passes/EnzymeHLOPatterns.h.inc"
716

817
namespace mlir::transform {
918
void addPadDotGeneral(RewritePatternSet &patterns, bool postPad,
10-
MLIRContext &context);
11-
}
19+
MLIRContext &context, PatternBenefit benefit);
20+
void addIotaSimplify(RewritePatternSet &patterns, int64_t maxConstantExpansion,
21+
MLIRContext &context, PatternBenefit benefit);
22+
void addBroadcastInDimSimplify(RewritePatternSet &patterns,
23+
int64_t maxConstantExpansion,
24+
MLIRContext &context, PatternBenefit benefit);
25+
} // namespace mlir::transform

src/enzyme_ad/jax/TransformOps/GenerateApplyPatterns.cpp

Lines changed: 83 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -37,21 +37,30 @@ void generatePatternGroup(OpBuilder &builder, Location loc, Value root,
3737
}
3838
}
3939

40-
LogicalResult generateTransform(OpBuilder &builder, llvm::APInt version) {
41-
auto loc = builder.getUnknownLoc();
40+
Value generateTransformMain(OpBuilder &builder, Location loc) {
4241
auto namedSequence = builder.create<transform::NamedSequenceOp>(
4342
loc, "__transform_main", builder.getType<transform::AnyOpType>(),
4443
TypeRange(), [](OpBuilder &builder, Location loc, BlockArgument) {
4544
builder.create<transform::YieldOp>(loc);
4645
});
46+
builder.setInsertionPointToStart(&namedSequence.getBody().front());
47+
auto match = builder.create<transform::MatchOp>(
48+
loc, namedSequence.getBody().front().getArgument(0),
49+
ArrayRef<StringRef>{func::FuncOp::getOperationName()});
50+
return match;
51+
}
52+
53+
LogicalResult generateTransform(OpBuilder &builder, llvm::APInt version) {
54+
auto loc = builder.getUnknownLoc();
55+
Value match = generateTransformMain(builder, loc);
4756

4857
SmallVector<OpConfig> opConfigurations;
4958
for (StringRef name : mlir::enzyme::getTransformOperationNames()) {
5059
std::optional<RegisteredOperationName> opName =
5160
RegisteredOperationName::lookup(name, builder.getContext());
5261
if (!opName) {
53-
return namedSequence->emitError() << "unregistered pattern op '" << name
54-
<< "' listed for construction";
62+
return emitError(loc) << "unregistered pattern op '" << name
63+
<< "' listed for construction";
5564
}
5665
auto *concept =
5766
opName->getInterface<SearchablePatternDescriptorOpInterface>();
@@ -60,11 +69,6 @@ LogicalResult generateTransform(OpBuilder &builder, llvm::APInt version) {
6069
}
6170
}
6271

63-
builder.setInsertionPointToStart(&namedSequence.getBody().front());
64-
auto match = builder.create<transform::MatchOp>(
65-
loc, namedSequence.getBody().front().getArgument(0),
66-
ArrayRef<StringRef>{func::FuncOp::getOperationName()});
67-
6872
auto configPow = llvm::APInt::getOneBitSet(opConfigurations.size() + 1,
6973
opConfigurations.size());
7074
do {
@@ -75,6 +79,60 @@ LogicalResult generateTransform(OpBuilder &builder, llvm::APInt version) {
7579
return success();
7680
}
7781

82+
LogicalResult parseTransform(OpBuilder &builder, Location loc,
83+
StringRef patterns) {
84+
Value root = generateTransformMain(builder, loc);
85+
auto apply = builder.create<transform::ApplyPatternsOp>(
86+
loc, root, [](OpBuilder &builder, Location loc) {});
87+
builder.setInsertionPointToStart(apply.getBody());
88+
89+
SmallVector<StringRef> singlePatterns;
90+
patterns.split(singlePatterns, ';', /*MaxSplit=*/-1, /*KeepEmpty=*/false);
91+
for (StringRef pattern : singlePatterns) {
92+
pattern = pattern.trim();
93+
size_t pos = pattern.find_first_of("<(");
94+
StringRef opName =
95+
pos == std::string::npos ? pattern : pattern.take_front(pos).trim();
96+
StringRef remainder =
97+
pos == std::string::npos ? "" : pattern.drop_front(pos);
98+
99+
int64_t benefit = 1;
100+
if (remainder.starts_with("<")) {
101+
size_t closing = remainder.find('>');
102+
if (closing == std::string::npos) {
103+
return ::emitError(loc)
104+
<< "couldn't find matching '>' in " << remainder;
105+
}
106+
StringRef benefitStr = remainder.drop_front().take_front(closing - 1);
107+
if (benefitStr.getAsInteger(0, benefit)) {
108+
return ::emitError(loc) << "couldn't parse benefit: " << benefitStr;
109+
}
110+
remainder = remainder.drop_front(closing + 1).trim();
111+
}
112+
113+
int64_t parameter = -1;
114+
if (remainder.starts_with("(")) {
115+
if (!remainder.ends_with(")")) {
116+
return ::emitError(loc)
117+
<< "couldn't find the closing ')' in " << remainder;
118+
}
119+
StringRef parameterStr = remainder.drop_front().drop_back();
120+
if (parameterStr.getAsInteger(0, parameter)) {
121+
return ::emitError(loc) << "couldn't parse parameter: " << parameterStr;
122+
}
123+
}
124+
125+
OperationState state(loc,
126+
"transform.apply_patterns.enzyme_hlo." + opName.str());
127+
if (benefit != 1)
128+
state.addAttribute("benefit", builder.getI64IntegerAttr(benefit));
129+
if (parameter != -1)
130+
state.addAttribute("parameter", builder.getI64IntegerAttr(parameter));
131+
builder.create(state);
132+
}
133+
return success();
134+
}
135+
78136
namespace {
79137
class GenerateApplyPatternsPass
80138
: public PassWrapper<GenerateApplyPatternsPass, OperationPass<>> {
@@ -93,27 +151,37 @@ class GenerateApplyPatternsPass
93151

94152
void runOnOperation() override {
95153
Operation *op = getOperation();
154+
if (!flags.getValue().empty() && !patterns.getValue().empty()) {
155+
op->emitError() << "flags and patterns are mutually exclusive";
156+
return signalPassFailure();
157+
}
96158
if (op->getNumRegions() != 1 || !llvm::hasSingleElement(op->getRegion(0))) {
97159
op->emitError()
98160
<< "can only run on a single-region single-block operation";
99161
return signalPassFailure();
100162
}
101163

102-
llvm::APInt version(
103-
llvm::APInt::getSufficientBitsNeeded(flags.getValue(), radix),
104-
flags.getValue(), radix);
105-
106164
OpBuilder builder(&getContext());
107165
op->setAttr(transform::TransformDialect::kWithNamedSequenceAttrName,
108166
builder.getUnitAttr());
109167

110168
builder.setInsertionPointToStart(&op->getRegion(0).front());
111-
if (failed(generateTransform(builder, version)))
112-
return signalPassFailure();
169+
170+
if (!flags.empty()) {
171+
llvm::APInt version(
172+
llvm::APInt::getSufficientBitsNeeded(flags.getValue(), radix) + 1,
173+
flags.getValue(), radix);
174+
if (failed(generateTransform(builder, version)))
175+
return signalPassFailure();
176+
} else {
177+
if (failed(parseTransform(builder, op->getLoc(), patterns)))
178+
return signalPassFailure();
179+
}
113180
}
114181

115182
Option<std::string> flags{*this, "flags", llvm::cl::init("")};
116183
Option<int> radix{*this, "radix", llvm::cl::init(10)};
184+
Option<std::string> patterns{*this, "patterns", llvm::cl::init("")};
117185
};
118186

119187
class RemoveTransform : public PassWrapper<RemoveTransform, OperationPass<>> {

src/enzyme_ad/jax/TransformOps/TransformOps.cpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,18 @@ namespace mlir {
2525
namespace transform {
2626

2727
void ApplyPadDotGeneralPatterns::populatePatterns(RewritePatternSet &patterns) {
28-
addPadDotGeneral(patterns, getPostPad(), *getContext());
28+
addPadDotGeneral(patterns, getParameter(), *getContext(),
29+
PatternBenefit(getBenefit().value_or(1)));
30+
}
31+
32+
void ApplyIotaSimplifyPatterns::populatePatterns(RewritePatternSet &patterns) {
33+
addIotaSimplify(patterns, getParameter(), *getContext(),
34+
PatternBenefit(getBenefit().value_or(1)));
35+
}
36+
void ApplyBroadcastInDimSimplifyPatterns::populatePatterns(
37+
RewritePatternSet &patterns) {
38+
addBroadcastInDimSimplify(patterns, getParameter(), *getContext(),
39+
PatternBenefit(getBenefit().value_or(1)));
2940
}
3041

3142
} // namespace transform

0 commit comments

Comments
 (0)