Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 4 additions & 13 deletions ceno_recursion/src/aggregation/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ impl CenoAggregationProver {
.collect();
let user_public_values: Vec<F> = zkvm_proof_inputs
.iter()
.flat_map(|p| p.raw_pi.iter().flat_map(|v| v.clone()).collect::<Vec<F>>())
.flat_map(|p| p.raw_pi.to_vec())
.collect();
let leaf_inputs = chunk_ceno_leaf_proof_inputs(zkvm_proof_inputs);

Expand Down Expand Up @@ -410,18 +410,9 @@ impl CenoLeafVmVerifierConfig {
}

let pv = &raw_pi;
let init_pc = {
let arr = builder.get(pv, INIT_PC_IDX);
builder.get(&arr, 0)
};
let end_pc = {
let arr = builder.get(pv, END_PC_IDX);
builder.get(&arr, 0)
};
let exit_code = {
let arr = builder.get(pv, EXIT_CODE_IDX);
builder.get(&arr, 0)
};
let init_pc = builder.get(pv, INIT_PC_IDX);
let end_pc = builder.get(pv, END_PC_IDX);
let exit_code = builder.get(pv, EXIT_CODE_IDX);
builder.assign(&stark_pvs.connector.initial_pc, init_pc);
builder.assign(&stark_pvs.connector.final_pc, end_pc);
builder.assign(&stark_pvs.connector.exit_code, exit_code);
Expand Down
102 changes: 69 additions & 33 deletions ceno_recursion/src/zkvm_verifier/binding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use crate::{
},
};
use ceno_zkvm::{
instructions::riscv::constants::{LIMB_BITS, LIMB_MASK, UINT_LIMBS},
scheme::{ZKVMChipProof, ZKVMProof},
structs::{EccQuarkProof, TowerProofs, ZKVMVerifyingKey},
};
Expand Down Expand Up @@ -41,6 +42,48 @@ pub type E = BinomialExtensionField<F, 4>;
pub type RecPcs = Basefold<E, BasefoldRSParams>;
pub type InnerConfig = AsmConfig<F, E>;

fn raw_pi_from_public_values(public_values: &ceno_zkvm::scheme::PublicValues) -> Vec<F> {
vec![
F::from_canonical_u32(public_values.exit_code & 0xffff),
F::from_canonical_u32((public_values.exit_code >> 16) & 0xffff),
F::from_canonical_u32(public_values.init_pc),
F::from_canonical_u64(public_values.init_cycle),
F::from_canonical_u32(public_values.end_pc),
F::from_canonical_u64(public_values.end_cycle),
F::from_canonical_u32(public_values.shard_id),
F::from_canonical_u32(public_values.heap_start_addr),
F::from_canonical_u32(public_values.heap_shard_len),
F::from_canonical_u32(public_values.hint_start_addr),
F::from_canonical_u32(public_values.hint_shard_len),
]
.into_iter()
.chain(
public_values
.shard_rw_sum
.iter()
.map(|value| F::from_canonical_u32(*value)),
)
.collect_vec()
}

fn mles_from_public_values(public_values: &ceno_zkvm::scheme::PublicValues) -> Vec<Vec<F>> {
(0..UINT_LIMBS)
.map(|limb_index| {
public_values
.public_io
.iter()
.map(|value| {
F::from_canonical_u16(((value >> (limb_index * LIMB_BITS)) & LIMB_MASK) as u16)
})
.collect_vec()
})
.collect_vec()
}

fn pi_evals_from_raw_pi(raw_pi: &[F]) -> Vec<E> {
raw_pi.iter().map(|v| E::from(*v)).collect_vec()
}

