Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions provekit/common/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ ruint.workspace = true
serde.workspace = true
serde_json.workspace = true
tracing.workspace = true
sha2.workspace = true
zerocopy.workspace = true
zeroize.workspace = true
zstd.workspace = true
Expand Down
1 change: 1 addition & 0 deletions provekit/common/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ pub use {
verifier::Verifier,
whir::crypto::fields::Field256 as FieldElement,
whir_r1cs::{IOPattern, WhirConfig, WhirR1CSProof, WhirR1CSScheme},
witness::PublicInputs,
};

#[cfg(test)]
Expand Down
3 changes: 2 additions & 1 deletion provekit/common/src/noir_proof_scheme.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use {
crate::{
whir_r1cs::{WhirR1CSProof, WhirR1CSScheme},
witness::{NoirWitnessGenerator, SplitWitnessBuilders},
NoirElement, R1CS,
NoirElement, PublicInputs, R1CS,
},
acir::circuit::Program,
serde::{Deserialize, Serialize},
Expand All @@ -20,6 +20,7 @@ pub struct NoirProofScheme {

#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct NoirProof {
pub public_inputs: PublicInputs,
pub whir_r1cs_proof: WhirR1CSProof,
}

Expand Down
6 changes: 5 additions & 1 deletion provekit/common/src/skyscraper/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,8 @@ mod pow;
mod sponge;
mod whir;

pub use self::{pow::SkyscraperPoW, sponge::SkyscraperSponge, whir::SkyscraperMerkleConfig};
pub use self::{
pow::SkyscraperPoW,
sponge::SkyscraperSponge,
whir::{SkyscraperCRH, SkyscraperMerkleConfig},
};
1 change: 1 addition & 0 deletions provekit/common/src/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
mod print_abi;
pub mod serde_ark;
pub mod serde_ark_option;
pub mod serde_ark_vec;
pub mod serde_hex;
pub mod serde_jsonify;
pub mod sumcheck;
Expand Down
87 changes: 87 additions & 0 deletions provekit/common/src/utils/serde_ark_vec.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
use {
crate::FieldElement,
ark_serialize::{CanonicalDeserialize, CanonicalSerialize},
serde::{
de::{Error as _, SeqAccess, Visitor},
ser::{Error as _, SerializeSeq},
Deserializer, Serializer,
},
std::fmt,
};

pub fn serialize<S>(vec: &Vec<FieldElement>, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let is_human_readable = serializer.is_human_readable();
let mut seq = serializer.serialize_seq(Some(vec.len()))?;
for element in vec {
let mut buf = Vec::with_capacity(element.compressed_size());
element
.serialize_compressed(&mut buf)
.map_err(|e| S::Error::custom(format!("Failed to serialize: {e}")))?;

// Write bytes
if is_human_readable {
// ark_serialize doesn't have human-readable serialization. And Serde
// doesn't have good defaults for [u8]. So we implement hexadecimal
// serialization.
let hex = hex::encode(buf);
seq.serialize_element(&hex)?;
} else {
seq.serialize_element(&buf)?;
}
}
seq.end()
}

pub fn deserialize<'de, D>(deserializer: D) -> Result<Vec<FieldElement>, D::Error>
where
D: Deserializer<'de>,
{
struct VecVisitor {
is_human_readable: bool,
}

impl<'de> Visitor<'de> for VecVisitor {
type Value = Vec<FieldElement>;

fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a sequence of field elements")
}

fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where
A: SeqAccess<'de>,
{
let mut vec = Vec::new();
if self.is_human_readable {
while let Some(hex) = seq.next_element::<String>()? {
let bytes = hex::decode(hex)
.map_err(|e| A::Error::custom(format!("invalid hex: {e}")))?;
let mut reader = &*bytes;
let element = FieldElement::deserialize_compressed(&mut reader)
.map_err(|e| A::Error::custom(format!("deserialize failed: {e}")))?;
if !reader.is_empty() {
return Err(A::Error::custom("while deserializing: trailing bytes"));
}
vec.push(element);
}
} else {
while let Some(bytes) = seq.next_element::<Vec<u8>>()? {
let mut reader = &*bytes;
let element = FieldElement::deserialize_compressed(&mut reader)
.map_err(|e| A::Error::custom(format!("deserialize failed: {e}")))?;
if !reader.is_empty() {
return Err(A::Error::custom("while deserializing: trailing bytes"));
}
vec.push(element);
}
}
Ok(vec)
}
}

let is_human_readable = deserializer.is_human_readable();
deserializer.deserialize_seq(VecVisitor { is_human_readable })
}
10 changes: 10 additions & 0 deletions provekit/common/src/utils/sumcheck.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,10 @@ pub trait SumcheckIOPattern {
fn add_rand(self, num_rand: usize) -> Self;

fn add_zk_sumcheck_polynomials(self, num_vars: usize) -> Self;

/// Prover sends the hash of the public inputs
/// Verifier sends randomness to construct weights
fn add_public_inputs(self) -> Self;
}

