From 9d0021402c4ef83128dcd339606002b8148d8209 Mon Sep 17 00:00:00 2001 From: Batmend Batsaikhan Date: Mon, 15 Sep 2025 17:46:16 +0800 Subject: [PATCH 1/7] SPARK prover with 2 gpa combined --- Cargo.toml | 1 + provekit/common/src/lib.rs | 2 +- spark-prover/Cargo.toml | 26 ++ spark-prover/README.md | 22 ++ spark-prover/src/bin/generate_test_r1cs.rs | 24 ++ spark-prover/src/bin/generate_test_request.rs | 27 ++ spark-prover/src/bin/spark-verifier.rs | 255 ++++++++++++++++++ spark-prover/src/gpa.rs | 183 +++++++++++++ spark-prover/src/lib.rs | 5 + spark-prover/src/main.rs | 54 ++++ spark-prover/src/memory.rs | 55 ++++ spark-prover/src/spark.rs | 192 +++++++++++++ spark-prover/src/utilities/iopattern/mod.rs | 54 ++++ spark-prover/src/utilities/matrix/mod.rs | 86 ++++++ spark-prover/src/utilities/mod.rs | 76 ++++++ spark-prover/src/whir.rs | 71 +++++ 16 files changed, 1132 insertions(+), 1 deletion(-) create mode 100644 spark-prover/Cargo.toml create mode 100644 spark-prover/README.md create mode 100644 spark-prover/src/bin/generate_test_r1cs.rs create mode 100644 spark-prover/src/bin/generate_test_request.rs create mode 100644 spark-prover/src/bin/spark-verifier.rs create mode 100644 spark-prover/src/gpa.rs create mode 100644 spark-prover/src/lib.rs create mode 100644 spark-prover/src/main.rs create mode 100644 spark-prover/src/memory.rs create mode 100644 spark-prover/src/spark.rs create mode 100644 spark-prover/src/utilities/iopattern/mod.rs create mode 100644 spark-prover/src/utilities/matrix/mod.rs create mode 100644 spark-prover/src/utilities/mod.rs create mode 100644 spark-prover/src/whir.rs diff --git a/Cargo.toml b/Cargo.toml index ae6eda32..f2349477 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,6 +13,7 @@ members = [ "tooling/cli", "tooling/provekit-bench", "tooling/provekit-gnark", + "spark-prover", ] exclude = [ "playground/passport-input-gen", diff --git a/provekit/common/src/lib.rs b/provekit/common/src/lib.rs index b60f6921..0e9288f2 100644 --- a/provekit/common/src/lib.rs +++ b/provekit/common/src/lib.rs @@ -10,9 +10,9 @@ pub mod witness; use crate::{ interner::{InternedFieldElement, Interner}, - sparse_matrix::{HydratedSparseMatrix, SparseMatrix}, }; pub use { + sparse_matrix::{HydratedSparseMatrix, SparseMatrix}, acir::FieldElement as NoirElement, noir_proof_scheme::{NoirProof, NoirProofScheme}, r1cs::R1CS, diff --git a/spark-prover/Cargo.toml b/spark-prover/Cargo.toml new file mode 100644 index 00000000..7a488255 --- /dev/null +++ b/spark-prover/Cargo.toml @@ -0,0 +1,26 @@ +[package] +name = "spark-prover" +version = "0.1.0" +edition.workspace = true +rust-version.workspace = true +authors.workspace = true +license.workspace = true +homepage.workspace = true +repository.workspace = true + +[dependencies] +provekit-common.workspace = true +provekit-r1cs-compiler.workspace = true +serde_json.workspace = true +serde.workspace = true +anyhow.workspace = true +spongefish.workspace = true +whir.workspace = true +ark-std.workspace = true +ark-ff.workspace = true +itertools = "0.14.0" + + +[lints] +workspace = true + diff --git a/spark-prover/README.md b/spark-prover/README.md new file mode 100644 index 00000000..b9e5f93e --- /dev/null +++ b/spark-prover/README.md @@ -0,0 +1,22 @@ +# SPARK +Experimental Rust prover and gnark recursive prover circuit will be implemented and optimized here. + +## Running SPARK (under development) +```cargo run --bin spark-prover``` + +## Test R1CS generation (for development) +A development utility is provided to generate test matrices. +To generate a test R1CS, run the following command: + +```cargo run -p spark-prover --bin generate_test_r1cs``` + +## Test request generation (for development) +A development utility is provided to generate test requests. +To generate a test request, run the following command: + +```cargo run -p spark-prover --bin generate_test_request``` + +## Reference SPARK verifier (for development) +A reference SPARK verifier is implemented to test the correctness of the SPARK proof while being a reference implementation for the gnark verifier circuit. + +```cargo run -p spark-prover --bin spark-verifier``` \ No newline at end of file diff --git a/spark-prover/src/bin/generate_test_r1cs.rs b/spark-prover/src/bin/generate_test_r1cs.rs new file mode 100644 index 00000000..1a1cb7e0 --- /dev/null +++ b/spark-prover/src/bin/generate_test_r1cs.rs @@ -0,0 +1,24 @@ +use { + provekit_common::{FieldElement, R1CS}, + std::{fs::File, io::Write}, +}; + +fn main() { + let mut r1cs = R1CS::new(); + r1cs.grow_matrices(1024, 512); + let interned_1 = r1cs.interner.intern(FieldElement::from(1)); + + for i in 0..64 { + r1cs.a.set(i, i, interned_1); + r1cs.b.set(i, i, interned_1); + r1cs.c.set(i, i, interned_1); + } + + let matrix_json = + serde_json::to_string(&r1cs).expect("Error: Failed to serialize R1CS to JSON"); + let mut request_file = File::create("spark-prover/r1cs.json") + .expect("Error: Failed to create the r1cs.json file"); + request_file + .write_all(matrix_json.as_bytes()) + .expect("Error: Failed to write JSON data to r1cs.json"); +} diff --git a/spark-prover/src/bin/generate_test_request.rs b/spark-prover/src/bin/generate_test_request.rs new file mode 100644 index 00000000..4cc4f7f5 --- /dev/null +++ b/spark-prover/src/bin/generate_test_request.rs @@ -0,0 +1,27 @@ +use { + provekit_common::FieldElement, + spark_prover::utilities::{ClaimedValues, Point, SPARKRequest}, + std::{fs::File, io::Write}, +}; + +fn main() { + let spark_request = SPARKRequest { + point_to_evaluate: Point { + row: vec![FieldElement::from(0); 10], + col: vec![FieldElement::from(0); 9], + }, + claimed_values: ClaimedValues { + a: FieldElement::from(1), + b: FieldElement::from(1), + c: FieldElement::from(1), + }, + }; + + let request_json = + serde_json::to_string(&spark_request).expect("Error: Failed to serialize R1CS to JSON"); + let mut request_file = File::create("spark-prover/request.json") + .expect("Error: Failed to create the request.json file"); + request_file + .write_all(request_json.as_bytes()) + .expect("Error: Failed to write JSON data to request.json"); +} diff --git a/spark-prover/src/bin/spark-verifier.rs b/spark-prover/src/bin/spark-verifier.rs new file mode 100644 index 00000000..9334ee82 --- /dev/null +++ b/spark-prover/src/bin/spark-verifier.rs @@ -0,0 +1,255 @@ +use { + anyhow::{ensure, Context, Result}, + ark_std::{One, Zero}, + provekit_common::{ + utils::{ + next_power_of_two, + sumcheck::{calculate_eq, eval_cubic_poly}, + }, + FieldElement, IOPattern, skyscraper::SkyscraperSponge, + }, + spark_prover::utilities::{SPARKProof, SPARKRequest}, + spongefish::{ + codecs::arkworks_algebra::{FieldToUnitDeserialize, UnitToField}, + VerifierState, + }, + std::fs::{self, File}, + whir::{ + poly_utils::multilinear::MultilinearPoint, + whir::{ + committer::CommitmentReader, + statement::{Statement, Weights}, + utils::HintDeserialize, + verifier::Verifier, + }, + }, +}; + +fn main() -> Result<()> { + let spark_proof_json_str = fs::read_to_string("spark-prover/spark_proof.json") + .context("Error: Failed to open the r1cs.json file")?; + let spark_proof: SPARKProof = serde_json::from_str(&spark_proof_json_str) + .context("Error: Failed to deserialize JSON to R1CS")?; + + let request_json_str = fs::read_to_string("spark-prover/request.json") + .context("Error: Failed to open the r1cs.json file")?; + let request: SPARKRequest = serde_json::from_str(&request_json_str) + .context("Error: Failed to deserialize JSON to R1CS")?; + + let io = IOPattern::from_string(spark_proof.io_pattern); + let mut arthur = io.to_verifier_state(&spark_proof.transcript); + + let commitment_reader = CommitmentReader::new(&spark_proof.whir_params.a); + let commitment_reader_row = CommitmentReader::new(&spark_proof.whir_params.row); + + let val_commitment = commitment_reader.parse_commitment(&mut arthur).unwrap(); + let e_rx_commitment = commitment_reader.parse_commitment(&mut arthur).unwrap(); + let e_ry_commitment = commitment_reader.parse_commitment(&mut arthur).unwrap(); + let final_row_commitment = commitment_reader_row.parse_commitment(&mut arthur).unwrap(); + + let (randomness, last_sumcheck_value) = run_sumcheck_verifier_spark( + &mut arthur, + next_power_of_two(spark_proof.matrix_dimensions.a_nonzero_terms), + request.claimed_values.a, + ) + .context("While verifying SPARK sumcheck")?; + + let final_folds: Vec = arthur.hint()?; + + let mut val_statement_verifier = Statement::::new(next_power_of_two( + spark_proof.matrix_dimensions.a_nonzero_terms, + )); + val_statement_verifier.add_constraint( + Weights::evaluation(MultilinearPoint(randomness.clone())), + final_folds[0], + ); + let val_verifier = Verifier::new(&spark_proof.whir_params.a); + val_verifier + .verify(&mut arthur, &val_commitment, &val_statement_verifier) + .context("while verifying WHIR")?; + + let mut e_rx_statement_verifier = Statement::::new(next_power_of_two( + spark_proof.matrix_dimensions.a_nonzero_terms, + )); + e_rx_statement_verifier.add_constraint( + Weights::evaluation(MultilinearPoint(randomness.clone())), + final_folds[1], + ); + let e_rx_verifier = Verifier::new(&spark_proof.whir_params.a); + e_rx_verifier + .verify(&mut arthur, &e_rx_commitment, &e_rx_statement_verifier) + .context("while verifying WHIR")?; + + let mut e_ry_statement_verifier = Statement::::new(next_power_of_two( + spark_proof.matrix_dimensions.a_nonzero_terms, + )); + e_ry_statement_verifier.add_constraint( + Weights::evaluation(MultilinearPoint(randomness.clone())), + final_folds[2], + ); + let e_ry_verifier = Verifier::new(&spark_proof.whir_params.a); + e_ry_verifier + .verify(&mut arthur, &e_ry_commitment, &e_ry_statement_verifier) + .context("while verifying WHIR")?; + + let mut tau_and_gamma = [FieldElement::from(0); 2]; + arthur.fill_challenge_scalars(&mut tau_and_gamma)?; + let tau = tau_and_gamma[0]; + let gamma = tau_and_gamma[1]; + + let gpa_result = gpa_sumcheck_verifier( + &mut arthur, + next_power_of_two(spark_proof.matrix_dimensions.num_rows) + 2, + )?; + + let (last_randomness, evaluation_randomness) = gpa_result.randomness.split_at(1); + + let init_adr = calculate_adr(&evaluation_randomness.to_vec()); + let init_mem = calculate_eq( + &request.point_to_evaluate.row, + &evaluation_randomness.to_vec(), + ); + let init_cntr = FieldElement::from(0); + + let init_opening = init_adr * gamma * gamma + init_mem * gamma + init_cntr - tau; + + let mut final_cntr: FieldElement = arthur.hint()?; + + let mut final_cntr_statement = + Statement::::new(next_power_of_two(spark_proof.matrix_dimensions.num_rows)); + final_cntr_statement.add_constraint( + Weights::evaluation(MultilinearPoint(evaluation_randomness.to_vec().clone())), + final_cntr, + ); + + let final_cntr_verifier = Verifier::new(&spark_proof.whir_params.row); + final_cntr_verifier + .verify(&mut arthur, &final_row_commitment, &final_cntr_statement) + .context("while verifying WHIR")?; + + let final_adr = calculate_adr(&evaluation_randomness.to_vec()); + let final_mem = calculate_eq( + &request.point_to_evaluate.row, + &evaluation_randomness.to_vec(), + ); + + let final_opening = final_adr * gamma * gamma + final_mem * gamma + final_cntr - tau; + + let evaluated_value = init_opening * (FieldElement::one() - last_randomness[0]) + + final_opening * last_randomness[0]; + + ensure!(evaluated_value == gpa_result.last_sumcheck_value); + + Ok(()) +} + +pub fn run_sumcheck_verifier_spark( + arthur: &mut VerifierState, + variable_count: usize, + initial_sumcheck_val: FieldElement, +) -> Result<(Vec, FieldElement)> { + let mut saved_val_for_sumcheck_equality_assertion = initial_sumcheck_val; + + let mut alpha = vec![FieldElement::zero(); variable_count]; + + for i in 0..variable_count { + let mut hhat_i = [FieldElement::zero(); 4]; + let mut alpha_i = [FieldElement::zero(); 1]; + let _ = arthur.fill_next_scalars(&mut hhat_i); + let _ = arthur.fill_challenge_scalars(&mut alpha_i); + alpha[i] = alpha_i[0]; + + let hhat_i_at_zero = eval_cubic_poly(&hhat_i, &FieldElement::zero()); + let hhat_i_at_one = eval_cubic_poly(&hhat_i, &FieldElement::one()); + ensure!( + saved_val_for_sumcheck_equality_assertion == hhat_i_at_zero + hhat_i_at_one, + "Sumcheck equality assertion failed" + ); + saved_val_for_sumcheck_equality_assertion = eval_cubic_poly(&hhat_i, &alpha_i[0]); + } + + Ok((alpha, saved_val_for_sumcheck_equality_assertion)) +} + +pub fn gpa_sumcheck_verifier( + arthur: &mut VerifierState, + height_of_binary_tree: usize, +) -> Result { + let mut prev_rand = Vec::::new(); + let mut rand = Vec::::new(); + let mut claimed_values = [FieldElement::from(0); 2]; + let mut l = [FieldElement::from(0); 2]; + let mut r = [FieldElement::from(0); 1]; + let mut h = [FieldElement::from(0); 4]; + let mut alpha = [FieldElement::from(0); 1]; + + arthur + .fill_next_scalars(&mut claimed_values) + .expect("Failed to fill next scalars"); + arthur + .fill_challenge_scalars(&mut r) + .expect("Failed to fill next scalars"); + let mut last_sumcheck_value = eval_linear_poly(&claimed_values, &r[0]); + + rand.push(r[0]); + prev_rand = rand; + rand = Vec::::new(); + + for i in 1..(height_of_binary_tree - 1) { + for _ in 0..i { + arthur + .fill_next_scalars(&mut h) + .expect("Failed to fill next scalars"); + arthur + .fill_challenge_scalars(&mut alpha) + .expect("Failed to fill next scalars"); + assert_eq!( + eval_cubic_poly(&h, &FieldElement::from(0)) + + eval_cubic_poly(&h, &FieldElement::from(1)), + last_sumcheck_value + ); + rand.push(alpha[0]); + last_sumcheck_value = eval_cubic_poly(&h, &alpha[0]); + } + arthur + .fill_next_scalars(&mut l) + .expect("Failed to fill next scalars"); + arthur + .fill_challenge_scalars(&mut r) + .expect("Failed to fill next scalars"); + let claimed_last_sch = calculate_eq(&prev_rand, &rand) + * eval_linear_poly(&l, &FieldElement::from(0)) + * eval_linear_poly(&l, &FieldElement::from(1)); + assert_eq!(claimed_last_sch, last_sumcheck_value); + rand.push(r[0]); + prev_rand = rand; + rand = Vec::::new(); + last_sumcheck_value = eval_linear_poly(&l, &r[0]); + } + + Ok(GPASumcheckResult { + claimed_values: claimed_values.to_vec(), + last_sumcheck_value, + randomness: prev_rand, + }) +} + +pub struct GPASumcheckResult { + pub claimed_values: Vec, + pub last_sumcheck_value: FieldElement, + pub randomness: Vec, +} + +pub fn eval_linear_poly(poly: &[FieldElement], point: &FieldElement) -> FieldElement { + poly[0] + *point * poly[1] +} + +pub fn calculate_adr(alpha: &Vec) -> FieldElement { + let mut ans = FieldElement::from(0); + let mut mult = FieldElement::from(1); + for a in alpha.iter().rev() { + ans = ans + *a * mult; + mult = mult * FieldElement::from(2); + } + ans +} diff --git a/spark-prover/src/gpa.rs b/spark-prover/src/gpa.rs new file mode 100644 index 00000000..6b8f6ebb --- /dev/null +++ b/spark-prover/src/gpa.rs @@ -0,0 +1,183 @@ +use { + provekit_common::{ + utils::{ + sumcheck::{ + calculate_evaluations_over_boolean_hypercube_for_eq, eval_cubic_poly, + sumcheck_fold_map_reduce, + }, + HALF, + }, + FieldElement, skyscraper::SkyscraperSponge, + }, + spongefish::{ + codecs::arkworks_algebra::{FieldToUnitSerialize, UnitToField}, + ProverState, + }, + whir::poly_utils::evals::EvaluationsList, +}; + +// TODO: Fix gpa and add line integration + +pub fn run_gpa( + merlin: &mut ProverState, + left: &Vec, + right: &Vec, +) -> Vec { + let mut h = left.clone(); + h.extend(right.iter().cloned()); + let layers = calculate_binary_multiplication_tree(h); + + let mut saved_val_for_sumcheck_equality_assertion; + let mut r; + let mut line_evaluations; + let mut alpha = Vec::::new(); + + (r, saved_val_for_sumcheck_equality_assertion) = add_line_to_merlin(merlin, layers[1].clone()); + + for i in 2..layers.len() { + (line_evaluations, alpha) = run_gpa_sumcheck( + merlin, + &r, + layers[i].clone(), + saved_val_for_sumcheck_equality_assertion, + alpha, + ); + (r, saved_val_for_sumcheck_equality_assertion) = + add_line_to_merlin(merlin, line_evaluations.to_vec()); + } + + alpha.push(r[0]); + + return alpha; +} + +fn calculate_binary_multiplication_tree( + array_to_prove: Vec, +) -> Vec> { + let mut layers = vec![]; + let mut current_layer = array_to_prove; + + while current_layer.len() > 1 { + let mut next_layer = vec![]; + + for i in (0..current_layer.len()).step_by(2) { + let product = current_layer[i] * current_layer[i + 1]; + next_layer.push(product); + } + + layers.push(current_layer); + current_layer = next_layer; + } + + layers.push(current_layer); + layers.reverse(); + layers +} + +fn add_line_to_merlin( + merlin: &mut ProverState, + arr: Vec, +) -> ([FieldElement; 1], FieldElement) { + let l_evaluations = EvaluationsList::new(arr); + let l_temp = l_evaluations.to_coeffs(); + let l: &[FieldElement] = l_temp.coeffs(); + merlin.add_scalars(&l).expect("Failed to add l"); + + let mut r = [FieldElement::from(0); 1]; + merlin + .fill_challenge_scalars(&mut r) + .expect("Failed to add a challenge scalar"); + + let saved_val_for_sumcheck_equality_assertion = l[0] + l[1] * r[0]; + + (r, saved_val_for_sumcheck_equality_assertion) +} + +fn run_gpa_sumcheck( + merlin: &mut ProverState, + r: &[FieldElement; 1], + layer: Vec, + mut saved_val_for_sumcheck_equality_assertion: FieldElement, + mut alpha: Vec, +) -> ([FieldElement; 2], Vec) { + let (mut v0, mut v1) = split_by_index(layer); + alpha.push(r[0]); + let mut eq_r = calculate_evaluations_over_boolean_hypercube_for_eq(&alpha); + let mut alpha_i_wrapped_in_vector = [FieldElement::from(0)]; + let mut alpha = Vec::::new(); + let mut fold = None; + + loop { + let [hhat_i_at_0, hhat_i_at_em1, hhat_i_at_inf_over_x_cube] = + sumcheck_fold_map_reduce([&mut eq_r, &mut v0, &mut v1], fold, |[eq_r, v0, v1]| { + [ + // Evaluation at 0 + eq_r.0 * v0.0 * v1.0, + // Evaluation at -1 + (eq_r.0 + eq_r.0 - eq_r.1) * (v0.0 + v0.0 - v0.1) * (v1.0 + v1.0 - v1.1), + // Evaluation at infinity + (eq_r.1 - eq_r.0) * (v0.1 - v0.0) * (v1.1 - v1.0), + ] + }); + + if fold.is_some() { + eq_r.truncate(eq_r.len() / 2); + v0.truncate(v0.len() / 2); + v1.truncate(v1.len() / 2); + } + + let mut hhat_i_coeffs = [FieldElement::from(0); 4]; + + hhat_i_coeffs[0] = hhat_i_at_0; + hhat_i_coeffs[2] = HALF + * (saved_val_for_sumcheck_equality_assertion + hhat_i_at_em1 + - hhat_i_at_0 + - hhat_i_at_0 + - hhat_i_at_0); + hhat_i_coeffs[3] = hhat_i_at_inf_over_x_cube; + hhat_i_coeffs[1] = saved_val_for_sumcheck_equality_assertion + - hhat_i_coeffs[0] + - hhat_i_coeffs[0] + - hhat_i_coeffs[3] + - hhat_i_coeffs[2]; + + assert_eq!( + saved_val_for_sumcheck_equality_assertion, + hhat_i_coeffs[0] + + hhat_i_coeffs[0] + + hhat_i_coeffs[1] + + hhat_i_coeffs[2] + + hhat_i_coeffs[3] + ); + + let _ = merlin.add_scalars(&hhat_i_coeffs[..]); + let _ = merlin.fill_challenge_scalars(&mut alpha_i_wrapped_in_vector); + fold = Some(alpha_i_wrapped_in_vector[0]); + saved_val_for_sumcheck_equality_assertion = + eval_cubic_poly(&hhat_i_coeffs, &alpha_i_wrapped_in_vector[0]); + alpha.push(alpha_i_wrapped_in_vector[0]); + if eq_r.len() <= 2 { + break; + } + } + + let folded_v0 = v0[0] + (v0[1] - v0[0]) * alpha_i_wrapped_in_vector[0]; + let folded_v1 = v1[0] + (v1[1] - v1[0]) * alpha_i_wrapped_in_vector[0]; + + ([folded_v0, folded_v1], alpha) +} + +fn split_by_index(input: Vec) -> (Vec, Vec) { + let mut even_indexed = Vec::new(); + let mut odd_indexed = Vec::new(); + + for (i, item) in input.into_iter().enumerate() { + if i % 2 == 0 { + even_indexed.push(item); + } else { + odd_indexed.push(item); + } + } + + (even_indexed, odd_indexed) +} diff --git a/spark-prover/src/lib.rs b/spark-prover/src/lib.rs new file mode 100644 index 00000000..c2e08d50 --- /dev/null +++ b/spark-prover/src/lib.rs @@ -0,0 +1,5 @@ +pub mod gpa; +pub mod memory; +pub mod spark; +pub mod utilities; +pub mod whir; diff --git a/spark-prover/src/main.rs b/spark-prover/src/main.rs new file mode 100644 index 00000000..9932b9bd --- /dev/null +++ b/spark-prover/src/main.rs @@ -0,0 +1,54 @@ +use { + anyhow::{Context, Result}, + provekit_common::utils::next_power_of_two, + spark_prover::{ + memory::{calculate_e_values_for_r1cs, calculate_memory}, + spark::prove_spark_for_single_matrix, + utilities::{ + calculate_matrix_dimensions, create_io_pattern, deserialize_r1cs, deserialize_request, + get_spark_r1cs, SPARKProof, + }, + whir::create_whir_configs, + }, + std::{fs::File, io::Write}, +}; + +fn main() -> Result<()> { + // Run once when receiving the matrix + let r1cs = deserialize_r1cs("spark-prover/r1cs.json") + .context("Error: Failed to create the R1CS object")?; + let spark_r1cs = get_spark_r1cs(&r1cs); + let spark_whir_configs = create_whir_configs(&r1cs); + + // Run for each request + let request = deserialize_request("spark-prover/request.json") + .context("Error: Failed to deserialize the request object")?; + let memory = calculate_memory(request.point_to_evaluate); + let e_values = calculate_e_values_for_r1cs(&memory, &r1cs); + let io_pattern = create_io_pattern(&r1cs, &spark_whir_configs); + let mut merlin = io_pattern.to_prover_state(); + + prove_spark_for_single_matrix( + &mut merlin, + spark_r1cs.a, + memory, + e_values.a, + request.claimed_values.a, + &spark_whir_configs, + )?; + + let spark_proof = SPARKProof { + transcript: merlin.narg_string().to_vec(), + io_pattern: String::from_utf8(io_pattern.as_bytes().to_vec()).unwrap(), + whir_params: spark_whir_configs, + matrix_dimensions: calculate_matrix_dimensions(&r1cs), + }; + + let mut spark_proof_file = File::create("spark-prover/spark_proof.json") + .context("Error: Failed to create the spark proof file")?; + spark_proof_file + .write_all(serde_json::to_string(&spark_proof).unwrap().as_bytes()) + .expect("Writing gnark parameters to a file failed"); + + Ok(()) +} diff --git a/spark-prover/src/memory.rs b/spark-prover/src/memory.rs new file mode 100644 index 00000000..76016fe6 --- /dev/null +++ b/spark-prover/src/memory.rs @@ -0,0 +1,55 @@ +use { + crate::utilities::Point, + provekit_common::{ + utils::sumcheck::calculate_evaluations_over_boolean_hypercube_for_eq, FieldElement, + HydratedSparseMatrix, R1CS, + }, +}; + +#[derive(Debug)] +pub struct Memory { + pub eq_rx: Vec, + pub eq_ry: Vec, +} + +#[derive(Debug)] +pub struct EValuesForMatrix { + pub e_rx: Vec, + pub e_ry: Vec, +} + +#[derive(Debug)] +pub struct EValues { + pub a: EValuesForMatrix, + pub b: EValuesForMatrix, + pub c: EValuesForMatrix, +} + +pub fn calculate_memory(point_to_evaluate: Point) -> Memory { + Memory { + eq_rx: calculate_evaluations_over_boolean_hypercube_for_eq(&point_to_evaluate.row), + eq_ry: calculate_evaluations_over_boolean_hypercube_for_eq(&point_to_evaluate.col), + } +} + +pub fn calculate_e_values_for_r1cs(memory: &Memory, r1cs: &R1CS) -> EValues { + EValues { + a: calculate_e_values_for_matrix(memory, &r1cs.a()), + b: calculate_e_values_for_matrix(memory, &r1cs.b()), + c: calculate_e_values_for_matrix(memory, &r1cs.c()), + } +} + +pub fn calculate_e_values_for_matrix( + memory: &Memory, + matrix: &HydratedSparseMatrix, +) -> EValuesForMatrix { + let mut e_rx = Vec::::new(); + let mut e_ry = Vec::::new(); + + for ((r, c), _) in matrix.iter() { + e_rx.push(memory.eq_rx[r]); + e_ry.push(memory.eq_ry[c]); + } + EValuesForMatrix { e_rx, e_ry } +} diff --git a/spark-prover/src/spark.rs b/spark-prover/src/spark.rs new file mode 100644 index 00000000..e76f6a6f --- /dev/null +++ b/spark-prover/src/spark.rs @@ -0,0 +1,192 @@ +use { + crate::{ + gpa::run_gpa, + memory::{EValuesForMatrix, Memory}, + utilities::matrix::SparkMatrix, + whir::{commit_to_vector, produce_whir_proof, SPARKWHIRConfigs}, + }, + anyhow::Result, + itertools::izip, + provekit_common::{ + utils::{ + sumcheck::{eval_cubic_poly, sumcheck_fold_map_reduce}, + HALF, + }, + FieldElement, skyscraper::SkyscraperSponge, + }, + spongefish::{ + codecs::arkworks_algebra::{FieldToUnitSerialize, UnitToField}, + ProverState, + }, + whir::{ + poly_utils::{evals::EvaluationsList, multilinear::MultilinearPoint}, + whir::{committer::CommitmentWriter, utils::HintSerialize}, + }, +}; + +pub fn prove_spark_for_single_matrix( + merlin: &mut ProverState, + matrix: SparkMatrix, + memory: Memory, + e_values: EValuesForMatrix, + claimed_value: FieldElement, + whir_configs: &SPARKWHIRConfigs, +) -> Result<()> { + let committer_a = CommitmentWriter::new(whir_configs.a.clone()); + let committer_row = CommitmentWriter::new(whir_configs.row.clone()); + + let val_witness = commit_to_vector(&committer_a, merlin, matrix.coo.val.clone()); + let e_rx_witness = commit_to_vector(&committer_a, merlin, e_values.e_rx.clone()); + let e_ry_witness = commit_to_vector(&committer_a, merlin, e_values.e_ry.clone()); + + let final_row_witness = + commit_to_vector(&committer_row, merlin, matrix.timestamps.final_row.clone()); + + let mles = [matrix.coo.val.clone(), e_values.e_rx, e_values.e_ry]; + let (sumcheck_final_folds, folding_randomness) = + run_spark_sumcheck(merlin, mles, claimed_value)?; + + produce_whir_proof( + merlin, + MultilinearPoint(folding_randomness.clone()), + sumcheck_final_folds[0], + whir_configs.a.clone(), + val_witness, + )?; + + produce_whir_proof( + merlin, + MultilinearPoint(folding_randomness.clone()), + sumcheck_final_folds[1], + whir_configs.b.clone(), + e_rx_witness, + )?; + + produce_whir_proof( + merlin, + MultilinearPoint(folding_randomness.clone()), + sumcheck_final_folds[2], + whir_configs.c.clone(), + e_ry_witness, + )?; + + // Rowwise + + let mut tau_and_gamma = [FieldElement::from(0); 2]; + merlin.fill_challenge_scalars(&mut tau_and_gamma)?; + let tau = tau_and_gamma[0]; + let gamma = tau_and_gamma[1]; + + let init_address: Vec = (0..memory.eq_rx.len() as u64) + .map(FieldElement::from) + .collect(); + let init_value = memory.eq_rx.clone(); + let init_timestamp = vec![FieldElement::from(0); memory.eq_rx.len()]; + + let init_vec: Vec = izip!(init_address, init_value, init_timestamp) + .map(|(a, v, t)| a * gamma * gamma + v * gamma + t - tau) + .collect(); + + let final_address: Vec = (0..memory.eq_rx.len() as u64) + .map(FieldElement::from) + .collect(); + let final_value = memory.eq_rx.clone(); + let final_timestamp = matrix.timestamps.final_row.clone(); + + let final_vec: Vec = izip!(final_address, final_value, final_timestamp) + .map(|(a, v, t)| a * gamma * gamma + v * gamma + t - tau) + .collect(); + + let gpa_randomness = run_gpa(merlin, &init_vec, &final_vec); + + let (combination_randomness, evaluation_randomness) = gpa_randomness.split_at(1); + + // TODO: Can I avoid evaluating here? + let final_row_eval = EvaluationsList::new(matrix.timestamps.final_row.clone()) + .evaluate(&MultilinearPoint(evaluation_randomness.to_vec().clone())); + merlin.hint(&final_row_eval)?; + + produce_whir_proof( + merlin, + MultilinearPoint(evaluation_randomness.to_vec()), + final_row_eval, + whir_configs.row.clone(), + final_row_witness, + )?; + + Ok(()) +} + +pub fn run_spark_sumcheck( + merlin: &mut ProverState, + mles: [Vec; 3], + mut claimed_value: FieldElement, +) -> Result<([FieldElement; 3], Vec)> { + let mut sumcheck_randomness = [FieldElement::from(0)]; + let mut sumcheck_randomness_accumulator = Vec::::new(); + let mut fold = None; + + let mut m0 = mles[0].clone(); + let mut m1 = mles[1].clone(); + let mut m2 = mles[2].clone(); + + loop { + let [hhat_i_at_0, hhat_i_at_em1, hhat_i_at_inf_over_x_cube] = + sumcheck_fold_map_reduce([&mut m0, &mut m1, &mut m2], fold, |[m0, m1, m2]| { + [ + // Evaluation at 0 + m0.0 * m1.0 * m2.0, + // Evaluation at -1 + (m0.0 + m0.0 - m0.1) * (m1.0 + m1.0 - m1.1) * (m2.0 + m2.0 - m2.1), + // Evaluation at infinity + (m0.1 - m0.0) * (m1.1 - m1.0) * (m2.1 - m2.0), + ] + }); + + if fold.is_some() { + m0.truncate(m0.len() / 2); + m1.truncate(m1.len() / 2); + m2.truncate(m2.len() / 2); + } + + let mut hhat_i_coeffs = [FieldElement::from(0); 4]; + + hhat_i_coeffs[0] = hhat_i_at_0; + hhat_i_coeffs[2] = + HALF * (claimed_value + hhat_i_at_em1 - hhat_i_at_0 - hhat_i_at_0 - hhat_i_at_0); + hhat_i_coeffs[3] = hhat_i_at_inf_over_x_cube; + hhat_i_coeffs[1] = claimed_value + - hhat_i_coeffs[0] + - hhat_i_coeffs[0] + - hhat_i_coeffs[3] + - hhat_i_coeffs[2]; + + assert_eq!( + claimed_value, + hhat_i_coeffs[0] + + hhat_i_coeffs[0] + + hhat_i_coeffs[1] + + hhat_i_coeffs[2] + + hhat_i_coeffs[3] + ); + + merlin.add_scalars(&hhat_i_coeffs[..])?; + merlin.fill_challenge_scalars(&mut sumcheck_randomness)?; + fold = Some(sumcheck_randomness[0]); + claimed_value = eval_cubic_poly(&hhat_i_coeffs, &sumcheck_randomness[0]); + sumcheck_randomness_accumulator.push(sumcheck_randomness[0]); + if m0.len() <= 2 { + break; + } + } + + let folded_v0 = m0[0] + (m0[1] - m0[0]) * sumcheck_randomness[0]; + let folded_v1 = m1[0] + (m1[1] - m1[0]) * sumcheck_randomness[0]; + let folded_v2 = m2[0] + (m2[1] - m2[0]) * sumcheck_randomness[0]; + + merlin.hint::>(&[folded_v0, folded_v1, folded_v2].to_vec())?; + Ok(( + [folded_v0, folded_v1, folded_v2], + sumcheck_randomness_accumulator, + )) +} diff --git a/spark-prover/src/utilities/iopattern/mod.rs b/spark-prover/src/utilities/iopattern/mod.rs new file mode 100644 index 00000000..33d1f313 --- /dev/null +++ b/spark-prover/src/utilities/iopattern/mod.rs @@ -0,0 +1,54 @@ +use { + crate::whir::SPARKWHIRConfigs, + provekit_common::{ + utils::{next_power_of_two, sumcheck::SumcheckIOPattern}, + FieldElement, IOPattern, R1CS, + }, + spongefish::codecs::arkworks_algebra::FieldDomainSeparator, + whir::whir::domainsep::WhirDomainSeparator, +}; + +pub trait SPARKDomainSeparator { + fn add_tau_and_gamma(self) -> Self; + + fn add_line(self) -> Self; +} + +impl SPARKDomainSeparator for IOPattern +where + IOPattern: FieldDomainSeparator, +{ + fn add_tau_and_gamma(self) -> Self { + self.challenge_scalars(2, "tau and gamma") + } + + fn add_line(self) -> Self { + self.add_scalars(2, "gpa line") + .challenge_scalars(1, "gpa line random") + } +} + +pub fn create_io_pattern(r1cs: &R1CS, configs: &SPARKWHIRConfigs) -> IOPattern { + let mut io = IOPattern::new("💥") + .commit_statement(&configs.a) + .commit_statement(&configs.a) + .commit_statement(&configs.a) + .commit_statement(&configs.row) + .add_sumcheck_polynomials(next_power_of_two(r1cs.a.num_entries())) + .hint("sumcheck_last_folds") + .add_whir_proof(&configs.a) + .add_whir_proof(&configs.a) + .add_whir_proof(&configs.a) + .add_tau_and_gamma(); + + for i in 0..=next_power_of_two(r1cs.a.num_rows) { + io = io.add_sumcheck_polynomials(i); + io = io.add_line(); + } + + io = io + .hint("Row final counter claimed evaluation") + .add_whir_proof(&configs.row); + + io +} diff --git a/spark-prover/src/utilities/matrix/mod.rs b/spark-prover/src/utilities/matrix/mod.rs new file mode 100644 index 00000000..5cba9545 --- /dev/null +++ b/spark-prover/src/utilities/matrix/mod.rs @@ -0,0 +1,86 @@ +use provekit_common::{FieldElement, HydratedSparseMatrix, SparseMatrix, R1CS}; + +#[derive(Debug)] +pub struct SparkR1CS { + pub a: SparkMatrix, + pub b: SparkMatrix, + pub c: SparkMatrix, +} +#[derive(Debug)] +pub struct SparkMatrix { + pub coo: COOMatrix, + pub timestamps: TimeStamps, +} +#[derive(Debug)] +pub struct COOMatrix { + pub row: Vec, + pub col: Vec, + pub val: Vec, +} +#[derive(Debug)] +pub struct TimeStamps { + pub read_row: Vec, + pub read_col: Vec, + pub final_row: Vec, + pub final_col: Vec, +} + +pub fn get_spark_r1cs(r1cs: &R1CS) -> SparkR1CS { + SparkR1CS { + a: get_spark_matrix(&r1cs.a()), + b: get_spark_matrix(&r1cs.b()), + c: get_spark_matrix(&r1cs.c()), + } +} + +pub fn get_spark_matrix(matrix: &HydratedSparseMatrix) -> SparkMatrix { + SparkMatrix { + coo: get_coordinate_rep_of_a_matrix(matrix), + timestamps: calculate_timestamps(matrix), + } +} + +pub fn get_coordinate_rep_of_a_matrix(matrix: &HydratedSparseMatrix) -> COOMatrix { + let mut row = Vec::::new(); + let mut col = Vec::::new(); + let mut val = Vec::::new(); + + for ((r, c), value) in matrix.iter() { + row.push(FieldElement::from(r as u64)); + col.push(FieldElement::from(c as u64)); + val.push(value.clone()); + } + + COOMatrix { row, col, val } +} + +pub fn calculate_timestamps(matrix: &HydratedSparseMatrix) -> TimeStamps { + let mut read_row_counters = vec![0; matrix.matrix.num_rows]; + let mut read_row = Vec::::new(); + let mut read_col_counters = vec![0; matrix.matrix.num_cols]; + let mut read_col = Vec::::new(); + + for ((r, c), _) in matrix.iter() { + read_row.push(FieldElement::from(read_row_counters[r] as u64)); + read_row_counters[r] += 1; + read_col.push(FieldElement::from(read_col_counters[c] as u64)); + read_col_counters[c] += 1; + } + + let final_row = read_row_counters + .iter() + .map(|&x| FieldElement::from(x as u64)) + .collect::>(); + + let final_col = read_col_counters + .iter() + .map(|&x| FieldElement::from(x as u64)) + .collect::>(); + + TimeStamps { + read_row, + read_col, + final_row, + final_col, + } +} diff --git a/spark-prover/src/utilities/mod.rs b/spark-prover/src/utilities/mod.rs new file mode 100644 index 00000000..b967aef1 --- /dev/null +++ b/spark-prover/src/utilities/mod.rs @@ -0,0 +1,76 @@ +mod iopattern; +pub mod matrix; +use { + crate::whir::SPARKWHIRConfigs, + anyhow::{Context, Result}, + provekit_common::{ + utils::{serde_ark, sumcheck::calculate_evaluations_over_boolean_hypercube_for_eq}, + FieldElement, HydratedSparseMatrix, WhirConfig, R1CS, + }, + serde::{Deserialize, Serialize}, + std::fs, +}; +pub use {iopattern::create_io_pattern, matrix::get_spark_r1cs}; + +pub fn deserialize_r1cs(path_str: &str) -> Result { + let json_str = + fs::read_to_string(path_str).context("Error: Failed to open the r1cs.json file")?; + serde_json::from_str(&json_str).context("Error: Failed to deserialize JSON to R1CS") +} + +pub fn deserialize_request(path_str: &str) -> Result { + let json_str = + fs::read_to_string(path_str).context("Error: Failed to open the request.json file")?; + serde_json::from_str(&json_str).context("Error: Failed to deserialize JSON to R1CS") +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct SPARKRequest { + pub point_to_evaluate: Point, + pub claimed_values: ClaimedValues, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct Point { + #[serde(with = "serde_ark")] + pub row: Vec, + #[serde(with = "serde_ark")] + pub col: Vec, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct ClaimedValues { + #[serde(with = "serde_ark")] + pub a: FieldElement, + #[serde(with = "serde_ark")] + pub b: FieldElement, + #[serde(with = "serde_ark")] + pub c: FieldElement, +} + +#[derive(Serialize, Deserialize)] +pub struct SPARKProof { + pub transcript: Vec, + pub io_pattern: String, + pub whir_params: SPARKWHIRConfigs, + pub matrix_dimensions: MatrixDimensions, +} + +#[derive(Serialize, Deserialize)] +pub struct MatrixDimensions { + pub num_rows: usize, + pub num_cols: usize, + pub a_nonzero_terms: usize, + pub b_nonzero_terms: usize, + pub c_nonzero_terms: usize, +} + +pub fn calculate_matrix_dimensions(r1cs: &R1CS) -> MatrixDimensions { + MatrixDimensions { + num_rows: r1cs.a.num_rows, + num_cols: r1cs.a.num_cols, + a_nonzero_terms: r1cs.a.num_entries(), + b_nonzero_terms: r1cs.b.num_entries(), + c_nonzero_terms: r1cs.c.num_entries(), + } +} diff --git a/spark-prover/src/whir.rs b/spark-prover/src/whir.rs new file mode 100644 index 00000000..fd9d81d6 --- /dev/null +++ b/spark-prover/src/whir.rs @@ -0,0 +1,71 @@ +use { + anyhow::{Context, Result}, + provekit_common::{ + WhirR1CSScheme, utils::next_power_of_two, FieldElement, skyscraper::SkyscraperMerkleConfig, + skyscraper::SkyscraperPoW, skyscraper::SkyscraperSponge, WhirConfig, R1CS, + }, + provekit_r1cs_compiler::WhirR1CSSchemeBuilder, + serde::{Deserialize, Serialize}, + spongefish::ProverState, + whir::{ + poly_utils::{evals::EvaluationsList, multilinear::MultilinearPoint}, + whir::{ + committer::{CommitmentWriter, Witness}, + prover::Prover, + statement::{Statement, Weights}, + }, + }, +}; + +pub fn commit_to_vector( + committer: &CommitmentWriter, + merlin: &mut ProverState, + vector: Vec, +) -> Witness { + assert!( + vector.len().is_power_of_two(), + "Committed vector length must be a power of two" + ); + let evals = EvaluationsList::new(vector); + let coeffs = evals.to_coeffs(); + committer + .commit(merlin, coeffs) + .expect("WHIR prover failed to commit") +} + +#[derive(Serialize, Deserialize)] +pub struct SPARKWHIRConfigs { + pub row: WhirConfig, + pub col: WhirConfig, + pub a: WhirConfig, + pub b: WhirConfig, + pub c: WhirConfig, +} + +pub fn create_whir_configs(r1cs: &R1CS) -> SPARKWHIRConfigs { + SPARKWHIRConfigs { + row: WhirR1CSScheme::new_whir_config_for_size(next_power_of_two(r1cs.a.num_rows), 1), + col: WhirR1CSScheme::new_whir_config_for_size(next_power_of_two(r1cs.a.num_cols), 1), + a: WhirR1CSScheme::new_whir_config_for_size(next_power_of_two(r1cs.a.num_entries()), 1), + b: WhirR1CSScheme::new_whir_config_for_size(next_power_of_two(r1cs.b.num_entries()), 1), + c: WhirR1CSScheme::new_whir_config_for_size(next_power_of_two(r1cs.c.num_entries()), 1), + } +} + +pub fn produce_whir_proof( + merlin: &mut ProverState, + evaluation_point: MultilinearPoint, + evaluated_value: FieldElement, + config: WhirConfig, + witness: Witness, +) -> Result<()> { + let mut statement = Statement::::new(evaluation_point.num_variables()); + statement.add_constraint(Weights::evaluation(evaluation_point), evaluated_value); + let prover = Prover(config); + + prover + .prove(merlin, statement, witness) + .context("while generating WHIR proof")?; + + Ok(()) +} From 03e2726d6a548758faee4e03cb8df31c646717db Mon Sep 17 00:00:00 2001 From: Batmend Batsaikhan Date: Tue, 16 Sep 2025 16:58:49 +0800 Subject: [PATCH 2/7] Adds RS and WS --- provekit/common/src/lib.rs | 6 +- spark-prover/src/bin/generate_test_r1cs.rs | 4 +- spark-prover/src/bin/spark-verifier.rs | 75 ++++++++++++++++++- spark-prover/src/gpa.rs | 3 +- spark-prover/src/spark.rs | 82 +++++++++++++++++++-- spark-prover/src/utilities/iopattern/mod.rs | 18 +++++ spark-prover/src/whir.rs | 5 +- 7 files changed, 176 insertions(+), 17 deletions(-) diff --git a/provekit/common/src/lib.rs b/provekit/common/src/lib.rs index 0e9288f2..68efb571 100644 --- a/provekit/common/src/lib.rs +++ b/provekit/common/src/lib.rs @@ -8,14 +8,12 @@ pub mod utils; mod whir_r1cs; pub mod witness; -use crate::{ - interner::{InternedFieldElement, Interner}, -}; +use crate::interner::{InternedFieldElement, Interner}; pub use { - sparse_matrix::{HydratedSparseMatrix, SparseMatrix}, acir::FieldElement as NoirElement, noir_proof_scheme::{NoirProof, NoirProofScheme}, r1cs::R1CS, + sparse_matrix::{HydratedSparseMatrix, SparseMatrix}, whir::crypto::fields::Field256 as FieldElement, whir_r1cs::{IOPattern, WhirConfig, WhirR1CSProof, WhirR1CSScheme}, }; diff --git a/spark-prover/src/bin/generate_test_r1cs.rs b/spark-prover/src/bin/generate_test_r1cs.rs index 1a1cb7e0..5c0c3fc7 100644 --- a/spark-prover/src/bin/generate_test_r1cs.rs +++ b/spark-prover/src/bin/generate_test_r1cs.rs @@ -16,8 +16,8 @@ fn main() { let matrix_json = serde_json::to_string(&r1cs).expect("Error: Failed to serialize R1CS to JSON"); - let mut request_file = File::create("spark-prover/r1cs.json") - .expect("Error: Failed to create the r1cs.json file"); + let mut request_file = + File::create("spark-prover/r1cs.json").expect("Error: Failed to create the r1cs.json file"); request_file .write_all(matrix_json.as_bytes()) .expect("Error: Failed to write JSON data to r1cs.json"); diff --git a/spark-prover/src/bin/spark-verifier.rs b/spark-prover/src/bin/spark-verifier.rs index 9334ee82..f8fcc03a 100644 --- a/spark-prover/src/bin/spark-verifier.rs +++ b/spark-prover/src/bin/spark-verifier.rs @@ -2,11 +2,12 @@ use { anyhow::{ensure, Context, Result}, ark_std::{One, Zero}, provekit_common::{ + skyscraper::SkyscraperSponge, utils::{ next_power_of_two, sumcheck::{calculate_eq, eval_cubic_poly}, }, - FieldElement, IOPattern, skyscraper::SkyscraperSponge, + FieldElement, IOPattern, }, spark_prover::utilities::{SPARKProof, SPARKRequest}, spongefish::{ @@ -46,6 +47,8 @@ fn main() -> Result<()> { let e_rx_commitment = commitment_reader.parse_commitment(&mut arthur).unwrap(); let e_ry_commitment = commitment_reader.parse_commitment(&mut arthur).unwrap(); let final_row_commitment = commitment_reader_row.parse_commitment(&mut arthur).unwrap(); + let row_commitment = commitment_reader.parse_commitment(&mut arthur).unwrap(); + let read_ts_commitment = commitment_reader.parse_commitment(&mut arthur).unwrap(); let (randomness, last_sumcheck_value) = run_sumcheck_verifier_spark( &mut arthur, @@ -102,6 +105,9 @@ fn main() -> Result<()> { next_power_of_two(spark_proof.matrix_dimensions.num_rows) + 2, )?; + let claimed_init = gpa_result.claimed_values[0]; + let claimed_final = gpa_result.claimed_values[1]; + let (last_randomness, evaluation_randomness) = gpa_result.randomness.split_at(1); let init_adr = calculate_adr(&evaluation_randomness.to_vec()); @@ -140,6 +146,73 @@ fn main() -> Result<()> { ensure!(evaluated_value == gpa_result.last_sumcheck_value); + // let mut rs_address: FieldElement = arthur.hint()?; + let gpa_result = gpa_sumcheck_verifier( + &mut arthur, + next_power_of_two(spark_proof.matrix_dimensions.a_nonzero_terms) + 2, + )?; + + let claimed_rs = gpa_result.claimed_values[0]; + let claimed_ws = gpa_result.claimed_values[1]; + + let (last_randomness, evaluation_randomness) = gpa_result.randomness.split_at(1); + + let rs_adr = arthur.hint()?; + + let mut rs_adr_statement = Statement::::new(next_power_of_two( + spark_proof.matrix_dimensions.a_nonzero_terms, + )); + rs_adr_statement.add_constraint( + Weights::evaluation(MultilinearPoint(evaluation_randomness.to_vec().clone())), + rs_adr, + ); + + let rs_adr_verifier = Verifier::new(&spark_proof.whir_params.a); + rs_adr_verifier + .verify(&mut arthur, &row_commitment, &rs_adr_statement) + .context("while verifying WHIR")?; + + let rs_mem = arthur.hint()?; + + let mut rs_val_statement = Statement::::new(next_power_of_two( + spark_proof.matrix_dimensions.a_nonzero_terms, + )); + rs_val_statement.add_constraint( + Weights::evaluation(MultilinearPoint(evaluation_randomness.to_vec().clone())), + rs_mem, + ); + + let rs_val_verifier = Verifier::new(&spark_proof.whir_params.a); + rs_val_verifier + .verify(&mut arthur, &e_rx_commitment, &rs_val_statement) + .context("while verifying WHIR")?; + + let rs_timestamp = arthur.hint()?; + + let mut rs_timestamp_statement = Statement::::new(next_power_of_two( + spark_proof.matrix_dimensions.a_nonzero_terms, + )); + rs_timestamp_statement.add_constraint( + Weights::evaluation(MultilinearPoint(evaluation_randomness.to_vec().clone())), + rs_timestamp, + ); + + let rs_timestamp_verifier = Verifier::new(&spark_proof.whir_params.a); + rs_timestamp_verifier + .verify(&mut arthur, &read_ts_commitment, &rs_timestamp_statement) + .context("while verifying WHIR")?; + + let rs_opening = rs_adr * gamma * gamma + rs_mem * gamma + rs_timestamp - tau; + let ws_opening = + rs_adr * gamma * gamma + rs_mem * gamma + rs_timestamp + FieldElement::from(1) - tau; + + let evaluated_value = + rs_opening * (FieldElement::one() - last_randomness[0]) + ws_opening * last_randomness[0]; + + ensure!(evaluated_value == gpa_result.last_sumcheck_value); + + ensure!(claimed_init * claimed_ws == claimed_final * claimed_rs); + Ok(()) } diff --git a/spark-prover/src/gpa.rs b/spark-prover/src/gpa.rs index 6b8f6ebb..26423f1e 100644 --- a/spark-prover/src/gpa.rs +++ b/spark-prover/src/gpa.rs @@ -1,5 +1,6 @@ use { provekit_common::{ + skyscraper::SkyscraperSponge, utils::{ sumcheck::{ calculate_evaluations_over_boolean_hypercube_for_eq, eval_cubic_poly, @@ -7,7 +8,7 @@ use { }, HALF, }, - FieldElement, skyscraper::SkyscraperSponge, + FieldElement, }, spongefish::{ codecs::arkworks_algebra::{FieldToUnitSerialize, UnitToField}, diff --git a/spark-prover/src/spark.rs b/spark-prover/src/spark.rs index e76f6a6f..6f240e2f 100644 --- a/spark-prover/src/spark.rs +++ b/spark-prover/src/spark.rs @@ -8,11 +8,12 @@ use { anyhow::Result, itertools::izip, provekit_common::{ + skyscraper::SkyscraperSponge, utils::{ sumcheck::{eval_cubic_poly, sumcheck_fold_map_reduce}, HALF, }, - FieldElement, skyscraper::SkyscraperSponge, + FieldElement, }, spongefish::{ codecs::arkworks_algebra::{FieldToUnitSerialize, UnitToField}, @@ -39,10 +40,17 @@ pub fn prove_spark_for_single_matrix( let e_rx_witness = commit_to_vector(&committer_a, merlin, e_values.e_rx.clone()); let e_ry_witness = commit_to_vector(&committer_a, merlin, e_values.e_ry.clone()); - let final_row_witness = + let final_row_ts_witness = commit_to_vector(&committer_row, merlin, matrix.timestamps.final_row.clone()); + let row_witness = commit_to_vector(&committer_a, merlin, matrix.coo.row.clone()); + let read_ts_witness = + commit_to_vector(&committer_a, merlin, matrix.timestamps.read_row.clone()); - let mles = [matrix.coo.val.clone(), e_values.e_rx, e_values.e_ry]; + let mles = [ + matrix.coo.val.clone(), + e_values.e_rx.clone(), + e_values.e_ry.clone(), + ]; let (sumcheck_final_folds, folding_randomness) = run_spark_sumcheck(merlin, mles, claimed_value)?; @@ -51,7 +59,7 @@ pub fn prove_spark_for_single_matrix( MultilinearPoint(folding_randomness.clone()), sumcheck_final_folds[0], whir_configs.a.clone(), - val_witness, + val_witness.clone(), )?; produce_whir_proof( @@ -59,7 +67,7 @@ pub fn prove_spark_for_single_matrix( MultilinearPoint(folding_randomness.clone()), sumcheck_final_folds[1], whir_configs.b.clone(), - e_rx_witness, + e_rx_witness.clone(), )?; produce_whir_proof( @@ -100,7 +108,7 @@ pub fn prove_spark_for_single_matrix( let gpa_randomness = run_gpa(merlin, &init_vec, &final_vec); let (combination_randomness, evaluation_randomness) = gpa_randomness.split_at(1); - + // TODO: Can I avoid evaluating here? let final_row_eval = EvaluationsList::new(matrix.timestamps.final_row.clone()) .evaluate(&MultilinearPoint(evaluation_randomness.to_vec().clone())); @@ -111,7 +119,67 @@ pub fn prove_spark_for_single_matrix( MultilinearPoint(evaluation_randomness.to_vec()), final_row_eval, whir_configs.row.clone(), - final_row_witness, + final_row_ts_witness, + )?; + + let rs_address = matrix.coo.row.clone(); + let rs_value = e_values.e_rx.clone(); + let rs_timestamp = matrix.timestamps.read_row.clone(); + + let rs_vec: Vec = + izip!(rs_address.clone(), rs_value.clone(), rs_timestamp.clone()) + .map(|(a, v, t)| a * gamma * gamma + v * gamma + t - tau) + .collect(); + + let ws_address = matrix.coo.row.clone(); + let ws_value = e_values.e_rx.clone(); + let ws_timestamp: Vec = matrix + .timestamps + .read_row + .into_iter() + .map(|a| a + FieldElement::from(1)) + .collect(); + + let ws_vec: Vec = + izip!(ws_address.clone(), ws_value.clone(), ws_timestamp.clone()) + .map(|(a, v, t)| a * gamma * gamma + v * gamma + t - tau) + .collect(); + + let gpa_randomness = run_gpa(merlin, &rs_vec, &ws_vec); + + let (combination_randomness, evaluation_randomness) = gpa_randomness.split_at(1); + + let rs_address_eval = EvaluationsList::new(rs_address) + .evaluate(&MultilinearPoint(evaluation_randomness.to_vec().clone())); + merlin.hint(&rs_address_eval)?; + produce_whir_proof( + merlin, + MultilinearPoint(evaluation_randomness.to_vec()), + rs_address_eval, + whir_configs.a.clone(), + row_witness.clone(), + )?; + + let rs_value_eval = EvaluationsList::new(rs_value) + .evaluate(&MultilinearPoint(evaluation_randomness.to_vec().clone())); + merlin.hint(&rs_value_eval)?; + produce_whir_proof( + merlin, + MultilinearPoint(evaluation_randomness.to_vec()), + rs_value_eval, + whir_configs.a.clone(), + e_rx_witness.clone(), + )?; + + let rs_timestamp_eval = EvaluationsList::new(rs_timestamp) + .evaluate(&MultilinearPoint(evaluation_randomness.to_vec().clone())); + merlin.hint(&rs_timestamp_eval)?; + produce_whir_proof( + merlin, + MultilinearPoint(evaluation_randomness.to_vec()), + rs_timestamp_eval, + whir_configs.a.clone(), + read_ts_witness.clone(), )?; Ok(()) diff --git a/spark-prover/src/utilities/iopattern/mod.rs b/spark-prover/src/utilities/iopattern/mod.rs index 33d1f313..b071fbfc 100644 --- a/spark-prover/src/utilities/iopattern/mod.rs +++ b/spark-prover/src/utilities/iopattern/mod.rs @@ -34,6 +34,8 @@ pub fn create_io_pattern(r1cs: &R1CS, configs: &SPARKWHIRConfigs) -> IOPattern { .commit_statement(&configs.a) .commit_statement(&configs.a) .commit_statement(&configs.row) + .commit_statement(&configs.a) + .commit_statement(&configs.a) .add_sumcheck_polynomials(next_power_of_two(r1cs.a.num_entries())) .hint("sumcheck_last_folds") .add_whir_proof(&configs.a) @@ -50,5 +52,21 @@ pub fn create_io_pattern(r1cs: &R1CS, configs: &SPARKWHIRConfigs) -> IOPattern { .hint("Row final counter claimed evaluation") .add_whir_proof(&configs.row); + for i in 0..=next_power_of_two(r1cs.a.num_entries()) { + io = io.add_sumcheck_polynomials(i); + io = io.add_line(); + } + + io = io + .hint("RS address claimed evaluation") + .add_whir_proof(&configs.a); + + io = io + .hint("RS value claimed evaluation") + .add_whir_proof(&configs.a); + + io = io + .hint("RS timestamp claimed evaluation") + .add_whir_proof(&configs.a); io } diff --git a/spark-prover/src/whir.rs b/spark-prover/src/whir.rs index fd9d81d6..7fcbc3c0 100644 --- a/spark-prover/src/whir.rs +++ b/spark-prover/src/whir.rs @@ -1,8 +1,9 @@ use { anyhow::{Context, Result}, provekit_common::{ - WhirR1CSScheme, utils::next_power_of_two, FieldElement, skyscraper::SkyscraperMerkleConfig, - skyscraper::SkyscraperPoW, skyscraper::SkyscraperSponge, WhirConfig, R1CS, + skyscraper::{SkyscraperMerkleConfig, SkyscraperPoW, SkyscraperSponge}, + utils::next_power_of_two, + FieldElement, WhirConfig, WhirR1CSScheme, R1CS, }, provekit_r1cs_compiler::WhirR1CSSchemeBuilder, serde::{Deserialize, Serialize}, From dd1dd4216db90823efa0538250d7d936fee47c06 Mon Sep 17 00:00:00 2001 From: Batmend Batsaikhan Date: Wed, 17 Sep 2025 14:24:13 +0800 Subject: [PATCH 3/7] wip: batch whir --- spark-prover/src/bin/spark-verifier.rs | 284 ++++++++++---------- spark-prover/src/spark.rs | 271 ++++++++++--------- spark-prover/src/utilities/iopattern/mod.rs | 50 ++-- spark-prover/src/whir.rs | 2 + 4 files changed, 296 insertions(+), 311 deletions(-) diff --git a/spark-prover/src/bin/spark-verifier.rs b/spark-prover/src/bin/spark-verifier.rs index f8fcc03a..33723ef2 100644 --- a/spark-prover/src/bin/spark-verifier.rs +++ b/spark-prover/src/bin/spark-verifier.rs @@ -42,10 +42,10 @@ fn main() -> Result<()> { let commitment_reader = CommitmentReader::new(&spark_proof.whir_params.a); let commitment_reader_row = CommitmentReader::new(&spark_proof.whir_params.row); + let a_spark_sumcheck_commitment_reader = CommitmentReader::new(&spark_proof.whir_params.a_spark_sumcheck); - let val_commitment = commitment_reader.parse_commitment(&mut arthur).unwrap(); - let e_rx_commitment = commitment_reader.parse_commitment(&mut arthur).unwrap(); - let e_ry_commitment = commitment_reader.parse_commitment(&mut arthur).unwrap(); + let a_spark_sumcheck_commitment = a_spark_sumcheck_commitment_reader.parse_commitment(&mut arthur)?; + let final_row_commitment = commitment_reader_row.parse_commitment(&mut arthur).unwrap(); let row_commitment = commitment_reader.parse_commitment(&mut arthur).unwrap(); let read_ts_commitment = commitment_reader.parse_commitment(&mut arthur).unwrap(); @@ -59,159 +59,147 @@ fn main() -> Result<()> { let final_folds: Vec = arthur.hint()?; - let mut val_statement_verifier = Statement::::new(next_power_of_two( + let mut a_spark_sumcheck_statement_verifier = Statement::::new(next_power_of_two( spark_proof.matrix_dimensions.a_nonzero_terms, )); - val_statement_verifier.add_constraint( - Weights::evaluation(MultilinearPoint(randomness.clone())), - final_folds[0], - ); - let val_verifier = Verifier::new(&spark_proof.whir_params.a); - val_verifier - .verify(&mut arthur, &val_commitment, &val_statement_verifier) - .context("while verifying WHIR")?; - let mut e_rx_statement_verifier = Statement::::new(next_power_of_two( - spark_proof.matrix_dimensions.a_nonzero_terms, - )); - e_rx_statement_verifier.add_constraint( - Weights::evaluation(MultilinearPoint(randomness.clone())), - final_folds[1], - ); - let e_rx_verifier = Verifier::new(&spark_proof.whir_params.a); - e_rx_verifier - .verify(&mut arthur, &e_rx_commitment, &e_rx_statement_verifier) - .context("while verifying WHIR")?; + // a_spark_sumcheck_statement_verifier.add_constraint( + // Weights::evaluation(MultilinearPoint(randomness.clone())), + // final_folds[0] + + // final_folds[1] * a_spark_sumcheck_commitment.batching_randomness + + // final_folds[2] * a_spark_sumcheck_commitment.batching_randomness * a_spark_sumcheck_commitment.batching_randomness, + // ); + println!("{:?}", final_folds[0] + final_folds[1] * a_spark_sumcheck_commitment.batching_randomness); - let mut e_ry_statement_verifier = Statement::::new(next_power_of_two( - spark_proof.matrix_dimensions.a_nonzero_terms, - )); - e_ry_statement_verifier.add_constraint( + a_spark_sumcheck_statement_verifier.add_constraint( Weights::evaluation(MultilinearPoint(randomness.clone())), - final_folds[2], + final_folds[0] + final_folds[1] * a_spark_sumcheck_commitment.batching_randomness, ); - let e_ry_verifier = Verifier::new(&spark_proof.whir_params.a); - e_ry_verifier - .verify(&mut arthur, &e_ry_commitment, &e_ry_statement_verifier) - .context("while verifying WHIR")?; - - let mut tau_and_gamma = [FieldElement::from(0); 2]; - arthur.fill_challenge_scalars(&mut tau_and_gamma)?; - let tau = tau_and_gamma[0]; - let gamma = tau_and_gamma[1]; - - let gpa_result = gpa_sumcheck_verifier( - &mut arthur, - next_power_of_two(spark_proof.matrix_dimensions.num_rows) + 2, - )?; - - let claimed_init = gpa_result.claimed_values[0]; - let claimed_final = gpa_result.claimed_values[1]; - - let (last_randomness, evaluation_randomness) = gpa_result.randomness.split_at(1); - - let init_adr = calculate_adr(&evaluation_randomness.to_vec()); - let init_mem = calculate_eq( - &request.point_to_evaluate.row, - &evaluation_randomness.to_vec(), - ); - let init_cntr = FieldElement::from(0); - - let init_opening = init_adr * gamma * gamma + init_mem * gamma + init_cntr - tau; - - let mut final_cntr: FieldElement = arthur.hint()?; - - let mut final_cntr_statement = - Statement::::new(next_power_of_two(spark_proof.matrix_dimensions.num_rows)); - final_cntr_statement.add_constraint( - Weights::evaluation(MultilinearPoint(evaluation_randomness.to_vec().clone())), - final_cntr, - ); - - let final_cntr_verifier = Verifier::new(&spark_proof.whir_params.row); - final_cntr_verifier - .verify(&mut arthur, &final_row_commitment, &final_cntr_statement) - .context("while verifying WHIR")?; - - let final_adr = calculate_adr(&evaluation_randomness.to_vec()); - let final_mem = calculate_eq( - &request.point_to_evaluate.row, - &evaluation_randomness.to_vec(), - ); - - let final_opening = final_adr * gamma * gamma + final_mem * gamma + final_cntr - tau; - - let evaluated_value = init_opening * (FieldElement::one() - last_randomness[0]) - + final_opening * last_randomness[0]; - - ensure!(evaluated_value == gpa_result.last_sumcheck_value); - - // let mut rs_address: FieldElement = arthur.hint()?; - let gpa_result = gpa_sumcheck_verifier( - &mut arthur, - next_power_of_two(spark_proof.matrix_dimensions.a_nonzero_terms) + 2, - )?; - - let claimed_rs = gpa_result.claimed_values[0]; - let claimed_ws = gpa_result.claimed_values[1]; - - let (last_randomness, evaluation_randomness) = gpa_result.randomness.split_at(1); - - let rs_adr = arthur.hint()?; - - let mut rs_adr_statement = Statement::::new(next_power_of_two( - spark_proof.matrix_dimensions.a_nonzero_terms, - )); - rs_adr_statement.add_constraint( - Weights::evaluation(MultilinearPoint(evaluation_randomness.to_vec().clone())), - rs_adr, - ); - - let rs_adr_verifier = Verifier::new(&spark_proof.whir_params.a); - rs_adr_verifier - .verify(&mut arthur, &row_commitment, &rs_adr_statement) - .context("while verifying WHIR")?; - - let rs_mem = arthur.hint()?; - - let mut rs_val_statement = Statement::::new(next_power_of_two( - spark_proof.matrix_dimensions.a_nonzero_terms, - )); - rs_val_statement.add_constraint( - Weights::evaluation(MultilinearPoint(evaluation_randomness.to_vec().clone())), - rs_mem, - ); - - let rs_val_verifier = Verifier::new(&spark_proof.whir_params.a); - rs_val_verifier - .verify(&mut arthur, &e_rx_commitment, &rs_val_statement) - .context("while verifying WHIR")?; - - let rs_timestamp = arthur.hint()?; - - let mut rs_timestamp_statement = Statement::::new(next_power_of_two( - spark_proof.matrix_dimensions.a_nonzero_terms, - )); - rs_timestamp_statement.add_constraint( - Weights::evaluation(MultilinearPoint(evaluation_randomness.to_vec().clone())), - rs_timestamp, - ); - - let rs_timestamp_verifier = Verifier::new(&spark_proof.whir_params.a); - rs_timestamp_verifier - .verify(&mut arthur, &read_ts_commitment, &rs_timestamp_statement) - .context("while verifying WHIR")?; - - let rs_opening = rs_adr * gamma * gamma + rs_mem * gamma + rs_timestamp - tau; - let ws_opening = - rs_adr * gamma * gamma + rs_mem * gamma + rs_timestamp + FieldElement::from(1) - tau; - - let evaluated_value = - rs_opening * (FieldElement::one() - last_randomness[0]) + ws_opening * last_randomness[0]; + // let + let a_spark_sumcheck_verifier = Verifier::new(&spark_proof.whir_params.a_spark_sumcheck); + a_spark_sumcheck_verifier.verify(&mut arthur, &a_spark_sumcheck_commitment, &a_spark_sumcheck_statement_verifier)?; + + // val_verifier + // .verify(&mut arthur, &val_commitment, &val_statement_verifier) + // .context("while verifying WHIR")?; + + // let mut tau_and_gamma = [FieldElement::from(0); 2]; + // arthur.fill_challenge_scalars(&mut tau_and_gamma)?; + // let tau = tau_and_gamma[0]; + // let gamma = tau_and_gamma[1]; + + // let gpa_result = gpa_sumcheck_verifier( + // &mut arthur, + // next_power_of_two(spark_proof.matrix_dimensions.num_rows) + 2, + // )?; + + // let claimed_init = gpa_result.claimed_values[0]; + // let claimed_final = gpa_result.claimed_values[1]; + + // let (last_randomness, evaluation_randomness) = gpa_result.randomness.split_at(1); + + // let init_adr = calculate_adr(&evaluation_randomness.to_vec()); + // let init_mem = calculate_eq( + // &request.point_to_evaluate.row, + // &evaluation_randomness.to_vec(), + // ); + // let init_cntr = FieldElement::from(0); + + // let init_opening = init_adr * gamma * gamma + init_mem * gamma + init_cntr - tau; + + // let mut final_cntr: FieldElement = arthur.hint()?; + + // let mut final_cntr_statement = + // Statement::::new(next_power_of_two(spark_proof.matrix_dimensions.num_rows)); + // final_cntr_statement.add_constraint( + // Weights::evaluation(MultilinearPoint(evaluation_randomness.to_vec().clone())), + // final_cntr, + // ); + + // let final_cntr_verifier = Verifier::new(&spark_proof.whir_params.row); + // final_cntr_verifier + // .verify(&mut arthur, &final_row_commitment, &final_cntr_statement) + // .context("while verifying WHIR")?; + + // let final_adr = calculate_adr(&evaluation_randomness.to_vec()); + // let final_mem = calculate_eq( + // &request.point_to_evaluate.row, + // &evaluation_randomness.to_vec(), + // ); + + // let final_opening = final_adr * gamma * gamma + final_mem * gamma + final_cntr - tau; + + // let evaluated_value = init_opening * (FieldElement::one() - last_randomness[0]) + // + final_opening * last_randomness[0]; + + // ensure!(evaluated_value == gpa_result.last_sumcheck_value); + + // // let mut rs_address: FieldElement = arthur.hint()?; + // let gpa_result = gpa_sumcheck_verifier( + // &mut arthur, + // next_power_of_two(spark_proof.matrix_dimensions.a_nonzero_terms) + 2, + // )?; + + // let claimed_rs = gpa_result.claimed_values[0]; + // let claimed_ws = gpa_result.claimed_values[1]; + + // let (last_randomness, evaluation_randomness) = gpa_result.randomness.split_at(1); + + // let rs_adr = arthur.hint()?; + + // let mut rs_adr_statement = Statement::::new(next_power_of_two( + // spark_proof.matrix_dimensions.a_nonzero_terms, + // )); + // rs_adr_statement.add_constraint( + // Weights::evaluation(MultilinearPoint(evaluation_randomness.to_vec().clone())), + // rs_adr, + // ); + + // let rs_adr_verifier = Verifier::new(&spark_proof.whir_params.a); + // rs_adr_verifier + // .verify(&mut arthur, &row_commitment, &rs_adr_statement) + // .context("while verifying WHIR")?; + + // let rs_mem = arthur.hint()?; + + // let mut rs_val_statement = Statement::::new(next_power_of_two( + // spark_proof.matrix_dimensions.a_nonzero_terms, + // )); + // rs_val_statement.add_constraint( + // Weights::evaluation(MultilinearPoint(evaluation_randomness.to_vec().clone())), + // rs_mem, + // ); + + // let rs_val_verifier = Verifier::new(&spark_proof.whir_params.a); + // rs_val_verifier + // .verify(&mut arthur, &e_rx_commitment, &rs_val_statement) + // .context("while verifying WHIR")?; + + // let rs_timestamp = arthur.hint()?; + + // let mut rs_timestamp_statement = Statement::::new(next_power_of_two( + // spark_proof.matrix_dimensions.a_nonzero_terms, + // )); + // rs_timestamp_statement.add_constraint( + // Weights::evaluation(MultilinearPoint(evaluation_randomness.to_vec().clone())), + // rs_timestamp, + // ); + + // let rs_timestamp_verifier = Verifier::new(&spark_proof.whir_params.a); + // rs_timestamp_verifier + // .verify(&mut arthur, &read_ts_commitment, &rs_timestamp_statement) + // .context("while verifying WHIR")?; - ensure!(evaluated_value == gpa_result.last_sumcheck_value); + // let rs_opening = rs_adr * gamma * gamma + rs_mem * gamma + rs_timestamp - tau; + // let ws_opening = + // rs_adr * gamma * gamma + rs_mem * gamma + rs_timestamp + FieldElement::from(1) - tau; - ensure!(claimed_init * claimed_ws == claimed_final * claimed_rs); + // let evaluated_value = + // rs_opening * (FieldElement::one() - last_randomness[0]) + ws_opening * last_randomness[0]; + + // ensure!(evaluated_value == gpa_result.last_sumcheck_value); + + // ensure!(claimed_init * claimed_ws == claimed_final * claimed_rs); Ok(()) } diff --git a/spark-prover/src/spark.rs b/spark-prover/src/spark.rs index 6f240e2f..15894b33 100644 --- a/spark-prover/src/spark.rs +++ b/spark-prover/src/spark.rs @@ -5,7 +5,7 @@ use { utilities::matrix::SparkMatrix, whir::{commit_to_vector, produce_whir_proof, SPARKWHIRConfigs}, }, - anyhow::Result, + anyhow::{ensure, Result}, itertools::izip, provekit_common::{ skyscraper::SkyscraperSponge, @@ -21,7 +21,7 @@ use { }, whir::{ poly_utils::{evals::EvaluationsList, multilinear::MultilinearPoint}, - whir::{committer::CommitmentWriter, utils::HintSerialize}, + whir::{committer::CommitmentWriter, prover::Prover, statement::{Statement, Weights}, utils::HintSerialize}, }, }; @@ -35,16 +35,18 @@ pub fn prove_spark_for_single_matrix( ) -> Result<()> { let committer_a = CommitmentWriter::new(whir_configs.a.clone()); let committer_row = CommitmentWriter::new(whir_configs.row.clone()); + let a_spark_sumcheck_committer = CommitmentWriter::new(whir_configs.a_spark_sumcheck.clone()); - let val_witness = commit_to_vector(&committer_a, merlin, matrix.coo.val.clone()); - let e_rx_witness = commit_to_vector(&committer_a, merlin, e_values.e_rx.clone()); - let e_ry_witness = commit_to_vector(&committer_a, merlin, e_values.e_ry.clone()); + let val_coeff = EvaluationsList::new(matrix.coo.val.clone()).to_coeffs(); + let e_rx_coeff = EvaluationsList::new(e_values.e_rx.clone()).to_coeffs(); + let e_ry_coeff = EvaluationsList::new(e_values.e_ry.clone()).to_coeffs(); - let final_row_ts_witness = - commit_to_vector(&committer_row, merlin, matrix.timestamps.final_row.clone()); + // let spark_sumcheck_witness = a_spark_sumcheck_committer.commit_batch(merlin, &[val_coeff, e_rx_coeff, e_ry_coeff])?; + let spark_sumcheck_witness = a_spark_sumcheck_committer.commit_batch(merlin, &[val_coeff, e_rx_coeff]).expect("Failed to commit"); + + let final_row_ts_witness = commit_to_vector(&committer_row, merlin, matrix.timestamps.final_row.clone()); let row_witness = commit_to_vector(&committer_a, merlin, matrix.coo.row.clone()); - let read_ts_witness = - commit_to_vector(&committer_a, merlin, matrix.timestamps.read_row.clone()); + let read_ts_witness = commit_to_vector(&committer_a, merlin, matrix.timestamps.read_row.clone()); let mles = [ matrix.coo.val.clone(), @@ -54,133 +56,130 @@ pub fn prove_spark_for_single_matrix( let (sumcheck_final_folds, folding_randomness) = run_spark_sumcheck(merlin, mles, claimed_value)?; - produce_whir_proof( - merlin, - MultilinearPoint(folding_randomness.clone()), - sumcheck_final_folds[0], - whir_configs.a.clone(), - val_witness.clone(), - )?; - - produce_whir_proof( - merlin, - MultilinearPoint(folding_randomness.clone()), - sumcheck_final_folds[1], - whir_configs.b.clone(), - e_rx_witness.clone(), - )?; - - produce_whir_proof( - merlin, - MultilinearPoint(folding_randomness.clone()), - sumcheck_final_folds[2], - whir_configs.c.clone(), - e_ry_witness, - )?; - - // Rowwise - - let mut tau_and_gamma = [FieldElement::from(0); 2]; - merlin.fill_challenge_scalars(&mut tau_and_gamma)?; - let tau = tau_and_gamma[0]; - let gamma = tau_and_gamma[1]; - - let init_address: Vec = (0..memory.eq_rx.len() as u64) - .map(FieldElement::from) - .collect(); - let init_value = memory.eq_rx.clone(); - let init_timestamp = vec![FieldElement::from(0); memory.eq_rx.len()]; - - let init_vec: Vec = izip!(init_address, init_value, init_timestamp) - .map(|(a, v, t)| a * gamma * gamma + v * gamma + t - tau) - .collect(); - - let final_address: Vec = (0..memory.eq_rx.len() as u64) - .map(FieldElement::from) - .collect(); - let final_value = memory.eq_rx.clone(); - let final_timestamp = matrix.timestamps.final_row.clone(); - - let final_vec: Vec = izip!(final_address, final_value, final_timestamp) - .map(|(a, v, t)| a * gamma * gamma + v * gamma + t - tau) - .collect(); - - let gpa_randomness = run_gpa(merlin, &init_vec, &final_vec); - - let (combination_randomness, evaluation_randomness) = gpa_randomness.split_at(1); - - // TODO: Can I avoid evaluating here? - let final_row_eval = EvaluationsList::new(matrix.timestamps.final_row.clone()) - .evaluate(&MultilinearPoint(evaluation_randomness.to_vec().clone())); - merlin.hint(&final_row_eval)?; - - produce_whir_proof( - merlin, - MultilinearPoint(evaluation_randomness.to_vec()), - final_row_eval, - whir_configs.row.clone(), - final_row_ts_witness, - )?; - - let rs_address = matrix.coo.row.clone(); - let rs_value = e_values.e_rx.clone(); - let rs_timestamp = matrix.timestamps.read_row.clone(); - - let rs_vec: Vec = - izip!(rs_address.clone(), rs_value.clone(), rs_timestamp.clone()) - .map(|(a, v, t)| a * gamma * gamma + v * gamma + t - tau) - .collect(); - - let ws_address = matrix.coo.row.clone(); - let ws_value = e_values.e_rx.clone(); - let ws_timestamp: Vec = matrix - .timestamps - .read_row - .into_iter() - .map(|a| a + FieldElement::from(1)) - .collect(); - - let ws_vec: Vec = - izip!(ws_address.clone(), ws_value.clone(), ws_timestamp.clone()) - .map(|(a, v, t)| a * gamma * gamma + v * gamma + t - tau) - .collect(); - - let gpa_randomness = run_gpa(merlin, &rs_vec, &ws_vec); - - let (combination_randomness, evaluation_randomness) = gpa_randomness.split_at(1); - - let rs_address_eval = EvaluationsList::new(rs_address) - .evaluate(&MultilinearPoint(evaluation_randomness.to_vec().clone())); - merlin.hint(&rs_address_eval)?; - produce_whir_proof( - merlin, - MultilinearPoint(evaluation_randomness.to_vec()), - rs_address_eval, - whir_configs.a.clone(), - row_witness.clone(), - )?; - - let rs_value_eval = EvaluationsList::new(rs_value) - .evaluate(&MultilinearPoint(evaluation_randomness.to_vec().clone())); - merlin.hint(&rs_value_eval)?; - produce_whir_proof( - merlin, - MultilinearPoint(evaluation_randomness.to_vec()), - rs_value_eval, - whir_configs.a.clone(), - e_rx_witness.clone(), - )?; - - let rs_timestamp_eval = EvaluationsList::new(rs_timestamp) - .evaluate(&MultilinearPoint(evaluation_randomness.to_vec().clone())); - merlin.hint(&rs_timestamp_eval)?; - produce_whir_proof( - merlin, - MultilinearPoint(evaluation_randomness.to_vec()), - rs_timestamp_eval, - whir_configs.a.clone(), - read_ts_witness.clone(), - )?; + let mut spark_sumcheck_statement = Statement::::new(folding_randomness.len()); + // let claimed_batched_value = + // sumcheck_final_folds[0] + + // sumcheck_final_folds[1] * spark_sumcheck_witness.batching_randomness + + // sumcheck_final_folds[2] * spark_sumcheck_witness.batching_randomness * spark_sumcheck_witness.batching_randomness; + + let claimed_batched_value = + sumcheck_final_folds[0] + + sumcheck_final_folds[1] * spark_sumcheck_witness.batching_randomness; + + let actual_val = spark_sumcheck_witness.batched_poly().evaluate(&MultilinearPoint(folding_randomness.clone())); + ensure!(actual_val == claimed_batched_value); + println!("{:?}", actual_val); + println!("{:?}", claimed_batched_value); + + spark_sumcheck_statement.add_constraint( + Weights::evaluation(MultilinearPoint(folding_randomness.clone())), claimed_batched_value); + + let prover = Prover(whir_configs.a_spark_sumcheck.clone()); + prover.prove(merlin, spark_sumcheck_statement, spark_sumcheck_witness)?; + + // // Rowwise + + // let mut tau_and_gamma = [FieldElement::from(0); 2]; + // merlin.fill_challenge_scalars(&mut tau_and_gamma)?; + // let tau = tau_and_gamma[0]; + // let gamma = tau_and_gamma[1]; + + // let init_address: Vec = (0..memory.eq_rx.len() as u64) + // .map(FieldElement::from) + // .collect(); + // let init_value = memory.eq_rx.clone(); + // let init_timestamp = vec![FieldElement::from(0); memory.eq_rx.len()]; + + // let init_vec: Vec = izip!(init_address, init_value, init_timestamp) + // .map(|(a, v, t)| a * gamma * gamma + v * gamma + t - tau) + // .collect(); + + // let final_address: Vec = (0..memory.eq_rx.len() as u64) + // .map(FieldElement::from) + // .collect(); + // let final_value = memory.eq_rx.clone(); + // let final_timestamp = matrix.timestamps.final_row.clone(); + + // let final_vec: Vec = izip!(final_address, final_value, final_timestamp) + // .map(|(a, v, t)| a * gamma * gamma + v * gamma + t - tau) + // .collect(); + + // let gpa_randomness = run_gpa(merlin, &init_vec, &final_vec); + + // let (combination_randomness, evaluation_randomness) = gpa_randomness.split_at(1); + + // // TODO: Can I avoid evaluating here? + // let final_row_eval = EvaluationsList::new(matrix.timestamps.final_row.clone()) + // .evaluate(&MultilinearPoint(evaluation_randomness.to_vec().clone())); + // merlin.hint(&final_row_eval)?; + + // produce_whir_proof( + // merlin, + // MultilinearPoint(evaluation_randomness.to_vec()), + // final_row_eval, + // whir_configs.row.clone(), + // final_row_ts_witness, + // )?; + + // let rs_address = matrix.coo.row.clone(); + // let rs_value = e_values.e_rx.clone(); + // let rs_timestamp = matrix.timestamps.read_row.clone(); + + // let rs_vec: Vec = + // izip!(rs_address.clone(), rs_value.clone(), rs_timestamp.clone()) + // .map(|(a, v, t)| a * gamma * gamma + v * gamma + t - tau) + // .collect(); + + // let ws_address = matrix.coo.row.clone(); + // let ws_value = e_values.e_rx.clone(); + // let ws_timestamp: Vec = matrix + // .timestamps + // .read_row + // .into_iter() + // .map(|a| a + FieldElement::from(1)) + // .collect(); + + // let ws_vec: Vec = + // izip!(ws_address.clone(), ws_value.clone(), ws_timestamp.clone()) + // .map(|(a, v, t)| a * gamma * gamma + v * gamma + t - tau) + // .collect(); + + // let gpa_randomness = run_gpa(merlin, &rs_vec, &ws_vec); + + // let (combination_randomness, evaluation_randomness) = gpa_randomness.split_at(1); + + // let rs_address_eval = EvaluationsList::new(rs_address) + // .evaluate(&MultilinearPoint(evaluation_randomness.to_vec().clone())); + // merlin.hint(&rs_address_eval)?; + // produce_whir_proof( + // merlin, + // MultilinearPoint(evaluation_randomness.to_vec()), + // rs_address_eval, + // whir_configs.a.clone(), + // row_witness.clone(), + // )?; + + // let rs_value_eval = EvaluationsList::new(rs_value) + // .evaluate(&MultilinearPoint(evaluation_randomness.to_vec().clone())); + // merlin.hint(&rs_value_eval)?; + // produce_whir_proof( + // merlin, + // MultilinearPoint(evaluation_randomness.to_vec()), + // rs_value_eval, + // whir_configs.a.clone(), + // e_rx_witness.clone(), + // )?; + + // let rs_timestamp_eval = EvaluationsList::new(rs_timestamp) + // .evaluate(&MultilinearPoint(evaluation_randomness.to_vec().clone())); + // merlin.hint(&rs_timestamp_eval)?; + // produce_whir_proof( + // merlin, + // MultilinearPoint(evaluation_randomness.to_vec()), + // rs_timestamp_eval, + // whir_configs.a.clone(), + // read_ts_witness.clone(), + // )?; Ok(()) } diff --git a/spark-prover/src/utilities/iopattern/mod.rs b/spark-prover/src/utilities/iopattern/mod.rs index b071fbfc..ac05f7ab 100644 --- a/spark-prover/src/utilities/iopattern/mod.rs +++ b/spark-prover/src/utilities/iopattern/mod.rs @@ -30,43 +30,39 @@ where pub fn create_io_pattern(r1cs: &R1CS, configs: &SPARKWHIRConfigs) -> IOPattern { let mut io = IOPattern::new("💥") - .commit_statement(&configs.a) - .commit_statement(&configs.a) - .commit_statement(&configs.a) + .commit_statement(&configs.a_spark_sumcheck) .commit_statement(&configs.row) .commit_statement(&configs.a) .commit_statement(&configs.a) .add_sumcheck_polynomials(next_power_of_two(r1cs.a.num_entries())) .hint("sumcheck_last_folds") - .add_whir_proof(&configs.a) - .add_whir_proof(&configs.a) - .add_whir_proof(&configs.a) - .add_tau_and_gamma(); + .add_whir_proof(&configs.a_spark_sumcheck); + // .add_tau_and_gamma(); - for i in 0..=next_power_of_two(r1cs.a.num_rows) { - io = io.add_sumcheck_polynomials(i); - io = io.add_line(); - } + // for i in 0..=next_power_of_two(r1cs.a.num_rows) { + // io = io.add_sumcheck_polynomials(i); + // io = io.add_line(); + // } - io = io - .hint("Row final counter claimed evaluation") - .add_whir_proof(&configs.row); + // io = io + // .hint("Row final counter claimed evaluation") + // .add_whir_proof(&configs.row); - for i in 0..=next_power_of_two(r1cs.a.num_entries()) { - io = io.add_sumcheck_polynomials(i); - io = io.add_line(); - } + // for i in 0..=next_power_of_two(r1cs.a.num_entries()) { + // io = io.add_sumcheck_polynomials(i); + // io = io.add_line(); + // } - io = io - .hint("RS address claimed evaluation") - .add_whir_proof(&configs.a); + // io = io + // .hint("RS address claimed evaluation") + // .add_whir_proof(&configs.a); - io = io - .hint("RS value claimed evaluation") - .add_whir_proof(&configs.a); + // io = io + // .hint("RS value claimed evaluation") + // .add_whir_proof(&configs.a); - io = io - .hint("RS timestamp claimed evaluation") - .add_whir_proof(&configs.a); + // io = io + // .hint("RS timestamp claimed evaluation") + // .add_whir_proof(&configs.a); io } diff --git a/spark-prover/src/whir.rs b/spark-prover/src/whir.rs index 7fcbc3c0..100a7547 100644 --- a/spark-prover/src/whir.rs +++ b/spark-prover/src/whir.rs @@ -41,6 +41,7 @@ pub struct SPARKWHIRConfigs { pub a: WhirConfig, pub b: WhirConfig, pub c: WhirConfig, + pub a_spark_sumcheck: WhirConfig, } pub fn create_whir_configs(r1cs: &R1CS) -> SPARKWHIRConfigs { @@ -50,6 +51,7 @@ pub fn create_whir_configs(r1cs: &R1CS) -> SPARKWHIRConfigs { a: WhirR1CSScheme::new_whir_config_for_size(next_power_of_two(r1cs.a.num_entries()), 1), b: WhirR1CSScheme::new_whir_config_for_size(next_power_of_two(r1cs.b.num_entries()), 1), c: WhirR1CSScheme::new_whir_config_for_size(next_power_of_two(r1cs.c.num_entries()), 1), + a_spark_sumcheck: WhirR1CSScheme::new_whir_config_for_size(next_power_of_two(r1cs.a.num_entries()), 2), } } From f5ca88d17e5654a747c1abd84ee715287ca6efef Mon Sep 17 00:00:00 2001 From: Batmend Batsaikhan Date: Wed, 17 Sep 2025 16:10:06 +0800 Subject: [PATCH 4/7] Adds test --- spark-prover/src/bin/test-batched-whir.rs | 35 +++++++++++++++++++++++ 1 file changed, 35 insertions(+) create mode 100644 spark-prover/src/bin/test-batched-whir.rs diff --git a/spark-prover/src/bin/test-batched-whir.rs b/spark-prover/src/bin/test-batched-whir.rs new file mode 100644 index 00000000..056b60af --- /dev/null +++ b/spark-prover/src/bin/test-batched-whir.rs @@ -0,0 +1,35 @@ +use provekit_common::{FieldElement, IOPattern, WhirR1CSScheme}; +use provekit_r1cs_compiler::WhirR1CSSchemeBuilder; +use whir::{poly_utils::{evals::EvaluationsList, multilinear::MultilinearPoint}, whir::{committer::{CommitmentReader, CommitmentWriter}, domainsep::WhirDomainSeparator, prover::Prover, statement::{Statement, Weights}, verifier::Verifier}}; +use anyhow::Result; + +fn main() -> Result<()> { + let whir_config = WhirR1CSScheme::new_whir_config_for_size(6, 2); + let mut io = IOPattern::new("💥") + .commit_statement(&whir_config) + .add_whir_proof(&whir_config); + let mut merlin = io.to_prover_state(); + + let poly1 = EvaluationsList::new([FieldElement::from(1); 64].to_vec()).to_coeffs(); + let poly2 = EvaluationsList::new([FieldElement::from(2); 64].to_vec()).to_coeffs(); + let committer = CommitmentWriter::new(whir_config.clone()); + let witness = committer.commit_batch(&mut merlin, &[poly1, poly2]).expect("Failed to commit"); + + let mut statement = Statement::::new(6); + statement.add_constraint(Weights::evaluation(MultilinearPoint([FieldElement::from(0); 6].to_vec())), FieldElement::from(3)); + let prover = Prover(whir_config.clone()); + let proof = prover.prove(&mut merlin, statement.clone(), witness.clone())?; + + + let mut arthur = io.to_verifier_state(merlin.narg_string()); + let commitment_reader = CommitmentReader::new(&whir_config); + let commitment = commitment_reader.parse_commitment(&mut arthur)?; + + let claimed_ans = FieldElement::from(1) + FieldElement::from(2) * commitment.batching_randomness; + let actual_ans = witness.batched_poly().evaluate(&MultilinearPoint([FieldElement::from(0); 6].to_vec())); + + let verifier = Verifier::new(&whir_config); + verifier.verify(&mut arthur, &commitment, &statement)?; + + Ok(()) +} \ No newline at end of file From eb30b9f36018aac1646a5969131ffe75081df6d0 Mon Sep 17 00:00:00 2001 From: Batmend Batsaikhan Date: Wed, 17 Sep 2025 17:38:16 +0800 Subject: [PATCH 5/7] Creates variable for number of variables --- spark-prover/src/bin/test-batched-whir.rs | 30 ++++++++++++++++++----- 1 file changed, 24 insertions(+), 6 deletions(-) diff --git a/spark-prover/src/bin/test-batched-whir.rs b/spark-prover/src/bin/test-batched-whir.rs index 056b60af..05e5d2f2 100644 --- a/spark-prover/src/bin/test-batched-whir.rs +++ b/spark-prover/src/bin/test-batched-whir.rs @@ -4,19 +4,35 @@ use whir::{poly_utils::{evals::EvaluationsList, multilinear::MultilinearPoint}, use anyhow::Result; fn main() -> Result<()> { - let whir_config = WhirR1CSScheme::new_whir_config_for_size(6, 2); + const NUM_VARIABLES: usize = 5; // Change this + + let whir_config = WhirR1CSScheme::new_whir_config_for_size(NUM_VARIABLES, 2); let mut io = IOPattern::new("💥") .commit_statement(&whir_config) .add_whir_proof(&whir_config); let mut merlin = io.to_prover_state(); - let poly1 = EvaluationsList::new([FieldElement::from(1); 64].to_vec()).to_coeffs(); - let poly2 = EvaluationsList::new([FieldElement::from(2); 64].to_vec()).to_coeffs(); + let poly1 = EvaluationsList::new([FieldElement::from(1); 1<::new(6); - statement.add_constraint(Weights::evaluation(MultilinearPoint([FieldElement::from(0); 6].to_vec())), FieldElement::from(3)); + println!("{:?}", witness.batched_poly()); + + // let actual_ans = witness.batched_poly().evaluate(&MultilinearPoint([FieldElement::from(0); 7].to_vec())); + + let mut statement = Statement::::new(NUM_VARIABLES); + + let weight = Weights::linear(EvaluationsList::new([FieldElement::from(0); 1< Result<()> { let commitment = commitment_reader.parse_commitment(&mut arthur)?; let claimed_ans = FieldElement::from(1) + FieldElement::from(2) * commitment.batching_randomness; - let actual_ans = witness.batched_poly().evaluate(&MultilinearPoint([FieldElement::from(0); 6].to_vec())); + + // println!("{:?}", claimed_ans); + // println!("{:?}", actual_ans); let verifier = Verifier::new(&whir_config); verifier.verify(&mut arthur, &commitment, &statement)?; From a9bd8f96a65712dc0f093b644d2139c354f41570 Mon Sep 17 00:00:00 2001 From: Batmend Batsaikhan Date: Thu, 18 Sep 2025 10:32:50 +0800 Subject: [PATCH 6/7] Adds batched WHIR to sumcheck+rowwise --- spark-prover/src/bin/generate_test_r1cs.rs | 2 +- spark-prover/src/bin/spark-verifier.rs | 188 ++++++++---------- spark-prover/src/spark.rs | 200 +++++++++++--------- spark-prover/src/utilities/iopattern/mod.rs | 41 ++-- spark-prover/src/whir.rs | 2 +- 5 files changed, 205 insertions(+), 228 deletions(-) diff --git a/spark-prover/src/bin/generate_test_r1cs.rs b/spark-prover/src/bin/generate_test_r1cs.rs index 5c0c3fc7..298b489f 100644 --- a/spark-prover/src/bin/generate_test_r1cs.rs +++ b/spark-prover/src/bin/generate_test_r1cs.rs @@ -8,7 +8,7 @@ fn main() { r1cs.grow_matrices(1024, 512); let interned_1 = r1cs.interner.intern(FieldElement::from(1)); - for i in 0..64 { + for i in 0..256 { r1cs.a.set(i, i, interned_1); r1cs.b.set(i, i, interned_1); r1cs.c.set(i, i, interned_1); diff --git a/spark-prover/src/bin/spark-verifier.rs b/spark-prover/src/bin/spark-verifier.rs index 33723ef2..90cb0d0a 100644 --- a/spark-prover/src/bin/spark-verifier.rs +++ b/spark-prover/src/bin/spark-verifier.rs @@ -40,15 +40,13 @@ fn main() -> Result<()> { let io = IOPattern::from_string(spark_proof.io_pattern); let mut arthur = io.to_verifier_state(&spark_proof.transcript); - let commitment_reader = CommitmentReader::new(&spark_proof.whir_params.a); let commitment_reader_row = CommitmentReader::new(&spark_proof.whir_params.row); let a_spark_sumcheck_commitment_reader = CommitmentReader::new(&spark_proof.whir_params.a_spark_sumcheck); let a_spark_sumcheck_commitment = a_spark_sumcheck_commitment_reader.parse_commitment(&mut arthur)?; + let a_spark_rowwise_commitment = a_spark_sumcheck_commitment_reader.parse_commitment(&mut arthur)?; let final_row_commitment = commitment_reader_row.parse_commitment(&mut arthur).unwrap(); - let row_commitment = commitment_reader.parse_commitment(&mut arthur).unwrap(); - let read_ts_commitment = commitment_reader.parse_commitment(&mut arthur).unwrap(); let (randomness, last_sumcheck_value) = run_sumcheck_verifier_spark( &mut arthur, @@ -59,147 +57,111 @@ fn main() -> Result<()> { let final_folds: Vec = arthur.hint()?; + assert!(last_sumcheck_value == final_folds[0] * final_folds[1] * final_folds[2]); + let mut a_spark_sumcheck_statement_verifier = Statement::::new(next_power_of_two( spark_proof.matrix_dimensions.a_nonzero_terms, )); - // a_spark_sumcheck_statement_verifier.add_constraint( - // Weights::evaluation(MultilinearPoint(randomness.clone())), - // final_folds[0] + - // final_folds[1] * a_spark_sumcheck_commitment.batching_randomness + - // final_folds[2] * a_spark_sumcheck_commitment.batching_randomness * a_spark_sumcheck_commitment.batching_randomness, - // ); - println!("{:?}", final_folds[0] + final_folds[1] * a_spark_sumcheck_commitment.batching_randomness); - a_spark_sumcheck_statement_verifier.add_constraint( Weights::evaluation(MultilinearPoint(randomness.clone())), - final_folds[0] + final_folds[1] * a_spark_sumcheck_commitment.batching_randomness, + final_folds[0] + + final_folds[1] * a_spark_sumcheck_commitment.batching_randomness + + final_folds[2] * a_spark_sumcheck_commitment.batching_randomness * a_spark_sumcheck_commitment.batching_randomness, ); - // let + let a_spark_sumcheck_verifier = Verifier::new(&spark_proof.whir_params.a_spark_sumcheck); a_spark_sumcheck_verifier.verify(&mut arthur, &a_spark_sumcheck_commitment, &a_spark_sumcheck_statement_verifier)?; - // val_verifier - // .verify(&mut arthur, &val_commitment, &val_statement_verifier) - // .context("while verifying WHIR")?; - - // let mut tau_and_gamma = [FieldElement::from(0); 2]; - // arthur.fill_challenge_scalars(&mut tau_and_gamma)?; - // let tau = tau_and_gamma[0]; - // let gamma = tau_and_gamma[1]; - - // let gpa_result = gpa_sumcheck_verifier( - // &mut arthur, - // next_power_of_two(spark_proof.matrix_dimensions.num_rows) + 2, - // )?; - - // let claimed_init = gpa_result.claimed_values[0]; - // let claimed_final = gpa_result.claimed_values[1]; - - // let (last_randomness, evaluation_randomness) = gpa_result.randomness.split_at(1); - - // let init_adr = calculate_adr(&evaluation_randomness.to_vec()); - // let init_mem = calculate_eq( - // &request.point_to_evaluate.row, - // &evaluation_randomness.to_vec(), - // ); - // let init_cntr = FieldElement::from(0); + // Rowwise - // let init_opening = init_adr * gamma * gamma + init_mem * gamma + init_cntr - tau; + let mut tau_and_gamma = [FieldElement::from(0); 2]; + arthur.fill_challenge_scalars(&mut tau_and_gamma)?; + let tau = tau_and_gamma[0]; + let gamma = tau_and_gamma[1]; - // let mut final_cntr: FieldElement = arthur.hint()?; + let gpa_result = gpa_sumcheck_verifier( + &mut arthur, + next_power_of_two(spark_proof.matrix_dimensions.num_rows) + 2, + )?; - // let mut final_cntr_statement = - // Statement::::new(next_power_of_two(spark_proof.matrix_dimensions.num_rows)); - // final_cntr_statement.add_constraint( - // Weights::evaluation(MultilinearPoint(evaluation_randomness.to_vec().clone())), - // final_cntr, - // ); + let claimed_init = gpa_result.claimed_values[0]; + let claimed_final = gpa_result.claimed_values[1]; + + let (last_randomness, evaluation_randomness) = gpa_result.randomness.split_at(1); + + let init_adr = calculate_adr(&evaluation_randomness.to_vec()); + let init_mem = calculate_eq( + &request.point_to_evaluate.row, + &evaluation_randomness.to_vec(), + ); + let init_cntr = FieldElement::from(0); - // let final_cntr_verifier = Verifier::new(&spark_proof.whir_params.row); - // final_cntr_verifier - // .verify(&mut arthur, &final_row_commitment, &final_cntr_statement) - // .context("while verifying WHIR")?; + let init_opening = init_adr * gamma * gamma + init_mem * gamma + init_cntr - tau; - // let final_adr = calculate_adr(&evaluation_randomness.to_vec()); - // let final_mem = calculate_eq( - // &request.point_to_evaluate.row, - // &evaluation_randomness.to_vec(), - // ); + let final_cntr: FieldElement = arthur.hint()?; - // let final_opening = final_adr * gamma * gamma + final_mem * gamma + final_cntr - tau; + let mut final_cntr_statement = + Statement::::new(next_power_of_two(spark_proof.matrix_dimensions.num_rows)); + final_cntr_statement.add_constraint( + Weights::evaluation(MultilinearPoint(evaluation_randomness.to_vec().clone())), + final_cntr, + ); - // let evaluated_value = init_opening * (FieldElement::one() - last_randomness[0]) - // + final_opening * last_randomness[0]; + let final_cntr_verifier = Verifier::new(&spark_proof.whir_params.row); + final_cntr_verifier + .verify(&mut arthur, &final_row_commitment, &final_cntr_statement) + .context("while verifying WHIR")?; - // ensure!(evaluated_value == gpa_result.last_sumcheck_value); + let final_adr = calculate_adr(&evaluation_randomness.to_vec()); + let final_mem = calculate_eq( + &request.point_to_evaluate.row, + &evaluation_randomness.to_vec(), + ); - // // let mut rs_address: FieldElement = arthur.hint()?; - // let gpa_result = gpa_sumcheck_verifier( - // &mut arthur, - // next_power_of_two(spark_proof.matrix_dimensions.a_nonzero_terms) + 2, - // )?; + let final_opening = final_adr * gamma * gamma + final_mem * gamma + final_cntr - tau; - // let claimed_rs = gpa_result.claimed_values[0]; - // let claimed_ws = gpa_result.claimed_values[1]; + let evaluated_value = init_opening * (FieldElement::one() - last_randomness[0]) + + final_opening * last_randomness[0]; - // let (last_randomness, evaluation_randomness) = gpa_result.randomness.split_at(1); + ensure!(evaluated_value == gpa_result.last_sumcheck_value); - // let rs_adr = arthur.hint()?; + let gpa_result = gpa_sumcheck_verifier( + &mut arthur, + next_power_of_two(spark_proof.matrix_dimensions.a_nonzero_terms) + 2, + )?; - // let mut rs_adr_statement = Statement::::new(next_power_of_two( - // spark_proof.matrix_dimensions.a_nonzero_terms, - // )); - // rs_adr_statement.add_constraint( - // Weights::evaluation(MultilinearPoint(evaluation_randomness.to_vec().clone())), - // rs_adr, - // ); + let (last_randomness, evaluation_randomness) = gpa_result.randomness.split_at(1); - // let rs_adr_verifier = Verifier::new(&spark_proof.whir_params.a); - // rs_adr_verifier - // .verify(&mut arthur, &row_commitment, &rs_adr_statement) - // .context("while verifying WHIR")?; + let claimed_rs = gpa_result.claimed_values[0]; + let claimed_ws = gpa_result.claimed_values[1]; - // let rs_mem = arthur.hint()?; + let rs_adr: FieldElement = arthur.hint()?; + let rs_mem: FieldElement = arthur.hint()?; + let rs_timestamp: FieldElement = arthur.hint()?; - // let mut rs_val_statement = Statement::::new(next_power_of_two( - // spark_proof.matrix_dimensions.a_nonzero_terms, - // )); - // rs_val_statement.add_constraint( - // Weights::evaluation(MultilinearPoint(evaluation_randomness.to_vec().clone())), - // rs_mem, - // ); + let rs_opening = rs_adr * gamma * gamma + rs_mem * gamma + rs_timestamp - tau; + let ws_opening = rs_adr * gamma * gamma + rs_mem * gamma + rs_timestamp + FieldElement::from(1) - tau; + + let evaluated_value = rs_opening * (FieldElement::one() - last_randomness[0]) + + ws_opening * last_randomness[0]; - // let rs_val_verifier = Verifier::new(&spark_proof.whir_params.a); - // rs_val_verifier - // .verify(&mut arthur, &e_rx_commitment, &rs_val_statement) - // .context("while verifying WHIR")?; + ensure!(evaluated_value == gpa_result.last_sumcheck_value); - // let rs_timestamp = arthur.hint()?; + let mut a_spark_rowwise_statement_verifier = Statement::::new(next_power_of_two( + spark_proof.matrix_dimensions.a_nonzero_terms, + )); - // let mut rs_timestamp_statement = Statement::::new(next_power_of_two( - // spark_proof.matrix_dimensions.a_nonzero_terms, - // )); - // rs_timestamp_statement.add_constraint( - // Weights::evaluation(MultilinearPoint(evaluation_randomness.to_vec().clone())), - // rs_timestamp, - // ); - - // let rs_timestamp_verifier = Verifier::new(&spark_proof.whir_params.a); - // rs_timestamp_verifier - // .verify(&mut arthur, &read_ts_commitment, &rs_timestamp_statement) - // .context("while verifying WHIR")?; + a_spark_rowwise_statement_verifier.add_constraint( + Weights::evaluation(MultilinearPoint(evaluation_randomness.to_vec().clone())), + rs_adr + + rs_mem * a_spark_rowwise_commitment.batching_randomness + + rs_timestamp * a_spark_rowwise_commitment.batching_randomness * a_spark_rowwise_commitment.batching_randomness, + ); - // let rs_opening = rs_adr * gamma * gamma + rs_mem * gamma + rs_timestamp - tau; - // let ws_opening = - // rs_adr * gamma * gamma + rs_mem * gamma + rs_timestamp + FieldElement::from(1) - tau; + a_spark_sumcheck_verifier.verify(&mut arthur, &a_spark_rowwise_commitment, &a_spark_rowwise_statement_verifier)?; - // let evaluated_value = - // rs_opening * (FieldElement::one() - last_randomness[0]) + ws_opening * last_randomness[0]; - - // ensure!(evaluated_value == gpa_result.last_sumcheck_value); - - // ensure!(claimed_init * claimed_ws == claimed_final * claimed_rs); + ensure!(claimed_init * claimed_ws == claimed_rs * claimed_final); Ok(()) } diff --git a/spark-prover/src/spark.rs b/spark-prover/src/spark.rs index 15894b33..bc4c88e5 100644 --- a/spark-prover/src/spark.rs +++ b/spark-prover/src/spark.rs @@ -41,116 +41,136 @@ pub fn prove_spark_for_single_matrix( let e_rx_coeff = EvaluationsList::new(e_values.e_rx.clone()).to_coeffs(); let e_ry_coeff = EvaluationsList::new(e_values.e_ry.clone()).to_coeffs(); - // let spark_sumcheck_witness = a_spark_sumcheck_committer.commit_batch(merlin, &[val_coeff, e_rx_coeff, e_ry_coeff])?; - let spark_sumcheck_witness = a_spark_sumcheck_committer.commit_batch(merlin, &[val_coeff, e_rx_coeff]).expect("Failed to commit"); + let spark_sumcheck_witness = a_spark_sumcheck_committer.commit_batch(merlin, &[val_coeff, e_rx_coeff, e_ry_coeff])?; + + let row_addr_coeff = EvaluationsList::new(matrix.coo.row.clone()).to_coeffs(); + let row_val_coeff = EvaluationsList::new(e_values.e_rx.clone()).to_coeffs(); + let row_timestamp_coeff = EvaluationsList::new(matrix.timestamps.read_row.clone()).to_coeffs(); + + let spark_rowwise_witness = a_spark_sumcheck_committer.commit_batch(merlin, &[row_addr_coeff, row_val_coeff, row_timestamp_coeff])?; let final_row_ts_witness = commit_to_vector(&committer_row, merlin, matrix.timestamps.final_row.clone()); - let row_witness = commit_to_vector(&committer_a, merlin, matrix.coo.row.clone()); - let read_ts_witness = commit_to_vector(&committer_a, merlin, matrix.timestamps.read_row.clone()); let mles = [ matrix.coo.val.clone(), e_values.e_rx.clone(), e_values.e_ry.clone(), ]; + let (sumcheck_final_folds, folding_randomness) = run_spark_sumcheck(merlin, mles, claimed_value)?; let mut spark_sumcheck_statement = Statement::::new(folding_randomness.len()); - // let claimed_batched_value = - // sumcheck_final_folds[0] + - // sumcheck_final_folds[1] * spark_sumcheck_witness.batching_randomness + - // sumcheck_final_folds[2] * spark_sumcheck_witness.batching_randomness * spark_sumcheck_witness.batching_randomness; - + let claimed_batched_value = sumcheck_final_folds[0] + - sumcheck_final_folds[1] * spark_sumcheck_witness.batching_randomness; - - let actual_val = spark_sumcheck_witness.batched_poly().evaluate(&MultilinearPoint(folding_randomness.clone())); - ensure!(actual_val == claimed_batched_value); - println!("{:?}", actual_val); - println!("{:?}", claimed_batched_value); + sumcheck_final_folds[1] * spark_sumcheck_witness.batching_randomness + + sumcheck_final_folds[2] * spark_sumcheck_witness.batching_randomness * spark_sumcheck_witness.batching_randomness; spark_sumcheck_statement.add_constraint( Weights::evaluation(MultilinearPoint(folding_randomness.clone())), claimed_batched_value); - let prover = Prover(whir_configs.a_spark_sumcheck.clone()); - prover.prove(merlin, spark_sumcheck_statement, spark_sumcheck_witness)?; - - // // Rowwise - - // let mut tau_and_gamma = [FieldElement::from(0); 2]; - // merlin.fill_challenge_scalars(&mut tau_and_gamma)?; - // let tau = tau_and_gamma[0]; - // let gamma = tau_and_gamma[1]; - - // let init_address: Vec = (0..memory.eq_rx.len() as u64) - // .map(FieldElement::from) - // .collect(); - // let init_value = memory.eq_rx.clone(); - // let init_timestamp = vec![FieldElement::from(0); memory.eq_rx.len()]; - - // let init_vec: Vec = izip!(init_address, init_value, init_timestamp) - // .map(|(a, v, t)| a * gamma * gamma + v * gamma + t - tau) - // .collect(); - - // let final_address: Vec = (0..memory.eq_rx.len() as u64) - // .map(FieldElement::from) - // .collect(); - // let final_value = memory.eq_rx.clone(); - // let final_timestamp = matrix.timestamps.final_row.clone(); - - // let final_vec: Vec = izip!(final_address, final_value, final_timestamp) - // .map(|(a, v, t)| a * gamma * gamma + v * gamma + t - tau) - // .collect(); - - // let gpa_randomness = run_gpa(merlin, &init_vec, &final_vec); - - // let (combination_randomness, evaluation_randomness) = gpa_randomness.split_at(1); - - // // TODO: Can I avoid evaluating here? - // let final_row_eval = EvaluationsList::new(matrix.timestamps.final_row.clone()) - // .evaluate(&MultilinearPoint(evaluation_randomness.to_vec().clone())); - // merlin.hint(&final_row_eval)?; - - // produce_whir_proof( - // merlin, - // MultilinearPoint(evaluation_randomness.to_vec()), - // final_row_eval, - // whir_configs.row.clone(), - // final_row_ts_witness, - // )?; - - // let rs_address = matrix.coo.row.clone(); - // let rs_value = e_values.e_rx.clone(); - // let rs_timestamp = matrix.timestamps.read_row.clone(); - - // let rs_vec: Vec = - // izip!(rs_address.clone(), rs_value.clone(), rs_timestamp.clone()) - // .map(|(a, v, t)| a * gamma * gamma + v * gamma + t - tau) - // .collect(); + let sumcheck_prover = Prover(whir_configs.a_spark_sumcheck.clone()); + sumcheck_prover.prove(merlin, spark_sumcheck_statement, spark_sumcheck_witness)?; + + // Rowwise + + let mut tau_and_gamma = [FieldElement::from(0); 2]; + merlin.fill_challenge_scalars(&mut tau_and_gamma)?; + let tau = tau_and_gamma[0]; + let gamma = tau_and_gamma[1]; + + let init_address: Vec = (0..memory.eq_rx.len() as u64) + .map(FieldElement::from) + .collect(); + let init_value = memory.eq_rx.clone(); + let init_timestamp = vec![FieldElement::from(0); memory.eq_rx.len()]; + + let init_vec: Vec = izip!(init_address, init_value, init_timestamp) + .map(|(a, v, t)| a * gamma * gamma + v * gamma + t - tau) + .collect(); + + let final_address: Vec = (0..memory.eq_rx.len() as u64) + .map(FieldElement::from) + .collect(); + let final_value = memory.eq_rx.clone(); + let final_timestamp = matrix.timestamps.final_row.clone(); + + let final_vec: Vec = izip!(final_address, final_value, final_timestamp) + .map(|(a, v, t)| a * gamma * gamma + v * gamma + t - tau) + .collect(); + + let gpa_randomness = run_gpa(merlin, &init_vec, &final_vec); + + let (combination_randomness, evaluation_randomness) = gpa_randomness.split_at(1); + + // TODO: Can I avoid evaluating here? + let final_row_eval = EvaluationsList::new(matrix.timestamps.final_row.clone()) + .evaluate(&MultilinearPoint(evaluation_randomness.to_vec().clone())); + merlin.hint(&final_row_eval)?; + + produce_whir_proof( + merlin, + MultilinearPoint(evaluation_randomness.to_vec()), + final_row_eval, + whir_configs.row.clone(), + final_row_ts_witness, + )?; + + let rs_address = matrix.coo.row.clone(); + let rs_value = e_values.e_rx.clone(); + let rs_timestamp = matrix.timestamps.read_row.clone(); + + let rs_vec: Vec = + izip!(rs_address.clone(), rs_value.clone(), rs_timestamp.clone()) + .map(|(a, v, t)| a * gamma * gamma + v * gamma + t - tau) + .collect(); + + let ws_address = matrix.coo.row.clone(); + let ws_value = e_values.e_rx.clone(); + let ws_timestamp: Vec = matrix + .timestamps + .read_row + .into_iter() + .map(|a| a + FieldElement::from(1)) + .collect(); + + let ws_vec: Vec = + izip!(ws_address.clone(), ws_value.clone(), ws_timestamp.clone()) + .map(|(a, v, t)| a * gamma * gamma + v * gamma + t - tau) + .collect(); + + let gpa_randomness = run_gpa(merlin, &rs_vec, &ws_vec); + + let (combination_randomness, evaluation_randomness) = gpa_randomness.split_at(1); + + let rs_address_eval = EvaluationsList::new(rs_address) + .evaluate(&MultilinearPoint(evaluation_randomness.to_vec().clone())); + merlin.hint(&rs_address_eval)?; + + let rs_value_eval = EvaluationsList::new(rs_value) + .evaluate(&MultilinearPoint(evaluation_randomness.to_vec().clone())); + merlin.hint(&rs_value_eval)?; - // let ws_address = matrix.coo.row.clone(); - // let ws_value = e_values.e_rx.clone(); - // let ws_timestamp: Vec = matrix - // .timestamps - // .read_row - // .into_iter() - // .map(|a| a + FieldElement::from(1)) - // .collect(); + let rs_timestamp_eval = EvaluationsList::new(rs_timestamp) + .evaluate(&MultilinearPoint(evaluation_randomness.to_vec().clone())); + merlin.hint(&rs_timestamp_eval)?; - // let ws_vec: Vec = - // izip!(ws_address.clone(), ws_value.clone(), ws_timestamp.clone()) - // .map(|(a, v, t)| a * gamma * gamma + v * gamma + t - tau) - // .collect(); + let mut spark_rowwise_statement = Statement::::new(evaluation_randomness.len()); - // let gpa_randomness = run_gpa(merlin, &rs_vec, &ws_vec); + let claimed_rowwise_eval = + rs_address_eval + + rs_value_eval * spark_rowwise_witness.batching_randomness + + rs_timestamp_eval * spark_rowwise_witness.batching_randomness * spark_rowwise_witness.batching_randomness; - // let (combination_randomness, evaluation_randomness) = gpa_randomness.split_at(1); + assert!(claimed_rowwise_eval == spark_rowwise_witness.batched_poly().evaluate(&MultilinearPoint(evaluation_randomness.to_vec()))); - // let rs_address_eval = EvaluationsList::new(rs_address) - // .evaluate(&MultilinearPoint(evaluation_randomness.to_vec().clone())); - // merlin.hint(&rs_address_eval)?; + spark_rowwise_statement.add_constraint( + Weights::evaluation(MultilinearPoint(evaluation_randomness.to_vec().clone())), claimed_rowwise_eval); + + let sumcheck_prover = Prover(whir_configs.a_spark_sumcheck.clone()); + sumcheck_prover.prove(merlin, spark_rowwise_statement, spark_rowwise_witness)?; + // produce_whir_proof( // merlin, // MultilinearPoint(evaluation_randomness.to_vec()), @@ -159,9 +179,6 @@ pub fn prove_spark_for_single_matrix( // row_witness.clone(), // )?; - // let rs_value_eval = EvaluationsList::new(rs_value) - // .evaluate(&MultilinearPoint(evaluation_randomness.to_vec().clone())); - // merlin.hint(&rs_value_eval)?; // produce_whir_proof( // merlin, // MultilinearPoint(evaluation_randomness.to_vec()), @@ -170,9 +187,6 @@ pub fn prove_spark_for_single_matrix( // e_rx_witness.clone(), // )?; - // let rs_timestamp_eval = EvaluationsList::new(rs_timestamp) - // .evaluate(&MultilinearPoint(evaluation_randomness.to_vec().clone())); - // merlin.hint(&rs_timestamp_eval)?; // produce_whir_proof( // merlin, // MultilinearPoint(evaluation_randomness.to_vec()), diff --git a/spark-prover/src/utilities/iopattern/mod.rs b/spark-prover/src/utilities/iopattern/mod.rs index ac05f7ab..6235957c 100644 --- a/spark-prover/src/utilities/iopattern/mod.rs +++ b/spark-prover/src/utilities/iopattern/mod.rs @@ -30,39 +30,40 @@ where pub fn create_io_pattern(r1cs: &R1CS, configs: &SPARKWHIRConfigs) -> IOPattern { let mut io = IOPattern::new("💥") + .commit_statement(&configs.a_spark_sumcheck) .commit_statement(&configs.a_spark_sumcheck) .commit_statement(&configs.row) - .commit_statement(&configs.a) - .commit_statement(&configs.a) .add_sumcheck_polynomials(next_power_of_two(r1cs.a.num_entries())) .hint("sumcheck_last_folds") .add_whir_proof(&configs.a_spark_sumcheck); - // .add_tau_and_gamma(); + + io = io.add_tau_and_gamma(); - // for i in 0..=next_power_of_two(r1cs.a.num_rows) { - // io = io.add_sumcheck_polynomials(i); - // io = io.add_line(); - // } + for i in 0..=next_power_of_two(r1cs.a.num_rows) { + io = io.add_sumcheck_polynomials(i); + io = io.add_line(); + } - // io = io - // .hint("Row final counter claimed evaluation") - // .add_whir_proof(&configs.row); + io = io + .hint("Row final counter claimed evaluation") + .add_whir_proof(&configs.row); - // for i in 0..=next_power_of_two(r1cs.a.num_entries()) { - // io = io.add_sumcheck_polynomials(i); - // io = io.add_line(); - // } + for i in 0..=next_power_of_two(r1cs.a.num_entries()) { + io = io.add_sumcheck_polynomials(i); + io = io.add_line(); + } - // io = io - // .hint("RS address claimed evaluation") - // .add_whir_proof(&configs.a); + io = io + .hint("RS address claimed evaluation") + .hint("RS value claimed evaluation") + .hint("RS timestamp claimed evaluation") + .add_whir_proof(&configs.a_spark_sumcheck); // io = io - // .hint("RS value claimed evaluation") + // .add_whir_proof(&configs.a); + // .add_whir_proof(&configs.a); // .add_whir_proof(&configs.a); // io = io - // .hint("RS timestamp claimed evaluation") - // .add_whir_proof(&configs.a); io } diff --git a/spark-prover/src/whir.rs b/spark-prover/src/whir.rs index 100a7547..5580ae80 100644 --- a/spark-prover/src/whir.rs +++ b/spark-prover/src/whir.rs @@ -51,7 +51,7 @@ pub fn create_whir_configs(r1cs: &R1CS) -> SPARKWHIRConfigs { a: WhirR1CSScheme::new_whir_config_for_size(next_power_of_two(r1cs.a.num_entries()), 1), b: WhirR1CSScheme::new_whir_config_for_size(next_power_of_two(r1cs.b.num_entries()), 1), c: WhirR1CSScheme::new_whir_config_for_size(next_power_of_two(r1cs.c.num_entries()), 1), - a_spark_sumcheck: WhirR1CSScheme::new_whir_config_for_size(next_power_of_two(r1cs.a.num_entries()), 2), + a_spark_sumcheck: WhirR1CSScheme::new_whir_config_for_size(next_power_of_two(r1cs.a.num_entries()), 3), } } From 6a67ee07e11b76dc5d8adbc32c2b5a5c189e2084 Mon Sep 17 00:00:00 2001 From: Batmend Batsaikhan Date: Thu, 18 Sep 2025 16:37:38 +0800 Subject: [PATCH 7/7] Adds B and C matrices --- spark-prover/src/bin/spark-verifier.rs | 224 ++++++++++++++++---- spark-prover/src/bin/test-batched-whir.rs | 53 ----- spark-prover/src/main.rs | 20 +- spark-prover/src/spark.rs | 203 +++++++++++++----- spark-prover/src/utilities/iopattern/mod.rs | 165 +++++++++++++- spark-prover/src/whir.rs | 8 +- 6 files changed, 512 insertions(+), 161 deletions(-) delete mode 100644 spark-prover/src/bin/test-batched-whir.rs diff --git a/spark-prover/src/bin/spark-verifier.rs b/spark-prover/src/bin/spark-verifier.rs index 90cb0d0a..e64c4a4f 100644 --- a/spark-prover/src/bin/spark-verifier.rs +++ b/spark-prover/src/bin/spark-verifier.rs @@ -7,7 +7,7 @@ use { next_power_of_two, sumcheck::{calculate_eq, eval_cubic_poly}, }, - FieldElement, IOPattern, + FieldElement, IOPattern, WhirConfig, }, spark_prover::utilities::{SPARKProof, SPARKRequest}, spongefish::{ @@ -37,43 +37,101 @@ fn main() -> Result<()> { let request: SPARKRequest = serde_json::from_str(&request_json_str) .context("Error: Failed to deserialize JSON to R1CS")?; - let io = IOPattern::from_string(spark_proof.io_pattern); + let io = IOPattern::from_string(spark_proof.io_pattern.clone()); let mut arthur = io.to_verifier_state(&spark_proof.transcript); - let commitment_reader_row = CommitmentReader::new(&spark_proof.whir_params.row); - let a_spark_sumcheck_commitment_reader = CommitmentReader::new(&spark_proof.whir_params.a_spark_sumcheck); + verify_spark_single_matrix( + &spark_proof.whir_params.row, + &spark_proof.whir_params.col, + &spark_proof.whir_params.a_3batched, + spark_proof.matrix_dimensions.num_rows, + spark_proof.matrix_dimensions.num_cols, + spark_proof.matrix_dimensions.a_nonzero_terms, + &mut arthur, + &request, + &request.claimed_values.a, + )?; + + verify_spark_single_matrix( + &spark_proof.whir_params.row, + &spark_proof.whir_params.col, + &spark_proof.whir_params.b_3batched, + spark_proof.matrix_dimensions.num_rows, + spark_proof.matrix_dimensions.num_cols, + spark_proof.matrix_dimensions.b_nonzero_terms, + &mut arthur, + &request, + &request.claimed_values.b, + )?; + + verify_spark_single_matrix( + &spark_proof.whir_params.row, + &spark_proof.whir_params.col, + &spark_proof.whir_params.c_3batched, + spark_proof.matrix_dimensions.num_rows, + spark_proof.matrix_dimensions.num_cols, + spark_proof.matrix_dimensions.c_nonzero_terms, + &mut arthur, + &request, + &request.claimed_values.c, + )?; + + Ok(()) +} + +pub fn verify_spark_single_matrix( + row_config: &WhirConfig, + col_config: &WhirConfig, + num_nonzero_term_batched3_config: &WhirConfig, + num_rows: usize, + num_cols: usize, + num_nonzero_terms: usize, + arthur: &mut VerifierState, + request: &SPARKRequest, + claimed_value: &FieldElement, +) -> Result<()> { + let commitment_reader_row = CommitmentReader::new(row_config); + let commitment_reader_col = CommitmentReader::new(col_config); + + // Matrix A - let a_spark_sumcheck_commitment = a_spark_sumcheck_commitment_reader.parse_commitment(&mut arthur)?; - let a_spark_rowwise_commitment = a_spark_sumcheck_commitment_reader.parse_commitment(&mut arthur)?; + let a_3batched_commitment_reader = CommitmentReader::new(num_nonzero_term_batched3_config); + + let a_sumcheck_commitment = a_3batched_commitment_reader.parse_commitment(arthur)?; + let a_rowwise_commitment = a_3batched_commitment_reader.parse_commitment(arthur)?; + let a_colwise_commitment = a_3batched_commitment_reader.parse_commitment(arthur)?; - let final_row_commitment = commitment_reader_row.parse_commitment(&mut arthur).unwrap(); + let a_row_finalts_commitment = commitment_reader_row.parse_commitment(arthur).unwrap(); + let a_col_finalts_commitment = commitment_reader_col.parse_commitment(arthur).unwrap(); + + // Matrix A - Sumcheck - let (randomness, last_sumcheck_value) = run_sumcheck_verifier_spark( - &mut arthur, - next_power_of_two(spark_proof.matrix_dimensions.a_nonzero_terms), - request.claimed_values.a, + let (randomness, a_last_sumcheck_value) = run_sumcheck_verifier_spark( + arthur, + next_power_of_two(num_nonzero_terms), + *claimed_value, ) .context("While verifying SPARK sumcheck")?; let final_folds: Vec = arthur.hint()?; - assert!(last_sumcheck_value == final_folds[0] * final_folds[1] * final_folds[2]); + assert!(a_last_sumcheck_value == final_folds[0] * final_folds[1] * final_folds[2]); let mut a_spark_sumcheck_statement_verifier = Statement::::new(next_power_of_two( - spark_proof.matrix_dimensions.a_nonzero_terms, + num_nonzero_terms, )); a_spark_sumcheck_statement_verifier.add_constraint( Weights::evaluation(MultilinearPoint(randomness.clone())), final_folds[0] + - final_folds[1] * a_spark_sumcheck_commitment.batching_randomness + - final_folds[2] * a_spark_sumcheck_commitment.batching_randomness * a_spark_sumcheck_commitment.batching_randomness, + final_folds[1] * a_sumcheck_commitment.batching_randomness + + final_folds[2] * a_sumcheck_commitment.batching_randomness * a_sumcheck_commitment.batching_randomness, ); - let a_spark_sumcheck_verifier = Verifier::new(&spark_proof.whir_params.a_spark_sumcheck); - a_spark_sumcheck_verifier.verify(&mut arthur, &a_spark_sumcheck_commitment, &a_spark_sumcheck_statement_verifier)?; + let a_spark_sumcheck_verifier = Verifier::new(num_nonzero_term_batched3_config); + a_spark_sumcheck_verifier.verify(arthur, &a_sumcheck_commitment, &a_spark_sumcheck_statement_verifier)?; - // Rowwise + // Matrix A - Rowwise let mut tau_and_gamma = [FieldElement::from(0); 2]; arthur.fill_challenge_scalars(&mut tau_and_gamma)?; @@ -81,8 +139,8 @@ fn main() -> Result<()> { let gamma = tau_and_gamma[1]; let gpa_result = gpa_sumcheck_verifier( - &mut arthur, - next_power_of_two(spark_proof.matrix_dimensions.num_rows) + 2, + arthur, + next_power_of_two(num_rows) + 2, )?; let claimed_init = gpa_result.claimed_values[0]; @@ -102,15 +160,15 @@ fn main() -> Result<()> { let final_cntr: FieldElement = arthur.hint()?; let mut final_cntr_statement = - Statement::::new(next_power_of_two(spark_proof.matrix_dimensions.num_rows)); + Statement::::new(next_power_of_two(num_rows)); final_cntr_statement.add_constraint( Weights::evaluation(MultilinearPoint(evaluation_randomness.to_vec().clone())), final_cntr, ); - let final_cntr_verifier = Verifier::new(&spark_proof.whir_params.row); + let final_cntr_verifier = Verifier::new(row_config); final_cntr_verifier - .verify(&mut arthur, &final_row_commitment, &final_cntr_statement) + .verify(arthur, &a_row_finalts_commitment, &final_cntr_statement) .context("while verifying WHIR")?; let final_adr = calculate_adr(&evaluation_randomness.to_vec()); @@ -124,11 +182,11 @@ fn main() -> Result<()> { let evaluated_value = init_opening * (FieldElement::one() - last_randomness[0]) + final_opening * last_randomness[0]; - ensure!(evaluated_value == gpa_result.last_sumcheck_value); + ensure!(evaluated_value == gpa_result.a_last_sumcheck_value); let gpa_result = gpa_sumcheck_verifier( - &mut arthur, - next_power_of_two(spark_proof.matrix_dimensions.a_nonzero_terms) + 2, + arthur, + next_power_of_two(num_nonzero_terms) + 2, )?; let (last_randomness, evaluation_randomness) = gpa_result.randomness.split_at(1); @@ -146,20 +204,110 @@ fn main() -> Result<()> { let evaluated_value = rs_opening * (FieldElement::one() - last_randomness[0]) + ws_opening * last_randomness[0]; - ensure!(evaluated_value == gpa_result.last_sumcheck_value); + ensure!(evaluated_value == gpa_result.a_last_sumcheck_value); let mut a_spark_rowwise_statement_verifier = Statement::::new(next_power_of_two( - spark_proof.matrix_dimensions.a_nonzero_terms, + num_nonzero_terms, )); a_spark_rowwise_statement_verifier.add_constraint( Weights::evaluation(MultilinearPoint(evaluation_randomness.to_vec().clone())), rs_adr + - rs_mem * a_spark_rowwise_commitment.batching_randomness + - rs_timestamp * a_spark_rowwise_commitment.batching_randomness * a_spark_rowwise_commitment.batching_randomness, + rs_mem * a_rowwise_commitment.batching_randomness + + rs_timestamp * a_rowwise_commitment.batching_randomness * a_rowwise_commitment.batching_randomness, + ); + + a_spark_sumcheck_verifier.verify(arthur, &a_rowwise_commitment, &a_spark_rowwise_statement_verifier)?; + + ensure!(claimed_init * claimed_ws == claimed_rs * claimed_final); + + // Matrix A - Colwise + + let mut tau_and_gamma = [FieldElement::from(0); 2]; + arthur.fill_challenge_scalars(&mut tau_and_gamma)?; + let tau = tau_and_gamma[0]; + let gamma = tau_and_gamma[1]; + + let gpa_result = gpa_sumcheck_verifier( + arthur, + next_power_of_two(num_cols) + 2, + )?; + + let claimed_init = gpa_result.claimed_values[0]; + let claimed_final = gpa_result.claimed_values[1]; + + let (last_randomness, evaluation_randomness) = gpa_result.randomness.split_at(1); + + let init_adr = calculate_adr(&evaluation_randomness.to_vec()); + let init_mem = calculate_eq( + &request.point_to_evaluate.col, + &evaluation_randomness.to_vec(), + ); + let init_cntr = FieldElement::from(0); + + let init_opening = init_adr * gamma * gamma + init_mem * gamma + init_cntr - tau; + + let final_cntr: FieldElement = arthur.hint()?; + + let mut final_cntr_statement = + Statement::::new(next_power_of_two(num_cols)); + final_cntr_statement.add_constraint( + Weights::evaluation(MultilinearPoint(evaluation_randomness.to_vec().clone())), + final_cntr, + ); + + let final_cntr_verifier = Verifier::new(col_config); + final_cntr_verifier + .verify(arthur, &a_col_finalts_commitment, &final_cntr_statement) + .context("while verifying WHIR")?; + + let final_adr = calculate_adr(&evaluation_randomness.to_vec()); + let final_mem = calculate_eq( + &request.point_to_evaluate.col, + &evaluation_randomness.to_vec(), + ); + + let final_opening = final_adr * gamma * gamma + final_mem * gamma + final_cntr - tau; + + let evaluated_value = init_opening * (FieldElement::one() - last_randomness[0]) + + final_opening * last_randomness[0]; + + ensure!(evaluated_value == gpa_result.a_last_sumcheck_value); + + let gpa_result = gpa_sumcheck_verifier( + arthur, + next_power_of_two(num_nonzero_terms) + 2, + )?; + + let (last_randomness, evaluation_randomness) = gpa_result.randomness.split_at(1); + + let claimed_rs = gpa_result.claimed_values[0]; + let claimed_ws = gpa_result.claimed_values[1]; + + let rs_adr: FieldElement = arthur.hint()?; + let rs_mem: FieldElement = arthur.hint()?; + let rs_timestamp: FieldElement = arthur.hint()?; + + let rs_opening = rs_adr * gamma * gamma + rs_mem * gamma + rs_timestamp - tau; + let ws_opening = rs_adr * gamma * gamma + rs_mem * gamma + rs_timestamp + FieldElement::from(1) - tau; + + let evaluated_value = rs_opening * (FieldElement::one() - last_randomness[0]) + + ws_opening * last_randomness[0]; + + ensure!(evaluated_value == gpa_result.a_last_sumcheck_value); + + let mut a_spark_colwise_statement_verifier = Statement::::new(next_power_of_two( + num_nonzero_terms, + )); + + a_spark_colwise_statement_verifier.add_constraint( + Weights::evaluation(MultilinearPoint(evaluation_randomness.to_vec().clone())), + rs_adr + + rs_mem * a_colwise_commitment.batching_randomness + + rs_timestamp * a_colwise_commitment.batching_randomness * a_colwise_commitment.batching_randomness, ); - a_spark_sumcheck_verifier.verify(&mut arthur, &a_spark_rowwise_commitment, &a_spark_rowwise_statement_verifier)?; + a_spark_sumcheck_verifier.verify(arthur, &a_colwise_commitment, &a_spark_colwise_statement_verifier)?; ensure!(claimed_init * claimed_ws == claimed_rs * claimed_final); @@ -212,7 +360,7 @@ pub fn gpa_sumcheck_verifier( arthur .fill_challenge_scalars(&mut r) .expect("Failed to fill next scalars"); - let mut last_sumcheck_value = eval_linear_poly(&claimed_values, &r[0]); + let mut a_last_sumcheck_value = eval_linear_poly(&claimed_values, &r[0]); rand.push(r[0]); prev_rand = rand; @@ -229,10 +377,10 @@ pub fn gpa_sumcheck_verifier( assert_eq!( eval_cubic_poly(&h, &FieldElement::from(0)) + eval_cubic_poly(&h, &FieldElement::from(1)), - last_sumcheck_value + a_last_sumcheck_value ); rand.push(alpha[0]); - last_sumcheck_value = eval_cubic_poly(&h, &alpha[0]); + a_last_sumcheck_value = eval_cubic_poly(&h, &alpha[0]); } arthur .fill_next_scalars(&mut l) @@ -243,23 +391,23 @@ pub fn gpa_sumcheck_verifier( let claimed_last_sch = calculate_eq(&prev_rand, &rand) * eval_linear_poly(&l, &FieldElement::from(0)) * eval_linear_poly(&l, &FieldElement::from(1)); - assert_eq!(claimed_last_sch, last_sumcheck_value); + assert_eq!(claimed_last_sch, a_last_sumcheck_value); rand.push(r[0]); prev_rand = rand; rand = Vec::::new(); - last_sumcheck_value = eval_linear_poly(&l, &r[0]); + a_last_sumcheck_value = eval_linear_poly(&l, &r[0]); } Ok(GPASumcheckResult { claimed_values: claimed_values.to_vec(), - last_sumcheck_value, + a_last_sumcheck_value, randomness: prev_rand, }) } pub struct GPASumcheckResult { pub claimed_values: Vec, - pub last_sumcheck_value: FieldElement, + pub a_last_sumcheck_value: FieldElement, pub randomness: Vec, } diff --git a/spark-prover/src/bin/test-batched-whir.rs b/spark-prover/src/bin/test-batched-whir.rs deleted file mode 100644 index 05e5d2f2..00000000 --- a/spark-prover/src/bin/test-batched-whir.rs +++ /dev/null @@ -1,53 +0,0 @@ -use provekit_common::{FieldElement, IOPattern, WhirR1CSScheme}; -use provekit_r1cs_compiler::WhirR1CSSchemeBuilder; -use whir::{poly_utils::{evals::EvaluationsList, multilinear::MultilinearPoint}, whir::{committer::{CommitmentReader, CommitmentWriter}, domainsep::WhirDomainSeparator, prover::Prover, statement::{Statement, Weights}, verifier::Verifier}}; -use anyhow::Result; - -fn main() -> Result<()> { - const NUM_VARIABLES: usize = 5; // Change this - - let whir_config = WhirR1CSScheme::new_whir_config_for_size(NUM_VARIABLES, 2); - let mut io = IOPattern::new("💥") - .commit_statement(&whir_config) - .add_whir_proof(&whir_config); - let mut merlin = io.to_prover_state(); - - let poly1 = EvaluationsList::new([FieldElement::from(1); 1<::new(NUM_VARIABLES); - - let weight = Weights::linear(EvaluationsList::new([FieldElement::from(0); 1< Result<()> { prove_spark_for_single_matrix( &mut merlin, spark_r1cs.a, - memory, + &memory, e_values.a, request.claimed_values.a, &spark_whir_configs, )?; + prove_spark_for_single_matrix( + &mut merlin, + spark_r1cs.b, + &memory, + e_values.b, + request.claimed_values.b, + &spark_whir_configs, + )?; + + prove_spark_for_single_matrix( + &mut merlin, + spark_r1cs.c, + &memory, + e_values.c, + request.claimed_values.c, + &spark_whir_configs, + )?; + let spark_proof = SPARKProof { transcript: merlin.narg_string().to_vec(), io_pattern: String::from_utf8(io_pattern.as_bytes().to_vec()).unwrap(), diff --git a/spark-prover/src/spark.rs b/spark-prover/src/spark.rs index bc4c88e5..a5c45dc6 100644 --- a/spark-prover/src/spark.rs +++ b/spark-prover/src/spark.rs @@ -28,28 +28,37 @@ use { pub fn prove_spark_for_single_matrix( merlin: &mut ProverState, matrix: SparkMatrix, - memory: Memory, + memory: &Memory, e_values: EValuesForMatrix, claimed_value: FieldElement, whir_configs: &SPARKWHIRConfigs, ) -> Result<()> { - let committer_a = CommitmentWriter::new(whir_configs.a.clone()); - let committer_row = CommitmentWriter::new(whir_configs.row.clone()); - let a_spark_sumcheck_committer = CommitmentWriter::new(whir_configs.a_spark_sumcheck.clone()); - - let val_coeff = EvaluationsList::new(matrix.coo.val.clone()).to_coeffs(); - let e_rx_coeff = EvaluationsList::new(e_values.e_rx.clone()).to_coeffs(); - let e_ry_coeff = EvaluationsList::new(e_values.e_ry.clone()).to_coeffs(); - - let spark_sumcheck_witness = a_spark_sumcheck_committer.commit_batch(merlin, &[val_coeff, e_rx_coeff, e_ry_coeff])?; - - let row_addr_coeff = EvaluationsList::new(matrix.coo.row.clone()).to_coeffs(); - let row_val_coeff = EvaluationsList::new(e_values.e_rx.clone()).to_coeffs(); - let row_timestamp_coeff = EvaluationsList::new(matrix.timestamps.read_row.clone()).to_coeffs(); - - let spark_rowwise_witness = a_spark_sumcheck_committer.commit_batch(merlin, &[row_addr_coeff, row_val_coeff, row_timestamp_coeff])?; - - let final_row_ts_witness = commit_to_vector(&committer_row, merlin, matrix.timestamps.final_row.clone()); + let row_committer = CommitmentWriter::new(whir_configs.row.clone()); + let col_committer = CommitmentWriter::new(whir_configs.col.clone()); + let a_3batched_committer = CommitmentWriter::new(whir_configs.a_3batched.clone()); + + let sumcheck_witness = a_3batched_committer.commit_batch(merlin, &[ + EvaluationsList::new(matrix.coo.val.clone()).to_coeffs(), + EvaluationsList::new(e_values.e_rx.clone()).to_coeffs(), + EvaluationsList::new(e_values.e_ry.clone()).to_coeffs(), + ])?; + + let rowwise_witness = a_3batched_committer.commit_batch(merlin, &[ + EvaluationsList::new(matrix.coo.row.clone()).to_coeffs(), + EvaluationsList::new(e_values.e_rx.clone()).to_coeffs(), + EvaluationsList::new(matrix.timestamps.read_row.clone()).to_coeffs(), + ])?; + + let colwise_witness = a_3batched_committer.commit_batch(merlin, &[ + EvaluationsList::new(matrix.coo.col.clone()).to_coeffs(), + EvaluationsList::new(e_values.e_ry.clone()).to_coeffs(), + EvaluationsList::new(matrix.timestamps.read_col.clone()).to_coeffs(), + ])?; + + let final_row_ts_witness = commit_to_vector(&row_committer, merlin, matrix.timestamps.final_row.clone()); + let final_col_ts_witness = commit_to_vector(&col_committer, merlin, matrix.timestamps.final_col.clone()); + + // Sumcheck let mles = [ matrix.coo.val.clone(), @@ -60,21 +69,23 @@ pub fn prove_spark_for_single_matrix( let (sumcheck_final_folds, folding_randomness) = run_spark_sumcheck(merlin, mles, claimed_value)?; - let mut spark_sumcheck_statement = Statement::::new(folding_randomness.len()); + let mut sumcheck_statement = Statement::::new(folding_randomness.len()); let claimed_batched_value = sumcheck_final_folds[0] + - sumcheck_final_folds[1] * spark_sumcheck_witness.batching_randomness + - sumcheck_final_folds[2] * spark_sumcheck_witness.batching_randomness * spark_sumcheck_witness.batching_randomness; + sumcheck_final_folds[1] * sumcheck_witness.batching_randomness + + sumcheck_final_folds[2] * sumcheck_witness.batching_randomness * sumcheck_witness.batching_randomness; - spark_sumcheck_statement.add_constraint( + sumcheck_statement.add_constraint( Weights::evaluation(MultilinearPoint(folding_randomness.clone())), claimed_batched_value); - let sumcheck_prover = Prover(whir_configs.a_spark_sumcheck.clone()); - sumcheck_prover.prove(merlin, spark_sumcheck_statement, spark_sumcheck_witness)?; + let sumcheck_prover = Prover(whir_configs.a_3batched.clone()); + sumcheck_prover.prove(merlin, sumcheck_statement, sumcheck_witness)?; // Rowwise + // Rowwise Init Final GPA + let mut tau_and_gamma = [FieldElement::from(0); 2]; merlin.fill_challenge_scalars(&mut tau_and_gamma)?; let tau = tau_and_gamma[0]; @@ -102,9 +113,8 @@ pub fn prove_spark_for_single_matrix( let gpa_randomness = run_gpa(merlin, &init_vec, &final_vec); - let (combination_randomness, evaluation_randomness) = gpa_randomness.split_at(1); + let (_combination_randomness, evaluation_randomness) = gpa_randomness.split_at(1); - // TODO: Can I avoid evaluating here? let final_row_eval = EvaluationsList::new(matrix.timestamps.final_row.clone()) .evaluate(&MultilinearPoint(evaluation_randomness.to_vec().clone())); merlin.hint(&final_row_eval)?; @@ -117,6 +127,8 @@ pub fn prove_spark_for_single_matrix( final_row_ts_witness, )?; + // Rowwise RS WS GPA + let rs_address = matrix.coo.row.clone(); let rs_value = e_values.e_rx.clone(); let rs_timestamp = matrix.timestamps.read_row.clone(); @@ -142,7 +154,7 @@ pub fn prove_spark_for_single_matrix( let gpa_randomness = run_gpa(merlin, &rs_vec, &ws_vec); - let (combination_randomness, evaluation_randomness) = gpa_randomness.split_at(1); + let (_combination_randomness, evaluation_randomness) = gpa_randomness.split_at(1); let rs_address_eval = EvaluationsList::new(rs_address) .evaluate(&MultilinearPoint(evaluation_randomness.to_vec().clone())); @@ -156,44 +168,121 @@ pub fn prove_spark_for_single_matrix( .evaluate(&MultilinearPoint(evaluation_randomness.to_vec().clone())); merlin.hint(&rs_timestamp_eval)?; - let mut spark_rowwise_statement = Statement::::new(evaluation_randomness.len()); + let mut rowwise_statement = Statement::::new(evaluation_randomness.len()); let claimed_rowwise_eval = rs_address_eval + - rs_value_eval * spark_rowwise_witness.batching_randomness + - rs_timestamp_eval * spark_rowwise_witness.batching_randomness * spark_rowwise_witness.batching_randomness; + rs_value_eval * rowwise_witness.batching_randomness + + rs_timestamp_eval * rowwise_witness.batching_randomness * rowwise_witness.batching_randomness; - assert!(claimed_rowwise_eval == spark_rowwise_witness.batched_poly().evaluate(&MultilinearPoint(evaluation_randomness.to_vec()))); + assert!(claimed_rowwise_eval == rowwise_witness.batched_poly().evaluate(&MultilinearPoint(evaluation_randomness.to_vec()))); - spark_rowwise_statement.add_constraint( + rowwise_statement.add_constraint( Weights::evaluation(MultilinearPoint(evaluation_randomness.to_vec().clone())), claimed_rowwise_eval); - let sumcheck_prover = Prover(whir_configs.a_spark_sumcheck.clone()); - sumcheck_prover.prove(merlin, spark_rowwise_statement, spark_rowwise_witness)?; + let sumcheck_prover = Prover(whir_configs.a_3batched.clone()); + sumcheck_prover.prove(merlin, rowwise_statement, rowwise_witness)?; + + // Colwise + + // Colwise Init Final GPA + + let mut tau_and_gamma = [FieldElement::from(0); 2]; + merlin.fill_challenge_scalars(&mut tau_and_gamma)?; + let tau = tau_and_gamma[0]; + let gamma = tau_and_gamma[1]; + + let init_address: Vec = (0..memory.eq_ry.len() as u64) + .map(FieldElement::from) + .collect(); + let init_value = memory.eq_ry.clone(); + let init_timestamp = vec![FieldElement::from(0); memory.eq_ry.len()]; + + let init_vec: Vec = izip!(init_address, init_value, init_timestamp) + .map(|(a, v, t)| a * gamma * gamma + v * gamma + t - tau) + .collect(); + + let final_address: Vec = (0..memory.eq_ry.len() as u64) + .map(FieldElement::from) + .collect(); + let final_value = memory.eq_ry.clone(); + let final_timestamp = matrix.timestamps.final_col.clone(); + + let final_vec: Vec = izip!(final_address, final_value, final_timestamp) + .map(|(a, v, t)| a * gamma * gamma + v * gamma + t - tau) + .collect(); + + let gpa_randomness = run_gpa(merlin, &init_vec, &final_vec); + + let (_combination_randomness, evaluation_randomness) = gpa_randomness.split_at(1); + + let final_col_eval = EvaluationsList::new(matrix.timestamps.final_col.clone()) + .evaluate(&MultilinearPoint(evaluation_randomness.to_vec().clone())); + merlin.hint(&final_col_eval)?; + + produce_whir_proof( + merlin, + MultilinearPoint(evaluation_randomness.to_vec()), + final_col_eval, + whir_configs.col.clone(), + final_col_ts_witness, + )?; + + // Colwise RS WS GPA + + let rs_address = matrix.coo.col.clone(); + let rs_value = e_values.e_ry.clone(); + let rs_timestamp = matrix.timestamps.read_col.clone(); + + let rs_vec: Vec = + izip!(rs_address.clone(), rs_value.clone(), rs_timestamp.clone()) + .map(|(a, v, t)| a * gamma * gamma + v * gamma + t - tau) + .collect(); + + let ws_address = matrix.coo.col.clone(); + let ws_value = e_values.e_ry.clone(); + let ws_timestamp: Vec = matrix + .timestamps + .read_col + .into_iter() + .map(|a| a + FieldElement::from(1)) + .collect(); + + let ws_vec: Vec = + izip!(ws_address.clone(), ws_value.clone(), ws_timestamp.clone()) + .map(|(a, v, t)| a * gamma * gamma + v * gamma + t - tau) + .collect(); + + let gpa_randomness = run_gpa(merlin, &rs_vec, &ws_vec); + + let (_combination_randomness, evaluation_randomness) = gpa_randomness.split_at(1); + + let rs_address_eval = EvaluationsList::new(rs_address) + .evaluate(&MultilinearPoint(evaluation_randomness.to_vec().clone())); + merlin.hint(&rs_address_eval)?; + + let rs_value_eval = EvaluationsList::new(rs_value) + .evaluate(&MultilinearPoint(evaluation_randomness.to_vec().clone())); + merlin.hint(&rs_value_eval)?; + + let rs_timestamp_eval = EvaluationsList::new(rs_timestamp) + .evaluate(&MultilinearPoint(evaluation_randomness.to_vec().clone())); + merlin.hint(&rs_timestamp_eval)?; + + let mut colwise_statement = Statement::::new(evaluation_randomness.len()); + + let claimed_colwise_eval = + rs_address_eval + + rs_value_eval * colwise_witness.batching_randomness + + rs_timestamp_eval * colwise_witness.batching_randomness * colwise_witness.batching_randomness; + + assert!(claimed_colwise_eval == colwise_witness.batched_poly().evaluate(&MultilinearPoint(evaluation_randomness.to_vec()))); + + colwise_statement.add_constraint( + Weights::evaluation(MultilinearPoint(evaluation_randomness.to_vec().clone())), claimed_colwise_eval); - // produce_whir_proof( - // merlin, - // MultilinearPoint(evaluation_randomness.to_vec()), - // rs_address_eval, - // whir_configs.a.clone(), - // row_witness.clone(), - // )?; - - // produce_whir_proof( - // merlin, - // MultilinearPoint(evaluation_randomness.to_vec()), - // rs_value_eval, - // whir_configs.a.clone(), - // e_rx_witness.clone(), - // )?; - - // produce_whir_proof( - // merlin, - // MultilinearPoint(evaluation_randomness.to_vec()), - // rs_timestamp_eval, - // whir_configs.a.clone(), - // read_ts_witness.clone(), - // )?; + let sumcheck_prover = Prover(whir_configs.a_3batched.clone()); + sumcheck_prover.prove(merlin, colwise_statement, colwise_witness)?; Ok(()) } diff --git a/spark-prover/src/utilities/iopattern/mod.rs b/spark-prover/src/utilities/iopattern/mod.rs index 6235957c..78480f42 100644 --- a/spark-prover/src/utilities/iopattern/mod.rs +++ b/spark-prover/src/utilities/iopattern/mod.rs @@ -29,14 +29,22 @@ where } pub fn create_io_pattern(r1cs: &R1CS, configs: &SPARKWHIRConfigs) -> IOPattern { - let mut io = IOPattern::new("💥") - .commit_statement(&configs.a_spark_sumcheck) - .commit_statement(&configs.a_spark_sumcheck) + let mut io = IOPattern::new("💥"); + + // Matrix A + + io = io + .commit_statement(&configs.a_3batched) + .commit_statement(&configs.a_3batched) + .commit_statement(&configs.a_3batched) .commit_statement(&configs.row) + .commit_statement(&configs.col) .add_sumcheck_polynomials(next_power_of_two(r1cs.a.num_entries())) .hint("sumcheck_last_folds") - .add_whir_proof(&configs.a_spark_sumcheck); + .add_whir_proof(&configs.a_3batched); + // Rowwise + io = io.add_tau_and_gamma(); for i in 0..=next_power_of_two(r1cs.a.num_rows) { @@ -57,13 +65,150 @@ pub fn create_io_pattern(r1cs: &R1CS, configs: &SPARKWHIRConfigs) -> IOPattern { .hint("RS address claimed evaluation") .hint("RS value claimed evaluation") .hint("RS timestamp claimed evaluation") - .add_whir_proof(&configs.a_spark_sumcheck); + .add_whir_proof(&configs.a_3batched); + + // Colwise + + io = io.add_tau_and_gamma(); + + for i in 0..=next_power_of_two(r1cs.a.num_cols) { + io = io.add_sumcheck_polynomials(i); + io = io.add_line(); + } + + io = io + .hint("Col final counter claimed evaluation") + .add_whir_proof(&configs.col); + + for i in 0..=next_power_of_two(r1cs.a.num_entries()) { + io = io.add_sumcheck_polynomials(i); + io = io.add_line(); + } + + io = io + .hint("RS address claimed evaluation") + .hint("RS value claimed evaluation") + .hint("RS timestamp claimed evaluation") + .add_whir_proof(&configs.a_3batched); + + // Matrix B + + io = io + .commit_statement(&configs.b_3batched) + .commit_statement(&configs.b_3batched) + .commit_statement(&configs.b_3batched) + .commit_statement(&configs.row) + .commit_statement(&configs.col) + .add_sumcheck_polynomials(next_power_of_two(r1cs.a.num_entries())) + .hint("sumcheck_last_folds") + .add_whir_proof(&configs.b_3batched); + + // Rowwise + + io = io.add_tau_and_gamma(); + + for i in 0..=next_power_of_two(r1cs.b.num_rows) { + io = io.add_sumcheck_polynomials(i); + io = io.add_line(); + } + + io = io + .hint("Row final counter claimed evaluation") + .add_whir_proof(&configs.row); + + for i in 0..=next_power_of_two(r1cs.b.num_entries()) { + io = io.add_sumcheck_polynomials(i); + io = io.add_line(); + } + + io = io + .hint("RS address claimed evaluation") + .hint("RS value claimed evaluation") + .hint("RS timestamp claimed evaluation") + .add_whir_proof(&configs.b_3batched); + + // Colwise + + io = io.add_tau_and_gamma(); + + for i in 0..=next_power_of_two(r1cs.b.num_cols) { + io = io.add_sumcheck_polynomials(i); + io = io.add_line(); + } + + io = io + .hint("Col final counter claimed evaluation") + .add_whir_proof(&configs.col); + + for i in 0..=next_power_of_two(r1cs.b.num_entries()) { + io = io.add_sumcheck_polynomials(i); + io = io.add_line(); + } + + io = io + .hint("RS address claimed evaluation") + .hint("RS value claimed evaluation") + .hint("RS timestamp claimed evaluation") + .add_whir_proof(&configs.b_3batched); + + // Matrix C + + io = io + .commit_statement(&configs.c_3batched) + .commit_statement(&configs.c_3batched) + .commit_statement(&configs.c_3batched) + .commit_statement(&configs.row) + .commit_statement(&configs.col) + .add_sumcheck_polynomials(next_power_of_two(r1cs.c.num_entries())) + .hint("sumcheck_last_folds") + .add_whir_proof(&configs.c_3batched); + + // Rowwise + + io = io.add_tau_and_gamma(); + + for i in 0..=next_power_of_two(r1cs.c.num_rows) { + io = io.add_sumcheck_polynomials(i); + io = io.add_line(); + } + + io = io + .hint("Row final counter claimed evaluation") + .add_whir_proof(&configs.row); + + for i in 0..=next_power_of_two(r1cs.c.num_entries()) { + io = io.add_sumcheck_polynomials(i); + io = io.add_line(); + } + + io = io + .hint("RS address claimed evaluation") + .hint("RS value claimed evaluation") + .hint("RS timestamp claimed evaluation") + .add_whir_proof(&configs.c_3batched); - // io = io - // .add_whir_proof(&configs.a); - // .add_whir_proof(&configs.a); - // .add_whir_proof(&configs.a); + // Colwise - // io = io + io = io.add_tau_and_gamma(); + + for i in 0..=next_power_of_two(r1cs.c.num_cols) { + io = io.add_sumcheck_polynomials(i); + io = io.add_line(); + } + + io = io + .hint("Col final counter claimed evaluation") + .add_whir_proof(&configs.col); + + for i in 0..=next_power_of_two(r1cs.c.num_entries()) { + io = io.add_sumcheck_polynomials(i); + io = io.add_line(); + } + + io = io + .hint("RS address claimed evaluation") + .hint("RS value claimed evaluation") + .hint("RS timestamp claimed evaluation") + .add_whir_proof(&configs.c_3batched); io } diff --git a/spark-prover/src/whir.rs b/spark-prover/src/whir.rs index 5580ae80..2d684cb6 100644 --- a/spark-prover/src/whir.rs +++ b/spark-prover/src/whir.rs @@ -41,7 +41,9 @@ pub struct SPARKWHIRConfigs { pub a: WhirConfig, pub b: WhirConfig, pub c: WhirConfig, - pub a_spark_sumcheck: WhirConfig, + pub a_3batched: WhirConfig, + pub b_3batched: WhirConfig, + pub c_3batched: WhirConfig, } pub fn create_whir_configs(r1cs: &R1CS) -> SPARKWHIRConfigs { @@ -51,7 +53,9 @@ pub fn create_whir_configs(r1cs: &R1CS) -> SPARKWHIRConfigs { a: WhirR1CSScheme::new_whir_config_for_size(next_power_of_two(r1cs.a.num_entries()), 1), b: WhirR1CSScheme::new_whir_config_for_size(next_power_of_two(r1cs.b.num_entries()), 1), c: WhirR1CSScheme::new_whir_config_for_size(next_power_of_two(r1cs.c.num_entries()), 1), - a_spark_sumcheck: WhirR1CSScheme::new_whir_config_for_size(next_power_of_two(r1cs.a.num_entries()), 3), + a_3batched: WhirR1CSScheme::new_whir_config_for_size(next_power_of_two(r1cs.a.num_entries()), 3), + b_3batched: WhirR1CSScheme::new_whir_config_for_size(next_power_of_two(r1cs.b.num_entries()), 3), + c_3batched: WhirR1CSScheme::new_whir_config_for_size(next_power_of_two(r1cs.c.num_entries()), 3), } }