diff --git a/Cargo.toml b/Cargo.toml index a0a0f09f..e14d5c4d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -76,6 +76,7 @@ block-multiplier-codegen = { path = "skyscraper/block-multiplier-codegen" } fp-rounding = { path = "skyscraper/fp-rounding" } hla = { path = "skyscraper/hla" } skyscraper = { path = "skyscraper/core" } +ntt = {path = "ntt"} # Workspace members - ProveKit provekit-bench = { path = "tooling/provekit-bench" } @@ -142,7 +143,8 @@ noirc_driver = { git = "https://github.com/noir-lang/noir", rev = "v1.0.0-beta.1 ark-bn254 = { version = "0.5.0", default-features = false, features = [ "scalar_field", ] } -ark-crypto-primitives = { version = "0.5", features = ["merkle_tree"] } +# ark-crypto-primitives = { version = "0.5", features = ["merkle_tree"] } +ark-crypto-primitives = { git = "https://github.com/veljkovranic/crypto-primitives", features = ["merkle_tree"], rev = "fa0623c7fb69ce8b39a5397a5c6de1ec84151295" } ark-ff = { version = "0.5", features = ["asm", "std"] } ark-poly = "0.5" ark-serialize = "0.5" @@ -151,4 +153,4 @@ spongefish = { git = "https://github.com/arkworks-rs/spongefish", features = [ "arkworks-algebra", ] } spongefish-pow = { git = "https://github.com/arkworks-rs/spongefish" } -whir = { git = "https://github.com/WizardOfMenlo/whir/", features = ["tracing"], rev = "3d627d31cec7d73a470a31a913229dd3128ee0cf" } +whir = { git = "https://github.com/xrvdg/whir", branch="xr/integration-custom-arkworks", features = ["tracing"] } diff --git a/ntt/src/lib.rs b/ntt/src/lib.rs index 3443f9d0..1eb62fd2 100644 --- a/ntt/src/lib.rs +++ b/ntt/src/lib.rs @@ -13,25 +13,41 @@ impl + AsMut<[T]>> NTTContainer for C {} /// The NTT is optimized for NTTs of a power of two. Arbitrary sized NTTs are /// not supported. Note: empty vectors (size 0) are also supported as a special /// case. +/// +/// NTTContainer can be a single polynomial or multiple polynomials that are +/// interleaved. interleaved polynomials; `[a0, b0, c0, d0, a1, b1, c1, d1, +/// ...]` for four polynomials `a`, `b`, `c`, and `d`. By operating on +/// interleaved data, you can perform the NTT on all polynomials in-place +/// without needing to first transpose the data #[derive(Debug, Clone, PartialEq)] pub struct NTT> { container: C, + order: Pow2, _phantom: PhantomData, } impl> NTT { - pub fn new(vec: C) -> Option { - match Pow2::::new(vec.as_ref().len()) { - Some(_) => Some(Self { - container: vec, - _phantom: PhantomData, - }), - _ => None, + pub fn new(vec: C, number_of_polynomials: usize) -> Option { + let n = vec.as_ref().len(); + // All polynomials of the same size + if number_of_polynomials == 0 || n % number_of_polynomials != 0 { + return None; } + + // The order of the individual polynomials needs to be a power of two + Pow2::new(n / number_of_polynomials).map(|order| Self { + container: vec, + order, + _phantom: PhantomData, + }) } pub fn order(&self) -> Pow2 { - Pow2(self.container.as_ref().len()) + self.order + } + + pub fn into_inner(self) -> C { + self.container } } @@ -54,7 +70,7 @@ impl> DerefMut for NTT { /// The allowed values depend on the type parameter: /// - `Pow2`: length is 0 or a power of two (`{0} ∪ {2ⁿ : n ≥ 0}`). /// - `Pow2>`: length is a nonzero power of two (`{2ⁿ : n ≥ 0}`). -#[derive(Clone, Copy, Debug)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)] pub struct Pow2(T); impl Pow2 { diff --git a/ntt/src/main.rs b/ntt/src/main.rs index 61120d09..447a7066 100644 --- a/ntt/src/main.rs +++ b/ntt/src/main.rs @@ -1,15 +1,14 @@ /// Executable for profiling NTT use { ark_bn254::Fr, - ntt::{NTTEngine, NTT}, + ntt::{ntt_nr, NTT}, std::hint::black_box, }; fn main() { rayon::ThreadPoolBuilder::new().build_global().unwrap(); - let mut input = NTT::new(vec![Fr::from(1); 2_usize.pow(24)]).unwrap(); - let mut engine = NTTEngine::with_order(input.order()); - engine.ntt_nr(&mut input); + let mut input = NTT::new(vec![Fr::from(1); 2_usize.pow(24)], 1).unwrap(); + ntt_nr(&mut input); black_box(input); } diff --git a/ntt/src/ntt.rs b/ntt/src/ntt.rs index 9c39b42f..c8c05a96 100644 --- a/ntt/src/ntt.rs +++ b/ntt/src/ntt.rs @@ -6,7 +6,10 @@ use { iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator}, slice::ParallelSliceMut, }, - std::{mem::size_of, num::NonZeroUsize}, + std::{ + mem::size_of, + sync::{LazyLock, RwLock}, + }, }; // Taken from utils in noir-r1cs crate @@ -66,8 +69,7 @@ impl NTTEngine { if new_half_order > old_half_order { let col_len = new_half_order / old_half_order; let unity = Fr::get_root_of_unity(*order as u64).unwrap(); - // Remark: change this to reserve exact if tighter control on memory is needed - table.reserve(new_half_order - old_half_order); + table.reserve_exact(new_half_order - old_half_order); let (init, uninit) = table.split_at_spare_mut(); // When viewing the roots as a matrix every row is a multiple of the first row @@ -91,43 +93,35 @@ impl NTTEngine { } } - /// Performs an in-place, interleaved Number Theoretic Transform (NTT) in - /// normal-to-reverse order. - /// - /// # Use Case - /// Use this function when you have multiple polynomials - /// stored in an interleaved fashion within a single vector, such as - /// `[a0, b0, c0, d0, a1, b1, c1, d1, ...]` for four polynomials `a`, `b`, - /// `c`, and `d`. By operating on interleaved data, you can perform the - /// NTT on all polynomials in-place without needing to first transpose - /// the data. - /// - /// For a single polynomial use [`NTTEngine::ntt_nr`]. - /// - /// # Arguments - /// * `values` - A mutable reference to an NTT container holding the - /// coefficients to be transformed. - /// * `num_of_polys` - The number of interleaved polynomials in `values`. - - // TODO(xrvdg) The NTT can work with any number of interleaving but requires the - // individual polynomials to be a power of two. - pub fn interleaved_ntt_nr>( - &mut self, - values: &mut NTT, - num_of_polys: Pow2, - ) { - self.extend_roots_table(Pow2::new(*values.order() / *num_of_polys).unwrap()); - interleaved_ntt_nr(&self.0, values, num_of_polys); + // Returns the maximum order that it supports without extention + fn order(&self) -> Pow2 { + Pow2(self.0.len() * 2) } +} - pub fn ntt_nr>(&mut self, values: &mut NTT) { - self.interleaved_ntt_nr(values, NonZeroUsize::new(1).and_then(Pow2::new).unwrap()); - } +static ENGINE: LazyLock> = LazyLock::new(|| RwLock::new(NTTEngine::new())); - pub fn intt_rn>(&mut self, values: &mut NTT) { - self.extend_roots_table(values.order()); - intt_rn(&self.0, values); - } +/// Performs an in-place, interleaved Number Theoretic Transform (NTT) in +/// normal-to-reverse bit order. +/// +/// # Arguments +/// * `values` - A mutable reference to an NTT container holding the +/// coefficients to be transformed. +pub fn ntt_nr>(values: &mut NTT) { + let roots = ENGINE.read().unwrap(); + let new_root = if roots.order() >= values.order() { + roots + } else { + // Drop read lock + drop(roots); + let mut roots = ENGINE.write().unwrap(); + roots.extend_roots_table(values.order()); + // Drop write lock + drop(roots); + ENGINE.read().unwrap() + }; + + interleaved_ntt_nr(&new_root.0, values) } impl Default for NTTEngine { @@ -136,40 +130,39 @@ impl Default for NTTEngine { } } -fn ntt_nr>(reverse_ordered_roots: &[Fr], values: &mut NTT) { - interleaved_ntt_nr( - reverse_ordered_roots, - values, - NonZeroUsize::new(1).and_then(Pow2::new).unwrap(), - ); -} - /// In-place Number Theoretic Transform (NTT) from normal order to reverse bit /// order. /// +/// # Use Case +/// +/// Use this function when you have multiple polynomials +/// stored in an interleaved fashion within a single vector, such as +/// `[a0, b0, c0, d0, a1, b1, c1, d1, ...]` for four polynomials `a`, `b`, +/// `c`, and `d`. By operating on interleaved data, you can perform the +/// NTT on all polynomials in-place without needing to first transpose +/// the data +/// /// # Arguments /// * `reversed_ordered_roots` - Precomputed roots of unity in reverse bit /// order. /// * `values` - coefficients to be transformed in place with evaluation or vice /// versa. -fn interleaved_ntt_nr>( - reversed_ordered_roots: &[Fr], - values: &mut NTT, - num_of_polys: Pow2, -) { +fn interleaved_ntt_nr>(reversed_ordered_roots: &[Fr], values: &mut NTT) { // Reversed ordered roots idea from "Inside the FFT blackbox" // Implementation is a DIT NR algorithm let n = values.len(); // The order of the interleaved NTTs themselves - let order = n / *num_of_polys; + let order = values.order().0; // This conditional is here because chunk_size for *chunk_exact_mut can't be 0 if order <= 1 { return; } + let number_of_polyes = n / order; + // Each unique twiddle factor within a stage is a group. let mut pairs_in_group = n / 2; let mut num_of_groups = 1; @@ -222,7 +215,7 @@ fn interleaved_ntt_nr>( .par_chunks_exact_mut(2 * pairs_in_group) .enumerate() .for_each(|(k, group)| { - dit_nr_cache(reversed_ordered_roots, k, group, num_of_polys); + dit_nr_cache(reversed_ordered_roots, k, group, number_of_polyes); }); } @@ -230,7 +223,7 @@ fn dit_nr_cache( reverse_ordered_roots: &[Fr], segment: usize, input: &mut [Fr], - num_of_polys: Pow2, + num_of_polys: usize, ) { let n = input.len(); debug_assert!(n.is_power_of_two()); @@ -238,7 +231,7 @@ fn dit_nr_cache( let mut pairs_in_group = n / 2; let mut num_of_groups = 1; - let single_n = n / *num_of_polys; + let single_n = n / num_of_polys; while num_of_groups < single_n { let twiddle_base = segment * num_of_groups; @@ -328,20 +321,20 @@ fn reverse_order>(values: &mut NTT) { } /// Note: not specifically optimized -fn intt_rn>(reverse_ordered_roots: &[Fr], input: &mut NTT) { +pub fn intt_rn>(input: &mut NTT) { reverse_order(input); - intt_nr(reverse_ordered_roots, input); + intt_nr(input); reverse_order(input); } // Inverse NTT -fn intt_nr>(reverse_ordered_roots: &[Fr], values: &mut NTT) { +fn intt_nr>(values: &mut NTT) { match *values.order() { 0 => (), n => { // Reverse the input such that the roots act as inverse roots values[1..].reverse(); - ntt_nr(reverse_ordered_roots, values); + ntt_nr(values); let factor = Fr::ONE / Fr::from(n as u64); @@ -358,7 +351,10 @@ mod tests { use proptest::prelude::*; use { super::{init_roots_reverse_ordered, reverse_order}, - crate::{ntt::NTTEngine, Pow2, NTT}, + crate::{ + ntt::{intt_rn, NTTEngine}, + ntt_nr, Pow2, NTT, + }, ark_bn254::Fr, ark_ff::BigInt, proptest::collection, @@ -386,12 +382,13 @@ mod tests { /// length. fn ntt( sizes: impl Strategy, + number_of_polynomials: usize, elem: impl Strategy + Clone, ) -> impl Strategy>> { sizes .prop_map(|k| 1 << k) .prop_flat_map(move |len| collection::vec(elem.clone(), len..=len)) - .prop_map(|v| NTT::new(v).unwrap()) + .prop_map(move |v| NTT::new(v, number_of_polynomials).unwrap()) } /// Newtype wrapper to prevent proptest from writing the contents of an NTT @@ -399,6 +396,7 @@ mod tests { /// /// If the contents does have to be viewed replace [`hidden_ntt`] with /// [`ntt`] as the test strategy + #[derive(Clone, PartialEq)] struct HiddenNTT(NTT>); impl fmt::Debug for HiddenNTT { @@ -409,23 +407,23 @@ mod tests { fn hidden_ntt( sizes: impl Strategy, + number_of_polynomials: usize, elem: impl Strategy + Clone, ) -> impl Strategy> { - ntt(sizes, elem).prop_map(HiddenNTT) + ntt(sizes, number_of_polynomials, elem).prop_map(HiddenNTT) } proptest! { #[test] - fn round_trip_ntt(original in ntt(0_usize..15, fr())) + fn round_trip_ntt(original in hidden_ntt(0_usize..15, 1, fr())) { let mut s = original.clone(); - let mut engine = NTTEngine::new(); // Forward NTT - engine.ntt_nr(&mut s); + ntt_nr(&mut s.0); // Inverse NTT - engine.intt_rn(&mut s); + intt_rn(&mut s.0); prop_assert_eq!(original,s); } @@ -483,7 +481,7 @@ mod tests { ( Just(constr(len - column)), Just(constr(column)), - hidden_ntt(len..=len, fr()), + hidden_ntt(len..=len, constr(column).get(), fr()), ) }) }) @@ -494,17 +492,16 @@ mod tests { fn test_interleaved((rows, columns, ntt) in interleaving_strategy(0_usize..20)) { let mut ntt = ntt.0; let mut transposed = transpose(&ntt, rows.get(), columns.get()); - let mut engine = NTTEngine::new(); for chunk in transposed.chunks_exact_mut(rows.get()){ - let mut fold = NTT::new(chunk).unwrap(); - engine.ntt_nr(&mut fold); + let mut fold = NTT::new(chunk,1).unwrap(); + ntt_nr(&mut fold); } - let double_transposed = NTT::new(transpose(&transposed, columns.get(), rows.get())).unwrap(); + let double_transposed = transpose(&transposed, columns.get(), rows.get()); - engine.interleaved_ntt_nr(&mut ntt, columns); - prop_assert!(double_transposed == ntt); + ntt_nr(&mut ntt); + prop_assert!(double_transposed == ntt.into_inner()); } } @@ -512,9 +509,8 @@ mod tests { #[test] // The roundtrip test doesn't test size 0. fn ntt_empty() { - let mut v = NTT::new(vec![]).unwrap(); - let mut engine = NTTEngine::new(); - engine.ntt_nr(&mut v); + let mut v = NTT::new(vec![], 1).unwrap(); + ntt_nr(&mut v); } // Compare direct generation of the roots vs. extending from a base set of roots @@ -529,7 +525,7 @@ mod tests { proptest! { #[test] - fn round_trip_reverse_order(original in ntt(0_usize..10, any::())){ + fn round_trip_reverse_order(original in ntt(0_usize..10, 1, any::())){ let mut v = original.clone(); reverse_order(&mut v); reverse_order(&mut v); @@ -539,7 +535,7 @@ mod tests { proptest! { #[test] - fn reverse_order_noop(original in ntt(0_usize..=1, any::())) { + fn reverse_order_noop(original in ntt(0_usize..=1, 1, any::())) { let mut v = original.clone(); reverse_order(&mut v); assert_eq!(original, v) diff --git a/provekit/common/Cargo.toml b/provekit/common/Cargo.toml index d39db74e..ad508ebc 100644 --- a/provekit/common/Cargo.toml +++ b/provekit/common/Cargo.toml @@ -23,6 +23,7 @@ ark-crypto-primitives.workspace = true ark-ff.workspace = true ark-serialize.workspace = true ark-std.workspace = true +ntt.workspace = true spongefish.workspace = true spongefish-pow.workspace = true whir.workspace = true diff --git a/provekit/common/src/lib.rs b/provekit/common/src/lib.rs index 680715d8..b89539e6 100644 --- a/provekit/common/src/lib.rs +++ b/provekit/common/src/lib.rs @@ -1,6 +1,7 @@ pub mod file; mod interner; mod noir_proof_scheme; +pub mod ntt; mod prover; mod r1cs; pub mod skyscraper; @@ -16,11 +17,11 @@ use crate::{ }; pub use { acir::FieldElement as NoirElement, + ark_bn254::Fr as FieldElement, noir_proof_scheme::{NoirProof, NoirProofScheme}, prover::Prover, r1cs::R1CS, verifier::Verifier, - whir::crypto::fields::Field256 as FieldElement, whir_r1cs::{IOPattern, WhirConfig, WhirR1CSProof, WhirR1CSScheme}, }; diff --git a/provekit/common/src/ntt.rs b/provekit/common/src/ntt.rs new file mode 100644 index 00000000..c8cfaaa5 --- /dev/null +++ b/provekit/common/src/ntt.rs @@ -0,0 +1,39 @@ +use {ark_bn254::Fr, ark_ff::AdditiveGroup, ntt::ntt_nr, whir::ntt::ReedSolomon}; + +pub struct RSFr; +impl ReedSolomon for RSFr { + fn interleaved_encode( + &self, + interleaved_coeffs: &[Fr], + expansion: usize, + fold_factor: usize, + ) -> Vec { + interleaved_rs_encode(interleaved_coeffs, expansion, fold_factor) + } + + fn is_bit_reversed(&self) -> bool { + true + } +} + +fn interleaved_rs_encode( + interleaved_coeffs: &[Fr], + expansion: usize, + fold_factor: usize, +) -> Vec { + let fold_factor_exp = 2usize.pow(fold_factor as u32); + let expanded_size = interleaved_coeffs.len() * expansion; + + debug_assert_eq!(expanded_size % fold_factor_exp, 0); + + // 1. Create zero-padded message of appropriate size + let mut result = vec![Fr::ZERO; expanded_size]; + result[..interleaved_coeffs.len()].copy_from_slice(interleaved_coeffs); + + let mut ntt = ntt::NTT::new(result, fold_factor_exp) + .expect("interleaved_coeffs.len() * expension / 2^fold_factor needs to be a power of two."); + + ntt_nr(&mut ntt); + + ntt.into_inner() +} diff --git a/provekit/prover/src/whir_r1cs.rs b/provekit/prover/src/whir_r1cs.rs index 4f92e79c..ea2f4abe 100644 --- a/provekit/prover/src/whir_r1cs.rs +++ b/provekit/prover/src/whir_r1cs.rs @@ -146,10 +146,7 @@ pub fn compute_blinding_coefficients_for_round( for _ in 0..(n - 1 - compute_for) { prefix_multiplier = prefix_multiplier + prefix_multiplier; } - let suffix_multiplier: ark_ff::Fp< - ark_ff::MontBackend, - 4, - > = prefix_multiplier / two; + let suffix_multiplier = prefix_multiplier / two; let constant_term_from_other_items = prefix_multiplier * prefix_sum + suffix_multiplier * suffix_sum; diff --git a/provekit/r1cs-compiler/src/whir_r1cs.rs b/provekit/r1cs-compiler/src/whir_r1cs.rs index e08ac26f..1411518b 100644 --- a/provekit/r1cs-compiler/src/whir_r1cs.rs +++ b/provekit/r1cs-compiler/src/whir_r1cs.rs @@ -1,5 +1,6 @@ use { - provekit_common::{utils::next_power_of_two, FieldElement, WhirConfig, WhirR1CSScheme, R1CS}, + provekit_common::{ntt::RSFr, utils::next_power_of_two, WhirConfig, WhirR1CSScheme, R1CS}, + std::sync::Arc, whir::parameters::{ default_max_pow, DeduplicationStrategy, FoldingFactor, MerkleProofStrategy, MultivariateParameters, ProtocolParameters, SoundnessType, @@ -63,6 +64,8 @@ impl WhirR1CSSchemeBuilder for WhirR1CSScheme { deduplication_strategy: DeduplicationStrategy::Disabled, merkle_proof_strategy: MerkleProofStrategy::Uncompressed, }; - WhirConfig::new(mv_params, whir_params) + let reed_solomon = Arc::new(RSFr); + let basefield_reed_solomon = reed_solomon.clone(); + WhirConfig::new(reed_solomon, basefield_reed_solomon, mv_params, whir_params) } }