From beb4ec8dc9db572539a81c094ae82ac35f7c26c9 Mon Sep 17 00:00:00 2001 From: kunxian xia Date: Tue, 24 Mar 2026 23:35:26 +0800 Subject: [PATCH] build round 2 univariate poly from original evaluations --- crates/multilinear_extensions/src/mle.rs | 70 ++++++ crates/sumcheck/src/prover.rs | 303 +++++++++++++++++++---- crates/sumcheck/src/structs.rs | 2 + 3 files changed, 333 insertions(+), 42 deletions(-) diff --git a/crates/multilinear_extensions/src/mle.rs b/crates/multilinear_extensions/src/mle.rs index f1370da..7589c84 100644 --- a/crates/multilinear_extensions/src/mle.rs +++ b/crates/multilinear_extensions/src/mle.rs @@ -567,6 +567,76 @@ impl<'a, E: ExtensionField> MultilinearExtension<'a, E> { self.num_vars = nv - partial_point.len(); } + /// Reduce the number of variables by 2 in one pass. + /// + /// This avoids calling `fix_variables` twice and directly computes + /// `f(r0, r1, ..)` from 4-point blocks. + pub fn fix_two_variables(&self, r0: E, r1: E) -> Self { + assert!(self.num_vars() >= 2, "num_vars {} < 2", self.num_vars()); + let nv = self.num_vars(); + match self.evaluations() { + FieldType::Base(slice) => MultilinearExtension::from_evaluations_ext_vec( + nv - 2, + slice + .chunks(4) + .map(|buf| { + let y0 = r0 * (buf[1] - buf[0]) + buf[0]; + let y1 = r0 * (buf[3] - buf[2]) + buf[2]; + y0 + (y1 - y0) * r1 + }) + .collect(), + ), + FieldType::Ext(slice) => MultilinearExtension::from_evaluations_ext_vec( + nv - 2, + slice + .chunks(4) + .map(|buf| { + let y0 = buf[0] + (buf[1] - buf[0]) * r0; + let y1 = buf[2] + (buf[3] - buf[2]) * r0; + y0 + (y1 - y0) * r1 + }) + .collect(), + ), + FieldType::Unreachable => unreachable!(), + } + } + + /// In-place variant of `fix_two_variables`. + pub fn fix_two_variables_in_place(&mut self, r0: E, r1: E) { + assert!(self.is_mut()); + assert!(self.num_vars() >= 2, "num_vars {} < 2", self.num_vars()); + let nv = self.num_vars(); + + match &mut self.evaluations { + FieldType::Base(slice) => { + let slice_ext = slice + .chunks(4) + .map(|buf| { + let y0 = r0 * (buf[1] - buf[0]) + buf[0]; + let y1 = r0 * (buf[3] - buf[2]) + buf[2]; + y0 + (y1 - y0) * r1 + }) + .collect(); + let _ = mem::replace( + &mut self.evaluations, + FieldType::Ext(SmartSlice::Owned(slice_ext)), + ); + } + FieldType::Ext(slice) => { + let slice_mut = slice.to_mut(); + (0..slice_mut.len()).step_by(4).for_each(|b| { + let y0 = slice_mut[b] + (slice_mut[b + 1] - slice_mut[b]) * r0; + let y1 = slice_mut[b + 2] + (slice_mut[b + 3] - slice_mut[b + 2]) * r0; + slice_mut[b >> 2] = y0 + (y1 - y0) * r1; + }); + slice.truncate_mut(1 << (nv - 2)); + } + FieldType::Unreachable => unreachable!(), + } + + self.num_vars = nv - 2; + } + /// Evaluate the MLE at a give point. /// Returns an error if the MLE length does not match the point. pub fn evaluate(&self, point: &[E]) -> E { diff --git a/crates/sumcheck/src/prover.rs b/crates/sumcheck/src/prover.rs index 900b873..ef91351 100644 --- a/crates/sumcheck/src/prover.rs +++ b/crates/sumcheck/src/prover.rs @@ -294,6 +294,7 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> { // This accounts for multiple phases and potential continuation challenges, // ensuring we avoid reallocations when the protocol spans multiple rounds challenges: Vec::with_capacity(2 * polynomial.aux_info.max_num_variables), + pending_r0: None, round: 0, poly: polynomial, poly_meta: poly_meta.unwrap_or_else(|| vec![PolyMeta::Normal; num_polys]), @@ -335,6 +336,7 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> { // // eval g over r_m, and mutate g to g(r_1, ... r_m,, x_{m+1}... x_n) let span = entered_span!("fix_variables"); + let mut build_with_deferred_first_round = None; if self.round > 0 { assert!( challenge.is_some(), @@ -344,7 +346,13 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> { let chal = challenge.unwrap(); self.challenges.push(chal); let r = self.challenges.last().unwrap(); - self.fix_var(r.elements); + if self.round == 1 { + // Avoid writing round-1 folded MLE evaluations into memory. + self.pending_r0 = Some(r.elements); + build_with_deferred_first_round = Some(r.elements); + } else { + self.fix_var(r.elements); + } } exit_span!(span); // exit_span!fix_argument); @@ -354,7 +362,113 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> { // Step 2: generate sum for the partial evaluated polynomial: // f(r_1, ... r_m,, x_{m+1}... x_n) let span = entered_span!("build_uni_poly"); - let AdditiveVec(mut uni_polys) = self.poly.products.iter().fold( + let AdditiveVec(mut uni_polys) = if let Some(r0) = build_with_deferred_first_round { + self.build_uni_poly_round2(r0) + } else { + self.poly.products.iter().fold( + AdditiveVec::new(self.poly.aux_info.max_degree + 1), + |mut uni_polys, MonomialTerms { terms }| { + for Term { + scalar, + product: prod, + } in terms + { + let f = &self.poly.flattened_ml_extensions; + let f_type = &self.poly_meta; + let get_poly_meta = || f_type[prod[0]]; + let mut uni_variate: Vec = + vec![E::ZERO; self.poly.aux_info.max_degree + 1]; + let uni_variate_monomial: Vec = match prod.len() { + 1 => sumcheck_code_gen!(1, false, |i| &f[prod[i]], || get_poly_meta()) + .to_vec(), + 2 => sumcheck_code_gen!(2, false, |i| &f[prod[i]], || get_poly_meta()) + .to_vec(), + 3 => sumcheck_code_gen!(3, false, |i| &f[prod[i]], || get_poly_meta()) + .to_vec(), + 4 => sumcheck_code_gen!(4, false, |i| &f[prod[i]], || get_poly_meta()) + .to_vec(), + 5 => sumcheck_code_gen!(5, false, |i| &f[prod[i]], || get_poly_meta()) + .to_vec(), + 6 => sumcheck_code_gen!(6, false, |i| &f[prod[i]], || get_poly_meta()) + .to_vec(), + _ => unimplemented!("do not support degree {} > 6", prod.len()), + }; + + uni_variate + .iter_mut() + .zip(uni_variate_monomial) + .take(prod.len() + 1) + .for_each(|(eval, monimial_eval,)| either::for_both!(scalar, scalar => *eval = monimial_eval**scalar)); + + + if prod.len() < self.poly.aux_info.max_degree { + // Perform extrapolation using the precomputed extrapolation table + extrapolate_from_table(&mut uni_variate, prod.len() + 1); + } + + uni_polys += AdditiveVec(uni_variate); + } + uni_polys + }, + ) + }; + exit_span!(span); + + exit_span!(start); + + assert!(uni_polys.len() > 1); + // NOTE remove uni_polys.eval(0) from lagrange domain + // as verifier can derive via claim - uni_polys.eval(1) + uni_polys.remove(0); + + IOPProverMessage { + evaluations: uni_polys, + } + } + + fn mle_eval_at_index(mle: &Arc>, index: usize) -> E { + match mle.evaluations() { + FieldType::Base(slice) => E::from(slice[index]), + FieldType::Ext(slice) => slice[index], + FieldType::Unreachable => unreachable!(), + } + } + + fn mle_eval_round2_without_cached_first_round( + mle: &Arc>, + offset: usize, + r0: E, + x: E, + ) -> E { + let f00 = Self::mle_eval_at_index(mle, offset); // f(0,0,b) + let f10 = Self::mle_eval_at_index(mle, offset + 1); // f(1,0,b) + let f01 = Self::mle_eval_at_index(mle, offset + 2); // f(0,1,b) + let f11 = Self::mle_eval_at_index(mle, offset + 3); // f(1,1,b) + + let y0 = f00 + (f10 - f00) * r0; // f(r0,0,b) + let y1 = f01 + (f11 - f01) * r0; // f(r0,1,b) + y0 + (y1 - y0) * x // f(r0,x,b) + } + + fn mle_eval_single_var( + mle: &Arc>, + offset: usize, + x: E, + ) -> E { + let y0 = Self::mle_eval_at_index(mle, offset); + let y1 = Self::mle_eval_at_index(mle, offset + 1); + y0 + (y1 - y0) * x + } + + /// build univariate polynomial for round 2 directly from original MLE evaluations + /// h(x) = \sum_b f(r0, x, b) + /// = eq(r0,0)*f(0,x,b) + eq(r0,1)*f(1,x,b) + /// = eq(r0,0)*eq(x,0)*f(0,0,b) + /// + eq(r0,0)*eq(x,1)*f(0,1,b) + /// + eq(r0,1)*eq(x,0)*f(1,0,b) + /// + eq(r0,1)*eq(x,1)*f(1,1,b) + fn build_uni_poly_round2(&self, r0: E) -> AdditiveVec { + self.poly.products.iter().fold( AdditiveVec::new(self.poly.aux_info.max_degree + 1), |mut uni_polys, MonomialTerms { terms }| { for Term { @@ -363,57 +477,124 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> { } in terms { let f = &self.poly.flattened_ml_extensions; - let f_type = &self.poly_meta; - let get_poly_meta = || f_type[prod[0]]; - let mut uni_variate: Vec = vec![E::ZERO; self.poly.aux_info.max_degree + 1]; - let uni_variate_monomial: Vec = match prod.len() { - 1 => sumcheck_code_gen!(1, false, |i| &f[prod[i]], || get_poly_meta()) - .to_vec(), - 2 => sumcheck_code_gen!(2, false, |i| &f[prod[i]], || get_poly_meta()) - .to_vec(), - 3 => sumcheck_code_gen!(3, false, |i| &f[prod[i]], || get_poly_meta()) - .to_vec(), - 4 => sumcheck_code_gen!(4, false, |i| &f[prod[i]], || get_poly_meta()) - .to_vec(), - 5 => sumcheck_code_gen!(5, false, |i| &f[prod[i]], || get_poly_meta()) - .to_vec(), - 6 => sumcheck_code_gen!(6, false, |i| &f[prod[i]], || get_poly_meta()) - .to_vec(), - _ => unimplemented!("do not support degree {} > 6", prod.len()), + let poly_type = self.poly_meta[prod[0]]; + let num_var = f[prod[0]].num_vars(); + // `self.round` has already been incremented before this builder runs. + // Conceptually, round-1 has been fixed, so this is the post-fix expected + // variable count used by the existing code path. + let expected_numvars_post_fix = self.expected_numvars_at_round(); + // Since we intentionally deferred the first-round fix, in-memory MLEs still + // carry one extra variable for the max-numvar branch. + let expected_numvars_unfixed = expected_numvars_post_fix + 1; + + let mut uni_variate = vec![E::ZERO; self.poly.aux_info.max_degree + 1]; + let degree = prod.len(); + + let one_term_product_at = |idx: usize| { + prod.iter() + .map(|&poly_idx| Self::mle_eval_at_index(&f[poly_idx], idx)) + .product::() }; + match poly_type { + PolyMeta::Phase2Only => { + if self.is_main_worker { + let eval_len = f[prod[0]].evaluations().len(); + let mut sum = if eval_len == 1 { + one_term_product_at(0) + } else { + (0..largest_even_below(eval_len)) + .map(one_term_product_at) + .sum() + }; + let multiplicity = self + .expected_numvars_at_round() + .saturating_add(self.phase2_numvar.unwrap_or(0)) + .saturating_sub(1) + .saturating_sub(num_var); + if multiplicity > 0 { + let factor = 1u64 + .checked_shl(multiplicity as u32) + .expect("phase2 multiplicity overflow"); + sum *= E::from_canonical_u64(factor); + } + uni_variate.iter_mut().take(degree + 1).for_each(|v| *v = sum); + } + } + PolyMeta::Normal if num_var + 1 == expected_numvars_unfixed => { + let eval_len = f[prod[0]].evaluations().len(); + for x_idx in 0..=degree { + let x = E::from_canonical_u64(x_idx as u64); + let sum = (0..eval_len) + .step_by(2) + .map(|b| { + prod.iter() + .map(|&poly_idx| { + Self::mle_eval_single_var(&f[poly_idx], b, x) + }) + .product::() + }) + .sum(); + uni_variate[x_idx] = sum; + } + } + PolyMeta::Normal if num_var + 1 < expected_numvars_unfixed => { + let eval_len = f[prod[0]].evaluations().len(); + let mut sum = if eval_len == 1 { + one_term_product_at(0) + } else { + (0..largest_even_below(eval_len)).map(one_term_product_at).sum() + }; + let multiplicity = expected_numvars_post_fix + .saturating_sub(1) + .saturating_sub(num_var); + if multiplicity > 0 { + let factor = 1u64 + .checked_shl(multiplicity as u32) + .expect("normal multiplicity overflow"); + sum *= E::from_canonical_u64(factor); + } + uni_variate.iter_mut().take(degree + 1).for_each(|v| *v = sum); + } + PolyMeta::Normal => { + debug_assert_eq!(num_var, expected_numvars_unfixed); + let eval_len = f[prod[0]].evaluations().len(); + for x_idx in 0..=degree { + let x = E::from_canonical_u64(x_idx as u64); + let sum = (0..eval_len) + .step_by(4) + .map(|b| { + prod.iter() + .map(|&poly_idx| { + Self::mle_eval_round2_without_cached_first_round( + &f[poly_idx], + b, + r0, + x, + ) + }) + .product::() + }) + .sum(); + uni_variate[x_idx] = sum; + } + } + } + uni_variate .iter_mut() - .zip(uni_variate_monomial) - .take(prod.len() + 1) - .for_each(|(eval, monimial_eval,)| either::for_both!(scalar, scalar => *eval = monimial_eval**scalar)); - + .take(degree + 1) + .for_each(|eval| either::for_both!(scalar, scalar => *eval *= *scalar)); - if prod.len() < self.poly.aux_info.max_degree { - // Perform extrapolation using the precomputed extrapolation table - extrapolate_from_table( - &mut uni_variate, - prod.len() + 1, - ); + if degree < self.poly.aux_info.max_degree { + extrapolate_from_table(&mut uni_variate, degree + 1); } uni_polys += AdditiveVec(uni_variate); } uni_polys }, - ); - exit_span!(span); - - exit_span!(start); - - assert!(uni_polys.len() > 1); - // NOTE remove uni_polys.eval(0) from lagrange domain - // as verifier can derive via claim - uni_polys.eval(1) - uni_polys.remove(0); - - IOPProverMessage { - evaluations: uni_polys, - } + ) } /// collect all mle evaluation (claim) after sumcheck @@ -448,6 +629,11 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> { /// fix_var pub fn fix_var(&mut self, r: E) { + if let Some(r0) = self.pending_r0.take() { + self.fix_two_vars(r0, r); + return; + } + let expected_numvars_at_round = self.expected_numvars_at_round(); self.poly .flattened_ml_extensions @@ -467,6 +653,38 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> { } }); } + + fn fix_two_vars(&mut self, r0: E, r1: E) { + // At this point we are consuming round-2 challenge `r1` while `r0` was deferred. + // - MLEs with num_vars = expected + 1 still need both `r0` and `r1`. + // - MLEs with num_vars = expected were independent of the first round variable, + // so they only need `r1`. + let expected_numvars_at_round = self.expected_numvars_at_round(); + self.poly + .flattened_ml_extensions + .iter_mut() + .zip_eq(&self.poly_meta) + .for_each(|(poly, poly_type)| { + debug_assert!(poly.num_vars() > 0); + if matches!(poly_type, PolyMeta::Normal) { + if poly.num_vars() == expected_numvars_at_round + 1 { + if !poly.is_mut() { + *poly = Arc::new(poly.fix_two_variables(r0, r1)); + } else { + let poly = Arc::get_mut(poly).unwrap(); + poly.fix_two_variables_in_place(r0, r1) + } + } else if poly.num_vars() == expected_numvars_at_round { + if !poly.is_mut() { + *poly = Arc::new(poly.fix_variables(&[r1])); + } else { + let poly = Arc::get_mut(poly).unwrap(); + poly.fix_variables_in_place(&[r1]) + } + } + } + }); + } } /// parallel version @@ -551,6 +769,7 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> { is_main_worker: true, max_num_variables: polynomial.aux_info.max_num_variables, challenges: Vec::with_capacity(polynomial.aux_info.max_num_variables), + pending_r0: None, round: 0, poly: polynomial, poly_meta, diff --git a/crates/sumcheck/src/structs.rs b/crates/sumcheck/src/structs.rs index ddabc55..b4051f4 100644 --- a/crates/sumcheck/src/structs.rs +++ b/crates/sumcheck/src/structs.rs @@ -33,6 +33,8 @@ pub struct IOPProverState<'a, E: ExtensionField> { pub is_main_worker: bool, /// sampled randomness given by the verifier pub challenges: Vec>, + /// Defer fixing the MLEs until the second-round challenge arrives. + pub(crate) pending_r0: Option, /// the current round number pub(crate) round: usize, /// pointer to the virtual polynomial