From 303a876cb3bfa1107130d65fe292694a83c774f7 Mon Sep 17 00:00:00 2001 From: Xander van der Goot Date: Tue, 21 Oct 2025 12:55:50 +0800 Subject: [PATCH 1/6] ReedSolomon: Use WHIR RS trait (Generics variant) --- Cargo.toml | 3 +- provekit/common/Cargo.toml | 1 + provekit/common/src/lib.rs | 3 +- provekit/common/src/ntt.rs | 47 ++++++++++++++++++++++++++++++++ provekit/prover/src/whir_r1cs.rs | 12 ++++---- 5 files changed, 58 insertions(+), 8 deletions(-) create mode 100644 provekit/common/src/ntt.rs diff --git a/Cargo.toml b/Cargo.toml index a0a0f09f..20660c12 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -76,6 +76,7 @@ block-multiplier-codegen = { path = "skyscraper/block-multiplier-codegen" } fp-rounding = { path = "skyscraper/fp-rounding" } hla = { path = "skyscraper/hla" } skyscraper = { path = "skyscraper/core" } +ntt = {path = "ntt"} # Workspace members - ProveKit provekit-bench = { path = "tooling/provekit-bench" } @@ -151,4 +152,4 @@ spongefish = { git = "https://github.com/arkworks-rs/spongefish", features = [ "arkworks-algebra", ] } spongefish-pow = { git = "https://github.com/arkworks-rs/spongefish" } -whir = { git = "https://github.com/WizardOfMenlo/whir/", features = ["tracing"], rev = "3d627d31cec7d73a470a31a913229dd3128ee0cf" } +whir = { path = "../whir-main", features = ["tracing"] } diff --git a/provekit/common/Cargo.toml b/provekit/common/Cargo.toml index d39db74e..ad508ebc 100644 --- a/provekit/common/Cargo.toml +++ b/provekit/common/Cargo.toml @@ -23,6 +23,7 @@ ark-crypto-primitives.workspace = true ark-ff.workspace = true ark-serialize.workspace = true ark-std.workspace = true +ntt.workspace = true spongefish.workspace = true spongefish-pow.workspace = true whir.workspace = true diff --git a/provekit/common/src/lib.rs b/provekit/common/src/lib.rs index 680715d8..b89539e6 100644 --- a/provekit/common/src/lib.rs +++ b/provekit/common/src/lib.rs @@ -1,6 +1,7 @@ pub mod file; mod interner; mod noir_proof_scheme; +pub mod ntt; mod prover; mod r1cs; pub mod skyscraper; @@ -16,11 +17,11 @@ use crate::{ }; pub use { acir::FieldElement as NoirElement, + ark_bn254::Fr as FieldElement, noir_proof_scheme::{NoirProof, NoirProofScheme}, prover::Prover, r1cs::R1CS, verifier::Verifier, - whir::crypto::fields::Field256 as FieldElement, whir_r1cs::{IOPattern, WhirConfig, WhirR1CSProof, WhirR1CSScheme}, }; diff --git a/provekit/common/src/ntt.rs b/provekit/common/src/ntt.rs new file mode 100644 index 00000000..d7d2f6f9 --- /dev/null +++ b/provekit/common/src/ntt.rs @@ -0,0 +1,47 @@ +use {ark_bn254::Fr, ark_ff::AdditiveGroup, std::num::NonZero, whir::ntt::ReedSolomon}; + +pub struct RSFr; +impl ReedSolomon for RSFr { + fn interleaved_encode( + interleaved_coeffs: &[Fr], + expansion: usize, + fold_factor: usize, + ) -> Vec { + interleaved_rs_encode(interleaved_coeffs, expansion, fold_factor) + } + + fn interleaved_basefield_encode( + interleaved_coeffs: &[Fr], + expansion: usize, + fold_factor: usize, + ) -> Vec { + interleaved_rs_encode(interleaved_coeffs, expansion, fold_factor) + } +} + +fn interleaved_rs_encode( + interleaved_coeffs: &[Fr], + expansion: usize, + fold_factor: usize, +) -> Vec { + let fold_factor_exp = 2usize.pow(fold_factor as u32); + let expanded_size = interleaved_coeffs.len() * expansion; + + debug_assert_eq!(expanded_size % fold_factor_exp, 0); + + // 1. Create zero-padded message of appropriate size + let mut result = vec![Fr::ZERO; expanded_size]; + result[..interleaved_coeffs.len()].copy_from_slice(interleaved_coeffs); + + let mut ntt = ntt::NTT::new(&mut result) + .expect("interleaved_coeffs.len() * expension needs to be a power of two."); + let mut engine = ntt::NTTEngine::new(); + engine.interleaved_ntt_nr( + &mut ntt, + NonZero::new(fold_factor_exp) + .and_then(ntt::Pow2::new) + .unwrap(), + ); + + result +} diff --git a/provekit/prover/src/whir_r1cs.rs b/provekit/prover/src/whir_r1cs.rs index 4f92e79c..80cd5911 100644 --- a/provekit/prover/src/whir_r1cs.rs +++ b/provekit/prover/src/whir_r1cs.rs @@ -22,6 +22,7 @@ use { }, tracing::{info, instrument, warn}, whir::{ + ntt::RSDefault, poly_utils::{evals::EvaluationsList, multilinear::MultilinearPoint}, whir::{ committer::{CommitmentWriter, Witness}, @@ -32,6 +33,8 @@ use { }, }; +type RS = RSDefault; + pub trait WhirR1CSProver { fn prove(&self, r1cs: R1CS, witness: Vec) -> Result; } @@ -146,10 +149,7 @@ pub fn compute_blinding_coefficients_for_round( for _ in 0..(n - 1 - compute_for) { prefix_multiplier = prefix_multiplier + prefix_multiplier; } - let suffix_multiplier: ark_ff::Fp< - ark_ff::MontBackend, - 4, - > = prefix_multiplier / two; + let suffix_multiplier = prefix_multiplier / two; let constant_term_from_other_items = prefix_multiplier * prefix_sum + suffix_multiplier * suffix_sum; @@ -211,7 +211,7 @@ pub fn batch_commit_to_polynomial( let committer = CommitmentWriter::new(whir_config.clone()); let witness_new = committer - .commit_batch(merlin, &[ + .commit_batch::<_, RS>(merlin, &[ &masked_polynomial_coeff, &random_polynomial_coeff, ]) @@ -494,7 +494,7 @@ pub fn run_zk_whir_pcs_prover( let prover = Prover::new(params.clone()); let (randomness, deferred) = prover - .prove(&mut merlin, statement, witness) + .prove::<_, RS>(&mut merlin, statement, witness) .expect("WHIR prover failed to generate a proof"); (merlin, randomness, deferred) From 57fcd1858f0a351b9d328ee78b950e8350e8027c Mon Sep 17 00:00:00 2001 From: Xander van der Goot Date: Tue, 21 Oct 2025 13:56:55 +0800 Subject: [PATCH 2/6] NTT: NTTContainer support for interleaved polynomials --- ntt/src/lib.rs | 16 ++++++---- ntt/src/main.rs | 2 +- ntt/src/ntt.rs | 64 +++++++++++++++++--------------------- provekit/common/src/ntt.rs | 9 ++---- 4 files changed, 42 insertions(+), 49 deletions(-) diff --git a/ntt/src/lib.rs b/ntt/src/lib.rs index 3443f9d0..a8a60e0d 100644 --- a/ntt/src/lib.rs +++ b/ntt/src/lib.rs @@ -16,22 +16,26 @@ impl + AsMut<[T]>> NTTContainer for C {} #[derive(Debug, Clone, PartialEq)] pub struct NTT> { container: C, + order: Pow2, _phantom: PhantomData, } impl> NTT { - pub fn new(vec: C) -> Option { - match Pow2::::new(vec.as_ref().len()) { - Some(_) => Some(Self { + pub fn new(vec: C, number_of_polynomials: usize) -> Option { + // This needs a better division. 7/3 will give 2 for example + match Pow2::::new(vec.as_ref().len() / number_of_polynomials) { + Some(order) => Some(Self { container: vec, - _phantom: PhantomData, + order, + _phantom: PhantomData, }), _ => None, } } + // TODO maybe a read only field is nicer? pub fn order(&self) -> Pow2 { - Pow2(self.container.as_ref().len()) + self.order } } @@ -54,7 +58,7 @@ impl> DerefMut for NTT { /// The allowed values depend on the type parameter: /// - `Pow2`: length is 0 or a power of two (`{0} ∪ {2ⁿ : n ≥ 0}`). /// - `Pow2>`: length is a nonzero power of two (`{2ⁿ : n ≥ 0}`). -#[derive(Clone, Copy, Debug)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)] pub struct Pow2(T); impl Pow2 { diff --git a/ntt/src/main.rs b/ntt/src/main.rs index 61120d09..5bb8634b 100644 --- a/ntt/src/main.rs +++ b/ntt/src/main.rs @@ -8,7 +8,7 @@ use { fn main() { rayon::ThreadPoolBuilder::new().build_global().unwrap(); - let mut input = NTT::new(vec![Fr::from(1); 2_usize.pow(24)]).unwrap(); + let mut input = NTT::new(vec![Fr::from(1); 2_usize.pow(24)], 1).unwrap(); let mut engine = NTTEngine::with_order(input.order()); engine.ntt_nr(&mut input); black_box(input); diff --git a/ntt/src/ntt.rs b/ntt/src/ntt.rs index 9c39b42f..b79810d7 100644 --- a/ntt/src/ntt.rs +++ b/ntt/src/ntt.rs @@ -6,7 +6,7 @@ use { iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator}, slice::ParallelSliceMut, }, - std::{mem::size_of, num::NonZeroUsize}, + std::mem::size_of, }; // Taken from utils in noir-r1cs crate @@ -111,17 +111,14 @@ impl NTTEngine { // TODO(xrvdg) The NTT can work with any number of interleaving but requires the // individual polynomials to be a power of two. - pub fn interleaved_ntt_nr>( - &mut self, - values: &mut NTT, - num_of_polys: Pow2, - ) { - self.extend_roots_table(Pow2::new(*values.order() / *num_of_polys).unwrap()); - interleaved_ntt_nr(&self.0, values, num_of_polys); + pub fn interleaved_ntt_nr>(&mut self, values: &mut NTT) { + self.extend_roots_table(values.order()); + interleaved_ntt_nr(&self.0, values); } + // TODO(xrvdg) remove this one pub fn ntt_nr>(&mut self, values: &mut NTT) { - self.interleaved_ntt_nr(values, NonZeroUsize::new(1).and_then(Pow2::new).unwrap()); + self.interleaved_ntt_nr(values); } pub fn intt_rn>(&mut self, values: &mut NTT) { @@ -137,11 +134,7 @@ impl Default for NTTEngine { } fn ntt_nr>(reverse_ordered_roots: &[Fr], values: &mut NTT) { - interleaved_ntt_nr( - reverse_ordered_roots, - values, - NonZeroUsize::new(1).and_then(Pow2::new).unwrap(), - ); + interleaved_ntt_nr(reverse_ordered_roots, values); } /// In-place Number Theoretic Transform (NTT) from normal order to reverse bit @@ -152,24 +145,22 @@ fn ntt_nr>(reverse_ordered_roots: &[Fr], values: &mut NTT>( - reversed_ordered_roots: &[Fr], - values: &mut NTT, - num_of_polys: Pow2, -) { +fn interleaved_ntt_nr>(reversed_ordered_roots: &[Fr], values: &mut NTT) { // Reversed ordered roots idea from "Inside the FFT blackbox" // Implementation is a DIT NR algorithm let n = values.len(); // The order of the interleaved NTTs themselves - let order = n / *num_of_polys; + let order = values.order().0; // This conditional is here because chunk_size for *chunk_exact_mut can't be 0 if order <= 1 { return; } + let number_of_polyes = n / order; + // Each unique twiddle factor within a stage is a group. let mut pairs_in_group = n / 2; let mut num_of_groups = 1; @@ -222,7 +213,7 @@ fn interleaved_ntt_nr>( .par_chunks_exact_mut(2 * pairs_in_group) .enumerate() .for_each(|(k, group)| { - dit_nr_cache(reversed_ordered_roots, k, group, num_of_polys); + dit_nr_cache(reversed_ordered_roots, k, group, number_of_polyes); }); } @@ -230,7 +221,7 @@ fn dit_nr_cache( reverse_ordered_roots: &[Fr], segment: usize, input: &mut [Fr], - num_of_polys: Pow2, + num_of_polys: usize, ) { let n = input.len(); debug_assert!(n.is_power_of_two()); @@ -238,7 +229,7 @@ fn dit_nr_cache( let mut pairs_in_group = n / 2; let mut num_of_groups = 1; - let single_n = n / *num_of_polys; + let single_n = n / num_of_polys; while num_of_groups < single_n { let twiddle_base = segment * num_of_groups; @@ -386,12 +377,13 @@ mod tests { /// length. fn ntt( sizes: impl Strategy, + number_of_polynomials: usize, elem: impl Strategy + Clone, ) -> impl Strategy>> { sizes .prop_map(|k| 1 << k) .prop_flat_map(move |len| collection::vec(elem.clone(), len..=len)) - .prop_map(|v| NTT::new(v).unwrap()) + .prop_map(move |v| NTT::new(v, number_of_polynomials).unwrap()) } /// Newtype wrapper to prevent proptest from writing the contents of an NTT @@ -399,6 +391,7 @@ mod tests { /// /// If the contents does have to be viewed replace [`hidden_ntt`] with /// [`ntt`] as the test strategy + #[derive(Clone, PartialEq)] struct HiddenNTT(NTT>); impl fmt::Debug for HiddenNTT { @@ -409,23 +402,24 @@ mod tests { fn hidden_ntt( sizes: impl Strategy, + number_of_polynomials: usize, elem: impl Strategy + Clone, ) -> impl Strategy> { - ntt(sizes, elem).prop_map(HiddenNTT) + ntt(sizes, number_of_polynomials, elem).prop_map(HiddenNTT) } proptest! { #[test] - fn round_trip_ntt(original in ntt(0_usize..15, fr())) + fn round_trip_ntt(original in hidden_ntt(0_usize..15, 1, fr())) { let mut s = original.clone(); let mut engine = NTTEngine::new(); // Forward NTT - engine.ntt_nr(&mut s); + engine.ntt_nr(&mut s.0); // Inverse NTT - engine.intt_rn(&mut s); + engine.intt_rn(&mut s.0); prop_assert_eq!(original,s); } @@ -483,7 +477,7 @@ mod tests { ( Just(constr(len - column)), Just(constr(column)), - hidden_ntt(len..=len, fr()), + hidden_ntt(len..=len, constr(column).get(), fr()), ) }) }) @@ -497,13 +491,13 @@ mod tests { let mut engine = NTTEngine::new(); for chunk in transposed.chunks_exact_mut(rows.get()){ - let mut fold = NTT::new(chunk).unwrap(); + let mut fold = NTT::new(chunk,1).unwrap(); engine.ntt_nr(&mut fold); } - let double_transposed = NTT::new(transpose(&transposed, columns.get(), rows.get())).unwrap(); + let double_transposed = NTT::new(transpose(&transposed, columns.get(), rows.get()),columns.get()).unwrap(); - engine.interleaved_ntt_nr(&mut ntt, columns); + engine.interleaved_ntt_nr(&mut ntt); prop_assert!(double_transposed == ntt); } @@ -512,7 +506,7 @@ mod tests { #[test] // The roundtrip test doesn't test size 0. fn ntt_empty() { - let mut v = NTT::new(vec![]).unwrap(); + let mut v = NTT::new(vec![], 1).unwrap(); let mut engine = NTTEngine::new(); engine.ntt_nr(&mut v); } @@ -529,7 +523,7 @@ mod tests { proptest! { #[test] - fn round_trip_reverse_order(original in ntt(0_usize..10, any::())){ + fn round_trip_reverse_order(original in ntt(0_usize..10, 1, any::())){ let mut v = original.clone(); reverse_order(&mut v); reverse_order(&mut v); @@ -539,7 +533,7 @@ mod tests { proptest! { #[test] - fn reverse_order_noop(original in ntt(0_usize..=1, any::())) { + fn reverse_order_noop(original in ntt(0_usize..=1, 1, any::())) { let mut v = original.clone(); reverse_order(&mut v); assert_eq!(original, v) diff --git a/provekit/common/src/ntt.rs b/provekit/common/src/ntt.rs index d7d2f6f9..57d5e854 100644 --- a/provekit/common/src/ntt.rs +++ b/provekit/common/src/ntt.rs @@ -33,15 +33,10 @@ fn interleaved_rs_encode( let mut result = vec![Fr::ZERO; expanded_size]; result[..interleaved_coeffs.len()].copy_from_slice(interleaved_coeffs); - let mut ntt = ntt::NTT::new(&mut result) + let mut ntt = ntt::NTT::new(&mut result, fold_factor_exp) .expect("interleaved_coeffs.len() * expension needs to be a power of two."); let mut engine = ntt::NTTEngine::new(); - engine.interleaved_ntt_nr( - &mut ntt, - NonZero::new(fold_factor_exp) - .and_then(ntt::Pow2::new) - .unwrap(), - ); + engine.interleaved_ntt_nr(&mut ntt); result } From 9c9539468db845a349313fdf1866b9df2cfa51ba Mon Sep 17 00:00:00 2001 From: Xander van der Goot Date: Tue, 21 Oct 2025 15:10:52 +0800 Subject: [PATCH 3/6] NTT: support for non-power of two interleaving. Power of two check is now on the individual polynomials --- ntt/src/lib.rs | 21 ++++++++++++++++++--- ntt/src/ntt.rs | 38 ++++++++++++++------------------------ provekit/common/src/ntt.rs | 8 ++++---- 3 files changed, 36 insertions(+), 31 deletions(-) diff --git a/ntt/src/lib.rs b/ntt/src/lib.rs index a8a60e0d..36e2b2e7 100644 --- a/ntt/src/lib.rs +++ b/ntt/src/lib.rs @@ -13,6 +13,12 @@ impl + AsMut<[T]>> NTTContainer for C {} /// The NTT is optimized for NTTs of a power of two. Arbitrary sized NTTs are /// not supported. Note: empty vectors (size 0) are also supported as a special /// case. +/// +/// NTTContainer can be a single polynomial or multiple polynomials that are +/// interleaved. interleaved polynomials; `[a0, b0, c0, d0, a1, b1, c1, d1, +/// ...]` for four polynomials `a`, `b`, `c`, and `d`. By operating on +/// interleaved data, you can perform the NTT on all polynomials in-place +/// without needing to first transpose the data #[derive(Debug, Clone, PartialEq)] pub struct NTT> { container: C, @@ -22,8 +28,14 @@ pub struct NTT> { impl> NTT { pub fn new(vec: C, number_of_polynomials: usize) -> Option { - // This needs a better division. 7/3 will give 2 for example - match Pow2::::new(vec.as_ref().len() / number_of_polynomials) { + let n = vec.as_ref().len(); + // All polynomials of the same size + if number_of_polynomials == 0 || n % number_of_polynomials != 0 { + return None; + } + + // The order of the individual polynomials needs to be a power of two + match Pow2::new(n / number_of_polynomials) { Some(order) => Some(Self { container: vec, order, @@ -33,10 +45,13 @@ impl> NTT { } } - // TODO maybe a read only field is nicer? pub fn order(&self) -> Pow2 { self.order } + + pub fn into_inner(self) -> C { + self.container + } } impl> Deref for NTT { diff --git a/ntt/src/ntt.rs b/ntt/src/ntt.rs index b79810d7..10cf2838 100644 --- a/ntt/src/ntt.rs +++ b/ntt/src/ntt.rs @@ -92,35 +92,16 @@ impl NTTEngine { } /// Performs an in-place, interleaved Number Theoretic Transform (NTT) in - /// normal-to-reverse order. - /// - /// # Use Case - /// Use this function when you have multiple polynomials - /// stored in an interleaved fashion within a single vector, such as - /// `[a0, b0, c0, d0, a1, b1, c1, d1, ...]` for four polynomials `a`, `b`, - /// `c`, and `d`. By operating on interleaved data, you can perform the - /// NTT on all polynomials in-place without needing to first transpose - /// the data. - /// - /// For a single polynomial use [`NTTEngine::ntt_nr`]. + /// normal-to-reverse bit order. /// /// # Arguments /// * `values` - A mutable reference to an NTT container holding the /// coefficients to be transformed. - /// * `num_of_polys` - The number of interleaved polynomials in `values`. - - // TODO(xrvdg) The NTT can work with any number of interleaving but requires the - // individual polynomials to be a power of two. - pub fn interleaved_ntt_nr>(&mut self, values: &mut NTT) { + pub fn ntt_nr>(&mut self, values: &mut NTT) { self.extend_roots_table(values.order()); interleaved_ntt_nr(&self.0, values); } - // TODO(xrvdg) remove this one - pub fn ntt_nr>(&mut self, values: &mut NTT) { - self.interleaved_ntt_nr(values); - } - pub fn intt_rn>(&mut self, values: &mut NTT) { self.extend_roots_table(values.order()); intt_rn(&self.0, values); @@ -140,6 +121,15 @@ fn ntt_nr>(reverse_ordered_roots: &[Fr], values: &mut NTT for RSFr { @@ -33,10 +33,10 @@ fn interleaved_rs_encode( let mut result = vec![Fr::ZERO; expanded_size]; result[..interleaved_coeffs.len()].copy_from_slice(interleaved_coeffs); - let mut ntt = ntt::NTT::new(&mut result, fold_factor_exp) - .expect("interleaved_coeffs.len() * expension needs to be a power of two."); + let mut ntt = ntt::NTT::new(result, fold_factor_exp) + .expect("interleaved_coeffs.len() * expension / 2^fold_factor needs to be a power of two."); let mut engine = ntt::NTTEngine::new(); engine.interleaved_ntt_nr(&mut ntt); - result + ntt.into_inner() } From f3c0b8388e39beb3c0669bc2c17285693b5034bd Mon Sep 17 00:00:00 2001 From: Xander van der Goot Date: Wed, 22 Oct 2025 12:13:22 +0800 Subject: [PATCH 4/6] NTT: cache roots global --- Cargo.toml | 2 +- ntt/src/lib.rs | 13 +++---- ntt/src/main.rs | 5 +-- ntt/src/ntt.rs | 78 ++++++++++++++++++++++---------------- provekit/common/src/ntt.rs | 6 +-- 5 files changed, 56 insertions(+), 48 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 20660c12..f36c9da3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -152,4 +152,4 @@ spongefish = { git = "https://github.com/arkworks-rs/spongefish", features = [ "arkworks-algebra", ] } spongefish-pow = { git = "https://github.com/arkworks-rs/spongefish" } -whir = { path = "../whir-main", features = ["tracing"] } +whir = { git = "https://github.com/xrvdg/whir", branch="xr/integration", features = ["tracing"] } diff --git a/ntt/src/lib.rs b/ntt/src/lib.rs index 36e2b2e7..1eb62fd2 100644 --- a/ntt/src/lib.rs +++ b/ntt/src/lib.rs @@ -35,14 +35,11 @@ impl> NTT { } // The order of the individual polynomials needs to be a power of two - match Pow2::new(n / number_of_polynomials) { - Some(order) => Some(Self { - container: vec, - order, - _phantom: PhantomData, - }), - _ => None, - } + Pow2::new(n / number_of_polynomials).map(|order| Self { + container: vec, + order, + _phantom: PhantomData, + }) } pub fn order(&self) -> Pow2 { diff --git a/ntt/src/main.rs b/ntt/src/main.rs index 5bb8634b..447a7066 100644 --- a/ntt/src/main.rs +++ b/ntt/src/main.rs @@ -1,7 +1,7 @@ /// Executable for profiling NTT use { ark_bn254::Fr, - ntt::{NTTEngine, NTT}, + ntt::{ntt_nr, NTT}, std::hint::black_box, }; @@ -9,7 +9,6 @@ fn main() { rayon::ThreadPoolBuilder::new().build_global().unwrap(); let mut input = NTT::new(vec![Fr::from(1); 2_usize.pow(24)], 1).unwrap(); - let mut engine = NTTEngine::with_order(input.order()); - engine.ntt_nr(&mut input); + ntt_nr(&mut input); black_box(input); } diff --git a/ntt/src/ntt.rs b/ntt/src/ntt.rs index 10cf2838..c8c05a96 100644 --- a/ntt/src/ntt.rs +++ b/ntt/src/ntt.rs @@ -6,7 +6,10 @@ use { iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator}, slice::ParallelSliceMut, }, - std::mem::size_of, + std::{ + mem::size_of, + sync::{LazyLock, RwLock}, + }, }; // Taken from utils in noir-r1cs crate @@ -66,8 +69,7 @@ impl NTTEngine { if new_half_order > old_half_order { let col_len = new_half_order / old_half_order; let unity = Fr::get_root_of_unity(*order as u64).unwrap(); - // Remark: change this to reserve exact if tighter control on memory is needed - table.reserve(new_half_order - old_half_order); + table.reserve_exact(new_half_order - old_half_order); let (init, uninit) = table.split_at_spare_mut(); // When viewing the roots as a matrix every row is a multiple of the first row @@ -91,21 +93,35 @@ impl NTTEngine { } } - /// Performs an in-place, interleaved Number Theoretic Transform (NTT) in - /// normal-to-reverse bit order. - /// - /// # Arguments - /// * `values` - A mutable reference to an NTT container holding the - /// coefficients to be transformed. - pub fn ntt_nr>(&mut self, values: &mut NTT) { - self.extend_roots_table(values.order()); - interleaved_ntt_nr(&self.0, values); + // Returns the maximum order that it supports without extention + fn order(&self) -> Pow2 { + Pow2(self.0.len() * 2) } +} - pub fn intt_rn>(&mut self, values: &mut NTT) { - self.extend_roots_table(values.order()); - intt_rn(&self.0, values); - } +static ENGINE: LazyLock> = LazyLock::new(|| RwLock::new(NTTEngine::new())); + +/// Performs an in-place, interleaved Number Theoretic Transform (NTT) in +/// normal-to-reverse bit order. +/// +/// # Arguments +/// * `values` - A mutable reference to an NTT container holding the +/// coefficients to be transformed. +pub fn ntt_nr>(values: &mut NTT) { + let roots = ENGINE.read().unwrap(); + let new_root = if roots.order() >= values.order() { + roots + } else { + // Drop read lock + drop(roots); + let mut roots = ENGINE.write().unwrap(); + roots.extend_roots_table(values.order()); + // Drop write lock + drop(roots); + ENGINE.read().unwrap() + }; + + interleaved_ntt_nr(&new_root.0, values) } impl Default for NTTEngine { @@ -114,10 +130,6 @@ impl Default for NTTEngine { } } -fn ntt_nr>(reverse_ordered_roots: &[Fr], values: &mut NTT) { - interleaved_ntt_nr(reverse_ordered_roots, values); -} - /// In-place Number Theoretic Transform (NTT) from normal order to reverse bit /// order. /// @@ -309,20 +321,20 @@ fn reverse_order>(values: &mut NTT) { } /// Note: not specifically optimized -fn intt_rn>(reverse_ordered_roots: &[Fr], input: &mut NTT) { +pub fn intt_rn>(input: &mut NTT) { reverse_order(input); - intt_nr(reverse_ordered_roots, input); + intt_nr(input); reverse_order(input); } // Inverse NTT -fn intt_nr>(reverse_ordered_roots: &[Fr], values: &mut NTT) { +fn intt_nr>(values: &mut NTT) { match *values.order() { 0 => (), n => { // Reverse the input such that the roots act as inverse roots values[1..].reverse(); - ntt_nr(reverse_ordered_roots, values); + ntt_nr(values); let factor = Fr::ONE / Fr::from(n as u64); @@ -339,7 +351,10 @@ mod tests { use proptest::prelude::*; use { super::{init_roots_reverse_ordered, reverse_order}, - crate::{ntt::NTTEngine, Pow2, NTT}, + crate::{ + ntt::{intt_rn, NTTEngine}, + ntt_nr, Pow2, NTT, + }, ark_bn254::Fr, ark_ff::BigInt, proptest::collection, @@ -404,12 +419,11 @@ mod tests { { let mut s = original.clone(); - let mut engine = NTTEngine::new(); // Forward NTT - engine.ntt_nr(&mut s.0); + ntt_nr(&mut s.0); // Inverse NTT - engine.intt_rn(&mut s.0); + intt_rn(&mut s.0); prop_assert_eq!(original,s); } @@ -478,16 +492,15 @@ mod tests { fn test_interleaved((rows, columns, ntt) in interleaving_strategy(0_usize..20)) { let mut ntt = ntt.0; let mut transposed = transpose(&ntt, rows.get(), columns.get()); - let mut engine = NTTEngine::new(); for chunk in transposed.chunks_exact_mut(rows.get()){ let mut fold = NTT::new(chunk,1).unwrap(); - engine.ntt_nr(&mut fold); + ntt_nr(&mut fold); } let double_transposed = transpose(&transposed, columns.get(), rows.get()); - engine.ntt_nr(&mut ntt); + ntt_nr(&mut ntt); prop_assert!(double_transposed == ntt.into_inner()); } @@ -497,8 +510,7 @@ mod tests { // The roundtrip test doesn't test size 0. fn ntt_empty() { let mut v = NTT::new(vec![], 1).unwrap(); - let mut engine = NTTEngine::new(); - engine.ntt_nr(&mut v); + ntt_nr(&mut v); } // Compare direct generation of the roots vs. extending from a base set of roots diff --git a/provekit/common/src/ntt.rs b/provekit/common/src/ntt.rs index adf6b6cc..c03edb75 100644 --- a/provekit/common/src/ntt.rs +++ b/provekit/common/src/ntt.rs @@ -1,4 +1,4 @@ -use {ark_bn254::Fr, ark_ff::AdditiveGroup, whir::ntt::ReedSolomon}; +use {ark_bn254::Fr, ark_ff::AdditiveGroup, ntt::ntt_nr, whir::ntt::ReedSolomon}; pub struct RSFr; impl ReedSolomon for RSFr { @@ -35,8 +35,8 @@ fn interleaved_rs_encode( let mut ntt = ntt::NTT::new(result, fold_factor_exp) .expect("interleaved_coeffs.len() * expension / 2^fold_factor needs to be a power of two."); - let mut engine = ntt::NTTEngine::new(); - engine.interleaved_ntt_nr(&mut ntt); + + ntt_nr(&mut ntt); ntt.into_inner() } From e0b29f8b4f6f782fa3a81fd53c1ef8b4e3b60686 Mon Sep 17 00:00:00 2001 From: Xander van der Goot Date: Tue, 11 Nov 2025 12:10:31 +0800 Subject: [PATCH 5/6] Reed Solomon: support for WHIR's RS Trait (vtable/dyn) --- Cargo.toml | 2 +- provekit/common/src/ntt.rs | 9 +-------- provekit/prover/src/whir_r1cs.rs | 7 ++----- provekit/r1cs-compiler/src/whir_r1cs.rs | 14 ++++++++++---- 4 files changed, 14 insertions(+), 18 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index f36c9da3..5fcb0edc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -152,4 +152,4 @@ spongefish = { git = "https://github.com/arkworks-rs/spongefish", features = [ "arkworks-algebra", ] } spongefish-pow = { git = "https://github.com/arkworks-rs/spongefish" } -whir = { git = "https://github.com/xrvdg/whir", branch="xr/integration", features = ["tracing"] } +whir = { git = "https://github.com/xrvdg/whir", branch="xr/rs-send-sync", features = ["tracing"] } diff --git a/provekit/common/src/ntt.rs b/provekit/common/src/ntt.rs index c03edb75..1ececbd4 100644 --- a/provekit/common/src/ntt.rs +++ b/provekit/common/src/ntt.rs @@ -3,14 +3,7 @@ use {ark_bn254::Fr, ark_ff::AdditiveGroup, ntt::ntt_nr, whir::ntt::ReedSolomon}; pub struct RSFr; impl ReedSolomon for RSFr { fn interleaved_encode( - interleaved_coeffs: &[Fr], - expansion: usize, - fold_factor: usize, - ) -> Vec { - interleaved_rs_encode(interleaved_coeffs, expansion, fold_factor) - } - - fn interleaved_basefield_encode( + &self, interleaved_coeffs: &[Fr], expansion: usize, fold_factor: usize, diff --git a/provekit/prover/src/whir_r1cs.rs b/provekit/prover/src/whir_r1cs.rs index 80cd5911..ea2f4abe 100644 --- a/provekit/prover/src/whir_r1cs.rs +++ b/provekit/prover/src/whir_r1cs.rs @@ -22,7 +22,6 @@ use { }, tracing::{info, instrument, warn}, whir::{ - ntt::RSDefault, poly_utils::{evals::EvaluationsList, multilinear::MultilinearPoint}, whir::{ committer::{CommitmentWriter, Witness}, @@ -33,8 +32,6 @@ use { }, }; -type RS = RSDefault; - pub trait WhirR1CSProver { fn prove(&self, r1cs: R1CS, witness: Vec) -> Result; } @@ -211,7 +208,7 @@ pub fn batch_commit_to_polynomial( let committer = CommitmentWriter::new(whir_config.clone()); let witness_new = committer - .commit_batch::<_, RS>(merlin, &[ + .commit_batch(merlin, &[ &masked_polynomial_coeff, &random_polynomial_coeff, ]) @@ -494,7 +491,7 @@ pub fn run_zk_whir_pcs_prover( let prover = Prover::new(params.clone()); let (randomness, deferred) = prover - .prove::<_, RS>(&mut merlin, statement, witness) + .prove(&mut merlin, statement, witness) .expect("WHIR prover failed to generate a proof"); (merlin, randomness, deferred) diff --git a/provekit/r1cs-compiler/src/whir_r1cs.rs b/provekit/r1cs-compiler/src/whir_r1cs.rs index e08ac26f..3b3fccfc 100644 --- a/provekit/r1cs-compiler/src/whir_r1cs.rs +++ b/provekit/r1cs-compiler/src/whir_r1cs.rs @@ -1,8 +1,12 @@ use { provekit_common::{utils::next_power_of_two, FieldElement, WhirConfig, WhirR1CSScheme, R1CS}, - whir::parameters::{ - default_max_pow, DeduplicationStrategy, FoldingFactor, MerkleProofStrategy, - MultivariateParameters, ProtocolParameters, SoundnessType, + std::sync::Arc, + whir::{ + ntt::RSDefault, + parameters::{ + default_max_pow, DeduplicationStrategy, FoldingFactor, MerkleProofStrategy, + MultivariateParameters, ProtocolParameters, SoundnessType, + }, }, }; @@ -63,6 +67,8 @@ impl WhirR1CSSchemeBuilder for WhirR1CSScheme { deduplication_strategy: DeduplicationStrategy::Disabled, merkle_proof_strategy: MerkleProofStrategy::Uncompressed, }; - WhirConfig::new(mv_params, whir_params) + let reed_solomon = Arc::new(RSDefault); + let basefield_reed_solomon = reed_solomon.clone(); + WhirConfig::new(reed_solomon, basefield_reed_solomon, mv_params, whir_params) } } From 9619c1f8e9ff0e615971910df27975797dac1ff7 Mon Sep 17 00:00:00 2001 From: Xander van der Goot Date: Tue, 11 Nov 2025 14:15:25 +0800 Subject: [PATCH 6/6] arkworks integration --- Cargo.toml | 5 +++-- provekit/common/src/ntt.rs | 4 ++++ provekit/r1cs-compiler/src/whir_r1cs.rs | 13 +++++-------- 3 files changed, 12 insertions(+), 10 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 5fcb0edc..e14d5c4d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -143,7 +143,8 @@ noirc_driver = { git = "https://github.com/noir-lang/noir", rev = "v1.0.0-beta.1 ark-bn254 = { version = "0.5.0", default-features = false, features = [ "scalar_field", ] } -ark-crypto-primitives = { version = "0.5", features = ["merkle_tree"] } +# ark-crypto-primitives = { version = "0.5", features = ["merkle_tree"] } +ark-crypto-primitives = { git = "https://github.com/veljkovranic/crypto-primitives", features = ["merkle_tree"], rev = "fa0623c7fb69ce8b39a5397a5c6de1ec84151295" } ark-ff = { version = "0.5", features = ["asm", "std"] } ark-poly = "0.5" ark-serialize = "0.5" @@ -152,4 +153,4 @@ spongefish = { git = "https://github.com/arkworks-rs/spongefish", features = [ "arkworks-algebra", ] } spongefish-pow = { git = "https://github.com/arkworks-rs/spongefish" } -whir = { git = "https://github.com/xrvdg/whir", branch="xr/rs-send-sync", features = ["tracing"] } +whir = { git = "https://github.com/xrvdg/whir", branch="xr/integration-custom-arkworks", features = ["tracing"] } diff --git a/provekit/common/src/ntt.rs b/provekit/common/src/ntt.rs index 1ececbd4..c8cfaaa5 100644 --- a/provekit/common/src/ntt.rs +++ b/provekit/common/src/ntt.rs @@ -10,6 +10,10 @@ impl ReedSolomon for RSFr { ) -> Vec { interleaved_rs_encode(interleaved_coeffs, expansion, fold_factor) } + + fn is_bit_reversed(&self) -> bool { + true + } } fn interleaved_rs_encode( diff --git a/provekit/r1cs-compiler/src/whir_r1cs.rs b/provekit/r1cs-compiler/src/whir_r1cs.rs index 3b3fccfc..1411518b 100644 --- a/provekit/r1cs-compiler/src/whir_r1cs.rs +++ b/provekit/r1cs-compiler/src/whir_r1cs.rs @@ -1,12 +1,9 @@ use { - provekit_common::{utils::next_power_of_two, FieldElement, WhirConfig, WhirR1CSScheme, R1CS}, + provekit_common::{ntt::RSFr, utils::next_power_of_two, WhirConfig, WhirR1CSScheme, R1CS}, std::sync::Arc, - whir::{ - ntt::RSDefault, - parameters::{ - default_max_pow, DeduplicationStrategy, FoldingFactor, MerkleProofStrategy, - MultivariateParameters, ProtocolParameters, SoundnessType, - }, + whir::parameters::{ + default_max_pow, DeduplicationStrategy, FoldingFactor, MerkleProofStrategy, + MultivariateParameters, ProtocolParameters, SoundnessType, }, }; @@ -67,7 +64,7 @@ impl WhirR1CSSchemeBuilder for WhirR1CSScheme { deduplication_strategy: DeduplicationStrategy::Disabled, merkle_proof_strategy: MerkleProofStrategy::Uncompressed, }; - let reed_solomon = Arc::new(RSDefault); + let reed_solomon = Arc::new(RSFr); let basefield_reed_solomon = reed_solomon.clone(); WhirConfig::new(reed_solomon, basefield_reed_solomon, mv_params, whir_params) }