Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RTG][Elaboration] Improve how tests are matched with targets #8391

Open
wants to merge 1 commit into
base: maerhart-rtg-randomized-sequences-as-identity-values
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion frontends/PyRTG/src/pyrtg/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def name(self) -> str:

def _codegen(self):
test = rtg.TestOp(
self.name,
self.name, self.name,
ir.TypeAttr.get(
rtg.DictType.get([
(ir.StringAttr.get(name), ty)
Expand Down
2 changes: 2 additions & 0 deletions frontends/PyRTG/test/basic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,5 @@ rtg.test @test0() {
%1 = rtg.randomize_sequence %0
rtg.embed_sequence %1
}

rtg.target @singleton : !rtg.dict<> {}
9 changes: 9 additions & 0 deletions frontends/PyRTG/test/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,15 @@

from pyrtg import test, sequence, target, entry, rtg, Label, Set, Integer, Bag, rtgtest, Immediate, IntegerRegister, Array, Bool, MemoryBlock, Memory, Tuple

# MLIR-LABEL: rtg.target @Singleton : !rtg.dict<>
# MLIR-NEXT: }


@target
class Singleton:
pass


# MLIR-LABEL: rtg.target @Tgt0 : !rtg.dict<entry0: !rtg.set<index>>
# MLIR-NEXT: [[C0:%.+]] = index.constant 0
# MLIR-NEXT: [[C1:%.+]] = index.constant 1
Expand Down
17 changes: 14 additions & 3 deletions include/circt/Dialect/RTG/IR/RTGOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -687,6 +687,7 @@ def TestOp : RTGOp<"test", [
SingleBlock,
NoTerminator,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmBlockArgumentNames"]>,
DeclareOpInterfaceMethods<SymbolUserOpInterface>,
HasParent<"mlir::ModuleOp">
]> {
let summary = "the root of a test";
Expand All @@ -701,8 +702,15 @@ def TestOp : RTGOp<"test", [
with that target.

By default each test can be matched with all targets that fulfill its
requirements, but the user should be able to specify more constraints on the
matching procedure.
requirements, but the user can also directly provide a target via the
'target' attribute. In that case, the test will only be randomized against
that target.

The 'templateName' attribute specifies the name of the original test
template (mostly for result reporting purposes). This is because a test
(template) can be matched against many targets and during this process one
test per match is created, but all of them preserve the same test template
name.

The body of this operation shall be processed the same way as an
`rtg.sequence`'s body with the exception of the block arguments.
Expand All @@ -712,11 +720,14 @@ def TestOp : RTGOp<"test", [
}];

let arguments = (ins SymbolNameAttr:$sym_name,
TypeAttrOf<DictType>:$target);
StrAttr:$templateName,
TypeAttrOf<DictType>:$targetType,
OptionalAttr<SymbolNameAttr>:$target);
let regions = (region SizedRegion<1>:$bodyRegion);

let hasCustomAssemblyFormat = 1;
let hasRegionVerifier = 1;
let hasVerifier = 1;
}

def TargetOp : RTGOp<"target", [
Expand Down
2 changes: 2 additions & 0 deletions include/circt/Dialect/RTG/Transforms/RTGPasses.td
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ def ElaborationPass : Pass<"rtg-elaborate", "mlir::ModuleOp"> {
let options = [
Option<"seed", "seed", "unsigned", /*default=*/"",
"The seed for any RNG constructs used in the pass.">,
Option<"deleteUnmatchedTests", "delete-unmatched-tests", "bool", /*default=*/"true",
"Delete tests that could not be matched with a target.">,
];

let dependentDialects = ["mlir::index::IndexDialect"];
Expand Down
12 changes: 9 additions & 3 deletions integration_test/Bindings/Python/dialects/rtg.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
cpu1 = rtg.ConstantOp(rtgtest.CPUAttr.get(cpuAttr.id + 1))
rtg.YieldOp([cpu0, cpu1])

test = rtg.TestOp('test_name', TypeAttr.get(dictTy))
test = rtg.TestOp('test_name', 'test_name', TypeAttr.get(dictTy))
Block.create_at_start(test.bodyRegion, [cpuTy, cpuTy])

# CHECK: rtg.target @target_name : !rtg.dict<cpu0: !rtgtest.cpu, cpu1: !rtgtest.cpu> {
Expand Down Expand Up @@ -58,12 +58,18 @@
seq = rtg.SequenceOp('sequence_name', TypeAttr.get(rtg.SequenceType.get()))
Block.create_at_start(seq.bodyRegion, [])

test = rtg.TestOp('test_name', TypeAttr.get(rtg.DictType.get()))
test = rtg.TestOp('test_name', 'test_name',
TypeAttr.get(rtg.DictType.get()))
block = Block.create_at_start(test.bodyRegion, [])
with InsertionPoint(block):
seq_get = rtg.GetSequenceOp(rtg.SequenceType.get(), 'sequence_name')
rtg.RandomizeSequenceOp(seq_get)

target = rtg.TargetOp('target', TypeAttr.get(rtg.DictType.get()))
block = Block.create_at_start(target.bodyRegion, [])
with InsertionPoint(block):
rtg.YieldOp([])

# CHECK: rtg.test @test_name() {
# CHECK-NEXT: [[SEQ:%.+]] = rtg.get_sequence @sequence_name
# CHECK-NEXT: rtg.randomize_sequence [[SEQ]]
Expand All @@ -78,7 +84,7 @@
rtgtool.populate_randomizer_pipeline(pm, options)
pm.run(m.operation)

# CHECK: rtg.test @test_name() {
# CHECK: rtg.test @test_name_target() template "test_name" target @target {
# CHECK-NEXT: }
print(m)

Expand Down
11 changes: 8 additions & 3 deletions integration_test/Bindings/Python/rtg_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# REQUIRES: bindings_python
# RUN: %PYTHON% %s %T && FileCheck %s --input-file=%T/test0.s --check-prefix=TEST0 && FileCheck %s --input-file=%T/test1.s --check-prefix=TEST1
# RUN: %PYTHON% %s %T && FileCheck %s --input-file=%T/test0_target.s --check-prefix=TEST0 && FileCheck %s --input-file=%T/test1_target.s --check-prefix=TEST1

import sys
import circt
Expand All @@ -15,16 +15,21 @@
circt.register_dialects(ctx)
m = Module.create()
with InsertionPoint(m.body):
test = rtg.TestOp('test0', TypeAttr.get(rtg.DictType.get()))
test = rtg.TestOp('test0', 'test0', TypeAttr.get(rtg.DictType.get()))
block = Block.create_at_start(test.bodyRegion, [])
with InsertionPoint(block):
rtgtest.rv32i_ebreak()

test = rtg.TestOp('test1', TypeAttr.get(rtg.DictType.get()))
test = rtg.TestOp('test1', 'test1', TypeAttr.get(rtg.DictType.get()))
block = Block.create_at_start(test.bodyRegion, [])
with InsertionPoint(block):
rtgtest.rv32i_ecall()

target = rtg.TargetOp('target', TypeAttr.get(rtg.DictType.get()))
block = Block.create_at_start(target.bodyRegion, [])
with InsertionPoint(block):
rtg.YieldOp([])

pm = PassManager()
options = rtgtool.Options(seed=0,
output_format=rtgtool.OutputFormat.ASM,
Expand Down
91 changes: 84 additions & 7 deletions lib/Dialect/RTG/IR/RTGOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -497,18 +497,61 @@ LogicalResult ContextSwitchOp::verify() {
//===----------------------------------------------------------------------===//

LogicalResult TestOp::verifyRegions() {
if (!getTarget().entryTypesMatch(getBody()->getArgumentTypes()))
if (!getTargetType().entryTypesMatch(getBody()->getArgumentTypes()))
return emitOpError("argument types must match dict entry types");

return success();
}

LogicalResult TestOp::verify() {
if (getTemplateName().empty())
return emitOpError("template name must not be empty");

return success();
}

LogicalResult TestOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
if (!getTargetAttr())
return success();

auto target =
symbolTable.lookupNearestSymbolFrom<TargetOp>(*this, getTargetAttr());
if (!target)
return emitOpError()
<< "'" << *getTarget()
<< "' does not reference a valid 'rtg.target' operation";

// Check if target is a subtype of test requirements
// Since entries are sorted by name, we can do this in a single pass
size_t targetIdx = 0;
auto targetEntries = target.getTarget().getEntries();
for (auto testEntry : getTargetType().getEntries()) {
// Find the matching entry in target entries.
while (targetIdx < targetEntries.size() &&
targetEntries[targetIdx].name.getValue() < testEntry.name.getValue())
targetIdx++;

// Check if we found a matching entry with the same name and type
if (targetIdx >= targetEntries.size() ||
targetEntries[targetIdx].name != testEntry.name ||
targetEntries[targetIdx].type != testEntry.type) {
return emitOpError("referenced 'rtg.target' op's type is invalid: "
"missing entry called '")
<< testEntry.name.getValue() << "' of type " << testEntry.type;
}
}

return success();
}

ParseResult TestOp::parse(OpAsmParser &parser, OperationState &result) {
// Parse the name as a symbol.
if (parser.parseSymbolName(
result.getOrAddProperties<TestOp::Properties>().sym_name))
StringAttr symNameAttr;
if (parser.parseSymbolName(symNameAttr))
return failure();

result.getOrAddProperties<TestOp::Properties>().sym_name = symNameAttr;

// Parse the function signature.
SmallVector<OpAsmParser::Argument> arguments;
SmallVector<StringAttr> names;
Expand Down Expand Up @@ -544,7 +587,31 @@ ParseResult TestOp::parse(OpAsmParser &parser, OperationState &result) {
ArrayRef<DictEntry>(entries));
if (!type)
return failure();
result.getOrAddProperties<TestOp::Properties>().target = TypeAttr::get(type);
result.getOrAddProperties<TestOp::Properties>().targetType =
TypeAttr::get(type);

std::string templateName;
if (!parser.parseOptionalKeyword("template")) {
auto loc = parser.getCurrentLocation();
if (parser.parseString(&templateName))
return failure();

if (templateName.empty())
return parser.emitError(loc, "template name must not be empty");
}

StringAttr templateNameAttr = symNameAttr;
if (!templateName.empty())
templateNameAttr = StringAttr::get(result.getContext(), templateName);

StringAttr targetName;
if (!parser.parseOptionalKeyword("target"))
if (parser.parseSymbolName(targetName))
return failure();

result.getOrAddProperties<TestOp::Properties>().templateName =
templateNameAttr;
result.getOrAddProperties<TestOp::Properties>().target = targetName;

auto loc = parser.getCurrentLocation();
if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
Expand Down Expand Up @@ -574,23 +641,33 @@ void TestOp::print(OpAsmPrinter &p) {
p << "(";
SmallString<32> resultNameStr;
llvm::interleaveComma(
llvm::zip(getTarget().getEntries(), getBody()->getArguments()), p,
llvm::zip(getTargetType().getEntries(), getBody()->getArguments()), p,
[&](auto entryAndArg) {
auto [entry, arg] = entryAndArg;
p << entry.name.getValue() << " = ";
p.printRegionArgument(arg);
});
p << ")";

if (getSymNameAttr() != getTemplateNameAttr())
p << " template " << getTemplateNameAttr();

if (getTargetAttr()) {
p << " target ";
p.printSymbolName(getTargetAttr().getValue());
}

p.printOptionalAttrDictWithKeyword(
(*this)->getAttrs(), {getSymNameAttrName(), getTargetAttrName()});
(*this)->getAttrs(), {getSymNameAttrName(), getTargetTypeAttrName(),
getTargetAttrName(), getTemplateNameAttrName()});
p << ' ';
p.printRegion(getBodyRegion(), /*printEntryBlockArgs=*/false);
}

void TestOp::getAsmBlockArgumentNames(Region &region,
OpAsmSetValueNameFn setNameFn) {
for (auto [entry, arg] :
llvm::zip(getTarget().getEntries(), region.getArguments()))
llvm::zip(getTargetType().getEntries(), region.getArguments()))
setNameFn(arg, entry.name.getValue());
}

Expand Down
Loading
Loading