impl<IOPattern> SumcheckIOPattern for IOPattern
Expand All @@ -136,6 +140,12 @@ where
self
}

fn add_public_inputs(mut self) -> Self {
self = self.add_scalars(1, "Public Inputs Hash");
self = self.challenge_scalars(1, "Public Weights Vector Random");
self
}

fn add_rand(self, num_rand: usize) -> Self {
self.challenge_scalars(num_rand, "rand")
}
Expand Down
29 changes: 28 additions & 1 deletion provekit/common/src/utils/zk_utils.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use {
crate::FieldElement, ark_ff::UniformRand, rayon::prelude::*,
crate::FieldElement,
ark_ff::{Field, UniformRand},
rayon::prelude::*,
whir::poly_utils::evals::EvaluationsList,
};

Expand Down Expand Up @@ -37,3 +39,28 @@ pub fn generate_random_multilinear_polynomial(num_vars: usize) -> Vec<FieldEleme

elements
}

/// Evaluates the mle of a polynomial from evaluations in a geometric
/// progression.
///
/// The evaluation list is of the form [1,a,a^2,a^3,...,a^{n-1},0,...,0]
/// a is the base of the geometric progression.
/// n is the number of non-zero terms in the progression.
pub fn geometric_till<F: Field>(mut a: F, n: usize, x: &[F]) -> F {
let k = x.len();
assert!(n > 0 && n < (1 << k));
let mut borrow_0 = F::one();
let mut borrow_1 = F::zero();
for (i, &xi) in x.iter().rev().enumerate() {
let bn = ((n - 1) >> i) & 1;
let b0 = F::one() - xi;
let b1 = a * xi;
(borrow_0, borrow_1) = if bn == 0 {
(b0 * borrow_0, (b0 + b1) * borrow_1 + b1 * borrow_0)
} else {
((b0 + b1) * borrow_0 + b0 * borrow_1, b1 * borrow_1)
};
a = a.square();
}
borrow_0
}
9 changes: 7 additions & 2 deletions provekit/common/src/whir_r1cs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,11 @@ impl WhirR1CSScheme {
if self.num_challenges > 0 {
// Compute total constraints: OOD + statement
// OOD: 2 witnesses × committment_ood_samples each
// Statement: 2 statements × 3 constraints each = 6
// Statement: statement_1 has 3 constraints + 1 public weights constraint = 4,
// statement_2 has 3 constraints = 3, total = 7
let num_witnesses = 2;
let num_ood_constraints = num_witnesses * self.whir_witness.committment_ood_samples;
let num_statement_constraints = 6; // 2 statements × 3 constraints
let num_statement_constraints = 7;
let num_constraints_total = num_ood_constraints + num_statement_constraints;

io = io
Expand All @@ -50,6 +51,8 @@ impl WhirR1CSScheme {
.add_whir_proof(&self.whir_for_hiding_spartan)
.hint("claimed_evaluations_1")
.hint("claimed_evaluations_2")
.add_public_inputs()
.hint("public_weights_evaluations")
.add_whir_batch_proof(&self.whir_witness, num_witnesses, num_constraints_total);
} else {
io = io
Expand All @@ -59,6 +62,8 @@ impl WhirR1CSScheme {
.add_zk_sumcheck_polynomials(self.m_0)
.add_whir_proof(&self.whir_for_hiding_spartan)
.hint("claimed_evaluations")
.add_public_inputs()
.hint("public_weights_evaluations")
.add_whir_proof(&self.whir_witness);
}

Expand Down
47 changes: 46 additions & 1 deletion provekit/common/src/witness/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@ mod witness_generator;
mod witness_io_pattern;

use {
crate::{utils::serde_ark, FieldElement},
crate::{
skyscraper::SkyscraperCRH,
utils::{serde_ark, serde_ark_vec},
FieldElement,
},
ark_crypto_primitives::crh::CRHScheme,
ark_ff::One,
serde::{Deserialize, Serialize},
};
Expand Down Expand Up @@ -40,3 +45,43 @@ impl ConstantOrR1CSWitness {
}
}
}

#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct PublicInputs(#[serde(with = "serde_ark_vec")] pub Vec<FieldElement>);

impl PublicInputs {
pub fn new() -> Self {
Self(Vec::new())
}

pub fn from_vec(vec: Vec<FieldElement>) -> Self {
Self(vec)
}

pub fn len(&self) -> usize {
self.0.len()
}

pub fn is_empty(&self) -> bool {
self.0.is_empty()
}

pub fn hash(&self) -> FieldElement {
match self.0.len() {
0 => FieldElement::from(0u64),
1 => {
// For single element, hash it with zero to ensure it gets properly hashed
let padded = vec![self.0[0], FieldElement::from(0u64)];
SkyscraperCRH::evaluate(&(), &padded[..]).expect("hash should succeed")
}
_ => SkyscraperCRH::evaluate(&(), &self.0[..])
.expect("hash should succeed for multiple inputs"),
}
}
}

impl Default for PublicInputs {
fn default() -> Self {
Self::new()
}
}
67 changes: 64 additions & 3 deletions provekit/common/src/witness/scheduling/splitter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@ impl<'a> WitnessSplitter<'a> {
/// (post-challenge).
///
/// Returns (w1_builder_indices, w2_builder_indices)
pub fn split_builders(&self) -> (Vec<usize>, Vec<usize>) {
pub fn split_builders(
&self,
acir_public_inputs_indices_set: HashSet<u32>,
) -> (Vec<usize>, Vec<usize>) {
let builder_count = self.witness_builders.len();

// Step 1: Find all Challenge builders
Expand All @@ -40,7 +43,11 @@ impl<'a> WitnessSplitter<'a> {
.collect();

if challenge_builders.is_empty() {
return ((0..builder_count).collect(), Vec::new());
let w1_indices = self.rearrange_w1(
(0..builder_count).collect(),
&acir_public_inputs_indices_set,
);
return (w1_indices, Vec::new());
}

// Step 2: Forward DFS from challenges to find mandatory_w2
Expand Down Expand Up @@ -135,6 +142,7 @@ impl<'a> WitnessSplitter<'a> {
// Step 7: Assign free builders greedily while respecting dependencies
// Rule: if any dependency is in w2, the builder must also be in w2
// (because w1 is solved before w2)
// A free builder for public input witnesses goes in w1.
let mut w1_set = mandatory_w1;
let mut w2_set = mandatory_w2;

Expand All @@ -149,6 +157,15 @@ impl<'a> WitnessSplitter<'a> {

let witness_count = DependencyInfo::extract_writes(&self.witness_builders[idx]).len();

// If free builder writes a public witness, add it to w1_set.
if let WitnessBuilder::Acir(_, acir_idx) = &self.witness_builders[idx] {
if acir_public_inputs_indices_set.contains(&(*acir_idx as u32)) {
w1_set.insert(idx);
w1_witness_count += witness_count;
continue;
}
}

if must_be_w2 {
w2_set.insert(idx);
w2_witness_count += witness_count;
Expand All @@ -165,9 +182,53 @@ impl<'a> WitnessSplitter<'a> {
let mut w1_indices: Vec<usize> = w1_set.into_iter().collect();
let mut w2_indices: Vec<usize> = w2_set.into_iter().collect();

w1_indices.sort_unstable();
w1_indices = self.rearrange_w1(w1_indices, &acir_public_inputs_indices_set);
w2_indices.sort_unstable();

(w1_indices, w2_indices)
}

/// Rearranges w1 builder indices into a canonical order:
/// 1. Constant builder (index 0) first, to preserve R1CS index 0 invariant
/// 2. Public input builders next, grouped together
/// 3. All other w1 builders last, sorted by index
fn rearrange_w1(
&self,
w1_indices: Vec<usize>,
acir_public_inputs_indices_set: &HashSet<u32>,
) -> Vec<usize> {
let mut public_input_builder_indices = Vec::new();
let mut rest_indices = Vec::new();

// Sanity Check: Make sure all public inputs and WITNESS_ONE_IDX are in
// w1_indices.
// Convert to HashSet for O(1) lookups since we're checking many times
let w1_indices_set = w1_indices.iter().copied().collect::<HashSet<_>>();
for &idx in acir_public_inputs_indices_set.iter() {
if !w1_indices_set.contains(&(idx as usize)) {
panic!("Public input {} is not in w1_indices", idx);
}
}

// Separate into: 0, public inputs, and rest
for builder_idx in w1_indices {
if builder_idx == 0 {
continue; // Will add 0 first
} else if let WitnessBuilder::Acir(_, acir_idx) = &self.witness_builders[builder_idx] {
if acir_public_inputs_indices_set.contains(&(*acir_idx as u32)) {
public_input_builder_indices.push(builder_idx);
continue;
}
}
rest_indices.push(builder_idx);
}

rest_indices.sort_unstable();

// Reorder: 0 first, then public inputs, then rest
let mut new_w1_indices = vec![0];
new_w1_indices.extend(public_input_builder_indices);
new_w1_indices.extend(rest_indices);
new_w1_indices
}
}
Loading
Loading