From 1b0984f410902ecee852d336024d28016ffe8abb Mon Sep 17 00:00:00 2001 From: spherel <101384151+spherel@users.noreply.github.com> Date: Mon, 2 Feb 2026 06:26:50 +0800 Subject: [PATCH 1/3] group expressions by selectors --- ceno_recursion/src/zkvm_verifier/verifier.rs | 14 +- ceno_zkvm/src/chip_handler.rs | 14 +- ceno_zkvm/src/chip_handler/memory.rs | 26 +- ceno_zkvm/src/instructions.rs | 32 +- ceno_zkvm/src/instructions/riscv/arith.rs | 2 +- ceno_zkvm/src/instructions/riscv/arith_imm.rs | 2 +- ceno_zkvm/src/instructions/riscv/auipc.rs | 2 +- .../src/instructions/riscv/branch/test.rs | 12 +- ceno_zkvm/src/instructions/riscv/div.rs | 9 +- .../src/instructions/riscv/dummy/test.rs | 2 +- .../instructions/riscv/ecall/fptower_fp.rs | 10 +- .../riscv/ecall/fptower_fp2_add.rs | 11 +- .../riscv/ecall/fptower_fp2_mul.rs | 10 +- .../src/instructions/riscv/ecall/keccak.rs | 20 +- .../instructions/riscv/ecall/sha_extend.rs | 11 +- .../src/instructions/riscv/ecall/uint256.rs | 23 +- .../riscv/ecall/weierstrass_add.rs | 15 +- .../riscv/ecall/weierstrass_decompress.rs | 15 +- .../riscv/ecall/weierstrass_double.rs | 16 +- ceno_zkvm/src/instructions/riscv/insn_base.rs | 27 + ceno_zkvm/src/instructions/riscv/jump/test.rs | 4 +- .../src/instructions/riscv/logic/test.rs | 6 +- .../riscv/logic_imm/logic_imm_circuit.rs | 2 +- .../src/instructions/riscv/logic_imm/test.rs | 2 +- ceno_zkvm/src/instructions/riscv/lui.rs | 2 +- .../src/instructions/riscv/memory/test.rs | 4 +- ceno_zkvm/src/instructions/riscv/mulh.rs | 6 +- ceno_zkvm/src/instructions/riscv/shift.rs | 2 +- ceno_zkvm/src/instructions/riscv/shift_imm.rs | 2 +- ceno_zkvm/src/instructions/riscv/slt.rs | 2 +- ceno_zkvm/src/instructions/riscv/slti.rs | 2 +- ceno_zkvm/src/precompiles/bitwise_keccakf.rs | 33 +- ceno_zkvm/src/precompiles/fptower/fp.rs | 132 +-- .../src/precompiles/fptower/fp2_addsub.rs | 156 +--- ceno_zkvm/src/precompiles/fptower/fp2_mul.rs | 132 +-- ceno_zkvm/src/precompiles/lookup_keccakf.rs | 735 ++++++++------- ceno_zkvm/src/precompiles/mod.rs | 15 +- ceno_zkvm/src/precompiles/sha256/extend.rs | 133 +-- ceno_zkvm/src/precompiles/uint256.rs | 226 ++--- .../weierstrass/weierstrass_add.rs | 135 +-- .../weierstrass/weierstrass_decompress.rs | 126 +-- .../weierstrass/weierstrass_double.rs | 132 +-- ceno_zkvm/src/scheme/cpu/mod.rs | 325 ++++--- ceno_zkvm/src/scheme/gpu/mod.rs | 222 +++-- ceno_zkvm/src/scheme/mock_prover.rs | 858 ++++++++++-------- ceno_zkvm/src/scheme/tests.rs | 104 ++- ceno_zkvm/src/scheme/utils.rs | 36 +- ceno_zkvm/src/scheme/verifier.rs | 112 +-- ceno_zkvm/src/stats.rs | 48 +- ceno_zkvm/src/structs.rs | 20 +- ceno_zkvm/src/tables/mod.rs | 35 +- ceno_zkvm/src/tables/ram/ram_circuit.rs | 23 +- ceno_zkvm/src/tables/range/range_circuit.rs | 27 +- ceno_zkvm/src/tables/shard_ram.rs | 58 +- ceno_zkvm/src/uint/arithmetic.rs | 10 +- gkr_iop/src/chip.rs | 21 +- gkr_iop/src/circuit_builder.rs | 486 ++++++++-- gkr_iop/src/circuit_builder/ram.rs | 74 +- gkr_iop/src/gkr.rs | 21 +- gkr_iop/src/gkr/layer.rs | 364 +++++--- gkr_iop/src/gkr/layer/gpu/mod.rs | 24 +- gkr_iop/src/lib.rs | 65 +- gkr_iop/src/selector.rs | 29 +- 63 files changed, 2618 insertions(+), 2606 deletions(-) diff --git a/ceno_recursion/src/zkvm_verifier/verifier.rs b/ceno_recursion/src/zkvm_verifier/verifier.rs index 6883cdd45..8ebc0a5f0 100644 --- a/ceno_recursion/src/zkvm_verifier/verifier.rs +++ b/ceno_recursion/src/zkvm_verifier/verifier.rs @@ -539,9 +539,9 @@ pub fn verify_chip_proof( } = &composed_cs; let one: Ext = builder.constant(C::EF::ONE); - let r_len = cs.r_expressions.len() + cs.r_table_expressions.len(); - let w_len = cs.w_expressions.len() + cs.w_table_expressions.len(); - let lk_len = cs.lk_expressions.len() + cs.lk_table_expressions.len(); + let r_len = cs.r_expressions_len() + cs.r_table_expressions_len(); + let w_len = cs.w_expressions_len() + cs.w_table_expressions_len(); + let lk_len = cs.lk_expressions_len() + cs.lk_table_expressions_len(); let num_batched = r_len + w_len + lk_len; let r_counts_per_instance: Usize = Usize::from(r_len); @@ -608,7 +608,7 @@ pub fn verify_chip_proof( ); builder.cycle_tracker_end(format!("verify tower proof for opcode {circuit_name}",).as_str()); - if cs.lk_table_expressions.is_empty() { + if cs.lk_table_expressions_len() == 0 { builder .range(0, logup_p_evals.len()) .for_each(|idx_vec, builder| { @@ -623,7 +623,7 @@ pub fn verify_chip_proof( builder.assert_usize_eq(logup_q_evals.len(), lk_counts_per_instance.clone()); // GKR circuit - let out_evals_len: Usize = if cs.lk_table_expressions.is_empty() { + let out_evals_len: Usize = if cs.lk_table_expressions_len() == 0 { builder.eval(record_evals.len() + logup_q_evals.len()) } else { builder.eval(record_evals.len() + logup_p_evals.len() + logup_q_evals.len()) @@ -638,7 +638,7 @@ pub fn verify_chip_proof( }); let end: Usize = Usize::uninit(builder); - if !cs.lk_table_expressions.is_empty() { + if cs.lk_table_expressions_len() > 0 { builder.assign(&end, record_evals.len() + logup_p_evals.len()); let p_slice = out_evals.slice(builder, record_evals.len(), end.clone()); @@ -687,7 +687,7 @@ pub fn verify_chip_proof( gkr_circuit .layers .first() - .map(|layer| layer.out_sel_and_eval_exprs.len()) + .map(|layer| layer.selector_ctxs_len()) .unwrap_or(0) ] } else { diff --git a/ceno_zkvm/src/chip_handler.rs b/ceno_zkvm/src/chip_handler.rs index 206a9e99f..0d596b30f 100644 --- a/ceno_zkvm/src/chip_handler.rs +++ b/ceno_zkvm/src/chip_handler.rs @@ -1,5 +1,5 @@ use ff_ext::ExtensionField; -use gkr_iop::{error::CircuitBuilderError, gadgets::AssertLtConfig}; +use gkr_iop::{error::CircuitBuilderError, gadgets::AssertLtConfig, selector::SelectorType}; use crate::instructions::riscv::constants::UINT_LIMBS; use multilinear_extensions::{Expression, ToExpr}; @@ -73,4 +73,16 @@ pub trait MemoryChipOperations, N: FnOnce() prev_values: MemoryExpr, value: MemoryExpr, ) -> Result<(Expression, AssertLtConfig), CircuitBuilderError>; + + fn memory_write_with_rw_selectors( + &mut self, + name_fn: N, + memory_addr: &AddressExpr, + prev_ts: Expression, + ts: Expression, + prev_values: MemoryExpr, + value: MemoryExpr, + r_selector: &SelectorType, + w_selector: &SelectorType, + ) -> Result<(Expression, AssertLtConfig), CircuitBuilderError>; } diff --git a/ceno_zkvm/src/chip_handler/memory.rs b/ceno_zkvm/src/chip_handler/memory.rs index ba76041e3..71f0d51a6 100644 --- a/ceno_zkvm/src/chip_handler/memory.rs +++ b/ceno_zkvm/src/chip_handler/memory.rs @@ -5,7 +5,7 @@ use crate::{ structs::RAMType, }; use ff_ext::ExtensionField; -use gkr_iop::error::CircuitBuilderError; +use gkr_iop::{error::CircuitBuilderError, selector::SelectorType}; use multilinear_extensions::Expression; impl, N: FnOnce() -> NR> MemoryChipOperations @@ -48,4 +48,28 @@ impl, N: FnOnce() -> NR> MemoryChipOperation value, ) } + + fn memory_write_with_rw_selectors( + &mut self, + name_fn: N, + memory_addr: &AddressExpr, + prev_ts: Expression, + ts: Expression, + prev_values: MemoryExpr, + value: MemoryExpr, + r_selector: &SelectorType, + w_selector: &SelectorType, + ) -> Result<(Expression, AssertLtConfig), CircuitBuilderError> { + self.ram_type_write_with_rw_selectors( + name_fn, + RAMType::Memory, + memory_addr.clone(), + prev_ts, + ts, + prev_values, + value, + r_selector, + w_selector, + ) + } } diff --git a/ceno_zkvm/src/instructions.rs b/ceno_zkvm/src/instructions.rs index df2c24ff9..c3297df51 100644 --- a/ceno_zkvm/src/instructions.rs +++ b/ceno_zkvm/src/instructions.rs @@ -6,6 +6,7 @@ use ceno_emul::StepRecord; use ff_ext::ExtensionField; use gkr_iop::{ chip::Chip, + default_out_eval_groups, gkr::{GKRCircuit, layer::Layer}, selector::SelectorType, utils::lk_multiplicity::Multiplicity, @@ -44,35 +45,16 @@ pub trait Instruction { param: &ProgramParams, ) -> Result<(Self::InstructionConfig, GKRCircuit), ZKVMError> { let config = Self::construct_circuit(cb, param)?; - let w_len = cb.cs.w_expressions.len(); - let r_len = cb.cs.r_expressions.len(); - let lk_len = cb.cs.lk_expressions.len(); - let zero_len = - cb.cs.assert_zero_expressions.len() + cb.cs.assert_zero_sumcheck_expressions.len(); let selector = cb.create_placeholder_structural_witin(|| "selector"); let selector_type = SelectorType::Prefix(selector.expr()); + cb.cs.set_default_read_selector(selector_type.clone()); + cb.cs.set_default_write_selector(selector_type.clone()); + cb.cs.set_default_lookup_selector(selector_type.clone()); + cb.cs.set_default_zero_selector(selector_type.clone()); - // all shared the same selector - let (out_evals, mut chip) = ( - [ - // r_record - (0..r_len).collect_vec(), - // w_record - (r_len..r_len + w_len).collect_vec(), - // lk_record - (r_len + w_len..r_len + w_len + lk_len).collect_vec(), - // zero_record - (0..zero_len).collect_vec(), - ], - Chip::new_from_cb(cb, 0), - ); - - // register selector to legacy constrain system - cb.cs.r_selector = Some(selector_type.clone()); - cb.cs.w_selector = Some(selector_type.clone()); - cb.cs.lk_selector = Some(selector_type.clone()); - cb.cs.zero_selector = Some(selector_type.clone()); + let out_evals = default_out_eval_groups(cb); + let mut chip = Chip::new_from_cb(cb, 0); let layer = Layer::from_circuit_builder(cb, format!("{}_main", Self::name()), 0, out_evals); chip.add_layer(layer); diff --git a/ceno_zkvm/src/instructions/riscv/arith.rs b/ceno_zkvm/src/instructions/riscv/arith.rs index 260245931..1cd2ec263 100644 --- a/ceno_zkvm/src/instructions/riscv/arith.rs +++ b/ceno_zkvm/src/instructions/riscv/arith.rs @@ -218,6 +218,6 @@ mod test { ) .unwrap(); - MockProver::assert_satisfied_raw(&cb, raw_witin, &[insn_code], None, Some(lkm)); + MockProver::assert_satisfied_raw(&cb, raw_witin, &[insn_code], 1, None, Some(lkm)); } } diff --git a/ceno_zkvm/src/instructions/riscv/arith_imm.rs b/ceno_zkvm/src/instructions/riscv/arith_imm.rs index a1c1d4403..bacd58ade 100644 --- a/ceno_zkvm/src/instructions/riscv/arith_imm.rs +++ b/ceno_zkvm/src/instructions/riscv/arith_imm.rs @@ -89,6 +89,6 @@ mod test { ) .unwrap(); - MockProver::assert_satisfied_raw(&cb, raw_witin, &[insn_code], None, Some(lkm)); + MockProver::assert_satisfied_raw(&cb, raw_witin, &[insn_code], 1, None, Some(lkm)); } } diff --git a/ceno_zkvm/src/instructions/riscv/auipc.rs b/ceno_zkvm/src/instructions/riscv/auipc.rs index ce7c64a95..f89f2de9b 100644 --- a/ceno_zkvm/src/instructions/riscv/auipc.rs +++ b/ceno_zkvm/src/instructions/riscv/auipc.rs @@ -272,6 +272,6 @@ mod tests { ) .unwrap(); - MockProver::assert_satisfied_raw(&cb, raw_witin, &[insn_code], None, Some(lkm)); + MockProver::assert_satisfied_raw(&cb, raw_witin, &[insn_code], 1, None, Some(lkm)); } } diff --git a/ceno_zkvm/src/instructions/riscv/branch/test.rs b/ceno_zkvm/src/instructions/riscv/branch/test.rs index 67f098ff0..28acf2991 100644 --- a/ceno_zkvm/src/instructions/riscv/branch/test.rs +++ b/ceno_zkvm/src/instructions/riscv/branch/test.rs @@ -52,7 +52,7 @@ fn impl_opcode_beq(take_branch: bool, a: u32, b: u32) { ) .unwrap(); - MockProver::assert_satisfied_raw(&cb, raw_witin, &[insn_code], None, Some(lkm)); + MockProver::assert_satisfied_raw(&cb, raw_witin, &[insn_code], 1, None, Some(lkm)); } #[test] @@ -94,7 +94,7 @@ fn impl_opcode_bne(take_branch: bool, a: u32, b: u32) { ) .unwrap(); - MockProver::assert_satisfied_raw(&cb, raw_witin, &[insn_code], None, Some(lkm)); + MockProver::assert_satisfied_raw(&cb, raw_witin, &[insn_code], 1, None, Some(lkm)); } #[test] @@ -138,7 +138,7 @@ fn impl_bltu_circuit(taken: bool, a: u32, b: u32) -> Result<(), ZKVMError> { ) .unwrap(); - MockProver::assert_satisfied_raw(&circuit_builder, raw_witin, &[insn_code], None, Some(lkm)); + MockProver::assert_satisfied_raw(&circuit_builder, raw_witin, &[insn_code], 1, None, Some(lkm)); Ok(()) } @@ -183,7 +183,7 @@ fn impl_bgeu_circuit(taken: bool, a: u32, b: u32) -> Result<(), ZKVMError> { ) .unwrap(); - MockProver::assert_satisfied_raw(&circuit_builder, raw_witin, &[insn_code], None, Some(lkm)); + MockProver::assert_satisfied_raw(&circuit_builder, raw_witin, &[insn_code], 1, None, Some(lkm)); Ok(()) } @@ -235,7 +235,7 @@ fn impl_blt_circuit(taken: bool, a: i32, b: i32) -> Result<() ) .unwrap(); - MockProver::assert_satisfied_raw(&circuit_builder, raw_witin, &[insn_code], None, Some(lkm)); + MockProver::assert_satisfied_raw(&circuit_builder, raw_witin, &[insn_code], 1, None, Some(lkm)); Ok(()) } @@ -287,6 +287,6 @@ fn impl_bge_circuit(taken: bool, a: i32, b: i32) -> Result<() ) .unwrap(); - MockProver::assert_satisfied_raw(&circuit_builder, raw_witin, &[insn_code], None, Some(lkm)); + MockProver::assert_satisfied_raw(&circuit_builder, raw_witin, &[insn_code], 1, None, Some(lkm)); Ok(()) } diff --git a/ceno_zkvm/src/instructions/riscv/div.rs b/ceno_zkvm/src/instructions/riscv/div.rs index 85718ad24..5e2ba2812 100644 --- a/ceno_zkvm/src/instructions/riscv/div.rs +++ b/ceno_zkvm/src/instructions/riscv/div.rs @@ -180,7 +180,7 @@ mod test { .expect("instruction must declare at least one InsnKind"); let insn_code = encode_rv32(insn_kind, 2, 3, 4, 0); // values assignment - let ([raw_witin, _], lkm) = Insn::assign_instances( + let ([raw_witin, raw_structural_witin], lkm) = Insn::assign_instances( &config, &mut ShardContext::default(), cb.cs.num_witin as usize, @@ -215,9 +215,14 @@ mod test { .into_iter() .map(|v| v.into()) .collect_vec(), - &[], + &raw_structural_witin + .to_mles() + .into_iter() + .map(|v| v.into()) + .collect_vec(), &[insn_code], expected_errors, + 1, None, Some(lkm), ); diff --git a/ceno_zkvm/src/instructions/riscv/dummy/test.rs b/ceno_zkvm/src/instructions/riscv/dummy/test.rs index 8cccc8d49..cb08e9666 100644 --- a/ceno_zkvm/src/instructions/riscv/dummy/test.rs +++ b/ceno_zkvm/src/instructions/riscv/dummy/test.rs @@ -28,5 +28,5 @@ fn test_large_ecall_dummy_keccak() { ) .unwrap(); - MockProver::assert_satisfied_raw(&cb, raw_witin, &program, None, Some(lkm)); + MockProver::assert_satisfied_raw(&cb, raw_witin, &program, 1, None, Some(lkm)); } diff --git a/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp.rs b/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp.rs index 7eeba1ab0..f2de34551 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp.rs @@ -7,8 +7,7 @@ use ceno_emul::{ use ff_ext::ExtensionField; use generic_array::typenum::Unsigned; use gkr_iop::{ - ProtocolBuilder, ProtocolWitnessGenerator, - gkr::{GKRCircuit, layer::Layer}, + ProtocolBuilder, ProtocolWitnessGenerator, gkr::GKRCircuit, utils::lk_multiplicity::Multiplicity, }; use itertools::{Itertools, izip}; @@ -250,7 +249,7 @@ fn build_fp_op_circuit( 0.into(), ))?; - let mut layout = as ProtocolBuilder>::build_layer_logic(cb, ())?; + let layout = as ProtocolBuilder>::build_layer_logic(cb, ())?; let mut mem_rw = izip!(&layout.input32_exprs[0], &layout.output32_exprs) .enumerate() @@ -287,10 +286,7 @@ fn build_fp_op_circuit( .collect::, _>>()?, ); - let (out_evals, mut chip) = layout.finalize(cb); - let layer = - Layer::from_circuit_builder(cb, layer_name.to_string(), layout.n_challenges, out_evals); - chip.add_layer(layer); + let chip = layout.finalize(layer_name.to_string(), cb); Ok(( EcallFpOpConfig { diff --git a/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp2_add.rs b/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp2_add.rs index 6552b6241..7837aa479 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp2_add.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp2_add.rs @@ -6,8 +6,7 @@ use ceno_emul::{ use ff_ext::ExtensionField; use generic_array::typenum::Unsigned; use gkr_iop::{ - ProtocolBuilder, ProtocolWitnessGenerator, - gkr::{GKRCircuit, layer::Layer}, + ProtocolBuilder, ProtocolWitnessGenerator, gkr::GKRCircuit, utils::lk_multiplicity::Multiplicity, }; use itertools::{Itertools, izip}; @@ -163,8 +162,7 @@ fn build_fp2_add_circuit as ProtocolBuilder>::build_layer_logic(cb, ())?; + let layout = as ProtocolBuilder>::build_layer_logic(cb, ())?; let mut mem_rw = izip!(&layout.input32_exprs[0], &layout.output32_exprs) .enumerate() @@ -201,10 +199,7 @@ fn build_fp2_add_circuit, _>>()?, ); - let (out_evals, mut chip) = layout.finalize(cb); - let layer = - Layer::from_circuit_builder(cb, "fp2_add".to_string(), layout.n_challenges, out_evals); - chip.add_layer(layer); + let chip = layout.finalize("fp2_add".to_string(), cb); Ok(( EcallFp2AddConfig { diff --git a/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp2_mul.rs b/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp2_mul.rs index 709e9734d..2470a3234 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp2_mul.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp2_mul.rs @@ -6,8 +6,7 @@ use ceno_emul::{ use ff_ext::ExtensionField; use generic_array::typenum::Unsigned; use gkr_iop::{ - ProtocolBuilder, ProtocolWitnessGenerator, - gkr::{GKRCircuit, layer::Layer}, + ProtocolBuilder, ProtocolWitnessGenerator, gkr::GKRCircuit, utils::lk_multiplicity::Multiplicity, }; use itertools::{Itertools, izip}; @@ -162,7 +161,7 @@ fn build_fp2_mul_circuit as ProtocolBuilder>::build_layer_logic(cb, ())?; + let layout = as ProtocolBuilder>::build_layer_logic(cb, ())?; let mut mem_rw = izip!(&layout.input32_exprs[0], &layout.output32_exprs) .enumerate() @@ -199,10 +198,7 @@ fn build_fp2_mul_circuit, _>>()?, ); - let (out_evals, mut chip) = layout.finalize(cb); - let layer = - Layer::from_circuit_builder(cb, "fp2_mul".to_string(), layout.n_challenges, out_evals); - chip.add_layer(layer); + let chip = layout.finalize("fp2_mul".to_string(), cb); Ok(( EcallFp2MulConfig { diff --git a/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs b/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs index 51568b56a..84b187d8a 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs @@ -6,7 +6,7 @@ use ceno_emul::{ use ff_ext::ExtensionField; use gkr_iop::{ ProtocolBuilder, ProtocolWitnessGenerator, - gkr::{GKRCircuit, booleanhypercube::BooleanHypercube, layer::Layer}, + gkr::{GKRCircuit, booleanhypercube::BooleanHypercube}, utils::lk_multiplicity::Multiplicity, }; use itertools::{Itertools, izip}; @@ -32,8 +32,8 @@ use crate::{ }, }, precompiles::{ - KECCAK_ROUNDS, KECCAK_ROUNDS_CEIL_LOG2, KeccakInstance, KeccakLayout, KeccakParams, - KeccakStateInstance, KeccakTrace, KeccakWitInstance, + KECCAK_ROUNDS, KECCAK_ROUNDS_CEIL_LOG2, KeccakInstance, KeccakLayout, KeccakStateInstance, + KeccakTrace, KeccakWitInstance, }, structs::ProgramParams, tables::{InsnRecord, RMMCollections}, @@ -75,6 +75,10 @@ impl Instruction for KeccakInstruction { cb: &mut CircuitBuilder, _param: &ProgramParams, ) -> Result<(Self::InstructionConfig, GKRCircuit), ZKVMError> { + // We should create the layout first to set the default selectors. + // TODO: find a better way to handle this. + let layout = KeccakLayout::build_layer_logic(cb, ())?; + // constrain vmstate let vm_state = StateInOut::construct_circuit(cb, false)?; @@ -108,11 +112,6 @@ impl Instruction for KeccakInstruction { 0.into(), ))?; - let mut layout = as gkr_iop::ProtocolBuilder>::build_layer_logic( - cb, - KeccakParams {}, - )?; - // memory rw, for we in-place update let mem_rw = izip!(&layout.input32_exprs, &layout.output32_exprs) .enumerate() @@ -131,10 +130,7 @@ impl Instruction for KeccakInstruction { }) .collect::, _>>()?; - let (out_evals, mut chip) = layout.finalize(cb); - - let layer = Layer::from_circuit_builder(cb, Self::name(), layout.n_challenges, out_evals); - chip.add_layer(layer); + let chip = layout.finalize(Self::name(), cb); let circuit = chip.gkr_circuit(); diff --git a/ceno_zkvm/src/instructions/riscv/ecall/sha_extend.rs b/ceno_zkvm/src/instructions/riscv/ecall/sha_extend.rs index 7f8513dc1..c95e263e4 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/sha_extend.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/sha_extend.rs @@ -3,8 +3,7 @@ use std::{array, marker::PhantomData}; use ceno_emul::{Change, InsnKind, Platform, SHA_EXTEND, StepRecord, WORD_SIZE, WriteOp}; use ff_ext::{ExtensionField, FieldInto}; use gkr_iop::{ - ProtocolBuilder, ProtocolWitnessGenerator, - gkr::{GKRCircuit, layer::Layer}, + ProtocolBuilder, ProtocolWitnessGenerator, gkr::GKRCircuit, utils::lk_multiplicity::Multiplicity, }; use itertools::{Itertools, izip}; @@ -102,7 +101,7 @@ impl Instruction for ShaExtendInstruction { 0.into(), ))?; - let mut layout = + let layout = as gkr_iop::ProtocolBuilder>::build_layer_logic(cb, ())?; let old_value = @@ -128,10 +127,7 @@ impl Instruction for ShaExtendInstruction { vm_state.ts, )?); - let (out_evals, mut chip) = layout.finalize(cb); - - let layer = Layer::from_circuit_builder(cb, Self::name(), layout.n_challenges, out_evals); - chip.add_layer(layer); + let chip = layout.finalize(Self::name(), cb); let circuit = chip.gkr_circuit(); @@ -175,7 +171,6 @@ impl Instruction for ShaExtendInstruction { steps: &[StepRecord], ) -> Result<(RMMCollections, Multiplicity), ZKVMError> { let mut lk_multiplicity = LkMultiplicity::default(); - let num_structural_witin = config.layout.n_structural_witin.max(num_structural_witin); if steps.is_empty() { return Ok(( [ diff --git a/ceno_zkvm/src/instructions/riscv/ecall/uint256.rs b/ceno_zkvm/src/instructions/riscv/ecall/uint256.rs index 8b38df7e1..a0a26c84a 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/uint256.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/uint256.rs @@ -7,8 +7,7 @@ use ceno_emul::{ use ff_ext::ExtensionField; use generic_array::typenum::Unsigned; use gkr_iop::{ - ProtocolBuilder, ProtocolWitnessGenerator, - gkr::{GKRCircuit, layer::Layer}, + ProtocolBuilder, ProtocolWitnessGenerator, gkr::GKRCircuit, utils::lk_multiplicity::Multiplicity, }; use itertools::{Itertools, chain, izip}; @@ -131,7 +130,7 @@ impl Instruction for Uint256MulInstruction { 0.into(), ))?; - let mut layout = + let layout = as gkr_iop::ProtocolBuilder>::build_layer_logic(cb, ())?; // Write the result to the same address of the first input point. @@ -177,15 +176,7 @@ impl Instruction for Uint256MulInstruction { .collect::, _>>()?, ); - let (out_evals, mut chip) = layout.finalize(cb); - - let layer = Layer::from_circuit_builder( - cb, - "uint256_mul".to_string(), - layout.n_challenges, - out_evals, - ); - chip.add_layer(layer); + let chip = layout.finalize("uint256_mul".to_string(), cb); let circuit = chip.gkr_circuit(); @@ -478,8 +469,7 @@ impl Instruction for Uint256InvInstr 0.into(), ))?; - let mut layout = - as ProtocolBuilder>::build_layer_logic(cb, ())?; + let layout = as ProtocolBuilder>::build_layer_logic(cb, ())?; // Write the result to the same address of the first input point. let mem_rw = layout @@ -503,10 +493,7 @@ impl Instruction for Uint256InvInstr }) .collect::, _>>()?; - let (out_evals, mut chip) = layout.finalize(cb); - - let layer = Layer::from_circuit_builder(cb, Spec::name(), layout.n_challenges, out_evals); - chip.add_layer(layer); + let chip = layout.finalize(Spec::name(), cb); let circuit = chip.gkr_circuit(); diff --git a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_add.rs b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_add.rs index 05a91cd97..08903c1ca 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_add.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_add.rs @@ -7,8 +7,7 @@ use ceno_emul::{ use ff_ext::ExtensionField; use generic_array::{GenericArray, typenum::Unsigned}; use gkr_iop::{ - ProtocolBuilder, ProtocolWitnessGenerator, - gkr::{GKRCircuit, layer::Layer}, + ProtocolBuilder, ProtocolWitnessGenerator, gkr::GKRCircuit, utils::lk_multiplicity::Multiplicity, }; use itertools::{Itertools, izip}; @@ -130,7 +129,7 @@ impl Instruction 0.into(), ))?; - let mut layout = + let layout = as gkr_iop::ProtocolBuilder>::build_layer_logic( cb, (), @@ -177,15 +176,7 @@ impl Instruction .collect::, _>>()?, ); - let (out_evals, mut chip) = layout.finalize(cb); - - let layer = Layer::from_circuit_builder( - cb, - "weierstrass_add".to_string(), - layout.n_challenges, - out_evals, - ); - chip.add_layer(layer); + let chip = layout.finalize("weierstrass_add".to_string(), cb); let circuit = chip.gkr_circuit(); diff --git a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_decompress.rs b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_decompress.rs index 6d9a7470b..4f349b3f6 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_decompress.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_decompress.rs @@ -7,8 +7,7 @@ use ceno_emul::{ use ff_ext::ExtensionField; use generic_array::{GenericArray, typenum::Unsigned}; use gkr_iop::{ - ProtocolBuilder, ProtocolWitnessGenerator, - gkr::{GKRCircuit, layer::Layer}, + ProtocolBuilder, ProtocolWitnessGenerator, gkr::GKRCircuit, utils::lk_multiplicity::Multiplicity, }; use itertools::{Itertools, izip}; @@ -88,7 +87,7 @@ impl Instruction Result<(Self::InstructionConfig, GKRCircuit), ZKVMError> { // constrain vmstate - let mut layout = + let layout = as gkr_iop::ProtocolBuilder>::build_layer_logic( cb, (), @@ -178,15 +177,7 @@ impl Instruction, _>>()?, ); - let (out_evals, mut chip) = layout.finalize(cb); - - let layer = Layer::from_circuit_builder( - cb, - "weierstrass_decompress".to_string(), - layout.n_challenges, - out_evals, - ); - chip.add_layer(layer); + let chip = layout.finalize("weierstrass_decompress".to_string(), cb); let circuit = chip.gkr_circuit(); diff --git a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_double.rs b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_double.rs index 4b9a2aeb6..0e11f55cb 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_double.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_double.rs @@ -7,8 +7,7 @@ use ceno_emul::{ use ff_ext::ExtensionField; use generic_array::{GenericArray, typenum::Unsigned}; use gkr_iop::{ - ProtocolBuilder, ProtocolWitnessGenerator, - gkr::{GKRCircuit, layer::Layer}, + ProtocolBuilder, ProtocolWitnessGenerator, gkr::GKRCircuit, utils::lk_multiplicity::Multiplicity, }; use itertools::{Itertools, izip}; @@ -125,12 +124,11 @@ impl Instruction as gkr_iop::ProtocolBuilder>::build_layer_logic( cb, (), )?; - // Write the result to the same address of the first input point. let mem_rw = izip!(&layout.input32_exprs, &layout.output32_exprs) .enumerate() @@ -150,15 +148,7 @@ impl Instruction, _>>()?; - let (out_evals, mut chip) = layout.finalize(cb); - - let layer = Layer::from_circuit_builder( - cb, - "weierstrass_double".to_string(), - layout.n_challenges, - out_evals, - ); - chip.add_layer(layer); + let chip = layout.finalize("weierstrass_double".to_string(), cb); let circuit = chip.gkr_circuit(); diff --git a/ceno_zkvm/src/instructions/riscv/insn_base.rs b/ceno_zkvm/src/instructions/riscv/insn_base.rs index 1a378ad8c..7e08dbbaf 100644 --- a/ceno_zkvm/src/instructions/riscv/insn_base.rs +++ b/ceno_zkvm/src/instructions/riscv/insn_base.rs @@ -1,5 +1,6 @@ use ceno_emul::{Cycle, StepRecord, Word, WriteOp}; use ff_ext::{ExtensionField, FieldInto, SmallField}; +use gkr_iop::selector::SelectorType; use itertools::Itertools; use p3::field::{Field, FieldAlgebra}; @@ -391,6 +392,31 @@ impl WriteMEM { Ok(WriteMEM { prev_ts, lt_cfg }) } + pub fn construct_circuit_with_rw_selectors( + circuit_builder: &mut CircuitBuilder, + mem_addr: AddressExpr, + prev_value: MemoryExpr, + new_value: MemoryExpr, + cur_ts: WitIn, + r_selector: &SelectorType, + w_selector: &SelectorType, + ) -> Result { + let prev_ts = circuit_builder.create_witin(|| "prev_ts"); + + let (_, lt_cfg) = circuit_builder.memory_write_with_rw_selectors( + || "write_memory", + &mem_addr, + prev_ts.expr(), + cur_ts.expr() + Tracer::SUBCYCLE_MEM, + prev_value, + new_value, + r_selector, + w_selector, + )?; + + Ok(WriteMEM { prev_ts, lt_cfg }) + } + pub fn assign_instance( &self, instance: &mut [::BaseField], @@ -683,6 +709,7 @@ mod test { &[], &[], if is_ok { &[] } else { &["mid_u14"] }, + num_rows, None, None, ); diff --git a/ceno_zkvm/src/instructions/riscv/jump/test.rs b/ceno_zkvm/src/instructions/riscv/jump/test.rs index 355dad511..28f81eb2f 100644 --- a/ceno_zkvm/src/instructions/riscv/jump/test.rs +++ b/ceno_zkvm/src/instructions/riscv/jump/test.rs @@ -70,7 +70,7 @@ fn verify_test_opcode_jal(pc_offset: i32) { ) .unwrap(); - MockProver::assert_satisfied_raw(&cb, raw_witin, &[insn_code], None, Some(lkm)); + MockProver::assert_satisfied_raw(&cb, raw_witin, &[insn_code], 1, None, Some(lkm)); } #[test] @@ -132,5 +132,5 @@ fn verify_test_opcode_jalr(rs1_read: Word, imm: i32) { )], ) .unwrap(); - MockProver::assert_satisfied_raw(&cb, raw_witin, &[insn_code], None, Some(lkm)); + MockProver::assert_satisfied_raw(&cb, raw_witin, &[insn_code], 1, None, Some(lkm)); } diff --git a/ceno_zkvm/src/instructions/riscv/logic/test.rs b/ceno_zkvm/src/instructions/riscv/logic/test.rs index 6bade9c0f..f3c61c644 100644 --- a/ceno_zkvm/src/instructions/riscv/logic/test.rs +++ b/ceno_zkvm/src/instructions/riscv/logic/test.rs @@ -54,7 +54,7 @@ fn test_opcode_and() { .require_equal(|| "assert_rd_written", &mut cb, &expected_rd_written) .unwrap(); - MockProver::assert_satisfied_raw(&cb, raw_witin, &[insn_code], None, Some(lkm)); + MockProver::assert_satisfied_raw(&cb, raw_witin, &[insn_code], 1, None, Some(lkm)); } #[test] @@ -97,7 +97,7 @@ fn test_opcode_or() { .require_equal(|| "assert_rd_written", &mut cb, &expected_rd_written) .unwrap(); - MockProver::assert_satisfied_raw(&cb, raw_witin, &[insn_code], None, Some(lkm)); + MockProver::assert_satisfied_raw(&cb, raw_witin, &[insn_code], 1, None, Some(lkm)); } #[test] @@ -140,5 +140,5 @@ fn test_opcode_xor() { .require_equal(|| "assert_rd_written", &mut cb, &expected_rd_written) .unwrap(); - MockProver::assert_satisfied_raw(&cb, raw_witin, &[insn_code], None, Some(lkm)); + MockProver::assert_satisfied_raw(&cb, raw_witin, &[insn_code], 1, None, Some(lkm)); } diff --git a/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit.rs b/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit.rs index fea2b03df..e459b4bf4 100644 --- a/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit.rs @@ -253,6 +253,6 @@ mod test { cb.require_equal(|| "assert_rd_written", rd_written_expr, expected.value()) .unwrap(); - MockProver::assert_satisfied_raw(&cb, raw_witin, &[insn_code], None, Some(lkm)); + MockProver::assert_satisfied_raw(&cb, raw_witin, &[insn_code], 1, None, Some(lkm)); } } diff --git a/ceno_zkvm/src/instructions/riscv/logic_imm/test.rs b/ceno_zkvm/src/instructions/riscv/logic_imm/test.rs index 3a003777b..557ddd225 100644 --- a/ceno_zkvm/src/instructions/riscv/logic_imm/test.rs +++ b/ceno_zkvm/src/instructions/riscv/logic_imm/test.rs @@ -90,5 +90,5 @@ fn verify(name: &'static str, rs1_read: u32, imm: u32, expected_rd_w cb.require_equal(|| "assert_rd_written", rd_written_expr, expected.value()) .unwrap(); - MockProver::assert_satisfied_raw(&cb, raw_witin, &[insn_code], None, Some(lkm)); + MockProver::assert_satisfied_raw(&cb, raw_witin, &[insn_code], 1, None, Some(lkm)); } diff --git a/ceno_zkvm/src/instructions/riscv/lui.rs b/ceno_zkvm/src/instructions/riscv/lui.rs index 93d24c4ef..0ba17af11 100644 --- a/ceno_zkvm/src/instructions/riscv/lui.rs +++ b/ceno_zkvm/src/instructions/riscv/lui.rs @@ -186,6 +186,6 @@ mod tests { ) .unwrap(); - MockProver::assert_satisfied_raw(&cb, raw_witin, &[insn_code], None, Some(lkm)); + MockProver::assert_satisfied_raw(&cb, raw_witin, &[insn_code], 1, None, Some(lkm)); } } diff --git a/ceno_zkvm/src/instructions/riscv/memory/test.rs b/ceno_zkvm/src/instructions/riscv/memory/test.rs index f6b0fa153..8a12f5755 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/test.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/test.rs @@ -136,7 +136,7 @@ fn impl_opcode_store>(imm: i32) { @@ -184,7 +184,7 @@ fn impl_opcode_load { #[derive(Clone, Debug)] pub struct KeccakLayers { pub output32: Output32Layer, + pub sel: EqT, pub inner_rounds: [KeccakRound; 23], pub first_round: KeccakRound, } @@ -340,6 +341,7 @@ impl Default for KeccakLayout { }); KeccakLayers { output32, + sel, inner_rounds, first_round, } @@ -664,6 +666,7 @@ impl KeccakLayout { let mut chip = Chip { n_fixed: 0, n_committed: STATE_SIZE, + n_structural_witin: 0, n_challenges: 0, n_evaluations: KECCAK_ALL_IN_EVAL_SIZE + KECCAK_OUT_EVAL_SIZE, layers: vec![], @@ -783,10 +786,6 @@ impl KeccakLayout { impl ProtocolBuilder for KeccakLayout { type Params = KeccakParams; - fn finalize(&mut self, _cb: &mut CircuitBuilder) -> (OutEvalGroups, Chip) { - unimplemented!() - } - fn build_layer_logic( _cb: &mut CircuitBuilder, _params: Self::Params, @@ -794,24 +793,8 @@ impl ProtocolBuilder for KeccakLayout { unimplemented!() } - fn n_committed(&self) -> usize { - STATE_SIZE - } - - fn n_fixed(&self) -> usize { - 0 - } - - fn n_challenges(&self) -> usize { - 0 - } - - fn n_layers(&self) -> usize { - 5 * ROUNDS + 1 - } - - fn n_evaluations(&self) -> usize { - KECCAK_ALL_IN_EVAL_SIZE + KECCAK_OUT_EVAL_SIZE + fn finalize(&self, _name: String, _cb: &mut CircuitBuilder) -> Chip { + unimplemented!() } } @@ -971,7 +954,7 @@ pub fn run_keccakf + 'stat gkr_circuit .layers .first() - .map(|layer| layer.out_sel_and_eval_exprs.len()) + .map(|layer| layer.selector_ctxs_len()) .unwrap() ]; let GKRProverOutput { gkr_proof, .. } = gkr_circuit diff --git a/ceno_zkvm/src/precompiles/fptower/fp.rs b/ceno_zkvm/src/precompiles/fptower/fp.rs index ba7c04308..2b8a446ec 100644 --- a/ceno_zkvm/src/precompiles/fptower/fp.rs +++ b/ceno_zkvm/src/precompiles/fptower/fp.rs @@ -28,11 +28,13 @@ use derive::AlignedBorrow; use ff_ext::ExtensionField; use generic_array::{GenericArray, sequence::GenericSequence}; use gkr_iop::{ - OutEvalGroups, ProtocolBuilder, ProtocolWitnessGenerator, chip::Chip, - circuit_builder::CircuitBuilder, error::CircuitBuilderError, selector::SelectorType, + ProtocolBuilder, ProtocolWitnessGenerator, chip::Chip, circuit_builder::CircuitBuilder, + default_out_eval_groups, error::CircuitBuilderError, gkr::layer::Layer, selector::SelectorType, }; use itertools::Itertools; -use multilinear_extensions::{Expression, ToExpr, WitIn, util::max_usable_threads}; +use multilinear_extensions::{ + Expression, StructuralWitIn, ToExpr, WitIn, util::max_usable_threads, +}; use num::BigUint; use p3::field::FieldAlgebra; use rayon::{ @@ -50,7 +52,7 @@ use witness::{InstancePaddingStrategy, RowMajorMatrix}; use crate::{ chip_handler::MemoryExpr, gadgets::{FieldOperation, field_op::FieldOpCols, range::FieldLtCols}, - precompiles::{SelectorTypeLayout, utils::merge_u8_slice_to_u16_limbs_pairs_and_extend}, + precompiles::utils::merge_u8_slice_to_u16_limbs_pairs_and_extend, witness::LkMultiplicity, }; @@ -102,13 +104,9 @@ pub struct FpOpLayer { #[derive(Clone, Debug)] pub struct FpOpLayout { pub layer_exprs: FpOpLayer, - pub selector_type_layout: SelectorTypeLayout, + pub sel: StructuralWitIn, pub input32_exprs: [GenericArray,

::WordsFieldElement>; 2], pub output32_exprs: GenericArray,

::WordsFieldElement>, - pub n_fixed: usize, - pub n_committed: usize, - pub n_structural_witin: usize, - pub n_challenges: usize, } impl FpOpLayout { @@ -123,14 +121,6 @@ impl FpOpLayout { output_range_check: FieldLtCols::create(cb, || "fp_op_output_range"), }; - let eq = cb.create_placeholder_structural_witin(|| "fp_op_structural_witin"); - let sel = SelectorType::Prefix(eq.expr()); - let selector_type_layout = SelectorTypeLayout { - sel_first: None, - sel_last: None, - sel_all: sel.clone(), - }; - let input32_exprs: [GenericArray,

::WordsFieldElement>; 2] = array::from_fn(|_| { GenericArray::generate(|_| array::from_fn(|_| Expression::WitIn(0))) @@ -140,13 +130,9 @@ impl FpOpLayout { Self { layer_exprs: FpOpLayer { wits }, - selector_type_layout, + sel: cb.create_placeholder_structural_witin(|| "fp_op_sel"), input32_exprs, output32_exprs, - n_fixed: 0, - n_committed: 0, - n_structural_witin: 0, - n_challenges: 0, } } @@ -226,31 +212,15 @@ impl ProtocolBuilder for FpOpLayout { Ok(layout) } - fn finalize(&mut self, cb: &mut CircuitBuilder) -> (OutEvalGroups, Chip) { - self.n_fixed = cb.cs.num_fixed; - self.n_committed = cb.cs.num_witin as usize; - self.n_structural_witin = cb.cs.num_structural_witin as usize; - self.n_challenges = 0; - - cb.cs.r_selector = Some(self.selector_type_layout.sel_all.clone()); - cb.cs.w_selector = Some(self.selector_type_layout.sel_all.clone()); - cb.cs.lk_selector = Some(self.selector_type_layout.sel_all.clone()); - cb.cs.zero_selector = Some(self.selector_type_layout.sel_all.clone()); - - let w_len = cb.cs.w_expressions.len(); - let r_len = cb.cs.r_expressions.len(); - let lk_len = cb.cs.lk_expressions.len(); - let zero_len = - cb.cs.assert_zero_expressions.len() + cb.cs.assert_zero_sumcheck_expressions.len(); - ( - [ - (0..r_len).collect_vec(), - (r_len..r_len + w_len).collect_vec(), - (r_len + w_len..r_len + w_len + lk_len).collect_vec(), - (0..zero_len).collect_vec(), - ], - Chip::new_from_cb(cb, self.n_challenges), - ) + fn finalize(&self, name: String, cb: &mut CircuitBuilder) -> Chip { + let sel = SelectorType::Prefix(self.sel.expr()); + cb.cs.set_all_default_selectors(sel); + + let out_evals = default_out_eval_groups(cb); + let mut chip = Chip::new_from_cb(cb, 0); + let layer = Layer::from_circuit_builder(cb, name, 0, out_evals); + chip.add_layer(layer); + chip } } @@ -267,12 +237,14 @@ impl ProtocolWitnessGenerator for FpOpLayout wits: [&mut RowMajorMatrix; 2], lk_multiplicity: &mut LkMultiplicity, ) { - let (wits_start, num_wit_cols) = + let (layout_start, num_layout_wit_cols) = (self.layer_exprs.wits.is_add.id as usize, num_fp_cols::

()); let [wits, structural_wits] = wits; let num_instances = wits.num_instances(); let nthreads = max_usable_threads(); let num_instance_per_batch = num_instances.div_ceil(nthreads).max(1); + let wits_width = wits.width; + let structural_wits_width = structural_wits.width; let raw_witin_iter = wits.par_batch_iter_mut(num_instance_per_batch); let raw_structural_wits_iter = structural_wits.par_batch_iter_mut(num_instance_per_batch); raw_witin_iter @@ -280,16 +252,14 @@ impl ProtocolWitnessGenerator for FpOpLayout .zip_eq(phase1.instances.par_chunks(num_instance_per_batch)) .for_each(|((rows, eqs), phase1_instances)| { let mut lk_multiplicity = lk_multiplicity.clone(); - rows.chunks_mut(self.n_committed) - .zip_eq(eqs.chunks_mut(self.n_structural_witin)) + rows.chunks_mut(wits_width) + .zip_eq(eqs.chunks_mut(structural_wits_width)) .zip_eq(phase1_instances) .for_each(|((row, eqs), phase1_instance)| { let cols: &mut FpOpWitCols = - row[wits_start..][..num_wit_cols].borrow_mut(); + row[layout_start..][..num_layout_wit_cols].borrow_mut(); Self::populate_row(phase1_instance, cols, &mut lk_multiplicity); - for x in eqs.iter_mut() { - *x = E::BaseField::ONE; - } + eqs[self.sel.id as usize] = E::BaseField::ONE; }); }); } @@ -302,7 +272,7 @@ mod tests { use gkr_iop::{ circuit_builder::{CircuitBuilder, ConstraintSystem}, cpu::{CpuBackend, CpuProver}, - gkr::{GKRProverOutput, layer::Layer}, + gkr::GKRProverOutput, selector::SelectorContext, }; use itertools::Itertools; @@ -331,12 +301,10 @@ mod tests { let mut cs = ConstraintSystem::::new(|| "fp_op_test"); let mut cb = CircuitBuilder::::new(&mut cs); - let mut layout = + let layout = FpOpLayout::::build_layer_logic(&mut cb, ()).expect("build_layer_logic failed"); - let (out_evals, mut chip) = layout.finalize(&mut cb); - let layer = - Layer::from_circuit_builder(&cb, "fp_op".to_string(), layout.n_challenges, out_evals); - chip.add_layer(layer); + + let chip = layout.finalize("fp_op".to_string(), &mut cb); let gkr_circuit = chip.gkr_circuit(); let instances = (0..count) @@ -354,12 +322,12 @@ mod tests { let mut phase1 = RowMajorMatrix::new( instances.len(), - layout.n_committed, + chip.n_committed, InstancePaddingStrategy::Default, ); let mut structural = RowMajorMatrix::new( instances.len(), - layout.n_structural_witin, + chip.n_structural_witin, InstancePaddingStrategy::Default, ); let mut lk_multiplicity = LkMultiplicity::default(); @@ -432,11 +400,9 @@ mod tests { .0 .par_iter() .map(|wit| { - let point = point[point.len() - wit.num_vars()..point.len()].to_vec(); - PointAndEval { - point: point.clone(), - eval: wit.evaluate(&point), - } + let point = point[point.len() - log2_num_instance..point.len()].to_vec(); + let eval = wit.evaluate(&point); + PointAndEval { point, eval } }) .collect::>(); @@ -450,7 +416,9 @@ mod tests { } }; - let selector_ctxs = vec![SelectorContext::new(0, num_instances, log2_num_instance); 1]; + let selector_ctxs_len = gkr_circuit.layers[0].selector_ctxs_len(); + let selector_ctxs = + vec![SelectorContext::new(0, num_instances, log2_num_instance); selector_ctxs_len]; let GKRProverOutput { gkr_proof, .. } = gkr_circuit .prove::, CpuProver<_>>( num_threads, @@ -488,41 +456,21 @@ mod tests { #[test] fn test_bls12381_fp_ops() { - std::thread::Builder::new() - .stack_size(32 * 1024 * 1024) - .spawn(|| test_fp_ops_helper::(8)) - .expect("spawn fp_ops test thread failed") - .join() - .expect("fp_ops test thread panicked"); + test_fp_ops_helper::(8) } #[test] fn test_bls12381_fp_ops_nonpow2() { - std::thread::Builder::new() - .stack_size(32 * 1024 * 1024) - .spawn(|| test_fp_ops_helper::(7)) - .expect("spawn fp_ops test thread failed") - .join() - .expect("fp_ops test thread panicked"); + test_fp_ops_helper::(7) } #[test] fn test_bn254_fp_ops() { - std::thread::Builder::new() - .stack_size(32 * 1024 * 1024) - .spawn(|| test_fp_ops_helper::(8)) - .expect("spawn fp_ops test thread failed") - .join() - .expect("fp_ops test thread panicked"); + test_fp_ops_helper::(8) } #[test] fn test_bn254_fp_ops_nonpow2() { - std::thread::Builder::new() - .stack_size(32 * 1024 * 1024) - .spawn(|| test_fp_ops_helper::(7)) - .expect("spawn fp_ops test thread failed") - .join() - .expect("fp_ops test thread panicked"); + test_fp_ops_helper::(7) } } diff --git a/ceno_zkvm/src/precompiles/fptower/fp2_addsub.rs b/ceno_zkvm/src/precompiles/fptower/fp2_addsub.rs index 76d32b31a..cafaa1f29 100644 --- a/ceno_zkvm/src/precompiles/fptower/fp2_addsub.rs +++ b/ceno_zkvm/src/precompiles/fptower/fp2_addsub.rs @@ -28,11 +28,13 @@ use derive::AlignedBorrow; use ff_ext::ExtensionField; use generic_array::{GenericArray, sequence::GenericSequence}; use gkr_iop::{ - OutEvalGroups, ProtocolBuilder, ProtocolWitnessGenerator, chip::Chip, - circuit_builder::CircuitBuilder, error::CircuitBuilderError, selector::SelectorType, + ProtocolBuilder, ProtocolWitnessGenerator, chip::Chip, circuit_builder::CircuitBuilder, + default_out_eval_groups, error::CircuitBuilderError, gkr::layer::Layer, selector::SelectorType, }; use itertools::Itertools; -use multilinear_extensions::{Expression, ToExpr, WitIn, util::max_usable_threads}; +use multilinear_extensions::{ + Expression, StructuralWitIn, ToExpr, WitIn, util::max_usable_threads, +}; use num::BigUint; use p3::field::FieldAlgebra; use rayon::{ @@ -50,7 +52,7 @@ use witness::{InstancePaddingStrategy, RowMajorMatrix}; use crate::{ chip_handler::MemoryExpr, gadgets::{FieldOperation, field_op::FieldOpCols, range::FieldLtCols}, - precompiles::{SelectorTypeLayout, utils::merge_u8_slice_to_u16_limbs_pairs_and_extend}, + precompiles::utils::merge_u8_slice_to_u16_limbs_pairs_and_extend, witness::LkMultiplicity, }; @@ -109,13 +111,9 @@ pub struct Fp2AddSubAssignLayer { #[derive(Clone, Debug)] pub struct Fp2AddSubAssignLayout { pub layer_exprs: Fp2AddSubAssignLayer, - pub selector_type_layout: SelectorTypeLayout, + pub sel: StructuralWitIn, pub input32_exprs: [GenericArray,

::WordsCurvePoint>; 2], pub output32_exprs: GenericArray,

::WordsCurvePoint>, - pub n_fixed: usize, - pub n_committed: usize, - pub n_structural_witin: usize, - pub n_challenges: usize, } impl Fp2AddSubAssignLayout { @@ -132,14 +130,6 @@ impl Fp2AddSubAssignLayout { c1_range_check: FieldLtCols::create(cb, || "fp2_c1_range"), }; - let eq = cb.create_placeholder_structural_witin(|| "fp2_addsub_structural_witin"); - let sel = SelectorType::Prefix(eq.expr()); - let selector_type_layout = SelectorTypeLayout { - sel_first: None, - sel_last: None, - sel_all: sel.clone(), - }; - let input32_exprs: [GenericArray,

::WordsCurvePoint>; 2] = array::from_fn(|_| { GenericArray::generate(|_| array::from_fn(|_| Expression::WitIn(0))) @@ -149,13 +139,9 @@ impl Fp2AddSubAssignLayout { Self { layer_exprs: Fp2AddSubAssignLayer { wits }, - selector_type_layout, + sel: cb.create_placeholder_structural_witin(|| "fp2_addsub_sel"), input32_exprs, output32_exprs, - n_fixed: 0, - n_committed: 0, - n_structural_witin: 0, - n_challenges: 0, } } @@ -250,31 +236,15 @@ impl ProtocolBuilder for Fp2AddSubAssignLayo Ok(layout) } - fn finalize(&mut self, cb: &mut CircuitBuilder) -> (OutEvalGroups, Chip) { - self.n_fixed = cb.cs.num_fixed; - self.n_committed = cb.cs.num_witin as usize; - self.n_structural_witin = cb.cs.num_structural_witin as usize; - self.n_challenges = 0; - - cb.cs.r_selector = Some(self.selector_type_layout.sel_all.clone()); - cb.cs.w_selector = Some(self.selector_type_layout.sel_all.clone()); - cb.cs.lk_selector = Some(self.selector_type_layout.sel_all.clone()); - cb.cs.zero_selector = Some(self.selector_type_layout.sel_all.clone()); - - let w_len = cb.cs.w_expressions.len(); - let r_len = cb.cs.r_expressions.len(); - let lk_len = cb.cs.lk_expressions.len(); - let zero_len = - cb.cs.assert_zero_expressions.len() + cb.cs.assert_zero_sumcheck_expressions.len(); - ( - [ - (0..r_len).collect_vec(), - (r_len..r_len + w_len).collect_vec(), - (r_len + w_len..r_len + w_len + lk_len).collect_vec(), - (0..zero_len).collect_vec(), - ], - Chip::new_from_cb(cb, self.n_challenges), - ) + fn finalize(&self, name: String, cb: &mut CircuitBuilder) -> Chip { + let sel = SelectorType::Prefix(self.sel.expr()); + cb.cs.set_all_default_selectors(sel); + + let out_evals = default_out_eval_groups(cb); + let mut chip = Chip::new_from_cb(cb, 0); + let layer = Layer::from_circuit_builder(cb, name, 0, out_evals); + chip.add_layer(layer); + chip } } @@ -291,7 +261,7 @@ impl ProtocolWitnessGenerator for Fp2AddSubA wits: [&mut RowMajorMatrix; 2], lk_multiplicity: &mut LkMultiplicity, ) { - let (wits_start, num_wit_cols) = ( + let (layout_start, num_layout_wit_cols) = ( self.layer_exprs.wits.is_add.id as usize, num_fp2_addsub_cols::

(), ); @@ -299,6 +269,8 @@ impl ProtocolWitnessGenerator for Fp2AddSubA let num_instances = wits.num_instances(); let nthreads = max_usable_threads(); let num_instance_per_batch = num_instances.div_ceil(nthreads).max(1); + let wits_width = wits.width; + let structural_wits_width = structural_wits.width; let raw_witin_iter = wits.par_batch_iter_mut(num_instance_per_batch); let raw_structural_wits_iter = structural_wits.par_batch_iter_mut(num_instance_per_batch); raw_witin_iter @@ -306,16 +278,14 @@ impl ProtocolWitnessGenerator for Fp2AddSubA .zip_eq(phase1.instances.par_chunks(num_instance_per_batch)) .for_each(|((rows, eqs), phase1_instances)| { let mut lk_multiplicity = lk_multiplicity.clone(); - rows.chunks_mut(self.n_committed) - .zip_eq(eqs.chunks_mut(self.n_structural_witin)) + rows.chunks_mut(wits_width) + .zip_eq(eqs.chunks_mut(structural_wits_width)) .zip_eq(phase1_instances) .for_each(|((row, eqs), phase1_instance)| { let cols: &mut Fp2AddSubAssignWitCols = - row[wits_start..][..num_wit_cols].borrow_mut(); + row[layout_start..][..num_layout_wit_cols].borrow_mut(); Self::populate_row(phase1_instance, cols, &mut lk_multiplicity); - for x in eqs.iter_mut() { - *x = E::BaseField::ONE; - } + eqs[self.sel.id as usize] = E::BaseField::ONE; }); }); } @@ -329,7 +299,7 @@ mod tests { use gkr_iop::{ circuit_builder::{CircuitBuilder, ConstraintSystem}, cpu::{CpuBackend, CpuProver}, - gkr::{GKRProverOutput, layer::Layer}, + gkr::GKRProverOutput, selector::SelectorContext, }; use itertools::Itertools; @@ -344,7 +314,7 @@ mod tests { use transcript::{BasicTranscript, Transcript}; use witness::{InstancePaddingStrategy, RowMajorMatrix}; - use crate::{gadgets::FieldOperation, witness::LkMultiplicity}; + use crate::{gadgets::FieldOperation, scheme::utils::gkr_witness, witness::LkMultiplicity}; fn random_mod() -> BigUint { let mut bytes = vec![0u8; P::NB_LIMBS + 8]; @@ -358,16 +328,9 @@ mod tests { let mut cs = ConstraintSystem::::new(|| "fp2_addsub_test"); let mut cb = CircuitBuilder::::new(&mut cs); - let mut layout = Fp2AddSubAssignLayout::::build_layer_logic(&mut cb, ()) + let layout = Fp2AddSubAssignLayout::::build_layer_logic(&mut cb, ()) .expect("build_layer_logic failed"); - let (out_evals, mut chip) = layout.finalize(&mut cb); - let layer = Layer::from_circuit_builder( - &cb, - "fp2_addsub".to_string(), - layout.n_challenges, - out_evals, - ); - chip.add_layer(layer); + let chip = layout.finalize("fp2_addsub".to_string(), &mut cb); let gkr_circuit = chip.gkr_circuit(); let instances = (0..count) @@ -387,12 +350,12 @@ mod tests { let mut phase1 = RowMajorMatrix::new( instances.len(), - layout.n_committed, + chip.n_committed, InstancePaddingStrategy::Default, ); let mut structural = RowMajorMatrix::new( instances.len(), - layout.n_structural_witin, + chip.n_structural_witin, InstancePaddingStrategy::Default, ); let mut lk_multiplicity = LkMultiplicity::default(); @@ -458,16 +421,15 @@ mod tests { .map(Arc::new) .collect_vec(); - let (gkr_witness, gkr_output) = - crate::scheme::utils::gkr_witness::, CpuProver<_>>( - &gkr_circuit, - &phase1_witness_group, - &structural_witness, - &fixed, - &[], - &[], - &challenges, - ); + let (gkr_witness, gkr_output) = gkr_witness::, CpuProver<_>>( + &gkr_circuit, + &phase1_witness_group, + &structural_witness, + &fixed, + &[], + &[], + &challenges, + ); let out_evals = { let mut point = Vec::with_capacity(log2_num_instance); @@ -477,11 +439,9 @@ mod tests { .0 .par_iter() .map(|wit| { - let point = point[point.len() - wit.num_vars()..point.len()].to_vec(); - PointAndEval { - point: point.clone(), - eval: wit.evaluate(&point), - } + let point = point[point.len() - log2_num_instance..point.len()].to_vec(); + let eval = wit.evaluate(&point); + PointAndEval { point, eval } }) .collect::>(); @@ -495,7 +455,9 @@ mod tests { } }; - let selector_ctxs = vec![SelectorContext::new(0, num_instances, log2_num_instance); 1]; + let selector_ctxs_len = gkr_circuit.layers[0].selector_ctxs_len(); + let selector_ctxs = + vec![SelectorContext::new(0, num_instances, log2_num_instance); selector_ctxs_len]; let GKRProverOutput { gkr_proof, .. } = gkr_circuit .prove::, CpuProver<_>>( num_threads, @@ -533,41 +495,21 @@ mod tests { #[test] fn test_bls12381_fp2_addsub() { - std::thread::Builder::new() - .stack_size(32 * 1024 * 1024) - .spawn(|| test_fp2_addsub_helper::(8)) - .expect("spawn fp2_addsub test thread failed") - .join() - .expect("fp2_addsub test thread panicked"); + test_fp2_addsub_helper::(8) } #[test] fn test_bls12381_fp2_addsub_nonpow2() { - std::thread::Builder::new() - .stack_size(32 * 1024 * 1024) - .spawn(|| test_fp2_addsub_helper::(7)) - .expect("spawn fp2_addsub test thread failed") - .join() - .expect("fp2_addsub test thread panicked"); + test_fp2_addsub_helper::(7) } #[test] fn test_bn254_fp2_addsub() { - std::thread::Builder::new() - .stack_size(32 * 1024 * 1024) - .spawn(|| test_fp2_addsub_helper::(8)) - .expect("spawn fp2_addsub test thread failed") - .join() - .expect("fp2_addsub test thread panicked"); + test_fp2_addsub_helper::(8) } #[test] fn test_bn254_fp2_addsub_nonpow2() { - std::thread::Builder::new() - .stack_size(32 * 1024 * 1024) - .spawn(|| test_fp2_addsub_helper::(7)) - .expect("spawn fp2_addsub test thread failed") - .join() - .expect("fp2_addsub test thread panicked"); + test_fp2_addsub_helper::(7) } } diff --git a/ceno_zkvm/src/precompiles/fptower/fp2_mul.rs b/ceno_zkvm/src/precompiles/fptower/fp2_mul.rs index c9160e6d9..6762fdc30 100644 --- a/ceno_zkvm/src/precompiles/fptower/fp2_mul.rs +++ b/ceno_zkvm/src/precompiles/fptower/fp2_mul.rs @@ -28,11 +28,13 @@ use derive::AlignedBorrow; use ff_ext::ExtensionField; use generic_array::{GenericArray, sequence::GenericSequence}; use gkr_iop::{ - OutEvalGroups, ProtocolBuilder, ProtocolWitnessGenerator, chip::Chip, - circuit_builder::CircuitBuilder, error::CircuitBuilderError, selector::SelectorType, + ProtocolBuilder, ProtocolWitnessGenerator, chip::Chip, circuit_builder::CircuitBuilder, + default_out_eval_groups, error::CircuitBuilderError, gkr::layer::Layer, selector::SelectorType, }; use itertools::Itertools; -use multilinear_extensions::{Expression, ToExpr, WitIn, util::max_usable_threads}; +use multilinear_extensions::{ + Expression, StructuralWitIn, ToExpr, WitIn, util::max_usable_threads, +}; use num::BigUint; use p3::field::FieldAlgebra; use rayon::{ @@ -50,7 +52,7 @@ use witness::{InstancePaddingStrategy, RowMajorMatrix}; use crate::{ chip_handler::MemoryExpr, gadgets::{FieldOperation, field_op::FieldOpCols, range::FieldLtCols}, - precompiles::{SelectorTypeLayout, utils::merge_u8_slice_to_u16_limbs_pairs_and_extend}, + precompiles::utils::merge_u8_slice_to_u16_limbs_pairs_and_extend, witness::LkMultiplicity, }; @@ -110,13 +112,9 @@ pub struct Fp2MulAssignLayer { #[derive(Clone, Debug)] pub struct Fp2MulAssignLayout { pub layer_exprs: Fp2MulAssignLayer, - pub selector_type_layout: SelectorTypeLayout, + pub sel: StructuralWitIn, pub input32_exprs: [GenericArray,

::WordsCurvePoint>; 2], pub output32_exprs: GenericArray,

::WordsCurvePoint>, - pub n_fixed: usize, - pub n_committed: usize, - pub n_structural_witin: usize, - pub n_challenges: usize, } impl Fp2MulAssignLayout { @@ -136,14 +134,6 @@ impl Fp2MulAssignLayout { c1_range_check: FieldLtCols::create(cb, || "fp2_mul_c1_range"), }; - let eq = cb.create_placeholder_structural_witin(|| "fp2_mul_structural_witin"); - let sel = SelectorType::Prefix(eq.expr()); - let selector_type_layout = SelectorTypeLayout { - sel_first: None, - sel_last: None, - sel_all: sel.clone(), - }; - let input32_exprs: [GenericArray,

::WordsCurvePoint>; 2] = array::from_fn(|_| { GenericArray::generate(|_| array::from_fn(|_| Expression::WitIn(0))) @@ -153,13 +143,9 @@ impl Fp2MulAssignLayout { Self { layer_exprs: Fp2MulAssignLayer { wits }, - selector_type_layout, + sel: cb.create_placeholder_structural_witin(|| "fp2_mul_sel"), input32_exprs, output32_exprs, - n_fixed: 0, - n_committed: 0, - n_structural_witin: 0, - n_challenges: 0, } } @@ -279,31 +265,15 @@ impl ProtocolBuilder for Fp2MulAssignLayout< Ok(layout) } - fn finalize(&mut self, cb: &mut CircuitBuilder) -> (OutEvalGroups, Chip) { - self.n_fixed = cb.cs.num_fixed; - self.n_committed = cb.cs.num_witin as usize; - self.n_structural_witin = cb.cs.num_structural_witin as usize; - self.n_challenges = 0; - - cb.cs.r_selector = Some(self.selector_type_layout.sel_all.clone()); - cb.cs.w_selector = Some(self.selector_type_layout.sel_all.clone()); - cb.cs.lk_selector = Some(self.selector_type_layout.sel_all.clone()); - cb.cs.zero_selector = Some(self.selector_type_layout.sel_all.clone()); - - let w_len = cb.cs.w_expressions.len(); - let r_len = cb.cs.r_expressions.len(); - let lk_len = cb.cs.lk_expressions.len(); - let zero_len = - cb.cs.assert_zero_expressions.len() + cb.cs.assert_zero_sumcheck_expressions.len(); - ( - [ - (0..r_len).collect_vec(), - (r_len..r_len + w_len).collect_vec(), - (r_len + w_len..r_len + w_len + lk_len).collect_vec(), - (0..zero_len).collect_vec(), - ], - Chip::new_from_cb(cb, self.n_challenges), - ) + fn finalize(&self, name: String, cb: &mut CircuitBuilder) -> Chip { + let sel = SelectorType::Prefix(self.sel.expr()); + cb.cs.set_all_default_selectors(sel); + + let out_evals = default_out_eval_groups(cb); + let mut chip = Chip::new_from_cb(cb, 0); + let layer = Layer::from_circuit_builder(cb, name, 0, out_evals); + chip.add_layer(layer); + chip } } @@ -320,7 +290,7 @@ impl ProtocolWitnessGenerator for Fp2MulAssi wits: [&mut RowMajorMatrix; 2], lk_multiplicity: &mut LkMultiplicity, ) { - let (wits_start, num_wit_cols) = ( + let (layout_start, num_layout_wit_cols) = ( self.layer_exprs.wits.a0.0[0].id as usize, num_fp2_mul_cols::

(), ); @@ -328,6 +298,8 @@ impl ProtocolWitnessGenerator for Fp2MulAssi let num_instances = wits.num_instances(); let nthreads = max_usable_threads(); let num_instance_per_batch = num_instances.div_ceil(nthreads).max(1); + let wits_width = wits.width; + let structural_wits_width = structural_wits.width; let raw_witin_iter = wits.par_batch_iter_mut(num_instance_per_batch); let raw_structural_wits_iter = structural_wits.par_batch_iter_mut(num_instance_per_batch); raw_witin_iter @@ -335,16 +307,14 @@ impl ProtocolWitnessGenerator for Fp2MulAssi .zip_eq(phase1.instances.par_chunks(num_instance_per_batch)) .for_each(|((rows, eqs), phase1_instances)| { let mut lk_multiplicity = lk_multiplicity.clone(); - rows.chunks_mut(self.n_committed) - .zip_eq(eqs.chunks_mut(self.n_structural_witin)) + rows.chunks_mut(wits_width) + .zip_eq(eqs.chunks_mut(structural_wits_width)) .zip_eq(phase1_instances) .for_each(|((row, eqs), phase1_instance)| { let cols: &mut Fp2MulAssignWitCols = - row[wits_start..][..num_wit_cols].borrow_mut(); + row[layout_start..][..num_layout_wit_cols].borrow_mut(); Self::populate_row(phase1_instance, cols, &mut lk_multiplicity); - for x in eqs.iter_mut() { - *x = E::BaseField::ONE; - } + eqs[self.sel.id as usize] = E::BaseField::ONE; }); }); } @@ -358,7 +328,7 @@ mod tests { use gkr_iop::{ circuit_builder::{CircuitBuilder, ConstraintSystem}, cpu::{CpuBackend, CpuProver}, - gkr::{GKRProverOutput, layer::Layer}, + gkr::GKRProverOutput, selector::SelectorContext, }; use itertools::Itertools; @@ -387,12 +357,9 @@ mod tests { let mut cs = ConstraintSystem::::new(|| "fp2_mul_test"); let mut cb = CircuitBuilder::::new(&mut cs); - let mut layout = Fp2MulAssignLayout::::build_layer_logic(&mut cb, ()) + let layout = Fp2MulAssignLayout::::build_layer_logic(&mut cb, ()) .expect("build_layer_logic failed"); - let (out_evals, mut chip) = layout.finalize(&mut cb); - let layer = - Layer::from_circuit_builder(&cb, "fp2_mul".to_string(), layout.n_challenges, out_evals); - chip.add_layer(layer); + let chip = layout.finalize("fp2_mul".to_string(), &mut cb); let gkr_circuit = chip.gkr_circuit(); let instances = (0..count) @@ -407,12 +374,12 @@ mod tests { let mut phase1 = RowMajorMatrix::new( instances.len(), - layout.n_committed, + chip.n_committed, InstancePaddingStrategy::Default, ); let mut structural = RowMajorMatrix::new( instances.len(), - layout.n_structural_witin, + chip.n_structural_witin, InstancePaddingStrategy::Default, ); let mut lk_multiplicity = LkMultiplicity::default(); @@ -481,7 +448,6 @@ mod tests { &[], &challenges, ); - let out_evals = { let mut point = Vec::with_capacity(log2_num_instance); point.extend(prover_transcript.sample_vec(log2_num_instance).to_vec()); @@ -490,11 +456,9 @@ mod tests { .0 .par_iter() .map(|wit| { - let point = point[point.len() - wit.num_vars()..point.len()].to_vec(); - PointAndEval { - point: point.clone(), - eval: wit.evaluate(&point), - } + let point = point[point.len() - log2_num_instance..point.len()].to_vec(); + let eval = wit.evaluate(&point); + PointAndEval { point, eval } }) .collect::>(); @@ -508,7 +472,9 @@ mod tests { } }; - let selector_ctxs = vec![SelectorContext::new(0, num_instances, log2_num_instance); 1]; + let selector_ctxs_len = gkr_circuit.layers[0].selector_ctxs_len(); + let selector_ctxs = + vec![SelectorContext::new(0, num_instances, log2_num_instance); selector_ctxs_len]; let GKRProverOutput { gkr_proof, .. } = gkr_circuit .prove::, CpuProver<_>>( num_threads, @@ -546,41 +512,21 @@ mod tests { #[test] fn test_bls12381_fp2_mul() { - std::thread::Builder::new() - .stack_size(32 * 1024 * 1024) - .spawn(|| test_fp2_mul_helper::(8)) - .expect("spawn fp2_mul test thread failed") - .join() - .expect("fp2_mul test thread panicked"); + test_fp2_mul_helper::(8) } #[test] fn test_bls12381_fp2_mul_nonpow2() { - std::thread::Builder::new() - .stack_size(32 * 1024 * 1024) - .spawn(|| test_fp2_mul_helper::(7)) - .expect("spawn fp2_mul test thread failed") - .join() - .expect("fp2_mul test thread panicked"); + test_fp2_mul_helper::(7) } #[test] fn test_bn254_fp2_mul() { - std::thread::Builder::new() - .stack_size(32 * 1024 * 1024) - .spawn(|| test_fp2_mul_helper::(8)) - .expect("spawn fp2_mul test thread failed") - .join() - .expect("fp2_mul test thread panicked"); + test_fp2_mul_helper::(8) } #[test] fn test_bn254_fp2_mul_nonpow2() { - std::thread::Builder::new() - .stack_size(32 * 1024 * 1024) - .spawn(|| test_fp2_mul_helper::(7)) - .expect("spawn fp2_mul test thread failed") - .join() - .expect("fp2_mul test thread panicked"); + test_fp2_mul_helper::(7) } } diff --git a/ceno_zkvm/src/precompiles/lookup_keccakf.rs b/ceno_zkvm/src/precompiles/lookup_keccakf.rs index 52391267a..de213a75e 100644 --- a/ceno_zkvm/src/precompiles/lookup_keccakf.rs +++ b/ceno_zkvm/src/precompiles/lookup_keccakf.rs @@ -1,10 +1,11 @@ use ceno_emul::{ByteAddr, Cycle, MemOp, StepRecord}; use ff_ext::ExtensionField; use gkr_iop::{ - OutEvalGroups, ProtocolBuilder, ProtocolWitnessGenerator, + ProtocolBuilder, ProtocolWitnessGenerator, chip::Chip, circuit_builder::{CircuitBuilder, ConstraintSystem, expansion_expr, rotation_split}, cpu::{CpuBackend, CpuProver}, + default_out_eval_groups, error::{BackendError, CircuitBuilderError}, gkr::{ GKRCircuit, GKRProof, GKRProverOutput, @@ -26,9 +27,9 @@ use ndarray::{ArrayView, Ix2, Ix3, s}; use p3::field::FieldAlgebra; use rayon::{ iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator}, - slice::{ParallelSlice, ParallelSliceMut}, + slice::ParallelSlice, }; -use std::{array, mem::transmute, sync::Arc}; +use std::{array, sync::Arc}; use sumcheck::{ macros::{entered_span, exit_span}, util::optimal_sumcheck_threads, @@ -41,9 +42,8 @@ use crate::{ e2e::ShardContext, error::ZKVMError, instructions::riscv::insn_base::{StateInOut, WriteMEM}, - precompiles::{ - SelectorTypeLayout, - utils::{Mask, MaskRepresentation, not8_expr, set_slice_felts_from_u64 as push_instance}, + precompiles::utils::{ + Mask, MaskRepresentation, not8_expr, set_slice_felts_from_u64 as push_instance, }, scheme::utils::gkr_witness, }; @@ -101,10 +101,6 @@ const LOOKUP_FELTS_PER_ROUND: usize = pub const AND_LOOKUPS: usize = AND_LOOKUPS_PER_ROUND; pub const XOR_LOOKUPS: usize = XOR_LOOKUPS_PER_ROUND; pub const RANGE_LOOKUPS: usize = RANGE_LOOKUPS_PER_ROUND; -pub const STRUCTURAL_WITIN: usize = 6; - -#[derive(Clone, Debug)] -pub struct KeccakParams; #[derive(Clone, Debug)] #[repr(C)] @@ -149,15 +145,17 @@ pub struct KeccakLayer { #[derive(Clone, Debug)] pub struct KeccakLayout { - pub params: KeccakParams, pub layer_exprs: KeccakLayer, pub selector_type_layout: SelectorTypeLayout, pub input32_exprs: [MemoryExpr; KECCAK_INPUT32_SIZE], pub output32_exprs: [MemoryExpr; KECCAK_OUTPUT32_SIZE], - pub n_fixed: usize, - pub n_committed: usize, - pub n_structural_witin: usize, - pub n_challenges: usize, +} + +#[derive(Clone, Debug)] +pub struct SelectorTypeLayout { + pub sel_first: SelectorType, + pub sel_last: SelectorType, + pub sel_all: SelectorType, } const ROTATION_WITNESS_LEN: usize = 196; @@ -184,33 +182,45 @@ fn split_mask_to_array(value: u64, sizes: &[usize; N]) -> [u64; } impl KeccakLayout { - fn new(cb: &mut CircuitBuilder, params: KeccakParams) -> Self { + fn new(cb: &mut CircuitBuilder) -> Self { // allocate witnesses, fixed, and eqs - let ( - wits, - // fixed, - [ - sel_first, - sel_last, - eq_zero, - eq_rotation_left, - eq_rotation_right, - eq_rotation, - ], - ): (KeccakWitCols, [StructuralWitIn; STRUCTURAL_WITIN]) = unsafe { - ( - transmute::<[WitIn; size_of::>()], KeccakWitCols>( - array::from_fn(|id| cb.create_witin(|| format!("keccak_witin_{}", id))), - ), - // transmute::<[Fixed; 8], KeccakFixedCols>(array::from_fn(|id| { - // cb.create_fixed(|| format!("keccak_fixed_{}", id)) - // })), - array::from_fn(|id| { - cb.create_placeholder_structural_witin(|| format!("keccak_eq_{}", id)) - }), - ) + let wits = KeccakWitCols { + input8: array::from_fn(|id| cb.create_witin(|| format!("keccak_input8_{id}"))), + c_aux: array::from_fn(|id| cb.create_witin(|| format!("keccak_c_aux_{id}"))), + c_temp: array::from_fn(|id| cb.create_witin(|| format!("keccak_c_temp_{id}"))), + c_rot: array::from_fn(|id| cb.create_witin(|| format!("keccak_c_rot_{id}"))), + d: array::from_fn(|id| cb.create_witin(|| format!("keccak_d_{id}"))), + theta_output: array::from_fn(|id| { + cb.create_witin(|| format!("keccak_theta_output_{id}")) + }), + rotation_witness: array::from_fn(|id| { + cb.create_witin(|| format!("keccak_rotation_witness_{id}")) + }), + rhopi_output: array::from_fn(|id| cb.create_witin(|| format!("keccak_rhopi_{id}"))), + nonlinear: array::from_fn(|id| cb.create_witin(|| format!("keccak_nonlinear_{id}"))), + chi_output: array::from_fn(|id| cb.create_witin(|| format!("keccak_chi_{id}"))), + iota_output: array::from_fn(|id| cb.create_witin(|| format!("keccak_iota_{id}"))), + rc: array::from_fn(|id| cb.create_witin(|| format!("keccak_rc_{id}"))), }; + // let fixed = KeccakFixedCols { + // rc: array::from_fn(|id| cb.create_fixed(|| format!("keccak_fixed_{id}"))), + // }; + + // The order matters because our prover implementation assumes that the + // selectors are assigned before the rotation witnesses. + + let [ + sel_first, + sel_last, + sel_all, + eq_rotation_left, + eq_rotation_right, + eq_rotation, + ] = array::from_fn(|id| { + cb.create_placeholder_structural_witin(|| format!("keccak_eq_{}", id)) + }); + // indices to activate zero/lookup constraints let checked_indices = CYCLIC_POW2_5 .iter() @@ -220,7 +230,6 @@ impl KeccakLayout { .map(|v| v as usize) .collect_vec(); Self { - params, layer_exprs: KeccakLayer { wits, // fixed, @@ -229,40 +238,36 @@ impl KeccakLayout { eq_rotation, }, selector_type_layout: SelectorTypeLayout { - sel_first: Some(SelectorType::OrderedSparse { + sel_first: SelectorType::OrderedSparse { num_vars: 5, indices: vec![CYCLIC_POW2_5[0] as usize], expression: sel_first.expr(), - }), - sel_last: Some(SelectorType::OrderedSparse { + }, + sel_last: SelectorType::OrderedSparse { num_vars: 5, indices: vec![CYCLIC_POW2_5[ROUNDS - 1] as usize], expression: sel_last.expr(), - }), + }, sel_all: SelectorType::OrderedSparse { num_vars: 5, indices: checked_indices.clone(), - expression: eq_zero.expr(), + expression: sel_all.expr(), }, }, input32_exprs: array::from_fn(|_| array::from_fn(|_| Expression::WitIn(0))), output32_exprs: array::from_fn(|_| array::from_fn(|_| Expression::WitIn(0))), - n_fixed: 0, - n_committed: 0, - n_structural_witin: STRUCTURAL_WITIN, - n_challenges: 0, } } } impl ProtocolBuilder for KeccakLayout { - type Params = KeccakParams; + type Params = (); fn build_layer_logic( cb: &mut CircuitBuilder, - params: Self::Params, + _params: Self::Params, ) -> Result { - let mut layout = Self::new(cb, params); + let mut layout = Self::new(cb); let system = cb; let KeccakWitCols { @@ -514,55 +519,21 @@ impl ProtocolBuilder for KeccakLayout { Ok(layout) } - fn finalize(&mut self, cb: &mut CircuitBuilder) -> (OutEvalGroups, Chip) { - self.n_fixed = cb.cs.num_fixed; - self.n_committed = cb.cs.num_witin as usize; - self.n_challenges = 0; - - // register selector to legacy constrain system - cb.cs.r_selector = Some(self.selector_type_layout.sel_first.clone().unwrap()); - cb.cs.w_selector = Some(self.selector_type_layout.sel_last.clone().unwrap()); - cb.cs.lk_selector = Some(self.selector_type_layout.sel_all.clone()); - cb.cs.zero_selector = Some(self.selector_type_layout.sel_all.clone()); - - let w_len = cb.cs.w_expressions.len(); - let r_len = cb.cs.r_expressions.len(); - let lk_len = cb.cs.lk_expressions.len(); - let zero_len = - cb.cs.assert_zero_expressions.len() + cb.cs.assert_zero_sumcheck_expressions.len(); - ( - [ - // r_record - (0..r_len).collect_vec(), - // w_record - (r_len..r_len + w_len).collect_vec(), - // lk_record - (r_len + w_len..r_len + w_len + lk_len).collect_vec(), - // zero_record - (0..zero_len).collect_vec(), - ], - Chip::new_from_cb(cb, self.n_challenges), - ) - } - - fn n_committed(&self) -> usize { - unimplemented!("retrieve from constrain system") - } - - fn n_fixed(&self) -> usize { - unimplemented!("retrieve from constrain system") - } - - fn n_challenges(&self) -> usize { - 0 - } - - fn n_evaluations(&self) -> usize { - unimplemented!() - } - - fn n_layers(&self) -> usize { - 1 + fn finalize(&self, name: String, cb: &mut CircuitBuilder) -> Chip { + cb.cs + .set_default_read_selector(self.selector_type_layout.sel_first.clone()); + cb.cs + .set_default_write_selector(self.selector_type_layout.sel_last.clone()); + cb.cs + .set_default_lookup_selector(self.selector_type_layout.sel_all.clone()); + cb.cs + .set_default_zero_selector(self.selector_type_layout.sel_all.clone()); + + let out_evals = default_out_eval_groups(cb); + let mut chip = Chip::new_from_cb(cb, 0); + let layer = Layer::from_circuit_builder(cb, name, 0, out_evals); + chip.add_layer(layer); + chip } } @@ -657,292 +628,305 @@ where } = self.layer_exprs; let num_instances = phase1.instances.len(); - - // keccak instance full rounds (24 rounds + 8 round padding) as chunk size - // we need to do assignment on respective 31 cyclic group index - wits.values - .par_chunks_mut(self.n_committed * ROUNDS.next_power_of_two()) - .take(num_instances) - .zip_eq( - structural_wits - .values - .par_chunks_mut(self.n_structural_witin * ROUNDS.next_power_of_two()) - .take(num_instances), - ) - .zip(&phase1.instances) - .for_each(|((wits, structural_wits), KeccakInstance { witin, .. })| { + let (layout_start, num_layout_wit_cols) = + (input8_witin[0].id as usize, size_of::>()); + + let wits_width = wits.width; + let structural_wits_width = structural_wits.width; + + let num_rotation = ROUNDS.next_power_of_two(); + let num_instance_per_batch = num_instances.div_ceil(max_usable_threads()).max(1); + let raw_witin_iter = wits.par_batch_iter_mut(num_instance_per_batch); + let raw_structural_wits_iter = structural_wits.par_batch_iter_mut(num_instance_per_batch); + raw_witin_iter + .zip_eq(raw_structural_wits_iter) + .zip_eq(phase1.instances.par_chunks(num_instance_per_batch)) + .for_each(|((rows, eqs), phase1_instances)| { let mut lk_multiplicity = lk_multiplicity.clone(); - let state_32_iter = witin.instance.iter().map(|e| *e as u64); - let mut state64 = [[0u64; 5]; 5]; - zip_eq(iproduct!(0..5, 0..5), state_32_iter.tuples()) - .map(|((x, y), (lo, hi))| { - state64[x][y] = lo | (hi << 32); - }) - .count(); - - let bh = BooleanHypercube::new(ROUNDS_CEIL_LOG2); - let mut cyclic_group = bh.into_iter(); - - let Some(sel_first) = self.selector_type_layout.sel_first.as_ref() else { - panic!("sel_first must be Some"); - }; - let (mut sel_first_iter, sel_first_structural_witin) = ( - sel_first.sparse_indices().iter(), - sel_first.selector_expr().id(), - ); - - let Some(sel_last) = self.selector_type_layout.sel_last.as_ref() else { - panic!("sel_last must be Some"); - }; - let (mut sel_last_iter, sel_last_structural_witin) = ( - sel_last.sparse_indices().iter(), - sel_last.selector_expr().id(), - ); - - let (mut sel_all_iter, sel_all_structural_witin) = ( - self.selector_type_layout.sel_all.sparse_indices().iter(), - self.selector_type_layout.sel_all.selector_expr().id(), - ); - - #[allow(clippy::needless_range_loop)] - for round in 0..ROUNDS { - let round_index = cyclic_group.next().unwrap(); - let wits = - &mut wits[round_index as usize * self.n_committed..][..self.n_committed]; - - // set selector - if let Some(index) = sel_first_iter.next() { - structural_wits - [index * self.n_structural_witin + sel_first_structural_witin] = - E::BaseField::ONE; - } - if let Some(index) = sel_last_iter.next() { - structural_wits - [index * self.n_structural_witin + sel_last_structural_witin] = - E::BaseField::ONE; - } - if let Some(index) = sel_all_iter.next() { - structural_wits - [index * self.n_structural_witin + sel_all_structural_witin] = - E::BaseField::ONE; - } + rows.chunks_mut(wits_width * num_rotation) + .zip_eq(eqs.chunks_mut(structural_wits_width * num_rotation)) + .zip_eq(phase1_instances) + .for_each(|((wits, structural_wits), KeccakInstance { witin, .. })| { + let state_32_iter = witin.instance.iter().map(|e| *e as u64); + let mut state64 = [[0u64; 5]; 5]; + zip_eq(iproduct!(0..5, 0..5), state_32_iter.tuples()) + .map(|((x, y), (lo, hi))| { + state64[x][y] = lo | (hi << 32); + }) + .count(); + + let bh = BooleanHypercube::new(ROUNDS_CEIL_LOG2); + let mut cyclic_group = bh.into_iter(); + + let (mut sel_first_iter, sel_first_structural_witin) = ( + self.selector_type_layout.sel_first.sparse_indices().iter(), + self.selector_type_layout.sel_first.selector_expr().id(), + ); - let mut state8 = [[[0u64; 8]; 5]; 5]; - for x in 0..5 { - for y in 0..5 { - state8[x][y] = split_mask_to_array(state64[x][y], &BYTE_SPLIT_SIZES); - } - } + let (mut sel_last_iter, sel_last_structural_witin) = ( + self.selector_type_layout.sel_last.sparse_indices().iter(), + self.selector_type_layout.sel_last.selector_expr().id(), + ); - push_instance::( - wits, - input8_witin[0].id.into(), - state8.into_iter().flatten().flatten(), - ); + let (mut sel_all_iter, sel_all_structural_witin) = ( + self.selector_type_layout.sel_all.sparse_indices().iter(), + self.selector_type_layout.sel_all.selector_expr().id(), + ); - let mut c_aux64 = [[0u64; 5]; 5]; - let mut c_aux8 = [[[0u64; 8]; 5]; 5]; + #[allow(clippy::needless_range_loop)] + for round in 0..ROUNDS { + let round_index = cyclic_group.next().unwrap(); + let wits = &mut wits[round_index as usize * wits_width..] + [layout_start..][..num_layout_wit_cols]; + + // set selector + if let Some(index) = sel_first_iter.next() { + structural_wits + [index * structural_wits_width + sel_first_structural_witin] = + E::BaseField::ONE; + } + if let Some(index) = sel_last_iter.next() { + structural_wits + [index * structural_wits_width + sel_last_structural_witin] = + E::BaseField::ONE; + } + if let Some(index) = sel_all_iter.next() { + structural_wits + [index * structural_wits_width + sel_all_structural_witin] = + E::BaseField::ONE; + } - for i in 0..5 { - c_aux64[i][0] = state64[0][i]; - c_aux8[i][0] = split_mask_to_array(c_aux64[i][0], &BYTE_SPLIT_SIZES); - for j in 1..5 { - c_aux64[i][j] = state64[j][i] ^ c_aux64[i][j - 1]; - for k in 0..8 { - lk_multiplicity - .lookup_xor_byte(c_aux8[i][j - 1][k], state8[j][i][k]); + let mut state8 = [[[0u64; 8]; 5]; 5]; + for x in 0..5 { + for y in 0..5 { + state8[x][y] = + split_mask_to_array(state64[x][y], &BYTE_SPLIT_SIZES); + } } - c_aux8[i][j] = split_mask_to_array(c_aux64[i][j], &BYTE_SPLIT_SIZES); - } - } - let mut c64 = [0u64; 5]; - let mut c8 = [[0u64; 8]; 5]; + push_instance::( + wits, + input8_witin[0].id.into(), + state8.into_iter().flatten().flatten(), + ); - for x in 0..5 { - c64[x] = c_aux64[x][4]; - c8[x] = split_mask_to_array(c64[x], &BYTE_SPLIT_SIZES); - } + let mut c_aux64 = [[0u64; 5]; 5]; + let mut c_aux8 = [[[0u64; 8]; 5]; 5]; + + for i in 0..5 { + c_aux64[i][0] = state64[0][i]; + c_aux8[i][0] = + split_mask_to_array(c_aux64[i][0], &BYTE_SPLIT_SIZES); + for j in 1..5 { + c_aux64[i][j] = state64[j][i] ^ c_aux64[i][j - 1]; + for k in 0..8 { + lk_multiplicity + .lookup_xor_byte(c_aux8[i][j - 1][k], state8[j][i][k]); + } + c_aux8[i][j] = + split_mask_to_array(c_aux64[i][j], &BYTE_SPLIT_SIZES); + } + } - let mut c_temp = [[0u64; 8]; 5]; - for i in 0..5 { - let chunks = split_mask_to_array(c64[i], &C_TEMP_SPLIT_SIZES); - for (chunk, size) in chunks.iter().zip(C_TEMP_SPLIT_SIZES.iter()) { - lk_multiplicity.assert_const_range(*chunk, *size); - } - c_temp[i] = chunks; - } + let mut c64 = [0u64; 5]; + let mut c8 = [[0u64; 8]; 5]; - let mut crot64 = [0u64; 5]; - let mut crot8 = [[0u64; 8]; 5]; - for i in 0..5 { - crot64[i] = c64[i].rotate_left(1); - crot8[i] = split_mask_to_array(crot64[i], &BYTE_SPLIT_SIZES); - } + for x in 0..5 { + c64[x] = c_aux64[x][4]; + c8[x] = split_mask_to_array(c64[x], &BYTE_SPLIT_SIZES); + } - let mut d64 = [0u64; 5]; - let mut d8 = [[0u64; 8]; 5]; - for x in 0..5 { - d64[x] = c64[(x + 4) % 5] ^ c64[(x + 1) % 5].rotate_left(1); - for k in 0..8 { - lk_multiplicity.lookup_xor_byte( - c_aux8[(x + 5 - 1) % 5][4][k], - crot8[(x + 1) % 5][k], - ); - } - d8[x] = split_mask_to_array(d64[x], &BYTE_SPLIT_SIZES); - } + let mut c_temp = [[0u64; 8]; 5]; + for i in 0..5 { + let chunks = split_mask_to_array(c64[i], &C_TEMP_SPLIT_SIZES); + for (chunk, size) in chunks.iter().zip(C_TEMP_SPLIT_SIZES.iter()) { + lk_multiplicity.assert_const_range(*chunk, *size); + } + c_temp[i] = chunks; + } - let mut theta_state64 = state64; - let mut theta_state8 = [[[0u64; 8]; 5]; 5]; - let mut rotation_witness = Vec::with_capacity(ROTATION_WITNESS_LEN); + let mut crot64 = [0u64; 5]; + let mut crot8 = [[0u64; 8]; 5]; + for i in 0..5 { + crot64[i] = c64[i].rotate_left(1); + crot8[i] = split_mask_to_array(crot64[i], &BYTE_SPLIT_SIZES); + } - for x in 0..5 { - for y in 0..5 { - theta_state64[y][x] ^= d64[x]; - for k in 0..8 { - lk_multiplicity.lookup_xor_byte(state8[y][x][k], d8[x][k]) + let mut d64 = [0u64; 5]; + let mut d8 = [[0u64; 8]; 5]; + for x in 0..5 { + d64[x] = c64[(x + 4) % 5] ^ c64[(x + 1) % 5].rotate_left(1); + for k in 0..8 { + lk_multiplicity.lookup_xor_byte( + c_aux8[(x + 5 - 1) % 5][4][k], + crot8[(x + 1) % 5][k], + ); + } + d8[x] = split_mask_to_array(d64[x], &BYTE_SPLIT_SIZES); } - theta_state8[y][x] = - split_mask_to_array(theta_state64[y][x], &BYTE_SPLIT_SIZES); - let (sizes, _) = rotation_split(ROTATION_CONSTANTS[y][x]); - let rotation_chunks = - MaskRepresentation::from_mask(Mask::new(64, theta_state64[y][x])) + let mut theta_state64 = state64; + let mut theta_state8 = [[[0u64; 8]; 5]; 5]; + let mut rotation_witness = Vec::with_capacity(ROTATION_WITNESS_LEN); + + for x in 0..5 { + for y in 0..5 { + theta_state64[y][x] ^= d64[x]; + for k in 0..8 { + lk_multiplicity.lookup_xor_byte(state8[y][x][k], d8[x][k]) + } + theta_state8[y][x] = + split_mask_to_array(theta_state64[y][x], &BYTE_SPLIT_SIZES); + + let (sizes, _) = rotation_split(ROTATION_CONSTANTS[y][x]); + let rotation_chunks = MaskRepresentation::from_mask(Mask::new( + 64, + theta_state64[y][x], + )) .convert(&sizes) .values(); - for (chunk, size) in rotation_chunks.iter().zip(sizes.iter()) { - lk_multiplicity.assert_const_range(*chunk, *size); + for (chunk, size) in rotation_chunks.iter().zip(sizes.iter()) { + lk_multiplicity.assert_const_range(*chunk, *size); + } + rotation_witness.extend(rotation_chunks); + } + } + assert_eq!(rotation_witness.len(), rotation_witness_witin.len()); + + // Rho and Pi steps + let mut rhopi_output64 = [[0u64; 5]; 5]; + let mut rhopi_output8 = [[[0u64; 8]; 5]; 5]; + + for x in 0..5 { + for y in 0..5 { + rhopi_output64[(2 * x + 3 * y) % 5][y % 5] = theta_state64[y] + [x] + .rotate_left(ROTATION_CONSTANTS[y][x] as u32); + } } - rotation_witness.extend(rotation_chunks); - } - } - assert_eq!(rotation_witness.len(), rotation_witness_witin.len()); - - // Rho and Pi steps - let mut rhopi_output64 = [[0u64; 5]; 5]; - let mut rhopi_output8 = [[[0u64; 8]; 5]; 5]; - for x in 0..5 { - for y in 0..5 { - rhopi_output64[(2 * x + 3 * y) % 5][y % 5] = - theta_state64[y][x].rotate_left(ROTATION_CONSTANTS[y][x] as u32); - } - } + for x in 0..5 { + for y in 0..5 { + rhopi_output8[x][y] = split_mask_to_array( + rhopi_output64[x][y], + &BYTE_SPLIT_SIZES, + ); + } + } - for x in 0..5 { - for y in 0..5 { - rhopi_output8[x][y] = - split_mask_to_array(rhopi_output64[x][y], &BYTE_SPLIT_SIZES); - } - } + // Chi step + let mut nonlinear64 = [[0u64; 5]; 5]; + let mut nonlinear8 = [[[0u64; 8]; 5]; 5]; + for x in 0..5 { + for y in 0..5 { + nonlinear64[y][x] = !rhopi_output64[y][(x + 1) % 5] + & rhopi_output64[y][(x + 2) % 5]; + for k in 0..8 { + lk_multiplicity.lookup_and_byte( + 0xFF - rhopi_output8[y][(x + 1) % 5][k], + rhopi_output8[y][(x + 2) % 5][k], + ); + } + nonlinear8[y][x] = + split_mask_to_array(nonlinear64[y][x], &BYTE_SPLIT_SIZES); + } + } - // Chi step - let mut nonlinear64 = [[0u64; 5]; 5]; - let mut nonlinear8 = [[[0u64; 8]; 5]; 5]; - for x in 0..5 { - for y in 0..5 { - nonlinear64[y][x] = - !rhopi_output64[y][(x + 1) % 5] & rhopi_output64[y][(x + 2) % 5]; - for k in 0..8 { - lk_multiplicity.lookup_and_byte( - 0xFF - rhopi_output8[y][(x + 1) % 5][k], - rhopi_output8[y][(x + 2) % 5][k], - ); + let mut chi_output64 = [[0u64; 5]; 5]; + let mut chi_output8 = [[[0u64; 8]; 5]; 5]; + for x in 0..5 { + for y in 0..5 { + chi_output64[y][x] = nonlinear64[y][x] ^ rhopi_output64[y][x]; + for k in 0..8 { + lk_multiplicity.lookup_xor_byte( + rhopi_output8[y][x][k], + nonlinear8[y][x][k], + ); + } + chi_output8[y][x] = + split_mask_to_array(chi_output64[y][x], &BYTE_SPLIT_SIZES); + } } - nonlinear8[y][x] = - split_mask_to_array(nonlinear64[y][x], &BYTE_SPLIT_SIZES); - } - } - let mut chi_output64 = [[0u64; 5]; 5]; - let mut chi_output8 = [[[0u64; 8]; 5]; 5]; - for x in 0..5 { - for y in 0..5 { - chi_output64[y][x] = nonlinear64[y][x] ^ rhopi_output64[y][x]; + // Iota step + let mut iota_output64 = chi_output64; + let mut iota_output8 = [[[0u64; 8]; 5]; 5]; + // TODO figure out how to deal with RC, since it's not a constant in rotation + iota_output64[0][0] ^= RC[round]; + for k in 0..8 { - lk_multiplicity - .lookup_xor_byte(rhopi_output8[y][x][k], nonlinear8[y][x][k]); + let rc8 = split_mask_to_array(RC[round], &BYTE_SPLIT_SIZES); + lk_multiplicity.lookup_xor_byte(chi_output8[0][0][k], rc8[k]); } - chi_output8[y][x] = - split_mask_to_array(chi_output64[y][x], &BYTE_SPLIT_SIZES); - } - } - // Iota step - let mut iota_output64 = chi_output64; - let mut iota_output8 = [[[0u64; 8]; 5]; 5]; - // TODO figure out how to deal with RC, since it's not a constant in rotation - iota_output64[0][0] ^= RC[round]; + for x in 0..5 { + for y in 0..5 { + iota_output8[x][y] = + split_mask_to_array(iota_output64[x][y], &BYTE_SPLIT_SIZES); + } + } - for k in 0..8 { - let rc8 = split_mask_to_array(RC[round], &BYTE_SPLIT_SIZES); - lk_multiplicity.lookup_xor_byte(chi_output8[0][0][k], rc8[k]); - } + // set witness + push_instance::( + wits, + c_aux_witin[0].id.into(), + c_aux8.into_iter().flatten().flatten(), + ); + push_instance::( + wits, + c_temp_witin[0].id.into(), + c_temp.into_iter().flatten(), + ); + push_instance::( + wits, + c_rot_witin[0].id.into(), + crot8.into_iter().flatten(), + ); + push_instance::( + wits, + d_witin[0].id.into(), + d8.into_iter().flatten(), + ); + push_instance::( + wits, + theta_output_witin[0].id.into(), + theta_state8.into_iter().flatten().flatten(), + ); + push_instance::( + wits, + rotation_witness_witin[0].id.into(), + rotation_witness.into_iter(), + ); + push_instance::( + wits, + rhopi_output_witin[0].id.into(), + rhopi_output8.into_iter().flatten().flatten(), + ); + push_instance::( + wits, + nonlinear_witin[0].id.into(), + nonlinear8.into_iter().flatten().flatten(), + ); + push_instance::( + wits, + chi_output_witin[0].id.into(), + chi_output8[0][0].iter().copied(), + ); + push_instance::( + wits, + iota_output_witin[0].id.into(), + iota_output8.into_iter().flatten().flatten(), + ); + // TODO temporarily move RC to witness + push_instance::( + wits, + rc_witin[0].id.into(), + (0..8).map(|i| (RC[round] >> (i << 3)) & 0xFF), + ); - for x in 0..5 { - for y in 0..5 { - iota_output8[x][y] = - split_mask_to_array(iota_output64[x][y], &BYTE_SPLIT_SIZES); + state64 = iota_output64; } - } - - // set witness - push_instance::( - wits, - c_aux_witin[0].id.into(), - c_aux8.into_iter().flatten().flatten(), - ); - push_instance::( - wits, - c_temp_witin[0].id.into(), - c_temp.into_iter().flatten(), - ); - push_instance::( - wits, - c_rot_witin[0].id.into(), - crot8.into_iter().flatten(), - ); - push_instance::(wits, d_witin[0].id.into(), d8.into_iter().flatten()); - push_instance::( - wits, - theta_output_witin[0].id.into(), - theta_state8.into_iter().flatten().flatten(), - ); - push_instance::( - wits, - rotation_witness_witin[0].id.into(), - rotation_witness.into_iter(), - ); - push_instance::( - wits, - rhopi_output_witin[0].id.into(), - rhopi_output8.into_iter().flatten().flatten(), - ); - push_instance::( - wits, - nonlinear_witin[0].id.into(), - nonlinear8.into_iter().flatten().flatten(), - ); - push_instance::( - wits, - chi_output_witin[0].id.into(), - chi_output8[0][0].iter().copied(), - ); - push_instance::( - wits, - iota_output_witin[0].id.into(), - iota_output8.into_iter().flatten().flatten(), - ); - // TODO temporarily move RC to witness - push_instance::( - wits, - rc_witin[0].id.into(), - (0..8).map(|i| (RC[round] >> (i << 3)) & 0xFF), - ); - - state64 = iota_output64; - } + }); }); } } @@ -959,14 +943,13 @@ pub fn setup_gkr_circuit() -> Result<(TestKeccakLayout, GKRCircuit, u16, u16), ZKVMError> { let mut cs = ConstraintSystem::new(|| "lookup_keccak"); let mut cb = CircuitBuilder::::new(&mut cs); + let layout = KeccakLayout::build_layer_logic(&mut cb, ())?; // constrain vmstate let vm_state = StateInOut::construct_circuit(&mut cb, false)?; let state_ptr = cb.create_witin(|| "state_ptr"); - let mut layout = KeccakLayout::build_layer_logic(&mut cb, KeccakParams {})?; - let mem_rw = izip!(&layout.input32_exprs, &layout.output32_exprs) .enumerate() .map(|(i, (val_before, val_after))| { @@ -981,16 +964,7 @@ pub fn setup_gkr_circuit() }) .collect::, _>>()?; - let (out_evals, mut chip) = layout.finalize(&mut cb); - - let layer = Layer::from_circuit_builder( - &cb, - "lookup_keccak".to_string(), - layout.n_challenges, - out_evals, - ); - chip.add_layer(layer); - + let chip = layout.finalize("lookup_keccak".to_string(), &mut cb); Ok(( TestKeccakLayout { layout, @@ -1208,11 +1182,9 @@ pub fn run_lookup_keccakf .0 .par_iter() .map(|wit| { - let point = point[point.len() - wit.num_vars()..point.len()].to_vec(); - PointAndEval { - point: point.clone(), - eval: wit.evaluate(&point), - } + let point = point[point.len() - wit.num_vars()..].to_vec(); + let eval = wit.evaluate(&point); + PointAndEval { point, eval } }) .collect::>() }; @@ -1226,7 +1198,9 @@ pub fn run_lookup_keccakf } let span = entered_span!("create_proof", profiling_2 = true); - let selector_ctxs = vec![SelectorContext::new(0, num_instances, log2_num_instance_rounds); 3]; + let selector_ctxs_len = gkr_circuit.layers[0].selector_ctxs_len(); + let selector_ctxs = + vec![SelectorContext::new(0, num_instances, log2_num_instance_rounds); selector_ctxs_len]; let GKRProverOutput { gkr_proof, .. } = gkr_circuit .prove::, CpuProver<_>>( num_threads, @@ -1285,9 +1259,12 @@ mod tests { #[test] fn test_keccakf() { - for num_instances in 1..32 { - test_keccakf_helper(num_instances) - } + test_keccakf_helper(8) + } + + #[test] + fn test_keccakf_non_pow2() { + test_keccakf_helper(7) } fn test_keccakf_helper(num_instances: usize) { diff --git a/ceno_zkvm/src/precompiles/mod.rs b/ceno_zkvm/src/precompiles/mod.rs index 3d9a6e545..bf50efedf 100644 --- a/ceno_zkvm/src/precompiles/mod.rs +++ b/ceno_zkvm/src/precompiles/mod.rs @@ -8,22 +8,20 @@ mod weierstrass; pub use lookup_keccakf::{ AND_LOOKUPS, KECCAK_INPUT32_SIZE, KECCAK_OUT_EVAL_SIZE, KeccakInstance, KeccakLayout, - KeccakParams, KeccakStateInstance, KeccakTrace, KeccakWitInstance, RANGE_LOOKUPS, - ROUNDS as KECCAK_ROUNDS, ROUNDS_CEIL_LOG2 as KECCAK_ROUNDS_CEIL_LOG2, XOR_LOOKUPS, - run_lookup_keccakf, setup_gkr_circuit as setup_lookup_keccak_gkr_circuit, + KeccakStateInstance, KeccakTrace, KeccakWitInstance, RANGE_LOOKUPS, ROUNDS as KECCAK_ROUNDS, + ROUNDS_CEIL_LOG2 as KECCAK_ROUNDS_CEIL_LOG2, XOR_LOOKUPS, run_lookup_keccakf, + setup_gkr_circuit as setup_lookup_keccak_gkr_circuit, }; pub use bitwise_keccakf::{ KeccakLayout as BitwiseKeccakLayout, run_keccakf as run_bitwise_keccakf, setup_gkr_circuit as setup_bitwise_keccak_gkr_circuit, }; -use ff_ext::ExtensionField; pub use fptower::{ fp::{FpOpInstance, FpOpLayout, FpOpTrace}, fp2_addsub::{Fp2AddSubAssignLayout, Fp2AddSubInstance, Fp2AddSubTrace}, fp2_mul::{Fp2MulAssignLayout, Fp2MulInstance, Fp2MulTrace}, }; -use gkr_iop::selector::SelectorType; pub use sha256::{ SHA_EXTEND_ROUNDS, ShaExtendInstance, ShaExtendLayout, ShaExtendTrace, ShaExtendWitInstance, }; @@ -47,10 +45,3 @@ pub use weierstrass::{ setup_gkr_circuit as setup_weierstrass_double_circuit, }, }; - -#[derive(Clone, Debug)] -pub struct SelectorTypeLayout { - pub sel_first: Option>, - pub sel_last: Option>, - pub sel_all: SelectorType, -} diff --git a/ceno_zkvm/src/precompiles/sha256/extend.rs b/ceno_zkvm/src/precompiles/sha256/extend.rs index 235e37b95..491fa88c3 100644 --- a/ceno_zkvm/src/precompiles/sha256/extend.rs +++ b/ceno_zkvm/src/precompiles/sha256/extend.rs @@ -27,11 +27,13 @@ use std::{array, borrow::BorrowMut, mem::size_of}; use derive::AlignedBorrow; use ff_ext::{ExtensionField, SmallField}; use gkr_iop::{ - OutEvalGroups, ProtocolBuilder, ProtocolWitnessGenerator, chip::Chip, - circuit_builder::CircuitBuilder, error::CircuitBuilderError, selector::SelectorType, + ProtocolBuilder, ProtocolWitnessGenerator, chip::Chip, circuit_builder::CircuitBuilder, + default_out_eval_groups, error::CircuitBuilderError, gkr::layer::Layer, selector::SelectorType, }; use itertools::Itertools; -use multilinear_extensions::{Expression, ToExpr, WitIn, util::max_usable_threads}; +use multilinear_extensions::{ + Expression, StructuralWitIn, ToExpr, WitIn, util::max_usable_threads, +}; use p3::field::{FieldAlgebra, TwoAdicField}; use rayon::{ iter::{IndexedParallelIterator, ParallelIterator}, @@ -44,7 +46,7 @@ use crate::{ gadgets::{ Add4Operation, FixedRotateRightOperation, FixedShiftRightOperation, Word, XorOperation, }, - precompiles::{SelectorTypeLayout, utils::merge_u8_slice_to_u16_limbs_pairs_and_extend}, + precompiles::utils::merge_u8_slice_to_u16_limbs_pairs_and_extend, witness::LkMultiplicity, }; @@ -126,13 +128,9 @@ pub struct ShaExtendLayer { #[derive(Clone, Debug)] pub struct ShaExtendLayout { pub layer_exprs: ShaExtendLayer, - pub selector_type_layout: SelectorTypeLayout, + pub sel: StructuralWitIn, pub input32_exprs: [MemoryExpr; 4], pub output32_expr: MemoryExpr, - pub n_fixed: usize, - pub n_committed: usize, - pub n_structural_witin: usize, - pub n_challenges: usize, } impl ShaExtendLayout { @@ -173,27 +171,15 @@ impl ShaExtendLayout { s2: Add4Operation::create(cb, || "ShaExtendLayer::s2"), }; - let sel_all = cb.create_placeholder_structural_witin(|| "sha_extend_sel_all"); - - let selector_type_layout = SelectorTypeLayout { - sel_first: None, - sel_last: None, - sel_all: SelectorType::::Prefix(sel_all.expr()), - }; - let input32_exprs: [MemoryExpr; 4] = array::from_fn(|_| array::from_fn(|_| Expression::WitIn(0))); let output32_expr: MemoryExpr = array::from_fn(|_| Expression::WitIn(0)); Self { layer_exprs: ShaExtendLayer { wits }, - selector_type_layout, + sel: cb.create_placeholder_structural_witin(|| "sha_extend_sel"), input32_exprs, output32_expr, - n_fixed: 0, - n_committed: 0, - n_structural_witin: 6, - n_challenges: 0, } } } @@ -268,31 +254,15 @@ impl ProtocolBuilder for ShaExtendLayout { Ok(layout) } - fn finalize(&mut self, cb: &mut CircuitBuilder) -> (OutEvalGroups, Chip) { - self.n_fixed = cb.cs.num_fixed; - self.n_committed = cb.cs.num_witin as usize; - self.n_structural_witin = cb.cs.num_structural_witin as usize; - self.n_challenges = 0; - - cb.cs.r_selector = Some(self.selector_type_layout.sel_all.clone()); - cb.cs.w_selector = Some(self.selector_type_layout.sel_all.clone()); - cb.cs.lk_selector = Some(self.selector_type_layout.sel_all.clone()); - cb.cs.zero_selector = Some(self.selector_type_layout.sel_all.clone()); - - let w_len = cb.cs.w_expressions.len(); - let r_len = cb.cs.r_expressions.len(); - let lk_len = cb.cs.lk_expressions.len(); - let zero_len = - cb.cs.assert_zero_expressions.len() + cb.cs.assert_zero_sumcheck_expressions.len(); - ( - [ - (0..r_len).collect_vec(), - (r_len..r_len + w_len).collect_vec(), - (r_len + w_len..r_len + w_len + lk_len).collect_vec(), - (0..zero_len).collect_vec(), - ], - Chip::new_from_cb(cb, self.n_challenges), - ) + fn finalize(&self, name: String, cb: &mut CircuitBuilder) -> Chip { + let sel = SelectorType::Prefix(self.sel.expr()); + cb.cs.set_all_default_selectors(sel); + + let out_evals = default_out_eval_groups(cb); + let mut chip = Chip::new_from_cb(cb, 0); + let layer = Layer::from_circuit_builder(cb, name, 0, out_evals); + chip.add_layer(layer); + chip } } @@ -327,7 +297,7 @@ impl ProtocolWitnessGenerator for ShaExtendLayout { wits: [&mut RowMajorMatrix; 2], lk_multiplicity: &mut LkMultiplicity, ) { - let (wits_start, num_wit_cols) = ( + let (layout_start, num_layout_wit_cols) = ( self.layer_exprs.wits.w_i_minus_15.0[0].id as usize, size_of::>(), ); @@ -335,6 +305,8 @@ impl ProtocolWitnessGenerator for ShaExtendLayout { let num_instances = wits.num_instances(); let nthreads = max_usable_threads(); let num_instance_per_batch = num_instances.div_ceil(nthreads).max(1); + let wits_width = wits.width; + let structural_wits_width = structural_wits.width; let raw_witin_iter = wits.par_batch_iter_mut(num_instance_per_batch); let raw_structural_wits_iter = structural_wits.par_batch_iter_mut(num_instance_per_batch); raw_witin_iter @@ -342,17 +314,14 @@ impl ProtocolWitnessGenerator for ShaExtendLayout { .zip_eq(phase1.instances.par_chunks(num_instance_per_batch)) .for_each(|((rows, eqs), instances)| { let mut lk_multiplicity = lk_multiplicity.clone(); - rows.chunks_mut(self.n_committed) - .zip_eq(eqs.chunks_mut(self.n_structural_witin)) + rows.chunks_mut(wits_width) + .zip_eq(eqs.chunks_mut(structural_wits_width)) .zip_eq(instances.iter()) .for_each(|((rows, eqs), phase1_instance)| { - let sel_all_structural_witin = - self.selector_type_layout.sel_all.selector_expr().id(); - eqs[sel_all_structural_witin] = E::BaseField::ONE; - let cols: &mut ShaExtendWitCols = - rows[wits_start..][..num_wit_cols].borrow_mut(); + rows[layout_start..][..num_layout_wit_cols].borrow_mut(); cols.populate(&phase1_instance.witin, &mut lk_multiplicity); + eqs[self.sel.id as usize] = E::BaseField::ONE; }); }); } @@ -370,7 +339,7 @@ mod tests { use ff_ext::BabyBearExt4; use gkr_iop::{ cpu::{CpuBackend, CpuProver}, - gkr::{GKRProverOutput, layer::Layer}, + gkr::GKRProverOutput, selector::SelectorContext, }; use itertools::Itertools; @@ -389,16 +358,9 @@ mod tests { let mut cs = ConstraintSystem::::new(|| "sha_extend_test"); let mut cb = CircuitBuilder::::new(&mut cs); - let mut layout = + let layout = ShaExtendLayout::::build_layer_logic(&mut cb, ()).expect("build_layer_logic failed"); - let (out_evals, mut chip) = layout.finalize(&mut cb); - let layer = Layer::from_circuit_builder( - &cb, - "sha_extend".to_string(), - layout.n_challenges, - out_evals, - ); - chip.add_layer(layer); + let chip = layout.finalize("sha_extend".to_string(), &mut cb); let gkr_circuit = chip.gkr_circuit(); let mut rng = StdRng::seed_from_u64(1); @@ -427,12 +389,12 @@ mod tests { let num_instances = num_instances * SHA_EXTEND_ROUNDS; let mut phase1 = RowMajorMatrix::new( num_instances, - layout.n_committed, + cb.cs.num_witin as usize, InstancePaddingStrategy::Default, ); let mut structural = RowMajorMatrix::new( num_instances, - layout.n_structural_witin, + cb.cs.num_structural_witin as usize, InstancePaddingStrategy::Default, ); let mut lk_multiplicity = LkMultiplicity::default(); @@ -463,8 +425,8 @@ mod tests { } let num_instances_rounds = next_pow2_instance_padding(num_instances); - let log2_num_instance_rounds = ceil_log2(num_instances_rounds); - let num_threads = optimal_sumcheck_threads(log2_num_instance_rounds); + let log2_num_instance = ceil_log2(num_instances_rounds); + let num_threads = optimal_sumcheck_threads(log2_num_instance); let mut prover_transcript = BasicTranscript::::new(b"protocol"); let challenges = [ prover_transcript.read_challenge().elements, @@ -492,28 +454,22 @@ mod tests { ); let out_evals = { - let mut point = Vec::with_capacity(log2_num_instance_rounds); - point.extend( - prover_transcript - .sample_vec(log2_num_instance_rounds) - .to_vec(), - ); + let mut point = Vec::with_capacity(log2_num_instance); + point.extend(prover_transcript.sample_vec(log2_num_instance).to_vec()); let out_evals = gkr_output .0 .par_iter() .map(|wit| { - let point = point[point.len() - wit.num_vars()..point.len()].to_vec(); - PointAndEval { - point: point.clone(), - eval: wit.evaluate(&point), - } + let point = point[point.len() - log2_num_instance..point.len()].to_vec(); + let eval = wit.evaluate(&point); + PointAndEval { point, eval } }) .collect::>(); if out_evals.is_empty() { vec![PointAndEval { - point: point[point.len() - log2_num_instance_rounds..point.len()].to_vec(), + point: point[point.len() - log2_num_instance..point.len()].to_vec(), eval: E::ZERO, }] } else { @@ -521,12 +477,13 @@ mod tests { } }; + let selector_ctxs_len = gkr_circuit.layers[0].selector_ctxs_len(); let selector_ctxs = - vec![SelectorContext::new(0, num_instances, log2_num_instance_rounds); 1]; + vec![SelectorContext::new(0, num_instances, log2_num_instance); selector_ctxs_len]; let GKRProverOutput { gkr_proof, .. } = gkr_circuit .prove::, CpuProver<_>>( num_threads, - log2_num_instance_rounds, + log2_num_instance, gkr_witness, &out_evals, &[], @@ -541,16 +498,12 @@ mod tests { verifier_transcript.read_challenge().elements, verifier_transcript.read_challenge().elements, ]; - let mut point = Vec::with_capacity(log2_num_instance_rounds); - point.extend( - verifier_transcript - .sample_vec(log2_num_instance_rounds) - .to_vec(), - ); + let mut point = Vec::with_capacity(log2_num_instance); + point.extend(verifier_transcript.sample_vec(log2_num_instance).to_vec()); gkr_circuit .verify( - log2_num_instance_rounds, + log2_num_instance, gkr_proof, &out_evals, &[], diff --git a/ceno_zkvm/src/precompiles/uint256.rs b/ceno_zkvm/src/precompiles/uint256.rs index e59e97c55..2999622e6 100644 --- a/ceno_zkvm/src/precompiles/uint256.rs +++ b/ceno_zkvm/src/precompiles/uint256.rs @@ -29,7 +29,7 @@ use crate::{ error::ZKVMError, gadgets::{FieldOperation, IsZeroOperation, field_op::FieldOpCols, range::FieldLtCols}, instructions::riscv::insn_base::{StateInOut, WriteMEM}, - precompiles::{SelectorTypeLayout, utils::merge_u8_slice_to_u16_limbs_pairs_and_extend}, + precompiles::utils::merge_u8_slice_to_u16_limbs_pairs_and_extend, scheme::utils::gkr_witness, structs::PointAndEval, witness::LkMultiplicity, @@ -39,9 +39,10 @@ use derive::AlignedBorrow; use ff_ext::{ExtensionField, SmallField}; use generic_array::{GenericArray, sequence::GenericSequence}; use gkr_iop::{ - OutEvalGroups, ProtocolBuilder, ProtocolWitnessGenerator, + ProtocolBuilder, ProtocolWitnessGenerator, chip::Chip, cpu::{CpuBackend, CpuProver}, + default_out_eval_groups, error::{BackendError, CircuitBuilderError}, gkr::{GKRCircuit, GKRProof, GKRProverOutput, layer::Layer, mock::MockProver}, selector::{SelectorContext, SelectorType}, @@ -49,7 +50,7 @@ use gkr_iop::{ use itertools::{Itertools, izip}; use mpcs::PolynomialCommitmentScheme; use multilinear_extensions::{ - Expression, ToExpr, WitIn, + Expression, StructuralWitIn, ToExpr, WitIn, util::{ceil_log2, max_usable_threads}, }; use num::{BigUint, One, Zero}; @@ -103,14 +104,10 @@ pub struct Uint256MulLayer { #[derive(Clone, Debug)] pub struct Uint256MulLayout { pub layer_exprs: Uint256MulLayer, - pub selector_type_layout: SelectorTypeLayout, + pub sel: StructuralWitIn, /// Read x, y, and modulus from memory. pub input32_exprs: [GenericArray, ::WordsFieldElement>; 3], pub output32_exprs: GenericArray, ::WordsFieldElement>, - pub n_fixed: usize, - pub n_committed: usize, - pub n_structural_witin: usize, - pub n_challenges: usize, } impl Uint256MulLayout { @@ -127,14 +124,6 @@ impl Uint256MulLayout { output_range_check: FieldLtCols::create(cb, || "uint256_mul_output_range_check"), }; - let eq = cb.create_placeholder_structural_witin(|| "uint256_mul_structural_witin"); - let sel = SelectorType::Prefix(eq.expr()); - let selector_type_layout = SelectorTypeLayout { - sel_first: None, - sel_last: None, - sel_all: sel.clone(), - }; - // Default expression, will be updated in build_layer_logic let input32_exprs: [GenericArray, ::WordsFieldElement>; 3] = array::from_fn(|_| { @@ -148,13 +137,9 @@ impl Uint256MulLayout { Self { layer_exprs: Uint256MulLayer { wits }, - selector_type_layout, + sel: cb.create_placeholder_structural_witin(|| "uint256_mul_sel"), input32_exprs, output32_exprs, - n_fixed: 0, - n_committed: 0, - n_challenges: 0, - n_structural_witin: 0, } } @@ -281,56 +266,15 @@ impl ProtocolBuilder for Uint256MulLayout { Ok(layout) } - fn finalize(&mut self, cb: &mut CircuitBuilder) -> (OutEvalGroups, Chip) { - self.n_fixed = cb.cs.num_fixed; - self.n_committed = cb.cs.num_witin as usize; - self.n_structural_witin = cb.cs.num_structural_witin as usize; - self.n_challenges = 0; - - // register selector to legacy constrain system - cb.cs.r_selector = Some(self.selector_type_layout.sel_all.clone()); - cb.cs.w_selector = Some(self.selector_type_layout.sel_all.clone()); - cb.cs.lk_selector = Some(self.selector_type_layout.sel_all.clone()); - cb.cs.zero_selector = Some(self.selector_type_layout.sel_all.clone()); - - let w_len = cb.cs.w_expressions.len(); - let r_len = cb.cs.r_expressions.len(); - let lk_len = cb.cs.lk_expressions.len(); - let zero_len = - cb.cs.assert_zero_expressions.len() + cb.cs.assert_zero_sumcheck_expressions.len(); - ( - [ - // r_record - (0..r_len).collect_vec(), - // w_record - (r_len..r_len + w_len).collect_vec(), - // lk_record - (r_len + w_len..r_len + w_len + lk_len).collect_vec(), - // zero_record - (0..zero_len).collect_vec(), - ], - Chip::new_from_cb(cb, self.n_challenges), - ) - } - - fn n_committed(&self) -> usize { - todo!() - } - - fn n_fixed(&self) -> usize { - todo!() - } + fn finalize(&self, name: String, cb: &mut CircuitBuilder) -> Chip { + let sel = SelectorType::Prefix(self.sel.expr()); + cb.cs.set_all_default_selectors(sel); - fn n_challenges(&self) -> usize { - todo!() - } - - fn n_evaluations(&self) -> usize { - todo!() - } - - fn n_layers(&self) -> usize { - todo!() + let out_evals = default_out_eval_groups(cb); + let mut chip = Chip::new_from_cb(cb, 0); + let layer = Layer::from_circuit_builder(cb, name, 0, out_evals); + chip.add_layer(layer); + chip } } @@ -354,8 +298,13 @@ impl ProtocolWitnessGenerator for Uint256MulLayout { let num_instances = wits[0].num_instances(); let nthreads = max_usable_threads(); let num_instance_per_batch = num_instances.div_ceil(nthreads).max(1); - let num_wit_cols = size_of::>(); + let (layout_start, num_layout_wit_cols) = ( + self.layer_exprs.wits.x_limbs.0[0].id as usize, + size_of::>(), + ); let [wits, structural_wits] = wits; + let wits_width = wits.width; + let structural_wits_width = structural_wits.width; let raw_witin_iter = wits.par_batch_iter_mut(num_instance_per_batch); let raw_structural_wits_iter = structural_wits.par_batch_iter_mut(num_instance_per_batch); raw_witin_iter @@ -363,17 +312,14 @@ impl ProtocolWitnessGenerator for Uint256MulLayout { .zip_eq(phase1.instances.par_chunks(num_instance_per_batch)) .for_each(|((rows, eqs), phase1_instances)| { let mut lk_multiplicity = lk_multiplicity.clone(); - rows.chunks_mut(self.n_committed) - .zip_eq(eqs.chunks_mut(self.n_structural_witin)) + rows.chunks_mut(wits_width) + .zip_eq(eqs.chunks_mut(structural_wits_width)) .zip_eq(phase1_instances) .for_each(|((row, eqs), phase1_instance)| { - let cols: &mut Uint256MulWitCols = row - [self.layer_exprs.wits.x_limbs.0[0].id as usize..][..num_wit_cols] // TODO: Find a better way to write it. - .borrow_mut(); + let cols: &mut Uint256MulWitCols = + row[layout_start..][..num_layout_wit_cols].borrow_mut(); Self::populate_row(&mut lk_multiplicity, cols, phase1_instance); - for x in eqs.iter_mut() { - *x = E::BaseField::ONE; - } + eqs[self.sel.id as usize] = E::BaseField::ONE; }); }); } @@ -408,15 +354,11 @@ pub struct Uint256InvLayer { #[derive(Clone, Debug)] pub struct Uint256InvLayout { pub layer_exprs: Uint256InvLayer, - pub selector_type_layout: SelectorTypeLayout, + pub sel: StructuralWitIn, // y from memory pub input32_exprs: GenericArray, ::WordsFieldElement>, pub modulus_limbs: Limbs, ::Limbs>, pub output32_exprs: GenericArray, ::WordsFieldElement>, - pub n_fixed: usize, - pub n_committed: usize, - pub n_structural_witin: usize, - pub n_challenges: usize, phantom: PhantomData, } @@ -429,14 +371,6 @@ impl Uint256InvLayout { }; let modulus_limbs = Spec::P::to_limbs_expr(&Spec::modulus()); - let eq = cb.create_placeholder_structural_witin(|| "uint256_mul_structural_witin"); - let sel = SelectorType::Prefix(eq.expr()); - let selector_type_layout = SelectorTypeLayout { - sel_first: None, - sel_last: None, - sel_all: sel.clone(), - }; - // Default expression, will be updated in build_layer_logic let input32_exprs = GenericArray::generate(|_| array::from_fn(|_| Expression::WitIn(0))); // Default expression, will be updated in build_layer_logic @@ -444,14 +378,10 @@ impl Uint256InvLayout { Self { layer_exprs: Uint256InvLayer { wits }, - selector_type_layout, + sel: cb.create_placeholder_structural_witin(|| "uint256_inv_sel"), input32_exprs, modulus_limbs, output32_exprs, - n_fixed: 0, - n_committed: 0, - n_challenges: 0, - n_structural_witin: 0, phantom: Default::default(), } } @@ -532,56 +462,15 @@ impl ProtocolBuilder for Uint256InvL Ok(layout) } - fn finalize(&mut self, cb: &mut CircuitBuilder) -> (OutEvalGroups, Chip) { - self.n_fixed = cb.cs.num_fixed; - self.n_committed = cb.cs.num_witin as usize; - self.n_structural_witin = cb.cs.num_structural_witin as usize; - self.n_challenges = 0; - - // register selector to legacy constrain system - cb.cs.r_selector = Some(self.selector_type_layout.sel_all.clone()); - cb.cs.w_selector = Some(self.selector_type_layout.sel_all.clone()); - cb.cs.lk_selector = Some(self.selector_type_layout.sel_all.clone()); - cb.cs.zero_selector = Some(self.selector_type_layout.sel_all.clone()); - - let w_len = cb.cs.w_expressions.len(); - let r_len = cb.cs.r_expressions.len(); - let lk_len = cb.cs.lk_expressions.len(); - let zero_len = - cb.cs.assert_zero_expressions.len() + cb.cs.assert_zero_sumcheck_expressions.len(); - ( - [ - // r_record - (0..r_len).collect_vec(), - // w_record - (r_len..r_len + w_len).collect_vec(), - // lk_record - (r_len + w_len..r_len + w_len + lk_len).collect_vec(), - // zero_record - (0..zero_len).collect_vec(), - ], - Chip::new_from_cb(cb, self.n_challenges), - ) - } - - fn n_committed(&self) -> usize { - todo!() - } - - fn n_fixed(&self) -> usize { - todo!() - } - - fn n_challenges(&self) -> usize { - todo!() - } - - fn n_evaluations(&self) -> usize { - todo!() - } + fn finalize(&self, name: String, cb: &mut CircuitBuilder) -> Chip { + let sel = SelectorType::Prefix(self.sel.expr()); + cb.cs.set_all_default_selectors(sel); - fn n_layers(&self) -> usize { - todo!() + let out_evals = default_out_eval_groups(cb); + let mut chip = Chip::new_from_cb(cb, 0); + let layer = Layer::from_circuit_builder(cb, name, 0, out_evals); + chip.add_layer(layer); + chip } } @@ -607,8 +496,13 @@ impl ProtocolWitnessGenerator let num_instances = wits[0].num_instances(); let nthreads = max_usable_threads(); let num_instance_per_batch = num_instances.div_ceil(nthreads).max(1); - let num_wit_cols = size_of::>(); + let (layout_start, num_layout_wit_cols) = ( + self.layer_exprs.wits.y_limbs.0[0].id as usize, + size_of::>(), + ); let [wits, structural_wits] = wits; + let wits_width = wits.width; + let structural_wits_width = structural_wits.width; let raw_witin_iter = wits.par_batch_iter_mut(num_instance_per_batch); let raw_structural_wits_iter = structural_wits.par_batch_iter_mut(num_instance_per_batch); raw_witin_iter @@ -616,17 +510,14 @@ impl ProtocolWitnessGenerator .zip_eq(phase1.instances.par_chunks(num_instance_per_batch)) .for_each(|((rows, eqs), phase1_instances)| { let mut lk_multiplicity = lk_multiplicity.clone(); - rows.chunks_mut(self.n_committed) - .zip_eq(eqs.chunks_mut(self.n_structural_witin)) + rows.chunks_mut(wits_width) + .zip_eq(eqs.chunks_mut(structural_wits_width)) .zip_eq(phase1_instances) .for_each(|((row, eqs), phase1_instance)| { - let cols: &mut Uint256InvWitCols = row - [self.layer_exprs.wits.y_limbs.0[0].id as usize..][..num_wit_cols] // TODO: Find a better way to write it. - .borrow_mut(); + let cols: &mut Uint256InvWitCols = + row[layout_start..][..num_layout_wit_cols].borrow_mut(); Self::populate_row(&mut lk_multiplicity, cols, phase1_instance); - for x in eqs.iter_mut() { - *x = E::BaseField::ONE; - } + eqs[self.sel.id as usize] = E::BaseField::ONE; }); }); } @@ -663,7 +554,7 @@ pub fn setup_uint256mul_gkr_circuit() let number_ptr = cb.create_witin(|| "state_ptr_0"); - let mut layout = Uint256MulLayout::build_layer_logic(&mut cb, ())?; + let layout = Uint256MulLayout::build_layer_logic(&mut cb, ())?; // Write the result to the same address of the first input point. let limb_len = layout.output32_exprs.len(); @@ -705,16 +596,7 @@ pub fn setup_uint256mul_gkr_circuit() }) .collect::, _>>()?; - let (out_evals, mut chip) = layout.finalize(&mut cb); - - let layer = Layer::from_circuit_builder( - &cb, - "weierstrass_add".to_string(), - layout.n_challenges, - out_evals, - ); - chip.add_layer(layer); - + let chip = layout.finalize("uint256_mul".to_string(), &mut cb); Ok(( TestUint256MulLayout { layout, @@ -884,11 +766,9 @@ pub fn run_uint256_mul + ' .0 .par_iter() .map(|wit| { - let point = point[point.len() - wit.num_vars()..point.len()].to_vec(); - PointAndEval { - point: point.clone(), - eval: wit.evaluate(&point), - } + let point = point[point.len() - log2_num_instance..point.len()].to_vec(); + let eval = wit.evaluate(&point); + PointAndEval { point, eval } }) .collect::>(); @@ -911,7 +791,9 @@ pub fn run_uint256_mul + ' } let span = entered_span!("create_proof", profiling_2 = true); - let selector_ctxs = vec![SelectorContext::new(0, num_instances, log2_num_instance); 1]; + let selector_ctxs_len = gkr_circuit.layers[0].selector_ctxs_len(); + let selector_ctxs = + vec![SelectorContext::new(0, num_instances, log2_num_instance); selector_ctxs_len]; let GKRProverOutput { gkr_proof, .. } = gkr_circuit .prove::, CpuProver<_>>( num_threads, diff --git a/ceno_zkvm/src/precompiles/weierstrass/weierstrass_add.rs b/ceno_zkvm/src/precompiles/weierstrass/weierstrass_add.rs index 012dcab80..9aef9453e 100644 --- a/ceno_zkvm/src/precompiles/weierstrass/weierstrass_add.rs +++ b/ceno_zkvm/src/precompiles/weierstrass/weierstrass_add.rs @@ -30,10 +30,11 @@ use derive::AlignedBorrow; use ff_ext::{ExtensionField, SmallField}; use generic_array::{GenericArray, sequence::GenericSequence, typenum::Unsigned}; use gkr_iop::{ - OutEvalGroups, ProtocolBuilder, ProtocolWitnessGenerator, + ProtocolBuilder, ProtocolWitnessGenerator, chip::Chip, circuit_builder::{CircuitBuilder, ConstraintSystem}, cpu::{CpuBackend, CpuProver}, + default_out_eval_groups, error::{BackendError, CircuitBuilderError}, gkr::{GKRCircuit, GKRProof, GKRProverOutput, layer::Layer, mock::MockProver}, selector::{SelectorContext, SelectorType}, @@ -41,7 +42,7 @@ use gkr_iop::{ use itertools::{Itertools, izip}; use mpcs::PolynomialCommitmentScheme; use multilinear_extensions::{ - Expression, ToExpr, WitIn, + Expression, StructuralWitIn, ToExpr, WitIn, util::{ceil_log2, max_usable_threads}, }; use num::BigUint; @@ -68,8 +69,7 @@ use crate::{ gadgets::{FieldOperation, field_op::FieldOpCols}, instructions::riscv::insn_base::{StateInOut, WriteMEM}, precompiles::{ - SelectorTypeLayout, utils::merge_u8_slice_to_u16_limbs_pairs_and_extend, - weierstrass::EllipticCurveAddInstance, + utils::merge_u8_slice_to_u16_limbs_pairs_and_extend, weierstrass::EllipticCurveAddInstance, }, scheme::utils::gkr_witness, structs::PointAndEval, @@ -104,14 +104,10 @@ pub struct WeierstrassAddAssignLayer { #[derive(Clone, Debug)] pub struct WeierstrassAddAssignLayout { pub layer_exprs: WeierstrassAddAssignLayer, - pub selector_type_layout: SelectorTypeLayout, + pub sel: StructuralWitIn, pub input32_exprs: [GenericArray, ::WordsCurvePoint>; 2], pub output32_exprs: GenericArray, ::WordsCurvePoint>, - pub n_fixed: usize, - pub n_committed: usize, - pub n_structural_witin: usize, - pub n_challenges: usize, } impl WeierstrassAddAssignLayout { @@ -132,14 +128,6 @@ impl WeierstrassAddAssignLayout { slope_times_p_x_minus_x: FieldOpCols::create(cb, || "slope_times_p_x_minus_x"), }; - let eq = cb.create_placeholder_structural_witin(|| "weierstrass_add_eq"); - let sel = SelectorType::Prefix(eq.expr()); - let selector_type_layout = SelectorTypeLayout { - sel_first: None, - sel_last: None, - sel_all: sel.clone(), - }; - // Default expression, will be updated in build_layer_logic let input32_exprs: [GenericArray< MemoryExpr, @@ -155,13 +143,9 @@ impl WeierstrassAddAssignLayout { Self { layer_exprs: WeierstrassAddAssignLayer { wits }, - selector_type_layout, + sel: cb.create_placeholder_structural_witin(|| "weierstrass_add_sel"), input32_exprs, output32_exprs, - n_fixed: 0, - n_committed: 0, - n_challenges: 0, - n_structural_witin: 0, } } @@ -326,56 +310,15 @@ impl ProtocolBuilder Ok(layout) } - fn finalize(&mut self, cb: &mut CircuitBuilder) -> (OutEvalGroups, Chip) { - self.n_fixed = cb.cs.num_fixed; - self.n_committed = cb.cs.num_witin as usize; - self.n_structural_witin = cb.cs.num_structural_witin as usize; - self.n_challenges = 0; - - // register selector to legacy constrain system - cb.cs.r_selector = Some(self.selector_type_layout.sel_all.clone()); - cb.cs.w_selector = Some(self.selector_type_layout.sel_all.clone()); - cb.cs.lk_selector = Some(self.selector_type_layout.sel_all.clone()); - cb.cs.zero_selector = Some(self.selector_type_layout.sel_all.clone()); - - let w_len = cb.cs.w_expressions.len(); - let r_len = cb.cs.r_expressions.len(); - let lk_len = cb.cs.lk_expressions.len(); - let zero_len = - cb.cs.assert_zero_expressions.len() + cb.cs.assert_zero_sumcheck_expressions.len(); - ( - [ - // r_record - (0..r_len).collect_vec(), - // w_record - (r_len..r_len + w_len).collect_vec(), - // lk_record - (r_len + w_len..r_len + w_len + lk_len).collect_vec(), - // zero_record - (0..zero_len).collect_vec(), - ], - Chip::new_from_cb(cb, self.n_challenges), - ) - } - - fn n_committed(&self) -> usize { - todo!() - } - - fn n_fixed(&self) -> usize { - todo!() - } - - fn n_challenges(&self) -> usize { - todo!() - } + fn finalize(&self, name: String, cb: &mut CircuitBuilder) -> Chip { + let sel = SelectorType::Prefix(self.sel.expr()); + cb.cs.set_all_default_selectors(sel); - fn n_evaluations(&self) -> usize { - todo!() - } - - fn n_layers(&self) -> usize { - todo!() + let out_evals = default_out_eval_groups(cb); + let mut chip = Chip::new_from_cb(cb, 0); + let layer = Layer::from_circuit_builder(cb, name, 0, out_evals); + chip.add_layer(layer); + chip } } @@ -402,8 +345,13 @@ impl ProtocolWitnessGenerator let num_instances = wits[0].num_instances(); let nthreads = max_usable_threads(); let num_instance_per_batch = num_instances.div_ceil(nthreads).max(1); - let num_wit_cols = size_of::>(); + let (layout_start, num_layout_wit_cols) = ( + self.layer_exprs.wits.p_x.0[0].id as usize, + size_of::>(), + ); let [wits, structural_wits] = wits; + let wits_width = wits.width; + let structural_wits_width = structural_wits.width; let raw_witin_iter = wits.par_batch_iter_mut(num_instance_per_batch); let raw_structural_wits_iter = structural_wits.par_batch_iter_mut(num_instance_per_batch); raw_witin_iter @@ -411,17 +359,14 @@ impl ProtocolWitnessGenerator .zip_eq(phase1.instances.par_chunks(num_instance_per_batch)) .for_each(|((rows, eqs), phase1_instances)| { let mut lk_multiplicity = lk_multiplicity.clone(); - rows.chunks_mut(self.n_committed) - .zip_eq(eqs.chunks_mut(self.n_structural_witin)) + rows.chunks_mut(wits_width) + .zip_eq(eqs.chunks_mut(structural_wits_width)) .zip_eq(phase1_instances) .for_each(|((row, eqs), phase1_instance)| { let cols: &mut WeierstrassAddAssignWitCols = - row[self.layer_exprs.wits.p_x.0[0].id as usize..][..num_wit_cols] // TODO: Find a better way to write it. - .borrow_mut(); + row[layout_start..][..num_layout_wit_cols].borrow_mut(); Self::populate_row(phase1_instance, cols, &mut lk_multiplicity); - for x in eqs.iter_mut() { - *x = E::BaseField::ONE; - } + eqs[self.sel.id as usize] = E::BaseField::ONE; }); }); } @@ -469,7 +414,7 @@ pub fn setup_gkr_circuit() let point_ptr_0 = cb.create_witin(|| "state_ptr_0"); - let mut layout = WeierstrassAddAssignLayout::build_layer_logic(&mut cb, ())?; + let layout = WeierstrassAddAssignLayout::build_layer_logic(&mut cb, ())?; // Write the result to the same address of the first input point. let mut mem_rw = izip!(&layout.input32_exprs[0], &layout.output32_exprs) @@ -508,15 +453,7 @@ pub fn setup_gkr_circuit() .collect::, _>>()?, ); - let (out_evals, mut chip) = layout.finalize(&mut cb); - - let layer = Layer::from_circuit_builder( - &cb, - "weierstrass_add".to_string(), - layout.n_challenges, - out_evals, - ); - chip.add_layer(layer); + let chip = layout.finalize("weierstrass_add".to_string(), &mut cb); Ok(( TestWeierstrassAddLayout { @@ -718,11 +655,9 @@ pub fn run_weierstrass_add< .0 .par_iter() .map(|wit| { - let point = point[point.len() - wit.num_vars()..point.len()].to_vec(); - PointAndEval { - point: point.clone(), - eval: wit.evaluate(&point), - } + let point = point[point.len() - log2_num_instance..point.len()].to_vec(); + let eval = wit.evaluate(&point); + PointAndEval { point, eval } }) .collect::>(); @@ -745,7 +680,9 @@ pub fn run_weierstrass_add< } let span = entered_span!("create_proof", profiling_2 = true); - let selector_ctxs = vec![SelectorContext::new(0, num_instances, log2_num_instance); 1]; + let selector_ctxs_len = gkr_circuit.layers[0].selector_ctxs_len(); + let selector_ctxs = + vec![SelectorContext::new(0, num_instances, log2_num_instance); selector_ctxs_len]; let GKRProverOutput { gkr_proof, .. } = gkr_circuit .prove::, CpuProver<_>>( num_threads, @@ -824,22 +761,22 @@ mod tests { #[test] fn test_weierstrass_add_bn254() { - test_weierstrass_add_helper::(); + test_weierstrass_add_helper::() } #[test] fn test_weierstrass_add_bls12381() { - test_weierstrass_add_helper::(); + test_weierstrass_add_helper::() } #[test] fn test_weierstrass_add_secp256k1() { - test_weierstrass_add_helper::(); + test_weierstrass_add_helper::() } #[test] fn test_weierstrass_add_secp256r1() { - test_weierstrass_add_helper::(); + test_weierstrass_add_helper::() } fn test_weierstrass_add_nonpow2_helper() { diff --git a/ceno_zkvm/src/precompiles/weierstrass/weierstrass_decompress.rs b/ceno_zkvm/src/precompiles/weierstrass/weierstrass_decompress.rs index d6400a2d7..d637fcb6b 100644 --- a/ceno_zkvm/src/precompiles/weierstrass/weierstrass_decompress.rs +++ b/ceno_zkvm/src/precompiles/weierstrass/weierstrass_decompress.rs @@ -30,10 +30,11 @@ use derive::AlignedBorrow; use ff_ext::{ExtensionField, SmallField}; use generic_array::{GenericArray, sequence::GenericSequence, typenum::Unsigned}; use gkr_iop::{ - OutEvalGroups, ProtocolBuilder, ProtocolWitnessGenerator, + ProtocolBuilder, ProtocolWitnessGenerator, chip::Chip, circuit_builder::{CircuitBuilder, ConstraintSystem}, cpu::{CpuBackend, CpuProver}, + default_out_eval_groups, error::{BackendError, CircuitBuilderError}, gkr::{GKRCircuit, GKRProof, GKRProverOutput, layer::Layer, mock::MockProver}, selector::{SelectorContext, SelectorType}, @@ -41,7 +42,7 @@ use gkr_iop::{ use itertools::{Itertools, izip}; use mpcs::PolynomialCommitmentScheme; use multilinear_extensions::{ - Expression, ToExpr, WitIn, + Expression, StructuralWitIn, ToExpr, WitIn, macros::{entered_span, exit_span}, util::{ceil_log2, max_usable_threads}, }; @@ -78,7 +79,7 @@ use crate::{ insn_base::{StateInOut, WriteMEM}, }, precompiles::{ - SelectorTypeLayout, utils::merge_u8_slice_to_u16_limbs_pairs_and_extend, + utils::merge_u8_slice_to_u16_limbs_pairs_and_extend, weierstrass::EllipticCurveDecompressInstance, }, scheme::utils::gkr_witness, @@ -113,15 +114,11 @@ pub struct WeierstrassDecompressLayer { #[derive(Clone, Debug)] pub struct WeierstrassDecompressLayout { pub layer_exprs: WeierstrassDecompressLayer, - pub selector_type_layout: SelectorTypeLayout, + pub sel: StructuralWitIn, pub input32_exprs: GenericArray, ::WordsFieldElement>, pub old_output32_exprs: GenericArray, ::WordsFieldElement>, pub output32_exprs: GenericArray, ::WordsFieldElement>, - pub n_fixed: usize, - pub n_committed: usize, - pub n_structural_witin: usize, - pub n_challenges: usize, } impl @@ -150,14 +147,6 @@ impl neg_y: FieldOpCols::create(cb, || "neg_y"), }; - let eq = cb.create_placeholder_structural_witin(|| "weierstrass_decompress_eq"); - let sel = SelectorType::Prefix(eq.expr()); - let selector_type_layout = SelectorTypeLayout { - sel_first: None, - sel_last: None, - sel_all: sel.clone(), - }; - let input32_exprs: GenericArray< MemoryExpr, ::WordsFieldElement, @@ -175,14 +164,10 @@ impl Self { layer_exprs: WeierstrassDecompressLayer { wits }, - selector_type_layout, + sel: cb.create_placeholder_structural_witin(|| "weierstrass_decompress_sel"), input32_exprs, old_output32_exprs, output32_exprs, - n_fixed: 0, - n_committed: 0, - n_structural_witin: 0, - n_challenges: 0, } } @@ -336,56 +321,15 @@ impl ProtocolBuild Ok(layout) } - fn finalize(&mut self, cb: &mut CircuitBuilder) -> (OutEvalGroups, Chip) { - self.n_fixed = cb.cs.num_fixed; - self.n_committed = cb.cs.num_witin as usize; - self.n_structural_witin = cb.cs.num_structural_witin as usize; - self.n_challenges = 0; - - // register selector to legacy constrain system - cb.cs.r_selector = Some(self.selector_type_layout.sel_all.clone()); - cb.cs.w_selector = Some(self.selector_type_layout.sel_all.clone()); - cb.cs.lk_selector = Some(self.selector_type_layout.sel_all.clone()); - cb.cs.zero_selector = Some(self.selector_type_layout.sel_all.clone()); - - let w_len = cb.cs.w_expressions.len(); - let r_len = cb.cs.r_expressions.len(); - let lk_len = cb.cs.lk_expressions.len(); - let zero_len = - cb.cs.assert_zero_expressions.len() + cb.cs.assert_zero_sumcheck_expressions.len(); - ( - [ - // r_record - (0..r_len).collect_vec(), - // w_record - (r_len..r_len + w_len).collect_vec(), - // lk_record - (r_len + w_len..r_len + w_len + lk_len).collect_vec(), - // zero_record - (0..zero_len).collect_vec(), - ], - Chip::new_from_cb(cb, self.n_challenges), - ) - } - - fn n_committed(&self) -> usize { - todo!() - } - - fn n_fixed(&self) -> usize { - todo!() - } - - fn n_challenges(&self) -> usize { - todo!() - } - - fn n_evaluations(&self) -> usize { - todo!() - } + fn finalize(&self, name: String, cb: &mut CircuitBuilder) -> Chip { + let sel = SelectorType::Prefix(self.sel.expr()); + cb.cs.set_all_default_selectors(sel); - fn n_layers(&self) -> usize { - todo!() + let out_evals = default_out_eval_groups(cb); + let mut chip = Chip::new_from_cb(cb, 0); + let layer = Layer::from_circuit_builder(cb, name, 0, out_evals); + chip.add_layer(layer); + chip } } @@ -415,9 +359,14 @@ impl ProtocolWitne let num_instance_per_batch = num_instances.div_ceil(nthreads).max(1); // The number of columns used for weierstrass decompress subcircuit. - let num_main_wit_cols = size_of::>(); + let (layout_start, num_layout_wit_cols) = ( + self.layer_exprs.wits.sign_bit.id as usize, + size_of::>(), + ); let [wits, structural_wits] = wits; + let wits_width = wits.width; + let structural_wits_width = structural_wits.width; let raw_witin_iter = wits.par_batch_iter_mut(num_instance_per_batch); let raw_structural_wits_iter = structural_wits.par_batch_iter_mut(num_instance_per_batch); raw_witin_iter @@ -425,18 +374,15 @@ impl ProtocolWitne .zip_eq(phase1.instances.par_chunks(num_instance_per_batch)) .for_each(|((rows, eqs), phase1_instances)| { let mut lk_multiplicity = lk_multiplicity.clone(); - rows.chunks_mut(self.n_committed) - .zip_eq(eqs.chunks_mut(self.n_structural_witin)) + rows.chunks_mut(wits_width) + .zip_eq(eqs.chunks_mut(structural_wits_width)) .zip_eq(phase1_instances) .for_each(|((row, eqs), phase1_instance)| { let cols: &mut WeierstrassDecompressWitCols = - row[self.layer_exprs.wits.sign_bit.id as usize..][..num_main_wit_cols] // TODO: Find a better way to write it. - .borrow_mut(); + row[layout_start..][..num_layout_wit_cols].borrow_mut(); Self::populate(&mut lk_multiplicity, cols, phase1_instance); - for x in eqs.iter_mut() { - *x = E::BaseField::ONE; - } + eqs[self.sel.id as usize] = E::BaseField::ONE; }); }); } @@ -468,7 +414,7 @@ pub fn setup_gkr_circuit::Limbs::U32; let mut mem_rw = layout @@ -506,15 +452,7 @@ pub fn setup_gkr_circuit, _>>()?, ); - let (out_evals, mut chip) = layout.finalize(&mut cb); - - let layer = Layer::from_circuit_builder( - &cb, - "weierstrass_decompress".to_string(), - layout.n_challenges, - out_evals, - ); - chip.add_layer(layer); + let chip = layout.finalize("weierstrass_decompress".to_string(), &mut cb); Ok(( TestWeierstrassDecompressLayout { @@ -698,11 +636,9 @@ pub fn run_weierstrass_decompress< .0 .par_iter() .map(|wit| { - let point = point[point.len() - wit.num_vars()..point.len()].to_vec(); - PointAndEval { - point: point.clone(), - eval: wit.evaluate(&point), - } + let point = point[point.len() - log2_num_instance..point.len()].to_vec(); + let eval = wit.evaluate(&point); + PointAndEval { point, eval } }) .collect::>(); @@ -725,7 +661,9 @@ pub fn run_weierstrass_decompress< } let span = entered_span!("create_proof", profiling_2 = true); - let selector_ctxs = vec![SelectorContext::new(0, num_instances, log2_num_instance); 1]; + let selector_ctxs_len = gkr_circuit.layers[0].selector_ctxs_len(); + let selector_ctxs = + vec![SelectorContext::new(0, num_instances, log2_num_instance); selector_ctxs_len]; let GKRProverOutput { gkr_proof, .. } = gkr_circuit .prove::, CpuProver<_>>( num_threads, diff --git a/ceno_zkvm/src/precompiles/weierstrass/weierstrass_double.rs b/ceno_zkvm/src/precompiles/weierstrass/weierstrass_double.rs index 686baa397..8a28dadb6 100644 --- a/ceno_zkvm/src/precompiles/weierstrass/weierstrass_double.rs +++ b/ceno_zkvm/src/precompiles/weierstrass/weierstrass_double.rs @@ -30,10 +30,11 @@ use derive::AlignedBorrow; use ff_ext::{ExtensionField, SmallField}; use generic_array::{GenericArray, sequence::GenericSequence, typenum::Unsigned}; use gkr_iop::{ - OutEvalGroups, ProtocolBuilder, ProtocolWitnessGenerator, + ProtocolBuilder, ProtocolWitnessGenerator, chip::Chip, circuit_builder::{CircuitBuilder, ConstraintSystem}, cpu::{CpuBackend, CpuProver}, + default_out_eval_groups, error::{BackendError, CircuitBuilderError}, gkr::{GKRCircuit, GKRProof, GKRProverOutput, layer::Layer, mock::MockProver}, selector::{SelectorContext, SelectorType}, @@ -41,7 +42,7 @@ use gkr_iop::{ use itertools::{Itertools, izip}; use mpcs::PolynomialCommitmentScheme; use multilinear_extensions::{ - Expression, ToExpr, WitIn, + Expression, StructuralWitIn, ToExpr, WitIn, util::{ceil_log2, max_usable_threads}, }; use num::BigUint; @@ -69,7 +70,7 @@ use crate::{ gadgets::{FieldOperation, field_op::FieldOpCols}, instructions::riscv::insn_base::{StateInOut, WriteMEM}, precompiles::{ - SelectorTypeLayout, utils::merge_u8_slice_to_u16_limbs_pairs_and_extend, + utils::merge_u8_slice_to_u16_limbs_pairs_and_extend, weierstrass::EllipticCurveDoubleInstance, }, scheme::utils::gkr_witness, @@ -105,13 +106,9 @@ pub struct WeierstrassDoubleAssignLayer { #[derive(Clone, Debug)] pub struct WeierstrassDoubleAssignLayout { pub layer_exprs: WeierstrassDoubleAssignLayer, - pub selector_type_layout: SelectorTypeLayout, + pub sel: StructuralWitIn, pub input32_exprs: GenericArray, ::WordsCurvePoint>, pub output32_exprs: GenericArray, ::WordsCurvePoint>, - pub n_fixed: usize, - pub n_committed: usize, - pub n_structural_witin: usize, - pub n_challenges: usize, } impl @@ -134,14 +131,6 @@ impl slope_times_p_x_minus_x: FieldOpCols::create(cb, || "slope_times_p_x_minus_x"), }; - let eq = cb.create_placeholder_structural_witin(|| "weierstrass_double_eq"); - let sel = SelectorType::Prefix(eq.expr()); - let selector_type_layout = SelectorTypeLayout { - sel_first: None, - sel_last: None, - sel_all: sel.clone(), - }; - let input32_exprs: GenericArray< MemoryExpr, ::WordsCurvePoint, @@ -153,13 +142,9 @@ impl Self { layer_exprs: WeierstrassDoubleAssignLayer { wits }, - selector_type_layout, + sel: cb.create_placeholder_structural_witin(|| "weierstrass_double_sel"), input32_exprs, output32_exprs, - n_fixed: 0, - n_committed: 0, - n_challenges: 0, - n_structural_witin: 0, } } @@ -354,56 +339,15 @@ impl ProtocolBuild Ok(layout) } - fn finalize(&mut self, cb: &mut CircuitBuilder) -> (OutEvalGroups, Chip) { - self.n_fixed = cb.cs.num_fixed; - self.n_committed = cb.cs.num_witin as usize; - self.n_structural_witin = cb.cs.num_structural_witin as usize; - self.n_challenges = 0; - - // register selector to legacy constrain system - cb.cs.r_selector = Some(self.selector_type_layout.sel_all.clone()); - cb.cs.w_selector = Some(self.selector_type_layout.sel_all.clone()); - cb.cs.lk_selector = Some(self.selector_type_layout.sel_all.clone()); - cb.cs.zero_selector = Some(self.selector_type_layout.sel_all.clone()); - - let w_len = cb.cs.w_expressions.len(); - let r_len = cb.cs.r_expressions.len(); - let lk_len = cb.cs.lk_expressions.len(); - let zero_len = - cb.cs.assert_zero_expressions.len() + cb.cs.assert_zero_sumcheck_expressions.len(); - ( - [ - // r_record - (0..r_len).collect_vec(), - // w_record - (r_len..r_len + w_len).collect_vec(), - // lk_record - (r_len + w_len..r_len + w_len + lk_len).collect_vec(), - // zero_record - (0..zero_len).collect_vec(), - ], - Chip::new_from_cb(cb, self.n_challenges), - ) - } - - fn n_committed(&self) -> usize { - todo!() - } - - fn n_fixed(&self) -> usize { - todo!() - } - - fn n_challenges(&self) -> usize { - todo!() - } + fn finalize(&self, name: String, cb: &mut CircuitBuilder) -> Chip { + let sel = SelectorType::Prefix(self.sel.expr()); + cb.cs.set_all_default_selectors(sel); - fn n_evaluations(&self) -> usize { - todo!() - } - - fn n_layers(&self) -> usize { - todo!() + let out_evals = default_out_eval_groups(cb); + let mut chip = Chip::new_from_cb(cb, 0); + let layer = Layer::from_circuit_builder(cb, name, 0, out_evals); + chip.add_layer(layer); + chip } } @@ -430,9 +374,14 @@ impl ProtocolWitne let num_instances = wits[0].num_instances(); let nthreads = max_usable_threads(); let num_instance_per_batch = num_instances.div_ceil(nthreads).max(1); - let num_wit_cols = size_of::>(); + let (layout_start, num_layout_wit_cols) = ( + self.layer_exprs.wits.p_x.0[0].id as usize, + size_of::>(), + ); let [wits, structural_wits] = wits; + let wits_width = wits.width; + let structural_wits_width = structural_wits.width; let raw_witin_iter = wits.par_batch_iter_mut(num_instance_per_batch); let raw_structural_wits_iter = structural_wits.par_batch_iter_mut(num_instance_per_batch); raw_witin_iter @@ -440,17 +389,15 @@ impl ProtocolWitne .zip_eq(phase1.instances.par_chunks(num_instance_per_batch)) .for_each(|((rows, eqs), phase1_instances)| { let mut lk_multiplicity = lk_multiplicity.clone(); - rows.chunks_mut(self.n_committed) - .zip_eq(eqs.chunks_mut(self.n_structural_witin)) + rows.chunks_mut(wits_width) + .zip_eq(eqs.chunks_mut(structural_wits_width)) .zip_eq(phase1_instances) .for_each(|((row, eqs), phase1_instance)| { let cols: &mut WeierstrassDoubleAssignWitCols = - row[self.layer_exprs.wits.p_x.0[0].id as usize..][..num_wit_cols] // TODO: Find a better way to write it. + row[layout_start..][..num_layout_wit_cols] // TODO: Find a better way to write it. .borrow_mut(); // We should construct the circuit to guarantee this part occurs first. Self::populate_row(phase1_instance, cols, &mut lk_multiplicity); - for x in eqs.iter_mut() { - *x = E::BaseField::ONE; - } + eqs[self.sel.id as usize] = E::BaseField::ONE; }); }); } @@ -496,7 +443,7 @@ pub fn setup_gkr_circuit, _>>()?; - let (out_evals, mut chip) = layout.finalize(&mut cb); - - let layer = Layer::from_circuit_builder( - &cb, - "weierstrass_double".to_string(), - layout.n_challenges, - out_evals, - ); - chip.add_layer(layer); + let chip = layout.finalize("weierstrass_double".to_string(), &mut cb); Ok(( TestWeierstrassDoubleLayout { @@ -720,10 +659,15 @@ pub fn run_weierstrass_double< .0 .par_iter() .map(|wit| { - let point = point[point.len() - wit.num_vars()..point.len()].to_vec(); + let full_point = point.clone(); + let eval_point = if wit.num_vars() == 0 { + Vec::new() + } else { + full_point[full_point.len() - wit.num_vars()..].to_vec() + }; PointAndEval { - point: point.clone(), - eval: wit.evaluate(&point), + point: full_point, + eval: wit.evaluate(&eval_point), } }) .collect::>(); @@ -747,7 +691,9 @@ pub fn run_weierstrass_double< } let span = entered_span!("create_proof", profiling_2 = true); - let selector_ctxs = vec![SelectorContext::new(0, num_instances, log2_num_instance); 1]; + let selector_ctxs_len = gkr_circuit.layers[0].selector_ctxs_len(); + let selector_ctxs = + vec![SelectorContext::new(0, num_instances, log2_num_instance); selector_ctxs_len]; let GKRProverOutput { gkr_proof, .. } = gkr_circuit .prove::, CpuProver<_>>( num_threads, @@ -826,7 +772,7 @@ mod tests { #[test] fn test_weierstrass_double_bn254() { - test_weierstrass_double_helper::(); + test_weierstrass_double_helper::() } #[test] @@ -841,7 +787,7 @@ mod tests { #[test] fn test_weierstrass_double_secp256r1() { - test_weierstrass_double_helper::(); + test_weierstrass_double_helper::() } fn test_weierstrass_double_nonpow2_helper() { diff --git a/ceno_zkvm/src/scheme/cpu/mod.rs b/ceno_zkvm/src/scheme/cpu/mod.rs index 4007f7fd7..740eff8a1 100644 --- a/ceno_zkvm/src/scheme/cpu/mod.rs +++ b/ceno_zkvm/src/scheme/cpu/mod.rs @@ -8,7 +8,7 @@ use crate::{ constants::{NUM_FANIN, SEPTIC_EXTENSION_DEGREE}, hal::{DeviceProvingKey, EccQuarkProver, ProofInput, TowerProverSpec}, septic_curve::{SepticExtension, SepticPoint, SymbolicSepticExtension}, - utils::{infer_tower_logup_witness, infer_tower_product_witness}, + utils::{global_selector_ctxs, infer_tower_logup_witness, infer_tower_product_witness}, }, structs::{ComposedConstrainSystem, EccQuarkProof, PointAndEval, TowerProofs}, }; @@ -604,163 +604,176 @@ impl> TowerProver>(); - let w_set_last_layer = r_set_last_layer.split_off(r_set_wit.len()); + let mut read_evals = vec![]; + let mut write_evals = vec![]; + let mut logup_evals = vec![]; + let mut prod_specs = vec![]; + let mut logup_specs = vec![]; - let mut lk_numerator_last_layer = lk_n_wit - .iter() - .chain(lk_d_wit.iter()) - .map(|wit| wit.as_view_chunks(NUM_FANIN)) - .collect::>(); - let lk_denominator_last_layer = lk_numerator_last_layer.split_off(lk_n_wit.len()); - exit_span!(span); + let mut offset = 0; + cs.expression_groups.iter().for_each(|(_, group)| { + let num_reads = group.r_expressions.len() + group.r_table_expressions.len(); + let num_writes = group.w_expressions.len() + group.w_table_expressions.len(); + let r_set_wit = &records[offset..][..num_reads]; + offset += num_reads; + let w_set_wit = &records[offset..][..num_writes]; + offset += num_writes; + let lk_table_len = group.lk_table_expressions.len(); + let lk_n_wit = &records[offset..][..lk_table_len]; + offset += lk_table_len; + let lk_d_wit = if lk_table_len > 0 { + &records[offset..][..lk_table_len] + } else { + &records[offset..][..group.lk_expressions.len()] + }; + + // infer all tower witness after last layer + let span = entered_span!("tower_witness_last_layer"); + let mut r_set_last_layer = r_set_wit + .iter() + .chain(w_set_wit.iter()) + .map(|wit| wit.as_view_chunks(NUM_FANIN)) + .collect::>(); + let w_set_last_layer = r_set_last_layer.split_off(r_set_wit.len()); - let span = entered_span!("tower_tower_witness"); - let r_wit_layers = r_set_last_layer - .into_iter() - .map(|last_layer| { - infer_tower_product_witness(num_var_with_rotation, last_layer, NUM_FANIN) - }) - .collect_vec(); - let w_wit_layers = w_set_last_layer - .into_iter() - .map(|last_layer| { - infer_tower_product_witness(num_var_with_rotation, last_layer, NUM_FANIN) - }) - .collect_vec(); - let lk_wit_layers = if !lk_numerator_last_layer.is_empty() { - lk_numerator_last_layer + let mut lk_numerator_last_layer = lk_n_wit + .iter() + .chain(lk_d_wit.iter()) + .map(|wit| wit.as_view_chunks(NUM_FANIN)) + .collect::>(); + let lk_denominator_last_layer = lk_numerator_last_layer.split_off(lk_n_wit.len()); + exit_span!(span); + + let span = entered_span!("tower_tower_witness"); + let r_wit_layers = r_set_last_layer .into_iter() - .zip(lk_denominator_last_layer) - .map(|(lk_n, lk_d)| infer_tower_logup_witness(Some(lk_n), lk_d)) - .collect_vec() - } else { - lk_denominator_last_layer + .map(|last_layer| { + infer_tower_product_witness(num_var_with_rotation, last_layer, NUM_FANIN) + }) + .collect_vec(); + let w_wit_layers = w_set_last_layer .into_iter() - .map(|lk_d| infer_tower_logup_witness(None, lk_d)) - .collect_vec() - }; - exit_span!(span); - - if cfg!(test) { - // sanity check - assert_eq!(r_wit_layers.len(), num_reads); - assert!( - r_wit_layers - .iter() - .zip(r_set_wit.iter()) // depth equals to num_vars - .all(|(layers, origin_mle)| layers.len() == origin_mle.num_vars()) - ); - assert!(r_wit_layers.iter().all(|layers| { - layers.iter().enumerate().all(|(i, w)| { - let expected_size = 1 << i; - w[0].evaluations().len() == expected_size - && w[1].evaluations().len() == expected_size + .map(|last_layer| { + infer_tower_product_witness(num_var_with_rotation, last_layer, NUM_FANIN) }) - })); + .collect_vec(); + let lk_wit_layers = if !lk_numerator_last_layer.is_empty() { + lk_numerator_last_layer + .into_iter() + .zip(lk_denominator_last_layer) + .map(|(lk_n, lk_d)| infer_tower_logup_witness(Some(lk_n), lk_d)) + .collect_vec() + } else { + lk_denominator_last_layer + .into_iter() + .map(|lk_d| infer_tower_logup_witness(None, lk_d)) + .collect_vec() + }; + exit_span!(span); + + if cfg!(test) { + // sanity check + assert_eq!(r_wit_layers.len(), num_reads); + assert!( + r_wit_layers + .iter() + .zip(r_set_wit.iter()) // depth equals to num_vars + .all(|(layers, origin_mle)| layers.len() == origin_mle.num_vars()) + ); + assert!(r_wit_layers.iter().all(|layers| { + layers.iter().enumerate().all(|(i, w)| { + let expected_size = 1 << i; + w[0].evaluations().len() == expected_size + && w[1].evaluations().len() == expected_size + }) + })); - assert_eq!(w_wit_layers.len(), num_writes); - assert!( - w_wit_layers - .iter() - .zip(w_set_wit.iter()) // depth equals to num_vars - .all(|(layers, origin_mle)| layers.len() == origin_mle.num_vars()) - ); - assert!(w_wit_layers.iter().all(|layers| { - layers.iter().enumerate().all(|(i, w)| { - let expected_size = 1 << i; - w[0].evaluations().len() == expected_size - && w[1].evaluations().len() == expected_size - }) - })); + assert_eq!(w_wit_layers.len(), num_writes); + assert!( + w_wit_layers + .iter() + .zip(w_set_wit.iter()) // depth equals to num_vars + .all(|(layers, origin_mle)| layers.len() == origin_mle.num_vars()) + ); + assert!(w_wit_layers.iter().all(|layers| { + layers.iter().enumerate().all(|(i, w)| { + let expected_size = 1 << i; + w[0].evaluations().len() == expected_size + && w[1].evaluations().len() == expected_size + }) + })); - assert_eq!( - lk_wit_layers.len(), - cs.lk_table_expressions.len() + cs.lk_expressions.len() - ); - assert!( - lk_wit_layers - .iter() - .zip(lk_n_wit.iter()) // depth equals to num_vars - .all(|(layers, origin_mle)| layers.len() == origin_mle.num_vars()) - ); - assert!(lk_wit_layers.iter().all(|layers| { - layers.iter().enumerate().all(|(i, w)| { - let expected_size = 1 << i; - let (p1, p2, q1, q2) = (&w[0], &w[1], &w[2], &w[3]); - p1.evaluations().len() == expected_size - && p2.evaluations().len() == expected_size - && q1.evaluations().len() == expected_size - && q2.evaluations().len() == expected_size + assert_eq!( + lk_wit_layers.len(), + group.lk_table_expressions.len() + group.lk_expressions.len() + ); + assert!( + lk_wit_layers + .iter() + .zip(lk_n_wit.iter()) // depth equals to num_vars + .all(|(layers, origin_mle)| layers.len() == origin_mle.num_vars()) + ); + assert!(lk_wit_layers.iter().all(|layers| { + layers.iter().enumerate().all(|(i, w)| { + let expected_size = 1 << i; + let (p1, p2, q1, q2) = (&w[0], &w[1], &w[2], &w[3]); + p1.evaluations().len() == expected_size + && p2.evaluations().len() == expected_size + && q1.evaluations().len() == expected_size + && q2.evaluations().len() == expected_size + }) + })); + } + + // final evals for verifier + let r_out_evals = r_wit_layers + .iter() + .map(|r_wit_layers| { + r_wit_layers[0] + .iter() + .map(|mle| mle.get_ext_field_vec()[0]) + .collect_vec() }) - })); - } + .collect_vec(); + let w_out_evals = w_wit_layers + .iter() + .map(|w_wit_layers| { + w_wit_layers[0] + .iter() + .map(|mle| mle.get_ext_field_vec()[0]) + .collect_vec() + }) + .collect_vec(); + let lk_out_evals = lk_wit_layers + .iter() + .map(|lk_wit_layers| { + lk_wit_layers[0] + .iter() + .map(|mle| mle.get_ext_field_vec()[0]) + .collect_vec() + }) + .collect_vec(); - // final evals for verifier - let r_out_evals = r_wit_layers - .iter() - .map(|r_wit_layers| { - r_wit_layers[0] - .iter() - .map(|mle| mle.get_ext_field_vec()[0]) - .collect_vec() - }) - .collect_vec(); - let w_out_evals = w_wit_layers - .iter() - .map(|w_wit_layers| { - w_wit_layers[0] - .iter() - .map(|mle| mle.get_ext_field_vec()[0]) - .collect_vec() - }) - .collect_vec(); - let lk_out_evals = lk_wit_layers - .iter() - .map(|lk_wit_layers| { - lk_wit_layers[0] - .iter() - .map(|mle| mle.get_ext_field_vec()[0]) - .collect_vec() - }) - .collect_vec(); + let prods = r_wit_layers + .into_iter() + .chain(w_wit_layers) + .map(|witness| TowerProverSpec::> { witness }) + .collect_vec(); + let logups = lk_wit_layers + .into_iter() + .map(|witness| TowerProverSpec::> { witness }) + .collect_vec(); - let prod_specs = r_wit_layers - .into_iter() - .chain(w_wit_layers) - .map(|witness| TowerProverSpec { witness }) - .collect_vec(); - let lookup_specs = lk_wit_layers - .into_iter() - .map(|witness| TowerProverSpec { witness }) - .collect_vec(); + read_evals.extend(r_out_evals); + write_evals.extend(w_out_evals); + logup_evals.extend(lk_out_evals); + prod_specs.extend(prods); + logup_specs.extend(logups); + }); - let out_evals = vec![r_out_evals, w_out_evals, lk_out_evals]; + let out_evals = vec![read_evals, write_evals, logup_evals]; - (out_evals, prod_specs, lookup_specs) + (out_evals, prod_specs, logup_specs) } #[tracing::instrument( @@ -854,28 +867,12 @@ impl> MainSumcheckProver( prod_gpu: &[ceno_gpu::GpuProverSpec], // GPU-built product towers logup_gpu: &[ceno_gpu::GpuProverSpec], // GPU-built logup towers - r_set_len: usize, + prod_kinds: &[ProdKind], ) -> (Vec>, Vec>, Vec>) { + assert_eq!( + prod_gpu.len(), + prod_kinds.len(), + "prod_gpu_specs and prod_kinds length mismatch" + ); + // Extract product out_evals from GPU towers let mut r_out_evals = Vec::new(); let mut w_out_evals = Vec::new(); - for (i, gpu_spec) in prod_gpu.iter().enumerate() { + for (gpu_spec, kind) in prod_gpu.iter().zip(prod_kinds.iter()) { let first_layer_evals: Vec = gpu_spec .get_output_evals() .expect("Failed to extract final evals from GPU product tower"); @@ -92,11 +101,10 @@ fn extract_out_evals_from_gpu_towers( "Product tower first layer should have 2 MLEs" ); - // Split into r_out_evals and w_out_evals based on r_set_len - if i < r_set_len { - r_out_evals.push(first_layer_evals); - } else { - w_out_evals.push(first_layer_evals); + // Split into r_out_evals and w_out_evals based on prod_kinds order + match kind { + ProdKind::Read => r_out_evals.push(first_layer_evals), + ProdKind::Write => w_out_evals.push(first_layer_evals), } } @@ -257,6 +265,7 @@ fn build_tower_witness_gpu<'buf, E: ExtensionField>( ( Vec>, Vec>, + Vec, ), String, > { @@ -279,31 +288,91 @@ fn build_tower_witness_gpu<'buf, E: ExtensionField>( >(records) }; - // Parse records into different categories (same as build_tower_witness) - let num_reads = cs.r_expressions.len() + cs.r_table_expressions.len(); - let num_writes = cs.w_expressions.len() + cs.w_table_expressions.len(); let mut offset = 0; - let r_set_wit = &records[offset..][..num_reads]; - offset += num_reads; - let w_set_wit = &records[offset..][..num_writes]; - offset += num_writes; - let lk_n_wit = &records[offset..][..cs.lk_table_expressions.len()]; - offset += cs.lk_table_expressions.len(); - let lk_d_wit = if !cs.lk_table_expressions.is_empty() { - &records[offset..][..cs.lk_table_expressions.len()] - } else { - &records[offset..][..cs.lk_expressions.len()] - }; + let mut prod_last_layers = Vec::new(); + let mut logup_last_layers = Vec::new(); + let mut prod_kinds = Vec::new(); + + // Parse records into different categories (same as build_tower_witness) + for (_, group) in cs.expression_groups.iter() { + let num_reads = group.r_expressions.len() + group.r_table_expressions.len(); + let num_writes = group.w_expressions.len() + group.w_table_expressions.len(); + let r_set_wit = &records[offset..][..num_reads]; + offset += num_reads; + let w_set_wit = &records[offset..][..num_writes]; + offset += num_writes; + let lk_table_len = group.lk_table_expressions.len(); + let lk_n_wit = &records[offset..][..lk_table_len]; + offset += lk_table_len; + let lk_d_wit = if lk_table_len > 0 { + &records[offset..][..lk_table_len] + } else { + &records[offset..][..group.lk_expressions.len()] + }; + + // prod: last layers + for wit in r_set_wit.iter() { + prod_last_layers.push(wit.as_view_chunks(NUM_FANIN)); + prod_kinds.push(ProdKind::Read); + } + for wit in w_set_wit.iter() { + prod_last_layers.push(wit.as_view_chunks(NUM_FANIN)); + prod_kinds.push(ProdKind::Write); + } + + // logup: last layers + let lk_numerator_last_layer = lk_n_wit + .iter() + .map(|wit| wit.as_view_chunks(NUM_FANIN_LOGUP)) + .collect::>(); + let lk_denominator_last_layer = lk_d_wit + .iter() + .map(|wit| wit.as_view_chunks(NUM_FANIN_LOGUP)) + .collect::>(); + if !lk_numerator_last_layer.is_empty() { + // Case when we have both numerator and denominator + // Combine [p1, p2] from numerator and [q1, q2] from denominator + for (lk_n_chunks, lk_d_chunks) in lk_numerator_last_layer + .into_iter() + .zip(lk_denominator_last_layer) + { + let mut last_layer = lk_n_chunks; + last_layer.extend(lk_d_chunks); + logup_last_layers.push(last_layer); + } + } else if !lk_denominator_last_layer.is_empty() { + // Case when numerator is empty - create shared ones_buffer and use views + let nv = lk_denominator_last_layer[0][0].num_vars(); + + let ones_poly = GpuPolynomialExt::new_with_scalar(&cuda_hal.inner, nv, BB31Ext::ONE) + .map_err(|e| format!("Failed to create shared ones_buffer: {:?}", e))?; + let ones_poly_static: GpuPolynomialExt<'static> = + unsafe { std::mem::transmute(ones_poly) }; + ones_buffer.push(ones_poly_static); + + // Get reference from storage to ensure proper lifetime + let ones_poly_ref = ones_buffer.last().unwrap(); + let mle_len_bytes = ones_poly_ref.evaluations().len() * std::mem::size_of::(); + + // Create views referencing the shared ones_buffer for each tower's p1, p2 + for lk_d_chunks in lk_denominator_last_layer { + let p1_view = ones_poly_ref.evaluations().as_slice_range(0..mle_len_bytes); + let p2_view = ones_poly_ref.evaluations().as_slice_range(0..mle_len_bytes); + let p1_gpu = GpuPolynomialExt::new(BufferImpl::new_from_view(p1_view), nv); + let p2_gpu = GpuPolynomialExt::new(BufferImpl::new_from_view(p2_view), nv); + let p1_gpu: GpuPolynomialExt<'static> = unsafe { std::mem::transmute(p1_gpu) }; + let p2_gpu: GpuPolynomialExt<'static> = unsafe { std::mem::transmute(p2_gpu) }; + let mut last_layer = vec![p1_gpu, p2_gpu]; + last_layer.extend(lk_d_chunks); + logup_last_layers.push(last_layer); + } + } + } assert_eq!(big_buffers.len(), 0, "expect no big buffers"); // prod: last layes & buffer let mut is_prod_buffer_exists = false; - let prod_last_layers = r_set_wit - .iter() - .chain(w_set_wit.iter()) - .map(|wit| wit.as_view_chunks(NUM_FANIN)) - .collect::>(); if !prod_last_layers.is_empty() { let first_layer = &prod_last_layers[0]; assert_eq!(first_layer.len(), 2, "prod last_layer must have 2 MLEs"); @@ -325,66 +394,7 @@ fn build_tower_witness_gpu<'buf, E: ExtensionField>( is_prod_buffer_exists = true; } - // logup: last layes let mut is_logup_buffer_exists = false; - let lk_numerator_last_layer = lk_n_wit - .iter() - .map(|wit| wit.as_view_chunks(NUM_FANIN_LOGUP)) - .collect::>(); - let lk_denominator_last_layer = lk_d_wit - .iter() - .map(|wit| wit.as_view_chunks(NUM_FANIN_LOGUP)) - .collect::>(); - let logup_last_layers = if !lk_numerator_last_layer.is_empty() { - // Case when we have both numerator and denominator - // Combine [p1, p2] from numerator and [q1, q2] from denominator - lk_numerator_last_layer - .into_iter() - .zip(lk_denominator_last_layer) - .map(|(lk_n_chunks, lk_d_chunks)| { - let mut last_layer = lk_n_chunks; - last_layer.extend(lk_d_chunks); - last_layer - }) - .collect::>() - } else if lk_denominator_last_layer.is_empty() { - vec![] - } else { - // Case when numerator is empty - create shared ones_buffer and use views - // This saves memory by having all p1, p2 polynomials reference the same buffer - let nv = lk_denominator_last_layer[0][0].num_vars(); - - // Create one shared ones_buffer as Owned (can be 'static) - let ones_poly = GpuPolynomialExt::new_with_scalar(&cuda_hal.inner, nv, BB31Ext::ONE) - .map_err(|e| format!("Failed to create shared ones_buffer: {:?}", e)) - .unwrap(); - // SAFETY: Owned buffer can be safely treated as 'static - let ones_poly_static: GpuPolynomialExt<'static> = unsafe { std::mem::transmute(ones_poly) }; - ones_buffer.push(ones_poly_static); - - // Get reference from storage to ensure proper lifetime - let ones_poly_ref = ones_buffer.last().unwrap(); - let mle_len_bytes = ones_poly_ref.evaluations().len() * std::mem::size_of::(); - - // Create views referencing the shared ones_buffer for each tower's p1, p2 - lk_denominator_last_layer - .into_iter() - .map(|lk_d_chunks| { - // Create views of ones_buffer for p1 and p2 - let p1_view = ones_poly_ref.evaluations().as_slice_range(0..mle_len_bytes); - let p2_view = ones_poly_ref.evaluations().as_slice_range(0..mle_len_bytes); - let p1_gpu = GpuPolynomialExt::new(BufferImpl::new_from_view(p1_view), nv); - let p2_gpu = GpuPolynomialExt::new(BufferImpl::new_from_view(p2_view), nv); - // SAFETY: views from 'static buffer can be 'static - let p1_gpu: GpuPolynomialExt<'static> = unsafe { std::mem::transmute(p1_gpu) }; - let p2_gpu: GpuPolynomialExt<'static> = unsafe { std::mem::transmute(p2_gpu) }; - // Use [p1, p2, q1, q2] format for the last layer - let mut last_layer = vec![p1_gpu, p2_gpu]; - last_layer.extend(lk_d_chunks); - last_layer - }) - .collect::>() - }; if !logup_last_layers.is_empty() { let first_layer = &logup_last_layers[0]; assert_eq!(first_layer.len(), 4, "logup last_layer must have 4 MLEs"); @@ -492,7 +502,7 @@ fn build_tower_witness_gpu<'buf, E: ExtensionField>( logup_gpu_specs.extend(gpu_specs); exit_span!(span_logup); } - Ok((prod_gpu_specs, logup_gpu_specs)) + Ok((prod_gpu_specs, logup_gpu_specs, prod_kinds)) } impl> TowerProver> @@ -545,12 +555,6 @@ impl> TowerProver> TowerProver> = Vec::new(); let mut _ones_buffer: Vec> = Vec::new(); let mut _view_last_layers: Vec>>> = Vec::new(); - let (prod_gpu, logup_gpu) = - info_span!("[ceno] build_tower_witness_gpu").in_scope(|| { + let (prod_gpu, logup_gpu, prod_kinds) = info_span!("[ceno] build_tower_witness_gpu") + .in_scope(|| { build_tower_witness_gpu( composed_cs, input, @@ -579,7 +583,7 @@ impl> TowerProver>> BasicTranscript @@ -662,28 +666,12 @@ impl> MainSumcheckProver MockProver { fixed: &[ArcMultilinearExtension<'a, E>], wits_in: &[ArcMultilinearExtension<'a, E>], structural_witin: &[ArcMultilinearExtension<'a, E>], + num_instances: usize, challenge: [E; 2], lkm: Option>, ) -> Result<(), Vec>> { @@ -513,6 +514,7 @@ impl<'a, E: ExtensionField + Hash> MockProver { &[], &[], &[], + num_instances, Some(challenge), lkm, ) @@ -521,10 +523,23 @@ impl<'a, E: ExtensionField + Hash> MockProver { pub fn run( cb: &CircuitBuilder, wits_in: &[ArcMultilinearExtension<'a, E>], + structural_witin: &[ArcMultilinearExtension<'a, E>], program: &[ceno_emul::Instruction], + num_instances: usize, lkm: Option>, ) -> Result<(), Vec>> { - Self::run_maybe_challenge(cb, &[], wits_in, &[], program, &[], &[], None, lkm) + Self::run_maybe_challenge( + cb, + &[], + wits_in, + structural_witin, + program, + &[], + &[], + num_instances, + None, + lkm, + ) } #[allow(clippy::too_many_arguments)] @@ -536,6 +551,7 @@ impl<'a, E: ExtensionField + Hash> MockProver { program: &[ceno_emul::Instruction], pi_mles: &[ArcMultilinearExtension<'a, E>], pub_io_evals: &[Either], + num_instances: usize, challenge: Option<[E; 2]>, lkm: Option>, ) -> Result<(), Vec>> { @@ -550,7 +566,7 @@ impl<'a, E: ExtensionField + Hash> MockProver { structural_witin, pi_mles, pub_io_evals, - 1, + num_instances, challenge, lkm, ) @@ -573,102 +589,132 @@ impl<'a, E: ExtensionField + Hash> MockProver { let mut shared_lkm = LkMultiplicityRaw::::default(); let mut errors = vec![]; - let num_instance_padded = wits_in - .first() - .or_else(|| fixed.first()) - .or_else(|| pi_mles.first()) - .or_else(|| structural_witin.first()) - .map(|mle| mle.evaluations().len()) - .unwrap_or_else(|| next_pow2_instance_padding(num_instances)); - - // Assert zero expressions - for (expr, name) in cs - .assert_zero_expressions - .iter() - .chain(&cs.assert_zero_sumcheck_expressions) - .zip_eq( - cs.assert_zero_expressions_namespace_map - .iter() - .chain(&cs.assert_zero_sumcheck_expressions_namespace_map), - ) - { - if expr.degree() > MAX_CONSTRAINT_DEGREE { - errors.push(MockProverError::DegreeTooHigh { - expression: expr.clone(), - degree: expr.degree(), - name: name.clone(), - }); - } - - let zero_selector: ArcMultilinearExtension<_> = - if let Some(zero_selector) = &cs.zero_selector { - structural_witin[zero_selector.selector_expr().id()].clone() - } else { - let mut selector = vec![E::BaseField::ONE; num_instances]; - selector.resize(num_instance_padded, E::BaseField::ZERO); - MultilinearExtension::from_evaluation_vec_smart( - ceil_log2(num_instance_padded), - selector, - ) - .into() - }; + let default_structural_witin = Self::default_structural_witin_from_wits(cs, num_instances); + let structural_witin = if structural_witin.is_empty() { + default_structural_witin.as_slice() + } else { + structural_witin + }; - // require_equal does not always have the form of Expr::Sum as - // the sum of witness and constant is expressed as scaled sum - if let Expression::Sum(left, right) = expr - && name.contains("require_equal") + // Assert zero expressions grouped by selector + for (selector, group) in &cs.expression_groups { + let zero_selector = if let Some(selector) = selector { + &structural_witin[selector.selector_expr().id()] + } else { + &structural_witin[0] + }; + for r in group + .assert_zero_expressions + .iter() + .chain(group.assert_zero_sumcheck_expressions.iter()) { - let right = -right.as_ref(); - - let left_evaluated = wit_infer_by_expr( - left, - cs.num_witin, - cs.num_fixed as WitnessId, - cs.instance_openings.len(), - fixed, - wits_in, - structural_witin, - pi_mles, - pub_io_evals, - &challenge, - ); - let left_evaluated = - filter_mle_by_selector_mle(left_evaluated, zero_selector.clone()); - - let right_evaluated = wit_infer_by_expr( - &right, - cs.num_witin, - cs.num_fixed as WitnessId, - cs.instance_openings.len(), - fixed, - wits_in, - structural_witin, - pi_mles, - pub_io_evals, - &challenge, - ); - let right_evaluated = - filter_mle_by_selector_mle(right_evaluated, zero_selector.clone()); + let expr = &r.expression; + let name = &r.expression_namespace_map; + if expr.degree() > MAX_CONSTRAINT_DEGREE { + errors.push(MockProverError::DegreeTooHigh { + expression: expr.clone(), + degree: expr.degree(), + name: name.clone(), + }); + } - // left_evaluated.len() ?= right_evaluated.len() due to padding instance - for (inst_id, (left_element, right_element)) in - izip!(left_evaluated, right_evaluated).enumerate() + // require_equal does not always have the form of Expr::Sum as + // the sum of witness and constant is expressed as scaled sum + if let Expression::Sum(left, right) = expr + && name.contains("require_equal") { - if left_element != right_element { - errors.push(MockProverError::AssertEqualError { - left_expression: *left.clone(), - right_expression: right.clone(), - left: Either::Right(left_element), - right: Either::Right(right_element), - name: name.clone(), - inst_id, - }); + let right = -right.as_ref(); + + let left_evaluated = wit_infer_by_expr( + left, + cs.num_witin, + cs.num_fixed as WitnessId, + cs.instance_openings.len(), + fixed, + wits_in, + structural_witin, + pi_mles, + pub_io_evals, + &challenge, + ); + let left_evaluated = + filter_mle_by_selector_mle(left_evaluated, zero_selector.clone()); + + let right_evaluated = wit_infer_by_expr( + &right, + cs.num_witin, + cs.num_fixed as WitnessId, + cs.instance_openings.len(), + fixed, + wits_in, + structural_witin, + pi_mles, + pub_io_evals, + &challenge, + ); + let right_evaluated = + filter_mle_by_selector_mle(right_evaluated, zero_selector.clone()); + + // left_evaluated.len() ?= right_evaluated.len() due to padding instance + for (inst_id, (left_element, right_element)) in + izip!(left_evaluated, right_evaluated).enumerate() + { + if left_element != right_element { + errors.push(MockProverError::AssertEqualError { + left_expression: *left.clone(), + right_expression: right.clone(), + left: Either::Right(left_element), + right: Either::Right(right_element), + name: name.clone(), + inst_id, + }); + } + } + } else { + // contains require_zero + let expr_evaluated = wit_infer_by_expr( + expr, + cs.num_witin, + cs.num_fixed as WitnessId, + cs.instance_openings.len(), + fixed, + wits_in, + structural_witin, + pi_mles, + pub_io_evals, + &challenge, + ); + let expr_evaluated = + filter_mle_by_selector_mle(expr_evaluated, zero_selector.clone()); + + for (inst_id, element) in enumerate(expr_evaluated) { + if element != E::ZERO { + errors.push(MockProverError::AssertZeroError { + expression: expr.clone(), + evaluated: Either::Right(element), + name: name.clone(), + inst_id, + }); + } } } + } + } + + // Lookup expressions + for (selector, group) in &cs.expression_groups { + let expressions = &group.lk_expressions; + if expressions.is_empty() { + continue; + } + let lk_selector = if let Some(selector) = selector { + &structural_witin[selector.selector_expr().id()] } else { - // contains require_zero + &structural_witin[0] + }; + for r in expressions.iter() { let expr_evaluated = wit_infer_by_expr( - expr, + &r.expression, cs.num_witin, cs.num_fixed as WitnessId, cs.instance_openings.len(), @@ -680,131 +726,103 @@ impl<'a, E: ExtensionField + Hash> MockProver { &challenge, ); let expr_evaluated = - filter_mle_by_selector_mle(expr_evaluated, zero_selector.clone()); - - for (inst_id, element) in enumerate(expr_evaluated) { - if element != E::ZERO { - errors.push(MockProverError::AssertZeroError { - expression: expr.clone(), - evaluated: Either::Right(element), - name: name.clone(), + filter_mle_by_selector_mle(expr_evaluated, lk_selector.clone()); + + // Check each lookup expr exists in t vec + for (inst_id, element) in enumerate(&expr_evaluated) { + if !table.contains(&element.to_canonical_u64_vec()) { + errors.push(MockProverError::LookupError { + rom_type: r.meta.0, + expression: r.expression.clone(), + evaluated: *element, + name: r.expression_namespace_map.clone(), inst_id, }); } } - } - } - let lk_selector: ArcMultilinearExtension<_> = if let Some(lk_selector) = &cs.lk_selector { - structural_witin[lk_selector.selector_expr().id()].clone() - } else { - let mut selector = vec![E::BaseField::ONE; num_instances]; - selector.resize(num_instance_padded, E::BaseField::ZERO); - MultilinearExtension::from_evaluation_vec_smart( - ceil_log2(num_instance_padded), - selector, - ) - .into() - }; - - // Lookup expressions - for (expr, (name, (rom_type, _))) in cs.lk_expressions.iter().zip( - cs.lk_expressions_namespace_map - .iter() - .zip_eq(cs.lk_expressions_items_map.iter()), - ) { - let expr_evaluated = wit_infer_by_expr( - expr, - cs.num_witin, - cs.num_fixed as WitnessId, - cs.instance_openings.len(), - fixed, - wits_in, - structural_witin, - pi_mles, - pub_io_evals, - &challenge, - ); - let expr_evaluated = filter_mle_by_selector_mle(expr_evaluated, lk_selector.clone()); - - // Check each lookup expr exists in t vec - for (inst_id, element) in enumerate(&expr_evaluated) { - if !table.contains(&element.to_canonical_u64_vec()) { - errors.push(MockProverError::LookupError { - rom_type: *rom_type, - expression: expr.clone(), - evaluated: *element, - name: name.clone(), - inst_id, - }); + // Increment shared LK Multiplicity + for element in expr_evaluated { + shared_lkm.increment(r.meta.0, element); } } - - // Increment shared LK Multiplicity - for element in expr_evaluated { - shared_lkm.increment(*rom_type, element); - } } // LK Multiplicity check if let Some(lkm_from_assignment) = expected_lkm { - let selected_count = lk_selector - .get_base_field_vec() - .iter() - .filter(|sel| **sel == E::BaseField::ONE) - .count(); // Infer LK Multiplicity from constraint system. let mut lkm_from_cs = LkMultiplicity::default(); - for (rom_type, args) in &cs.lk_expressions_items_map { - let args_eval: Vec<_> = args + for (selector, group) in &cs.expression_groups { + let expressions = &group.lk_expressions; + if expressions.is_empty() { + continue; + } + let lk_selector = if let Some(selector) = selector { + &structural_witin[selector.selector_expr().id()] + } else { + &structural_witin[0] + }; + let selected_count = lk_selector + .get_base_field_vec() .iter() - .map(|arg_expr| { - let arg_eval = wit_infer_by_expr( - arg_expr, - cs.num_witin, - cs.num_fixed as WitnessId, - cs.instance_openings.len(), - fixed, - wits_in, - structural_witin, - pi_mles, - pub_io_evals, - &challenge, - ); - if arg_expr.is_constant() && arg_eval.evaluations.len() == 1 { - vec![arg_eval.get_ext_field_vec()[0].to_canonical_u64(); selected_count] - } else { - filter_mle_by_selector_mle(arg_eval, lk_selector.clone()) - .iter() - .map(E::to_canonical_u64) - .collect_vec() - } - }) - .collect(); + .filter(|sel| **sel == E::BaseField::ONE) + .count(); - // Count lookups infered from ConstraintSystem from all instances into lkm_from_cs. - for (arg0, arg1) in args_eval[0] - .iter() - .zip(args_eval[1].iter()) - .take(selected_count) - { - match rom_type { - ROMType::Dynamic => { - lkm_from_cs.assert_dynamic_range(*arg0, *arg1); - } - ROMType::DoubleU8 => { - lkm_from_cs.assert_double_u8(*arg0, *arg1); - } - ROMType::And => lkm_from_cs.lookup_and_byte(*arg0, *arg1), - ROMType::Or => lkm_from_cs.lookup_or_byte(*arg0, *arg1), - ROMType::Xor => lkm_from_cs.lookup_xor_byte(*arg0, *arg1), - ROMType::Ltu => lkm_from_cs.lookup_ltu_byte(*arg0, *arg1), - ROMType::Pow => { - assert_eq!(*arg0, 2); - lkm_from_cs.lookup_pow2(*arg1) - } - ROMType::Instruction => lkm_from_cs.fetch(*arg0 as u32), - }; + for r in expressions.iter() { + let (rom_type, args) = &r.meta; + let args_eval: Vec<_> = args + .iter() + .map(|arg_expr| { + let arg_eval = wit_infer_by_expr( + arg_expr, + cs.num_witin, + cs.num_fixed as WitnessId, + cs.instance_openings.len(), + fixed, + wits_in, + structural_witin, + pi_mles, + pub_io_evals, + &challenge, + ); + if arg_expr.is_constant() && arg_eval.evaluations.len() == 1 { + vec![ + arg_eval.get_ext_field_vec()[0].to_canonical_u64(); + selected_count + ] + } else { + filter_mle_by_selector_mle(arg_eval, lk_selector.clone()) + .iter() + .map(E::to_canonical_u64) + .collect_vec() + } + }) + .collect(); + + // Count lookups infered from ConstraintSystem from all instances into lkm_from_cs. + for (arg0, arg1) in args_eval[0] + .iter() + .zip(args_eval[1].iter()) + .take(selected_count) + { + match rom_type { + ROMType::Dynamic => { + lkm_from_cs.assert_dynamic_range(*arg0, *arg1); + } + ROMType::DoubleU8 => { + lkm_from_cs.assert_double_u8(*arg0, *arg1); + } + ROMType::And => lkm_from_cs.lookup_and_byte(*arg0, *arg1), + ROMType::Or => lkm_from_cs.lookup_or_byte(*arg0, *arg1), + ROMType::Xor => lkm_from_cs.lookup_xor_byte(*arg0, *arg1), + ROMType::Ltu => lkm_from_cs.lookup_ltu_byte(*arg0, *arg1), + ROMType::Pow => { + assert_eq!(*arg0, 2); + lkm_from_cs.lookup_pow2(*arg1) + } + ROMType::Instruction => lkm_from_cs.fetch(*arg0 as u32), + }; + } } } @@ -850,7 +868,8 @@ impl<'a, E: ExtensionField + Hash> MockProver { let mut cb = CircuitBuilder::new(&mut cs); let config = ProgramTableCircuit::<_>::construct_circuit(&mut cb, ¶ms).unwrap(); let fixed = ProgramTableCircuit::::generate_fixed_traces(&config, cs.num_fixed, program); - for table_expr in &cs.lk_table_expressions { + + for table_expr in cs.lk_table_expressions_all() { for row in fixed.iter_rows() { // TODO: Find a better way to obtain the row content. let row: Vec = row.iter().map(|v| (*v).into()).collect(); @@ -874,13 +893,22 @@ impl<'a, E: ExtensionField + Hash> MockProver { structural_witin: &[ArcMultilinearExtension<'a, E>], program: &[ceno_emul::Instruction], constraint_names: &[&str], + num_instances: usize, challenge: Option<[E; 2]>, lkm: Option>, ) { let error_groups = if let Some(challenge) = challenge { - Self::run_with_challenge(cb, fixed, wits_in, structural_witin, challenge, lkm) + Self::run_with_challenge( + cb, + fixed, + wits_in, + structural_witin, + num_instances, + challenge, + lkm, + ) } else { - Self::run(cb, wits_in, program, lkm) + Self::run(cb, wits_in, structural_witin, program, num_instances, lkm) } .err() .into_iter() @@ -917,6 +945,7 @@ Hints: cb: &CircuitBuilder, [raw_witin, raw_structural_witin]: RMMCollections, program: &[ceno_emul::Instruction], + num_instances: usize, challenge: Option<[E; 2]>, lkm: Option>, ) { @@ -930,7 +959,15 @@ Hints: .into_iter() .map(|v| v.into()) .collect_vec(); - Self::assert_satisfied(cb, &wits_in, &structural_witin, program, challenge, lkm); + Self::assert_satisfied( + cb, + &wits_in, + &structural_witin, + program, + num_instances, + challenge, + lkm, + ); } pub fn assert_satisfied( @@ -938,6 +975,7 @@ Hints: wits_in: &[ArcMultilinearExtension<'a, E>], structural_witin: &[ArcMultilinearExtension<'a, E>], program: &[ceno_emul::Instruction], + num_instances: usize, challenge: Option<[E; 2]>, lkm: Option>, ) { @@ -949,6 +987,7 @@ Hints: structural_witin, program, &[], + num_instances, challenge, lkm, ); @@ -1045,7 +1084,7 @@ Hints: fixed.to_mles().into_iter().map(|f| f.into()).collect_vec() }); // not lookup table - if cs.lk_table_expressions.is_empty() { + if cs.lk_table_expressions_len() == 0 { tracing::info!( "Mock proving opcode {} with {} entries", circuit_name, @@ -1082,9 +1121,8 @@ Hints: num_rows ); // gather lookup tables - for (expr, (rom_type, _)) in - izip!(&cs.lk_table_expressions, &cs.lk_expressions_items_map) - { + for expr in cs.lk_table_expressions_all() { + let rom_type = expr.meta.0; let lk_table = wit_infer_by_expr( &expr.values, cs.num_witin, @@ -1117,7 +1155,7 @@ Hints: for (key, multiplicity) in izip!(lk_table, multiplicity) { lkm_tables.set_count( - *rom_type, + rom_type, key, multiplicity.to_canonical_u64() as usize, ); @@ -1162,6 +1200,7 @@ Hints: let fixed = fixed_mles.get(circuit_name).unwrap(); let witness = wit_mles.get(circuit_name).unwrap(); let structural_witness = structural_wit_mles.get(circuit_name).unwrap(); + let pi_mles = cs .instance_openings .iter() @@ -1172,87 +1211,86 @@ Hints: if *num_rows == 0 { continue; } - let w_selector: ArcMultilinearExtension<_> = - if let Some(w_selector) = &cs.w_selector { - structural_witness[w_selector.selector_expr().id()].clone() - } else { - let mut selector = vec![E::BaseField::ONE; *num_rows]; - selector.resize(next_pow2_instance_padding(*num_rows), E::BaseField::ZERO); - MultilinearExtension::from_evaluation_vec_smart( - ceil_log2(next_pow2_instance_padding(*num_rows)), - selector, - ) - .into() + for (selector, group) in &cs.expression_groups { + let Some(selector) = selector else { + assert!(group.is_empty(), "all expressions must have a selector"); + continue; }; - - for ((w_rlc_expr, annotation), (ram_type_expr, _)) in (cs - .w_expressions - .iter() - .chain(cs.w_table_expressions.iter().map(|expr| &expr.expr))) - .zip_eq( - cs.w_expressions_namespace_map - .iter() - .chain(cs.w_table_expressions_namespace_map.iter()), - ) - .zip_eq(cs.w_ram_types.iter()) - { - let ram_type_mle = wit_infer_by_expr( - ram_type_expr, - cs.num_witin, - cs.num_fixed as WitnessId, - cs.instance_openings.len(), - fixed, - witness, - structural_witness, - &pi_mles, - &pub_io_evals, - &challenges, - ); - let ram_type_vec = ram_type_mle.get_ext_field_vec(); - let write_rlc_records = wit_infer_by_expr( - w_rlc_expr, - cs.num_witin, - cs.num_fixed as WitnessId, - cs.instance_openings.len(), - fixed, - witness, - structural_witness, - &pi_mles, - &pub_io_evals, - &challenges, - ); - let w_selector_vec = w_selector.get_base_field_vec(); - let write_rlc_records = - filter_mle_by_predicate(write_rlc_records, |i, _v| { - ram_type_vec[i] == E::from_canonical_u32($ram_type as u32) - && w_selector_vec[i] == E::BaseField::ONE - }); - if write_rlc_records.is_empty() { + let w_expressions = &group.w_expressions; + if w_expressions.is_empty() && group.w_table_expressions.is_empty() { continue; } + let w_selector = structural_witness[selector.selector_expr().id()].clone(); - let mut records = vec![]; - let mut writes_within_expr_dedup = HashSet::new(); - for (row, record_rlc) in enumerate(write_rlc_records) { - // TODO: report error - assert_eq!( - writes_within_expr_dedup.insert(record_rlc), - true, - "circuit name {circuit_name} within expression write duplicated on RAMType {:?} annotation {:?} on row {row}", - $ram_type, - annotation + for (w_rlc_expr, annotation, (ram_type_expr, _)) in w_expressions + .iter() + .map(|r| (&r.expression, &r.expression_namespace_map, &r.meta)) + .chain( + group + .w_table_expressions + .iter() + .map(|t| (&t.expr, &t.expression_namespace_map, &t.meta)), + ) + { + let ram_type_mle = wit_infer_by_expr( + ram_type_expr, + cs.num_witin, + cs.num_fixed as WitnessId, + cs.instance_openings.len(), + fixed, + witness, + structural_witness, + &pi_mles, + &pub_io_evals, + &challenges, ); - assert_eq!( - writes.insert(record_rlc), - true, - "circuit name {circuit_name} crossing-chip write duplicated on RAMType {:?} annotation {:?} on row {row}", - $ram_type, - annotation + let ram_type_vec = ram_type_mle.get_ext_field_vec(); + let write_rlc_records = wit_infer_by_expr( + w_rlc_expr, + cs.num_witin, + cs.num_fixed as WitnessId, + cs.instance_openings.len(), + fixed, + witness, + structural_witness, + &pi_mles, + &pub_io_evals, + &challenges, ); - records.push((record_rlc, row)); + let w_selector_vec = w_selector.get_base_field_vec(); + let write_rlc_records = + filter_mle_by_predicate(write_rlc_records, |i, _v| { + ram_type_vec[i] + == E::from_canonical_u32($ram_type as u32) + && w_selector_vec[i] == E::BaseField::ONE + }); + if write_rlc_records.is_empty() { + continue; + } + + let mut records = vec![]; + let mut writes_within_expr_dedup = HashSet::new(); + for (row, record_rlc) in enumerate(write_rlc_records) { + // TODO: report error + assert_eq!( + writes_within_expr_dedup.insert(record_rlc), + true, + "circuit name {circuit_name} within expression write duplicated on RAMType {:?} annotation {:?} on row {row}", + $ram_type, + annotation + ); + assert_eq!( + writes.insert(record_rlc), + true, + "circuit name {circuit_name} crossing-chip write duplicated on RAMType {:?} annotation {:?} on row {row}", + $ram_type, + annotation + ); + records.push((record_rlc, row)); + } + writes_grp_by_annotations + .insert(annotation.clone(), (records, circuit_name.clone())); } - writes_grp_by_annotations - .insert(annotation.clone(), (records, circuit_name.clone())); } } @@ -1268,6 +1306,7 @@ Hints: let fixed = fixed_mles.get(circuit_name).unwrap(); let witness = wit_mles.get(circuit_name).unwrap(); let structural_witness = structural_wit_mles.get(circuit_name).unwrap(); + let pi_mles = cs .instance_openings .iter() @@ -1277,119 +1316,117 @@ Hints: if *num_rows == 0 { continue; } - let r_selector: ArcMultilinearExtension<_> = - if let Some(r_selector) = &cs.r_selector { - structural_witness[r_selector.selector_expr().id()].clone() - } else { - let mut selector = vec![E::BaseField::ONE; *num_rows]; - selector.resize(next_pow2_instance_padding(*num_rows), E::BaseField::ZERO); - MultilinearExtension::from_evaluation_vec_smart( - ceil_log2(next_pow2_instance_padding(*num_rows)), - selector, - ) - .into() + for (selector, group) in &cs.expression_groups { + let Some(selector) = selector else { + assert!(group.is_empty(), "all expressions must have a selector"); + continue; }; - for ((r_rlc_expr, annotation), (ram_type_expr, r_exprs)) in (cs - .r_expressions - .iter() - .chain(cs.r_table_expressions.iter().map(|expr| &expr.expr))) - .zip_eq( - cs.r_expressions_namespace_map - .iter() - .chain(cs.r_table_expressions_namespace_map.iter()), - ) - .zip_eq(cs.r_ram_types.iter()) - { - let ram_type_mle = wit_infer_by_expr( - ram_type_expr, - cs.num_witin, - cs.num_fixed as WitnessId, - cs.instance_openings.len(), - fixed, - witness, - structural_witness, - &pi_mles, - &pub_io_evals, - &challenges, - ); - let ram_type_vec = ram_type_mle.get_ext_field_vec(); - let read_records = wit_infer_by_expr( - r_rlc_expr, - cs.num_witin, - cs.num_fixed as WitnessId, - cs.instance_openings.len(), - fixed, - witness, - structural_witness, - &pi_mles, - &pub_io_evals, - &challenges, - ); - let r_selector_vec = r_selector.get_base_field_vec(); - let read_records = filter_mle_by_predicate(read_records, |i, _v| { - ram_type_vec[i] == E::from_canonical_u32($ram_type as u32) - && r_selector_vec[i] == E::BaseField::ONE - }); - if read_records.is_empty() { + let r_expressions = &group.r_expressions; + if r_expressions.is_empty() && group.r_table_expressions.is_empty() { continue; } + let r_selector = structural_witness[selector.selector_expr().id()].clone(); - if $ram_type == RAMType::GlobalState { - // r_exprs = [GlobalState, pc, timestamp] - assert_eq!(r_exprs.len(), 3); - let r = r_exprs - .into_iter() - .skip(1) - .map(|expr| { - let v = wit_infer_by_expr( - expr, - cs.num_witin, - cs.num_fixed as WitnessId, - cs.instance_openings.len(), - fixed, - witness, - structural_witness, - &pi_mles, - &pub_io_evals, - &challenges, - ); - filter_mle_by_selector_mle(v, r_selector.clone()) - }) - .collect_vec(); - // convert [[pc], [timestamp]] into [[pc, timestamp]] - let r = (0..r[0].len()) - // TODO: use transpose - .map(|row| r.iter().map(|r| r[row]).collect_vec()) - .collect_vec(); - - assert!(gs.insert(circuit_name.clone(), r).is_none()); - }; - - let mut records = vec![]; - let mut reads_within_expr_dedup = HashSet::new(); - for (row, record) in enumerate(read_records) { - // TODO: return error - assert_eq!( - reads_within_expr_dedup.insert(record), - true, - "circuit name {circuit_name} within expression read duplicated on RAMType {:?} annotation {:?} on row {row}", - $ram_type, - annotation, + for (r_rlc_expr, annotation, (ram_type_expr, r_exprs)) in r_expressions + .iter() + .map(|r| (&r.expression, &r.expression_namespace_map, &r.meta)) + .chain( + group + .r_table_expressions + .iter() + .map(|t| (&t.expr, &t.expression_namespace_map, &t.meta)), + ) + { + let ram_type_mle = wit_infer_by_expr( + ram_type_expr, + cs.num_witin, + cs.num_fixed as WitnessId, + cs.instance_openings.len(), + fixed, + witness, + structural_witness, + &pi_mles, + &pub_io_evals, + &challenges, ); - assert_eq!( - reads.insert(record), - true, - "circuit name {circuit_name} crossing-chip read duplicated on RAMType {:?} annotation {:?} on row {row}", - $ram_type, - annotation, + let ram_type_vec = ram_type_mle.get_ext_field_vec(); + let read_records = wit_infer_by_expr( + r_rlc_expr, + cs.num_witin, + cs.num_fixed as WitnessId, + cs.instance_openings.len(), + fixed, + witness, + structural_witness, + &pi_mles, + &pub_io_evals, + &challenges, ); - records.push((record, row)); + let r_selector_vec = r_selector.get_base_field_vec(); + let read_records = filter_mle_by_predicate(read_records, |i, _v| { + ram_type_vec[i] == E::from_canonical_u32($ram_type as u32) + && r_selector_vec[i] == E::BaseField::ONE + }); + if read_records.is_empty() { + continue; + } + + if $ram_type == RAMType::GlobalState { + // r_exprs = [GlobalState, pc, timestamp] + assert_eq!(r_exprs.len(), 3); + let r = r_exprs + .into_iter() + .skip(1) + .map(|expr| { + let v = wit_infer_by_expr( + expr, + cs.num_witin, + cs.num_fixed as WitnessId, + cs.instance_openings.len(), + fixed, + witness, + structural_witness, + &pi_mles, + &pub_io_evals, + &challenges, + ); + filter_mle_by_selector_mle(v, r_selector.clone()) + }) + .collect_vec(); + // convert [[pc], [timestamp]] into [[pc, timestamp]] + let r = (0..r[0].len()) + // TODO: use transpose + .map(|row| r.iter().map(|r| r[row]).collect_vec()) + .collect_vec(); + + assert!(gs.insert(circuit_name.clone(), r).is_none()); + }; + + let mut records = vec![]; + let mut reads_within_expr_dedup = HashSet::new(); + for (row, record) in enumerate(read_records) { + // TODO: return error + assert_eq!( + reads_within_expr_dedup.insert(record), + true, + "circuit name {circuit_name} within expression read duplicated on RAMType {:?} annotation {:?} on row {row}", + $ram_type, + annotation, + ); + assert_eq!( + reads.insert(record), + true, + "circuit name {circuit_name} crossing-chip read duplicated on RAMType {:?} annotation {:?} on row {row}", + $ram_type, + annotation, + ); + records.push((record, row)); + } + reads_grp_by_annotations + .insert(annotation.clone(), (records, circuit_name.clone())); } - reads_grp_by_annotations - .insert(annotation.clone(), (records, circuit_name.clone())); } } - ( reads, reads_grp_by_annotations, @@ -1552,6 +1589,18 @@ Hints: panic!("found {} r/w mismatch errors", num_rw_mismatch_errors); } } + + fn default_structural_witin_from_wits( + _cs: &ConstraintSystem, + eval_len: usize, + ) -> Vec> { + let mut selector_eval = vec![E::BaseField::ONE; eval_len]; + selector_eval.resize(eval_len.next_power_of_two(), E::BaseField::ZERO); + let selector = + MultilinearExtension::from_evaluation_vec_smart(ceil_log2(eval_len), selector_eval) + .into(); + vec![selector; 1] + } } fn compare_lkm(lkm_a: Multiplicity, lkm_b: Multiplicity) -> Vec> @@ -1631,6 +1680,7 @@ fn filter_mle_by_selector_mle( #[cfg(test)] mod tests { + use super::*; use crate::{ ROMType, @@ -1693,7 +1743,7 @@ mod tests { .map(|f| f.into_mle().into()) .collect_vec(); - MockProver::assert_satisfied(&builder, &wits_in, &[], &[], None, None); + MockProver::assert_satisfied(&builder, &wits_in, &[], &[], 2, None, None); } #[derive(Debug)] @@ -1729,7 +1779,7 @@ mod tests { ]; let challenge = [1.into_f(), 1000.into_f()]; - MockProver::assert_satisfied(&builder, &wits_in, &[], &[], Some(challenge), None); + MockProver::assert_satisfied(&builder, &wits_in, &[], &[], 2, Some(challenge), None); } #[test] @@ -1743,7 +1793,8 @@ mod tests { let wits_in = vec![(vec![123u64.into_f()] as Vec).into_mle().into()]; let challenge = [2.into_f(), 1000.into_f()]; - let result = MockProver::run_with_challenge(&builder, &[], &wits_in, &[], challenge, None); + let result = + MockProver::run_with_challenge(&builder, &[], &wits_in, &[], 1, challenge, None); assert!(result.is_err(), "Expected error"); let err = result.unwrap_err(); assert_eq!( @@ -1800,10 +1851,14 @@ mod tests { } impl AssertLtCircuit { - fn construct_circuit(cb: &mut CircuitBuilder) -> Result { + fn construct_circuit( + cb: &mut CircuitBuilder, + max_bits: usize, + ) -> Result { let a = cb.create_witin(|| "a"); let b = cb.create_witin(|| "b"); - let lt_wtns = AssertLtConfig::construct_circuit(cb, || "lt", a.expr(), b.expr(), 1)?; + let lt_wtns = + AssertLtConfig::construct_circuit(cb, || "lt", a.expr(), b.expr(), max_bits)?; Ok(Self { a, b, lt_wtns }) } @@ -1849,7 +1904,7 @@ mod tests { let mut cs = ConstraintSystem::new(|| "test_assert_lt_1"); let mut builder = CircuitBuilder::::new(&mut cs); - let circuit = AssertLtCircuit::construct_circuit(&mut builder).unwrap(); + let circuit = AssertLtCircuit::construct_circuit(&mut builder, 2).unwrap(); let mut lk_multiplicity = LkMultiplicity::default(); let raw_witin = circuit @@ -1867,6 +1922,7 @@ mod tests { &builder, [raw_witin, RowMajorMatrix::empty()], &[], + 2, Some([1.into_f(), 1000.into_f()]), None, ); @@ -1877,7 +1933,7 @@ mod tests { let mut cs = ConstraintSystem::new(|| "test_assert_lt_u32"); let mut builder = CircuitBuilder::::new(&mut cs); - let circuit = AssertLtCircuit::construct_circuit(&mut builder).unwrap(); + let circuit = AssertLtCircuit::construct_circuit(&mut builder, 2).unwrap(); let mut lk_multiplicity = LkMultiplicity::default(); let raw_witin = circuit .assign_instances::( @@ -1900,6 +1956,7 @@ mod tests { &builder, [raw_witin, RowMajorMatrix::empty()], &[], + 2, Some([1.into_f(), 1000.into_f()]), None, ); @@ -1918,10 +1975,13 @@ mod tests { } impl LtCircuit { - fn construct_circuit(cb: &mut CircuitBuilder) -> Result { + fn construct_circuit( + cb: &mut CircuitBuilder, + max_bits: usize, + ) -> Result { let a = cb.create_witin(|| "a"); let b = cb.create_witin(|| "b"); - let lt_wtns = IsLtConfig::construct_circuit(cb, || "lt", a.expr(), b.expr(), 1)?; + let lt_wtns = IsLtConfig::construct_circuit(cb, || "lt", a.expr(), b.expr(), max_bits)?; Ok(Self { a, b, lt_wtns }) } @@ -1967,7 +2027,7 @@ mod tests { let mut cs = ConstraintSystem::new(|| "test_lt_1"); let mut builder = CircuitBuilder::::new(&mut cs); - let circuit = LtCircuit::construct_circuit(&mut builder).unwrap(); + let circuit = LtCircuit::construct_circuit(&mut builder, 2).unwrap(); let mut lk_multiplicity = LkMultiplicity::default(); let raw_witin = circuit @@ -1985,6 +2045,7 @@ mod tests { &builder, [raw_witin, RowMajorMatrix::empty()], &[], + 2, Some([1.into_f(), 1000.into_f()]), None, ); @@ -1995,7 +2056,7 @@ mod tests { let mut cs = ConstraintSystem::new(|| "test_lt_u32"); let mut builder = CircuitBuilder::::new(&mut cs); - let circuit = LtCircuit::construct_circuit(&mut builder).unwrap(); + let circuit = LtCircuit::construct_circuit(&mut builder, 2).unwrap(); let mut lk_multiplicity = LkMultiplicity::default(); let raw_witin = circuit @@ -2019,6 +2080,7 @@ mod tests { &builder, [raw_witin, RowMajorMatrix::empty()], &[], + 2, Some([1.into_f(), 1000.into_f()]), None, ); diff --git a/ceno_zkvm/src/scheme/tests.rs b/ceno_zkvm/src/scheme/tests.rs index dfb6c35ef..b089b4025 100644 --- a/ceno_zkvm/src/scheme/tests.rs +++ b/ceno_zkvm/src/scheme/tests.rs @@ -1,5 +1,5 @@ use crate::{ - circuit_builder::CircuitBuilder, + circuit_builder::{CircuitBuilder, ConstraintSystem}, error::ZKVMError, instructions::{ Instruction, @@ -9,9 +9,12 @@ use crate::{ constants::DYNAMIC_RANGE_MAX_BITS, cpu::CpuTowerProver, create_backend, create_prover, - hal::{ProofInput, TowerProverSpec}, + hal::{ProofInput, TowerProver, TowerProverSpec}, + }, + structs::{ + ComposedConstrainSystem, ProgramParams, RAMType, ZKVMConstraintSystem, ZKVMFixedTraces, + ZKVMWitnesses, }, - structs::{ProgramParams, RAMType, ZKVMConstraintSystem, ZKVMFixedTraces, ZKVMWitnesses}, tables::ProgramTableCircuit, witness::{LkMultiplicity, set_val}, }; @@ -22,12 +25,14 @@ use ceno_emul::{ }; use ff_ext::{ExtensionField, FieldInto, FromUniformBytes, GoldilocksExt2}; use gkr_iop::cpu::default_backend_config; +use gkr_iop::selector::SelectorType; #[cfg(feature = "gpu")] use gkr_iop::gpu::{MultilinearExtensionGpu, gpu_prover::*}; -use multilinear_extensions::{ToExpr, WitIn, mle::MultilinearExtension}; -use std::marker::PhantomData; -#[cfg(feature = "gpu")] -use std::sync::Arc; +use multilinear_extensions::{ + ToExpr, WitIn, + mle::{ArcMultilinearExtension, MultilinearExtension}, +}; +use std::{marker::PhantomData, rc::Rc, sync::Arc}; #[cfg(debug_assertions)] use ff_ext::{Instrumented, PoseidonField}; @@ -87,9 +92,9 @@ impl Instruction for Test cb.assert_ux::<_, _, 16>(|| "regid_in_range", reg_id.expr())?; Result::<(), ZKVMError>::Ok(()) })?; - assert_eq!(cb.cs.lk_expressions.len(), L); - assert_eq!(cb.cs.r_expressions.len(), RW); - assert_eq!(cb.cs.w_expressions.len(), RW); + assert_eq!(cb.cs.lk_expressions_len(), L); + assert_eq!(cb.cs.r_expressions_len(), RW); + assert_eq!(cb.cs.w_expressions_len(), RW); Ok(TestConfig { reg_id }) } @@ -268,6 +273,85 @@ fn test_rw_lk_expression_combination() { test_rw_lk_expression_combination_inner::<17, 61, E, Pcs>(); } +#[test] +fn test_tower_record_order_multi_selector() { + type E = GoldilocksExt2; + type Pcs = WhirDefault; + + let mut cs = ConstraintSystem::::new(|| "test"); + let mut cb = CircuitBuilder::new(&mut cs); + let reg_id = cb.create_witin(|| "reg_id"); + let record = vec![1.into(), reg_id.expr()]; + + cb.with_selector(SelectorType::None, |cb| { + cb.read_record(|| "read_a", RAMType::Register, record.clone())?; + cb.write_record(|| "write_a", RAMType::Register, record.clone())?; + Ok(()) + }) + .unwrap(); + + let sel_b = cb.create_placeholder_structural_witin(|| "sel_b"); + let sel_b = SelectorType::Prefix(sel_b.expr()); + cb.with_selector(sel_b, |cb| { + cb.read_record(|| "read_b", RAMType::Register, record.clone())?; + cb.write_record(|| "write_b", RAMType::Register, record)?; + Ok(()) + }) + .unwrap(); + + assert_eq!(cs.expression_groups.len(), 2); + + let composed_cs = ComposedConstrainSystem { + zkvm_v1_css: cs, + gkr_circuit: None, + }; + + let (max_num_variables, security_level) = default_backend_config(); + let backend = Rc::new(gkr_iop::cpu::CpuBackend::::new( + max_num_variables, + security_level, + )); + let prover = gkr_iop::cpu::CpuProver::new(backend); + + let input = ProofInput::> { + witness: vec![], + structural_witness: vec![], + fixed: vec![], + public_input: vec![], + pub_io_evals: vec![], + num_instances: vec![2], + has_ecc_ops: false, + }; + + let to_e = |v: u64| E::from(::BaseField::from_canonical_u64(v)); + let make_record = |a: u64, b: u64| -> ArcMultilinearExtension { + Arc::new(vec![to_e(a), to_e(b)].into_mle()) + }; + let records = vec![ + make_record(1, 2), // group A read + make_record(3, 4), // group A write + make_record(5, 6), // group B read + make_record(7, 8), // group B write + ]; + + let (out_evals, _prod_specs, _logup_specs) = + prover.build_tower_witness(&composed_cs, &input, &records); + + assert_eq!(out_evals.len(), 3); + let r_out_evals = &out_evals[0]; + let w_out_evals = &out_evals[1]; + let lk_out_evals = &out_evals[2]; + + assert!(lk_out_evals.is_empty()); + assert_eq!(r_out_evals.len(), 2); + assert_eq!(w_out_evals.len(), 2); + + assert_eq!(r_out_evals[0], vec![to_e(1), to_e(2)]); + assert_eq!(w_out_evals[0], vec![to_e(3), to_e(4)]); + assert_eq!(r_out_evals[1], vec![to_e(5), to_e(6)]); + assert_eq!(w_out_evals[1], vec![to_e(7), to_e(8)]); +} + const PROGRAM_CODE: [ceno_emul::Instruction; 4] = [ encode_rv32(ADD, 4, 1, 4, 0), encode_rv32(ECALL, 0, 0, 0, 0), diff --git a/ceno_zkvm/src/scheme/utils.rs b/ceno_zkvm/src/scheme/utils.rs index f9bfae7df..7ef084114 100644 --- a/ceno_zkvm/src/scheme/utils.rs +++ b/ceno_zkvm/src/scheme/utils.rs @@ -11,6 +11,7 @@ use gkr_iop::{ evaluation::EvalExpression, gkr::{GKRCircuit, GKRCircuitOutput, GKRCircuitWitness, layer::LayerWitness}, hal::{MultilinearPolynomial, ProtocolWitnessGeneratorProver, ProverBackend}, + selector::SelectorContext, }; use itertools::Itertools; use mpcs::PolynomialCommitmentScheme; @@ -332,13 +333,7 @@ pub fn build_main_witness< // circuit must have at least one read/write/lookup assert!( - cs.r_expressions.len() - + cs.w_expressions.len() - + cs.lk_expressions.len() - + cs.r_table_expressions.len() - + cs.w_table_expressions.len() - + cs.lk_table_expressions.len() - > 0, + composed_cs.num_reads() + composed_cs.num_writes() + composed_cs.num_lks() > 0, "assert circuit" ); @@ -522,6 +517,33 @@ pub fn gkr_witness< ) } +/// This assumes the order is always zero, read, write, lookup +/// TODO: make it more general if needed +pub fn global_selector_ctxs( + _cs: &gkr_iop::circuit_builder::ConstraintSystem, + _gkr_circuit: &gkr_iop::gkr::GKRCircuit, + num_instances: &[usize], + num_var_with_rotation: usize, +) -> Vec { + vec![ + SelectorContext { + offset: 0, + num_instances: num_instances[0], + num_vars: num_var_with_rotation, + }, + SelectorContext { + offset: num_instances[0], + num_instances: num_instances[1], + num_vars: num_var_with_rotation, + }, + SelectorContext { + offset: 0, + num_instances: num_instances[0] + num_instances[1], + num_vars: num_var_with_rotation, + }, + ] +} + #[cfg(test)] mod tests { diff --git a/ceno_zkvm/src/scheme/verifier.rs b/ceno_zkvm/src/scheme/verifier.rs index a040108ab..2ae458be9 100644 --- a/ceno_zkvm/src/scheme/verifier.rs +++ b/ceno_zkvm/src/scheme/verifier.rs @@ -1,29 +1,10 @@ -use either::Either; -use ff_ext::ExtensionField; use std::{ iter::{self, once, repeat_n}, marker::PhantomData, }; -#[cfg(debug_assertions)] -use ff_ext::{Instrumented, PoseidonField}; - -use super::{ZKVMChipProof, ZKVMProof}; -use crate::{ - error::ZKVMError, - instructions::riscv::constants::{ - END_PC_IDX, HEAP_LENGTH_IDX, HEAP_START_ADDR_IDX, INIT_CYCLE_IDX, INIT_PC_IDX, SHARD_ID_IDX, - }, - scheme::{ - constants::{NUM_FANIN, SEPTIC_EXTENSION_DEGREE}, - septic_curve::{SepticExtension, SepticPoint}, - }, - structs::{ - ComposedConstrainSystem, EccQuarkProof, PointAndEval, TowerProofs, VerifyingKey, - ZKVMVerifyingKey, - }, -}; use ceno_emul::{FullTracer as Tracer, WORD_SIZE}; +use ff_ext::ExtensionField; use gkr_iop::{ self, selector::{SelectorContext, SelectorType}, @@ -46,6 +27,26 @@ use sumcheck::{ use transcript::{ForkableTranscript, Transcript}; use witness::next_pow2_instance_padding; +use super::{ZKVMChipProof, ZKVMProof}; +use crate::{ + error::ZKVMError, + instructions::riscv::constants::{ + END_PC_IDX, HEAP_LENGTH_IDX, HEAP_START_ADDR_IDX, INIT_CYCLE_IDX, INIT_PC_IDX, SHARD_ID_IDX, + }, + scheme::{ + constants::{NUM_FANIN, SEPTIC_EXTENSION_DEGREE}, + septic_curve::{SepticExtension, SepticPoint}, + utils::global_selector_ctxs, + }, + structs::{ + ComposedConstrainSystem, EccQuarkProof, PointAndEval, TowerProofs, VerifyingKey, + ZKVMVerifyingKey, + }, +}; + +#[cfg(debug_assertions)] +use ff_ext::{Instrumented, PoseidonField}; + pub struct ZKVMVerifier> { pub vk: ZKVMVerifyingKey, } @@ -478,9 +479,9 @@ impl> ZKVMVerifier } = &composed_cs; let num_instances = proof.num_instances.iter().sum(); let (r_counts_per_instance, w_counts_per_instance, lk_counts_per_instance) = ( - cs.r_expressions.len() + cs.r_table_expressions.len(), - cs.w_expressions.len() + cs.w_table_expressions.len(), - cs.lk_expressions.len() + cs.lk_table_expressions.len(), + composed_cs.num_reads(), + composed_cs.num_writes(), + composed_cs.num_lks(), ); let num_batched = r_counts_per_instance + w_counts_per_instance + lk_counts_per_instance; @@ -494,9 +495,8 @@ impl> ZKVMVerifier let num_var_with_rotation = log2_num_instances + composed_cs.rotation_vars().unwrap_or(0); // constrain log2_num_instances within max length - cs.r_table_expressions - .iter() - .chain(&cs.w_table_expressions) + cs.r_table_expressions_all() + .chain(cs.w_table_expressions_all()) .for_each(|set_table_expr| { // iterate through structural witins and collect max round. let num_vars = set_table_expr @@ -518,7 +518,7 @@ impl> ZKVMVerifier }); assert_eq!(num_vars, log2_num_instances); }); - cs.lk_table_expressions.iter().for_each(|l| { + cs.lk_table_expressions_all().for_each(|l| { // iterate through structural witins and collect max round. let num_vars = l.table_spec.len.map(ceil_log2).unwrap_or_else(|| { l.table_spec @@ -566,7 +566,7 @@ impl> ZKVMVerifier transcript, )?; - if cs.lk_table_expressions.is_empty() { + if cs.lk_table_expressions_len() == 0 { // verify LogUp witness nominator p(x) ?= constant vector 1 logup_p_evals .iter() @@ -593,17 +593,33 @@ impl> ZKVMVerifier debug_assert_eq!(logup_p_evals.len(), lk_counts_per_instance); debug_assert_eq!(logup_q_evals.len(), lk_counts_per_instance); - let evals = record_evals - .iter() - // append p_evals if there got lk table expressions - .chain(if cs.lk_table_expressions.is_empty() { - Either::Left(iter::empty()) + let (r_record_evals, w_record_evals) = record_evals.split_at(r_counts_per_instance); + let mut r_iter = r_record_evals.iter(); + let mut w_iter = w_record_evals.iter(); + let mut p_iter = logup_p_evals.iter(); + let mut q_iter = logup_q_evals.iter(); + + let mut evals = Vec::with_capacity(cs.output_evaluations_len()); + for (_, group) in cs.expression_groups.iter() { + for _ in 0..(group.r_expressions.len() + group.r_table_expressions.len()) { + evals.push(r_iter.next().expect("missing r eval").clone()); + } + for _ in 0..(group.w_expressions.len() + group.w_table_expressions.len()) { + evals.push(w_iter.next().expect("missing w eval").clone()); + } + if group.lk_table_expressions.is_empty() { + for _ in 0..group.lk_expressions.len() { + evals.push(q_iter.next().expect("missing lk q eval").clone()); + } } else { - Either::Right(logup_p_evals.iter()) - }) - .chain(&logup_q_evals) - .cloned() - .collect_vec(); + for _ in 0..group.lk_table_expressions.len() { + evals.push(p_iter.next().expect("missing lk p eval").clone()); + } + for _ in 0..group.lk_table_expressions.len() { + evals.push(q_iter.next().expect("missing lk q eval").clone()); + } + } + } let gkr_circuit = gkr_circuit.as_ref().unwrap(); let selector_ctxs = if cs.ec_final_sum.is_empty() { @@ -614,7 +630,7 @@ impl> ZKVMVerifier gkr_circuit .layers .first() - .map(|layer| layer.out_sel_and_eval_exprs.len()) + .map(|layer| layer.selector_ctxs_len()) .unwrap_or(0) ] } else { @@ -626,23 +642,7 @@ impl> ZKVMVerifier proof.num_instances[1], proof.num_instances[0] + proof.num_instances[1], ); - vec![ - SelectorContext { - offset: 0, - num_instances: proof.num_instances[0], - num_vars: num_var_with_rotation, - }, - SelectorContext { - offset: proof.num_instances[0], - num_instances: proof.num_instances[1], - num_vars: num_var_with_rotation, - }, - SelectorContext { - offset: 0, - num_instances: proof.num_instances[0] + proof.num_instances[1], - num_vars: num_var_with_rotation, - }, - ] + global_selector_ctxs(cs, gkr_circuit, &proof.num_instances, num_var_with_rotation) }; let (_, rt) = gkr_circuit.verify( num_var_with_rotation, diff --git a/ceno_zkvm/src/stats.rs b/ceno_zkvm/src/stats.rs index 35ec58025..c1993ce6a 100644 --- a/ceno_zkvm/src/stats.rs +++ b/ceno_zkvm/src/stats.rs @@ -4,8 +4,8 @@ use crate::{ utils, }; use ff_ext::ExtensionField; +use gkr_iop::circuit_builder::RecordExpression; use itertools::Itertools; -use multilinear_extensions::Expression; use prettytable::{Table, row}; use serde_json::json; use std::{ @@ -82,35 +82,47 @@ impl std::ops::Add for CircuitStats { impl CircuitStats { pub fn new(system: &ConstraintSystem) -> Self { - let just_degrees_grouped = |exprs: &Vec>| { + let just_degrees_grouped = |expressions: Vec<&RecordExpression>| { let mut counter = HashMap::new(); - for expr in exprs { + for expr in expressions.into_iter().map(|r| &r.expression) { *counter.entry(expr.degree()).or_insert(0) += 1; } counter }; - let is_opcode = system.lk_table_expressions.is_empty() - && system.r_table_expressions.is_empty() - && system.w_table_expressions.is_empty(); + let is_opcode = system.lk_table_expressions_len() == 0 + && system.r_table_expressions_len() == 0 + && system.w_table_expressions_len() == 0; // distinguishing opcodes from tables as done in ZKVMProver::create_proof if is_opcode { CircuitStats::OpCode(OpCodeStats { namespace: system.ns.clone(), witnesses: system.num_witin as usize, - reads: system.r_expressions.len(), - writes: system.w_expressions.len(), - lookups: system.lk_expressions.len(), - assert_zero_expr_degrees: just_degrees_grouped(&system.assert_zero_expressions), - assert_zero_sumcheck_expr_degrees: just_degrees_grouped( - &system.assert_zero_sumcheck_expressions, - ), + reads: system.r_expressions_len(), + writes: system.w_expressions_len(), + lookups: system.lk_expressions_len(), + assert_zero_expr_degrees: { + let exprs = system + .expression_groups + .values() + .flat_map(|g| g.assert_zero_expressions.iter()) + .collect::>(); + just_degrees_grouped(exprs) + }, + assert_zero_sumcheck_expr_degrees: { + let exprs = system + .expression_groups + .values() + .flat_map(|g| g.assert_zero_sumcheck_expressions.iter()) + .collect::>(); + just_degrees_grouped(exprs) + }, }) } else { - let table_len = if !system.lk_table_expressions.is_empty() { - system.lk_table_expressions[0].table_spec.len.unwrap_or(0) - } else { - 0 - }; + let table_len = system + .lk_table_expressions_all() + .next() + .map(|t| t.table_spec.len.unwrap_or(0)) + .unwrap_or(0); CircuitStats::Table(TableStats { table_len }) } } diff --git a/ceno_zkvm/src/structs.rs b/ceno_zkvm/src/structs.rs index 76a2d9334..b3a815c76 100644 --- a/ceno_zkvm/src/structs.rs +++ b/ceno_zkvm/src/structs.rs @@ -148,11 +148,11 @@ impl ComposedConstrainSystem { } pub fn num_reads(&self) -> usize { - self.zkvm_v1_css.r_expressions.len() + self.zkvm_v1_css.r_table_expressions.len() + self.zkvm_v1_css.r_expressions_len() + self.zkvm_v1_css.r_table_expressions_len() } pub fn num_writes(&self) -> usize { - self.zkvm_v1_css.w_expressions.len() + self.zkvm_v1_css.w_table_expressions.len() + self.zkvm_v1_css.w_expressions_len() + self.zkvm_v1_css.w_table_expressions_len() } pub fn instance_openings(&self) -> &[Instance] { @@ -163,12 +163,12 @@ impl ComposedConstrainSystem { } pub fn is_with_lk_table(&self) -> bool { - !self.zkvm_v1_css.lk_table_expressions.is_empty() + self.zkvm_v1_css.lk_table_expressions_len() > 0 } /// return number of lookup operation pub fn num_lks(&self) -> usize { - self.zkvm_v1_css.lk_expressions.len() + self.zkvm_v1_css.lk_table_expressions.len() + self.zkvm_v1_css.lk_expressions_len() + self.zkvm_v1_css.lk_table_expressions_len() } /// return num_vars belongs to rotation @@ -222,7 +222,8 @@ impl ZKVMConstraintSystem { } pub fn register_opcode_circuit>(&mut self) -> OC::InstructionConfig { - let mut cs = ConstraintSystem::new(|| format!("riscv_opcode/{}", OC::name())); + let mut cs = + ConstraintSystem::new(|| format!("riscv_opcode/{}", OC::name())); let mut circuit_builder = CircuitBuilder::::new(&mut cs); let (config, gkr_iop_circuit) = OC::build_gkr_iop_circuit(&mut circuit_builder, &self.params).unwrap(); @@ -234,9 +235,9 @@ impl ZKVMConstraintSystem { "opcode circuit {} has {} witnesses, {} reads, {} writes, {} lookups", OC::name(), cs.num_witin(), - cs.zkvm_v1_css.r_expressions.len(), - cs.zkvm_v1_css.w_expressions.len(), - cs.zkvm_v1_css.lk_expressions.len(), + cs.zkvm_v1_css.r_expressions_len(), + cs.zkvm_v1_css.w_expressions_len(), + cs.zkvm_v1_css.lk_expressions_len(), ); assert!( self.circuit_css.insert(OC::name(), cs).is_none(), @@ -247,7 +248,8 @@ impl ZKVMConstraintSystem { } pub fn register_table_circuit>(&mut self) -> TC::TableConfig { - let mut cs = ConstraintSystem::new(|| format!("riscv_table/{}", TC::name())); + let mut cs = + ConstraintSystem::new(|| format!("riscv_table/{}", TC::name())); let mut circuit_builder = CircuitBuilder::::new(&mut cs); let (config, gkr_iop_circuit) = TC::build_gkr_iop_circuit(&mut circuit_builder, &self.params).unwrap(); diff --git a/ceno_zkvm/src/tables/mod.rs b/ceno_zkvm/src/tables/mod.rs index 0acfed059..6e1915753 100644 --- a/ceno_zkvm/src/tables/mod.rs +++ b/ceno_zkvm/src/tables/mod.rs @@ -2,10 +2,10 @@ use crate::{circuit_builder::CircuitBuilder, error::ZKVMError, structs::ProgramP use ff_ext::ExtensionField; use gkr_iop::{ chip::Chip, + default_out_eval_groups, gkr::{GKRCircuit, layer::Layer}, selector::SelectorType, }; -use itertools::Itertools; use multilinear_extensions::ToExpr; use std::collections::HashMap; use witness::RowMajorMatrix; @@ -45,38 +45,15 @@ pub trait TableCircuit { param: &ProgramParams, ) -> Result<(Self::TableConfig, Option>), ZKVMError> { let config = Self::construct_circuit(cb, param)?; - let r_table_len = cb.cs.r_table_expressions.len(); - let w_table_len = cb.cs.w_table_expressions.len(); - let lk_table_len = cb.cs.lk_table_expressions.len() * 2; let selector = cb.create_placeholder_structural_witin(|| "selector"); let selector_type = SelectorType::Prefix(selector.expr()); + cb.cs.set_default_read_selector(selector_type.clone()); + cb.cs.set_default_write_selector(selector_type.clone()); + cb.cs.set_default_lookup_selector(selector_type.clone()); - // all shared the same selector - let (out_evals, mut chip) = ( - [ - // r_record - (0..r_table_len).collect_vec(), - // w_record - (r_table_len..r_table_len + w_table_len).collect_vec(), - // lk_record - (r_table_len + w_table_len..r_table_len + w_table_len + lk_table_len).collect_vec(), - // zero_record - vec![], - ], - Chip::new_from_cb(cb, 0), - ); - - // register selector to legacy constrain system - if r_table_len > 0 { - cb.cs.r_selector = Some(selector_type.clone()); - } - if w_table_len > 0 { - cb.cs.w_selector = Some(selector_type.clone()); - } - if lk_table_len > 0 { - cb.cs.lk_selector = Some(selector_type.clone()); - } + let out_evals = default_out_eval_groups(cb); + let mut chip = Chip::new_from_cb(cb, 0); let layer = Layer::from_circuit_builder(cb, Self::name(), 0, out_evals); chip.add_layer(layer); diff --git a/ceno_zkvm/src/tables/ram/ram_circuit.rs b/ceno_zkvm/src/tables/ram/ram_circuit.rs index 249f70125..97f208cf8 100644 --- a/ceno_zkvm/src/tables/ram/ram_circuit.rs +++ b/ceno_zkvm/src/tables/ram/ram_circuit.rs @@ -13,11 +13,11 @@ use ceno_emul::{Addr, Cycle, GetAddr, WORD_SIZE, Word}; use ff_ext::{ExtensionField, SmallField}; use gkr_iop::{ chip::Chip, + default_out_eval_groups, error::CircuitBuilderError, gkr::{GKRCircuit, layer::Layer}, selector::SelectorType, }; -use itertools::Itertools; use multilinear_extensions::{Expression, StructuralWitIn, StructuralWitInType, ToExpr}; use std::{collections::HashMap, marker::PhantomData, ops::Range}; use witness::{InstancePaddingStrategy, RowMajorMatrix}; @@ -339,28 +339,13 @@ impl TableCircuit for LocalFinalRamC param: &ProgramParams, ) -> Result<(Self::TableConfig, Option>), ZKVMError> { let config = Self::construct_circuit(cb, param)?; - let r_table_len = cb.cs.r_table_expressions.len(); let selector = cb.create_placeholder_structural_witin(|| "selector"); let selector_type = SelectorType::Prefix(selector.expr()); + cb.cs.set_default_read_selector(selector_type.clone()); - // all shared the same selector - let (out_evals, mut chip) = ( - [ - // r_record - (0..r_table_len).collect_vec(), - // w_record - vec![], - // lk_record - vec![], - // zero_record - vec![], - ], - Chip::new_from_cb(cb, 0), - ); - - // register selector to legacy constrain system - cb.cs.r_selector = Some(selector_type.clone()); + let out_evals = default_out_eval_groups(cb); + let mut chip = Chip::new_from_cb(cb, 0); let layer = Layer::from_circuit_builder(cb, Self::name(), 0, out_evals); chip.add_layer(layer); diff --git a/ceno_zkvm/src/tables/range/range_circuit.rs b/ceno_zkvm/src/tables/range/range_circuit.rs index d98161fea..78b2ccb6f 100644 --- a/ceno_zkvm/src/tables/range/range_circuit.rs +++ b/ceno_zkvm/src/tables/range/range_circuit.rs @@ -14,11 +14,11 @@ use crate::{ use ff_ext::ExtensionField; use gkr_iop::{ chip::Chip, + default_out_eval_groups, gkr::{GKRCircuit, layer::Layer}, selector::SelectorType, tables::LookupTable, }; -use itertools::Itertools; use multilinear_extensions::ToExpr; use witness::{InstancePaddingStrategy, RowMajorMatrix}; @@ -108,28 +108,13 @@ impl Result<(Self::TableConfig, Option>), ZKVMError> { let config = Self::construct_circuit(cb, param)?; - let lk_table_len = cb.cs.lk_table_expressions.len() * 2; let selector = cb.create_placeholder_structural_witin(|| "selector"); - let selector_type = SelectorType::Whole(selector.expr()); - - // all shared the same selector - let (out_evals, mut chip) = ( - [ - // r_record - vec![], - // w_record - vec![], - // lk_record - (0..lk_table_len).collect_vec(), - // zero_record - vec![], - ], - Chip::new_from_cb(cb, 0), - ); - - // register selector to legacy constrain system - cb.cs.lk_selector = Some(selector_type.clone()); + cb.cs + .set_default_lookup_selector(SelectorType::Whole(selector.expr())); + + let out_evals = default_out_eval_groups(cb); + let mut chip = Chip::new_from_cb(cb, 0); let layer = Layer::from_circuit_builder(cb, Self::name(), 0, out_evals); chip.add_layer(layer); diff --git a/ceno_zkvm/src/tables/shard_ram.rs b/ceno_zkvm/src/tables/shard_ram.rs index f07e0644b..cadc0eac7 100644 --- a/ceno_zkvm/src/tables/shard_ram.rs +++ b/ceno_zkvm/src/tables/shard_ram.rs @@ -17,12 +17,15 @@ use ff_ext::{ExtensionField, FieldInto, PoseidonField, SmallField}; use gkr_iop::{ chip::Chip, circuit_builder::CircuitBuilder, + default_out_eval_groups, error::CircuitBuilderError, gkr::{GKRCircuit, layer::Layer}, selector::SelectorType, }; use itertools::{Itertools, chain}; -use multilinear_extensions::{Expression, ToExpr, WitIn, util::max_usable_threads}; +use multilinear_extensions::{ + Expression, StructuralWitIn, StructuralWitInType, ToExpr, WitIn, util::max_usable_threads, +}; use p3::{ field::{Field, FieldAlgebra}, matrix::{Matrix, dense::RowMajorMatrix}, @@ -419,14 +422,6 @@ impl TableCircuit for ShardRamCircuit { let selector_w = cb.create_placeholder_structural_witin(|| "selector_w"); let selector_zero = cb.create_placeholder_structural_witin(|| "selector_zero"); - let config = Self::construct_circuit(cb, param)?; - - let w_len = cb.cs.w_expressions.len(); - let r_len = cb.cs.r_expressions.len(); - let lk_len = cb.cs.lk_expressions.len(); - let zero_len = - cb.cs.assert_zero_expressions.len() + cb.cs.assert_zero_sumcheck_expressions.len(); - let selector_r = SelectorType::Prefix(selector_r.expr()); // note that the actual offset should be set by prover // depending on the number of local read instances @@ -435,25 +430,17 @@ impl TableCircuit for ShardRamCircuit { // when selector_w = 1 => selector_zero = 1 let selector_zero = SelectorType::Prefix(selector_zero.expr()); - cb.cs.r_selector = Some(selector_r); - cb.cs.w_selector = Some(selector_w); - cb.cs.zero_selector = Some(selector_zero.clone()); - cb.cs.lk_selector = Some(selector_zero); - - // all shared the same selector - let (out_evals, mut chip) = ( - [ - // r_record - (0..r_len).collect_vec(), - // w_record - (r_len..r_len + w_len).collect_vec(), - // lk_record - (r_len + w_len..r_len + w_len + lk_len).collect_vec(), - // zero_record - (0..zero_len).collect_vec(), - ], - Chip::new_from_cb(cb, 0), - ); + cb.cs.set_default_read_selector(selector_r.clone()); + cb.cs.set_default_write_selector(selector_w.clone()); + cb.cs.set_default_lookup_selector(selector_zero.clone()); + cb.cs.set_default_zero_selector(selector_zero.clone()); + let config = Self::construct_circuit(cb, param)?; + + let out_evals = default_out_eval_groups(cb); + + // note that the actual offset should be set by prover + // depending on the number of local read instances + let mut chip = Chip::new_from_cb(cb, 0); let layer = Layer::from_circuit_builder(cb, format!("{}_main", Self::name()), 0, out_evals); chip.add_layer(layer); @@ -488,9 +475,18 @@ impl TableCircuit for ShardRamCircuit { // we can remove this one all opcode unittest migrate to call `build_gkr_iop_circuit` assert_eq!(num_structural_witin, 3); - let selector_r_witin = WitIn { id: 0 }; - let selector_w_witin = WitIn { id: 1 }; - let selector_zero_witin = WitIn { id: 2 }; + let selector_r_witin = StructuralWitIn { + id: 0, + witin_type: StructuralWitInType::Empty, + }; + let selector_w_witin = StructuralWitIn { + id: 1, + witin_type: StructuralWitInType::Empty, + }; + let selector_zero_witin = StructuralWitIn { + id: 2, + witin_type: StructuralWitInType::Empty, + }; let nthreads = max_usable_threads(); diff --git a/ceno_zkvm/src/uint/arithmetic.rs b/ceno_zkvm/src/uint/arithmetic.rs index d07a66e46..aa8e50c52 100644 --- a/ceno_zkvm/src/uint/arithmetic.rs +++ b/ceno_zkvm/src/uint/arithmetic.rs @@ -793,7 +793,7 @@ mod tests { .require_equal(|| "assert_g", &mut cb, &uint_e) .unwrap(); - MockProver::assert_satisfied(&cb, &witness_values, &[], &[], None, None); + MockProver::assert_satisfied(&cb, &witness_values, &[], &[], 1, None, None); } #[test] @@ -843,7 +843,7 @@ mod tests { .require_equal(|| "assert_g", &mut cb, &uint_g) .unwrap(); - MockProver::assert_satisfied(&cb, &witness_values, &[], &[], None, None); + MockProver::assert_satisfied(&cb, &witness_values, &[], &[], 1, None, None); } #[test] @@ -882,7 +882,7 @@ mod tests { .require_equal(|| "assert_e", &mut cb, &uint_e) .unwrap(); - MockProver::assert_satisfied(&cb, &witness_values, &[], &[], None, None); + MockProver::assert_satisfied(&cb, &witness_values, &[], &[], 1, None, None); } #[test] @@ -921,7 +921,7 @@ mod tests { .require_equal(|| "assert_e", &mut cb, &uint_e) .unwrap(); - MockProver::assert_satisfied(&cb, &witness_values, &[], &[], None, None); + MockProver::assert_satisfied(&cb, &witness_values, &[], &[], 1, None, None); } #[test] @@ -958,7 +958,7 @@ mod tests { .require_equal(|| "assert_g", &mut cb, &uint_c) .unwrap(); - MockProver::assert_satisfied(&cb, &witness_values, &[], &[], None, None); + MockProver::assert_satisfied(&cb, &witness_values, &[], &[], 1, None, None); } } } diff --git a/gkr_iop/src/chip.rs b/gkr_iop/src/chip.rs index 2b72f251e..cbc808a8b 100644 --- a/gkr_iop/src/chip.rs +++ b/gkr_iop/src/chip.rs @@ -17,6 +17,8 @@ pub struct Chip { pub n_fixed: usize, /// The number of base inputs committed in the whole protocol. pub n_committed: usize, + /// The number of structural within-layer inputs committed in the whole protocol. + pub n_structural_witin: usize, /// The number of challenges generated through the whole protocols /// (except the ones inside sumcheck protocols). @@ -36,23 +38,10 @@ impl Chip { Self { n_fixed: cb.cs.num_fixed, n_committed: cb.cs.num_witin as usize, + n_structural_witin: cb.cs.num_structural_witin as usize, n_challenges, - n_evaluations: cb.cs.w_expressions.len() - + cb.cs.r_expressions.len() - + cb.cs.lk_expressions.len() - + cb.cs.w_table_expressions.len() - + cb.cs.r_table_expressions.len() - + cb.cs.lk_table_expressions.len() * 2 - + cb.cs.num_fixed - + cb.cs.num_witin as usize - + cb.cs.instance_openings.len(), - final_out_evals: (0..cb.cs.w_expressions.len() - + cb.cs.r_expressions.len() - + cb.cs.lk_expressions.len() - + cb.cs.w_table_expressions.len() - + cb.cs.r_table_expressions.len() - + cb.cs.lk_table_expressions.len() * 2) - .collect_vec(), + n_evaluations: cb.cs.input_evaluations_len(), + final_out_evals: (0..cb.cs.output_evaluations_len()).collect_vec(), layers: vec![], } } diff --git a/gkr_iop/src/circuit_builder.rs b/gkr_iop/src/circuit_builder.rs index d84cb4d55..42d3ede3e 100644 --- a/gkr_iop/src/circuit_builder.rs +++ b/gkr_iop/src/circuit_builder.rs @@ -4,7 +4,12 @@ use multilinear_extensions::{ rlc_chip_record, }; use serde::de::DeserializeOwned; -use std::{collections::HashMap, iter::once, marker::PhantomData}; +use std::{ + collections::{BTreeMap, HashMap}, + iter::once, + marker::PhantomData, + mem, +}; use ff_ext::ExtensionField; @@ -68,6 +73,8 @@ pub struct LogupTableExpression { pub multiplicity: Expression, pub values: Expression, pub table_spec: SetTableSpec, + pub expression_namespace_map: String, + pub meta: (LookupTable, Vec>), } #[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] @@ -85,8 +92,30 @@ pub struct SetTableExpression { // TODO make decision to have enum/struct // for which option is more friendly to be processed by ConstrainSystem + recursive verifier pub table_spec: SetTableSpec, + + pub expression_namespace_map: String, + pub meta: (Expression, Vec>), +} + +#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] +#[serde(bound = "E: ExtensionField + DeserializeOwned, M: serde::Serialize + DeserializeOwned")] +pub struct RecordExpression { + pub expression: Expression, + pub expression_namespace_map: String, + pub meta: M, +} + +impl RecordExpression { + pub fn new(expression: Expression, expression_namespace_map: String, meta: M) -> Self { + Self { + expression, + expression_namespace_map, + meta, + } + } } +#[allow(clippy::type_complexity)] #[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] #[serde(bound = "E: ExtensionField + DeserializeOwned")] pub struct ConstraintSystem { @@ -108,44 +137,16 @@ pub struct ConstraintSystem { pub ec_slope_exprs: Vec>, pub ec_final_sum: Vec>, - pub r_selector: Option>, - pub r_expressions: Vec>, - pub r_expressions_namespace_map: Vec, - // for each read expression we store its ram type and original value before doing RLC - // the original value will be used for debugging - pub r_ram_types: Vec<(Expression, Vec>)>, - - pub w_selector: Option>, - pub w_expressions: Vec>, - pub w_expressions_namespace_map: Vec, - // for each write expression we store its ram type and original value before doing RLC - // the original value will be used for debugging - pub w_ram_types: Vec<(Expression, Vec>)>, - - /// init/final ram expression - pub r_table_expressions: Vec>, - pub r_table_expressions_namespace_map: Vec, - pub w_table_expressions: Vec>, - pub w_table_expressions_namespace_map: Vec, + pub expression_groups: BTreeMap>, ExpressionGroup>, + // specify whether constrains system cover only init_w // as it imply w/r set and final_w might happen ACROSS shards pub with_omc_init_only: bool, - pub lk_selector: Option>, - /// lookup expression - pub lk_expressions: Vec>, - pub lk_table_expressions: Vec>, - pub lk_expressions_namespace_map: Vec, - pub lk_expressions_items_map: Vec<(LookupTable, Vec>)>, - - pub zero_selector: Option>, - /// main constraints zero expression - pub assert_zero_expressions: Vec>, - pub assert_zero_expressions_namespace_map: Vec, - - /// main constraints zero expression for expression degree > 1, which require sumcheck to prove - pub assert_zero_sumcheck_expressions: Vec>, - pub assert_zero_sumcheck_expressions_namespace_map: Vec, + pub default_read_selector: Option>, + pub default_write_selector: Option>, + pub default_lookup_selector: Option>, + pub default_zero_selector: Option>, /// max zero sumcheck degree pub max_non_lc_degree: usize, @@ -163,6 +164,36 @@ pub struct ConstraintSystem { pub(crate) phantom: PhantomData, } +#[allow(clippy::type_complexity)] +#[derive(Clone, Debug, Default, serde::Serialize, serde::Deserialize)] +#[serde(bound( + serialize = "E::BaseField: serde::Serialize", + deserialize = "E::BaseField: DeserializeOwned" +))] +pub struct ExpressionGroup { + pub r_expressions: Vec, Vec>)>>, /* RLC records */ + pub w_expressions: Vec, Vec>)>>, /* RLC records */ + pub r_table_expressions: Vec>, + pub w_table_expressions: Vec>, + pub lk_expressions: Vec>)>>, // RLC records + pub lk_table_expressions: Vec>, + pub assert_zero_expressions: Vec>, + pub assert_zero_sumcheck_expressions: Vec>, +} + +impl ExpressionGroup { + pub fn is_empty(&self) -> bool { + self.r_expressions.is_empty() + && self.w_expressions.is_empty() + && self.r_table_expressions.is_empty() + && self.w_table_expressions.is_empty() + && self.lk_expressions.is_empty() + && self.lk_table_expressions.is_empty() + && self.assert_zero_expressions.is_empty() + && self.assert_zero_sumcheck_expressions.is_empty() + } +} + impl ConstraintSystem { pub fn new, N: FnOnce() -> NR>(root_name_fn: N) -> Self { Self { @@ -179,29 +210,12 @@ impl ConstraintSystem { ec_final_sum: vec![], ec_slope_exprs: vec![], ec_point_exprs: vec![], - r_selector: None, - r_expressions: vec![], - r_expressions_namespace_map: vec![], - r_ram_types: vec![], - w_selector: None, - w_expressions: vec![], - w_expressions_namespace_map: vec![], - w_ram_types: vec![], - r_table_expressions: vec![], - r_table_expressions_namespace_map: vec![], - w_table_expressions: vec![], - w_table_expressions_namespace_map: vec![], with_omc_init_only: false, - lk_selector: None, - lk_expressions: vec![], - lk_table_expressions: vec![], - lk_expressions_namespace_map: vec![], - lk_expressions_items_map: vec![], - zero_selector: None, - assert_zero_expressions: vec![], - assert_zero_expressions_namespace_map: vec![], - assert_zero_sumcheck_expressions: vec![], - assert_zero_sumcheck_expressions_namespace_map: vec![], + expression_groups: BTreeMap::new(), + default_read_selector: None, + default_write_selector: None, + default_lookup_selector: None, + default_zero_selector: None, max_non_lc_degree: 0, rotations: vec![], rotation_params: None, @@ -299,12 +313,12 @@ impl ConstraintSystem { .chain(record.clone()) .collect(), ); - self.lk_expressions.push(rlc_record); let path = self.ns.compute_path(name_fn().into()); - self.lk_expressions_namespace_map.push(path); // Since lk_expression is RLC(record) and when we're debugging // it's helpful to recover the value of record itself. - self.lk_expressions_items_map.push((rom_type, record)); + self.group_mut(self.default_lookup_selector.clone()) + .lk_expressions + .push(RecordExpression::new(rlc_record, path, (rom_type, record))); Ok(()) } @@ -332,16 +346,16 @@ impl ConstraintSystem { "rlc lk_table_record degree ({})", name_fn().into() ); - self.lk_table_expressions.push(LogupTableExpression { - values: rlc_record, - multiplicity, - table_spec, - }); let path = self.ns.compute_path(name_fn().into()); - self.lk_expressions_namespace_map.push(path); - // Since lk_expression is RLC(record) and when we're debugging - // it's helpful to recover the value of record itself. - self.lk_expressions_items_map.push((rom_type, record)); + self.group_mut(self.default_lookup_selector.clone()) + .lk_table_expressions + .push(LogupTableExpression { + values: rlc_record, + multiplicity, + table_spec, + expression_namespace_map: path, + meta: (rom_type, record), + }); Ok(()) } @@ -379,14 +393,15 @@ impl ConstraintSystem { NR: Into, N: FnOnce() -> NR, { - self.r_table_expressions.push(SetTableExpression { - expr: rlc_record, - table_spec, - }); let path = self.ns.compute_path(name_fn().into()); - self.r_table_expressions_namespace_map.push(path); - self.r_ram_types.push((ram_type, record)); - + self.group_mut(self.default_read_selector.clone()) + .r_table_expressions + .push(SetTableExpression { + expr: rlc_record, + table_spec, + expression_namespace_map: path, + meta: (ram_type, record), + }); Ok(()) } @@ -423,13 +438,15 @@ impl ConstraintSystem { NR: Into, N: FnOnce() -> NR, { - self.w_table_expressions.push(SetTableExpression { - expr: rlc_record, - table_spec, - }); let path = self.ns.compute_path(name_fn().into()); - self.w_table_expressions_namespace_map.push(path); - self.w_ram_types.push((ram_type, record)); + self.group_mut(self.default_write_selector.clone()) + .w_table_expressions + .push(SetTableExpression { + expr: rlc_record, + table_spec, + expression_namespace_map: path, + meta: (ram_type, record), + }); Ok(()) } @@ -451,12 +468,46 @@ impl ConstraintSystem { record: Vec>, rlc_record: Expression, ) -> Result<(), CircuitBuilderError> { - self.r_expressions.push(rlc_record); + self.read_rlc_record_with_selector( + self.default_read_selector.clone(), + name_fn, + ram_type, + record, + rlc_record, + ) + } + + pub fn read_record_with_selector, N: FnOnce() -> NR>( + &mut self, + selector: Option>, + name_fn: N, + ram_type: RAMType, + record: Vec>, + ) -> Result<(), CircuitBuilderError> { + let rlc_record = self.rlc_chip_record(record.clone()); + self.read_rlc_record_with_selector( + selector, + name_fn, + (ram_type as u64).into(), + record, + rlc_record, + ) + } + + pub fn read_rlc_record_with_selector, N: FnOnce() -> NR>( + &mut self, + selector: Option>, + name_fn: N, + ram_type: Expression, + record: Vec>, + rlc_record: Expression, + ) -> Result<(), CircuitBuilderError> { let path = self.ns.compute_path(name_fn().into()); - self.r_expressions_namespace_map.push(path); // Since r_expression is RLC(record) and when we're debugging // it's helpful to recover the value of record itself. - self.r_ram_types.push((ram_type, record)); + self.group_mut(selector) + .r_expressions + .push(RecordExpression::new(rlc_record, path, (ram_type, record))); Ok(()) } @@ -477,12 +528,46 @@ impl ConstraintSystem { record: Vec>, rlc_record: Expression, ) -> Result<(), CircuitBuilderError> { - self.w_expressions.push(rlc_record); + self.write_rlc_record_with_selector( + self.default_write_selector.clone(), + name_fn, + ram_type, + record, + rlc_record, + ) + } + + pub fn write_record_with_selector, N: FnOnce() -> NR>( + &mut self, + selector: Option>, + name_fn: N, + ram_type: RAMType, + record: Vec>, + ) -> Result<(), CircuitBuilderError> { + let rlc_record = self.rlc_chip_record(record.clone()); + self.write_rlc_record_with_selector( + selector, + name_fn, + (ram_type as u64).into(), + record, + rlc_record, + ) + } + + pub fn write_rlc_record_with_selector, N: FnOnce() -> NR>( + &mut self, + selector: Option>, + name_fn: N, + ram_type: Expression, + record: Vec>, + rlc_record: Expression, + ) -> Result<(), CircuitBuilderError> { let path = self.ns.compute_path(name_fn().into()); - self.w_expressions_namespace_map.push(path); // Since w_expression is RLC(record) and when we're debugging // it's helpful to recover the value of record itself. - self.w_ram_types.push((ram_type, record)); + self.group_mut(selector) + .w_expressions + .push(RecordExpression::new(rlc_record, path, (ram_type, record))); Ok(()) } @@ -510,15 +595,29 @@ impl ConstraintSystem { &mut self, name_fn: N, assert_zero_expr: Expression, + ) -> Result<(), CircuitBuilderError> { + self.require_zero_with_selector( + self.default_zero_selector.clone(), + name_fn, + assert_zero_expr, + ) + } + + pub fn require_zero_with_selector, N: FnOnce() -> NR>( + &mut self, + selector: Option>, + name_fn: N, + assert_zero_expr: Expression, ) -> Result<(), CircuitBuilderError> { assert!( assert_zero_expr.degree() > 0, "constant expression assert to zero ?" ); if assert_zero_expr.degree() == 1 { - self.assert_zero_expressions.push(assert_zero_expr); let path = self.ns.compute_path(name_fn().into()); - self.assert_zero_expressions_namespace_map.push(path); + self.group_mut(selector) + .assert_zero_expressions + .push(RecordExpression::new(assert_zero_expr, path, ())); } else { let assert_zero_expr = if assert_zero_expr.is_monomial_form() { assert_zero_expr @@ -528,10 +627,10 @@ impl ConstraintSystem { e }; self.max_non_lc_degree = self.max_non_lc_degree.max(assert_zero_expr.degree()); - self.assert_zero_sumcheck_expressions.push(assert_zero_expr); let path = self.ns.compute_path(name_fn().into()); - self.assert_zero_sumcheck_expressions_namespace_map - .push(path); + self.group_mut(selector) + .assert_zero_sumcheck_expressions + .push(RecordExpression::new(assert_zero_expr, path, ())); } Ok(()) } @@ -550,6 +649,188 @@ impl ConstraintSystem { pub fn set_omc_init_only(&mut self) { self.with_omc_init_only = true; } + + pub fn set_default_read_selector(&mut self, selector: SelectorType) { + self.default_read_selector = Some(selector); + let read_expressions = mem::take(&mut self.group_mut(None).r_expressions); + let read_table_expressions = mem::take(&mut self.group_mut(None).r_table_expressions); + let group = self.group_mut(self.default_read_selector.clone()); + group.r_expressions.extend(read_expressions); + group.r_table_expressions.extend(read_table_expressions); + } + + pub fn set_default_write_selector(&mut self, selector: SelectorType) { + self.default_write_selector = Some(selector); + let write_expressions = mem::take(&mut self.group_mut(None).w_expressions); + let write_table_expressions = mem::take(&mut self.group_mut(None).w_table_expressions); + let group = self.group_mut(self.default_write_selector.clone()); + group.w_expressions.extend(write_expressions); + group.w_table_expressions.extend(write_table_expressions); + } + + pub fn set_default_lookup_selector(&mut self, selector: SelectorType) { + self.default_lookup_selector = Some(selector); + let lk_expressions = mem::take(&mut self.group_mut(None).lk_expressions); + let lk_table_expressions = mem::take(&mut self.group_mut(None).lk_table_expressions); + let group = self.group_mut(self.default_lookup_selector.clone()); + group.lk_expressions.extend(lk_expressions); + group.lk_table_expressions.extend(lk_table_expressions); + } + + pub fn set_default_zero_selector(&mut self, selector: SelectorType) { + self.default_zero_selector = Some(selector); + let assert_zero_expressions = mem::take(&mut self.group_mut(None).assert_zero_expressions); + let assert_zero_sumcheck_expressions = + mem::take(&mut self.group_mut(None).assert_zero_sumcheck_expressions); + let group = self.group_mut(self.default_zero_selector.clone()); + group + .assert_zero_expressions + .extend(assert_zero_expressions); + group + .assert_zero_sumcheck_expressions + .extend(assert_zero_sumcheck_expressions); + } + + pub fn set_all_default_selectors(&mut self, selector: SelectorType) { + self.set_default_read_selector(selector.clone()); + self.set_default_write_selector(selector.clone()); + self.set_default_lookup_selector(selector.clone()); + self.set_default_zero_selector(selector); + self.expression_groups.remove(&None); + } +} + +#[allow(clippy::type_complexity)] +impl ConstraintSystem { + fn group_mut(&mut self, selector: Option>) -> &mut ExpressionGroup { + self.expression_groups.entry(selector).or_default() + } + + pub fn group(&self, selector: &Option>) -> Option<&ExpressionGroup> { + self.expression_groups.get(selector) + } + + pub fn selector_len(&self) -> usize { + self.expression_groups.keys().len() + } + + pub fn expressions_len(&self) -> usize { + self.non_zero_expressions_len() + self.zero_expressions_len() + } + + pub fn non_zero_expressions_len(&self) -> usize { + self.expression_groups + .values() + .map(|g| { + g.r_expressions.len() + + g.w_expressions.len() + + g.r_table_expressions.len() + + g.w_table_expressions.len() + + g.lk_expressions.len() + + g.lk_table_expressions.len() * 2 + }) + .sum() + } + + pub fn zero_expressions_len(&self) -> usize { + self.expression_groups + .values() + .map(|g| g.assert_zero_expressions.len() + g.assert_zero_sumcheck_expressions.len()) + .sum() + } + + pub fn input_evaluations_len(&self) -> usize { + self.non_zero_expressions_len() + + self.num_fixed + + self.num_witin as usize + + self.instance_openings.len() + } + + pub fn output_evaluations_len(&self) -> usize { + self.expression_groups + .values() + .map(|g| { + g.r_expressions.len() + + g.w_expressions.len() + + g.r_table_expressions.len() + + g.w_table_expressions.len() + + g.lk_expressions.len() + + g.lk_table_expressions.len() * 2 + }) + .sum::() + } + + pub fn r_expressions_len(&self) -> usize { + self.expression_groups + .values() + .map(|g| g.r_expressions.len()) + .sum() + } + + pub fn r_table_expressions_len(&self) -> usize { + self.expression_groups + .values() + .map(|g| g.r_table_expressions.len()) + .sum() + } + + pub fn r_table_expressions_all(&self) -> impl Iterator> + '_ { + self.expression_groups + .values() + .flat_map(|g| g.r_table_expressions.iter()) + } + + pub fn w_expressions_len(&self) -> usize { + self.expression_groups + .values() + .map(|g| g.w_expressions.len()) + .sum() + } + + pub fn w_table_expressions_len(&self) -> usize { + self.expression_groups + .values() + .map(|g| g.w_table_expressions.len()) + .sum() + } + + pub fn w_table_expressions_all(&self) -> impl Iterator> + '_ { + self.expression_groups + .values() + .flat_map(|g| g.w_table_expressions.iter()) + } + + pub fn lk_expressions_len(&self) -> usize { + self.expression_groups + .values() + .map(|g| g.lk_expressions.len()) + .sum() + } + + pub fn lk_table_expressions_len(&self) -> usize { + self.expression_groups + .values() + .map(|g| g.lk_table_expressions.len()) + .sum() + } + + pub fn lk_table_expressions_all(&self) -> impl Iterator> + '_ { + self.expression_groups + .values() + .flat_map(|g| g.lk_table_expressions.iter()) + } + + pub fn unique_expr_group(&self) -> &ExpressionGroup { + assert_eq!( + self.expression_groups.len(), + 1, + "only support single selector for unique_expr_group retrieval" + ); + self.expression_groups + .values() + .next() + .expect("at least one expression group") + } } impl ConstraintSystem { @@ -592,7 +873,7 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { cb: impl for<'b> FnOnce(&mut CircuitBuilder<'b, E>) -> Result, ) -> Result { self.cs.namespace(name_fn, |cs| { - let mut inner_circuit_builder = CircuitBuilder::<'_, E>::new(cs); + let mut inner_circuit_builder = CircuitBuilder::<'_, E> { cs }; cb(&mut inner_circuit_builder) }) } @@ -648,6 +929,23 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { self.cs.create_fixed(name_fn) } + pub fn with_selector( + &mut self, + selector: SelectorType, + action: impl FnOnce(&mut Self) -> Result<(), CircuitBuilderError>, + ) -> Result<(), CircuitBuilderError> { + let previous_read_selector = self.cs.default_read_selector.replace(selector.clone()); + let previous_write_selector = self.cs.default_write_selector.replace(selector.clone()); + let previous_lookup_selector = self.cs.default_lookup_selector.replace(selector.clone()); + let previous_zero_selector = self.cs.default_zero_selector.replace(selector.clone()); + let result = action(self); + self.cs.default_read_selector = previous_read_selector; + self.cs.default_write_selector = previous_write_selector; + self.cs.default_lookup_selector = previous_lookup_selector; + self.cs.default_zero_selector = previous_zero_selector; + result + } + pub fn lk_record( &mut self, name_fn: N, diff --git a/gkr_iop/src/circuit_builder/ram.rs b/gkr_iop/src/circuit_builder/ram.rs index 30b63fc06..1270f2f50 100644 --- a/gkr_iop/src/circuit_builder/ram.rs +++ b/gkr_iop/src/circuit_builder/ram.rs @@ -1,4 +1,4 @@ -use crate::{RAMType, error::CircuitBuilderError}; +use crate::{RAMType, error::CircuitBuilderError, selector::SelectorType}; use ff_ext::ExtensionField; use crate::circuit_builder::DebugIndex; @@ -117,4 +117,76 @@ impl CircuitBuilder<'_, E> { Ok((next_ts, lt_cfg)) }) } + + #[allow(clippy::too_many_arguments)] + pub fn ram_type_write_with_rw_selectors< + const LIMBS: usize, + NR: Into, + N: FnOnce() -> NR, + >( + &mut self, + name_fn: N, + ram_type: RAMType, + identifier: impl ToExpr>, + prev_ts: Expression, + ts: Expression, + prev_values: [Expression; LIMBS], + value: [Expression; LIMBS], + r_selector: &SelectorType, + w_selector: &SelectorType, + ) -> Result<(Expression, AssertLtConfig), CircuitBuilderError> { + assert!(identifier.expr().degree() <= 1); + assert!(E::BaseField::bits() > Self::MAX_TS_BITS); + self.namespace(name_fn, |cb| { + // READ (a, v, t) + let read_record = [ + vec![ram_type.into()], + vec![identifier.expr()], + prev_values.to_vec(), + vec![prev_ts.clone()], + ] + .concat(); + // Write (a, v, t) + let write_record = [ + vec![ram_type.into()], + vec![identifier.expr()], + value.to_vec(), + vec![ts.clone()], + ] + .concat(); + + cb.with_selector(r_selector.clone(), |cb| { + cb.read_record(|| "read_record", ram_type, read_record) + })?; + cb.with_selector(w_selector.clone(), |cb| { + cb.write_record(|| "write_record", ram_type, write_record) + })?; + + let lt_cfg = AssertLtConfig::construct_circuit( + cb, + || "prev_ts < ts", + prev_ts, + ts.clone(), + Self::MAX_TS_BITS, + )?; + + let next_ts = ts + 1; + + if matches!(ram_type, RAMType::Register) { + let pow_u16 = power_sequence((1 << u16::BITS as u64).into()); + cb.register_debug_expr( + DebugIndex::RdWrite as usize, + izip!(value.clone(), pow_u16).map(|(v, pow)| v * pow).sum(), + ); + } else if matches!(ram_type, RAMType::Memory) { + let pow_u16 = power_sequence((1 << u16::BITS as u64).into()); + cb.register_debug_expr( + DebugIndex::MemWrite as usize, + izip!(value, pow_u16).map(|(v, pow)| v * pow).sum(), + ); + } + + Ok((next_ts, lt_cfg)) + }) + } } diff --git a/gkr_iop/src/gkr.rs b/gkr_iop/src/gkr.rs index b025aa1e4..f31c75244 100644 --- a/gkr_iop/src/gkr.rs +++ b/gkr_iop/src/gkr.rs @@ -1,4 +1,5 @@ use core::fmt; +use std::fmt::Debug; use ff_ext::ExtensionField; use itertools::{Itertools, izip}; @@ -29,14 +30,30 @@ pub struct GKRCircuit { pub n_evaluations: usize, } -#[derive(Clone, Debug)] +#[derive(Clone)] pub struct GKRCircuitWitness<'a, PB: ProverBackend> { pub layers: Vec>, } -#[derive(Clone, Debug)] +impl Debug for GKRCircuitWitness<'_, PB> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("GKRCircuitWitness") + .field("layers", &self.layers) + .finish() + } +} + +#[derive(Clone)] pub struct GKRCircuitOutput<'a, PB: ProverBackend>(pub LayerWitness<'a, PB>); +impl Debug for GKRCircuitOutput<'_, PB> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("GKRCircuitOutput") + .field("layer", &self.0) + .finish() + } +} + #[derive(Clone, Serialize, Deserialize)] #[serde(bound( serialize = "E::BaseField: Serialize, Evaluation: Serialize", diff --git a/gkr_iop/src/gkr/layer.rs b/gkr_iop/src/gkr/layer.rs index a67862df7..c62a4da8a 100644 --- a/gkr_iop/src/gkr/layer.rs +++ b/gkr_iop/src/gkr/layer.rs @@ -1,4 +1,3 @@ -use either::Either; use ff_ext::ExtensionField; use itertools::{Itertools, chain, izip}; use linear_layer::{LayerClaims, LinearLayer}; @@ -10,7 +9,7 @@ use multilinear_extensions::{ use p3::field::FieldAlgebra; use rayon::iter::{IntoParallelIterator, IntoParallelRefIterator}; use serde::{Deserialize, Serialize, de::DeserializeOwned}; -use std::{ops::Neg, sync::Arc, vec::IntoIter}; +use std::{fmt::Debug, ops::Neg, sync::Arc, vec::IntoIter}; use sumcheck_layer::LayerProof; use transcript::Transcript; use zerocheck_layer::ZerocheckLayer; @@ -125,9 +124,15 @@ pub struct Layer { pub rotation_sumcheck_expression: Option>, } -#[derive(Clone, Debug)] +#[derive(Clone)] pub struct LayerWitness<'a, PB: ProverBackend>(pub Vec>>); +impl Debug for LayerWitness<'_, PB> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_list().entries(self.0.iter()).finish() + } +} + impl<'a, PB: ProverBackend> std::ops::Index for LayerWitness<'a, PB> { type Output = Arc>; @@ -297,6 +302,10 @@ impl Layer { Ok(in_point) } + pub fn selector_ctxs_len(&self) -> usize { + self.out_sel_and_eval_exprs.len() + } + // extract claim and dudup point fn extract_claim_and_point( &self, @@ -348,143 +357,200 @@ impl Layer { cb: &CircuitBuilder, layer_name: String, n_challenges: usize, - out_evals: OutEvalGroups, + out_evals: OutEvalGroups, ) -> Layer { - let w_len = cb.cs.w_expressions.len() + cb.cs.w_table_expressions.len(); - let r_len = cb.cs.r_expressions.len() + cb.cs.r_table_expressions.len(); - let lk_len = cb.cs.lk_expressions.len() + cb.cs.lk_table_expressions.len() * 2; // logup lk table include p, q - let zero_len = - cb.cs.assert_zero_expressions.len() + cb.cs.assert_zero_sumcheck_expressions.len(); - - let [r_record_evals, w_record_evals, lookup_evals, zero_evals] = out_evals; - assert_eq!(r_record_evals.len(), r_len); - assert_eq!(w_record_evals.len(), w_len); - assert_eq!(lookup_evals.len(), lk_len); - assert_eq!(zero_evals.len(), zero_len); - - let non_zero_expr_len = cb.cs.w_expressions.len() - + cb.cs.w_table_expressions.len() - + cb.cs.r_expressions.len() - + cb.cs.r_table_expressions.len() - + cb.cs.lk_expressions.len() - + cb.cs.lk_table_expressions.len() * 2; - let zero_expr_len = - cb.cs.assert_zero_expressions.len() + cb.cs.assert_zero_sumcheck_expressions.len(); - - let mut expr_evals = Vec::with_capacity(4); - let mut expr_names = Vec::with_capacity(non_zero_expr_len + zero_expr_len); - let mut expressions = Vec::with_capacity(non_zero_expr_len + zero_expr_len); - - if let Some(r_selector) = cb.cs.r_selector.as_ref() { - // process r_record - let evals = Self::dedup_last_selector_evals(r_selector, &mut expr_evals); - for (idx, ((ram_expr, name), ram_eval)) in (cb - .cs - .r_expressions - .iter() - .chain(cb.cs.r_table_expressions.iter().map(|t| &t.expr))) - .zip_eq( - cb.cs - .r_expressions_namespace_map + let mut expr_evals = vec![]; + let mut expr_names = Vec::with_capacity(cb.cs.expressions_len()); + let mut expressions = Vec::with_capacity(cb.cs.expressions_len()); + + for (selector, group) in cb.cs.expression_groups.iter() { + let Some(selector) = selector else { + assert!(group.is_empty(), "all expressions must have a selector"); + continue; + }; + let [r_record_evals, w_record_evals, lookup_evals] = out_evals.get(selector).unwrap(); + let (r_expr_evals, r_table_evals) = r_record_evals.split_at(group.r_expressions.len()); + let (w_expr_evals, w_table_evals) = w_record_evals.split_at(group.w_expressions.len()); + let (lk_expr_evals, lk_table_evals) = lookup_evals.split_at(group.lk_expressions.len()); + let (lk_table_mult_evals, lk_table_val_evals) = + lk_table_evals.split_at(group.lk_table_expressions.len()); + + extend_evals_and_exprs( + selector, + group + .r_expressions .iter() - .chain(&cb.cs.r_table_expressions_namespace_map), - ) - .zip_eq(&r_record_evals) - .enumerate() - { - expressions.push(ram_expr - E::BaseField::ONE.expr()); - evals.push(EvalExpression::::Linear( - // evaluation = claim * one - one (padding) - *ram_eval, - E::BaseField::ONE.expr().into(), - E::BaseField::ONE.neg().expr().into(), - )); - expr_names.push(format!("{}/{idx}", name)); + .map(|re| (&re.expression, &re.expression_namespace_map)), + r_expr_evals, + &mut expr_evals, + &mut expressions, + &mut expr_names, + |ram_expr| ram_expr - E::BaseField::ONE.expr(), + |ram_eval| { + EvalExpression::Linear( + // evaluation = claim * one - one (padding) + ram_eval, + E::BaseField::ONE.expr().into(), + E::BaseField::ONE.neg().expr().into(), + ) + }, + ); + if !group.r_table_expressions.is_empty() { + extend_evals_and_exprs( + selector, + group + .r_table_expressions + .iter() + .map(|re| (&re.expr, &re.expression_namespace_map)), + r_table_evals, + &mut expr_evals, + &mut expressions, + &mut expr_names, + |ram_expr| ram_expr - E::BaseField::ONE.expr(), + |ram_eval| { + EvalExpression::Linear( + // evaluation = claim * one - one (padding) + ram_eval, + E::BaseField::ONE.expr().into(), + E::BaseField::ONE.neg().expr().into(), + ) + }, + ); } - } - if let Some(w_selector) = cb.cs.w_selector.as_ref() { - // process w_record - let evals = Self::dedup_last_selector_evals(w_selector, &mut expr_evals); - for (idx, ((ram_expr, name), ram_eval)) in (cb - .cs - .w_expressions - .iter() - .chain(cb.cs.w_table_expressions.iter().map(|t| &t.expr))) - .zip_eq( - cb.cs - .w_expressions_namespace_map + extend_evals_and_exprs( + selector, + group + .w_expressions .iter() - .chain(&cb.cs.w_table_expressions_namespace_map), - ) - .zip_eq(&w_record_evals) - .enumerate() - { - expressions.push(ram_expr - E::BaseField::ONE.expr()); - evals.push(EvalExpression::::Linear( - // evaluation = claim * one - one (padding) - *ram_eval, - E::BaseField::ONE.expr().into(), - E::BaseField::ONE.neg().expr().into(), - )); - expr_names.push(format!("{}/{idx}", name)); - } - } - - if let Some(lk_selector) = cb.cs.lk_selector.as_ref() { - // process lookup records - let evals = Self::dedup_last_selector_evals(lk_selector, &mut expr_evals); - for (idx, ((lookup, name), lookup_eval)) in (cb - .cs - .lk_expressions - .iter() - .chain(cb.cs.lk_table_expressions.iter().map(|t| &t.multiplicity)) - .chain(cb.cs.lk_table_expressions.iter().map(|t| &t.values))) - .zip_eq(if cb.cs.lk_table_expressions.is_empty() { - Either::Left(cb.cs.lk_expressions_namespace_map.iter()) - } else { - // repeat expressions_namespace_map twice to deal with lk p, q - Either::Right( - cb.cs - .lk_expressions_namespace_map + .map(|re| (&re.expression, &re.expression_namespace_map)), + w_expr_evals, + &mut expr_evals, + &mut expressions, + &mut expr_names, + |ram_expr| ram_expr - E::BaseField::ONE.expr(), + |ram_eval| { + EvalExpression::Linear( + // evaluation = claim * one - one (padding) + ram_eval, + E::BaseField::ONE.expr().into(), + E::BaseField::ONE.neg().expr().into(), + ) + }, + ); + if !group.w_table_expressions.is_empty() { + extend_evals_and_exprs( + selector, + group + .w_table_expressions .iter() - .chain(&cb.cs.lk_expressions_namespace_map), - ) - }) - .zip_eq(&lookup_evals) - .enumerate() - { - expressions.push(lookup - cb.cs.chip_record_alpha.clone()); - evals.push(EvalExpression::::Linear( - // evaluation = claim * one - alpha (padding) - *lookup_eval, - E::BaseField::ONE.expr().into(), - cb.cs.chip_record_alpha.clone().neg().into(), - )); - expr_names.push(format!("{}/{idx}", name)); + .map(|re| (&re.expr, &re.expression_namespace_map)), + w_table_evals, + &mut expr_evals, + &mut expressions, + &mut expr_names, + |ram_expr| ram_expr - E::BaseField::ONE.expr(), + |ram_eval| { + EvalExpression::Linear( + // evaluation = claim * one - one (padding) + ram_eval, + E::BaseField::ONE.expr().into(), + E::BaseField::ONE.neg().expr().into(), + ) + }, + ); } - } - if let Some(zero_selector) = cb.cs.zero_selector.as_ref() { - // process zero_record - let evals = Self::dedup_last_selector_evals(zero_selector, &mut expr_evals); - for (idx, (zero_expr, name)) in izip!( - 0.., - chain!( - cb.cs - .assert_zero_expressions + extend_evals_and_exprs( + selector, + group + .lk_expressions + .iter() + .map(|re| (&re.expression, &re.expression_namespace_map)), + lk_expr_evals, + &mut expr_evals, + &mut expressions, + &mut expr_names, + |lookup| lookup - cb.cs.chip_record_alpha.clone(), + |lookup_eval| { + EvalExpression::::Linear( + // evaluation = claim * one - alpha (padding) + lookup_eval, + E::BaseField::ONE.expr().into(), + cb.cs.chip_record_alpha.clone().neg().into(), + ) + }, + ); + if !group.lk_table_expressions.is_empty() { + extend_evals_and_exprs( + selector, + group + .lk_table_expressions .iter() - .zip_eq(&cb.cs.assert_zero_expressions_namespace_map), - cb.cs - .assert_zero_sumcheck_expressions + .map(|re| (&re.multiplicity, &re.expression_namespace_map)), + lk_table_mult_evals, + &mut expr_evals, + &mut expressions, + &mut expr_names, + |lookup| lookup - cb.cs.chip_record_alpha.clone(), + |lookup_eval| { + EvalExpression::::Linear( + // evaluation = claim * one - alpha (padding) + lookup_eval, + E::BaseField::ONE.expr().into(), + cb.cs.chip_record_alpha.clone().neg().into(), + ) + }, + ); + extend_evals_and_exprs( + selector, + group + .lk_table_expressions .iter() - .zip_eq(&cb.cs.assert_zero_sumcheck_expressions_namespace_map) - ) - ) { - expressions.push(zero_expr.clone()); - evals.push(EvalExpression::Zero); - expr_names.push(format!("{}/{idx}", name)); + .map(|re| (&re.values, &re.expression_namespace_map)), + lk_table_val_evals, + &mut expr_evals, + &mut expressions, + &mut expr_names, + |lookup| lookup - cb.cs.chip_record_alpha.clone(), + |lookup_eval| { + EvalExpression::::Linear( + // evaluation = claim * one - alpha (padding) + lookup_eval, + E::BaseField::ONE.expr().into(), + cb.cs.chip_record_alpha.clone().neg().into(), + ) + }, + ); } + + extend_evals_and_exprs( + selector, + group + .assert_zero_expressions + .iter() + .map(|re| (&re.expression, &re.expression_namespace_map)), + &vec![0; group.assert_zero_expressions.len()], + &mut expr_evals, + &mut expressions, + &mut expr_names, + |zero_expr| zero_expr.clone(), + |_| EvalExpression::Zero, + ); + + extend_evals_and_exprs( + selector, + group + .assert_zero_sumcheck_expressions + .iter() + .map(|re| (&re.expression, &re.expression_namespace_map)), + &vec![0; group.assert_zero_sumcheck_expressions.len()], + &mut expr_evals, + &mut expressions, + &mut expr_names, + |zero_expr| zero_expr.clone(), + |_| EvalExpression::Zero, + ); } // Sort expressions, expr_names, and evals according to eval.0 and classify evals. @@ -494,7 +560,7 @@ impl Layer { .. } = &cb.cs; - let in_eval_expr = (non_zero_expr_len..) + let in_eval_expr = (cb.cs.non_zero_expressions_len()..) .take(cb.cs.num_witin as usize + cb.cs.num_fixed + cb.cs.instance_openings.len()) .collect_vec(); if rotations.is_empty() { @@ -545,26 +611,28 @@ impl Layer { ) } } +} - // return previous evals for extend, if new selector match with last selector - // otherwise push new evals and return it for mutability - fn dedup_last_selector_evals<'a>( - new_selector: &SelectorType, - expr_evals: &'a mut Vec<(SelectorType, Vec>)>, - ) -> &'a mut Vec> - where - SelectorType: Clone + PartialEq, - { - let need_push = match expr_evals.last() { - Some((last_sel, _)) => last_sel != new_selector, - None => true, - }; - - if need_push { - expr_evals.push((new_selector.clone(), vec![])); - } +#[allow(clippy::too_many_arguments)] +fn extend_evals_and_exprs<'a, E: ExtensionField>( + selector: &SelectorType, + record_exprs: impl Iterator, &'a String)>, + record_evals: &[usize], + expr_evals: &mut Vec>, + expressions: &mut Vec>, + expr_names: &mut Vec, + compute_expr: impl Fn(&Expression) -> Expression, + compute_eval: impl Fn(usize) -> EvalExpression, +) { + if expr_evals.is_empty() || expr_evals.last().unwrap().0 != *selector { + expr_evals.push((selector.clone(), vec![])); + } - &mut expr_evals.last_mut().unwrap().1 + let evals: &mut Vec> = expr_evals.last_mut().unwrap().1.as_mut(); + for (idx, ((expr, name), eval)) in record_exprs.zip_eq(record_evals).enumerate() { + expressions.push(compute_expr(expr)); + evals.push(compute_eval(*eval)); + expr_names.push(format!("{}/{idx}", name)); } } diff --git a/gkr_iop/src/gkr/layer/gpu/mod.rs b/gkr_iop/src/gkr/layer/gpu/mod.rs index 2ce1436d7..a5cfe5e4a 100644 --- a/gkr_iop/src/gkr/layer/gpu/mod.rs +++ b/gkr_iop/src/gkr/layer/gpu/mod.rs @@ -130,6 +130,12 @@ impl> ZerocheckLayerProver layer.out_sel_and_eval_exprs.len(), out_points.len(), ); + assert_eq!( + layer.out_sel_and_eval_exprs.len(), + selector_ctxs.len(), + "selector_ctxs length {}", + selector_ctxs.len() + ); let (_, raw_rotation_exprs) = &layer.rotation_exprs; let (rotation_proof, rotation_left, rotation_right, rotation_point) = @@ -174,34 +180,36 @@ impl> ZerocheckLayerProver let span_eq = entered_span!("build eqs", profiling_2 = true); let cuda_hal = get_cuda_hal().unwrap(); - let eqs_gpu = layer + let eqs = layer .out_sel_and_eval_exprs .iter() .zip(out_points.iter()) .zip(selector_ctxs.iter()) - .map(|(((sel_type, _), point), selector_ctx)| { - build_eq_x_r_with_sel_gpu(&cuda_hal, point, selector_ctx, sel_type) + .filter_map(|(((sel_type, _), point), selector_ctx)| { + Some(build_eq_x_r_with_sel_gpu( + &cuda_hal, + point, + selector_ctx, + sel_type, + )) }) - // for rotation left point .chain( rotation_left .iter() .map(|rotation_left| build_eq_x_r_gpu(&cuda_hal, rotation_left)), ) - // for rotation right point .chain( rotation_right .iter() .map(|rotation_right| build_eq_x_r_gpu(&cuda_hal, rotation_right)), ) - // for rotation point .chain( rotation_point .iter() .map(|rotation_point| build_eq_x_r_gpu(&cuda_hal, rotation_point)), ) .collect::>(); - // `wit` := witin ++ fixed ++ pubio + let all_witins_gpu = wit .iter() .take(layer.n_witin + layer.n_fixed + layer.n_instance) @@ -222,7 +230,7 @@ impl> ZerocheckLayerProver ) .map(|mle| mle.as_ref()), ) - .chain(eqs_gpu.iter()) + .chain(eqs.iter()) .collect_vec(); assert_eq!( all_witins_gpu.len(), diff --git a/gkr_iop/src/lib.rs b/gkr_iop/src/lib.rs index 4f6217399..8ecdb82d0 100644 --- a/gkr_iop/src/lib.rs +++ b/gkr_iop/src/lib.rs @@ -1,14 +1,14 @@ #![feature(variant_count)] use crate::{ chip::Chip, circuit_builder::CircuitBuilder, error::CircuitBuilderError, - utils::lk_multiplicity::LkMultiplicity, + selector::SelectorType, utils::lk_multiplicity::LkMultiplicity, }; use either::Either; use ff_ext::ExtensionField; +use itertools::Itertools; use multilinear_extensions::{Expression, impl_expr_from_unsigned, mle::ArcMultilinearExtension}; -use std::marker::PhantomData; +use std::collections::BTreeMap; use strum_macros::EnumIter; -use transcript::Transcript; use witness::RowMajorMatrix; pub mod chip; @@ -28,8 +28,7 @@ pub mod utils; pub type Phase1WitnessGroup<'a, E> = Vec>; // format: [r_records, w_records, lk_records, zero_records] -pub type OutEvalGroups = [Vec; 4]; - +pub type OutEvalGroups = BTreeMap, [Vec; 3]>; pub trait ProtocolBuilder: Sized { type Params; @@ -41,24 +40,7 @@ pub trait ProtocolBuilder: Sized { params: Self::Params, ) -> Result; - fn finalize(&mut self, cb: &mut CircuitBuilder) -> (OutEvalGroups, Chip); - - fn n_committed(&self) -> usize { - todo!() - } - fn n_fixed(&self) -> usize { - todo!() - } - fn n_challenges(&self) -> usize { - todo!() - } - fn n_evaluations(&self) -> usize { - todo!() - } - - fn n_layers(&self) -> usize { - todo!() - } + fn finalize(&self, name: String, cb: &mut CircuitBuilder) -> Chip; } pub trait ProtocolWitnessGenerator { type Trace; @@ -75,18 +57,6 @@ pub trait ProtocolWitnessGenerator { ); } -// TODO: the following trait consists of `commit_phase1`, `commit_phase2`, -// `gkr_phase` and `opening_phase`. -pub struct ProtocolProver, PCS>( - PhantomData<(E, Trans, PCS)>, -); - -// TODO: the following trait consists of `commit_phase1`, `commit_phase2`, -// `gkr_phase` and `opening_phase`. -pub struct ProtocolVerifier, PCS>( - PhantomData<(E, Trans, PCS)>, -); - #[derive(Clone, Debug, Copy, EnumIter, PartialEq, Eq, serde::Serialize, serde::Deserialize)] #[repr(usize)] pub enum RAMType { @@ -97,3 +67,28 @@ pub enum RAMType { } impl_expr_from_unsigned!(RAMType); + +pub fn default_out_eval_groups(cb: &CircuitBuilder) -> OutEvalGroups { + let mut next_idx = 0usize; + let mut evals = BTreeMap::new(); + for (selector, group) in cb.cs.expression_groups.iter() { + let Some(selector) = selector else { + assert!(group.is_empty(), "all expressions must have a selector"); + continue; + }; + let r_len = group.r_expressions.len() + group.r_table_expressions.len(); + let r_evals = (next_idx..next_idx + r_len).collect_vec(); + next_idx += r_len; + + let w_len = group.w_expressions.len() + group.w_table_expressions.len(); + let w_evals = (next_idx..next_idx + w_len).collect_vec(); + next_idx += w_len; + + let lk_len = group.lk_expressions.len() + group.lk_table_expressions.len() * 2; + let lk_evals = (next_idx..next_idx + lk_len).collect_vec(); + next_idx += lk_len; + + evals.insert(selector.clone(), [r_evals, w_evals, lk_evals]); + } + evals +} diff --git a/gkr_iop/src/selector.rs b/gkr_iop/src/selector.rs index 857dbd588..80644a03e 100644 --- a/gkr_iop/src/selector.rs +++ b/gkr_iop/src/selector.rs @@ -1,4 +1,4 @@ -use std::iter::repeat_n; +use std::{cmp::Ordering, iter::repeat_n}; use rayon::iter::IndexedParallelIterator; @@ -37,7 +37,7 @@ impl SelectorContext { } /// Selector selects part of the witnesses in the sumcheck protocol. -#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)] +#[derive(Clone, Debug, Serialize, Deserialize)] #[serde(bound( serialize = "E::BaseField: Serialize", deserialize = "E::BaseField: DeserializeOwned" @@ -380,6 +380,31 @@ impl SelectorType { } } +impl Ord for SelectorType { + fn cmp(&self, other: &Self) -> Ordering { + match (self, other) { + (SelectorType::None, SelectorType::None) => Ordering::Equal, + (SelectorType::None, _) => Ordering::Less, + (_, SelectorType::None) => Ordering::Greater, + _ => self.selector_expr().cmp(other.selector_expr()), + } + } +} + +impl PartialOrd for SelectorType { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl PartialEq for SelectorType { + fn eq(&self, other: &Self) -> bool { + self.cmp(other) == Ordering::Equal + } +} + +impl Eq for SelectorType {} + #[cfg(test)] mod tests { use ff_ext::{BabyBearExt4, FromUniformBytes}; From 29f4af0ef601808fabc1519c428f2add6646c777 Mon Sep 17 00:00:00 2001 From: spherel <101384151+spherel@users.noreply.github.com> Date: Mon, 2 Feb 2026 19:30:14 +0800 Subject: [PATCH 2/3] cargo clippy --- ceno_zkvm/src/chip_handler.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/ceno_zkvm/src/chip_handler.rs b/ceno_zkvm/src/chip_handler.rs index 0d596b30f..4eb98bc34 100644 --- a/ceno_zkvm/src/chip_handler.rs +++ b/ceno_zkvm/src/chip_handler.rs @@ -53,6 +53,7 @@ pub type AddressExpr = Expression; /// Format: `[u16; UINT_LIMBS]`, least-significant-first. pub type MemoryExpr = [Expression; UINT_LIMBS]; +#[allow(clippy::type_complexity)] pub trait MemoryChipOperations, N: FnOnce() -> NR> { fn memory_read( &mut self, From 5efb7ab8af3e4a1d330d2fbf60855c35b308aab0 Mon Sep 17 00:00:00 2001 From: spherel <101384151+spherel@users.noreply.github.com> Date: Mon, 2 Feb 2026 19:40:13 +0800 Subject: [PATCH 3/3] cargo fmt --- .../src/instructions/riscv/branch/test.rs | 36 ++++++++++++++++--- ceno_zkvm/src/scheme/tests.rs | 4 +-- ceno_zkvm/src/structs.rs | 6 ++-- 3 files changed, 36 insertions(+), 10 deletions(-) diff --git a/ceno_zkvm/src/instructions/riscv/branch/test.rs b/ceno_zkvm/src/instructions/riscv/branch/test.rs index 28acf2991..29b6f8293 100644 --- a/ceno_zkvm/src/instructions/riscv/branch/test.rs +++ b/ceno_zkvm/src/instructions/riscv/branch/test.rs @@ -138,7 +138,14 @@ fn impl_bltu_circuit(taken: bool, a: u32, b: u32) -> Result<(), ZKVMError> { ) .unwrap(); - MockProver::assert_satisfied_raw(&circuit_builder, raw_witin, &[insn_code], 1, None, Some(lkm)); + MockProver::assert_satisfied_raw( + &circuit_builder, + raw_witin, + &[insn_code], + 1, + None, + Some(lkm), + ); Ok(()) } @@ -183,7 +190,14 @@ fn impl_bgeu_circuit(taken: bool, a: u32, b: u32) -> Result<(), ZKVMError> { ) .unwrap(); - MockProver::assert_satisfied_raw(&circuit_builder, raw_witin, &[insn_code], 1, None, Some(lkm)); + MockProver::assert_satisfied_raw( + &circuit_builder, + raw_witin, + &[insn_code], + 1, + None, + Some(lkm), + ); Ok(()) } @@ -235,7 +249,14 @@ fn impl_blt_circuit(taken: bool, a: i32, b: i32) -> Result<() ) .unwrap(); - MockProver::assert_satisfied_raw(&circuit_builder, raw_witin, &[insn_code], 1, None, Some(lkm)); + MockProver::assert_satisfied_raw( + &circuit_builder, + raw_witin, + &[insn_code], + 1, + None, + Some(lkm), + ); Ok(()) } @@ -287,6 +308,13 @@ fn impl_bge_circuit(taken: bool, a: i32, b: i32) -> Result<() ) .unwrap(); - MockProver::assert_satisfied_raw(&circuit_builder, raw_witin, &[insn_code], 1, None, Some(lkm)); + MockProver::assert_satisfied_raw( + &circuit_builder, + raw_witin, + &[insn_code], + 1, + None, + Some(lkm), + ); Ok(()) } diff --git a/ceno_zkvm/src/scheme/tests.rs b/ceno_zkvm/src/scheme/tests.rs index b089b4025..3518e30ad 100644 --- a/ceno_zkvm/src/scheme/tests.rs +++ b/ceno_zkvm/src/scheme/tests.rs @@ -24,8 +24,8 @@ use ceno_emul::{ Platform, Program, StepRecord, VMState, encode_rv32, }; use ff_ext::{ExtensionField, FieldInto, FromUniformBytes, GoldilocksExt2}; -use gkr_iop::cpu::default_backend_config; -use gkr_iop::selector::SelectorType; +use gkr_iop::{cpu::default_backend_config, selector::SelectorType}; + #[cfg(feature = "gpu")] use gkr_iop::gpu::{MultilinearExtensionGpu, gpu_prover::*}; use multilinear_extensions::{ diff --git a/ceno_zkvm/src/structs.rs b/ceno_zkvm/src/structs.rs index b3a815c76..dfac7eade 100644 --- a/ceno_zkvm/src/structs.rs +++ b/ceno_zkvm/src/structs.rs @@ -222,8 +222,7 @@ impl ZKVMConstraintSystem { } pub fn register_opcode_circuit>(&mut self) -> OC::InstructionConfig { - let mut cs = - ConstraintSystem::new(|| format!("riscv_opcode/{}", OC::name())); + let mut cs = ConstraintSystem::new(|| format!("riscv_opcode/{}", OC::name())); let mut circuit_builder = CircuitBuilder::::new(&mut cs); let (config, gkr_iop_circuit) = OC::build_gkr_iop_circuit(&mut circuit_builder, &self.params).unwrap(); @@ -248,8 +247,7 @@ impl ZKVMConstraintSystem { } pub fn register_table_circuit>(&mut self) -> TC::TableConfig { - let mut cs = - ConstraintSystem::new(|| format!("riscv_table/{}", TC::name())); + let mut cs = ConstraintSystem::new(|| format!("riscv_table/{}", TC::name())); let mut circuit_builder = CircuitBuilder::::new(&mut cs); let (config, gkr_iop_circuit) = TC::build_gkr_iop_circuit(&mut circuit_builder, &self.params).unwrap();