Skip to content

Commit 094d579

Browse files
authored
fix: multi-builder first/last row issue (#997)
1 parent 2f6f3c8 commit 094d579

File tree

2 files changed

+72
-4
lines changed

2 files changed

+72
-4
lines changed

recursion/core/src/air/multi_builder.rs

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,28 @@ use sp1_core::air::MessageBuilder;
88
/// the sub tables in the multi table.
99
pub struct MultiBuilder<'a, AB: AirBuilder> {
1010
inner: FilteredAirBuilder<'a, AB>,
11+
12+
/// These fields are used to determine whether a row is is the first or last row of the subtable,
13+
/// which requires hinting from the parent table.
14+
is_first_row: AB::Expr,
15+
is_last_row: AB::Expr,
16+
1117
next_condition: AB::Expr,
1218
}
1319

1420
impl<'a, AB: AirBuilder> MultiBuilder<'a, AB> {
15-
pub fn new(builder: &'a mut AB, local_condition: AB::Expr, next_condition: AB::Expr) -> Self {
21+
pub fn new(
22+
builder: &'a mut AB,
23+
local_condition: AB::Expr,
24+
is_first_row: AB::Expr,
25+
is_last_row: AB::Expr,
26+
next_condition: AB::Expr,
27+
) -> Self {
1628
let inner = builder.when(local_condition.clone());
1729
Self {
1830
inner,
31+
is_first_row,
32+
is_last_row,
1933
next_condition,
2034
}
2135
}
@@ -32,11 +46,11 @@ impl<'a, AB: AirBuilder> AirBuilder for MultiBuilder<'a, AB> {
3246
}
3347

3448
fn is_first_row(&self) -> Self::Expr {
35-
self.inner.is_first_row()
49+
self.is_first_row.clone()
3650
}
3751

3852
fn is_last_row(&self) -> Self::Expr {
39-
self.inner.is_last_row()
53+
self.is_last_row.clone()
4054
}
4155

4256
fn is_transition_window(&self, size: usize) -> Self::Expr {

recursion/core/src/multi/mod.rs

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use std::ops::Deref;
55

66
use itertools::Itertools;
77
use p3_air::{Air, AirBuilder, BaseAir};
8-
use p3_field::PrimeField32;
8+
use p3_field::{AbstractField, PrimeField32};
99
use p3_matrix::dense::RowMajorMatrix;
1010
use p3_matrix::Matrix;
1111
use sp1_core::air::{BaseAirBuilder, MachineAir};
@@ -37,6 +37,14 @@ pub struct MultiCols<T: Copy> {
3737

3838
pub is_poseidon2: T,
3939

40+
/// A flag column to indicate whether the row is the first poseidon2 row.
41+
pub poseidon2_first_row: T,
42+
/// A flag column to indicate whether the row is the last poseidon2 row.
43+
pub poseidon2_last_row: T,
44+
45+
/// Similar for Fri_fold.
46+
pub fri_fold_last_row: T,
47+
4048
/// Rows that needs to receive a poseidon2 syscall.
4149
pub poseidon2_receive_table: T,
4250
/// Hash/Permute state entries that needs to access memory. This is for the the first half of the permute state.
@@ -81,6 +89,9 @@ impl<F: PrimeField32, const DEGREE: usize> MachineAir<F> for MultiChip<DEGREE> {
8189
let fri_fold_trace = fri_fold_chip.generate_trace(input, output);
8290
let mut poseidon2_trace = poseidon2.generate_trace(input, output);
8391

92+
let fri_fold_height = fri_fold_trace.height();
93+
let poseidon2_height = poseidon2_trace.height();
94+
8495
let num_columns = <MultiChip<DEGREE> as BaseAir<F>>::width(self);
8596

8697
let mut rows = fri_fold_trace
@@ -104,6 +115,9 @@ impl<F: PrimeField32, const DEGREE: usize> MachineAir<F> for MultiChip<DEGREE> {
104115
FriFoldChip::<DEGREE>::do_receive_table(fri_fold_cols);
105116
multi_cols.fri_fold_memory_access =
106117
FriFoldChip::<DEGREE>::do_memory_access(fri_fold_cols);
118+
if i == fri_fold_trace.height() - 1 {
119+
multi_cols.fri_fold_last_row = F::one();
120+
}
107121
} else {
108122
let multi_cols: &mut MultiCols<F> = row[0..NUM_MULTI_COLS].borrow_mut();
109123
multi_cols.is_poseidon2 = F::one();
@@ -116,6 +130,11 @@ impl<F: PrimeField32, const DEGREE: usize> MachineAir<F> for MultiChip<DEGREE> {
116130
multi_cols.poseidon2_2nd_half_memory_access =
117131
poseidon2_cols.control_flow().is_compress;
118132
multi_cols.poseidon2_send_range_check = poseidon2_cols.control_flow().is_absorb;
133+
134+
// The first row of the poseidon2 trace has index fri_fold_trace.height()
135+
multi_cols.poseidon2_first_row = F::from_bool(i == fri_fold_height);
136+
multi_cols.poseidon2_last_row =
137+
F::from_bool(i == fri_fold_height + poseidon2_height - 1);
119138
}
120139

121140
row
@@ -169,6 +188,37 @@ where
169188
builder.assert_bool(local_multi_cols.is_poseidon2);
170189
builder.assert_bool(local_is_real.clone());
171190

191+
// Constrain the flags to be boolean.
192+
builder.assert_bool(local_multi_cols.poseidon2_first_row);
193+
builder.assert_bool(local_multi_cols.poseidon2_last_row);
194+
builder.assert_bool(local_multi_cols.fri_fold_last_row);
195+
196+
// Constrain that the flags are computed correctly.
197+
builder.when_transition().assert_eq(
198+
local_multi_cols.is_fri_fold * (AB::Expr::one() - next_multi_cols.is_fri_fold),
199+
local_multi_cols.fri_fold_last_row,
200+
);
201+
builder.when_last_row().assert_eq(
202+
local_multi_cols.is_fri_fold,
203+
local_multi_cols.fri_fold_last_row,
204+
);
205+
builder.when_first_row().assert_eq(
206+
local_multi_cols.is_poseidon2,
207+
local_multi_cols.poseidon2_first_row,
208+
);
209+
builder.when_transition().assert_eq(
210+
next_multi_cols.poseidon2_first_row,
211+
local_multi_cols.is_fri_fold * next_multi_cols.is_poseidon2,
212+
);
213+
builder.when_transition().assert_eq(
214+
local_multi_cols.is_poseidon2 * (AB::Expr::one() - next_multi_cols.is_poseidon2),
215+
local_multi_cols.poseidon2_last_row,
216+
);
217+
builder.when_last_row().assert_eq(
218+
local_multi_cols.is_poseidon2,
219+
local_multi_cols.poseidon2_last_row,
220+
);
221+
172222
// Fri fold requires that it's rows are contiguous, since each invocation spans multiple rows
173223
// and it's AIR checks for consistencies among them. The following constraints enforce that
174224
// all the fri fold rows are first, then the posiedon2 rows, and finally any padded (non-real) rows.
@@ -189,6 +239,8 @@ where
189239
let mut sub_builder = MultiBuilder::new(
190240
builder,
191241
local_multi_cols.is_fri_fold.into(),
242+
builder.is_first_row(),
243+
local_multi_cols.fri_fold_last_row.into(),
192244
next_multi_cols.is_fri_fold.into(),
193245
);
194246

@@ -218,6 +270,8 @@ where
218270
let mut sub_builder = MultiBuilder::new(
219271
builder,
220272
local_multi_cols.is_poseidon2.into(),
273+
local_multi_cols.poseidon2_first_row.into(),
274+
local_multi_cols.poseidon2_last_row.into(),
221275
next_multi_cols.is_poseidon2.into(),
222276
);
223277

0 commit comments

Comments
 (0)