Skip to content

PDLL: Add equals operator #98

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions mlir/include/mlir/Dialect/PDL/IR/Builtins.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
#ifndef MLIR_DIALECT_PDL_IR_BUILTINS_H_
#define MLIR_DIALECT_PDL_IR_BUILTINS_H_

#include "mlir/Support/LogicalResult.h"

namespace mlir {
class PDLPatternModule;
class Attribute;
Expand All @@ -29,6 +31,7 @@ Attribute addEntryToDictionaryAttr(PatternRewriter &rewriter,
Attribute createArrayAttr(PatternRewriter &rewriter);
Attribute addElemToArrayAttr(PatternRewriter &rewriter, Attribute attr,
Attribute element);
LogicalResult equals(PatternRewriter &rewriter, Attribute lhs, Attribute rhs);
} // namespace builtin
} // namespace pdl
} // namespace mlir
Expand Down
25 changes: 25 additions & 0 deletions mlir/lib/Dialect/PDL/IR/Builtins.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,30 @@ mlir::Attribute addElemToArrayAttr(mlir::PatternRewriter &rewriter,
values.push_back(element);
return rewriter.getArrayAttr(values);
}

LogicalResult equals(mlir::PatternRewriter &, mlir::Attribute lhs,
mlir::Attribute rhs) {
if (auto lhsAttr = dyn_cast_or_null<IntegerAttr>(lhs)) {
auto rhsAttr = dyn_cast_or_null<IntegerAttr>(rhs);
if (!rhsAttr || lhsAttr.getType() != rhsAttr.getType())
return failure();

APInt lhsVal = lhsAttr.getValue();
APInt rhsVal = rhsAttr.getValue();
return success(lhsVal.eq(rhsVal));
}

if (auto lhsAttr = dyn_cast_or_null<FloatAttr>(lhs)) {
auto rhsAttr = dyn_cast_or_null<FloatAttr>(rhs);
if (!rhsAttr || lhsAttr.getType() != rhsAttr.getType())
return failure();

APFloat lhsVal = lhsAttr.getValue();
APFloat rhsVal = rhsAttr.getValue();
return success(lhsVal.compare(rhsVal) == llvm::APFloatBase::cmpEqual);
}
return failure();
}
Comment on lines +43 to +65

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why only IntegerAttr and floatAttr and why don't we use == operator from Attribute class?

} // namespace builtin

void registerBuiltins(PDLPatternModule &pdlPattern) {
Expand All @@ -52,5 +76,6 @@ void registerBuiltins(PDLPatternModule &pdlPattern) {
createArrayAttr);
pdlPattern.registerRewriteFunction("__builtin_addElemToArrayAttr",
addElemToArrayAttr);
pdlPattern.registerConstraintFunction("__builtin_equals", equals);
}
} // namespace mlir::pdl
4 changes: 4 additions & 0 deletions mlir/lib/Tools/PDLL/Parser/Lexer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,10 @@ Token Lexer::lexToken() {
++curPtr;
return formToken(Token::equal_arrow, tokStart);
}
if (*curPtr == '=') {
++curPtr;
return formToken(Token::equal_equal, tokStart);
}
return formToken(Token::equal, tokStart);
case ';':
return formToken(Token::semicolon, tokStart);
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Tools/PDLL/Parser/Lexer.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ class Token {
dot,
equal,
equal_arrow,
equal_equal,
semicolon,
exclam,
/// Paired punctuation.
Expand Down
83 changes: 78 additions & 5 deletions mlir/lib/Tools/PDLL/Parser/Parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,14 @@ class Parser {
// Exprs

FailureOr<ast::Expr *> parseExpr();
FailureOr<ast::Expr *> parseLogicalOrExpr();
FailureOr<ast::Expr *> parseLogicalAndExpr();
FailureOr<ast::Expr *> parseEqualityExpr();
FailureOr<ast::Expr *> parseRelationExpr();
FailureOr<ast::Expr *> parseAddSubExpr();
FailureOr<ast::Expr *> parseMulDivExpr();
FailureOr<ast::Expr *> parseLogicalNotExpr();
FailureOr<ast::Expr *> parseOtherExpr();

/// Identifier expressions.
FailureOr<ast::Expr *> parseArrayAttrExpr();
Expand Down Expand Up @@ -593,6 +601,7 @@ class Parser {
ast::UserRewriteDecl *addEntryToDictionaryAttr;
ast::UserRewriteDecl *createArrayAttr;
ast::UserRewriteDecl *addElemToArrayAttr;
ast::UserConstraintDecl *equals;
} builtins{};
};
} // namespace
Expand Down Expand Up @@ -621,7 +630,7 @@ T *Parser::declareBuiltin(StringRef name, ArrayRef<StringRef> argNames,
popDeclScope();

auto *constraintDecl = T::createNative(ctx, ast::Name::create(ctx, name, loc),
args, results, {}, attrTy);
args, results, {}, createUserConstraintRewriteResultType(results));
curDeclScope->add(constraintDecl);
return constraintDecl;
}
Expand All @@ -637,6 +646,10 @@ void Parser::declareBuiltins() {
builtins.addElemToArrayAttr = declareBuiltin<ast::UserRewriteDecl>(
"__builtin_addElemToArrayAttr", {"attr", "element"},
/*returnsAttr=*/true);

builtins.equals = declareBuiltin<ast::UserConstraintDecl>(
"__builtin_equals", {"lhs", "rhs"},
/*returnsAttr=*/false);
}

