Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RTG] Add set_cartesian_product operation #8376

Open
wants to merge 1 commit into
base: maerhart-rtg-elaboration-folders
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions frontends/PyRTG/src/pyrtg/sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,18 @@ def __sub__(self, other: Value) -> Set:
), "type of the provided value must match element type of the set"
return self - Set.create(other)

@staticmethod
def cartesian_product(*args: Set) -> Set:
"""
Compute the n-ary cartesian product of the given sets.
This means, for n input sets it computes
`X_1 x ... x X_n = {(x_1, ..., x_n) | x_i \in X_i for i \in {1, ..., n}}`.
At least one input set has to be provided (i.e., `n > 0`).
"""

assert len(args) > 0, "at least one set must be provided"
return rtg.SetCartesianProductOp(args)

def get_random(self) -> Value:
"""
Returns an element from the set picked uniformly at random. If the set is
Expand Down
18 changes: 16 additions & 2 deletions frontends/PyRTG/test/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,11 @@ def seq1():
Label.declare("s1").place()


@sequence(Set.type(Tuple.type(Integer.type(), Bool.type())))
def seq2(set):
pass


# MLIR-LABEL: rtg.test @test0
# MLIR-NEXT: }

Expand Down Expand Up @@ -418,13 +423,22 @@ def test8_random_integer(a, b):
int_consumer(Integer.random(a, b))


# MLIR-LABEL: rtg.test @test9_tuples
# MLIR-LABEL: rtg.test @test90_tuples
# MLIR-NEXT: [[V0:%.+]] = rtg.tuple_create %a, %b : index, i1
# MLIR-NEXT: rtg.tuple_extract [[V0]] at 1 : tuple<index, i1>


@test(("a", Integer.type()), ("b", Bool.type()),
("tup", Tuple.type(Integer.type(), Bool.type())))
def test9_tuples(a, b, tup):
def test90_tuples(a, b, tup):
tup = Tuple.create(a, b)
consumer(tup[1])


# MLIR-LABEL: rtg.test @test91_sets
# MLIR-NEXT: rtg.set_cartesian_product %a, %b : !rtg.set<index>, !rtg.set<i1>


@test(("a", Set.type(Integer.type())), ("b", Set.type(Bool.type())))
def test91_sets(a, b):
seq2(Set.cartesian_product(a, b))
31 changes: 31 additions & 0 deletions include/circt/Dialect/RTG/IR/RTGOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,37 @@ def SetSizeOp : RTGOp<"set_size", [Pure]> {
}];
}

def SetCartesianProductOp : RTGOp<"set_cartesian_product", [
Pure,
DeclareOpInterfaceMethods<InferTypeOpInterface>,
]> {
let summary = "computes the n-ary cartesian product of sets";
let description = [{
This operation computes a set of tuples from a list of input sets such that
each combination of elements from the input sets is present in the result
set. More formally, for n input sets it computes
`X_1 x ... x X_n = {(x_1, ..., x_n) | x_i \in X_i for i \in {1, ..., n}}`.
At least one input set has to be provided (i.e., `n > 0`).

For example, given two sets A and B with elements
`A = {a0, a1}, B = {b0, b1}` the result set R will be
`R = {(a0, b0), (a0, b1), (a1, b0), (a1, b1)}`.

Note that an RTG set does not provide any guarantees about the order of
elements an can thus not be iterated over or indexed into, however, a
random element can be selected and subtracted from the set until it is
empty. This procedure is determinstic and will yield the same sequence of
elements for a fixed seed and RTG version. If more guarantees about the
order of elements is necessary, use arrays instead (and compute the
cartesian product manually using nested loops).
}];

let arguments = (ins Variadic<SetType>:$inputs);
let results = (outs SetType:$result);

let assemblyFormat = "$inputs `:` qualified(type($inputs)) attr-dict";
}

//===- Bag Operations ------------------------------------------------------===//

def BagCreateOp : RTGOp<"bag_create", [Pure, SameVariadicOperandSize]> {
Expand Down
5 changes: 5 additions & 0 deletions include/circt/Dialect/RTG/IR/RTGTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,11 @@ def SetType : RTGTypeDef<"Set"> {

let mnemonic = "set";
let assemblyFormat = "`<` $elementType `>`";

let builders = [
TypeBuilderWithInferredContext<(ins "::mlir::Type":$elementType),
"return $_get(elementType.getContext(), elementType);">,
];
}

def BagType : RTGTypeDef<"Bag"> {
Expand Down
3 changes: 2 additions & 1 deletion include/circt/Dialect/RTG/IR/RTGVisitors.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class RTGOpVisitor {
RandomizeSequenceOp, EmbedSequenceOp, InterleaveSequencesOp,
// Sets
SetCreateOp, SetSelectRandomOp, SetDifferenceOp, SetUnionOp,
SetSizeOp,
SetSizeOp, SetCartesianProductOp,
// Arrays
ArrayCreateOp, ArrayExtractOp, ArrayInjectOp, ArraySizeOp,
// Tuples
Expand Down Expand Up @@ -109,6 +109,7 @@ class RTGOpVisitor {
HANDLE(SetDifferenceOp, Unhandled);
HANDLE(SetUnionOp, Unhandled);
HANDLE(SetSizeOp, Unhandled);
HANDLE(SetCartesianProductOp, Unhandled);
HANDLE(BagCreateOp, Unhandled);
HANDLE(BagSelectRandomOp, Unhandled);
HANDLE(BagDifferenceOp, Unhandled);
Expand Down
22 changes: 22 additions & 0 deletions lib/Dialect/RTG/IR/RTGOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,28 @@ LogicalResult SetCreateOp::verify() {
return success();
}

//===----------------------------------------------------------------------===//
// SetCartesianProductOp
//===----------------------------------------------------------------------===//

LogicalResult SetCartesianProductOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> loc, ValueRange operands,
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
if (operands.empty()) {
if (loc)
return mlir::emitError(*loc) << "at least one set must be provided";
return failure();
}

SmallVector<Type> elementTypes;
for (auto operand : operands)
elementTypes.push_back(cast<SetType>(operand.getType()).getElementType());
inferredReturnTypes.push_back(
SetType::get(TupleType::get(context, elementTypes)));
return success();
}

//===----------------------------------------------------------------------===//
// BagCreateOp
//===----------------------------------------------------------------------===//
Expand Down
36 changes: 36 additions & 0 deletions lib/Dialect/RTG/Transforms/ElaborationPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1137,6 +1137,42 @@ class Elaborator : public RTGOpVisitor<Elaborator, FailureOr<DeletionKind>> {
return DeletionKind::Delete;
}

// {a0,a1} x {b0,b1} x {c0,c1} -> {(a0), (a1)} -> {(a0,b0), (a0,b1), (a1,b0),
// (a1,b1)} -> {(a0,b0,c0), (a0,b0,c1), (a0,b1,c0), (a0,b1,c1), (a1,b0,c0),
// (a1,b0,c1), (a1,b1,c0), (a1,b1,c1)}
FailureOr<DeletionKind> visitOp(SetCartesianProductOp op) {
SetVector<ElaboratorValue> result;
SmallVector<SmallVector<ElaboratorValue>> tuples;
tuples.push_back({});

for (auto input : op.getInputs()) {
auto &set = get<SetStorage *>(input)->set;
if (set.empty()) {
SetVector<ElaboratorValue> empty;
state[op.getResult()] =
sharedState.internalizer.internalize<SetStorage>(std::move(empty),
op.getType());
return DeletionKind::Delete;
}

for (unsigned i = 0, e = tuples.size(); i < e; ++i) {
for (auto setEl : set.getArrayRef().drop_back()) {
tuples.push_back(tuples[i]);
tuples.back().push_back(setEl);
}
tuples[i].push_back(set.back());
}
}

for (auto &tup : tuples)
result.insert(
sharedState.internalizer.internalize<TupleStorage>(std::move(tup)));

state[op.getResult()] = sharedState.internalizer.internalize<SetStorage>(
std::move(result), op.getType());
return DeletionKind::Delete;
}

FailureOr<DeletionKind> visitOp(BagCreateOp op) {
MapVector<ElaboratorValue, uint64_t> bag;
for (auto [val, multiple] :
Expand Down
6 changes: 4 additions & 2 deletions test/Dialect/RTG/IR/basic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -59,21 +59,23 @@ rtg.sequence @seqRandomizationAndEmbedding() {
}

// CHECK-LABEL: @sets
func.func @sets(%arg0: i32, %arg1: i32) {
func.func @sets(%arg0: i32, %arg1: i32) -> !rtg.set<tuple<i32, i32>> {
// CHECK: [[SET:%.+]] = rtg.set_create %arg0, %arg1 : i32
// CHECK: [[R:%.+]] = rtg.set_select_random [[SET]] : !rtg.set<i32>
// CHECK: [[EMPTY:%.+]] = rtg.set_create : i32
// CHECK: [[DIFF:%.+]] = rtg.set_difference [[SET]], [[EMPTY]] : !rtg.set<i32>
// CHECK: rtg.set_union [[SET]], [[DIFF]] : !rtg.set<i32>
// CHECK: rtg.set_size [[SET]] : !rtg.set<i32>
// CHECK: rtg.set_cartesian_product [[SET]], [[SET]] : !rtg.set<i32>, !rtg.set<i32>
%set = rtg.set_create %arg0, %arg1 : i32
%r = rtg.set_select_random %set : !rtg.set<i32>
%empty = rtg.set_create : i32
%diff = rtg.set_difference %set, %empty : !rtg.set<i32>
%union = rtg.set_union %set, %diff : !rtg.set<i32>
%size = rtg.set_size %set : !rtg.set<i32>
%prod = rtg.set_cartesian_product %set, %set : !rtg.set<i32>, !rtg.set<i32>

return
return %prod : !rtg.set<tuple<i32, i32>>
}

// CHECK-LABEL: @bags
Expand Down
8 changes: 8 additions & 0 deletions test/Dialect/RTG/IR/errors.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,14 @@ rtg.sequence @seq() {

// -----

rtg.sequence @setCartesianProduct() {
// expected-error @below {{at least one set must be provided}}
// expected-error @below {{failed to infer returned types}}
%0 = "rtg.set_cartesian_product"() : () -> (!rtg.set<tuple<index>>)
}

// -----

rtg.sequence @seq() {
// expected-error @below {{expected 1 or more operands, but found 0}}
rtg.bag_union : !rtg.bag<i32>
Expand Down
47 changes: 47 additions & 0 deletions test/Dialect/RTG/Transform/elaboration.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ func.func @dummy5(%arg0: i1) -> () {return}
func.func @dummy6(%arg0: !rtg.isa.immediate<2>) -> () {return}
func.func @dummy7(%arg0: !rtg.array<index>) -> () {return}
func.func @dummy8(%arg0: tuple<index, index>) -> () {return}
func.func @dummy9(%arg0: !rtg.set<tuple<index, i1, !rtgtest.ireg>>) -> () {return}
func.func @dummy10(%arg0: !rtg.set<tuple<index>>) -> () {return}

// CHECK-LABEL: @immediates
rtg.test @immediates() {
Expand Down Expand Up @@ -46,6 +48,51 @@ rtg.test @setOperations() {
func.call @dummy1(%4, %5, %diff) : (index, index, !rtg.set<index>) -> ()
}

// CHECK-LABEL: rtg.test @setCartesianProduct
rtg.test @setCartesianProduct() {
%idx0 = index.constant 0
%idx1 = index.constant 1
%0 = rtg.set_create %idx0, %idx1 : index
%true = index.bool.constant true
%false = index.bool.constant false
%1 = rtg.set_create %true, %false : i1
%s0 = rtg.fixed_reg #rtgtest.s0
%s1 = rtg.fixed_reg #rtgtest.s1
%2 = rtg.set_create %s0, %s1 : !rtgtest.ireg

// CHECK-DAG: [[IDX1:%.+]] = index.constant 1
// CHECK-DAG: [[FALSE:%.+]] = index.bool.constant false
// CHECK-DAG: [[S1:%.+]] = rtg.fixed_reg #rtgtest.s1 : !rtgtest.ireg
// CHECK-DAG: [[T1:%.+]] = rtg.tuple_create [[IDX1]], [[FALSE]], [[S1]] : index, i1, !rtgtest.ireg
// CHECK-DAG: [[IDX0:%.+]] = index.constant 0
// CHECK-DAG: [[T2:%.+]] = rtg.tuple_create [[IDX0]], [[FALSE]], [[S1]] : index, i1, !rtgtest.ireg
// CHECK-DAG: [[TRUE:%.+]] = index.bool.constant true
// CHECK-DAG: [[T3:%.+]] = rtg.tuple_create [[IDX1]], [[TRUE]], [[S1]] : index, i1, !rtgtest.ireg
// CHECK-DAG: [[T4:%.+]] = rtg.tuple_create [[IDX0]], [[TRUE]], [[S1]] : index, i1, !rtgtest.ireg
// CHECK-DAG: [[S0:%.+]] = rtg.fixed_reg #rtgtest.s0 : !rtgtest.ireg
// CHECK-DAG: [[T5:%.+]] = rtg.tuple_create [[IDX1]], [[FALSE]], [[S0]] : index, i1, !rtgtest.ireg
// CHECK-DAG: [[T6:%.+]] = rtg.tuple_create [[IDX0]], [[FALSE]], [[S0]] : index, i1, !rtgtest.ireg
// CHECK-DAG: [[T7:%.+]] = rtg.tuple_create [[IDX1]], [[TRUE]], [[S0]] : index, i1, !rtgtest.ireg
// CHECK-DAG: [[T8:%.+]] = rtg.tuple_create [[IDX0]], [[TRUE]], [[S0]] : index, i1, !rtgtest.ireg
// CHECK-DAG: [[SET:%.+]] = rtg.set_create [[T1]], [[T2]], [[T3]], [[T4]], [[T5]], [[T6]], [[T7]], [[T8]] : tuple<index, i1, !rtgtest.ireg>
// CHECK-NEXT: func.call @dummy9([[SET]]) : (!rtg.set<tuple<index, i1, !rtgtest.ireg>>) -> ()
%3 = rtg.set_cartesian_product %0, %1, %2 : !rtg.set<index>, !rtg.set<i1>, !rtg.set<!rtgtest.ireg>
func.call @dummy9(%3) : (!rtg.set<tuple<index, i1, !rtgtest.ireg>>) -> ()

// CHECK-NEXT: [[EMPTY:%.+]] = rtg.set_create : tuple<index, i1, !rtgtest.ireg>
// CHECK-NEXT: func.call @dummy9([[EMPTY]]) : (!rtg.set<tuple<index, i1, !rtgtest.ireg>>) -> ()
%4 = rtg.set_create : !rtgtest.ireg
%5 = rtg.set_cartesian_product %0, %1, %4 : !rtg.set<index>, !rtg.set<i1>, !rtg.set<!rtgtest.ireg>
func.call @dummy9(%5) : (!rtg.set<tuple<index, i1, !rtgtest.ireg>>) -> ()

// CHECK-NEXT: [[T9:%.+]] = rtg.tuple_create [[IDX1]] : index
// CHECK-NEXT: [[T10:%.+]] = rtg.tuple_create [[IDX0]] : index
// CHECK-NEXT: [[SET2:%.+]] = rtg.set_create [[T9]], [[T10]] : tuple<index>
// CHECK-NEXT: func.call @dummy10([[SET2]]) : (!rtg.set<tuple<index>>) -> ()
%6 = rtg.set_cartesian_product %0 : !rtg.set<index>
func.call @dummy10(%6) : (!rtg.set<tuple<index>>) -> ()
}

// CHECK-LABEL: rtg.test @bagOperations
rtg.test @bagOperations() {
// CHECK-NEXT: [[V0:%.+]] = index.constant 2
Expand Down