From 4a4ee1fe3338e2c6930f64b9d67ae1d7d83bb165 Mon Sep 17 00:00:00 2001 From: Martin Erhart Date: Tue, 1 Apr 2025 15:32:45 +0100 Subject: [PATCH] [RTG] Add bag_convert_to_set operation --- frontends/PyRTG/src/pyrtg/bags.py | 8 ++++++++ frontends/PyRTG/test/basic.py | 7 +++++-- include/circt/Dialect/RTG/IR/RTGOps.td | 18 ++++++++++++++++++ include/circt/Dialect/RTG/IR/RTGVisitors.h | 3 ++- lib/Dialect/RTG/Transforms/ElaborationPass.cpp | 10 ++++++++++ test/Dialect/RTG/IR/basic.mlir | 2 ++ test/Dialect/RTG/Transform/elaboration.mlir | 6 ++++++ 7 files changed, 51 insertions(+), 3 deletions(-) diff --git a/frontends/PyRTG/src/pyrtg/bags.py b/frontends/PyRTG/src/pyrtg/bags.py index b911be9efee0..306b07f58743 100644 --- a/frontends/PyRTG/src/pyrtg/bags.py +++ b/frontends/PyRTG/src/pyrtg/bags.py @@ -129,6 +129,14 @@ def get_random_and_exclude(self) -> Value: self._value = self.exclude(r)._get_ssa_value() return r + def to_set(self) -> Value: + """ + Returns this bag converted to a set, i.e., all duplicates are dropped. Does + not modify this object. + """ + + return rtg.BagConvertToSetOp(self) + def _get_ssa_value(self) -> ir.Value: return self._value diff --git a/frontends/PyRTG/test/basic.py b/frontends/PyRTG/test/basic.py index c6d064061b39..73add7f1b9ca 100644 --- a/frontends/PyRTG/test/basic.py +++ b/frontends/PyRTG/test/basic.py @@ -437,8 +437,11 @@ def test90_tuples(a, b, tup): # MLIR-LABEL: rtg.test @test91_sets # MLIR-NEXT: rtg.set_cartesian_product %a, %b : !rtg.set, !rtg.set +# MLIR: rtg.bag_convert_to_set %c : !rtg.bag -@test(("a", Set.type(Integer.type())), ("b", Set.type(Bool.type()))) -def test91_sets(a, b): +@test(("a", Set.type(Integer.type())), ("b", Set.type(Bool.type())), + ("c", Bag.type(Integer.type()))) +def test91_sets(a, b, c): seq2(Set.cartesian_product(a, b)) + int_consumer(c.to_set().get_random()) diff --git a/include/circt/Dialect/RTG/IR/RTGOps.td b/include/circt/Dialect/RTG/IR/RTGOps.td index ca2453a39909..220f63927e67 100644 --- a/include/circt/Dialect/RTG/IR/RTGOps.td +++ b/include/circt/Dialect/RTG/IR/RTGOps.td @@ -487,6 +487,24 @@ def BagUniqueSizeOp : RTGOp<"bag_unique_size", [Pure]> { }]; } +def BagConvertToSetOp : RTGOp<"bag_convert_to_set", [ + Pure, + TypesMatchWith<"element type of set must match the bag's element type", + "input", "result", + "SetType::get(cast($_self).getElementType())">, +]> { + let summary = "convert a bag to a set"; + let description = [{ + This operation converts a bag to a set by dropping all duplicate elements. + For example, the bag `{a, a, b}` is converted to `{a, b}`. + }]; + + let arguments = (ins BagType:$input); + let results = (outs SetType:$result); + + let assemblyFormat = "$input `:` qualified(type($input)) attr-dict"; +} + //===- Array Operations -------------------------------------------------===// def ArrayCreateOp : RTGOp<"array_create", [ diff --git a/include/circt/Dialect/RTG/IR/RTGVisitors.h b/include/circt/Dialect/RTG/IR/RTGVisitors.h index 1cfb94b664ed..d57b613590c1 100644 --- a/include/circt/Dialect/RTG/IR/RTGVisitors.h +++ b/include/circt/Dialect/RTG/IR/RTGVisitors.h @@ -36,7 +36,7 @@ class RTGOpVisitor { ConstantOp, // Bags BagCreateOp, BagSelectRandomOp, BagDifferenceOp, BagUnionOp, - BagUniqueSizeOp, + BagUniqueSizeOp, BagConvertToSetOp, // Contexts OnContextOp, ContextSwitchOp, // Labels @@ -115,6 +115,7 @@ class RTGOpVisitor { HANDLE(BagDifferenceOp, Unhandled); HANDLE(BagUnionOp, Unhandled); HANDLE(BagUniqueSizeOp, Unhandled); + HANDLE(BagConvertToSetOp, Unhandled); HANDLE(ArrayCreateOp, Unhandled); HANDLE(ArrayExtractOp, Unhandled); HANDLE(ArrayInjectOp, Unhandled); diff --git a/lib/Dialect/RTG/Transforms/ElaborationPass.cpp b/lib/Dialect/RTG/Transforms/ElaborationPass.cpp index 18ff394bfc0d..39e1d2ef99c6 100644 --- a/lib/Dialect/RTG/Transforms/ElaborationPass.cpp +++ b/lib/Dialect/RTG/Transforms/ElaborationPass.cpp @@ -1264,6 +1264,16 @@ class Elaborator : public RTGOpVisitor> { return DeletionKind::Delete; } + FailureOr visitOp(BagConvertToSetOp op) { + auto bag = get(op.getInput())->bag; + SetVector set; + for (auto [k, v] : bag) + set.insert(k); + state[op.getResult()] = sharedState.internalizer.internalize( + std::move(set), op.getType()); + return DeletionKind::Delete; + } + FailureOr visitOp(FixedRegisterOp op) { return visitPureOp(op); } diff --git a/test/Dialect/RTG/IR/basic.mlir b/test/Dialect/RTG/IR/basic.mlir index bc7c5138e562..b868749fef38 100644 --- a/test/Dialect/RTG/IR/basic.mlir +++ b/test/Dialect/RTG/IR/basic.mlir @@ -87,6 +87,7 @@ rtg.sequence @bags(%arg0: i32, %arg1: i32, %arg2: index) { // CHECK: rtg.bag_difference [[BAG]], [[EMPTY]] inf : !rtg.bag // CHECK: rtg.bag_union [[BAG]], [[EMPTY]], [[DIFF]] : !rtg.bag // CHECK: rtg.bag_unique_size [[BAG]] : !rtg.bag + // CHECK: rtg.bag_convert_to_set [[BAG]] : !rtg.bag %bag = rtg.bag_create (%arg2 x %arg0, %arg2 x %arg1) : i32 {rtg.some_attr} %r = rtg.bag_select_random %bag : !rtg.bag {rtg.some_attr} %empty = rtg.bag_create : i32 @@ -94,6 +95,7 @@ rtg.sequence @bags(%arg0: i32, %arg1: i32, %arg2: index) { %diff2 = rtg.bag_difference %bag, %empty inf : !rtg.bag %union = rtg.bag_union %bag, %empty, %diff : !rtg.bag %size = rtg.bag_unique_size %bag : !rtg.bag + %set = rtg.bag_convert_to_set %bag : !rtg.bag } // CHECK-LABEL: rtg.target @empty_target : !rtg.dict<> { diff --git a/test/Dialect/RTG/Transform/elaboration.mlir b/test/Dialect/RTG/Transform/elaboration.mlir index 1a4860b417b1..baffc36da419 100644 --- a/test/Dialect/RTG/Transform/elaboration.mlir +++ b/test/Dialect/RTG/Transform/elaboration.mlir @@ -10,6 +10,7 @@ func.func @dummy7(%arg0: !rtg.array) -> () {return} func.func @dummy8(%arg0: tuple) -> () {return} func.func @dummy9(%arg0: !rtg.set>) -> () {return} func.func @dummy10(%arg0: !rtg.set>) -> () {return} +func.func @dummy11(%arg0: !rtg.set) -> () {return} // CHECK-LABEL: @immediates rtg.test @immediates() { @@ -117,6 +118,11 @@ rtg.test @bagOperations() { %diff2 = rtg.bag_difference %bag, %new_bag inf : !rtg.bag %4 = rtg.bag_select_random %diff2 : !rtg.bag {rtg.elaboration_custom_seed = 5} func.call @dummy4(%3, %4, %diff, %diff2) : (index, index, !rtg.bag, !rtg.bag) -> () + + // CHECK-NEXT: [[SET:%.+]] = rtg.set_create [[V0]], [[V2]] : + // CHECK-NEXT: func.call @dummy11([[SET]]) + %5 = rtg.bag_convert_to_set %bag0 : !rtg.bag + func.call @dummy11(%5) : (!rtg.set) -> () } // CHECK-LABEL: rtg.test @setSize