diff --git a/provekit/common/Cargo.toml b/provekit/common/Cargo.toml index d39db74e..20fe2897 100644 --- a/provekit/common/Cargo.toml +++ b/provekit/common/Cargo.toml @@ -39,6 +39,7 @@ ruint.workspace = true serde.workspace = true serde_json.workspace = true tracing.workspace = true +sha2.workspace = true zerocopy.workspace = true zeroize.workspace = true zstd.workspace = true diff --git a/provekit/common/src/lib.rs b/provekit/common/src/lib.rs index 680715d8..31cb4b56 100644 --- a/provekit/common/src/lib.rs +++ b/provekit/common/src/lib.rs @@ -22,6 +22,7 @@ pub use { verifier::Verifier, whir::crypto::fields::Field256 as FieldElement, whir_r1cs::{IOPattern, WhirConfig, WhirR1CSProof, WhirR1CSScheme}, + witness::PublicInputs, }; #[cfg(test)] diff --git a/provekit/common/src/noir_proof_scheme.rs b/provekit/common/src/noir_proof_scheme.rs index b8c8ff93..7552ab26 100644 --- a/provekit/common/src/noir_proof_scheme.rs +++ b/provekit/common/src/noir_proof_scheme.rs @@ -2,7 +2,7 @@ use { crate::{ whir_r1cs::{WhirR1CSProof, WhirR1CSScheme}, witness::{NoirWitnessGenerator, SplitWitnessBuilders}, - NoirElement, R1CS, + NoirElement, PublicInputs, R1CS, }, acir::circuit::Program, serde::{Deserialize, Serialize}, @@ -20,6 +20,7 @@ pub struct NoirProofScheme { #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub struct NoirProof { + pub public_inputs: PublicInputs, pub whir_r1cs_proof: WhirR1CSProof, } diff --git a/provekit/common/src/skyscraper/mod.rs b/provekit/common/src/skyscraper/mod.rs index 3b6da92a..2caecdc2 100644 --- a/provekit/common/src/skyscraper/mod.rs +++ b/provekit/common/src/skyscraper/mod.rs @@ -2,4 +2,8 @@ mod pow; mod sponge; mod whir; -pub use self::{pow::SkyscraperPoW, sponge::SkyscraperSponge, whir::SkyscraperMerkleConfig}; +pub use self::{ + pow::SkyscraperPoW, + sponge::SkyscraperSponge, + whir::{SkyscraperCRH, SkyscraperMerkleConfig}, +}; diff --git a/provekit/common/src/utils/mod.rs b/provekit/common/src/utils/mod.rs index a5f6aa5b..43943c56 100644 --- a/provekit/common/src/utils/mod.rs +++ b/provekit/common/src/utils/mod.rs @@ -2,6 +2,7 @@ mod print_abi; pub mod serde_ark; pub mod serde_ark_option; +pub mod serde_ark_vec; pub mod serde_hex; pub mod serde_jsonify; pub mod sumcheck; diff --git a/provekit/common/src/utils/serde_ark_vec.rs b/provekit/common/src/utils/serde_ark_vec.rs new file mode 100644 index 00000000..b26c309e --- /dev/null +++ b/provekit/common/src/utils/serde_ark_vec.rs @@ -0,0 +1,87 @@ +use { + crate::FieldElement, + ark_serialize::{CanonicalDeserialize, CanonicalSerialize}, + serde::{ + de::{Error as _, SeqAccess, Visitor}, + ser::{Error as _, SerializeSeq}, + Deserializer, Serializer, + }, + std::fmt, +}; + +pub fn serialize(vec: &Vec, serializer: S) -> Result +where + S: Serializer, +{ + let is_human_readable = serializer.is_human_readable(); + let mut seq = serializer.serialize_seq(Some(vec.len()))?; + for element in vec { + let mut buf = Vec::with_capacity(element.compressed_size()); + element + .serialize_compressed(&mut buf) + .map_err(|e| S::Error::custom(format!("Failed to serialize: {e}")))?; + + // Write bytes + if is_human_readable { + // ark_serialize doesn't have human-readable serialization. And Serde + // doesn't have good defaults for [u8]. So we implement hexadecimal + // serialization. + let hex = hex::encode(buf); + seq.serialize_element(&hex)?; + } else { + seq.serialize_element(&buf)?; + } + } + seq.end() +} + +pub fn deserialize<'de, D>(deserializer: D) -> Result, D::Error> +where + D: Deserializer<'de>, +{ + struct VecVisitor { + is_human_readable: bool, + } + + impl<'de> Visitor<'de> for VecVisitor { + type Value = Vec; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a sequence of field elements") + } + + fn visit_seq(self, mut seq: A) -> Result + where + A: SeqAccess<'de>, + { + let mut vec = Vec::new(); + if self.is_human_readable { + while let Some(hex) = seq.next_element::()? { + let bytes = hex::decode(hex) + .map_err(|e| A::Error::custom(format!("invalid hex: {e}")))?; + let mut reader = &*bytes; + let element = FieldElement::deserialize_compressed(&mut reader) + .map_err(|e| A::Error::custom(format!("deserialize failed: {e}")))?; + if !reader.is_empty() { + return Err(A::Error::custom("while deserializing: trailing bytes")); + } + vec.push(element); + } + } else { + while let Some(bytes) = seq.next_element::>()? { + let mut reader = &*bytes; + let element = FieldElement::deserialize_compressed(&mut reader) + .map_err(|e| A::Error::custom(format!("deserialize failed: {e}")))?; + if !reader.is_empty() { + return Err(A::Error::custom("while deserializing: trailing bytes")); + } + vec.push(element); + } + } + Ok(vec) + } + } + + let is_human_readable = deserializer.is_human_readable(); + deserializer.deserialize_seq(VecVisitor { is_human_readable }) +} diff --git a/provekit/common/src/utils/sumcheck.rs b/provekit/common/src/utils/sumcheck.rs index 7e1c5a24..6baef51d 100644 --- a/provekit/common/src/utils/sumcheck.rs +++ b/provekit/common/src/utils/sumcheck.rs @@ -114,6 +114,10 @@ pub trait SumcheckIOPattern { fn add_rand(self, num_rand: usize) -> Self; fn add_zk_sumcheck_polynomials(self, num_vars: usize) -> Self; + + /// Prover sends the hash of the public inputs + /// Verifier sends randomness to construct weights + fn add_public_inputs(self) -> Self; } impl SumcheckIOPattern for IOPattern @@ -136,6 +140,12 @@ where self } + fn add_public_inputs(mut self) -> Self { + self = self.add_scalars(1, "Public Inputs Hash"); + self = self.challenge_scalars(1, "Public Weights Vector Random"); + self + } + fn add_rand(self, num_rand: usize) -> Self { self.challenge_scalars(num_rand, "rand") } diff --git a/provekit/common/src/utils/zk_utils.rs b/provekit/common/src/utils/zk_utils.rs index 82a65087..87fc806a 100644 --- a/provekit/common/src/utils/zk_utils.rs +++ b/provekit/common/src/utils/zk_utils.rs @@ -1,5 +1,7 @@ use { - crate::FieldElement, ark_ff::UniformRand, rayon::prelude::*, + crate::FieldElement, + ark_ff::{Field, UniformRand}, + rayon::prelude::*, whir::poly_utils::evals::EvaluationsList, }; @@ -37,3 +39,28 @@ pub fn generate_random_multilinear_polynomial(num_vars: usize) -> Vec(mut a: F, n: usize, x: &[F]) -> F { + let k = x.len(); + assert!(n > 0 && n < (1 << k)); + let mut borrow_0 = F::one(); + let mut borrow_1 = F::zero(); + for (i, &xi) in x.iter().rev().enumerate() { + let bn = ((n - 1) >> i) & 1; + let b0 = F::one() - xi; + let b1 = a * xi; + (borrow_0, borrow_1) = if bn == 0 { + (b0 * borrow_0, (b0 + b1) * borrow_1 + b1 * borrow_0) + } else { + ((b0 + b1) * borrow_0 + b0 * borrow_1, b1 * borrow_1) + }; + a = a.square(); + } + borrow_0 +} diff --git a/provekit/common/src/whir_r1cs.rs b/provekit/common/src/whir_r1cs.rs index 302bfd79..702a25c7 100644 --- a/provekit/common/src/whir_r1cs.rs +++ b/provekit/common/src/whir_r1cs.rs @@ -34,10 +34,11 @@ impl WhirR1CSScheme { if self.num_challenges > 0 { // Compute total constraints: OOD + statement // OOD: 2 witnesses × committment_ood_samples each - // Statement: 2 statements × 3 constraints each = 6 + // Statement: statement_1 has 3 constraints + 1 public weights constraint = 4, + // statement_2 has 3 constraints = 3, total = 7 let num_witnesses = 2; let num_ood_constraints = num_witnesses * self.whir_witness.committment_ood_samples; - let num_statement_constraints = 6; // 2 statements × 3 constraints + let num_statement_constraints = 7; let num_constraints_total = num_ood_constraints + num_statement_constraints; io = io @@ -50,6 +51,8 @@ impl WhirR1CSScheme { .add_whir_proof(&self.whir_for_hiding_spartan) .hint("claimed_evaluations_1") .hint("claimed_evaluations_2") + .add_public_inputs() + .hint("public_weights_evaluations") .add_whir_batch_proof(&self.whir_witness, num_witnesses, num_constraints_total); } else { io = io @@ -59,6 +62,8 @@ impl WhirR1CSScheme { .add_zk_sumcheck_polynomials(self.m_0) .add_whir_proof(&self.whir_for_hiding_spartan) .hint("claimed_evaluations") + .add_public_inputs() + .hint("public_weights_evaluations") .add_whir_proof(&self.whir_witness); } diff --git a/provekit/common/src/witness/mod.rs b/provekit/common/src/witness/mod.rs index 4361a5dc..a5e9a8c4 100644 --- a/provekit/common/src/witness/mod.rs +++ b/provekit/common/src/witness/mod.rs @@ -7,7 +7,12 @@ mod witness_generator; mod witness_io_pattern; use { - crate::{utils::serde_ark, FieldElement}, + crate::{ + skyscraper::SkyscraperCRH, + utils::{serde_ark, serde_ark_vec}, + FieldElement, + }, + ark_crypto_primitives::crh::CRHScheme, ark_ff::One, serde::{Deserialize, Serialize}, }; @@ -40,3 +45,43 @@ impl ConstantOrR1CSWitness { } } } + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct PublicInputs(#[serde(with = "serde_ark_vec")] pub Vec); + +impl PublicInputs { + pub fn new() -> Self { + Self(Vec::new()) + } + + pub fn from_vec(vec: Vec) -> Self { + Self(vec) + } + + pub fn len(&self) -> usize { + self.0.len() + } + + pub fn is_empty(&self) -> bool { + self.0.is_empty() + } + + pub fn hash(&self) -> FieldElement { + match self.0.len() { + 0 => FieldElement::from(0u64), + 1 => { + // For single element, hash it with zero to ensure it gets properly hashed + let padded = vec![self.0[0], FieldElement::from(0u64)]; + SkyscraperCRH::evaluate(&(), &padded[..]).expect("hash should succeed") + } + _ => SkyscraperCRH::evaluate(&(), &self.0[..]) + .expect("hash should succeed for multiple inputs"), + } + } +} + +impl Default for PublicInputs { + fn default() -> Self { + Self::new() + } +} diff --git a/provekit/common/src/witness/scheduling/splitter.rs b/provekit/common/src/witness/scheduling/splitter.rs index 57a44367..94ffffe9 100644 --- a/provekit/common/src/witness/scheduling/splitter.rs +++ b/provekit/common/src/witness/scheduling/splitter.rs @@ -26,7 +26,10 @@ impl<'a> WitnessSplitter<'a> { /// (post-challenge). /// /// Returns (w1_builder_indices, w2_builder_indices) - pub fn split_builders(&self) -> (Vec, Vec) { + pub fn split_builders( + &self, + acir_public_inputs_indices_set: HashSet, + ) -> (Vec, Vec) { let builder_count = self.witness_builders.len(); // Step 1: Find all Challenge builders @@ -40,7 +43,11 @@ impl<'a> WitnessSplitter<'a> { .collect(); if challenge_builders.is_empty() { - return ((0..builder_count).collect(), Vec::new()); + let w1_indices = self.rearrange_w1( + (0..builder_count).collect(), + &acir_public_inputs_indices_set, + ); + return (w1_indices, Vec::new()); } // Step 2: Forward DFS from challenges to find mandatory_w2 @@ -135,6 +142,7 @@ impl<'a> WitnessSplitter<'a> { // Step 7: Assign free builders greedily while respecting dependencies // Rule: if any dependency is in w2, the builder must also be in w2 // (because w1 is solved before w2) + // A free builder for public input witnesses goes in w1. let mut w1_set = mandatory_w1; let mut w2_set = mandatory_w2; @@ -149,6 +157,15 @@ impl<'a> WitnessSplitter<'a> { let witness_count = DependencyInfo::extract_writes(&self.witness_builders[idx]).len(); + // If free builder writes a public witness, add it to w1_set. + if let WitnessBuilder::Acir(_, acir_idx) = &self.witness_builders[idx] { + if acir_public_inputs_indices_set.contains(&(*acir_idx as u32)) { + w1_set.insert(idx); + w1_witness_count += witness_count; + continue; + } + } + if must_be_w2 { w2_set.insert(idx); w2_witness_count += witness_count; @@ -165,9 +182,53 @@ impl<'a> WitnessSplitter<'a> { let mut w1_indices: Vec = w1_set.into_iter().collect(); let mut w2_indices: Vec = w2_set.into_iter().collect(); - w1_indices.sort_unstable(); + w1_indices = self.rearrange_w1(w1_indices, &acir_public_inputs_indices_set); w2_indices.sort_unstable(); (w1_indices, w2_indices) } + + /// Rearranges w1 builder indices into a canonical order: + /// 1. Constant builder (index 0) first, to preserve R1CS index 0 invariant + /// 2. Public input builders next, grouped together + /// 3. All other w1 builders last, sorted by index + fn rearrange_w1( + &self, + w1_indices: Vec, + acir_public_inputs_indices_set: &HashSet, + ) -> Vec { + let mut public_input_builder_indices = Vec::new(); + let mut rest_indices = Vec::new(); + + // Sanity Check: Make sure all public inputs and WITNESS_ONE_IDX are in + // w1_indices. + // Convert to HashSet for O(1) lookups since we're checking many times + let w1_indices_set = w1_indices.iter().copied().collect::>(); + for &idx in acir_public_inputs_indices_set.iter() { + if !w1_indices_set.contains(&(idx as usize)) { + panic!("Public input {} is not in w1_indices", idx); + } + } + + // Separate into: 0, public inputs, and rest + for builder_idx in w1_indices { + if builder_idx == 0 { + continue; // Will add 0 first + } else if let WitnessBuilder::Acir(_, acir_idx) = &self.witness_builders[builder_idx] { + if acir_public_inputs_indices_set.contains(&(*acir_idx as u32)) { + public_input_builder_indices.push(builder_idx); + continue; + } + } + rest_indices.push(builder_idx); + } + + rest_indices.sort_unstable(); + + // Reorder: 0 first, then public inputs, then rest + let mut new_w1_indices = vec![0]; + new_w1_indices.extend(public_input_builder_indices); + new_w1_indices.extend(rest_indices); + new_w1_indices + } } diff --git a/provekit/common/src/witness/witness_builder.rs b/provekit/common/src/witness/witness_builder.rs index d212c9bc..9567eac7 100644 --- a/provekit/common/src/witness/witness_builder.rs +++ b/provekit/common/src/witness/witness_builder.rs @@ -14,7 +14,7 @@ use { FieldElement, R1CS, }, serde::{Deserialize, Serialize}, - std::num::NonZeroU32, + std::{collections::HashSet, num::NonZeroU32}, }; #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] @@ -174,6 +174,7 @@ impl WitnessBuilder { witness_builders: &[WitnessBuilder], r1cs: R1CS, witness_map: Vec>, + acir_public_inputs_indices_set: HashSet, ) -> (SplitWitnessBuilders, R1CS, Vec>, usize) { if witness_builders.is_empty() { return ( @@ -190,7 +191,7 @@ impl WitnessBuilder { // Step 1: Analyze dependencies and split into w1/w2 let splitter = WitnessSplitter::new(witness_builders); - let (w1_indices, w2_indices) = splitter.split_builders(); + let (w1_indices, w2_indices) = splitter.split_builders(acir_public_inputs_indices_set); // Step 2: Extract w1 and w2 builders in order let w1_builders: Vec = w1_indices diff --git a/provekit/prover/src/lib.rs b/provekit/prover/src/lib.rs index 031a29dd..bb89b790 100644 --- a/provekit/prover/src/lib.rs +++ b/provekit/prover/src/lib.rs @@ -6,7 +6,7 @@ use { nargo::foreign_calls::DefaultForeignCallBuilder, noir_artifact_cli::fs::inputs::read_inputs_from_file, noirc_abi::InputMap, - provekit_common::{FieldElement, IOPattern, NoirElement, NoirProof, Prover}, + provekit_common::{FieldElement, IOPattern, NoirElement, NoirProof, Prover, PublicInputs}, std::path::Path, tracing::instrument, }; @@ -56,6 +56,7 @@ impl Prove for Prover { read_inputs_from_file(prover_toml.as_ref(), self.witness_generator.abi())?; let acir_witness_idx_to_value_map = self.generate_witness(input_map)?; + let acir_public_inputs = self.program.functions[0].public_inputs().indices(); // Set up transcript let io: IOPattern = self.whir_for_witness.create_io_pattern(); @@ -112,14 +113,30 @@ impl Prove for Prover { self.r1cs .test_witness_satisfaction(&witness.iter().map(|w| w.unwrap()).collect::>()) .context("While verifying R1CS instance")?; + + // Gather public inputs from witness + let num_public_inputs = acir_public_inputs.len(); + let public_inputs = if num_public_inputs == 0 { + PublicInputs::new() + } else { + PublicInputs::from_vec( + witness[1..=num_public_inputs] + .iter() + .map(|w| w.ok_or_else(|| anyhow::anyhow!("Missing public input witness"))) + .collect::>>()?, + ) + }; drop(witness); let whir_r1cs_proof = self .whir_for_witness - .prove(merlin, self.r1cs, commitments) + .prove(merlin, self.r1cs, commitments, &public_inputs) .context("While proving R1CS instance")?; - Ok(NoirProof { whir_r1cs_proof }) + Ok(NoirProof { + public_inputs, + whir_r1cs_proof, + }) } } diff --git a/provekit/prover/src/whir_r1cs.rs b/provekit/prover/src/whir_r1cs.rs index 60bb54cd..55c95969 100644 --- a/provekit/prover/src/whir_r1cs.rs +++ b/provekit/prover/src/whir_r1cs.rs @@ -14,7 +14,7 @@ use { zk_utils::{create_masked_polynomial, generate_random_multilinear_polynomial}, HALF, }, - FieldElement, WhirConfig, WhirR1CSProof, WhirR1CSScheme, R1CS, + FieldElement, PublicInputs, WhirConfig, WhirR1CSProof, WhirR1CSScheme, R1CS, }, spongefish::{ codecs::arkworks_algebra::{FieldToUnitSerialize, UnitToField}, @@ -54,6 +54,7 @@ pub trait WhirR1CSProver { merlin: ProverState, r1cs: R1CS, commitments: Vec, + public_inputs: &PublicInputs, ) -> Result; } @@ -121,6 +122,7 @@ impl WhirR1CSProver for WhirR1CSScheme { mut merlin: ProverState, r1cs: R1CS, mut commitments: Vec, + public_inputs: &PublicInputs, ) -> Result { ensure!(!commitments.is_empty(), "Need at least one commitment"); @@ -159,16 +161,36 @@ impl WhirR1CSProver for WhirR1CSScheme { let commitment = commitments.into_iter().next().unwrap(); let alphas: [Vec; 3] = alphas.try_into().unwrap(); - let (statement, f_sums, g_sums) = create_combined_statement_over_two_polynomials::<3>( + let (mut statement, f_sums, g_sums) = create_combined_statement_over_two_polynomials::<3>( self.m, &commitment.commitment_to_witness, - commitment.masked_polynomial, - commitment.random_polynomial, + &commitment.masked_polynomial, + &commitment.random_polynomial, &alphas, ); merlin.hint::<(Vec, Vec)>(&(f_sums, g_sums))?; + // VERIFY the size given by self.m + let public_weight = get_public_weights(public_inputs, &mut merlin, self.m); + let (public_f_sum, public_g_sum) = if public_inputs.is_empty() { + // If there are no public inputs, the hint is unused by the verifier and can be + // assigned an arbitrary value. + let public_f_sum = FieldElement::zero(); + let public_g_sum = FieldElement::zero(); + (public_f_sum, public_g_sum) + } else { + update_statement_with_public_weights( + &mut statement, + &commitment.commitment_to_witness, + &commitment.masked_polynomial, + &commitment.random_polynomial, + public_weight, + ) + }; + + merlin.hint::<(FieldElement, FieldElement)>(&(public_f_sum, public_g_sum))?; + run_zk_whir_pcs_prover( commitment.commitment_to_witness, statement, @@ -193,12 +215,12 @@ impl WhirR1CSProver for WhirR1CSScheme { let alphas_1: [Vec; 3] = alphas_1.try_into().unwrap(); let alphas_2: [Vec; 3] = alphas_2.try_into().unwrap(); - let (statement_1, f_sums_1, g_sums_1) = + let (mut statement_1, f_sums_1, g_sums_1) = create_combined_statement_over_two_polynomials::<3>( self.m, &c1.commitment_to_witness, - c1.masked_polynomial, - c1.random_polynomial, + &c1.masked_polynomial, + &c1.random_polynomial, &alphas_1, ); drop(alphas_1); @@ -207,8 +229,8 @@ impl WhirR1CSProver for WhirR1CSScheme { create_combined_statement_over_two_polynomials::<3>( self.m, &c2.commitment_to_witness, - c2.masked_polynomial, - c2.random_polynomial, + &c2.masked_polynomial, + &c2.random_polynomial, &alphas_2, ); drop(alphas_2); @@ -216,6 +238,23 @@ impl WhirR1CSProver for WhirR1CSScheme { merlin.hint::<(Vec, Vec)>(&(f_sums_1, g_sums_1))?; merlin.hint::<(Vec, Vec)>(&(f_sums_2, g_sums_2))?; + let public_weight = get_public_weights(public_inputs, &mut merlin, self.m); + let (public_f_sum, public_g_sum) = if public_inputs.is_empty() { + let public_f_sum = FieldElement::zero(); + let public_g_sum = FieldElement::zero(); + (public_f_sum, public_g_sum) + } else { + update_statement_with_public_weights( + &mut statement_1, + &c1.commitment_to_witness, + &c1.masked_polynomial, + &c1.random_polynomial, + public_weight, + ) + }; + + merlin.hint::<(FieldElement, FieldElement)>(&(public_f_sum, public_g_sum))?; + run_zk_whir_pcs_batch_prover( &[c1.commitment_to_witness, c2.commitment_to_witness], &[statement_1, statement_2], @@ -514,8 +553,8 @@ pub fn run_zk_sumcheck_prover( create_combined_statement_over_two_polynomials::<1>( blinding_polynomial_variables + 1, &commitment_to_blinding_polynomial, - blindings_mask_polynomial, - blindings_blind_polynomial, + &blindings_mask_polynomial, + &blindings_blind_polynomial, &[expand_powers(alpha.as_slice())], ); @@ -548,8 +587,8 @@ fn expand_powers(values: &[FieldElement]) -> Vec { fn create_combined_statement_over_two_polynomials( cfg_nv: usize, witness: &Witness, - f_polynomial: EvaluationsList, - g_polynomial: EvaluationsList, + f_polynomial: &EvaluationsList, + g_polynomial: &EvaluationsList, alphas: &[Vec; N], ) -> ( Statement, @@ -579,8 +618,8 @@ fn create_combined_statement_over_two_polynomials( w_full.resize(final_len, FieldElement::zero()); let weight = Weights::linear(EvaluationsList::new(w_full)); - let f = weight.weighted_sum(&f_polynomial); - let g = weight.weighted_sum(&g_polynomial); + let f = weight.weighted_sum(f_polynomial); + let g = weight.weighted_sum(g_polynomial); statement.add_constraint(weight, f + witness.batching_randomness * g); f_sums.push(f); @@ -631,3 +670,45 @@ pub fn run_zk_whir_pcs_batch_prover( (randomness, deferred) } + +fn update_statement_with_public_weights( + statement: &mut Statement, + witness: &Witness, + f_polynomial: &EvaluationsList, + g_polynomial: &EvaluationsList, + public_weights: Weights, +) -> (FieldElement, FieldElement) { + let f = public_weights.weighted_sum(f_polynomial); + let g = public_weights.weighted_sum(g_polynomial); + statement.add_constraint_in_front(public_weights, f + witness.batching_randomness * g); + (f, g) +} + +fn get_public_weights( + public_inputs: &PublicInputs, + merlin: &mut ProverState, + m: usize, +) -> Weights { + // Add hash to transcript + let public_inputs_hash = public_inputs.hash(); + let _ = merlin.add_scalars(&[public_inputs_hash]); + + // Get random point x + let mut x_buf = [FieldElement::zero()]; + merlin + .fill_challenge_scalars(&mut x_buf) + .expect("Failed to get challenge from Merlin"); + let x = x_buf[0]; + + let domain_size = 1 << m; + let mut public_weights = vec![FieldElement::zero(); domain_size]; + + // Set public weights for public inputs [1,x,x^2,x^3...x^n-1,0,0,0...0] + let mut current_pow = FieldElement::one(); + for (idx, _) in public_inputs.0.iter().enumerate() { + public_weights[idx] = current_pow; + current_pow = current_pow * x; + } + + Weights::geometric(x, public_inputs.len(), EvaluationsList::new(public_weights)) +} diff --git a/provekit/r1cs-compiler/src/noir_proof_scheme.rs b/provekit/r1cs-compiler/src/noir_proof_scheme.rs index 58e9346a..0cde2f1b 100644 --- a/provekit/r1cs-compiler/src/noir_proof_scheme.rs +++ b/provekit/r1cs-compiler/src/noir_proof_scheme.rs @@ -10,7 +10,7 @@ use { witness::{NoirWitnessGenerator, WitnessBuilder}, NoirProofScheme, WhirR1CSScheme, }, - std::{fs::File, path::Path}, + std::{collections::HashSet, fs::File, path::Path}, tracing::{info, instrument}, }; @@ -61,9 +61,18 @@ impl NoirProofSchemeBuilder for NoirProofScheme { r1cs.c.num_entries() ); + // Extract ACIR public input indices set + let acir_public_inputs_indices_set: HashSet = + main.public_inputs().indices().iter().cloned().collect(); + // Split witness builders and remap indices for sound challenge generation let (split_witness_builders, remapped_r1cs, remapped_witness_map, num_challenges) = - WitnessBuilder::split_and_prepare_layers(&witness_builders, r1cs, witness_map); + WitnessBuilder::split_and_prepare_layers( + &witness_builders, + r1cs, + witness_map, + acir_public_inputs_indices_set, + ); info!( "Witness split: w1 size = {}, w2 size = {}", split_witness_builders.w1_size, diff --git a/provekit/verifier/src/lib.rs b/provekit/verifier/src/lib.rs index abdb369f..ea70671c 100644 --- a/provekit/verifier/src/lib.rs +++ b/provekit/verifier/src/lib.rs @@ -3,21 +3,22 @@ mod whir_r1cs; use { crate::whir_r1cs::WhirR1CSVerifier, anyhow::Result, - provekit_common::{NoirProof, Verifier}, + provekit_common::{NoirProof, Verifier, R1CS}, tracing::instrument, }; pub trait Verify { - fn verify(&mut self, proof: &NoirProof) -> Result<()>; + fn verify(&mut self, proof: &NoirProof, r1cs: &R1CS) -> Result<()>; } impl Verify for Verifier { #[instrument(skip_all)] - fn verify(&mut self, proof: &NoirProof) -> Result<()> { - self.whir_for_witness - .take() - .unwrap() - .verify(&proof.whir_r1cs_proof)?; + fn verify(&mut self, proof: &NoirProof, r1cs: &R1CS) -> Result<()> { + self.whir_for_witness.take().unwrap().verify( + &proof.whir_r1cs_proof, + &proof.public_inputs, + r1cs, + )?; Ok(()) } diff --git a/provekit/verifier/src/whir_r1cs.rs b/provekit/verifier/src/whir_r1cs.rs index 8a668fa0..c1065e2e 100644 --- a/provekit/verifier/src/whir_r1cs.rs +++ b/provekit/verifier/src/whir_r1cs.rs @@ -3,8 +3,13 @@ use { ark_std::{One, Zero}, provekit_common::{ skyscraper::SkyscraperSponge, - utils::sumcheck::{calculate_eq, eval_cubic_poly}, - FieldElement, WhirConfig, WhirR1CSProof, WhirR1CSScheme, + utils::{ + sumcheck::{ + calculate_eq, calculate_evaluations_over_boolean_hypercube_for_eq, eval_cubic_poly, + }, + zk_utils::geometric_till, + }, + FieldElement, PublicInputs, WhirConfig, WhirR1CSProof, WhirR1CSScheme, R1CS, }, spongefish::{ codecs::arkworks_algebra::{FieldToUnitDeserialize, UnitToField}, @@ -29,13 +34,23 @@ pub struct DataFromSumcheckVerifier { } pub trait WhirR1CSVerifier { - fn verify(&self, proof: &WhirR1CSProof) -> Result<()>; + fn verify( + &self, + proof: &WhirR1CSProof, + public_inputs: &PublicInputs, + r1cs: &R1CS, + ) -> Result<()>; } impl WhirR1CSVerifier for WhirR1CSScheme { #[instrument(skip_all)] #[allow(unused)] - fn verify(&self, proof: &WhirR1CSProof) -> Result<()> { + fn verify( + &self, + proof: &WhirR1CSProof, + public_inputs: &PublicInputs, + r1cs: &R1CS, + ) -> Result<()> { let io = self.create_io_pattern(); let mut arthur = io.to_verifier_state(&proof.transcript); @@ -57,63 +72,130 @@ impl WhirR1CSVerifier for WhirR1CSScheme { .context("while verifying sumcheck")?; // Read hints and verify WHIR proof - let (az_at_alpha, bz_at_alpha, cz_at_alpha) = - if let Some(parsed_commitment_2) = parsed_commitment_2 { - // Dual commitment mode - let sums_1: (Vec, Vec) = arthur.hint()?; - let sums_2: (Vec, Vec) = arthur.hint()?; - - let whir_sums_1: ([FieldElement; 3], [FieldElement; 3]) = - (sums_1.0.try_into().unwrap(), sums_1.1.try_into().unwrap()); - let whir_sums_2: ([FieldElement; 3], [FieldElement; 3]) = - (sums_2.0.try_into().unwrap(), sums_2.1.try_into().unwrap()); - - let statement_1 = prepare_statement_for_witness_verifier::<3>( + let ( + az_at_alpha, + bz_at_alpha, + cz_at_alpha, + whir_folding_randomness, + deferred_evals, + public_weights_challenge, + ) = if let Some(parsed_commitment_2) = parsed_commitment_2 { + // Dual commitment mode + let sums_1: (Vec, Vec) = arthur.hint()?; + let sums_2: (Vec, Vec) = arthur.hint()?; + + let whir_sums_1: ([FieldElement; 3], [FieldElement; 3]) = + (sums_1.0.try_into().unwrap(), sums_1.1.try_into().unwrap()); + let whir_sums_2: ([FieldElement; 3], [FieldElement; 3]) = + (sums_2.0.try_into().unwrap(), sums_2.1.try_into().unwrap()); + + let mut statement_1 = prepare_statement_for_witness_verifier::<3>( + self.m, + &parsed_commitment_1, + &whir_sums_1, + ); + let statement_2 = prepare_statement_for_witness_verifier::<3>( + self.m, + &parsed_commitment_2, + &whir_sums_2, + ); + + let mut public_inputs_hash_buf = [FieldElement::zero()]; + arthur.fill_next_scalars(&mut public_inputs_hash_buf)?; + let expected_public_inputs_hash = public_inputs.hash(); + ensure!( + public_inputs_hash_buf[0] == expected_public_inputs_hash, + "Public inputs hash mismatch: expected {:?}, got {:?}", + expected_public_inputs_hash, + public_inputs_hash_buf[0] + ); + + let mut public_weights_vector_random_buf = [FieldElement::zero()]; + arthur.fill_challenge_scalars(&mut public_weights_vector_random_buf)?; + + let whir_public_weights_query_answer: (FieldElement, FieldElement) = arthur + .hint() + .context("failed to read WHIR public weights query answer")?; + + if !public_inputs.is_empty() { + update_statement_for_witness_verifier( self.m, + &mut statement_1, &parsed_commitment_1, - &whir_sums_1, - ); - let statement_2 = prepare_statement_for_witness_verifier::<3>( - self.m, - &parsed_commitment_2, - &whir_sums_2, + whir_public_weights_query_answer, ); + } + + let (whir_folding_randomness, deferred_evals) = run_whir_pcs_batch_verifier( + &mut arthur, + &self.whir_witness, + &[parsed_commitment_1, parsed_commitment_2], + &[statement_1, statement_2], + ) + .context("while verifying WHIR batch proof")?; + + ( + whir_sums_1.0[0] + whir_sums_2.0[0], + whir_sums_1.0[1] + whir_sums_2.0[1], + whir_sums_1.0[2] + whir_sums_2.0[2], + whir_folding_randomness.0.to_vec(), + deferred_evals, + public_weights_vector_random_buf[0], + ) + } else { + // Single commitment mode + let sums: (Vec, Vec) = arthur.hint()?; + let whir_sums: ([FieldElement; 3], [FieldElement; 3]) = + (sums.0.try_into().unwrap(), sums.1.try_into().unwrap()); - run_whir_pcs_batch_verifier( - &mut arthur, - &self.whir_witness, - &[parsed_commitment_1, parsed_commitment_2], - &[statement_1, statement_2], - ) - .context("while verifying WHIR batch proof")?; - - ( - whir_sums_1.0[0] + whir_sums_2.0[0], - whir_sums_1.0[1] + whir_sums_2.0[1], - whir_sums_1.0[2] + whir_sums_2.0[2], - ) - } else { - // Single commitment mode - let sums: (Vec, Vec) = arthur.hint()?; - let whir_sums: ([FieldElement; 3], [FieldElement; 3]) = - (sums.0.try_into().unwrap(), sums.1.try_into().unwrap()); - - let statement = prepare_statement_for_witness_verifier::<3>( + let mut statement = prepare_statement_for_witness_verifier::<3>( + self.m, + &parsed_commitment_1, + &whir_sums, + ); + + let mut public_inputs_hash_buf = [FieldElement::zero()]; + arthur.fill_next_scalars(&mut public_inputs_hash_buf)?; + let expected_public_inputs_hash = public_inputs.hash(); + ensure!( + public_inputs_hash_buf[0] == expected_public_inputs_hash, + "Public inputs hash mismatch: expected {:?}, got {:?}", + expected_public_inputs_hash, + public_inputs_hash_buf[0] + ); + + let mut public_weights_vector_random_buf = [FieldElement::zero()]; + arthur.fill_challenge_scalars(&mut public_weights_vector_random_buf)?; + + let whir_public_weights_query_answer: (FieldElement, FieldElement) = arthur + .hint() + .context("failed to read WHIR public weights query answer")?; + if !public_inputs.is_empty() { + update_statement_for_witness_verifier( self.m, + &mut statement, &parsed_commitment_1, - &whir_sums, + whir_public_weights_query_answer, ); + } - run_whir_pcs_verifier( - &mut arthur, - &parsed_commitment_1, - &self.whir_witness, - &statement, - ) - .context("while verifying WHIR proof")?; + let (whir_folding_randomness, deferred_evals) = run_whir_pcs_verifier( + &mut arthur, + &parsed_commitment_1, + &self.whir_witness, + &statement, + ) + .context("while verifying WHIR proof")?; - (whir_sums.0[0], whir_sums.0[1], whir_sums.0[2]) - }; + ( + whir_sums.0[0], + whir_sums.0[1], + whir_sums.0[2], + whir_folding_randomness.0.to_vec(), + deferred_evals, + public_weights_vector_random_buf[0], + ) + }; // Check the Spartan sumcheck relation ensure!( @@ -126,6 +208,54 @@ impl WhirR1CSVerifier for WhirR1CSScheme { "last sumcheck value does not match" ); + // Check deferred linear and geometric constraints + let offset = if public_inputs.is_empty() { 0 } else { 1 }; + + // Linear deferred + if self.num_challenges > 0 { + let matrix_extension_evals = evaluate_r1cs_matrix_extension_batch( + r1cs, + &data_from_sumcheck_verifier.alpha, + &whir_folding_randomness, + self.w1_size, + ); + for i in 0..6 { + ensure!( + matrix_extension_evals[i] == deferred_evals[offset + i], + "Matrix extension evaluation {} does not match deferred value", + i + ); + } + } else { + let matrix_extension_evals = evaluate_r1cs_matrix_extension( + r1cs, + &data_from_sumcheck_verifier.alpha, + &whir_folding_randomness, + ); + + for i in 0..3 { + ensure!( + matrix_extension_evals[i] == deferred_evals[offset + i], + "Matrix extension evaluation {} does not match deferred value", + i + ); + } + } + + // Geometric deferred + if !public_inputs.is_empty() { + let public_weight_eval = compute_public_weight_evaluation( + public_inputs, + &whir_folding_randomness, + self.whir_witness.mv_parameters.num_variables, + public_weights_challenge, + ); + ensure!( + public_weight_eval == deferred_evals[0], + "Public weight evaluation does not match deferred value" + ); + } + Ok(()) } } @@ -147,6 +277,20 @@ fn prepare_statement_for_witness_verifier( statement_verifier } +fn update_statement_for_witness_verifier( + m: usize, + statement_verifier: &mut Statement, + parsed_commitment: &ParsedCommitment, + whir_public_weights_query_answer: (FieldElement, FieldElement), +) { + let (public_f_sum, public_g_sum) = whir_public_weights_query_answer; + let public_weight = Weights::linear(EvaluationsList::new(vec![FieldElement::zero(); 1 << m])); + statement_verifier.add_constraint_in_front( + public_weight, + public_f_sum + public_g_sum * parsed_commitment.batching_randomness, + ); +} + #[instrument(skip_all)] pub fn run_sumcheck_verifier( arthur: &mut VerifierState, @@ -240,3 +384,101 @@ pub fn run_whir_pcs_batch_verifier( .context("while verifying batch WHIR")?; Ok((folding_randomness, deferred)) } + +fn evaluate_r1cs_matrix_extension( + r1cs: &R1CS, + row_rand: &[FieldElement], + col_rand: &[FieldElement], +) -> [FieldElement; 3] { + let row_eval = calculate_evaluations_over_boolean_hypercube_for_eq(row_rand.to_vec()); + let col_eval = calculate_evaluations_over_boolean_hypercube_for_eq(col_rand.to_vec()); + + let mut ans_a = FieldElement::zero(); + let mut ans_b = FieldElement::zero(); + let mut ans_c = FieldElement::zero(); + + for ((row, col), val) in r1cs.a().iter() { + ans_a += val * row_eval[row] * col_eval[col]; + } + + for ((row, col), val) in r1cs.b().iter() { + ans_b += val * row_eval[row] * col_eval[col]; + } + + for ((row, col), val) in r1cs.c().iter() { + ans_c += val * row_eval[row] * col_eval[col]; + } + + [ans_a, ans_b, ans_c] +} + +fn evaluate_r1cs_matrix_extension_batch( + r1cs: &R1CS, + row_rand: &[FieldElement], + col_rand: &[FieldElement], + w1_size: usize, +) -> [FieldElement; 6] { + let row_eval = calculate_evaluations_over_boolean_hypercube_for_eq(row_rand.to_vec()); + let col_eval = calculate_evaluations_over_boolean_hypercube_for_eq(col_rand.to_vec()); + + let mut ans = [FieldElement::zero(); 6]; + + // Evaluate matrices - split by column based on w1_size + for ((row, col), val) in r1cs.a().iter() { + if col < w1_size { + ans[0] += val * row_eval[row] * col_eval[col]; + } else { + ans[3] += val * row_eval[row] * col_eval[col - w1_size]; + } + } + + for ((row, col), val) in r1cs.b().iter() { + if col < w1_size { + ans[1] += val * row_eval[row] * col_eval[col]; + } else { + ans[4] += val * row_eval[row] * col_eval[col - w1_size]; + } + } + + for ((row, col), val) in r1cs.c().iter() { + if col < w1_size { + ans[2] += val * row_eval[row] * col_eval[col]; + } else { + ans[5] += val * row_eval[row] * col_eval[col - w1_size]; + } + } + + ans +} + +fn compute_public_weight_evaluation( + public_inputs: &PublicInputs, + folding_randomness: &[FieldElement], + m: usize, + x: FieldElement, +) -> FieldElement { + let domain_size = 1 << m; + let mut public_weights = vec![FieldElement::zero(); domain_size]; + + let mut current_pow = FieldElement::one(); + for (idx, _) in public_inputs.0.iter().enumerate() { + public_weights[idx] = current_pow; + current_pow = current_pow * x; + } + + let mle = geometric_till(x, public_inputs.len(), folding_randomness); + + #[cfg(test)] + { + let eq_polys = + calculate_evaluations_over_boolean_hypercube_for_eq(folding_randomness.to_vec()); + let sum: FieldElement = public_weights + .iter() + .zip(eq_polys.iter()) + .map(|(w, eq)| *w * eq) + .sum(); + assert!(sum == mle, "Sum does not match mle"); + } + + mle +} diff --git a/recursive-verifier/app/circuit/circuit.go b/recursive-verifier/app/circuit/circuit.go index 2f95d5e6..fcc1e216 100644 --- a/recursive-verifier/app/circuit/circuit.go +++ b/recursive-verifier/app/circuit/circuit.go @@ -39,6 +39,9 @@ type Circuit struct { WitnessClaimedEvaluations [][]frontend.Variable // [commitment_idx][eval_idx] WitnessBlindingEvaluations [][]frontend.Variable + // For public_f_sum and public_g_sum + PubWitnessEvaluations []frontend.Variable + // Batch mode only: batched polynomial for rounds 1+ WitnessMerkle Merkle @@ -46,9 +49,9 @@ type Circuit struct { MatrixB []MatrixCell MatrixC []MatrixCell - // Public Input - IO []byte - Transcript []uints.U8 `gnark:",public"` + IO []byte + Transcript []uints.U8 `gnark:",public"` + PublicInputs PublicInputs } func (circuit *Circuit) Define(api frontend.API) error { @@ -95,12 +98,56 @@ func (circuit *Circuit) Define(api frontend.API) error { return err } + // Read public inputs hash from transcript + publicInputsHashBuf := make([]frontend.Variable, 1) + if err := arthur.FillNextScalars(publicInputsHashBuf); err != nil { + return fmt.Errorf("failed to read public inputs hash: %w", err) + } + + expectedHash, err := hashPublicInputs(sc, circuit.PublicInputs) + if err != nil { + return fmt.Errorf("failed to compute public inputs hash: %w", err) + } + + api.AssertIsEqual(publicInputsHashBuf[0], expectedHash) + + // Squeeze rand for public weights + publicWeightsChallenge := make([]frontend.Variable, 1) + if err := arthur.FillChallengeScalars(publicWeightsChallenge); err != nil { + return fmt.Errorf("failed to read public weights challenge: %w", err) + } + // WHIR verification var whirFoldingRandomness []frontend.Variable var az, bz, cz frontend.Variable if circuit.NumChallenges > 0 { - // Dual commitment mode: batch WHIR verification + // Only statement_1 (first commitment) gets extended with public weights, statement_2 remains unchanged + extendedLinearStatementEvalsBatch := make([][][]frontend.Variable, 2) + + if !circuit.PublicInputs.IsEmpty() { + extendedLinearStatementEvalsBatch[0] = extendLinearStatement( + circuit, + [][]frontend.Variable{circuit.WitnessClaimedEvaluations[0], circuit.WitnessBlindingEvaluations[0]}, + circuit.PubWitnessEvaluations, + ) + + extendedLinearStatementEvalsBatch[1] = [][]frontend.Variable{ + circuit.WitnessClaimedEvaluations[1], + circuit.WitnessBlindingEvaluations[1], + } + } else { + // Use original arrays as before, no public inputs + extendedLinearStatementEvalsBatch[0] = [][]frontend.Variable{ + circuit.WitnessClaimedEvaluations[0], + circuit.WitnessBlindingEvaluations[0], + } + extendedLinearStatementEvalsBatch[1] = [][]frontend.Variable{ + circuit.WitnessClaimedEvaluations[1], + circuit.WitnessBlindingEvaluations[1], + } + } + whirFoldingRandomness, err = RunZKWhirBatch( api, arthur, uapi, sc, circuit.WitnessFirstRounds, // firstRounds []Merkle @@ -109,12 +156,10 @@ func (circuit *Circuit) Define(api frontend.API) error { [][][]frontend.Variable{initialOODAnswers1, initialOODAnswers2}, // initialOODAnswers []frontend.Variable{rootHash1, rootHash2}, // rootHashes circuit.WitnessMerkle, // batchedMerkle - [][][]frontend.Variable{ // linearStatementEvals - {circuit.WitnessClaimedEvaluations[0], circuit.WitnessBlindingEvaluations[0]}, - {circuit.WitnessClaimedEvaluations[1], circuit.WitnessBlindingEvaluations[1]}, - }, - circuit.WHIRParamsWitness, // whirParams - circuit.WitnessLinearStatementEvaluations, // linearStatementValuesAtPoints + extendedLinearStatementEvalsBatch, // linearStatementEvals (extended for first commitment) + circuit.WHIRParamsWitness, // whirParams + circuit.WitnessLinearStatementEvaluations, // linearStatementValuesAtPoints + circuit.PublicInputs, // publicInputs ) if err != nil { return err @@ -125,12 +170,15 @@ func (circuit *Circuit) Define(api frontend.API) error { bz = api.Add(circuit.WitnessClaimedEvaluations[0][1], circuit.WitnessClaimedEvaluations[1][1]) cz = api.Add(circuit.WitnessClaimedEvaluations[0][2], circuit.WitnessClaimedEvaluations[1][2]) } else { + log.Println("Single Mode") + extendedLinearStatementEvals := extendLinearStatement(circuit, [][]frontend.Variable{circuit.WitnessClaimedEvaluations[0], circuit.WitnessBlindingEvaluations[0]}, circuit.PubWitnessEvaluations) + // Single commitment mode whirFoldingRandomness, err = RunZKWhir( api, arthur, uapi, sc, circuit.WitnessMerkle, circuit.WitnessFirstRounds[0], circuit.WHIRParamsWitness, - [][]frontend.Variable{circuit.WitnessClaimedEvaluations[0], circuit.WitnessBlindingEvaluations[0]}, + extendedLinearStatementEvals, circuit.WitnessLinearStatementEvaluations, batchingRandomness1, initialOODQueries1, @@ -150,23 +198,74 @@ func (circuit *Circuit) Define(api frontend.API) error { x := api.Mul(api.Sub(api.Mul(az, bz), cz), calculateEQ(api, spartanSumcheckRand, tRand)) api.AssertIsEqual(spartanSumcheckLastValue, x) + offset := 0 + if !circuit.PublicInputs.IsEmpty() { + // can be generalized later on if we have more different kinds of statements + offset = 1 + } + if circuit.NumChallenges > 0 { // Batch mode - check 6 deferred values matrixExtensionEvals := evaluateR1CSMatrixExtensionBatch(api, circuit, spartanSumcheckRand, whirFoldingRandomness, circuit.W1Size) for i := 0; i < 6; i++ { - api.AssertIsEqual(matrixExtensionEvals[i], circuit.WitnessLinearStatementEvaluations[i]) + api.AssertIsEqual(matrixExtensionEvals[i], circuit.WitnessLinearStatementEvaluations[offset+i]) } } else { + // Single mode - existing logic matrixExtensionEvals := evaluateR1CSMatrixExtension(api, circuit, spartanSumcheckRand, whirFoldingRandomness) for i := 0; i < 3; i++ { - api.AssertIsEqual(matrixExtensionEvals[i], circuit.WitnessLinearStatementEvaluations[i]) + api.AssertIsEqual(matrixExtensionEvals[i], circuit.WitnessLinearStatementEvaluations[offset+i]) } } + // Geomteric weights for public inputs + if !circuit.PublicInputs.IsEmpty() { + publicWeightEval := computePublicWeightEvaluation( + api, circuit.PublicInputs, whirFoldingRandomness, + circuit.WHIRParamsWitness.MVParamsNumberOfVariables, publicWeightsChallenge[0], + ) + + api.AssertIsEqual(publicWeightEval, circuit.WitnessLinearStatementEvaluations[0]) + } + return nil } +func computePublicWeightEvaluation( + api frontend.API, + publicInputs PublicInputs, + foldingRandomness []frontend.Variable, + m int, // domain size = 2^m + x frontend.Variable, +) frontend.Variable { + // Build public weight vector: [1, x, x^2, ..., x^(n-1), 0, 0, ..., 0] where n = len(publicInputs.Values) and total length = 2^m + domainSize := 1 << m + publicWeights := make([]frontend.Variable, domainSize) + + for i := 0; i < domainSize; i++ { + publicWeights[i] = 0 + } + + // Set public weights: [1, x, x^2, ..., x^(n-1), 0, 0, ..., 0] + currentPower := frontend.Variable(1) + for i := 0; i < len(publicInputs.Values); i++ { + publicWeights[i] = currentPower + currentPower = api.Mul(currentPower, x) + } + + // TODO : Replace it with geometric_till algo + // Evaluate the multilinear extension of publicWeights at foldingRandomness + // Formula: f(r) = Σ_{i=0}^{2^m-1} f[i] * eq_i(r) + // where eq_i(r) is the i-th Lagrange basis polynomial + eqPolys := calculateEQOverBooleanHypercube(api, foldingRandomness) + result := frontend.Variable(0) + for i := 0; i < len(publicWeights); i++ { + result = api.Add(result, api.Mul(publicWeights[i], eqPolys[i])) + } + return result +} + func verifyCircuit( deferred []Fp256, cfg Config, @@ -175,9 +274,11 @@ func verifyCircuit( vk *groth16.VerifyingKey, claimedEvaluations ClaimedEvaluations, claimedEvaluations2 ClaimedEvaluations, + publicWeightsClaimedEvaluation [2]Fp256, internedR1CS R1CS, interner Interner, buildOps common.BuildOps, + publicInputs PublicInputs, ) error { transcriptT := make([]uints.U8, cfg.TranscriptLen) contTranscript := make([]uints.U8, cfg.TranscriptLen) @@ -189,9 +290,18 @@ func verifyCircuit( // Determine witness linear statement evals size based on mode var witnessLinearStatementEvalsSize int if cfg.NumChallenges > 0 { - witnessLinearStatementEvalsSize = 6 // 3 per commitment in batch mode + if !cfg.PublicInputs.IsEmpty() { + // 3 per commitment in batch mode + 1 public_input (geometric statement as a subset of linear statement) + witnessLinearStatementEvalsSize = 7 + } else { + witnessLinearStatementEvalsSize = 6 + } } else { - witnessLinearStatementEvalsSize = 3 + if !cfg.PublicInputs.IsEmpty() { + witnessLinearStatementEvalsSize = 4 + } else { + witnessLinearStatementEvalsSize = 3 + } } witnessLinearStatementEvaluations := make([]frontend.Variable, witnessLinearStatementEvalsSize) @@ -199,6 +309,9 @@ func verifyCircuit( contWitnessLinearStatementEvaluations := make([]frontend.Variable, witnessLinearStatementEvalsSize) contHidingSpartanLinearStatementEvaluations := make([]frontend.Variable, 1) + if len(deferred) < 1+witnessLinearStatementEvalsSize { + return fmt.Errorf("deferred array too short: expected at least %d elements, got %d", 1+witnessLinearStatementEvalsSize, len(deferred)) + } hidingSpartanLinearStatementEvaluations[0] = typeConverters.LimbsToBigIntMod(deferred[0].Limbs) for i := 0; i < witnessLinearStatementEvalsSize; i++ { witnessLinearStatementEvaluations[i] = typeConverters.LimbsToBigIntMod(deferred[1+i].Limbs) @@ -258,6 +371,10 @@ func verifyCircuit( fSums2, gSums2 = parseClaimedEvaluations(claimedEvaluations2, true) } + // Parse public weights claimed evaluation + fSumPublicWeights, gSumPublicWeights := parsePublicWeightsClaimedEvaluation(publicWeightsClaimedEvaluation, true) + pubWitnessEvaluations := []frontend.Variable{fSumPublicWeights, gSumPublicWeights} + // Build witness slices conditionally var witnessClaimedEvals, witnessBlindingEvals [][]frontend.Variable if cfg.NumChallenges > 0 { @@ -268,6 +385,13 @@ func verifyCircuit( witnessBlindingEvals = [][]frontend.Variable{gSums} } + // Empty container while circuit creation + publicInputsContainer := PublicInputs{ + Values: make([]frontend.Variable, len(publicInputs.Values)), + } + + log.Println("publicInputs", publicInputs) + circuit := Circuit{ IO: []byte(cfg.IOPattern), Transcript: contTranscript, @@ -276,6 +400,7 @@ func verifyCircuit( LogANumTerms: cfg.LogANumTerms, WitnessClaimedEvaluations: witnessClaimedEvals, WitnessBlindingEvaluations: witnessBlindingEvals, + PubWitnessEvaluations: pubWitnessEvaluations, WitnessLinearStatementEvaluations: contWitnessLinearStatementEvaluations, HidingSpartanLinearStatementEvaluations: contHidingSpartanLinearStatementEvaluations, HidingSpartanFirstRound: newMerkle(hints.spartanHidingHint.firstRoundMerklePaths.path, true), @@ -289,6 +414,7 @@ func verifyCircuit( MatrixA: matrixA, MatrixB: matrixB, MatrixC: matrixC, + PublicInputs: publicInputsContainer, } ccs, err := frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, &circuit) @@ -377,13 +503,19 @@ func verifyCircuit( witnessBlindingEvals = [][]frontend.Variable{gSums} } + fSumPublicWeights, gSumPublicWeights = parsePublicWeightsClaimedEvaluation(publicWeightsClaimedEvaluation, false) + pubWitnessEvaluations = []frontend.Variable{fSumPublicWeights, gSumPublicWeights} + assignment := Circuit{ IO: []byte(cfg.IOPattern), Transcript: transcriptT, LogNumConstraints: cfg.LogNumConstraints, + LogNumVariables: cfg.LogNumVariables, + LogANumTerms: cfg.LogANumTerms, WitnessClaimedEvaluations: witnessClaimedEvals, WitnessBlindingEvaluations: witnessBlindingEvals, WitnessLinearStatementEvaluations: witnessLinearStatementEvaluations, + PubWitnessEvaluations: pubWitnessEvaluations, HidingSpartanLinearStatementEvaluations: hidingSpartanLinearStatementEvaluations, HidingSpartanFirstRound: newMerkle(hints.spartanHidingHint.firstRoundMerklePaths.path, false), HidingSpartanMerkle: newMerkle(hints.spartanHidingHint.roundHints, false), @@ -396,10 +528,15 @@ func verifyCircuit( MatrixA: matrixA, MatrixB: matrixB, MatrixC: matrixC, + PublicInputs: publicInputs, } witness, _ := frontend.NewWitness(&assignment, ecc.BN254.ScalarField()) - publicWitness, _ := witness.Public() + publicWitness, err := witness.Public() + if err != nil { + log.Printf("Failed witess,Public(): %v", err) + return err + } opts := []backend.ProverOption{ backend.WithSolverOptions(solver.WithHints(utilities.IndexOf)), @@ -436,3 +573,42 @@ func witnessFirstRounds(hints Hints, isContainer bool) []Merkle { } return result } + +func parsePublicWeightsClaimedEvaluation(publicWeightsClaimedEvaluation [2]Fp256, isContainer bool) (frontend.Variable, frontend.Variable) { + var fSumPublicWeights, gSumPublicWeights frontend.Variable + + if !isContainer { + fSumPublicWeights = typeConverters.LimbsToBigIntMod(publicWeightsClaimedEvaluation[0].Limbs) + gSumPublicWeights = typeConverters.LimbsToBigIntMod(publicWeightsClaimedEvaluation[1].Limbs) + } + + return fSumPublicWeights, gSumPublicWeights +} + +func extendLinearStatement( + circuit *Circuit, + linearStatementEvaluations [][]frontend.Variable, + pubWitnessEvaluations []frontend.Variable, +) [][]frontend.Variable { + var extendedLinearStatementEvals [][]frontend.Variable + + if !circuit.PublicInputs.IsEmpty() { + // Extend the statement equivalent array by prepending the public constraint (public constraint is added in starting at prover side) + extendedLinearStatementEvals = make([][]frontend.Variable, 2) + + // f_sums: [public_f_sum, f_sums[0], f_sums[1]... ] + extendedLinearStatementEvals[0] = make([]frontend.Variable, len(linearStatementEvaluations[0])+1) + extendedLinearStatementEvals[0][0] = pubWitnessEvaluations[0] + copy(extendedLinearStatementEvals[0][1:], linearStatementEvaluations[0]) + + // g_sums: [public_g_sum, g_sums[0], g_sums[1]... ] + extendedLinearStatementEvals[1] = make([]frontend.Variable, len(linearStatementEvaluations[1])+1) + extendedLinearStatementEvals[1][0] = pubWitnessEvaluations[1] + copy(extendedLinearStatementEvals[1][1:], linearStatementEvaluations[1]) + } else { + // No public inputs, use original arrays + extendedLinearStatementEvals = linearStatementEvaluations + } + + return extendedLinearStatementEvals +} diff --git a/recursive-verifier/app/circuit/common.go b/recursive-verifier/app/circuit/common.go index 16eea7c1..70728673 100644 --- a/recursive-verifier/app/circuit/common.go +++ b/recursive-verifier/app/circuit/common.go @@ -30,6 +30,7 @@ func PrepareAndVerifyCircuit(config Config, r1cs R1CS, pk *groth16.ProvingKey, v var deferred []Fp256 var claimedEvaluations ClaimedEvaluations var claimedEvaluations2 ClaimedEvaluations + var publicWeightsEvaluations [2]Fp256 for _, op := range io.Ops { switch op.Kind { @@ -128,6 +129,16 @@ func PrepareAndVerifyCircuit(config Config, r1cs R1CS, pk *groth16.ProvingKey, v if err != nil { return fmt.Errorf("failed to deserialize claimed_evaluations_2: %w", err) } + + case "public_weights_evaluations": + _, err = arkSerialize.CanonicalDeserializeWithMode( + bytes.NewReader(config.Transcript[start:end]), + &publicWeightsEvaluations, + false, false, + ) + if err != nil { + return fmt.Errorf("failed to deserialize public_weights_evaluations: %w", err) + } } if err != nil { @@ -204,7 +215,7 @@ func PrepareAndVerifyCircuit(config Config, r1cs R1CS, pk *groth16.ProvingKey, v WitnessRoundHints: witnessRoundHints, } - err = verifyCircuit(deferred, config, hints, pk, vk, claimedEvaluations, claimedEvaluations2, r1cs, interner, buildOps) + err = verifyCircuit(deferred, config, hints, pk, vk, claimedEvaluations, claimedEvaluations2, publicWeightsEvaluations, r1cs, interner, buildOps, config.PublicInputs) if err != nil { return fmt.Errorf("verification failed: %w", err) } diff --git a/recursive-verifier/app/circuit/mtUtilities.go b/recursive-verifier/app/circuit/mtUtilities.go index 264727b3..4dc74eee 100644 --- a/recursive-verifier/app/circuit/mtUtilities.go +++ b/recursive-verifier/app/circuit/mtUtilities.go @@ -112,3 +112,24 @@ func rlcBatchedLeaves(api frontend.API, leaves [][]frontend.Variable, foldSize i } return collapsed } + +// hashPublicInputs computes the hash of public inputs as field elements sequentially +func hashPublicInputs(sc *skyscraper.Skyscraper, publicInputs PublicInputs) (frontend.Variable, error) { + + if len(publicInputs.Values) == 0 { + return frontend.Variable(0), nil + } + + // For single element, we hash it with a zero + if len(publicInputs.Values) == 1 { + return sc.CompressV2(publicInputs.Values[0], frontend.Variable(0)), nil + } + + // For 2+ elements, use standard approach + hash := sc.CompressV2(publicInputs.Values[0], publicInputs.Values[1]) + for i := 2; i < len(publicInputs.Values); i++ { + hash = sc.CompressV2(hash, publicInputs.Values[i]) + } + + return hash, nil +} diff --git a/recursive-verifier/app/circuit/types.go b/recursive-verifier/app/circuit/types.go index 420b43e0..f6db49d0 100644 --- a/recursive-verifier/app/circuit/types.go +++ b/recursive-verifier/app/circuit/types.go @@ -1,6 +1,8 @@ package circuit import ( + "reilabs/whir-verifier-circuit/app/utilities" + "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/std/math/uints" ) @@ -89,18 +91,19 @@ type ProofObject struct { } type Config struct { - WHIRConfigWitness WHIRConfig `json:"whir_config_witness"` - WHIRConfigHidingSpartan WHIRConfig `json:"whir_config_hiding_spartan"` - LogNumConstraints int `json:"log_num_constraints"` - LogNumVariables int `json:"log_num_variables"` - LogANumTerms int `json:"log_a_num_terms"` - IOPattern string `json:"io_pattern"` - Transcript []byte `json:"transcript"` - TranscriptLen int `json:"transcript_len"` - WitnessStatementEvaluations []string `json:"witness_statement_evaluations"` - BlindingStatementEvaluations []string `json:"blinding_statement_evaluations"` - NumChallenges int `json:"num_challenges"` - W1Size int `json:"w1_size"` + WHIRConfigWitness WHIRConfig `json:"whir_config_witness"` + WHIRConfigHidingSpartan WHIRConfig `json:"whir_config_hiding_spartan"` + LogNumConstraints int `json:"log_num_constraints"` + LogNumVariables int `json:"log_num_variables"` + LogANumTerms int `json:"log_a_num_terms"` + IOPattern string `json:"io_pattern"` + Transcript []byte `json:"transcript"` + TranscriptLen int `json:"transcript_len"` + WitnessStatementEvaluations []string `json:"witness_statement_evaluations"` + BlindingStatementEvaluations []string `json:"blinding_statement_evaluations"` + NumChallenges int `json:"num_challenges"` + W1Size int `json:"w1_size"` + PublicInputs PublicInputs `json:"public_inputs"` } // Update Hints to support batch mode @@ -139,3 +142,20 @@ type DualClaimedEvaluations struct { First ClaimedEvaluations Second ClaimedEvaluations } + +type PublicInputs struct { + Values []frontend.Variable +} + +func (p *PublicInputs) UnmarshalJSON(data []byte) error { + values, err := utilities.UnmarshalPublicInputs(data) + if err != nil { + return err + } + p.Values = values + return nil +} + +func (p *PublicInputs) IsEmpty() bool { + return len(p.Values) == 0 +} diff --git a/recursive-verifier/app/circuit/whir.go b/recursive-verifier/app/circuit/whir.go index 0e153280..3bdea8f1 100644 --- a/recursive-verifier/app/circuit/whir.go +++ b/recursive-verifier/app/circuit/whir.go @@ -238,6 +238,7 @@ func RunZKWhirBatch( // Common parameters whirParams WHIRParams, linearStatementValuesAtPoints []frontend.Variable, + publicInputs PublicInputs, ) (totalFoldingRandomness []frontend.Variable, err error) { numPolynomials := len(firstRounds) if numPolynomials == 0 { @@ -255,7 +256,15 @@ func RunZKWhirBatch( for i := 0; i < numPolynomials; i++ { numOOD += len(initialOODQueries[i]) } - numStatementConstraints := numPolynomials * 3 // 3 per commitment (Az, Bz, Cz) + + numStatementConstraints := 0 + + // w1 has 4 (pub, Az, Bz, Cz) constraints, w2 and remaining have 3 (Az, Bz, Cz) constraints + if !publicInputs.IsEmpty() { + numStatementConstraints = 4 + 3*(numPolynomials-1) + } else { + numStatementConstraints = numPolynomials * 3 + } numConstraints := numOOD + numStatementConstraints // Step 3: Read N×M evaluation matrix from transcript diff --git a/recursive-verifier/app/circuit/whir_utilities.go b/recursive-verifier/app/circuit/whir_utilities.go index 666dde44..ac75a9c3 100644 --- a/recursive-verifier/app/circuit/whir_utilities.go +++ b/recursive-verifier/app/circuit/whir_utilities.go @@ -140,6 +140,7 @@ func computeWPoly( value = api.Add(value, api.Mul(initialData.InitialCombinationRandomness[j], utilities.EqPolyOutside(api, utilities.ExpandFromUnivariate(api, initialData.InitialOODQueries[j], numberVars), totalFoldingRandomness))) } + // Values are directly used as all linearStatements are deffered and hints were given. Checking of hints will be done later on. for j, linearStatementValueAtPoint := range linearStatementValuesAtPoints { value = api.Add(value, api.Mul(initialData.InitialCombinationRandomness[len(initialData.InitialOODQueries)+j], linearStatementValueAtPoint)) } diff --git a/recursive-verifier/app/utilities/utilities.go b/recursive-verifier/app/utilities/utilities.go index 04a8e9d2..7222133b 100644 --- a/recursive-verifier/app/utilities/utilities.go +++ b/recursive-verifier/app/utilities/utilities.go @@ -1,6 +1,10 @@ package utilities import ( + "crypto/sha256" + "encoding/binary" + "encoding/hex" + "encoding/json" "fmt" "math/big" "reilabs/whir-verifier-circuit/app/typeConverters" @@ -58,6 +62,77 @@ func IndexOf(_ *big.Int, inputs []*big.Int, outputs []*big.Int) error { return nil } +// HashPublicInputsHint is a hint function that computes SHA-256 hash of public inputs +// matching the Rust PublicInputs::hash() implementation. +// It takes public input values, converts them to BigInt, extracts limbs, hashes them, +// and returns the hash as a field element. +func HashPublicInputsHint(_ *big.Int, inputs []*big.Int, outputs []*big.Int) error { + if len(outputs) != 1 { + return fmt.Errorf("expecting one output") + } + + if len(inputs) == 0 { + outputs[0] = big.NewInt(0) + return nil + } + + hasher := sha256.New() + + // Process each public input value + for _, input := range inputs { + // Convert field element to BigInt (it's already a BigInt, but ensure it's in range) + value := new(big.Int).Set(input) + + // Extract limbs (u64 values) from BigInt + // Field elements are represented as 4 u64 limbs in little-endian + limbs := make([]uint64, 4) + temp := new(big.Int).Set(value) + limbs[0] = temp.Uint64() // Least significant limb + temp.Rsh(temp, 64) + limbs[1] = temp.Uint64() + temp.Rsh(temp, 64) + limbs[2] = temp.Uint64() + temp.Rsh(temp, 64) + limbs[3] = temp.Uint64() // Most significant limb + + // Hash each limb as little-endian bytes (8 bytes per limb) + for _, limb := range limbs { + limbBytes := make([]byte, 8) + binary.LittleEndian.PutUint64(limbBytes, limb) + hasher.Write(limbBytes) + } + } + + // Get the hash result (32 bytes) + hashResult := hasher.Sum(nil) + + // Convert hash result to field element by splitting into 4 u64 limbs + // Each chunk of 8 bytes becomes a u64 (little-endian) + limbs := make([]uint64, 4) + for i := 0; i < 4; i++ { + start := i * 8 + end := start + 8 + limbs[i] = binary.LittleEndian.Uint64(hashResult[start:end]) + } + + // Reconstruct field element from limbs + result := new(big.Int).SetUint64(limbs[0]) + temp := new(big.Int).SetUint64(limbs[1]) + result.Add(result, temp.Lsh(temp, 64)) + temp.SetUint64(limbs[2]) + result.Add(result, temp.Lsh(temp, 128)) + temp.SetUint64(limbs[3]) + result.Add(result, temp.Lsh(temp, 192)) + + // Apply modulus to ensure result is in field range + modulus := new(big.Int) + modulus.SetString("21888242871839275222246405745257275088548364400416034343698204186575808495617", 10) + result.Mod(result, modulus) + + outputs[0] = result + return nil +} + func Reverse[T any](s []T) []T { res := make([]T, len(s)) copy(res, s) @@ -210,3 +285,49 @@ func DotProduct(api frontend.API, a []frontend.Variable, b []frontend.Variable) } return acc } + +// ParseHexFieldElement parses a hex string representing a FieldElement (little-endian) +// and converts it to a big.Int. The hex string should be 64 characters (32 bytes). +func ParseHexFieldElement(hexStr string) (*big.Int, error) { + if len(hexStr) >= 2 && hexStr[0:2] == "0x" { + hexStr = hexStr[2:] + } + + bytes, err := hex.DecodeString(hexStr) + if err != nil { + return nil, fmt.Errorf("invalid hex string: %w", err) + } + + reversed := make([]byte, len(bytes)) + for i, b := range bytes { + reversed[len(bytes)-1-i] = b + } + + result := new(big.Int) + result.SetBytes(reversed) + + modulus := new(big.Int) + modulus.SetString("21888242871839275222246405745257275088548364400416034343698204186575808495617", 10) + result.Mod(result, modulus) + + return result, nil +} + +// UnmarshalPublicInputs parses a JSON array of hex-encoded FieldElement strings +// and returns them as frontend.Variable slice. +func UnmarshalPublicInputs(data []byte) ([]frontend.Variable, error) { + var arr []string + if err := json.Unmarshal(data, &arr); err != nil { + return nil, err + } + + values := make([]frontend.Variable, len(arr)) + for i, hexStr := range arr { + value, err := ParseHexFieldElement(hexStr) + if err != nil { + return nil, fmt.Errorf("failed to parse public input at index %d: %w", i, err) + } + values[i] = value + } + return values, nil +} diff --git a/tooling/cli/src/cmd/generate_gnark_inputs.rs b/tooling/cli/src/cmd/generate_gnark_inputs.rs index 883f52b5..0e3cc6d0 100644 --- a/tooling/cli/src/cmd/generate_gnark_inputs.rs +++ b/tooling/cli/src/cmd/generate_gnark_inputs.rs @@ -62,6 +62,7 @@ impl Command for Args { prover.whir_for_witness.a_num_terms, prover.whir_for_witness.num_challenges, prover.whir_for_witness.w1_size, + &proof.public_inputs, &self.params_for_recursive_verifier, ); diff --git a/tooling/cli/src/cmd/prove.rs b/tooling/cli/src/cmd/prove.rs index 728b62c4..6169dd0e 100644 --- a/tooling/cli/src/cmd/prove.rs +++ b/tooling/cli/src/cmd/prove.rs @@ -48,8 +48,8 @@ impl Command for Args { let (constraints, witnesses) = prover.size(); info!(constraints, witnesses, "Read Noir proof scheme"); - // // Read the input toml - // let input_map = scheme.read_witness(&self.input_path)?; + #[cfg(test)] + let r1cs = prover.r1cs.clone(); // Generate the proof let proof = prover @@ -62,7 +62,7 @@ impl Command for Args { let mut verifier: Verifier = read(&self.verifier_path).context("while reading Provekit Verifier")?; verifier - .verify(&proof) + .verify(&proof, &r1cs) .context("While verifying Noir proof")?; } diff --git a/tooling/cli/src/cmd/verify.rs b/tooling/cli/src/cmd/verify.rs index 6b95ce62..762cca4d 100644 --- a/tooling/cli/src/cmd/verify.rs +++ b/tooling/cli/src/cmd/verify.rs @@ -2,7 +2,7 @@ use { super::Command, anyhow::{Context, Result}, argh::FromArgs, - provekit_common::{file::read, Verifier}, + provekit_common::{file::read, Prover, Verifier}, provekit_verifier::Verify, std::path::PathBuf, tracing::instrument, @@ -16,6 +16,10 @@ pub struct Args { #[argh(positional)] verifier_path: PathBuf, + /// path to the prover file (.pkp) + #[argh(positional)] + prover_path: PathBuf, + /// path to the proof file #[argh(positional)] proof_path: PathBuf, @@ -31,9 +35,12 @@ impl Command for Args { // Read the proof let proof = read(&self.proof_path).context("while reading proof")?; + // Read the prover and R1CS (similar to generate-gnark-inputs) + let prover: Prover = read(&self.prover_path).context("while reading Provekit Prover")?; + // Verify the proof verifier - .verify(&proof) + .verify(&proof, &prover.r1cs) .context("While verifying Noir proof")?; Ok(()) diff --git a/tooling/provekit-bench/benches/bench.rs b/tooling/provekit-bench/benches/bench.rs index ce703453..f1ab224e 100644 --- a/tooling/provekit-bench/benches/bench.rs +++ b/tooling/provekit-bench/benches/bench.rs @@ -63,7 +63,11 @@ fn verify_poseidon_1000(bencher: Bencher) { let mut verifier: Verifier = read(&proof_verifier_path).unwrap(); let proof_path = crate_dir.join("noir-proof.np"); let proof: NoirProof = read(&proof_path).unwrap(); - bencher.bench_local(|| black_box(&mut verifier).verify(black_box(&proof))); + let proof_prover_path = crate_dir.join("noir-provekit-prover.pkp"); + let prover: Prover = read(&proof_prover_path).unwrap(); + bencher.bench_local(|| { + black_box(&mut verifier).verify(black_box(&proof), black_box(&prover.r1cs)) + }); } fn main() { diff --git a/tooling/provekit-bench/tests/compiler.rs b/tooling/provekit-bench/tests/compiler.rs index 9ac24ffd..4b64aaba 100644 --- a/tooling/provekit-bench/tests/compiler.rs +++ b/tooling/provekit-bench/tests/compiler.rs @@ -42,11 +42,12 @@ fn test_compiler(test_case_path: impl AsRef) { let prover = Prover::from_noir_proof_scheme(schema.clone()); let mut verifier = Verifier::from_noir_proof_scheme(schema.clone()); + let r1cs = prover.r1cs.clone(); let proof = prover .prove(&witness_file_path) .expect("While proving Noir program statement"); - verifier.verify(&proof).expect("Verifying proof"); + verifier.verify(&proof, &r1cs).expect("Verifying proof"); } pub fn compile_workspace(workspace_path: impl AsRef) -> Result { diff --git a/tooling/provekit-gnark/src/gnark_config.rs b/tooling/provekit-gnark/src/gnark_config.rs index 015cc68f..41439a75 100644 --- a/tooling/provekit-gnark/src/gnark_config.rs +++ b/tooling/provekit-gnark/src/gnark_config.rs @@ -1,6 +1,6 @@ use { ark_poly::EvaluationDomain, - provekit_common::{IOPattern, WhirConfig}, + provekit_common::{IOPattern, PublicInputs, WhirConfig}, serde::{Deserialize, Serialize}, std::{fs::File, io::Write}, tracing::instrument, @@ -29,6 +29,8 @@ pub struct GnarkConfig { pub num_challenges: usize, /// size of w1 pub w1_size: usize, + /// public inputs + pub public_inputs: PublicInputs, } #[derive(Debug, Serialize, Deserialize)] @@ -114,6 +116,7 @@ pub fn gnark_parameters( a_num_terms: usize, num_challenges: usize, w1_size: usize, + public_inputs: &PublicInputs, ) -> GnarkConfig { GnarkConfig { whir_config_witness: WHIRConfigGnark::new(whir_params_witness), @@ -126,6 +129,7 @@ pub fn gnark_parameters( transcript_len: transcript.to_vec().len(), num_challenges, w1_size, + public_inputs: public_inputs.clone(), } } @@ -141,6 +145,7 @@ pub fn write_gnark_parameters_to_file( a_num_terms: usize, num_challenges: usize, w1_size: usize, + public_inputs: &PublicInputs, file_path: &str, ) { let gnark_config = gnark_parameters( @@ -153,6 +158,7 @@ pub fn write_gnark_parameters_to_file( a_num_terms, num_challenges, w1_size, + public_inputs, ); let mut file_params = File::create(file_path).unwrap(); file_params diff --git a/tooling/verifier-server/src/services/verification.rs b/tooling/verifier-server/src/services/verification.rs index f8abd138..398b282c 100644 --- a/tooling/verifier-server/src/services/verification.rs +++ b/tooling/verifier-server/src/services/verification.rs @@ -97,6 +97,7 @@ impl VerificationService { whir_scheme.a_num_terms, whir_scheme.num_challenges, whir_scheme.w1_size, + &proof.public_inputs, gnark_params_path, );