@@ -13,6 +13,8 @@ use zisk_pil::{KeccakfFixed, KeccakfTrace, KeccakfTraceRow};
13
13
14
14
use crate :: { keccakf_constants:: * , KeccakfTableGateOp , KeccakfTableSM , Script , ValueType } ;
15
15
16
+ use rayon:: prelude:: * ;
17
+
16
18
/// The `KeccakfSM` struct encapsulates the logic of the Keccakf State Machine.
17
19
pub struct KeccakfSM {
18
20
/// Reference to the Keccakf Table State Machine.
@@ -238,35 +240,60 @@ impl KeccakfSM {
238
240
239
241
// Set the values of free_in_a, free_in_b, free_in_c using the script
240
242
let script = self . script . clone ( ) ;
241
- let mut offset = 0 ;
242
- for ( i, input) in inputs_bits. iter ( ) . enumerate ( ) {
243
+
244
+ let row0 = trace. buffer [ 0 ] ;
245
+
246
+ let mut trace_slice = & mut trace. buffer [ 1 ..] ;
247
+ let mut par_traces = Vec :: new ( ) ;
248
+
249
+ for _ in 0 ..inputs_bits. len ( ) {
250
+ // while !par_traces.is_empty() {
251
+ let take = self . slot_size . min ( trace_slice. len ( ) ) ;
252
+ let ( head, tail) = trace_slice. split_at_mut ( take) ;
253
+ par_traces. push ( head) ;
254
+ trace_slice = tail;
255
+ }
256
+
257
+ par_traces. into_par_iter ( ) . enumerate ( ) . for_each ( |( i, par_trace) | {
243
258
let mut bit_input_pos = [ 0u64 ; INPUT_DATA_SIZE_BITS ] ;
244
259
let mut bit_output_pos = [ 0u64 ; INPUT_DATA_SIZE_BITS ] ;
260
+
245
261
for j in 0 ..self . slot_size {
246
262
let line = & script. program [ j] ;
247
- let row = line. ref_ + i * self . slot_size ;
263
+ let row = line. ref_ - 1 ;
248
264
249
265
let a = & line. a ;
250
266
match a {
251
267
ValueType :: Input ( a) => {
252
- set_col ( trace , |row| & mut row. free_in_a , row, input [ a. bit ] ) ;
268
+ set_col ( par_trace , |row| & mut row. free_in_a , row, inputs_bits [ i ] [ a. bit ] ) ;
253
269
}
254
270
ValueType :: Wired ( b) => {
255
- let mut gate = b. gate ;
256
- if gate > 0 {
257
- gate += offset;
258
- }
271
+ let gate = b. gate ;
259
272
260
273
let pin = & b. pin ;
261
274
if pin == "a" {
262
- let pinned_value = get_col ( trace, |row| & mut row. free_in_a , gate) ;
263
- set_col ( trace, |row| & mut row. free_in_a , row, pinned_value) ;
275
+ let pinned_value = if gate > 0 {
276
+ get_col ( par_trace, |row| & row. free_in_a , gate - 1 )
277
+ } else {
278
+ get_col_row ( & row0, |row| & row. free_in_a )
279
+ } ;
280
+ set_col ( par_trace, |row| & mut row. free_in_a , row, pinned_value) ;
264
281
} else if pin == "b" {
265
- let pinned_value = get_col ( trace, |row| & mut row. free_in_b , gate) ;
266
- set_col ( trace, |row| & mut row. free_in_a , row, pinned_value) ;
282
+ let pinned_value = if gate > 0 {
283
+ get_col ( par_trace, |row| & row. free_in_b , gate - 1 )
284
+ } else {
285
+ get_col_row ( & row0, |row| & row. free_in_b )
286
+ } ;
287
+
288
+ set_col ( par_trace, |row| & mut row. free_in_a , row, pinned_value) ;
267
289
} else if pin == "c" {
268
- let pinned_value = get_col ( trace, |row| & mut row. free_in_c , gate) ;
269
- set_col ( trace, |row| & mut row. free_in_a , row, pinned_value) ;
290
+ let pinned_value = if gate > 0 {
291
+ get_col ( par_trace, |row| & row. free_in_c , gate - 1 )
292
+ } else {
293
+ get_col_row ( & row0, |row| & row. free_in_c )
294
+ } ;
295
+
296
+ set_col ( par_trace, |row| & mut row. free_in_a , row, pinned_value) ;
270
297
} else {
271
298
panic ! ( "Invalid pin" ) ;
272
299
}
@@ -276,32 +303,44 @@ impl KeccakfSM {
276
303
let b = & line. b ;
277
304
match b {
278
305
ValueType :: Input ( b) => {
279
- set_col ( trace , |row| & mut row. free_in_b , row, input [ b. bit ] ) ;
306
+ set_col ( par_trace , |row| & mut row. free_in_b , row, inputs_bits [ i ] [ b. bit ] ) ;
280
307
}
281
308
ValueType :: Wired ( b) => {
282
- let mut gate = b. gate ;
283
- if gate > 0 {
284
- gate += offset;
285
- }
309
+ let gate = b. gate ;
286
310
287
311
let pin = & b. pin ;
288
312
if pin == "a" {
289
- let pinned_value = get_col ( trace, |row| & mut row. free_in_a , gate) ;
290
- set_col ( trace, |row| & mut row. free_in_b , row, pinned_value) ;
313
+ let pinned_value = if gate > 0 {
314
+ get_col ( par_trace, |row| & row. free_in_a , gate - 1 )
315
+ } else {
316
+ get_col_row ( & row0, |row| & row. free_in_a )
317
+ } ;
318
+
319
+ set_col ( par_trace, |row| & mut row. free_in_b , row, pinned_value) ;
291
320
} else if pin == "b" {
292
- let pinned_value = get_col ( trace, |row| & mut row. free_in_b , gate) ;
293
- set_col ( trace, |row| & mut row. free_in_b , row, pinned_value) ;
321
+ let pinned_value = if gate > 0 {
322
+ get_col ( par_trace, |row| & row. free_in_b , gate - 1 )
323
+ } else {
324
+ get_col_row ( & row0, |row| & row. free_in_b )
325
+ } ;
326
+
327
+ set_col ( par_trace, |row| & mut row. free_in_b , row, pinned_value) ;
294
328
} else if pin == "c" {
295
- let pinned_value = get_col ( trace, |row| & mut row. free_in_c , gate) ;
296
- set_col ( trace, |row| & mut row. free_in_b , row, pinned_value) ;
329
+ let pinned_value = if gate > 0 {
330
+ get_col ( par_trace, |row| & row. free_in_c , gate - 1 )
331
+ } else {
332
+ get_col_row ( & row0, |row| & row. free_in_c )
333
+ } ;
334
+
335
+ set_col ( par_trace, |row| & mut row. free_in_b , row, pinned_value) ;
297
336
} else {
298
337
panic ! ( "Invalid pin" ) ;
299
338
}
300
339
}
301
340
}
302
341
303
- let a_val = get_col ( trace , |row| & mut row. free_in_a , row) & MASK_CHUNK_BITS_KECCAKF ;
304
- let b_val = get_col ( trace , |row| & mut row. free_in_b , row) & MASK_CHUNK_BITS_KECCAKF ;
342
+ let a_val = get_col ( par_trace , |row| & row. free_in_a , row) & MASK_CHUNK_BITS_KECCAKF ;
343
+ let b_val = get_col ( par_trace , |row| & row. free_in_b , row) & MASK_CHUNK_BITS_KECCAKF ;
305
344
let op = & line. op ;
306
345
let c_val;
307
346
if op == "xor" {
@@ -312,7 +351,7 @@ impl KeccakfSM {
312
351
panic ! ( "Invalid operation" ) ;
313
352
}
314
353
315
- set_col ( trace , |row| & mut row. free_in_c , row, c_val) ;
354
+ set_col ( par_trace , |row| & mut row. free_in_c , row, c_val) ;
316
355
317
356
if ( line. ref_ >= STATE_IN_REF_0 )
318
357
&& ( line. ref_
@@ -340,11 +379,10 @@ impl KeccakfSM {
340
379
}
341
380
342
381
// Update the multiplicity table for the slot
343
- let row_idx = if offset == 0 { 1 } else { offset + 1 } ;
344
- for i in row_idx..( row_idx + self . slot_size ) {
345
- let a = trace[ i] . free_in_a ;
346
- let b = trace[ i] . free_in_b ;
347
- let gate_op = fixed[ i] . GATE_OP ;
382
+ for k in 0 ..self . slot_size {
383
+ let a = par_trace[ k] . free_in_a ;
384
+ let b = par_trace[ k] . free_in_b ;
385
+ let gate_op = fixed[ k + 1 + i * self . slot_size ] . GATE_OP ;
348
386
let gate_op_val = match F :: as_canonical_u64 ( & gate_op) {
349
387
0u64 => KeccakfTableGateOp :: Xor ,
350
388
1u64 => KeccakfTableGateOp :: Andp ,
@@ -357,10 +395,7 @@ impl KeccakfSM {
357
395
self . keccakf_table_sm . update_input ( table_row, 1 ) ;
358
396
}
359
397
}
360
-
361
- // Move to the next slot
362
- offset += self . slot_size ;
363
- }
398
+ } ) ;
364
399
365
400
fn update_bit_val < F : PrimeField64 > (
366
401
fixed : & KeccakfFixed < F > ,
@@ -379,7 +414,7 @@ impl KeccakfSM {
379
414
}
380
415
381
416
fn set_col < F : PrimeField64 > (
382
- trace : & mut KeccakfTrace < F > ,
417
+ trace : & mut [ KeccakfTraceRow < F > ] ,
383
418
cols : impl Fn ( & mut KeccakfTraceRow < F > ) -> & mut [ F ; CHUNKS_KECCAKF ] ,
384
419
index : usize ,
385
420
value : u64 ,
@@ -394,12 +429,26 @@ impl KeccakfSM {
394
429
}
395
430
396
431
fn get_col < F : PrimeField64 > (
397
- trace : & mut KeccakfTrace < F > ,
398
- cols : impl Fn ( & mut KeccakfTraceRow < F > ) -> & mut [ F ; CHUNKS_KECCAKF ] ,
432
+ trace : & [ KeccakfTraceRow < F > ] ,
433
+ cols : impl Fn ( & KeccakfTraceRow < F > ) -> & [ F ; CHUNKS_KECCAKF ] ,
399
434
index : usize ,
400
435
) -> u64 {
401
436
let mut value = 0 ;
402
- let row = & mut trace[ index] ;
437
+ let row = & trace[ index] ;
438
+ let cols = cols ( row) ;
439
+ for ( i, col) in cols. iter ( ) . enumerate ( ) {
440
+ let col_i_val = F :: as_canonical_u64 ( col) ;
441
+ value += col_i_val << ( ( i * BITS_KECCAKF ) as u64 ) ;
442
+ }
443
+ value
444
+ }
445
+
446
+ fn get_col_row < F : PrimeField64 > (
447
+ trace_row : & KeccakfTraceRow < F > ,
448
+ cols : impl Fn ( & KeccakfTraceRow < F > ) -> & [ F ; CHUNKS_KECCAKF ] ,
449
+ ) -> u64 {
450
+ let mut value = 0 ;
451
+ let row = trace_row;
403
452
let cols = cols ( row) ;
404
453
for ( i, col) in cols. iter ( ) . enumerate ( ) {
405
454
let col_i_val = F :: as_canonical_u64 ( col) ;
@@ -425,7 +474,7 @@ impl KeccakfSM {
425
474
let airgroup_id = KeccakfTrace :: < usize > :: AIRGROUP_ID ;
426
475
let air_id = KeccakfTrace :: < usize > :: AIR_ID ;
427
476
let fixed_pols = sctx. get_fixed ( airgroup_id, air_id) ;
428
- let fixed = KeccakfFixed :: from_slice ( & fixed_pols) ;
477
+ let fixed = KeccakfFixed :: from_vec ( fixed_pols) ;
429
478
430
479
timer_start_trace ! ( KECCAKF_TRACE ) ;
431
480
let mut keccakf_trace = KeccakfTrace :: new ( ) ;
0 commit comments