Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RTG] Add tuple operations #8370

Open
wants to merge 1 commit into
base: maerhart-pyrtg-random-integer
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions include/circt/Dialect/RTG/IR/RTGOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
//
//===----------------------------------------------------------------------===//

include "mlir/IR/BuiltinTypes.td"
include "mlir/IR/EnumAttr.td"
include "mlir/IR/CommonTypeConstraints.td"
include "mlir/IR/CommonAttrConstraints.td"
Expand Down Expand Up @@ -525,6 +526,36 @@ def ArraySizeOp : RTGOp<"array_size", [Pure]> {
let assemblyFormat = "$array `:` qualified(type($array)) attr-dict";
}

//===- Tuple Operations ---------------------------------------------------===//

def TupleCreateOp : RTGOp<"tuple_create", [
Pure,
DeclareOpInterfaceMethods<InferTypeOpInterface>,
]> {
let summary = "create a tuple";

let arguments = (ins Variadic<AnyType>:$elements);
let results = (outs Builtin_Tuple:$result);

let assemblyFormat = [{
($elements^ `:` qualified(type($elements)))? attr-dict
}];
}

def TupleExtractOp : RTGOp<"tuple_extract", [
Pure,
DeclareOpInterfaceMethods<InferTypeOpInterface>,
]> {
let summary = "get an element from a tuple";

let arguments = (ins Builtin_Tuple:$tuple, IndexAttr:$index);
let results = (outs AnyType:$result);

let assemblyFormat = [{
$tuple `at` $index `:` qualified(type($tuple)) attr-dict
}];
}

//===- Integer Operations -------------------------------------------------===//

