Skip to content

[Comb] delete slow canonicalizer #8014

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 0 additions & 111 deletions lib/Dialect/Comb/CombFolds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1178,107 +1178,6 @@ OpFoldResult OrOp::fold(FoldAdaptor adaptor) {
return constFoldAssociativeOp(inputs, hw::PEO::Or);
}

/// Simplify concat ops in an or op when a constant operand is present in either
/// concat.
///
/// This will invert an or(concat, concat) into concat(or, or, ...), which can
/// often be further simplified due to the smaller or ops being easier to fold.
///
/// For example:
///
/// or(..., concat(x, 0), concat(0, y))
/// ==> or(..., concat(x, 0, y)), when x and y don't overlap.
///
/// or(..., concat(x: i2, cst1: i4), concat(cst2: i5, y: i1))
/// ==> or(..., concat(or(x: i2, extract(cst2, 4..3)),
/// or(extract(cst1, 3..1), extract(cst2, 2..0)),
/// or(extract(cst1, 0..0), y: i1))
static bool canonicalizeOrOfConcatsWithCstOperands(OrOp op, size_t concatIdx1,
size_t concatIdx2,
PatternRewriter &rewriter) {
assert(concatIdx1 < concatIdx2 && "concatIdx1 must be < concatIdx2");

auto inputs = op.getInputs();
auto concat1 = inputs[concatIdx1].getDefiningOp<ConcatOp>();
auto concat2 = inputs[concatIdx2].getDefiningOp<ConcatOp>();

assert(concat1 && concat2 && "expected indexes to point to ConcatOps");

// We can simplify as long as a constant is present in either concat.
bool hasConstantOp1 =
llvm::any_of(concat1->getOperands(), [&](Value operand) -> bool {
return operand.getDefiningOp<hw::ConstantOp>();
});
if (!hasConstantOp1) {
bool hasConstantOp2 =
llvm::any_of(concat2->getOperands(), [&](Value operand) -> bool {
return operand.getDefiningOp<hw::ConstantOp>();
});
if (!hasConstantOp2)
return false;
}

SmallVector<Value> newConcatOperands;

// Simultaneously iterate over the operands of both concat ops, from MSB to
// LSB, pushing out or's of overlapping ranges of the operands. When operands
// span different bit ranges, we extract only the maximum overlap.
auto operands1 = concat1->getOperands();
auto operands2 = concat2->getOperands();
// Number of bits already consumed from operands 1 and 2, respectively.
unsigned consumedWidth1 = 0;
unsigned consumedWidth2 = 0;
for (auto it1 = operands1.begin(), end1 = operands1.end(),
it2 = operands2.begin(), end2 = operands2.end();
it1 != end1 && it2 != end2;) {
auto operand1 = *it1;
auto operand2 = *it2;

unsigned remainingWidth1 =
hw::getBitWidth(operand1.getType()) - consumedWidth1;
unsigned remainingWidth2 =
hw::getBitWidth(operand2.getType()) - consumedWidth2;
unsigned widthToConsume = std::min(remainingWidth1, remainingWidth2);
auto narrowedType = rewriter.getIntegerType(widthToConsume);

auto extract1 = rewriter.createOrFold<ExtractOp>(
op.getLoc(), narrowedType, operand1, remainingWidth1 - widthToConsume);
auto extract2 = rewriter.createOrFold<ExtractOp>(
op.getLoc(), narrowedType, operand2, remainingWidth2 - widthToConsume);

newConcatOperands.push_back(
rewriter.createOrFold<OrOp>(op.getLoc(), extract1, extract2, false));

consumedWidth1 += widthToConsume;
consumedWidth2 += widthToConsume;

if (widthToConsume == remainingWidth1) {
++it1;
consumedWidth1 = 0;
}
if (widthToConsume == remainingWidth2) {
++it2;
consumedWidth2 = 0;
}
}

ConcatOp newOp = rewriter.create<ConcatOp>(op.getLoc(), newConcatOperands);

// Copy the old operands except for concatIdx1 and concatIdx2, and append the
// new ConcatOp to the end.
SmallVector<Value> newOrOperands;
newOrOperands.append(inputs.begin(), inputs.begin() + concatIdx1);
newOrOperands.append(inputs.begin() + concatIdx1 + 1,
inputs.begin() + concatIdx2);
newOrOperands.append(inputs.begin() + concatIdx2 + 1,
inputs.begin() + inputs.size());
newOrOperands.push_back(newOp);

replaceOpWithNewOpAndCopyName<OrOp>(rewriter, op, op.getType(),
newOrOperands);
return true;
}

LogicalResult OrOp::canonicalize(OrOp op, PatternRewriter &rewriter) {
auto inputs = op.getInputs();
auto size = inputs.size();
Expand Down Expand Up @@ -1328,16 +1227,6 @@ LogicalResult OrOp::canonicalize(OrOp op, PatternRewriter &rewriter) {
}
}

// or(..., concat(x, cst1), concat(cst2, y)
// ==> or(..., concat(x, cst3, y)), when x and y don't overlap.
for (size_t i = 0; i < size - 1; ++i) {
if (auto concat = inputs[i].getDefiningOp<ConcatOp>())
for (size_t j = i + 1; j < size; ++j)
if (auto concat = inputs[j].getDefiningOp<ConcatOp>())
if (canonicalizeOrOfConcatsWithCstOperands(op, i, j, rewriter))
return success();
}

