Skip to content

Commit 49afef6

Browse files
committed
[RTG][Elaboration] Improve how tests are matched with targets
1 parent 5750c2a commit 49afef6

File tree

13 files changed

+434
-103
lines changed

13 files changed

+434
-103
lines changed

frontends/PyRTG/src/pyrtg/tests.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def name(self) -> str:
2424

2525
def _codegen(self):
2626
test = rtg.TestOp(
27-
self.name,
27+
self.name, self.name,
2828
ir.TypeAttr.get(
2929
rtg.DictType.get([
3030
(ir.StringAttr.get(name), ty)

frontends/PyRTG/test/basic.mlir

+2
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,5 @@ rtg.test @test0() {
3030
%1 = rtg.randomize_sequence %0
3131
rtg.embed_sequence %1
3232
}
33+
34+
rtg.target @singleton : !rtg.dict<> {}

frontends/PyRTG/test/basic.py

+9
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,15 @@
44

55
from pyrtg import test, sequence, target, entry, rtg, Label, Set, Integer, Bag, rtgtest, Immediate, IntegerRegister, Array, Bool, MemoryBlock, Memory, Tuple
66

7+
# MLIR-LABEL: rtg.target @Singleton : !rtg.dict<>
8+
# MLIR-NEXT: }
9+
10+
11+
@target
12+
class Singleton:
13+
pass
14+
15+
716
# MLIR-LABEL: rtg.target @Tgt0 : !rtg.dict<entry0: !rtg.set<index>>
817
# MLIR-NEXT: [[C0:%.+]] = index.constant 0
918
# MLIR-NEXT: [[C1:%.+]] = index.constant 1

include/circt/Dialect/RTG/IR/RTGOps.td

+14-3
Original file line numberDiff line numberDiff line change
@@ -687,6 +687,7 @@ def TestOp : RTGOp<"test", [
687687
SingleBlock,
688688
NoTerminator,
689689
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmBlockArgumentNames"]>,
690+
DeclareOpInterfaceMethods<SymbolUserOpInterface>,
690691
HasParent<"mlir::ModuleOp">
691692
]> {
692693
let summary = "the root of a test";
@@ -701,8 +702,15 @@ def TestOp : RTGOp<"test", [
701702
with that target.
702703

703704
By default each test can be matched with all targets that fulfill its
704-
requirements, but the user should be able to specify more constraints on the
705-
matching procedure.
705+
requirements, but the user can also directly provide a target via the
706+
'target' attribute. In that case, the test will only be randomized against
707+
that target.
708+
709+
The 'templateName' attribute specifies the name of the original test
710+
template (mostly for result reporting purposes). This is because a test
711+
(template) can be matched against many targets and during this process one
712+
test per match is created, but all of them preserve the same test template
713+
name.
706714

707715
The body of this operation shall be processed the same way as an
708716
`rtg.sequence`'s body with the exception of the block arguments.
@@ -712,11 +720,14 @@ def TestOp : RTGOp<"test", [
712720
}];
713721

714722
let arguments = (ins SymbolNameAttr:$sym_name,
715-
TypeAttrOf<DictType>:$target);
723+
StrAttr:$templateName,
724+
TypeAttrOf<DictType>:$targetType,
725+
OptionalAttr<SymbolNameAttr>:$target);
716726
let regions = (region SizedRegion<1>:$bodyRegion);
717727

718728
let hasCustomAssemblyFormat = 1;
719729
let hasRegionVerifier = 1;
730+
let hasVerifier = 1;
720731
}
721732

722733
def TargetOp : RTGOp<"target", [

include/circt/Dialect/RTG/Transforms/RTGPasses.td

+2
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ def ElaborationPass : Pass<"rtg-elaborate", "mlir::ModuleOp"> {
2727
let options = [
2828
Option<"seed", "seed", "unsigned", /*default=*/"",
2929
"The seed for any RNG constructs used in the pass.">,
30+
Option<"deleteUnmatchedTests", "delete-unmatched-tests", "bool", /*default=*/"true",
31+
"Delete tests that could not be matched with a target.">,
3032
];
3133

3234
let dependentDialects = ["mlir::index::IndexDialect"];

integration_test/Bindings/Python/dialects/rtg.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
cpu1 = rtg.ConstantOp(rtgtest.CPUAttr.get(cpuAttr.id + 1))
2525
rtg.YieldOp([cpu0, cpu1])
2626

27-
test = rtg.TestOp('test_name', TypeAttr.get(dictTy))
27+
test = rtg.TestOp('test_name', 'test_name', TypeAttr.get(dictTy))
2828
Block.create_at_start(test.bodyRegion, [cpuTy, cpuTy])
2929

3030
# CHECK: rtg.target @target_name : !rtg.dict<cpu0: !rtgtest.cpu, cpu1: !rtgtest.cpu> {
@@ -58,12 +58,18 @@
5858
seq = rtg.SequenceOp('sequence_name', TypeAttr.get(rtg.SequenceType.get()))
5959
Block.create_at_start(seq.bodyRegion, [])
6060

61-
test = rtg.TestOp('test_name', TypeAttr.get(rtg.DictType.get()))
61+
test = rtg.TestOp('test_name', 'test_name',
62+
TypeAttr.get(rtg.DictType.get()))
6263
block = Block.create_at_start(test.bodyRegion, [])
6364
with InsertionPoint(block):
6465
seq_get = rtg.GetSequenceOp(rtg.SequenceType.get(), 'sequence_name')
6566
rtg.RandomizeSequenceOp(seq_get)
6667

68+
target = rtg.TargetOp('target', TypeAttr.get(rtg.DictType.get()))
69+
block = Block.create_at_start(target.bodyRegion, [])
70+
with InsertionPoint(block):
71+
rtg.YieldOp([])
72+
6773
# CHECK: rtg.test @test_name() {
6874
# CHECK-NEXT: [[SEQ:%.+]] = rtg.get_sequence @sequence_name
6975
# CHECK-NEXT: rtg.randomize_sequence [[SEQ]]
@@ -78,7 +84,7 @@
7884
rtgtool.populate_randomizer_pipeline(pm, options)
7985
pm.run(m.operation)
8086

81-
# CHECK: rtg.test @test_name() {
87+
# CHECK: rtg.test @test_name_target() template "test_name" target @target {
8288
# CHECK-NEXT: }
8389
print(m)
8490

integration_test/Bindings/Python/rtg_pipeline.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# REQUIRES: bindings_python
2-
# RUN: %PYTHON% %s %T && FileCheck %s --input-file=%T/test0.s --check-prefix=TEST0 && FileCheck %s --input-file=%T/test1.s --check-prefix=TEST1
2+
# RUN: %PYTHON% %s %T && FileCheck %s --input-file=%T/test0_target.s --check-prefix=TEST0 && FileCheck %s --input-file=%T/test1_target.s --check-prefix=TEST1
33

44
import sys
55
import circt
@@ -15,16 +15,21 @@
1515
circt.register_dialects(ctx)
1616
m = Module.create()
1717
with InsertionPoint(m.body):
18-
test = rtg.TestOp('test0', TypeAttr.get(rtg.DictType.get()))
18+
test = rtg.TestOp('test0', 'test0', TypeAttr.get(rtg.DictType.get()))
1919
block = Block.create_at_start(test.bodyRegion, [])
2020
with InsertionPoint(block):
2121
rtgtest.rv32i_ebreak()
2222

23-
test = rtg.TestOp('test1', TypeAttr.get(rtg.DictType.get()))
23+
test = rtg.TestOp('test1', 'test1', TypeAttr.get(rtg.DictType.get()))
2424
block = Block.create_at_start(test.bodyRegion, [])
2525
with InsertionPoint(block):
2626
rtgtest.rv32i_ecall()
2727

28+
target = rtg.TargetOp('target', TypeAttr.get(rtg.DictType.get()))
29+
block = Block.create_at_start(target.bodyRegion, [])
30+
with InsertionPoint(block):
31+
rtg.YieldOp([])
32+
2833
pm = PassManager()
2934
options = rtgtool.Options(seed=0,
3035
output_format=rtgtool.OutputFormat.ASM,

lib/Dialect/RTG/IR/RTGOps.cpp

+84-7
Original file line numberDiff line numberDiff line change
@@ -497,18 +497,61 @@ LogicalResult ContextSwitchOp::verify() {
497497
//===----------------------------------------------------------------------===//
498498

499499
LogicalResult TestOp::verifyRegions() {
500-
if (!getTarget().entryTypesMatch(getBody()->getArgumentTypes()))
500+
if (!getTargetType().entryTypesMatch(getBody()->getArgumentTypes()))
501501
return emitOpError("argument types must match dict entry types");
502502

503503
return success();
504504
}
505505

506+
LogicalResult TestOp::verify() {
507+
if (getTemplateName().empty())
508+
return emitOpError("template name must not be empty");
509+
510+
return success();
511+
}
512+
513+
LogicalResult TestOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
514+
if (!getTargetAttr())
515+
return success();
516+
517+
auto target =
518+
symbolTable.lookupNearestSymbolFrom<TargetOp>(*this, getTargetAttr());
519+
if (!target)
520+
return emitOpError()
521+
<< "'" << *getTarget()
522+
<< "' does not reference a valid 'rtg.target' operation";
523+
524+
// Check if target is a subtype of test requirements
525+
// Since entries are sorted by name, we can do this in a single pass
526+
size_t targetIdx = 0;
527+
auto targetEntries = target.getTarget().getEntries();
528+
for (auto testEntry : getTargetType().getEntries()) {
529+
// Find the matching entry in target entries.
530+
while (targetIdx < targetEntries.size() &&
531+
targetEntries[targetIdx].name.getValue() < testEntry.name.getValue())
532+
targetIdx++;
533+
534+
// Check if we found a matching entry with the same name and type
535+
if (targetIdx >= targetEntries.size() ||
536+
targetEntries[targetIdx].name != testEntry.name ||
537+
targetEntries[targetIdx].type != testEntry.type) {
538+
return emitOpError("referenced 'rtg.target' op's type is invalid: "
539+
"missing entry called '")
540+
<< testEntry.name.getValue() << "' of type " << testEntry.type;
541+
}
542+
}
543+
544+
return success();
545+
}
546+
506547
ParseResult TestOp::parse(OpAsmParser &parser, OperationState &result) {
507548
// Parse the name as a symbol.
508-
if (parser.parseSymbolName(
509-
result.getOrAddProperties<TestOp::Properties>().sym_name))
549+
StringAttr symNameAttr;
550+
if (parser.parseSymbolName(symNameAttr))
510551
return failure();
511552

553+
result.getOrAddProperties<TestOp::Properties>().sym_name = symNameAttr;
554+
512555
// Parse the function signature.
513556
SmallVector<OpAsmParser::Argument> arguments;
514557
SmallVector<StringAttr> names;
@@ -544,7 +587,31 @@ ParseResult TestOp::parse(OpAsmParser &parser, OperationState &result) {
544587
ArrayRef<DictEntry>(entries));
545588
if (!type)
546589
return failure();
547-
result.getOrAddProperties<TestOp::Properties>().target = TypeAttr::get(type);
590+
result.getOrAddProperties<TestOp::Properties>().targetType =
591+
TypeAttr::get(type);
592+
593+
std::string templateName;
594+
if (!parser.parseOptionalKeyword("template")) {
595+
auto loc = parser.getCurrentLocation();
596+
if (parser.parseString(&templateName))
597+
return failure();
598+
599+
if (templateName.empty())
600+
return parser.emitError(loc, "template name must not be empty");
601+
}
602+
603+
StringAttr templateNameAttr = symNameAttr;
604+
if (!templateName.empty())
605+
templateNameAttr = StringAttr::get(result.getContext(), templateName);
606+
607+
StringAttr targetName;
608+
if (!parser.parseOptionalKeyword("target"))
609+
if (parser.parseSymbolName(targetName))
610+
return failure();
611+
612+
result.getOrAddProperties<TestOp::Properties>().templateName =
613+
templateNameAttr;
614+
result.getOrAddProperties<TestOp::Properties>().target = targetName;
548615

549616
auto loc = parser.getCurrentLocation();
550617
if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
@@ -574,23 +641,33 @@ void TestOp::print(OpAsmPrinter &p) {
574641
p << "(";
575642
SmallString<32> resultNameStr;
576643
llvm::interleaveComma(
577-
llvm::zip(getTarget().getEntries(), getBody()->getArguments()), p,
644+
llvm::zip(getTargetType().getEntries(), getBody()->getArguments()), p,
578645
[&](auto entryAndArg) {
579646
auto [entry, arg] = entryAndArg;
580647
p << entry.name.getValue() << " = ";
581648
p.printRegionArgument(arg);
582649
});
583650
p << ")";
651+
652+
if (getSymNameAttr() != getTemplateNameAttr())
653+
p << " template " << getTemplateNameAttr();
654+
655+
if (getTargetAttr()) {
656+
p << " target ";
657+
p.printSymbolName(getTargetAttr().getValue());
658+
}
659+
584660
p.printOptionalAttrDictWithKeyword(
585-
(*this)->getAttrs(), {getSymNameAttrName(), getTargetAttrName()});
661+
(*this)->getAttrs(), {getSymNameAttrName(), getTargetTypeAttrName(),
662+
getTargetAttrName(), getTemplateNameAttrName()});
586663
p << ' ';
587664
p.printRegion(getBodyRegion(), /*printEntryBlockArgs=*/false);
588665
}
589666

590667
void TestOp::getAsmBlockArgumentNames(Region &region,
591668
OpAsmSetValueNameFn setNameFn) {
592669
for (auto [entry, arg] :
593-
llvm::zip(getTarget().getEntries(), region.getArguments()))
670+
llvm::zip(getTargetType().getEntries(), region.getArguments()))
594671
setNameFn(arg, entry.name.getValue());
595672
}
596673

0 commit comments

Comments
 (0)