FailureOr<ast::Module *> Parser::parseModule() {
Expand Down Expand Up @@ -1859,7 +1872,68 @@ FailureOr<ast::ConstraintRef> Parser::parseArgOrResultConstraint() {
//===----------------------------------------------------------------------===//
// Exprs

FailureOr<ast::Expr *> Parser::parseExpr() {
// Operator precedence follows C++:
// When parsing an expression, an operator which is listed on some row below with a precedence will be bound tighter (as if by parentheses) to
// its arguments than any operator that is listed on a row further below it with
// a lower precedence. Operators that have the same precedence are bound to
// their arguments left-to-right.
// Highest precedence first:
// - call, member access
// - logical not
// - multipication, division, remainder
// - addition, subtraction
// - relation operators
// - equality operators
// - logical and
// - logical or
FailureOr<ast::Expr *> Parser::parseExpr() { return parseLogicalOrExpr(); }

FailureOr<ast::Expr *> Parser::parseLogicalOrExpr() {
return parseLogicalAndExpr();
}

FailureOr<ast::Expr *> Parser::parseLogicalAndExpr() {
return parseEqualityExpr();
}

FailureOr<ast::Expr *> Parser::parseEqualityExpr() {
auto lhs = parseRelationExpr();
if (failed(lhs))
return failure();

switch (curToken.getKind()) {
case Token::equal_equal: {
consumeToken();
auto rhs = parseRelationExpr();
if (failed(rhs))
return failure();
SmallVector<ast::Expr *> args{*lhs, *rhs};
return createBuiltinCall(curToken.getLoc(), builtins.equals, args);
}
default:
return lhs;
}
}

FailureOr<ast::Expr *> Parser::parseRelationExpr() { return parseAddSubExpr(); }

FailureOr<ast::Expr *> Parser::parseAddSubExpr() { return parseMulDivExpr(); }

FailureOr<ast::Expr *> Parser::parseMulDivExpr() {
return parseLogicalNotExpr();
}

FailureOr<ast::Expr *> Parser::parseLogicalNotExpr() {
switch (curToken.getKind()) {
case Token::exclam:
return parseNegatedExpr();
break;
default:
return parseOtherExpr();
}
}

FailureOr<ast::Expr *> Parser::parseOtherExpr() {
if (curToken.is(Token::underscore))
return parseUnderscoreExpr();

Expand Down Expand Up @@ -1893,9 +1967,6 @@ FailureOr<ast::Expr *> Parser::parseExpr() {
case Token::l_square:
lhsExpr = parseArrayAttrExpr();
break;
case Token::exclam:
lhsExpr = parseNegatedExpr();
break;
case Token::string_block:
return emitError("expected expression. If you are trying to create an "
"ArrayAttr, use a space between `[` and `{`.");
Expand Down Expand Up @@ -2136,6 +2207,8 @@ FailureOr<ast::Expr *> Parser::parseNegatedExpr() {
FailureOr<ast::Expr *> identifierExpr = parseIdentifierExpr();
if (failed(identifierExpr))
return failure();
if (!curToken.is(Token::l_paren))
return emitError("expected `(` after function name");
return parseCallExpr(*identifierExpr, /*isNegated = */ true);
}

Expand Down
7 changes: 7 additions & 0 deletions mlir/test/lib/Tools/PDLL/TestPDLL.pdll
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,10 @@ Pattern TestSimplePattern => replace op<test.simple> with op<test.success>;

// Test the import of interfaces.
Pattern TestInterface => replace _: CastOpInterface with op<test.success>;

// Test equals builtin
Pattern TestEquals {
let op = op<test.equals> {val = val : Attr};
val == attr<"4 : i32">;
replace op with op<test.success>;
}
14 changes: 14 additions & 0 deletions mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll
Original file line number Diff line number Diff line change
Expand Up @@ -252,3 +252,17 @@ Pattern RewriteMultiplyElementsArrayAttr {
replace root with newRoot;
};
}

// -----
// CHECK-LABEL: pdl.pattern @TestEquals : benefit(0) {
// CHECK: %[[VAL_0:.*]] = operands
// CHECK: %[[VAL_1:.*]] = attribute
// CHECK: %[[VAL_2:.*]] = types
// CHECK: %[[VAL_3:.*]] = operation "test.op"(%[[VAL_0]] : !pdl.range<value>) {"val" = %[[VAL_1]]} -> (%[[VAL_2]] : !pdl.range<type>)
// CHECK: %[[VAL_4:.*]] = attribute = 4 : i32
// CHECK: apply_native_constraint "__builtin_equals"(%[[VAL_1]], %[[VAL_4]] : !pdl.attribute, !pdl.attribute)
Pattern TestEquals {
let op = op<test.op> {val = val : Attr};
val == attr<"4 : i32">;
replace op with op<test.success>;
}
9 changes: 9 additions & 0 deletions mlir/test/mlir-pdll/Integration/test-pdll.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,12 @@ func.func @testImportedInterface() -> i1 {
%value = "builtin.unrealized_conversion_cast"() : () -> (i1)
return %value : i1
}

// CHECK-LABEL: func @test_builtin
func.func @test_builtin() {
// CHECK: test.success
// CHECK: test.equals_neg
"test.equals"() { val = 4 : i32 }: () -> ()
"test.equals_neg"() { val = 4 : i32 }: () -> ()
return
}
37 changes: 37 additions & 0 deletions mlir/test/mlir-pdll/Parser/expr-failure.pdll
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,23 @@ Pattern {

// -----

Pattern {
// CHECK: expected native constraint
!attr<"0 : i1">
erase _;
}

// -----

Pattern {
let tuple = (attr<"3 : i34">);
// CHECK: expected `(` after function name
!tuple.0;
erase _;
}

// -----

Pattern {
// CHECK: expected expression
let tuple = (10 = _: Value);
Expand Down Expand Up @@ -395,3 +412,23 @@ Pattern {
// CHECK: expected `>` after type literal
let foo = type<"";
}

// -----

//===----------------------------------------------------------------------===//
// Builtins
//===----------------------------------------------------------------------===//

Pattern {
// CHECK: expected expression
==
erase _: Op;
}

// -----

Pattern {
// CHECK: expected expression
attr<"4 : i32"> ==
erase _: Op;
}
16 changes: 16 additions & 0 deletions mlir/test/mlir-pdll/Parser/expr.pdll
Original file line number Diff line number Diff line change
Expand Up @@ -359,3 +359,19 @@ Pattern {

erase _: Op;
}

// -----

//===----------------------------------------------------------------------===//
// Builtins
//===----------------------------------------------------------------------===//

// CHECK: Module {{.*}}
// CHECK: UserConstraintDecl {{.*}} Name<__builtin_equals> ResultType<Tuple<>>
// CHECK: UserConstraintDecl {{.*}} Name<__builtin_equals> ResultType<Tuple<>>
Pattern {
attr<"4 : i32"> == attr<"5 : i32">;
let a: Attr;
a == a;
erase _: Op;
}
20 changes: 20 additions & 0 deletions mlir/unittests/Dialect/PDL/BuiltinTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,4 +69,24 @@ TEST_F(BuiltinTest, addElemToArrayAttr) {
cast<DictionaryAttr>(*cast<ArrayAttr>(updatedArrAttr).begin());
EXPECT_EQ(dictInsideArrAttr, dict);
}

TEST_F(BuiltinTest, equals) {
auto onei16 = rewriter.getI16IntegerAttr(1);
auto onei32 = rewriter.getI32IntegerAttr(1);
auto zeroi32 = rewriter.getI32IntegerAttr(0);

EXPECT_TRUE(builtin::equals(rewriter, onei16, onei16).succeeded());
EXPECT_TRUE(builtin::equals(rewriter, onei16, onei32).failed());
EXPECT_TRUE(builtin::equals(rewriter, zeroi32, onei32).failed());

auto onef32 = rewriter.getF32FloatAttr(1.0);
auto zerof32 = rewriter.getF32FloatAttr(0.0);
auto negzerof32 = rewriter.getF32FloatAttr(-0.0);
auto zerof64 = rewriter.getF64FloatAttr(0.0);

EXPECT_TRUE(builtin::equals(rewriter, onef32, onef32).succeeded());
EXPECT_TRUE(builtin::equals(rewriter, onef32, zerof32).failed());
EXPECT_TRUE(builtin::equals(rewriter, negzerof32, zerof32).succeeded());
EXPECT_TRUE(builtin::equals(rewriter, zerof32, zerof64).failed());
}
} // namespace