From 68c787d6cde151bcaff34893d1837f6fe7a85c42 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Wed, 25 Mar 2026 23:20:32 +0800 Subject: [PATCH 1/7] refactor: use instance_value-indexed PI evals in zkvm transcripts Replace raw public-input transcript absorption with circuit-driven instance_value indexing in prover, native verifier, and recursion verifier. Also add instance_value tracking in ConstraintSystem::query_instance and fix mutability/borrow fallout in public-value query helpers and RAM table address setup. --- ceno_recursion/src/zkvm_verifier/verifier.rs | 11 +++++++---- ceno_zkvm/src/chip_handler/general.rs | 16 ++++++++-------- ceno_zkvm/src/scheme/prover.rs | 16 +++++++++++++--- ceno_zkvm/src/scheme/verifier.rs | 17 +++++++++++------ ceno_zkvm/src/tables/ram.rs | 6 ++++-- gkr_iop/src/circuit_builder.rs | 11 +++++++++-- 6 files changed, 52 insertions(+), 25 deletions(-) diff --git a/ceno_recursion/src/zkvm_verifier/verifier.rs b/ceno_recursion/src/zkvm_verifier/verifier.rs index e0d823814..3a13d9126 100644 --- a/ceno_recursion/src/zkvm_verifier/verifier.rs +++ b/ceno_recursion/src/zkvm_verifier/verifier.rs @@ -104,10 +104,13 @@ pub fn verify_zkvm_proof>( let prod_w: Ext = builder.constant(C::EF::ONE); let logup_sum: Ext = builder.constant(C::EF::ZERO); - iter_zip!(builder, zkvm_proof_input.raw_pi).for_each(|ptr_vec, builder| { - let v = builder.iter_ptr_get(&zkvm_proof_input.raw_pi, ptr_vec[0]); - challenger_multi_observe(builder, &mut challenger, &v); - }); + for (_, circuit_vk) in vk.circuit_vks.iter() { + for instance_value in circuit_vk.get_cs().zkvm_v1_css.instance_values.iter() { + let eval = builder.get(&zkvm_proof_input.pi_evals, instance_value.0); + let eval_felts = builder.ext2felt(eval); + challenger.observe_slice(builder, eval_felts); + } + } iter_zip!(builder, zkvm_proof_input.raw_pi, zkvm_proof_input.pi_evals).for_each( |ptr_vec, builder| { diff --git a/ceno_zkvm/src/chip_handler/general.rs b/ceno_zkvm/src/chip_handler/general.rs index bb63ce504..bea259603 100644 --- a/ceno_zkvm/src/chip_handler/general.rs +++ b/ceno_zkvm/src/chip_handler/general.rs @@ -27,12 +27,12 @@ pub trait PublicValuesQuery { fn query_public_io(&mut self) -> Result<[Instance; UINT_LIMBS], CircuitBuilderError>; #[allow(dead_code)] fn query_shard_id(&mut self) -> Result; - fn query_heap_start_addr(&self) -> Result; + fn query_heap_start_addr(&mut self) -> Result; #[allow(dead_code)] - fn query_heap_shard_len(&self) -> Result; - fn query_hint_start_addr(&self) -> Result; + fn query_heap_shard_len(&mut self) -> Result; + fn query_hint_start_addr(&mut self) -> Result; #[allow(dead_code)] - fn query_hint_shard_len(&self) -> Result; + fn query_hint_shard_len(&mut self) -> Result; } impl<'a, E: ExtensionField> InstFetch for CircuitBuilder<'a, E> { @@ -95,18 +95,18 @@ impl<'a, E: ExtensionField> PublicValuesQuery for CircuitBuilder<'a, E> { self.cs.query_instance(SHARD_ID_IDX) } - fn query_heap_start_addr(&self) -> Result { + fn query_heap_start_addr(&mut self) -> Result { self.cs.query_instance(HEAP_START_ADDR_IDX) } - fn query_heap_shard_len(&self) -> Result { + fn query_heap_shard_len(&mut self) -> Result { self.cs.query_instance(HEAP_LENGTH_IDX) } - fn query_hint_start_addr(&self) -> Result { + fn query_hint_start_addr(&mut self) -> Result { self.cs.query_instance(HINT_START_ADDR_IDX) } - fn query_hint_shard_len(&self) -> Result { + fn query_hint_shard_len(&mut self) -> Result { self.cs.query_instance(HINT_LENGTH_IDX) } } diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index 99ad1ac39..0f4153226 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -160,10 +160,20 @@ impl< let mut pi_evals = ZKVMProof::::pi_evals(&raw_pi); let span = entered_span!("commit_to_pi", profiling_1 = true); - // including raw public input to transcript - for v in raw_pi.iter().flatten() { - transcript.append_field_element(v); + // Include transcript-visible public values in canonical circuit order. + // The order must match verifier and recursion verifier exactly. + // TODO deal with public io memory later + for (_, circuit_pk) in self.pk.circuit_pks.iter() { + for instance_value in circuit_pk.get_cs().zkvm_v1_css.instance_value.iter() { + let idx = instance_value.0; + let eval = pi_evals + .get(idx) + .copied() + .expect("instance_value index out of bounds for pi_evals"); + transcript.append_field_element_ext(&eval); + } } + exit_span!(span); let pi: Vec> = diff --git a/ceno_zkvm/src/scheme/verifier.rs b/ceno_zkvm/src/scheme/verifier.rs index 2c49ee658..ddd978ebf 100644 --- a/ceno_zkvm/src/scheme/verifier.rs +++ b/ceno_zkvm/src/scheme/verifier.rs @@ -181,12 +181,17 @@ impl> ZKVMVerifier } } - // TODO fix soundness: construct raw public input by ourself and trustless from proof - // including raw public input to transcript - vm_proof - .raw_pi - .iter() - .for_each(|v| v.iter().for_each(|v| transcript.append_field_element(v))); + // Include transcript-visible public values in canonical circuit order. + // This must match prover and recursion verifier exactly. + for (_, circuit_vk) in self.vk.circuit_vks.iter() { + for instance_value in circuit_vk.get_cs().zkvm_v1_css.instance_values.iter() { + let idx = instance_value.0; + let eval = *pi_evals + .get(idx) + .expect("instance_value index out of bounds for pi_evals"); + transcript.append_field_element_ext(&eval); + } + } // check shard id assert_eq!( diff --git a/ceno_zkvm/src/tables/ram.rs b/ceno_zkvm/src/tables/ram.rs index d934127f9..130adc787 100644 --- a/ceno_zkvm/src/tables/ram.rs +++ b/ceno_zkvm/src/tables/ram.rs @@ -38,11 +38,12 @@ impl DynVolatileRamTable for HeapTable { params: &ProgramParams, ) -> Result<(Expression, StructuralWitIn), CircuitBuilderError> { let max_len = Self::max_len(params); + let offset_instance_id = cb.query_heap_start_addr()?.0 as WitnessId; let addr = cb.create_structural_witin( || "addr", StructuralWitInType::EqualDistanceDynamicSequence { max_len, - offset_instance_id: cb.query_heap_start_addr()?.0 as WitnessId, + offset_instance_id, multi_factor: WORD_SIZE, descending: Self::DESCENDING, }, @@ -143,11 +144,12 @@ impl DynVolatileRamTable for HintsTable { params: &ProgramParams, ) -> Result<(Expression, StructuralWitIn), CircuitBuilderError> { let max_len = Self::max_len(params); + let offset_instance_id = cb.query_hint_start_addr()?.0 as WitnessId; let addr = cb.create_structural_witin( || "addr", StructuralWitInType::EqualDistanceDynamicSequence { max_len, - offset_instance_id: cb.query_hint_start_addr()?.0 as WitnessId, + offset_instance_id, multi_factor: WORD_SIZE, descending: Self::DESCENDING, }, diff --git a/gkr_iop/src/circuit_builder.rs b/gkr_iop/src/circuit_builder.rs index d84cb4d55..567a8ef06 100644 --- a/gkr_iop/src/circuit_builder.rs +++ b/gkr_iop/src/circuit_builder.rs @@ -102,6 +102,8 @@ pub struct ConstraintSystem { pub num_fixed: usize, pub fixed_namespace_map: Vec, + // record which public input index is involving in constraint computation + pub instance_values: Vec, pub instance_openings: Vec, pub ec_point_exprs: Vec>, @@ -175,6 +177,7 @@ impl ConstraintSystem { num_fixed: 0, fixed_namespace_map: vec![], ns: NameSpace::new(root_name_fn), + instance_values: vec![], instance_openings: vec![], ec_final_sum: vec![], ec_slope_exprs: vec![], @@ -259,8 +262,13 @@ impl ConstraintSystem { f } - pub fn query_instance(&self, idx: usize) -> Result { + pub fn query_instance(&mut self, idx: usize) -> Result { let i = Instance(idx); + assert!( + !self.instance_values.contains(&i), + "query same pubio idx {idx} value more than once", + ); + self.instance_values.push(i); Ok(i) } @@ -276,7 +284,6 @@ impl ConstraintSystem { ); self.instance_openings.push(i); - // return instance only count Ok(Instance(self.instance_openings.len() - 1)) } From 69da9ba2fe05f6d2e947ee12916faa3300de5c1c Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Thu, 26 Mar 2026 17:11:21 +0800 Subject: [PATCH 2/7] refactor(zkvm): derive public-input openings from PublicValues MLEs --- ceno_recursion/src/aggregation/mod.rs | 17 +- ceno_recursion/src/zkvm_verifier/binding.rs | 99 ++++++++---- ceno_recursion/src/zkvm_verifier/verifier.rs | 82 +++++++--- ceno_zkvm/src/bin/e2e.rs | 4 +- ceno_zkvm/src/e2e.rs | 2 +- ceno_zkvm/src/instructions/riscv/constants.rs | 9 +- ceno_zkvm/src/scheme.rs | 126 ++++++--------- ceno_zkvm/src/scheme/cpu/mod.rs | 14 +- ceno_zkvm/src/scheme/gpu/mod.rs | 14 +- ceno_zkvm/src/scheme/hal.rs | 1 + ceno_zkvm/src/scheme/mock_prover.rs | 95 +++++------ ceno_zkvm/src/scheme/prover.rs | 131 ++++++---------- ceno_zkvm/src/scheme/scheduler.rs | 8 +- ceno_zkvm/src/scheme/tests.rs | 7 +- ceno_zkvm/src/scheme/verifier.rs | 148 +++++++++++------- ceno_zkvm/src/tables/shard_ram.rs | 56 ++++--- gkr_iop/src/circuit_builder.rs | 2 +- gkr_iop/src/gkr.rs | 4 +- gkr_iop/src/gkr/layer.rs | 4 +- gkr_iop/src/gkr/layer/zerocheck_layer.rs | 13 +- 20 files changed, 440 insertions(+), 396 deletions(-) diff --git a/ceno_recursion/src/aggregation/mod.rs b/ceno_recursion/src/aggregation/mod.rs index d17829fe9..c7ef469d6 100644 --- a/ceno_recursion/src/aggregation/mod.rs +++ b/ceno_recursion/src/aggregation/mod.rs @@ -265,7 +265,7 @@ impl CenoAggregationProver { .collect(); let user_public_values: Vec = zkvm_proof_inputs .iter() - .flat_map(|p| p.raw_pi.iter().flat_map(|v| v.clone()).collect::>()) + .flat_map(|p| p.raw_pi.iter().copied().collect::>()) .collect(); let leaf_inputs = chunk_ceno_leaf_proof_inputs(zkvm_proof_inputs); @@ -410,18 +410,9 @@ impl CenoLeafVmVerifierConfig { } let pv = &raw_pi; - let init_pc = { - let arr = builder.get(pv, INIT_PC_IDX); - builder.get(&arr, 0) - }; - let end_pc = { - let arr = builder.get(pv, END_PC_IDX); - builder.get(&arr, 0) - }; - let exit_code = { - let arr = builder.get(pv, EXIT_CODE_IDX); - builder.get(&arr, 0) - }; + let init_pc = builder.get(pv, INIT_PC_IDX); + let end_pc = builder.get(pv, END_PC_IDX); + let exit_code = builder.get(pv, EXIT_CODE_IDX); builder.assign(&stark_pvs.connector.initial_pc, init_pc); builder.assign(&stark_pvs.connector.final_pc, end_pc); builder.assign(&stark_pvs.connector.exit_code, exit_code); diff --git a/ceno_recursion/src/zkvm_verifier/binding.rs b/ceno_recursion/src/zkvm_verifier/binding.rs index 08c76dbe0..e6abae576 100644 --- a/ceno_recursion/src/zkvm_verifier/binding.rs +++ b/ceno_recursion/src/zkvm_verifier/binding.rs @@ -14,6 +14,7 @@ use crate::{ }, }; use ceno_zkvm::{ + instructions::riscv::constants::{LIMB_BITS, LIMB_MASK, UINT_LIMBS}, scheme::{ZKVMChipProof, ZKVMProof}, structs::{EccQuarkProof, TowerProofs, ZKVMVerifyingKey}, }; @@ -41,6 +42,46 @@ pub type E = BinomialExtensionField; pub type RecPcs = Basefold; pub type InnerConfig = AsmConfig; +fn raw_pi_from_public_values(public_values: &ceno_zkvm::scheme::PublicValues) -> Vec { + vec![ + F::from_canonical_u32(public_values.exit_code & 0xffff), + F::from_canonical_u32((public_values.exit_code >> 16) & 0xffff), + F::from_canonical_u32(public_values.init_pc), + F::from_canonical_u64(public_values.init_cycle), + F::from_canonical_u32(public_values.end_pc), + F::from_canonical_u64(public_values.end_cycle), + F::from_canonical_u32(public_values.shard_id), + F::from_canonical_u32(public_values.heap_start_addr), + F::from_canonical_u32(public_values.heap_shard_len), + F::from_canonical_u32(public_values.hint_start_addr), + F::from_canonical_u32(public_values.hint_shard_len), + ] + .into_iter() + .chain( + public_values + .shard_rw_sum + .iter() + .map(|value| F::from_canonical_u32(*value)), + ) + .collect_vec() +} + +fn mles_from_public_values(public_values: &ceno_zkvm::scheme::PublicValues) -> Vec> { + (0..UINT_LIMBS) + .map(|limb_index| { + public_values + .public_io + .iter() + .map(|value| F::from_canonical_u16(((value >> (limb_index * LIMB_BITS)) & LIMB_MASK) as u16)) + .collect_vec() + }) + .collect_vec() +} + +fn pi_evals_from_raw_pi(raw_pi: &[F]) -> Vec { + raw_pi.iter().map(|v| E::from(*v)).collect_vec() +} + pub fn decompose_minus_one_bits(n: usize) -> Vec { let a = if n > 0 { n - 1 } else { 0 }; let mut bit_decomp: Vec = vec![]; @@ -71,7 +112,8 @@ pub fn decompose_prefixed_layer_bits(n: usize) -> (Vec, Vec>) { #[derive(DslVariable, Clone)] pub struct ZKVMProofInputVariable { pub shard_id: Usize, - pub raw_pi: Array>>, + pub raw_pi: Array>, + pub mles: Array>>, pub raw_pi_num_variables: Array>, pub pi_evals: Array>, pub chip_proofs: Array>>, @@ -95,7 +137,9 @@ pub struct TowerProofInputVariable { pub(crate) struct ZKVMProofInput { pub shard_id: usize, - pub raw_pi: Vec>, + pub raw_pi: Vec, + pub mles: Vec>, + pub raw_pi_num_variables: Vec, // Evaluation of raw_pi. pub pi_evals: Vec, pub chip_proofs: BTreeMap, @@ -109,6 +153,14 @@ impl ZKVMProofInput { zkvm_proof: ZKVMProof, vk: &ZKVMVerifyingKey, ) -> Self { + let raw_pi = raw_pi_from_public_values(&zkvm_proof.public_values); + let mles = mles_from_public_values(&zkvm_proof.public_values); + let raw_pi_num_variables = mles + .iter() + .map(|v| ceil_log2(v.len().next_power_of_two())) + .collect::>(); + let pi_evals = pi_evals_from_raw_pi(&raw_pi); + let mut chip_witin_num_vars: HashMap = HashMap::new(); // (chip_id, (num_witin, num_fixed)) let mut chip_indices = zkvm_proof .chip_proofs @@ -136,8 +188,10 @@ impl ZKVMProofInput { ZKVMProofInput { shard_id, - raw_pi: zkvm_proof.raw_pi, - pi_evals: zkvm_proof.pi_evals, + raw_pi, + mles, + raw_pi_num_variables, + pi_evals, chip_proofs: zkvm_proof .chip_proofs .into_iter() @@ -168,7 +222,8 @@ impl Hintable for ZKVMProofInput { fn read(builder: &mut Builder) -> Self::HintVariable { let shard_id = Usize::Var(usize::read(builder)); - let raw_pi = Vec::>::read(builder); + let raw_pi = Vec::::read(builder); + let mles = Vec::>::read(builder); let raw_pi_num_variables = Vec::::read(builder); let pi_evals = Vec::::read(builder); builder.cycle_tracker_start("read chip proofs"); @@ -187,6 +242,7 @@ impl Hintable for ZKVMProofInput { ZKVMProofInputVariable { shard_id, raw_pi, + mles, raw_pi_num_variables, pi_evals, chip_proofs, @@ -201,11 +257,6 @@ impl Hintable for ZKVMProofInput { fn write(&self) -> Vec::N>> { let mut stream = Vec::new(); - let raw_pi_num_variables: Vec = self - .raw_pi - .iter() - .map(|v| ceil_log2(v.len().next_power_of_two())) - .collect(); let witin_num_vars = self .chip_proofs .iter() @@ -217,21 +268,21 @@ impl Hintable for ZKVMProofInput { .chip_proofs .iter() .flat_map(|(_, proofs)| proofs.iter()) - .map(|proof| proof.wits_in_evals.len().max(1)) + .map(|proof| proof.num_witin.max(1)) .collect::>(); let fixed_num_vars = self .chip_proofs .iter() .flat_map(|(_, proofs)| proofs.iter()) - .filter(|proof| !proof.fixed_in_evals.is_empty()) + .filter(|proof| proof.num_fixed > 0) .map(|proof| proof.num_vars) .collect::>(); let fixed_max_widths = self .chip_proofs .iter() .flat_map(|(_, proofs)| proofs.iter()) - .filter(|proof| !proof.fixed_in_evals.is_empty()) - .map(|proof| proof.fixed_in_evals.len()) + .filter(|proof| proof.num_fixed > 0) + .map(|proof| proof.num_fixed) .collect::>(); let max_num_var = witin_num_vars .iter() @@ -264,7 +315,8 @@ impl Hintable for ZKVMProofInput { stream.extend(>::write(&self.shard_id)); stream.extend(self.raw_pi.write()); - stream.extend(raw_pi_num_variables.write()); + stream.extend(self.mles.write()); + stream.extend(self.raw_pi_num_variables.write()); stream.extend(self.pi_evals.write()); stream.extend(vec![vec![F::from_canonical_usize(self.chip_proofs.len())]]); for proofs in self.chip_proofs.values() { @@ -403,9 +455,6 @@ pub struct ZKVMChipProofInput { pub ecc_proof: EccQuarkProofInput, pub num_instances: Vec, - - pub wits_in_evals: Vec, - pub fixed_in_evals: Vec, } impl VecAutoHintable for ZKVMChipProofInput {} @@ -499,8 +548,6 @@ impl From<(usize, ZKVMChipProof, usize, usize)> for ZKVMChipProofInput { EccQuarkProofInput::dummy() }, num_instances: p.num_instances, - wits_in_evals: p.wits_in_evals, - fixed_in_evals: p.fixed_in_evals, } } } @@ -531,9 +578,6 @@ pub struct ZKVMChipProofInputVariable { pub num_instances: Array>, pub n_inst_0_bit_decomps: Array>, pub n_inst_1_bit_decomps: Array>, - - pub fixed_in_evals: Array>, - pub wits_in_evals: Array>, } impl Hintable for ZKVMChipProofInput { type HintVariable = ZKVMChipProofInputVariable; @@ -571,11 +615,6 @@ impl Hintable for ZKVMChipProofInput { let n_inst_0_bit_decomps = Vec::::read(builder); let n_inst_1_bit_decomps = Vec::::read(builder); - builder.cycle_tracker_start("read wit/fixed evals"); - let fixed_in_evals = Vec::::read(builder); - let wits_in_evals = Vec::::read(builder); - builder.cycle_tracker_end("read wit/fixed evals"); - ZKVMChipProofInputVariable { idx, idx_felt, @@ -597,8 +636,6 @@ impl Hintable for ZKVMChipProofInput { num_instances, n_inst_0_bit_decomps, n_inst_1_bit_decomps, - fixed_in_evals, - wits_in_evals, } } @@ -674,8 +711,6 @@ impl Hintable for ZKVMChipProofInput { stream.extend(n_inst_0_bit_decomps.write()); stream.extend(n_inst_1_bit_decomps.write()); - stream.extend(self.fixed_in_evals.write()); - stream.extend(self.wits_in_evals.write()); stream } diff --git a/ceno_recursion/src/zkvm_verifier/verifier.rs b/ceno_recursion/src/zkvm_verifier/verifier.rs index 3a13d9126..1681c1ddc 100644 --- a/ceno_recursion/src/zkvm_verifier/verifier.rs +++ b/ceno_recursion/src/zkvm_verifier/verifier.rs @@ -106,7 +106,8 @@ pub fn verify_zkvm_proof>( for (_, circuit_vk) in vk.circuit_vks.iter() { for instance_value in circuit_vk.get_cs().zkvm_v1_css.instance_values.iter() { - let eval = builder.get(&zkvm_proof_input.pi_evals, instance_value.0); + let raw = builder.get(&zkvm_proof_input.raw_pi, instance_value.0); + let eval = builder.ext_from_base_slice(&[raw]); let eval_felts = builder.ext2felt(eval); challenger.observe_slice(builder, eval_felts); } @@ -116,12 +117,8 @@ pub fn verify_zkvm_proof>( |ptr_vec, builder| { let raw = builder.iter_ptr_get(&zkvm_proof_input.raw_pi, ptr_vec[0]); let eval = builder.iter_ptr_get(&zkvm_proof_input.pi_evals, ptr_vec[1]); - let raw0 = builder.get(&raw, 0); - - builder.if_eq(raw.len(), Usize::from(1)).then(|builder| { - let raw0_ext = builder.ext_from_base_slice(&[raw0]); - builder.assert_ext_eq(raw0_ext, eval); - }); + let raw_ext = builder.ext_from_base_slice(&[raw]); + builder.assert_ext_eq(raw_ext, eval); }, ); @@ -276,14 +273,6 @@ pub fn verify_zkvm_proof>( // fork transcript to support chip concurrently proved let mut chip_challenger = clone_challenger_state(builder, &challenger); challenger_add_forked_index(builder, &mut chip_challenger, &forked_sample_index); - builder.assert_usize_eq( - chip_proof.wits_in_evals.len(), - Usize::from(circuit_vk.get_cs().num_witin()), - ); - builder.assert_usize_eq( - chip_proof.fixed_in_evals.len(), - Usize::from(circuit_vk.get_cs().num_fixed()), - ); builder.assert_usize_eq( chip_proof.rw_out_evals.length.clone(), Usize::from( @@ -339,8 +328,8 @@ pub fn verify_zkvm_proof>( builder, &mut chip_challenger, &chip_proof, - &zkvm_proof_input.pi_evals, &zkvm_proof_input.raw_pi, + &zkvm_proof_input.mles, &zkvm_proof_input.raw_pi_num_variables, &challenges, chip_vk, @@ -376,13 +365,20 @@ pub fn verify_zkvm_proof>( let point_clone: Array> = builder.eval(input_opening_point.clone()); + let (wits_in_evals, fixed_in_evals, _pi_in_evals) = split_input_opening_evals( + builder, + &chip_proof, + circuit_vk.get_cs().num_witin(), + circuit_vk.get_cs().num_fixed(), + circuit_vk.get_cs().instance_openings().len(), + ); if circuit_vk.get_cs().num_witin() > 0 { let witin_round: RoundOpeningVariable = builder.eval(RoundOpeningVariable { num_var: input_opening_point.len().get_var(), point_and_evals: PointAndEvalsVariable { point: PointVariable { fs: point_clone }, - evals: chip_proof.wits_in_evals, + evals: wits_in_evals, }, }); builder.set_value(&witin_openings, num_witin_openings.get_var(), witin_round); @@ -395,7 +391,7 @@ pub fn verify_zkvm_proof>( point: PointVariable { fs: input_opening_point, }, - evals: chip_proof.fixed_in_evals, + evals: fixed_in_evals, }, }); @@ -559,13 +555,41 @@ pub fn verify_zkvm_proof>( shard_ec_sum } +fn split_input_opening_evals( + builder: &mut Builder, + chip_proof: &ZKVMChipProofInputVariable, + num_witin: usize, + num_fixed: usize, + num_pi: usize, +) -> ( + Array>, + Array>, + Array>, +) { + let last_layer_idx: Usize = + builder.eval(chip_proof.gkr_iop_proof.layer_proofs.len() - Usize::from(1)); + let last_layer = builder.get(&chip_proof.gkr_iop_proof.layer_proofs, last_layer_idx); + let main_evals = last_layer.main.evals; + + let wit_end = Usize::from(num_witin); + let fixed_end: Usize = builder.eval(wit_end.clone() + Usize::from(num_fixed)); + let pi_end: Usize = builder.eval(fixed_end.clone() + Usize::from(num_pi)); + builder.assert_usize_eq(main_evals.len(), pi_end.clone()); + + ( + main_evals.slice(builder, Usize::from(0), wit_end), + main_evals.slice(builder, Usize::from(num_witin), fixed_end), + main_evals.slice(builder, Usize::from(num_witin + num_fixed), pi_end), + ) +} + pub fn verify_chip_proof( circuit_name: &str, builder: &mut Builder, challenger: &mut DuplexChallengerVariable, chip_proof: &ZKVMChipProofInputVariable, - pi_evals: &Array>, - raw_pi: &Array>>, + raw_pi: &Array>, + mles: &Array>>, raw_pi_num_variables: &Array>, challenges: &Array>, vk: &VerifyingKey, @@ -712,6 +736,13 @@ pub fn verify_chip_proof( builder.set(&q_slice, idx_vec[0], cpt); }); let gkr_circuit = gkr_circuit.clone().unwrap(); + let circuit_pi_evals: Array> = + builder.dyn_array(Usize::from(cs.instance_values.len())); + for (i, instance) in cs.instance_values.iter().enumerate() { + let raw = builder.get(raw_pi, instance.0); + let eval = builder.ext_from_base_slice(&[raw]); + builder.set(&circuit_pi_evals, i, eval); + } let zero_bit_decomps: Array> = builder.dyn_array(32); let selector_ctxs: Vec> = if cs.ec_final_sum.is_empty() { @@ -810,11 +841,10 @@ pub fn verify_chip_proof( gkr_circuit, &chip_proof.gkr_iop_proof, challenges, - pi_evals, - raw_pi, + &circuit_pi_evals, + mles, raw_pi_num_variables, &out_evals, - chip_proof, selector_ctxs, unipoly_extrapolator, poly_evaluator, @@ -832,10 +862,9 @@ pub fn verify_gkr_circuit( gkr_proof: &GKRProofVariable, challenges: &Array>, pub_io_evals: &Array>, - raw_pi: &Array>>, + mles: &Array>>, raw_pi_num_variables: &Array>, claims: &Array>, - _chip_proof: &ZKVMChipProofInputVariable, selector_ctxs: Vec>, unipoly_extrapolator: &UniPolyExtrapolator, poly_evaluator: &mut PolyEvaluator, @@ -1130,7 +1159,7 @@ pub fn verify_gkr_circuit( let pubio_offset = layer.n_witin + layer.n_fixed; for (index, instance) in layer.instance_openings.iter().enumerate() { let index: usize = pubio_offset + index; - let poly = builder.get(raw_pi, instance.0); + let poly = builder.get(mles, instance.0); let num_variable = builder.get(raw_pi_num_variables, instance.0); let in_point_slice = in_point.slice(builder, 0, num_variable); let expected_eval = @@ -1139,6 +1168,7 @@ pub fn verify_gkr_circuit( builder.assert_ext_eq(expected_eval, main_eval); } + // TODO: we should store alpha_pows in a bigger array to avoid concatenating them let main_sumcheck_challenges_len: Usize = builder.eval(alpha_pows.len() + Usize::from(2)); diff --git a/ceno_zkvm/src/bin/e2e.rs b/ceno_zkvm/src/bin/e2e.rs index 09e5a1131..e16d93c82 100644 --- a/ceno_zkvm/src/bin/e2e.rs +++ b/ceno_zkvm/src/bin/e2e.rs @@ -19,7 +19,6 @@ use gkr_iop::hal::ProverBackend; use mpcs::{ Basefold, BasefoldRSParams, PolynomialCommitmentScheme, SecurityLevel, Whir, WhirDefaultSpec, }; -use p3::field::FieldAlgebra; use serde::{Serialize, de::DeserializeOwned}; use std::{fs, panic, panic::AssertUnwindSafe, path::PathBuf}; use tracing::{error, level_filters::LevelFilter}; @@ -404,8 +403,7 @@ fn soundness_test>( // do sanity check let transcript = Transcript::new(b"riscv"); // change public input maliciously should cause verifier to reject proof - zkvm_proof.raw_pi[0] = vec![E::BaseField::ONE]; - zkvm_proof.raw_pi[1] = vec![E::BaseField::ONE]; + zkvm_proof.public_values.exit_code = 1; // capture panic message, if have let result = with_panic_hook(Box::new(|_info| ()), || { diff --git a/ceno_zkvm/src/e2e.rs b/ceno_zkvm/src/e2e.rs index 215cbf7b6..0c3d936b7 100644 --- a/ceno_zkvm/src/e2e.rs +++ b/ceno_zkvm/src/e2e.rs @@ -1021,7 +1021,7 @@ pub fn emulate_program<'a>( platform.hints.start, hints_final.len() as u32, io_init.iter().map(|rec| rec.value).collect_vec(), - vec![0; SEPTIC_EXTENSION_DEGREE * 2], // point_at_infinity + [0; SEPTIC_EXTENSION_DEGREE * 2], // point_at_infinity ); #[cfg(debug_assertions)] diff --git a/ceno_zkvm/src/instructions/riscv/constants.rs b/ceno_zkvm/src/instructions/riscv/constants.rs index 61e673246..2ac17f528 100644 --- a/ceno_zkvm/src/instructions/riscv/constants.rs +++ b/ceno_zkvm/src/instructions/riscv/constants.rs @@ -3,6 +3,8 @@ pub use ceno_emul::PC_STEP_SIZE; pub const ECALL_HALT_OPCODE: [usize; 2] = [0x00_00, 0x00_00]; pub const EXIT_PC: usize = 0; + +/// scalar-based public value, id start from 0 pub const EXIT_CODE_IDX: usize = 0; // exit code u32 occupied 2 limb, each with 16 pub const INIT_PC_IDX: usize = EXIT_CODE_IDX + 2; @@ -14,8 +16,11 @@ pub const HEAP_START_ADDR_IDX: usize = SHARD_ID_IDX + 1; pub const HEAP_LENGTH_IDX: usize = HEAP_START_ADDR_IDX + 1; pub const HINT_START_ADDR_IDX: usize = HEAP_LENGTH_IDX + 1; pub const HINT_LENGTH_IDX: usize = HINT_START_ADDR_IDX + 1; -pub const PUBLIC_IO_IDX: usize = HINT_LENGTH_IDX + 1; -pub const SHARD_RW_SUM_IDX: usize = PUBLIC_IO_IDX + 2; + +pub const SHARD_RW_SUM_IDX: usize = HINT_LENGTH_IDX + 1; + +/// vector-based public value, id start from 0 +pub const PUBLIC_IO_IDX: usize = 0; pub const LIMB_BITS: usize = 16; pub const LIMB_MASK: u32 = 0xFFFF; diff --git a/ceno_zkvm/src/scheme.rs b/ceno_zkvm/src/scheme.rs index fa9127f10..3ba4e3651 100644 --- a/ceno_zkvm/src/scheme.rs +++ b/ceno_zkvm/src/scheme.rs @@ -3,6 +3,7 @@ use ff_ext::ExtensionField; use gkr_iop::gkr::GKRProof; use itertools::Itertools; use mpcs::PolynomialCommitmentScheme; +use multilinear_extensions::mle::{IntoMLE, MultilinearExtension}; use p3::field::FieldAlgebra; use serde::{Deserialize, Serialize, de::DeserializeOwned}; use std::{ @@ -18,10 +19,15 @@ use crate::{ instructions::{ Instruction, riscv::{ - constants::{LIMB_BITS, LIMB_MASK, UINT_LIMBS}, + constants::{ + END_CYCLE_IDX, END_PC_IDX, EXIT_CODE_IDX, HEAP_LENGTH_IDX, HEAP_START_ADDR_IDX, + HINT_LENGTH_IDX, HINT_START_ADDR_IDX, INIT_CYCLE_IDX, INIT_PC_IDX, LIMB_BITS, + LIMB_MASK, SHARD_ID_IDX, SHARD_RW_SUM_IDX, UINT_LIMBS, + }, ecall::HaltInstruction, }, }, + scheme::constants::SEPTIC_EXTENSION_DEGREE, structs::{TowerProofs, ZKVMVerifyingKey}, }; @@ -65,13 +71,10 @@ pub struct ZKVMChipProof { pub ecc_proof: Option>, pub num_instances: Vec, - - pub fixed_in_evals: Vec, - pub wits_in_evals: Vec, } /// each field will be interpret to (constant) polynomial -#[derive(Default, Clone, Debug)] +#[derive(Default, Clone, Debug, Serialize, Deserialize)] pub struct PublicValues { pub exit_code: u32, pub init_pc: u32, @@ -84,7 +87,7 @@ pub struct PublicValues { pub hint_start_addr: u32, pub hint_shard_len: u32, pub public_io: Vec, - pub shard_rw_sum: Vec, + pub shard_rw_sum: [u32; SEPTIC_EXTENSION_DEGREE * 2], } impl PublicValues { @@ -101,7 +104,7 @@ impl PublicValues { hint_start_addr: u32, hint_shard_len: u32, public_io: Vec, - shard_rw_sum: Vec, + shard_rw_sum: [u32; SEPTIC_EXTENSION_DEGREE * 2], ) -> Self { Self { exit_code, @@ -118,45 +121,45 @@ impl PublicValues { shard_rw_sum, } } - pub fn to_vec(&self) -> Vec> { - vec![ - vec![E::BaseField::from_canonical_u32(self.exit_code & 0xffff)], - vec![E::BaseField::from_canonical_u32( - (self.exit_code >> 16) & 0xffff, - )], - vec![E::BaseField::from_canonical_u32(self.init_pc)], - vec![E::BaseField::from_canonical_u64(self.init_cycle)], - vec![E::BaseField::from_canonical_u32(self.end_pc)], - vec![E::BaseField::from_canonical_u64(self.end_cycle)], - vec![E::BaseField::from_canonical_u32(self.shard_id)], - vec![E::BaseField::from_canonical_u32(self.heap_start_addr)], - vec![E::BaseField::from_canonical_u32(self.heap_shard_len)], - vec![E::BaseField::from_canonical_u32(self.hint_start_addr)], - vec![E::BaseField::from_canonical_u32(self.hint_shard_len)], - ] - .into_iter() - .chain( - // public io processed into UINT_LIMBS column - (0..UINT_LIMBS) - .map(|limb_index| { - self.public_io - .iter() - .map(|value| { - E::BaseField::from_canonical_u16( - ((value >> (limb_index * LIMB_BITS)) & LIMB_MASK) as u16, - ) - }) - .collect_vec() - }) - .collect_vec(), - ) - .chain( - self.shard_rw_sum - .iter() - .map(|value| vec![E::BaseField::from_canonical_u32(*value)]) - .collect_vec(), - ) - .collect::>() + pub fn query_by_index(&self, index: usize) -> E { + match index { + EXIT_CODE_IDX => E::from_canonical_u32(self.exit_code & 0xffff), + idx if idx == EXIT_CODE_IDX + 1 => { + E::from_canonical_u32((self.exit_code >> 16) & 0xffff) + } + INIT_PC_IDX => E::from_canonical_u32(self.init_pc), + INIT_CYCLE_IDX => E::from_canonical_u64(self.init_cycle), + END_PC_IDX => E::from_canonical_u32(self.end_pc), + END_CYCLE_IDX => E::from_canonical_u64(self.end_cycle), + SHARD_ID_IDX => E::from_canonical_u32(self.shard_id), + HEAP_START_ADDR_IDX => E::from_canonical_u32(self.heap_start_addr), + HEAP_LENGTH_IDX => E::from_canonical_u32(self.heap_shard_len), + HINT_START_ADDR_IDX => E::from_canonical_u32(self.hint_start_addr), + HINT_LENGTH_IDX => E::from_canonical_u32(self.hint_shard_len), + idx if (SHARD_RW_SUM_IDX..(SHARD_RW_SUM_IDX + SEPTIC_EXTENSION_DEGREE * 2)) + .contains(&idx) => + { + E::from_canonical_u32(self.shard_rw_sum[idx - SHARD_RW_SUM_IDX]) + } + _ => panic!("public value index {index} out of range"), + } + } + + pub fn mles(&self) -> Vec> { + // public_io is represented as UINT_LIMBS columns. + (0..UINT_LIMBS) + .map(|limb_index| { + self.public_io + .iter() + .map(|value| { + E::BaseField::from_canonical_u16( + ((value >> (limb_index * LIMB_BITS)) & LIMB_MASK) as u16, + ) + }) + .collect_vec() + .into_mle() + }) + .collect_vec() } } @@ -169,11 +172,7 @@ impl PublicValues { deserialize = "E::BaseField: DeserializeOwned" ))] pub struct ZKVMProof> { - // TODO preserve in serde only for auxiliary public input - // other raw value can be construct by verifier directly. - pub raw_pi: Vec>, - // the evaluation of raw_pi. - pub pi_evals: Vec, + pub public_values: PublicValues, // each circuit may have multiple proof instances pub chip_proofs: BTreeMap>>, pub witin_commit: >::Commitment, @@ -182,40 +181,19 @@ pub struct ZKVMProof> { impl> ZKVMProof { pub fn new( - raw_pi: Vec>, - pi_evals: Vec, + public_values: PublicValues, chip_proofs: BTreeMap>>, witin_commit: >::Commitment, opening_proof: PCS::Proof, ) -> Self { Self { - raw_pi, - pi_evals, + public_values, chip_proofs, witin_commit, opening_proof, } } - pub fn pi_evals(raw_pi: &[Vec]) -> Vec { - raw_pi - .iter() - .map(|pv| { - if pv.len() == 1 { - // this is constant poly, and always evaluate to same constant value - E::from(pv[0]) - } else { - // set 0 as placeholder. will be evaluate lazily - // Or the vector is empty, i.e. the constant 0 polynomial. - E::ZERO - } - }) - .collect_vec() - } - - pub fn update_pi_eval(&mut self, idx: usize, v: E) { - self.pi_evals[idx] = v; - } pub fn num_circuits(&self) -> usize { self.chip_proofs.len() diff --git a/ceno_zkvm/src/scheme/cpu/mod.rs b/ceno_zkvm/src/scheme/cpu/mod.rs index a773f4de9..85decd78e 100644 --- a/ceno_zkvm/src/scheme/cpu/mod.rs +++ b/ceno_zkvm/src/scheme/cpu/mod.rs @@ -842,11 +842,8 @@ impl> MainSumcheckProver> MainSumcheckProver { pub struct MainSumcheckEvals { pub wits_in_evals: Vec, pub fixed_in_evals: Vec, + pub pi_in_evals: Vec, } pub trait MainSumcheckProver { diff --git a/ceno_zkvm/src/scheme/mock_prover.rs b/ceno_zkvm/src/scheme/mock_prover.rs index a477b1c6a..825c4fae4 100644 --- a/ceno_zkvm/src/scheme/mock_prover.rs +++ b/ceno_zkvm/src/scheme/mock_prover.rs @@ -26,7 +26,7 @@ use gkr_iop::{ use itertools::{Itertools, chain, enumerate, izip}; use multilinear_extensions::{ Expression, WitnessId, fmt, - mle::{ArcMultilinearExtension, IntoMLEs, MultilinearExtension}, + mle::{ArcMultilinearExtension, MultilinearExtension}, util::ceil_log2, utils::{eval_by_expr, eval_by_expr_with_fixed, eval_by_expr_with_instance}, }; @@ -40,7 +40,6 @@ use std::{ hash::Hash, io::{BufReader, ErrorKind}, marker::PhantomData, - ops::Index, sync::OnceLock, }; use strum::IntoEnumIterator; @@ -964,17 +963,22 @@ Hints: ) where E: LkMultiplicityKey, { - let pub_io_evals = pi - .to_vec::() - .into_iter() - .map(|v| Either::Right(E::from(*v.index(0)))) - .collect_vec(); - let pi_mles: Vec> = pi - .to_vec::() - .into_mles() - .into_iter() - .map(|v| v.into()) - .collect_vec(); + let all_pi_mles: Vec> = + pi.mles::().into_iter().map(|v| v.into()).collect_vec(); + let get_circuit_pi_inputs = |circuit_cs: &ConstraintSystem| { + let circuit_pub_io_evals = circuit_cs + .instance_values + .iter() + .map(|instance| Either::Right(pi.query_by_index::(instance.0))) + .collect_vec(); + let circuit_pi_mles = circuit_cs + .instance_openings + .iter() + .map(|instance| all_pi_mles[instance.0].clone()) + .collect_vec(); + (circuit_pub_io_evals, circuit_pi_mles) + }; + let mut rng = thread_rng(); let challenges = [0u8; 2].map(|_| E::random(&mut rng)); @@ -1000,11 +1004,7 @@ Hints: let ComposedConstrainSystem { zkvm_v1_css: cs, .. } = &composed_cs; - let pi_mles = cs - .instance_openings - .iter() - .map(|instance| pi_mles[instance.0].clone()) - .collect_vec(); + let (circuit_pub_io_evals, circuit_pi_mles) = get_circuit_pi_inputs(cs); // skip init table on non-first shard if composed_cs.with_omc_init_only() && !shard_ctx.is_first_shard() { @@ -1061,8 +1061,8 @@ Hints: &fixed, &witness, &structural_witness, - &pi_mles, - &pub_io_evals, + &circuit_pi_mles, + &circuit_pub_io_evals, num_rows, challenges, lkm_from_assignments, @@ -1093,8 +1093,8 @@ Hints: &fixed, &witness, &structural_witness, - &pi_mles, - &pub_io_evals, + &circuit_pi_mles, + &circuit_pub_io_evals, &challenges, ) .get_ext_field_vec() @@ -1108,8 +1108,8 @@ Hints: &fixed, &witness, &structural_witness, - &pi_mles, - &pub_io_evals, + &circuit_pi_mles, + &circuit_pub_io_evals, &challenges, ) .get_ext_field_vec() @@ -1162,11 +1162,7 @@ Hints: let fixed = fixed_mles.get(circuit_name).unwrap(); let witness = wit_mles.get(circuit_name).unwrap(); let structural_witness = structural_wit_mles.get(circuit_name).unwrap(); - let pi_mles = cs - .instance_openings - .iter() - .map(|instance| pi_mles[instance.0].clone()) - .collect_vec(); + let (circuit_pub_io_evals, circuit_pi_mles) = get_circuit_pi_inputs(cs); let num_rows = num_instances.get(circuit_name).unwrap(); if *num_rows == 0 { @@ -1204,8 +1200,8 @@ Hints: fixed, witness, structural_witness, - &pi_mles, - &pub_io_evals, + &circuit_pi_mles, + &circuit_pub_io_evals, &challenges, ); let ram_type_vec = ram_type_mle.get_ext_field_vec(); @@ -1217,8 +1213,8 @@ Hints: fixed, witness, structural_witness, - &pi_mles, - &pub_io_evals, + &circuit_pi_mles, + &circuit_pub_io_evals, &challenges, ); let w_selector_vec = w_selector.get_base_field_vec(); @@ -1268,11 +1264,7 @@ Hints: let fixed = fixed_mles.get(circuit_name).unwrap(); let witness = wit_mles.get(circuit_name).unwrap(); let structural_witness = structural_wit_mles.get(circuit_name).unwrap(); - let pi_mles = cs - .instance_openings - .iter() - .map(|instance| pi_mles[instance.0].clone()) - .collect_vec(); + let (circuit_pub_io_evals, circuit_pi_mles) = get_circuit_pi_inputs(cs); let num_rows = num_instances.get(circuit_name).unwrap(); if *num_rows == 0 { continue; @@ -1308,8 +1300,8 @@ Hints: fixed, witness, structural_witness, - &pi_mles, - &pub_io_evals, + &circuit_pi_mles, + &circuit_pub_io_evals, &challenges, ); let ram_type_vec = ram_type_mle.get_ext_field_vec(); @@ -1321,8 +1313,8 @@ Hints: fixed, witness, structural_witness, - &pi_mles, - &pub_io_evals, + &circuit_pi_mles, + &circuit_pub_io_evals, &challenges, ); let r_selector_vec = r_selector.get_base_field_vec(); @@ -1349,8 +1341,8 @@ Hints: fixed, witness, structural_witness, - &pi_mles, - &pub_io_evals, + &circuit_pi_mles, + &circuit_pub_io_evals, &challenges, ); filter_mle_by_selector_mle(v, r_selector.clone()) @@ -1481,6 +1473,11 @@ Hints: let mut cb = CircuitBuilder::new(&mut cs); let gs_init = GlobalState::initial_global_state(&mut cb).unwrap(); let gs_final = GlobalState::finalize_global_state(&mut cb).unwrap(); + let gs_pub_io_evals = cs + .instance_values + .iter() + .map(|instance| pi.query_by_index::(instance.0)) + .collect_vec(); let (mut gs_rs, rs_grp_by_anno, mut gs_ws, ws_grp_by_anno, gs) = derive_ram_rws!(RAMType::GlobalState); @@ -1489,10 +1486,7 @@ Hints: &[], &[], &[], - &pub_io_evals - .iter() - .map(|v| v.right().unwrap()) - .collect_vec(), + &gs_pub_io_evals, &challenges, &gs_final, ) @@ -1504,10 +1498,7 @@ Hints: &[], &[], &[], - &pub_io_evals - .iter() - .map(|v| v.right().unwrap()) - .collect_vec(), + &gs_pub_io_evals, &challenges, &gs_init, ) diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index 0f4153226..0c6215fb3 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -4,7 +4,7 @@ use gkr_iop::{ hal::ProverBackend, }; use std::{ - collections::{BTreeMap, HashMap}, + collections::BTreeMap, marker::PhantomData, sync::Arc, }; @@ -17,13 +17,9 @@ use crate::scheme::{ scheduler::{ChipScheduler, ChipTask, ChipTaskResult}, }; use either::Either; -use gkr_iop::hal::MultilinearPolynomial; use itertools::Itertools; use mpcs::{Point, PolynomialCommitmentScheme}; -use multilinear_extensions::{ - Expression, Instance, - mle::{IntoMLE, MultilinearExtension}, -}; +use multilinear_extensions::{Expression, Instance}; use p3::field::FieldAlgebra; use std::iter::Iterator; use sumcheck::{ @@ -39,6 +35,7 @@ use crate::structs::ProvingKey; use crate::{ e2e::ShardContext, error::ZKVMError, + instructions::riscv::constants::UINT_LIMBS, scheme::{ hal::{DeviceProvingKey, ProofInput}, utils::build_main_witness, @@ -46,7 +43,7 @@ use crate::{ structs::{TowerProofs, ZKVMProvingKey, ZKVMWitnesses}, }; -type CreateTableProof = (ZKVMChipProof, HashMap, Point); +type CreateTableProof = (ZKVMChipProof, MainSumcheckEvals, Point); pub type ZkVMCpuProver = ZKVMProver, CpuProver>>; @@ -150,35 +147,25 @@ impl< .get_device_proving_key(shard_ctx) .map(|dpk| dpk.fixed_mles.clone()) .unwrap_or_default(); + let pi_mles_preload = pi.mles::(); info_span!( "[ceno] create_proof_of_shard", shard_id = shard_ctx.shard_id ) .in_scope(|| { - let raw_pi = pi.to_vec::(); - let mut pi_evals = ZKVMProof::::pi_evals(&raw_pi); - let span = entered_span!("commit_to_pi", profiling_1 = true); // Include transcript-visible public values in canonical circuit order. // The order must match verifier and recursion verifier exactly. - // TODO deal with public io memory later for (_, circuit_pk) in self.pk.circuit_pks.iter() { - for instance_value in circuit_pk.get_cs().zkvm_v1_css.instance_value.iter() { - let idx = instance_value.0; - let eval = pi_evals - .get(idx) - .copied() - .expect("instance_value index out of bounds for pi_evals"); + for instance_value in circuit_pk.get_cs().zkvm_v1_css.instance_values.iter() { + let eval = pi.query_by_index::(instance_value.0); transcript.append_field_element_ext(&eval); } } exit_span!(span); - let pi: Vec> = - raw_pi.iter().map(|p| p.to_vec().into_mle()).collect(); - // commit to fixed commitment let span = entered_span!("commit_to_fixed_commit", profiling_1 = true); if let Some(fixed_commit) = &self.pk.fixed_commit @@ -286,7 +273,7 @@ impl< tracing::debug!("global challenges in prover: {:?}", challenges); let public_input_span = entered_span!("public_input", profiling_1 = true); - let public_input = self.device.transport_mles(&pi); + let public_input = self.device.transport_mles(&pi_mles_preload); exit_span!(public_input_span); let main_proofs_span = entered_span!("main_proofs", profiling_1 = true); @@ -302,7 +289,7 @@ impl< fixed_mles, challenges, public_input, - &pi_evals, + &pi, &circuit_trace_indices, ); exit_span!(build_tasks_span); @@ -317,11 +304,7 @@ impl< // Phase 3: Collect results let collect_results_span = entered_span!("collect_chip_results", profiling_1 = true); - let (chip_proofs, points, evaluations, pi_updates) = - Self::collect_chip_results(results); - for (idx, eval) in pi_updates { - pi_evals[idx] = eval; - } + let (chip_proofs, points, evaluations) = Self::collect_chip_results(results); exit_span!(collect_results_span); exit_span!(main_proofs_span); @@ -346,8 +329,7 @@ impl< exit_span!(pcs_opening); let vm_proof = ZKVMProof::new( - raw_pi, - pi_evals, + pi, chip_proofs, witin_commit, mpcs_opening_proof, @@ -399,7 +381,7 @@ impl< let gpu_input: ProofInput<'static, gkr_iop::gpu::GpuBackend> = unsafe { std::mem::transmute(task.input) }; - let (proof, pi_in_evals, input_opening_point) = + let (proof, opening_evals, input_opening_point) = create_chip_proof_gpu_impl::( task.circuit_name.as_str(), task.pk, @@ -416,7 +398,7 @@ impl< task_id: task.task_id, circuit_idx: task.circuit_idx, proof, - pi_in_evals, + opening_evals, input_opening_point, has_witness_or_fixed: task.has_witness_or_fixed, }) @@ -434,14 +416,14 @@ impl< // Prepare: deferred extraction for GPU, no-op for CPU self.device.prepare_chip_input(&mut task, witness_data); - let (proof, pi_in_evals, input_opening_point) = + let (proof, opening_evals, input_opening_point) = self.create_chip_proof(&task, transcript)?; Ok(ChipTaskResult { task_id: task.task_id, circuit_idx: task.circuit_idx, proof, - pi_in_evals, + opening_evals, input_opening_point, has_witness_or_fixed: task.has_witness_or_fixed, }) @@ -530,23 +512,10 @@ impl< let MainSumcheckEvals { wits_in_evals, fixed_in_evals, + pi_in_evals, } = evals; exit_span!(span); - // evaluate pi if there is instance query - let mut pi_in_evals: HashMap = HashMap::new(); - if !cs.instance_openings().is_empty() { - let span = entered_span!("pi::evals", profiling_2 = true); - for &Instance(idx) in cs.instance_openings() { - let poly = &input.public_input[idx]; - pi_in_evals.insert( - idx, - poly.eval(input_opening_point[..poly.num_vars()].to_vec()), - ); - } - exit_span!(span); - } - Ok(( ZKVMChipProof { r_out_evals, @@ -556,11 +525,13 @@ impl< gkr_iop_proof, tower_proof, ecc_proof, - fixed_in_evals, - wits_in_evals, num_instances: input.num_instances.clone(), }, - pi_in_evals, + MainSumcheckEvals { + wits_in_evals, + fixed_in_evals, + pi_in_evals, + }, input_opening_point, )) } @@ -578,7 +549,7 @@ impl< mut fixed_mles: Vec>>, challenges: [E; 2], public_input: Vec>>, - pi_evals: &[E], + pi: &PublicValues, circuit_trace_indices: &[Option], ) -> Vec> { // CPU path: eagerly extract witness MLEs from pcs_data @@ -655,12 +626,28 @@ impl< }; let fixed = fixed_mles.drain(..cs.num_fixed()).collect_vec(); + let public_io = cs + .instance_openings() + .iter() + .map(|Instance(idx)| { + debug_assert!(*idx < UINT_LIMBS, "instance_opening index {idx} out of range"); + public_input[*idx].clone() + }) + .collect_vec(); + + let pi_evals = cs + .zkvm_v1_css + .instance_values + .iter() + .map(|Instance(idx)| Either::Right(pi.query_by_index::(*idx))) + .collect_vec(); + let input_temp: ProofInput<'_, PB> = ProofInput { witness: witness_mle, fixed, structural_witness, - public_input: public_input.clone(), - pub_io_evals: pi_evals.iter().map(|p| Either::Right(*p)).collect(), + public_input: public_io, + pub_io_evals: pi_evals, num_instances: num_instances.clone(), has_ecc_ops: cs.has_ecc_ops(), }; @@ -728,12 +715,10 @@ impl< BTreeMap>>, Vec>, Vec>>, - HashMap, ) { let mut chip_proofs = BTreeMap::new(); let mut points = Vec::new(); let mut evaluations = Vec::new(); - let mut pi_updates = HashMap::new(); for result in results { tracing::trace!( @@ -745,23 +730,18 @@ impl< if result.has_witness_or_fixed { points.push(result.input_opening_point); evaluations.push(vec![ - result.proof.wits_in_evals.clone(), - result.proof.fixed_in_evals.clone(), + result.opening_evals.wits_in_evals, + result.opening_evals.fixed_in_evals, + result.opening_evals.pi_in_evals, ]); - } else { - assert!(result.proof.wits_in_evals.is_empty()); - assert!(result.proof.fixed_in_evals.is_empty()); } chip_proofs .entry(result.circuit_idx) .or_insert(vec![]) .push(result.proof); - for (idx, eval) in result.pi_in_evals { - pi_updates.insert(idx, eval); - } } - (chip_proofs, points, evaluations, pi_updates) + (chip_proofs, points, evaluations) } } @@ -890,23 +870,10 @@ where let MainSumcheckEvals { wits_in_evals, fixed_in_evals, + pi_in_evals, } = evals; exit_span!(span); - // evaluate pi if there is instance query - let mut pi_in_evals: HashMap = HashMap::new(); - if !cs.instance_openings().is_empty() { - let span = entered_span!("pi::evals", profiling_2 = true); - for &Instance(idx) in cs.instance_openings() { - let poly = &input.public_input[idx]; - pi_in_evals.insert( - idx, - poly.eval(input_opening_point[..poly.num_vars()].to_vec()), - ); - } - exit_span!(span); - } - Ok(( ZKVMChipProof { r_out_evals, @@ -916,11 +883,13 @@ where gkr_iop_proof, tower_proof, ecc_proof, - fixed_in_evals, - wits_in_evals, num_instances: input.num_instances, }, - pi_in_evals, + MainSumcheckEvals { + wits_in_evals, + fixed_in_evals, + pi_in_evals, + }, input_opening_point, )) } diff --git a/ceno_zkvm/src/scheme/scheduler.rs b/ceno_zkvm/src/scheme/scheduler.rs index 438421b2e..f0f76b03b 100644 --- a/ceno_zkvm/src/scheme/scheduler.rs +++ b/ceno_zkvm/src/scheme/scheduler.rs @@ -12,14 +12,14 @@ use crate::{ error::ZKVMError, - scheme::{ZKVMChipProof, hal::ProofInput}, + scheme::{ZKVMChipProof, hal::MainSumcheckEvals, hal::ProofInput}, structs::ProvingKey, }; use ff_ext::ExtensionField; use gkr_iop::hal::ProverBackend; use mpcs::Point; use p3::field::FieldAlgebra; -use std::{collections::HashMap, sync::OnceLock}; +use std::sync::OnceLock; use transcript::Transcript; static CHIP_PROVING_MODE: OnceLock = OnceLock::new(); @@ -77,8 +77,8 @@ pub struct ChipTaskResult { pub circuit_idx: usize, /// The generated proof pub proof: ZKVMChipProof, - /// Public input evaluations - pub pi_in_evals: HashMap, + /// Prover-only opening evaluations split by witness/fixed/pi domains. + pub opening_evals: MainSumcheckEvals, /// Opening point for this proof pub input_opening_point: Point, /// Whether this circuit has witness or fixed polynomials diff --git a/ceno_zkvm/src/scheme/tests.rs b/ceno_zkvm/src/scheme/tests.rs index 91a4e4563..763c3bb4e 100644 --- a/ceno_zkvm/src/scheme/tests.rs +++ b/ceno_zkvm/src/scheme/tests.rs @@ -248,13 +248,12 @@ fn test_rw_lk_expression_combination() { { Instrumented::<<::BaseField as PoseidonField>::P>::clear_metrics(); } - verifier + let _ = verifier .verify_chip_proof( name.as_str(), verifier.vk.circuit_vks.get(&name).unwrap(), &proof, - &[], - &[], + &PublicValues::default(), &mut v_transcript, NUM_FANIN, &PointAndEval::default(), @@ -397,7 +396,7 @@ fn test_single_add_instance_e2e() { .assign_table_circuit::>(&zkvm_cs, &prog_config, &program) .unwrap(); - let pi = PublicValues::new(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, vec![0], vec![0; 14]); + let pi = PublicValues::new(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, vec![0], [0; 14]); let transcript = BasicTranscript::new(b"riscv"); let zkvm_proof = prover .create_proof(&shard_ctx, zkvm_witness, pi, transcript) diff --git a/ceno_zkvm/src/scheme/verifier.rs b/ceno_zkvm/src/scheme/verifier.rs index ddd978ebf..859c88510 100644 --- a/ceno_zkvm/src/scheme/verifier.rs +++ b/ceno_zkvm/src/scheme/verifier.rs @@ -8,11 +8,12 @@ use std::{ #[cfg(debug_assertions)] use ff_ext::{Instrumented, PoseidonField}; -use super::{ZKVMChipProof, ZKVMProof}; +use super::{PublicValues, ZKVMChipProof, ZKVMProof}; use crate::{ error::ZKVMError, instructions::riscv::constants::{ - END_PC_IDX, HEAP_LENGTH_IDX, HEAP_START_ADDR_IDX, INIT_CYCLE_IDX, INIT_PC_IDX, SHARD_ID_IDX, + END_CYCLE_IDX, END_PC_IDX, HEAP_LENGTH_IDX, HEAP_START_ADDR_IDX, INIT_CYCLE_IDX, + INIT_PC_IDX, }, scheme::{ constants::{NUM_FANIN, SEPTIC_EXTENSION_DEGREE}, @@ -51,6 +52,40 @@ pub struct ZKVMVerifier> { } impl> ZKVMVerifier { + fn split_input_opening_evals( + circuit_vk: &VerifyingKey, + proof: &ZKVMChipProof, + ) -> Result<(Vec, Vec, Vec), ZKVMError> { + let cs = circuit_vk.get_cs(); + let Some(gkr_proof) = proof.gkr_iop_proof.as_ref() else { + return Err(ZKVMError::InvalidProof("missing gkr proof".into())); + }; + let Some(last_layer) = gkr_proof.0.last() else { + return Err(ZKVMError::InvalidProof("empty gkr proof layers".into())); + }; + + let evals = &last_layer.main.evals; + let wit_len = cs.num_witin(); + let fixed_len = cs.num_fixed(); + let pi_len = cs.instance_openings().len(); + let min_len = wit_len + fixed_len + pi_len; + if evals.len() < min_len { + return Err(ZKVMError::InvalidProof( + format!( + "insufficient main evals: {} < required {}", + evals.len(), + min_len + ) + .into(), + )); + } + + let wits_in_evals = evals[..wit_len].to_vec(); + let fixed_in_evals = evals[wit_len..(wit_len + fixed_len)].to_vec(); + let pi_in_evals = evals[(wit_len + fixed_len)..(wit_len + fixed_len + pi_len)].to_vec(); + Ok((wits_in_evals, fixed_in_evals, pi_in_evals)) + } + pub fn new(vk: ZKVMVerifyingKey) -> Self { ZKVMVerifier { vk } } @@ -116,19 +151,31 @@ impl> ZKVMVerifier } // each shard set init cycle = Tracer::SUBCYCLES_PER_INSN // to satisfy initial reads for all prev_cycle = 0 < init_cycle - assert_eq!(vm_proof.pi_evals[INIT_CYCLE_IDX], E::from_canonical_u64(Tracer::SUBCYCLES_PER_INSN)); + assert_eq!( + vm_proof.public_values.query_by_index::(INIT_CYCLE_IDX), + E::from_canonical_u64(Tracer::SUBCYCLES_PER_INSN) + ); // check init_pc match prev end_pc if let Some(prev_pc) = prev_pc { - assert_eq!(vm_proof.pi_evals[INIT_PC_IDX], prev_pc); + assert_eq!(vm_proof.public_values.query_by_index::(INIT_PC_IDX), prev_pc); } else { // first chunk, check program entry - assert_eq!(vm_proof.pi_evals[INIT_PC_IDX], E::from_canonical_u32(self.vk.entry_pc)); + assert_eq!( + vm_proof.public_values.query_by_index::(INIT_PC_IDX), + E::from_canonical_u32(self.vk.entry_pc) + ); } - let end_pc = vm_proof.pi_evals[END_PC_IDX]; + let end_pc = vm_proof.public_values.query_by_index::(END_PC_IDX); // check memory continuation consistency - let heap_addr_start_u32 = vm_proof.pi_evals[HEAP_START_ADDR_IDX].to_canonical_u64() as u32; - let heap_len= vm_proof.pi_evals[HEAP_LENGTH_IDX].to_canonical_u64() as u32; + let heap_addr_start_u32 = vm_proof + .public_values + .query_by_index::(HEAP_START_ADDR_IDX) + .to_canonical_u64() as u32; + let heap_len = vm_proof + .public_values + .query_by_index::(HEAP_LENGTH_IDX) + .to_canonical_u64() as u32; if let Some(prev_heap_addr_end) = prev_heap_addr_end { assert_eq!(heap_addr_start_u32, prev_heap_addr_end); // TODO check heap addr in prime field within range @@ -165,7 +212,13 @@ impl> ZKVMVerifier let mut prod_w = E::ONE; let mut logup_sum = E::ZERO; - let pi_evals = &vm_proof.pi_evals; + + // Global-state expressions are built from compact instance IDs + // (query order), not absolute public-value indices. + let pi_evals = [INIT_PC_IDX, INIT_CYCLE_IDX, END_PC_IDX, END_CYCLE_IDX] + .into_iter() + .map(|idx| vm_proof.public_values.query_by_index::(idx)) + .collect_vec(); // make sure circuit index of chip proofs are // subset of that of self.vk.circuit_vks @@ -185,34 +238,13 @@ impl> ZKVMVerifier // This must match prover and recursion verifier exactly. for (_, circuit_vk) in self.vk.circuit_vks.iter() { for instance_value in circuit_vk.get_cs().zkvm_v1_css.instance_values.iter() { - let idx = instance_value.0; - let eval = *pi_evals - .get(idx) - .expect("instance_value index out of bounds for pi_evals"); + let eval = vm_proof.public_values.query_by_index::(instance_value.0); transcript.append_field_element_ext(&eval); } } // check shard id - assert_eq!( - vm_proof.raw_pi[SHARD_ID_IDX], - vec![E::BaseField::from_canonical_usize(shard_id)] - ); - - // verify constant poly(s) evaluation result match - // we can evaluate at this moment because constant always evaluate to same value - // non-constant poly(s) will be verified in respective (table) proof accordingly - izip!(&vm_proof.raw_pi, pi_evals) - .enumerate() - .try_for_each(|(i, (raw, eval))| { - if raw.len() == 1 && E::from(raw[0]) != *eval { - Err(ZKVMError::VerifyError( - format!("{shard_id}th shard pub input on index {i} mismatch {raw:?} != {eval:?}").into(), - )) - } else { - Ok(()) - } - })?; + assert_eq!(vm_proof.public_values.shard_id, shard_id as u32); // write fixed commitment to transcript // TODO check soundness if there is no fixed_commit but got fixed proof? @@ -296,21 +328,6 @@ impl> ZKVMVerifier let circuit_name = &self.vk.circuit_index_to_name[index]; let circuit_vk = &self.vk.circuit_vks[circuit_name]; - // check chip proof is well-formed - if proof.wits_in_evals.len() != circuit_vk.get_cs().num_witin() - || proof.fixed_in_evals.len() != circuit_vk.get_cs().num_fixed() - { - return Err(ZKVMError::InvalidProof( - format!( - "{shard_id}th shard witness/fixed evaluations length mismatch: ({}, {}) != ({}, {})", - proof.wits_in_evals.len(), - proof.fixed_in_evals.len(), - circuit_vk.get_cs().num_witin(), - circuit_vk.get_cs().num_fixed(), - ) - .into(), - )); - } if proof.r_out_evals.len() != circuit_vk.get_cs().num_reads() || proof.w_out_evals.len() != circuit_vk.get_cs().num_writes() { @@ -363,12 +380,12 @@ impl> ZKVMVerifier logup_sum += chip_logup_sum; }; - let (input_opening_point, chip_shard_ec_sum) = self.verify_chip_proof( + let (input_opening_point, chip_shard_ec_sum, wits_in_evals, fixed_in_evals) = + self.verify_chip_proof( circuit_name, circuit_vk, proof, - pi_evals, - &vm_proof.raw_pi, + &vm_proof.public_values, transcript, NUM_FANIN, &point_eval, @@ -377,13 +394,13 @@ impl> ZKVMVerifier if circuit_vk.get_cs().num_witin() > 0 { witin_openings.push(( input_opening_point.len(), - (input_opening_point.clone(), proof.wits_in_evals.clone()), + (input_opening_point.clone(), wits_in_evals), )); } if circuit_vk.get_cs().num_fixed() > 0 { fixed_openings.push(( input_opening_point.len(), - (input_opening_point.clone(), proof.fixed_in_evals.clone()), + (input_opening_point.clone(), fixed_in_evals), )); } prod_w *= proof.w_out_evals.iter().flatten().copied().product::(); @@ -440,7 +457,7 @@ impl> ZKVMVerifier &[], &[], &[], - pi_evals, + &pi_evals, &challenges, &self.vk.initial_global_state_expr, ) @@ -451,7 +468,7 @@ impl> ZKVMVerifier &[], &[], &[], - pi_evals, + &pi_evals, &challenges, &self.vk.finalize_global_state_expr, ) @@ -483,13 +500,12 @@ impl> ZKVMVerifier _name: &str, circuit_vk: &VerifyingKey, proof: &ZKVMChipProof, - pi: &[E], - raw_pi: &[Vec], + public_values: &PublicValues, transcript: &mut impl Transcript, num_product_fanin: usize, _out_evals: &PointAndEval, challenges: &[E; 2], // derive challenge from PCS - ) -> Result<(Point, Option>), ZKVMError> { + ) -> Result<(Point, Option>, Vec, Vec), ZKVMError> { let composed_cs = circuit_vk.get_cs(); let ComposedConstrainSystem { zkvm_v1_css: cs, @@ -674,17 +690,29 @@ impl> ZKVMVerifier }, ] }; + let pi = cs + .instance_values + .iter() + .map(|instance| public_values.query_by_index::(instance.0)) + .collect_vec(); + let (wits_in_evals, fixed_in_evals, _pi_in_evals) = + Self::split_input_opening_evals(circuit_vk, proof)?; + let instance_mles = public_values + .mles::() + .into_iter() + .map(|mle| mle.get_base_field_vec().to_vec()) + .collect_vec(); let (_, rt) = gkr_circuit.verify( num_var_with_rotation, proof.gkr_iop_proof.clone().unwrap(), &evals, - pi, - raw_pi, + &pi, + &instance_mles, challenges, transcript, &selector_ctxs, )?; - Ok((rt, shard_ec_sum)) + Ok((rt, shard_ec_sum, wits_in_evals, fixed_in_evals)) } } diff --git a/ceno_zkvm/src/tables/shard_ram.rs b/ceno_zkvm/src/tables/shard_ram.rs index 23897fce8..cecd661d9 100644 --- a/ceno_zkvm/src/tables/shard_ram.rs +++ b/ceno_zkvm/src/tables/shard_ram.rs @@ -754,6 +754,16 @@ mod tests { .map(|record| record.ec_point.point.clone()) .sum(); + let mut shard_rw_sum = [0u32; SEPTIC_EXTENSION_DEGREE * 2]; + for (i, fe) in global_ec_sum + .x + .iter() + .chain(global_ec_sum.y.iter()) + .enumerate() + { + shard_rw_sum[i] = fe.as_canonical_u32(); + } + let public_value = PublicValues::new( 0, 0, @@ -766,12 +776,7 @@ mod tests { 0, 0, vec![0], // dummy - global_ec_sum - .x - .iter() - .chain(global_ec_sum.y.iter()) - .map(|fe| fe.as_canonical_u32()) - .collect_vec(), + shard_rw_sum, ); // assign witness @@ -807,18 +812,24 @@ mod tests { let zkvm_prover = ZKVMProver::new(zkvm_pk.into(), pd); let mut transcript = BasicTranscript::new(b"global chip test"); - let pub_io_evals = public_value - .to_vec::() - .into_iter() - .map(|v| Either::Right(E::from(*v.index(0)))) + let pub_io_evals = pk + .get_cs() + .zkvm_v1_css + .instance_values + .iter() + .map(|instance| Either::Right(public_value.query_by_index::(instance.0))) .collect_vec(); + let pi_mles = public_value.mles::(); #[cfg(not(feature = "gpu"))] let (witness_mles, structural_mles, public_input_mles) = { - let public_input_mles = public_value - .to_vec::() - .into_iter() - .map(|v| Arc::new(v.into_mle())) + let public_input_mles = pk + .get_cs() + .instance_openings() + .iter() + .map(|instance| { + Arc::new(pi_mles[instance.0].clone()) + }) .collect_vec(); ( witness[0].to_mles().into_iter().map(Arc::new).collect(), @@ -831,10 +842,12 @@ mod tests { let cuda_hal = get_cuda_hal().unwrap(); let witness_cpu: Vec<_> = witness[0].to_mles(); let structural_cpu: Vec<_> = witness[1].to_mles(); - let public_cpu: Vec<_> = public_value - .to_vec::() + let public_cpu: Vec<_> = pk + .get_cs() + .instance_openings() + .iter() + .map(|instance| pi_mles[instance.0].clone()) .into_iter() - .map(|v| v.into_mle()) .collect_vec(); ( witness_cpu @@ -882,17 +895,12 @@ mod tests { let mut transcript = BasicTranscript::new(b"global chip test"); let verifier = ZKVMVerifier::new(zkvm_vk); - let pi_evals = public_input_mles - .iter() - .map(|mle| mle.evaluate(&point[..mle.num_vars()])) - .collect_vec(); - let (vrf_point, _) = verifier + let (vrf_point, _, _, _) = verifier .verify_chip_proof( "global", &pk.vk, &proof, - &pi_evals, - &public_value.to_vec::(), + &public_value, &mut transcript, 2, &PointAndEval::default(), diff --git a/gkr_iop/src/circuit_builder.rs b/gkr_iop/src/circuit_builder.rs index 567a8ef06..8c54fedaf 100644 --- a/gkr_iop/src/circuit_builder.rs +++ b/gkr_iop/src/circuit_builder.rs @@ -269,7 +269,7 @@ impl ConstraintSystem { "query same pubio idx {idx} value more than once", ); self.instance_values.push(i); - Ok(i) + Ok(Instance(self.instance_values.len() - 1)) } pub fn query_instance_for_openings( diff --git a/gkr_iop/src/gkr.rs b/gkr_iop/src/gkr.rs index b025aa1e4..a741775ff 100644 --- a/gkr_iop/src/gkr.rs +++ b/gkr_iop/src/gkr.rs @@ -123,7 +123,7 @@ impl GKRCircuit { gkr_proof: GKRProof, out_evals: &[PointAndEval], pub_io_evals: &[E], - raw_pi: &[Vec], + instance_mles: &[Vec], challenges: &[E], transcript: &mut impl Transcript, selector_ctxs: &[SelectorContext], @@ -145,7 +145,7 @@ impl GKRCircuit { layer_proof, &mut evaluations, pub_io_evals, - raw_pi, + instance_mles, &mut challenges, transcript, selector_ctxs, diff --git a/gkr_iop/src/gkr/layer.rs b/gkr_iop/src/gkr/layer.rs index a67862df7..711458dbf 100644 --- a/gkr_iop/src/gkr/layer.rs +++ b/gkr_iop/src/gkr/layer.rs @@ -258,7 +258,7 @@ impl Layer { proof: LayerProof, claims: &mut [PointAndEval], pub_io_evals: &[E], - raw_pi: &[Vec], + instance_mles: &[Vec], challenges: &mut Vec, transcript: &mut Trans, selector_ctxs: &[SelectorContext], @@ -273,7 +273,7 @@ impl Layer { proof, eval_and_dedup_points, pub_io_evals, - raw_pi, + instance_mles, challenges, transcript, selector_ctxs, diff --git a/gkr_iop/src/gkr/layer/zerocheck_layer.rs b/gkr_iop/src/gkr/layer/zerocheck_layer.rs index cf1da6df2..7ac6ae8d4 100644 --- a/gkr_iop/src/gkr/layer/zerocheck_layer.rs +++ b/gkr_iop/src/gkr/layer/zerocheck_layer.rs @@ -77,7 +77,7 @@ pub trait ZerocheckLayer { proof: LayerProof, eval_and_dedup_points: Vec<(Vec, Option>)>, pub_io_evals: &[E], - raw_pi: &[Vec], + instance_mles: &[Vec], challenges: &[E], transcript: &mut impl Transcript, selector_ctxs: &[SelectorContext], @@ -228,7 +228,7 @@ impl ZerocheckLayer for Layer { proof: LayerProof, mut eval_and_dedup_points: Vec<(Vec, Option>)>, pub_io_evals: &[E], - raw_pi: &[Vec], + instance_mles: &[Vec], challenges: &[E], transcript: &mut impl Transcript, selector_ctxs: &[SelectorContext], @@ -386,12 +386,15 @@ impl ZerocheckLayer for Layer { } } - // check pub-io - // assume public io is tiny vector, so we evaluate it directly without PCS + // check pub-io openings by evaluating the opened public-input MLEs. let pubio_offset = self.n_witin + self.n_fixed; for (index, instance) in self.instance_openings.iter().enumerate() { let index = pubio_offset + index; - let poly = raw_pi[instance.0].to_vec().into_mle(); + let poly = instance_mles + .get(instance.0) + .expect("instance opening index out of bounds for instance_mles") + .clone() + .into_mle(); let expected_eval = poly.evaluate(&in_point[..poly.num_vars()]); if expected_eval != main_evals[index] { return Err(BackendError::LayerVerificationFailed( From 7099d2c8844a9e9cdbcd555267cf0ad1d9255ec2 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Thu, 26 Mar 2026 19:15:09 +0800 Subject: [PATCH 3/7] misc: clippy --- ceno_recursion/src/aggregation/mod.rs | 2 +- ceno_zkvm/src/scheme/verifier.rs | 1 + ceno_zkvm/src/tables/shard_ram.rs | 4 ++-- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/ceno_recursion/src/aggregation/mod.rs b/ceno_recursion/src/aggregation/mod.rs index c7ef469d6..c8d94fedb 100644 --- a/ceno_recursion/src/aggregation/mod.rs +++ b/ceno_recursion/src/aggregation/mod.rs @@ -265,7 +265,7 @@ impl CenoAggregationProver { .collect(); let user_public_values: Vec = zkvm_proof_inputs .iter() - .flat_map(|p| p.raw_pi.iter().copied().collect::>()) + .flat_map(|p| p.raw_pi.to_vec()) .collect(); let leaf_inputs = chunk_ceno_leaf_proof_inputs(zkvm_proof_inputs); diff --git a/ceno_zkvm/src/scheme/verifier.rs b/ceno_zkvm/src/scheme/verifier.rs index 859c88510..7e578b4ec 100644 --- a/ceno_zkvm/src/scheme/verifier.rs +++ b/ceno_zkvm/src/scheme/verifier.rs @@ -52,6 +52,7 @@ pub struct ZKVMVerifier> { } impl> ZKVMVerifier { + #[allow(clippy::type_complexity)] fn split_input_opening_evals( circuit_vk: &VerifyingKey, proof: &ZKVMChipProof, diff --git a/ceno_zkvm/src/tables/shard_ram.rs b/ceno_zkvm/src/tables/shard_ram.rs index cecd661d9..dd59af682 100644 --- a/ceno_zkvm/src/tables/shard_ram.rs +++ b/ceno_zkvm/src/tables/shard_ram.rs @@ -651,7 +651,7 @@ mod tests { use mpcs::{BasefoldDefault, PolynomialCommitmentScheme, SecurityLevel}; use p3::babybear::BabyBear; use rand::thread_rng; - use std::{ops::Index, sync::Arc}; + use std::sync::Arc; use tracing_forest::{ForestLayer, util::LevelFilter}; use tracing_subscriber::{EnvFilter, Registry, layer::SubscriberExt, util::SubscriberInitExt}; use transcript::BasicTranscript; @@ -670,7 +670,7 @@ mod tests { gpu::{MultilinearExtensionGpu, get_cuda_hal}, hal::MultilinearPolynomial, }; - use multilinear_extensions::mle::IntoMLE; + use crate::scheme::constants::SEPTIC_EXTENSION_DEGREE; use p3::field::PrimeField32; type E = BabyBearExt4; From d418e9d3a381cf0ef608772ae08245c277c48c5e Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Thu, 26 Mar 2026 20:31:28 +0800 Subject: [PATCH 4/7] fmt && fix test --- ceno_recursion/src/zkvm_verifier/binding.rs | 13 +++--- ceno_recursion/src/zkvm_verifier/verifier.rs | 1 - ceno_zkvm/src/scheme.rs | 40 ++++++++++-------- ceno_zkvm/src/scheme/mock_prover.rs | 30 ++++---------- ceno_zkvm/src/scheme/prover.rs | 24 ++++------- ceno_zkvm/src/scheme/scheduler.rs | 5 ++- ceno_zkvm/src/scheme/verifier.rs | 43 +++++++++++--------- ceno_zkvm/src/tables/shard_ram.rs | 11 ++--- 8 files changed, 78 insertions(+), 89 deletions(-) diff --git a/ceno_recursion/src/zkvm_verifier/binding.rs b/ceno_recursion/src/zkvm_verifier/binding.rs index e6abae576..aacb10601 100644 --- a/ceno_recursion/src/zkvm_verifier/binding.rs +++ b/ceno_recursion/src/zkvm_verifier/binding.rs @@ -69,11 +69,13 @@ fn raw_pi_from_public_values(public_values: &ceno_zkvm::scheme::PublicValues) -> fn mles_from_public_values(public_values: &ceno_zkvm::scheme::PublicValues) -> Vec> { (0..UINT_LIMBS) .map(|limb_index| { - public_values - .public_io - .iter() - .map(|value| F::from_canonical_u16(((value >> (limb_index * LIMB_BITS)) & LIMB_MASK) as u16)) - .collect_vec() + public_values + .public_io + .iter() + .map(|value| { + F::from_canonical_u16(((value >> (limb_index * LIMB_BITS)) & LIMB_MASK) as u16) + }) + .collect_vec() }) .collect_vec() } @@ -711,7 +713,6 @@ impl Hintable for ZKVMChipProofInput { stream.extend(n_inst_0_bit_decomps.write()); stream.extend(n_inst_1_bit_decomps.write()); - stream } } diff --git a/ceno_recursion/src/zkvm_verifier/verifier.rs b/ceno_recursion/src/zkvm_verifier/verifier.rs index 1681c1ddc..1ea788c6c 100644 --- a/ceno_recursion/src/zkvm_verifier/verifier.rs +++ b/ceno_recursion/src/zkvm_verifier/verifier.rs @@ -1168,7 +1168,6 @@ pub fn verify_gkr_circuit( builder.assert_ext_eq(expected_eval, main_eval); } - // TODO: we should store alpha_pows in a bigger array to avoid concatenating them let main_sumcheck_challenges_len: Usize = builder.eval(alpha_pows.len() + Usize::from(2)); diff --git a/ceno_zkvm/src/scheme.rs b/ceno_zkvm/src/scheme.rs index 3ba4e3651..b105eb7f9 100644 --- a/ceno_zkvm/src/scheme.rs +++ b/ceno_zkvm/src/scheme.rs @@ -121,25 +121,25 @@ impl PublicValues { shard_rw_sum, } } - pub fn query_by_index(&self, index: usize) -> E { + pub fn query_by_index(&self, index: usize) -> E::BaseField { match index { - EXIT_CODE_IDX => E::from_canonical_u32(self.exit_code & 0xffff), + EXIT_CODE_IDX => E::BaseField::from_canonical_u32(self.exit_code & 0xffff), idx if idx == EXIT_CODE_IDX + 1 => { - E::from_canonical_u32((self.exit_code >> 16) & 0xffff) + E::BaseField::from_canonical_u32((self.exit_code >> 16) & 0xffff) } - INIT_PC_IDX => E::from_canonical_u32(self.init_pc), - INIT_CYCLE_IDX => E::from_canonical_u64(self.init_cycle), - END_PC_IDX => E::from_canonical_u32(self.end_pc), - END_CYCLE_IDX => E::from_canonical_u64(self.end_cycle), - SHARD_ID_IDX => E::from_canonical_u32(self.shard_id), - HEAP_START_ADDR_IDX => E::from_canonical_u32(self.heap_start_addr), - HEAP_LENGTH_IDX => E::from_canonical_u32(self.heap_shard_len), - HINT_START_ADDR_IDX => E::from_canonical_u32(self.hint_start_addr), - HINT_LENGTH_IDX => E::from_canonical_u32(self.hint_shard_len), + INIT_PC_IDX => E::BaseField::from_canonical_u32(self.init_pc), + INIT_CYCLE_IDX => E::BaseField::from_canonical_u64(self.init_cycle), + END_PC_IDX => E::BaseField::from_canonical_u32(self.end_pc), + END_CYCLE_IDX => E::BaseField::from_canonical_u64(self.end_cycle), + SHARD_ID_IDX => E::BaseField::from_canonical_u32(self.shard_id), + HEAP_START_ADDR_IDX => E::BaseField::from_canonical_u32(self.heap_start_addr), + HEAP_LENGTH_IDX => E::BaseField::from_canonical_u32(self.heap_shard_len), + HINT_START_ADDR_IDX => E::BaseField::from_canonical_u32(self.hint_start_addr), + HINT_LENGTH_IDX => E::BaseField::from_canonical_u32(self.hint_shard_len), idx if (SHARD_RW_SUM_IDX..(SHARD_RW_SUM_IDX + SEPTIC_EXTENSION_DEGREE * 2)) .contains(&idx) => { - E::from_canonical_u32(self.shard_rw_sum[idx - SHARD_RW_SUM_IDX]) + E::BaseField::from_canonical_u32(self.shard_rw_sum[idx - SHARD_RW_SUM_IDX]) } _ => panic!("public value index {index} out of range"), } @@ -149,15 +149,22 @@ impl PublicValues { // public_io is represented as UINT_LIMBS columns. (0..UINT_LIMBS) .map(|limb_index| { - self.public_io + let limb_values = self + .public_io .iter() .map(|value| { E::BaseField::from_canonical_u16( ((value >> (limb_index * LIMB_BITS)) & LIMB_MASK) as u16, ) }) - .collect_vec() - .into_mle() + .collect_vec(); + + // Empty public_io means a constant-zero public input column. + if limb_values.is_empty() { + vec![E::BaseField::ZERO].into_mle() + } else { + limb_values.into_mle() + } }) .collect_vec() } @@ -194,7 +201,6 @@ impl> ZKVMProof { } } - pub fn num_circuits(&self) -> usize { self.chip_proofs.len() } diff --git a/ceno_zkvm/src/scheme/mock_prover.rs b/ceno_zkvm/src/scheme/mock_prover.rs index 825c4fae4..20f4da1b1 100644 --- a/ceno_zkvm/src/scheme/mock_prover.rs +++ b/ceno_zkvm/src/scheme/mock_prover.rs @@ -969,7 +969,7 @@ Hints: let circuit_pub_io_evals = circuit_cs .instance_values .iter() - .map(|instance| Either::Right(pi.query_by_index::(instance.0))) + .map(|instance| Either::Right(E::from(pi.query_by_index::(instance.0)))) .collect_vec(); let circuit_pi_mles = circuit_cs .instance_openings @@ -1476,34 +1476,20 @@ Hints: let gs_pub_io_evals = cs .instance_values .iter() - .map(|instance| pi.query_by_index::(instance.0)) + .map(|instance| E::from(pi.query_by_index::(instance.0))) .collect_vec(); let (mut gs_rs, rs_grp_by_anno, mut gs_ws, ws_grp_by_anno, gs) = derive_ram_rws!(RAMType::GlobalState); gs_rs.insert( - eval_by_expr_with_instance( - &[], - &[], - &[], - &gs_pub_io_evals, - &challenges, - &gs_final, - ) - .right() - .unwrap(), + eval_by_expr_with_instance(&[], &[], &[], &gs_pub_io_evals, &challenges, &gs_final) + .right() + .unwrap(), ); gs_ws.insert( - eval_by_expr_with_instance( - &[], - &[], - &[], - &gs_pub_io_evals, - &challenges, - &gs_init, - ) - .right() - .unwrap(), + eval_by_expr_with_instance(&[], &[], &[], &gs_pub_io_evals, &challenges, &gs_init) + .right() + .unwrap(), ); // gs stores { (pc, timestamp) } diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index 0c6215fb3..96b28ce1a 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -3,11 +3,7 @@ use gkr_iop::{ cpu::{CpuBackend, CpuProver}, hal::ProverBackend, }; -use std::{ - collections::BTreeMap, - marker::PhantomData, - sync::Arc, -}; +use std::{collections::BTreeMap, marker::PhantomData, sync::Arc}; #[cfg(feature = "gpu")] use crate::scheme::gpu::estimate_chip_proof_memory; @@ -157,10 +153,10 @@ impl< let span = entered_span!("commit_to_pi", profiling_1 = true); // Include transcript-visible public values in canonical circuit order. // The order must match verifier and recursion verifier exactly. + // TODO deal with vector-based public value to transcript for (_, circuit_pk) in self.pk.circuit_pks.iter() { for instance_value in circuit_pk.get_cs().zkvm_v1_css.instance_values.iter() { - let eval = pi.query_by_index::(instance_value.0); - transcript.append_field_element_ext(&eval); + transcript.append_field_element(&pi.query_by_index::(instance_value.0)); } } @@ -328,12 +324,7 @@ impl< }); exit_span!(pcs_opening); - let vm_proof = ZKVMProof::new( - pi, - chip_proofs, - witin_commit, - mpcs_opening_proof, - ); + let vm_proof = ZKVMProof::new(pi, chip_proofs, witin_commit, mpcs_opening_proof); Ok(vm_proof) }) @@ -630,7 +621,10 @@ impl< .instance_openings() .iter() .map(|Instance(idx)| { - debug_assert!(*idx < UINT_LIMBS, "instance_opening index {idx} out of range"); + debug_assert!( + *idx < UINT_LIMBS, + "instance_opening index {idx} out of range" + ); public_input[*idx].clone() }) .collect_vec(); @@ -639,7 +633,7 @@ impl< .zkvm_v1_css .instance_values .iter() - .map(|Instance(idx)| Either::Right(pi.query_by_index::(*idx))) + .map(|Instance(idx)| Either::Left(pi.query_by_index::(*idx))) .collect_vec(); let input_temp: ProofInput<'_, PB> = ProofInput { diff --git a/ceno_zkvm/src/scheme/scheduler.rs b/ceno_zkvm/src/scheme/scheduler.rs index f0f76b03b..c1bd260f4 100644 --- a/ceno_zkvm/src/scheme/scheduler.rs +++ b/ceno_zkvm/src/scheme/scheduler.rs @@ -12,7 +12,10 @@ use crate::{ error::ZKVMError, - scheme::{ZKVMChipProof, hal::MainSumcheckEvals, hal::ProofInput}, + scheme::{ + ZKVMChipProof, + hal::{MainSumcheckEvals, ProofInput}, + }, structs::ProvingKey, }; use ff_ext::ExtensionField; diff --git a/ceno_zkvm/src/scheme/verifier.rs b/ceno_zkvm/src/scheme/verifier.rs index 7e578b4ec..28afc506b 100644 --- a/ceno_zkvm/src/scheme/verifier.rs +++ b/ceno_zkvm/src/scheme/verifier.rs @@ -1,5 +1,5 @@ use either::Either; -use ff_ext::ExtensionField; +use ff_ext::{ExtensionField, SmallField}; use std::{ iter::{self, once, repeat_n}, marker::PhantomData, @@ -154,16 +154,19 @@ impl> ZKVMVerifier // to satisfy initial reads for all prev_cycle = 0 < init_cycle assert_eq!( vm_proof.public_values.query_by_index::(INIT_CYCLE_IDX), - E::from_canonical_u64(Tracer::SUBCYCLES_PER_INSN) + E::BaseField::from_canonical_u64(Tracer::SUBCYCLES_PER_INSN) ); // check init_pc match prev end_pc if let Some(prev_pc) = prev_pc { - assert_eq!(vm_proof.public_values.query_by_index::(INIT_PC_IDX), prev_pc); + assert_eq!( + vm_proof.public_values.query_by_index::(INIT_PC_IDX), + prev_pc + ); } else { // first chunk, check program entry assert_eq!( vm_proof.public_values.query_by_index::(INIT_PC_IDX), - E::from_canonical_u32(self.vk.entry_pc) + E::BaseField::from_canonical_u32(self.vk.entry_pc) ); } let end_pc = vm_proof.public_values.query_by_index::(END_PC_IDX); @@ -213,12 +216,11 @@ impl> ZKVMVerifier let mut prod_w = E::ONE; let mut logup_sum = E::ZERO; - // Global-state expressions are built from compact instance IDs // (query order), not absolute public-value indices. let pi_evals = [INIT_PC_IDX, INIT_CYCLE_IDX, END_PC_IDX, END_CYCLE_IDX] .into_iter() - .map(|idx| vm_proof.public_values.query_by_index::(idx)) + .map(|idx| E::from(vm_proof.public_values.query_by_index::(idx))) .collect_vec(); // make sure circuit index of chip proofs are @@ -239,8 +241,9 @@ impl> ZKVMVerifier // This must match prover and recursion verifier exactly. for (_, circuit_vk) in self.vk.circuit_vks.iter() { for instance_value in circuit_vk.get_cs().zkvm_v1_css.instance_values.iter() { - let eval = vm_proof.public_values.query_by_index::(instance_value.0); - transcript.append_field_element_ext(&eval); + transcript.append_field_element( + &vm_proof.public_values.query_by_index::(instance_value.0), + ); } } @@ -381,17 +384,17 @@ impl> ZKVMVerifier logup_sum += chip_logup_sum; }; - let (input_opening_point, chip_shard_ec_sum, wits_in_evals, fixed_in_evals) = - self.verify_chip_proof( - circuit_name, - circuit_vk, - proof, - &vm_proof.public_values, - transcript, - NUM_FANIN, - &point_eval, - &challenges, - )?; + let (input_opening_point, chip_shard_ec_sum, wits_in_evals, fixed_in_evals) = self + .verify_chip_proof( + circuit_name, + circuit_vk, + proof, + &vm_proof.public_values, + transcript, + NUM_FANIN, + &point_eval, + &challenges, + )?; if circuit_vk.get_cs().num_witin() > 0 { witin_openings.push(( input_opening_point.len(), @@ -694,7 +697,7 @@ impl> ZKVMVerifier let pi = cs .instance_values .iter() - .map(|instance| public_values.query_by_index::(instance.0)) + .map(|instance| E::from(public_values.query_by_index::(instance.0))) .collect_vec(); let (wits_in_evals, fixed_in_evals, _pi_in_evals) = Self::split_input_opening_evals(circuit_vk, proof)?; diff --git a/ceno_zkvm/src/tables/shard_ram.rs b/ceno_zkvm/src/tables/shard_ram.rs index dd59af682..f9caf5513 100644 --- a/ceno_zkvm/src/tables/shard_ram.rs +++ b/ceno_zkvm/src/tables/shard_ram.rs @@ -659,8 +659,8 @@ mod tests { use crate::{ circuit_builder::{CircuitBuilder, ConstraintSystem}, scheme::{ - PublicValues, create_backend, create_prover, hal::ProofInput, prover::ZKVMProver, - septic_curve::SepticPoint, verifier::ZKVMVerifier, + PublicValues, constants::SEPTIC_EXTENSION_DEGREE, create_backend, create_prover, + hal::ProofInput, prover::ZKVMProver, septic_curve::SepticPoint, verifier::ZKVMVerifier, }, structs::{ComposedConstrainSystem, PointAndEval, ProgramParams, RAMType, ZKVMProvingKey}, tables::{ShardRamCircuit, ShardRamInput, ShardRamRecord, TableCircuit}, @@ -670,7 +670,6 @@ mod tests { gpu::{MultilinearExtensionGpu, get_cuda_hal}, hal::MultilinearPolynomial, }; - use crate::scheme::constants::SEPTIC_EXTENSION_DEGREE; use p3::field::PrimeField32; type E = BabyBearExt4; @@ -817,7 +816,7 @@ mod tests { .zkvm_v1_css .instance_values .iter() - .map(|instance| Either::Right(public_value.query_by_index::(instance.0))) + .map(|instance| Either::Right(E::from(public_value.query_by_index::(instance.0)))) .collect_vec(); let pi_mles = public_value.mles::(); @@ -827,9 +826,7 @@ mod tests { .get_cs() .instance_openings() .iter() - .map(|instance| { - Arc::new(pi_mles[instance.0].clone()) - }) + .map(|instance| Arc::new(pi_mles[instance.0].clone())) .collect_vec(); ( witness[0].to_mles().into_iter().map(Arc::new).collect(), From 876d8a7074ef392202dfee37ea91efe1093eabeb Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Thu, 26 Mar 2026 21:27:13 +0800 Subject: [PATCH 5/7] fix error message --- ceno_zkvm/src/bin/e2e.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ceno_zkvm/src/bin/e2e.rs b/ceno_zkvm/src/bin/e2e.rs index e16d93c82..396321724 100644 --- a/ceno_zkvm/src/bin/e2e.rs +++ b/ceno_zkvm/src/bin/e2e.rs @@ -426,7 +426,7 @@ fn soundness_test>( unreachable!() }; - if !msg.starts_with("0th round's prover message is not consistent with the claim") { + if !msg.starts_with("assertion `left == right` failed") { error!("unknown panic {msg:?}"); panic::resume_unwind(err); }; From 1c7fc09a347674fa93ed9bb2b20c14e9123caf37 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Thu, 26 Mar 2026 22:42:21 +0800 Subject: [PATCH 6/7] recursion verifier bug fix --- ceno_recursion/src/zkvm_verifier/verifier.rs | 25 +++++++++++++++----- 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/ceno_recursion/src/zkvm_verifier/verifier.rs b/ceno_recursion/src/zkvm_verifier/verifier.rs index 1ea788c6c..d8e9a105f 100644 --- a/ceno_recursion/src/zkvm_verifier/verifier.rs +++ b/ceno_recursion/src/zkvm_verifier/verifier.rs @@ -26,7 +26,10 @@ use crate::{ SepticExtensionVariable, SepticPointVariable, SumcheckLayerProofVariable, }, }; -use ceno_zkvm::structs::{ComposedConstrainSystem, VerifyingKey, ZKVMVerifyingKey}; +use ceno_zkvm::{ + instructions::riscv::constants::{END_CYCLE_IDX, END_PC_IDX, INIT_CYCLE_IDX, INIT_PC_IDX}, + structs::{ComposedConstrainSystem, VerifyingKey, ZKVMVerifyingKey}, +}; use ff_ext::BabyBearExt4; use crate::transcript::{challenger_add_forked_index, clone_challenger_state}; @@ -107,9 +110,8 @@ pub fn verify_zkvm_proof>( for (_, circuit_vk) in vk.circuit_vks.iter() { for instance_value in circuit_vk.get_cs().zkvm_v1_css.instance_values.iter() { let raw = builder.get(&zkvm_proof_input.raw_pi, instance_value.0); - let eval = builder.ext_from_base_slice(&[raw]); - let eval_felts = builder.ext2felt(eval); - challenger.observe_slice(builder, eval_felts); + // Match native verifier transcript behavior: append base-field PI element directly. + challenger.observe(builder, raw); } } @@ -521,6 +523,17 @@ pub fn verify_zkvm_proof>( &unipoly_extrapolator, &mut challenger, ); + // Global-state expressions are defined over compact/query-order PI slots. + // Keep this aligned with ceno_zkvm verifier: [init_pc, init_cycle, end_pc, end_cycle]. + let global_state_pi_evals: Array> = builder.dyn_array(4); + [INIT_PC_IDX, INIT_CYCLE_IDX, END_PC_IDX, END_CYCLE_IDX] + .into_iter() + .enumerate() + .for_each(|(i, idx)| { + let raw = builder.get(&zkvm_proof_input.raw_pi, idx); + let eval = builder.ext_from_base_slice(&[raw]); + builder.set(&global_state_pi_evals, i, eval); + }); let empty_arr: Array> = builder.dyn_array(0); let initial_global_state = eval_ceno_expr_with_instance( @@ -528,7 +541,7 @@ pub fn verify_zkvm_proof>( &empty_arr, &empty_arr, &empty_arr, - &zkvm_proof_input.pi_evals, + &global_state_pi_evals, &challenges, &vk.initial_global_state_expr, ); @@ -539,7 +552,7 @@ pub fn verify_zkvm_proof>( &empty_arr, &empty_arr, &empty_arr, - &zkvm_proof_input.pi_evals, + &global_state_pi_evals, &challenges, &vk.finalize_global_state_expr, ); From f49649a643ea044c05174768651896865b4152b8 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Thu, 26 Mar 2026 22:57:59 +0800 Subject: [PATCH 7/7] fix(recursion): accept extra tower main eval tail in opening split --- ceno_recursion/src/zkvm_verifier/verifier.rs | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/ceno_recursion/src/zkvm_verifier/verifier.rs b/ceno_recursion/src/zkvm_verifier/verifier.rs index d8e9a105f..743ce9ad1 100644 --- a/ceno_recursion/src/zkvm_verifier/verifier.rs +++ b/ceno_recursion/src/zkvm_verifier/verifier.rs @@ -587,12 +587,14 @@ fn split_input_opening_evals( let wit_end = Usize::from(num_witin); let fixed_end: Usize = builder.eval(wit_end.clone() + Usize::from(num_fixed)); let pi_end: Usize = builder.eval(fixed_end.clone() + Usize::from(num_pi)); - builder.assert_usize_eq(main_evals.len(), pi_end.clone()); + // Native verifier accepts extra trailing evals; only the prefix is consumed here. + // Keep recursion semantics aligned by slicing the required prefix. + let eval_prefix = main_evals.slice(builder, Usize::from(0), pi_end.clone()); ( - main_evals.slice(builder, Usize::from(0), wit_end), - main_evals.slice(builder, Usize::from(num_witin), fixed_end), - main_evals.slice(builder, Usize::from(num_witin + num_fixed), pi_end), + eval_prefix.slice(builder, Usize::from(0), wit_end), + eval_prefix.slice(builder, Usize::from(num_witin), fixed_end), + eval_prefix.slice(builder, Usize::from(num_witin + num_fixed), pi_end), ) }