@@ -5,7 +5,7 @@ use std::ops::Deref;
5
5
6
6
use itertools:: Itertools ;
7
7
use p3_air:: { Air , AirBuilder , BaseAir } ;
8
- use p3_field:: PrimeField32 ;
8
+ use p3_field:: { AbstractField , PrimeField32 } ;
9
9
use p3_matrix:: dense:: RowMajorMatrix ;
10
10
use p3_matrix:: Matrix ;
11
11
use sp1_core:: air:: { BaseAirBuilder , MachineAir } ;
@@ -37,6 +37,14 @@ pub struct MultiCols<T: Copy> {
37
37
38
38
pub is_poseidon2 : T ,
39
39
40
+ /// A flag column to indicate whether the row is the first poseidon2 row.
41
+ pub poseidon2_first_row : T ,
42
+ /// A flag column to indicate whether the row is the last poseidon2 row.
43
+ pub poseidon2_last_row : T ,
44
+
45
+ /// Similar for Fri_fold.
46
+ pub fri_fold_last_row : T ,
47
+
40
48
/// Rows that needs to receive a poseidon2 syscall.
41
49
pub poseidon2_receive_table : T ,
42
50
/// Hash/Permute state entries that needs to access memory. This is for the the first half of the permute state.
@@ -81,6 +89,9 @@ impl<F: PrimeField32, const DEGREE: usize> MachineAir<F> for MultiChip<DEGREE> {
81
89
let fri_fold_trace = fri_fold_chip. generate_trace ( input, output) ;
82
90
let mut poseidon2_trace = poseidon2. generate_trace ( input, output) ;
83
91
92
+ let fri_fold_height = fri_fold_trace. height ( ) ;
93
+ let poseidon2_height = poseidon2_trace. height ( ) ;
94
+
84
95
let num_columns = <MultiChip < DEGREE > as BaseAir < F > >:: width ( self ) ;
85
96
86
97
let mut rows = fri_fold_trace
@@ -104,6 +115,9 @@ impl<F: PrimeField32, const DEGREE: usize> MachineAir<F> for MultiChip<DEGREE> {
104
115
FriFoldChip :: < DEGREE > :: do_receive_table ( fri_fold_cols) ;
105
116
multi_cols. fri_fold_memory_access =
106
117
FriFoldChip :: < DEGREE > :: do_memory_access ( fri_fold_cols) ;
118
+ if i == fri_fold_trace. height ( ) - 1 {
119
+ multi_cols. fri_fold_last_row = F :: one ( ) ;
120
+ }
107
121
} else {
108
122
let multi_cols: & mut MultiCols < F > = row[ 0 ..NUM_MULTI_COLS ] . borrow_mut ( ) ;
109
123
multi_cols. is_poseidon2 = F :: one ( ) ;
@@ -116,6 +130,11 @@ impl<F: PrimeField32, const DEGREE: usize> MachineAir<F> for MultiChip<DEGREE> {
116
130
multi_cols. poseidon2_2nd_half_memory_access =
117
131
poseidon2_cols. control_flow ( ) . is_compress ;
118
132
multi_cols. poseidon2_send_range_check = poseidon2_cols. control_flow ( ) . is_absorb ;
133
+
134
+ // The first row of the poseidon2 trace has index fri_fold_trace.height()
135
+ multi_cols. poseidon2_first_row = F :: from_bool ( i == fri_fold_height) ;
136
+ multi_cols. poseidon2_last_row =
137
+ F :: from_bool ( i == fri_fold_height + poseidon2_height - 1 ) ;
119
138
}
120
139
121
140
row
@@ -169,6 +188,37 @@ where
169
188
builder. assert_bool ( local_multi_cols. is_poseidon2 ) ;
170
189
builder. assert_bool ( local_is_real. clone ( ) ) ;
171
190
191
+ // Constrain the flags to be boolean.
192
+ builder. assert_bool ( local_multi_cols. poseidon2_first_row ) ;
193
+ builder. assert_bool ( local_multi_cols. poseidon2_last_row ) ;
194
+ builder. assert_bool ( local_multi_cols. fri_fold_last_row ) ;
195
+
196
+ // Constrain that the flags are computed correctly.
197
+ builder. when_transition ( ) . assert_eq (
198
+ local_multi_cols. is_fri_fold * ( AB :: Expr :: one ( ) - next_multi_cols. is_fri_fold ) ,
199
+ local_multi_cols. fri_fold_last_row ,
200
+ ) ;
201
+ builder. when_last_row ( ) . assert_eq (
202
+ local_multi_cols. is_fri_fold ,
203
+ local_multi_cols. fri_fold_last_row ,
204
+ ) ;
205
+ builder. when_first_row ( ) . assert_eq (
206
+ local_multi_cols. is_poseidon2 ,
207
+ local_multi_cols. poseidon2_first_row ,
208
+ ) ;
209
+ builder. when_transition ( ) . assert_eq (
210
+ next_multi_cols. poseidon2_first_row ,
211
+ local_multi_cols. is_fri_fold * next_multi_cols. is_poseidon2 ,
212
+ ) ;
213
+ builder. when_transition ( ) . assert_eq (
214
+ local_multi_cols. is_poseidon2 * ( AB :: Expr :: one ( ) - next_multi_cols. is_poseidon2 ) ,
215
+ local_multi_cols. poseidon2_last_row ,
216
+ ) ;
217
+ builder. when_last_row ( ) . assert_eq (
218
+ local_multi_cols. is_poseidon2 ,
219
+ local_multi_cols. poseidon2_last_row ,
220
+ ) ;
221
+
172
222
// Fri fold requires that it's rows are contiguous, since each invocation spans multiple rows
173
223
// and it's AIR checks for consistencies among them. The following constraints enforce that
174
224
// all the fri fold rows are first, then the posiedon2 rows, and finally any padded (non-real) rows.
@@ -189,6 +239,8 @@ where
189
239
let mut sub_builder = MultiBuilder :: new (
190
240
builder,
191
241
local_multi_cols. is_fri_fold . into ( ) ,
242
+ builder. is_first_row ( ) ,
243
+ local_multi_cols. fri_fold_last_row . into ( ) ,
192
244
next_multi_cols. is_fri_fold . into ( ) ,
193
245
) ;
194
246
@@ -218,6 +270,8 @@ where
218
270
let mut sub_builder = MultiBuilder :: new (
219
271
builder,
220
272
local_multi_cols. is_poseidon2 . into ( ) ,
273
+ local_multi_cols. poseidon2_first_row . into ( ) ,
274
+ local_multi_cols. poseidon2_last_row . into ( ) ,
221
275
next_multi_cols. is_poseidon2 . into ( ) ,
222
276
) ;
223
277
0 commit comments