From 1f0427b7e6624abae8d07771a1946b31f008491e Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Wed, 25 Mar 2026 14:17:19 +0800 Subject: [PATCH] Refactor num_instances to fixed [usize; 2] in zkVM prove/verify flow --- ceno_recursion/src/zkvm_verifier/binding.rs | 13 ++-- ceno_recursion/src/zkvm_verifier/verifier.rs | 75 +++++++++----------- ceno_zkvm/benches/riscv_add.rs | 2 +- ceno_zkvm/src/scheme.rs | 2 +- ceno_zkvm/src/scheme/hal.rs | 2 +- ceno_zkvm/src/scheme/prover.rs | 25 +++---- ceno_zkvm/src/scheme/tests.rs | 2 +- ceno_zkvm/src/scheme/verifier.rs | 12 ++-- ceno_zkvm/src/structs.rs | 33 +++------ ceno_zkvm/src/tables/shard_ram.rs | 2 +- 10 files changed, 65 insertions(+), 103 deletions(-) diff --git a/ceno_recursion/src/zkvm_verifier/binding.rs b/ceno_recursion/src/zkvm_verifier/binding.rs index 08c76dbe0..8d9b7e5f5 100644 --- a/ceno_recursion/src/zkvm_verifier/binding.rs +++ b/ceno_recursion/src/zkvm_verifier/binding.rs @@ -34,6 +34,7 @@ use openvm_native_recursion::{ use openvm_stark_backend::p3_field::{FieldAlgebra, extension::BinomialExtensionField}; use openvm_stark_sdk::p3_baby_bear::BabyBear; use p3::field::FieldExtensionAlgebra; +use std::cmp::max; use sumcheck::structs::IOPProof; pub type F = BabyBear; @@ -402,7 +403,7 @@ pub struct ZKVMChipProofInput { pub has_ecc_proof: usize, pub ecc_proof: EccQuarkProofInput, - pub num_instances: Vec, + pub num_instances: [usize; 2], pub wits_in_evals: Vec, pub fixed_in_evals: Vec, @@ -657,18 +658,12 @@ impl Hintable for ZKVMChipProofInput { stream.extend(>::write(&self.has_ecc_proof)); stream.extend(self.ecc_proof.write()); - stream.extend( as Hintable>::write( - &self.num_instances, - )); + stream.extend(self.num_instances.to_vec().write()); let n_inst_0 = self.num_instances[0]; let n_inst_0_bit_decomps = decompose_minus_one_bits(n_inst_0); - let n_inst_1 = if self.num_instances.len() > 1 { - self.num_instances[1] - } else { - 1usize - }; + let n_inst_1 = max(self.num_instances[1], 1); let n_inst_1_bit_decomps = decompose_minus_one_bits(n_inst_1); stream.extend(n_inst_0_bit_decomps.write()); diff --git a/ceno_recursion/src/zkvm_verifier/verifier.rs b/ceno_recursion/src/zkvm_verifier/verifier.rs index e0d823814..829860b73 100644 --- a/ceno_recursion/src/zkvm_verifier/verifier.rs +++ b/ceno_recursion/src/zkvm_verifier/verifier.rs @@ -122,54 +122,46 @@ pub fn verify_zkvm_proof>( }, ); - builder - .if_eq(zkvm_proof_input.shard_id.clone(), Usize::from(0)) - .then(|builder| { - if let Some(fixed_commit) = vk.fixed_commit.as_ref() { - let commit: crate::basefold_verifier::hash::Hash = fixed_commit.commit().into(); - let commit_array: Array> = builder.dyn_array(commit.value.len()); + if let Some(fixed_commit) = vk.fixed_commit.as_ref() { + let commit: crate::basefold_verifier::hash::Hash = fixed_commit.commit().into(); + let commit_array: Array> = builder.dyn_array(commit.value.len()); - commit.value.into_iter().enumerate().for_each(|(i, v)| { - let v = builder.constant(v); - // TODO: put fixed commit to public values - // builder.commit_public_value(v); + commit.value.into_iter().enumerate().for_each(|(i, v)| { + let v = builder.constant(v); + // TODO: put fixed commit to public values + // builder.commit_public_value(v); - builder.set_value(&commit_array, i, v); - }); + builder.set_value(&commit_array, i, v); + }); - challenger_multi_observe(builder, &mut challenger, &commit_array); + challenger_multi_observe(builder, &mut challenger, &commit_array); - let log2_max_codeword_size_felt = builder.constant(C::F::from_canonical_usize( - fixed_commit.log2_max_codeword_size, - )); + let log2_max_codeword_size_felt = builder.constant(C::F::from_canonical_usize( + fixed_commit.log2_max_codeword_size, + )); - challenger.observe(builder, log2_max_codeword_size_felt); - } - }); + challenger.observe(builder, log2_max_codeword_size_felt); + } - builder - .if_ne(zkvm_proof_input.shard_id.clone(), Usize::from(0)) - .then(|builder| { - if let Some(fixed_commit) = vk.fixed_no_omc_init_commit.as_ref() { - let commit: crate::basefold_verifier::hash::Hash = fixed_commit.commit().into(); - let commit_array: Array> = builder.dyn_array(commit.value.len()); - - commit.value.into_iter().enumerate().for_each(|(i, v)| { - let v = builder.constant(v); - // TODO: put fixed commit to public values - // builder.commit_public_value(v); - - builder.set_value(&commit_array, i, v); - }); - challenger_multi_observe(builder, &mut challenger, &commit_array); + if let Some(fixed_commit) = vk.fixed_no_omc_init_commit.as_ref() { + let commit: crate::basefold_verifier::hash::Hash = fixed_commit.commit().into(); + let commit_array: Array> = builder.dyn_array(commit.value.len()); - let log2_max_codeword_size_felt = builder.constant(C::F::from_canonical_usize( - fixed_commit.log2_max_codeword_size, - )); + commit.value.into_iter().enumerate().for_each(|(i, v)| { + let v = builder.constant(v); + // TODO: put fixed commit to public values + // builder.commit_public_value(v); - challenger.observe(builder, log2_max_codeword_size_felt); - } + builder.set_value(&commit_array, i, v); }); + challenger_multi_observe(builder, &mut challenger, &commit_array); + + let log2_max_codeword_size_felt = builder.constant(C::F::from_canonical_usize( + fixed_commit.log2_max_codeword_size, + )); + + challenger.observe(builder, log2_max_codeword_size_felt); + } iter_zip!(builder, zkvm_proof_input.chip_proofs).for_each(|ptr_vec, builder| { let chip_proofs = builder.iter_ptr_get(&zkvm_proof_input.chip_proofs, ptr_vec[0]); @@ -712,7 +704,8 @@ pub fn verify_chip_proof( let zero_bit_decomps: Array> = builder.dyn_array(32); let selector_ctxs: Vec> = if cs.ec_final_sum.is_empty() { - builder.assert_usize_eq(chip_proof.num_instances.len(), Usize::from(1)); + let non_shard_n1 = Usize::Var(builder.get(&chip_proof.num_instances, 1)); + builder.assert_usize_eq(non_shard_n1, Usize::from(0)); let num_instances_bit_decomps: Array>> = builder.dyn_array(1); builder.set( &num_instances_bit_decomps, @@ -740,8 +733,6 @@ pub fn verify_chip_proof( .unwrap_or(0) ] } else { - builder.assert_usize_eq(chip_proof.num_instances.len(), Usize::from(2)); - let num_inst_0_bit_decomps: Array>> = builder.dyn_array(1); let num_inst_1_bit_decomps: Array>> = builder.dyn_array(1); let num_inst_sum_bit_decomps: Array>> = builder.dyn_array(1); diff --git a/ceno_zkvm/benches/riscv_add.rs b/ceno_zkvm/benches/riscv_add.rs index 3823219d7..9e150f6a3 100644 --- a/ceno_zkvm/benches/riscv_add.rs +++ b/ceno_zkvm/benches/riscv_add.rs @@ -114,7 +114,7 @@ fn bench_add(c: &mut Criterion) { structural_witness: vec![], public_input: vec![], pub_io_evals: vec![], - num_instances: vec![num_instances], + num_instances: [num_instances, 0], has_ecc_ops: false, }; let task = ChipTask { diff --git a/ceno_zkvm/src/scheme.rs b/ceno_zkvm/src/scheme.rs index fa9127f10..e2bb673e3 100644 --- a/ceno_zkvm/src/scheme.rs +++ b/ceno_zkvm/src/scheme.rs @@ -64,7 +64,7 @@ pub struct ZKVMChipProof { pub tower_proof: TowerProofs, pub ecc_proof: Option>, - pub num_instances: Vec, + pub num_instances: [usize; 2], pub fixed_in_evals: Vec, pub wits_in_evals: Vec, diff --git a/ceno_zkvm/src/scheme/hal.rs b/ceno_zkvm/src/scheme/hal.rs index b9a48485f..0274478fc 100644 --- a/ceno_zkvm/src/scheme/hal.rs +++ b/ceno_zkvm/src/scheme/hal.rs @@ -50,7 +50,7 @@ pub struct ProofInput<'a, PB: ProverBackend> { pub fixed: Vec>>, pub public_input: Vec>>, pub pub_io_evals: Vec::BaseField, PB::E>>, - pub num_instances: Vec, + pub num_instances: [usize; 2], pub has_ecc_ops: bool, } diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index 99ad1ac39..c05f5e1a9 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -171,14 +171,11 @@ impl< // commit to fixed commitment let span = entered_span!("commit_to_fixed_commit", profiling_1 = true); - if let Some(fixed_commit) = &self.pk.fixed_commit - && shard_ctx.is_first_shard() - { + if let Some(fixed_commit) = self.pk.fixed_commit.as_ref() { PCS::write_commitment(fixed_commit, &mut transcript) .map_err(ZKVMError::PCSError)?; - } else if let Some(fixed_commit) = &self.pk.fixed_no_omc_init_commit - && !shard_ctx.is_first_shard() - { + } + if let Some(fixed_commit) = self.pk.fixed_no_omc_init_commit.as_ref() { PCS::write_commitment(fixed_commit, &mut transcript) .map_err(ZKVMError::PCSError)?; } @@ -198,10 +195,10 @@ impl< // num_instance from witness might include rotation let num_instances = chip_inputs .iter() - .flat_map(|chip_input| &chip_input.num_instances) + .flat_map(|chip_input| chip_input.num_instances) .collect_vec(); - if num_instances.is_empty() { + if num_instances.iter().sum::() == 0 { continue; } @@ -210,7 +207,7 @@ impl< transcript.append_field_element(&E::BaseField::from_canonical_usize(*circuit_idx)); for num_instance in num_instances { transcript - .append_field_element(&E::BaseField::from_canonical_usize(*num_instance)); + .append_field_element(&E::BaseField::from_canonical_usize(num_instance)); } } @@ -548,7 +545,7 @@ impl< ecc_proof, fixed_in_evals, wits_in_evals, - num_instances: input.num_instances.clone(), + num_instances: input.num_instances, }, pi_in_evals, input_opening_point, @@ -561,7 +558,7 @@ impl< fn build_chip_tasks<'data>( &self, shard_ctx: &ShardContext, - name_and_instances: Vec<(String, Vec)>, + name_and_instances: Vec<(String, [usize; 2])>, structural_rmms: Vec>, #[allow(unused_mut)] mut witness_mles: Vec>, witness_data: &PB::PcsData, @@ -600,11 +597,11 @@ impl< let pk = self.pk.circuit_pks.get(&circuit_name).unwrap(); let cs = pk.get_cs(); if !shard_ctx.is_first_shard() && cs.with_omc_init_only() { - assert!(num_instances.is_empty()); + assert_eq!(num_instances, [0, 0]); // skip drain respective fixed because we use different set of fixed commitment continue; } - if num_instances.is_empty() { + if num_instances.iter().sum::() == 0 { // we need to drain respective fixed when num_instances is 0 if cs.num_fixed() > 0 { let _ = fixed_mles.drain(..cs.num_fixed()).collect_vec(); @@ -651,7 +648,7 @@ impl< structural_witness, public_input: public_input.clone(), pub_io_evals: pi_evals.iter().map(|p| Either::Right(*p)).collect(), - num_instances: num_instances.clone(), + num_instances, has_ecc_ops: cs.has_ecc_ops(), }; // SAFETY: All Arcs in ProofInput contain 'static data: diff --git a/ceno_zkvm/src/scheme/tests.rs b/ceno_zkvm/src/scheme/tests.rs index 91a4e4563..f95632463 100644 --- a/ceno_zkvm/src/scheme/tests.rs +++ b/ceno_zkvm/src/scheme/tests.rs @@ -213,7 +213,7 @@ fn test_rw_lk_expression_combination() { structural_witness: structural_in, public_input: vec![], pub_io_evals: vec![], - num_instances: vec![num_instances], + num_instances: [num_instances, 0], has_ecc_ops: false, }; let task = crate::scheme::scheduler::ChipTask { diff --git a/ceno_zkvm/src/scheme/verifier.rs b/ceno_zkvm/src/scheme/verifier.rs index 2c49ee658..b17126134 100644 --- a/ceno_zkvm/src/scheme/verifier.rs +++ b/ceno_zkvm/src/scheme/verifier.rs @@ -211,13 +211,10 @@ impl> ZKVMVerifier // write fixed commitment to transcript // TODO check soundness if there is no fixed_commit but got fixed proof? - if let Some(fixed_commit) = self.vk.fixed_commit.as_ref() - && shard_id == 0 - { + if let Some(fixed_commit) = self.vk.fixed_commit.as_ref() { PCS::write_commitment(fixed_commit, &mut transcript).map_err(ZKVMError::PCSError)?; - } else if let Some(fixed_commit) = self.vk.fixed_no_omc_init_commit.as_ref() - && shard_id > 0 - { + } + if let Some(fixed_commit) = self.vk.fixed_no_omc_init_commit.as_ref() { PCS::write_commitment(fixed_commit, &mut transcript).map_err(ZKVMError::PCSError)?; } @@ -632,7 +629,7 @@ impl> ZKVMVerifier let gkr_circuit = gkr_circuit.as_ref().unwrap(); let selector_ctxs = if cs.ec_final_sum.is_empty() { - assert_eq!(proof.num_instances.len(), 1); + assert_eq!(proof.num_instances[1], 0); // it's not shard chip vec![ SelectorContext::new(0, num_instances, num_var_with_rotation); @@ -643,7 +640,6 @@ impl> ZKVMVerifier .unwrap_or(0) ] } else { - assert_eq!(proof.num_instances.len(), 2); // it's shard chip tracing::debug!( "num_reads: {}, num_writes: {}, total: {}", diff --git a/ceno_zkvm/src/structs.rs b/ceno_zkvm/src/structs.rs index 1f6847140..83289e988 100644 --- a/ceno_zkvm/src/structs.rs +++ b/ceno_zkvm/src/structs.rs @@ -325,15 +325,14 @@ impl ZKVMFixedTraces { pub struct ChipInput { pub name: String, pub witness_rmms: RMMCollections, - // in shard ram chip, num_instances length would be > 1 - pub num_instances: Vec, + pub num_instances: [usize; 2], } impl ChipInput { pub fn new( name: String, witness_rmms: RMMCollections, - num_instances: Vec, + num_instances: [usize; 2], ) -> Self { Self { name, @@ -382,16 +381,8 @@ impl ZKVMWitnesses { shard_steps, indices, )?; - let num_instances = vec![witness[0].num_instances()]; - let input = ChipInput::new( - OC::name(), - witness, - if num_instances[0] > 0 { - num_instances - } else { - vec![] - }, - ); + let num_instances = [witness[0].num_instances(), 0]; + let input = ChipInput::new(OC::name(), witness, num_instances); assert!(self.witnesses.insert(OC::name(), vec![input]).is_none()); assert!( self.lk_mlts @@ -445,15 +436,7 @@ impl ZKVMWitnesses { input, )?; let num_instances = std::cmp::max(witness[0].num_instances(), witness[1].num_instances()); - let input = ChipInput::new( - TC::name(), - witness, - if num_instances > 0 { - vec![num_instances] - } else { - vec![] - }, - ); + let input = ChipInput::new(TC::name(), witness, [num_instances, 0]); assert!(self.witnesses.insert(TC::name(), vec![input]).is_none()); Ok(()) @@ -613,7 +596,7 @@ impl ZKVMWitnesses { Ok(ChipInput::new( ShardRamCircuit::::name(), witness, - vec![num_reads, num_writes], + [num_reads, num_writes], )) }) .collect::, ZKVMError>>()?; @@ -627,13 +610,13 @@ impl ZKVMWitnesses { Ok(()) } - pub fn get_witnesses_name_instance(&self) -> Vec<(String, Vec)> { + pub fn get_witnesses_name_instance(&self) -> Vec<(String, [usize; 2])> { self.witnesses .iter() .flat_map(|(_, chip_inputs)| { chip_inputs .iter() - .map(|chip_input| (chip_input.name.clone(), chip_input.num_instances.clone())) + .map(|chip_input| (chip_input.name.clone(), chip_input.num_instances)) }) .collect_vec() } diff --git a/ceno_zkvm/src/tables/shard_ram.rs b/ceno_zkvm/src/tables/shard_ram.rs index 23897fce8..b2a70cc6a 100644 --- a/ceno_zkvm/src/tables/shard_ram.rs +++ b/ceno_zkvm/src/tables/shard_ram.rs @@ -858,7 +858,7 @@ mod tests { fixed: vec![], public_input: public_input_mles.clone(), pub_io_evals, - num_instances: vec![n_global_writes as usize, n_global_reads as usize], + num_instances: [n_global_writes as usize, n_global_reads as usize], has_ecc_ops: true, }; let mut rng = thread_rng();