Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crates/mpcs/src/basefold/commit_phase.rs
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,7 @@ where
let evaluations: AdditiveVec<E> =
prover_msgs
.into_iter()
.fold(AdditiveVec::new(3), |mut acc, prover_msg| {
.fold(AdditiveVec::new(2), |mut acc, prover_msg| {
acc += AdditiveVec(prover_msg.evaluations);
acc
});
Expand Down
34 changes: 20 additions & 14 deletions crates/mpcs/src/basefold/query_phase.rs
Original file line number Diff line number Diff line change
Expand Up @@ -282,22 +282,28 @@ pub fn batch_verifier_query_phase<E: ExtensionField, S: EncodingScheme<E>>(
.sum::<E>();
}
}
assert_eq!(expected_sum, {
sumcheck_messages[0].evaluations[0] + sumcheck_messages[0].evaluations[1]
});
// 2. check every round of sumcheck match with prev claims
for i in 0..fold_challenges.len() - 1 {
assert_eq!(
extrapolate_uni_poly(&sumcheck_messages[i].evaluations, fold_challenges[i]),
{ sumcheck_messages[i + 1].evaluations[0] + sumcheck_messages[i + 1].evaluations[1] }
);

assert_eq!(
sumcheck_messages.len(),
fold_challenges.len(),
"sumcheck messages and fold challenges length mismatch"
);

// Reconstruct the implicit P(0) evaluation for each round and update the claim in place.
let mut current_claim = expected_sum;
for (msg, challenge) in sumcheck_messages.iter().zip(fold_challenges.iter()) {
let eval_1 = msg
.evaluations
.first()
.copied()
.expect("sumcheck prover message missing evaluations");
let eval_0 = current_claim - eval_1;
current_claim = extrapolate_uni_poly(eval_0, &msg.evaluations, *challenge);
}
// 3. check final evaluation are correct

// check final evaluation are correct
assert_eq!(
extrapolate_uni_poly(
&sumcheck_messages[fold_challenges.len() - 1].evaluations,
fold_challenges[fold_challenges.len() - 1]
),
current_claim,
// \sum_i eq(p,[r,i]) * f(r,i)
izip!(
final_message,
Expand Down
17 changes: 13 additions & 4 deletions crates/sumcheck/src/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ impl<'a, E: ExtensionField> Phase1Workers<'a, E> {
.workers_states
.par_iter_mut()
.map(|state| state.run_round())
.reduce(|| AdditiveVec::new(max_degree + 1), |a, b| a + b);
.reduce(|| AdditiveVec::new(max_degree), |a, b| a + b);

transcript.append_field_element_exts(&evaluations.0);

Expand Down Expand Up @@ -353,7 +353,7 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> {
// Step 2: generate sum for the partial evaluated polynomial:
// f(r_1, ... r_m,, x_{m+1}... x_n)
let span = entered_span!("build_uni_poly");
let AdditiveVec(uni_polys) = self.poly.products.iter().fold(
let AdditiveVec(mut uni_polys) = self.poly.products.iter().fold(
AdditiveVec::new(self.poly.aux_info.max_degree + 1),
|mut uni_polys, MonomialTerms { terms }| {
for Term {
Expand Down Expand Up @@ -405,6 +405,11 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> {

exit_span!(start);

assert!(uni_polys.len() > 1);
// NOTE remove uni_polys.eval(0) from lagrange domain
// as verifier can derive via claim - uni_polys.eval(1)
uni_polys.remove(0);

IOPProverMessage {
evaluations: uni_polys,
}
Expand Down Expand Up @@ -603,7 +608,7 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> {
// Step 2: generate sum for the partial evaluated polynomial:
// f(r_1, ... r_m,, x_{m+1}... x_n)
let span = entered_span!("build_uni_poly");
let AdditiveVec(uni_polys) = self
let AdditiveVec(mut uni_polys) = self
.poly
.products
.par_iter()
Expand Down Expand Up @@ -654,9 +659,13 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> {
.reduce_with(|acc, item| acc + item)
.unwrap();
exit_span!(span);

exit_span!(start);

assert!(uni_polys.len() > 1);
// NOTE remove uni_polys.eval(0) from lagrange domain
// as verifier can derive via claim - uni_polys.eval(1)
uni_polys.remove(0);

IOPProverMessage {
evaluations: uni_polys,
}
Expand Down
5 changes: 0 additions & 5 deletions crates/sumcheck/src/structs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,6 @@ use transcript::Challenge;
pub struct IOPProof<E: ExtensionField> {
pub proofs: Vec<IOPProverMessage<E>>,
}
impl<E: ExtensionField> IOPProof<E> {
pub fn extract_sum(&self) -> E {
self.proofs[0].evaluations[0] + self.proofs[0].evaluations[1]
}
}

/// A message from the prover to the verifier at a given round
/// is a list of evaluations.
Expand Down
20 changes: 4 additions & 16 deletions crates/sumcheck/src/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,21 +190,6 @@ fn test_normal_polynomial_helper<E: ExtensionField>() {
test_sumcheck_internal::<E>(nv, num_multiplicands_range, num_products);
}

#[test]
fn test_extract_sum() {
test_extract_sum_helper::<GoldilocksExt2>();
test_extract_sum_helper::<BabyBearExt4>();
}

fn test_extract_sum_helper<E: ExtensionField>() {
let mut rng = thread_rng();
let mut transcript = BasicTranscript::new(b"test");
let (poly, asserted_sum) = VirtualPolynomial::<E>::random(&[8], (2, 3), 3, &mut rng);
#[allow(deprecated)]
let (proof, _) = IOPProverState::<E>::prove_parallel(poly, &mut transcript);
assert_eq!(proof.extract_sum(), asserted_sum);
}

struct DensePolynomial(Vec<GoldilocksExt2>);

impl DensePolynomial {
Expand Down Expand Up @@ -236,7 +221,10 @@ fn test_extrapolation() {
.map(|i| poly.evaluate(&GoldilocksExt2::from_canonical_u64(i as u64)))
.collect::<Vec<_>>();
let query = GoldilocksExt2::random(&mut prng);
assert_eq!(poly.evaluate(&query), extrapolate_uni_poly(&evals, query));
assert_eq!(
poly.evaluate(&query),
extrapolate_uni_poly(evals[0], &evals[1..], query)
);
}

run_extrapolation_test(1);
Expand Down
55 changes: 30 additions & 25 deletions crates/sumcheck/src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ pub fn extrapolate_from_table<E: ExtensionField>(uni_variate: &mut [E], start: u
}
}

fn extrapolate_uni_poly_deg_1<F: Field>(p_i: &[F; 2], eval_at: F) -> F {
fn extrapolate_uni_poly_deg_1<F: Field>(p0: F, p1: F, eval_at: F) -> F {
let x0 = F::ZERO;
let x1 = F::ONE;

Expand All @@ -69,13 +69,13 @@ fn extrapolate_uni_poly_deg_1<F: Field>(p_i: &[F; 2], eval_at: F) -> F {
let inv_d0 = d0.inverse();
let inv_d1 = d1.inverse();

let t0 = w0 * p_i[0] * inv_d0;
let t1 = w1 * p_i[1] * inv_d1;
let t0 = w0 * p0 * inv_d0;
let t1 = w1 * p1 * inv_d1;

l * (t0 + t1)
}

fn extrapolate_uni_poly_deg_2<F: Field>(p_i: &[F; 3], eval_at: F) -> F {
fn extrapolate_uni_poly_deg_2<F: Field>(p0: F, p1: F, p2: F, eval_at: F) -> F {
let x0 = F::from_canonical_u64(0);
let x1 = F::from_canonical_u64(1);
let x2 = F::from_canonical_u64(2);
Expand All @@ -97,14 +97,14 @@ fn extrapolate_uni_poly_deg_2<F: Field>(p_i: &[F; 3], eval_at: F) -> F {
let inv_d1 = d1.inverse();
let inv_d2 = d2.inverse();

let t0 = w0 * p_i[0] * inv_d0;
let t1 = w1 * p_i[1] * inv_d1;
let t2 = w2 * p_i[2] * inv_d2;
let t0 = w0 * p0 * inv_d0;
let t1 = w1 * p1 * inv_d1;
let t2 = w2 * p2 * inv_d2;

l * (t0 + t1 + t2)
}

fn extrapolate_uni_poly_deg_3<F: Field>(p_i: &[F; 4], eval_at: F) -> F {
fn extrapolate_uni_poly_deg_3<F: Field>(p0: F, p1: F, p2: F, p3: F, eval_at: F) -> F {
let x0 = F::from_canonical_u64(0);
let x1 = F::from_canonical_u64(1);
let x2 = F::from_canonical_u64(2);
Expand All @@ -131,15 +131,15 @@ fn extrapolate_uni_poly_deg_3<F: Field>(p_i: &[F; 4], eval_at: F) -> F {
let inv_d2 = d2.inverse();
let inv_d3 = d3.inverse();

let t0 = w0 * p_i[0] * inv_d0;
let t1 = w1 * p_i[1] * inv_d1;
let t2 = w2 * p_i[2] * inv_d2;
let t3 = w3 * p_i[3] * inv_d3;
let t0 = w0 * p0 * inv_d0;
let t1 = w1 * p1 * inv_d1;
let t2 = w2 * p2 * inv_d2;
let t3 = w3 * p3 * inv_d3;

l * (t0 + t1 + t2 + t3)
}

fn extrapolate_uni_poly_deg_4<F: Field>(p_i: &[F; 5], eval_at: F) -> F {
fn extrapolate_uni_poly_deg_4<F: Field>(p0: F, p1: F, p2: F, p3: F, p4: F, eval_at: F) -> F {
let x0 = F::from_canonical_u64(0);
let x1 = F::from_canonical_u64(1);
let x2 = F::from_canonical_u64(2);
Expand Down Expand Up @@ -171,11 +171,11 @@ fn extrapolate_uni_poly_deg_4<F: Field>(p_i: &[F; 5], eval_at: F) -> F {
let inv_d3 = d3.inverse();
let inv_d4 = d4.inverse();

let t0 = w0 * p_i[0] * inv_d0;
let t1 = w1 * p_i[1] * inv_d1;
let t2 = w2 * p_i[2] * inv_d2;
let t3 = w3 * p_i[3] * inv_d3;
let t4 = w4 * p_i[4] * inv_d4;
let t0 = w0 * p0 * inv_d0;
let t1 = w1 * p1 * inv_d1;
let t2 = w2 * p2 * inv_d2;
let t3 = w3 * p3 * inv_d3;
let t4 = w4 * p4 * inv_d4;

l * (t0 + t1 + t2 + t3 + t4)
}
Expand All @@ -195,18 +195,23 @@ fn extrapolate_uni_poly_deg_4<F: Field>(p_i: &[F; 5], eval_at: F) -> F {
/// with unrolled loops for performance
///
/// # Arguments
/// * `p_i` - Values of the polynomial at consecutive integer points.
/// * `p0` - Polynomial evaluation at point 0.
/// * `p_i` - Values of the polynomial at consecutive integer points starting from 1.
/// * `eval_at` - The point at which to evaluate the interpolated polynomial.
///
/// # Returns
/// The value of the polynomial `eval_at`.
pub fn extrapolate_uni_poly<F: Field>(p: &[F], eval_at: F) -> F {
pub fn extrapolate_uni_poly<F: Field>(p0: F, p: &[F], eval_at: F) -> F {
assert!(
!p.is_empty(),
"at least one evaluation beyond p(0) is required"
);
match p.len() {
2 => extrapolate_uni_poly_deg_1(p.try_into().unwrap(), eval_at),
3 => extrapolate_uni_poly_deg_2(p.try_into().unwrap(), eval_at),
4 => extrapolate_uni_poly_deg_3(p.try_into().unwrap(), eval_at),
5 => extrapolate_uni_poly_deg_4(p.try_into().unwrap(), eval_at),
_ => unimplemented!("Extrapolation for degree {} not implemented", p.len() - 1),
1 => extrapolate_uni_poly_deg_1(p0, p[0], eval_at),
2 => extrapolate_uni_poly_deg_2(p0, p[0], p[1], eval_at),
3 => extrapolate_uni_poly_deg_3(p0, p[0], p[1], p[2], eval_at),
4 => extrapolate_uni_poly_deg_4(p0, p[0], p[1], p[2], p[3], eval_at),
_ => unimplemented!("Extrapolation for degree {} not implemented", p.len()),
}
}

Expand Down
49 changes: 29 additions & 20 deletions crates/sumcheck/src/verifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,39 +122,48 @@ impl<E: ExtensionField> IOPVerifierState<E> {

// the deferred check during the interactive phase:
// 2. set `expected` to P(r)`
let mut expected_vec = self
let (expected_vec, evals_0) = self
.polynomials_received
.iter()
.zip(self.challenges.iter())
.map(|(evaluations, challenge)| {
if evaluations.len() != self.max_degree + 1 {
panic!(
"incorrect number of evaluations: {} vs {}",
evaluations.len(),
self.max_degree + 1
);
}
extrapolate_uni_poly::<E>(evaluations, challenge.elements)
})
.collect::<Vec<_>>();

// l-append asserted_sum to the first position of the expected vector
expected_vec.insert(0, *asserted_sum);

for (i, (evaluations, &expected)) in self
.fold(
(vec![*asserted_sum], vec![]),
|(mut claims, mut evals_0), (evaluations, challenge)| {
let last_claim = claims.last().copied().unwrap();
if evaluations.len() != self.max_degree {
panic!(
"incorrect number of evaluations: {} vs {}",
evaluations.len(),
self.max_degree
);
}
// https://eprint.iacr.org/2024/108.pdf sec 3.1 derive eval_0 = claim - eval_1
let eval_0 = last_claim - evaluations.first().copied().unwrap();
evals_0.push(eval_0);
claims.push(extrapolate_uni_poly::<E>(
eval_0,
evaluations,
challenge.elements,
));
(claims, evals_0)
},
);

for (i, ((evaluations, &expected), eval_0)) in self
.polynomials_received
.iter()
.zip(expected_vec.iter())
.zip(&expected_vec)
.zip(&evals_0)
.enumerate()
.take(self.num_vars)
{
// the deferred check during the interactive phase:
// 1. check if the received 'P(0) + P(1) = expected`.
if evaluations[0] + evaluations[1] != expected {
if *eval_0 + evaluations[0] != expected {
panic!(
"{}th round's prover message is not consistent with the claim. {:?} {:?}",
i,
evaluations[0] + evaluations[1],
*eval_0 + evaluations[0],
expected
);
}
Expand Down