// extracts only of or(...) -> or(extract()...)
if (narrowOperationWidth(op, true, rewriter))
return success();
Expand Down
81 changes: 0 additions & 81 deletions test/Dialect/Comb/canonicalization.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -181,87 +181,6 @@ hw.module @dedupLong(in %arg0 : i7, in %arg1 : i7, in %arg2: i7, out resAnd: i7,
hw.output %0, %1 : i7, i7
}

// CHECK-LABEL: hw.module @orExclusiveConcats
hw.module @orExclusiveConcats(in %arg0 : i6, in %arg1 : i2, out o: i9) {
// CHECK-NEXT: %false = hw.constant false
// CHECK-NEXT: %0 = comb.concat %arg1, %false, %arg0 : i2, i1, i6
// CHECK-NEXT: hw.output %0 : i9
%c0 = hw.constant 0 : i3
%0 = comb.concat %c0, %arg0 : i3, i6
%c1 = hw.constant 0 : i7
%1 = comb.concat %arg1, %c1 : i2, i7
%2 = comb.or %0, %1 : i9
hw.output %2 : i9
}

// When two concats are or'd together and have mutually-exclusive fields, they
// can be merged together into a single concat.
// concat0: 0aaa aaa0 0000 0bb0
// concat1: 0000 0000 ccdd d000
// merged: 0aaa aaa0 ccdd dbb0
// CHECK-LABEL: hw.module @orExclusiveConcats2
hw.module @orExclusiveConcats2(in %arg0 : i6, in %arg1 : i2, in %arg2: i2, in %arg3: i3, out o: i16) {
// CHECK-NEXT: %false = hw.constant false
// CHECK-NEXT: %0 = comb.concat %false, %arg0, %false, %arg2, %arg3, %arg1, %false : i1, i6, i1, i2, i3, i2, i1
// CHECK-NEXT: hw.output %0 : i16
%c0 = hw.constant 0 : i1
%c1 = hw.constant 0 : i6
%c2 = hw.constant 0 : i1
%0 = comb.concat %c0, %arg0, %c1, %arg1, %c2: i1, i6, i6, i2, i1
%c3 = hw.constant 0 : i8
%c4 = hw.constant 0 : i3
%1 = comb.concat %c3, %arg2, %arg3, %c4 : i8, i2, i3, i3
%2 = comb.or %0, %1 : i16
hw.output %2 : i16
}

// When two concats are or'd together and have mutually-exclusive fields, they
// can be merged together into a single concat.
// concat0: aaaa 1111
// concat1: 1111 10bb
// merged: 1111 1111
// CHECK-LABEL: hw.module @orExclusiveConcats3
hw.module @orExclusiveConcats3(in %arg0 : i4, in %arg1 : i2, out o: i8) {
// CHECK-NEXT: [[RES:%[a-z0-9_-]+]] = hw.constant -1 : i8
// CHECK-NEXT: hw.output [[RES]] : i8
%c0 = hw.constant -1 : i4
%0 = comb.concat %arg0, %c0: i4, i4
%c1 = hw.constant -1 : i5
%c2 = hw.constant 0 : i1
%1 = comb.concat %c1, %c2, %arg1 : i5, i1, i2
%2 = comb.or %0, %1 : i8
hw.output %2 : i8
}

// CHECK-LABEL: hw.module @orMultipleExclusiveConcats
hw.module @orMultipleExclusiveConcats(in %arg0 : i2, in %arg1 : i2, in %arg2: i2, out o: i6) {
// CHECK-NEXT: %0 = comb.concat %arg0, %arg1, %arg2 : i2, i2, i2
// CHECK-NEXT: hw.output %0 : i6
%c2 = hw.constant 0 : i2
%c4 = hw.constant 0 : i4
%0 = comb.concat %arg0, %c4: i2, i4
%1 = comb.concat %c2, %arg1, %c2: i2, i2, i2
%2 = comb.concat %c4, %arg2: i4, i2
%out = comb.or %0, %1, %2 : i6
hw.output %out : i6
}

// CHECK-LABEL: hw.module @orConcatsWithMux
hw.module @orConcatsWithMux(in %bit: i1, in %cond: i1, out o: i6) {
// CHECK-NEXT: [[RES:%[a-z0-9_-]+]] = hw.constant 0 : i4
// CHECK-NEXT: %0 = comb.concat [[RES]], %cond, %bit : i4, i1, i1
// CHECK-NEXT: hw.output %0 : i6
%c0 = hw.constant 0 : i5
%0 = comb.concat %c0, %bit: i5, i1
%c1 = hw.constant 0 : i4
%c2 = hw.constant 2 : i2
%c3 = hw.constant 0 : i2
%1 = comb.mux %cond, %c2, %c3 : i2
%2 = comb.concat %c1, %1 : i4, i2
%3 = comb.or %0, %2 : i6
hw.output %3 : i6
}

// CHECK-LABEL: @extractNested
hw.module @extractNested(in %0: i5, out o1 : i1) {
// Multiple layers of nested extract is a weak evidence that the cannonicalization
Expand Down
Loading