Skip to content

Commit 36e27d5

Browse files
#519 improve set integration bench (#527)
* #519 improve set integration bench closes #519 issue * fix + improved perf * fix error name --------- Co-authored-by: lanaivina <[email protected]>
1 parent 28e4fdc commit 36e27d5

File tree

5 files changed

+87
-50
lines changed

5 files changed

+87
-50
lines changed

build.zig.zon

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
.hash = "1220ab73fb7cc11b2308edc3364988e05efcddbcac31b707f55e6216d1b9c0da13f1",
1515
},
1616
.starknet = .{
17-
.url = "https://github.com/StringNick/starknet-zig/archive/8cfb4286ffda4ad2781647c3d96b2aec8ccfeb32.zip",
18-
.hash = "122026eaa24834fd2e2cc7e8b6c4eefb03dda08158a2844615f189758fa24d32fc44",
17+
.url = "https://github.com/StringNick/starknet-zig/archive/57810b7a64364f1bf12725ba823385c2a213bfa5.zip",
18+
.hash = "1220d848be799ff21a80c6751c088ea619891ec450f20017cc7aa5cbbeb5904ae8b8",
1919
},
2020
},
2121
}

src/hint_processor/set.zig

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ const HintProcessor = @import("hint_processor_def.zig").CairoVMHintProcessor;
1212
const HintData = @import("hint_processor_def.zig").HintData;
1313
const Relocatable = @import("../vm/memory/relocatable.zig").Relocatable;
1414
const MaybeRelocatable = @import("../vm/memory/relocatable.zig").MaybeRelocatable;
15+
const MemoryCell = @import("../vm/memory/memory.zig").MemoryCell;
1516
const Felt252 = @import("../math/fields/starknet.zig").Felt252;
1617
const hint_codes = @import("builtin_hint_codes.zig");
1718
const MathError = @import("../vm/error.zig").MathError;
@@ -60,10 +61,22 @@ pub fn setAdd(
6061
// Calculate the range limit.
6162
const range_limit = (try set_end_ptr.sub(set_ptr)).offset;
6263

64+
// load all list, and then we compare elements
65+
var elm_segment = vm.segments.memory.getSegmentAtIndex(elm_ptr.segment_index) orelse return HintError.InvalidSetRange;
66+
67+
if (elm_segment.len < elm_ptr.offset + elm_size) return HintError.InvalidSetRange;
68+
69+
var set_segment = vm.segments.memory.getSegmentAtIndex(set_ptr.segment_index) orelse return HintError.InvalidSetRange;
70+
71+
if (set_ptr.offset + range_limit > set_segment.len) return HintError.InvalidSetRange;
72+
73+
elm_segment = elm_segment[elm_ptr.offset .. elm_ptr.offset + elm_size];
74+
set_segment = set_segment[set_ptr.offset .. set_ptr.offset + range_limit];
75+
6376
// Iterate over the set elements.
64-
for (0..range_limit) |i| {
77+
for (0..range_limit / elm_size) |i| {
6578
// Check if the element is in the set.
66-
if (try vm.memEq(elm_ptr, try set_ptr.addUint(elm_size * i), elm_size)) {
79+
if (MemoryCell.eqlSlice(elm_segment, set_segment[i * elm_size .. (i + 1) * elm_size])) {
6780
// Insert index of the element into the virtual machine.
6881
try hint_utils.insertValueFromVarName(
6982
allocator,

src/vm/core_test.zig

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3666,7 +3666,7 @@ test "CairoVM: runInstruction without any insertion in the memory" {
36663666
// Compare each cell in VM's memory with the corresponding cell in the expected memory.
36673667
for (vm.segments.memory.data.items, 0..) |d, i| {
36683668
for (d.items, 0..) |cell, j| {
3669-
try expect(cell.eql(expected_memory.data.items[i].items[j]));
3669+
try expect(cell.eql(&expected_memory.data.items[i].items[j]));
36703670
}
36713671
}
36723672
}
@@ -3839,7 +3839,7 @@ test "CairoVM: runInstruction with Op0 being deduced" {
38393839
// Compare each cell in VM's memory with the corresponding cell in the expected memory.
38403840
for (vm.segments.memory.data.items, 0..) |d, i| {
38413841
for (d.items, 0..) |cell, j| {
3842-
try expect(cell.eql(expected_memory.data.items[i].items[j]));
3842+
try expect(cell.eql(&expected_memory.data.items[i].items[j]));
38433843
}
38443844
}
38453845
}

src/vm/memory/memory.zig

Lines changed: 58 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ const RangeCheckBuiltinRunner = @import("../builtins/builtin_runner/range_check.
2323
// Function that validates a memory address and returns a list of validated adresses
2424
pub const validation_rule = *const fn (Allocator, *Memory, Relocatable) anyerror!std.ArrayList(Relocatable);
2525

26-
pub const MemoryCell = struct {
26+
pub const MemoryCell = extern struct {
2727
/// Represents a memory cell that holds relocation information and access status.
2828
const Self = @This();
2929
const ACCESS_MASK: u64 = 1 << 62;
@@ -103,8 +103,12 @@ pub const MemoryCell = struct {
103103
/// # Returns
104104
///
105105
/// Returns `true` if both MemoryCell instances are equal, otherwise `false`.
106-
pub fn eql(self: Self, other: Self) bool {
107-
return std.mem.eql(u64, self.data[0..], other.data[0..]);
106+
pub fn eql(self: *const Self, other: *const Self) bool {
107+
inline for (0..4) |i| {
108+
if (self.data[i] != other.data[i]) return false;
109+
}
110+
111+
return true;
108112
}
109113

110114
/// Checks equality between slices of MemoryCell instances.
@@ -124,7 +128,7 @@ pub const MemoryCell = struct {
124128
if (a.len != b.len) return false;
125129
if (a.ptr == b.ptr) return true;
126130

127-
for (a, b) |a_elem, b_elem| {
131+
for (a, b) |*a_elem, *b_elem| {
128132
if (!a_elem.eql(b_elem)) return false;
129133
}
130134

@@ -609,20 +613,11 @@ pub const Memory = struct {
609613
/// # Returns
610614
///
611615
/// Returns the segment of MemoryCell items if it exists, or `null` if not found.
612-
fn getSegmentAtIndex(self: *Self, idx: i64) ?[]MemoryCell {
613-
return switch (idx < 0) {
614-
true => blk: {
615-
const i: usize = @intCast(-(idx + 1));
616-
break :blk if (i < self.temp_data.items.len)
617-
self.temp_data.items[i].items
618-
else
619-
null;
620-
},
621-
false => if (idx < self.data.items.len)
622-
self.data.items[@intCast(idx)].items
623-
else
624-
null,
625-
};
616+
pub inline fn getSegmentAtIndex(self: *const Self, idx: i64) ?[]MemoryCell {
617+
return if (idx < 0) {
618+
const i: usize = @bitCast(-(idx + 1));
619+
return if (i >= self.temp_data.items.len) null else self.temp_data.items[i].items;
620+
} else if (idx >= self.data.items.len) null else self.data.items[@intCast(idx)].items;
626621
}
627622

628623
/// Compares two memory segments within the VM's memory starting from specified addresses
@@ -663,12 +658,6 @@ pub const Memory = struct {
663658
const l_idx = lhs.offset + i;
664659
const r_idx = rhs.offset + i;
665660

666-
// std.log.err("lhs: {any}, rhs: {any}, i: {any}, {any}", .{
667-
// if (l_idx < ls.len) ls[l_idx] else MemoryCell.NONE, if (r_idx < rs.len) rs[r_idx] else MemoryCell.NONE, i, MemoryCell.cmp(
668-
// if (l_idx < ls.len) ls[l_idx] else MemoryCell.NONE,
669-
// if (r_idx < rs.len) rs[r_idx] else MemoryCell.NONE,
670-
// ),
671-
// });
672661
return switch (MemoryCell.cmp(
673662
if (l_idx < ls.len) ls[l_idx] else MemoryCell.NONE,
674663
if (r_idx < rs.len) rs[r_idx] else MemoryCell.NONE,
@@ -700,7 +689,7 @@ pub const Memory = struct {
700689
/// # Returns
701690
///
702691
/// Returns `true` if segments are equal up to the specified length, otherwise `false`.
703-
pub fn memEq(self: *Self, lhs: Relocatable, rhs: Relocatable, len: usize) !bool {
692+
pub fn memEq(self: *const Self, lhs: Relocatable, rhs: Relocatable, len: usize) !bool {
704693
// Check if the left and right addresses are the same, in which case the segments are equal.
705694
if (lhs.eq(rhs)) return true;
706695

@@ -714,29 +703,25 @@ pub const Memory = struct {
714703
// Get the segment starting from the right-hand address.
715704
const r: ?[]MemoryCell = if (self.getSegmentAtIndex(rhs.segment_index)) |s|
716705
// Check if the offset is within the bounds of the segment.
717-
if (rhs.offset < s.len) s[rhs.offset..] else if (l == null) return true else return false
718-
else if (l == null) return true else return false;
706+
if (rhs.offset < s.len) s[rhs.offset..] else return l == null
707+
else
708+
return l == null;
719709

720710
// If the left segment exists, perform further checks.
721711
if (l) |ls| {
722712
// If the right segment also exists, compare the segments up to the specified length.
723-
if (r) |rs| {
724-
// Determine the actual lengths to compare.
725-
const lhs_len = @min(ls.len, len);
726-
const rhs_len = @min(rs.len, len);
713+
// Determine the actual lengths to compare.
714+
const lhs_len = @min(ls.len, len);
715+
const rhs_len = @min(r.?.len, len);
727716

728-
// Compare slices of MemoryCell items up to the specified length.
729-
if (lhs_len != rhs_len) return false;
730-
731-
return MemoryCell.eqlSlice(ls[0..lhs_len], rs[0..rhs_len]);
732-
}
717+
// Compare slices of MemoryCell items up to the specified length.
718+
if (lhs_len != rhs_len) return false;
733719

734-
// If only the left segment exists, return false.
735-
return false;
720+
return MemoryCell.eqlSlice(ls[0..lhs_len], r.?[0..rhs_len]);
736721
}
737722

738-
// If the left segment does not exist, return true only if the right segment is also null.
739-
return r == null;
723+
// If only the left segment exists, return false.
724+
return false;
740725
}
741726

742727
/// Retrieves a range of memory values starting from a specified address.
@@ -769,6 +754,36 @@ pub const Memory = struct {
769754
return values;
770755
}
771756

757+
/// Retrieves a range of memory values starting from a specified address.
758+
///
759+
/// # Arguments
760+
///
761+
/// * `allocator`: The allocator used for the memory allocation of the returned list.
762+
/// * `address`: The starting address in the memory from which the range is retrieved.
763+
/// * `size`: The size of the range to be retrieved.
764+
///
765+
/// # Returns
766+
///
767+
/// Returns a list containing memory values retrieved from the specified range starting at the given address.
768+
/// The list may contain `MemoryCell.NONE` elements for inaccessible memory positions.
769+
///
770+
/// # Errors
771+
///
772+
/// Returns an error if there are any issues encountered during the retrieval of the memory range.
773+
pub fn getRangeRaw(
774+
self: *Self,
775+
allocator: Allocator,
776+
address: Relocatable,
777+
size: usize,
778+
) !std.ArrayList(?MaybeRelocatable) {
779+
var values = std.ArrayList(?MaybeRelocatable).init(allocator);
780+
errdefer values.deinit();
781+
for (0..size) |i| {
782+
try values.append(self.get(try address.addUint(i)));
783+
}
784+
return values;
785+
}
786+
772787
/// Counts the number of accessed addresses within a specified segment in the VM memory.
773788
///
774789
/// # Arguments
@@ -2426,9 +2441,9 @@ test "MemoryCell: eql function" {
24262441
memoryCell4.markAccessed();
24272442

24282443
// Test checks
2429-
try expect(memoryCell1.eql(memoryCell2));
2430-
try expect(!memoryCell1.eql(memoryCell3));
2431-
try expect(!memoryCell1.eql(memoryCell4));
2444+
try expect(memoryCell1.eql(&memoryCell2));
2445+
try expect(!memoryCell1.eql(&memoryCell3));
2446+
try expect(!memoryCell1.eql(&memoryCell4));
24322447
}
24332448

24342449
test "MemoryCell: eqlSlice should return false if slice len are not the same" {

src/vm/memory/relocatable.zig

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,16 @@ pub const MaybeRelocatable = union(enum) {
312312
/// * `true` if the two instances are equal.
313313
/// * `false` otherwise.
314314
pub fn eq(self: Self, other: Self) bool {
315-
return std.meta.eql(self, other);
315+
return switch (self) {
316+
inline .felt => |f| switch (other) {
317+
inline .felt => |f1| f.eql(f1),
318+
else => false,
319+
},
320+
inline .relocatable => |r| switch (other) {
321+
inline .relocatable => |r1| r.eq(r1),
322+
else => false,
323+
},
324+
};
316325
}
317326

318327
/// Determines if self is less than other.

0 commit comments

Comments
 (0)