Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 4 additions & 9 deletions ceno_recursion/src/zkvm_verifier/binding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ use openvm_native_recursion::{
use openvm_stark_backend::p3_field::{FieldAlgebra, extension::BinomialExtensionField};
use openvm_stark_sdk::p3_baby_bear::BabyBear;
use p3::field::FieldExtensionAlgebra;
use std::cmp::max;
use sumcheck::structs::IOPProof;

pub type F = BabyBear;
Expand Down Expand Up @@ -402,7 +403,7 @@ pub struct ZKVMChipProofInput {
pub has_ecc_proof: usize,
pub ecc_proof: EccQuarkProofInput,

pub num_instances: Vec<usize>,
pub num_instances: [usize; 2],

pub wits_in_evals: Vec<E>,
pub fixed_in_evals: Vec<E>,
Expand Down Expand Up @@ -657,18 +658,12 @@ impl Hintable<InnerConfig> for ZKVMChipProofInput {
stream.extend(<usize as Hintable<InnerConfig>>::write(&self.has_ecc_proof));
stream.extend(self.ecc_proof.write());

stream.extend(<Vec<usize> as Hintable<InnerConfig>>::write(
&self.num_instances,
));
stream.extend(self.num_instances.to_vec().write());

let n_inst_0 = self.num_instances[0];
let n_inst_0_bit_decomps = decompose_minus_one_bits(n_inst_0);

let n_inst_1 = if self.num_instances.len() > 1 {
self.num_instances[1]
} else {
1usize
};
let n_inst_1 = max(self.num_instances[1], 1);
let n_inst_1_bit_decomps = decompose_minus_one_bits(n_inst_1);

stream.extend(n_inst_0_bit_decomps.write());
Expand Down
75 changes: 33 additions & 42 deletions ceno_recursion/src/zkvm_verifier/verifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,54 +122,46 @@ pub fn verify_zkvm_proof<C: Config<F = F>>(
},
);

builder
.if_eq(zkvm_proof_input.shard_id.clone(), Usize::from(0))
.then(|builder| {
if let Some(fixed_commit) = vk.fixed_commit.as_ref() {
let commit: crate::basefold_verifier::hash::Hash = fixed_commit.commit().into();
let commit_array: Array<C, Felt<C::F>> = builder.dyn_array(commit.value.len());
if let Some(fixed_commit) = vk.fixed_commit.as_ref() {
let commit: crate::basefold_verifier::hash::Hash = fixed_commit.commit().into();
let commit_array: Array<C, Felt<C::F>> = builder.dyn_array(commit.value.len());

commit.value.into_iter().enumerate().for_each(|(i, v)| {
let v = builder.constant(v);
// TODO: put fixed commit to public values
// builder.commit_public_value(v);
commit.value.into_iter().enumerate().for_each(|(i, v)| {
let v = builder.constant(v);
// TODO: put fixed commit to public values
// builder.commit_public_value(v);

builder.set_value(&commit_array, i, v);
});
builder.set_value(&commit_array, i, v);
});

challenger_multi_observe(builder, &mut challenger, &commit_array);
challenger_multi_observe(builder, &mut challenger, &commit_array);

let log2_max_codeword_size_felt = builder.constant(C::F::from_canonical_usize(
fixed_commit.log2_max_codeword_size,
));
let log2_max_codeword_size_felt = builder.constant(C::F::from_canonical_usize(
fixed_commit.log2_max_codeword_size,
));

challenger.observe(builder, log2_max_codeword_size_felt);
}
});
challenger.observe(builder, log2_max_codeword_size_felt);
}

builder
.if_ne(zkvm_proof_input.shard_id.clone(), Usize::from(0))
.then(|builder| {
if let Some(fixed_commit) = vk.fixed_no_omc_init_commit.as_ref() {
let commit: crate::basefold_verifier::hash::Hash = fixed_commit.commit().into();
let commit_array: Array<C, Felt<C::F>> = builder.dyn_array(commit.value.len());

commit.value.into_iter().enumerate().for_each(|(i, v)| {
let v = builder.constant(v);
// TODO: put fixed commit to public values
// builder.commit_public_value(v);

builder.set_value(&commit_array, i, v);
});
challenger_multi_observe(builder, &mut challenger, &commit_array);
if let Some(fixed_commit) = vk.fixed_no_omc_init_commit.as_ref() {
let commit: crate::basefold_verifier::hash::Hash = fixed_commit.commit().into();
let commit_array: Array<C, Felt<C::F>> = builder.dyn_array(commit.value.len());

let log2_max_codeword_size_felt = builder.constant(C::F::from_canonical_usize(
fixed_commit.log2_max_codeword_size,
));
commit.value.into_iter().enumerate().for_each(|(i, v)| {
let v = builder.constant(v);
// TODO: put fixed commit to public values
// builder.commit_public_value(v);

challenger.observe(builder, log2_max_codeword_size_felt);
}
builder.set_value(&commit_array, i, v);
});
challenger_multi_observe(builder, &mut challenger, &commit_array);

let log2_max_codeword_size_felt = builder.constant(C::F::from_canonical_usize(
fixed_commit.log2_max_codeword_size,
));

challenger.observe(builder, log2_max_codeword_size_felt);
}

iter_zip!(builder, zkvm_proof_input.chip_proofs).for_each(|ptr_vec, builder| {
let chip_proofs = builder.iter_ptr_get(&zkvm_proof_input.chip_proofs, ptr_vec[0]);
Expand Down Expand Up @@ -712,7 +704,8 @@ pub fn verify_chip_proof<C: Config>(

let zero_bit_decomps: Array<C, Felt<C::F>> = builder.dyn_array(32);
let selector_ctxs: Vec<SelectorContextVariable<C>> = if cs.ec_final_sum.is_empty() {
builder.assert_usize_eq(chip_proof.num_instances.len(), Usize::from(1));
let non_shard_n1 = Usize::Var(builder.get(&chip_proof.num_instances, 1));
builder.assert_usize_eq(non_shard_n1, Usize::from(0));
let num_instances_bit_decomps: Array<C, Array<C, Felt<C::F>>> = builder.dyn_array(1);
builder.set(
&num_instances_bit_decomps,
Expand Down Expand Up @@ -740,8 +733,6 @@ pub fn verify_chip_proof<C: Config>(
.unwrap_or(0)
]
} else {
builder.assert_usize_eq(chip_proof.num_instances.len(), Usize::from(2));

let num_inst_0_bit_decomps: Array<C, Array<C, Felt<C::F>>> = builder.dyn_array(1);
let num_inst_1_bit_decomps: Array<C, Array<C, Felt<C::F>>> = builder.dyn_array(1);
let num_inst_sum_bit_decomps: Array<C, Array<C, Felt<C::F>>> = builder.dyn_array(1);
Expand Down
2 changes: 1 addition & 1 deletion ceno_zkvm/benches/riscv_add.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ fn bench_add(c: &mut Criterion) {
structural_witness: vec![],
public_input: vec![],
pub_io_evals: vec![],
num_instances: vec![num_instances],
num_instances: [num_instances, 0],
has_ecc_ops: false,
};
let task = ChipTask {
Expand Down
2 changes: 1 addition & 1 deletion ceno_zkvm/src/scheme.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ pub struct ZKVMChipProof<E: ExtensionField> {
pub tower_proof: TowerProofs<E>,
pub ecc_proof: Option<EccQuarkProof<E>>,

pub num_instances: Vec<usize>,
pub num_instances: [usize; 2],

pub fixed_in_evals: Vec<E>,
pub wits_in_evals: Vec<E>,
Expand Down
2 changes: 1 addition & 1 deletion ceno_zkvm/src/scheme/hal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ pub struct ProofInput<'a, PB: ProverBackend> {
pub fixed: Vec<Arc<PB::MultilinearPoly<'a>>>,
pub public_input: Vec<Arc<PB::MultilinearPoly<'a>>>,
pub pub_io_evals: Vec<Either<<PB::E as ExtensionField>::BaseField, PB::E>>,
pub num_instances: Vec<usize>,
pub num_instances: [usize; 2],
pub has_ecc_ops: bool,
}

Expand Down
25 changes: 11 additions & 14 deletions ceno_zkvm/src/scheme/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -171,14 +171,11 @@ impl<

// commit to fixed commitment
let span = entered_span!("commit_to_fixed_commit", profiling_1 = true);
if let Some(fixed_commit) = &self.pk.fixed_commit
&& shard_ctx.is_first_shard()
{
if let Some(fixed_commit) = self.pk.fixed_commit.as_ref() {
PCS::write_commitment(fixed_commit, &mut transcript)
.map_err(ZKVMError::PCSError)?;
} else if let Some(fixed_commit) = &self.pk.fixed_no_omc_init_commit
&& !shard_ctx.is_first_shard()
{
}
if let Some(fixed_commit) = self.pk.fixed_no_omc_init_commit.as_ref() {
PCS::write_commitment(fixed_commit, &mut transcript)
.map_err(ZKVMError::PCSError)?;
}
Expand All @@ -198,10 +195,10 @@ impl<
// num_instance from witness might include rotation
let num_instances = chip_inputs
.iter()
.flat_map(|chip_input| &chip_input.num_instances)
.flat_map(|chip_input| chip_input.num_instances)
.collect_vec();

if num_instances.is_empty() {
if num_instances.iter().sum::<usize>() == 0 {
continue;
}

Expand All @@ -210,7 +207,7 @@ impl<
transcript.append_field_element(&E::BaseField::from_canonical_usize(*circuit_idx));
for num_instance in num_instances {
transcript
.append_field_element(&E::BaseField::from_canonical_usize(*num_instance));
.append_field_element(&E::BaseField::from_canonical_usize(num_instance));
}
}

Expand Down Expand Up @@ -548,7 +545,7 @@ impl<
ecc_proof,
fixed_in_evals,
wits_in_evals,
num_instances: input.num_instances.clone(),
num_instances: input.num_instances,
},
pi_in_evals,
input_opening_point,
Expand All @@ -561,7 +558,7 @@ impl<
fn build_chip_tasks<'data>(
&self,
shard_ctx: &ShardContext,
name_and_instances: Vec<(String, Vec<usize>)>,
name_and_instances: Vec<(String, [usize; 2])>,
structural_rmms: Vec<witness::RowMajorMatrix<E::BaseField>>,
#[allow(unused_mut)] mut witness_mles: Vec<PB::MultilinearPoly<'data>>,
witness_data: &PB::PcsData,
Expand Down Expand Up @@ -600,11 +597,11 @@ impl<
let pk = self.pk.circuit_pks.get(&circuit_name).unwrap();
let cs = pk.get_cs();
if !shard_ctx.is_first_shard() && cs.with_omc_init_only() {
assert!(num_instances.is_empty());
assert_eq!(num_instances, [0, 0]);
// skip drain respective fixed because we use different set of fixed commitment
continue;
}
if num_instances.is_empty() {
if num_instances.iter().sum::<usize>() == 0 {
// we need to drain respective fixed when num_instances is 0
if cs.num_fixed() > 0 {
let _ = fixed_mles.drain(..cs.num_fixed()).collect_vec();
Expand Down Expand Up @@ -651,7 +648,7 @@ impl<
structural_witness,
public_input: public_input.clone(),
pub_io_evals: pi_evals.iter().map(|p| Either::Right(*p)).collect(),
num_instances: num_instances.clone(),
num_instances,
has_ecc_ops: cs.has_ecc_ops(),
};
// SAFETY: All Arcs in ProofInput contain 'static data:
Expand Down
2 changes: 1 addition & 1 deletion ceno_zkvm/src/scheme/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ fn test_rw_lk_expression_combination() {
structural_witness: structural_in,
public_input: vec![],
pub_io_evals: vec![],
num_instances: vec![num_instances],
num_instances: [num_instances, 0],
has_ecc_ops: false,
};
let task = crate::scheme::scheduler::ChipTask {
Expand Down
12 changes: 4 additions & 8 deletions ceno_zkvm/src/scheme/verifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,13 +211,10 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMVerifier<E, PCS>

// write fixed commitment to transcript
// TODO check soundness if there is no fixed_commit but got fixed proof?
if let Some(fixed_commit) = self.vk.fixed_commit.as_ref()
&& shard_id == 0
{
if let Some(fixed_commit) = self.vk.fixed_commit.as_ref() {
PCS::write_commitment(fixed_commit, &mut transcript).map_err(ZKVMError::PCSError)?;
} else if let Some(fixed_commit) = self.vk.fixed_no_omc_init_commit.as_ref()
&& shard_id > 0
{
}
if let Some(fixed_commit) = self.vk.fixed_no_omc_init_commit.as_ref() {
PCS::write_commitment(fixed_commit, &mut transcript).map_err(ZKVMError::PCSError)?;
}

Expand Down Expand Up @@ -632,7 +629,7 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMVerifier<E, PCS>

let gkr_circuit = gkr_circuit.as_ref().unwrap();
let selector_ctxs = if cs.ec_final_sum.is_empty() {
assert_eq!(proof.num_instances.len(), 1);
assert_eq!(proof.num_instances[1], 0);
// it's not shard chip
vec![
SelectorContext::new(0, num_instances, num_var_with_rotation);
Expand All @@ -643,7 +640,6 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMVerifier<E, PCS>
.unwrap_or(0)
]
} else {
assert_eq!(proof.num_instances.len(), 2);
// it's shard chip
tracing::debug!(
"num_reads: {}, num_writes: {}, total: {}",
Expand Down
33 changes: 8 additions & 25 deletions ceno_zkvm/src/structs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -325,15 +325,14 @@ impl<E: ExtensionField> ZKVMFixedTraces<E> {
pub struct ChipInput<E: ExtensionField> {
pub name: String,
pub witness_rmms: RMMCollections<E::BaseField>,
// in shard ram chip, num_instances length would be > 1
pub num_instances: Vec<usize>,
pub num_instances: [usize; 2],
}

impl<E: ExtensionField> ChipInput<E> {
pub fn new(
name: String,
witness_rmms: RMMCollections<E::BaseField>,
num_instances: Vec<usize>,
num_instances: [usize; 2],
) -> Self {
Self {
name,
Expand Down Expand Up @@ -382,16 +381,8 @@ impl<E: ExtensionField> ZKVMWitnesses<E> {
shard_steps,
indices,
)?;
let num_instances = vec![witness[0].num_instances()];
let input = ChipInput::new(
OC::name(),
witness,
if num_instances[0] > 0 {
num_instances
} else {
vec![]
},
);
let num_instances = [witness[0].num_instances(), 0];
let input = ChipInput::new(OC::name(), witness, num_instances);
assert!(self.witnesses.insert(OC::name(), vec![input]).is_none());
assert!(
self.lk_mlts
Expand Down Expand Up @@ -445,15 +436,7 @@ impl<E: ExtensionField> ZKVMWitnesses<E> {
input,
)?;
let num_instances = std::cmp::max(witness[0].num_instances(), witness[1].num_instances());
let input = ChipInput::new(
TC::name(),
witness,
if num_instances > 0 {
vec![num_instances]
} else {
vec![]
},
);
let input = ChipInput::new(TC::name(), witness, [num_instances, 0]);
assert!(self.witnesses.insert(TC::name(), vec![input]).is_none());

Ok(())
Expand Down Expand Up @@ -613,7 +596,7 @@ impl<E: ExtensionField> ZKVMWitnesses<E> {
Ok(ChipInput::new(
ShardRamCircuit::<E>::name(),
witness,
vec![num_reads, num_writes],
[num_reads, num_writes],
))
})
.collect::<Result<Vec<_>, ZKVMError>>()?;
Expand All @@ -627,13 +610,13 @@ impl<E: ExtensionField> ZKVMWitnesses<E> {
Ok(())
}

pub fn get_witnesses_name_instance(&self) -> Vec<(String, Vec<usize>)> {
pub fn get_witnesses_name_instance(&self) -> Vec<(String, [usize; 2])> {
self.witnesses
.iter()
.flat_map(|(_, chip_inputs)| {
chip_inputs
.iter()
.map(|chip_input| (chip_input.name.clone(), chip_input.num_instances.clone()))
.map(|chip_input| (chip_input.name.clone(), chip_input.num_instances))
})
.collect_vec()
}
Expand Down
2 changes: 1 addition & 1 deletion ceno_zkvm/src/tables/shard_ram.rs
Original file line number Diff line number Diff line change
Expand Up @@ -858,7 +858,7 @@ mod tests {
fixed: vec![],
public_input: public_input_mles.clone(),
pub_io_evals,
num_instances: vec![n_global_writes as usize, n_global_reads as usize],
num_instances: [n_global_writes as usize, n_global_reads as usize],
has_ecc_ops: true,
};
let mut rng = thread_rng();
Expand Down
Loading