From 789ba3d6b15a0e74bbf24feb139a353c8c4e68f4 Mon Sep 17 00:00:00 2001 From: ashpect Date: Mon, 8 Dec 2025 12:50:14 +0530 Subject: [PATCH 01/18] feat: pubwit struct --- provekit/common/Cargo.toml | 1 + provekit/common/src/lib.rs | 1 + provekit/common/src/noir_proof_scheme.rs | 3 +- provekit/common/src/utils/mod.rs | 1 + provekit/common/src/utils/serde_ark_vec.rs | 87 ++++++++++++++++++++++ provekit/common/src/witness/mod.rs | 52 ++++++++++++- provekit/prover/src/lib.rs | 6 +- 7 files changed, 146 insertions(+), 5 deletions(-) create mode 100644 provekit/common/src/utils/serde_ark_vec.rs diff --git a/provekit/common/Cargo.toml b/provekit/common/Cargo.toml index d39db74e..20fe2897 100644 --- a/provekit/common/Cargo.toml +++ b/provekit/common/Cargo.toml @@ -39,6 +39,7 @@ ruint.workspace = true serde.workspace = true serde_json.workspace = true tracing.workspace = true +sha2.workspace = true zerocopy.workspace = true zeroize.workspace = true zstd.workspace = true diff --git a/provekit/common/src/lib.rs b/provekit/common/src/lib.rs index 680715d8..31cb4b56 100644 --- a/provekit/common/src/lib.rs +++ b/provekit/common/src/lib.rs @@ -22,6 +22,7 @@ pub use { verifier::Verifier, whir::crypto::fields::Field256 as FieldElement, whir_r1cs::{IOPattern, WhirConfig, WhirR1CSProof, WhirR1CSScheme}, + witness::PublicInputs, }; #[cfg(test)] diff --git a/provekit/common/src/noir_proof_scheme.rs b/provekit/common/src/noir_proof_scheme.rs index b8c8ff93..f7e40fd2 100644 --- a/provekit/common/src/noir_proof_scheme.rs +++ b/provekit/common/src/noir_proof_scheme.rs @@ -2,7 +2,7 @@ use { crate::{ whir_r1cs::{WhirR1CSProof, WhirR1CSScheme}, witness::{NoirWitnessGenerator, SplitWitnessBuilders}, - NoirElement, R1CS, + NoirElement, R1CS, PublicInputs, }, acir::circuit::Program, serde::{Deserialize, Serialize}, @@ -20,6 +20,7 @@ pub struct NoirProofScheme { #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub struct NoirProof { + pub public_inputs: PublicInputs, pub whir_r1cs_proof: WhirR1CSProof, } diff --git a/provekit/common/src/utils/mod.rs b/provekit/common/src/utils/mod.rs index a5f6aa5b..43943c56 100644 --- a/provekit/common/src/utils/mod.rs +++ b/provekit/common/src/utils/mod.rs @@ -2,6 +2,7 @@ mod print_abi; pub mod serde_ark; pub mod serde_ark_option; +pub mod serde_ark_vec; pub mod serde_hex; pub mod serde_jsonify; pub mod sumcheck; diff --git a/provekit/common/src/utils/serde_ark_vec.rs b/provekit/common/src/utils/serde_ark_vec.rs new file mode 100644 index 00000000..b26c309e --- /dev/null +++ b/provekit/common/src/utils/serde_ark_vec.rs @@ -0,0 +1,87 @@ +use { + crate::FieldElement, + ark_serialize::{CanonicalDeserialize, CanonicalSerialize}, + serde::{ + de::{Error as _, SeqAccess, Visitor}, + ser::{Error as _, SerializeSeq}, + Deserializer, Serializer, + }, + std::fmt, +}; + +pub fn serialize(vec: &Vec, serializer: S) -> Result +where + S: Serializer, +{ + let is_human_readable = serializer.is_human_readable(); + let mut seq = serializer.serialize_seq(Some(vec.len()))?; + for element in vec { + let mut buf = Vec::with_capacity(element.compressed_size()); + element + .serialize_compressed(&mut buf) + .map_err(|e| S::Error::custom(format!("Failed to serialize: {e}")))?; + + // Write bytes + if is_human_readable { + // ark_serialize doesn't have human-readable serialization. And Serde + // doesn't have good defaults for [u8]. So we implement hexadecimal + // serialization. + let hex = hex::encode(buf); + seq.serialize_element(&hex)?; + } else { + seq.serialize_element(&buf)?; + } + } + seq.end() +} + +pub fn deserialize<'de, D>(deserializer: D) -> Result, D::Error> +where + D: Deserializer<'de>, +{ + struct VecVisitor { + is_human_readable: bool, + } + + impl<'de> Visitor<'de> for VecVisitor { + type Value = Vec; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a sequence of field elements") + } + + fn visit_seq(self, mut seq: A) -> Result + where + A: SeqAccess<'de>, + { + let mut vec = Vec::new(); + if self.is_human_readable { + while let Some(hex) = seq.next_element::()? { + let bytes = hex::decode(hex) + .map_err(|e| A::Error::custom(format!("invalid hex: {e}")))?; + let mut reader = &*bytes; + let element = FieldElement::deserialize_compressed(&mut reader) + .map_err(|e| A::Error::custom(format!("deserialize failed: {e}")))?; + if !reader.is_empty() { + return Err(A::Error::custom("while deserializing: trailing bytes")); + } + vec.push(element); + } + } else { + while let Some(bytes) = seq.next_element::>()? { + let mut reader = &*bytes; + let element = FieldElement::deserialize_compressed(&mut reader) + .map_err(|e| A::Error::custom(format!("deserialize failed: {e}")))?; + if !reader.is_empty() { + return Err(A::Error::custom("while deserializing: trailing bytes")); + } + vec.push(element); + } + } + Ok(vec) + } + } + + let is_human_readable = deserializer.is_human_readable(); + deserializer.deserialize_seq(VecVisitor { is_human_readable }) +} diff --git a/provekit/common/src/witness/mod.rs b/provekit/common/src/witness/mod.rs index 4361a5dc..549ff34b 100644 --- a/provekit/common/src/witness/mod.rs +++ b/provekit/common/src/witness/mod.rs @@ -7,9 +7,10 @@ mod witness_generator; mod witness_io_pattern; use { - crate::{utils::serde_ark, FieldElement}, - ark_ff::One, + crate::{utils::{serde_ark, serde_ark_vec}, FieldElement}, + ark_ff::{BigInt, One, PrimeField}, serde::{Deserialize, Serialize}, + sha2::{Digest, Sha256}, }; pub use { binops::{BINOP_ATOMIC_BITS, BINOP_BITS, NUM_DIGITS}, @@ -40,3 +41,50 @@ impl ConstantOrR1CSWitness { } } } + + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct PublicInputs(#[serde(with = "serde_ark_vec")] pub Vec); + +impl PublicInputs { + /// Creates a new `PublicInputs` with a constant 1 field element at the + /// start. + pub fn new() -> Self { + Self(vec![FieldElement::one()]) + } + + /// Creates a new `PublicInputs` from a vector, adding a constant 1 field + /// element at the start. To emulate the constant 1 witness in the R1CS + /// instance. + pub fn from_vec(mut vec: Vec) -> Self { + vec.insert(0, FieldElement::one()); + Self(vec) + } + + pub fn len(&self) -> usize { + self.0.len() + } + + /// Hashes the public input values using SHA-256 and converts the result to + /// a FieldElement. + pub fn hash(&self) -> FieldElement { + let mut hasher = Sha256::new(); + + // Hash all public values from witness + for value in self.0.iter() { + let bigint = value.into_bigint(); + for limb in bigint.0.iter() { + hasher.update(&limb.to_le_bytes()); + } + } + + let result = hasher.finalize(); + + let limbs = result + .chunks_exact(8) + .map(|s| u64::from_le_bytes(s.try_into().unwrap())) + .collect::>(); + + FieldElement::new(BigInt::new(limbs.try_into().unwrap())) + } +} diff --git a/provekit/prover/src/lib.rs b/provekit/prover/src/lib.rs index 031a29dd..0fdefd4e 100644 --- a/provekit/prover/src/lib.rs +++ b/provekit/prover/src/lib.rs @@ -6,7 +6,7 @@ use { nargo::foreign_calls::DefaultForeignCallBuilder, noir_artifact_cli::fs::inputs::read_inputs_from_file, noirc_abi::InputMap, - provekit_common::{FieldElement, IOPattern, NoirElement, NoirProof, Prover}, + provekit_common::{FieldElement, IOPattern, NoirElement, NoirProof, Prover, PublicInputs}, std::path::Path, tracing::instrument, }; @@ -119,7 +119,9 @@ impl Prove for Prover { .prove(merlin, self.r1cs, commitments) .context("While proving R1CS instance")?; - Ok(NoirProof { whir_r1cs_proof }) + let public_inputs = PublicInputs::new(); + + Ok(NoirProof { public_inputs, whir_r1cs_proof }) } } From 70344a4d91b513849877478a353e087df5724658 Mon Sep 17 00:00:00 2001 From: ashpect Date: Wed, 10 Dec 2025 01:08:43 +0530 Subject: [PATCH 02/18] feat: put pub wit at starting --- .../common/src/witness/scheduling/splitter.rs | 59 ++++++++++++++++++- .../common/src/witness/witness_builder.rs | 4 +- .../r1cs-compiler/src/noir_proof_scheme.rs | 7 ++- 3 files changed, 66 insertions(+), 4 deletions(-) diff --git a/provekit/common/src/witness/scheduling/splitter.rs b/provekit/common/src/witness/scheduling/splitter.rs index 57a44367..eb990dd2 100644 --- a/provekit/common/src/witness/scheduling/splitter.rs +++ b/provekit/common/src/witness/scheduling/splitter.rs @@ -26,7 +26,7 @@ impl<'a> WitnessSplitter<'a> { /// (post-challenge). /// /// Returns (w1_builder_indices, w2_builder_indices) - pub fn split_builders(&self) -> (Vec, Vec) { + pub fn split_builders(&self, acir_public_inputs_indices_set: HashSet) -> (Vec, Vec) { let builder_count = self.witness_builders.len(); // Step 1: Find all Challenge builders @@ -40,7 +40,11 @@ impl<'a> WitnessSplitter<'a> { .collect(); if challenge_builders.is_empty() { - return ((0..builder_count).collect(), Vec::new()); + let w1_indices = self.rearrange_w1( + (0..builder_count).collect(), + &acir_public_inputs_indices_set, + ); + return (w1_indices, Vec::new()); } // Step 2: Forward DFS from challenges to find mandatory_w2 @@ -135,6 +139,7 @@ impl<'a> WitnessSplitter<'a> { // Step 7: Assign free builders greedily while respecting dependencies // Rule: if any dependency is in w2, the builder must also be in w2 // (because w1 is solved before w2) + // with the exception of public builders writing public witnesses) let mut w1_set = mandatory_w1; let mut w2_set = mandatory_w2; @@ -149,6 +154,15 @@ impl<'a> WitnessSplitter<'a> { let witness_count = DependencyInfo::extract_writes(&self.witness_builders[idx]).len(); + // If free builder writes a public witness, add it to w1_set. + if let WitnessBuilder::Acir(_, acir_idx) = &self.witness_builders[idx] { + if acir_public_inputs_indices_set.contains(&(*acir_idx as u32)) { + w1_set.insert(idx); + w1_witness_count += witness_count; + continue; + } + } + if must_be_w2 { w2_set.insert(idx); w2_witness_count += witness_count; @@ -170,4 +184,45 @@ impl<'a> WitnessSplitter<'a> { (w1_indices, w2_indices) } + + /// Rearranges w1 indices: constant builder (0) first, then public inputs, + /// then rest. + fn rearrange_w1( + &self, + w1_indices: Vec, + acir_public_inputs_indices_set: &HashSet, + ) -> Vec { + let mut public_input_builder_indices = Vec::new(); + let mut rest_indices = Vec::new(); + + // Sanity Check: Make sure all public inputs and WITNESS_ONE_IDX are in + // w1_indices. + for &idx in acir_public_inputs_indices_set.iter() { + if !w1_indices.contains(&(idx as usize)) { + panic!("Public input {} is not in w1_indices", idx); + } + } + + // Separate into: 0, public inputs, and rest + for builder_idx in w1_indices { + if builder_idx == 0 { + continue; // Will add 0 first + } else if let WitnessBuilder::Acir(_, acir_idx) = &self.witness_builders[builder_idx] { + if acir_public_inputs_indices_set.contains(&(*acir_idx as u32)) { + public_input_builder_indices.push(builder_idx); + continue; + } + } + rest_indices.push(builder_idx); + } + + public_input_builder_indices.sort_unstable(); + rest_indices.sort_unstable(); + + // Reorder: 0 first, then public inputs, then rest + let mut new_w1_indices = vec![0]; + new_w1_indices.extend(public_input_builder_indices); + new_w1_indices.extend(rest_indices); + new_w1_indices + } } diff --git a/provekit/common/src/witness/witness_builder.rs b/provekit/common/src/witness/witness_builder.rs index d212c9bc..ee248153 100644 --- a/provekit/common/src/witness/witness_builder.rs +++ b/provekit/common/src/witness/witness_builder.rs @@ -14,6 +14,7 @@ use { FieldElement, R1CS, }, serde::{Deserialize, Serialize}, + std::collections::HashSet, std::num::NonZeroU32, }; @@ -174,6 +175,7 @@ impl WitnessBuilder { witness_builders: &[WitnessBuilder], r1cs: R1CS, witness_map: Vec>, + acir_public_inputs_indices_set: HashSet, ) -> (SplitWitnessBuilders, R1CS, Vec>, usize) { if witness_builders.is_empty() { return ( @@ -190,7 +192,7 @@ impl WitnessBuilder { // Step 1: Analyze dependencies and split into w1/w2 let splitter = WitnessSplitter::new(witness_builders); - let (w1_indices, w2_indices) = splitter.split_builders(); + let (w1_indices, w2_indices) = splitter.split_builders(acir_public_inputs_indices_set); // Step 2: Extract w1 and w2 builders in order let w1_builders: Vec = w1_indices diff --git a/provekit/r1cs-compiler/src/noir_proof_scheme.rs b/provekit/r1cs-compiler/src/noir_proof_scheme.rs index 58e9346a..6c9ca322 100644 --- a/provekit/r1cs-compiler/src/noir_proof_scheme.rs +++ b/provekit/r1cs-compiler/src/noir_proof_scheme.rs @@ -12,6 +12,7 @@ use { }, std::{fs::File, path::Path}, tracing::{info, instrument}, + std::collections::HashSet, }; pub trait NoirProofSchemeBuilder { @@ -61,9 +62,13 @@ impl NoirProofSchemeBuilder for NoirProofScheme { r1cs.c.num_entries() ); + // Extract ACIR public input indices set + let acir_public_inputs_indices_set: HashSet = + main.public_inputs().indices().iter().cloned().collect(); + // Split witness builders and remap indices for sound challenge generation let (split_witness_builders, remapped_r1cs, remapped_witness_map, num_challenges) = - WitnessBuilder::split_and_prepare_layers(&witness_builders, r1cs, witness_map); + WitnessBuilder::split_and_prepare_layers(&witness_builders, r1cs, witness_map, acir_public_inputs_indices_set); info!( "Witness split: w1 size = {}, w2 size = {}", split_witness_builders.w1_size, From 978ccf7823e9d3315ca234a126c3789ba9375e32 Mon Sep 17 00:00:00 2001 From: ashpect Date: Wed, 10 Dec 2025 10:34:59 +0530 Subject: [PATCH 03/18] feat: prover addon --- provekit/common/src/utils/sumcheck.rs | 10 +++ provekit/common/src/whir_r1cs.rs | 4 ++ provekit/prover/src/lib.rs | 43 +++++++++++-- provekit/prover/src/r1cs.rs | 13 ++++ provekit/prover/src/whir_r1cs.rs | 90 +++++++++++++++++++++++---- 5 files changed, 143 insertions(+), 17 deletions(-) diff --git a/provekit/common/src/utils/sumcheck.rs b/provekit/common/src/utils/sumcheck.rs index 7e1c5a24..6baef51d 100644 --- a/provekit/common/src/utils/sumcheck.rs +++ b/provekit/common/src/utils/sumcheck.rs @@ -114,6 +114,10 @@ pub trait SumcheckIOPattern { fn add_rand(self, num_rand: usize) -> Self; fn add_zk_sumcheck_polynomials(self, num_vars: usize) -> Self; + + /// Prover sends the hash of the public inputs + /// Verifier sends randomness to construct weights + fn add_public_inputs(self) -> Self; } impl SumcheckIOPattern for IOPattern @@ -136,6 +140,12 @@ where self } + fn add_public_inputs(mut self) -> Self { + self = self.add_scalars(1, "Public Inputs Hash"); + self = self.challenge_scalars(1, "Public Weights Vector Random"); + self + } + fn add_rand(self, num_rand: usize) -> Self { self.challenge_scalars(num_rand, "rand") } diff --git a/provekit/common/src/whir_r1cs.rs b/provekit/common/src/whir_r1cs.rs index 302bfd79..9163c967 100644 --- a/provekit/common/src/whir_r1cs.rs +++ b/provekit/common/src/whir_r1cs.rs @@ -50,6 +50,8 @@ impl WhirR1CSScheme { .add_whir_proof(&self.whir_for_hiding_spartan) .hint("claimed_evaluations_1") .hint("claimed_evaluations_2") + .add_public_inputs() + .hint("public_weights_evaluations") .add_whir_batch_proof(&self.whir_witness, num_witnesses, num_constraints_total); } else { io = io @@ -59,6 +61,8 @@ impl WhirR1CSScheme { .add_zk_sumcheck_polynomials(self.m_0) .add_whir_proof(&self.whir_for_hiding_spartan) .hint("claimed_evaluations") + .add_public_inputs() + .hint("public_weights_evaluations") .add_whir_proof(&self.whir_witness); } diff --git a/provekit/prover/src/lib.rs b/provekit/prover/src/lib.rs index 0fdefd4e..7d69475b 100644 --- a/provekit/prover/src/lib.rs +++ b/provekit/prover/src/lib.rs @@ -7,6 +7,7 @@ use { noir_artifact_cli::fs::inputs::read_inputs_from_file, noirc_abi::InputMap, provekit_common::{FieldElement, IOPattern, NoirElement, NoirProof, Prover, PublicInputs}, + std::collections::{HashMap, HashSet}, std::path::Path, tracing::instrument, }; @@ -57,6 +58,15 @@ impl Prove for Prover { let acir_witness_idx_to_value_map = self.generate_witness(input_map)?; + let acir_public_inputs = self.program.functions[0].public_inputs().indices(); + let acir_public_inputs_set: HashSet = acir_public_inputs.iter().cloned().collect(); + let mut acir_to_r1cs_public_map = HashMap::new(); + + println!("DEBUG_ASH: acir_witness_idx_to_value_map: {:?}", acir_witness_idx_to_value_map); + println!("DEBUG_ASH: acir_public_inputs: {:?}", acir_public_inputs); + println!("DEBUG_ASH: acir_public_inputs_set: {:?}", acir_public_inputs_set); + println!("DEBUG_ASH: acir_to_r1cs_public_map: {:?}", acir_to_r1cs_public_map); + // Set up transcript let io: IOPattern = self.whir_for_witness.create_io_pattern(); let mut merlin = io.to_prover_state(); @@ -70,13 +80,19 @@ impl Prove for Prover { self.split_witness_builders.w1_layers, &acir_witness_idx_to_value_map, &mut merlin, + &acir_public_inputs_set, + &mut acir_to_r1cs_public_map, ); + println!("DEBUG_ASH: acir_to_r1cs_public_map after w1: {:?}", acir_to_r1cs_public_map); + + let w1 = witness[..self.whir_for_witness.w1_size] .iter() .map(|w| w.ok_or_else(|| anyhow::anyhow!("Some witnesses in w1 are missing"))) .collect::>>()?; + println!("DEBUG_ASH: w1: {:?}", w1); let commitment_1 = self .whir_for_witness .commit(&mut merlin, &self.r1cs, w1, true) @@ -90,7 +106,11 @@ impl Prove for Prover { self.split_witness_builders.w2_layers, &acir_witness_idx_to_value_map, &mut merlin, - ); + &acir_public_inputs_set, + &mut acir_to_r1cs_public_map, + ); // DEBUG_ASH : if w2 didn't have pub witness, no need honestly for this + + println!("DEBUG_ASH: acir_to_r1cs_public_map after w2: {:?}", acir_to_r1cs_public_map); let w2 = witness[self.whir_for_witness.w1_size..] .iter() @@ -112,15 +132,30 @@ impl Prove for Prover { self.r1cs .test_witness_satisfaction(&witness.iter().map(|w| w.unwrap()).collect::>()) .context("While verifying R1CS instance")?; + + // Gather public inputs from witness + let public_indices = acir_to_r1cs_public_map + .values() + .map(|&x| x) + .collect::>(); + + let public_inputs = PublicInputs::from_vec( + public_indices + .iter() + .map(|&i| { + witness[i].ok_or_else(|| anyhow::anyhow!("Missing public input witness at index {i}")) + }) + .collect::>>()?, + ); + drop(witness); let whir_r1cs_proof = self .whir_for_witness - .prove(merlin, self.r1cs, commitments) + .prove(merlin, self.r1cs, commitments, &public_inputs) .context("While proving R1CS instance")?; - let public_inputs = PublicInputs::new(); - + println!("DEBUG_ASH: public_inputs: {:?}", public_inputs); Ok(NoirProof { public_inputs, whir_r1cs_proof }) } } diff --git a/provekit/prover/src/r1cs.rs b/provekit/prover/src/r1cs.rs index 18b3ffca..cb752c26 100644 --- a/provekit/prover/src/r1cs.rs +++ b/provekit/prover/src/r1cs.rs @@ -10,6 +10,7 @@ use { FieldElement, NoirElement, R1CS, }, spongefish::ProverState, + std::collections::{HashMap, HashSet}, tracing::instrument, }; @@ -20,6 +21,8 @@ pub trait R1CSSolver { plan: LayeredWitnessBuilders, acir_map: &WitnessMap, transcript: &mut ProverState, + acir_public_inputs_set: &HashSet, + acir_to_r1cs_public_map: &mut HashMap, ); #[cfg(test)] @@ -54,12 +57,22 @@ impl R1CSSolver for R1CS { plan: LayeredWitnessBuilders, acir_map: &WitnessMap, transcript: &mut ProverState, + acir_public_inputs_set: &HashSet, + acir_to_r1cs_public_map: &mut HashMap, ) { for layer in &plan.layers { match layer.typ { LayerType::Other => { // Execute regular operations for builder in &layer.witness_builders { + + if let WitnessBuilder::Acir(r1cs_witness_idx, acir_witness_idx) = builder { + if acir_public_inputs_set.contains(&(*acir_witness_idx as u32)) { + acir_to_r1cs_public_map + .insert(*acir_witness_idx as u32, *r1cs_witness_idx); + } + } + builder.solve(&acir_map, witness, transcript); } } diff --git a/provekit/prover/src/whir_r1cs.rs b/provekit/prover/src/whir_r1cs.rs index 60bb54cd..4601105e 100644 --- a/provekit/prover/src/whir_r1cs.rs +++ b/provekit/prover/src/whir_r1cs.rs @@ -3,6 +3,7 @@ use { ark_ff::UniformRand, ark_std::{One, Zero}, provekit_common::{ + PublicInputs, skyscraper::{SkyscraperMerkleConfig, SkyscraperSponge}, utils::{ pad_to_power_of_two, @@ -54,6 +55,7 @@ pub trait WhirR1CSProver { merlin: ProverState, r1cs: R1CS, commitments: Vec, + public_inputs: &PublicInputs, ) -> Result; } @@ -121,6 +123,7 @@ impl WhirR1CSProver for WhirR1CSScheme { mut merlin: ProverState, r1cs: R1CS, mut commitments: Vec, + public_inputs: &PublicInputs, ) -> Result { ensure!(!commitments.is_empty(), "Need at least one commitment"); @@ -141,6 +144,8 @@ impl WhirR1CSProver for WhirR1CSScheme { w }; + println!("DEBUG_ASH: full_witness: {:?}", full_witness); + // First round: ZK sumcheck to reduce R1CS to weighted evaluation let alpha = run_zk_sumcheck_prover( &r1cs, @@ -159,16 +164,28 @@ impl WhirR1CSProver for WhirR1CSScheme { let commitment = commitments.into_iter().next().unwrap(); let alphas: [Vec; 3] = alphas.try_into().unwrap(); - let (statement, f_sums, g_sums) = create_combined_statement_over_two_polynomials::<3>( + let (mut statement, f_sums, g_sums) = create_combined_statement_over_two_polynomials::<3>( self.m, &commitment.commitment_to_witness, - commitment.masked_polynomial, - commitment.random_polynomial, + &commitment.masked_polynomial, + &commitment.random_polynomial, &alphas, ); merlin.hint::<(Vec, Vec)>(&(f_sums, g_sums))?; + // VERIFY the size given by self.m + let public_weight = get_public_weights(public_inputs, &mut merlin, self.m); + let (public_f_sum, public_g_sum) = update_statement_with_public_weights( + &mut statement, + &commitment.commitment_to_witness, + &commitment.masked_polynomial, + &commitment.random_polynomial, + public_weight, + ); + + let _ = merlin.hint::<(FieldElement, FieldElement)>(&(public_f_sum, public_g_sum)); + run_zk_whir_pcs_prover( commitment.commitment_to_witness, statement, @@ -197,8 +214,8 @@ impl WhirR1CSProver for WhirR1CSScheme { create_combined_statement_over_two_polynomials::<3>( self.m, &c1.commitment_to_witness, - c1.masked_polynomial, - c1.random_polynomial, + &c1.masked_polynomial, + &c1.random_polynomial, &alphas_1, ); drop(alphas_1); @@ -207,8 +224,8 @@ impl WhirR1CSProver for WhirR1CSScheme { create_combined_statement_over_two_polynomials::<3>( self.m, &c2.commitment_to_witness, - c2.masked_polynomial, - c2.random_polynomial, + &c2.masked_polynomial, + &c2.random_polynomial, &alphas_2, ); drop(alphas_2); @@ -514,8 +531,8 @@ pub fn run_zk_sumcheck_prover( create_combined_statement_over_two_polynomials::<1>( blinding_polynomial_variables + 1, &commitment_to_blinding_polynomial, - blindings_mask_polynomial, - blindings_blind_polynomial, + &blindings_mask_polynomial, + &blindings_blind_polynomial, &[expand_powers(alpha.as_slice())], ); @@ -548,8 +565,8 @@ fn expand_powers(values: &[FieldElement]) -> Vec { fn create_combined_statement_over_two_polynomials( cfg_nv: usize, witness: &Witness, - f_polynomial: EvaluationsList, - g_polynomial: EvaluationsList, + f_polynomial: &EvaluationsList, + g_polynomial: &EvaluationsList, alphas: &[Vec; N], ) -> ( Statement, @@ -579,8 +596,8 @@ fn create_combined_statement_over_two_polynomials( w_full.resize(final_len, FieldElement::zero()); let weight = Weights::linear(EvaluationsList::new(w_full)); - let f = weight.weighted_sum(&f_polynomial); - let g = weight.weighted_sum(&g_polynomial); + let f = weight.weighted_sum(f_polynomial); + let g = weight.weighted_sum(g_polynomial); statement.add_constraint(weight, f + witness.batching_randomness * g); f_sums.push(f); @@ -631,3 +648,50 @@ pub fn run_zk_whir_pcs_batch_prover( (randomness, deferred) } + + +fn update_statement_with_public_weights( + statement: &mut Statement, + witness: &Witness, + f_polynomial: &EvaluationsList, + g_polynomial: &EvaluationsList, + public_weights: Weights, +) -> (FieldElement, FieldElement) { + let f = public_weights.weighted_sum(f_polynomial); + let g = public_weights.weighted_sum(g_polynomial); + statement.add_constraint_in_front(public_weights, f + witness.batching_randomness * g); + (f, g) +} + +fn get_public_weights( + public_inputs: &PublicInputs, + merlin: &mut ProverState, + m: usize, +) -> Weights { + // Add hash to transcript + let public_inputs_hash = public_inputs.hash(); + let _ = merlin.add_scalars(&[public_inputs_hash]); + + // Get random point x + let mut x_buf = [FieldElement::zero()]; + merlin + .fill_challenge_scalars(&mut x_buf) + .expect("Failed to get challenge from Merlin"); + let x = x_buf[0]; + + let domain_size = 1 << m; + let mut public_weights = vec![FieldElement::zero(); domain_size]; + + // Set public weights for public inputs [1,x,x^2,x^3...x^n-1,0,0,0...0] + let mut current_pow = FieldElement::one(); + for (idx, _) in public_inputs.0.iter().enumerate() { + public_weights[idx] = current_pow; + current_pow = current_pow * x; + } + + Weights::geometric( + x, + public_inputs.0.len(), + EvaluationsList::new(public_weights), + ) +} From 3a471ff3cf5d650a40e682f30aebedaa86dc4e14 Mon Sep 17 00:00:00 2001 From: ashpect Date: Wed, 10 Dec 2025 10:35:35 +0530 Subject: [PATCH 04/18] feat: verifier updates --- provekit/verifier/src/lib.rs | 2 +- provekit/verifier/src/whir_r1cs.rs | 65 ++++++++++++++++++++++++++++-- 2 files changed, 62 insertions(+), 5 deletions(-) diff --git a/provekit/verifier/src/lib.rs b/provekit/verifier/src/lib.rs index abdb369f..d1ec351a 100644 --- a/provekit/verifier/src/lib.rs +++ b/provekit/verifier/src/lib.rs @@ -17,7 +17,7 @@ impl Verify for Verifier { self.whir_for_witness .take() .unwrap() - .verify(&proof.whir_r1cs_proof)?; + .verify(&proof.whir_r1cs_proof, &proof.public_inputs)?; Ok(()) } diff --git a/provekit/verifier/src/whir_r1cs.rs b/provekit/verifier/src/whir_r1cs.rs index 8a668fa0..3ae9dc93 100644 --- a/provekit/verifier/src/whir_r1cs.rs +++ b/provekit/verifier/src/whir_r1cs.rs @@ -2,6 +2,7 @@ use { anyhow::{ensure, Context, Result}, ark_std::{One, Zero}, provekit_common::{ + PublicInputs, skyscraper::SkyscraperSponge, utils::sumcheck::{calculate_eq, eval_cubic_poly}, FieldElement, WhirConfig, WhirR1CSProof, WhirR1CSScheme, @@ -29,13 +30,13 @@ pub struct DataFromSumcheckVerifier { } pub trait WhirR1CSVerifier { - fn verify(&self, proof: &WhirR1CSProof) -> Result<()>; + fn verify(&self, proof: &WhirR1CSProof, public_inputs: &PublicInputs) -> Result<()>; } impl WhirR1CSVerifier for WhirR1CSScheme { #[instrument(skip_all)] #[allow(unused)] - fn verify(&self, proof: &WhirR1CSProof) -> Result<()> { + fn verify(&self, proof: &WhirR1CSProof, public_inputs: &PublicInputs) -> Result<()> { let io = self.create_io_pattern(); let mut arthur = io.to_verifier_state(&proof.transcript); @@ -68,7 +69,7 @@ impl WhirR1CSVerifier for WhirR1CSScheme { let whir_sums_2: ([FieldElement; 3], [FieldElement; 3]) = (sums_2.0.try_into().unwrap(), sums_2.1.try_into().unwrap()); - let statement_1 = prepare_statement_for_witness_verifier::<3>( + let mut statement_1 = prepare_statement_for_witness_verifier::<3>( self.m, &parsed_commitment_1, &whir_sums_1, @@ -79,6 +80,27 @@ impl WhirR1CSVerifier for WhirR1CSScheme { &whir_sums_2, ); + let mut public_inputs_hash_buf = [FieldElement::zero()]; + arthur.fill_next_scalars(&mut public_inputs_hash_buf)?; + let expected_public_inputs_hash = public_inputs.hash(); + ensure!( + public_inputs_hash_buf[0] == expected_public_inputs_hash, + "Public inputs hash mismatch: expected {:?}, got {:?}", + expected_public_inputs_hash, + public_inputs_hash_buf[0] + ); + + let mut public_weights_vector_random_buf = [FieldElement::zero()]; + arthur.fill_challenge_scalars(&mut public_weights_vector_random_buf)?; + + let whir_pub_weights_query_answer: (FieldElement, FieldElement) = arthur.hint().unwrap(); + update_statement_for_witness_verifier( + self.m, + &mut statement_1, + &parsed_commitment_1, + whir_pub_weights_query_answer, + ); + run_whir_pcs_batch_verifier( &mut arthur, &self.whir_witness, @@ -98,12 +120,33 @@ impl WhirR1CSVerifier for WhirR1CSScheme { let whir_sums: ([FieldElement; 3], [FieldElement; 3]) = (sums.0.try_into().unwrap(), sums.1.try_into().unwrap()); - let statement = prepare_statement_for_witness_verifier::<3>( + let mut statement = prepare_statement_for_witness_verifier::<3>( self.m, &parsed_commitment_1, &whir_sums, ); + let mut public_inputs_hash_buf = [FieldElement::zero()]; + arthur.fill_next_scalars(&mut public_inputs_hash_buf)?; + let expected_public_inputs_hash = public_inputs.hash(); + ensure!( + public_inputs_hash_buf[0] == expected_public_inputs_hash, + "Public inputs hash mismatch: expected {:?}, got {:?}", + expected_public_inputs_hash, + public_inputs_hash_buf[0] + ); + + let mut public_weights_vector_random_buf = [FieldElement::zero()]; + arthur.fill_challenge_scalars(&mut public_weights_vector_random_buf)?; + + let whir_pub_weights_query_answer: (FieldElement, FieldElement) = arthur.hint().unwrap(); + update_statement_for_witness_verifier( + self.m, + &mut statement, + &parsed_commitment_1, + whir_pub_weights_query_answer, + ); + run_whir_pcs_verifier( &mut arthur, &parsed_commitment_1, @@ -147,6 +190,20 @@ fn prepare_statement_for_witness_verifier( statement_verifier } +fn update_statement_for_witness_verifier( + m: usize, + statement_verifier: &mut Statement, + parsed_commitment: &ParsedCommitment, + whir_public_weights_query_answer: (FieldElement, FieldElement), +) { + let (public_f_sum, public_g_sum) = whir_public_weights_query_answer; + let public_weight = Weights::linear(EvaluationsList::new(vec![FieldElement::zero(); 1 << m])); + statement_verifier.add_constraint_in_front( + public_weight, + public_f_sum + public_g_sum * parsed_commitment.batching_randomness, + ); +} + #[instrument(skip_all)] pub fn run_sumcheck_verifier( arthur: &mut VerifierState, From a0ab7e6fe3577335684160c703762d5d09fd1a17 Mon Sep 17 00:00:00 2001 From: ashpect Date: Wed, 10 Dec 2025 12:28:55 +0530 Subject: [PATCH 05/18] feat: patch batch_prove --- provekit/prover/src/whir_r1cs.rs | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/provekit/prover/src/whir_r1cs.rs b/provekit/prover/src/whir_r1cs.rs index 4601105e..9c7a00fe 100644 --- a/provekit/prover/src/whir_r1cs.rs +++ b/provekit/prover/src/whir_r1cs.rs @@ -210,7 +210,7 @@ impl WhirR1CSProver for WhirR1CSScheme { let alphas_1: [Vec; 3] = alphas_1.try_into().unwrap(); let alphas_2: [Vec; 3] = alphas_2.try_into().unwrap(); - let (statement_1, f_sums_1, g_sums_1) = + let (mut statement_1, f_sums_1, g_sums_1) = create_combined_statement_over_two_polynomials::<3>( self.m, &c1.commitment_to_witness, @@ -233,6 +233,18 @@ impl WhirR1CSProver for WhirR1CSScheme { merlin.hint::<(Vec, Vec)>(&(f_sums_1, g_sums_1))?; merlin.hint::<(Vec, Vec)>(&(f_sums_2, g_sums_2))?; + // VERIFY the size given by self.m + let public_weight = get_public_weights(public_inputs, &mut merlin, self.m); + let (public_f_sum, public_g_sum) = update_statement_with_public_weights( + &mut statement_1, + &c1.commitment_to_witness, + &c1.masked_polynomial, + &c1.random_polynomial, + public_weight, + ); + + let _ = merlin.hint::<(FieldElement, FieldElement)>(&(public_f_sum, public_g_sum)); + run_zk_whir_pcs_batch_prover( &[c1.commitment_to_witness, c2.commitment_to_witness], &[statement_1, statement_2], From 23f99295603f77bc785a6d38ec6898e9f5651c39 Mon Sep 17 00:00:00 2001 From: ashpect Date: Thu, 11 Dec 2025 16:20:51 +0530 Subject: [PATCH 06/18] fix: rearrange for w2 --- provekit/common/src/whir_r1cs.rs | 4 ++-- provekit/common/src/witness/scheduling/splitter.rs | 4 +++- provekit/common/src/witness/witness_builder.rs | 3 +++ provekit/prover/src/whir_r1cs.rs | 4 ++-- 4 files changed, 10 insertions(+), 5 deletions(-) diff --git a/provekit/common/src/whir_r1cs.rs b/provekit/common/src/whir_r1cs.rs index 9163c967..572fa562 100644 --- a/provekit/common/src/whir_r1cs.rs +++ b/provekit/common/src/whir_r1cs.rs @@ -34,10 +34,10 @@ impl WhirR1CSScheme { if self.num_challenges > 0 { // Compute total constraints: OOD + statement // OOD: 2 witnesses × committment_ood_samples each - // Statement: 2 statements × 3 constraints each = 6 + // Statement: statement_1 has 3 constraints + 1 public weights constraint = 4, statement_2 has 3 = 3, total = 7 let num_witnesses = 2; let num_ood_constraints = num_witnesses * self.whir_witness.committment_ood_samples; - let num_statement_constraints = 6; // 2 statements × 3 constraints + let num_statement_constraints = 7; // (3+1) + (3) let num_constraints_total = num_ood_constraints + num_statement_constraints; io = io diff --git a/provekit/common/src/witness/scheduling/splitter.rs b/provekit/common/src/witness/scheduling/splitter.rs index eb990dd2..b7223dc2 100644 --- a/provekit/common/src/witness/scheduling/splitter.rs +++ b/provekit/common/src/witness/scheduling/splitter.rs @@ -157,6 +157,7 @@ impl<'a> WitnessSplitter<'a> { // If free builder writes a public witness, add it to w1_set. if let WitnessBuilder::Acir(_, acir_idx) = &self.witness_builders[idx] { if acir_public_inputs_indices_set.contains(&(*acir_idx as u32)) { + println!("DEBUG_ASH w2 exists: acir_idx: {:?}, idx: {:?}", acir_idx, idx); w1_set.insert(idx); w1_witness_count += witness_count; continue; @@ -179,7 +180,7 @@ impl<'a> WitnessSplitter<'a> { let mut w1_indices: Vec = w1_set.into_iter().collect(); let mut w2_indices: Vec = w2_set.into_iter().collect(); - w1_indices.sort_unstable(); + w1_indices = self.rearrange_w1(w1_indices, &acir_public_inputs_indices_set); w2_indices.sort_unstable(); (w1_indices, w2_indices) @@ -209,6 +210,7 @@ impl<'a> WitnessSplitter<'a> { continue; // Will add 0 first } else if let WitnessBuilder::Acir(_, acir_idx) = &self.witness_builders[builder_idx] { if acir_public_inputs_indices_set.contains(&(*acir_idx as u32)) { + println!("DEBUG_ASH: acir_idx: {:?}, builder_idx: {:?}", acir_idx, builder_idx); public_input_builder_indices.push(builder_idx); continue; } diff --git a/provekit/common/src/witness/witness_builder.rs b/provekit/common/src/witness/witness_builder.rs index ee248153..f42851f1 100644 --- a/provekit/common/src/witness/witness_builder.rs +++ b/provekit/common/src/witness/witness_builder.rs @@ -194,6 +194,9 @@ impl WitnessBuilder { let splitter = WitnessSplitter::new(witness_builders); let (w1_indices, w2_indices) = splitter.split_builders(acir_public_inputs_indices_set); + println!("Dx {:?}", w1_indices); + println!("DEBUG_ASH: w2_indices: {:?}", w2_indices); + // Step 2: Extract w1 and w2 builders in order let w1_builders: Vec = w1_indices .iter() diff --git a/provekit/prover/src/whir_r1cs.rs b/provekit/prover/src/whir_r1cs.rs index 9c7a00fe..751acd42 100644 --- a/provekit/prover/src/whir_r1cs.rs +++ b/provekit/prover/src/whir_r1cs.rs @@ -144,8 +144,6 @@ impl WhirR1CSProver for WhirR1CSScheme { w }; - println!("DEBUG_ASH: full_witness: {:?}", full_witness); - // First round: ZK sumcheck to reduce R1CS to weighted evaluation let alpha = run_zk_sumcheck_prover( &r1cs, @@ -210,6 +208,8 @@ impl WhirR1CSProver for WhirR1CSScheme { let alphas_1: [Vec; 3] = alphas_1.try_into().unwrap(); let alphas_2: [Vec; 3] = alphas_2.try_into().unwrap(); + println!("DEBUG_ASH: &c1.parsed_commitment: {:?}", &c1.padded_witness); + let (mut statement_1, f_sums_1, g_sums_1) = create_combined_statement_over_two_polynomials::<3>( self.m, From f5934209c7e987d37675352fc6c9fccb357f81ff Mon Sep 17 00:00:00 2001 From: ashpect Date: Fri, 12 Dec 2025 12:20:14 +0530 Subject: [PATCH 07/18] fix: reduce reduendancy in prove --- provekit/common/src/witness/mod.rs | 5 +++++ provekit/prover/src/lib.rs | 36 +++++------------------------- provekit/prover/src/r1cs.rs | 12 ---------- 3 files changed, 11 insertions(+), 42 deletions(-) diff --git a/provekit/common/src/witness/mod.rs b/provekit/common/src/witness/mod.rs index 549ff34b..0b2ec1de 100644 --- a/provekit/common/src/witness/mod.rs +++ b/provekit/common/src/witness/mod.rs @@ -61,6 +61,11 @@ impl PublicInputs { Self(vec) } + /// Assuming the given vector already has a constant 1 field element at the start. + pub fn from_vec_with_constant_one(vec: Vec) -> Self { + Self(vec) + } + pub fn len(&self) -> usize { self.0.len() } diff --git a/provekit/prover/src/lib.rs b/provekit/prover/src/lib.rs index 7d69475b..2bd0639f 100644 --- a/provekit/prover/src/lib.rs +++ b/provekit/prover/src/lib.rs @@ -57,15 +57,7 @@ impl Prove for Prover { read_inputs_from_file(prover_toml.as_ref(), self.witness_generator.abi())?; let acir_witness_idx_to_value_map = self.generate_witness(input_map)?; - let acir_public_inputs = self.program.functions[0].public_inputs().indices(); - let acir_public_inputs_set: HashSet = acir_public_inputs.iter().cloned().collect(); - let mut acir_to_r1cs_public_map = HashMap::new(); - - println!("DEBUG_ASH: acir_witness_idx_to_value_map: {:?}", acir_witness_idx_to_value_map); - println!("DEBUG_ASH: acir_public_inputs: {:?}", acir_public_inputs); - println!("DEBUG_ASH: acir_public_inputs_set: {:?}", acir_public_inputs_set); - println!("DEBUG_ASH: acir_to_r1cs_public_map: {:?}", acir_to_r1cs_public_map); // Set up transcript let io: IOPattern = self.whir_for_witness.create_io_pattern(); @@ -80,13 +72,8 @@ impl Prove for Prover { self.split_witness_builders.w1_layers, &acir_witness_idx_to_value_map, &mut merlin, - &acir_public_inputs_set, - &mut acir_to_r1cs_public_map, ); - println!("DEBUG_ASH: acir_to_r1cs_public_map after w1: {:?}", acir_to_r1cs_public_map); - - let w1 = witness[..self.whir_for_witness.w1_size] .iter() .map(|w| w.ok_or_else(|| anyhow::anyhow!("Some witnesses in w1 are missing"))) @@ -106,11 +93,7 @@ impl Prove for Prover { self.split_witness_builders.w2_layers, &acir_witness_idx_to_value_map, &mut merlin, - &acir_public_inputs_set, - &mut acir_to_r1cs_public_map, - ); // DEBUG_ASH : if w2 didn't have pub witness, no need honestly for this - - println!("DEBUG_ASH: acir_to_r1cs_public_map after w2: {:?}", acir_to_r1cs_public_map); + ); let w2 = witness[self.whir_for_witness.w1_size..] .iter() @@ -133,21 +116,14 @@ impl Prove for Prover { .test_witness_satisfaction(&witness.iter().map(|w| w.unwrap()).collect::>()) .context("While verifying R1CS instance")?; - // Gather public inputs from witness - let public_indices = acir_to_r1cs_public_map - .values() - .map(|&x| x) - .collect::>(); - - let public_inputs = PublicInputs::from_vec( - public_indices + // Gather public inputs from witness + let num_public_inputs = acir_public_inputs.len(); + let public_inputs = PublicInputs::from_vec_with_constant_one( + witness[0..=num_public_inputs] .iter() - .map(|&i| { - witness[i].ok_or_else(|| anyhow::anyhow!("Missing public input witness at index {i}")) - }) + .map(|w| w.ok_or_else(|| anyhow::anyhow!("Missing public input witness"))) .collect::>>()?, ); - drop(witness); let whir_r1cs_proof = self diff --git a/provekit/prover/src/r1cs.rs b/provekit/prover/src/r1cs.rs index cb752c26..967c258d 100644 --- a/provekit/prover/src/r1cs.rs +++ b/provekit/prover/src/r1cs.rs @@ -21,8 +21,6 @@ pub trait R1CSSolver { plan: LayeredWitnessBuilders, acir_map: &WitnessMap, transcript: &mut ProverState, - acir_public_inputs_set: &HashSet, - acir_to_r1cs_public_map: &mut HashMap, ); #[cfg(test)] @@ -57,22 +55,12 @@ impl R1CSSolver for R1CS { plan: LayeredWitnessBuilders, acir_map: &WitnessMap, transcript: &mut ProverState, - acir_public_inputs_set: &HashSet, - acir_to_r1cs_public_map: &mut HashMap, ) { for layer in &plan.layers { match layer.typ { LayerType::Other => { // Execute regular operations for builder in &layer.witness_builders { - - if let WitnessBuilder::Acir(r1cs_witness_idx, acir_witness_idx) = builder { - if acir_public_inputs_set.contains(&(*acir_witness_idx as u32)) { - acir_to_r1cs_public_map - .insert(*acir_witness_idx as u32, *r1cs_witness_idx); - } - } - builder.solve(&acir_map, witness, transcript); } } From 801a04058ed506a1e9a4632a44a023ddf54e05fd Mon Sep 17 00:00:00 2001 From: ashpect Date: Fri, 12 Dec 2025 13:12:55 +0530 Subject: [PATCH 08/18] chore: cleanup, remove logging and fmt --- provekit/common/src/noir_proof_scheme.rs | 2 +- provekit/common/src/whir_r1cs.rs | 3 ++- provekit/common/src/witness/mod.rs | 9 ++++++--- .../common/src/witness/scheduling/splitter.rs | 11 ++++++----- provekit/common/src/witness/witness_builder.rs | 6 +----- provekit/prover/src/lib.rs | 17 ++++++++++------- provekit/prover/src/r1cs.rs | 1 - provekit/prover/src/whir_r1cs.rs | 12 ++++-------- provekit/r1cs-compiler/src/noir_proof_scheme.rs | 12 ++++++++---- provekit/verifier/src/whir_r1cs.rs | 17 +++++++++-------- 10 files changed, 47 insertions(+), 43 deletions(-) diff --git a/provekit/common/src/noir_proof_scheme.rs b/provekit/common/src/noir_proof_scheme.rs index f7e40fd2..7552ab26 100644 --- a/provekit/common/src/noir_proof_scheme.rs +++ b/provekit/common/src/noir_proof_scheme.rs @@ -2,7 +2,7 @@ use { crate::{ whir_r1cs::{WhirR1CSProof, WhirR1CSScheme}, witness::{NoirWitnessGenerator, SplitWitnessBuilders}, - NoirElement, R1CS, PublicInputs, + NoirElement, PublicInputs, R1CS, }, acir::circuit::Program, serde::{Deserialize, Serialize}, diff --git a/provekit/common/src/whir_r1cs.rs b/provekit/common/src/whir_r1cs.rs index 572fa562..43aed2c2 100644 --- a/provekit/common/src/whir_r1cs.rs +++ b/provekit/common/src/whir_r1cs.rs @@ -34,7 +34,8 @@ impl WhirR1CSScheme { if self.num_challenges > 0 { // Compute total constraints: OOD + statement // OOD: 2 witnesses × committment_ood_samples each - // Statement: statement_1 has 3 constraints + 1 public weights constraint = 4, statement_2 has 3 = 3, total = 7 + // Statement: statement_1 has 3 constraints + 1 public weights constraint = 4, + // statement_2 has 3 = 3, total = 7 let num_witnesses = 2; let num_ood_constraints = num_witnesses * self.whir_witness.committment_ood_samples; let num_statement_constraints = 7; // (3+1) + (3) diff --git a/provekit/common/src/witness/mod.rs b/provekit/common/src/witness/mod.rs index 0b2ec1de..69481720 100644 --- a/provekit/common/src/witness/mod.rs +++ b/provekit/common/src/witness/mod.rs @@ -7,7 +7,10 @@ mod witness_generator; mod witness_io_pattern; use { - crate::{utils::{serde_ark, serde_ark_vec}, FieldElement}, + crate::{ + utils::{serde_ark, serde_ark_vec}, + FieldElement, + }, ark_ff::{BigInt, One, PrimeField}, serde::{Deserialize, Serialize}, sha2::{Digest, Sha256}, @@ -42,7 +45,6 @@ impl ConstantOrR1CSWitness { } } - #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub struct PublicInputs(#[serde(with = "serde_ark_vec")] pub Vec); @@ -61,7 +63,8 @@ impl PublicInputs { Self(vec) } - /// Assuming the given vector already has a constant 1 field element at the start. + /// Assuming the given vector already has a constant 1 field element at the + /// start. pub fn from_vec_with_constant_one(vec: Vec) -> Self { Self(vec) } diff --git a/provekit/common/src/witness/scheduling/splitter.rs b/provekit/common/src/witness/scheduling/splitter.rs index b7223dc2..2a30e351 100644 --- a/provekit/common/src/witness/scheduling/splitter.rs +++ b/provekit/common/src/witness/scheduling/splitter.rs @@ -26,7 +26,10 @@ impl<'a> WitnessSplitter<'a> { /// (post-challenge). /// /// Returns (w1_builder_indices, w2_builder_indices) - pub fn split_builders(&self, acir_public_inputs_indices_set: HashSet) -> (Vec, Vec) { + pub fn split_builders( + &self, + acir_public_inputs_indices_set: HashSet, + ) -> (Vec, Vec) { let builder_count = self.witness_builders.len(); // Step 1: Find all Challenge builders @@ -154,10 +157,9 @@ impl<'a> WitnessSplitter<'a> { let witness_count = DependencyInfo::extract_writes(&self.witness_builders[idx]).len(); - // If free builder writes a public witness, add it to w1_set. + // If free builder writes a public witness, add it to w1_set. if let WitnessBuilder::Acir(_, acir_idx) = &self.witness_builders[idx] { if acir_public_inputs_indices_set.contains(&(*acir_idx as u32)) { - println!("DEBUG_ASH w2 exists: acir_idx: {:?}, idx: {:?}", acir_idx, idx); w1_set.insert(idx); w1_witness_count += witness_count; continue; @@ -186,7 +188,7 @@ impl<'a> WitnessSplitter<'a> { (w1_indices, w2_indices) } - /// Rearranges w1 indices: constant builder (0) first, then public inputs, + /// Rearranges w1 indices: constant builder (0) first, then public inputs, /// then rest. fn rearrange_w1( &self, @@ -210,7 +212,6 @@ impl<'a> WitnessSplitter<'a> { continue; // Will add 0 first } else if let WitnessBuilder::Acir(_, acir_idx) = &self.witness_builders[builder_idx] { if acir_public_inputs_indices_set.contains(&(*acir_idx as u32)) { - println!("DEBUG_ASH: acir_idx: {:?}, builder_idx: {:?}", acir_idx, builder_idx); public_input_builder_indices.push(builder_idx); continue; } diff --git a/provekit/common/src/witness/witness_builder.rs b/provekit/common/src/witness/witness_builder.rs index f42851f1..9567eac7 100644 --- a/provekit/common/src/witness/witness_builder.rs +++ b/provekit/common/src/witness/witness_builder.rs @@ -14,8 +14,7 @@ use { FieldElement, R1CS, }, serde::{Deserialize, Serialize}, - std::collections::HashSet, - std::num::NonZeroU32, + std::{collections::HashSet, num::NonZeroU32}, }; #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] @@ -194,9 +193,6 @@ impl WitnessBuilder { let splitter = WitnessSplitter::new(witness_builders); let (w1_indices, w2_indices) = splitter.split_builders(acir_public_inputs_indices_set); - println!("Dx {:?}", w1_indices); - println!("DEBUG_ASH: w2_indices: {:?}", w2_indices); - // Step 2: Extract w1 and w2 builders in order let w1_builders: Vec = w1_indices .iter() diff --git a/provekit/prover/src/lib.rs b/provekit/prover/src/lib.rs index 2bd0639f..d7e8f846 100644 --- a/provekit/prover/src/lib.rs +++ b/provekit/prover/src/lib.rs @@ -7,8 +7,10 @@ use { noir_artifact_cli::fs::inputs::read_inputs_from_file, noirc_abi::InputMap, provekit_common::{FieldElement, IOPattern, NoirElement, NoirProof, Prover, PublicInputs}, - std::collections::{HashMap, HashSet}, - std::path::Path, + std::{ + collections::{HashMap, HashSet}, + path::Path, + }, tracing::instrument, }; @@ -79,7 +81,6 @@ impl Prove for Prover { .map(|w| w.ok_or_else(|| anyhow::anyhow!("Some witnesses in w1 are missing"))) .collect::>>()?; - println!("DEBUG_ASH: w1: {:?}", w1); let commitment_1 = self .whir_for_witness .commit(&mut merlin, &self.r1cs, w1, true) @@ -93,7 +94,7 @@ impl Prove for Prover { self.split_witness_builders.w2_layers, &acir_witness_idx_to_value_map, &mut merlin, - ); + ); let w2 = witness[self.whir_for_witness.w1_size..] .iter() @@ -116,7 +117,7 @@ impl Prove for Prover { .test_witness_satisfaction(&witness.iter().map(|w| w.unwrap()).collect::>()) .context("While verifying R1CS instance")?; - // Gather public inputs from witness + // Gather public inputs from witness let num_public_inputs = acir_public_inputs.len(); let public_inputs = PublicInputs::from_vec_with_constant_one( witness[0..=num_public_inputs] @@ -131,8 +132,10 @@ impl Prove for Prover { .prove(merlin, self.r1cs, commitments, &public_inputs) .context("While proving R1CS instance")?; - println!("DEBUG_ASH: public_inputs: {:?}", public_inputs); - Ok(NoirProof { public_inputs, whir_r1cs_proof }) + Ok(NoirProof { + public_inputs, + whir_r1cs_proof, + }) } } diff --git a/provekit/prover/src/r1cs.rs b/provekit/prover/src/r1cs.rs index 967c258d..18b3ffca 100644 --- a/provekit/prover/src/r1cs.rs +++ b/provekit/prover/src/r1cs.rs @@ -10,7 +10,6 @@ use { FieldElement, NoirElement, R1CS, }, spongefish::ProverState, - std::collections::{HashMap, HashSet}, tracing::instrument, }; diff --git a/provekit/prover/src/whir_r1cs.rs b/provekit/prover/src/whir_r1cs.rs index 751acd42..a0494d59 100644 --- a/provekit/prover/src/whir_r1cs.rs +++ b/provekit/prover/src/whir_r1cs.rs @@ -3,7 +3,6 @@ use { ark_ff::UniformRand, ark_std::{One, Zero}, provekit_common::{ - PublicInputs, skyscraper::{SkyscraperMerkleConfig, SkyscraperSponge}, utils::{ pad_to_power_of_two, @@ -15,7 +14,7 @@ use { zk_utils::{create_masked_polynomial, generate_random_multilinear_polynomial}, HALF, }, - FieldElement, WhirConfig, WhirR1CSProof, WhirR1CSScheme, R1CS, + FieldElement, PublicInputs, WhirConfig, WhirR1CSProof, WhirR1CSScheme, R1CS, }, spongefish::{ codecs::arkworks_algebra::{FieldToUnitSerialize, UnitToField}, @@ -181,9 +180,9 @@ impl WhirR1CSProver for WhirR1CSScheme { &commitment.random_polynomial, public_weight, ); - + let _ = merlin.hint::<(FieldElement, FieldElement)>(&(public_f_sum, public_g_sum)); - + run_zk_whir_pcs_prover( commitment.commitment_to_witness, statement, @@ -208,8 +207,6 @@ impl WhirR1CSProver for WhirR1CSScheme { let alphas_1: [Vec; 3] = alphas_1.try_into().unwrap(); let alphas_2: [Vec; 3] = alphas_2.try_into().unwrap(); - println!("DEBUG_ASH: &c1.parsed_commitment: {:?}", &c1.padded_witness); - let (mut statement_1, f_sums_1, g_sums_1) = create_combined_statement_over_two_polynomials::<3>( self.m, @@ -242,7 +239,7 @@ impl WhirR1CSProver for WhirR1CSScheme { &c1.random_polynomial, public_weight, ); - + let _ = merlin.hint::<(FieldElement, FieldElement)>(&(public_f_sum, public_g_sum)); run_zk_whir_pcs_batch_prover( @@ -661,7 +658,6 @@ pub fn run_zk_whir_pcs_batch_prover( (randomness, deferred) } - fn update_statement_with_public_weights( statement: &mut Statement, witness: &Witness, diff --git a/provekit/r1cs-compiler/src/noir_proof_scheme.rs b/provekit/r1cs-compiler/src/noir_proof_scheme.rs index 6c9ca322..0cde2f1b 100644 --- a/provekit/r1cs-compiler/src/noir_proof_scheme.rs +++ b/provekit/r1cs-compiler/src/noir_proof_scheme.rs @@ -10,9 +10,8 @@ use { witness::{NoirWitnessGenerator, WitnessBuilder}, NoirProofScheme, WhirR1CSScheme, }, - std::{fs::File, path::Path}, + std::{collections::HashSet, fs::File, path::Path}, tracing::{info, instrument}, - std::collections::HashSet, }; pub trait NoirProofSchemeBuilder { @@ -64,11 +63,16 @@ impl NoirProofSchemeBuilder for NoirProofScheme { // Extract ACIR public input indices set let acir_public_inputs_indices_set: HashSet = - main.public_inputs().indices().iter().cloned().collect(); + main.public_inputs().indices().iter().cloned().collect(); // Split witness builders and remap indices for sound challenge generation let (split_witness_builders, remapped_r1cs, remapped_witness_map, num_challenges) = - WitnessBuilder::split_and_prepare_layers(&witness_builders, r1cs, witness_map, acir_public_inputs_indices_set); + WitnessBuilder::split_and_prepare_layers( + &witness_builders, + r1cs, + witness_map, + acir_public_inputs_indices_set, + ); info!( "Witness split: w1 size = {}, w2 size = {}", split_witness_builders.w1_size, diff --git a/provekit/verifier/src/whir_r1cs.rs b/provekit/verifier/src/whir_r1cs.rs index 3ae9dc93..22296aae 100644 --- a/provekit/verifier/src/whir_r1cs.rs +++ b/provekit/verifier/src/whir_r1cs.rs @@ -2,10 +2,9 @@ use { anyhow::{ensure, Context, Result}, ark_std::{One, Zero}, provekit_common::{ - PublicInputs, skyscraper::SkyscraperSponge, utils::sumcheck::{calculate_eq, eval_cubic_poly}, - FieldElement, WhirConfig, WhirR1CSProof, WhirR1CSScheme, + FieldElement, PublicInputs, WhirConfig, WhirR1CSProof, WhirR1CSScheme, }, spongefish::{ codecs::arkworks_algebra::{FieldToUnitDeserialize, UnitToField}, @@ -89,11 +88,12 @@ impl WhirR1CSVerifier for WhirR1CSScheme { expected_public_inputs_hash, public_inputs_hash_buf[0] ); - + let mut public_weights_vector_random_buf = [FieldElement::zero()]; arthur.fill_challenge_scalars(&mut public_weights_vector_random_buf)?; - - let whir_pub_weights_query_answer: (FieldElement, FieldElement) = arthur.hint().unwrap(); + + let whir_pub_weights_query_answer: (FieldElement, FieldElement) = + arthur.hint().unwrap(); update_statement_for_witness_verifier( self.m, &mut statement_1, @@ -135,11 +135,12 @@ impl WhirR1CSVerifier for WhirR1CSScheme { expected_public_inputs_hash, public_inputs_hash_buf[0] ); - + let mut public_weights_vector_random_buf = [FieldElement::zero()]; arthur.fill_challenge_scalars(&mut public_weights_vector_random_buf)?; - - let whir_pub_weights_query_answer: (FieldElement, FieldElement) = arthur.hint().unwrap(); + + let whir_pub_weights_query_answer: (FieldElement, FieldElement) = + arthur.hint().unwrap(); update_statement_for_witness_verifier( self.m, &mut statement, From 54c49f2b7ae32573c1ef39e8577c55f4aea94e00 Mon Sep 17 00:00:00 2001 From: ashpect Date: Wed, 24 Dec 2025 19:20:01 +0530 Subject: [PATCH 09/18] feat: fix ordering and hashing --- provekit/common/src/witness/mod.rs | 18 ++------ .../common/src/witness/scheduling/splitter.rs | 1 - provekit/prover/src/lib.rs | 21 ++++----- provekit/prover/src/whir_r1cs.rs | 45 ++++++++++++------- provekit/verifier/src/whir_r1cs.rs | 29 +++++++----- 5 files changed, 62 insertions(+), 52 deletions(-) diff --git a/provekit/common/src/witness/mod.rs b/provekit/common/src/witness/mod.rs index 69481720..ae139032 100644 --- a/provekit/common/src/witness/mod.rs +++ b/provekit/common/src/witness/mod.rs @@ -49,23 +49,13 @@ impl ConstantOrR1CSWitness { pub struct PublicInputs(#[serde(with = "serde_ark_vec")] pub Vec); impl PublicInputs { - /// Creates a new `PublicInputs` with a constant 1 field element at the - /// start. + /// Creates a new `PublicInputs` with an empty vector. pub fn new() -> Self { - Self(vec![FieldElement::one()]) + Self(Vec::new()) } - /// Creates a new `PublicInputs` from a vector, adding a constant 1 field - /// element at the start. To emulate the constant 1 witness in the R1CS - /// instance. - pub fn from_vec(mut vec: Vec) -> Self { - vec.insert(0, FieldElement::one()); - Self(vec) - } - - /// Assuming the given vector already has a constant 1 field element at the - /// start. - pub fn from_vec_with_constant_one(vec: Vec) -> Self { + /// Creates a new `PublicInputs` from a vector. + pub fn from_vec(vec: Vec) -> Self { Self(vec) } diff --git a/provekit/common/src/witness/scheduling/splitter.rs b/provekit/common/src/witness/scheduling/splitter.rs index 2a30e351..2dbd2ba9 100644 --- a/provekit/common/src/witness/scheduling/splitter.rs +++ b/provekit/common/src/witness/scheduling/splitter.rs @@ -219,7 +219,6 @@ impl<'a> WitnessSplitter<'a> { rest_indices.push(builder_idx); } - public_input_builder_indices.sort_unstable(); rest_indices.sort_unstable(); // Reorder: 0 first, then public inputs, then rest diff --git a/provekit/prover/src/lib.rs b/provekit/prover/src/lib.rs index d7e8f846..bb89b790 100644 --- a/provekit/prover/src/lib.rs +++ b/provekit/prover/src/lib.rs @@ -7,10 +7,7 @@ use { noir_artifact_cli::fs::inputs::read_inputs_from_file, noirc_abi::InputMap, provekit_common::{FieldElement, IOPattern, NoirElement, NoirProof, Prover, PublicInputs}, - std::{ - collections::{HashMap, HashSet}, - path::Path, - }, + std::path::Path, tracing::instrument, }; @@ -119,12 +116,16 @@ impl Prove for Prover { // Gather public inputs from witness let num_public_inputs = acir_public_inputs.len(); - let public_inputs = PublicInputs::from_vec_with_constant_one( - witness[0..=num_public_inputs] - .iter() - .map(|w| w.ok_or_else(|| anyhow::anyhow!("Missing public input witness"))) - .collect::>>()?, - ); + let public_inputs = if num_public_inputs == 0 { + PublicInputs::new() + } else { + PublicInputs::from_vec( + witness[1..=num_public_inputs] + .iter() + .map(|w| w.ok_or_else(|| anyhow::anyhow!("Missing public input witness"))) + .collect::>>()?, + ) + }; drop(witness); let whir_r1cs_proof = self diff --git a/provekit/prover/src/whir_r1cs.rs b/provekit/prover/src/whir_r1cs.rs index a0494d59..9813cf50 100644 --- a/provekit/prover/src/whir_r1cs.rs +++ b/provekit/prover/src/whir_r1cs.rs @@ -173,13 +173,21 @@ impl WhirR1CSProver for WhirR1CSScheme { // VERIFY the size given by self.m let public_weight = get_public_weights(public_inputs, &mut merlin, self.m); - let (public_f_sum, public_g_sum) = update_statement_with_public_weights( - &mut statement, - &commitment.commitment_to_witness, - &commitment.masked_polynomial, - &commitment.random_polynomial, - public_weight, - ); + let (public_f_sum, public_g_sum) = if public_inputs.len() == 0 { + // If there are no public inputs, the hint is unused by the verifier and can be + // assigned an arbitrary value. + let public_f_sum = FieldElement::zero(); + let public_g_sum = FieldElement::zero(); + (public_f_sum, public_g_sum) + } else { + update_statement_with_public_weights( + &mut statement, + &commitment.commitment_to_witness, + &commitment.masked_polynomial, + &commitment.random_polynomial, + public_weight, + ) + }; let _ = merlin.hint::<(FieldElement, FieldElement)>(&(public_f_sum, public_g_sum)); @@ -230,15 +238,20 @@ impl WhirR1CSProver for WhirR1CSScheme { merlin.hint::<(Vec, Vec)>(&(f_sums_1, g_sums_1))?; merlin.hint::<(Vec, Vec)>(&(f_sums_2, g_sums_2))?; - // VERIFY the size given by self.m let public_weight = get_public_weights(public_inputs, &mut merlin, self.m); - let (public_f_sum, public_g_sum) = update_statement_with_public_weights( - &mut statement_1, - &c1.commitment_to_witness, - &c1.masked_polynomial, - &c1.random_polynomial, - public_weight, - ); + let (public_f_sum, public_g_sum) = if public_inputs.len() == 0 { + let public_f_sum = FieldElement::zero(); + let public_g_sum = FieldElement::zero(); + (public_f_sum, public_g_sum) + } else { + update_statement_with_public_weights( + &mut statement_1, + &c1.commitment_to_witness, + &c1.masked_polynomial, + &c1.random_polynomial, + public_weight, + ) + }; let _ = merlin.hint::<(FieldElement, FieldElement)>(&(public_f_sum, public_g_sum)); @@ -677,7 +690,9 @@ fn get_public_weights( m: usize, ) -> Weights { // Add hash to transcript + info!("ASH_TEST : Public inputs: {:?}", public_inputs.0); let public_inputs_hash = public_inputs.hash(); + info!("ASH_TEST : Public inputs hash: {:?}", public_inputs_hash); let _ = merlin.add_scalars(&[public_inputs_hash]); // Get random point x diff --git a/provekit/verifier/src/whir_r1cs.rs b/provekit/verifier/src/whir_r1cs.rs index 22296aae..54cabbcb 100644 --- a/provekit/verifier/src/whir_r1cs.rs +++ b/provekit/verifier/src/whir_r1cs.rs @@ -94,12 +94,15 @@ impl WhirR1CSVerifier for WhirR1CSScheme { let whir_pub_weights_query_answer: (FieldElement, FieldElement) = arthur.hint().unwrap(); - update_statement_for_witness_verifier( - self.m, - &mut statement_1, - &parsed_commitment_1, - whir_pub_weights_query_answer, - ); + + if public_inputs.len() > 0 { + update_statement_for_witness_verifier( + self.m, + &mut statement_1, + &parsed_commitment_1, + whir_pub_weights_query_answer, + ); + } run_whir_pcs_batch_verifier( &mut arthur, @@ -141,12 +144,14 @@ impl WhirR1CSVerifier for WhirR1CSScheme { let whir_pub_weights_query_answer: (FieldElement, FieldElement) = arthur.hint().unwrap(); - update_statement_for_witness_verifier( - self.m, - &mut statement, - &parsed_commitment_1, - whir_pub_weights_query_answer, - ); + if public_inputs.len() > 0 { + update_statement_for_witness_verifier( + self.m, + &mut statement, + &parsed_commitment_1, + whir_pub_weights_query_answer, + ); + } run_whir_pcs_verifier( &mut arthur, From bcef170e81ea774b1dccd78f5013d0ec248e099d Mon Sep 17 00:00:00 2001 From: ashpect Date: Sat, 3 Jan 2026 03:54:29 +0530 Subject: [PATCH 10/18] chore: cleanup, namin, etc --- provekit/common/src/whir_r1cs.rs | 4 ++-- provekit/common/src/witness/mod.rs | 10 ++++++++++ .../common/src/witness/scheduling/splitter.rs | 12 ++++++++---- provekit/prover/src/whir_r1cs.rs | 10 ++++------ provekit/verifier/src/whir_r1cs.rs | 18 ++++++++++-------- 5 files changed, 34 insertions(+), 20 deletions(-) diff --git a/provekit/common/src/whir_r1cs.rs b/provekit/common/src/whir_r1cs.rs index 43aed2c2..702a25c7 100644 --- a/provekit/common/src/whir_r1cs.rs +++ b/provekit/common/src/whir_r1cs.rs @@ -35,10 +35,10 @@ impl WhirR1CSScheme { // Compute total constraints: OOD + statement // OOD: 2 witnesses × committment_ood_samples each // Statement: statement_1 has 3 constraints + 1 public weights constraint = 4, - // statement_2 has 3 = 3, total = 7 + // statement_2 has 3 constraints = 3, total = 7 let num_witnesses = 2; let num_ood_constraints = num_witnesses * self.whir_witness.committment_ood_samples; - let num_statement_constraints = 7; // (3+1) + (3) + let num_statement_constraints = 7; let num_constraints_total = num_ood_constraints + num_statement_constraints; io = io diff --git a/provekit/common/src/witness/mod.rs b/provekit/common/src/witness/mod.rs index ae139032..d0fa8ea6 100644 --- a/provekit/common/src/witness/mod.rs +++ b/provekit/common/src/witness/mod.rs @@ -63,6 +63,10 @@ impl PublicInputs { self.0.len() } + pub fn is_empty(&self) -> bool { + self.0.is_empty() + } + /// Hashes the public input values using SHA-256 and converts the result to /// a FieldElement. pub fn hash(&self) -> FieldElement { @@ -86,3 +90,9 @@ impl PublicInputs { FieldElement::new(BigInt::new(limbs.try_into().unwrap())) } } + +impl Default for PublicInputs { + fn default() -> Self { + Self::new() + } +} diff --git a/provekit/common/src/witness/scheduling/splitter.rs b/provekit/common/src/witness/scheduling/splitter.rs index 2dbd2ba9..94ffffe9 100644 --- a/provekit/common/src/witness/scheduling/splitter.rs +++ b/provekit/common/src/witness/scheduling/splitter.rs @@ -142,7 +142,7 @@ impl<'a> WitnessSplitter<'a> { // Step 7: Assign free builders greedily while respecting dependencies // Rule: if any dependency is in w2, the builder must also be in w2 // (because w1 is solved before w2) - // with the exception of public builders writing public witnesses) + // A free builder for public input witnesses goes in w1. let mut w1_set = mandatory_w1; let mut w2_set = mandatory_w2; @@ -188,8 +188,10 @@ impl<'a> WitnessSplitter<'a> { (w1_indices, w2_indices) } - /// Rearranges w1 indices: constant builder (0) first, then public inputs, - /// then rest. + /// Rearranges w1 builder indices into a canonical order: + /// 1. Constant builder (index 0) first, to preserve R1CS index 0 invariant + /// 2. Public input builders next, grouped together + /// 3. All other w1 builders last, sorted by index fn rearrange_w1( &self, w1_indices: Vec, @@ -200,8 +202,10 @@ impl<'a> WitnessSplitter<'a> { // Sanity Check: Make sure all public inputs and WITNESS_ONE_IDX are in // w1_indices. + // Convert to HashSet for O(1) lookups since we're checking many times + let w1_indices_set = w1_indices.iter().copied().collect::>(); for &idx in acir_public_inputs_indices_set.iter() { - if !w1_indices.contains(&(idx as usize)) { + if !w1_indices_set.contains(&(idx as usize)) { panic!("Public input {} is not in w1_indices", idx); } } diff --git a/provekit/prover/src/whir_r1cs.rs b/provekit/prover/src/whir_r1cs.rs index 9813cf50..f22ab397 100644 --- a/provekit/prover/src/whir_r1cs.rs +++ b/provekit/prover/src/whir_r1cs.rs @@ -173,7 +173,7 @@ impl WhirR1CSProver for WhirR1CSScheme { // VERIFY the size given by self.m let public_weight = get_public_weights(public_inputs, &mut merlin, self.m); - let (public_f_sum, public_g_sum) = if public_inputs.len() == 0 { + let (public_f_sum, public_g_sum) = if public_inputs.is_empty() { // If there are no public inputs, the hint is unused by the verifier and can be // assigned an arbitrary value. let public_f_sum = FieldElement::zero(); @@ -189,7 +189,7 @@ impl WhirR1CSProver for WhirR1CSScheme { ) }; - let _ = merlin.hint::<(FieldElement, FieldElement)>(&(public_f_sum, public_g_sum)); + merlin.hint::<(FieldElement, FieldElement)>(&(public_f_sum, public_g_sum))?; run_zk_whir_pcs_prover( commitment.commitment_to_witness, @@ -239,7 +239,7 @@ impl WhirR1CSProver for WhirR1CSScheme { merlin.hint::<(Vec, Vec)>(&(f_sums_2, g_sums_2))?; let public_weight = get_public_weights(public_inputs, &mut merlin, self.m); - let (public_f_sum, public_g_sum) = if public_inputs.len() == 0 { + let (public_f_sum, public_g_sum) = if public_inputs.is_empty() { let public_f_sum = FieldElement::zero(); let public_g_sum = FieldElement::zero(); (public_f_sum, public_g_sum) @@ -253,7 +253,7 @@ impl WhirR1CSProver for WhirR1CSScheme { ) }; - let _ = merlin.hint::<(FieldElement, FieldElement)>(&(public_f_sum, public_g_sum)); + merlin.hint::<(FieldElement, FieldElement)>(&(public_f_sum, public_g_sum))?; run_zk_whir_pcs_batch_prover( &[c1.commitment_to_witness, c2.commitment_to_witness], @@ -690,9 +690,7 @@ fn get_public_weights( m: usize, ) -> Weights { // Add hash to transcript - info!("ASH_TEST : Public inputs: {:?}", public_inputs.0); let public_inputs_hash = public_inputs.hash(); - info!("ASH_TEST : Public inputs hash: {:?}", public_inputs_hash); let _ = merlin.add_scalars(&[public_inputs_hash]); // Get random point x diff --git a/provekit/verifier/src/whir_r1cs.rs b/provekit/verifier/src/whir_r1cs.rs index 54cabbcb..17e102fb 100644 --- a/provekit/verifier/src/whir_r1cs.rs +++ b/provekit/verifier/src/whir_r1cs.rs @@ -92,15 +92,16 @@ impl WhirR1CSVerifier for WhirR1CSScheme { let mut public_weights_vector_random_buf = [FieldElement::zero()]; arthur.fill_challenge_scalars(&mut public_weights_vector_random_buf)?; - let whir_pub_weights_query_answer: (FieldElement, FieldElement) = - arthur.hint().unwrap(); + let whir_public_weights_query_answer: (FieldElement, FieldElement) = arthur + .hint() + .context("failed to read WHIR public weights query answer")?; - if public_inputs.len() > 0 { + if !public_inputs.is_empty() { update_statement_for_witness_verifier( self.m, &mut statement_1, &parsed_commitment_1, - whir_pub_weights_query_answer, + whir_public_weights_query_answer, ); } @@ -142,14 +143,15 @@ impl WhirR1CSVerifier for WhirR1CSScheme { let mut public_weights_vector_random_buf = [FieldElement::zero()]; arthur.fill_challenge_scalars(&mut public_weights_vector_random_buf)?; - let whir_pub_weights_query_answer: (FieldElement, FieldElement) = - arthur.hint().unwrap(); - if public_inputs.len() > 0 { + let whir_public_weights_query_answer: (FieldElement, FieldElement) = arthur + .hint() + .context("failed to read WHIR public weights query answer")?; + if !public_inputs.is_empty() { update_statement_for_witness_verifier( self.m, &mut statement, &parsed_commitment_1, - whir_pub_weights_query_answer, + whir_public_weights_query_answer, ); } From 6e5272491f1035b33880b47ffdac73c59c4c1f7c Mon Sep 17 00:00:00 2001 From: ashpect Date: Thu, 15 Jan 2026 06:14:52 +0530 Subject: [PATCH 11/18] feat: add gnark support for public witness --- recursive-verifier/app/circuit/circuit.go | 171 +++++++++++++++++- .../app/circuit/circuit_test.go | 4 +- recursive-verifier/app/circuit/common.go | 13 +- recursive-verifier/app/circuit/mtUtilities.go | 83 +++++++++ recursive-verifier/app/circuit/types.go | 20 ++ recursive-verifier/app/circuit/whir.go | 3 +- .../app/circuit/whir_utilities.go | 1 + recursive-verifier/app/utilities/utilities.go | 121 +++++++++++++ 8 files changed, 402 insertions(+), 14 deletions(-) diff --git a/recursive-verifier/app/circuit/circuit.go b/recursive-verifier/app/circuit/circuit.go index 2f95d5e6..dd7c2543 100644 --- a/recursive-verifier/app/circuit/circuit.go +++ b/recursive-verifier/app/circuit/circuit.go @@ -39,6 +39,9 @@ type Circuit struct { WitnessClaimedEvaluations [][]frontend.Variable // [commitment_idx][eval_idx] WitnessBlindingEvaluations [][]frontend.Variable + // For public_f_sum and public_g_sum + PubWitnessEvaluations []frontend.Variable + // Batch mode only: batched polynomial for rounds 1+ WitnessMerkle Merkle @@ -46,9 +49,9 @@ type Circuit struct { MatrixB []MatrixCell MatrixC []MatrixCell - // Public Input - IO []byte - Transcript []uints.U8 `gnark:",public"` + IO []byte + Transcript []uints.U8 `gnark:",public"` + PublicInputs PublicInputs } func (circuit *Circuit) Define(api frontend.API) error { @@ -95,6 +98,26 @@ func (circuit *Circuit) Define(api frontend.API) error { return err } + // Read public inputs hash from transcript + publicInputsHashBuf := make([]frontend.Variable, 1) + if err := arthur.FillNextScalars(publicInputsHashBuf); err != nil { + return fmt.Errorf("failed to read public inputs hash: %w", err) + } + + // TODO : Compute expected public inputs hash and verify + expectedHash, err := hashPublicInputs(api, sc, circuit.PublicInputs) + if err != nil { + return fmt.Errorf("failed to compute public inputs hash: %w", err) + } + + api.AssertIsEqual(publicInputsHashBuf[0], expectedHash) + + // Squeeze rand for public weights + publicWeightsChallenge := make([]frontend.Variable, 1) + if err := arthur.FillChallengeScalars(publicWeightsChallenge); err != nil { + return fmt.Errorf("failed to read public weights challenge: %w", err) + } + // WHIR verification var whirFoldingRandomness []frontend.Variable var az, bz, cz frontend.Variable @@ -115,6 +138,7 @@ func (circuit *Circuit) Define(api frontend.API) error { }, circuit.WHIRParamsWitness, // whirParams circuit.WitnessLinearStatementEvaluations, // linearStatementValuesAtPoints + circuit.PublicInputs, // publicInputs ) if err != nil { return err @@ -125,12 +149,15 @@ func (circuit *Circuit) Define(api frontend.API) error { bz = api.Add(circuit.WitnessClaimedEvaluations[0][1], circuit.WitnessClaimedEvaluations[1][1]) cz = api.Add(circuit.WitnessClaimedEvaluations[0][2], circuit.WitnessClaimedEvaluations[1][2]) } else { + log.Println("Single Mode") + extendedLinearStatementEvals := extendLinearStatement(circuit, [][]frontend.Variable{circuit.WitnessClaimedEvaluations[0], circuit.WitnessBlindingEvaluations[0]}, circuit.PubWitnessEvaluations) + // Single commitment mode whirFoldingRandomness, err = RunZKWhir( api, arthur, uapi, sc, circuit.WitnessMerkle, circuit.WitnessFirstRounds[0], circuit.WHIRParamsWitness, - [][]frontend.Variable{circuit.WitnessClaimedEvaluations[0], circuit.WitnessBlindingEvaluations[0]}, + extendedLinearStatementEvals, circuit.WitnessLinearStatementEvaluations, batchingRandomness1, initialOODQueries1, @@ -150,23 +177,72 @@ func (circuit *Circuit) Define(api frontend.API) error { x := api.Mul(api.Sub(api.Mul(az, bz), cz), calculateEQ(api, spartanSumcheckRand, tRand)) api.AssertIsEqual(spartanSumcheckLastValue, x) + // TODO : generalize it later on if we have more different kinds of statements + // for handling geometric weights statement added at starting + offset := 1 + if circuit.NumChallenges > 0 { // Batch mode - check 6 deferred values matrixExtensionEvals := evaluateR1CSMatrixExtensionBatch(api, circuit, spartanSumcheckRand, whirFoldingRandomness, circuit.W1Size) for i := 0; i < 6; i++ { - api.AssertIsEqual(matrixExtensionEvals[i], circuit.WitnessLinearStatementEvaluations[i]) + api.AssertIsEqual(matrixExtensionEvals[i], circuit.WitnessLinearStatementEvaluations[offset+i]) } } else { + // Single mode - existing logic matrixExtensionEvals := evaluateR1CSMatrixExtension(api, circuit, spartanSumcheckRand, whirFoldingRandomness) for i := 0; i < 3; i++ { - api.AssertIsEqual(matrixExtensionEvals[i], circuit.WitnessLinearStatementEvaluations[i]) + api.AssertIsEqual(matrixExtensionEvals[i], circuit.WitnessLinearStatementEvaluations[offset+i]) } } + // Geomteric weights for public inputs + if !circuit.PublicInputs.IsEmpty() { + publicWeightEval := computePublicWeightEvaluation( + api, circuit.PublicInputs, whirFoldingRandomness, + circuit.WHIRParamsWitness.MVParamsNumberOfVariables, publicWeightsChallenge[0], + ) + + api.AssertIsEqual(publicWeightEval, circuit.WitnessLinearStatementEvaluations[0]) + } + return nil } +func computePublicWeightEvaluation( + api frontend.API, + publicInputs PublicInputs, + foldingRandomness []frontend.Variable, + m int, // domain size = 2^m + x frontend.Variable, +) frontend.Variable { + // Build public weight vector: [1, x, x^2, ..., x^(n-1), 0, 0, ..., 0] where n = len(publicInputs.Values) and total length = 2^m + domainSize := 1 << m + publicWeights := make([]frontend.Variable, domainSize) + + for i := 0; i < domainSize; i++ { + publicWeights[i] = 0 + } + + // Set public weights: [1, x, x^2, ..., x^(n-1), 0, 0, ..., 0] + currentPower := frontend.Variable(1) + for i := 0; i < len(publicInputs.Values); i++ { + publicWeights[i] = currentPower + currentPower = api.Mul(currentPower, x) + } + + // TODO : Replace it with geometric_till algo + // Evaluate the multilinear extension of publicWeights at foldingRandomness + // Formula: f(r) = Σ_{i=0}^{2^m-1} f[i] * eq_i(r) + // where eq_i(r) is the i-th Lagrange basis polynomial + eqPolys := calculateEQOverBooleanHypercube(api, foldingRandomness) + result := frontend.Variable(0) + for i := 0; i < len(publicWeights); i++ { + result = api.Add(result, api.Mul(publicWeights[i], eqPolys[i])) + } + return result +} + func verifyCircuit( deferred []Fp256, cfg Config, @@ -175,9 +251,11 @@ func verifyCircuit( vk *groth16.VerifyingKey, claimedEvaluations ClaimedEvaluations, claimedEvaluations2 ClaimedEvaluations, + publicWeightsClaimedEvaluation [2]Fp256, internedR1CS R1CS, interner Interner, buildOps common.BuildOps, + publicInputs PublicInputs, ) error { transcriptT := make([]uints.U8, cfg.TranscriptLen) contTranscript := make([]uints.U8, cfg.TranscriptLen) @@ -189,9 +267,18 @@ func verifyCircuit( // Determine witness linear statement evals size based on mode var witnessLinearStatementEvalsSize int if cfg.NumChallenges > 0 { - witnessLinearStatementEvalsSize = 6 // 3 per commitment in batch mode + if !cfg.PublicInputs.IsEmpty() { + // 3 per commitment in batch mode + 1 public_input (geometric statement as a subset of linear statement) + witnessLinearStatementEvalsSize = 7 + } else { + witnessLinearStatementEvalsSize = 6 + } } else { - witnessLinearStatementEvalsSize = 3 + if !cfg.PublicInputs.IsEmpty() { + witnessLinearStatementEvalsSize = 4 + } else { + witnessLinearStatementEvalsSize = 3 + } } witnessLinearStatementEvaluations := make([]frontend.Variable, witnessLinearStatementEvalsSize) @@ -199,6 +286,9 @@ func verifyCircuit( contWitnessLinearStatementEvaluations := make([]frontend.Variable, witnessLinearStatementEvalsSize) contHidingSpartanLinearStatementEvaluations := make([]frontend.Variable, 1) + if len(deferred) < 1+witnessLinearStatementEvalsSize { + return fmt.Errorf("deferred array too short: expected at least %d elements, got %d", 1+witnessLinearStatementEvalsSize, len(deferred)) + } hidingSpartanLinearStatementEvaluations[0] = typeConverters.LimbsToBigIntMod(deferred[0].Limbs) for i := 0; i < witnessLinearStatementEvalsSize; i++ { witnessLinearStatementEvaluations[i] = typeConverters.LimbsToBigIntMod(deferred[1+i].Limbs) @@ -258,6 +348,10 @@ func verifyCircuit( fSums2, gSums2 = parseClaimedEvaluations(claimedEvaluations2, true) } + // Parse public weights claimed evaluation + fSumPublicWeights, gSumPublicWeights := parsePublicWeightsClaimedEvaluation(publicWeightsClaimedEvaluation, true) + pubWitnessEvaluations := []frontend.Variable{fSumPublicWeights, gSumPublicWeights} + // Build witness slices conditionally var witnessClaimedEvals, witnessBlindingEvals [][]frontend.Variable if cfg.NumChallenges > 0 { @@ -268,6 +362,11 @@ func verifyCircuit( witnessBlindingEvals = [][]frontend.Variable{gSums} } + // Empty container while circuit creation + publicInputsContainer := PublicInputs{ + Values: make([]frontend.Variable, len(publicInputs.Values)), + } + circuit := Circuit{ IO: []byte(cfg.IOPattern), Transcript: contTranscript, @@ -276,6 +375,7 @@ func verifyCircuit( LogANumTerms: cfg.LogANumTerms, WitnessClaimedEvaluations: witnessClaimedEvals, WitnessBlindingEvaluations: witnessBlindingEvals, + PubWitnessEvaluations: pubWitnessEvaluations, WitnessLinearStatementEvaluations: contWitnessLinearStatementEvaluations, HidingSpartanLinearStatementEvaluations: contHidingSpartanLinearStatementEvaluations, HidingSpartanFirstRound: newMerkle(hints.spartanHidingHint.firstRoundMerklePaths.path, true), @@ -289,6 +389,7 @@ func verifyCircuit( MatrixA: matrixA, MatrixB: matrixB, MatrixC: matrixC, + PublicInputs: publicInputsContainer, } ccs, err := frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, &circuit) @@ -377,13 +478,19 @@ func verifyCircuit( witnessBlindingEvals = [][]frontend.Variable{gSums} } + fSumPublicWeights, gSumPublicWeights = parsePublicWeightsClaimedEvaluation(publicWeightsClaimedEvaluation, false) + pubWitnessEvaluations = []frontend.Variable{fSumPublicWeights, gSumPublicWeights} + assignment := Circuit{ IO: []byte(cfg.IOPattern), Transcript: transcriptT, LogNumConstraints: cfg.LogNumConstraints, + LogNumVariables: cfg.LogNumVariables, + LogANumTerms: cfg.LogANumTerms, WitnessClaimedEvaluations: witnessClaimedEvals, WitnessBlindingEvaluations: witnessBlindingEvals, WitnessLinearStatementEvaluations: witnessLinearStatementEvaluations, + PubWitnessEvaluations: pubWitnessEvaluations, HidingSpartanLinearStatementEvaluations: hidingSpartanLinearStatementEvaluations, HidingSpartanFirstRound: newMerkle(hints.spartanHidingHint.firstRoundMerklePaths.path, false), HidingSpartanMerkle: newMerkle(hints.spartanHidingHint.roundHints, false), @@ -396,13 +503,18 @@ func verifyCircuit( MatrixA: matrixA, MatrixB: matrixB, MatrixC: matrixC, + PublicInputs: publicInputs, } witness, _ := frontend.NewWitness(&assignment, ecc.BN254.ScalarField()) - publicWitness, _ := witness.Public() + publicWitness, err := witness.Public() + if err != nil { + log.Printf("Failed witess,Public(): %v", err) + return err + } opts := []backend.ProverOption{ - backend.WithSolverOptions(solver.WithHints(utilities.IndexOf)), + backend.WithSolverOptions(solver.WithHints(utilities.IndexOf, utilities.HashPublicInputsHint)), backend.WithIcicleAcceleration(), } @@ -436,3 +548,42 @@ func witnessFirstRounds(hints Hints, isContainer bool) []Merkle { } return result } + +func parsePublicWeightsClaimedEvaluation(publicWeightsClaimedEvaluation [2]Fp256, isContainer bool) (frontend.Variable, frontend.Variable) { + var fSumPublicWeights, gSumPublicWeights frontend.Variable + + if !isContainer { + fSumPublicWeights = typeConverters.LimbsToBigIntMod(publicWeightsClaimedEvaluation[0].Limbs) + gSumPublicWeights = typeConverters.LimbsToBigIntMod(publicWeightsClaimedEvaluation[1].Limbs) + } + + return fSumPublicWeights, gSumPublicWeights +} + +func extendLinearStatement( + circuit *Circuit, + linearStatementEvaluations [][]frontend.Variable, + pubWitnessEvaluations []frontend.Variable, +) [][]frontend.Variable { + var extendedLinearStatementEvals [][]frontend.Variable + + if !circuit.PublicInputs.IsEmpty() { + // Extend the statement equivalent array by prepending the public constraint (public constraint is added in starting at prover side) + extendedLinearStatementEvals = make([][]frontend.Variable, 2) + + // f_sums: [public_f_sum, f_sums[0], f_sums[1]... ] + extendedLinearStatementEvals[0] = make([]frontend.Variable, len(linearStatementEvaluations[0])+1) + extendedLinearStatementEvals[0][0] = pubWitnessEvaluations[0] + copy(extendedLinearStatementEvals[0][1:], linearStatementEvaluations[0]) + + // g_sums: [public_g_sum, g_sums[0], g_sums[1]... ] + extendedLinearStatementEvals[1] = make([]frontend.Variable, len(linearStatementEvaluations[1])+1) + extendedLinearStatementEvals[1][0] = pubWitnessEvaluations[1] + copy(extendedLinearStatementEvals[1][1:], linearStatementEvaluations[1]) + } else { + // No public inputs, use original arrays + extendedLinearStatementEvals = linearStatementEvaluations + } + + return extendedLinearStatementEvals +} diff --git a/recursive-verifier/app/circuit/circuit_test.go b/recursive-verifier/app/circuit/circuit_test.go index 19e2c4a0..69c1fa46 100644 --- a/recursive-verifier/app/circuit/circuit_test.go +++ b/recursive-verifier/app/circuit/circuit_test.go @@ -70,7 +70,7 @@ func TestCircuitConstraints(t *testing.T) { test.WithValidAssignment(assignment), test.WithCurves(ecc.BN254), test.WithBackends(backend.GROTH16), - test.WithSolverOpts(solver.WithHints(utilities.IndexOf)), + test.WithSolverOpts(solver.WithHints(utilities.IndexOf, utilities.HashPublicInputsHint)), ) } @@ -124,7 +124,7 @@ func TestCircuitConstraintsSolverOnly(t *testing.T) { } // Solve the constraint system - _, err = ccs.Solve(witness, solver.WithHints(utilities.IndexOf)) + _, err = ccs.Solve(witness, solver.WithHints(utilities.IndexOf, utilities.HashPublicInputsHint)) if err != nil { t.Fatalf("Constraint system not satisfied: %v", err) } diff --git a/recursive-verifier/app/circuit/common.go b/recursive-verifier/app/circuit/common.go index 16eea7c1..70728673 100644 --- a/recursive-verifier/app/circuit/common.go +++ b/recursive-verifier/app/circuit/common.go @@ -30,6 +30,7 @@ func PrepareAndVerifyCircuit(config Config, r1cs R1CS, pk *groth16.ProvingKey, v var deferred []Fp256 var claimedEvaluations ClaimedEvaluations var claimedEvaluations2 ClaimedEvaluations + var publicWeightsEvaluations [2]Fp256 for _, op := range io.Ops { switch op.Kind { @@ -128,6 +129,16 @@ func PrepareAndVerifyCircuit(config Config, r1cs R1CS, pk *groth16.ProvingKey, v if err != nil { return fmt.Errorf("failed to deserialize claimed_evaluations_2: %w", err) } + + case "public_weights_evaluations": + _, err = arkSerialize.CanonicalDeserializeWithMode( + bytes.NewReader(config.Transcript[start:end]), + &publicWeightsEvaluations, + false, false, + ) + if err != nil { + return fmt.Errorf("failed to deserialize public_weights_evaluations: %w", err) + } } if err != nil { @@ -204,7 +215,7 @@ func PrepareAndVerifyCircuit(config Config, r1cs R1CS, pk *groth16.ProvingKey, v WitnessRoundHints: witnessRoundHints, } - err = verifyCircuit(deferred, config, hints, pk, vk, claimedEvaluations, claimedEvaluations2, r1cs, interner, buildOps) + err = verifyCircuit(deferred, config, hints, pk, vk, claimedEvaluations, claimedEvaluations2, publicWeightsEvaluations, r1cs, interner, buildOps, config.PublicInputs) if err != nil { return fmt.Errorf("verification failed: %w", err) } diff --git a/recursive-verifier/app/circuit/mtUtilities.go b/recursive-verifier/app/circuit/mtUtilities.go index 264727b3..47afcd22 100644 --- a/recursive-verifier/app/circuit/mtUtilities.go +++ b/recursive-verifier/app/circuit/mtUtilities.go @@ -1,6 +1,7 @@ package circuit import ( + "fmt" "reilabs/whir-verifier-circuit/app/utilities" "github.com/consensys/gnark/frontend" @@ -112,3 +113,85 @@ func rlcBatchedLeaves(api frontend.API, leaves [][]frontend.Variable, foldSize i } return collapsed } + +// hashPublicInputs computes the hash of public inputs by treating them as field elements +// This mimics the Rust PublicInputs::hash() function using SHA-256, TODO : Shift to skyscraper hash function later +func hashPublicInputs(api frontend.API, sc *skyscraper.Skyscraper, publicInputs PublicInputs) (frontend.Variable, error) { + if len(publicInputs.Values) == 0 { + // Return zero if no public inputs + return frontend.Variable(0), nil + } + + // Use hint to compute SHA-256 hash outside the circuit + // The hint function will be called during witness generation + hashResult, err := api.Compiler().NewHint(utilities.HashPublicInputsHint, 1, publicInputs.Values...) + if err != nil { + return nil, fmt.Errorf("failed to create hash hint: %w", err) + } + + return hashResult[0], nil +} + +// verifyPublicInputsAndReadWeights reads and verifies the public inputs hash from the transcript, +// then reads the public weights challenge and query answer. +// Returns (publicWeightsChallenge, publicWeightsQueryAnswer, error) +func verifyPublicInputsAndReadWeights( + api frontend.API, + sc *skyscraper.Skyscraper, + arthur gnarkNimue.Arthur, + publicInputs PublicInputs, +) (frontend.Variable, []frontend.Variable, error) { + // Read public inputs hash from transcript + publicInputsHashBuf := make([]frontend.Variable, 1) + if err := arthur.FillNextScalars(publicInputsHashBuf); err != nil { + return nil, nil, fmt.Errorf("failed to read public inputs hash: %w", err) + } + + // Compute expected public inputs hash + expectedHash, err := hashPublicInputs(api, sc, publicInputs) + if err != nil { + return nil, nil, fmt.Errorf("failed to compute public inputs hash: %w", err) + } + + // Verify hash matches + api.AssertIsEqual(publicInputsHashBuf[0], expectedHash) + + // Read public weights vector random challenge + publicWeightsChallenge := make([]frontend.Variable, 1) + if err := arthur.FillChallengeScalars(publicWeightsChallenge); err != nil { + return nil, nil, fmt.Errorf("failed to read public weights challenge: %w", err) + } + + // Read WHIR public weights query answer (2 field elements: f_sum, g_sum) + publicWeightsQueryAnswer := make([]frontend.Variable, 2) + if err := arthur.FillNextScalars(publicWeightsQueryAnswer); err != nil { + return nil, nil, fmt.Errorf("failed to read public weights query answer: %w", err) + } + + return publicWeightsChallenge[0], publicWeightsQueryAnswer, nil +} + +// readPublicWeightsQueryAnswer reads only the public weights query answer from the transcript. +// The challenge has already been read at the circuit level to match transcript order. +// Returns (publicWeightsQueryAnswer, error) +func readPublicWeightsQueryAnswer(arthur gnarkNimue.Arthur) ([]frontend.Variable, error) { + // Read WHIR public weights query answer (2 field elements: f_sum, g_sum) + publicWeightsQueryAnswer := make([]frontend.Variable, 2) + if err := arthur.FillNextScalars(publicWeightsQueryAnswer); err != nil { + return nil, fmt.Errorf("failed to read public weights query answer: %w", err) + } + + return publicWeightsQueryAnswer, nil +} + +// computePublicWeightsClaimedSum computes the claimed sum for the public weights constraint +// This is: public_f_sum + public_g_sum * batching_randomness +func computePublicWeightsClaimedSum( + api frontend.API, + publicWeightsQueryAnswer []frontend.Variable, + batchingRandomness frontend.Variable, +) frontend.Variable { + publicFSum := publicWeightsQueryAnswer[0] + publicGSum := publicWeightsQueryAnswer[1] + return api.Add(publicFSum, api.Mul(publicGSum, batchingRandomness)) +} diff --git a/recursive-verifier/app/circuit/types.go b/recursive-verifier/app/circuit/types.go index 420b43e0..b456231a 100644 --- a/recursive-verifier/app/circuit/types.go +++ b/recursive-verifier/app/circuit/types.go @@ -1,6 +1,8 @@ package circuit import ( + "reilabs/whir-verifier-circuit/app/utilities" + "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/std/math/uints" ) @@ -101,6 +103,7 @@ type Config struct { BlindingStatementEvaluations []string `json:"blinding_statement_evaluations"` NumChallenges int `json:"num_challenges"` W1Size int `json:"w1_size"` + PublicInputs PublicInputs `json:"public_inputs"` } // Update Hints to support batch mode @@ -139,3 +142,20 @@ type DualClaimedEvaluations struct { First ClaimedEvaluations Second ClaimedEvaluations } + +type PublicInputs struct { + Values []frontend.Variable +} + +func (p *PublicInputs) UnmarshalJSON(data []byte) error { + values, err := utilities.UnmarshalPublicInputs(data) + if err != nil { + return err + } + p.Values = values + return nil +} + +func (p *PublicInputs) IsEmpty() bool { + return len(p.Values) == 0 +} \ No newline at end of file diff --git a/recursive-verifier/app/circuit/whir.go b/recursive-verifier/app/circuit/whir.go index 0e153280..4e367cb3 100644 --- a/recursive-verifier/app/circuit/whir.go +++ b/recursive-verifier/app/circuit/whir.go @@ -59,7 +59,7 @@ func RunZKWhir( firstRound Merkle, whirParams WHIRParams, linearStatementEvaluations [][]frontend.Variable, - linearStatementValuesAtPoints []frontend.Variable, + linearStatementValuesAtPoints []frontend.Variable, // weights.evaluate(random_point) - this is what needs to be done batchingRandomness frontend.Variable, initialOODQueries []frontend.Variable, initialOODAnswers [][]frontend.Variable, @@ -238,6 +238,7 @@ func RunZKWhirBatch( // Common parameters whirParams WHIRParams, linearStatementValuesAtPoints []frontend.Variable, + publicInputs PublicInputs, ) (totalFoldingRandomness []frontend.Variable, err error) { numPolynomials := len(firstRounds) if numPolynomials == 0 { diff --git a/recursive-verifier/app/circuit/whir_utilities.go b/recursive-verifier/app/circuit/whir_utilities.go index 666dde44..ac75a9c3 100644 --- a/recursive-verifier/app/circuit/whir_utilities.go +++ b/recursive-verifier/app/circuit/whir_utilities.go @@ -140,6 +140,7 @@ func computeWPoly( value = api.Add(value, api.Mul(initialData.InitialCombinationRandomness[j], utilities.EqPolyOutside(api, utilities.ExpandFromUnivariate(api, initialData.InitialOODQueries[j], numberVars), totalFoldingRandomness))) } + // Values are directly used as all linearStatements are deffered and hints were given. Checking of hints will be done later on. for j, linearStatementValueAtPoint := range linearStatementValuesAtPoints { value = api.Add(value, api.Mul(initialData.InitialCombinationRandomness[len(initialData.InitialOODQueries)+j], linearStatementValueAtPoint)) } diff --git a/recursive-verifier/app/utilities/utilities.go b/recursive-verifier/app/utilities/utilities.go index 04a8e9d2..7222133b 100644 --- a/recursive-verifier/app/utilities/utilities.go +++ b/recursive-verifier/app/utilities/utilities.go @@ -1,6 +1,10 @@ package utilities import ( + "crypto/sha256" + "encoding/binary" + "encoding/hex" + "encoding/json" "fmt" "math/big" "reilabs/whir-verifier-circuit/app/typeConverters" @@ -58,6 +62,77 @@ func IndexOf(_ *big.Int, inputs []*big.Int, outputs []*big.Int) error { return nil } +// HashPublicInputsHint is a hint function that computes SHA-256 hash of public inputs +// matching the Rust PublicInputs::hash() implementation. +// It takes public input values, converts them to BigInt, extracts limbs, hashes them, +// and returns the hash as a field element. +func HashPublicInputsHint(_ *big.Int, inputs []*big.Int, outputs []*big.Int) error { + if len(outputs) != 1 { + return fmt.Errorf("expecting one output") + } + + if len(inputs) == 0 { + outputs[0] = big.NewInt(0) + return nil + } + + hasher := sha256.New() + + // Process each public input value + for _, input := range inputs { + // Convert field element to BigInt (it's already a BigInt, but ensure it's in range) + value := new(big.Int).Set(input) + + // Extract limbs (u64 values) from BigInt + // Field elements are represented as 4 u64 limbs in little-endian + limbs := make([]uint64, 4) + temp := new(big.Int).Set(value) + limbs[0] = temp.Uint64() // Least significant limb + temp.Rsh(temp, 64) + limbs[1] = temp.Uint64() + temp.Rsh(temp, 64) + limbs[2] = temp.Uint64() + temp.Rsh(temp, 64) + limbs[3] = temp.Uint64() // Most significant limb + + // Hash each limb as little-endian bytes (8 bytes per limb) + for _, limb := range limbs { + limbBytes := make([]byte, 8) + binary.LittleEndian.PutUint64(limbBytes, limb) + hasher.Write(limbBytes) + } + } + + // Get the hash result (32 bytes) + hashResult := hasher.Sum(nil) + + // Convert hash result to field element by splitting into 4 u64 limbs + // Each chunk of 8 bytes becomes a u64 (little-endian) + limbs := make([]uint64, 4) + for i := 0; i < 4; i++ { + start := i * 8 + end := start + 8 + limbs[i] = binary.LittleEndian.Uint64(hashResult[start:end]) + } + + // Reconstruct field element from limbs + result := new(big.Int).SetUint64(limbs[0]) + temp := new(big.Int).SetUint64(limbs[1]) + result.Add(result, temp.Lsh(temp, 64)) + temp.SetUint64(limbs[2]) + result.Add(result, temp.Lsh(temp, 128)) + temp.SetUint64(limbs[3]) + result.Add(result, temp.Lsh(temp, 192)) + + // Apply modulus to ensure result is in field range + modulus := new(big.Int) + modulus.SetString("21888242871839275222246405745257275088548364400416034343698204186575808495617", 10) + result.Mod(result, modulus) + + outputs[0] = result + return nil +} + func Reverse[T any](s []T) []T { res := make([]T, len(s)) copy(res, s) @@ -210,3 +285,49 @@ func DotProduct(api frontend.API, a []frontend.Variable, b []frontend.Variable) } return acc } + +// ParseHexFieldElement parses a hex string representing a FieldElement (little-endian) +// and converts it to a big.Int. The hex string should be 64 characters (32 bytes). +func ParseHexFieldElement(hexStr string) (*big.Int, error) { + if len(hexStr) >= 2 && hexStr[0:2] == "0x" { + hexStr = hexStr[2:] + } + + bytes, err := hex.DecodeString(hexStr) + if err != nil { + return nil, fmt.Errorf("invalid hex string: %w", err) + } + + reversed := make([]byte, len(bytes)) + for i, b := range bytes { + reversed[len(bytes)-1-i] = b + } + + result := new(big.Int) + result.SetBytes(reversed) + + modulus := new(big.Int) + modulus.SetString("21888242871839275222246405745257275088548364400416034343698204186575808495617", 10) + result.Mod(result, modulus) + + return result, nil +} + +// UnmarshalPublicInputs parses a JSON array of hex-encoded FieldElement strings +// and returns them as frontend.Variable slice. +func UnmarshalPublicInputs(data []byte) ([]frontend.Variable, error) { + var arr []string + if err := json.Unmarshal(data, &arr); err != nil { + return nil, err + } + + values := make([]frontend.Variable, len(arr)) + for i, hexStr := range arr { + value, err := ParseHexFieldElement(hexStr) + if err != nil { + return nil, fmt.Errorf("failed to parse public input at index %d: %w", i, err) + } + values[i] = value + } + return values, nil +} From cb9ec4aed40c6acd83a9112d43deaaa46ec3a190 Mon Sep 17 00:00:00 2001 From: ashpect Date: Thu, 15 Jan 2026 07:24:55 +0530 Subject: [PATCH 12/18] feat: tooling support --- tooling/cli/src/cmd/generate_gnark_inputs.rs | 1 + tooling/provekit-gnark/src/gnark_config.rs | 8 +++++++- tooling/verifier-server/src/services/verification.rs | 1 + 3 files changed, 9 insertions(+), 1 deletion(-) diff --git a/tooling/cli/src/cmd/generate_gnark_inputs.rs b/tooling/cli/src/cmd/generate_gnark_inputs.rs index 883f52b5..0e3cc6d0 100644 --- a/tooling/cli/src/cmd/generate_gnark_inputs.rs +++ b/tooling/cli/src/cmd/generate_gnark_inputs.rs @@ -62,6 +62,7 @@ impl Command for Args { prover.whir_for_witness.a_num_terms, prover.whir_for_witness.num_challenges, prover.whir_for_witness.w1_size, + &proof.public_inputs, &self.params_for_recursive_verifier, ); diff --git a/tooling/provekit-gnark/src/gnark_config.rs b/tooling/provekit-gnark/src/gnark_config.rs index 015cc68f..41439a75 100644 --- a/tooling/provekit-gnark/src/gnark_config.rs +++ b/tooling/provekit-gnark/src/gnark_config.rs @@ -1,6 +1,6 @@ use { ark_poly::EvaluationDomain, - provekit_common::{IOPattern, WhirConfig}, + provekit_common::{IOPattern, PublicInputs, WhirConfig}, serde::{Deserialize, Serialize}, std::{fs::File, io::Write}, tracing::instrument, @@ -29,6 +29,8 @@ pub struct GnarkConfig { pub num_challenges: usize, /// size of w1 pub w1_size: usize, + /// public inputs + pub public_inputs: PublicInputs, } #[derive(Debug, Serialize, Deserialize)] @@ -114,6 +116,7 @@ pub fn gnark_parameters( a_num_terms: usize, num_challenges: usize, w1_size: usize, + public_inputs: &PublicInputs, ) -> GnarkConfig { GnarkConfig { whir_config_witness: WHIRConfigGnark::new(whir_params_witness), @@ -126,6 +129,7 @@ pub fn gnark_parameters( transcript_len: transcript.to_vec().len(), num_challenges, w1_size, + public_inputs: public_inputs.clone(), } } @@ -141,6 +145,7 @@ pub fn write_gnark_parameters_to_file( a_num_terms: usize, num_challenges: usize, w1_size: usize, + public_inputs: &PublicInputs, file_path: &str, ) { let gnark_config = gnark_parameters( @@ -153,6 +158,7 @@ pub fn write_gnark_parameters_to_file( a_num_terms, num_challenges, w1_size, + public_inputs, ); let mut file_params = File::create(file_path).unwrap(); file_params diff --git a/tooling/verifier-server/src/services/verification.rs b/tooling/verifier-server/src/services/verification.rs index f8abd138..398b282c 100644 --- a/tooling/verifier-server/src/services/verification.rs +++ b/tooling/verifier-server/src/services/verification.rs @@ -97,6 +97,7 @@ impl VerificationService { whir_scheme.a_num_terms, whir_scheme.num_challenges, whir_scheme.w1_size, + &proof.public_inputs, gnark_params_path, ); From c9310eb6434ff91614e8a51030522f4c6ba4be48 Mon Sep 17 00:00:00 2001 From: ashpect Date: Thu, 15 Jan 2026 07:37:09 +0530 Subject: [PATCH 13/18] feat: switch to skyscraper for hashing --- provekit/common/src/skyscraper/mod.rs | 6 +- provekit/common/src/witness/mod.rs | 36 ++++---- recursive-verifier/app/circuit/circuit.go | 7 +- .../app/circuit/circuit_test.go | 4 +- recursive-verifier/app/circuit/mtUtilities.go | 84 +++---------------- 5 files changed, 36 insertions(+), 101 deletions(-) diff --git a/provekit/common/src/skyscraper/mod.rs b/provekit/common/src/skyscraper/mod.rs index 3b6da92a..2caecdc2 100644 --- a/provekit/common/src/skyscraper/mod.rs +++ b/provekit/common/src/skyscraper/mod.rs @@ -2,4 +2,8 @@ mod pow; mod sponge; mod whir; -pub use self::{pow::SkyscraperPoW, sponge::SkyscraperSponge, whir::SkyscraperMerkleConfig}; +pub use self::{ + pow::SkyscraperPoW, + sponge::SkyscraperSponge, + whir::{SkyscraperCRH, SkyscraperMerkleConfig}, +}; diff --git a/provekit/common/src/witness/mod.rs b/provekit/common/src/witness/mod.rs index d0fa8ea6..a0ff25d1 100644 --- a/provekit/common/src/witness/mod.rs +++ b/provekit/common/src/witness/mod.rs @@ -8,12 +8,13 @@ mod witness_io_pattern; use { crate::{ + skyscraper::SkyscraperCRH, utils::{serde_ark, serde_ark_vec}, FieldElement, }, - ark_ff::{BigInt, One, PrimeField}, + ark_crypto_primitives::crh::CRHScheme, + ark_ff::One, serde::{Deserialize, Serialize}, - sha2::{Digest, Sha256}, }; pub use { binops::{BINOP_ATOMIC_BITS, BINOP_BITS, NUM_DIGITS}, @@ -49,12 +50,10 @@ impl ConstantOrR1CSWitness { pub struct PublicInputs(#[serde(with = "serde_ark_vec")] pub Vec); impl PublicInputs { - /// Creates a new `PublicInputs` with an empty vector. pub fn new() -> Self { Self(Vec::new()) } - /// Creates a new `PublicInputs` from a vector. pub fn from_vec(vec: Vec) -> Self { Self(vec) } @@ -67,27 +66,20 @@ impl PublicInputs { self.0.is_empty() } - /// Hashes the public input values using SHA-256 and converts the result to - /// a FieldElement. pub fn hash(&self) -> FieldElement { - let mut hasher = Sha256::new(); - - // Hash all public values from witness - for value in self.0.iter() { - let bigint = value.into_bigint(); - for limb in bigint.0.iter() { - hasher.update(&limb.to_le_bytes()); + match self.0.len() { + 0 => FieldElement::from(0u64), + 1 => { + // For single element, hash it with zero to ensure it gets properly hashed + let padded = vec![self.0[0], FieldElement::from(0u64)]; + SkyscraperCRH::evaluate(&(), &padded[..]) + .expect("hash should succeed") + } + _ => { + SkyscraperCRH::evaluate(&(), &self.0[..]) + .expect("hash should succeed for multiple inputs") } } - - let result = hasher.finalize(); - - let limbs = result - .chunks_exact(8) - .map(|s| u64::from_le_bytes(s.try_into().unwrap())) - .collect::>(); - - FieldElement::new(BigInt::new(limbs.try_into().unwrap())) } } diff --git a/recursive-verifier/app/circuit/circuit.go b/recursive-verifier/app/circuit/circuit.go index dd7c2543..c5a7fbb4 100644 --- a/recursive-verifier/app/circuit/circuit.go +++ b/recursive-verifier/app/circuit/circuit.go @@ -104,8 +104,7 @@ func (circuit *Circuit) Define(api frontend.API) error { return fmt.Errorf("failed to read public inputs hash: %w", err) } - // TODO : Compute expected public inputs hash and verify - expectedHash, err := hashPublicInputs(api, sc, circuit.PublicInputs) + expectedHash, err := hashPublicInputs(sc, circuit.PublicInputs) if err != nil { return fmt.Errorf("failed to compute public inputs hash: %w", err) } @@ -367,6 +366,8 @@ func verifyCircuit( Values: make([]frontend.Variable, len(publicInputs.Values)), } + log.Println("publicInputs", publicInputs) + circuit := Circuit{ IO: []byte(cfg.IOPattern), Transcript: contTranscript, @@ -514,7 +515,7 @@ func verifyCircuit( } opts := []backend.ProverOption{ - backend.WithSolverOptions(solver.WithHints(utilities.IndexOf, utilities.HashPublicInputsHint)), + backend.WithSolverOptions(solver.WithHints(utilities.IndexOf)), backend.WithIcicleAcceleration(), } diff --git a/recursive-verifier/app/circuit/circuit_test.go b/recursive-verifier/app/circuit/circuit_test.go index 69c1fa46..19e2c4a0 100644 --- a/recursive-verifier/app/circuit/circuit_test.go +++ b/recursive-verifier/app/circuit/circuit_test.go @@ -70,7 +70,7 @@ func TestCircuitConstraints(t *testing.T) { test.WithValidAssignment(assignment), test.WithCurves(ecc.BN254), test.WithBackends(backend.GROTH16), - test.WithSolverOpts(solver.WithHints(utilities.IndexOf, utilities.HashPublicInputsHint)), + test.WithSolverOpts(solver.WithHints(utilities.IndexOf)), ) } @@ -124,7 +124,7 @@ func TestCircuitConstraintsSolverOnly(t *testing.T) { } // Solve the constraint system - _, err = ccs.Solve(witness, solver.WithHints(utilities.IndexOf, utilities.HashPublicInputsHint)) + _, err = ccs.Solve(witness, solver.WithHints(utilities.IndexOf)) if err != nil { t.Fatalf("Constraint system not satisfied: %v", err) } diff --git a/recursive-verifier/app/circuit/mtUtilities.go b/recursive-verifier/app/circuit/mtUtilities.go index 47afcd22..4dc74eee 100644 --- a/recursive-verifier/app/circuit/mtUtilities.go +++ b/recursive-verifier/app/circuit/mtUtilities.go @@ -1,7 +1,6 @@ package circuit import ( - "fmt" "reilabs/whir-verifier-circuit/app/utilities" "github.com/consensys/gnark/frontend" @@ -114,84 +113,23 @@ func rlcBatchedLeaves(api frontend.API, leaves [][]frontend.Variable, foldSize i return collapsed } -// hashPublicInputs computes the hash of public inputs by treating them as field elements -// This mimics the Rust PublicInputs::hash() function using SHA-256, TODO : Shift to skyscraper hash function later -func hashPublicInputs(api frontend.API, sc *skyscraper.Skyscraper, publicInputs PublicInputs) (frontend.Variable, error) { +// hashPublicInputs computes the hash of public inputs as field elements sequentially +func hashPublicInputs(sc *skyscraper.Skyscraper, publicInputs PublicInputs) (frontend.Variable, error) { + if len(publicInputs.Values) == 0 { - // Return zero if no public inputs return frontend.Variable(0), nil } - // Use hint to compute SHA-256 hash outside the circuit - // The hint function will be called during witness generation - hashResult, err := api.Compiler().NewHint(utilities.HashPublicInputsHint, 1, publicInputs.Values...) - if err != nil { - return nil, fmt.Errorf("failed to create hash hint: %w", err) - } - - return hashResult[0], nil -} - -// verifyPublicInputsAndReadWeights reads and verifies the public inputs hash from the transcript, -// then reads the public weights challenge and query answer. -// Returns (publicWeightsChallenge, publicWeightsQueryAnswer, error) -func verifyPublicInputsAndReadWeights( - api frontend.API, - sc *skyscraper.Skyscraper, - arthur gnarkNimue.Arthur, - publicInputs PublicInputs, -) (frontend.Variable, []frontend.Variable, error) { - // Read public inputs hash from transcript - publicInputsHashBuf := make([]frontend.Variable, 1) - if err := arthur.FillNextScalars(publicInputsHashBuf); err != nil { - return nil, nil, fmt.Errorf("failed to read public inputs hash: %w", err) - } - - // Compute expected public inputs hash - expectedHash, err := hashPublicInputs(api, sc, publicInputs) - if err != nil { - return nil, nil, fmt.Errorf("failed to compute public inputs hash: %w", err) - } - - // Verify hash matches - api.AssertIsEqual(publicInputsHashBuf[0], expectedHash) - - // Read public weights vector random challenge - publicWeightsChallenge := make([]frontend.Variable, 1) - if err := arthur.FillChallengeScalars(publicWeightsChallenge); err != nil { - return nil, nil, fmt.Errorf("failed to read public weights challenge: %w", err) + // For single element, we hash it with a zero + if len(publicInputs.Values) == 1 { + return sc.CompressV2(publicInputs.Values[0], frontend.Variable(0)), nil } - // Read WHIR public weights query answer (2 field elements: f_sum, g_sum) - publicWeightsQueryAnswer := make([]frontend.Variable, 2) - if err := arthur.FillNextScalars(publicWeightsQueryAnswer); err != nil { - return nil, nil, fmt.Errorf("failed to read public weights query answer: %w", err) + // For 2+ elements, use standard approach + hash := sc.CompressV2(publicInputs.Values[0], publicInputs.Values[1]) + for i := 2; i < len(publicInputs.Values); i++ { + hash = sc.CompressV2(hash, publicInputs.Values[i]) } - return publicWeightsChallenge[0], publicWeightsQueryAnswer, nil -} - -// readPublicWeightsQueryAnswer reads only the public weights query answer from the transcript. -// The challenge has already been read at the circuit level to match transcript order. -// Returns (publicWeightsQueryAnswer, error) -func readPublicWeightsQueryAnswer(arthur gnarkNimue.Arthur) ([]frontend.Variable, error) { - // Read WHIR public weights query answer (2 field elements: f_sum, g_sum) - publicWeightsQueryAnswer := make([]frontend.Variable, 2) - if err := arthur.FillNextScalars(publicWeightsQueryAnswer); err != nil { - return nil, fmt.Errorf("failed to read public weights query answer: %w", err) - } - - return publicWeightsQueryAnswer, nil -} - -// computePublicWeightsClaimedSum computes the claimed sum for the public weights constraint -// This is: public_f_sum + public_g_sum * batching_randomness -func computePublicWeightsClaimedSum( - api frontend.API, - publicWeightsQueryAnswer []frontend.Variable, - batchingRandomness frontend.Variable, -) frontend.Variable { - publicFSum := publicWeightsQueryAnswer[0] - publicGSum := publicWeightsQueryAnswer[1] - return api.Add(publicFSum, api.Mul(publicGSum, batchingRandomness)) + return hash, nil } From 5f3e1917763b846c197bd9ec01ab473b30e52c97 Mon Sep 17 00:00:00 2001 From: ashpect Date: Thu, 15 Jan 2026 08:06:42 +0530 Subject: [PATCH 14/18] feat: fix dualmode --- recursive-verifier/app/circuit/circuit.go | 38 ++++++++++++++++++----- recursive-verifier/app/circuit/whir.go | 12 +++++-- 2 files changed, 40 insertions(+), 10 deletions(-) diff --git a/recursive-verifier/app/circuit/circuit.go b/recursive-verifier/app/circuit/circuit.go index c5a7fbb4..1262d76a 100644 --- a/recursive-verifier/app/circuit/circuit.go +++ b/recursive-verifier/app/circuit/circuit.go @@ -122,7 +122,32 @@ func (circuit *Circuit) Define(api frontend.API) error { var az, bz, cz frontend.Variable if circuit.NumChallenges > 0 { - // Dual commitment mode: batch WHIR verification + // Only statement_1 (first commitment) gets extended with public weights, statement_2 remains unchanged + extendedLinearStatementEvalsBatch := make([][][]frontend.Variable, 2) + + if !circuit.PublicInputs.IsEmpty() { + extendedLinearStatementEvalsBatch[0] = extendLinearStatement( + circuit, + [][]frontend.Variable{circuit.WitnessClaimedEvaluations[0], circuit.WitnessBlindingEvaluations[0]}, + circuit.PubWitnessEvaluations, + ) + + extendedLinearStatementEvalsBatch[1] = [][]frontend.Variable{ + circuit.WitnessClaimedEvaluations[1], + circuit.WitnessBlindingEvaluations[1], + } + } else { + // Use original arrays as before, no public inputs + extendedLinearStatementEvalsBatch[0] = [][]frontend.Variable{ + circuit.WitnessClaimedEvaluations[0], + circuit.WitnessBlindingEvaluations[0], + } + extendedLinearStatementEvalsBatch[1] = [][]frontend.Variable{ + circuit.WitnessClaimedEvaluations[1], + circuit.WitnessBlindingEvaluations[1], + } + } + whirFoldingRandomness, err = RunZKWhirBatch( api, arthur, uapi, sc, circuit.WitnessFirstRounds, // firstRounds []Merkle @@ -131,13 +156,10 @@ func (circuit *Circuit) Define(api frontend.API) error { [][][]frontend.Variable{initialOODAnswers1, initialOODAnswers2}, // initialOODAnswers []frontend.Variable{rootHash1, rootHash2}, // rootHashes circuit.WitnessMerkle, // batchedMerkle - [][][]frontend.Variable{ // linearStatementEvals - {circuit.WitnessClaimedEvaluations[0], circuit.WitnessBlindingEvaluations[0]}, - {circuit.WitnessClaimedEvaluations[1], circuit.WitnessBlindingEvaluations[1]}, - }, - circuit.WHIRParamsWitness, // whirParams - circuit.WitnessLinearStatementEvaluations, // linearStatementValuesAtPoints - circuit.PublicInputs, // publicInputs + extendedLinearStatementEvalsBatch, // linearStatementEvals (extended for first commitment) + circuit.WHIRParamsWitness, // whirParams + circuit.WitnessLinearStatementEvaluations, // linearStatementValuesAtPoints + circuit.PublicInputs, // publicInputs ) if err != nil { return err diff --git a/recursive-verifier/app/circuit/whir.go b/recursive-verifier/app/circuit/whir.go index 4e367cb3..3d780dd7 100644 --- a/recursive-verifier/app/circuit/whir.go +++ b/recursive-verifier/app/circuit/whir.go @@ -59,7 +59,7 @@ func RunZKWhir( firstRound Merkle, whirParams WHIRParams, linearStatementEvaluations [][]frontend.Variable, - linearStatementValuesAtPoints []frontend.Variable, // weights.evaluate(random_point) - this is what needs to be done + linearStatementValuesAtPoints []frontend.Variable, batchingRandomness frontend.Variable, initialOODQueries []frontend.Variable, initialOODAnswers [][]frontend.Variable, @@ -256,7 +256,15 @@ func RunZKWhirBatch( for i := 0; i < numPolynomials; i++ { numOOD += len(initialOODQueries[i]) } - numStatementConstraints := numPolynomials * 3 // 3 per commitment (Az, Bz, Cz) + + numStatementConstraints := 0 + + // w1 has 4 (pub, Az, Bz, Cz) constraints, w2 and remaining have 3 (Az, Bz, Cz) constraints + if !publicInputs.IsEmpty() { + numStatementConstraints = 4 + 3*(numPolynomials-1) + } else { + numStatementConstraints = numPolynomials * 3 + } numConstraints := numOOD + numStatementConstraints // Step 3: Read N×M evaluation matrix from transcript From 545e59ea773d20888e4e5e262a16f26a7bc4f10e Mon Sep 17 00:00:00 2001 From: ashpect Date: Thu, 15 Jan 2026 08:38:36 +0530 Subject: [PATCH 15/18] chore: cleanup --- recursive-verifier/app/circuit/circuit.go | 8 ++++--- recursive-verifier/app/circuit/types.go | 26 +++++++++++------------ recursive-verifier/app/circuit/whir.go | 2 +- 3 files changed, 19 insertions(+), 17 deletions(-) diff --git a/recursive-verifier/app/circuit/circuit.go b/recursive-verifier/app/circuit/circuit.go index 1262d76a..fcc1e216 100644 --- a/recursive-verifier/app/circuit/circuit.go +++ b/recursive-verifier/app/circuit/circuit.go @@ -198,9 +198,11 @@ func (circuit *Circuit) Define(api frontend.API) error { x := api.Mul(api.Sub(api.Mul(az, bz), cz), calculateEQ(api, spartanSumcheckRand, tRand)) api.AssertIsEqual(spartanSumcheckLastValue, x) - // TODO : generalize it later on if we have more different kinds of statements - // for handling geometric weights statement added at starting - offset := 1 + offset := 0 + if !circuit.PublicInputs.IsEmpty() { + // can be generalized later on if we have more different kinds of statements + offset = 1 + } if circuit.NumChallenges > 0 { // Batch mode - check 6 deferred values diff --git a/recursive-verifier/app/circuit/types.go b/recursive-verifier/app/circuit/types.go index b456231a..f6db49d0 100644 --- a/recursive-verifier/app/circuit/types.go +++ b/recursive-verifier/app/circuit/types.go @@ -91,18 +91,18 @@ type ProofObject struct { } type Config struct { - WHIRConfigWitness WHIRConfig `json:"whir_config_witness"` - WHIRConfigHidingSpartan WHIRConfig `json:"whir_config_hiding_spartan"` - LogNumConstraints int `json:"log_num_constraints"` - LogNumVariables int `json:"log_num_variables"` - LogANumTerms int `json:"log_a_num_terms"` - IOPattern string `json:"io_pattern"` - Transcript []byte `json:"transcript"` - TranscriptLen int `json:"transcript_len"` - WitnessStatementEvaluations []string `json:"witness_statement_evaluations"` - BlindingStatementEvaluations []string `json:"blinding_statement_evaluations"` - NumChallenges int `json:"num_challenges"` - W1Size int `json:"w1_size"` + WHIRConfigWitness WHIRConfig `json:"whir_config_witness"` + WHIRConfigHidingSpartan WHIRConfig `json:"whir_config_hiding_spartan"` + LogNumConstraints int `json:"log_num_constraints"` + LogNumVariables int `json:"log_num_variables"` + LogANumTerms int `json:"log_a_num_terms"` + IOPattern string `json:"io_pattern"` + Transcript []byte `json:"transcript"` + TranscriptLen int `json:"transcript_len"` + WitnessStatementEvaluations []string `json:"witness_statement_evaluations"` + BlindingStatementEvaluations []string `json:"blinding_statement_evaluations"` + NumChallenges int `json:"num_challenges"` + W1Size int `json:"w1_size"` PublicInputs PublicInputs `json:"public_inputs"` } @@ -158,4 +158,4 @@ func (p *PublicInputs) UnmarshalJSON(data []byte) error { func (p *PublicInputs) IsEmpty() bool { return len(p.Values) == 0 -} \ No newline at end of file +} diff --git a/recursive-verifier/app/circuit/whir.go b/recursive-verifier/app/circuit/whir.go index 3d780dd7..3bdea8f1 100644 --- a/recursive-verifier/app/circuit/whir.go +++ b/recursive-verifier/app/circuit/whir.go @@ -256,7 +256,7 @@ func RunZKWhirBatch( for i := 0; i < numPolynomials; i++ { numOOD += len(initialOODQueries[i]) } - + numStatementConstraints := 0 // w1 has 4 (pub, Az, Bz, Cz) constraints, w2 and remaining have 3 (Az, Bz, Cz) constraints From 15bddd86e4399f9d716c52da791325556f80ad74 Mon Sep 17 00:00:00 2001 From: ashpect Date: Thu, 15 Jan 2026 08:58:05 +0530 Subject: [PATCH 16/18] chore: cargofmt --- provekit/common/src/witness/mod.rs | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/provekit/common/src/witness/mod.rs b/provekit/common/src/witness/mod.rs index a0ff25d1..a5e9a8c4 100644 --- a/provekit/common/src/witness/mod.rs +++ b/provekit/common/src/witness/mod.rs @@ -72,13 +72,10 @@ impl PublicInputs { 1 => { // For single element, hash it with zero to ensure it gets properly hashed let padded = vec![self.0[0], FieldElement::from(0u64)]; - SkyscraperCRH::evaluate(&(), &padded[..]) - .expect("hash should succeed") - } - _ => { - SkyscraperCRH::evaluate(&(), &self.0[..]) - .expect("hash should succeed for multiple inputs") + SkyscraperCRH::evaluate(&(), &padded[..]).expect("hash should succeed") } + _ => SkyscraperCRH::evaluate(&(), &self.0[..]) + .expect("hash should succeed for multiple inputs"), } } } From 493c1757c6ec74f191c6424ef0c0779b26f7ba24 Mon Sep 17 00:00:00 2001 From: ashpect Date: Wed, 21 Jan 2026 00:54:58 +0530 Subject: [PATCH 17/18] feat: native verifier deffered evals --- provekit/common/src/utils/zk_utils.rs | 28 ++++- provekit/prover/src/whir_r1cs.rs | 2 +- provekit/verifier/src/lib.rs | 8 +- provekit/verifier/src/whir_r1cs.rs | 173 ++++++++++++++++++++++++-- tooling/cli/src/cmd/prove.rs | 6 +- tooling/cli/src/cmd/verify.rs | 11 +- 6 files changed, 210 insertions(+), 18 deletions(-) diff --git a/provekit/common/src/utils/zk_utils.rs b/provekit/common/src/utils/zk_utils.rs index 82a65087..0a445802 100644 --- a/provekit/common/src/utils/zk_utils.rs +++ b/provekit/common/src/utils/zk_utils.rs @@ -1,5 +1,7 @@ use { - crate::FieldElement, ark_ff::UniformRand, rayon::prelude::*, + crate::FieldElement, + ark_ff::{Field, UniformRand}, + rayon::prelude::*, whir::poly_utils::evals::EvaluationsList, }; @@ -37,3 +39,27 @@ pub fn generate_random_multilinear_polynomial(num_vars: usize) -> Vec(mut a: F, n: usize, x: &[F]) -> F { + let k = x.len(); + assert!(n > 0 && n < (1 << k)); + let mut borrow_0 = F::one(); + let mut borrow_1 = F::zero(); + for (i, &xi) in x.iter().rev().enumerate() { + let bn = ((n - 1) >> i) & 1; + let b0 = F::one() - xi; + let b1 = a * xi; + (borrow_0, borrow_1) = if bn == 0 { + (b0 * borrow_0, (b0 + b1) * borrow_1 + b1 * borrow_0) + } else { + ((b0 + b1) * borrow_0 + b0 * borrow_1, b1 * borrow_1) + }; + a = a.square(); + } + borrow_0 +} \ No newline at end of file diff --git a/provekit/prover/src/whir_r1cs.rs b/provekit/prover/src/whir_r1cs.rs index f22ab397..4be932e7 100644 --- a/provekit/prover/src/whir_r1cs.rs +++ b/provekit/prover/src/whir_r1cs.rs @@ -712,7 +712,7 @@ fn get_public_weights( Weights::geometric( x, - public_inputs.0.len(), + public_inputs.len(), EvaluationsList::new(public_weights), ) } diff --git a/provekit/verifier/src/lib.rs b/provekit/verifier/src/lib.rs index d1ec351a..9ac47f4d 100644 --- a/provekit/verifier/src/lib.rs +++ b/provekit/verifier/src/lib.rs @@ -3,21 +3,21 @@ mod whir_r1cs; use { crate::whir_r1cs::WhirR1CSVerifier, anyhow::Result, - provekit_common::{NoirProof, Verifier}, + provekit_common::{NoirProof, R1CS, Verifier}, tracing::instrument, }; pub trait Verify { - fn verify(&mut self, proof: &NoirProof) -> Result<()>; + fn verify(&mut self, proof: &NoirProof, r1cs: &R1CS) -> Result<()>; } impl Verify for Verifier { #[instrument(skip_all)] - fn verify(&mut self, proof: &NoirProof) -> Result<()> { + fn verify(&mut self, proof: &NoirProof, r1cs: &R1CS) -> Result<()> { self.whir_for_witness .take() .unwrap() - .verify(&proof.whir_r1cs_proof, &proof.public_inputs)?; + .verify(&proof.whir_r1cs_proof, &proof.public_inputs, r1cs)?; Ok(()) } diff --git a/provekit/verifier/src/whir_r1cs.rs b/provekit/verifier/src/whir_r1cs.rs index 17e102fb..53834362 100644 --- a/provekit/verifier/src/whir_r1cs.rs +++ b/provekit/verifier/src/whir_r1cs.rs @@ -3,8 +3,12 @@ use { ark_std::{One, Zero}, provekit_common::{ skyscraper::SkyscraperSponge, - utils::sumcheck::{calculate_eq, eval_cubic_poly}, + utils::sumcheck::{ + calculate_eq, calculate_evaluations_over_boolean_hypercube_for_eq, eval_cubic_poly, + }, FieldElement, PublicInputs, WhirConfig, WhirR1CSProof, WhirR1CSScheme, + R1CS, + utils::zk_utils::geometric_till, }, spongefish::{ codecs::arkworks_algebra::{FieldToUnitDeserialize, UnitToField}, @@ -29,13 +33,13 @@ pub struct DataFromSumcheckVerifier { } pub trait WhirR1CSVerifier { - fn verify(&self, proof: &WhirR1CSProof, public_inputs: &PublicInputs) -> Result<()>; + fn verify(&self, proof: &WhirR1CSProof, public_inputs: &PublicInputs, r1cs: &R1CS) -> Result<()>; } impl WhirR1CSVerifier for WhirR1CSScheme { #[instrument(skip_all)] #[allow(unused)] - fn verify(&self, proof: &WhirR1CSProof, public_inputs: &PublicInputs) -> Result<()> { + fn verify(&self, proof: &WhirR1CSProof, public_inputs: &PublicInputs, r1cs: &R1CS) -> Result<()> { let io = self.create_io_pattern(); let mut arthur = io.to_verifier_state(&proof.transcript); @@ -57,7 +61,7 @@ impl WhirR1CSVerifier for WhirR1CSScheme { .context("while verifying sumcheck")?; // Read hints and verify WHIR proof - let (az_at_alpha, bz_at_alpha, cz_at_alpha) = + let (az_at_alpha, bz_at_alpha, cz_at_alpha, whir_folding_randomness, deferred_evals, public_weights_challenge) = if let Some(parsed_commitment_2) = parsed_commitment_2 { // Dual commitment mode let sums_1: (Vec, Vec) = arthur.hint()?; @@ -105,7 +109,7 @@ impl WhirR1CSVerifier for WhirR1CSScheme { ); } - run_whir_pcs_batch_verifier( + let (whir_folding_randomness, deferred_evals) = run_whir_pcs_batch_verifier( &mut arthur, &self.whir_witness, &[parsed_commitment_1, parsed_commitment_2], @@ -117,6 +121,9 @@ impl WhirR1CSVerifier for WhirR1CSScheme { whir_sums_1.0[0] + whir_sums_2.0[0], whir_sums_1.0[1] + whir_sums_2.0[1], whir_sums_1.0[2] + whir_sums_2.0[2], + whir_folding_randomness.0.to_vec(), + deferred_evals, + public_weights_vector_random_buf[0], ) } else { // Single commitment mode @@ -155,7 +162,7 @@ impl WhirR1CSVerifier for WhirR1CSScheme { ); } - run_whir_pcs_verifier( + let (whir_folding_randomness, deferred_evals) = run_whir_pcs_verifier( &mut arthur, &parsed_commitment_1, &self.whir_witness, @@ -163,7 +170,14 @@ impl WhirR1CSVerifier for WhirR1CSScheme { ) .context("while verifying WHIR proof")?; - (whir_sums.0[0], whir_sums.0[1], whir_sums.0[2]) + ( + whir_sums.0[0], + whir_sums.0[1], + whir_sums.0[2], + whir_folding_randomness.0.to_vec(), + deferred_evals, + public_weights_vector_random_buf[0], + ) }; // Check the Spartan sumcheck relation @@ -177,6 +191,54 @@ impl WhirR1CSVerifier for WhirR1CSScheme { "last sumcheck value does not match" ); + // Check deferred linear and geometric constraints + let offset = if public_inputs.is_empty() { 0 } else { 1 }; + + // Linear deferred + if self.num_challenges > 0 { + let matrix_extension_evals = evaluate_r1cs_matrix_extension_batch( + r1cs, + &data_from_sumcheck_verifier.alpha, + &whir_folding_randomness, + self.w1_size, + ); + for i in 0..6 { + ensure!( + matrix_extension_evals[i] == deferred_evals[offset + i], + "Matrix extension evaluation {} does not match deferred value", + i + ); + } + } else { + let matrix_extension_evals = evaluate_r1cs_matrix_extension( + r1cs, + &data_from_sumcheck_verifier.alpha, + &whir_folding_randomness, + ); + + for i in 0..3 { + ensure!( + matrix_extension_evals[i] == deferred_evals[offset + i], + "Matrix extension evaluation {} does not match deferred value", + i + ); + } + } + + // Geometric deferred + if !public_inputs.is_empty() { + let public_weight_eval = compute_public_weight_evaluation( + public_inputs, + &whir_folding_randomness, + self.whir_witness.mv_parameters.num_variables, + public_weights_challenge, + ); + ensure!( + public_weight_eval == deferred_evals[0], + "Public weight evaluation does not match deferred value" + ); + } + Ok(()) } } @@ -305,3 +367,100 @@ pub fn run_whir_pcs_batch_verifier( .context("while verifying batch WHIR")?; Ok((folding_randomness, deferred)) } + +fn evaluate_r1cs_matrix_extension( + r1cs: &R1CS, + row_rand: &[FieldElement], + col_rand: &[FieldElement], +) -> [FieldElement; 3] { + let row_eval = calculate_evaluations_over_boolean_hypercube_for_eq(row_rand.to_vec()); + let col_eval = calculate_evaluations_over_boolean_hypercube_for_eq(col_rand.to_vec()); + + let mut ans_a = FieldElement::zero(); + let mut ans_b = FieldElement::zero(); + let mut ans_c = FieldElement::zero(); + + for ((row, col), val) in r1cs.a().iter() { + ans_a += val * row_eval[row] * col_eval[col]; + } + + for ((row, col), val) in r1cs.b().iter() { + ans_b += val * row_eval[row] * col_eval[col]; + } + + for ((row, col), val) in r1cs.c().iter() { + ans_c += val * row_eval[row] * col_eval[col]; + } + + [ans_a, ans_b, ans_c] +} + +fn evaluate_r1cs_matrix_extension_batch( + r1cs: &R1CS, + row_rand: &[FieldElement], + col_rand: &[FieldElement], + w1_size: usize, +) -> [FieldElement; 6] { + let row_eval = calculate_evaluations_over_boolean_hypercube_for_eq(row_rand.to_vec()); + let col_eval = calculate_evaluations_over_boolean_hypercube_for_eq(col_rand.to_vec()); + + let mut ans = [FieldElement::zero(); 6]; + + // Evaluate matrices - split by column based on w1_size + for ((row, col), val) in r1cs.a().iter() { + if col < w1_size { + ans[0] += val * row_eval[row] * col_eval[col]; + } else { + ans[3] += val * row_eval[row] * col_eval[col - w1_size]; + } + } + + for ((row, col), val) in r1cs.b().iter() { + if col < w1_size { + ans[1] += val * row_eval[row] * col_eval[col]; + } else { + ans[4] += val * row_eval[row] * col_eval[col - w1_size]; + } + } + + for ((row, col), val) in r1cs.c().iter() { + if col < w1_size { + ans[2] += val * row_eval[row] * col_eval[col]; + } else { + ans[5] += val * row_eval[row] * col_eval[col - w1_size]; + } + } + + ans +} + +fn compute_public_weight_evaluation( + public_inputs: &PublicInputs, + folding_randomness: &[FieldElement], + m: usize, + x: FieldElement, +) -> FieldElement { + let domain_size = 1 << m; + let mut public_weights = vec![FieldElement::zero(); domain_size]; + + let mut current_pow = FieldElement::one(); + for (idx, _) in public_inputs.0.iter().enumerate() { + public_weights[idx] = current_pow; + current_pow = current_pow * x; + } + + let mle = geometric_till(x, public_inputs.len(), folding_randomness); + + #[cfg(test)] + { + let eq_polys = calculate_evaluations_over_boolean_hypercube_for_eq(folding_randomness.to_vec()); + let sum: FieldElement = public_weights + .iter() + .zip(eq_polys.iter()) + .map(|(w, eq)| *w * eq) + .sum(); + assert!(sum == mle, "Sum does not match mle"); + } + + mle +} diff --git a/tooling/cli/src/cmd/prove.rs b/tooling/cli/src/cmd/prove.rs index 728b62c4..6169dd0e 100644 --- a/tooling/cli/src/cmd/prove.rs +++ b/tooling/cli/src/cmd/prove.rs @@ -48,8 +48,8 @@ impl Command for Args { let (constraints, witnesses) = prover.size(); info!(constraints, witnesses, "Read Noir proof scheme"); - // // Read the input toml - // let input_map = scheme.read_witness(&self.input_path)?; + #[cfg(test)] + let r1cs = prover.r1cs.clone(); // Generate the proof let proof = prover @@ -62,7 +62,7 @@ impl Command for Args { let mut verifier: Verifier = read(&self.verifier_path).context("while reading Provekit Verifier")?; verifier - .verify(&proof) + .verify(&proof, &r1cs) .context("While verifying Noir proof")?; } diff --git a/tooling/cli/src/cmd/verify.rs b/tooling/cli/src/cmd/verify.rs index 6b95ce62..762cca4d 100644 --- a/tooling/cli/src/cmd/verify.rs +++ b/tooling/cli/src/cmd/verify.rs @@ -2,7 +2,7 @@ use { super::Command, anyhow::{Context, Result}, argh::FromArgs, - provekit_common::{file::read, Verifier}, + provekit_common::{file::read, Prover, Verifier}, provekit_verifier::Verify, std::path::PathBuf, tracing::instrument, @@ -16,6 +16,10 @@ pub struct Args { #[argh(positional)] verifier_path: PathBuf, + /// path to the prover file (.pkp) + #[argh(positional)] + prover_path: PathBuf, + /// path to the proof file #[argh(positional)] proof_path: PathBuf, @@ -31,9 +35,12 @@ impl Command for Args { // Read the proof let proof = read(&self.proof_path).context("while reading proof")?; + // Read the prover and R1CS (similar to generate-gnark-inputs) + let prover: Prover = read(&self.prover_path).context("while reading Provekit Prover")?; + // Verify the proof verifier - .verify(&proof) + .verify(&proof, &prover.r1cs) .context("While verifying Noir proof")?; Ok(()) From 37d7c42c37a672f839eef9c61c43257b8d8b19b0 Mon Sep 17 00:00:00 2001 From: ashpect Date: Wed, 21 Jan 2026 00:59:37 +0530 Subject: [PATCH 18/18] chore: cleanup --- provekit/common/src/utils/zk_utils.rs | 5 +- provekit/prover/src/whir_r1cs.rs | 6 +- provekit/verifier/src/lib.rs | 11 +- provekit/verifier/src/whir_r1cs.rs | 252 ++++++++++++----------- tooling/provekit-bench/benches/bench.rs | 6 +- tooling/provekit-bench/tests/compiler.rs | 3 +- 6 files changed, 152 insertions(+), 131 deletions(-) diff --git a/provekit/common/src/utils/zk_utils.rs b/provekit/common/src/utils/zk_utils.rs index 0a445802..87fc806a 100644 --- a/provekit/common/src/utils/zk_utils.rs +++ b/provekit/common/src/utils/zk_utils.rs @@ -40,7 +40,8 @@ pub fn generate_random_multilinear_polynomial(num_vars: usize) -> Vec(mut a: F, n: usize, x: &[F]) -> F { a = a.square(); } borrow_0 -} \ No newline at end of file +} diff --git a/provekit/prover/src/whir_r1cs.rs b/provekit/prover/src/whir_r1cs.rs index 4be932e7..55c95969 100644 --- a/provekit/prover/src/whir_r1cs.rs +++ b/provekit/prover/src/whir_r1cs.rs @@ -710,9 +710,5 @@ fn get_public_weights( current_pow = current_pow * x; } - Weights::geometric( - x, - public_inputs.len(), - EvaluationsList::new(public_weights), - ) + Weights::geometric(x, public_inputs.len(), EvaluationsList::new(public_weights)) } diff --git a/provekit/verifier/src/lib.rs b/provekit/verifier/src/lib.rs index 9ac47f4d..ea70671c 100644 --- a/provekit/verifier/src/lib.rs +++ b/provekit/verifier/src/lib.rs @@ -3,7 +3,7 @@ mod whir_r1cs; use { crate::whir_r1cs::WhirR1CSVerifier, anyhow::Result, - provekit_common::{NoirProof, R1CS, Verifier}, + provekit_common::{NoirProof, Verifier, R1CS}, tracing::instrument, }; @@ -14,10 +14,11 @@ pub trait Verify { impl Verify for Verifier { #[instrument(skip_all)] fn verify(&mut self, proof: &NoirProof, r1cs: &R1CS) -> Result<()> { - self.whir_for_witness - .take() - .unwrap() - .verify(&proof.whir_r1cs_proof, &proof.public_inputs, r1cs)?; + self.whir_for_witness.take().unwrap().verify( + &proof.whir_r1cs_proof, + &proof.public_inputs, + r1cs, + )?; Ok(()) } diff --git a/provekit/verifier/src/whir_r1cs.rs b/provekit/verifier/src/whir_r1cs.rs index 53834362..c1065e2e 100644 --- a/provekit/verifier/src/whir_r1cs.rs +++ b/provekit/verifier/src/whir_r1cs.rs @@ -3,12 +3,13 @@ use { ark_std::{One, Zero}, provekit_common::{ skyscraper::SkyscraperSponge, - utils::sumcheck::{ - calculate_eq, calculate_evaluations_over_boolean_hypercube_for_eq, eval_cubic_poly, + utils::{ + sumcheck::{ + calculate_eq, calculate_evaluations_over_boolean_hypercube_for_eq, eval_cubic_poly, + }, + zk_utils::geometric_till, }, - FieldElement, PublicInputs, WhirConfig, WhirR1CSProof, WhirR1CSScheme, - R1CS, - utils::zk_utils::geometric_till, + FieldElement, PublicInputs, WhirConfig, WhirR1CSProof, WhirR1CSScheme, R1CS, }, spongefish::{ codecs::arkworks_algebra::{FieldToUnitDeserialize, UnitToField}, @@ -33,13 +34,23 @@ pub struct DataFromSumcheckVerifier { } pub trait WhirR1CSVerifier { - fn verify(&self, proof: &WhirR1CSProof, public_inputs: &PublicInputs, r1cs: &R1CS) -> Result<()>; + fn verify( + &self, + proof: &WhirR1CSProof, + public_inputs: &PublicInputs, + r1cs: &R1CS, + ) -> Result<()>; } impl WhirR1CSVerifier for WhirR1CSScheme { #[instrument(skip_all)] #[allow(unused)] - fn verify(&self, proof: &WhirR1CSProof, public_inputs: &PublicInputs, r1cs: &R1CS) -> Result<()> { + fn verify( + &self, + proof: &WhirR1CSProof, + public_inputs: &PublicInputs, + r1cs: &R1CS, + ) -> Result<()> { let io = self.create_io_pattern(); let mut arthur = io.to_verifier_state(&proof.transcript); @@ -61,124 +72,130 @@ impl WhirR1CSVerifier for WhirR1CSScheme { .context("while verifying sumcheck")?; // Read hints and verify WHIR proof - let (az_at_alpha, bz_at_alpha, cz_at_alpha, whir_folding_randomness, deferred_evals, public_weights_challenge) = - if let Some(parsed_commitment_2) = parsed_commitment_2 { - // Dual commitment mode - let sums_1: (Vec, Vec) = arthur.hint()?; - let sums_2: (Vec, Vec) = arthur.hint()?; - - let whir_sums_1: ([FieldElement; 3], [FieldElement; 3]) = - (sums_1.0.try_into().unwrap(), sums_1.1.try_into().unwrap()); - let whir_sums_2: ([FieldElement; 3], [FieldElement; 3]) = - (sums_2.0.try_into().unwrap(), sums_2.1.try_into().unwrap()); - - let mut statement_1 = prepare_statement_for_witness_verifier::<3>( + let ( + az_at_alpha, + bz_at_alpha, + cz_at_alpha, + whir_folding_randomness, + deferred_evals, + public_weights_challenge, + ) = if let Some(parsed_commitment_2) = parsed_commitment_2 { + // Dual commitment mode + let sums_1: (Vec, Vec) = arthur.hint()?; + let sums_2: (Vec, Vec) = arthur.hint()?; + + let whir_sums_1: ([FieldElement; 3], [FieldElement; 3]) = + (sums_1.0.try_into().unwrap(), sums_1.1.try_into().unwrap()); + let whir_sums_2: ([FieldElement; 3], [FieldElement; 3]) = + (sums_2.0.try_into().unwrap(), sums_2.1.try_into().unwrap()); + + let mut statement_1 = prepare_statement_for_witness_verifier::<3>( + self.m, + &parsed_commitment_1, + &whir_sums_1, + ); + let statement_2 = prepare_statement_for_witness_verifier::<3>( + self.m, + &parsed_commitment_2, + &whir_sums_2, + ); + + let mut public_inputs_hash_buf = [FieldElement::zero()]; + arthur.fill_next_scalars(&mut public_inputs_hash_buf)?; + let expected_public_inputs_hash = public_inputs.hash(); + ensure!( + public_inputs_hash_buf[0] == expected_public_inputs_hash, + "Public inputs hash mismatch: expected {:?}, got {:?}", + expected_public_inputs_hash, + public_inputs_hash_buf[0] + ); + + let mut public_weights_vector_random_buf = [FieldElement::zero()]; + arthur.fill_challenge_scalars(&mut public_weights_vector_random_buf)?; + + let whir_public_weights_query_answer: (FieldElement, FieldElement) = arthur + .hint() + .context("failed to read WHIR public weights query answer")?; + + if !public_inputs.is_empty() { + update_statement_for_witness_verifier( self.m, + &mut statement_1, &parsed_commitment_1, - &whir_sums_1, - ); - let statement_2 = prepare_statement_for_witness_verifier::<3>( - self.m, - &parsed_commitment_2, - &whir_sums_2, + whir_public_weights_query_answer, ); + } - let mut public_inputs_hash_buf = [FieldElement::zero()]; - arthur.fill_next_scalars(&mut public_inputs_hash_buf)?; - let expected_public_inputs_hash = public_inputs.hash(); - ensure!( - public_inputs_hash_buf[0] == expected_public_inputs_hash, - "Public inputs hash mismatch: expected {:?}, got {:?}", - expected_public_inputs_hash, - public_inputs_hash_buf[0] - ); + let (whir_folding_randomness, deferred_evals) = run_whir_pcs_batch_verifier( + &mut arthur, + &self.whir_witness, + &[parsed_commitment_1, parsed_commitment_2], + &[statement_1, statement_2], + ) + .context("while verifying WHIR batch proof")?; + + ( + whir_sums_1.0[0] + whir_sums_2.0[0], + whir_sums_1.0[1] + whir_sums_2.0[1], + whir_sums_1.0[2] + whir_sums_2.0[2], + whir_folding_randomness.0.to_vec(), + deferred_evals, + public_weights_vector_random_buf[0], + ) + } else { + // Single commitment mode + let sums: (Vec, Vec) = arthur.hint()?; + let whir_sums: ([FieldElement; 3], [FieldElement; 3]) = + (sums.0.try_into().unwrap(), sums.1.try_into().unwrap()); + + let mut statement = prepare_statement_for_witness_verifier::<3>( + self.m, + &parsed_commitment_1, + &whir_sums, + ); + + let mut public_inputs_hash_buf = [FieldElement::zero()]; + arthur.fill_next_scalars(&mut public_inputs_hash_buf)?; + let expected_public_inputs_hash = public_inputs.hash(); + ensure!( + public_inputs_hash_buf[0] == expected_public_inputs_hash, + "Public inputs hash mismatch: expected {:?}, got {:?}", + expected_public_inputs_hash, + public_inputs_hash_buf[0] + ); + + let mut public_weights_vector_random_buf = [FieldElement::zero()]; + arthur.fill_challenge_scalars(&mut public_weights_vector_random_buf)?; - let mut public_weights_vector_random_buf = [FieldElement::zero()]; - arthur.fill_challenge_scalars(&mut public_weights_vector_random_buf)?; - - let whir_public_weights_query_answer: (FieldElement, FieldElement) = arthur - .hint() - .context("failed to read WHIR public weights query answer")?; - - if !public_inputs.is_empty() { - update_statement_for_witness_verifier( - self.m, - &mut statement_1, - &parsed_commitment_1, - whir_public_weights_query_answer, - ); - } - - let (whir_folding_randomness, deferred_evals) = run_whir_pcs_batch_verifier( - &mut arthur, - &self.whir_witness, - &[parsed_commitment_1, parsed_commitment_2], - &[statement_1, statement_2], - ) - .context("while verifying WHIR batch proof")?; - - ( - whir_sums_1.0[0] + whir_sums_2.0[0], - whir_sums_1.0[1] + whir_sums_2.0[1], - whir_sums_1.0[2] + whir_sums_2.0[2], - whir_folding_randomness.0.to_vec(), - deferred_evals, - public_weights_vector_random_buf[0], - ) - } else { - // Single commitment mode - let sums: (Vec, Vec) = arthur.hint()?; - let whir_sums: ([FieldElement; 3], [FieldElement; 3]) = - (sums.0.try_into().unwrap(), sums.1.try_into().unwrap()); - - let mut statement = prepare_statement_for_witness_verifier::<3>( + let whir_public_weights_query_answer: (FieldElement, FieldElement) = arthur + .hint() + .context("failed to read WHIR public weights query answer")?; + if !public_inputs.is_empty() { + update_statement_for_witness_verifier( self.m, + &mut statement, &parsed_commitment_1, - &whir_sums, - ); - - let mut public_inputs_hash_buf = [FieldElement::zero()]; - arthur.fill_next_scalars(&mut public_inputs_hash_buf)?; - let expected_public_inputs_hash = public_inputs.hash(); - ensure!( - public_inputs_hash_buf[0] == expected_public_inputs_hash, - "Public inputs hash mismatch: expected {:?}, got {:?}", - expected_public_inputs_hash, - public_inputs_hash_buf[0] + whir_public_weights_query_answer, ); + } - let mut public_weights_vector_random_buf = [FieldElement::zero()]; - arthur.fill_challenge_scalars(&mut public_weights_vector_random_buf)?; - - let whir_public_weights_query_answer: (FieldElement, FieldElement) = arthur - .hint() - .context("failed to read WHIR public weights query answer")?; - if !public_inputs.is_empty() { - update_statement_for_witness_verifier( - self.m, - &mut statement, - &parsed_commitment_1, - whir_public_weights_query_answer, - ); - } - - let (whir_folding_randomness, deferred_evals) = run_whir_pcs_verifier( - &mut arthur, - &parsed_commitment_1, - &self.whir_witness, - &statement, - ) - .context("while verifying WHIR proof")?; - - ( - whir_sums.0[0], - whir_sums.0[1], - whir_sums.0[2], - whir_folding_randomness.0.to_vec(), - deferred_evals, - public_weights_vector_random_buf[0], - ) - }; + let (whir_folding_randomness, deferred_evals) = run_whir_pcs_verifier( + &mut arthur, + &parsed_commitment_1, + &self.whir_witness, + &statement, + ) + .context("while verifying WHIR proof")?; + + ( + whir_sums.0[0], + whir_sums.0[1], + whir_sums.0[2], + whir_folding_randomness.0.to_vec(), + deferred_evals, + public_weights_vector_random_buf[0], + ) + }; // Check the Spartan sumcheck relation ensure!( @@ -453,7 +470,8 @@ fn compute_public_weight_evaluation( #[cfg(test)] { - let eq_polys = calculate_evaluations_over_boolean_hypercube_for_eq(folding_randomness.to_vec()); + let eq_polys = + calculate_evaluations_over_boolean_hypercube_for_eq(folding_randomness.to_vec()); let sum: FieldElement = public_weights .iter() .zip(eq_polys.iter()) diff --git a/tooling/provekit-bench/benches/bench.rs b/tooling/provekit-bench/benches/bench.rs index ce703453..f1ab224e 100644 --- a/tooling/provekit-bench/benches/bench.rs +++ b/tooling/provekit-bench/benches/bench.rs @@ -63,7 +63,11 @@ fn verify_poseidon_1000(bencher: Bencher) { let mut verifier: Verifier = read(&proof_verifier_path).unwrap(); let proof_path = crate_dir.join("noir-proof.np"); let proof: NoirProof = read(&proof_path).unwrap(); - bencher.bench_local(|| black_box(&mut verifier).verify(black_box(&proof))); + let proof_prover_path = crate_dir.join("noir-provekit-prover.pkp"); + let prover: Prover = read(&proof_prover_path).unwrap(); + bencher.bench_local(|| { + black_box(&mut verifier).verify(black_box(&proof), black_box(&prover.r1cs)) + }); } fn main() { diff --git a/tooling/provekit-bench/tests/compiler.rs b/tooling/provekit-bench/tests/compiler.rs index 9ac24ffd..4b64aaba 100644 --- a/tooling/provekit-bench/tests/compiler.rs +++ b/tooling/provekit-bench/tests/compiler.rs @@ -42,11 +42,12 @@ fn test_compiler(test_case_path: impl AsRef) { let prover = Prover::from_noir_proof_scheme(schema.clone()); let mut verifier = Verifier::from_noir_proof_scheme(schema.clone()); + let r1cs = prover.r1cs.clone(); let proof = prover .prove(&witness_file_path) .expect("While proving Noir program statement"); - verifier.verify(&proof).expect("Verifying proof"); + verifier.verify(&proof, &r1cs).expect("Verifying proof"); } pub fn compile_workspace(workspace_path: impl AsRef) -> Result {