diff --git a/Cargo.toml b/Cargo.toml index fc3ac1c..89e6a79 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,9 +17,11 @@ halo2_proofs = { git = "https://github.com/scroll-tech/halo2.git", branch = "scr [features] # Use an implementation using fewer rows (8) per permutation. short = [] +cached = [] # printout the layout of circuits for demo and some unittests print_layout = ["halo2_proofs/dev-graph"] legacy = [] +default = ["short","cached"] [dev-dependencies] rand = "0.8" @@ -32,6 +34,10 @@ subtle = "2" name = "hash" harness = false +[[bench]] +name = "synthesis" +harness = false + [profile.test] opt-level = 3 debug-assertions = true diff --git a/benches/synthesis.rs b/benches/synthesis.rs new file mode 100644 index 0000000..4f7c442 --- /dev/null +++ b/benches/synthesis.rs @@ -0,0 +1,70 @@ +#[macro_use] +extern crate bencher; +use bencher::Bencher; + +use halo2_proofs::dev::MockProver; +use halo2_proofs::halo2curves::{bn256::Fr as Fp, group::ff::PrimeField}; +use halo2_proofs::{ + circuit::{Layouter, SimpleFloorPlanner}, + plonk::{Circuit, ConstraintSystem, Error}, +}; +use poseidon_circuit::{hash::*, DEFAULT_STEP}; + +struct TestCircuit(PoseidonHashTable, usize); + +// test circuit derived from table data +impl Circuit for TestCircuit { + type Config = PoseidonHashConfig; + type FloorPlanner = SimpleFloorPlanner; + + fn without_witnesses(&self) -> Self { + Self(PoseidonHashTable::default(), self.1) + } + + fn configure(meta: &mut ConstraintSystem) -> Self::Config { + let hash_tbl = [0; 5].map(|_| meta.advice_column()); + let q_enable = meta.fixed_column(); + SpongeConfig::configure_sub(meta, (q_enable, hash_tbl), DEFAULT_STEP) + } + + fn synthesize( + &self, + config: Self::Config, + mut layouter: impl Layouter, + ) -> Result<(), Error> { + let chip = PoseidonHashChip::::construct(config, &self.0, self.1); + chip.load(&mut layouter) + } +} + +fn synthesis(bench: &mut Bencher) { + let message1 = [ + Fp::from_str_vartime("1").unwrap(), + Fp::from_str_vartime("2").unwrap(), + ]; + let message2 = [ + Fp::from_str_vartime("0").unwrap(), + Fp::from_str_vartime("1").unwrap(), + ]; + + let k = 12; + let circuit = TestCircuit( + PoseidonHashTable { + inputs: vec![message1, message2], + ..Default::default() + }, + 500, + ); + + bench.iter(|| { + MockProver::run(k, &circuit, vec![]).unwrap(); + }); +} + +fn synthesis_limited(bench: &mut Bencher) { + bench.bench_n(1, synthesis); +} + +benchmark_group!(syth_bench, synthesis_limited); + +benchmark_main!(syth_bench); diff --git a/src/poseidon/septidon/full_round.rs b/src/poseidon/septidon/full_round.rs index 5f0c349..756c2a4 100644 --- a/src/poseidon/septidon/full_round.rs +++ b/src/poseidon/septidon/full_round.rs @@ -1,7 +1,7 @@ use super::loop_chip::LoopBody; -use super::params::{mds, CachedConstants}; +use super::params::{calc::matmul, CachedConstants}; use super::state::{Cell, FullState, SBox}; -use super::util::{join_values, matmul, query, split_values}; +use super::util::{join_values, query, split_values}; use halo2_proofs::circuit::{Region, Value}; //use halo2_proofs::halo2curves::bn256::Fr as F; use halo2_proofs::plonk::{ConstraintSystem, Error, Expression, VirtualCells}; @@ -27,7 +27,7 @@ impl FullRoundChip { meta: &mut VirtualCells<'_, F>, ) -> [Expression; 3] { let sbox_out = self.0.map(|sbox: &SBox| sbox.output_expr(meta)); - matmul::expr(mds(), sbox_out) + matmul::expr(sbox_out) } pub fn input_cells(&self) -> [Cell; 3] { @@ -47,7 +47,7 @@ impl FullRoundChip { let sbox: &SBox = &self.0 .0[i]; sbox_out[i] = sbox.assign(region, offset, round_constants[i], input[i])?; } - let output = join_values(sbox_out).map(|sbox_out| matmul::value(mds(), sbox_out)); + let output = join_values(sbox_out).map(|sbox_out| matmul::value(sbox_out)); Ok(split_values(output)) } } diff --git a/src/poseidon/septidon/params.rs b/src/poseidon/septidon/params.rs index 823f2bb..cb0c263 100644 --- a/src/poseidon/septidon/params.rs +++ b/src/poseidon/septidon/params.rs @@ -12,32 +12,149 @@ pub trait CachedConstants: P128Pow5T3Constants { fn cached_mds() -> &'static Mds; /// cached inversed mds fn cached_mds_inv() -> &'static Mds; + /// cached pow5 calc result + fn cached_pow5(self) -> (Self, Option) { + (self, None) + } + /// cached muladd calc result + fn cached_muladd(vector: [Self; 3]) -> ([Self; 3], Option<[Self; 3]>) { + (vector, None) + } } -pub mod sbox { - use super::super::util::pow_5; - use halo2_proofs::arithmetic::FieldExt; +/// Wrap Fr as Hash key +#[derive(Eq, PartialEq, Debug, Hash)] +pub struct KeyConstant(T); + +pub mod calc { + use super::CachedConstants; use halo2_proofs::plonk::Expression; - pub fn expr(input: Expression, round_constant: Expression) -> Expression { - pow_5::expr(input + round_constant) + pub mod sbox { + use super::super::super::util::pow_5; + use super::*; + pub fn expr( + input: Expression, + round_constant: Expression, + ) -> Expression { + pow_5::expr(input + round_constant) + } + + pub fn value(input: F, round_constant: F) -> F { + let val_added = input + round_constant; + let (val_added, ret) = val_added.cached_pow5(); + ret.unwrap_or_else(|| pow_5::value(val_added)) + } } - pub fn value(input: F, round_constant: F) -> F { - pow_5::value(input + round_constant) + pub mod matmul { + use super::super::super::util::matmul; + use super::*; + + /// Multiply a vector of expressions by a constant matrix. + pub fn expr(vector: [Expression; 3]) -> [Expression; 3] { + matmul::expr(F::cached_mds(), vector) + } + + /// Multiply a vector of values by a constant matrix. + pub fn value(vector: [F; 3]) -> [F; 3] { + let (vector, ret) = F::cached_muladd(vector); + ret.unwrap_or_else(|| matmul::value(F::cached_mds(), vector)) + } } } pub type Mds = MdsT; mod bn254 { - use super::{CachedConstants, Mds}; + use super::super::util::{matmul, pow_5}; + use super::*; use crate::poseidon::primitives::{P128Pow5T3Compact, Spec}; use halo2_proofs::halo2curves::bn256::Fr as F; use lazy_static::lazy_static; + use std::iter; + + type Pow5CacheMap = std::collections::HashMap, F>; + type MulAddCacheMap = std::collections::HashMap<[KeyConstant; 3], [F; 3]>; + lazy_static! { // Cache the round constants and the MDS matrix (and unused inverse MDS matrix). static ref CONSTANTS: (Vec<[F; 3]>, Mds, Mds) = P128Pow5T3Compact::::constants(); + pub static ref POW5_CONSTANTS: Pow5CacheMap = { + + let r_f = P128Pow5T3Compact::::full_rounds() / 2; + let r_p = P128Pow5T3Compact::::partial_rounds(); + let mds = &CONSTANTS.1; + + let full_round = |ret: &mut Pow5CacheMap, state: &mut [F; 3], rcs: &[F; 3]| { + for (word, rc) in state.iter_mut().zip(rcs.iter()) { + let key = KeyConstant(*word + rc); + *word = pow_5::value(*word + rc); + ret.insert(key, *word); + } + *state = matmul::value(mds, *state); + }; + + let part_round = |ret: &mut Pow5CacheMap, state: &mut [F; 3], rcs: &[F; 3]| { + // In a partial round, the S-box is only applied to the first state word. + // and the compact constants has only first rc is not zero + let key = KeyConstant(state[0]+rcs[0]); + state[0] = pow_5::value(state[0]+rcs[0]); + ret.insert(key, state[0]); + *state = matmul::value(mds, *state); + }; + + let (ret, _) = iter::empty() + .chain(iter::repeat(&full_round as &dyn Fn(&mut Pow5CacheMap, &mut [F; 3], &[F; 3])).take(r_f)) + .chain(iter::repeat(&part_round as &dyn Fn(&mut Pow5CacheMap, &mut [F; 3], &[F; 3])).take(r_p)) + .chain(iter::repeat(&full_round as &dyn Fn(&mut Pow5CacheMap, &mut [F; 3], &[F; 3])).take(r_f)) + .zip(CONSTANTS.0.iter()) + .fold((Pow5CacheMap::new(), [F::zero();3]), |(mut ret, mut state), (round, rcs)| { + round(&mut ret, &mut state, rcs); + (ret, state) + }); + + //let mut t_state = [F::zero(); 3]; + //crate::poseidon::primitives::permute::, 3, 2>(&mut t_state, mds, &CONSTANTS.0); + //assert_eq!(t_state, state); + ret + }; + static ref MULADD_CONSTANTS: MulAddCacheMap = { + let r_f = P128Pow5T3Compact::::full_rounds() / 2; + let r_p = P128Pow5T3Compact::::partial_rounds(); + let mds = &CONSTANTS.1; + + let full_round = |ret: &mut MulAddCacheMap, state: &mut [F; 3], rcs: &[F; 3]| { + for (word, rc) in state.iter_mut().zip(rcs.iter()) { + *word = pow_5::value(*word + rc); + } + let key = state.map(KeyConstant); + *state = matmul::value(mds, *state); + ret.insert(key, *state); + }; + + let part_round = |ret: &mut MulAddCacheMap, state: &mut [F; 3], rcs: &[F; 3]| { + // In a partial round, the S-box is only applied to the first state word. + // and the compact constants has only first rc is not zero + state[0] = pow_5::value(state[0]+rcs[0]); + let key = state.map(KeyConstant); + *state = matmul::value(mds, *state); + ret.insert(key, *state); + }; + + let (ret, _) = iter::empty() + .chain(iter::repeat(&full_round as &dyn Fn(&mut MulAddCacheMap, &mut [F; 3], &[F; 3])).take(r_f)) + .chain(iter::repeat(&part_round as &dyn Fn(&mut MulAddCacheMap, &mut [F; 3], &[F; 3])).take(r_p)) + .chain(iter::repeat(&full_round as &dyn Fn(&mut MulAddCacheMap, &mut [F; 3], &[F; 3])).take(r_f)) + .zip(CONSTANTS.0.iter()) + .fold((MulAddCacheMap::new(), [F::zero();3]), |(mut ret, mut state), (round, rcs)| { + round(&mut ret, &mut state, rcs); + (ret, state) + }); + + ret + }; + } impl CachedConstants for F { @@ -50,6 +167,18 @@ mod bn254 { fn cached_mds_inv() -> &'static Mds { &CONSTANTS.2 } + #[cfg(feature = "cached")] + fn cached_pow5(self) -> (Self, Option) { + let key = KeyConstant(self); + let ret = POW5_CONSTANTS.get(&key).copied(); + (key.0, ret) + } + #[cfg(feature = "cached")] + fn cached_muladd(vector: [Self; 3]) -> ([Self; 3], Option<[Self; 3]>) { + let key = vector.map(KeyConstant); + let ret = MULADD_CONSTANTS.get(&key).copied(); + (key.map(|k| k.0), ret) + } } } @@ -60,3 +189,17 @@ pub fn round_constant(index: usize) -> [F; 3] { pub fn mds() -> &'static Mds { F::cached_mds() } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_constants() { + println!( + "{:?},{:?}", + bn254::POW5_CONSTANTS.keys(), + bn254::POW5_CONSTANTS.values() + ); + } +} diff --git a/src/poseidon/septidon/septuple_round.rs b/src/poseidon/septidon/septuple_round.rs index 95e36f5..d1443a7 100644 --- a/src/poseidon/septidon/septuple_round.rs +++ b/src/poseidon/septidon/septuple_round.rs @@ -1,7 +1,7 @@ use super::loop_chip::LoopBody; -use super::params::{mds, CachedConstants}; +use super::params::{calc::matmul, CachedConstants}; use super::state::{Cell, SBox}; -use super::util::{join_values, matmul, query, split_values}; +use super::util::{join_values, query, split_values}; use halo2_proofs::circuit::{Region, Value}; //use halo2_proofs::halo2curves::bn256::Fr as F; use halo2_proofs::plonk::{ConstraintSystem, Constraints, Error, Expression, VirtualCells}; @@ -82,7 +82,7 @@ impl SeptupleRoundChip { input: &[Expression; 3], ) -> [Expression; 3] { let sbox_out = [sbox.output_expr(meta), input[1].clone(), input[2].clone()]; - matmul::expr(mds(), sbox_out) + matmul::expr(sbox_out) } pub fn input(&self) -> [Cell; 3] { @@ -111,7 +111,7 @@ impl SeptupleRoundChip { // Assign the following S-Boxes. state[0] = sbox.assign(region, offset, round_constants[i], state[0])?; // Apply the matrix. - state = split_values(join_values(state).map(|s| matmul::value(mds(), s))); + state = split_values(join_values(state).map(|s| matmul::value(s))); Ok(()) }; diff --git a/src/poseidon/septidon/state.rs b/src/poseidon/septidon/state.rs index b0faee7..d12ef21 100644 --- a/src/poseidon/septidon/state.rs +++ b/src/poseidon/septidon/state.rs @@ -1,4 +1,4 @@ -use super::params; +use super::params::{self, CachedConstants}; use halo2_proofs::circuit::{Region, Value}; //use halo2_proofs::halo2curves::bn256::Fr as F; use halo2_proofs::arithmetic::FieldExt; @@ -72,7 +72,7 @@ impl SBox { } /// Assign the witness of the input. - pub fn assign( + pub fn assign( &self, region: &mut Region<'_, F>, offset: usize, @@ -91,7 +91,7 @@ impl SBox { offset + self.input.region_offset(), || input, )?; - let output = input.map(|i| params::sbox::value(i, round_constant)); + let output = input.map(|i| params::calc::sbox::value(i, round_constant)); Ok(output) } @@ -103,10 +103,10 @@ impl SBox { meta.query_fixed(self.round_constant, Rotation(self.input.offset)) } - pub fn output_expr(&self, meta: &mut VirtualCells<'_, F>) -> Expression { + pub fn output_expr(&self, meta: &mut VirtualCells<'_, F>) -> Expression { let input = self.input_expr(meta); let round_constant = self.rc_expr(meta); - params::sbox::expr(input, round_constant) + params::calc::sbox::expr(input, round_constant) } } diff --git a/src/poseidon/septidon/transition_round.rs b/src/poseidon/septidon/transition_round.rs index ab55e7f..4b9a7c6 100644 --- a/src/poseidon/septidon/transition_round.rs +++ b/src/poseidon/septidon/transition_round.rs @@ -1,7 +1,7 @@ use super::params; -use super::params::{mds, round_constant, CachedConstants}; +use super::params::{calc::matmul, round_constant, CachedConstants}; use super::state::Cell; -use super::util::{join_values, matmul, split_values}; +use super::util::{join_values, split_values}; use halo2_proofs::circuit::{Region, Value}; //use halo2_proofs::halo2curves::bn256::Fr as F; use halo2_proofs::plonk::{Advice, Column, ConstraintSystem, Constraints, Error, Expression}; @@ -58,11 +58,11 @@ impl TransitionRoundChip { ) -> [Expression; 3] { let rc = Expression::Constant(Self::round_constant()); let sbox_out = [ - params::sbox::expr(input[0].clone(), rc), + params::calc::sbox::expr(input[0].clone(), rc), input[1].clone(), input[2].clone(), ]; - matmul::expr(mds(), sbox_out) + matmul::expr(sbox_out) } fn round_constant() -> F { @@ -102,11 +102,11 @@ impl TransitionRoundChip { fn first_partial_round(input: &[Value; 3]) -> [Value; 3] { let sbox_out = [ - input[0].map(|f| params::sbox::value(f, Self::round_constant())), + input[0].map(|f| params::calc::sbox::value(f, Self::round_constant())), input[1], input[2], ]; - let output = join_values(sbox_out).map(|s| matmul::value(mds(), s)); + let output = join_values(sbox_out).map(|s| matmul::value(s)); split_values(output) } diff --git a/src/poseidon/septidon/util.rs b/src/poseidon/septidon/util.rs index 142afe9..f51f0df 100644 --- a/src/poseidon/septidon/util.rs +++ b/src/poseidon/septidon/util.rs @@ -41,7 +41,7 @@ pub fn split_values(values: Value<[F; 3]>) -> [Value; 3] { ] } -pub mod pow_5 { +pub(crate) mod pow_5 { use super::FieldExt; use halo2_proofs::plonk::Expression; @@ -57,7 +57,7 @@ pub mod pow_5 { } /// Matrix multiplication expressions and values. -pub mod matmul { +pub(crate) mod matmul { use super::super::params::Mds; use super::FieldExt; use halo2_proofs::plonk::Expression; diff --git a/tests/hash_proving.rs b/tests/hash_proving.rs index daefbea..20affea 100644 --- a/tests/hash_proving.rs +++ b/tests/hash_proving.rs @@ -17,7 +17,7 @@ use halo2_proofs::{ circuit::{Layouter, SimpleFloorPlanner}, plonk::{Circuit, ConstraintSystem, Error}, }; -use poseidon_circuit::poseidon::Pow5Chip; +//use poseidon_circuit::poseidon::Pow5Chip; use poseidon_circuit::{hash::*, DEFAULT_STEP}; use rand::SeedableRng; use rand_chacha::ChaCha8Rng; @@ -26,7 +26,7 @@ struct TestCircuit(PoseidonHashTable, usize); // test circuit derived from table data impl Circuit for TestCircuit { - type Config = SpongeConfig>; + type Config = PoseidonHashConfig; type FloorPlanner = SimpleFloorPlanner; fn without_witnesses(&self) -> Self { @@ -44,8 +44,7 @@ impl Circuit for TestCircuit { config: Self::Config, mut layouter: impl Layouter, ) -> Result<(), Error> { - let chip = - SpongeChip::>::construct(config, &self.0, self.1); + let chip = PoseidonHashChip::::construct(config, &self.0, self.1); chip.load(&mut layouter) } }