@@ -497,18 +497,61 @@ LogicalResult ContextSwitchOp::verify() {
497
497
// ===----------------------------------------------------------------------===//
498
498
499
499
LogicalResult TestOp::verifyRegions () {
500
- if (!getTarget ().entryTypesMatch (getBody ()->getArgumentTypes ()))
500
+ if (!getTargetType ().entryTypesMatch (getBody ()->getArgumentTypes ()))
501
501
return emitOpError (" argument types must match dict entry types" );
502
502
503
503
return success ();
504
504
}
505
505
506
+ LogicalResult TestOp::verify () {
507
+ if (getTemplateName ().empty ())
508
+ return emitOpError (" template name must not be empty" );
509
+
510
+ return success ();
511
+ }
512
+
513
+ LogicalResult TestOp::verifySymbolUses (SymbolTableCollection &symbolTable) {
514
+ if (!getTargetAttr ())
515
+ return success ();
516
+
517
+ auto target =
518
+ symbolTable.lookupNearestSymbolFrom <TargetOp>(*this , getTargetAttr ());
519
+ if (!target)
520
+ return emitOpError ()
521
+ << " '" << *getTarget ()
522
+ << " ' does not reference a valid 'rtg.target' operation" ;
523
+
524
+ // Check if target is a subtype of test requirements
525
+ // Since entries are sorted by name, we can do this in a single pass
526
+ size_t targetIdx = 0 ;
527
+ auto targetEntries = target.getTarget ().getEntries ();
528
+ for (auto testEntry : getTargetType ().getEntries ()) {
529
+ // Find the matching entry in target entries.
530
+ while (targetIdx < targetEntries.size () &&
531
+ targetEntries[targetIdx].name .getValue () < testEntry.name .getValue ())
532
+ targetIdx++;
533
+
534
+ // Check if we found a matching entry with the same name and type
535
+ if (targetIdx >= targetEntries.size () ||
536
+ targetEntries[targetIdx].name != testEntry.name ||
537
+ targetEntries[targetIdx].type != testEntry.type ) {
538
+ return emitOpError (" referenced 'rtg.target' op's type is invalid: "
539
+ " missing entry called '" )
540
+ << testEntry.name .getValue () << " ' of type " << testEntry.type ;
541
+ }
542
+ }
543
+
544
+ return success ();
545
+ }
546
+
506
547
ParseResult TestOp::parse (OpAsmParser &parser, OperationState &result) {
507
548
// Parse the name as a symbol.
508
- if (parser. parseSymbolName (
509
- result. getOrAddProperties <TestOp::Properties>(). sym_name ))
549
+ StringAttr symNameAttr;
550
+ if (parser. parseSymbolName (symNameAttr ))
510
551
return failure ();
511
552
553
+ result.getOrAddProperties <TestOp::Properties>().sym_name = symNameAttr;
554
+
512
555
// Parse the function signature.
513
556
SmallVector<OpAsmParser::Argument> arguments;
514
557
SmallVector<StringAttr> names;
@@ -544,7 +587,31 @@ ParseResult TestOp::parse(OpAsmParser &parser, OperationState &result) {
544
587
ArrayRef<DictEntry>(entries));
545
588
if (!type)
546
589
return failure ();
547
- result.getOrAddProperties <TestOp::Properties>().target = TypeAttr::get (type);
590
+ result.getOrAddProperties <TestOp::Properties>().targetType =
591
+ TypeAttr::get (type);
592
+
593
+ std::string templateName;
594
+ if (!parser.parseOptionalKeyword (" template" )) {
595
+ auto loc = parser.getCurrentLocation ();
596
+ if (parser.parseString (&templateName))
597
+ return failure ();
598
+
599
+ if (templateName.empty ())
600
+ return parser.emitError (loc, " template name must not be empty" );
601
+ }
602
+
603
+ StringAttr templateNameAttr = symNameAttr;
604
+ if (!templateName.empty ())
605
+ templateNameAttr = StringAttr::get (result.getContext (), templateName);
606
+
607
+ StringAttr targetName;
608
+ if (!parser.parseOptionalKeyword (" target" ))
609
+ if (parser.parseSymbolName (targetName))
610
+ return failure ();
611
+
612
+ result.getOrAddProperties <TestOp::Properties>().templateName =
613
+ templateNameAttr;
614
+ result.getOrAddProperties <TestOp::Properties>().target = targetName;
548
615
549
616
auto loc = parser.getCurrentLocation ();
550
617
if (parser.parseOptionalAttrDictWithKeyword (result.attributes ))
@@ -574,23 +641,33 @@ void TestOp::print(OpAsmPrinter &p) {
574
641
p << " (" ;
575
642
SmallString<32 > resultNameStr;
576
643
llvm::interleaveComma (
577
- llvm::zip (getTarget ().getEntries (), getBody ()->getArguments ()), p,
644
+ llvm::zip (getTargetType ().getEntries (), getBody ()->getArguments ()), p,
578
645
[&](auto entryAndArg) {
579
646
auto [entry, arg] = entryAndArg;
580
647
p << entry.name .getValue () << " = " ;
581
648
p.printRegionArgument (arg);
582
649
});
583
650
p << " )" ;
651
+
652
+ if (getSymNameAttr () != getTemplateNameAttr ())
653
+ p << " template " << getTemplateNameAttr ();
654
+
655
+ if (getTargetAttr ()) {
656
+ p << " target " ;
657
+ p.printSymbolName (getTargetAttr ().getValue ());
658
+ }
659
+
584
660
p.printOptionalAttrDictWithKeyword (
585
- (*this )->getAttrs (), {getSymNameAttrName (), getTargetAttrName ()});
661
+ (*this )->getAttrs (), {getSymNameAttrName (), getTargetTypeAttrName (),
662
+ getTargetAttrName (), getTemplateNameAttrName ()});
586
663
p << ' ' ;
587
664
p.printRegion (getBodyRegion (), /* printEntryBlockArgs=*/ false );
588
665
}
589
666
590
667
void TestOp::getAsmBlockArgumentNames (Region ®ion,
591
668
OpAsmSetValueNameFn setNameFn) {
592
669
for (auto [entry, arg] :
593
- llvm::zip (getTarget ().getEntries (), region.getArguments ()))
670
+ llvm::zip (getTargetType ().getEntries (), region.getArguments ()))
594
671
setNameFn (arg, entry.name .getValue ());
595
672
}
596
673
0 commit comments