Skip to content

Transform dialect ops for all patterns #71

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/enzyme_ad/jax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ cc_library(
"@llvm-project//mlir:LinalgTransformOps",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:TransformDialect",
"@llvm-project//mlir:TransformDialectInterfaces",
":TransformOpsIncGen",
":TransformOpsImplIncGen",
":XLADerivatives",
Expand Down Expand Up @@ -237,8 +238,8 @@ cc_library(
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:FunctionInterfaces",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:LLVMCommonConversion",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:Transforms",
Expand Down
20 changes: 18 additions & 2 deletions src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4090,8 +4090,24 @@ template <typename T> struct CSE final : OpRewritePattern<T> {
#include "src/enzyme_ad/jax/Passes/EnzymeHLOPatterns.cpp.inc"

void mlir::transform::addPadDotGeneral(RewritePatternSet &patterns,
bool postPad, MLIRContext &context) {
patterns.insert<PadDotGeneral>(postPad, &context);
bool postPad, MLIRContext &context,
PatternBenefit benefit) {
patterns.insert<PadDotGeneral>(postPad, &context, benefit);
}

void mlir::transform::addIotaSimplify(RewritePatternSet &patterns,
int64_t maxConstantExpansion,
MLIRContext &context,
PatternBenefit benefit) {
patterns.insert<IotaSimplify>(maxConstantExpansion, &context, benefit);
}

void mlir::transform::addBroadcastInDimSimplify(RewritePatternSet &patterns,
int64_t maxConstantExpansion,
MLIRContext &context,
PatternBenefit benefit) {
patterns.insert<BroadcastInDimSimplify>(maxConstantExpansion, &context,
benefit);
}

namespace {
Expand Down
20 changes: 17 additions & 3 deletions src/enzyme_ad/jax/Passes/EnzymeHLOPatterns.h
Original file line number Diff line number Diff line change
@@ -1,11 +1,25 @@
//===- EnzymeHLOPatterns.h - functions to register patterns -----*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

namespace mlir {
class RewritePatternSet;
class MLIRContext;
class PatternBenefit;
class RewritePatternSet;
} // namespace mlir

#include "src/enzyme_ad/jax/Passes/EnzymeHLOPatterns.h.inc"

namespace mlir::transform {
void addPadDotGeneral(RewritePatternSet &patterns, bool postPad,
MLIRContext &context);
}
MLIRContext &context, PatternBenefit benefit);
void addIotaSimplify(RewritePatternSet &patterns, int64_t maxConstantExpansion,
MLIRContext &context, PatternBenefit benefit);
void addBroadcastInDimSimplify(RewritePatternSet &patterns,
int64_t maxConstantExpansion,
MLIRContext &context, PatternBenefit benefit);
} // namespace mlir::transform
98 changes: 83 additions & 15 deletions src/enzyme_ad/jax/TransformOps/GenerateApplyPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,21 +37,30 @@ void generatePatternGroup(OpBuilder &builder, Location loc, Value root,
}
}

LogicalResult generateTransform(OpBuilder &builder, llvm::APInt version) {
auto loc = builder.getUnknownLoc();
Value generateTransformMain(OpBuilder &builder, Location loc) {
auto namedSequence = builder.create<transform::NamedSequenceOp>(
loc, "__transform_main", builder.getType<transform::AnyOpType>(),
TypeRange(), [](OpBuilder &builder, Location loc, BlockArgument) {
builder.create<transform::YieldOp>(loc);
});
builder.setInsertionPointToStart(&namedSequence.getBody().front());
auto match = builder.create<transform::MatchOp>(
loc, namedSequence.getBody().front().getArgument(0),
ArrayRef<StringRef>{func::FuncOp::getOperationName()});
return match;
}

LogicalResult generateTransform(OpBuilder &builder, llvm::APInt version) {
auto loc = builder.getUnknownLoc();
Value match = generateTransformMain(builder, loc);

SmallVector<OpConfig> opConfigurations;
for (StringRef name : mlir::enzyme::getTransformOperationNames()) {
std::optional<RegisteredOperationName> opName =
RegisteredOperationName::lookup(name, builder.getContext());
if (!opName) {
return namedSequence->emitError() << "unregistered pattern op '" << name
<< "' listed for construction";
return emitError(loc) << "unregistered pattern op '" << name
<< "' listed for construction";
}
auto *conceptV =
opName->getInterface<SearchablePatternDescriptorOpInterface>();
Expand All @@ -61,11 +70,6 @@ LogicalResult generateTransform(OpBuilder &builder, llvm::APInt version) {
}
}

builder.setInsertionPointToStart(&namedSequence.getBody().front());
auto match = builder.create<transform::MatchOp>(
loc, namedSequence.getBody().front().getArgument(0),
ArrayRef<StringRef>{func::FuncOp::getOperationName()});

if (version.getBitWidth() < opConfigurations.size() + 1)
version = version.zext(opConfigurations.size() + 1);

Expand All @@ -79,6 +83,60 @@ LogicalResult generateTransform(OpBuilder &builder, llvm::APInt version) {
return success();
}

LogicalResult parseTransform(OpBuilder &builder, Location loc,
StringRef patterns) {
Value root = generateTransformMain(builder, loc);
auto apply = builder.create<transform::ApplyPatternsOp>(
loc, root, [](OpBuilder &builder, Location loc) {});
builder.setInsertionPointToStart(apply.getBody());

SmallVector<StringRef> singlePatterns;
patterns.split(singlePatterns, ';', /*MaxSplit=*/-1, /*KeepEmpty=*/false);
for (StringRef pattern : singlePatterns) {
pattern = pattern.trim();
size_t pos = pattern.find_first_of("<(");
StringRef opName =
pos == std::string::npos ? pattern : pattern.take_front(pos).trim();
StringRef remainder =
pos == std::string::npos ? "" : pattern.drop_front(pos);

int64_t benefit = 1;
if (remainder.starts_with("<")) {
size_t closing = remainder.find('>');
if (closing == std::string::npos) {
return ::emitError(loc)
<< "couldn't find matching '>' in " << remainder;
}
StringRef benefitStr = remainder.drop_front().take_front(closing - 1);
if (benefitStr.getAsInteger(0, benefit)) {
return ::emitError(loc) << "couldn't parse benefit: " << benefitStr;
}
remainder = remainder.drop_front(closing + 1).trim();
}

int64_t parameter = -1;
if (remainder.starts_with("(")) {
if (!remainder.ends_with(")")) {
return ::emitError(loc)
<< "couldn't find the closing ')' in " << remainder;
}
StringRef parameterStr = remainder.drop_front().drop_back();
if (parameterStr.getAsInteger(0, parameter)) {
return ::emitError(loc) << "couldn't parse parameter: " << parameterStr;
}
}

OperationState state(loc,
"transform.apply_patterns.enzyme_hlo." + opName.str());
if (benefit != 1)
state.addAttribute("benefit", builder.getI64IntegerAttr(benefit));
if (parameter != -1)
state.addAttribute("parameter", builder.getI64IntegerAttr(parameter));
builder.create(state);
}
return success();
}

namespace {
class GenerateApplyPatternsPass
: public PassWrapper<GenerateApplyPatternsPass, OperationPass<>> {
Expand All @@ -97,27 +155,37 @@ class GenerateApplyPatternsPass

void runOnOperation() override {
Operation *op = getOperation();
if (!flags.getValue().empty() && !patterns.getValue().empty()) {
op->emitError() << "flags and patterns are mutually exclusive";
return signalPassFailure();
}
if (op->getNumRegions() != 1 || !llvm::hasSingleElement(op->getRegion(0))) {
op->emitError()
<< "can only run on a single-region single-block operation";
return signalPassFailure();
}

llvm::APInt version(
llvm::APInt::getSufficientBitsNeeded(flags.getValue(), radix),
flags.getValue(), radix);

OpBuilder builder(&getContext());
op->setAttr(transform::TransformDialect::kWithNamedSequenceAttrName,
builder.getUnitAttr());

builder.setInsertionPointToStart(&op->getRegion(0).front());
if (failed(generateTransform(builder, version)))
return signalPassFailure();

if (!flags.empty()) {
llvm::APInt version(
llvm::APInt::getSufficientBitsNeeded(flags.getValue(), radix) + 1,
flags.getValue(), radix);
if (failed(generateTransform(builder, version)))
return signalPassFailure();
} else {
if (failed(parseTransform(builder, op->getLoc(), patterns)))
return signalPassFailure();
}
}

Option<std::string> flags{*this, "flags", llvm::cl::init("")};
Option<int> radix{*this, "radix", llvm::cl::init(10)};
Option<std::string> patterns{*this, "patterns", llvm::cl::init("")};
};

class RemoveTransform : public PassWrapper<RemoveTransform, OperationPass<>> {
Expand Down
13 changes: 12 additions & 1 deletion src/enzyme_ad/jax/TransformOps/TransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,18 @@ namespace mlir {
namespace transform {

void ApplyPadDotGeneralPatterns::populatePatterns(RewritePatternSet &patterns) {
addPadDotGeneral(patterns, getPostPad(), *getContext());
addPadDotGeneral(patterns, getParameter(), *getContext(),
PatternBenefit(getBenefit().value_or(1)));
}

void ApplyIotaSimplifyPatterns::populatePatterns(RewritePatternSet &patterns) {
addIotaSimplify(patterns, getParameter(), *getContext(),
PatternBenefit(getBenefit().value_or(1)));
}
void ApplyBroadcastInDimSimplifyPatterns::populatePatterns(
RewritePatternSet &patterns) {
addBroadcastInDimSimplify(patterns, getParameter(), *getContext(),
PatternBenefit(getBenefit().value_or(1)));
}

} // namespace transform
Expand Down
Loading