pub fn decompose_minus_one_bits(n: usize) -> Vec<F> {
let a = if n > 0 { n - 1 } else { 0 };
let mut bit_decomp: Vec<F> = vec![];
Expand Down Expand Up @@ -71,7 +114,8 @@ pub fn decompose_prefixed_layer_bits(n: usize) -> (Vec<usize>, Vec<Vec<F>>) {
#[derive(DslVariable, Clone)]
pub struct ZKVMProofInputVariable<C: Config> {
pub shard_id: Usize<C::N>,
pub raw_pi: Array<C, Array<C, Felt<C::F>>>,
pub raw_pi: Array<C, Felt<C::F>>,
pub mles: Array<C, Array<C, Felt<C::F>>>,
pub raw_pi_num_variables: Array<C, Var<C::N>>,
pub pi_evals: Array<C, Ext<C::F, C::EF>>,
pub chip_proofs: Array<C, Array<C, ZKVMChipProofInputVariable<C>>>,
Expand All @@ -95,7 +139,9 @@ pub struct TowerProofInputVariable<C: Config> {

pub(crate) struct ZKVMProofInput {
pub shard_id: usize,
pub raw_pi: Vec<Vec<F>>,
pub raw_pi: Vec<F>,
pub mles: Vec<Vec<F>>,
pub raw_pi_num_variables: Vec<usize>,
// Evaluation of raw_pi.
pub pi_evals: Vec<E>,
pub chip_proofs: BTreeMap<usize, ZKVMChipProofs>,
Expand All @@ -109,6 +155,14 @@ impl ZKVMProofInput {
zkvm_proof: ZKVMProof<E, RecPcs>,
vk: &ZKVMVerifyingKey<E, RecPcs>,
) -> Self {
let raw_pi = raw_pi_from_public_values(&zkvm_proof.public_values);
let mles = mles_from_public_values(&zkvm_proof.public_values);
let raw_pi_num_variables = mles
.iter()
.map(|v| ceil_log2(v.len().next_power_of_two()))
.collect::<Vec<_>>();
let pi_evals = pi_evals_from_raw_pi(&raw_pi);

let mut chip_witin_num_vars: HashMap<usize, (usize, usize)> = HashMap::new(); // (chip_id, (num_witin, num_fixed))
let mut chip_indices = zkvm_proof
.chip_proofs
Expand Down Expand Up @@ -136,8 +190,10 @@ impl ZKVMProofInput {

ZKVMProofInput {
shard_id,
raw_pi: zkvm_proof.raw_pi,
pi_evals: zkvm_proof.pi_evals,
raw_pi,
mles,
raw_pi_num_variables,
pi_evals,
chip_proofs: zkvm_proof
.chip_proofs
.into_iter()
Expand Down Expand Up @@ -168,7 +224,8 @@ impl Hintable<InnerConfig> for ZKVMProofInput {

fn read(builder: &mut Builder<InnerConfig>) -> Self::HintVariable {
let shard_id = Usize::Var(usize::read(builder));
let raw_pi = Vec::<Vec<F>>::read(builder);
let raw_pi = Vec::<F>::read(builder);
let mles = Vec::<Vec<F>>::read(builder);
let raw_pi_num_variables = Vec::<usize>::read(builder);
let pi_evals = Vec::<E>::read(builder);
builder.cycle_tracker_start("read chip proofs");
Expand All @@ -187,6 +244,7 @@ impl Hintable<InnerConfig> for ZKVMProofInput {
ZKVMProofInputVariable {
shard_id,
raw_pi,
mles,
raw_pi_num_variables,
pi_evals,
chip_proofs,
Expand All @@ -201,11 +259,6 @@ impl Hintable<InnerConfig> for ZKVMProofInput {

fn write(&self) -> Vec<Vec<<InnerConfig as Config>::N>> {
let mut stream = Vec::new();
let raw_pi_num_variables: Vec<usize> = self
.raw_pi
.iter()
.map(|v| ceil_log2(v.len().next_power_of_two()))
.collect();
let witin_num_vars = self
.chip_proofs
.iter()
Expand All @@ -217,21 +270,21 @@ impl Hintable<InnerConfig> for ZKVMProofInput {
.chip_proofs
.iter()
.flat_map(|(_, proofs)| proofs.iter())
.map(|proof| proof.wits_in_evals.len().max(1))
.map(|proof| proof.num_witin.max(1))
.collect::<Vec<_>>();
let fixed_num_vars = self
.chip_proofs
.iter()
.flat_map(|(_, proofs)| proofs.iter())
.filter(|proof| !proof.fixed_in_evals.is_empty())
.filter(|proof| proof.num_fixed > 0)
.map(|proof| proof.num_vars)
.collect::<Vec<_>>();
let fixed_max_widths = self
.chip_proofs
.iter()
.flat_map(|(_, proofs)| proofs.iter())
.filter(|proof| !proof.fixed_in_evals.is_empty())
.map(|proof| proof.fixed_in_evals.len())
.filter(|proof| proof.num_fixed > 0)
.map(|proof| proof.num_fixed)
.collect::<Vec<_>>();
let max_num_var = witin_num_vars
.iter()
Expand Down Expand Up @@ -264,7 +317,8 @@ impl Hintable<InnerConfig> for ZKVMProofInput {

stream.extend(<usize as Hintable<InnerConfig>>::write(&self.shard_id));
stream.extend(self.raw_pi.write());
stream.extend(raw_pi_num_variables.write());
stream.extend(self.mles.write());
stream.extend(self.raw_pi_num_variables.write());
stream.extend(self.pi_evals.write());
stream.extend(vec![vec![F::from_canonical_usize(self.chip_proofs.len())]]);
for proofs in self.chip_proofs.values() {
Expand Down Expand Up @@ -403,9 +457,6 @@ pub struct ZKVMChipProofInput {
pub ecc_proof: EccQuarkProofInput,

pub num_instances: Vec<usize>,

pub wits_in_evals: Vec<E>,
pub fixed_in_evals: Vec<E>,
}

impl VecAutoHintable for ZKVMChipProofInput {}
Expand Down Expand Up @@ -499,8 +550,6 @@ impl From<(usize, ZKVMChipProof<E>, usize, usize)> for ZKVMChipProofInput {
EccQuarkProofInput::dummy()
},
num_instances: p.num_instances,
wits_in_evals: p.wits_in_evals,
fixed_in_evals: p.fixed_in_evals,
}
}
}
Expand Down Expand Up @@ -531,9 +580,6 @@ pub struct ZKVMChipProofInputVariable<C: Config> {
pub num_instances: Array<C, Var<C::N>>,
pub n_inst_0_bit_decomps: Array<C, Felt<C::F>>,
pub n_inst_1_bit_decomps: Array<C, Felt<C::F>>,

pub fixed_in_evals: Array<C, Ext<C::F, C::EF>>,
pub wits_in_evals: Array<C, Ext<C::F, C::EF>>,
}
impl Hintable<InnerConfig> for ZKVMChipProofInput {
type HintVariable = ZKVMChipProofInputVariable<InnerConfig>;
Expand Down Expand Up @@ -571,11 +617,6 @@ impl Hintable<InnerConfig> for ZKVMChipProofInput {
let n_inst_0_bit_decomps = Vec::<F>::read(builder);
let n_inst_1_bit_decomps = Vec::<F>::read(builder);

builder.cycle_tracker_start("read wit/fixed evals");
let fixed_in_evals = Vec::<E>::read(builder);
let wits_in_evals = Vec::<E>::read(builder);
builder.cycle_tracker_end("read wit/fixed evals");

ZKVMChipProofInputVariable {
idx,
idx_felt,
Expand All @@ -597,8 +638,6 @@ impl Hintable<InnerConfig> for ZKVMChipProofInput {
num_instances,
n_inst_0_bit_decomps,
n_inst_1_bit_decomps,
fixed_in_evals,
wits_in_evals,
}
}

Expand Down Expand Up @@ -674,9 +713,6 @@ impl Hintable<InnerConfig> for ZKVMChipProofInput {
stream.extend(n_inst_0_bit_decomps.write());
stream.extend(n_inst_1_bit_decomps.write());

stream.extend(self.fixed_in_evals.write());
stream.extend(self.wits_in_evals.write());

stream
}
}
Expand Down
Loading
Loading