Skip to content

Commit 33efb96

Browse files
Parallelize keccak circuit slots (#322)
* parallelize keccak * improve fixed const load * Cargo clippy, fmt and update * Cargo fmt --------- Co-authored-by: roger.taule <[email protected]>
1 parent c7a8013 commit 33efb96

File tree

2 files changed

+114
-65
lines changed

2 files changed

+114
-65
lines changed

Cargo.lock

+23-23
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

precompiles/keccakf/src/keccakf.rs

+91-42
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ use zisk_pil::{KeccakfFixed, KeccakfTrace, KeccakfTraceRow};
1313

1414
use crate::{keccakf_constants::*, KeccakfTableGateOp, KeccakfTableSM, Script, ValueType};
1515

16+
use rayon::prelude::*;
17+
1618
/// The `KeccakfSM` struct encapsulates the logic of the Keccakf State Machine.
1719
pub struct KeccakfSM {
1820
/// Reference to the Keccakf Table State Machine.
@@ -238,35 +240,60 @@ impl KeccakfSM {
238240

239241
// Set the values of free_in_a, free_in_b, free_in_c using the script
240242
let script = self.script.clone();
241-
let mut offset = 0;
242-
for (i, input) in inputs_bits.iter().enumerate() {
243+
244+
let row0 = trace.buffer[0];
245+
246+
let mut trace_slice = &mut trace.buffer[1..];
247+
let mut par_traces = Vec::new();
248+
249+
for _ in 0..inputs_bits.len() {
250+
// while !par_traces.is_empty() {
251+
let take = self.slot_size.min(trace_slice.len());
252+
let (head, tail) = trace_slice.split_at_mut(take);
253+
par_traces.push(head);
254+
trace_slice = tail;
255+
}
256+
257+
par_traces.into_par_iter().enumerate().for_each(|(i, par_trace)| {
243258
let mut bit_input_pos = [0u64; INPUT_DATA_SIZE_BITS];
244259
let mut bit_output_pos = [0u64; INPUT_DATA_SIZE_BITS];
260+
245261
for j in 0..self.slot_size {
246262
let line = &script.program[j];
247-
let row = line.ref_ + i * self.slot_size;
263+
let row = line.ref_ - 1;
248264

249265
let a = &line.a;
250266
match a {
251267
ValueType::Input(a) => {
252-
set_col(trace, |row| &mut row.free_in_a, row, input[a.bit]);
268+
set_col(par_trace, |row| &mut row.free_in_a, row, inputs_bits[i][a.bit]);
253269
}
254270
ValueType::Wired(b) => {
255-
let mut gate = b.gate;
256-
if gate > 0 {
257-
gate += offset;
258-
}
271+
let gate = b.gate;
259272

260273
let pin = &b.pin;
261274
if pin == "a" {
262-
let pinned_value = get_col(trace, |row| &mut row.free_in_a, gate);
263-
set_col(trace, |row| &mut row.free_in_a, row, pinned_value);
275+
let pinned_value = if gate > 0 {
276+
get_col(par_trace, |row| &row.free_in_a, gate - 1)
277+
} else {
278+
get_col_row(&row0, |row| &row.free_in_a)
279+
};
280+
set_col(par_trace, |row| &mut row.free_in_a, row, pinned_value);
264281
} else if pin == "b" {
265-
let pinned_value = get_col(trace, |row| &mut row.free_in_b, gate);
266-
set_col(trace, |row| &mut row.free_in_a, row, pinned_value);
282+
let pinned_value = if gate > 0 {
283+
get_col(par_trace, |row| &row.free_in_b, gate - 1)
284+
} else {
285+
get_col_row(&row0, |row| &row.free_in_b)
286+
};
287+
288+
set_col(par_trace, |row| &mut row.free_in_a, row, pinned_value);
267289
} else if pin == "c" {
268-
let pinned_value = get_col(trace, |row| &mut row.free_in_c, gate);
269-
set_col(trace, |row| &mut row.free_in_a, row, pinned_value);
290+
let pinned_value = if gate > 0 {
291+
get_col(par_trace, |row| &row.free_in_c, gate - 1)
292+
} else {
293+
get_col_row(&row0, |row| &row.free_in_c)
294+
};
295+
296+
set_col(par_trace, |row| &mut row.free_in_a, row, pinned_value);
270297
} else {
271298
panic!("Invalid pin");
272299
}
@@ -276,32 +303,44 @@ impl KeccakfSM {
276303
let b = &line.b;
277304
match b {
278305
ValueType::Input(b) => {
279-
set_col(trace, |row| &mut row.free_in_b, row, input[b.bit]);
306+
set_col(par_trace, |row| &mut row.free_in_b, row, inputs_bits[i][b.bit]);
280307
}
281308
ValueType::Wired(b) => {
282-
let mut gate = b.gate;
283-
if gate > 0 {
284-
gate += offset;
285-
}
309+
let gate = b.gate;
286310

287311
let pin = &b.pin;
288312
if pin == "a" {
289-
let pinned_value = get_col(trace, |row| &mut row.free_in_a, gate);
290-
set_col(trace, |row| &mut row.free_in_b, row, pinned_value);
313+
let pinned_value = if gate > 0 {
314+
get_col(par_trace, |row| &row.free_in_a, gate - 1)
315+
} else {
316+
get_col_row(&row0, |row| &row.free_in_a)
317+
};
318+
319+
set_col(par_trace, |row| &mut row.free_in_b, row, pinned_value);
291320
} else if pin == "b" {
292-
let pinned_value = get_col(trace, |row| &mut row.free_in_b, gate);
293-
set_col(trace, |row| &mut row.free_in_b, row, pinned_value);
321+
let pinned_value = if gate > 0 {
322+
get_col(par_trace, |row| &row.free_in_b, gate - 1)
323+
} else {
324+
get_col_row(&row0, |row| &row.free_in_b)
325+
};
326+
327+
set_col(par_trace, |row| &mut row.free_in_b, row, pinned_value);
294328
} else if pin == "c" {
295-
let pinned_value = get_col(trace, |row| &mut row.free_in_c, gate);
296-
set_col(trace, |row| &mut row.free_in_b, row, pinned_value);
329+
let pinned_value = if gate > 0 {
330+
get_col(par_trace, |row| &row.free_in_c, gate - 1)
331+
} else {
332+
get_col_row(&row0, |row| &row.free_in_c)
333+
};
334+
335+
set_col(par_trace, |row| &mut row.free_in_b, row, pinned_value);
297336
} else {
298337
panic!("Invalid pin");
299338
}
300339
}
301340
}
302341

303-
let a_val = get_col(trace, |row| &mut row.free_in_a, row) & MASK_CHUNK_BITS_KECCAKF;
304-
let b_val = get_col(trace, |row| &mut row.free_in_b, row) & MASK_CHUNK_BITS_KECCAKF;
342+
let a_val = get_col(par_trace, |row| &row.free_in_a, row) & MASK_CHUNK_BITS_KECCAKF;
343+
let b_val = get_col(par_trace, |row| &row.free_in_b, row) & MASK_CHUNK_BITS_KECCAKF;
305344
let op = &line.op;
306345
let c_val;
307346
if op == "xor" {
@@ -312,7 +351,7 @@ impl KeccakfSM {
312351
panic!("Invalid operation");
313352
}
314353

315-
set_col(trace, |row| &mut row.free_in_c, row, c_val);
354+
set_col(par_trace, |row| &mut row.free_in_c, row, c_val);
316355

317356
if (line.ref_ >= STATE_IN_REF_0)
318357
&& (line.ref_
@@ -340,11 +379,10 @@ impl KeccakfSM {
340379
}
341380

342381
// Update the multiplicity table for the slot
343-
let row_idx = if offset == 0 { 1 } else { offset + 1 };
344-
for i in row_idx..(row_idx + self.slot_size) {
345-
let a = trace[i].free_in_a;
346-
let b = trace[i].free_in_b;
347-
let gate_op = fixed[i].GATE_OP;
382+
for k in 0..self.slot_size {
383+
let a = par_trace[k].free_in_a;
384+
let b = par_trace[k].free_in_b;
385+
let gate_op = fixed[k + 1 + i * self.slot_size].GATE_OP;
348386
let gate_op_val = match F::as_canonical_u64(&gate_op) {
349387
0u64 => KeccakfTableGateOp::Xor,
350388
1u64 => KeccakfTableGateOp::Andp,
@@ -357,10 +395,7 @@ impl KeccakfSM {
357395
self.keccakf_table_sm.update_input(table_row, 1);
358396
}
359397
}
360-
361-
// Move to the next slot
362-
offset += self.slot_size;
363-
}
398+
});
364399

365400
fn update_bit_val<F: PrimeField64>(
366401
fixed: &KeccakfFixed<F>,
@@ -379,7 +414,7 @@ impl KeccakfSM {
379414
}
380415

381416
fn set_col<F: PrimeField64>(
382-
trace: &mut KeccakfTrace<F>,
417+
trace: &mut [KeccakfTraceRow<F>],
383418
cols: impl Fn(&mut KeccakfTraceRow<F>) -> &mut [F; CHUNKS_KECCAKF],
384419
index: usize,
385420
value: u64,
@@ -394,12 +429,26 @@ impl KeccakfSM {
394429
}
395430

396431
fn get_col<F: PrimeField64>(
397-
trace: &mut KeccakfTrace<F>,
398-
cols: impl Fn(&mut KeccakfTraceRow<F>) -> &mut [F; CHUNKS_KECCAKF],
432+
trace: &[KeccakfTraceRow<F>],
433+
cols: impl Fn(&KeccakfTraceRow<F>) -> &[F; CHUNKS_KECCAKF],
399434
index: usize,
400435
) -> u64 {
401436
let mut value = 0;
402-
let row = &mut trace[index];
437+
let row = &trace[index];
438+
let cols = cols(row);
439+
for (i, col) in cols.iter().enumerate() {
440+
let col_i_val = F::as_canonical_u64(col);
441+
value += col_i_val << ((i * BITS_KECCAKF) as u64);
442+
}
443+
value
444+
}
445+
446+
fn get_col_row<F: PrimeField64>(
447+
trace_row: &KeccakfTraceRow<F>,
448+
cols: impl Fn(&KeccakfTraceRow<F>) -> &[F; CHUNKS_KECCAKF],
449+
) -> u64 {
450+
let mut value = 0;
451+
let row = trace_row;
403452
let cols = cols(row);
404453
for (i, col) in cols.iter().enumerate() {
405454
let col_i_val = F::as_canonical_u64(col);
@@ -425,7 +474,7 @@ impl KeccakfSM {
425474
let airgroup_id = KeccakfTrace::<usize>::AIRGROUP_ID;
426475
let air_id = KeccakfTrace::<usize>::AIR_ID;
427476
let fixed_pols = sctx.get_fixed(airgroup_id, air_id);
428-
let fixed = KeccakfFixed::from_slice(&fixed_pols);
477+
let fixed = KeccakfFixed::from_vec(fixed_pols);
429478

430479
timer_start_trace!(KECCAKF_TRACE);
431480
let mut keccakf_trace = KeccakfTrace::new();

0 commit comments

Comments
 (0)