diff --git a/ceno_recursion/src/aggregation/mod.rs b/ceno_recursion/src/aggregation/mod.rs index d17829fe9..c8d94fedb 100644 --- a/ceno_recursion/src/aggregation/mod.rs +++ b/ceno_recursion/src/aggregation/mod.rs @@ -265,7 +265,7 @@ impl CenoAggregationProver { .collect(); let user_public_values: Vec = zkvm_proof_inputs .iter() - .flat_map(|p| p.raw_pi.iter().flat_map(|v| v.clone()).collect::>()) + .flat_map(|p| p.raw_pi.to_vec()) .collect(); let leaf_inputs = chunk_ceno_leaf_proof_inputs(zkvm_proof_inputs); @@ -410,18 +410,9 @@ impl CenoLeafVmVerifierConfig { } let pv = &raw_pi; - let init_pc = { - let arr = builder.get(pv, INIT_PC_IDX); - builder.get(&arr, 0) - }; - let end_pc = { - let arr = builder.get(pv, END_PC_IDX); - builder.get(&arr, 0) - }; - let exit_code = { - let arr = builder.get(pv, EXIT_CODE_IDX); - builder.get(&arr, 0) - }; + let init_pc = builder.get(pv, INIT_PC_IDX); + let end_pc = builder.get(pv, END_PC_IDX); + let exit_code = builder.get(pv, EXIT_CODE_IDX); builder.assign(&stark_pvs.connector.initial_pc, init_pc); builder.assign(&stark_pvs.connector.final_pc, end_pc); builder.assign(&stark_pvs.connector.exit_code, exit_code); diff --git a/ceno_recursion/src/zkvm_verifier/binding.rs b/ceno_recursion/src/zkvm_verifier/binding.rs index 08c76dbe0..aacb10601 100644 --- a/ceno_recursion/src/zkvm_verifier/binding.rs +++ b/ceno_recursion/src/zkvm_verifier/binding.rs @@ -14,6 +14,7 @@ use crate::{ }, }; use ceno_zkvm::{ + instructions::riscv::constants::{LIMB_BITS, LIMB_MASK, UINT_LIMBS}, scheme::{ZKVMChipProof, ZKVMProof}, structs::{EccQuarkProof, TowerProofs, ZKVMVerifyingKey}, }; @@ -41,6 +42,48 @@ pub type E = BinomialExtensionField; pub type RecPcs = Basefold; pub type InnerConfig = AsmConfig; +fn raw_pi_from_public_values(public_values: &ceno_zkvm::scheme::PublicValues) -> Vec { + vec![ + F::from_canonical_u32(public_values.exit_code & 0xffff), + F::from_canonical_u32((public_values.exit_code >> 16) & 0xffff), + F::from_canonical_u32(public_values.init_pc), + F::from_canonical_u64(public_values.init_cycle), + F::from_canonical_u32(public_values.end_pc), + F::from_canonical_u64(public_values.end_cycle), + F::from_canonical_u32(public_values.shard_id), + F::from_canonical_u32(public_values.heap_start_addr), + F::from_canonical_u32(public_values.heap_shard_len), + F::from_canonical_u32(public_values.hint_start_addr), + F::from_canonical_u32(public_values.hint_shard_len), + ] + .into_iter() + .chain( + public_values + .shard_rw_sum + .iter() + .map(|value| F::from_canonical_u32(*value)), + ) + .collect_vec() +} + +fn mles_from_public_values(public_values: &ceno_zkvm::scheme::PublicValues) -> Vec> { + (0..UINT_LIMBS) + .map(|limb_index| { + public_values + .public_io + .iter() + .map(|value| { + F::from_canonical_u16(((value >> (limb_index * LIMB_BITS)) & LIMB_MASK) as u16) + }) + .collect_vec() + }) + .collect_vec() +} + +fn pi_evals_from_raw_pi(raw_pi: &[F]) -> Vec { + raw_pi.iter().map(|v| E::from(*v)).collect_vec() +} + pub fn decompose_minus_one_bits(n: usize) -> Vec { let a = if n > 0 { n - 1 } else { 0 }; let mut bit_decomp: Vec = vec![]; @@ -71,7 +114,8 @@ pub fn decompose_prefixed_layer_bits(n: usize) -> (Vec, Vec>) { #[derive(DslVariable, Clone)] pub struct ZKVMProofInputVariable { pub shard_id: Usize, - pub raw_pi: Array>>, + pub raw_pi: Array>, + pub mles: Array>>, pub raw_pi_num_variables: Array>, pub pi_evals: Array>, pub chip_proofs: Array>>, @@ -95,7 +139,9 @@ pub struct TowerProofInputVariable { pub(crate) struct ZKVMProofInput { pub shard_id: usize, - pub raw_pi: Vec>, + pub raw_pi: Vec, + pub mles: Vec>, + pub raw_pi_num_variables: Vec, // Evaluation of raw_pi. pub pi_evals: Vec, pub chip_proofs: BTreeMap, @@ -109,6 +155,14 @@ impl ZKVMProofInput { zkvm_proof: ZKVMProof, vk: &ZKVMVerifyingKey, ) -> Self { + let raw_pi = raw_pi_from_public_values(&zkvm_proof.public_values); + let mles = mles_from_public_values(&zkvm_proof.public_values); + let raw_pi_num_variables = mles + .iter() + .map(|v| ceil_log2(v.len().next_power_of_two())) + .collect::>(); + let pi_evals = pi_evals_from_raw_pi(&raw_pi); + let mut chip_witin_num_vars: HashMap = HashMap::new(); // (chip_id, (num_witin, num_fixed)) let mut chip_indices = zkvm_proof .chip_proofs @@ -136,8 +190,10 @@ impl ZKVMProofInput { ZKVMProofInput { shard_id, - raw_pi: zkvm_proof.raw_pi, - pi_evals: zkvm_proof.pi_evals, + raw_pi, + mles, + raw_pi_num_variables, + pi_evals, chip_proofs: zkvm_proof .chip_proofs .into_iter() @@ -168,7 +224,8 @@ impl Hintable for ZKVMProofInput { fn read(builder: &mut Builder) -> Self::HintVariable { let shard_id = Usize::Var(usize::read(builder)); - let raw_pi = Vec::>::read(builder); + let raw_pi = Vec::::read(builder); + let mles = Vec::>::read(builder); let raw_pi_num_variables = Vec::::read(builder); let pi_evals = Vec::::read(builder); builder.cycle_tracker_start("read chip proofs"); @@ -187,6 +244,7 @@ impl Hintable for ZKVMProofInput { ZKVMProofInputVariable { shard_id, raw_pi, + mles, raw_pi_num_variables, pi_evals, chip_proofs, @@ -201,11 +259,6 @@ impl Hintable for ZKVMProofInput { fn write(&self) -> Vec::N>> { let mut stream = Vec::new(); - let raw_pi_num_variables: Vec = self - .raw_pi - .iter() - .map(|v| ceil_log2(v.len().next_power_of_two())) - .collect(); let witin_num_vars = self .chip_proofs .iter() @@ -217,21 +270,21 @@ impl Hintable for ZKVMProofInput { .chip_proofs .iter() .flat_map(|(_, proofs)| proofs.iter()) - .map(|proof| proof.wits_in_evals.len().max(1)) + .map(|proof| proof.num_witin.max(1)) .collect::>(); let fixed_num_vars = self .chip_proofs .iter() .flat_map(|(_, proofs)| proofs.iter()) - .filter(|proof| !proof.fixed_in_evals.is_empty()) + .filter(|proof| proof.num_fixed > 0) .map(|proof| proof.num_vars) .collect::>(); let fixed_max_widths = self .chip_proofs .iter() .flat_map(|(_, proofs)| proofs.iter()) - .filter(|proof| !proof.fixed_in_evals.is_empty()) - .map(|proof| proof.fixed_in_evals.len()) + .filter(|proof| proof.num_fixed > 0) + .map(|proof| proof.num_fixed) .collect::>(); let max_num_var = witin_num_vars .iter() @@ -264,7 +317,8 @@ impl Hintable for ZKVMProofInput { stream.extend(>::write(&self.shard_id)); stream.extend(self.raw_pi.write()); - stream.extend(raw_pi_num_variables.write()); + stream.extend(self.mles.write()); + stream.extend(self.raw_pi_num_variables.write()); stream.extend(self.pi_evals.write()); stream.extend(vec![vec![F::from_canonical_usize(self.chip_proofs.len())]]); for proofs in self.chip_proofs.values() { @@ -403,9 +457,6 @@ pub struct ZKVMChipProofInput { pub ecc_proof: EccQuarkProofInput, pub num_instances: Vec, - - pub wits_in_evals: Vec, - pub fixed_in_evals: Vec, } impl VecAutoHintable for ZKVMChipProofInput {} @@ -499,8 +550,6 @@ impl From<(usize, ZKVMChipProof, usize, usize)> for ZKVMChipProofInput { EccQuarkProofInput::dummy() }, num_instances: p.num_instances, - wits_in_evals: p.wits_in_evals, - fixed_in_evals: p.fixed_in_evals, } } } @@ -531,9 +580,6 @@ pub struct ZKVMChipProofInputVariable { pub num_instances: Array>, pub n_inst_0_bit_decomps: Array>, pub n_inst_1_bit_decomps: Array>, - - pub fixed_in_evals: Array>, - pub wits_in_evals: Array>, } impl Hintable for ZKVMChipProofInput { type HintVariable = ZKVMChipProofInputVariable; @@ -571,11 +617,6 @@ impl Hintable for ZKVMChipProofInput { let n_inst_0_bit_decomps = Vec::::read(builder); let n_inst_1_bit_decomps = Vec::::read(builder); - builder.cycle_tracker_start("read wit/fixed evals"); - let fixed_in_evals = Vec::::read(builder); - let wits_in_evals = Vec::::read(builder); - builder.cycle_tracker_end("read wit/fixed evals"); - ZKVMChipProofInputVariable { idx, idx_felt, @@ -597,8 +638,6 @@ impl Hintable for ZKVMChipProofInput { num_instances, n_inst_0_bit_decomps, n_inst_1_bit_decomps, - fixed_in_evals, - wits_in_evals, } } @@ -674,9 +713,6 @@ impl Hintable for ZKVMChipProofInput { stream.extend(n_inst_0_bit_decomps.write()); stream.extend(n_inst_1_bit_decomps.write()); - stream.extend(self.fixed_in_evals.write()); - stream.extend(self.wits_in_evals.write()); - stream } } diff --git a/ceno_recursion/src/zkvm_verifier/verifier.rs b/ceno_recursion/src/zkvm_verifier/verifier.rs index e0d823814..743ce9ad1 100644 --- a/ceno_recursion/src/zkvm_verifier/verifier.rs +++ b/ceno_recursion/src/zkvm_verifier/verifier.rs @@ -26,7 +26,10 @@ use crate::{ SepticExtensionVariable, SepticPointVariable, SumcheckLayerProofVariable, }, }; -use ceno_zkvm::structs::{ComposedConstrainSystem, VerifyingKey, ZKVMVerifyingKey}; +use ceno_zkvm::{ + instructions::riscv::constants::{END_CYCLE_IDX, END_PC_IDX, INIT_CYCLE_IDX, INIT_PC_IDX}, + structs::{ComposedConstrainSystem, VerifyingKey, ZKVMVerifyingKey}, +}; use ff_ext::BabyBearExt4; use crate::transcript::{challenger_add_forked_index, clone_challenger_state}; @@ -104,21 +107,20 @@ pub fn verify_zkvm_proof>( let prod_w: Ext = builder.constant(C::EF::ONE); let logup_sum: Ext = builder.constant(C::EF::ZERO); - iter_zip!(builder, zkvm_proof_input.raw_pi).for_each(|ptr_vec, builder| { - let v = builder.iter_ptr_get(&zkvm_proof_input.raw_pi, ptr_vec[0]); - challenger_multi_observe(builder, &mut challenger, &v); - }); + for (_, circuit_vk) in vk.circuit_vks.iter() { + for instance_value in circuit_vk.get_cs().zkvm_v1_css.instance_values.iter() { + let raw = builder.get(&zkvm_proof_input.raw_pi, instance_value.0); + // Match native verifier transcript behavior: append base-field PI element directly. + challenger.observe(builder, raw); + } + } iter_zip!(builder, zkvm_proof_input.raw_pi, zkvm_proof_input.pi_evals).for_each( |ptr_vec, builder| { let raw = builder.iter_ptr_get(&zkvm_proof_input.raw_pi, ptr_vec[0]); let eval = builder.iter_ptr_get(&zkvm_proof_input.pi_evals, ptr_vec[1]); - let raw0 = builder.get(&raw, 0); - - builder.if_eq(raw.len(), Usize::from(1)).then(|builder| { - let raw0_ext = builder.ext_from_base_slice(&[raw0]); - builder.assert_ext_eq(raw0_ext, eval); - }); + let raw_ext = builder.ext_from_base_slice(&[raw]); + builder.assert_ext_eq(raw_ext, eval); }, ); @@ -273,14 +275,6 @@ pub fn verify_zkvm_proof>( // fork transcript to support chip concurrently proved let mut chip_challenger = clone_challenger_state(builder, &challenger); challenger_add_forked_index(builder, &mut chip_challenger, &forked_sample_index); - builder.assert_usize_eq( - chip_proof.wits_in_evals.len(), - Usize::from(circuit_vk.get_cs().num_witin()), - ); - builder.assert_usize_eq( - chip_proof.fixed_in_evals.len(), - Usize::from(circuit_vk.get_cs().num_fixed()), - ); builder.assert_usize_eq( chip_proof.rw_out_evals.length.clone(), Usize::from( @@ -336,8 +330,8 @@ pub fn verify_zkvm_proof>( builder, &mut chip_challenger, &chip_proof, - &zkvm_proof_input.pi_evals, &zkvm_proof_input.raw_pi, + &zkvm_proof_input.mles, &zkvm_proof_input.raw_pi_num_variables, &challenges, chip_vk, @@ -373,13 +367,20 @@ pub fn verify_zkvm_proof>( let point_clone: Array> = builder.eval(input_opening_point.clone()); + let (wits_in_evals, fixed_in_evals, _pi_in_evals) = split_input_opening_evals( + builder, + &chip_proof, + circuit_vk.get_cs().num_witin(), + circuit_vk.get_cs().num_fixed(), + circuit_vk.get_cs().instance_openings().len(), + ); if circuit_vk.get_cs().num_witin() > 0 { let witin_round: RoundOpeningVariable = builder.eval(RoundOpeningVariable { num_var: input_opening_point.len().get_var(), point_and_evals: PointAndEvalsVariable { point: PointVariable { fs: point_clone }, - evals: chip_proof.wits_in_evals, + evals: wits_in_evals, }, }); builder.set_value(&witin_openings, num_witin_openings.get_var(), witin_round); @@ -392,7 +393,7 @@ pub fn verify_zkvm_proof>( point: PointVariable { fs: input_opening_point, }, - evals: chip_proof.fixed_in_evals, + evals: fixed_in_evals, }, }); @@ -522,6 +523,17 @@ pub fn verify_zkvm_proof>( &unipoly_extrapolator, &mut challenger, ); + // Global-state expressions are defined over compact/query-order PI slots. + // Keep this aligned with ceno_zkvm verifier: [init_pc, init_cycle, end_pc, end_cycle]. + let global_state_pi_evals: Array> = builder.dyn_array(4); + [INIT_PC_IDX, INIT_CYCLE_IDX, END_PC_IDX, END_CYCLE_IDX] + .into_iter() + .enumerate() + .for_each(|(i, idx)| { + let raw = builder.get(&zkvm_proof_input.raw_pi, idx); + let eval = builder.ext_from_base_slice(&[raw]); + builder.set(&global_state_pi_evals, i, eval); + }); let empty_arr: Array> = builder.dyn_array(0); let initial_global_state = eval_ceno_expr_with_instance( @@ -529,7 +541,7 @@ pub fn verify_zkvm_proof>( &empty_arr, &empty_arr, &empty_arr, - &zkvm_proof_input.pi_evals, + &global_state_pi_evals, &challenges, &vk.initial_global_state_expr, ); @@ -540,7 +552,7 @@ pub fn verify_zkvm_proof>( &empty_arr, &empty_arr, &empty_arr, - &zkvm_proof_input.pi_evals, + &global_state_pi_evals, &challenges, &vk.finalize_global_state_expr, ); @@ -556,13 +568,43 @@ pub fn verify_zkvm_proof>( shard_ec_sum } +fn split_input_opening_evals( + builder: &mut Builder, + chip_proof: &ZKVMChipProofInputVariable, + num_witin: usize, + num_fixed: usize, + num_pi: usize, +) -> ( + Array>, + Array>, + Array>, +) { + let last_layer_idx: Usize = + builder.eval(chip_proof.gkr_iop_proof.layer_proofs.len() - Usize::from(1)); + let last_layer = builder.get(&chip_proof.gkr_iop_proof.layer_proofs, last_layer_idx); + let main_evals = last_layer.main.evals; + + let wit_end = Usize::from(num_witin); + let fixed_end: Usize = builder.eval(wit_end.clone() + Usize::from(num_fixed)); + let pi_end: Usize = builder.eval(fixed_end.clone() + Usize::from(num_pi)); + // Native verifier accepts extra trailing evals; only the prefix is consumed here. + // Keep recursion semantics aligned by slicing the required prefix. + let eval_prefix = main_evals.slice(builder, Usize::from(0), pi_end.clone()); + + ( + eval_prefix.slice(builder, Usize::from(0), wit_end), + eval_prefix.slice(builder, Usize::from(num_witin), fixed_end), + eval_prefix.slice(builder, Usize::from(num_witin + num_fixed), pi_end), + ) +} + pub fn verify_chip_proof( circuit_name: &str, builder: &mut Builder, challenger: &mut DuplexChallengerVariable, chip_proof: &ZKVMChipProofInputVariable, - pi_evals: &Array>, - raw_pi: &Array>>, + raw_pi: &Array>, + mles: &Array>>, raw_pi_num_variables: &Array>, challenges: &Array>, vk: &VerifyingKey, @@ -709,6 +751,13 @@ pub fn verify_chip_proof( builder.set(&q_slice, idx_vec[0], cpt); }); let gkr_circuit = gkr_circuit.clone().unwrap(); + let circuit_pi_evals: Array> = + builder.dyn_array(Usize::from(cs.instance_values.len())); + for (i, instance) in cs.instance_values.iter().enumerate() { + let raw = builder.get(raw_pi, instance.0); + let eval = builder.ext_from_base_slice(&[raw]); + builder.set(&circuit_pi_evals, i, eval); + } let zero_bit_decomps: Array> = builder.dyn_array(32); let selector_ctxs: Vec> = if cs.ec_final_sum.is_empty() { @@ -807,11 +856,10 @@ pub fn verify_chip_proof( gkr_circuit, &chip_proof.gkr_iop_proof, challenges, - pi_evals, - raw_pi, + &circuit_pi_evals, + mles, raw_pi_num_variables, &out_evals, - chip_proof, selector_ctxs, unipoly_extrapolator, poly_evaluator, @@ -829,10 +877,9 @@ pub fn verify_gkr_circuit( gkr_proof: &GKRProofVariable, challenges: &Array>, pub_io_evals: &Array>, - raw_pi: &Array>>, + mles: &Array>>, raw_pi_num_variables: &Array>, claims: &Array>, - _chip_proof: &ZKVMChipProofInputVariable, selector_ctxs: Vec>, unipoly_extrapolator: &UniPolyExtrapolator, poly_evaluator: &mut PolyEvaluator, @@ -1127,7 +1174,7 @@ pub fn verify_gkr_circuit( let pubio_offset = layer.n_witin + layer.n_fixed; for (index, instance) in layer.instance_openings.iter().enumerate() { let index: usize = pubio_offset + index; - let poly = builder.get(raw_pi, instance.0); + let poly = builder.get(mles, instance.0); let num_variable = builder.get(raw_pi_num_variables, instance.0); let in_point_slice = in_point.slice(builder, 0, num_variable); let expected_eval = diff --git a/ceno_zkvm/src/bin/e2e.rs b/ceno_zkvm/src/bin/e2e.rs index 09e5a1131..396321724 100644 --- a/ceno_zkvm/src/bin/e2e.rs +++ b/ceno_zkvm/src/bin/e2e.rs @@ -19,7 +19,6 @@ use gkr_iop::hal::ProverBackend; use mpcs::{ Basefold, BasefoldRSParams, PolynomialCommitmentScheme, SecurityLevel, Whir, WhirDefaultSpec, }; -use p3::field::FieldAlgebra; use serde::{Serialize, de::DeserializeOwned}; use std::{fs, panic, panic::AssertUnwindSafe, path::PathBuf}; use tracing::{error, level_filters::LevelFilter}; @@ -404,8 +403,7 @@ fn soundness_test>( // do sanity check let transcript = Transcript::new(b"riscv"); // change public input maliciously should cause verifier to reject proof - zkvm_proof.raw_pi[0] = vec![E::BaseField::ONE]; - zkvm_proof.raw_pi[1] = vec![E::BaseField::ONE]; + zkvm_proof.public_values.exit_code = 1; // capture panic message, if have let result = with_panic_hook(Box::new(|_info| ()), || { @@ -428,7 +426,7 @@ fn soundness_test>( unreachable!() }; - if !msg.starts_with("0th round's prover message is not consistent with the claim") { + if !msg.starts_with("assertion `left == right` failed") { error!("unknown panic {msg:?}"); panic::resume_unwind(err); }; diff --git a/ceno_zkvm/src/chip_handler/general.rs b/ceno_zkvm/src/chip_handler/general.rs index bb63ce504..bea259603 100644 --- a/ceno_zkvm/src/chip_handler/general.rs +++ b/ceno_zkvm/src/chip_handler/general.rs @@ -27,12 +27,12 @@ pub trait PublicValuesQuery { fn query_public_io(&mut self) -> Result<[Instance; UINT_LIMBS], CircuitBuilderError>; #[allow(dead_code)] fn query_shard_id(&mut self) -> Result; - fn query_heap_start_addr(&self) -> Result; + fn query_heap_start_addr(&mut self) -> Result; #[allow(dead_code)] - fn query_heap_shard_len(&self) -> Result; - fn query_hint_start_addr(&self) -> Result; + fn query_heap_shard_len(&mut self) -> Result; + fn query_hint_start_addr(&mut self) -> Result; #[allow(dead_code)] - fn query_hint_shard_len(&self) -> Result; + fn query_hint_shard_len(&mut self) -> Result; } impl<'a, E: ExtensionField> InstFetch for CircuitBuilder<'a, E> { @@ -95,18 +95,18 @@ impl<'a, E: ExtensionField> PublicValuesQuery for CircuitBuilder<'a, E> { self.cs.query_instance(SHARD_ID_IDX) } - fn query_heap_start_addr(&self) -> Result { + fn query_heap_start_addr(&mut self) -> Result { self.cs.query_instance(HEAP_START_ADDR_IDX) } - fn query_heap_shard_len(&self) -> Result { + fn query_heap_shard_len(&mut self) -> Result { self.cs.query_instance(HEAP_LENGTH_IDX) } - fn query_hint_start_addr(&self) -> Result { + fn query_hint_start_addr(&mut self) -> Result { self.cs.query_instance(HINT_START_ADDR_IDX) } - fn query_hint_shard_len(&self) -> Result { + fn query_hint_shard_len(&mut self) -> Result { self.cs.query_instance(HINT_LENGTH_IDX) } } diff --git a/ceno_zkvm/src/e2e.rs b/ceno_zkvm/src/e2e.rs index 215cbf7b6..0c3d936b7 100644 --- a/ceno_zkvm/src/e2e.rs +++ b/ceno_zkvm/src/e2e.rs @@ -1021,7 +1021,7 @@ pub fn emulate_program<'a>( platform.hints.start, hints_final.len() as u32, io_init.iter().map(|rec| rec.value).collect_vec(), - vec![0; SEPTIC_EXTENSION_DEGREE * 2], // point_at_infinity + [0; SEPTIC_EXTENSION_DEGREE * 2], // point_at_infinity ); #[cfg(debug_assertions)] diff --git a/ceno_zkvm/src/instructions/riscv/constants.rs b/ceno_zkvm/src/instructions/riscv/constants.rs index 61e673246..2ac17f528 100644 --- a/ceno_zkvm/src/instructions/riscv/constants.rs +++ b/ceno_zkvm/src/instructions/riscv/constants.rs @@ -3,6 +3,8 @@ pub use ceno_emul::PC_STEP_SIZE; pub const ECALL_HALT_OPCODE: [usize; 2] = [0x00_00, 0x00_00]; pub const EXIT_PC: usize = 0; + +/// scalar-based public value, id start from 0 pub const EXIT_CODE_IDX: usize = 0; // exit code u32 occupied 2 limb, each with 16 pub const INIT_PC_IDX: usize = EXIT_CODE_IDX + 2; @@ -14,8 +16,11 @@ pub const HEAP_START_ADDR_IDX: usize = SHARD_ID_IDX + 1; pub const HEAP_LENGTH_IDX: usize = HEAP_START_ADDR_IDX + 1; pub const HINT_START_ADDR_IDX: usize = HEAP_LENGTH_IDX + 1; pub const HINT_LENGTH_IDX: usize = HINT_START_ADDR_IDX + 1; -pub const PUBLIC_IO_IDX: usize = HINT_LENGTH_IDX + 1; -pub const SHARD_RW_SUM_IDX: usize = PUBLIC_IO_IDX + 2; + +pub const SHARD_RW_SUM_IDX: usize = HINT_LENGTH_IDX + 1; + +/// vector-based public value, id start from 0 +pub const PUBLIC_IO_IDX: usize = 0; pub const LIMB_BITS: usize = 16; pub const LIMB_MASK: u32 = 0xFFFF; diff --git a/ceno_zkvm/src/scheme.rs b/ceno_zkvm/src/scheme.rs index fa9127f10..b105eb7f9 100644 --- a/ceno_zkvm/src/scheme.rs +++ b/ceno_zkvm/src/scheme.rs @@ -3,6 +3,7 @@ use ff_ext::ExtensionField; use gkr_iop::gkr::GKRProof; use itertools::Itertools; use mpcs::PolynomialCommitmentScheme; +use multilinear_extensions::mle::{IntoMLE, MultilinearExtension}; use p3::field::FieldAlgebra; use serde::{Deserialize, Serialize, de::DeserializeOwned}; use std::{ @@ -18,10 +19,15 @@ use crate::{ instructions::{ Instruction, riscv::{ - constants::{LIMB_BITS, LIMB_MASK, UINT_LIMBS}, + constants::{ + END_CYCLE_IDX, END_PC_IDX, EXIT_CODE_IDX, HEAP_LENGTH_IDX, HEAP_START_ADDR_IDX, + HINT_LENGTH_IDX, HINT_START_ADDR_IDX, INIT_CYCLE_IDX, INIT_PC_IDX, LIMB_BITS, + LIMB_MASK, SHARD_ID_IDX, SHARD_RW_SUM_IDX, UINT_LIMBS, + }, ecall::HaltInstruction, }, }, + scheme::constants::SEPTIC_EXTENSION_DEGREE, structs::{TowerProofs, ZKVMVerifyingKey}, }; @@ -65,13 +71,10 @@ pub struct ZKVMChipProof { pub ecc_proof: Option>, pub num_instances: Vec, - - pub fixed_in_evals: Vec, - pub wits_in_evals: Vec, } /// each field will be interpret to (constant) polynomial -#[derive(Default, Clone, Debug)] +#[derive(Default, Clone, Debug, Serialize, Deserialize)] pub struct PublicValues { pub exit_code: u32, pub init_pc: u32, @@ -84,7 +87,7 @@ pub struct PublicValues { pub hint_start_addr: u32, pub hint_shard_len: u32, pub public_io: Vec, - pub shard_rw_sum: Vec, + pub shard_rw_sum: [u32; SEPTIC_EXTENSION_DEGREE * 2], } impl PublicValues { @@ -101,7 +104,7 @@ impl PublicValues { hint_start_addr: u32, hint_shard_len: u32, public_io: Vec, - shard_rw_sum: Vec, + shard_rw_sum: [u32; SEPTIC_EXTENSION_DEGREE * 2], ) -> Self { Self { exit_code, @@ -118,45 +121,52 @@ impl PublicValues { shard_rw_sum, } } - pub fn to_vec(&self) -> Vec> { - vec![ - vec![E::BaseField::from_canonical_u32(self.exit_code & 0xffff)], - vec![E::BaseField::from_canonical_u32( - (self.exit_code >> 16) & 0xffff, - )], - vec![E::BaseField::from_canonical_u32(self.init_pc)], - vec![E::BaseField::from_canonical_u64(self.init_cycle)], - vec![E::BaseField::from_canonical_u32(self.end_pc)], - vec![E::BaseField::from_canonical_u64(self.end_cycle)], - vec![E::BaseField::from_canonical_u32(self.shard_id)], - vec![E::BaseField::from_canonical_u32(self.heap_start_addr)], - vec![E::BaseField::from_canonical_u32(self.heap_shard_len)], - vec![E::BaseField::from_canonical_u32(self.hint_start_addr)], - vec![E::BaseField::from_canonical_u32(self.hint_shard_len)], - ] - .into_iter() - .chain( - // public io processed into UINT_LIMBS column - (0..UINT_LIMBS) - .map(|limb_index| { - self.public_io - .iter() - .map(|value| { - E::BaseField::from_canonical_u16( - ((value >> (limb_index * LIMB_BITS)) & LIMB_MASK) as u16, - ) - }) - .collect_vec() - }) - .collect_vec(), - ) - .chain( - self.shard_rw_sum - .iter() - .map(|value| vec![E::BaseField::from_canonical_u32(*value)]) - .collect_vec(), - ) - .collect::>() + pub fn query_by_index(&self, index: usize) -> E::BaseField { + match index { + EXIT_CODE_IDX => E::BaseField::from_canonical_u32(self.exit_code & 0xffff), + idx if idx == EXIT_CODE_IDX + 1 => { + E::BaseField::from_canonical_u32((self.exit_code >> 16) & 0xffff) + } + INIT_PC_IDX => E::BaseField::from_canonical_u32(self.init_pc), + INIT_CYCLE_IDX => E::BaseField::from_canonical_u64(self.init_cycle), + END_PC_IDX => E::BaseField::from_canonical_u32(self.end_pc), + END_CYCLE_IDX => E::BaseField::from_canonical_u64(self.end_cycle), + SHARD_ID_IDX => E::BaseField::from_canonical_u32(self.shard_id), + HEAP_START_ADDR_IDX => E::BaseField::from_canonical_u32(self.heap_start_addr), + HEAP_LENGTH_IDX => E::BaseField::from_canonical_u32(self.heap_shard_len), + HINT_START_ADDR_IDX => E::BaseField::from_canonical_u32(self.hint_start_addr), + HINT_LENGTH_IDX => E::BaseField::from_canonical_u32(self.hint_shard_len), + idx if (SHARD_RW_SUM_IDX..(SHARD_RW_SUM_IDX + SEPTIC_EXTENSION_DEGREE * 2)) + .contains(&idx) => + { + E::BaseField::from_canonical_u32(self.shard_rw_sum[idx - SHARD_RW_SUM_IDX]) + } + _ => panic!("public value index {index} out of range"), + } + } + + pub fn mles(&self) -> Vec> { + // public_io is represented as UINT_LIMBS columns. + (0..UINT_LIMBS) + .map(|limb_index| { + let limb_values = self + .public_io + .iter() + .map(|value| { + E::BaseField::from_canonical_u16( + ((value >> (limb_index * LIMB_BITS)) & LIMB_MASK) as u16, + ) + }) + .collect_vec(); + + // Empty public_io means a constant-zero public input column. + if limb_values.is_empty() { + vec![E::BaseField::ZERO].into_mle() + } else { + limb_values.into_mle() + } + }) + .collect_vec() } } @@ -169,11 +179,7 @@ impl PublicValues { deserialize = "E::BaseField: DeserializeOwned" ))] pub struct ZKVMProof> { - // TODO preserve in serde only for auxiliary public input - // other raw value can be construct by verifier directly. - pub raw_pi: Vec>, - // the evaluation of raw_pi. - pub pi_evals: Vec, + pub public_values: PublicValues, // each circuit may have multiple proof instances pub chip_proofs: BTreeMap>>, pub witin_commit: >::Commitment, @@ -182,41 +188,19 @@ pub struct ZKVMProof> { impl> ZKVMProof { pub fn new( - raw_pi: Vec>, - pi_evals: Vec, + public_values: PublicValues, chip_proofs: BTreeMap>>, witin_commit: >::Commitment, opening_proof: PCS::Proof, ) -> Self { Self { - raw_pi, - pi_evals, + public_values, chip_proofs, witin_commit, opening_proof, } } - pub fn pi_evals(raw_pi: &[Vec]) -> Vec { - raw_pi - .iter() - .map(|pv| { - if pv.len() == 1 { - // this is constant poly, and always evaluate to same constant value - E::from(pv[0]) - } else { - // set 0 as placeholder. will be evaluate lazily - // Or the vector is empty, i.e. the constant 0 polynomial. - E::ZERO - } - }) - .collect_vec() - } - - pub fn update_pi_eval(&mut self, idx: usize, v: E) { - self.pi_evals[idx] = v; - } - pub fn num_circuits(&self) -> usize { self.chip_proofs.len() } diff --git a/ceno_zkvm/src/scheme/cpu/mod.rs b/ceno_zkvm/src/scheme/cpu/mod.rs index a773f4de9..85decd78e 100644 --- a/ceno_zkvm/src/scheme/cpu/mod.rs +++ b/ceno_zkvm/src/scheme/cpu/mod.rs @@ -842,11 +842,8 @@ impl> MainSumcheckProver> MainSumcheckProver { pub struct MainSumcheckEvals { pub wits_in_evals: Vec, pub fixed_in_evals: Vec, + pub pi_in_evals: Vec, } pub trait MainSumcheckProver { diff --git a/ceno_zkvm/src/scheme/mock_prover.rs b/ceno_zkvm/src/scheme/mock_prover.rs index a477b1c6a..20f4da1b1 100644 --- a/ceno_zkvm/src/scheme/mock_prover.rs +++ b/ceno_zkvm/src/scheme/mock_prover.rs @@ -26,7 +26,7 @@ use gkr_iop::{ use itertools::{Itertools, chain, enumerate, izip}; use multilinear_extensions::{ Expression, WitnessId, fmt, - mle::{ArcMultilinearExtension, IntoMLEs, MultilinearExtension}, + mle::{ArcMultilinearExtension, MultilinearExtension}, util::ceil_log2, utils::{eval_by_expr, eval_by_expr_with_fixed, eval_by_expr_with_instance}, }; @@ -40,7 +40,6 @@ use std::{ hash::Hash, io::{BufReader, ErrorKind}, marker::PhantomData, - ops::Index, sync::OnceLock, }; use strum::IntoEnumIterator; @@ -964,17 +963,22 @@ Hints: ) where E: LkMultiplicityKey, { - let pub_io_evals = pi - .to_vec::() - .into_iter() - .map(|v| Either::Right(E::from(*v.index(0)))) - .collect_vec(); - let pi_mles: Vec> = pi - .to_vec::() - .into_mles() - .into_iter() - .map(|v| v.into()) - .collect_vec(); + let all_pi_mles: Vec> = + pi.mles::().into_iter().map(|v| v.into()).collect_vec(); + let get_circuit_pi_inputs = |circuit_cs: &ConstraintSystem| { + let circuit_pub_io_evals = circuit_cs + .instance_values + .iter() + .map(|instance| Either::Right(E::from(pi.query_by_index::(instance.0)))) + .collect_vec(); + let circuit_pi_mles = circuit_cs + .instance_openings + .iter() + .map(|instance| all_pi_mles[instance.0].clone()) + .collect_vec(); + (circuit_pub_io_evals, circuit_pi_mles) + }; + let mut rng = thread_rng(); let challenges = [0u8; 2].map(|_| E::random(&mut rng)); @@ -1000,11 +1004,7 @@ Hints: let ComposedConstrainSystem { zkvm_v1_css: cs, .. } = &composed_cs; - let pi_mles = cs - .instance_openings - .iter() - .map(|instance| pi_mles[instance.0].clone()) - .collect_vec(); + let (circuit_pub_io_evals, circuit_pi_mles) = get_circuit_pi_inputs(cs); // skip init table on non-first shard if composed_cs.with_omc_init_only() && !shard_ctx.is_first_shard() { @@ -1061,8 +1061,8 @@ Hints: &fixed, &witness, &structural_witness, - &pi_mles, - &pub_io_evals, + &circuit_pi_mles, + &circuit_pub_io_evals, num_rows, challenges, lkm_from_assignments, @@ -1093,8 +1093,8 @@ Hints: &fixed, &witness, &structural_witness, - &pi_mles, - &pub_io_evals, + &circuit_pi_mles, + &circuit_pub_io_evals, &challenges, ) .get_ext_field_vec() @@ -1108,8 +1108,8 @@ Hints: &fixed, &witness, &structural_witness, - &pi_mles, - &pub_io_evals, + &circuit_pi_mles, + &circuit_pub_io_evals, &challenges, ) .get_ext_field_vec() @@ -1162,11 +1162,7 @@ Hints: let fixed = fixed_mles.get(circuit_name).unwrap(); let witness = wit_mles.get(circuit_name).unwrap(); let structural_witness = structural_wit_mles.get(circuit_name).unwrap(); - let pi_mles = cs - .instance_openings - .iter() - .map(|instance| pi_mles[instance.0].clone()) - .collect_vec(); + let (circuit_pub_io_evals, circuit_pi_mles) = get_circuit_pi_inputs(cs); let num_rows = num_instances.get(circuit_name).unwrap(); if *num_rows == 0 { @@ -1204,8 +1200,8 @@ Hints: fixed, witness, structural_witness, - &pi_mles, - &pub_io_evals, + &circuit_pi_mles, + &circuit_pub_io_evals, &challenges, ); let ram_type_vec = ram_type_mle.get_ext_field_vec(); @@ -1217,8 +1213,8 @@ Hints: fixed, witness, structural_witness, - &pi_mles, - &pub_io_evals, + &circuit_pi_mles, + &circuit_pub_io_evals, &challenges, ); let w_selector_vec = w_selector.get_base_field_vec(); @@ -1268,11 +1264,7 @@ Hints: let fixed = fixed_mles.get(circuit_name).unwrap(); let witness = wit_mles.get(circuit_name).unwrap(); let structural_witness = structural_wit_mles.get(circuit_name).unwrap(); - let pi_mles = cs - .instance_openings - .iter() - .map(|instance| pi_mles[instance.0].clone()) - .collect_vec(); + let (circuit_pub_io_evals, circuit_pi_mles) = get_circuit_pi_inputs(cs); let num_rows = num_instances.get(circuit_name).unwrap(); if *num_rows == 0 { continue; @@ -1308,8 +1300,8 @@ Hints: fixed, witness, structural_witness, - &pi_mles, - &pub_io_evals, + &circuit_pi_mles, + &circuit_pub_io_evals, &challenges, ); let ram_type_vec = ram_type_mle.get_ext_field_vec(); @@ -1321,8 +1313,8 @@ Hints: fixed, witness, structural_witness, - &pi_mles, - &pub_io_evals, + &circuit_pi_mles, + &circuit_pub_io_evals, &challenges, ); let r_selector_vec = r_selector.get_base_field_vec(); @@ -1349,8 +1341,8 @@ Hints: fixed, witness, structural_witness, - &pi_mles, - &pub_io_evals, + &circuit_pi_mles, + &circuit_pub_io_evals, &challenges, ); filter_mle_by_selector_mle(v, r_selector.clone()) @@ -1481,38 +1473,23 @@ Hints: let mut cb = CircuitBuilder::new(&mut cs); let gs_init = GlobalState::initial_global_state(&mut cb).unwrap(); let gs_final = GlobalState::finalize_global_state(&mut cb).unwrap(); + let gs_pub_io_evals = cs + .instance_values + .iter() + .map(|instance| E::from(pi.query_by_index::(instance.0))) + .collect_vec(); let (mut gs_rs, rs_grp_by_anno, mut gs_ws, ws_grp_by_anno, gs) = derive_ram_rws!(RAMType::GlobalState); gs_rs.insert( - eval_by_expr_with_instance( - &[], - &[], - &[], - &pub_io_evals - .iter() - .map(|v| v.right().unwrap()) - .collect_vec(), - &challenges, - &gs_final, - ) - .right() - .unwrap(), + eval_by_expr_with_instance(&[], &[], &[], &gs_pub_io_evals, &challenges, &gs_final) + .right() + .unwrap(), ); gs_ws.insert( - eval_by_expr_with_instance( - &[], - &[], - &[], - &pub_io_evals - .iter() - .map(|v| v.right().unwrap()) - .collect_vec(), - &challenges, - &gs_init, - ) - .right() - .unwrap(), + eval_by_expr_with_instance(&[], &[], &[], &gs_pub_io_evals, &challenges, &gs_init) + .right() + .unwrap(), ); // gs stores { (pc, timestamp) } diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index 99ad1ac39..96b28ce1a 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -3,11 +3,7 @@ use gkr_iop::{ cpu::{CpuBackend, CpuProver}, hal::ProverBackend, }; -use std::{ - collections::{BTreeMap, HashMap}, - marker::PhantomData, - sync::Arc, -}; +use std::{collections::BTreeMap, marker::PhantomData, sync::Arc}; #[cfg(feature = "gpu")] use crate::scheme::gpu::estimate_chip_proof_memory; @@ -17,13 +13,9 @@ use crate::scheme::{ scheduler::{ChipScheduler, ChipTask, ChipTaskResult}, }; use either::Either; -use gkr_iop::hal::MultilinearPolynomial; use itertools::Itertools; use mpcs::{Point, PolynomialCommitmentScheme}; -use multilinear_extensions::{ - Expression, Instance, - mle::{IntoMLE, MultilinearExtension}, -}; +use multilinear_extensions::{Expression, Instance}; use p3::field::FieldAlgebra; use std::iter::Iterator; use sumcheck::{ @@ -39,6 +31,7 @@ use crate::structs::ProvingKey; use crate::{ e2e::ShardContext, error::ZKVMError, + instructions::riscv::constants::UINT_LIMBS, scheme::{ hal::{DeviceProvingKey, ProofInput}, utils::build_main_witness, @@ -46,7 +39,7 @@ use crate::{ structs::{TowerProofs, ZKVMProvingKey, ZKVMWitnesses}, }; -type CreateTableProof = (ZKVMChipProof, HashMap, Point); +type CreateTableProof = (ZKVMChipProof, MainSumcheckEvals, Point); pub type ZkVMCpuProver = ZKVMProver, CpuProver>>; @@ -150,24 +143,24 @@ impl< .get_device_proving_key(shard_ctx) .map(|dpk| dpk.fixed_mles.clone()) .unwrap_or_default(); + let pi_mles_preload = pi.mles::(); info_span!( "[ceno] create_proof_of_shard", shard_id = shard_ctx.shard_id ) .in_scope(|| { - let raw_pi = pi.to_vec::(); - let mut pi_evals = ZKVMProof::::pi_evals(&raw_pi); - let span = entered_span!("commit_to_pi", profiling_1 = true); - // including raw public input to transcript - for v in raw_pi.iter().flatten() { - transcript.append_field_element(v); + // Include transcript-visible public values in canonical circuit order. + // The order must match verifier and recursion verifier exactly. + // TODO deal with vector-based public value to transcript + for (_, circuit_pk) in self.pk.circuit_pks.iter() { + for instance_value in circuit_pk.get_cs().zkvm_v1_css.instance_values.iter() { + transcript.append_field_element(&pi.query_by_index::(instance_value.0)); + } } - exit_span!(span); - let pi: Vec> = - raw_pi.iter().map(|p| p.to_vec().into_mle()).collect(); + exit_span!(span); // commit to fixed commitment let span = entered_span!("commit_to_fixed_commit", profiling_1 = true); @@ -276,7 +269,7 @@ impl< tracing::debug!("global challenges in prover: {:?}", challenges); let public_input_span = entered_span!("public_input", profiling_1 = true); - let public_input = self.device.transport_mles(&pi); + let public_input = self.device.transport_mles(&pi_mles_preload); exit_span!(public_input_span); let main_proofs_span = entered_span!("main_proofs", profiling_1 = true); @@ -292,7 +285,7 @@ impl< fixed_mles, challenges, public_input, - &pi_evals, + &pi, &circuit_trace_indices, ); exit_span!(build_tasks_span); @@ -307,11 +300,7 @@ impl< // Phase 3: Collect results let collect_results_span = entered_span!("collect_chip_results", profiling_1 = true); - let (chip_proofs, points, evaluations, pi_updates) = - Self::collect_chip_results(results); - for (idx, eval) in pi_updates { - pi_evals[idx] = eval; - } + let (chip_proofs, points, evaluations) = Self::collect_chip_results(results); exit_span!(collect_results_span); exit_span!(main_proofs_span); @@ -335,13 +324,7 @@ impl< }); exit_span!(pcs_opening); - let vm_proof = ZKVMProof::new( - raw_pi, - pi_evals, - chip_proofs, - witin_commit, - mpcs_opening_proof, - ); + let vm_proof = ZKVMProof::new(pi, chip_proofs, witin_commit, mpcs_opening_proof); Ok(vm_proof) }) @@ -389,7 +372,7 @@ impl< let gpu_input: ProofInput<'static, gkr_iop::gpu::GpuBackend> = unsafe { std::mem::transmute(task.input) }; - let (proof, pi_in_evals, input_opening_point) = + let (proof, opening_evals, input_opening_point) = create_chip_proof_gpu_impl::( task.circuit_name.as_str(), task.pk, @@ -406,7 +389,7 @@ impl< task_id: task.task_id, circuit_idx: task.circuit_idx, proof, - pi_in_evals, + opening_evals, input_opening_point, has_witness_or_fixed: task.has_witness_or_fixed, }) @@ -424,14 +407,14 @@ impl< // Prepare: deferred extraction for GPU, no-op for CPU self.device.prepare_chip_input(&mut task, witness_data); - let (proof, pi_in_evals, input_opening_point) = + let (proof, opening_evals, input_opening_point) = self.create_chip_proof(&task, transcript)?; Ok(ChipTaskResult { task_id: task.task_id, circuit_idx: task.circuit_idx, proof, - pi_in_evals, + opening_evals, input_opening_point, has_witness_or_fixed: task.has_witness_or_fixed, }) @@ -520,23 +503,10 @@ impl< let MainSumcheckEvals { wits_in_evals, fixed_in_evals, + pi_in_evals, } = evals; exit_span!(span); - // evaluate pi if there is instance query - let mut pi_in_evals: HashMap = HashMap::new(); - if !cs.instance_openings().is_empty() { - let span = entered_span!("pi::evals", profiling_2 = true); - for &Instance(idx) in cs.instance_openings() { - let poly = &input.public_input[idx]; - pi_in_evals.insert( - idx, - poly.eval(input_opening_point[..poly.num_vars()].to_vec()), - ); - } - exit_span!(span); - } - Ok(( ZKVMChipProof { r_out_evals, @@ -546,11 +516,13 @@ impl< gkr_iop_proof, tower_proof, ecc_proof, - fixed_in_evals, - wits_in_evals, num_instances: input.num_instances.clone(), }, - pi_in_evals, + MainSumcheckEvals { + wits_in_evals, + fixed_in_evals, + pi_in_evals, + }, input_opening_point, )) } @@ -568,7 +540,7 @@ impl< mut fixed_mles: Vec>>, challenges: [E; 2], public_input: Vec>>, - pi_evals: &[E], + pi: &PublicValues, circuit_trace_indices: &[Option], ) -> Vec> { // CPU path: eagerly extract witness MLEs from pcs_data @@ -645,12 +617,31 @@ impl< }; let fixed = fixed_mles.drain(..cs.num_fixed()).collect_vec(); + let public_io = cs + .instance_openings() + .iter() + .map(|Instance(idx)| { + debug_assert!( + *idx < UINT_LIMBS, + "instance_opening index {idx} out of range" + ); + public_input[*idx].clone() + }) + .collect_vec(); + + let pi_evals = cs + .zkvm_v1_css + .instance_values + .iter() + .map(|Instance(idx)| Either::Left(pi.query_by_index::(*idx))) + .collect_vec(); + let input_temp: ProofInput<'_, PB> = ProofInput { witness: witness_mle, fixed, structural_witness, - public_input: public_input.clone(), - pub_io_evals: pi_evals.iter().map(|p| Either::Right(*p)).collect(), + public_input: public_io, + pub_io_evals: pi_evals, num_instances: num_instances.clone(), has_ecc_ops: cs.has_ecc_ops(), }; @@ -718,12 +709,10 @@ impl< BTreeMap>>, Vec>, Vec>>, - HashMap, ) { let mut chip_proofs = BTreeMap::new(); let mut points = Vec::new(); let mut evaluations = Vec::new(); - let mut pi_updates = HashMap::new(); for result in results { tracing::trace!( @@ -735,23 +724,18 @@ impl< if result.has_witness_or_fixed { points.push(result.input_opening_point); evaluations.push(vec![ - result.proof.wits_in_evals.clone(), - result.proof.fixed_in_evals.clone(), + result.opening_evals.wits_in_evals, + result.opening_evals.fixed_in_evals, + result.opening_evals.pi_in_evals, ]); - } else { - assert!(result.proof.wits_in_evals.is_empty()); - assert!(result.proof.fixed_in_evals.is_empty()); } chip_proofs .entry(result.circuit_idx) .or_insert(vec![]) .push(result.proof); - for (idx, eval) in result.pi_in_evals { - pi_updates.insert(idx, eval); - } } - (chip_proofs, points, evaluations, pi_updates) + (chip_proofs, points, evaluations) } } @@ -880,23 +864,10 @@ where let MainSumcheckEvals { wits_in_evals, fixed_in_evals, + pi_in_evals, } = evals; exit_span!(span); - // evaluate pi if there is instance query - let mut pi_in_evals: HashMap = HashMap::new(); - if !cs.instance_openings().is_empty() { - let span = entered_span!("pi::evals", profiling_2 = true); - for &Instance(idx) in cs.instance_openings() { - let poly = &input.public_input[idx]; - pi_in_evals.insert( - idx, - poly.eval(input_opening_point[..poly.num_vars()].to_vec()), - ); - } - exit_span!(span); - } - Ok(( ZKVMChipProof { r_out_evals, @@ -906,11 +877,13 @@ where gkr_iop_proof, tower_proof, ecc_proof, - fixed_in_evals, - wits_in_evals, num_instances: input.num_instances, }, - pi_in_evals, + MainSumcheckEvals { + wits_in_evals, + fixed_in_evals, + pi_in_evals, + }, input_opening_point, )) } diff --git a/ceno_zkvm/src/scheme/scheduler.rs b/ceno_zkvm/src/scheme/scheduler.rs index 438421b2e..c1bd260f4 100644 --- a/ceno_zkvm/src/scheme/scheduler.rs +++ b/ceno_zkvm/src/scheme/scheduler.rs @@ -12,14 +12,17 @@ use crate::{ error::ZKVMError, - scheme::{ZKVMChipProof, hal::ProofInput}, + scheme::{ + ZKVMChipProof, + hal::{MainSumcheckEvals, ProofInput}, + }, structs::ProvingKey, }; use ff_ext::ExtensionField; use gkr_iop::hal::ProverBackend; use mpcs::Point; use p3::field::FieldAlgebra; -use std::{collections::HashMap, sync::OnceLock}; +use std::sync::OnceLock; use transcript::Transcript; static CHIP_PROVING_MODE: OnceLock = OnceLock::new(); @@ -77,8 +80,8 @@ pub struct ChipTaskResult { pub circuit_idx: usize, /// The generated proof pub proof: ZKVMChipProof, - /// Public input evaluations - pub pi_in_evals: HashMap, + /// Prover-only opening evaluations split by witness/fixed/pi domains. + pub opening_evals: MainSumcheckEvals, /// Opening point for this proof pub input_opening_point: Point, /// Whether this circuit has witness or fixed polynomials diff --git a/ceno_zkvm/src/scheme/tests.rs b/ceno_zkvm/src/scheme/tests.rs index 91a4e4563..763c3bb4e 100644 --- a/ceno_zkvm/src/scheme/tests.rs +++ b/ceno_zkvm/src/scheme/tests.rs @@ -248,13 +248,12 @@ fn test_rw_lk_expression_combination() { { Instrumented::<<::BaseField as PoseidonField>::P>::clear_metrics(); } - verifier + let _ = verifier .verify_chip_proof( name.as_str(), verifier.vk.circuit_vks.get(&name).unwrap(), &proof, - &[], - &[], + &PublicValues::default(), &mut v_transcript, NUM_FANIN, &PointAndEval::default(), @@ -397,7 +396,7 @@ fn test_single_add_instance_e2e() { .assign_table_circuit::>(&zkvm_cs, &prog_config, &program) .unwrap(); - let pi = PublicValues::new(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, vec![0], vec![0; 14]); + let pi = PublicValues::new(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, vec![0], [0; 14]); let transcript = BasicTranscript::new(b"riscv"); let zkvm_proof = prover .create_proof(&shard_ctx, zkvm_witness, pi, transcript) diff --git a/ceno_zkvm/src/scheme/verifier.rs b/ceno_zkvm/src/scheme/verifier.rs index 2c49ee658..28afc506b 100644 --- a/ceno_zkvm/src/scheme/verifier.rs +++ b/ceno_zkvm/src/scheme/verifier.rs @@ -1,5 +1,5 @@ use either::Either; -use ff_ext::ExtensionField; +use ff_ext::{ExtensionField, SmallField}; use std::{ iter::{self, once, repeat_n}, marker::PhantomData, @@ -8,11 +8,12 @@ use std::{ #[cfg(debug_assertions)] use ff_ext::{Instrumented, PoseidonField}; -use super::{ZKVMChipProof, ZKVMProof}; +use super::{PublicValues, ZKVMChipProof, ZKVMProof}; use crate::{ error::ZKVMError, instructions::riscv::constants::{ - END_PC_IDX, HEAP_LENGTH_IDX, HEAP_START_ADDR_IDX, INIT_CYCLE_IDX, INIT_PC_IDX, SHARD_ID_IDX, + END_CYCLE_IDX, END_PC_IDX, HEAP_LENGTH_IDX, HEAP_START_ADDR_IDX, INIT_CYCLE_IDX, + INIT_PC_IDX, }, scheme::{ constants::{NUM_FANIN, SEPTIC_EXTENSION_DEGREE}, @@ -51,6 +52,41 @@ pub struct ZKVMVerifier> { } impl> ZKVMVerifier { + #[allow(clippy::type_complexity)] + fn split_input_opening_evals( + circuit_vk: &VerifyingKey, + proof: &ZKVMChipProof, + ) -> Result<(Vec, Vec, Vec), ZKVMError> { + let cs = circuit_vk.get_cs(); + let Some(gkr_proof) = proof.gkr_iop_proof.as_ref() else { + return Err(ZKVMError::InvalidProof("missing gkr proof".into())); + }; + let Some(last_layer) = gkr_proof.0.last() else { + return Err(ZKVMError::InvalidProof("empty gkr proof layers".into())); + }; + + let evals = &last_layer.main.evals; + let wit_len = cs.num_witin(); + let fixed_len = cs.num_fixed(); + let pi_len = cs.instance_openings().len(); + let min_len = wit_len + fixed_len + pi_len; + if evals.len() < min_len { + return Err(ZKVMError::InvalidProof( + format!( + "insufficient main evals: {} < required {}", + evals.len(), + min_len + ) + .into(), + )); + } + + let wits_in_evals = evals[..wit_len].to_vec(); + let fixed_in_evals = evals[wit_len..(wit_len + fixed_len)].to_vec(); + let pi_in_evals = evals[(wit_len + fixed_len)..(wit_len + fixed_len + pi_len)].to_vec(); + Ok((wits_in_evals, fixed_in_evals, pi_in_evals)) + } + pub fn new(vk: ZKVMVerifyingKey) -> Self { ZKVMVerifier { vk } } @@ -116,19 +152,34 @@ impl> ZKVMVerifier } // each shard set init cycle = Tracer::SUBCYCLES_PER_INSN // to satisfy initial reads for all prev_cycle = 0 < init_cycle - assert_eq!(vm_proof.pi_evals[INIT_CYCLE_IDX], E::from_canonical_u64(Tracer::SUBCYCLES_PER_INSN)); + assert_eq!( + vm_proof.public_values.query_by_index::(INIT_CYCLE_IDX), + E::BaseField::from_canonical_u64(Tracer::SUBCYCLES_PER_INSN) + ); // check init_pc match prev end_pc if let Some(prev_pc) = prev_pc { - assert_eq!(vm_proof.pi_evals[INIT_PC_IDX], prev_pc); + assert_eq!( + vm_proof.public_values.query_by_index::(INIT_PC_IDX), + prev_pc + ); } else { // first chunk, check program entry - assert_eq!(vm_proof.pi_evals[INIT_PC_IDX], E::from_canonical_u32(self.vk.entry_pc)); + assert_eq!( + vm_proof.public_values.query_by_index::(INIT_PC_IDX), + E::BaseField::from_canonical_u32(self.vk.entry_pc) + ); } - let end_pc = vm_proof.pi_evals[END_PC_IDX]; + let end_pc = vm_proof.public_values.query_by_index::(END_PC_IDX); // check memory continuation consistency - let heap_addr_start_u32 = vm_proof.pi_evals[HEAP_START_ADDR_IDX].to_canonical_u64() as u32; - let heap_len= vm_proof.pi_evals[HEAP_LENGTH_IDX].to_canonical_u64() as u32; + let heap_addr_start_u32 = vm_proof + .public_values + .query_by_index::(HEAP_START_ADDR_IDX) + .to_canonical_u64() as u32; + let heap_len = vm_proof + .public_values + .query_by_index::(HEAP_LENGTH_IDX) + .to_canonical_u64() as u32; if let Some(prev_heap_addr_end) = prev_heap_addr_end { assert_eq!(heap_addr_start_u32, prev_heap_addr_end); // TODO check heap addr in prime field within range @@ -165,7 +216,12 @@ impl> ZKVMVerifier let mut prod_w = E::ONE; let mut logup_sum = E::ZERO; - let pi_evals = &vm_proof.pi_evals; + // Global-state expressions are built from compact instance IDs + // (query order), not absolute public-value indices. + let pi_evals = [INIT_PC_IDX, INIT_CYCLE_IDX, END_PC_IDX, END_CYCLE_IDX] + .into_iter() + .map(|idx| E::from(vm_proof.public_values.query_by_index::(idx))) + .collect_vec(); // make sure circuit index of chip proofs are // subset of that of self.vk.circuit_vks @@ -181,33 +237,18 @@ impl> ZKVMVerifier } } - // TODO fix soundness: construct raw public input by ourself and trustless from proof - // including raw public input to transcript - vm_proof - .raw_pi - .iter() - .for_each(|v| v.iter().for_each(|v| transcript.append_field_element(v))); + // Include transcript-visible public values in canonical circuit order. + // This must match prover and recursion verifier exactly. + for (_, circuit_vk) in self.vk.circuit_vks.iter() { + for instance_value in circuit_vk.get_cs().zkvm_v1_css.instance_values.iter() { + transcript.append_field_element( + &vm_proof.public_values.query_by_index::(instance_value.0), + ); + } + } // check shard id - assert_eq!( - vm_proof.raw_pi[SHARD_ID_IDX], - vec![E::BaseField::from_canonical_usize(shard_id)] - ); - - // verify constant poly(s) evaluation result match - // we can evaluate at this moment because constant always evaluate to same value - // non-constant poly(s) will be verified in respective (table) proof accordingly - izip!(&vm_proof.raw_pi, pi_evals) - .enumerate() - .try_for_each(|(i, (raw, eval))| { - if raw.len() == 1 && E::from(raw[0]) != *eval { - Err(ZKVMError::VerifyError( - format!("{shard_id}th shard pub input on index {i} mismatch {raw:?} != {eval:?}").into(), - )) - } else { - Ok(()) - } - })?; + assert_eq!(vm_proof.public_values.shard_id, shard_id as u32); // write fixed commitment to transcript // TODO check soundness if there is no fixed_commit but got fixed proof? @@ -291,21 +332,6 @@ impl> ZKVMVerifier let circuit_name = &self.vk.circuit_index_to_name[index]; let circuit_vk = &self.vk.circuit_vks[circuit_name]; - // check chip proof is well-formed - if proof.wits_in_evals.len() != circuit_vk.get_cs().num_witin() - || proof.fixed_in_evals.len() != circuit_vk.get_cs().num_fixed() - { - return Err(ZKVMError::InvalidProof( - format!( - "{shard_id}th shard witness/fixed evaluations length mismatch: ({}, {}) != ({}, {})", - proof.wits_in_evals.len(), - proof.fixed_in_evals.len(), - circuit_vk.get_cs().num_witin(), - circuit_vk.get_cs().num_fixed(), - ) - .into(), - )); - } if proof.r_out_evals.len() != circuit_vk.get_cs().num_reads() || proof.w_out_evals.len() != circuit_vk.get_cs().num_writes() { @@ -358,27 +384,27 @@ impl> ZKVMVerifier logup_sum += chip_logup_sum; }; - let (input_opening_point, chip_shard_ec_sum) = self.verify_chip_proof( - circuit_name, - circuit_vk, - proof, - pi_evals, - &vm_proof.raw_pi, - transcript, - NUM_FANIN, - &point_eval, - &challenges, - )?; + let (input_opening_point, chip_shard_ec_sum, wits_in_evals, fixed_in_evals) = self + .verify_chip_proof( + circuit_name, + circuit_vk, + proof, + &vm_proof.public_values, + transcript, + NUM_FANIN, + &point_eval, + &challenges, + )?; if circuit_vk.get_cs().num_witin() > 0 { witin_openings.push(( input_opening_point.len(), - (input_opening_point.clone(), proof.wits_in_evals.clone()), + (input_opening_point.clone(), wits_in_evals), )); } if circuit_vk.get_cs().num_fixed() > 0 { fixed_openings.push(( input_opening_point.len(), - (input_opening_point.clone(), proof.fixed_in_evals.clone()), + (input_opening_point.clone(), fixed_in_evals), )); } prod_w *= proof.w_out_evals.iter().flatten().copied().product::(); @@ -435,7 +461,7 @@ impl> ZKVMVerifier &[], &[], &[], - pi_evals, + &pi_evals, &challenges, &self.vk.initial_global_state_expr, ) @@ -446,7 +472,7 @@ impl> ZKVMVerifier &[], &[], &[], - pi_evals, + &pi_evals, &challenges, &self.vk.finalize_global_state_expr, ) @@ -478,13 +504,12 @@ impl> ZKVMVerifier _name: &str, circuit_vk: &VerifyingKey, proof: &ZKVMChipProof, - pi: &[E], - raw_pi: &[Vec], + public_values: &PublicValues, transcript: &mut impl Transcript, num_product_fanin: usize, _out_evals: &PointAndEval, challenges: &[E; 2], // derive challenge from PCS - ) -> Result<(Point, Option>), ZKVMError> { + ) -> Result<(Point, Option>, Vec, Vec), ZKVMError> { let composed_cs = circuit_vk.get_cs(); let ComposedConstrainSystem { zkvm_v1_css: cs, @@ -669,17 +694,29 @@ impl> ZKVMVerifier }, ] }; + let pi = cs + .instance_values + .iter() + .map(|instance| E::from(public_values.query_by_index::(instance.0))) + .collect_vec(); + let (wits_in_evals, fixed_in_evals, _pi_in_evals) = + Self::split_input_opening_evals(circuit_vk, proof)?; + let instance_mles = public_values + .mles::() + .into_iter() + .map(|mle| mle.get_base_field_vec().to_vec()) + .collect_vec(); let (_, rt) = gkr_circuit.verify( num_var_with_rotation, proof.gkr_iop_proof.clone().unwrap(), &evals, - pi, - raw_pi, + &pi, + &instance_mles, challenges, transcript, &selector_ctxs, )?; - Ok((rt, shard_ec_sum)) + Ok((rt, shard_ec_sum, wits_in_evals, fixed_in_evals)) } } diff --git a/ceno_zkvm/src/tables/ram.rs b/ceno_zkvm/src/tables/ram.rs index d934127f9..130adc787 100644 --- a/ceno_zkvm/src/tables/ram.rs +++ b/ceno_zkvm/src/tables/ram.rs @@ -38,11 +38,12 @@ impl DynVolatileRamTable for HeapTable { params: &ProgramParams, ) -> Result<(Expression, StructuralWitIn), CircuitBuilderError> { let max_len = Self::max_len(params); + let offset_instance_id = cb.query_heap_start_addr()?.0 as WitnessId; let addr = cb.create_structural_witin( || "addr", StructuralWitInType::EqualDistanceDynamicSequence { max_len, - offset_instance_id: cb.query_heap_start_addr()?.0 as WitnessId, + offset_instance_id, multi_factor: WORD_SIZE, descending: Self::DESCENDING, }, @@ -143,11 +144,12 @@ impl DynVolatileRamTable for HintsTable { params: &ProgramParams, ) -> Result<(Expression, StructuralWitIn), CircuitBuilderError> { let max_len = Self::max_len(params); + let offset_instance_id = cb.query_hint_start_addr()?.0 as WitnessId; let addr = cb.create_structural_witin( || "addr", StructuralWitInType::EqualDistanceDynamicSequence { max_len, - offset_instance_id: cb.query_hint_start_addr()?.0 as WitnessId, + offset_instance_id, multi_factor: WORD_SIZE, descending: Self::DESCENDING, }, diff --git a/ceno_zkvm/src/tables/shard_ram.rs b/ceno_zkvm/src/tables/shard_ram.rs index 23897fce8..f9caf5513 100644 --- a/ceno_zkvm/src/tables/shard_ram.rs +++ b/ceno_zkvm/src/tables/shard_ram.rs @@ -651,7 +651,7 @@ mod tests { use mpcs::{BasefoldDefault, PolynomialCommitmentScheme, SecurityLevel}; use p3::babybear::BabyBear; use rand::thread_rng; - use std::{ops::Index, sync::Arc}; + use std::sync::Arc; use tracing_forest::{ForestLayer, util::LevelFilter}; use tracing_subscriber::{EnvFilter, Registry, layer::SubscriberExt, util::SubscriberInitExt}; use transcript::BasicTranscript; @@ -659,8 +659,8 @@ mod tests { use crate::{ circuit_builder::{CircuitBuilder, ConstraintSystem}, scheme::{ - PublicValues, create_backend, create_prover, hal::ProofInput, prover::ZKVMProver, - septic_curve::SepticPoint, verifier::ZKVMVerifier, + PublicValues, constants::SEPTIC_EXTENSION_DEGREE, create_backend, create_prover, + hal::ProofInput, prover::ZKVMProver, septic_curve::SepticPoint, verifier::ZKVMVerifier, }, structs::{ComposedConstrainSystem, PointAndEval, ProgramParams, RAMType, ZKVMProvingKey}, tables::{ShardRamCircuit, ShardRamInput, ShardRamRecord, TableCircuit}, @@ -670,7 +670,6 @@ mod tests { gpu::{MultilinearExtensionGpu, get_cuda_hal}, hal::MultilinearPolynomial, }; - use multilinear_extensions::mle::IntoMLE; use p3::field::PrimeField32; type E = BabyBearExt4; @@ -754,6 +753,16 @@ mod tests { .map(|record| record.ec_point.point.clone()) .sum(); + let mut shard_rw_sum = [0u32; SEPTIC_EXTENSION_DEGREE * 2]; + for (i, fe) in global_ec_sum + .x + .iter() + .chain(global_ec_sum.y.iter()) + .enumerate() + { + shard_rw_sum[i] = fe.as_canonical_u32(); + } + let public_value = PublicValues::new( 0, 0, @@ -766,12 +775,7 @@ mod tests { 0, 0, vec![0], // dummy - global_ec_sum - .x - .iter() - .chain(global_ec_sum.y.iter()) - .map(|fe| fe.as_canonical_u32()) - .collect_vec(), + shard_rw_sum, ); // assign witness @@ -807,18 +811,22 @@ mod tests { let zkvm_prover = ZKVMProver::new(zkvm_pk.into(), pd); let mut transcript = BasicTranscript::new(b"global chip test"); - let pub_io_evals = public_value - .to_vec::() - .into_iter() - .map(|v| Either::Right(E::from(*v.index(0)))) + let pub_io_evals = pk + .get_cs() + .zkvm_v1_css + .instance_values + .iter() + .map(|instance| Either::Right(E::from(public_value.query_by_index::(instance.0)))) .collect_vec(); + let pi_mles = public_value.mles::(); #[cfg(not(feature = "gpu"))] let (witness_mles, structural_mles, public_input_mles) = { - let public_input_mles = public_value - .to_vec::() - .into_iter() - .map(|v| Arc::new(v.into_mle())) + let public_input_mles = pk + .get_cs() + .instance_openings() + .iter() + .map(|instance| Arc::new(pi_mles[instance.0].clone())) .collect_vec(); ( witness[0].to_mles().into_iter().map(Arc::new).collect(), @@ -831,10 +839,12 @@ mod tests { let cuda_hal = get_cuda_hal().unwrap(); let witness_cpu: Vec<_> = witness[0].to_mles(); let structural_cpu: Vec<_> = witness[1].to_mles(); - let public_cpu: Vec<_> = public_value - .to_vec::() + let public_cpu: Vec<_> = pk + .get_cs() + .instance_openings() + .iter() + .map(|instance| pi_mles[instance.0].clone()) .into_iter() - .map(|v| v.into_mle()) .collect_vec(); ( witness_cpu @@ -882,17 +892,12 @@ mod tests { let mut transcript = BasicTranscript::new(b"global chip test"); let verifier = ZKVMVerifier::new(zkvm_vk); - let pi_evals = public_input_mles - .iter() - .map(|mle| mle.evaluate(&point[..mle.num_vars()])) - .collect_vec(); - let (vrf_point, _) = verifier + let (vrf_point, _, _, _) = verifier .verify_chip_proof( "global", &pk.vk, &proof, - &pi_evals, - &public_value.to_vec::(), + &public_value, &mut transcript, 2, &PointAndEval::default(), diff --git a/gkr_iop/src/circuit_builder.rs b/gkr_iop/src/circuit_builder.rs index d84cb4d55..8c54fedaf 100644 --- a/gkr_iop/src/circuit_builder.rs +++ b/gkr_iop/src/circuit_builder.rs @@ -102,6 +102,8 @@ pub struct ConstraintSystem { pub num_fixed: usize, pub fixed_namespace_map: Vec, + // record which public input index is involving in constraint computation + pub instance_values: Vec, pub instance_openings: Vec, pub ec_point_exprs: Vec>, @@ -175,6 +177,7 @@ impl ConstraintSystem { num_fixed: 0, fixed_namespace_map: vec![], ns: NameSpace::new(root_name_fn), + instance_values: vec![], instance_openings: vec![], ec_final_sum: vec![], ec_slope_exprs: vec![], @@ -259,9 +262,14 @@ impl ConstraintSystem { f } - pub fn query_instance(&self, idx: usize) -> Result { + pub fn query_instance(&mut self, idx: usize) -> Result { let i = Instance(idx); - Ok(i) + assert!( + !self.instance_values.contains(&i), + "query same pubio idx {idx} value more than once", + ); + self.instance_values.push(i); + Ok(Instance(self.instance_values.len() - 1)) } pub fn query_instance_for_openings( @@ -276,7 +284,6 @@ impl ConstraintSystem { ); self.instance_openings.push(i); - // return instance only count Ok(Instance(self.instance_openings.len() - 1)) } diff --git a/gkr_iop/src/gkr.rs b/gkr_iop/src/gkr.rs index b025aa1e4..a741775ff 100644 --- a/gkr_iop/src/gkr.rs +++ b/gkr_iop/src/gkr.rs @@ -123,7 +123,7 @@ impl GKRCircuit { gkr_proof: GKRProof, out_evals: &[PointAndEval], pub_io_evals: &[E], - raw_pi: &[Vec], + instance_mles: &[Vec], challenges: &[E], transcript: &mut impl Transcript, selector_ctxs: &[SelectorContext], @@ -145,7 +145,7 @@ impl GKRCircuit { layer_proof, &mut evaluations, pub_io_evals, - raw_pi, + instance_mles, &mut challenges, transcript, selector_ctxs, diff --git a/gkr_iop/src/gkr/layer.rs b/gkr_iop/src/gkr/layer.rs index a67862df7..711458dbf 100644 --- a/gkr_iop/src/gkr/layer.rs +++ b/gkr_iop/src/gkr/layer.rs @@ -258,7 +258,7 @@ impl Layer { proof: LayerProof, claims: &mut [PointAndEval], pub_io_evals: &[E], - raw_pi: &[Vec], + instance_mles: &[Vec], challenges: &mut Vec, transcript: &mut Trans, selector_ctxs: &[SelectorContext], @@ -273,7 +273,7 @@ impl Layer { proof, eval_and_dedup_points, pub_io_evals, - raw_pi, + instance_mles, challenges, transcript, selector_ctxs, diff --git a/gkr_iop/src/gkr/layer/zerocheck_layer.rs b/gkr_iop/src/gkr/layer/zerocheck_layer.rs index cf1da6df2..7ac6ae8d4 100644 --- a/gkr_iop/src/gkr/layer/zerocheck_layer.rs +++ b/gkr_iop/src/gkr/layer/zerocheck_layer.rs @@ -77,7 +77,7 @@ pub trait ZerocheckLayer { proof: LayerProof, eval_and_dedup_points: Vec<(Vec, Option>)>, pub_io_evals: &[E], - raw_pi: &[Vec], + instance_mles: &[Vec], challenges: &[E], transcript: &mut impl Transcript, selector_ctxs: &[SelectorContext], @@ -228,7 +228,7 @@ impl ZerocheckLayer for Layer { proof: LayerProof, mut eval_and_dedup_points: Vec<(Vec, Option>)>, pub_io_evals: &[E], - raw_pi: &[Vec], + instance_mles: &[Vec], challenges: &[E], transcript: &mut impl Transcript, selector_ctxs: &[SelectorContext], @@ -386,12 +386,15 @@ impl ZerocheckLayer for Layer { } } - // check pub-io - // assume public io is tiny vector, so we evaluate it directly without PCS + // check pub-io openings by evaluating the opened public-input MLEs. let pubio_offset = self.n_witin + self.n_fixed; for (index, instance) in self.instance_openings.iter().enumerate() { let index = pubio_offset + index; - let poly = raw_pi[instance.0].to_vec().into_mle(); + let poly = instance_mles + .get(instance.0) + .expect("instance opening index out of bounds for instance_mles") + .clone() + .into_mle(); let expected_eval = poly.evaluate(&in_point[..poly.num_vars()]); if expected_eval != main_evals[index] { return Err(BackendError::LayerVerificationFailed(