def RandomNumberInRangeOp : RTGOp<"random_number_in_range", []> {
Expand Down
4 changes: 4 additions & 0 deletions include/circt/Dialect/RTG/IR/RTGVisitors.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ class RTGOpVisitor {
SetSizeOp,
// Arrays
ArrayCreateOp, ArrayExtractOp, ArrayInjectOp, ArraySizeOp,
// Tuples
TupleCreateOp, TupleExtractOp,
// Immediates
IntToImmediateOp,
// Memories
Expand Down Expand Up @@ -116,6 +118,8 @@ class RTGOpVisitor {
HANDLE(ArrayExtractOp, Unhandled);
HANDLE(ArrayInjectOp, Unhandled);
HANDLE(ArraySizeOp, Unhandled);
HANDLE(TupleCreateOp, Unhandled);
HANDLE(TupleExtractOp, Unhandled);
HANDLE(LabelDeclOp, Unhandled);
HANDLE(LabelUniqueDeclOp, Unhandled);
HANDLE(LabelOp, Unhandled);
Expand Down
46 changes: 46 additions & 0 deletions lib/Dialect/RTG/IR/RTGOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,52 @@ LogicalResult BagCreateOp::verify() {
return success();
}

//===----------------------------------------------------------------------===//
// TupleCreateOp
//===----------------------------------------------------------------------===//

LogicalResult TupleCreateOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> loc, ValueRange operands,
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
if (operands.empty()) {
if (loc)
return mlir::emitError(*loc) << "empty tuples not allowed";
return failure();
}

SmallVector<Type> elementTypes;
for (auto operand : operands)
elementTypes.push_back(operand.getType());
inferredReturnTypes.push_back(TupleType::get(context, elementTypes));
return success();
}

//===----------------------------------------------------------------------===//
// TupleExtractOp
//===----------------------------------------------------------------------===//

LogicalResult TupleExtractOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> loc, ValueRange operands,
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
assert(operands.size() == 1 && "must have exactly one operand");

auto tupleTy = dyn_cast<TupleType>(operands[0].getType());
size_t idx = properties.as<Properties *>()->getIndex().getInt();
if (!tupleTy || tupleTy.getTypes().size() <= idx) {
if (loc)
return mlir::emitError(*loc)
<< "index (" << idx
<< ") must be smaller than number of elements in tuple ("
<< tupleTy.getTypes().size() << ")";
return failure();
}

inferredReturnTypes.push_back(tupleTy.getTypes()[idx]);
return success();
}

//===----------------------------------------------------------------------===//
// FixedRegisterOp
//===----------------------------------------------------------------------===//
Expand Down
60 changes: 59 additions & 1 deletion lib/Dialect/RTG/Transforms/ElaborationPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ struct SetStorage;
struct VirtualRegisterStorage;
struct UniqueLabelStorage;
struct MemoryBlockStorage;
struct TupleStorage;

/// Simple wrapper around a 'StringAttr' such that we know to materialize it as
/// a label declaration instead of calling the builtin dialect constant
Expand All @@ -113,7 +114,8 @@ using ElaboratorValue =
std::variant<TypedAttr, BagStorage *, bool, size_t, SequenceStorage *,
RandomizedSequenceStorage *, InterleavedSequenceStorage *,
SetStorage *, VirtualRegisterStorage *, UniqueLabelStorage *,
LabelValue, MemoryBlockStorage *, ArrayStorage *>;
LabelValue, MemoryBlockStorage *, ArrayStorage *,
TupleStorage *>;

// NOLINTNEXTLINE(readability-identifier-naming)
llvm::hash_code hash_value(const LabelValue &val) {
Expand Down Expand Up @@ -401,6 +403,22 @@ struct ArrayStorage {
const SmallVector<ElaboratorValue> array;
};

/// Storage object for 'tuple`-typed values.
struct TupleStorage {
TupleStorage(SmallVector<ElaboratorValue> &&values)
: hashcode(llvm::hash_combine_range(values.begin(), values.end())),
values(std::move(values)) {}

bool isEqual(const TupleStorage *other) const {
return hashcode == other->hashcode && values == other->values;
}

// The cached hashcode to avoid repeated computations.
const unsigned hashcode;

const SmallVector<ElaboratorValue> values;
};

/// An 'Internalizer' object internalizes storages and takes ownership of them.
/// When the initializer object is destroyed, all owned storages are also
/// deallocated and thus must not be accessed anymore.
Expand Down Expand Up @@ -447,6 +465,8 @@ class Internalizer {
return internedRandomizedSequences;
else if constexpr (std::is_same_v<StorageTy, InterleavedSequenceStorage>)
return internedInterleavedSequences;
else if constexpr (std::is_same_v<StorageTy, TupleStorage>)
return internedTuples;
else
static_assert(!sizeof(StorageTy),
"no intern set available for this storage type.");
Expand All @@ -471,6 +491,8 @@ class Internalizer {
DenseSet<HashedStorage<InterleavedSequenceStorage>,
StorageKeyInfo<InterleavedSequenceStorage>>
internedInterleavedSequences;
DenseSet<HashedStorage<TupleStorage>, StorageKeyInfo<TupleStorage>>
internedTuples;
};

} // namespace
Expand Down Expand Up @@ -556,6 +578,13 @@ static void print(const MemoryBlockStorage *val, llvm::raw_ostream &os) {
<< ", base-address=" << val->baseAddress << "}>";
}

static void print(const TupleStorage *val, llvm::raw_ostream &os) {
os << "<tuple (";
llvm::interleaveComma(val->values, os,
[&](const ElaboratorValue &val) { os << val; });
os << ")>";
}

static llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
const ElaboratorValue &value) {
std::visit([&](auto val) { print(val, os); }, value);
Expand Down Expand Up @@ -839,6 +868,18 @@ class Materializer {
return res;
}

Value visit(TupleStorage *val, Location loc,
std::queue<RandomizedSequenceStorage *> &elabRequests,
function_ref<InFlightDiagnostic()> emitError) {
SmallVector<Value> materialized;
materialized.reserve(val->values.size());
for (auto v : val->values)
materialized.push_back(materialize(v, loc, elabRequests, emitError));
Value res = builder.create<TupleCreateOp>(loc, materialized);
materializedValues[val] = res;
return res;
}

private:
/// Cache values we have already materialized to reuse them later. We start
/// with an insertion point at the start of the block and cache the (updated)
Expand Down Expand Up @@ -1360,6 +1401,23 @@ class Elaborator : public RTGOpVisitor<Elaborator, FailureOr<DeletionKind>> {
return DeletionKind::Delete;
}

FailureOr<DeletionKind> visitOp(TupleCreateOp op) {
SmallVector<ElaboratorValue> values;
values.reserve(op.getElements().size());
for (auto el : op.getElements())
values.push_back(state[el]);

state[op.getResult()] =
sharedState.internalizer.internalize<TupleStorage>(std::move(values));
return DeletionKind::Delete;
}

FailureOr<DeletionKind> visitOp(TupleExtractOp op) {
auto *tuple = get<TupleStorage *>(op.getTuple());
state[op.getResult()] = tuple->values[op.getIndex().getZExtValue()];
return DeletionKind::Delete;
}

FailureOr<DeletionKind> visitOp(scf::IfOp op) {
bool cond = get<bool>(op.getCondition());
auto &toElaborate = cond ? op.getThenRegion() : op.getElseRegion();
Expand Down
12 changes: 12 additions & 0 deletions test/Dialect/RTG/IR/basic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -187,3 +187,15 @@ rtg.test @arrays(arr = %arr: !rtg.array<index>) {
%3 = rtg.array_inject %2[%idx1], %idx1 : !rtg.array<index>
%4 = rtg.array_size %3 : !rtg.array<index>
}

// CHECK-LABEL: rtg.test @tuples
rtg.test @tuples() {
// CHECK-NEXT: [[IDX0:%.+]] = index.constant 0
// CHECK-NEXT: [[TRUE:%.+]] = index.bool.constant true
// CHECK-NEXT: [[TUPLE:%.+]] = rtg.tuple_create [[IDX0]], [[TRUE]] : index, i1
// CHECK-NEXT: rtg.tuple_extract [[TUPLE]] at 1 : tuple<index, i1>
%idx0 = index.constant 0
%true = index.bool.constant true
%0 = rtg.tuple_create %idx0, %true : index, i1
%1 = rtg.tuple_extract %0 at 1 : tuple<index, i1>
}
14 changes: 14 additions & 0 deletions test/Dialect/RTG/IR/errors.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -189,3 +189,17 @@ rtg.test @test() {
// expected-error @below {{must have at least one sequence in the list}}
%0 = rtg.interleave_sequences
}

// -----

rtg.test @emptyTuple() {
// expected-error @below {{empty tuples not allowed}}
%0 = rtg.tuple_create
}

// -----

rtg.test @tupleExtractOOB(tup = %tup : tuple<index, i1>) {
// expected-error @below {{index (2) must be smaller than number of elements in tuple (2)}}
rtg.tuple_extract %tup at 2 : tuple<index, i1>
}
18 changes: 18 additions & 0 deletions test/Dialect/RTG/Transform/elaboration.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ func.func @dummy4(%arg0: index, %arg1: index, %arg2: !rtg.bag<index>, %arg3: !rt
func.func @dummy5(%arg0: i1) -> () {return}
func.func @dummy6(%arg0: !rtg.isa.immediate<2>) -> () {return}
func.func @dummy7(%arg0: !rtg.array<index>) -> () {return}
func.func @dummy8(%arg0: tuple<index, index>) -> () {return}

// CHECK-LABEL: @immediates
rtg.test @immediates() {
Expand Down Expand Up @@ -559,6 +560,23 @@ rtg.test @arithOps() {
func.call @dummy2(%6) : (index) -> ()
}

// CHECK-LABEL: rtg.test @tuples
rtg.test @tuples() {
%idx0 = index.constant 0
%idx1 = index.constant 1
%0 = rtg.tuple_create %idx1, %idx0 : index, index
%1 = rtg.tuple_extract %0 at 1 : tuple<index, index>

// CHECK-NEXT: %idx1 = index.constant 1
// CHECK-NEXT: %idx0 = index.constant 0
// CHECK-NEXT: [[V0:%.+]] = rtg.tuple_create %idx1, %idx0 : index, index
// CHECK-NEXT: func.call @dummy8([[V0]])
func.call @dummy8(%0) : (tuple<index, index>) -> ()

// CHECK-NEXT: func.call @dummy2(%idx0)
func.call @dummy2(%1) : (index) -> ()
}

// -----

rtg.test @nestedRegionsNotSupported() {
Expand Down
Loading