diff --git a/crates/mpcs/src/basefold/commit_phase.rs b/crates/mpcs/src/basefold/commit_phase.rs index d2544c6..17d7b93 100644 --- a/crates/mpcs/src/basefold/commit_phase.rs +++ b/crates/mpcs/src/basefold/commit_phase.rs @@ -417,7 +417,7 @@ where let evaluations: AdditiveVec = prover_msgs .into_iter() - .fold(AdditiveVec::new(3), |mut acc, prover_msg| { + .fold(AdditiveVec::new(2), |mut acc, prover_msg| { acc += AdditiveVec(prover_msg.evaluations); acc }); diff --git a/crates/mpcs/src/basefold/query_phase.rs b/crates/mpcs/src/basefold/query_phase.rs index b501075..8daf593 100644 --- a/crates/mpcs/src/basefold/query_phase.rs +++ b/crates/mpcs/src/basefold/query_phase.rs @@ -282,22 +282,28 @@ pub fn batch_verifier_query_phase>( .sum::(); } } - assert_eq!(expected_sum, { - sumcheck_messages[0].evaluations[0] + sumcheck_messages[0].evaluations[1] - }); - // 2. check every round of sumcheck match with prev claims - for i in 0..fold_challenges.len() - 1 { - assert_eq!( - extrapolate_uni_poly(&sumcheck_messages[i].evaluations, fold_challenges[i]), - { sumcheck_messages[i + 1].evaluations[0] + sumcheck_messages[i + 1].evaluations[1] } - ); + + assert_eq!( + sumcheck_messages.len(), + fold_challenges.len(), + "sumcheck messages and fold challenges length mismatch" + ); + + // Reconstruct the implicit P(0) evaluation for each round and update the claim in place. + let mut current_claim = expected_sum; + for (msg, challenge) in sumcheck_messages.iter().zip(fold_challenges.iter()) { + let eval_1 = msg + .evaluations + .first() + .copied() + .expect("sumcheck prover message missing evaluations"); + let eval_0 = current_claim - eval_1; + current_claim = extrapolate_uni_poly(eval_0, &msg.evaluations, *challenge); } - // 3. check final evaluation are correct + + // check final evaluation are correct assert_eq!( - extrapolate_uni_poly( - &sumcheck_messages[fold_challenges.len() - 1].evaluations, - fold_challenges[fold_challenges.len() - 1] - ), + current_claim, // \sum_i eq(p,[r,i]) * f(r,i) izip!( final_message, diff --git a/crates/sumcheck/src/prover.rs b/crates/sumcheck/src/prover.rs index ff53cab..bd9efb5 100644 --- a/crates/sumcheck/src/prover.rs +++ b/crates/sumcheck/src/prover.rs @@ -45,7 +45,7 @@ impl<'a, E: ExtensionField> Phase1Workers<'a, E> { .workers_states .par_iter_mut() .map(|state| state.run_round()) - .reduce(|| AdditiveVec::new(max_degree + 1), |a, b| a + b); + .reduce(|| AdditiveVec::new(max_degree), |a, b| a + b); transcript.append_field_element_exts(&evaluations.0); @@ -353,7 +353,7 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> { // Step 2: generate sum for the partial evaluated polynomial: // f(r_1, ... r_m,, x_{m+1}... x_n) let span = entered_span!("build_uni_poly"); - let AdditiveVec(uni_polys) = self.poly.products.iter().fold( + let AdditiveVec(mut uni_polys) = self.poly.products.iter().fold( AdditiveVec::new(self.poly.aux_info.max_degree + 1), |mut uni_polys, MonomialTerms { terms }| { for Term { @@ -405,6 +405,11 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> { exit_span!(start); + assert!(uni_polys.len() > 1); + // NOTE remove uni_polys.eval(0) from lagrange domain + // as verifier can derive via claim - uni_polys.eval(1) + uni_polys.remove(0); + IOPProverMessage { evaluations: uni_polys, } @@ -603,7 +608,7 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> { // Step 2: generate sum for the partial evaluated polynomial: // f(r_1, ... r_m,, x_{m+1}... x_n) let span = entered_span!("build_uni_poly"); - let AdditiveVec(uni_polys) = self + let AdditiveVec(mut uni_polys) = self .poly .products .par_iter() @@ -654,9 +659,13 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> { .reduce_with(|acc, item| acc + item) .unwrap(); exit_span!(span); - exit_span!(start); + assert!(uni_polys.len() > 1); + // NOTE remove uni_polys.eval(0) from lagrange domain + // as verifier can derive via claim - uni_polys.eval(1) + uni_polys.remove(0); + IOPProverMessage { evaluations: uni_polys, } diff --git a/crates/sumcheck/src/structs.rs b/crates/sumcheck/src/structs.rs index bfbfb2b..ddabc55 100644 --- a/crates/sumcheck/src/structs.rs +++ b/crates/sumcheck/src/structs.rs @@ -15,11 +15,6 @@ use transcript::Challenge; pub struct IOPProof { pub proofs: Vec>, } -impl IOPProof { - pub fn extract_sum(&self) -> E { - self.proofs[0].evaluations[0] + self.proofs[0].evaluations[1] - } -} /// A message from the prover to the verifier at a given round /// is a list of evaluations. diff --git a/crates/sumcheck/src/test.rs b/crates/sumcheck/src/test.rs index 4b8b908..fc37738 100644 --- a/crates/sumcheck/src/test.rs +++ b/crates/sumcheck/src/test.rs @@ -190,21 +190,6 @@ fn test_normal_polynomial_helper() { test_sumcheck_internal::(nv, num_multiplicands_range, num_products); } -#[test] -fn test_extract_sum() { - test_extract_sum_helper::(); - test_extract_sum_helper::(); -} - -fn test_extract_sum_helper() { - let mut rng = thread_rng(); - let mut transcript = BasicTranscript::new(b"test"); - let (poly, asserted_sum) = VirtualPolynomial::::random(&[8], (2, 3), 3, &mut rng); - #[allow(deprecated)] - let (proof, _) = IOPProverState::::prove_parallel(poly, &mut transcript); - assert_eq!(proof.extract_sum(), asserted_sum); -} - struct DensePolynomial(Vec); impl DensePolynomial { @@ -236,7 +221,10 @@ fn test_extrapolation() { .map(|i| poly.evaluate(&GoldilocksExt2::from_canonical_u64(i as u64))) .collect::>(); let query = GoldilocksExt2::random(&mut prng); - assert_eq!(poly.evaluate(&query), extrapolate_uni_poly(&evals, query)); + assert_eq!( + poly.evaluate(&query), + extrapolate_uni_poly(evals[0], &evals[1..], query) + ); } run_extrapolation_test(1); diff --git a/crates/sumcheck/src/util.rs b/crates/sumcheck/src/util.rs index 18c0627..955bdab 100644 --- a/crates/sumcheck/src/util.rs +++ b/crates/sumcheck/src/util.rs @@ -53,7 +53,7 @@ pub fn extrapolate_from_table(uni_variate: &mut [E], start: u } } -fn extrapolate_uni_poly_deg_1(p_i: &[F; 2], eval_at: F) -> F { +fn extrapolate_uni_poly_deg_1(p0: F, p1: F, eval_at: F) -> F { let x0 = F::ZERO; let x1 = F::ONE; @@ -69,13 +69,13 @@ fn extrapolate_uni_poly_deg_1(p_i: &[F; 2], eval_at: F) -> F { let inv_d0 = d0.inverse(); let inv_d1 = d1.inverse(); - let t0 = w0 * p_i[0] * inv_d0; - let t1 = w1 * p_i[1] * inv_d1; + let t0 = w0 * p0 * inv_d0; + let t1 = w1 * p1 * inv_d1; l * (t0 + t1) } -fn extrapolate_uni_poly_deg_2(p_i: &[F; 3], eval_at: F) -> F { +fn extrapolate_uni_poly_deg_2(p0: F, p1: F, p2: F, eval_at: F) -> F { let x0 = F::from_canonical_u64(0); let x1 = F::from_canonical_u64(1); let x2 = F::from_canonical_u64(2); @@ -97,14 +97,14 @@ fn extrapolate_uni_poly_deg_2(p_i: &[F; 3], eval_at: F) -> F { let inv_d1 = d1.inverse(); let inv_d2 = d2.inverse(); - let t0 = w0 * p_i[0] * inv_d0; - let t1 = w1 * p_i[1] * inv_d1; - let t2 = w2 * p_i[2] * inv_d2; + let t0 = w0 * p0 * inv_d0; + let t1 = w1 * p1 * inv_d1; + let t2 = w2 * p2 * inv_d2; l * (t0 + t1 + t2) } -fn extrapolate_uni_poly_deg_3(p_i: &[F; 4], eval_at: F) -> F { +fn extrapolate_uni_poly_deg_3(p0: F, p1: F, p2: F, p3: F, eval_at: F) -> F { let x0 = F::from_canonical_u64(0); let x1 = F::from_canonical_u64(1); let x2 = F::from_canonical_u64(2); @@ -131,15 +131,15 @@ fn extrapolate_uni_poly_deg_3(p_i: &[F; 4], eval_at: F) -> F { let inv_d2 = d2.inverse(); let inv_d3 = d3.inverse(); - let t0 = w0 * p_i[0] * inv_d0; - let t1 = w1 * p_i[1] * inv_d1; - let t2 = w2 * p_i[2] * inv_d2; - let t3 = w3 * p_i[3] * inv_d3; + let t0 = w0 * p0 * inv_d0; + let t1 = w1 * p1 * inv_d1; + let t2 = w2 * p2 * inv_d2; + let t3 = w3 * p3 * inv_d3; l * (t0 + t1 + t2 + t3) } -fn extrapolate_uni_poly_deg_4(p_i: &[F; 5], eval_at: F) -> F { +fn extrapolate_uni_poly_deg_4(p0: F, p1: F, p2: F, p3: F, p4: F, eval_at: F) -> F { let x0 = F::from_canonical_u64(0); let x1 = F::from_canonical_u64(1); let x2 = F::from_canonical_u64(2); @@ -171,11 +171,11 @@ fn extrapolate_uni_poly_deg_4(p_i: &[F; 5], eval_at: F) -> F { let inv_d3 = d3.inverse(); let inv_d4 = d4.inverse(); - let t0 = w0 * p_i[0] * inv_d0; - let t1 = w1 * p_i[1] * inv_d1; - let t2 = w2 * p_i[2] * inv_d2; - let t3 = w3 * p_i[3] * inv_d3; - let t4 = w4 * p_i[4] * inv_d4; + let t0 = w0 * p0 * inv_d0; + let t1 = w1 * p1 * inv_d1; + let t2 = w2 * p2 * inv_d2; + let t3 = w3 * p3 * inv_d3; + let t4 = w4 * p4 * inv_d4; l * (t0 + t1 + t2 + t3 + t4) } @@ -195,18 +195,23 @@ fn extrapolate_uni_poly_deg_4(p_i: &[F; 5], eval_at: F) -> F { /// with unrolled loops for performance /// /// # Arguments -/// * `p_i` - Values of the polynomial at consecutive integer points. +/// * `p0` - Polynomial evaluation at point 0. +/// * `p_i` - Values of the polynomial at consecutive integer points starting from 1. /// * `eval_at` - The point at which to evaluate the interpolated polynomial. /// /// # Returns /// The value of the polynomial `eval_at`. -pub fn extrapolate_uni_poly(p: &[F], eval_at: F) -> F { +pub fn extrapolate_uni_poly(p0: F, p: &[F], eval_at: F) -> F { + assert!( + !p.is_empty(), + "at least one evaluation beyond p(0) is required" + ); match p.len() { - 2 => extrapolate_uni_poly_deg_1(p.try_into().unwrap(), eval_at), - 3 => extrapolate_uni_poly_deg_2(p.try_into().unwrap(), eval_at), - 4 => extrapolate_uni_poly_deg_3(p.try_into().unwrap(), eval_at), - 5 => extrapolate_uni_poly_deg_4(p.try_into().unwrap(), eval_at), - _ => unimplemented!("Extrapolation for degree {} not implemented", p.len() - 1), + 1 => extrapolate_uni_poly_deg_1(p0, p[0], eval_at), + 2 => extrapolate_uni_poly_deg_2(p0, p[0], p[1], eval_at), + 3 => extrapolate_uni_poly_deg_3(p0, p[0], p[1], p[2], eval_at), + 4 => extrapolate_uni_poly_deg_4(p0, p[0], p[1], p[2], p[3], eval_at), + _ => unimplemented!("Extrapolation for degree {} not implemented", p.len()), } } diff --git a/crates/sumcheck/src/verifier.rs b/crates/sumcheck/src/verifier.rs index 9046940..b1c7e56 100644 --- a/crates/sumcheck/src/verifier.rs +++ b/crates/sumcheck/src/verifier.rs @@ -122,39 +122,48 @@ impl IOPVerifierState { // the deferred check during the interactive phase: // 2. set `expected` to P(r)` - let mut expected_vec = self + let (expected_vec, evals_0) = self .polynomials_received .iter() .zip(self.challenges.iter()) - .map(|(evaluations, challenge)| { - if evaluations.len() != self.max_degree + 1 { - panic!( - "incorrect number of evaluations: {} vs {}", - evaluations.len(), - self.max_degree + 1 - ); - } - extrapolate_uni_poly::(evaluations, challenge.elements) - }) - .collect::>(); - - // l-append asserted_sum to the first position of the expected vector - expected_vec.insert(0, *asserted_sum); - - for (i, (evaluations, &expected)) in self + .fold( + (vec![*asserted_sum], vec![]), + |(mut claims, mut evals_0), (evaluations, challenge)| { + let last_claim = claims.last().copied().unwrap(); + if evaluations.len() != self.max_degree { + panic!( + "incorrect number of evaluations: {} vs {}", + evaluations.len(), + self.max_degree + ); + } + // https://eprint.iacr.org/2024/108.pdf sec 3.1 derive eval_0 = claim - eval_1 + let eval_0 = last_claim - evaluations.first().copied().unwrap(); + evals_0.push(eval_0); + claims.push(extrapolate_uni_poly::( + eval_0, + evaluations, + challenge.elements, + )); + (claims, evals_0) + }, + ); + + for (i, ((evaluations, &expected), eval_0)) in self .polynomials_received .iter() - .zip(expected_vec.iter()) + .zip(&expected_vec) + .zip(&evals_0) .enumerate() .take(self.num_vars) { // the deferred check during the interactive phase: // 1. check if the received 'P(0) + P(1) = expected`. - if evaluations[0] + evaluations[1] != expected { + if *eval_0 + evaluations[0] != expected { panic!( "{}th round's prover message is not consistent with the claim. {:?} {:?}", i, - evaluations[0] + evaluations[1], + *eval_0 + evaluations[0], expected ); }