diff --git a/.github/workflows/circuit_keys.yml b/.github/workflows/circuit_keys.yml index 3c4914bd..4d9f373e 100644 --- a/.github/workflows/circuit_keys.yml +++ b/.github/workflows/circuit_keys.yml @@ -35,6 +35,9 @@ jobs: - name: Build provekit-cli run: cargo build --release --bin provekit-cli + - name: Build provekit-cli + run: cargo build --release --bin spark-cli + - name: Setup Go uses: actions/setup-go@v5 with: @@ -53,14 +56,20 @@ jobs: - name: Generate Gnark inputs working-directory: ${{ env.CIRCUIT_DIR }} run: | - cargo run --release --bin provekit-cli prepare ./target/${{ env.CIRCUIT_NAME }}.json -p ./noir-provekit-prover.pkp -v ./noir-provekit-verifier.pkv - cargo run --release --bin provekit-cli prove ./noir-provekit-prover.pkp ./Prover.toml -o ./noir-proof.np - cargo run --release --bin provekit-cli generate-gnark-inputs ./noir-provekit-prover.pkp ./noir-proof.np + cargo run --release --bin provekit-cli prepare ./target/${{ env.CIRCUIT_NAME }}.json --pkp ./prover.pkp --pkv ./verifier.pkv + cargo run --release --bin provekit-cli prove ./prover.pkp ./Prover.toml -o ./proof.np + cargo run --release --bin provekit-cli generate-gnark-inputs ./prover.pkp ./proof.np + + - name: Generate Spark Gnark inputs + working-directory: ${{ env.CIRCUIT_DIR }} + run: | + cargo run --release --bin spark-cli -- prove --noir-proof-scheme ./prover.pkp --noir-proof ./proof.np + cargo run --release --bin spark-cli -- verify --spark-proof spark_proof.json --noir-proof ./proof.np - name: Run Gnark verifier working-directory: recursive-verifier run: | - go run cmd/cli/main.go --config "../${{ env.CIRCUIT_DIR }}/params_for_recursive_verifier" --r1cs "../${{ env.CIRCUIT_DIR }}/r1cs.json" --saveKeys "./keys" + go run cmd/cli/main.go --evaluation spark --config "../${{ env.CIRCUIT_DIR }}/params_for_recursive_verifier" --r1cs "../${{ env.CIRCUIT_DIR }}/r1cs.json" --saveKeys "./keys" --spark_config "../${{ env.CIRCUIT_DIR }}/gnark_spark_proof.json" - name: List keys working-directory: recursive-verifier diff --git a/.github/workflows/end-to-end.yml b/.github/workflows/end-to-end.yml index 29e4b44f..a63b5228 100644 --- a/.github/workflows/end-to-end.yml +++ b/.github/workflows/end-to-end.yml @@ -51,9 +51,11 @@ jobs: - name: Generate Gnark inputs working-directory: noir-examples/noir-passport-examples/complete_age_check run: | - cargo run --release --bin provekit-cli prepare ./target/complete_age_check.json -p ./noir-provekit-prover.pkp -v ./noir-provekit-verifier.pkv - cargo run --release --bin provekit-cli prove ./noir-provekit-prover.pkp ./Prover.toml -o ./noir-proof.np - cargo run --release --bin provekit-cli generate-gnark-inputs ./noir-provekit-prover.pkp ./noir-proof.np + cargo run --release --bin provekit-cli prepare ./target/complete_age_check.json --pkp ./prover.pkp --pkv ./verifier.pkv + cargo run --release --bin provekit-cli prove ./prover.pkp ./Prover.toml -o ./proof.np + cargo run --release --bin provekit-cli generate-gnark-inputs ./prover.pkp ./proof.np + cargo run --release --bin spark-cli -- prove --noir-proof-scheme ./prover.pkp --noir-proof ./proof.np + cargo run --release --bin spark-cli -- verify --spark-proof spark_proof.json --noir-proof ./proof.np - name: Run Gnark verifier working-directory: recursive-verifier @@ -85,7 +87,7 @@ jobs: MONITOR_PID=$! # Run the main process - ./gnark-verifier --config "../noir-examples/noir-passport-examples/complete_age_check/params_for_recursive_verifier" --r1cs "../noir-examples/noir-passport-examples/complete_age_check/r1cs.json" --pk_url ${{ vars.AGE_CHECK_PK }} --vk_url ${{ vars.AGE_CHECK_VK }} + ./gnark-verifier --evaluation spark --config "../noir-examples/noir-passport-examples/complete_age_check/params_for_recursive_verifier" --spark_config "../noir-examples/noir-passport-examples/complete_age_check/gnark_spark_proof.json" --r1cs "../noir-examples/noir-passport-examples/complete_age_check/r1cs.json" --pk_url ${{ vars.AGE_CHECK_SPARK_PK }} --vk_url ${{ vars.AGE_CHECK_SPARK_VK }} # Stop monitoring kill $MONITOR_PID \ No newline at end of file diff --git a/.gitignore b/.gitignore index f770c0ae..94f7b19b 100644 --- a/.gitignore +++ b/.gitignore @@ -14,6 +14,8 @@ *.np params_for_recursive_verifier params +*.pkp +*.pkv artifacts/ # Don't ignore benchmarking artifacts diff --git a/Cargo.toml b/Cargo.toml index a0a0f09f..519664df 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,9 +10,11 @@ members = [ "provekit/r1cs-compiler", "provekit/prover", "provekit/verifier", + "provekit/spark", "tooling/cli", "tooling/provekit-bench", "tooling/provekit-gnark", + "tooling/spark-cli", "tooling/verifier-server", "ntt", ] @@ -86,6 +88,7 @@ provekit-prover = { path = "provekit/prover" } provekit-r1cs-compiler = { path = "provekit/r1cs-compiler" } provekit-verifier = { path = "provekit/verifier" } provekit-verifier-server = { path = "tooling/verifier-server" } +provekit-spark = { path = "provekit/spark" } # 3rd party anyhow = "1.0.93" diff --git a/README.md b/README.md index 38dca83b..6eadc336 100644 --- a/README.md +++ b/README.md @@ -12,45 +12,64 @@ noirup --version v1.0.0-beta.11 ## Demo instructions -> _NOTE:_ The example below is being run for single example `poseidon-rounds`. You can use different example to run same commands. +> _NOTE:_ The example below is being run for single example `complete_age_check`. You can use different example to run same commands. -Compile the Noir circuit: +1. Compile the Noir circuit: ```sh -cd noir-examples/poseidon-rounds +cd noir-examples/noir-passport-examples/complete_age_check nargo compile ``` Prepare the Noir program (generates prover and verifier files): ```sh -cargo run --release --bin provekit-cli prepare ./target/basic.json --pkp ./prover.pkp --pkv ./verifier.pkv +cargo run --release --bin provekit-cli prepare ./target/complete_age_check.json --pkp ./prover.pkp --pkv ./verifier.pkv ``` -Generate the Noir Proof using the input Toml: +3. Generate the Noir Proof using the input Toml: ```sh cargo run --release --bin provekit-cli prove ./prover.pkp ./Prover.toml -o ./proof.np ``` -Verify the Noir Proof: +(Optional) Verify the Noir Proof: ```sh cargo run --release --bin provekit-cli verify ./verifier.pkv ./proof.np ``` -Generate inputs for Gnark circuit: +4. Generate inputs for Gnark circuit: ```sh cargo run --release --bin provekit-cli generate-gnark-inputs ./prover.pkp ./proof.np ``` -Recursively verify in a Gnark proof (reads the proof from `../ProveKit/prover/proof`): +5. Recursively verify in a Gnark proof (reads the proof from `../ProveKit/prover/proof`). We provide two methods to prove deferred evaluations. + + 5.1. Spark: Spark prover generates proof of MLE evaluation -```sh -cd ../../recursive-verifier -go run . -``` + 5.1.1 Generate the Spark Proof: + ```sh + cargo run --release --bin spark-cli -- prove --noir-proof-scheme ./prover.pkp --noir-proof ./proof.np + ``` + + (Optional) Verify the Spark Proof: + ```sh + cargo run --release --bin spark-cli -- verify --spark-proof spark_proof.json --noir-proof ./proof.np + ``` + + 5.1.2 Recursively verify + ```sh + cd ../../../recursive-verifier/cmd/cli + go run . --evaluation spark --config ../../../noir-examples/noir-passport-examples/complete_age_check/params_for_recursive_verifier --spark_config ../../../noir-examples/noir-passport-examples/complete_age_check/gnark_spark_proof.json --r1cs ../../../noir-examples/noir-passport-examples/complete_age_check/r1cs.json + ``` + + 5.2. Direct evaluation: verifier directly evaluates MLEs of R1CS matrices + ```sh + cd ../../../recursive-verifier/cmd/cli + go run . --evaluation direct --config ../../../noir-examples/noir-passport-examples/complete_age_check/params_for_recursive_verifier --spark_config ../../../noir-examples/noir-passport-examples/complete_age_check/gnark_spark_proof.json --r1cs ../../../noir-examples/noir-passport-examples/complete_age_check/r1cs.json + ``` ### Benchmarking @@ -160,9 +179,11 @@ ProveKit follows a modular architecture with clear separation of concerns: - **`provekit/r1cs-compiler/`** - R1CS compilation logic and Noir integration - **`provekit/prover/`** - Proving functionality with witness generation - **`provekit/verifier/`** - Verification functionality +- **`provekit/spark/`** - Spark logic ### Tooling - **`tooling/cli/`** - Command-line interface (`provekit-cli`) +- **`tooling/spark-cli/`** - Command-line interface for Spark(`spark-cli`) - **`tooling/provekit-bench/`** - Benchmarking infrastructure - **`tooling/provekit-gnark/`** - Gnark integration utilities @@ -172,7 +193,7 @@ ProveKit follows a modular architecture with clear separation of concerns: ### Examples & Tests - **`noir-examples/`** - Example circuits and test programs -- **`gnark-whir/`** - Go-based recursive verification using Gnark +- **`recursive-verifier/`** - Go-based recursive verification using Gnark ## Dependencies diff --git a/provekit/common/Cargo.toml b/provekit/common/Cargo.toml index d39db74e..d2f33e3d 100644 --- a/provekit/common/Cargo.toml +++ b/provekit/common/Cargo.toml @@ -23,6 +23,7 @@ ark-crypto-primitives.workspace = true ark-ff.workspace = true ark-serialize.workspace = true ark-std.workspace = true +ark-poly.workspace = true spongefish.workspace = true spongefish-pow.workspace = true whir.workspace = true diff --git a/provekit/common/src/gnark.rs b/provekit/common/src/gnark.rs new file mode 100644 index 00000000..998b8789 --- /dev/null +++ b/provekit/common/src/gnark.rs @@ -0,0 +1,76 @@ +use { + crate::WhirConfig, + ark_poly::EvaluationDomain, + serde::{Deserialize, Serialize}, +}; + +#[derive(Debug, Serialize, Deserialize)] + +pub struct WHIRConfigGnark { + /// number of rounds + pub n_rounds: usize, + /// rate + pub rate: usize, + /// number of variables + pub n_vars: usize, + /// folding factor + pub folding_factor: Vec, + /// out of domain samples + pub ood_samples: Vec, + /// number of queries + pub num_queries: Vec, + /// proof of work bits + pub pow_bits: Vec, + /// final queries + pub final_queries: usize, + /// final proof of work bits + pub final_pow_bits: i32, + /// final folding proof of work bits + pub final_folding_pow_bits: i32, + /// domain generator string + pub domain_generator: String, + /// batch size + pub batch_size: usize, +} + +impl WHIRConfigGnark { + pub fn new(whir_params: &WhirConfig) -> Self { + WHIRConfigGnark { + n_rounds: whir_params + .folding_factor + .compute_number_of_rounds(whir_params.mv_parameters.num_variables) + .0, + rate: whir_params.starting_log_inv_rate, + n_vars: whir_params.mv_parameters.num_variables, + folding_factor: (0..(whir_params + .folding_factor + .compute_number_of_rounds(whir_params.mv_parameters.num_variables) + .0)) + .map(|round| whir_params.folding_factor.at_round(round)) + .collect(), + ood_samples: whir_params + .round_parameters + .iter() + .map(|x| x.ood_samples) + .collect(), + num_queries: whir_params + .round_parameters + .iter() + .map(|x| x.num_queries) + .collect(), + pow_bits: whir_params + .round_parameters + .iter() + .map(|x| x.pow_bits as i32) + .collect(), + final_queries: whir_params.final_queries, + final_pow_bits: whir_params.final_pow_bits as i32, + final_folding_pow_bits: whir_params.final_folding_pow_bits as i32, + domain_generator: format!( + "{}", + whir_params.starting_domain.backing_domain.group_gen() + ), + batch_size: whir_params.batch_size, + } + } +} diff --git a/provekit/common/src/lib.rs b/provekit/common/src/lib.rs index 680715d8..53cfbc98 100644 --- a/provekit/common/src/lib.rs +++ b/provekit/common/src/lib.rs @@ -1,24 +1,24 @@ pub mod file; +pub mod gnark; mod interner; mod noir_proof_scheme; mod prover; mod r1cs; pub mod skyscraper; +pub mod spark; mod sparse_matrix; pub mod utils; mod verifier; mod whir_r1cs; pub mod witness; -use crate::{ - interner::{InternedFieldElement, Interner}, - sparse_matrix::{HydratedSparseMatrix, SparseMatrix}, -}; +use crate::interner::{InternedFieldElement, Interner}; pub use { acir::FieldElement as NoirElement, noir_proof_scheme::{NoirProof, NoirProofScheme}, prover::Prover, r1cs::R1CS, + sparse_matrix::{HydratedSparseMatrix, SparseMatrix}, verifier::Verifier, whir::crypto::fields::Field256 as FieldElement, whir_r1cs::{IOPattern, WhirConfig, WhirR1CSProof, WhirR1CSScheme}, diff --git a/provekit/common/src/noir_proof_scheme.rs b/provekit/common/src/noir_proof_scheme.rs index 14a3e6ef..f0b26838 100644 --- a/provekit/common/src/noir_proof_scheme.rs +++ b/provekit/common/src/noir_proof_scheme.rs @@ -1,5 +1,6 @@ use { crate::{ + spark::SparkStatement, whir_r1cs::{WhirR1CSProof, WhirR1CSScheme}, witness::{LayeredWitnessBuilders, NoirWitnessGenerator}, NoirElement, R1CS, @@ -21,6 +22,7 @@ pub struct NoirProofScheme { #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub struct NoirProof { pub whir_r1cs_proof: WhirR1CSProof, + pub spark_statement: SparkStatement, } impl NoirProofScheme { diff --git a/provekit/common/src/spark.rs b/provekit/common/src/spark.rs new file mode 100644 index 00000000..5c314dbc --- /dev/null +++ b/provekit/common/src/spark.rs @@ -0,0 +1,35 @@ +use { + crate::{utils::serde_ark, FieldElement}, + ark_serialize::{CanonicalDeserialize, CanonicalSerialize}, + serde::{Deserialize, Serialize}, +}; + +#[derive( + Debug, Clone, PartialEq, Eq, CanonicalSerialize, Serialize, CanonicalDeserialize, Deserialize, +)] +pub struct Point { + #[serde(with = "serde_ark")] + pub row: Vec, + #[serde(with = "serde_ark")] + pub col: Vec, +} + +#[derive( + Debug, Clone, PartialEq, Eq, CanonicalSerialize, Serialize, CanonicalDeserialize, Deserialize, +)] +pub struct ClaimedValues { + #[serde(with = "serde_ark")] + pub a: FieldElement, + #[serde(with = "serde_ark")] + pub b: FieldElement, + #[serde(with = "serde_ark")] + pub c: FieldElement, +} + +#[derive( + Debug, Clone, PartialEq, Eq, CanonicalSerialize, Serialize, CanonicalDeserialize, Deserialize, +)] +pub struct SparkStatement { + pub point_to_evaluate: Point, + pub claimed_values: ClaimedValues, +} diff --git a/provekit/common/src/utils/sumcheck.rs b/provekit/common/src/utils/sumcheck.rs index 7e1c5a24..a93aa05d 100644 --- a/provekit/common/src/utils/sumcheck.rs +++ b/provekit/common/src/utils/sumcheck.rs @@ -152,6 +152,7 @@ pub fn calculate_evaluations_over_boolean_hypercube_for_eq( } /// Evaluates the equality polynomial recursively. +#[inline(always)] fn eval_eq(eval: &[FieldElement], out: &mut [FieldElement], scalar: FieldElement) { debug_assert_eq!(out.len(), 1 << eval.len()); let size = out.len(); diff --git a/provekit/prover/Cargo.toml b/provekit/prover/Cargo.toml index f031a3b2..865e4178 100644 --- a/provekit/prover/Cargo.toml +++ b/provekit/prover/Cargo.toml @@ -12,6 +12,7 @@ repository.workspace = true # Workspace crates provekit-common.workspace = true skyscraper.workspace = true +serde_json.workspace = true # Noir language acir.workspace = true diff --git a/provekit/prover/src/lib.rs b/provekit/prover/src/lib.rs index 46d8ddd4..2487408f 100644 --- a/provekit/prover/src/lib.rs +++ b/provekit/prover/src/lib.rs @@ -92,12 +92,15 @@ impl Prove for Prover { .context("While verifying R1CS instance")?; // Prove R1CS instance - let whir_r1cs_proof = self + let (whir_r1cs_proof, spark_statement) = self .whir_for_witness .prove(self.r1cs, witness) .context("While proving R1CS instance")?; - Ok(NoirProof { whir_r1cs_proof }) + Ok(NoirProof { + whir_r1cs_proof, + spark_statement, + }) } fn create_witness_io_pattern(&self) -> IOPattern { diff --git a/provekit/prover/src/noir_proof_scheme.rs b/provekit/prover/src/noir_proof_scheme.rs new file mode 100644 index 00000000..e17a4dbe --- /dev/null +++ b/provekit/prover/src/noir_proof_scheme.rs @@ -0,0 +1,137 @@ +use { + crate::{ + r1cs::R1CSSolver, + whir_r1cs::WhirR1CSProver, + witness::{fill_witness, witness_io_pattern::WitnessIOPattern}, + }, + acir::native_types::WitnessMap, + anyhow::{Context, Result}, + bn254_blackbox_solver::Bn254BlackBoxSolver, + nargo::foreign_calls::DefaultForeignCallBuilder, + noirc_abi::InputMap, + provekit_common::{ + skyscraper::SkyscraperSponge, utils::noir_to_native, witness::WitnessBuilder, FieldElement, + IOPattern, NoirElement, NoirProof, NoirProofScheme, + }, + spongefish::{codecs::arkworks_algebra::FieldToUnitSerialize, ProverState}, + tracing::instrument, +}; + +pub trait NoirProofSchemeProver { + fn generate_witness(&self, input_map: &InputMap) -> Result>; + + fn prove(&self, input_map: &InputMap) -> Result; + + fn create_witness_io_pattern(&self) -> IOPattern; + + fn seed_witness_merlin( + &self, + merlin: &mut ProverState, + witness: &WitnessMap, + ) -> Result<()>; +} + +impl NoirProofSchemeProver for NoirProofScheme { + #[instrument(skip_all)] + fn generate_witness(&self, input_map: &InputMap) -> Result> { + let solver = Bn254BlackBoxSolver::default(); + let mut output_buffer = Vec::new(); + let mut foreign_call_executor = DefaultForeignCallBuilder { + output: &mut output_buffer, + enable_mocks: false, + resolver_url: None, + root_path: None, + package_name: None, + } + .build(); + + let initial_witness = self.witness_generator.abi().encode(input_map, None)?; + + let mut witness_stack = nargo::ops::execute_program( + &self.program, + initial_witness, + &solver, + &mut foreign_call_executor, + )?; + + Ok(witness_stack + .pop() + .context("Missing witness results")? + .witness) + } + + #[instrument(skip_all)] + fn prove(&self, input_map: &InputMap) -> Result { + let acir_witness_idx_to_value_map = self.generate_witness(input_map)?; + + // Solve R1CS instance + let witness_io = self.create_witness_io_pattern(); + let mut witness_merlin = witness_io.to_prover_state(); + self.seed_witness_merlin(&mut witness_merlin, &acir_witness_idx_to_value_map)?; + + let partial_witness = self.r1cs.solve_witness_vec( + &self.witness_builders, + &acir_witness_idx_to_value_map, + &mut witness_merlin, + ); + let witness = fill_witness(partial_witness).context("while filling witness")?; + + // Verify witness (redudant with solve) + #[cfg(test)] + self.r1cs + .test_witness_satisfaction(&witness) + .context("While verifying R1CS instance")?; + + // Prove R1CS instance + let (whir_r1cs_proof, spark_statement) = self + .whir_for_witness + .prove(&self.r1cs, witness) + .context("While proving R1CS instance")?; + + Ok(NoirProof { + whir_r1cs_proof, + spark_statement, + }) + } + + fn create_witness_io_pattern(&self) -> IOPattern { + let circuit = &self.program.functions[0]; + let public_idxs = circuit.public_inputs().indices(); + let num_challenges = self + .witness_builders + .iter() + .filter(|b| matches!(b, WitnessBuilder::Challenge(_))) + .count(); + + // Create witness IO pattern + IOPattern::new("📜") + .add_shape() + .add_public_inputs(public_idxs.len()) + .add_logup_challenges(num_challenges) + } + + fn seed_witness_merlin( + &self, + merlin: &mut ProverState, + witness: &WitnessMap, + ) -> Result<()> { + // Absorb circuit shape + let _ = merlin.add_scalars(&[ + FieldElement::from(self.r1cs.num_constraints() as u64), + FieldElement::from(self.r1cs.num_witnesses() as u64), + ]); + + // Absorb public inputs (values) in canonical order + let circuit = &self.program.functions[0]; + let public_idxs = circuit.public_inputs().indices(); + if !public_idxs.is_empty() { + let pub_vals: Vec = public_idxs + .iter() + .map(|&i| noir_to_native(*witness.get_index(i).expect("missing public input"))) + .collect(); + let _ = merlin.add_scalars(&pub_vals); + } + + Ok(()) + } +} diff --git a/provekit/prover/src/whir_r1cs.rs b/provekit/prover/src/whir_r1cs.rs index 4f92e79c..4812ccc7 100644 --- a/provekit/prover/src/whir_r1cs.rs +++ b/provekit/prover/src/whir_r1cs.rs @@ -3,7 +3,9 @@ use { ark_ff::UniformRand, ark_std::{One, Zero}, provekit_common::{ + file::write, skyscraper::{SkyscraperMerkleConfig, SkyscraperSponge}, + spark::{ClaimedValues, Point, SparkStatement}, utils::{ pad_to_power_of_two, sumcheck::{ @@ -14,12 +16,13 @@ use { zk_utils::{create_masked_polynomial, generate_random_multilinear_polynomial}, HALF, }, - FieldElement, IOPattern, WhirConfig, WhirR1CSProof, WhirR1CSScheme, R1CS, + FieldElement, IOPattern, SparseMatrix, WhirConfig, WhirR1CSProof, WhirR1CSScheme, R1CS, }, spongefish::{ codecs::arkworks_algebra::{FieldToUnitSerialize, UnitToField}, ProverState, }, + std::{fs::File, io::Write}, tracing::{info, instrument, warn}, whir::{ poly_utils::{evals::EvaluationsList, multilinear::MultilinearPoint}, @@ -33,12 +36,20 @@ use { }; pub trait WhirR1CSProver { - fn prove(&self, r1cs: R1CS, witness: Vec) -> Result; + fn prove( + &self, + r1cs: R1CS, + witness: Vec, + ) -> Result<(WhirR1CSProof, SparkStatement)>; } impl WhirR1CSProver for WhirR1CSScheme { #[instrument(skip_all)] - fn prove(&self, r1cs: R1CS, witness: Vec) -> Result { + fn prove( + &self, + r1cs: R1CS, + witness: Vec, + ) -> Result<(WhirR1CSProof, SparkStatement)> { ensure!( witness.len() == r1cs.num_witnesses(), "Unexpected witness length for R1CS instance" @@ -92,7 +103,7 @@ impl WhirR1CSProver for WhirR1CSScheme { drop(z); // Compute weights from R1CS instance - let alphas = calculate_external_row_of_r1cs_matrices(alpha, r1cs); + let alphas = calculate_external_row_of_r1cs_matrices(alpha.clone(), r1cs); let (statement, f_sums, g_sums) = create_combined_statement_over_two_polynomials::<3>( self.m, &commitment_to_witness, @@ -104,12 +115,24 @@ impl WhirR1CSProver for WhirR1CSScheme { let _ = merlin.hint::<(Vec, Vec)>(&(f_sums, g_sums)); // Compute WHIR weighted batch opening proof - let (merlin, ..) = + let (merlin, whir_randomness, deferred_evaluations) = run_zk_whir_pcs_prover(commitment_to_witness, statement, &self.whir_witness, merlin); + let spark_statement: SparkStatement = SparkStatement { + point_to_evaluate: Point { + row: alpha, + col: whir_randomness.0, + }, + claimed_values: ClaimedValues { + a: deferred_evaluations[0], + b: deferred_evaluations[1], + c: deferred_evaluations[2], + }, + }; + let transcript = merlin.narg_string().to_vec(); - Ok(WhirR1CSProof { transcript }) + Ok((WhirR1CSProof { transcript }, spark_statement)) } } @@ -143,7 +166,7 @@ pub fn compute_blinding_coefficients_for_round( let two = FieldElement::one() + FieldElement::one(); let mut prefix_multiplier = FieldElement::one(); - for _ in 0..(n - 1 - compute_for) { + for _ in 0..n - 1 - compute_for { prefix_multiplier = prefix_multiplier + prefix_multiplier; } let suffix_multiplier: ark_ff::Fp< diff --git a/provekit/spark/Cargo.toml b/provekit/spark/Cargo.toml new file mode 100644 index 00000000..ba99ba00 --- /dev/null +++ b/provekit/spark/Cargo.toml @@ -0,0 +1,35 @@ +[package] +name = "provekit-spark" +version = "0.1.0" +edition.workspace = true +rust-version.workspace = true +authors.workspace = true +license.workspace = true +homepage.workspace = true +repository.workspace = true + +[dependencies] +provekit-common.workspace = true +provekit-r1cs-compiler.workspace = true +ark-ff.workspace = true +ark-std.workspace = true +anyhow.workspace = true +serde.workspace = true +serde_json.workspace = true +spongefish.workspace = true +whir.workspace = true +itertools = "0.14.0" +tracing.workspace = true +rayon.workspace = true + +[lints] +workspace = true + +[[bin]] +name = "generate_test_r1cs" +path = "src/bin/generate_test_r1cs.rs" + +[[bin]] +name = "generate_test_request" +path = "src/bin/generate_test_request.rs" + diff --git a/provekit/spark/README.md b/provekit/spark/README.md new file mode 100644 index 00000000..bfca9dae --- /dev/null +++ b/provekit/spark/README.md @@ -0,0 +1,44 @@ +# ProveKit SPARK + +SPARK (Sparse Polynomial Argument of Knowledge) prover and verifier implementation for ProveKit. + +## Structure + +- `src/types.rs` - Type definitions (proof, request, matrices, memory) +- `src/prover.rs` - Prover implementation and trait +- `src/verifier.rs` - Verifier implementation and trait +- `src/preprocessing.rs` - R1CS to SPARK matrix conversion +- `src/sumcheck.rs` - Sumcheck protocol (prover + verifier) +- `src/gpa.rs` - Grand Product Argument (prover + verifier) +- `src/memory.rs` - Memory checking (rowwise + colwise) +- `src/utils.rs` - Utilities (I/O, memory calculation, IO patterns) + +## Usage + +Use the `spark-cli` tool in `tooling/spark-cli`: + +```bash +# Prove +cargo run --release --bin spark-cli -- prove --noir-proof-scheme ./noir-provekit-prover.pkp --noir-proof ./noir-proof.np + +# Verify +cargo run --release --bin spark-cli -- verify --spark-proof spark_proof.json --noir-proof ./noir-proof.np +``` + +### Test Utilities + +Generate test R1CS and request files: + +```bash +cargo run -p provekit-spark --bin generate_test_r1cs +cargo run -p provekit-spark --bin generate_test_request +``` + +## Architecture + +The SPARK implementation follows a trait-based design: + +- **SPARKProver**: Trait for proving, implemented by SPARKProverScheme +- **SPARKVerifier**: Trait for verification, implemented by SPARKVerifierScheme + +The prover and verifier share common types and utilities but are otherwise independent, allowing for easy testing and extension. diff --git a/provekit/spark/src/bin/generate_test_r1cs.rs b/provekit/spark/src/bin/generate_test_r1cs.rs new file mode 100644 index 00000000..aec6764e --- /dev/null +++ b/provekit/spark/src/bin/generate_test_r1cs.rs @@ -0,0 +1,32 @@ +use { + provekit_common::{FieldElement, R1CS}, + std::{fs::File, io::Write}, +}; + +fn main() { + let mut r1cs = R1CS::new(); + r1cs.grow_matrices(256, 256); + let interned_1 = r1cs.interner.intern(FieldElement::from(1)); + let interned_2 = r1cs.interner.intern(FieldElement::from(2)); + let interned_3 = r1cs.interner.intern(FieldElement::from(3)); + + for i in 0..256 { + r1cs.a.set(i, i, interned_1); + r1cs.b.set(i, i, interned_2); + r1cs.c.set(i, i, interned_3); + } + + r1cs.a.set(1, 0, interned_1); + r1cs.a.set(2, 0, interned_1); + r1cs.a.set(3, 0, interned_1); + + let matrix_json = + serde_json::to_string(&r1cs).expect("Error: Failed to serialize R1CS to JSON"); + let mut request_file = + File::create("r1cs.json").expect("Error: Failed to create the r1cs.json file"); + request_file + .write_all(matrix_json.as_bytes()) + .expect("Error: Failed to write JSON data to r1cs.json"); + + println!("Generated r1cs.json"); +} diff --git a/provekit/spark/src/bin/generate_test_request.rs b/provekit/spark/src/bin/generate_test_request.rs new file mode 100644 index 00000000..5d9af972 --- /dev/null +++ b/provekit/spark/src/bin/generate_test_request.rs @@ -0,0 +1,33 @@ +use { + provekit_common::{ + spark::{ClaimedValues, Point, SparkStatement}, + FieldElement, + }, + std::{fs::File, io::Write}, +}; + +fn main() { + let mut row = vec![FieldElement::from(0); 8]; + let col = vec![FieldElement::from(0); 9]; + + row[7] = FieldElement::from(1); + + let spark_request = SparkStatement { + point_to_evaluate: Point { row, col }, + claimed_values: ClaimedValues { + a: FieldElement::from(1), + b: FieldElement::from(0), + c: FieldElement::from(0), + }, + }; + + let request_json = + serde_json::to_string(&spark_request).expect("Error: Failed to serialize request to JSON"); + let mut request_file = + File::create("request.json").expect("Error: Failed to create the request.json file"); + request_file + .write_all(request_json.as_bytes()) + .expect("Error: Failed to write JSON data to request.json"); + + println!("Generated request.json"); +} diff --git a/provekit/spark/src/gpa.rs b/provekit/spark/src/gpa.rs new file mode 100644 index 00000000..84efb48a --- /dev/null +++ b/provekit/spark/src/gpa.rs @@ -0,0 +1,490 @@ +use { + provekit_common::{ + skyscraper::SkyscraperSponge, + utils::{ + next_power_of_two, + sumcheck::{ + calculate_eq, calculate_evaluations_over_boolean_hypercube_for_eq, eval_cubic_poly, + sumcheck_fold_map_reduce, + }, + HALF, + }, + FieldElement, + }, + spongefish::{ + codecs::arkworks_algebra::{FieldToUnitDeserialize, FieldToUnitSerialize, UnitToField}, + ProverState, VerifierState, + }, + tracing::instrument, + whir::poly_utils::{evals::EvaluationsList, multilinear::MultilinearPoint}, +}; + +/// Runs the Grand Product Argument (GPA) protocol to prove product equality. +/// +/// GPA constructs a binary multiplication tree from `left` and `right` vectors, +/// then uses sumcheck-based proofs to verify that `∏left[i] = ∏right[i]` +/// without revealing the individual values. +/// +/// This is the core primitive for memory checking in SPARK, enabling efficient +/// verification that read and write sets are consistent. +/// +/// # Arguments +/// +/// * `merlin` - The prover's Fiat-Shamir transcript +/// * `left` - Initial state vector (must be power-of-2 length) +/// * `right` - Final state vector (must match `left` length) +/// +/// # Returns +/// +/// Vector of challenge randomness accumulated across all sumcheck rounds +/// +/// # Panics +/// +/// Panics if input vectors are not power-of-2 length +#[instrument(skip_all)] +pub fn run_gpa2( + merlin: &mut ProverState, + left: &[FieldElement], + right: &[FieldElement], +) -> Vec { + let mut concatenated = left.to_vec(); + concatenated.extend_from_slice(right); + let layers = calculate_binary_multiplication_tree(concatenated); + + let (accumulated_randomness, mut sumcheck_claim) = + add_line_to_transcript(merlin, layers[1].clone()); + let mut accumulated_randomness = accumulated_randomness.to_vec(); + + for i in 2..layers.len() { + (sumcheck_claim, accumulated_randomness) = run_gpa_sumcheck( + merlin, + layers[i].clone(), + sumcheck_claim, + accumulated_randomness, + ); + } + + accumulated_randomness +} + +#[instrument(skip_all)] +pub fn run_gpa4( + merlin: &mut ProverState, + leaves: Vec, +) -> Vec { + let layers = calculate_binary_multiplication_tree(leaves); + + let evaluation_form = EvaluationsList::new(layers[2].clone()); + let coefficient_form = evaluation_form.to_coeffs(); + let coeffs: &[FieldElement] = coefficient_form.coeffs(); + + merlin + .add_scalars(coeffs) + .expect("Failed to add line polynomial to transcript"); + + let mut accumulated_randomness = [FieldElement::from(0); 2].to_vec(); + merlin + .fill_challenge_scalars(&mut accumulated_randomness) + .expect("Failed to sample accumulated_randomness"); + + let mut sumcheck_claim = + coefficient_form.evaluate(&MultilinearPoint(accumulated_randomness.to_vec())); + + for i in 3..layers.len() { + (sumcheck_claim, accumulated_randomness) = run_gpa_sumcheck( + merlin, + layers[i].clone(), + sumcheck_claim, + accumulated_randomness, + ); + } + + accumulated_randomness +} + +/// Constructs a binary multiplication tree from the input vector. +/// +/// Each parent node is the product of its two children, forming a complete +/// binary tree where the root is the product of all elements. +/// +/// # Returns +/// +/// Vector of layers, where: +/// - `layers[0]` is the root (single element) +/// - `layers[layers.len()-1]` is the leaf layer (input) +/// +/// # Panics +/// +/// Panics if input length is not a power of two +fn calculate_binary_multiplication_tree( + array_to_prove: Vec, +) -> Vec> { + assert!( + array_to_prove.len() == (1 << next_power_of_two(array_to_prove.len())), + "Input length must be power of two" + ); + + let mut layers = vec![]; + let mut current_layer = array_to_prove; + + while current_layer.len() > 1 { + let next_layer = current_layer + .chunks_exact(2) + .map(|pair| pair[0] * pair[1]) + .collect(); + + layers.push(current_layer); + current_layer = next_layer; + } + + layers.push(current_layer); + layers.reverse(); + layers +} + +/// Adds a line polynomial to the transcript and samples verifier challenge. +/// +/// Converts evaluations to coefficients, commits them to the transcript, +/// then receives a random challenge to bind the prover to this layer. +/// +/// # Returns +/// +/// Tuple of `(challenge, next_sumcheck_claim)` for the following GPA round +fn add_line_to_transcript( + merlin: &mut ProverState, + arr: Vec, +) -> ([FieldElement; 1], FieldElement) { + let evaluations = EvaluationsList::new(arr); + let coeffs = evaluations.to_coeffs(); + let line_poly: &[FieldElement] = coeffs.coeffs(); + + merlin + .add_scalars(line_poly) + .expect("Failed to add line polynomial to transcript"); + + let mut challenge = [FieldElement::from(0); 1]; + merlin + .fill_challenge_scalars(&mut challenge) + .expect("Failed to sample challenge"); + + let next_claim = line_poly[0] + line_poly[1] * challenge[0]; + + (challenge, next_claim) +} + +/// Executes a single sumcheck round within the GPA protocol. +/// +/// This proves the relation: `eq(r, x) · v₀(x) · v₁(x)` sums correctly +/// over the boolean hypercube, where `v₀` and `v₁` are child layers +/// in the multiplication tree. +/// +/// # Returns +/// +/// Tuple of `(final_evaluations, accumulated_randomness)` for next round +fn run_gpa_sumcheck( + merlin: &mut ProverState, + layer: Vec, + mut sumcheck_claim: FieldElement, + accumulated_randomness: Vec, +) -> (FieldElement, Vec) { + let (mut even_layer, mut odd_layer) = split_even_odd(layer); + + let mut eq_evaluations = + calculate_evaluations_over_boolean_hypercube_for_eq(accumulated_randomness); + let mut challenge = [FieldElement::from(0)]; + let mut round_randomness = Vec::::new(); + let mut fold = None; + + loop { + // Evaluate sumcheck polynomial at special points: 0, -1, ∞ + let [eval_at_0, eval_at_neg1, eval_at_inf_over_x3] = sumcheck_fold_map_reduce( + [&mut eq_evaluations, &mut even_layer, &mut odd_layer], + fold, + |[eq, v0, v1]| { + [ + eq.0 * v0.0 * v1.0, + (eq.0 + eq.0 - eq.1) * (v0.0 + v0.0 - v0.1) * (v1.0 + v1.0 - v1.1), + (eq.1 - eq.0) * (v0.1 - v0.0) * (v1.1 - v1.0), + ] + }, + ); + + if fold.is_some() { + eq_evaluations.truncate(eq_evaluations.len() / 2); + even_layer.truncate(even_layer.len() / 2); + odd_layer.truncate(odd_layer.len() / 2); + } + + // Reconstruct cubic polynomial from evaluation points + let poly_coeffs = reconstruct_cubic_from_evaluations( + sumcheck_claim, + eval_at_0, + eval_at_neg1, + eval_at_inf_over_x3, + ); + + // Verify sumcheck binding: h(0) + h(1) = claimed_sum + assert_eq!( + sumcheck_claim, + poly_coeffs[0] + poly_coeffs[0] + poly_coeffs[1] + poly_coeffs[2] + poly_coeffs[3], + "Sumcheck binding check failed" + ); + + merlin + .add_scalars(&poly_coeffs) + .expect("Failed to add polynomial"); + merlin + .fill_challenge_scalars(&mut challenge) + .expect("Failed to sample challenge"); + + fold = Some(challenge[0]); + sumcheck_claim = eval_cubic_poly(poly_coeffs, challenge[0]); + round_randomness.push(challenge[0]); + + if eq_evaluations.len() <= 2 { + break; + } + } + + let final_v0 = even_layer[0] + (even_layer[1] - even_layer[0]) * challenge[0]; + let final_v1 = odd_layer[0] + (odd_layer[1] - odd_layer[0]) * challenge[0]; + + let evaluations = EvaluationsList::new([final_v0, final_v1].to_vec()); + let coeffs = evaluations.to_coeffs(); + let line_poly: &[FieldElement] = coeffs.coeffs(); + + merlin + .add_scalars(line_poly) + .expect("Failed to add line polynomial to transcript"); + let mut challenge = [FieldElement::from(0); 1]; + merlin + .fill_challenge_scalars(&mut challenge) + .expect("Failed to sample challenge"); + let next_claim = line_poly[0] + line_poly[1] * challenge[0]; + round_randomness.push(challenge[0]); + + (next_claim, round_randomness) +} + +/// Reconstructs cubic polynomial coefficients from special point evaluations. +/// +/// Given evaluations at 0, -1, and ∞/x³, computes the unique cubic polynomial +/// that passes through these points and satisfies the sumcheck binding. +fn reconstruct_cubic_from_evaluations( + binding_value: FieldElement, + at_0: FieldElement, + at_neg1: FieldElement, + at_inf_over_x3: FieldElement, +) -> [FieldElement; 4] { + let mut coeffs = [FieldElement::from(0); 4]; + + coeffs[0] = at_0; + coeffs[2] = HALF * (binding_value + at_neg1 - at_0 - at_0 - at_0); + coeffs[3] = at_inf_over_x3; + coeffs[1] = binding_value - coeffs[0] - coeffs[0] - coeffs[3] - coeffs[2]; + + coeffs +} + +/// Splits vector into even-indexed and odd-indexed elements. +/// +/// Used to separate left/right children in the binary multiplication tree. +fn split_even_odd(input: Vec) -> (Vec, Vec) { + let mut even = Vec::new(); + let mut odd = Vec::new(); + + for (i, item) in input.into_iter().enumerate() { + if i % 2 == 0 { + even.push(item); + } else { + odd.push(item); + } + } + + (even, odd) +} + +/// Result of GPA sumcheck verification containing final randomness and claims. +pub struct GPASumcheckResult { + /// The two claimed values at the leaves (left and right products) + pub claimed_values: Vec, + /// Final sumcheck evaluation after all rounds + pub a_last_sumcheck_value: FieldElement, + /// Accumulated verifier randomness from all rounds + pub randomness: Vec, +} + +/// Verifies a Grand Product Argument proof from the transcript. +/// +/// This is the verifier's counterpart to [`run_gpa2`], checking that the +/// prover's sumcheck proofs are valid without recomputing the multiplication +/// tree. +/// +/// # Arguments +/// +/// * `arthur` - The verifier's transcript state (Fiat-Shamir) +/// * `height_of_binary_tree` - Number of layers in the multiplication tree +/// +/// # Returns +/// +/// [`GPASumcheckResult`] containing verified claims and randomness +pub fn gpa_sumcheck_verifier( + arthur: &mut VerifierState, + height_of_binary_tree: usize, +) -> anyhow::Result { + let mut prev_randomness; + let mut current_randomness = Vec::::new(); + let mut claimed_values = [FieldElement::from(0); 2]; + let mut line_coeffs = [FieldElement::from(0); 2]; + let mut line_challenge = [FieldElement::from(0); 1]; + let mut cubic_coeffs = [FieldElement::from(0); 4]; + let mut sumcheck_challenge = [FieldElement::from(0); 1]; + + arthur.fill_next_scalars(&mut claimed_values)?; + arthur.fill_challenge_scalars(&mut line_challenge)?; + + let mut sumcheck_value = eval_line(&claimed_values, &line_challenge[0]); + current_randomness.push(line_challenge[0]); + prev_randomness = current_randomness; + current_randomness = Vec::new(); + + for layer_idx in 1..height_of_binary_tree - 1 { + for _ in 0..layer_idx { + arthur.fill_next_scalars(&mut cubic_coeffs)?; + arthur.fill_challenge_scalars(&mut sumcheck_challenge)?; + + // Verify sumcheck binding + assert_eq!( + eval_cubic_poly(cubic_coeffs, FieldElement::from(0)) + + eval_cubic_poly(cubic_coeffs, FieldElement::from(1)), + sumcheck_value, + "Sumcheck verification failed at layer {layer_idx}" + ); + + current_randomness.push(sumcheck_challenge[0]); + sumcheck_value = eval_cubic_poly(cubic_coeffs, sumcheck_challenge[0]); + } + + arthur.fill_next_scalars(&mut line_coeffs)?; + arthur.fill_challenge_scalars(&mut line_challenge)?; + + // Verify line polynomial evaluation + let expected_line_value = calculate_eq(&prev_randomness, ¤t_randomness) + * eval_line(&line_coeffs, &FieldElement::from(0)) + * eval_line(&line_coeffs, &FieldElement::from(1)); + assert_eq!( + expected_line_value, sumcheck_value, + "Line evaluation mismatch" + ); + + current_randomness.push(line_challenge[0]); + prev_randomness = current_randomness; + current_randomness = Vec::new(); + sumcheck_value = eval_line(&line_coeffs, &line_challenge[0]); + } + + let claimed_values = [claimed_values[0], claimed_values[0] + claimed_values[1]].to_vec(); + + Ok(GPASumcheckResult { + claimed_values: claimed_values.to_vec(), + a_last_sumcheck_value: sumcheck_value, + randomness: prev_randomness, + }) +} + +pub fn gpa_sumcheck_verifier4( + arthur: &mut VerifierState, + height_of_binary_tree: usize, +) -> anyhow::Result { + let mut claimed_values = [FieldElement::from(0); 4]; + let mut prev_randomness = [FieldElement::from(0); 2]; + let mut current_randomness = Vec::::new(); + let mut line_coeffs = [FieldElement::from(0); 2]; + let mut line_challenge = [FieldElement::from(0); 1]; + let mut cubic_coeffs = [FieldElement::from(0); 4]; + let mut sumcheck_challenge = [FieldElement::from(0); 1]; + + arthur.fill_next_scalars(&mut claimed_values)?; + arthur.fill_challenge_scalars(&mut prev_randomness)?; + let mut prev_randomness = prev_randomness.to_vec(); + + // let mut sumcheck_value = eval_line(&claimed_values, &line_challenge[0]); + let mut sumcheck_value = claimed_values[0] + + claimed_values[1] * prev_randomness[1] + + claimed_values[2] * prev_randomness[0] + + claimed_values[3] * prev_randomness[0] * prev_randomness[1]; + + for layer_idx in 2..height_of_binary_tree - 1 { + for _ in 0..layer_idx { + arthur.fill_next_scalars(&mut cubic_coeffs)?; + arthur.fill_challenge_scalars(&mut sumcheck_challenge)?; + + // Verify sumcheck binding + assert_eq!( + eval_cubic_poly(cubic_coeffs, FieldElement::from(0)) + + eval_cubic_poly(cubic_coeffs, FieldElement::from(1)), + sumcheck_value, + "Sumcheck verification failed at layer {layer_idx}" + ); + + current_randomness.push(sumcheck_challenge[0]); + sumcheck_value = eval_cubic_poly(cubic_coeffs, sumcheck_challenge[0]); + } + + arthur.fill_next_scalars(&mut line_coeffs)?; + arthur.fill_challenge_scalars(&mut line_challenge)?; + + // Verify line polynomial evaluation + let expected_line_value = calculate_eq(&prev_randomness, ¤t_randomness) + * eval_line(&line_coeffs, &FieldElement::from(0)) + * eval_line(&line_coeffs, &FieldElement::from(1)); + assert_eq!( + expected_line_value, sumcheck_value, + "Line evaluation mismatch" + ); + + current_randomness.push(line_challenge[0]); + prev_randomness = current_randomness; + current_randomness = Vec::new(); + sumcheck_value = eval_line(&line_coeffs, &line_challenge[0]); + } + + let claimed_values = [ + claimed_values[0], + claimed_values[0] + claimed_values[1], + claimed_values[0] + claimed_values[2], + claimed_values[0] + claimed_values[1] + claimed_values[2] + claimed_values[3], + ] + .to_vec(); + + Ok(GPASumcheckResult { + claimed_values: claimed_values.to_vec(), + a_last_sumcheck_value: sumcheck_value, + randomness: prev_randomness, + }) +} + +/// Evaluates a linear polynomial at a given point. +/// +/// Computes `poly[0] + point * poly[1]` for a degree-1 polynomial. +pub fn eval_line(poly: &[FieldElement], point: &FieldElement) -> FieldElement { + poly[0] + *point * poly[1] +} + +/// Calculates address from binary representation of evaluation point. +/// +/// Interprets the randomness vector as a binary number in reverse order, +/// converting it to the corresponding memory address in the hypercube. +/// +/// # Example +/// +/// `[r₀, r₁, r₂]` → `r₂·2² + r₁·2¹ + r₀·2⁰` +pub fn calculate_adr(randomness: &[FieldElement]) -> FieldElement { + randomness + .iter() + .rev() + .enumerate() + .fold(FieldElement::from(0), |acc, (i, &r)| { + acc + r * FieldElement::from(1 << i) + }) +} diff --git a/provekit/spark/src/lib.rs b/provekit/spark/src/lib.rs new file mode 100644 index 00000000..cfed5d18 --- /dev/null +++ b/provekit/spark/src/lib.rs @@ -0,0 +1,14 @@ +pub mod gpa; +pub mod memory; +pub mod prover; +pub mod sumcheck; +pub mod types; +pub mod utils; +pub mod verifier; + +pub use { + prover::{SPARKProver, SPARKScheme as SPARKProverScheme}, + types::{MatrixDimensions, SPARKProof, SPARKProofGnark, SPARKWHIRConfigs}, + utils::{calculate_memory, deserialize_r1cs, deserialize_request}, + verifier::{SPARKScheme as SPARKVerifierScheme, SPARKVerifier}, +}; diff --git a/provekit/spark/src/memory.rs b/provekit/spark/src/memory.rs new file mode 100644 index 00000000..cdeaeb37 --- /dev/null +++ b/provekit/spark/src/memory.rs @@ -0,0 +1,298 @@ +use { + crate::{ + gpa::{calculate_adr, gpa_sumcheck_verifier, run_gpa2}, + types::{Memory, SPARKWHIRConfigs}, + }, + anyhow::{ensure, Result}, + ark_ff::{Field, Fp, MontBackend}, + ark_std::One, + itertools::izip, + provekit_common::{ + skyscraper::{SkyscraperMerkleConfig, SkyscraperSponge}, + spark::SparkStatement, + utils::sumcheck::calculate_eq, + FieldElement, WhirConfig, + }, + spongefish::{ProverState, VerifierState}, + tracing::instrument, + whir::{ + crypto::fields::BN254Config, + poly_utils::{evals::EvaluationsList, multilinear::MultilinearPoint}, + whir::{ + committer::{reader::ParsedCommitment, Witness}, + prover::Prover, + statement::{Statement, Weights}, + utils::{HintDeserialize, HintSerialize}, + verifier::Verifier, + }, + }, +}; + +/// Configuration bundle for row/column axis-specific data. +/// +/// This zero-cost abstraction eliminates code duplication between +/// row-wise and column-wise memory checking protocols. +struct AxisConfig<'a> { + eq_memory: &'a [FieldElement], + final_timestamp: &'a [FieldElement], + whir_config: &'a WhirConfig, +} + +/// Proves memory consistency for a single axis (row or column). +/// +/// Executes two GPAs: +/// 1. Init-Final GPA: Proves memory state transitions from initialization to +/// final +/// 2. Read-Write GPA: Proves read-set and write-set timestamps are consistent +/// +/// This is the core of SPARK's memory checking, ensuring that claimed memory +/// values match the actual constraint system evaluations. +#[inline] +fn prove_axis( + merlin: &mut ProverState, + config: AxisConfig<'_>, + final_ts_witness: Witness, + gamma: &FieldElement, + tau: &FieldElement, +) -> Result<()> { + // Construct opening vectors for init/final GPA using Fiat-Shamir challenges. + // Each opening encodes (address, value, timestamp) as: a*γ² + v*γ + t - τ + let init_vec: Vec<_> = izip!(0.., config.eq_memory.iter(), config.final_timestamp.iter()) + .map(|(i, &v, _)| { + let a = FieldElement::from(i); + a * gamma * gamma + v * gamma - tau + }) + .collect(); + + let final_vec: Vec<_> = izip!(0.., config.eq_memory.iter(), config.final_timestamp.iter()) + .map(|(i, &v, &t)| { + let a = FieldElement::from(i); + a * gamma * gamma + v * gamma + t - tau + }) + .collect(); + + let gpa_randomness = run_gpa2(merlin, &init_vec, &final_vec); + let (_combination_randomness, evaluation_randomness) = gpa_randomness.split_at(1); + + let final_ts_eval = EvaluationsList::new(config.final_timestamp.to_vec()) + .evaluate(&MultilinearPoint(evaluation_randomness.to_vec())); + merlin.hint(&final_ts_eval)?; + + produce_whir_proof( + merlin, + MultilinearPoint(evaluation_randomness.to_vec()), + final_ts_eval, + config.whir_config.clone(), + final_ts_witness, + )?; + + Ok(()) +} + +/// Proves row-wise memory consistency for the SPARK protocol. +/// +/// # Arguments +/// * `merlin` - Prover's transcript state +/// * `matrix` - The preprocessed SPARK matrix with COO format and timestamps +/// * `memory` - Pre-computed equality check evaluations +/// * `e_rx` - Row evaluation vector +/// * `whir_configs` - WHIR polynomial commitment configurations +/// * `final_row_ts_witness` - Commitment witness for final row timestamps +/// * `rowwise_witness` - Batched commitment witness for row data +#[instrument(skip_all)] +pub fn prove_rowwise( + merlin: &mut ProverState, + final_row: &Vec, + memory: &Memory, + whir_configs: &SPARKWHIRConfigs, + final_row_ts_witness: Witness, + gamma: &FieldElement, + tau: &FieldElement, +) -> Result<()> { + prove_axis( + merlin, + AxisConfig { + eq_memory: &memory.eq_rx, + final_timestamp: final_row, + whir_config: &whir_configs.row, + }, + final_row_ts_witness, + gamma, + tau, + ) +} + +#[instrument(skip_all)] +pub fn prove_colwise( + merlin: &mut ProverState, + final_col: &Vec, + memory: &Memory, + whir_configs: &SPARKWHIRConfigs, + final_col_ts_witness: Witness, + gamma: &FieldElement, + tau: &FieldElement, +) -> Result<()> { + prove_axis( + merlin, + AxisConfig { + eq_memory: &memory.eq_ry, + final_timestamp: final_col, + whir_config: &whir_configs.col, + }, + final_col_ts_witness, + gamma, + tau, + ) +} + +// ============================================================================ +// Verifier - Generic Implementation +// ============================================================================ + +#[inline] +fn verify_axis( + arthur: &mut VerifierState, + num_axis_items: usize, + whir_config: &WhirConfig, + finalts_commitment: ParsedCommitment< + Fp, 4>, + Fp, 4>, + >, + init_mem_fn: impl Fn(&[FieldElement]) -> FieldElement, + tau: &FieldElement, + gamma: &FieldElement, + claimed_rs: &FieldElement, + claimed_ws: &FieldElement, +) -> Result<()> { + // Init Final GPA + let gpa_result = gpa_sumcheck_verifier( + arthur, + provekit_common::utils::next_power_of_two(num_axis_items) + 2, + )?; + + let claimed_init = gpa_result.claimed_values[0]; + let claimed_final = gpa_result.claimed_values[1]; + let (last_randomness, evaluation_randomness) = gpa_result.randomness.split_at(1); + + let init_adr = calculate_adr(&evaluation_randomness.to_vec()); + let init_mem = init_mem_fn(&evaluation_randomness.to_vec()); + let init_opening = init_adr * gamma * gamma + init_mem * gamma - tau; + + let final_cntr: FieldElement = arthur.hint()?; + + let mut final_cntr_statement = + Statement::::new(provekit_common::utils::next_power_of_two(num_axis_items)); + final_cntr_statement.add_constraint( + Weights::evaluation(MultilinearPoint(evaluation_randomness.to_vec())), + final_cntr, + ); + + let final_cntr_verifier = Verifier::new(whir_config); + final_cntr_verifier.verify(arthur, &finalts_commitment, &final_cntr_statement)?; + + let final_adr = calculate_adr(&evaluation_randomness.to_vec()); + let final_mem = init_mem_fn(&evaluation_randomness.to_vec()); + let final_opening = final_adr * gamma * gamma + final_mem * gamma + final_cntr - tau; + + let evaluated_value = init_opening * (FieldElement::one() - last_randomness[0]) + + final_opening * last_randomness[0]; + + ensure!(evaluated_value == gpa_result.a_last_sumcheck_value); + + ensure!(claimed_init * claimed_ws == claimed_final * claimed_rs); + + Ok(()) +} + +// ============================================================================ +// Public API - Verifier +// ============================================================================ + +pub fn verify_rowwise( + arthur: &mut VerifierState, + num_rows: usize, + whir_params: &SPARKWHIRConfigs, + request: &SparkStatement, + row_finalts_commitment: ParsedCommitment< + Fp, 4>, + Fp, 4>, + >, + tau: &FieldElement, + gamma: &FieldElement, + claimed_rs: &FieldElement, + claimed_ws: &FieldElement, + matrix_batching_randomness: &FieldElement, +) -> Result<()> { + let b1 = *matrix_batching_randomness / (FieldElement::ONE + *matrix_batching_randomness); + let row_evaluation_point: Vec = std::iter::once(b1) + .chain(request.point_to_evaluate.row.clone()) + .collect(); + verify_axis( + arthur, + num_rows, + &whir_params.row, + row_finalts_commitment, + |eval_rand| calculate_eq(&row_evaluation_point, eval_rand), + tau, + gamma, + claimed_rs, + claimed_ws, + ) +} + +pub fn verify_colwise( + arthur: &mut VerifierState, + num_cols: usize, + whir_params: &SPARKWHIRConfigs, + request: &SparkStatement, + col_finalts_commitment: ParsedCommitment< + Fp, 4>, + Fp, 4>, + >, + tau: &FieldElement, + gamma: &FieldElement, + claimed_rs: &FieldElement, + claimed_ws: &FieldElement, + matrix_batching_randomness: &FieldElement, +) -> Result<()> { + let b1 = *matrix_batching_randomness / (FieldElement::ONE + *matrix_batching_randomness); + let col_evaluation_point: Vec = std::iter::once(b1) + .chain(request.point_to_evaluate.col[1..].to_vec()) + .collect(); + verify_axis( + arthur, + num_cols, + &whir_params.col, + col_finalts_commitment, + |eval_rand| { + calculate_eq(&col_evaluation_point, eval_rand) + * (FieldElement::from(1) - request.point_to_evaluate.col[0]) + }, + tau, + gamma, + claimed_rs, + claimed_ws, + ) +} + +/// Helper to generate and verify a WHIR proof at a specific evaluation point. +/// +/// # Note +/// This is called multiple times during SPARK proving for different polynomial +/// commitments. +#[instrument(skip_all)] +pub fn produce_whir_proof( + merlin: &mut ProverState, + evaluation_point: MultilinearPoint, + evaluated_value: FieldElement, + config: WhirConfig, + witness: Witness, +) -> Result<()> { + let mut statement = Statement::::new(evaluation_point.num_variables()); + statement.add_constraint(Weights::evaluation(evaluation_point), evaluated_value); + + let prover = Prover::new(config); + prover.prove(merlin, statement, witness)?; + + Ok(()) +} diff --git a/provekit/spark/src/prover.rs b/provekit/spark/src/prover.rs new file mode 100644 index 00000000..123fdd3f --- /dev/null +++ b/provekit/spark/src/prover.rs @@ -0,0 +1,582 @@ +use { + crate::{ + gpa::run_gpa4, + memory::{produce_whir_proof, prove_colwise, prove_rowwise}, + sumcheck::run_spark_sumcheck, + types::{ + COOMatrix, EValuesForMatrix, MatrixDimensions, Memory, SPARKProof, SPARKWHIRConfigs, + SparkMatrix, TimeStamps, + }, + utils::{calculate_memory, SPARKDomainSeparator}, + }, + anyhow::{Context, Result}, + ark_ff::{AdditiveGroup, Field}, + itertools::izip, + provekit_common::{ + skyscraper::{SkyscraperMerkleConfig, SkyscraperSponge}, + spark::SparkStatement, + utils::{next_power_of_two, sumcheck::SumcheckIOPattern}, + FieldElement, IOPattern, WhirConfig, WhirR1CSScheme, R1CS, + }, + provekit_r1cs_compiler::WhirR1CSSchemeBuilder, + rayon::{join, prelude::*}, + spongefish::{ + codecs::arkworks_algebra::{FieldToUnitSerialize, UnitToField}, + ProverState, + }, + tracing::instrument, + whir::{ + poly_utils::{evals::EvaluationsList, multilinear::MultilinearPoint}, + whir::{ + committer::{CommitmentWriter, Witness}, + domainsep::WhirDomainSeparator, + prover::Prover, + statement::{Statement, Weights}, + utils::HintSerialize, + }, + }, +}; + +/// SPARK proving interface for R1CS constraint systems. +pub trait SPARKProver { + /// Generates a SPARK proof from R1CS and evaluation request. + fn prove(&self, r1cs: &R1CS, request: &SparkStatement) -> Result; +} + +/// SPARK scheme with pre-configured WHIR parameters and IO pattern. +pub struct SPARKScheme { + pub whir_configs: SPARKWHIRConfigs, + pub io_pattern: IOPattern, + pub matrix_dimensions: MatrixDimensions, +} + +impl SPARKScheme { + /// Configures SPARK scheme for given R1CS dimensions. + pub fn new_for_r1cs(r1cs: &R1CS) -> Self { + let num_rows = 2 * r1cs.num_constraints(); + let num_cols = 2 * r1cs.num_witnesses(); + let nonzero_terms = + r1cs.a().iter().count() + r1cs.b().iter().count() + r1cs.c().iter().count(); + + let padded_num_entries = 1 << next_power_of_two(nonzero_terms); + + let row_config = WhirR1CSScheme::new_whir_config_for_size(next_power_of_two(num_rows), 1); + let col_config = WhirR1CSScheme::new_whir_config_for_size(next_power_of_two(num_cols), 1); + let num_terms_1batched_config = + WhirR1CSScheme::new_whir_config_for_size(next_power_of_two(padded_num_entries), 1); + let num_terms_2batched_config = + WhirR1CSScheme::new_whir_config_for_size(next_power_of_two(padded_num_entries), 2); + let num_terms_4batched_config = + WhirR1CSScheme::new_whir_config_for_size(next_power_of_two(padded_num_entries), 4); + + let whir_configs = SPARKWHIRConfigs { + row: row_config.clone(), + col: col_config.clone(), + num_terms_1batched: num_terms_1batched_config.clone(), + num_terms_2batched: num_terms_2batched_config.clone(), + num_terms_4batched: num_terms_4batched_config.clone(), + }; + + let mut io = IOPattern::new("💥"); + + io = io + .hint("point_row") + .hint("point_col") + .add_claimed_evaluations(); + + io = io + .commit_statement(&num_terms_1batched_config) + .commit_statement(&num_terms_4batched_config) + .commit_statement(&row_config) + .commit_statement(&col_config) + .commit_statement(&num_terms_2batched_config) + .add_sumcheck_polynomials(next_power_of_two(padded_num_entries)) + .hint("sumcheck_last_folds") + .add_whir_proof(&num_terms_2batched_config) + .add_whir_proof(&num_terms_1batched_config); + + io = io.add_tau_and_gamma(); + + io = io.add_gpa4_claimed_values(); + for i in 2..=(next_power_of_two(padded_num_entries) + 1) { + io = io.add_sumcheck_polynomials(i).add_line(); + } + + io = io + .hint("row_rs_address_claimed_evaluation") + .hint("row_rs_timestamp_claimed_evaluation") + .hint("col_rs_address_claimed_evaluation") + .hint("col_rs_timestamp_claimed_evaluation") + .add_whir_proof(&num_terms_4batched_config); + + io = io + .hint("row_rs_value_claimed_evaluation") + .hint("col_rs_value_claimed_evaluation") + .add_whir_proof(&num_terms_2batched_config); + + for i in 0..=next_power_of_two(num_rows) { + io = io.add_sumcheck_polynomials(i).add_line(); + } + io = io + .hint("row_final_counter_claimed_evaluation") + .add_whir_proof(&row_config); + + for i in 0..=next_power_of_two(num_cols) { + io = io.add_sumcheck_polynomials(i).add_line(); + } + io = io + .hint("col_final_counter_claimed_evaluation") + .add_whir_proof(&col_config); + + Self { + whir_configs, + io_pattern: io, + matrix_dimensions: MatrixDimensions { + num_rows, + num_cols, + nonzero_terms, + }, + } + } +} + +impl SPARKProver for SPARKScheme { + #[instrument(skip_all)] + fn prove(&self, r1cs: &R1CS, request: &SparkStatement) -> Result { + let original_num_entries = + r1cs.a().iter().count() + r1cs.b().iter().count() + r1cs.c().iter().count(); + let padded_num_entries = 1 << next_power_of_two(original_num_entries); + let to_fill = padded_num_entries - original_num_entries; + + let row_cnt = r1cs.num_constraints(); + let col_cnt = r1cs.num_witnesses(); + + let (mut row, mut col, mut val) = ( + Vec::with_capacity(padded_num_entries), + Vec::with_capacity(padded_num_entries), + Vec::with_capacity(padded_num_entries), + ); + + for ((r, c), v) in r1cs.a().iter() { + row.push(r); + col.push(c); + val.push(v); + } + for ((r, c), v) in r1cs.b().iter() { + row.push(r); + col.push(c + col_cnt); + val.push(v); + } + for ((r, c), v) in r1cs.c().iter() { + row.push(r + row_cnt); + col.push(c + col_cnt); + val.push(v); + } + for _ in 0..to_fill { + row.push(0); + col.push(0); + val.push(FieldElement::from(0)); + } + + // Memory timestamps track access order for GPA protocol + + let mut read_row_counters = vec![0; 2 * r1cs.num_constraints()]; + let mut read_col_counters = vec![0; 2 * r1cs.num_witnesses()]; + let mut read_row = Vec::with_capacity(padded_num_entries); + let mut read_col = Vec::with_capacity(padded_num_entries); + + for i in 0..padded_num_entries { + read_row.push(FieldElement::from(read_row_counters[row[i]] as u64)); + read_row_counters[row[i]] += 1; + read_col.push(FieldElement::from(read_col_counters[col[i]] as u64)); + read_col_counters[col[i]] += 1; + } + + let final_row = read_row_counters + .iter() + .map(|&x| FieldElement::from(x as u64)) + .collect::>(); + + let final_col = read_col_counters + .iter() + .map(|&x| FieldElement::from(x as u64)) + .collect::>(); + + let mut merlin = self.io_pattern.to_prover_state(); + + merlin.hint(&request.point_to_evaluate.row)?; + merlin.hint(&request.point_to_evaluate.col)?; + + merlin.add_scalars(&[ + request.claimed_values.a, + request.claimed_values.b, + request.claimed_values.c, + ])?; + let mut matrix_batching_randomness = [FieldElement::ZERO; 1]; + merlin.fill_challenge_scalars(&mut matrix_batching_randomness)?; + + let memory = calculate_memory( + matrix_batching_randomness[0] / (FieldElement::ONE + matrix_batching_randomness[0]), + request.point_to_evaluate.clone(), + ); + let mut claimed_value = request.claimed_values.a + + request.claimed_values.b * matrix_batching_randomness[0] + + request.claimed_values.c + * matrix_batching_randomness[0] + * matrix_batching_randomness[0]; + claimed_value = (claimed_value / (FieldElement::ONE + matrix_batching_randomness[0])) + / (FieldElement::ONE + matrix_batching_randomness[0]); + + let mut e_rx = Vec::with_capacity(padded_num_entries); + let mut e_ry = Vec::with_capacity(padded_num_entries); + + for i in 0..padded_num_entries { + e_rx.push(memory.eq_rx[row[i]]); + e_ry.push(memory.eq_ry[col[i]]); + } + + let e_values = EValuesForMatrix { e_rx, e_ry }; + + // let spark_matrix = processed.into_spark_matrix(r1cs, + // matrix_batching_randomness); + let spark_matrix = SparkMatrix { + coo: COOMatrix { + row: row.iter().map(|r| FieldElement::from(*r as u64)).collect(), + col: col.iter().map(|c| FieldElement::from(*c as u64)).collect(), + val, + }, + timestamps: TimeStamps { + read_row, + read_col, + final_row, + final_col, + }, + }; + + prove_spark_for_single_matrix( + &mut merlin, + spark_matrix, + &memory, + e_values, + claimed_value, + &self.whir_configs, + )?; + + Ok(SPARKProof { + transcript: merlin.narg_string().to_vec(), + io_pattern: String::from_utf8(self.io_pattern.as_bytes().to_vec())?, + whir_params: self.whir_configs.clone(), + matrix_dimensions: self.matrix_dimensions.clone(), + }) + } +} + +/// Core SPARK protocol: sumcheck + row/col memory checking. +#[instrument(skip_all)] +fn prove_spark_for_single_matrix( + merlin: &mut ProverState, + matrix: SparkMatrix, + memory: &Memory, + e_values: EValuesForMatrix, + claimed_value: FieldElement, + whir_configs: &SPARKWHIRConfigs, +) -> Result<()> { + let (vals_witness, rs_ws_witness, final_row_ts_witness, final_col_ts_witness, evalues_witness) = + generate_witnesses(merlin, whir_configs, &matrix, &e_values)?; + + spark_sumcheck( + merlin, + &matrix.coo.val, + &e_values.e_rx, + &e_values.e_ry, + &claimed_value, + &evalues_witness, + &vals_witness, + &whir_configs.num_terms_1batched, + &whir_configs.num_terms_2batched, + )?; + + let mut tau_and_gamma = [FieldElement::from(0); 2]; + merlin.fill_challenge_scalars(&mut tau_and_gamma)?; + let tau = tau_and_gamma[0]; + let gamma = tau_and_gamma[1]; + // RS WS combined (extracted to helper) + run_rs_ws_gpa_and_proofs( + merlin, + &matrix, + &e_values, + rs_ws_witness, + evalues_witness, + whir_configs, + &gamma, + &tau, + )?; + + // Potential optimization: Init and Final can be done together in one GPA if we + // make their lengths equal. + + prove_rowwise( + merlin, + &matrix.timestamps.final_row, + memory, + whir_configs, + final_row_ts_witness, + &gamma, + &tau, + )?; + + prove_colwise( + merlin, + &matrix.timestamps.final_col, + memory, + whir_configs, + final_col_ts_witness, + &gamma, + &tau, + )?; + + Ok(()) +} + +/// Commits to vector and returns WHIR witness. +fn commit_to_vector( + committer: &CommitmentWriter< + FieldElement, + SkyscraperMerkleConfig, + provekit_common::skyscraper::SkyscraperPoW, + >, + merlin: &mut spongefish::ProverState, + vector: Vec, +) -> Result> { + assert!( + vector.len().is_power_of_two(), + "Vector length must be power of two" + ); + let evals = EvaluationsList::new(vector); + let coeffs = evals.to_coeffs(); + committer + .commit(merlin, &coeffs) + .context("WHIR commitment failed") +} + +#[instrument(skip_all)] +fn spark_sumcheck( + merlin: &mut ProverState, + val: &[FieldElement], + e_rx: &[FieldElement], + e_ry: &[FieldElement], + claimed_value: &FieldElement, + evalues_witness: &Witness, + vals_witness: &Witness, + num_terms_1batched: &WhirConfig, + num_terms_2batched: &WhirConfig, +) -> Result<()> { + let mles: [&[FieldElement]; 3] = [val, e_rx, e_ry]; + let (sumcheck_final_folds, folding_randomness) = + run_spark_sumcheck(merlin, mles, *claimed_value)?; + + merlin.hint::>( + &[ + sumcheck_final_folds[0], + sumcheck_final_folds[1], + sumcheck_final_folds[2], + ] + .to_vec(), + )?; + + let batched_e_claimed = + sumcheck_final_folds[1] + sumcheck_final_folds[2] * evalues_witness.batching_randomness; + + produce_whir_proof( + merlin, + MultilinearPoint(folding_randomness.to_vec()), + batched_e_claimed, + num_terms_2batched.clone(), + evalues_witness.clone(), + )?; + + produce_whir_proof( + merlin, + MultilinearPoint(folding_randomness.to_vec()), + sumcheck_final_folds[0], + num_terms_1batched.clone(), + vals_witness.clone(), + )?; + + Ok(()) +} + +/// Helper that runs the RS/WS GPA, produces the batched WHIR openings for +/// rs/ws and the e-values, and advances the transcript accordingly. +#[instrument(skip_all)] +fn run_rs_ws_gpa_and_proofs( + merlin: &mut ProverState, + matrix: &SparkMatrix, + e_values: &EValuesForMatrix, + rs_ws_witness: Witness, + evalues_witness: Witness, + whir_configs: &SPARKWHIRConfigs, + gamma: &FieldElement, + tau: &FieldElement, +) -> Result<()> { + // RS WS combined + let gamma_sq = *gamma * *gamma; + let one = FieldElement::from(1u64); + + // Build row/col RS/WS vectors in parallel + let n = matrix.coo.row.len(); + let m = matrix.coo.col.len(); + + let (row_pairs, col_pairs) = join( + || { + (0..n) + .into_par_iter() + .map(|i| { + let a = matrix.coo.row[i]; + let v = e_values.e_rx[i]; + let t = matrix.timestamps.read_row[i]; + let base = a * gamma_sq + v * *gamma + t - *tau; + (base, base + one) + }) + .collect::>() + }, + || { + (0..m) + .into_par_iter() + .map(|i| { + let a = matrix.coo.col[i]; + let v = e_values.e_ry[i]; + let t = matrix.timestamps.read_col[i]; + let base = a * gamma_sq + v * *gamma + t - *tau; + (base, base + one) + }) + .collect::>() + }, + ); + let (row_rs_vec, row_ws_vec): (Vec<_>, Vec<_>) = row_pairs.into_iter().unzip(); + let (col_rs_vec, col_ws_vec): (Vec<_>, Vec<_>) = col_pairs.into_iter().unzip(); + + let mut gpa_leaves_flat = Vec::with_capacity(4 * row_rs_vec.len()); + let gpa_leaves = [row_rs_vec, row_ws_vec, col_rs_vec, col_ws_vec]; + gpa_leaves_flat.extend(gpa_leaves.into_iter().flatten()); + let gpa_randomness = run_gpa4(merlin, gpa_leaves_flat); + + let (_combination_randomness, evaluation_randomness) = gpa_randomness.split_at(2); + let eval_point = MultilinearPoint(evaluation_randomness.to_vec()); + + // Evaluate addresses/timestamps in parallel + let ((row_address_eval, row_timestamp_eval), (col_address_eval, col_timestamp_eval)) = join( + || { + join( + || EvaluationsList::new(matrix.coo.row.clone()).evaluate(&eval_point), + || EvaluationsList::new(matrix.timestamps.read_row.clone()).evaluate(&eval_point), + ) + }, + || { + join( + || EvaluationsList::new(matrix.coo.col.clone()).evaluate(&eval_point), + || EvaluationsList::new(matrix.timestamps.read_col.clone()).evaluate(&eval_point), + ) + }, + ); + + merlin.hint(&row_address_eval)?; + merlin.hint(&row_timestamp_eval)?; + merlin.hint(&col_address_eval)?; + merlin.hint(&col_timestamp_eval)?; + + let rs_ws_claimed_eval = row_address_eval + + row_timestamp_eval * rs_ws_witness.batching_randomness + + col_address_eval * rs_ws_witness.batching_randomness * rs_ws_witness.batching_randomness + + col_timestamp_eval + * rs_ws_witness.batching_randomness + * rs_ws_witness.batching_randomness + * rs_ws_witness.batching_randomness; + + produce_whir_proof( + merlin, + eval_point.clone(), + rs_ws_claimed_eval, + whir_configs.num_terms_4batched.clone(), + rs_ws_witness, + )?; + + // Evaluate e-values in parallel + let (row_value_eval, col_value_eval) = join( + || EvaluationsList::new(e_values.e_rx.clone()).evaluate(&eval_point), + || EvaluationsList::new(e_values.e_ry.clone()).evaluate(&eval_point), + ); + merlin.hint(&row_value_eval)?; + merlin.hint(&col_value_eval)?; + + let claimed_e_eval = row_value_eval + col_value_eval * evalues_witness.batching_randomness; + + produce_whir_proof( + merlin, + eval_point, + claimed_e_eval, + whir_configs.num_terms_2batched.clone(), + evalues_witness, + )?; + + Ok(()) +} + +#[instrument(skip_all)] +fn generate_witnesses( + merlin: &mut ProverState, + whir_configs: &SPARKWHIRConfigs, + matrix: &SparkMatrix, + e_values: &EValuesForMatrix, +) -> Result<( + Witness, + Witness, + Witness, + Witness, + Witness, +)> { + let row_committer = CommitmentWriter::new(whir_configs.row.clone()); + let col_committer = CommitmentWriter::new(whir_configs.col.clone()); + let batched1_committer = CommitmentWriter::new(whir_configs.num_terms_1batched.clone()); + let batched2_committer = CommitmentWriter::new(whir_configs.num_terms_2batched.clone()); + let batched4_committer = CommitmentWriter::new(whir_configs.num_terms_4batched.clone()); + + // Should be committed before request: + let vals_witness = batched1_committer + .commit_batch(merlin, &[ + &EvaluationsList::new(matrix.coo.val.clone()).to_coeffs() + ]) + .context("Commit batch for vals_witness failed")?; + + let rs_ws_witness = batched4_committer + .commit_batch(merlin, &[ + &EvaluationsList::new(matrix.coo.row.clone()).to_coeffs(), + &EvaluationsList::new(matrix.timestamps.read_row.clone()).to_coeffs(), + &EvaluationsList::new(matrix.coo.col.clone()).to_coeffs(), + &EvaluationsList::new(matrix.timestamps.read_col.clone()).to_coeffs(), + ]) + .context("Commit batch for rs_ws_witness failed")?; + + let final_row_ts_witness = + commit_to_vector(&row_committer, merlin, matrix.timestamps.final_row.clone()) + .context("Commit to final_row timestamps failed")?; + let final_col_ts_witness = + commit_to_vector(&col_committer, merlin, matrix.timestamps.final_col.clone()) + .context("Commit to final_col timestamps failed")?; + + // Commited for each request: + let evalues_witness = batched2_committer + .commit_batch(merlin, &[ + &EvaluationsList::new(e_values.e_rx.clone()).to_coeffs(), + &EvaluationsList::new(e_values.e_ry.clone()).to_coeffs(), + ]) + .context("Commit batch for evalues_witness failed")?; + + Ok(( + vals_witness, + rs_ws_witness, + final_row_ts_witness, + final_col_ts_witness, + evalues_witness, + )) +} diff --git a/provekit/spark/src/sumcheck.rs b/provekit/spark/src/sumcheck.rs new file mode 100644 index 00000000..a822f561 --- /dev/null +++ b/provekit/spark/src/sumcheck.rs @@ -0,0 +1,134 @@ +use { + anyhow::{ensure, Result}, + ark_std::{One, Zero}, + provekit_common::{ + skyscraper::SkyscraperSponge, + utils::{ + sumcheck::{eval_cubic_poly, sumcheck_fold_map_reduce}, + HALF, + }, + FieldElement, + }, + spongefish::{ + codecs::arkworks_algebra::{FieldToUnitDeserialize, FieldToUnitSerialize, UnitToField}, + ProverState, VerifierState, + }, + tracing::instrument, +}; + +/// Runs sumcheck protocol for SPARK matrix evaluation. +/// +/// Proves that `∑ m₀(x) · m₁(x) · m₂(x) = claimed_value` over the boolean +/// hypercube without revealing individual polynomial values. +/// +/// # Returns +/// +/// Tuple of `(final_folded_values, accumulated_randomness)` +#[instrument(skip_all)] +pub fn run_spark_sumcheck( + merlin: &mut ProverState, + mles: [&[FieldElement]; 3], + mut claimed_value: FieldElement, +) -> Result<([FieldElement; 3], Vec)> { + let mut sumcheck_randomness = [FieldElement::from(0)]; + let mut sumcheck_randomness_accumulator = Vec::::new(); + let mut fold = None; + + let mut m0 = mles[0].to_vec(); + let mut m1 = mles[1].to_vec(); + let mut m2 = mles[2].to_vec(); + + loop { + // Evaluate cubic at special points: 0, -1, ∞ + let [hhat_i_at_0, hhat_i_at_em1, hhat_i_at_inf_over_x_cube] = + sumcheck_fold_map_reduce([&mut m0, &mut m1, &mut m2], fold, |[m0, m1, m2]| { + [ + m0.0 * m1.0 * m2.0, + (m0.0 + m0.0 - m0.1) * (m1.0 + m1.0 - m1.1) * (m2.0 + m2.0 - m2.1), + (m0.1 - m0.0) * (m1.1 - m1.0) * (m2.1 - m2.0), + ] + }); + + if fold.is_some() { + m0.truncate(m0.len() / 2); + m1.truncate(m1.len() / 2); + m2.truncate(m2.len() / 2); + } + + let mut hhat_i_coeffs = [FieldElement::from(0); 4]; + + hhat_i_coeffs[0] = hhat_i_at_0; + hhat_i_coeffs[2] = + HALF * (claimed_value + hhat_i_at_em1 - hhat_i_at_0 - hhat_i_at_0 - hhat_i_at_0); + hhat_i_coeffs[3] = hhat_i_at_inf_over_x_cube; + hhat_i_coeffs[1] = claimed_value + - hhat_i_coeffs[0] + - hhat_i_coeffs[0] + - hhat_i_coeffs[3] + - hhat_i_coeffs[2]; + + debug_assert_eq!( + claimed_value, + hhat_i_coeffs[0] + + hhat_i_coeffs[0] + + hhat_i_coeffs[1] + + hhat_i_coeffs[2] + + hhat_i_coeffs[3], + "Sumcheck binding check failed" + ); + + merlin.add_scalars(&hhat_i_coeffs[..])?; + merlin.fill_challenge_scalars(&mut sumcheck_randomness)?; + fold = Some(sumcheck_randomness[0]); + claimed_value = eval_cubic_poly(hhat_i_coeffs, sumcheck_randomness[0]); + sumcheck_randomness_accumulator.push(sumcheck_randomness[0]); + if m0.len() <= 2 { + break; + } + } + + let folded_v0 = m0[0] + (m0[1] - m0[0]) * sumcheck_randomness[0]; + let folded_v1 = m1[0] + (m1[1] - m1[0]) * sumcheck_randomness[0]; + let folded_v2 = m2[0] + (m2[1] - m2[0]) * sumcheck_randomness[0]; + + Ok(( + [folded_v0, folded_v1, folded_v2], + sumcheck_randomness_accumulator, + )) +} + +/// Verifies a SPARK sumcheck proof from the transcript. +/// +/// Checks that the prover's claimed sum is correct by verifying polynomial +/// evaluations at each round without recomputing the full sum. +/// +/// # Returns +/// +/// Tuple of `(accumulated_randomness, final_evaluation)` +pub fn run_sumcheck_verifier_spark( + arthur: &mut VerifierState, + variable_count: usize, + initial_sumcheck_val: FieldElement, +) -> Result<(Vec, FieldElement)> { + let mut saved_val_for_sumcheck_equality_assertion = initial_sumcheck_val; + + let mut alpha = vec![FieldElement::zero(); variable_count]; + + for i in 0..variable_count { + let mut hhat_i = [FieldElement::zero(); 4]; + let mut alpha_i = [FieldElement::zero(); 1]; + arthur.fill_next_scalars(&mut hhat_i)?; + arthur.fill_challenge_scalars(&mut alpha_i)?; + alpha[i] = alpha_i[0]; + + let hhat_i_at_zero = eval_cubic_poly(hhat_i, FieldElement::zero()); + let hhat_i_at_one = eval_cubic_poly(hhat_i, FieldElement::one()); + ensure!( + saved_val_for_sumcheck_equality_assertion == hhat_i_at_zero + hhat_i_at_one, + "Sumcheck equality check failed" + ); + saved_val_for_sumcheck_equality_assertion = eval_cubic_poly(hhat_i, alpha_i[0]); + } + + Ok((alpha, saved_val_for_sumcheck_equality_assertion)) +} diff --git a/provekit/spark/src/test b/provekit/spark/src/test new file mode 100644 index 00000000..fddb6943 --- /dev/null +++ b/provekit/spark/src/test @@ -0,0 +1,8 @@ +[ +4835592336198174705352400703506937093376006563621879466256772000372998731136, 13842047038711022483770763283082747198904712224324162406798144810509600218283, 5803349932421386922653536813787357009325092095930239281002917144817379287788, 7112363092876161088155824432078896182194603041756112465820527588963216680105] + +[7715668456526668687870296270384312490096084851922028954886526938263798440470, 217540674808248399652549832274153101504587080714092820584987064929140712291] +[10330114051913303653629396762973881367473860167600894772273318653472428842159, 9649983605390653446068145518441909119042324595064247650990698324879719926601] + +7715668456526668687870296270384312490096084851922028954886526938263798440470 9006454702512847778418362579575810105528705660702282940541372810136601487147 14390115090120854934028659307147115699956866629208098209396664313241150767438 4835592336198174705352400703506937093376006563621879466256772000372998731136 +10330114051913303653629396762973881367473860167600894772273318653472428842159 14190801329781201609330330783973004155889169685539624587974441820585044400787 21208112425316625014685154500725302840116828827879387222415583857983099580059 967757596223212217301136110280419915949085532308359814746145144444380556652 diff --git a/provekit/spark/src/types.rs b/provekit/spark/src/types.rs new file mode 100644 index 00000000..e170e449 --- /dev/null +++ b/provekit/spark/src/types.rs @@ -0,0 +1,100 @@ +use { + provekit_common::{FieldElement, WhirConfig}, + serde::{Deserialize, Serialize}, +}; + +/// Complete SPARK proof including transcript and configuration. +#[derive(Serialize, Deserialize)] +pub struct SPARKProof { + pub transcript: Vec, + pub io_pattern: String, + pub whir_params: SPARKWHIRConfigs, + pub matrix_dimensions: MatrixDimensions, +} + +/// Dimensions of the R1CS matrices used in the proof. +#[derive(Serialize, Deserialize, Clone)] +pub struct MatrixDimensions { + pub num_rows: usize, + pub num_cols: usize, + pub nonzero_terms: usize, +} + +/// WHIR commitment scheme configurations for different vector sizes. +#[derive(Serialize, Deserialize, Clone)] +pub struct SPARKWHIRConfigs { + pub row: WhirConfig, + pub col: WhirConfig, + pub num_terms_1batched: WhirConfig, + pub num_terms_2batched: WhirConfig, + pub num_terms_4batched: WhirConfig, +} + +/// SPARK matrix in COO format with memory access timestamps. +#[derive(Debug, Clone)] +pub struct SparkMatrix { + pub coo: COOMatrix, + pub timestamps: TimeStamps, +} + +/// Coordinate (COO) sparse matrix format storing row/col indices and values. +#[derive(Debug, Clone)] +pub struct COOMatrix { + pub row: Vec, + pub col: Vec, + pub val: Vec, +} + +/// Memory access timestamps for GPA protocol. +#[derive(Debug, Clone)] +pub struct TimeStamps { + pub read_row: Vec, + pub read_col: Vec, + pub final_row: Vec, + pub final_col: Vec, +} + +/// Precomputed equality check evaluations for memory arguments. +#[derive(Debug, Clone)] +pub struct Memory { + pub eq_rx: Vec, + pub eq_ry: Vec, +} + +/// Row and column evaluation vectors at the challenge point. +#[derive(Debug, Clone)] +pub struct EValuesForMatrix { + pub e_rx: Vec, + pub e_ry: Vec, +} + +use provekit_common::gnark::WHIRConfigGnark; + +/// SPARK proof formatted for Gnark recursive verifier. +#[derive(Serialize, Deserialize)] +pub struct SPARKProofGnark { + pub transcript: Vec, + pub io_pattern: String, + pub whir_row: WHIRConfigGnark, + pub whir_col: WHIRConfigGnark, + pub whir_1batched: WHIRConfigGnark, + pub whir_2batched: WHIRConfigGnark, + pub whir_4batched: WHIRConfigGnark, + pub log_num_terms: usize, +} + +impl SPARKProofGnark { + /// Converts SPARK proof to Gnark-compatible format. + pub fn from_proof(proof: &SPARKProof, log_num_terms: usize) -> Self { + Self { + transcript: proof.transcript.clone(), + io_pattern: proof.io_pattern.clone(), + whir_row: WHIRConfigGnark::new(&proof.whir_params.row), + whir_col: WHIRConfigGnark::new(&proof.whir_params.col), + whir_1batched: WHIRConfigGnark::new(&proof.whir_params.num_terms_1batched), + whir_2batched: WHIRConfigGnark::new(&proof.whir_params.num_terms_2batched), + whir_4batched: WHIRConfigGnark::new(&proof.whir_params.num_terms_4batched), + log_num_terms, + } + } +} diff --git a/provekit/spark/src/utils.rs b/provekit/spark/src/utils.rs new file mode 100644 index 00000000..67f04cf4 --- /dev/null +++ b/provekit/spark/src/utils.rs @@ -0,0 +1,77 @@ +pub use crate::types::Memory; +use { + anyhow::{Context, Result}, + provekit_common::{ + spark::{Point, SparkStatement}, + utils::{next_power_of_two, sumcheck::calculate_evaluations_over_boolean_hypercube_for_eq}, + FieldElement, R1CS, + }, + spongefish::codecs::arkworks_algebra::FieldDomainSeparator, + std::{fs, path::Path}, +}; + +/// Deserializes R1CS from JSON and pads matrices to power-of-2 dimensions. +pub fn deserialize_r1cs(path: impl AsRef) -> Result { + let json_str = fs::read_to_string(path).context("Failed to read R1CS file")?; + let mut r1cs: R1CS = serde_json::from_str(&json_str).context("Failed to deserialize R1CS")?; + r1cs.grow_matrices( + 1 << next_power_of_two(r1cs.num_constraints()), + 1 << next_power_of_two(r1cs.num_witnesses()), + ); + Ok(r1cs) +} + +/// Deserializes SPARK request from JSON. +pub fn deserialize_request(path: impl AsRef) -> Result { + let json_str = fs::read_to_string(path).context("Failed to read request file")?; + serde_json::from_str(&json_str).context("Failed to deserialize request") +} + +/// Computes equality check evaluations for row and column points. +pub fn calculate_memory(b: FieldElement, point_to_evaluate: Point) -> Memory { + Memory { + eq_rx: calculate_evaluations_over_boolean_hypercube_for_eq( + std::iter::once(b).chain(point_to_evaluate.row).collect(), + ), + eq_ry: calculate_evaluations_over_boolean_hypercube_for_eq( + std::iter::once(b) + .chain(point_to_evaluate.col[1..].to_vec()) + .collect(), + ) + .iter() + .map(|x| *x * (FieldElement::from(1) - point_to_evaluate.col[0])) + .collect(), + } +} + +/// Trait extending IO patterns with SPARK-specific domain separators. +pub trait SPARKDomainSeparator { + fn add_tau_and_gamma(self) -> Self; + fn add_line(self) -> Self; + fn add_gpa4_claimed_values(self) -> Self; + fn add_claimed_evaluations(self) -> Self; +} + +impl SPARKDomainSeparator for IOPattern +where + IOPattern: FieldDomainSeparator, +{ + fn add_tau_and_gamma(self) -> Self { + self.challenge_scalars(2, "tau and gamma") + } + + fn add_line(self) -> Self { + self.add_scalars(2, "gpa line") + .challenge_scalars(1, "gpa line random") + } + + fn add_gpa4_claimed_values(self) -> Self { + self.add_scalars(4, "gpa claimed values") + .challenge_scalars(2, "gpa randomness") + } + + fn add_claimed_evaluations(self) -> Self { + self.add_scalars(3, "claimed evaluations") + .challenge_scalars(1, "matrix combination randomness") + } +} diff --git a/provekit/spark/src/verifier.rs b/provekit/spark/src/verifier.rs new file mode 100644 index 00000000..f5d4e653 --- /dev/null +++ b/provekit/spark/src/verifier.rs @@ -0,0 +1,240 @@ +use { + crate::{ + gpa::gpa_sumcheck_verifier4, + memory::{verify_colwise, verify_rowwise}, + sumcheck::run_sumcheck_verifier_spark, + types::{MatrixDimensions, SPARKProof, SPARKWHIRConfigs}, + }, + anyhow::{ensure, Context, Result}, + ark_ff::Field, + provekit_common::{ + skyscraper::SkyscraperSponge, spark::SparkStatement, utils::next_power_of_two, + FieldElement, IOPattern, + }, + spongefish::codecs::arkworks_algebra::{FieldToUnitDeserialize, UnitToField}, + whir::{ + poly_utils::multilinear::MultilinearPoint, + whir::{ + committer::CommitmentReader, + statement::{Statement, Weights}, + utils::HintDeserialize, + verifier::Verifier, + }, + }, +}; + +/// SPARK verification interface. +pub trait SPARKVerifier { + /// Verifies a SPARK proof against the given request. + fn verify(&self, proof: &SPARKProof, request: &SparkStatement) -> Result<()>; +} + +/// SPARK verification scheme with configuration extracted from proof. +pub struct SPARKScheme { + pub whir_configs: SPARKWHIRConfigs, + pub io_pattern: IOPattern, + pub matrix_dimensions: MatrixDimensions, +} + +impl SPARKScheme { + /// Constructs verifier scheme from proof metadata. + pub fn from_proof(proof: &SPARKProof) -> Self { + Self { + whir_configs: proof.whir_params.clone(), + io_pattern: IOPattern::from_string(proof.io_pattern.clone()), + matrix_dimensions: proof.matrix_dimensions.clone(), + } + } +} + +impl SPARKVerifier for SPARKScheme { + fn verify(&self, proof: &SPARKProof, request: &SparkStatement) -> Result<()> { + let io = IOPattern::from_string(proof.io_pattern.clone()); + let mut arthur = io.to_verifier_state(&proof.transcript); + + let _point_row: Vec = arthur.hint()?; + let _point_col: Vec = arthur.hint()?; + + let mut claimed_values = [FieldElement::from(0); 3]; + arthur.fill_next_scalars(&mut claimed_values)?; + + let mut matrix_batching_randomness = [FieldElement::from(0); 1]; + arthur.fill_challenge_scalars(&mut matrix_batching_randomness)?; + let matrix_batching_randomness = matrix_batching_randomness[0]; + + let mut claimed_value = claimed_values[0] + + claimed_values[1] * matrix_batching_randomness + + claimed_values[2] * matrix_batching_randomness * matrix_batching_randomness; + + claimed_value = (claimed_value / (FieldElement::ONE + matrix_batching_randomness)) + / (FieldElement::ONE + matrix_batching_randomness); + + verify_spark_single_matrix( + &matrix_batching_randomness, + &proof.whir_params, + proof.matrix_dimensions.clone(), + &mut arthur, + request, + &claimed_value, + ) + } +} + +/// Core SPARK verification: sumcheck + row/col memory checks. +fn verify_spark_single_matrix( + matrix_batching_randomness: &FieldElement, + whir_params: &SPARKWHIRConfigs, + matrix_dimensions: MatrixDimensions, + arthur: &mut spongefish::VerifierState, + request: &SparkStatement, + claimed_value: &FieldElement, +) -> Result<()> { + let commitment_reader_row = CommitmentReader::new(&whir_params.row); + let commitment_reader_col = CommitmentReader::new(&whir_params.col); + let commitment_reader_batched1 = CommitmentReader::new(&whir_params.num_terms_1batched); + let commitment_reader_batched2 = CommitmentReader::new(&whir_params.num_terms_2batched); + let commitment_reader_batched4 = CommitmentReader::new(&whir_params.num_terms_4batched); + + let val_commitment = commitment_reader_batched1.parse_commitment(arthur)?; + let rsws_commitment = commitment_reader_batched4.parse_commitment(arthur)?; + let a_row_finalts_commitment = commitment_reader_row.parse_commitment(arthur)?; + let a_col_finalts_commitment = commitment_reader_col.parse_commitment(arthur)?; + + let evalues_commitment = commitment_reader_batched2.parse_commitment(arthur)?; + + let (randomness, a_last_sumcheck_value) = run_sumcheck_verifier_spark( + arthur, + next_power_of_two(matrix_dimensions.nonzero_terms), + *claimed_value, + ) + .context("While verifying SPARK sumcheck")?; + + let sumcheck_hints: Vec = arthur.hint()?; + + ensure!(a_last_sumcheck_value == sumcheck_hints[0] * sumcheck_hints[1] * sumcheck_hints[2]); + + let mut sumcheck_evalues_statement = + Statement::::new(next_power_of_two(matrix_dimensions.nonzero_terms)); + + sumcheck_evalues_statement.add_constraint( + Weights::evaluation(MultilinearPoint(randomness.clone())), + sumcheck_hints[1] + sumcheck_hints[2] * evalues_commitment.batching_randomness, + ); + + let sumcheck_evalues_verifier = Verifier::new(&whir_params.num_terms_2batched); + sumcheck_evalues_verifier.verify(arthur, &evalues_commitment, &sumcheck_evalues_statement)?; + + let mut spark_sumcheck_val_statement_verifier = + Statement::::new(next_power_of_two(matrix_dimensions.nonzero_terms)); + + spark_sumcheck_val_statement_verifier.add_constraint( + Weights::evaluation(MultilinearPoint(randomness.clone())), + sumcheck_hints[0], + ); + + let spark_sumcheck_val_verifier = Verifier::new(&whir_params.num_terms_1batched); + spark_sumcheck_val_verifier.verify( + arthur, + &val_commitment, + &spark_sumcheck_val_statement_verifier, + )?; + + let mut tau_and_gamma = [FieldElement::from(0); 2]; + arthur.fill_challenge_scalars(&mut tau_and_gamma)?; + let tau = tau_and_gamma[0]; + let gamma = tau_and_gamma[1]; + + let gpa_result = gpa_sumcheck_verifier4( + arthur, + provekit_common::utils::next_power_of_two(matrix_dimensions.nonzero_terms) + 3, + )?; + + let (combination_randomness, evaluation_randomness) = gpa_result.randomness.split_at(2); + + let claimed_row_rs = gpa_result.claimed_values[0]; + let claimed_row_ws = gpa_result.claimed_values[1]; + let claimed_col_rs = gpa_result.claimed_values[2]; + let claimed_col_ws = gpa_result.claimed_values[3]; + + let row_adr: FieldElement = arthur.hint()?; + let row_timestamp: FieldElement = arthur.hint()?; + let col_adr: FieldElement = arthur.hint()?; + let col_timestamp: FieldElement = arthur.hint()?; + + let mut rsws_statement = Statement::::new( + provekit_common::utils::next_power_of_two(matrix_dimensions.nonzero_terms), + ); + rsws_statement.add_constraint( + Weights::evaluation(MultilinearPoint(evaluation_randomness.to_vec())), + row_adr + + row_timestamp * rsws_commitment.batching_randomness + + col_adr * rsws_commitment.batching_randomness * rsws_commitment.batching_randomness + + col_timestamp + * rsws_commitment.batching_randomness + * rsws_commitment.batching_randomness + * rsws_commitment.batching_randomness, + ); + let rsws_verifier = Verifier::new(&whir_params.num_terms_4batched); + rsws_verifier.verify(arthur, &rsws_commitment, &rsws_statement)?; + + let row_mem: FieldElement = arthur.hint()?; + let col_mem: FieldElement = arthur.hint()?; + + let mut evalues_statement = Statement::::new( + provekit_common::utils::next_power_of_two(matrix_dimensions.nonzero_terms), + ); + evalues_statement.add_constraint( + Weights::evaluation(MultilinearPoint(evaluation_randomness.to_vec())), + row_mem + col_mem * evalues_commitment.batching_randomness, + ); + let evalues_verifier = Verifier::new(&whir_params.num_terms_2batched); + evalues_verifier.verify(arthur, &evalues_commitment, &evalues_statement)?; + + let row_rs_opening = row_adr * gamma * gamma + row_mem * gamma + row_timestamp - tau; + let row_ws_opening = + row_adr * gamma * gamma + row_mem * gamma + row_timestamp + FieldElement::from(1) - tau; + let col_rs_opening = col_adr * gamma * gamma + col_mem * gamma + col_timestamp - tau; + let col_ws_opening = + col_adr * gamma * gamma + col_mem * gamma + col_timestamp + FieldElement::from(1) - tau; + + let evaluated_value = row_rs_opening + * (FieldElement::from(1) - combination_randomness[0]) + * (FieldElement::from(1) - combination_randomness[1]) + + row_ws_opening + * (FieldElement::from(1) - combination_randomness[0]) + * combination_randomness[1] + + col_rs_opening + * combination_randomness[0] + * (FieldElement::from(1) - combination_randomness[1]) + + col_ws_opening * combination_randomness[0] * combination_randomness[1]; + + ensure!(evaluated_value == gpa_result.a_last_sumcheck_value); + + verify_rowwise( + arthur, + matrix_dimensions.num_rows, + whir_params, + request, + a_row_finalts_commitment, + &tau, + &gamma, + &claimed_row_rs, + &claimed_row_ws, + matrix_batching_randomness, + )?; + + verify_colwise( + arthur, + matrix_dimensions.num_cols, + whir_params, + request, + a_col_finalts_commitment, + &tau, + &gamma, + &claimed_col_rs, + &claimed_col_ws, + matrix_batching_randomness, + )?; + + Ok(()) +} diff --git a/provekit/verifier/src/whir_r1cs.rs b/provekit/verifier/src/whir_r1cs.rs index f72d220d..ef203921 100644 --- a/provekit/verifier/src/whir_r1cs.rs +++ b/provekit/verifier/src/whir_r1cs.rs @@ -46,8 +46,7 @@ impl WhirR1CSVerifier for WhirR1CSScheme { let data_from_sumcheck_verifier = run_sumcheck_verifier( &mut arthur, self.m_0, - &self.whir_for_hiding_spartan, - // proof.whir_spartan_blinding_values, + &self.whir_for_hiding_spartan, // proof.whir_spartan_blinding_values, ) .context("while verifying sumcheck")?; diff --git a/recursive-verifier/app/circuit/circuit.go b/recursive-verifier/app/circuit/circuit.go index 68927073..9f0f4966 100644 --- a/recursive-verifier/app/circuit/circuit.go +++ b/recursive-verifier/app/circuit/circuit.go @@ -18,10 +18,11 @@ import ( "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/frontend/cs/r1cs" "github.com/consensys/gnark/std/math/uints" + gnarkNimue "github.com/reilabs/gnark-nimue" + skyscraper "github.com/reilabs/gnark-skyscraper" ) type Circuit struct { - // Inputs WitnessLinearStatementEvaluations []frontend.Variable HidingSpartanLinearStatementEvaluations []frontend.Variable LogNumConstraints int @@ -39,9 +40,21 @@ type Circuit struct { MatrixA []MatrixCell MatrixB []MatrixCell MatrixC []MatrixCell - // Public Input - IO []byte - Transcript []uints.U8 `gnark:",public"` + + IO []byte + UseSpark bool + SPARKTranscript []uints.U8 + + SPARKIO []byte + Transcript []uints.U8 + + WHIRRow WHIRParams + WHIRCol WHIRParams + + PointRow []frontend.Variable + PointCol []frontend.Variable + + SparkRLC SPARKMatrixData } func (circuit *Circuit) Define(api frontend.API) error { @@ -50,7 +63,7 @@ func (circuit *Circuit) Define(api frontend.API) error { return err } - rootHash, batchingRandomness, initialOODQueries, initialOODAnswers, err := parseBatchedCommitment(arthur, circuit.WHIRParamsWitness) + spartanCommitment, err := parseBatchedCommitment(arthur, circuit.WHIRParamsWitness) if err != nil { return err @@ -67,7 +80,10 @@ func (circuit *Circuit) Define(api frontend.API) error { return err } - whirFoldingRandomness, err := RunZKWhir(api, arthur, uapi, sc, circuit.WitnessMerkle, circuit.WitnessFirstRound, circuit.WHIRParamsWitness, [][]frontend.Variable{circuit.WitnessClaimedEvaluations, circuit.WitnessBlindingEvaluations}, circuit.WitnessLinearStatementEvaluations, batchingRandomness, initialOODQueries, initialOODAnswers, rootHash) + whirFoldingRandomness, err := RunZKWhir(api, arthur, uapi, sc, circuit.WitnessMerkle, circuit.WitnessFirstRound, circuit.WHIRParamsWitness, [][]frontend.Variable{circuit.WitnessClaimedEvaluations, circuit.WitnessBlindingEvaluations}, circuit.WitnessLinearStatementEvaluations, spartanCommitment, + [][]frontend.Variable{{}, {}}, + [][]frontend.Variable{}, + ) if err != nil { return err @@ -76,17 +92,41 @@ func (circuit *Circuit) Define(api frontend.API) error { x := api.Mul(api.Sub(api.Mul(circuit.WitnessClaimedEvaluations[0], circuit.WitnessClaimedEvaluations[1]), circuit.WitnessClaimedEvaluations[2]), calculateEQ(api, spartanSumcheckRand, tRand)) api.AssertIsEqual(spartanSumcheckLastValue, x) - matrixExtensionEvals := evaluateR1CSMatrixExtension(api, circuit, spartanSumcheckRand, whirFoldingRandomness) + if circuit.UseSpark { + sc := skyscraper.NewSkyscraper(api, 2) + arthur, err := gnarkNimue.NewSkyscraperArthur(api, sc, circuit.SPARKIO, circuit.SPARKTranscript[:], true) + if err != nil { + return err + } + uapi, err := uints.New[uints.U64](api) + if err != nil { + return err + } + + err = sparkSingleMatrix( + api, + arthur, + uapi, + sc, + circuit.SparkRLC, + circuit, + ) + if err != nil { + return err + } + } else { + matrixExtensionEvals := evaluateR1CSMatrixExtension(api, circuit, spartanSumcheckRand, whirFoldingRandomness) - for i := 0; i < 3; i++ { - api.AssertIsEqual(matrixExtensionEvals[i], circuit.WitnessLinearStatementEvaluations[i]) + for i := range 3 { + api.AssertIsEqual(matrixExtensionEvals[i], circuit.WitnessLinearStatementEvaluations[i]) + } } return nil } func verifyCircuit( - deferred []Fp256, cfg Config, hints Hints, pk *groth16.ProvingKey, vk *groth16.VerifyingKey, claimedEvaluations ClaimedEvaluations, internedR1CS R1CS, interner Interner, buildOps common.BuildOps, + deferred []Fp256, cfg Config, sparkConfig SparkConfig, hints Hints, pk *groth16.ProvingKey, vk *groth16.VerifyingKey, claimedEvaluations ClaimedEvaluations, internedR1CS R1CS, interner Interner, buildOps common.BuildOps, ) error { transcriptT := make([]uints.U8, cfg.TranscriptLen) contTranscript := make([]uints.U8, cfg.TranscriptLen) @@ -95,6 +135,13 @@ func verifyCircuit( transcriptT[i] = uints.NewU8(cfg.Transcript[i]) } + sparkTranscriptT := make([]uints.U8, len(sparkConfig.Transcript)) + sparkContTranscript := make([]uints.U8, len(sparkConfig.Transcript)) + + for i := range sparkConfig.Transcript { + sparkTranscriptT[i] = uints.NewU8(sparkConfig.Transcript[i]) + } + witnessLinearStatementEvaluations := make([]frontend.Variable, 3) hidingSpartanLinearStatementEvaluations := make([]frontend.Variable, 1) contWitnessLinearStatementEvaluations := make([]frontend.Variable, 3) @@ -105,6 +152,24 @@ func verifyCircuit( witnessLinearStatementEvaluations[1] = typeConverters.LimbsToBigIntMod(deferred[2].Limbs) witnessLinearStatementEvaluations[2] = typeConverters.LimbsToBigIntMod(deferred[3].Limbs) + contSparkSumcheckLast := make([]frontend.Variable, 3) + sparkSumcheckLast := make([]frontend.Variable, 3) + sparkSumcheckLast[0] = typeConverters.LimbsToBigIntMod(hints.SparkHints.sparkClaimedEvaluations[0].Limbs) + sparkSumcheckLast[1] = typeConverters.LimbsToBigIntMod(hints.SparkHints.sparkClaimedEvaluations[1].Limbs) + sparkSumcheckLast[2] = typeConverters.LimbsToBigIntMod(hints.SparkHints.sparkClaimedEvaluations[2].Limbs) + + contPointRow := make([]frontend.Variable, len(hints.pointRow)) + pointRow := make([]frontend.Variable, len(hints.pointRow)) + for i := range len(hints.pointRow) { + pointRow[i] = typeConverters.LimbsToBigIntMod(hints.pointRow[i].Limbs) + } + + contPointCol := make([]frontend.Variable, len(hints.pointCol)) + pointCol := make([]frontend.Variable, len(hints.pointCol)) + for i := range len(hints.pointCol) { + pointCol[i] = typeConverters.LimbsToBigIntMod(hints.pointCol[i].Limbs) + } + fSums, gSums := parseClaimedEvaluations(claimedEvaluations, true) matrixA := make([]MatrixCell, len(internedR1CS.A.Values)) @@ -152,12 +217,12 @@ func verifyCircuit( } } + useSpark := buildOps.Evaluation == "spark" + var circuit = Circuit{ IO: []byte(cfg.IOPattern), Transcript: contTranscript, LogNumConstraints: cfg.LogNumConstraints, - LogNumVariables: cfg.LogNumVariables, - LogANumTerms: cfg.LogANumTerms, WitnessClaimedEvaluations: fSums, WitnessBlindingEvaluations: gSums, WitnessLinearStatementEvaluations: contWitnessLinearStatementEvaluations, @@ -173,6 +238,58 @@ func verifyCircuit( MatrixA: matrixA, MatrixB: matrixB, MatrixC: matrixC, + + SPARKIO: []byte(sparkConfig.IOPattern), + SPARKTranscript: sparkContTranscript, + WHIRRow: NewWhirParams(sparkConfig.WHIRRow), + WHIRCol: NewWhirParams(sparkConfig.WHIRCol), + + LogANumTerms: sparkConfig.LogNumTerms, + + PointRow: contPointRow, + PointCol: contPointCol, + + SparkRLC: SPARKMatrixData{ + Claimed: typeConverters.LimbsToBigIntMod(hints.SparkHints.claimed.Limbs), + + SparkSumcheckLast: contSparkSumcheckLast, + + RowFinalCounter: typeConverters.LimbsToBigIntMod(hints.SparkHints.rowFinalCounter.Limbs), + RowRSAddressEvaluation: typeConverters.LimbsToBigIntMod(hints.SparkHints.rowRSAddressEvaluation.Limbs), + RowRSValueEvaluation: typeConverters.LimbsToBigIntMod(hints.SparkHints.rowRSValueEvaluation.Limbs), + RowRSTimestampEvaluation: typeConverters.LimbsToBigIntMod(hints.SparkHints.rowRSTimestampEvaluation.Limbs), + + ColFinalCounter: typeConverters.LimbsToBigIntMod(hints.SparkHints.colFinalCounter.Limbs), + ColRSAddressEvaluation: typeConverters.LimbsToBigIntMod(hints.SparkHints.colRSAddressEvaluation.Limbs), + ColRSValueEvaluation: typeConverters.LimbsToBigIntMod(hints.SparkHints.colRSValueEvaluation.Limbs), + ColRSTimestampEvaluation: typeConverters.LimbsToBigIntMod(hints.SparkHints.colRSTimestampEvaluation.Limbs), + + EvaluesSumcheckMerkleFirstRound: newMerkle(hints.SparkHints.evaluesSumcheck.firstRoundMerklePaths.path, true), + EvaluesSumcheckMerkleRemainingRounds: newMerkle(hints.SparkHints.evaluesSumcheck.roundHints, true), + + ValsMerkleFirstRound: newMerkle(hints.SparkHints.vals.firstRoundMerklePaths.path, true), + ValsMerkleRemainingRounds: newMerkle(hints.SparkHints.vals.roundHints, true), + + RSWSMerkleFirstRound: newMerkle(hints.SparkHints.rsws.firstRoundMerklePaths.path, true), + RSWSMerkleRemainingRounds: newMerkle(hints.SparkHints.rsws.roundHints, true), + + EvaluesRSWSMerkleFirstRound: newMerkle(hints.SparkHints.evaluesRSWS.firstRoundMerklePaths.path, true), + EvaluesRSWSMerkleRemainingRounds: newMerkle(hints.SparkHints.evaluesRSWS.roundHints, true), + + RowFinalMerkleFirstRound: newMerkle(hints.SparkHints.rowFinal.firstRoundMerklePaths.path, true), + RowFinalMerkleRemainingRounds: newMerkle(hints.SparkHints.rowFinal.roundHints, true), + + ColFinalMerkleFirstRound: newMerkle(hints.SparkHints.colFinal.firstRoundMerklePaths.path, true), + ColFinalMerkleRemainingRounds: newMerkle(hints.SparkHints.colFinal.roundHints, true), + + WHIR1: NewWhirParams(sparkConfig.WHIR1), + WHIR2: NewWhirParams(sparkConfig.WHIR2), + WHIR4: NewWhirParams(sparkConfig.WHIR4), + + LogNumTerms: sparkConfig.LogNumTerms, + }, + + UseSpark: useSpark, } ccs, err := frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, &circuit) @@ -273,6 +390,57 @@ func verifyCircuit( MatrixA: matrixA, MatrixB: matrixB, MatrixC: matrixC, + + SPARKIO: []byte(sparkConfig.IOPattern), + SPARKTranscript: sparkTranscriptT, + WHIRRow: NewWhirParams(sparkConfig.WHIRRow), + WHIRCol: NewWhirParams(sparkConfig.WHIRCol), + LogANumTerms: sparkConfig.LogNumTerms, + + PointRow: pointRow, + PointCol: pointCol, + + SparkRLC: SPARKMatrixData{ + Claimed: typeConverters.LimbsToBigIntMod(hints.SparkHints.claimed.Limbs), + + SparkSumcheckLast: sparkSumcheckLast, + + RowFinalCounter: typeConverters.LimbsToBigIntMod(hints.SparkHints.rowFinalCounter.Limbs), + RowRSAddressEvaluation: typeConverters.LimbsToBigIntMod(hints.SparkHints.rowRSAddressEvaluation.Limbs), + RowRSValueEvaluation: typeConverters.LimbsToBigIntMod(hints.SparkHints.rowRSValueEvaluation.Limbs), + RowRSTimestampEvaluation: typeConverters.LimbsToBigIntMod(hints.SparkHints.rowRSTimestampEvaluation.Limbs), + + ColFinalCounter: typeConverters.LimbsToBigIntMod(hints.SparkHints.colFinalCounter.Limbs), + ColRSAddressEvaluation: typeConverters.LimbsToBigIntMod(hints.SparkHints.colRSAddressEvaluation.Limbs), + ColRSValueEvaluation: typeConverters.LimbsToBigIntMod(hints.SparkHints.colRSValueEvaluation.Limbs), + ColRSTimestampEvaluation: typeConverters.LimbsToBigIntMod(hints.SparkHints.colRSTimestampEvaluation.Limbs), + + EvaluesSumcheckMerkleFirstRound: newMerkle(hints.SparkHints.evaluesSumcheck.firstRoundMerklePaths.path, false), + EvaluesSumcheckMerkleRemainingRounds: newMerkle(hints.SparkHints.evaluesSumcheck.roundHints, false), + + ValsMerkleFirstRound: newMerkle(hints.SparkHints.vals.firstRoundMerklePaths.path, false), + ValsMerkleRemainingRounds: newMerkle(hints.SparkHints.vals.roundHints, false), + + RSWSMerkleFirstRound: newMerkle(hints.SparkHints.rsws.firstRoundMerklePaths.path, false), + RSWSMerkleRemainingRounds: newMerkle(hints.SparkHints.rsws.roundHints, false), + + EvaluesRSWSMerkleFirstRound: newMerkle(hints.SparkHints.evaluesRSWS.firstRoundMerklePaths.path, false), + EvaluesRSWSMerkleRemainingRounds: newMerkle(hints.SparkHints.evaluesRSWS.roundHints, false), + + RowFinalMerkleFirstRound: newMerkle(hints.SparkHints.rowFinal.firstRoundMerklePaths.path, false), + RowFinalMerkleRemainingRounds: newMerkle(hints.SparkHints.rowFinal.roundHints, false), + + ColFinalMerkleFirstRound: newMerkle(hints.SparkHints.colFinal.firstRoundMerklePaths.path, false), + ColFinalMerkleRemainingRounds: newMerkle(hints.SparkHints.colFinal.roundHints, false), + + WHIR1: NewWhirParams(sparkConfig.WHIR1), + WHIR2: NewWhirParams(sparkConfig.WHIR2), + WHIR4: NewWhirParams(sparkConfig.WHIR4), + + LogNumTerms: sparkConfig.LogNumTerms, + }, + + UseSpark: useSpark, } witness, _ := frontend.NewWitness(&assignment, ecc.BN254.ScalarField()) @@ -305,3 +473,365 @@ func parseClaimedEvaluations(claimedEvaluations ClaimedEvaluations, isContainer return fSums, gSums } + +func gpaSumcheckVerifier( + api frontend.API, + arthur gnarkNimue.Arthur, + layerCount int, +) (GPASumcheckResult, error) { + l := make([]frontend.Variable, 2) + r := make([]frontend.Variable, 1) + + gpaClaimedValues := make([]frontend.Variable, 2) + err := arthur.FillNextScalars(gpaClaimedValues) + if err != nil { + return GPASumcheckResult{}, err + } + err = arthur.FillChallengeScalars(r) + if err != nil { + return GPASumcheckResult{}, err + } + lastEval := utilities.UnivarPoly(api, gpaClaimedValues, r)[0] + prevRand := []frontend.Variable{r[0]} + var rand []frontend.Variable + + for i := 1; i < (layerCount - 1); i++ { + rand, lastEval, err = runSumcheck( + api, + arthur, + lastEval, + i, + 4, + ) + if err != nil { + return GPASumcheckResult{}, err + } + + err = arthur.FillNextScalars(l) + if err != nil { + return GPASumcheckResult{}, err + } + err = arthur.FillChallengeScalars(r) + if err != nil { + return GPASumcheckResult{}, err + } + claimedLastSch := api.Mul( + calculateEQ(api, prevRand, rand), + utilities.UnivarPoly(api, l, []frontend.Variable{0})[0], + utilities.UnivarPoly(api, l, []frontend.Variable{1})[0], + ) + api.AssertIsEqual(claimedLastSch, lastEval) + prevRand = append(rand, r[0]) + lastEval = utilities.UnivarPoly(api, l, []frontend.Variable{r[0]})[0] + } + + gpaClaimedValues = []frontend.Variable{ + gpaClaimedValues[0], + api.Add(gpaClaimedValues[0], gpaClaimedValues[1]), + } + + return GPASumcheckResult{ + claimedProducts: gpaClaimedValues, + lastSumcheckValue: lastEval, + randomness: prevRand, + }, nil +} + +type GPASumcheckResult struct { + claimedProducts []frontend.Variable + lastSumcheckValue frontend.Variable + randomness []frontend.Variable +} + +func CalculateAdr(api frontend.API, coefficients []frontend.Variable) frontend.Variable { + ans := frontend.Variable(0) + for _, coefficient := range coefficients { + ans = api.Add(api.Mul(ans, 2), coefficient) + } + + return ans +} + +func sparkSingleMatrix( + api frontend.API, + arthur gnarkNimue.Arthur, + uapi *uints.BinaryField[uints.U64], + sc *skyscraper.Skyscraper, + matrix SPARKMatrixData, + circuit *Circuit, +) error { + claimedEvaluations := make([]frontend.Variable, 3) + if err := arthur.FillNextScalars(claimedEvaluations); err != nil { + return err + } + + matrixCombinationRandomness := make([]frontend.Variable, 1) + if err := arthur.FillChallengeScalars(matrixCombinationRandomness); err != nil { + return err + } + + claimedValue := api.Add( + claimedEvaluations[0], + api.Mul(claimedEvaluations[1], matrixCombinationRandomness[0]), + api.Mul(claimedEvaluations[2], matrixCombinationRandomness[0], matrixCombinationRandomness[0]), + ) + + claimedValue = api.Div(api.Div(claimedValue, (api.Add(1, matrixCombinationRandomness[0]))), (api.Add(1, matrixCombinationRandomness[0]))) + + valsCommitment, err := parseBatchedCommitment(arthur, matrix.WHIR1) + if err != nil { + return err + } + rsWSCommitment, err := parseBatchedCommitment(arthur, matrix.WHIR4) + if err != nil { + return err + } + rowFinalCommitment, err := parseBatchedCommitment(arthur, circuit.WHIRRow) + if err != nil { + return err + } + colFinalCommitment, err := parseBatchedCommitment(arthur, circuit.WHIRCol) + if err != nil { + return err + } + evaluesCommitment, err := parseBatchedCommitment(arthur, matrix.WHIR2) + if err != nil { + return err + } + + sparkSumcheckFoldingRandomness, sparkSumcheckLastEval, err := runSumcheck(api, arthur, claimedValue, matrix.LogNumTerms, 4) + if err != nil { + return err + } + + // Verify spark sumcheck last value + + api.AssertIsEqual(sparkSumcheckLastEval, api.Mul(matrix.SparkSumcheckLast[0], matrix.SparkSumcheckLast[1], matrix.SparkSumcheckLast[2])) + + _, err = RunZKWhir(api, arthur, uapi, sc, matrix.EvaluesSumcheckMerkleRemainingRounds, matrix.EvaluesSumcheckMerkleFirstRound, matrix.WHIR2, [][]frontend.Variable{{}, {}}, []frontend.Variable{}, evaluesCommitment, + [][]frontend.Variable{{matrix.SparkSumcheckLast[1]}, {matrix.SparkSumcheckLast[2]}}, + [][]frontend.Variable{sparkSumcheckFoldingRandomness}, + ) + if err != nil { + return err + } + + _, err = RunZKWhir(api, arthur, uapi, sc, matrix.ValsMerkleRemainingRounds, matrix.ValsMerkleFirstRound, matrix.WHIR1, [][]frontend.Variable{{}}, []frontend.Variable{}, valsCommitment, + [][]frontend.Variable{{matrix.SparkSumcheckLast[0]}}, + [][]frontend.Variable{sparkSumcheckFoldingRandomness}, + ) + if err != nil { + return err + } + + // RS WS + tauGammaTemp := make([]frontend.Variable, 2) + if err := arthur.FillChallengeScalars(tauGammaTemp); err != nil { + return err + } + tau := tauGammaTemp[0] + gamma := tauGammaTemp[1] + + gpaResultRSWS, err := gpaSumcheckVerifier4(api, arthur, matrix.LogNumTerms+3) + if err != nil { + return err + } + + rsws_combination_randomness := gpaResultRSWS.randomness[0:2] + rsws_evaluation_randomness := gpaResultRSWS.randomness[2:] + + claimedRowRS := gpaResultRSWS.claimedProducts[0] + claimedRowWS := gpaResultRSWS.claimedProducts[1] + claimedColRS := gpaResultRSWS.claimedProducts[2] + claimedColWS := gpaResultRSWS.claimedProducts[3] + + _, err = RunZKWhir(api, arthur, uapi, sc, matrix.RSWSMerkleRemainingRounds, matrix.RSWSMerkleFirstRound, matrix.WHIR4, [][]frontend.Variable{{}}, []frontend.Variable{}, rsWSCommitment, + [][]frontend.Variable{{matrix.RowRSAddressEvaluation}, {matrix.RowRSTimestampEvaluation}, {matrix.ColRSAddressEvaluation}, {matrix.ColRSTimestampEvaluation}}, + [][]frontend.Variable{rsws_evaluation_randomness}, + ) + if err != nil { + return err + } + + _, err = RunZKWhir(api, arthur, uapi, sc, matrix.EvaluesRSWSMerkleRemainingRounds, matrix.EvaluesRSWSMerkleFirstRound, matrix.WHIR2, [][]frontend.Variable{{}}, []frontend.Variable{}, evaluesCommitment, + [][]frontend.Variable{{matrix.RowRSValueEvaluation}, {matrix.ColRSValueEvaluation}}, + [][]frontend.Variable{rsws_evaluation_randomness}, + ) + if err != nil { + return err + } + + row_rs_opening := api.Sub(api.Add(api.Mul(matrix.RowRSAddressEvaluation, gamma, gamma), api.Mul(matrix.RowRSValueEvaluation, gamma), matrix.RowRSTimestampEvaluation), tau) + row_ws_opening := api.Sub(api.Add(api.Mul(matrix.RowRSAddressEvaluation, gamma, gamma), api.Mul(matrix.RowRSValueEvaluation, gamma), matrix.RowRSTimestampEvaluation, 1), tau) + col_rs_opening := api.Sub(api.Add(api.Mul(matrix.ColRSAddressEvaluation, gamma, gamma), api.Mul(matrix.ColRSValueEvaluation, gamma), matrix.ColRSTimestampEvaluation), tau) + col_ws_opening := api.Sub(api.Add(api.Mul(matrix.ColRSAddressEvaluation, gamma, gamma), api.Mul(matrix.ColRSValueEvaluation, gamma), matrix.ColRSTimestampEvaluation, 1), tau) + + evaluated_value := api.Add( + api.Mul( + row_rs_opening, + api.Sub(1, rsws_combination_randomness[0]), + api.Sub(1, rsws_combination_randomness[1]), + ), + api.Mul( + row_ws_opening, + api.Sub(1, rsws_combination_randomness[0]), + rsws_combination_randomness[1], + ), + api.Mul( + col_rs_opening, + rsws_combination_randomness[0], + api.Sub(1, rsws_combination_randomness[1]), + ), + api.Mul( + col_ws_opening, + rsws_combination_randomness[0], + rsws_combination_randomness[1], + ), + ) + + api.AssertIsEqual(evaluated_value, gpaResultRSWS.lastSumcheckValue) + + b1 := api.Div(matrixCombinationRandomness[0], api.Add(1, matrixCombinationRandomness[0])) + + // Rowwise + circuit.PointRow = append([]frontend.Variable{b1}, circuit.PointRow...) + + rowwiseGpaResult, err := gpaSumcheckVerifier(api, arthur, len(circuit.PointRow)+2) + if err != nil { + return err + } + + rowwiseClaimedInit := rowwiseGpaResult.claimedProducts[0] + rowwiseClaimedFinal := rowwiseGpaResult.claimedProducts[1] + + last_randomness := rowwiseGpaResult.randomness[0] + evaluation_randomness := rowwiseGpaResult.randomness[1:] + + addr := CalculateAdr(api, evaluation_randomness) + mem := calculateEQ(api, circuit.PointRow, evaluation_randomness) + init_cntr := 0 + + init_opening := api.Sub(api.Add(api.Mul(addr, gamma, gamma), api.Mul(mem, gamma), init_cntr), tau) + + _, err = RunZKWhir(api, arthur, uapi, sc, matrix.RowFinalMerkleRemainingRounds, matrix.RowFinalMerkleFirstRound, circuit.WHIRRow, [][]frontend.Variable{{}}, []frontend.Variable{}, rowFinalCommitment, + [][]frontend.Variable{{matrix.RowFinalCounter}}, + [][]frontend.Variable{evaluation_randomness}, + ) + if err != nil { + return err + } + + final_opening := api.Sub(api.Add(api.Mul(addr, gamma, gamma), api.Mul(mem, gamma), matrix.RowFinalCounter), tau) + + rowwise_evaluated_value := api.Add(api.Mul(init_opening, api.Sub(1, last_randomness)), api.Mul(final_opening, last_randomness)) + + api.AssertIsEqual(rowwiseGpaResult.lastSumcheckValue, rowwise_evaluated_value) + + // Colwise + + colwiseInitFinalGpaResult, err := gpaSumcheckVerifier(api, arthur, len(circuit.PointCol)+2) + if err != nil { + return err + } + + colwiseClaimedInit := colwiseInitFinalGpaResult.claimedProducts[0] + colwiseClaimedFinal := colwiseInitFinalGpaResult.claimedProducts[1] + + colwiseLast_randomness := colwiseInitFinalGpaResult.randomness[0] + colwiseEvaluation_randomness := colwiseInitFinalGpaResult.randomness[1:] + + colwiseaddr := CalculateAdr(api, colwiseEvaluation_randomness) + + colwisemem := api.Mul(calculateEQ(api, append([]frontend.Variable{b1}, circuit.PointCol[1:]...), colwiseEvaluation_randomness), api.Sub(1, circuit.PointCol[0])) + colwiseinit_cntr := 0 + + colwiseinit_opening := api.Sub(api.Add(api.Mul(colwiseaddr, gamma, gamma), api.Mul(colwisemem, gamma), colwiseinit_cntr), tau) + + _, err = RunZKWhir(api, arthur, uapi, sc, circuit.SparkRLC.ColFinalMerkleRemainingRounds, circuit.SparkRLC.ColFinalMerkleFirstRound, circuit.WHIRCol, [][]frontend.Variable{{}}, []frontend.Variable{}, colFinalCommitment, + [][]frontend.Variable{{matrix.ColFinalCounter}}, + [][]frontend.Variable{colwiseEvaluation_randomness}, + ) + if err != nil { + return err + } + + colwisefinal_opening := api.Sub(api.Add(api.Mul(colwiseaddr, gamma, gamma), api.Mul(colwisemem, gamma), matrix.ColFinalCounter), tau) + colwiseevaluated_value := api.Add(api.Mul(colwiseinit_opening, api.Sub(1, colwiseLast_randomness)), api.Mul(colwisefinal_opening, colwiseLast_randomness)) + api.AssertIsEqual(colwiseInitFinalGpaResult.lastSumcheckValue, colwiseevaluated_value) + + api.AssertIsEqual(api.Mul(rowwiseClaimedInit, claimedRowWS), api.Mul(claimedRowRS, rowwiseClaimedFinal)) + api.AssertIsEqual(api.Mul(colwiseClaimedInit, claimedColWS), api.Mul(claimedColRS, colwiseClaimedFinal)) + + return nil +} + +func gpaSumcheckVerifier4( + api frontend.API, + arthur gnarkNimue.Arthur, + layerCount int, +) (GPASumcheckResult, error) { + l := make([]frontend.Variable, 2) + r := make([]frontend.Variable, 1) + gpaClaimedValues := make([]frontend.Variable, 4) + prevRand := make([]frontend.Variable, 2) + + err := arthur.FillNextScalars(gpaClaimedValues) + if err != nil { + return GPASumcheckResult{}, err + } + err = arthur.FillChallengeScalars(prevRand) + if err != nil { + return GPASumcheckResult{}, err + } + lastEval := api.Add( + gpaClaimedValues[0], + api.Mul(gpaClaimedValues[1], prevRand[1]), + api.Mul(gpaClaimedValues[2], prevRand[0]), + api.Mul(gpaClaimedValues[3], prevRand[0], prevRand[1]), + ) + + var rand []frontend.Variable + + for i := 2; i < (layerCount - 1); i++ { + rand, lastEval, err = runSumcheck( + api, + arthur, + lastEval, + i, + 4, + ) + if err != nil { + return GPASumcheckResult{}, err + } + + err = arthur.FillNextScalars(l) + if err != nil { + return GPASumcheckResult{}, err + } + err = arthur.FillChallengeScalars(r) + if err != nil { + return GPASumcheckResult{}, err + } + claimedLastSch := api.Mul( + calculateEQ(api, prevRand, rand), + utilities.UnivarPoly(api, l, []frontend.Variable{0})[0], + utilities.UnivarPoly(api, l, []frontend.Variable{1})[0], + ) + api.AssertIsEqual(claimedLastSch, lastEval) + prevRand = append(rand, r[0]) + lastEval = utilities.UnivarPoly(api, l, []frontend.Variable{r[0]})[0] + } + + gpaClaimedValues = []frontend.Variable{ + gpaClaimedValues[0], + api.Add(gpaClaimedValues[0], gpaClaimedValues[1]), + api.Add(gpaClaimedValues[0], gpaClaimedValues[2]), + api.Add(gpaClaimedValues[0], gpaClaimedValues[1], gpaClaimedValues[2], gpaClaimedValues[3]), + } + + return GPASumcheckResult{ + claimedProducts: gpaClaimedValues, + lastSumcheckValue: lastEval, + randomness: prevRand, + }, nil +} diff --git a/recursive-verifier/app/circuit/common.go b/recursive-verifier/app/circuit/common.go index 52acea04..364ec1b5 100644 --- a/recursive-verifier/app/circuit/common.go +++ b/recursive-verifier/app/circuit/common.go @@ -14,7 +14,7 @@ import ( "reilabs/whir-verifier-circuit/app/common" ) -func PrepareAndVerifyCircuit(config Config, r1cs R1CS, pk *groth16.ProvingKey, vk *groth16.VerifyingKey, buildOps common.BuildOps) error { +func PrepareAndVerifyCircuit(config Config, sparkConfig SparkConfig, r1cs R1CS, pk *groth16.ProvingKey, vk *groth16.VerifyingKey, buildOps common.BuildOps) error { io := gnarkNimue.IOPattern{} err := io.Parse([]byte(config.IOPattern)) if err != nil { @@ -108,6 +108,238 @@ func PrepareAndVerifyCircuit(config Config, r1cs R1CS, pk *groth16.ProvingKey, v config.Transcript = truncated + // Spark start + spark_io := gnarkNimue.IOPattern{} + err = spark_io.Parse([]byte(sparkConfig.IOPattern)) + if err != nil { + return fmt.Errorf("failed to parse IO pattern: %w", err) + } + + var spark_pointer uint64 + var spark_truncated_transcript []byte + + var sparkMerklePaths []FullMultiPath[KeccakDigest] + var sparkStirAnswers [][][]Fp256 + var sparkClaimedEvaluations [][]Fp256 + + var rowFinalCounter []Fp256 + var rowRSAddressEvaluation []Fp256 + var rowRSValueEvaluation []Fp256 + var rowRSTimestampEvaluation []Fp256 + + var colFinalCounter []Fp256 + var colRSAddressEvaluation []Fp256 + var colRSValueEvaluation []Fp256 + var colRSTimestampEvaluation []Fp256 + + var claimedA Fp256 + var claimedB Fp256 + var claimedC Fp256 + var pointRow []Fp256 + var pointCol []Fp256 + + for _, op := range spark_io.Ops { + switch op.Kind { + case gnarkNimue.Hint: + if spark_pointer+4 > uint64(len(sparkConfig.Transcript)) { + return fmt.Errorf("insufficient bytes for hint length") + } + hintLen := binary.LittleEndian.Uint32(sparkConfig.Transcript[spark_pointer : spark_pointer+4]) + start := spark_pointer + 4 + end := start + uint64(hintLen) + + if end > uint64(len(sparkConfig.Transcript)) { + return fmt.Errorf("insufficient bytes for merkle proof") + } + + switch string(op.Label) { + case "merkle_proof": + var path FullMultiPath[KeccakDigest] + _, err = arkSerialize.CanonicalDeserializeWithMode( + bytes.NewReader(sparkConfig.Transcript[start:end]), + &path, + false, false, + ) + sparkMerklePaths = append(sparkMerklePaths, path) + + case "stir_answers": + var stirAnswersTemporary [][]Fp256 + _, err = arkSerialize.CanonicalDeserializeWithMode( + bytes.NewReader(sparkConfig.Transcript[start:end]), + &stirAnswersTemporary, + false, false, + ) + sparkStirAnswers = append(sparkStirAnswers, stirAnswersTemporary) + case "sumcheck_last_folds": + var temp []Fp256 + _, err = arkSerialize.CanonicalDeserializeWithMode( + bytes.NewReader(sparkConfig.Transcript[start:end]), + &temp, + false, false, + ) + if err != nil { + return fmt.Errorf("failed to deserialize spark_last_folds: %w", err) + } + sparkClaimedEvaluations = append(sparkClaimedEvaluations, temp) + case "row_final_counter_claimed_evaluation": + var temp Fp256 + _, err = arkSerialize.CanonicalDeserializeWithMode( + bytes.NewReader(sparkConfig.Transcript[start:end]), + &temp, + false, false, + ) + if err != nil { + return fmt.Errorf("failed to deserialize row_final_counter_claimed_evaluation : %w", err) + } + rowFinalCounter = append(rowFinalCounter, temp) + case "row_rs_address_claimed_evaluation": + var temp Fp256 + _, err = arkSerialize.CanonicalDeserializeWithMode( + bytes.NewReader(sparkConfig.Transcript[start:end]), + &temp, + false, false, + ) + if err != nil { + return fmt.Errorf("failed to deserialize row_rs_address_claimed_evaluation : %w", err) + } + rowRSAddressEvaluation = append(rowRSAddressEvaluation, temp) + case "row_rs_value_claimed_evaluation": + var temp Fp256 + _, err = arkSerialize.CanonicalDeserializeWithMode( + bytes.NewReader(sparkConfig.Transcript[start:end]), + &temp, + false, false, + ) + if err != nil { + return fmt.Errorf("failed to deserialize row_rs_value_claimed_evaluation : %w", err) + } + rowRSValueEvaluation = append(rowRSValueEvaluation, temp) + case "row_rs_timestamp_claimed_evaluation": + var temp Fp256 + _, err = arkSerialize.CanonicalDeserializeWithMode( + bytes.NewReader(sparkConfig.Transcript[start:end]), + &temp, + false, false, + ) + if err != nil { + return fmt.Errorf("failed to deserialize row_rs_timestamp_claimed_evaluation : %w", err) + } + rowRSTimestampEvaluation = append(rowRSTimestampEvaluation, temp) + case "col_final_counter_claimed_evaluation": + var temp Fp256 + _, err = arkSerialize.CanonicalDeserializeWithMode( + bytes.NewReader(sparkConfig.Transcript[start:end]), + &temp, + false, false, + ) + if err != nil { + return fmt.Errorf("failed to deserialize col_final_counter_claimed_evaluation : %w", err) + } + colFinalCounter = append(colFinalCounter, temp) + case "col_rs_address_claimed_evaluation": + var temp Fp256 + _, err = arkSerialize.CanonicalDeserializeWithMode( + bytes.NewReader(sparkConfig.Transcript[start:end]), + &temp, + false, false, + ) + if err != nil { + return fmt.Errorf("failed to deserialize col_rs_address_claimed_evaluation : %w", err) + } + colRSAddressEvaluation = append(colRSAddressEvaluation, temp) + case "col_rs_value_claimed_evaluation": + var temp Fp256 + _, err = arkSerialize.CanonicalDeserializeWithMode( + bytes.NewReader(sparkConfig.Transcript[start:end]), + &temp, + false, false, + ) + if err != nil { + return fmt.Errorf("failed to deserialize col_rs_value_claimed_evaluation : %w", err) + } + colRSValueEvaluation = append(colRSValueEvaluation, temp) + case "col_rs_timestamp_claimed_evaluation": + var temp Fp256 + _, err = arkSerialize.CanonicalDeserializeWithMode( + bytes.NewReader(sparkConfig.Transcript[start:end]), + &temp, + false, false, + ) + if err != nil { + return fmt.Errorf("failed to deserialize col_rs_timestamp_claimed_evaluation : %w", err) + } + colRSTimestampEvaluation = append(colRSTimestampEvaluation, temp) + case "claimed_a": + _, err = arkSerialize.CanonicalDeserializeWithMode( + bytes.NewReader(sparkConfig.Transcript[start:end]), + &claimedA, + false, false, + ) + if err != nil { + return fmt.Errorf("failed to deserialize row_rs_address_claimed_evaluation : %w", err) + } + case "claimed_b": + _, err = arkSerialize.CanonicalDeserializeWithMode( + bytes.NewReader(sparkConfig.Transcript[start:end]), + &claimedB, + false, false, + ) + if err != nil { + return fmt.Errorf("failed to deserialize row_rs_address_claimed_evaluation : %w", err) + } + case "claimed_c": + _, err = arkSerialize.CanonicalDeserializeWithMode( + bytes.NewReader(sparkConfig.Transcript[start:end]), + &claimedC, + false, false, + ) + if err != nil { + return fmt.Errorf("failed to deserialize row_rs_address_claimed_evaluation : %w", err) + } + case "point_row": + _, err = arkSerialize.CanonicalDeserializeWithMode( + bytes.NewReader(sparkConfig.Transcript[start:end]), + &pointRow, + false, false, + ) + if err != nil { + return fmt.Errorf("failed to deserialize row_rs_address_claimed_evaluation : %w", err) + } + case "point_col": + _, err = arkSerialize.CanonicalDeserializeWithMode( + bytes.NewReader(sparkConfig.Transcript[start:end]), + &pointCol, + false, false, + ) + if err != nil { + return fmt.Errorf("failed to deserialize row_rs_address_claimed_evaluation : %w", err) + } + } + + if err != nil { + return fmt.Errorf("failed to deserialize merkle proof: %w", err) + } + + spark_pointer = end + + case gnarkNimue.Absorb: + start := spark_pointer + if string(op.Label) == "pow-nonce" { + spark_pointer += op.Size + } else { + spark_pointer += op.Size * 32 + } + + if spark_pointer > uint64(len(sparkConfig.Transcript)) { + return fmt.Errorf("absorb exceeds transcript length") + } + + spark_truncated_transcript = append(spark_truncated_transcript, sparkConfig.Transcript[start:spark_pointer]...) + } + } + + sparkConfig.Transcript = spark_truncated_transcript + internerBytes, err := hex.DecodeString(r1cs.Interner.Values) if err != nil { return fmt.Errorf("failed to decode interner values: %w", err) @@ -122,14 +354,50 @@ func PrepareAndVerifyCircuit(config Config, r1cs R1CS, pk *groth16.ProvingKey, v } var hidingSpartanData = consumeWhirData(config.WHIRConfigHidingSpartan, &merklePaths, &stirAnswers) - var witnessData = consumeWhirData(config.WHIRConfigWitness, &merklePaths, &stirAnswers) + var evaluesSumcheck = consumeWhirData(sparkConfig.WHIR2, &sparkMerklePaths, &sparkStirAnswers) + var vals = consumeWhirData(sparkConfig.WHIR1, &sparkMerklePaths, &sparkStirAnswers) + var rsws = consumeWhirData(sparkConfig.WHIR4, &sparkMerklePaths, &sparkStirAnswers) + var evaluesRSWS = consumeWhirData(sparkConfig.WHIR2, &sparkMerklePaths, &sparkStirAnswers) + var rowFinal = consumeWhirData(sparkConfig.WHIRRow, &sparkMerklePaths, &sparkStirAnswers) + var colFinal = consumeWhirData(sparkConfig.WHIRCol, &sparkMerklePaths, &sparkStirAnswers) + + // fmt.Print(len(evalues.firstRoundMerklePaths.path.stirAnswers[0][0])) + hints := Hints{ + pointRow: pointRow, + pointCol: pointCol, + witnessHints: witnessData, spartanHidingHint: hidingSpartanData, + + SparkHints: SparkMatrixHints{ + claimed: claimedA, + + evaluesSumcheck: evaluesSumcheck, + vals: vals, + rsws: rsws, + evaluesRSWS: evaluesRSWS, + rowFinal: rowFinal, + colFinal: colFinal, + + sparkClaimedEvaluations: sparkClaimedEvaluations[0], + + rowFinalCounter: rowFinalCounter[0], + rowRSAddressEvaluation: rowRSAddressEvaluation[0], + rowRSValueEvaluation: rowRSValueEvaluation[0], + rowRSTimestampEvaluation: rowRSTimestampEvaluation[0], + + colFinalCounter: colFinalCounter[0], + colRSAddressEvaluation: colRSAddressEvaluation[0], + colRSValueEvaluation: colRSValueEvaluation[0], + colRSTimestampEvaluation: colRSTimestampEvaluation[0], + }, } - err = verifyCircuit(deferred, config, hints, pk, vk, claimedEvaluations, r1cs, interner, buildOps) + + err = verifyCircuit(deferred, config, sparkConfig, hints, pk, vk, claimedEvaluations, r1cs, interner, buildOps) + if err != nil { return fmt.Errorf("verification failed: %w", err) } diff --git a/recursive-verifier/app/circuit/mtUtilities.go b/recursive-verifier/app/circuit/mtUtilities.go index 264727b3..1e73f0ba 100644 --- a/recursive-verifier/app/circuit/mtUtilities.go +++ b/recursive-verifier/app/circuit/mtUtilities.go @@ -17,15 +17,21 @@ func initialSumcheck( initialOODAnswers []frontend.Variable, whirParams WHIRParams, linearStatementEvaluations [][]frontend.Variable, + evaluationStatementClaimedValues [][]frontend.Variable, ) (InitialSumcheckData, frontend.Variable, []frontend.Variable, error) { + lengthOfLinearStatementEvaluations := len(linearStatementEvaluations[0]) + lengthOfEvaluationStatement := len(evaluationStatementClaimedValues[0]) - initialCombinationRandomness, err := GenerateCombinationRandomness(api, arthur, len(initialOODAnswers)+len(linearStatementEvaluations[0])) + initialCombinationRandomness, err := GenerateCombinationRandomness(api, arthur, len(initialOODAnswers)+lengthOfLinearStatementEvaluations+lengthOfEvaluationStatement) if err != nil { return InitialSumcheckData{}, nil, nil, err } - combinedLinearStatementEvaluations := make([]frontend.Variable, len(linearStatementEvaluations[0])) //[0, 1, 2] - for evaluationIndex := range len(linearStatementEvaluations[0]) { + _ = initialCombinationRandomness + _ = lengthOfEvaluationStatement + + combinedLinearStatementEvaluations := make([]frontend.Variable, lengthOfLinearStatementEvaluations) + for evaluationIndex := range lengthOfLinearStatementEvaluations { sum := frontend.Variable(0) multiplier := frontend.Variable(1) for j := range len(linearStatementEvaluations) { @@ -34,7 +40,20 @@ func initialSumcheck( } combinedLinearStatementEvaluations[evaluationIndex] = sum } - OODAnswersAndStatmentEvaluations := append(initialOODAnswers, combinedLinearStatementEvaluations...) + + combinedEvaluationStatementEvaluations := make([]frontend.Variable, lengthOfEvaluationStatement) + for evaluationIndex := range lengthOfEvaluationStatement { + sum := frontend.Variable(0) + multiplier := frontend.Variable(1) + for j := range len(evaluationStatementClaimedValues) { + sum = api.Add(sum, api.Mul(evaluationStatementClaimedValues[j][evaluationIndex], multiplier)) + multiplier = api.Mul(multiplier, batchingRandomness) + } + combinedEvaluationStatementEvaluations[evaluationIndex] = sum + } + + OODAnswersAndStatmentEvaluations := append(append(initialOODAnswers, combinedLinearStatementEvaluations...), combinedEvaluationStatementEvaluations...) + lastEval := utilities.DotProduct(api, initialCombinationRandomness, OODAnswersAndStatmentEvaluations) initialSumcheckFoldingRandomness, lastEval, err := runWhirSumcheckRounds(api, lastEval, arthur, whirParams.FoldingFactorArray[0], 3) @@ -42,37 +61,48 @@ func initialSumcheck( return InitialSumcheckData{}, nil, nil, err } + _ = initialSumcheckFoldingRandomness + return InitialSumcheckData{ InitialOODQueries: initialOODQueries, InitialCombinationRandomness: initialCombinationRandomness, }, lastEval, initialSumcheckFoldingRandomness, nil + } -func parseBatchedCommitment(arthur gnarkNimue.Arthur, whir_params WHIRParams) (frontend.Variable, frontend.Variable, []frontend.Variable, [][]frontend.Variable, error) { +func parseBatchedCommitment(arthur gnarkNimue.Arthur, whir_params WHIRParams) (Commitment, error) { rootHash := make([]frontend.Variable, 1) if err := arthur.FillNextScalars(rootHash); err != nil { - return nil, nil, nil, [][]frontend.Variable{}, err + return Commitment{}, err } oodPoints := make([]frontend.Variable, 1) oodAnswers := make([][]frontend.Variable, whir_params.BatchSize) if err := arthur.FillChallengeScalars(oodPoints); err != nil { - return nil, nil, nil, nil, err + return Commitment{}, err } for i := range whir_params.BatchSize { oodAnswer := make([]frontend.Variable, 1) if err := arthur.FillNextScalars(oodAnswer); err != nil { - return nil, nil, nil, nil, err + return Commitment{}, err } oodAnswers[i] = oodAnswer } - batchingRandomness := make([]frontend.Variable, 1) - if err := arthur.FillChallengeScalars(batchingRandomness); err != nil { - return nil, 0, nil, nil, err + batchingRandomness := []frontend.Variable{0} + if whir_params.BatchSize > 1 { + if err := arthur.FillChallengeScalars(batchingRandomness); err != nil { + return Commitment{}, err + } } - return rootHash[0], batchingRandomness[0], oodPoints, oodAnswers, nil + return Commitment{ + rootHash: rootHash[0], + batchingRandomness: batchingRandomness[0], + initialOODQueries: oodPoints, + initialOODAnswers: oodAnswers, + }, nil + } func generateFinalCoefficientsAndRandomnessPoints(api frontend.API, arthur gnarkNimue.Arthur, whir_params WHIRParams, circuit Merkle, uapi *uints.BinaryField[uints.U64], sc *skyscraper.Skyscraper, domainSize int, expDomainGenerator frontend.Variable) ([]frontend.Variable, []frontend.Variable, error) { diff --git a/recursive-verifier/app/circuit/types.go b/recursive-verifier/app/circuit/types.go index 67bc53b4..b7c8efd5 100644 --- a/recursive-verifier/app/circuit/types.go +++ b/recursive-verifier/app/circuit/types.go @@ -102,8 +102,13 @@ type Config struct { } type Hints struct { + pointRow []Fp256 + pointCol []Fp256 + witnessHints ZKHint spartanHidingHint ZKHint + + SparkHints SparkMatrixHints } type Hint struct { @@ -125,3 +130,83 @@ type ClaimedEvaluations struct { FSums []Fp256 GSums []Fp256 } + +type SparkConfig struct { + Transcript []byte `json:"transcript"` + IOPattern string `json:"io_pattern"` + WHIRRow WHIRConfig `json:"whir_row"` + WHIRCol WHIRConfig `json:"whir_col"` + WHIR1 WHIRConfig `json:"whir_1batched"` + WHIR2 WHIRConfig `json:"whir_2batched"` + WHIR4 WHIRConfig `json:"whir_4batched"` + LogNumTerms int `json:"log_num_terms"` +} + +type Commitment struct { + rootHash frontend.Variable + batchingRandomness frontend.Variable + initialOODQueries []frontend.Variable + initialOODAnswers [][]frontend.Variable +} + +type SPARKMatrixData struct { + Claimed frontend.Variable + + WHIR1 WHIRParams + WHIR2 WHIRParams + WHIR4 WHIRParams + LogNumTerms int + + SparkSumcheckLast []frontend.Variable + + RowFinalCounter frontend.Variable + RowRSAddressEvaluation frontend.Variable + RowRSValueEvaluation frontend.Variable + RowRSTimestampEvaluation frontend.Variable + + ColFinalCounter frontend.Variable + ColRSAddressEvaluation frontend.Variable + ColRSValueEvaluation frontend.Variable + ColRSTimestampEvaluation frontend.Variable + + EvaluesSumcheckMerkleFirstRound Merkle + EvaluesSumcheckMerkleRemainingRounds Merkle + + ValsMerkleFirstRound Merkle + ValsMerkleRemainingRounds Merkle + + RSWSMerkleFirstRound Merkle + RSWSMerkleRemainingRounds Merkle + + EvaluesRSWSMerkleFirstRound Merkle + EvaluesRSWSMerkleRemainingRounds Merkle + + RowFinalMerkleFirstRound Merkle + RowFinalMerkleRemainingRounds Merkle + + ColFinalMerkleFirstRound Merkle + ColFinalMerkleRemainingRounds Merkle +} + +type SparkMatrixHints struct { + claimed Fp256 + + evaluesSumcheck ZKHint + vals ZKHint + rsws ZKHint + evaluesRSWS ZKHint + rowFinal ZKHint + colFinal ZKHint + + sparkClaimedEvaluations []Fp256 + + rowFinalCounter Fp256 + rowRSAddressEvaluation Fp256 + rowRSValueEvaluation Fp256 + rowRSTimestampEvaluation Fp256 + + colFinalCounter Fp256 + colRSAddressEvaluation Fp256 + colRSValueEvaluation Fp256 + colRSTimestampEvaluation Fp256 +} diff --git a/recursive-verifier/app/circuit/utilities.go b/recursive-verifier/app/circuit/utilities.go index 65f76d69..790ed11a 100644 --- a/recursive-verifier/app/circuit/utilities.go +++ b/recursive-verifier/app/circuit/utilities.go @@ -22,7 +22,7 @@ import ( func calculateEQ(api frontend.API, alphas []frontend.Variable, r []frontend.Variable) frontend.Variable { ans := frontend.Variable(1) for i, alpha := range alphas { - ans = api.Mul(ans, api.Add(api.Mul(alpha, r[i]), api.Mul(api.Sub(frontend.Variable(1), alpha), api.Sub(frontend.Variable(1), r[i])))) + ans = api.Mul(ans, api.Add(api.Mul(alpha, r[i]), api.Mul(api.Sub(1, alpha), api.Sub(1, r[i])))) } return ans } @@ -176,7 +176,7 @@ func runZKSumcheck( whirParams WHIRParams, ) ([]frontend.Variable, frontend.Variable, error) { - rootHash, batchingRandomness, initialOODQueries, initialOODAnswers, err := parseBatchedCommitment(arthur, whirParams) + commitment, err := parseBatchedCommitment(arthur, whirParams) if err != nil { return nil, nil, err } @@ -195,7 +195,10 @@ func runZKSumcheck( lastEval, polynomialSums := unblindLastEval(api, arthur, lastEval, rhoRandomness) - _, err = RunZKWhir(api, arthur, uapi, sc, circuit.HidingSpartanMerkle, circuit.HidingSpartanFirstRound, whirParams, [][]frontend.Variable{{polynomialSums[0]}, {polynomialSums[1]}}, circuit.HidingSpartanLinearStatementEvaluations, batchingRandomness, initialOODQueries, initialOODAnswers, rootHash) + _, err = RunZKWhir(api, arthur, uapi, sc, circuit.HidingSpartanMerkle, circuit.HidingSpartanFirstRound, whirParams, [][]frontend.Variable{{polynomialSums[0]}, {polynomialSums[1]}}, circuit.HidingSpartanLinearStatementEvaluations, commitment, + [][]frontend.Variable{{}, {}}, + [][]frontend.Variable{}, + ) if err != nil { return nil, nil, err } diff --git a/recursive-verifier/app/circuit/whir.go b/recursive-verifier/app/circuit/whir.go index b43ec3ad..b9c3dc47 100644 --- a/recursive-verifier/app/circuit/whir.go +++ b/recursive-verifier/app/circuit/whir.go @@ -58,35 +58,30 @@ func RunZKWhir( whirParams WHIRParams, linearStatementEvaluations [][]frontend.Variable, linearStatementValuesAtPoints []frontend.Variable, - batchingRandomness frontend.Variable, - initialOODQueries []frontend.Variable, - initialOODAnswers [][]frontend.Variable, - rootHashes frontend.Variable, + + commitment Commitment, + + // batchingRandomness frontend.Variable, + // initialOODQueries []frontend.Variable, + // initialOODAnswers [][]frontend.Variable, + // rootHashes frontend.Variable, + + evaluationStatementClaimedValues [][]frontend.Variable, + evaluationPoints [][]frontend.Variable, + ) (totalFoldingRandomness []frontend.Variable, err error) { - initialOODs := oodAnswers(api, initialOODAnswers, batchingRandomness) - // batchSizeLen := whirParams.BatchSize + initialOODs := oodAnswers(api, commitment.initialOODAnswers, commitment.batchingRandomness) - initialSumcheckData, lastEval, initialSumcheckFoldingRandomness, err := initialSumcheck(api, arthur, batchingRandomness, initialOODQueries, initialOODs, whirParams, linearStatementEvaluations) + initialSumcheckData, lastEval, initialSumcheckFoldingRandomness, err := initialSumcheck(api, arthur, commitment.batchingRandomness, commitment.initialOODQueries, initialOODs, whirParams, linearStatementEvaluations, evaluationStatementClaimedValues) if err != nil { return } - copyOfFirstLeaves := make([][][]frontend.Variable, len(firstRound.Leaves)) - for i := range len(firstRound.Leaves) { - copyOfFirstLeaves[i] = make([][]frontend.Variable, len(firstRound.Leaves[i])) - for j := range len(firstRound.Leaves[i]) { - copyOfFirstLeaves[i][j] = make([]frontend.Variable, len(firstRound.Leaves[i][j])) - for k := range len(firstRound.Leaves[i][j]) { - copyOfFirstLeaves[i][j][k] = firstRound.Leaves[i][j][k] - } - } - } - roundAnswers := make([][][]frontend.Variable, len(circuit.Leaves)+1) foldSize := 1 << whirParams.FoldingFactorArray[0] - collapsed := rlcBatchedLeaves(api, firstRound.Leaves[0], foldSize, whirParams.BatchSize, batchingRandomness) + collapsed := rlcBatchedLeaves(api, firstRound.Leaves[0], foldSize, whirParams.BatchSize, commitment.batchingRandomness) roundAnswers[0] = collapsed for i := range len(circuit.Leaves) { @@ -130,7 +125,7 @@ func RunZKWhir( if err != nil { return } - err = verifyMerkleTreeProofs(api, uapi, sc, firstRound.LeafIndexes[0], firstRound.Leaves[0], firstRound.LeafSiblingHashes[0], firstRound.AuthPaths[0], rootHashes) + err = verifyMerkleTreeProofs(api, uapi, sc, firstRound.LeafIndexes[0], firstRound.Leaves[0], firstRound.LeafSiblingHashes[0], firstRound.AuthPaths[0], commitment.rootHash) if err != nil { return } @@ -207,6 +202,7 @@ func RunZKWhir( mainRoundData, totalFoldingRandomness, linearStatementValuesAtPoints, + evaluationPoints, ) api.AssertIsEqual( @@ -218,142 +214,143 @@ func RunZKWhir( } //nolint:unused -func runWhir( - api frontend.API, - arthur gnarkNimue.Arthur, - uapi *uints.BinaryField[uints.U64], - sc *skyscraper.Skyscraper, - circuit Merkle, - whirParams WHIRParams, - linearStatementEvaluations []frontend.Variable, - linearStatementValuesAtPoints []frontend.Variable, -) (totalFoldingRandomness []frontend.Variable, err error) { - if err = fillInAndVerifyRootHash(0, api, uapi, sc, circuit, arthur); err != nil { - return - } - - initialOODQueries, initialOODAnswers, tempErr := fillInOODPointsAndAnswers(whirParams.CommittmentOODSamples, arthur) - if tempErr != nil { - err = tempErr - return - } - - initialCombinationRandomness, tempErr := GenerateCombinationRandomness(api, arthur, whirParams.CommittmentOODSamples+len(linearStatementEvaluations)) - if tempErr != nil { - err = tempErr - return - } - - OODAnswersAndStatmentEvaluations := append(initialOODAnswers, linearStatementEvaluations...) - lastEval := utilities.DotProduct(api, initialCombinationRandomness, OODAnswersAndStatmentEvaluations) - - initialSumcheckFoldingRandomness, lastEval, tempErr := runWhirSumcheckRounds(api, lastEval, arthur, whirParams.FoldingFactorArray[0], 3) - if tempErr != nil { - err = tempErr - return - } - - initialData := InitialSumcheckData{ - InitialOODQueries: initialOODQueries, - InitialCombinationRandomness: initialCombinationRandomness, - } - - computedFold := computeFold(circuit.Leaves[0], initialSumcheckFoldingRandomness, api) - - mainRoundData := generateEmptyMainRoundData(whirParams) - - expDomainGenerator := utilities.Exponent(api, uapi, whirParams.StartingDomainBackingDomainGenerator, uints.NewU64(uint64(1<, - /// out of domain samples - pub ood_samples: Vec, - /// number of queries - pub num_queries: Vec, - /// proof of work bits - pub pow_bits: Vec, - /// final queries - pub final_queries: usize, - /// final proof of work bits - pub final_pow_bits: i32, - /// final folding proof of work bits - pub final_folding_pow_bits: i32, - /// domain generator string - pub domain_generator: String, - /// batch size - pub batch_size: usize, -} - -impl WHIRConfigGnark { - pub fn new(whir_params: &WhirConfig) -> Self { - WHIRConfigGnark { - n_rounds: whir_params - .folding_factor - .compute_number_of_rounds(whir_params.mv_parameters.num_variables) - .0, - rate: whir_params.starting_log_inv_rate, - n_vars: whir_params.mv_parameters.num_variables, - folding_factor: (0..(whir_params - .folding_factor - .compute_number_of_rounds(whir_params.mv_parameters.num_variables) - .0)) - .map(|round| whir_params.folding_factor.at_round(round)) - .collect(), - ood_samples: whir_params - .round_parameters - .iter() - .map(|x| x.ood_samples) - .collect(), - num_queries: whir_params - .round_parameters - .iter() - .map(|x| x.num_queries) - .collect(), - pow_bits: whir_params - .round_parameters - .iter() - .map(|x| x.pow_bits as i32) - .collect(), - final_queries: whir_params.final_queries, - final_pow_bits: whir_params.final_pow_bits as i32, - final_folding_pow_bits: whir_params.final_folding_pow_bits as i32, - domain_generator: format!( - "{}", - whir_params.starting_domain.backing_domain.group_gen() - ), - batch_size: whir_params.batch_size, - } - } -} - /// Writes config used for Gnark circuit to a file #[instrument(skip_all)] pub fn gnark_parameters( diff --git a/tooling/spark-cli/Cargo.toml b/tooling/spark-cli/Cargo.toml new file mode 100644 index 00000000..7cd3b34c --- /dev/null +++ b/tooling/spark-cli/Cargo.toml @@ -0,0 +1,29 @@ +[package] +name = "spark-cli" +version = "0.1.0" +edition.workspace = true +rust-version.workspace = true +authors.workspace = true +license.workspace = true +homepage.workspace = true +repository.workspace = true + +[dependencies] +provekit-common.workspace = true +provekit-spark.workspace = true +anyhow.workspace = true +argh.workspace = true +serde_json.workspace = true +tracing.workspace = true +tracing-subscriber.workspace = true +tracing-tracy = { workspace = true, optional = true, features = ["default", "sampling","manual-lifetime"] } + +[lints] +workspace = true + +[features] +default = ["profiling-allocator"] +profiling-allocator = [] +tracy = ["dep:tracing-tracy"] + + diff --git a/tooling/spark-cli/src/cmd/mod.rs b/tooling/spark-cli/src/cmd/mod.rs new file mode 100644 index 00000000..ec5e380f --- /dev/null +++ b/tooling/spark-cli/src/cmd/mod.rs @@ -0,0 +1,2 @@ +pub mod prove; +pub mod verify; diff --git a/tooling/spark-cli/src/cmd/prove.rs b/tooling/spark-cli/src/cmd/prove.rs new file mode 100644 index 00000000..1c2b45bb --- /dev/null +++ b/tooling/spark-cli/src/cmd/prove.rs @@ -0,0 +1,77 @@ +use { + anyhow::{Context, Result}, + argh::FromArgs, + provekit_common::{file::read, utils::next_power_of_two, NoirProof, Prover}, + provekit_spark::{SPARKProofGnark, SPARKProver, SPARKProverScheme}, + std::{fs::File, io::Write, path::PathBuf}, + tracing::instrument, +}; + +#[derive(FromArgs)] +#[argh(subcommand, name = "prove")] +#[argh(description = "Generate a SPARK proof")] +pub struct ProveArgs { + /// path to NPS file + #[argh(option)] + noir_proof_scheme: PathBuf, + + /// path to NoirProof file (.np or .json) containing the SPARK statement + #[argh(option)] + noir_proof: PathBuf, + + /// output path for proof (default: spark_proof.json) + #[argh(option, short = 'o', default = "PathBuf::from(\"spark_proof.json\")")] + output: PathBuf, + + /// output path for gnark proof (default: gnark_spark_proof.json) + #[argh(option, default = "PathBuf::from(\"gnark_spark_proof.json\")")] + gnark_output: PathBuf, +} + +#[instrument(skip_all)] +pub fn execute(args: ProveArgs) -> Result<()> { + println!("Loading R1CS from {:?}...", args.noir_proof_scheme); + let scheme: Prover = + read(&args.noir_proof_scheme).context("while reading Noir proof scheme")?; + let mut r1cs = scheme.r1cs.clone(); + r1cs.grow_matrices( + 1 << next_power_of_two(r1cs.num_constraints()), + 1 << next_power_of_two(r1cs.num_witnesses()), + ); + drop(scheme); + + println!("Loading NoirProof from {:?}...", args.noir_proof); + let noir_proof: NoirProof = read(&args.noir_proof).context("Failed to read NoirProof file")?; + + // Extract SPARK statement from the proof + let spark_statement = noir_proof.spark_statement; + println!("✓ Extracted SPARK statement from NoirProof"); + + println!("Creating SPARK scheme..."); + let scheme = SPARKProverScheme::new_for_r1cs(&r1cs); + + println!("Generating SPARK proof..."); + let proof = scheme + .prove(&r1cs, &spark_statement) + .context("Failed to generate proof")?; + + // Write proof + println!("Writing proof to {:?}...", args.output); + let mut file = File::create(&args.output).context("Failed to create output file")?; + file.write_all(serde_json::to_string(&proof)?.as_bytes()) + .context("Failed to write proof")?; + + // Write gnark proof + println!("Writing gnark proof to {:?}...", args.gnark_output); + let log_num_terms = + provekit_common::utils::next_power_of_two(proof.matrix_dimensions.nonzero_terms); + let gnark_proof = SPARKProofGnark::from_proof(&proof, log_num_terms); + let mut gnark_file = + File::create(&args.gnark_output).context("Failed to create gnark output file")?; + gnark_file + .write_all(serde_json::to_string(&gnark_proof)?.as_bytes()) + .context("Failed to write gnark proof")?; + + println!("✓ SPARK proof generated successfully"); + Ok(()) +} diff --git a/tooling/spark-cli/src/cmd/verify.rs b/tooling/spark-cli/src/cmd/verify.rs new file mode 100644 index 00000000..0ca79687 --- /dev/null +++ b/tooling/spark-cli/src/cmd/verify.rs @@ -0,0 +1,45 @@ +use { + anyhow::{Context, Result}, + argh::FromArgs, + provekit_common::{file::read, NoirProof}, + provekit_spark::{SPARKProof, SPARKVerifier, SPARKVerifierScheme}, + std::{fs, path::PathBuf}, +}; + +#[derive(FromArgs)] +#[argh(subcommand, name = "verify")] +#[argh(description = "Verify a SPARK proof")] +pub struct VerifyArgs { + /// path to proof file + #[argh(option)] + spark_proof: PathBuf, + + /// path to NoirProof file (.np or .json) containing the SPARK statement + #[argh(option)] + noir_proof: PathBuf, +} + +pub fn execute(args: VerifyArgs) -> Result<()> { + println!("Loading spark-proof from {:?}...", args.spark_proof); + let proof_str = fs::read_to_string(&args.spark_proof).context("Failed to read proof file")?; + let proof: SPARKProof = + serde_json::from_str(&proof_str).context("Failed to deserialize spark-proof")?; + + println!("Loading NoirProof from {:?}...", args.noir_proof); + let noir_proof: NoirProof = read(&args.noir_proof).context("Failed to read NoirProof file")?; + + println!("✓ Extracted SPARK statement from NoirProof"); + let spark_statement = noir_proof.spark_statement.clone(); + drop(noir_proof); + + println!("Creating verification scheme..."); + let scheme = SPARKVerifierScheme::from_proof(&proof); + + println!("Verifying proof..."); + scheme + .verify(&proof, &spark_statement) + .context("Verification failed")?; + + println!("✓ Proof verified successfully"); + Ok(()) +} diff --git a/tooling/spark-cli/src/main.rs b/tooling/spark-cli/src/main.rs new file mode 100644 index 00000000..e921d36b --- /dev/null +++ b/tooling/spark-cli/src/main.rs @@ -0,0 +1,45 @@ +#![allow(missing_docs)] +mod cmd; +#[cfg(feature = "profiling-allocator")] +mod profiling_alloc; +mod span_stats; + +#[cfg(feature = "profiling-allocator")] +use crate::profiling_alloc::ProfilingAllocator; + +#[cfg(feature = "profiling-allocator")] +#[global_allocator] +static ALLOCATOR: ProfilingAllocator = ProfilingAllocator::new(); + +use { + anyhow::Result, + argh::FromArgs, + span_stats::SpanStats, + tracing::{instrument, subscriber}, + tracing_subscriber::{self, layer::SubscriberExt as _, Registry}, +}; + +#[derive(FromArgs)] +#[argh(description = "SPARK Prover CLI")] +struct Args { + #[argh(subcommand)] + command: Command, +} + +#[derive(FromArgs)] +#[argh(subcommand)] +enum Command { + Prove(cmd::prove::ProveArgs), + Verify(cmd::verify::VerifyArgs), +} + +fn main() -> Result<()> { + let args: Args = argh::from_env(); + let subscriber = Registry::default().with(SpanStats); + subscriber::set_global_default(subscriber)?; + + match args.command { + Command::Prove(args) => cmd::prove::execute(args), + Command::Verify(args) => cmd::verify::execute(args), + } +} diff --git a/tooling/spark-cli/src/profiling_alloc.rs b/tooling/spark-cli/src/profiling_alloc.rs new file mode 100644 index 00000000..c1fae147 --- /dev/null +++ b/tooling/spark-cli/src/profiling_alloc.rs @@ -0,0 +1,168 @@ +use std::{ + alloc::{GlobalAlloc, Layout, System as SystemAlloc}, + sync::atomic::{AtomicUsize, Ordering}, +}; +#[cfg(feature = "tracy")] +use {std::sync::atomic::AtomicBool, tracing_tracy::client::sys as tracy_sys}; + +/// Custom allocator that keeps track of statistics to see program memory +/// consumption. +pub struct ProfilingAllocator { + /// Allocated bytes + current: AtomicUsize, + + /// Maximum allocated bytes (reached so far) + max: AtomicUsize, + + /// Number of allocations done + count: AtomicUsize, + + /// Enable Tracy allocation profiling + #[cfg(feature = "tracy")] + tracy_enabled: AtomicBool, + + /// Stack depth to include in Tracy allocation profiling + /// (only used if `tracy_enabled` is true) + /// **Note.** This makes allocation very slow. + #[cfg(feature = "tracy")] + tracy_depth: AtomicUsize, +} + +impl ProfilingAllocator { + pub const fn new() -> Self { + Self { + current: AtomicUsize::new(0), + max: AtomicUsize::new(0), + count: AtomicUsize::new(0), + + #[cfg(feature = "tracy")] + tracy_enabled: AtomicBool::new(false), + #[cfg(feature = "tracy")] + tracy_depth: AtomicUsize::new(0), + } + } + + pub fn current(&self) -> usize { + self.current.load(Ordering::SeqCst) + } + + pub fn max(&self) -> usize { + self.max.load(Ordering::SeqCst) + } + + pub fn reset_max(&self) -> usize { + let current = self.current(); + self.max.store(current, Ordering::SeqCst); + current + } + + pub fn count(&self) -> usize { + self.count.load(Ordering::SeqCst) + } + + #[cfg(feature = "tracy")] + pub fn enable_tracy(&self, depth: usize) { + self.tracy_enabled.store(true, Ordering::SeqCst); + self.tracy_depth.store(depth, Ordering::SeqCst); + } + + #[allow(unused_variables)] // Conditional compilation may not use all variables + fn tracy_alloc(&self, size: usize, ptr: *mut u8) { + // If Tracy profiling is enabled, report this allocation to Tracy. + #[cfg(feature = "tracy")] + if self.tracy_enabled.load(Ordering::SeqCst) { + let depth = self.tracy_depth.load(Ordering::SeqCst); + if depth == 0 { + // If depth is 0, we don't capture any stack information + unsafe { + tracy_sys::___tracy_emit_memory_alloc(ptr.cast(), size, 1); + } + } else { + // Capture stack information up to `depth` frames + unsafe { + tracy_sys::___tracy_emit_memory_alloc_callstack( + ptr.cast(), + size, + depth as i32, + 1, + ); + } + } + } + } + + #[allow(unused_variables)] // Conditional compilation may not use all variables + fn tracy_dealloc(&self, ptr: *mut u8) { + // If Tracy profiling is enabled, report this deallocation to Tracy. + #[cfg(feature = "tracy")] + if self.tracy_enabled.load(Ordering::SeqCst) { + let depth = self.tracy_depth.load(Ordering::SeqCst); + if depth == 0 { + // If depth is 0, we don't capture any stack information + unsafe { + tracy_sys::___tracy_emit_memory_free(ptr.cast(), 1); + } + } else { + // Capture stack information up to `depth` frames + unsafe { + tracy_sys::___tracy_emit_memory_free_callstack(ptr.cast(), depth as i32, 1); + } + } + } + } +} + +#[allow(unsafe_code)] +unsafe impl GlobalAlloc for ProfilingAllocator { + unsafe fn alloc(&self, layout: Layout) -> *mut u8 { + let ptr = SystemAlloc.alloc(layout); + let size = layout.size(); + let current = self + .current + .fetch_add(size, Ordering::SeqCst) + .wrapping_add(size); + self.max.fetch_max(current, Ordering::SeqCst); + self.count.fetch_add(1, Ordering::SeqCst); + self.tracy_alloc(size, ptr); + ptr + } + + unsafe fn dealloc(&self, ptr: *mut u8, layout: Layout) { + self.current.fetch_sub(layout.size(), Ordering::SeqCst); + self.tracy_dealloc(ptr); + SystemAlloc.dealloc(ptr, layout); + } + + unsafe fn alloc_zeroed(&self, layout: Layout) -> *mut u8 { + let ptr = SystemAlloc.alloc_zeroed(layout); + let size = layout.size(); + let current = self + .current + .fetch_add(size, Ordering::SeqCst) + .wrapping_add(size); + self.max.fetch_max(current, Ordering::SeqCst); + self.count.fetch_add(1, Ordering::SeqCst); + self.tracy_alloc(size, ptr); + ptr + } + + unsafe fn realloc(&self, ptr: *mut u8, old_layout: Layout, new_size: usize) -> *mut u8 { + self.tracy_dealloc(ptr); + let ptr = SystemAlloc.realloc(ptr, old_layout, new_size); + let old_size = old_layout.size(); + if new_size > old_size { + let diff = new_size - old_size; + let current = self + .current + .fetch_add(diff, Ordering::SeqCst) + .wrapping_add(diff); + self.max.fetch_max(current, Ordering::SeqCst); + self.count.fetch_add(1, Ordering::SeqCst); + } else { + self.current + .fetch_sub(old_size - new_size, Ordering::SeqCst); + } + self.tracy_alloc(new_size, ptr); + ptr + } +} diff --git a/tooling/spark-cli/src/span_stats.rs b/tooling/spark-cli/src/span_stats.rs new file mode 100644 index 00000000..64f36318 --- /dev/null +++ b/tooling/spark-cli/src/span_stats.rs @@ -0,0 +1,263 @@ +//! Using `tracing` spans to print performance statistics for the program. +//! +//! NOTE: This module is only included in the bin, not in the lib. +#[cfg(feature = "profiling-allocator")] +use crate::ALLOCATOR; +use { + provekit_common::utils::human, + std::{ + cmp::max, + fmt::{self, Write as _}, + time::Instant, + }, + tracing::{ + field::{Field, Visit}, + span::{Attributes, Id}, + Level, Subscriber, + }, + tracing_subscriber::{self, layer::Context, registry::LookupSpan, Layer}, +}; + +const DIM: &str = "\x1b[2m"; +const UNDIM: &str = "\x1b[22m"; + +// Span extension data +pub struct Data { + depth: usize, + time: Instant, + + #[cfg(feature = "profiling-allocator")] + memory: usize, + + #[cfg(feature = "profiling-allocator")] + allocations: usize, + + /// `peak_memory` will be updated as it is not monotonic + #[cfg(feature = "profiling-allocator")] + peak_memory: usize, + + children: bool, + kvs: Vec<(&'static str, String)>, +} + +impl Data { + pub fn new(attrs: &Attributes<'_>, depth: usize) -> Self { + let mut span = Self { + depth, + time: Instant::now(), + + #[cfg(feature = "profiling-allocator")] + memory: ALLOCATOR.current(), + #[cfg(feature = "profiling-allocator")] + allocations: ALLOCATOR.count(), + #[cfg(feature = "profiling-allocator")] + peak_memory: ALLOCATOR.current(), + + children: false, + kvs: Vec::new(), + }; + attrs.record(&mut span); + span + } +} + +impl Visit for Data { + fn record_debug(&mut self, field: &Field, value: &dyn fmt::Debug) { + self.kvs.push((field.name(), format!("{value:?}"))); + } +} + +pub struct FmtEvent<'a>(&'a mut String); + +impl Visit for FmtEvent<'_> { + fn record_debug(&mut self, field: &Field, value: &dyn fmt::Debug) { + match field.name() { + "message" => { + write!(self.0, " {value:?}").unwrap(); + } + name => { + write!(self.0, " {name}={value:?}").unwrap(); + } + } + } +} + +/// Logging layer that keeps track of time and memory consumption of spans. +pub struct SpanStats; + +impl Layer for SpanStats +where + S: Subscriber + for<'span> LookupSpan<'span>, +{ + fn on_new_span(&self, attrs: &Attributes, id: &Id, ctx: Context) { + let span = ctx.span(id).expect("invalid span in on_new_span"); + + // Update parent + if let Some(parent) = span.parent() { + if let Some(data) = parent.extensions_mut().get_mut::() { + data.children = true; + + #[cfg(feature = "profiling-allocator")] + { + data.peak_memory = max(data.peak_memory, ALLOCATOR.max()); + } + } + } + #[cfg(feature = "profiling-allocator")] + ALLOCATOR.reset_max(); + + // Add Data if it hasn't already + if span.extensions().get::().is_none() { + let depth = span.parent().map_or(0, |s| { + s.extensions() + .get::() + .expect("parent span has no data") + .depth + + 1 + }); + let data = Data::new(attrs, depth); + span.extensions_mut().insert(data); + } + + // Fetch data + let ext = span.extensions(); + let data = ext.get::().expect("span does not have data"); + + let mut buffer = String::with_capacity(100); + + // Box draw tree indentation + if data.depth >= 1 { + for _ in 0..(data.depth - 1) { + let _ = write!(&mut buffer, "│ "); + } + let _ = write!(&mut buffer, "├─"); + } + let _ = write!(&mut buffer, "╮ "); + + // Span name + let _ = write!( + &mut buffer, + "{DIM}{}::{UNDIM}{}", + span.metadata().target(), + span.metadata().name() + ); + + // KV args + for (key, val) in &data.kvs { + let _ = write!(&mut buffer, " {key}={val}"); + } + + // Start-of-span memory stats + #[cfg(feature = "profiling-allocator")] + let _ = write!( + &mut buffer, + " {DIM}start:{UNDIM} {}B{DIM} current, {UNDIM}{:#}{DIM} allocations{UNDIM}", + human(ALLOCATOR.current() as f64), + human(ALLOCATOR.count() as f64) + ); + + eprintln!("{buffer}"); + } + + fn on_event(&self, event: &tracing::Event<'_>, ctx: Context<'_, S>) { + let span = ctx.current_span().id().and_then(|id| ctx.span(id)); + + let mut buffer = String::with_capacity(100); + + // Span indentation + time in span + if let Some(span) = &span { + // Flag child on parent + if let Some(parent) = span.parent() { + if let Some(data) = parent.extensions_mut().get_mut::() { + data.children = true; + } + } + + if let Some(data) = span.extensions().get::() { + // Box draw tree indentation + for _ in 0..=data.depth { + let _ = write!(&mut buffer, "│ "); + } + + // Time + let elapsed = data.time.elapsed(); + let _ = write!( + &mut buffer, + "{DIM}{:6}s {UNDIM}", + human(elapsed.as_secs_f64()) + ); + } + } + + // Log level + match *event.metadata().level() { + Level::TRACE => write!(&mut buffer, "TRACE"), + Level::DEBUG => write!(&mut buffer, "DEBUG"), + Level::INFO => write!(&mut buffer, "\x1b[1;32mINFO\x1b[0m"), + Level::WARN => write!(&mut buffer, "\x1b[1;38;5;208mWARN\x1b[0m"), + Level::ERROR => write!(&mut buffer, "\x1b[1;31mERROR\x1b[0m"), + } + .unwrap(); + + let mut visitor = FmtEvent(&mut buffer); + event.record(&mut visitor); + + eprintln!("{buffer}"); + } + + fn on_close(&self, id: Id, ctx: Context) { + let span = ctx.span(&id).expect("invalid span in on_close"); + let ext = span.extensions(); + let data = ext.get::().expect("span does not have data"); + let duration = data.time.elapsed(); + + let mut buffer = String::with_capacity(100); + + // Box draw tree indentation + if data.depth >= 1 { + for _ in 0..(data.depth - 1) { + let _ = write!(&mut buffer, "│ "); + } + let _ = write!(&mut buffer, "├─"); + } + let _ = write!(&mut buffer, "╯ "); + + // Short span name if not childless + if data.children { + let _ = write!(&mut buffer, "{DIM}{}: {UNDIM}", span.metadata().name()); + } + + // Print stats + let _ = write!( + &mut buffer, + "{}s{DIM} duration", + human(duration.as_secs_f64()), + ); + #[cfg(feature = "profiling-allocator")] + { + let peak_memory: usize = std::cmp::max(ALLOCATOR.max(), data.peak_memory); + let allocations = ALLOCATOR.count() - data.allocations; + let own = peak_memory - data.memory; + + // Update parent + if let Some(parent) = span.parent() { + if let Some(data) = parent.extensions_mut().get_mut::() { + data.peak_memory = max(data.peak_memory, peak_memory); + } + } + + let current_now = ALLOCATOR.current(); + let _ = write!( + &mut buffer, + ", {UNDIM}{}B{DIM} peak memory, {UNDIM}{}B{DIM} local, {UNDIM}{}B{DIM} current, \ + {UNDIM}{:#}{DIM} allocations{UNDIM}", + human(peak_memory as f64), + human(own as f64), + human(current_now as f64), + human(allocations as f64) + ); + } + + eprintln!("{buffer}"); + } +}