From 77c43d128846d29858e2b48755192671e944d139 Mon Sep 17 00:00:00 2001 From: Velaciela Date: Wed, 4 Mar 2026 10:02:34 +0800 Subject: [PATCH 01/73] repr(C) StepRecord --- ceno_emul/src/addr.rs | 4 +- ceno_emul/src/disassemble/mod.rs | 41 +-- ceno_emul/src/lib.rs | 2 +- ceno_emul/src/platform.rs | 4 +- ceno_emul/src/rv32im.rs | 18 +- ceno_emul/src/test_utils.rs | 8 +- ceno_emul/src/tracer.rs | 285 ++++++++++++++---- ceno_emul/src/vm_state.rs | 8 +- ceno_host/tests/test_elf.rs | 87 ++++-- ceno_zkvm/src/e2e.rs | 28 +- .../instructions/riscv/dummy/dummy_ecall.rs | 3 +- .../src/instructions/riscv/dummy/test.rs | 6 +- .../instructions/riscv/ecall/fptower_fp.rs | 5 +- .../riscv/ecall/fptower_fp2_add.rs | 5 +- .../riscv/ecall/fptower_fp2_mul.rs | 5 +- .../src/instructions/riscv/ecall/keccak.rs | 5 +- .../instructions/riscv/ecall/sha_extend.rs | 6 +- .../src/instructions/riscv/ecall/uint256.rs | 10 +- .../riscv/ecall/weierstrass_add.rs | 5 +- .../riscv/ecall/weierstrass_decompress.rs | 5 +- .../riscv/ecall/weierstrass_double.rs | 5 +- .../src/instructions/riscv/ecall_base.rs | 4 +- 22 files changed, 394 insertions(+), 155 deletions(-) diff --git a/ceno_emul/src/addr.rs b/ceno_emul/src/addr.rs index 24a9ceea2..14c938e07 100644 --- a/ceno_emul/src/addr.rs +++ b/ceno_emul/src/addr.rs @@ -30,12 +30,14 @@ pub type Word = u32; pub type SWord = i32; pub type Addr = u32; pub type Cycle = u64; -pub type RegIdx = usize; +pub type RegIdx = u8; #[derive(Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)] +#[repr(C)] pub struct ByteAddr(pub u32); #[derive(Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[repr(C)] pub struct WordAddr(pub u32); impl From for WordAddr { diff --git a/ceno_emul/src/disassemble/mod.rs b/ceno_emul/src/disassemble/mod.rs index 8332a6d6f..853617d63 100644 --- a/ceno_emul/src/disassemble/mod.rs +++ b/ceno_emul/src/disassemble/mod.rs @@ -1,4 +1,7 @@ -use crate::rv32im::{InsnKind, Instruction}; +use crate::{ + addr::RegIdx, + rv32im::{InsnKind, Instruction}, +}; use itertools::izip; use rrs_lib::{ InstructionProcessor, @@ -19,9 +22,9 @@ impl Instruction { pub const fn from_r_type(kind: InsnKind, dec_insn: &RType, raw: u32) -> Self { Self { kind, - rd: dec_insn.rd, - rs1: dec_insn.rs1, - rs2: dec_insn.rs2, + rd: dec_insn.rd as RegIdx, + rs1: dec_insn.rs1 as RegIdx, + rs2: dec_insn.rs2 as RegIdx, imm: 0, raw, } @@ -32,8 +35,8 @@ impl Instruction { pub const fn from_i_type(kind: InsnKind, dec_insn: &IType, raw: u32) -> Self { Self { kind, - rd: dec_insn.rd, - rs1: dec_insn.rs1, + rd: dec_insn.rd as RegIdx, + rs1: dec_insn.rs1 as RegIdx, imm: dec_insn.imm, rs2: 0, raw, @@ -45,8 +48,8 @@ impl Instruction { pub const fn from_i_type_shamt(kind: InsnKind, dec_insn: &ITypeShamt, raw: u32) -> Self { Self { kind, - rd: dec_insn.rd, - rs1: dec_insn.rs1, + rd: dec_insn.rd as RegIdx, + rs1: dec_insn.rs1 as RegIdx, imm: dec_insn.shamt as i32, rs2: 0, raw, @@ -59,8 +62,8 @@ impl Instruction { Self { kind, rd: 0, - rs1: dec_insn.rs1, - rs2: dec_insn.rs2, + rs1: dec_insn.rs1 as RegIdx, + rs2: dec_insn.rs2 as RegIdx, imm: dec_insn.imm, raw, } @@ -72,8 +75,8 @@ impl Instruction { Self { kind, rd: 0, - rs1: dec_insn.rs1, - rs2: dec_insn.rs2, + rs1: dec_insn.rs1 as RegIdx, + rs2: dec_insn.rs2 as RegIdx, imm: dec_insn.imm, raw, } @@ -231,7 +234,7 @@ impl InstructionProcessor for InstructionTranspiler { fn process_jal(&mut self, dec_insn: JType) -> Self::InstructionResult { Instruction { kind: InsnKind::JAL, - rd: dec_insn.rd, + rd: dec_insn.rd as RegIdx, rs1: 0, rs2: 0, imm: dec_insn.imm, @@ -242,8 +245,8 @@ impl InstructionProcessor for InstructionTranspiler { fn process_jalr(&mut self, dec_insn: IType) -> Self::InstructionResult { Instruction { kind: InsnKind::JALR, - rd: dec_insn.rd, - rs1: dec_insn.rs1, + rd: dec_insn.rd as RegIdx, + rs1: dec_insn.rs1 as RegIdx, rs2: 0, imm: dec_insn.imm, raw: self.word, @@ -265,7 +268,7 @@ impl InstructionProcessor for InstructionTranspiler { // See [`InstructionTranspiler::process_auipc`] for more background on the conversion. Instruction { kind: InsnKind::ADDI, - rd: dec_insn.rd, + rd: dec_insn.rd as RegIdx, rs1: 0, rs2: 0, imm: dec_insn.imm, @@ -276,7 +279,7 @@ impl InstructionProcessor for InstructionTranspiler { { Instruction { kind: InsnKind::LUI, - rd: dec_insn.rd, + rd: dec_insn.rd as RegIdx, rs1: 0, rs2: 0, imm: dec_insn.imm, @@ -311,7 +314,7 @@ impl InstructionProcessor for InstructionTranspiler { // real world scenarios like a `reth` run. Instruction { kind: InsnKind::ADDI, - rd: dec_insn.rd, + rd: dec_insn.rd as RegIdx, rs1: 0, rs2: 0, imm: dec_insn.imm.wrapping_add(pc as i32), @@ -322,7 +325,7 @@ impl InstructionProcessor for InstructionTranspiler { { Instruction { kind: InsnKind::AUIPC, - rd: dec_insn.rd, + rd: dec_insn.rd as RegIdx, rs1: 0, rs2: 0, imm: dec_insn.imm, diff --git a/ceno_emul/src/lib.rs b/ceno_emul/src/lib.rs index 6b16a3587..915edd18f 100644 --- a/ceno_emul/src/lib.rs +++ b/ceno_emul/src/lib.rs @@ -34,7 +34,7 @@ pub use syscalls::{ BN254_FP_MUL, BN254_FP2_ADD, BN254_FP2_MUL, KECCAK_PERMUTE, SECP256K1_ADD, SECP256K1_DECOMPRESS, SECP256K1_DOUBLE, SECP256K1_SCALAR_INVERT, SECP256R1_ADD, SECP256R1_DECOMPRESS, SECP256R1_DOUBLE, SECP256R1_SCALAR_INVERT, SHA_EXTEND, SyscallSpec, - UINT256_MUL, + SyscallWitness, UINT256_MUL, bn254::{ BN254_FP_WORDS, BN254_FP2_WORDS, BN254_POINT_WORDS, Bn254AddSpec, Bn254DoubleSpec, Bn254Fp2AddSpec, Bn254Fp2MulSpec, Bn254FpAddSpec, Bn254FpMulSpec, diff --git a/ceno_emul/src/platform.rs b/ceno_emul/src/platform.rs index 75c7e8f11..4c84b96c9 100644 --- a/ceno_emul/src/platform.rs +++ b/ceno_emul/src/platform.rs @@ -134,7 +134,7 @@ impl Platform { /// Virtual address of a register. pub const fn register_vma(index: RegIdx) -> Addr { // Register VMAs are aligned, cannot be confused with indices, and readable in hex. - (index << 8) as Addr + (index as Addr) << 8 } /// Register index from a virtual address (unchecked). @@ -220,7 +220,7 @@ mod tests { // Registers do not overlap with ROM or RAM. for reg in [ Platform::register_vma(0), - Platform::register_vma(VMState::::REG_COUNT - 1), + Platform::register_vma((VMState::::REG_COUNT - 1) as RegIdx), ] { assert!(!p.is_rom(reg)); assert!(!p.is_ram(reg)); diff --git a/ceno_emul/src/rv32im.rs b/ceno_emul/src/rv32im.rs index a3eac8896..48f7beab9 100644 --- a/ceno_emul/src/rv32im.rs +++ b/ceno_emul/src/rv32im.rs @@ -29,9 +29,9 @@ use super::addr::{ByteAddr, RegIdx, WORD_SIZE, Word, WordAddr}; pub const fn encode_rv32(kind: InsnKind, rs1: u32, rs2: u32, rd: u32, imm: i32) -> Instruction { Instruction { kind, - rs1: rs1 as usize, - rs2: rs2 as usize, - rd: rd as usize, + rs1: rs1 as RegIdx, + rs2: rs2 as RegIdx, + rd: rd as RegIdx, imm, raw: 0, } @@ -43,9 +43,9 @@ pub const fn encode_rv32(kind: InsnKind, rs1: u32, rs2: u32, rd: u32, imm: i32) pub const fn encode_rv32u(kind: InsnKind, rs1: u32, rs2: u32, rd: u32, imm: u32) -> Instruction { Instruction { kind, - rs1: rs1 as usize, - rs2: rs2 as usize, - rd: rd as usize, + rs1: rs1 as RegIdx, + rs2: rs2 as RegIdx, + rd: rd as RegIdx, imm: imm as i32, raw: 0, } @@ -113,6 +113,7 @@ pub enum TrapCause { } #[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord)] +#[repr(C)] pub struct Instruction { pub kind: InsnKind, pub rs1: RegIdx, @@ -162,6 +163,7 @@ use InsnFormat::*; ToPrimitive, Default, )] +#[repr(u8)] #[allow(clippy::upper_case_acronyms)] pub enum InsnKind { #[default] @@ -425,7 +427,7 @@ fn step_compute(ctx: &mut M, kind: InsnKind, insn: &Instruction) if !new_pc.is_aligned() { return ctx.trap(TrapCause::InstructionAddressMisaligned); } - ctx.store_register(insn.rd_internal() as usize, out)?; + ctx.store_register(insn.rd_internal() as RegIdx, out)?; ctx.set_pc(new_pc); Ok(true) } @@ -502,7 +504,7 @@ fn step_load(ctx: &mut M, kind: InsnKind, decoded: &Instruction) } _ => unreachable!(), }; - ctx.store_register(decoded.rd_internal() as usize, out)?; + ctx.store_register(decoded.rd_internal() as RegIdx, out)?; ctx.set_pc(ctx.get_pc() + WORD_SIZE); Ok(true) } diff --git a/ceno_emul/src/test_utils.rs b/ceno_emul/src/test_utils.rs index 39577c13c..92ad64a97 100644 --- a/ceno_emul/src/test_utils.rs +++ b/ceno_emul/src/test_utils.rs @@ -1,10 +1,11 @@ use crate::{ CENO_PLATFORM, InsnKind, Instruction, Platform, Program, StepRecord, VMState, encode_rv32, - encode_rv32u, syscalls::KECCAK_PERMUTE, + encode_rv32u, + syscalls::{KECCAK_PERMUTE, SyscallWitness}, }; use anyhow::Result; -pub fn keccak_step() -> (StepRecord, Vec) { +pub fn keccak_step() -> (StepRecord, Vec, Vec) { let instructions = vec![ // Call Keccak-f. load_immediate(Platform::reg_arg0() as u32, CENO_PLATFORM.heap.start), @@ -26,8 +27,9 @@ pub fn keccak_step() -> (StepRecord, Vec) { let mut vm = VMState::new(CENO_PLATFORM.clone(), program.into()); vm.iter_until_halt().collect::>>().unwrap(); let steps = vm.tracer().recorded_steps(); + let syscall_witnesses = vm.tracer().syscall_witnesses().to_vec(); - (steps[2].clone(), instructions) + (steps[2].clone(), instructions, syscall_witnesses) } const fn load_immediate(rd: u32, imm: u32) -> Instruction { diff --git a/ceno_emul/src/tracer.rs b/ceno_emul/src/tracer.rs index 45821ae64..4aa7b3080 100644 --- a/ceno_emul/src/tracer.rs +++ b/ceno_emul/src/tracer.rs @@ -22,7 +22,8 @@ use std::{collections::BTreeMap, fmt, sync::Arc}; /// - Any of `rs1 / rs2 / rd` **may be `x0`**. The trace handles this like any register, including the value that was _supposed_ to be stored. The circuits must handle this case: either **store `0` or skip `x0` operations**. /// /// - Any pair of `rs1 / rs2 / rd` **may be the same**. Then, one op will point to the other op in the same instruction but a different subcycle. The circuits may follow the operations **without special handling** of repeated registers. -#[derive(Clone, Debug, Default, PartialEq, Eq)] +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +#[repr(C)] pub struct StepRecord { cycle: Cycle, pc: Change, @@ -30,14 +31,45 @@ pub struct StepRecord { pub hint_maxtouch_addr: Change, pub insn: Instruction, - rs1: Option, - rs2: Option, + has_rs1: bool, + has_rs2: bool, + has_rd: bool, + has_memory_op: bool, - rd: Option, + rs1: ReadOp, + rs2: ReadOp, + rd: WriteOp, + memory_op: WriteOp, - memory_op: Option, + /// Index into the separate syscall witness storage. + /// `u32::MAX` means no syscall for this step. + syscall_index: u32, +} - syscall: Option, +impl StepRecord { + /// Sentinel value indicating no syscall is associated with this step. + pub const NO_SYSCALL: u32 = u32::MAX; +} + +impl Default for StepRecord { + fn default() -> Self { + Self { + cycle: 0, + pc: Default::default(), + heap_maxtouch_addr: Default::default(), + hint_maxtouch_addr: Default::default(), + insn: Default::default(), + has_rs1: false, + has_rs2: false, + has_rd: false, + has_memory_op: false, + rs1: Default::default(), + rs2: Default::default(), + rd: Default::default(), + memory_op: Default::default(), + syscall_index: StepRecord::NO_SYSCALL, + } + } } pub type StepIndex = usize; @@ -152,7 +184,8 @@ pub trait Tracer { ) -> Option<(WordAddr, WordAddr)>; } -#[derive(Clone, Debug, PartialEq, Eq, Hash)] +#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, Hash)] +#[repr(C)] pub struct MemOp { /// Virtual Memory Address. /// For registers, get it from `Platform::register_vma(idx)`. @@ -605,27 +638,41 @@ impl StepRecord { heap_maxtouch_addr: Change, hint_maxtouch_addr: Change, ) -> StepRecord { + let has_rs1 = rs1_read.is_some(); + let has_rs2 = rs2_read.is_some(); + let has_rd = rd.is_some(); + let has_memory_op = memory_op.is_some(); StepRecord { cycle, pc, - rs1: rs1_read.map(|rs1| ReadOp { - addr: Platform::register_vma(insn.rs1).into(), - value: rs1, - previous_cycle, - }), - rs2: rs2_read.map(|rs2| ReadOp { - addr: Platform::register_vma(insn.rs2).into(), - value: rs2, - previous_cycle, - }), - rd: rd.map(|rd| WriteOp { - addr: Platform::register_vma(insn.rd_internal() as RegIdx).into(), - value: rd, - previous_cycle, - }), + has_rs1, + has_rs2, + has_rd, + has_memory_op, + rs1: rs1_read + .map(|rs1| ReadOp { + addr: Platform::register_vma(insn.rs1).into(), + value: rs1, + previous_cycle, + }) + .unwrap_or_default(), + rs2: rs2_read + .map(|rs2| ReadOp { + addr: Platform::register_vma(insn.rs2).into(), + value: rs2, + previous_cycle, + }) + .unwrap_or_default(), + rd: rd + .map(|rd| WriteOp { + addr: Platform::register_vma(insn.rd_internal() as RegIdx).into(), + value: rd, + previous_cycle, + }) + .unwrap_or_default(), insn, - memory_op, - syscall: None, + memory_op: memory_op.unwrap_or_default(), + syscall_index: StepRecord::NO_SYSCALL, heap_maxtouch_addr, hint_maxtouch_addr, } @@ -645,19 +692,23 @@ impl StepRecord { } pub fn rs1(&self) -> Option { - self.rs1.clone() + if self.has_rs1 { Some(self.rs1) } else { None } } pub fn rs2(&self) -> Option { - self.rs2.clone() + if self.has_rs2 { Some(self.rs2) } else { None } } pub fn rd(&self) -> Option { - self.rd.clone() + if self.has_rd { Some(self.rd) } else { None } } pub fn memory_op(&self) -> Option { - self.memory_op.clone() + if self.has_memory_op { + Some(self.memory_op) + } else { + None + } } #[inline(always)] @@ -665,8 +716,19 @@ impl StepRecord { self.pc.before == self.pc.after } - pub fn syscall(&self) -> Option<&SyscallWitness> { - self.syscall.as_ref() + /// Returns true if this step has a syscall witness. + pub fn has_syscall(&self) -> bool { + self.syscall_index != Self::NO_SYSCALL + } + + /// Look up the syscall witness from a separate store. + /// The store is typically obtained from `FullTracer::syscall_witnesses()`. + pub fn syscall<'a>(&self, store: &'a [SyscallWitness]) -> Option<&'a SyscallWitness> { + if self.syscall_index == Self::NO_SYSCALL { + None + } else { + Some(&store[self.syscall_index as usize]) + } } } @@ -684,6 +746,9 @@ pub struct FullTracer { pending_index: usize, pending_cycle: Cycle, + /// Syscall witnesses stored separately (StepRecord references by index). + syscall_witnesses: Vec, + // record each section max access address // (start_addr -> (start_addr, end_addr, min_access_addr, max_access_addr)) mmio_min_max_access: Option>, @@ -724,6 +789,7 @@ impl FullTracer { len: 0, pending_index: 0, pending_cycle: Self::SUBCYCLES_PER_INSN, + syscall_witnesses: Vec::new(), mmio_min_max_access: Some(mmio_max_access), platform: platform.clone(), latest_accesses: LatestAccesses::new(platform), @@ -760,6 +826,7 @@ impl FullTracer { pub fn reset_step_buffer(&mut self) { self.len = 0; self.pending_index = 0; + self.syscall_witnesses.clear(); self.reset_pending_slot(); } @@ -767,6 +834,11 @@ impl FullTracer { &self.records[..self.len] } + /// Returns the syscall witness store. Pass this to `StepRecord::syscall()`. + pub fn syscall_witnesses(&self) -> &[SyscallWitness] { + &self.syscall_witnesses + } + #[inline(always)] pub fn step_record(&self, index: StepIndex) -> &StepRecord { assert!( @@ -822,41 +894,41 @@ impl FullTracer { #[inline(always)] pub fn load_register(&mut self, idx: RegIdx, value: Word) { let addr = Platform::register_vma(idx).into(); - match ( - self.records[self.pending_index].rs1.as_ref(), - self.records[self.pending_index].rs2.as_ref(), - ) { - (None, None) => { - self.records[self.pending_index].rs1 = Some(ReadOp { - addr, - value, - previous_cycle: self.track_access(addr, Self::SUBCYCLE_RS1), - }); - } - (Some(_), None) => { - self.records[self.pending_index].rs2 = Some(ReadOp { - addr, - value, - previous_cycle: self.track_access(addr, Self::SUBCYCLE_RS2), - }); - } - _ => unimplemented!("Only two register reads are supported"), + if !self.records[self.pending_index].has_rs1 { + let previous_cycle = self.track_access(addr, Self::SUBCYCLE_RS1); + self.records[self.pending_index].rs1 = ReadOp { + addr, + value, + previous_cycle, + }; + self.records[self.pending_index].has_rs1 = true; + } else if !self.records[self.pending_index].has_rs2 { + let previous_cycle = self.track_access(addr, Self::SUBCYCLE_RS2); + self.records[self.pending_index].rs2 = ReadOp { + addr, + value, + previous_cycle, + }; + self.records[self.pending_index].has_rs2 = true; + } else { + unimplemented!("Only two register reads are supported"); } } #[inline(always)] pub fn store_register(&mut self, idx: RegIdx, value: Change) { - if self.records[self.pending_index].rd.is_some() { + if self.records[self.pending_index].has_rd { unimplemented!("Only one register write is supported"); } let addr = Platform::register_vma(idx).into(); let previous_cycle = self.track_access(addr, Self::SUBCYCLE_RD); - self.records[self.pending_index].rd = Some(WriteOp { + self.records[self.pending_index].rd = WriteOp { addr, value, previous_cycle, - }); + }; + self.records[self.pending_index].has_rd = true; } #[inline(always)] @@ -866,7 +938,7 @@ impl FullTracer { #[inline(always)] pub fn store_memory(&mut self, addr: WordAddr, value: Change) { - if self.records[self.pending_index].memory_op.is_some() { + if self.records[self.pending_index].has_memory_op { unimplemented!("Only one memory access is supported"); } @@ -899,19 +971,26 @@ impl FullTracer { } } - self.records[self.pending_index].memory_op = Some(WriteOp { + let previous_cycle = self.track_access(addr, Self::SUBCYCLE_MEM); + self.records[self.pending_index].memory_op = WriteOp { addr, value, - previous_cycle: self.track_access(addr, Self::SUBCYCLE_MEM), - }); + previous_cycle, + }; + self.records[self.pending_index].has_memory_op = true; } #[inline(always)] pub fn track_syscall(&mut self, effects: SyscallEffects) { let witness = effects.finalize(self); let record = &mut self.records[self.pending_index]; - assert!(record.syscall.is_none(), "Only one syscall per step"); - record.syscall = Some(witness); + assert!( + record.syscall_index == StepRecord::NO_SYSCALL, + "Only one syscall per step" + ); + let idx = self.syscall_witnesses.len(); + self.syscall_witnesses.push(witness); + record.syscall_index = idx as u32; } #[inline(always)] @@ -1371,6 +1450,7 @@ impl Tracer for FullTracer { } #[derive(Copy, Clone, Default, PartialEq, Eq)] +#[repr(C)] pub struct Change { pub before: T, pub after: T, @@ -1387,3 +1467,92 @@ impl fmt::Debug for Change { write!(f, "{:?} -> {:?}", self.before, self.after) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_step_record_is_copy_and_compact() { + // Verify StepRecord is Copy (this compiles only if Copy is implemented) + fn assert_copy() {} + assert_copy::(); + + // Verify repr(C) compactness — should be well under 128 bytes + let size = std::mem::size_of::(); + eprintln!("StepRecord size: {} bytes", size); + assert!( + size <= 144, + "StepRecord should be compact for GPU transfer: got {} bytes", + size + ); + } + + #[test] + fn test_supporting_types_are_copy() { + fn assert_copy() {} + assert_copy::(); + assert_copy::(); + assert_copy::>(); + assert_copy::>(); + } + + /// Verify exact byte offsets of StepRecord fields for CUDA struct alignment. + /// If this test fails, the CUDA step_record.cuh header must be updated to match. + #[test] + fn test_step_record_layout_for_gpu() { + use std::mem; + + macro_rules! offset_of { + ($type:ty, $field:ident) => {{ + let val = <$type>::default(); + let base = &val as *const _ as usize; + let field = &val.$field as *const _ as usize; + field - base + }}; + } + + // Sub-type sizes + assert_eq!(mem::size_of::(), 12, "Instruction size"); + assert_eq!(mem::size_of::(), 16, "ReadOp size"); + assert_eq!(mem::size_of::(), 24, "WriteOp size"); + assert_eq!( + mem::size_of::>(), + 8, + "Change size" + ); + + // StepRecord field offsets — these must match step_record.cuh + assert_eq!(offset_of!(StepRecord, cycle), 0); + assert_eq!(offset_of!(StepRecord, pc), 8); + assert_eq!(offset_of!(StepRecord, heap_maxtouch_addr), 16); + assert_eq!(offset_of!(StepRecord, hint_maxtouch_addr), 24); + assert_eq!(offset_of!(StepRecord, insn), 32); + assert_eq!(offset_of!(StepRecord, has_rs1), 44); + assert_eq!(offset_of!(StepRecord, has_rs2), 45); + assert_eq!(offset_of!(StepRecord, has_rd), 46); + assert_eq!(offset_of!(StepRecord, has_memory_op), 47); + assert_eq!(offset_of!(StepRecord, rs1), 48); + assert_eq!(offset_of!(StepRecord, rs2), 64); + assert_eq!(offset_of!(StepRecord, rd), 80); + assert_eq!(offset_of!(StepRecord, memory_op), 104); + assert_eq!(offset_of!(StepRecord, syscall_index), 128); + + // Total size + assert_eq!(mem::size_of::(), 136, "StepRecord total size"); + assert_eq!(mem::align_of::(), 8, "StepRecord alignment"); + + // InsnKind must be repr(u8) for CUDA compatibility + assert_eq!( + mem::size_of::(), + 1, + "InsnKind must be 1 byte (repr(u8))" + ); + + eprintln!( + "StepRecord layout verified: {} bytes, {} align", + mem::size_of::(), + mem::align_of::() + ); + } +} diff --git a/ceno_emul/src/vm_state.rs b/ceno_emul/src/vm_state.rs index d65844682..0613436be 100644 --- a/ceno_emul/src/vm_state.rs +++ b/ceno_emul/src/vm_state.rs @@ -155,7 +155,7 @@ impl VMState { } pub fn init_register_unsafe(&mut self, idx: RegIdx, value: Word) { - self.registers[idx] = value; + self.registers[idx as usize] = value; } fn halt(&mut self, exit_code: u32) { @@ -171,7 +171,7 @@ impl VMState { } for (idx, value) in effects.iter_reg_values() { - self.registers[idx] = value; + self.registers[idx as usize] = value; } let next_pc = effects.next_pc.unwrap_or(self.pc + PC_STEP_SIZE as u32); @@ -252,7 +252,7 @@ impl EmuContext for VMState { if idx != 0 { let before = self.peek_register(idx); self.tracer.store_register(idx, Change { before, after }); - self.registers[idx] = after; + self.registers[idx as usize] = after; } Ok(()) } @@ -276,7 +276,7 @@ impl EmuContext for VMState { /// Get the value of a register without side-effects. fn peek_register(&self, idx: RegIdx) -> Word { - self.registers[idx] + self.registers[idx as usize] } /// Get the value of a memory word without side-effects. diff --git a/ceno_host/tests/test_elf.rs b/ceno_host/tests/test_elf.rs index ff752267d..b06c48d6e 100644 --- a/ceno_host/tests/test_elf.rs +++ b/ceno_host/tests/test_elf.rs @@ -3,7 +3,7 @@ use std::{collections::BTreeSet, iter::from_fn, sync::Arc}; use anyhow::Result; use ceno_emul::{ BN254_FP_WORDS, BN254_FP2_WORDS, BN254_POINT_WORDS, CENO_PLATFORM, EmuContext, InsnKind, - Platform, Program, SECP256K1_ARG_WORDS, SECP256K1_COORDINATE_WORDS, StepRecord, + Platform, Program, SECP256K1_ARG_WORDS, SECP256K1_COORDINATE_WORDS, StepRecord, SyscallWitness, UINT256_WORDS_FIELD_ELEMENT, VMState, WORD_SIZE, Word, WordAddr, WriteOp, host_utils::{read_all_messages, read_all_messages_as_words}, }; @@ -21,7 +21,7 @@ fn test_ceno_rt_mini() -> Result<()> { ..CENO_PLATFORM.clone() }; let mut state = VMState::new(platform, Arc::new(program)); - let _steps = run(&mut state)?; + let (_steps, _syscall_witnesses) = run(&mut state)?; Ok(()) } @@ -39,7 +39,7 @@ fn test_ceno_rt_panic() { ..CENO_PLATFORM.clone() }; let mut state = VMState::new(platform, Arc::new(program)); - let steps = run(&mut state).unwrap(); + let (steps, _syscall_witnesses) = run(&mut state).unwrap(); let last = steps.last().unwrap(); assert_eq!(last.insn().kind, InsnKind::ECALL); assert_eq!(last.rs1().unwrap().value, Platform::ecall_halt()); @@ -56,7 +56,7 @@ fn test_ceno_rt_mem() -> Result<()> { }; let sheap = program.sheap.into(); let mut state = VMState::new(platform, Arc::new(program.clone())); - let _steps = run(&mut state)?; + let (_steps, _syscall_witnesses) = run(&mut state)?; let value = state.peek_memory(sheap); assert_eq!(value, 6765, "Expected Fibonacci 20, got {}", value); @@ -72,7 +72,7 @@ fn test_ceno_rt_alloc() -> Result<()> { ..CENO_PLATFORM.clone() }; let mut state = VMState::new(platform, Arc::new(program)); - let _steps = run(&mut state)?; + let (_steps, _syscall_witnesses) = run(&mut state)?; // Search for the RAM action of the test program. let mut found = (false, false); @@ -102,7 +102,7 @@ fn test_ceno_rt_io() -> Result<()> { ..CENO_PLATFORM.clone() }; let mut state = VMState::new(platform, Arc::new(program)); - let _steps = run(&mut state)?; + let (_steps, _syscall_witnesses) = run(&mut state)?; let all_messages = messages_to_strings(&read_all_messages(&state)); for msg in &all_messages { @@ -235,7 +235,7 @@ fn test_hashing() -> Result<()> { fn test_keccak_syscall() -> Result<()> { let program_elf = ceno_examples::keccak_syscall; let mut state = VMState::new_from_elf(unsafe_platform(), program_elf)?; - let steps = run(&mut state)?; + let (steps, syscall_witnesses) = run(&mut state)?; // Expect the program to have written successive states between Keccak permutations. let keccak_first_iter_outs = sample_keccak_f(1); @@ -251,7 +251,10 @@ fn test_keccak_syscall() -> Result<()> { } // Find the syscall records. - let syscalls = steps.iter().filter_map(|step| step.syscall()).collect_vec(); + let syscalls = steps + .iter() + .filter_map(|step| step.syscall(&syscall_witnesses)) + .collect_vec(); assert_eq!(syscalls.len(), 100); // Check the syscall effects. @@ -293,9 +296,12 @@ fn bytes_to_words(bytes: [u8; 65]) -> [u32; 16] { fn test_secp256k1() -> Result<()> { let program_elf = ceno_examples::secp256k1; let mut state = VMState::new_from_elf(unsafe_platform(), program_elf)?; - let steps = run(&mut state)?; + let (steps, syscall_witnesses) = run(&mut state)?; - let syscalls = steps.iter().filter_map(|step| step.syscall()).collect_vec(); + let syscalls = steps + .iter() + .filter_map(|step| step.syscall(&syscall_witnesses)) + .collect_vec(); assert!(!syscalls.is_empty()); Ok(()) @@ -305,9 +311,12 @@ fn test_secp256k1() -> Result<()> { fn test_secp256k1_add() -> Result<()> { let program_elf = ceno_examples::secp256k1_add_syscall; let mut state = VMState::new_from_elf(unsafe_platform(), program_elf)?; - let steps = run(&mut state)?; + let (steps, syscall_witnesses) = run(&mut state)?; - let syscalls = steps.iter().filter_map(|step| step.syscall()).collect_vec(); + let syscalls = steps + .iter() + .filter_map(|step| step.syscall(&syscall_witnesses)) + .collect_vec(); assert_eq!(syscalls.len(), 1); let witness = syscalls[0]; @@ -358,9 +367,12 @@ fn test_secp256k1_double() -> Result<()> { let program_elf = ceno_examples::secp256k1_double_syscall; let mut state = VMState::new_from_elf(unsafe_platform(), program_elf)?; - let steps = run(&mut state)?; + let (steps, syscall_witnesses) = run(&mut state)?; - let syscalls = steps.iter().filter_map(|step| step.syscall()).collect_vec(); + let syscalls = steps + .iter() + .filter_map(|step| step.syscall(&syscall_witnesses)) + .collect_vec(); assert_eq!(syscalls.len(), 1); let witness = syscalls[0]; @@ -394,9 +406,12 @@ fn test_secp256k1_decompress() -> Result<()> { let program_elf = ceno_examples::secp256k1_decompress_syscall; let mut state = VMState::new_from_elf(unsafe_platform(), program_elf)?; - let steps = run(&mut state)?; + let (steps, syscall_witnesses) = run(&mut state)?; - let syscalls = steps.iter().filter_map(|step| step.syscall()).collect_vec(); + let syscalls = steps + .iter() + .filter_map(|step| step.syscall(&syscall_witnesses)) + .collect_vec(); assert_eq!(syscalls.len(), 1); let witness = syscalls[0]; @@ -456,8 +471,11 @@ fn test_secp256k1_ecrecover() -> Result<()> { let program_elf = ceno_examples::secp256k1_ecrecover; let mut state = VMState::new_from_elf(unsafe_platform(), program_elf)?; - let steps = run(&mut state)?; - let syscalls = steps.iter().filter_map(|step| step.syscall()).collect_vec(); + let (steps, syscall_witnesses) = run(&mut state)?; + let syscalls = steps + .iter() + .filter_map(|step| step.syscall(&syscall_witnesses)) + .collect_vec(); assert!(!syscalls.is_empty()); Ok(()) @@ -479,8 +497,11 @@ fn test_sha256_extend() -> Result<()> { ]; let mut state = VMState::new_from_elf(unsafe_platform(), program_elf)?; - let steps = run(&mut state)?; - let syscalls = steps.iter().filter_map(|step| step.syscall()).collect_vec(); + let (steps, syscall_witnesses) = run(&mut state)?; + let syscalls = steps + .iter() + .filter_map(|step| step.syscall(&syscall_witnesses)) + .collect_vec(); assert_eq!(syscalls.len(), 48); for round in 0..48 { @@ -534,10 +555,13 @@ fn test_sha256_full() -> Result<()> { fn test_bn254_fptower_syscalls() -> Result<()> { let program_elf = ceno_examples::bn254_fptower_syscalls; let mut state = VMState::new_from_elf(unsafe_platform(), program_elf)?; - let steps = run(&mut state)?; + let (steps, syscall_witnesses) = run(&mut state)?; const RUNS: usize = 10; - let syscalls = steps.iter().filter_map(|step| step.syscall()).collect_vec(); + let syscalls = steps + .iter() + .filter_map(|step| step.syscall(&syscall_witnesses)) + .collect_vec(); assert_eq!(syscalls.len(), 4 * RUNS); for witness in syscalls.iter() { @@ -584,9 +608,12 @@ fn test_bn254_fptower_syscalls() -> Result<()> { fn test_bn254_curve() -> Result<()> { let program_elf = ceno_examples::bn254_curve_syscalls; let mut state = VMState::new_from_elf(unsafe_platform(), program_elf)?; - let steps = run(&mut state)?; + let (steps, syscall_witnesses) = run(&mut state)?; - let syscalls = steps.iter().filter_map(|step| step.syscall()).collect_vec(); + let syscalls = steps + .iter() + .filter_map(|step| step.syscall(&syscall_witnesses)) + .collect_vec(); assert_eq!(syscalls.len(), 3); // add @@ -652,9 +679,12 @@ fn test_uint256_mul() -> Result<()> { let program_elf = ceno_examples::uint256_mul_syscall; let mut state = VMState::new_from_elf(unsafe_platform(), program_elf)?; - let steps = run(&mut state)?; + let (steps, syscall_witnesses) = run(&mut state)?; - let syscalls = steps.iter().filter_map(|step| step.syscall()).collect_vec(); + let syscalls = steps + .iter() + .filter_map(|step| step.syscall(&syscall_witnesses)) + .collect_vec(); assert_eq!(syscalls.len(), 1); let witness = syscalls[0]; @@ -805,9 +835,10 @@ fn messages_to_strings(messages: &[Vec]) -> Vec { .collect() } -fn run(state: &mut VMState) -> Result> { +fn run(state: &mut VMState) -> Result<(Vec, Vec)> { state.iter_until_halt().collect::>>()?; let steps = state.tracer().recorded_steps().to_vec(); + let syscall_witnesses = state.tracer().syscall_witnesses().to_vec(); eprintln!("Emulator ran for {} steps.", steps.len()); - Ok(steps) + Ok((steps, syscall_witnesses)) } diff --git a/ceno_zkvm/src/e2e.rs b/ceno_zkvm/src/e2e.rs index 215cbf7b6..7a3c4710c 100644 --- a/ceno_zkvm/src/e2e.rs +++ b/ceno_zkvm/src/e2e.rs @@ -24,9 +24,9 @@ use crate::{ }; use ceno_emul::{ Addr, ByteAddr, CENO_PLATFORM, Cycle, EmuContext, FullTracer, FullTracerConfig, IterAddresses, - NextCycleAccess, Platform, PreflightTracer, PreflightTracerConfig, Program, StepCellExtractor, - StepIndex, StepRecord, Tracer, VM_REG_COUNT, VMState, WORD_SIZE, Word, WordAddr, - host_utils::read_all_messages, + NextCycleAccess, Platform, PreflightTracer, PreflightTracerConfig, Program, RegIdx, + StepCellExtractor, StepIndex, StepRecord, SyscallWitness, Tracer, VM_REG_COUNT, VMState, + WORD_SIZE, Word, WordAddr, host_utils::read_all_messages, }; use clap::ValueEnum; use either::Either; @@ -199,6 +199,8 @@ pub struct ShardContext<'a> { pub platform: Platform, pub shard_heap_addr_range: Range, pub shard_hint_addr_range: Range, + /// Syscall witnesses for StepRecord::syscall() lookups. + pub syscall_witnesses: Arc>, } impl<'a> Default for ShardContext<'a> { @@ -233,6 +235,7 @@ impl<'a> Default for ShardContext<'a> { platform: CENO_PLATFORM.clone(), shard_heap_addr_range: CENO_PLATFORM.heap.clone(), shard_hint_addr_range: CENO_PLATFORM.hints.clone(), + syscall_witnesses: Arc::new(Vec::new()), } } } @@ -279,6 +282,7 @@ impl<'a> ShardContext<'a> { platform: self.platform.clone(), shard_heap_addr_range: self.shard_heap_addr_range.clone(), shard_hint_addr_range: self.shard_hint_addr_range.clone(), + syscall_witnesses: self.syscall_witnesses.clone(), }) .collect_vec(), _ => panic!("invalid type"), @@ -750,6 +754,7 @@ pub trait StepSource: Iterator { fn start_new_shard(&mut self); fn shard_steps(&self) -> &[StepRecord]; fn step_record(&self, idx: StepIndex) -> &StepRecord; + fn syscall_witnesses(&self) -> &[SyscallWitness]; } /// Lazily replays `StepRecord`s by re-running the VM up to the number of steps @@ -822,6 +827,10 @@ impl StepSource for StepReplay { fn step_record(&self, idx: StepIndex) -> &StepRecord { self.vm.tracer().step_record(idx) } + + fn syscall_witnesses(&self) -> &[SyscallWitness] { + self.vm.tracer().syscall_witnesses() + } } pub fn emulate_program<'a>( @@ -899,8 +908,8 @@ pub fn emulate_program<'a>( let reg_final = reg_init .iter() .map(|rec| { - let index = rec.addr as usize; - if index < VM_REG_COUNT { + if (rec.addr as usize) < VM_REG_COUNT { + let index = rec.addr as RegIdx; let vma: WordAddr = Platform::register_vma(index).into(); MemFinalRecord { ram_type: RAMType::Register, @@ -1282,6 +1291,7 @@ pub fn generate_witness<'a, E: ExtensionField>( }; tracing::debug!("position_next_shard finish in {:?}", time.elapsed()); let shard_steps = step_iter.shard_steps(); + shard_ctx.syscall_witnesses = Arc::new(step_iter.syscall_witnesses().to_vec()); let mut zkvm_witness = ZKVMWitnesses::default(); let mut pi = pi_template.clone(); @@ -2122,7 +2132,9 @@ pub fn verify + serde::Ser #[cfg(test)] mod tests { use crate::e2e::{MultiProver, ShardContextBuilder}; - use ceno_emul::{CENO_PLATFORM, Cycle, FullTracer, NextCycleAccess, StepIndex, StepRecord}; + use ceno_emul::{ + CENO_PLATFORM, Cycle, FullTracer, NextCycleAccess, StepIndex, StepRecord, SyscallWitness, + }; use itertools::Itertools; use std::sync::Arc; @@ -2224,6 +2236,10 @@ mod tests { fn step_record(&self, idx: StepIndex) -> &StepRecord { &self.steps[self.shard_start + idx] } + + fn syscall_witnesses(&self) -> &[SyscallWitness] { + &[] // Test replay doesn't track syscalls + } } let mut steps_iter = TestReplay::new(steps); diff --git a/ceno_zkvm/src/instructions/riscv/dummy/dummy_ecall.rs b/ceno_zkvm/src/instructions/riscv/dummy/dummy_ecall.rs index 3ae516e9c..650d5d97a 100644 --- a/ceno_zkvm/src/instructions/riscv/dummy/dummy_ecall.rs +++ b/ceno_zkvm/src/instructions/riscv/dummy/dummy_ecall.rs @@ -99,7 +99,8 @@ impl Instruction for LargeEcallDummy lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, ) -> Result<(), ZKVMError> { - let ops = &step.syscall().expect("syscall step"); + let sw = shard_ctx.syscall_witnesses.clone(); + let ops = &step.syscall(&sw).expect("syscall step"); // Assign instruction. config diff --git a/ceno_zkvm/src/instructions/riscv/dummy/test.rs b/ceno_zkvm/src/instructions/riscv/dummy/test.rs index b74c8ca39..bf952a4f1 100644 --- a/ceno_zkvm/src/instructions/riscv/dummy/test.rs +++ b/ceno_zkvm/src/instructions/riscv/dummy/test.rs @@ -18,10 +18,12 @@ fn test_large_ecall_dummy_keccak() { let mut cb = CircuitBuilder::new(&mut cs); let config = KeccakDummy::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); - let (step, program) = ceno_emul::test_utils::keccak_step(); + let (step, program, syscall_witnesses) = ceno_emul::test_utils::keccak_step(); + let mut shard_ctx = ShardContext::default(); + shard_ctx.syscall_witnesses = std::sync::Arc::new(syscall_witnesses); let (raw_witin, lkm) = KeccakDummy::assign_instances_from_steps( &config, - &mut ShardContext::default(), + &mut shard_ctx, cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, &[step], diff --git a/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp.rs b/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp.rs index 57d824b01..7ee531cc8 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp.rs @@ -358,7 +358,8 @@ fn assign_fp_op_instances( .zip_eq(indices.iter().copied()) .map(|(instance, idx)| { let step = &steps[idx]; - let ops = &step.syscall().expect("syscall step"); + let sw = shard_ctx.syscall_witnesses.clone(); + let ops = &step.syscall(&sw).expect("syscall step"); config .vm_state .assign_instance(instance, &shard_ctx, step)?; @@ -419,7 +420,7 @@ fn assign_fp_op_instances( .map(|&idx| { let step = &steps[idx]; let values: Vec = step - .syscall() + .syscall(&shard_ctx.syscall_witnesses) .unwrap() .mem_ops .iter() diff --git a/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp2_add.rs b/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp2_add.rs index 6715a0b74..2d99ed4a4 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp2_add.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp2_add.rs @@ -273,7 +273,8 @@ fn assign_fp2_add_instances = step - .syscall() + .syscall(&shard_ctx.syscall_witnesses) .unwrap() .mem_ops .iter() diff --git a/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp2_mul.rs b/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp2_mul.rs index 7537d31ca..4aacf3418 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp2_mul.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp2_mul.rs @@ -271,7 +271,8 @@ fn assign_fp2_mul_instances = step - .syscall() + .syscall(&shard_ctx.syscall_witnesses) .unwrap() .mem_ops .iter() diff --git a/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs b/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs index e088cc0cc..f9c9f1712 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs @@ -221,7 +221,8 @@ impl Instruction for KeccakInstruction { .zip_eq(indices.iter().copied()) .map(|(instance_with_rotation, idx)| { let step = &steps[idx]; - let ops = &step.syscall().expect("syscall step"); + let sw = shard_ctx.syscall_witnesses.clone(); + let ops = &step.syscall(&sw).expect("syscall step"); let bh = BooleanHypercube::new(KECCAK_ROUNDS_CEIL_LOG2); let mut cyclic_group = bh.into_iter(); @@ -285,7 +286,7 @@ impl Instruction for KeccakInstruction { .map(|&idx| -> KeccakInstance { let step = &steps[idx]; let (instance, prev_ts): (Vec, Vec) = step - .syscall() + .syscall(&shard_ctx.syscall_witnesses) .unwrap() .mem_ops .iter() diff --git a/ceno_zkvm/src/instructions/riscv/ecall/sha_extend.rs b/ceno_zkvm/src/instructions/riscv/ecall/sha_extend.rs index b61673e4a..3a0f42ab2 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/sha_extend.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/sha_extend.rs @@ -218,7 +218,8 @@ impl Instruction for ShaExtendInstruction { .zip_eq(indices.iter().copied()) .map(|(instance, idx)| { let step = &steps[idx]; - let ops = step.syscall().expect("syscall step"); + let sw = shard_ctx.syscall_witnesses.clone(); + let ops = step.syscall(&sw).expect("syscall step"); // vm_state config @@ -285,7 +286,8 @@ impl Instruction for ShaExtendInstruction { .iter() .map(|&idx| -> ShaExtendInstance { let step = &steps[idx]; - let ops = step.syscall().expect("syscall step"); + let sw = shard_ctx.syscall_witnesses.clone(); + let ops = step.syscall(&sw).expect("syscall step"); let w_i_minus_2 = ops.mem_ops[0].value.before; let w_i_minus_7 = ops.mem_ops[1].value.before; let w_i_minus_15 = ops.mem_ops[2].value.before; diff --git a/ceno_zkvm/src/instructions/riscv/ecall/uint256.rs b/ceno_zkvm/src/instructions/riscv/ecall/uint256.rs index f3a39093f..67a042cd7 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/uint256.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/uint256.rs @@ -270,7 +270,8 @@ impl Instruction for Uint256MulInstruction { .zip_eq(indices.iter().copied()) .map(|(instance, idx)| { let step = &steps[idx]; - let ops = &step.syscall().expect("syscall step"); + let sw = shard_ctx.syscall_witnesses.clone(); + let ops = &step.syscall(&sw).expect("syscall step"); // vm_state config @@ -336,7 +337,7 @@ impl Instruction for Uint256MulInstruction { .map(|&idx| { let step = &steps[idx]; let (instance, _prev_ts): (Vec, Vec) = step - .syscall() + .syscall(&shard_ctx.syscall_witnesses) .unwrap() .mem_ops .iter() @@ -593,7 +594,8 @@ impl Instruction for Uint256InvInstr .zip_eq(indices.iter().copied()) .map(|(instance, idx)| { let step = &steps[idx]; - let ops = &step.syscall().expect("syscall step"); + let sw = shard_ctx.syscall_witnesses.clone(); + let ops = &step.syscall(&sw).expect("syscall step"); // vm_state config @@ -646,7 +648,7 @@ impl Instruction for Uint256InvInstr .map(|&idx| { let step = &steps[idx]; let (instance, _): (Vec, Vec) = step - .syscall() + .syscall(&shard_ctx.syscall_witnesses) .unwrap() .mem_ops .iter() diff --git a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_add.rs b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_add.rs index 80c85ef7a..ca4f59ed3 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_add.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_add.rs @@ -278,7 +278,8 @@ impl Instruction .zip_eq(indices.iter().copied()) .map(|(instance, idx)| { let step = &steps[idx]; - let ops = &step.syscall().expect("syscall step"); + let sw = shard_ctx.syscall_witnesses.clone(); + let ops = &step.syscall(&sw).expect("syscall step"); // vm_state config @@ -345,7 +346,7 @@ impl Instruction .map(|&idx| { let step = &steps[idx]; let (instance, _prev_ts): (Vec, Vec) = step - .syscall() + .syscall(&shard_ctx.syscall_witnesses) .unwrap() .mem_ops .iter() diff --git a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_decompress.rs b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_decompress.rs index a07fc00b2..1d79efd63 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_decompress.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_decompress.rs @@ -278,7 +278,8 @@ impl Instruction Instruction, Vec) = step - .syscall() + .syscall(&shard_ctx.syscall_witnesses) .unwrap() .mem_ops .iter() diff --git a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_double.rs b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_double.rs index 72f5f71d8..206ef143d 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_double.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_double.rs @@ -250,7 +250,8 @@ impl Instruction Instruction, Vec) = step - .syscall() + .syscall(&shard_ctx.syscall_witnesses) .unwrap() .mem_ops .iter() diff --git a/ceno_zkvm/src/instructions/riscv/ecall_base.rs b/ceno_zkvm/src/instructions/riscv/ecall_base.rs index 7e655c408..69d9d286f 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall_base.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall_base.rs @@ -18,13 +18,13 @@ use ceno_emul::FullTracer as Tracer; use multilinear_extensions::{ToExpr, WitIn}; #[derive(Debug)] -pub struct OpFixedRS { +pub struct OpFixedRS { pub prev_ts: WitIn, pub prev_value: Option>, pub lt_cfg: AssertLtConfig, } -impl OpFixedRS { +impl OpFixedRS { pub fn construct_circuit( circuit_builder: &mut CircuitBuilder, rd_written: RegisterExpr, From ac2bf54d394f0db191cb7f2b5fb139d0b074cc94 Mon Sep 17 00:00:00 2001 From: Velaciela Date: Wed, 4 Mar 2026 10:25:28 +0800 Subject: [PATCH 02/73] fix --- ceno_emul/src/syscalls.rs | 8 ++------ ceno_emul/src/test_utils.rs | 2 +- ceno_zkvm/src/instructions/riscv/memory/gadget.rs | 2 +- 3 files changed, 4 insertions(+), 8 deletions(-) diff --git a/ceno_emul/src/syscalls.rs b/ceno_emul/src/syscalls.rs index 5d9674fc6..31c87b271 100644 --- a/ceno_emul/src/syscalls.rs +++ b/ceno_emul/src/syscalls.rs @@ -60,19 +60,15 @@ pub fn handle_syscall(vm: &VMState, function_code: u32) -> Result< /// A syscall event, available to the circuit witness generators. /// TODO: separate mem_ops into two stages: reads-and-writes #[derive(Clone, Debug, Default, PartialEq, Eq)] +#[non_exhaustive] pub struct SyscallWitness { pub mem_ops: Vec, pub reg_ops: Vec, - _marker: (), } impl SyscallWitness { fn new(mem_ops: Vec, reg_ops: Vec) -> SyscallWitness { - SyscallWitness { - mem_ops, - reg_ops, - _marker: (), - } + SyscallWitness { mem_ops, reg_ops } } } diff --git a/ceno_emul/src/test_utils.rs b/ceno_emul/src/test_utils.rs index 92ad64a97..625ed52c5 100644 --- a/ceno_emul/src/test_utils.rs +++ b/ceno_emul/src/test_utils.rs @@ -29,7 +29,7 @@ pub fn keccak_step() -> (StepRecord, Vec, Vec) { let steps = vm.tracer().recorded_steps(); let syscall_witnesses = vm.tracer().syscall_witnesses().to_vec(); - (steps[2].clone(), instructions, syscall_witnesses) + (steps[2], instructions, syscall_witnesses) } const fn load_immediate(rd: u32, imm: u32) -> Instruction { diff --git a/ceno_zkvm/src/instructions/riscv/memory/gadget.rs b/ceno_zkvm/src/instructions/riscv/memory/gadget.rs index 3a8da4a09..a37be1f61 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/gadget.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/gadget.rs @@ -138,7 +138,7 @@ impl MemWordUtil { step: &StepRecord, shift: u32, ) -> Result<(), ZKVMError> { - let memory_op = step.memory_op().clone().unwrap(); + let memory_op = step.memory_op().unwrap(); let prev_value = Value::new_unchecked(memory_op.value.before); let rs2_value = Value::new_unchecked(step.rs2().unwrap().value); From f550783d59da1d1b0d9918f1828a44426d3179ac Mon Sep 17 00:00:00 2001 From: Velaciela Date: Tue, 3 Mar 2026 09:42:27 +0800 Subject: [PATCH 03/73] witgen: add --- ceno_zkvm/benches/witgen_add_gpu.rs | 117 ++++++++ ceno_zkvm/src/instructions/riscv.rs | 2 + ceno_zkvm/src/instructions/riscv/arith.rs | 8 +- ceno_zkvm/src/instructions/riscv/gpu/add.rs | 286 ++++++++++++++++++++ ceno_zkvm/src/instructions/riscv/gpu/mod.rs | 2 + gkr_iop/src/gadgets/is_lt.rs | 2 +- 6 files changed, 412 insertions(+), 5 deletions(-) create mode 100644 ceno_zkvm/benches/witgen_add_gpu.rs create mode 100644 ceno_zkvm/src/instructions/riscv/gpu/add.rs create mode 100644 ceno_zkvm/src/instructions/riscv/gpu/mod.rs diff --git a/ceno_zkvm/benches/witgen_add_gpu.rs b/ceno_zkvm/benches/witgen_add_gpu.rs new file mode 100644 index 000000000..360582d11 --- /dev/null +++ b/ceno_zkvm/benches/witgen_add_gpu.rs @@ -0,0 +1,117 @@ +use std::time::Duration; + +use ceno_zkvm::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + e2e::ShardContext, + instructions::{ + Instruction, + riscv::arith::AddInstruction, + }, + structs::ProgramParams, +}; +use ceno_emul::{ByteAddr, Change, InsnKind, StepRecord, encode_rv32}; +use criterion::*; +use ff_ext::BabyBearExt4; + +#[cfg(feature = "gpu")] +use ceno_gpu::bb31::CudaHalBB31; +#[cfg(feature = "gpu")] +use ceno_zkvm::instructions::riscv::gpu::add::{extract_add_column_map, pack_add_soa}; + +mod alloc; + +type E = BabyBearExt4; + +criterion_group! { + name = witgen_add; + config = Criterion::default().warm_up_time(Duration::from_millis(2000)); + targets = bench_witgen_add +} + +criterion_main!(witgen_add); + +fn make_test_steps(n: usize) -> Vec { + let pc_start = 0x1000u32; + (0..n) + .map(|i| { + let rs1 = (i as u32) % 1000 + 1; + let rs2 = (i as u32) % 500 + 3; + let rd_before = (i as u32) % 200; + let rd_after = rs1.wrapping_add(rs2); + let cycle = 4 + (i as u64) * 4; + let pc = ByteAddr(pc_start + (i as u32) * 4); + let insn_code = encode_rv32(InsnKind::ADD, 2, 3, 4, 0); + StepRecord::new_r_instruction( + cycle, + pc, + insn_code, + rs1, + rs2, + Change::new(rd_before, rd_after), + 0, + ) + }) + .collect() +} + +fn bench_witgen_add(c: &mut Criterion) { + let mut cs = ConstraintSystem::::new(|| "bench"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + AddInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let num_witin = cb.cs.num_witin as usize; + let num_structural_witin = cb.cs.num_structural_witin as usize; + + #[cfg(feature = "gpu")] + let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); + #[cfg(feature = "gpu")] + let col_map = extract_add_column_map(&config, num_witin); + + for pow in [10, 12, 14, 16, 18] { + let n = 1usize << pow; + let mut group = c.benchmark_group(format!("witgen_add_2^{}", pow)); + group.sample_size(10); + + let steps = make_test_steps(n); + let indices: Vec = (0..n).collect(); + + // CPU benchmark + group.bench_function("cpu_assign_instances", |b| { + b.iter(|| { + let mut shard_ctx = ShardContext::default(); + AddInstruction::::assign_instances( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap() + }) + }); + + // GPU benchmark (total: SOA pack + H2D + kernel + synchronize) + #[cfg(feature = "gpu")] + group.bench_function("gpu_total", |b| { + b.iter(|| { + let shard_ctx = ShardContext::default(); + let soa = pack_add_soa(&shard_ctx, &steps, &indices); + hal.witgen_add(&col_map, &soa, None).unwrap() + }) + }); + + // GPU benchmark (kernel only: pre-upload SOA, measure only kernel) + #[cfg(feature = "gpu")] + { + let shard_ctx = ShardContext::default(); + let soa = pack_add_soa(&shard_ctx, &steps, &indices); + + group.bench_function("gpu_kernel_only", |b| { + b.iter(|| hal.witgen_add(&col_map, &soa, None).unwrap()) + }); + } + + group.finish(); + } +} diff --git a/ceno_zkvm/src/instructions/riscv.rs b/ceno_zkvm/src/instructions/riscv.rs index c77b707b4..c70264071 100644 --- a/ceno_zkvm/src/instructions/riscv.rs +++ b/ceno_zkvm/src/instructions/riscv.rs @@ -32,6 +32,8 @@ mod r_insn; mod ecall_insn; +pub mod gpu; + #[cfg(feature = "u16limb_circuit")] mod auipc; mod im_insn; diff --git a/ceno_zkvm/src/instructions/riscv/arith.rs b/ceno_zkvm/src/instructions/riscv/arith.rs index a5f6e006f..1fd5f98d4 100644 --- a/ceno_zkvm/src/instructions/riscv/arith.rs +++ b/ceno_zkvm/src/instructions/riscv/arith.rs @@ -11,11 +11,11 @@ use ff_ext::ExtensionField; /// This config handles R-Instructions that represent registers values as 2 * u16. #[derive(Debug)] pub struct ArithConfig { - r_insn: RInstructionConfig, + pub r_insn: RInstructionConfig, - rs1_read: UInt, - rs2_read: UInt, - rd_written: UInt, + pub rs1_read: UInt, + pub rs2_read: UInt, + pub rd_written: UInt, } pub struct ArithInstruction(PhantomData<(E, I)>); diff --git a/ceno_zkvm/src/instructions/riscv/gpu/add.rs b/ceno_zkvm/src/instructions/riscv/gpu/add.rs new file mode 100644 index 000000000..8eefc05f7 --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/gpu/add.rs @@ -0,0 +1,286 @@ +use ceno_emul::StepIndex; +use ceno_gpu::common::witgen_types::{AddColumnMap, AddStepRecordSOA}; +use ff_ext::ExtensionField; + +use crate::instructions::riscv::arith::ArithConfig; +use crate::e2e::ShardContext; +use ceno_emul::StepRecord; + +/// Extract column map from a constructed ArithConfig (ADD variant). +/// +/// This reads all WitIn.id values from the config tree and packs them +/// into an AddColumnMap suitable for GPU kernel dispatch. +pub fn extract_add_column_map( + config: &ArithConfig, + num_witin: usize, +) -> AddColumnMap { + // StateInOut + let pc = config.r_insn.vm_state.pc.id as u32; + let ts = config.r_insn.vm_state.ts.id as u32; + + // ReadRS1 + let rs1_id = config.r_insn.rs1.id.id as u32; + let rs1_prev_ts = config.r_insn.rs1.prev_ts.id as u32; + let rs1_lt_diff: [u32; 2] = { + let diffs = &config.r_insn.rs1.lt_cfg.0.diff; + assert_eq!(diffs.len(), 2, "Expected 2 AssertLt diff limbs for RS1"); + [diffs[0].id as u32, diffs[1].id as u32] + }; + + // ReadRS2 + let rs2_id = config.r_insn.rs2.id.id as u32; + let rs2_prev_ts = config.r_insn.rs2.prev_ts.id as u32; + let rs2_lt_diff: [u32; 2] = { + let diffs = &config.r_insn.rs2.lt_cfg.0.diff; + assert_eq!(diffs.len(), 2, "Expected 2 AssertLt diff limbs for RS2"); + [diffs[0].id as u32, diffs[1].id as u32] + }; + + // WriteRD + let rd_id = config.r_insn.rd.id.id as u32; + let rd_prev_ts = config.r_insn.rd.prev_ts.id as u32; + let rd_prev_val: [u32; 2] = { + let limbs = config.r_insn.rd.prev_value.wits_in() + .expect("WriteRD prev_value should have WitIn limbs"); + assert_eq!(limbs.len(), 2, "Expected 2 prev_value limbs"); + [limbs[0].id as u32, limbs[1].id as u32] + }; + let rd_lt_diff: [u32; 2] = { + let diffs = &config.r_insn.rd.lt_cfg.0.diff; + assert_eq!(diffs.len(), 2, "Expected 2 AssertLt diff limbs for RD"); + [diffs[0].id as u32, diffs[1].id as u32] + }; + + // Arithmetic: rs1/rs2 u16 limbs + let rs1_limbs: [u32; 2] = { + let limbs = config.rs1_read.wits_in() + .expect("rs1_read should have WitIn limbs"); + assert_eq!(limbs.len(), 2, "Expected 2 rs1_read limbs"); + [limbs[0].id as u32, limbs[1].id as u32] + }; + let rs2_limbs: [u32; 2] = { + let limbs = config.rs2_read.wits_in() + .expect("rs2_read should have WitIn limbs"); + assert_eq!(limbs.len(), 2, "Expected 2 rs2_read limbs"); + [limbs[0].id as u32, limbs[1].id as u32] + }; + + // rd carries + let rd_carries: [u32; 2] = { + let carries = config.rd_written.carries.as_ref() + .expect("rd_written should have carries"); + assert_eq!(carries.len(), 2, "Expected 2 rd_written carries"); + [carries[0].id as u32, carries[1].id as u32] + }; + + AddColumnMap { + pc, + ts, + rs1_id, + rs1_prev_ts, + rs1_lt_diff, + rs2_id, + rs2_prev_ts, + rs2_lt_diff, + rd_id, + rd_prev_ts, + rd_prev_val, + rd_lt_diff, + rs1_limbs, + rs2_limbs, + rd_carries, + num_cols: num_witin as u32, + } +} + +/// Pack step records into SOA format for GPU transfer. +/// +/// Pre-computes shard-adjusted timing values on CPU so the GPU kernel +/// only needs to do witness filling. +pub fn pack_add_soa( + shard_ctx: &ShardContext, + shard_steps: &[StepRecord], + step_indices: &[StepIndex], +) -> AddStepRecordSOA { + let n = step_indices.len(); + let mut soa = AddStepRecordSOA::with_capacity(n); + + let offset = shard_ctx.current_shard_offset_cycle(); + + for &idx in step_indices { + let step = &shard_steps[idx]; + let rs1 = step.rs1().expect("ADD requires rs1"); + let rs2 = step.rs2().expect("ADD requires rs2"); + let rd = step.rd().expect("ADD requires rd"); + + soa.pc_before.push(step.pc().before.0); + soa.cycle.push(step.cycle() - offset); + soa.rs1_reg.push(rs1.register_index() as u32); + soa.rs1_val.push(rs1.value); + soa.rs1_prev_cycle.push(aligned_prev_ts(rs1.previous_cycle, offset)); + soa.rs2_reg.push(rs2.register_index() as u32); + soa.rs2_val.push(rs2.value); + soa.rs2_prev_cycle.push(aligned_prev_ts(rs2.previous_cycle, offset)); + soa.rd_reg.push(rd.register_index() as u32); + soa.rd_val_before.push(rd.value.before); + soa.rd_prev_cycle.push(aligned_prev_ts(rd.previous_cycle, offset)); + } + + soa +} + +/// Inline version of ShardContext::aligned_prev_ts for SOA packing. +fn aligned_prev_ts(prev_cycle: u64, shard_offset: u64) -> u64 { + let mut ts = prev_cycle.saturating_sub(shard_offset); + if ts < ceno_emul::FullTracer::SUBCYCLES_PER_INSN { + ts = 0; + } + ts +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + e2e::ShardContext, + instructions::Instruction, + instructions::riscv::arith::AddInstruction, + structs::ProgramParams, + }; + use ceno_emul::{Change, encode_rv32, InsnKind, ByteAddr}; + use ceno_gpu::bb31::CudaHalBB31; + use ceno_gpu::Buffer; + use ff_ext::BabyBearExt4; + + type E = BabyBearExt4; + + fn make_test_steps(n: usize) -> Vec { + // Use small PC values that fit within BabyBear field (P ≈ 2×10^9) + let pc_start = 0x1000u32; + (0..n) + .map(|i| { + let rs1 = (i as u32) % 1000 + 1; + let rs2 = (i as u32) % 500 + 3; + let rd_before = (i as u32) % 200; + let rd_after = rs1.wrapping_add(rs2); + let cycle = 4 + (i as u64) * 4; // cycles start at 4 (SUBCYCLES_PER_INSN) + let pc = ByteAddr(pc_start + (i as u32) * 4); + let insn_code = encode_rv32(InsnKind::ADD, 2, 3, 4, 0); + StepRecord::new_r_instruction( + cycle, + pc, + insn_code, + rs1, + rs2, + Change::new(rd_before, rd_after), + 0, // prev_cycle + ) + }) + .collect() + } + + #[test] + fn test_extract_add_column_map() { + let mut cs = ConstraintSystem::::new(|| "test"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = AddInstruction::::construct_circuit(&mut cb, &ProgramParams::default()) + .unwrap(); + + let col_map = extract_add_column_map(&config, cb.cs.num_witin as usize); + let flat = col_map.to_flat(); + + // All column IDs should be unique and within range + for (i, &col) in flat.iter().enumerate() { + assert!( + (col as usize) < col_map.num_cols as usize, + "Column {} (index {}) out of range: {} >= {}", + i, col, col, col_map.num_cols + ); + } + // Check uniqueness + let mut seen = std::collections::HashSet::new(); + for &col in &flat { + assert!(seen.insert(col), "Duplicate column ID: {}", col); + } + } + + #[test] + fn test_pack_add_soa() { + let steps = make_test_steps(4); + let indices: Vec = (0..steps.len()).collect(); + let shard_ctx = ShardContext::default(); + let soa = pack_add_soa(&shard_ctx, &steps, &indices); + + assert_eq!(soa.len(), 4); + // Check first step + assert_eq!(soa.rs1_val[0], 1); // 0 * 7 + 1 + assert_eq!(soa.rs2_val[0], 3); // 0 * 13 + 3 + } + + #[test] + #[cfg(feature = "gpu")] + fn test_gpu_witgen_add_correctness() { + let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); + + // Construct circuit + let mut cs = ConstraintSystem::::new(|| "test_gpu"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = AddInstruction::::construct_circuit(&mut cb, &ProgramParams::default()) + .unwrap(); + let num_witin = cb.cs.num_witin as usize; + let num_structural_witin = cb.cs.num_structural_witin as usize; + + // Generate test data + let n = 1024; + let steps = make_test_steps(n); + let indices: Vec = (0..n).collect(); + + // CPU path + let mut shard_ctx = ShardContext::default(); + let (cpu_rmms, _lkm) = AddInstruction::::assign_instances( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(); + let cpu_witness = &cpu_rmms[0]; // witness matrix (not structural) + + // GPU path + let col_map = extract_add_column_map(&config, num_witin); + let shard_ctx_gpu = ShardContext::default(); + let soa = pack_add_soa(&shard_ctx_gpu, &steps, &indices); + let gpu_result = hal.witgen_add(&col_map, &soa, None).unwrap(); + + // D2H copy + let gpu_data: Vec<::BaseField> = + gpu_result.device_buffer.to_vec().unwrap(); + + // Compare element by element + let cpu_data = cpu_witness.values(); + assert_eq!(gpu_data.len(), cpu_data.len(), + "Size mismatch: GPU {} vs CPU {}", gpu_data.len(), cpu_data.len()); + + let mut mismatches = 0; + for row in 0..n { + for col in 0..num_witin { + let gpu_val = gpu_data[row * num_witin + col]; + let cpu_val = cpu_data[row * num_witin + col]; + if gpu_val != cpu_val { + if mismatches < 10 { + eprintln!( + "Mismatch at row={}, col={}: GPU={:?}, CPU={:?}", + row, col, gpu_val, cpu_val + ); + } + mismatches += 1; + } + } + } + assert_eq!(mismatches, 0, "Found {} mismatches out of {} elements", + mismatches, n * num_witin); + } +} diff --git a/ceno_zkvm/src/instructions/riscv/gpu/mod.rs b/ceno_zkvm/src/instructions/riscv/gpu/mod.rs new file mode 100644 index 000000000..b0179cee0 --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/gpu/mod.rs @@ -0,0 +1,2 @@ +#[cfg(feature = "gpu")] +pub mod add; diff --git a/gkr_iop/src/gadgets/is_lt.rs b/gkr_iop/src/gadgets/is_lt.rs index d3f4a2ac6..b6dc4720f 100644 --- a/gkr_iop/src/gadgets/is_lt.rs +++ b/gkr_iop/src/gadgets/is_lt.rs @@ -12,7 +12,7 @@ use crate::{ }; #[derive(Debug, Clone)] -pub struct AssertLtConfig(InnerLtConfig); +pub struct AssertLtConfig(pub InnerLtConfig); impl AssertLtConfig { pub fn construct_circuit< From 5a10916e8a34b9dba1305d8cee5f4bc1e03f0ecc Mon Sep 17 00:00:00 2001 From: Velaciela Date: Tue, 3 Mar 2026 09:42:42 +0800 Subject: [PATCH 04/73] witgen: lw --- ceno_zkvm/src/instructions/riscv/gpu/lw.rs | 311 ++++++++++++++++++ ceno_zkvm/src/instructions/riscv/gpu/mod.rs | 2 + ceno_zkvm/src/instructions/riscv/im_insn.rs | 8 +- ceno_zkvm/src/instructions/riscv/insn_base.rs | 6 +- ceno_zkvm/src/instructions/riscv/memory.rs | 2 +- .../src/instructions/riscv/memory/load.rs | 16 +- .../src/instructions/riscv/memory/load_v2.rs | 18 +- 7 files changed, 338 insertions(+), 25 deletions(-) create mode 100644 ceno_zkvm/src/instructions/riscv/gpu/lw.rs diff --git a/ceno_zkvm/src/instructions/riscv/gpu/lw.rs b/ceno_zkvm/src/instructions/riscv/gpu/lw.rs new file mode 100644 index 000000000..65accacd2 --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/gpu/lw.rs @@ -0,0 +1,311 @@ +use ceno_emul::StepIndex; +use ceno_gpu::common::witgen_types::{LwColumnMap, LwStepRecordSOA}; +use ff_ext::ExtensionField; + +use crate::e2e::ShardContext; +#[cfg(not(feature = "u16limb_circuit"))] +use crate::instructions::riscv::memory::load::LoadConfig; +#[cfg(feature = "u16limb_circuit")] +use crate::instructions::riscv::memory::load_v2::LoadConfig; +use crate::tables::InsnRecord; +use ceno_emul::{ByteAddr, StepRecord}; + +/// Extract column map from a constructed LoadConfig (LW variant). +pub fn extract_lw_column_map( + config: &LoadConfig, + num_witin: usize, +) -> LwColumnMap { + let im = &config.im_insn; + + // StateInOut + let pc = im.vm_state.pc.id as u32; + let ts = im.vm_state.ts.id as u32; + + // ReadRS1 + let rs1_id = im.rs1.id.id as u32; + let rs1_prev_ts = im.rs1.prev_ts.id as u32; + let rs1_lt_diff = { + let d = &im.rs1.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + // WriteRD + let rd_id = im.rd.id.id as u32; + let rd_prev_ts = im.rd.prev_ts.id as u32; + let rd_prev_val = { + let l = im.rd.prev_value.wits_in().expect("rd prev_value WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + let rd_lt_diff = { + let d = &im.rd.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + // ReadMEM + let mem_prev_ts = im.mem_read.prev_ts.id as u32; + let mem_lt_diff = { + let d = &im.mem_read.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + // Load-specific + let rs1_limbs = { + let l = config.rs1_read.wits_in().expect("rs1_read WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + let imm = config.imm.id as u32; + #[cfg(feature = "u16limb_circuit")] + let imm_sign = Some(config.imm_sign.id as u32); + #[cfg(not(feature = "u16limb_circuit"))] + let imm_sign = None; + let mem_addr_limbs = { + let l = config.memory_addr.addr.wits_in().expect("memory_addr WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + let mem_read_limbs = { + let l = config.memory_read.wits_in().expect("memory_read WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + + LwColumnMap { + pc, + ts, + rs1_id, + rs1_prev_ts, + rs1_lt_diff, + rd_id, + rd_prev_ts, + rd_prev_val, + rd_lt_diff, + mem_prev_ts, + mem_lt_diff, + rs1_limbs, + imm, + imm_sign, + mem_addr_limbs, + mem_read_limbs, + num_cols: num_witin as u32, + } +} + +/// Pack step records into SOA format for LW GPU transfer. +pub fn pack_lw_soa( + shard_ctx: &ShardContext, + shard_steps: &[StepRecord], + step_indices: &[StepIndex], +) -> LwStepRecordSOA { + use p3::field::PrimeField32; + type B = ::BaseField; + + let n = step_indices.len(); + let mut soa = LwStepRecordSOA::with_capacity(n); + let offset = shard_ctx.current_shard_offset_cycle(); + + for &idx in step_indices { + let step = &shard_steps[idx]; + let rs1_op = step.rs1().expect("LW requires rs1"); + let rd_op = step.rd().expect("LW requires rd"); + let mem_op = step.memory_op().expect("LW requires memory_op"); + + // Compute imm field value (signed immediate as BabyBear) + let imm_pair = InsnRecord::::imm_internal(&step.insn()); + let imm_field_val: B = imm_pair.1; + + // Compute unaligned address + let unaligned_addr = + ByteAddr::from(rs1_op.value.wrapping_add_signed(imm_pair.0 as i32)); + + soa.pc_before.push(step.pc().before.0); + soa.cycle.push(step.cycle() - offset); + soa.rs1_reg.push(rs1_op.register_index() as u32); + soa.rs1_val.push(rs1_op.value); + soa.rs1_prev_cycle + .push(aligned_prev_ts(rs1_op.previous_cycle, offset)); + soa.rd_reg.push(rd_op.register_index() as u32); + soa.rd_val_before.push(rd_op.value.before); + soa.rd_prev_cycle + .push(aligned_prev_ts(rd_op.previous_cycle, offset)); + soa.mem_prev_cycle + .push(aligned_prev_ts(mem_op.previous_cycle, offset)); + soa.mem_val.push(mem_op.value.before); + soa.imm_field.push(imm_field_val.as_canonical_u32()); + soa.unaligned_addr.push(unaligned_addr.0); + + // imm_sign for v2 variant + #[cfg(feature = "u16limb_circuit")] + { + let imm_sign_extend = + crate::utils::imm_sign_extend(true, step.insn().imm as i16); + let is_neg = if imm_sign_extend[1] > 0 { 1u32 } else { 0u32 }; + if soa.imm_sign_field.is_none() { + soa.imm_sign_field = Some(Vec::with_capacity(n)); + } + soa.imm_sign_field.as_mut().unwrap().push(is_neg); + } + } + + soa +} + +fn aligned_prev_ts(prev_cycle: u64, shard_offset: u64) -> u64 { + let mut ts = prev_cycle.saturating_sub(shard_offset); + if ts < ceno_emul::FullTracer::SUBCYCLES_PER_INSN { + ts = 0; + } + ts +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + e2e::ShardContext, + instructions::Instruction, + structs::ProgramParams, + }; + use ceno_emul::{ + ByteAddr, Change, InsnKind, ReadOp, StepRecord, WordAddr, encode_rv32, + }; + use ceno_gpu::{Buffer, bb31::CudaHalBB31}; + use ff_ext::BabyBearExt4; + + type E = BabyBearExt4; + type LwInstruction = crate::instructions::riscv::LwInstruction; + + fn make_lw_test_steps(n: usize) -> Vec { + let pc_start = 0x1000u32; + (0..n) + .map(|i| { + let rs1_val = 0x100u32 + (i as u32) * 4; // base address, 4-byte aligned + let imm: i32 = 0; // zero offset for simplicity + let mem_addr = rs1_val.wrapping_add_signed(imm); + let mem_val = (i as u32) * 111 % 1000000; // some value < P + let rd_before = (i as u32) % 200; + let cycle = 4 + (i as u64) * 4; + let pc = ByteAddr(pc_start + (i as u32) * 4); + let insn_code = encode_rv32(InsnKind::LW, 2, 0, 4, imm); + + let mem_read_op = ReadOp { + addr: WordAddr::from(ByteAddr(mem_addr)), + value: mem_val, + previous_cycle: 0, + }; + + StepRecord::new_im_instruction( + cycle, + pc, + insn_code, + rs1_val, + Change::new(rd_before, mem_val), + mem_read_op, + 0, // prev_cycle + ) + }) + .collect() + } + + #[test] + fn test_extract_lw_column_map() { + let mut cs = ConstraintSystem::::new(|| "test_lw"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + LwInstruction::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + + let col_map = extract_lw_column_map(&config, cb.cs.num_witin as usize); + let (n_entries, flat) = col_map.to_flat(); + + // All column IDs should be within range + for (i, &col) in flat[..n_entries].iter().enumerate() { + assert!( + (col as usize) < col_map.num_cols as usize, + "Column {} (index {}) out of range: {} >= {}", + i, col, col, col_map.num_cols + ); + } + // Check uniqueness + let mut seen = std::collections::HashSet::new(); + for &col in &flat[..n_entries] { + assert!(seen.insert(col), "Duplicate column ID: {}", col); + } + } + + #[test] + fn test_gpu_witgen_lw_correctness() { + let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); + + let mut cs = ConstraintSystem::::new(|| "test_lw_gpu"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + LwInstruction::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let num_witin = cb.cs.num_witin as usize; + let num_structural_witin = cb.cs.num_structural_witin as usize; + + let n = 1024; + let steps = make_lw_test_steps(n); + let indices: Vec = (0..n).collect(); + + // CPU path + let mut shard_ctx = ShardContext::default(); + let (cpu_rmms, _lkm) = LwInstruction::assign_instances( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(); + let cpu_witness = &cpu_rmms[0]; + + // GPU path + let col_map = extract_lw_column_map(&config, num_witin); + let shard_ctx_gpu = ShardContext::default(); + let soa = pack_lw_soa::(&shard_ctx_gpu, &steps, &indices); + let gpu_result = hal.witgen_lw(&col_map, &soa, None).unwrap(); + + let gpu_data: Vec<::BaseField> = + gpu_result.device_buffer.to_vec().unwrap(); + + let cpu_data = cpu_witness.values(); + assert_eq!( + gpu_data.len(), + cpu_data.len(), + "Size mismatch: GPU {} vs CPU {}", + gpu_data.len(), + cpu_data.len() + ); + + // Only compare columns that the GPU fills (the col_map columns) + let (n_entries, flat) = col_map.to_flat(); + let mut mismatches = 0; + for row in 0..n { + for &col in &flat[..n_entries] { + let c = col as usize; + let gpu_val = gpu_data[row * num_witin + c]; + let cpu_val = cpu_data[row * num_witin + c]; + if gpu_val != cpu_val { + if mismatches < 10 { + eprintln!( + "Mismatch at row={}, col={}: GPU={:?}, CPU={:?}", + row, c, gpu_val, cpu_val + ); + } + mismatches += 1; + } + } + } + assert_eq!( + mismatches, 0, + "Found {} mismatches in GPU-filled columns", + mismatches + ); + } +} diff --git a/ceno_zkvm/src/instructions/riscv/gpu/mod.rs b/ceno_zkvm/src/instructions/riscv/gpu/mod.rs index b0179cee0..6d06c2672 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/mod.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/mod.rs @@ -1,2 +1,4 @@ #[cfg(feature = "gpu")] pub mod add; +#[cfg(feature = "gpu")] +pub mod lw; diff --git a/ceno_zkvm/src/instructions/riscv/im_insn.rs b/ceno_zkvm/src/instructions/riscv/im_insn.rs index c7f6cace0..26b8ce7b9 100644 --- a/ceno_zkvm/src/instructions/riscv/im_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/im_insn.rs @@ -17,10 +17,10 @@ use multilinear_extensions::{Expression, ToExpr}; /// - Register reads and writes /// - Memory reads pub struct IMInstructionConfig { - vm_state: StateInOut, - rs1: ReadRS1, - rd: WriteRD, - mem_read: ReadMEM, + pub vm_state: StateInOut, + pub rs1: ReadRS1, + pub rd: WriteRD, + pub mem_read: ReadMEM, } impl IMInstructionConfig { diff --git a/ceno_zkvm/src/instructions/riscv/insn_base.rs b/ceno_zkvm/src/instructions/riscv/insn_base.rs index 1a378ad8c..69ea105b7 100644 --- a/ceno_zkvm/src/instructions/riscv/insn_base.rs +++ b/ceno_zkvm/src/instructions/riscv/insn_base.rs @@ -438,9 +438,9 @@ impl WriteMEM { #[derive(Debug)] pub struct MemAddr { - addr: UInt, - low_bits: Vec, - max_bits: usize, + pub addr: UInt, + pub low_bits: Vec, + pub max_bits: usize, } impl MemAddr { diff --git a/ceno_zkvm/src/instructions/riscv/memory.rs b/ceno_zkvm/src/instructions/riscv/memory.rs index bb29491f7..ca432360b 100644 --- a/ceno_zkvm/src/instructions/riscv/memory.rs +++ b/ceno_zkvm/src/instructions/riscv/memory.rs @@ -6,7 +6,7 @@ pub mod load; pub mod store; #[cfg(feature = "u16limb_circuit")] -mod load_v2; +pub mod load_v2; #[cfg(feature = "u16limb_circuit")] mod store_v2; #[cfg(test)] diff --git a/ceno_zkvm/src/instructions/riscv/memory/load.rs b/ceno_zkvm/src/instructions/riscv/memory/load.rs index 818e8902a..e25d9c4c6 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/load.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/load.rs @@ -22,16 +22,16 @@ use p3::field::FieldAlgebra; use std::marker::PhantomData; pub struct LoadConfig { - im_insn: IMInstructionConfig, + pub im_insn: IMInstructionConfig, - rs1_read: UInt, - imm: WitIn, - memory_addr: MemAddr, + pub rs1_read: UInt, + pub imm: WitIn, + pub memory_addr: MemAddr, - memory_read: UInt, - target_limb: Option, - target_limb_bytes: Option>, - signed_extend_config: Option>, + pub memory_read: UInt, + pub target_limb: Option, + pub target_limb_bytes: Option>, + pub signed_extend_config: Option>, } pub struct LoadInstruction(PhantomData<(E, I)>); diff --git a/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs b/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs index 5a9ed40eb..b5e4ba807 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs @@ -25,17 +25,17 @@ use p3::field::{Field, FieldAlgebra}; use std::marker::PhantomData; pub struct LoadConfig { - im_insn: IMInstructionConfig, + pub im_insn: IMInstructionConfig, - rs1_read: UInt, - imm: WitIn, - imm_sign: WitIn, - memory_addr: MemAddr, + pub rs1_read: UInt, + pub imm: WitIn, + pub imm_sign: WitIn, + pub memory_addr: MemAddr, - memory_read: UInt, - target_limb: Option, - target_limb_bytes: Option>, - signed_extend_config: Option>, + pub memory_read: UInt, + pub target_limb: Option, + pub target_limb_bytes: Option>, + pub signed_extend_config: Option>, } pub struct LoadInstruction(PhantomData<(E, I)>); From 07360f8fb9b68c758be565fc4f896e5a8e54a254 Mon Sep 17 00:00:00 2001 From: Velaciela Date: Tue, 3 Mar 2026 09:42:57 +0800 Subject: [PATCH 05/73] witgen: integration --- ceno_zkvm/Cargo.toml | 5 + ceno_zkvm/src/instructions.rs | 70 +++++ ceno_zkvm/src/instructions/riscv/arith.rs | 42 +++ ceno_zkvm/src/instructions/riscv/gpu/mod.rs | 2 + .../src/instructions/riscv/gpu/witgen_gpu.rs | 240 ++++++++++++++++++ .../src/instructions/riscv/memory/load.rs | 39 +++ .../src/instructions/riscv/memory/load_v2.rs | 39 +++ 7 files changed, 437 insertions(+) create mode 100644 ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs diff --git a/ceno_zkvm/Cargo.toml b/ceno_zkvm/Cargo.toml index 42d4e869b..eb350c444 100644 --- a/ceno_zkvm/Cargo.toml +++ b/ceno_zkvm/Cargo.toml @@ -134,3 +134,8 @@ name = "weierstrass_add" [[bench]] harness = false name = "weierstrass_double" + +[[bench]] +harness = false +name = "witgen_add_gpu" +required-features = ["gpu"] diff --git a/ceno_zkvm/src/instructions.rs b/ceno_zkvm/src/instructions.rs index 9dd99ef92..521879386 100644 --- a/ceno_zkvm/src/instructions.rs +++ b/ceno_zkvm/src/instructions.rs @@ -190,3 +190,73 @@ pub trait Instruction { pub fn full_step_indices(steps: &[StepRecord]) -> Vec { (0..steps.len()).collect() } + +/// CPU-only assign_instances. Extracted so GPU-enabled instructions can call this as fallback. +pub fn cpu_assign_instances>( + config: &I::InstructionConfig, + shard_ctx: &mut ShardContext, + num_witin: usize, + num_structural_witin: usize, + shard_steps: &[StepRecord], + step_indices: &[StepIndex], +) -> Result<(crate::tables::RMMCollections, gkr_iop::utils::lk_multiplicity::Multiplicity), ZKVMError> { + assert!(num_structural_witin == 0 || num_structural_witin == 1); + let num_structural_witin = num_structural_witin.max(1); + + let nthreads = multilinear_extensions::util::max_usable_threads(); + let total_instances = step_indices.len(); + let num_instance_per_batch = if total_instances > 256 { + total_instances.div_ceil(nthreads) + } else { + total_instances + } + .max(1); + let lk_multiplicity = crate::witness::LkMultiplicity::default(); + let mut raw_witin = RowMajorMatrix::::new( + total_instances, + num_witin, + I::padding_strategy(), + ); + let mut raw_structual_witin = RowMajorMatrix::::new( + total_instances, + num_structural_witin, + I::padding_strategy(), + ); + let raw_witin_iter = raw_witin.par_batch_iter_mut(num_instance_per_batch); + let raw_structual_witin_iter = + raw_structual_witin.par_batch_iter_mut(num_instance_per_batch); + let shard_ctx_vec = shard_ctx.get_forked(); + + raw_witin_iter + .zip_eq(raw_structual_witin_iter) + .zip_eq(step_indices.par_chunks(num_instance_per_batch)) + .zip(shard_ctx_vec) + .flat_map( + |(((instances, structural_instance), indices), mut shard_ctx)| { + let mut lk_multiplicity = lk_multiplicity.clone(); + instances + .chunks_mut(num_witin) + .zip_eq(structural_instance.chunks_mut(num_structural_witin)) + .zip_eq(indices.iter().copied()) + .map(|((instance, structural_instance), step_idx)| { + *structural_instance.last_mut().unwrap() = E::BaseField::ONE; + I::assign_instance( + config, + &mut shard_ctx, + instance, + &mut lk_multiplicity, + &shard_steps[step_idx], + ) + }) + .collect::>() + }, + ) + .collect::>()?; + + raw_witin.padding_by_strategy(); + raw_structual_witin.padding_by_strategy(); + Ok(( + [raw_witin, raw_structual_witin], + lk_multiplicity.into_finalize_result(), + )) +} diff --git a/ceno_zkvm/src/instructions/riscv/arith.rs b/ceno_zkvm/src/instructions/riscv/arith.rs index 1fd5f98d4..aa8a05093 100644 --- a/ceno_zkvm/src/instructions/riscv/arith.rs +++ b/ceno_zkvm/src/instructions/riscv/arith.rs @@ -8,6 +8,13 @@ use crate::{ use ceno_emul::{InsnKind, StepRecord}; use ff_ext::ExtensionField; +#[cfg(feature = "gpu")] +use crate::tables::RMMCollections; +#[cfg(feature = "gpu")] +use ceno_emul::StepIndex; +#[cfg(feature = "gpu")] +use gkr_iop::utils::lk_multiplicity::Multiplicity; + /// This config handles R-Instructions that represent registers values as 2 * u16. #[derive(Debug)] pub struct ArithConfig { @@ -132,6 +139,41 @@ impl Instruction for ArithInstruction Result<(RMMCollections, Multiplicity), ZKVMError> { + use crate::instructions::riscv::gpu::witgen_gpu; + // Only ADD gets GPU path; SUB and others fall through to CPU + if I::INST_KIND == InsnKind::ADD { + if let Some(result) = witgen_gpu::try_gpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + witgen_gpu::GpuWitgenKind::Add, + )? { + return Ok(result); + } + } + // Fallback to CPU path + crate::instructions::cpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + ) + } } #[cfg(test)] diff --git a/ceno_zkvm/src/instructions/riscv/gpu/mod.rs b/ceno_zkvm/src/instructions/riscv/gpu/mod.rs index 6d06c2672..5ebf0d50b 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/mod.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/mod.rs @@ -2,3 +2,5 @@ pub mod add; #[cfg(feature = "gpu")] pub mod lw; +#[cfg(feature = "gpu")] +pub mod witgen_gpu; diff --git a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs new file mode 100644 index 000000000..0d419ed56 --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs @@ -0,0 +1,240 @@ +/// GPU witness generation dispatcher for the proving pipeline. +/// +/// This module provides `try_gpu_assign_instances` which: +/// 1. Runs the GPU kernel to fill the witness matrix (fast) +/// 2. Runs a CPU loop to collect side effects (shard_ctx.send, lk_multiplicity) +/// 3. Returns the GPU-generated witness + CPU-collected side effects +use ceno_emul::{StepIndex, StepRecord}; +use ceno_gpu::bb31::CudaHalBB31; +use ceno_gpu::Buffer; +use ff_ext::ExtensionField; +use gkr_iop::utils::lk_multiplicity::Multiplicity; +use multilinear_extensions::util::max_usable_threads; +use p3::field::FieldAlgebra; +use rayon::iter::{IndexedParallelIterator, ParallelIterator}; +use rayon::slice::ParallelSlice; +use tracing::info_span; +use witness::{InstancePaddingStrategy, RowMajorMatrix}; + +use crate::e2e::ShardContext; +use crate::error::ZKVMError; +use crate::instructions::Instruction; +use crate::tables::RMMCollections; +use crate::witness::LkMultiplicity; + +#[derive(Debug, Clone, Copy)] +pub enum GpuWitgenKind { + Add, + Lw, +} + +/// Try to run GPU witness generation for the given instruction. +/// Returns `Ok(Some(...))` if GPU was used, `Ok(None)` if GPU is unavailable (caller should fallback to CPU). +pub fn try_gpu_assign_instances>( + config: &I::InstructionConfig, + shard_ctx: &mut ShardContext, + num_witin: usize, + num_structural_witin: usize, + shard_steps: &[StepRecord], + step_indices: &[StepIndex], + kind: GpuWitgenKind, +) -> Result, Multiplicity)>, ZKVMError> { + use gkr_iop::gpu::get_cuda_hal; + + let total_instances = step_indices.len(); + if total_instances == 0 { + // Empty: just return empty matrices + let num_structural_witin = num_structural_witin.max(1); + let raw_witin = + RowMajorMatrix::::new(0, num_witin, I::padding_strategy()); + let raw_structural = + RowMajorMatrix::::new(0, num_structural_witin, I::padding_strategy()); + let lk = LkMultiplicity::default(); + return Ok(Some(([raw_witin, raw_structural], lk.into_finalize_result()))); + } + + let hal = match get_cuda_hal() { + Ok(hal) => hal, + Err(_) => return Ok(None), // GPU not available, fallback to CPU + }; + + info_span!("gpu_witgen", kind = ?kind, n = total_instances).in_scope(|| { + gpu_assign_instances_inner::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + kind, + &hal, + ) + .map(Some) + }) +} + +fn gpu_assign_instances_inner>( + config: &I::InstructionConfig, + shard_ctx: &mut ShardContext, + num_witin: usize, + num_structural_witin: usize, + shard_steps: &[StepRecord], + step_indices: &[StepIndex], + kind: GpuWitgenKind, + hal: &CudaHalBB31, +) -> Result<(RMMCollections, Multiplicity), ZKVMError> { + let num_structural_witin = num_structural_witin.max(1); + let total_instances = step_indices.len(); + + // Step 1: GPU fills witness matrix + let gpu_witness = info_span!("gpu_kernel").in_scope(|| { + gpu_fill_witness::(hal, config, shard_ctx, num_witin, shard_steps, step_indices, kind) + })?; + + // Step 2: CPU collects side effects (shard_ctx.send, lk_multiplicity) + // We run assign_instance with a scratch buffer per thread and discard the witness data. + let lk_multiplicity = info_span!("cpu_side_effects").in_scope(|| { + collect_side_effects::(config, shard_ctx, num_witin, shard_steps, step_indices) + })?; + + // Step 3: Build structural witness (just selector = ONE) + let mut raw_structural = RowMajorMatrix::::new( + total_instances, + num_structural_witin, + I::padding_strategy(), + ); + for row in raw_structural.iter_mut() { + *row.last_mut().unwrap() = E::BaseField::ONE; + } + raw_structural.padding_by_strategy(); + + // Step 4: Convert GPU witness to RowMajorMatrix + let mut raw_witin = info_span!("d2h_copy").in_scope(|| { + gpu_witness_to_rmm::(gpu_witness, total_instances, num_witin, I::padding_strategy()) + })?; + raw_witin.padding_by_strategy(); + + Ok(([raw_witin, raw_structural], lk_multiplicity.into_finalize_result())) +} + +/// GPU kernel dispatch based on instruction kind. +fn gpu_fill_witness>( + hal: &CudaHalBB31, + config: &I::InstructionConfig, + shard_ctx: &ShardContext, + num_witin: usize, + shard_steps: &[StepRecord], + step_indices: &[StepIndex], + kind: GpuWitgenKind, +) -> Result::BaseField>>, ZKVMError> { + match kind { + GpuWitgenKind::Add => { + // Safety: we know config is ArithConfig when kind == Add + let arith_config = unsafe { + &*(config as *const I::InstructionConfig + as *const crate::instructions::riscv::arith::ArithConfig) + }; + let col_map = + super::add::extract_add_column_map(arith_config, num_witin); + let soa = super::add::pack_add_soa(shard_ctx, shard_steps, step_indices); + hal.witgen_add(&col_map, &soa, None) + .map_err(|e| ZKVMError::InvalidWitness(format!("GPU witgen_add failed: {e}").into())) + } + GpuWitgenKind::Lw => { + // LoadConfig location depends on the u16limb_circuit feature + #[cfg(feature = "u16limb_circuit")] + let load_config = unsafe { + &*(config as *const I::InstructionConfig + as *const crate::instructions::riscv::memory::load_v2::LoadConfig) + }; + #[cfg(not(feature = "u16limb_circuit"))] + let load_config = unsafe { + &*(config as *const I::InstructionConfig + as *const crate::instructions::riscv::memory::load::LoadConfig) + }; + let col_map = + super::lw::extract_lw_column_map(load_config, num_witin); + let soa = super::lw::pack_lw_soa::(shard_ctx, shard_steps, step_indices); + hal.witgen_lw(&col_map, &soa, None) + .map_err(|e| ZKVMError::InvalidWitness(format!("GPU witgen_lw failed: {e}").into())) + } + } +} + +/// CPU-side loop to collect side effects only (shard_ctx.send, lk_multiplicity). +/// Runs assign_instance with a scratch buffer per thread. +fn collect_side_effects>( + config: &I::InstructionConfig, + shard_ctx: &mut ShardContext, + num_witin: usize, + shard_steps: &[StepRecord], + step_indices: &[StepIndex], +) -> Result { + let nthreads = max_usable_threads(); + let total = step_indices.len(); + let batch_size = if total > 256 { + total.div_ceil(nthreads) + } else { + total + } + .max(1); + + let lk_multiplicity = LkMultiplicity::default(); + let shard_ctx_vec = shard_ctx.get_forked(); + + step_indices + .par_chunks(batch_size) + .zip(shard_ctx_vec) + .flat_map(|(indices, mut shard_ctx)| { + let mut lk_multiplicity = lk_multiplicity.clone(); + // Reusable scratch buffer for this thread's assign_instance calls + let mut scratch = vec![E::BaseField::ZERO; num_witin]; + indices + .iter() + .copied() + .map(|step_idx| { + // Zero out scratch for each step + scratch.fill(E::BaseField::ZERO); + I::assign_instance( + config, + &mut shard_ctx, + &mut scratch, + &mut lk_multiplicity, + &shard_steps[step_idx], + ) + }) + .collect::>() + }) + .collect::>()?; + + Ok(lk_multiplicity) +} + +/// Convert GPU device buffer to RowMajorMatrix via D2H copy. +fn gpu_witness_to_rmm( + gpu_result: ceno_gpu::common::witgen_types::GpuWitnessResult< + ceno_gpu::common::BufferImpl<'static, ::BaseField>, + >, + num_rows: usize, + num_cols: usize, + padding: InstancePaddingStrategy, +) -> Result, ZKVMError> { + let gpu_data: Vec<::BaseField> = gpu_result + .device_buffer + .to_vec() + .map_err(|e| ZKVMError::InvalidWitness(format!("GPU D2H copy failed: {e}").into()))?; + + // Safety: BabyBear is the only supported GPU field, and E::BaseField must match + let data: Vec = unsafe { + let mut data = std::mem::ManuallyDrop::new(gpu_data); + Vec::from_raw_parts( + data.as_mut_ptr() as *mut E::BaseField, + data.len(), + data.capacity(), + ) + }; + + Ok(RowMajorMatrix::::new_by_values( + data, num_cols, padding, + )) +} diff --git a/ceno_zkvm/src/instructions/riscv/memory/load.rs b/ceno_zkvm/src/instructions/riscv/memory/load.rs index e25d9c4c6..08ca6c878 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/load.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/load.rs @@ -226,4 +226,43 @@ impl Instruction for LoadInstruction Result< + ( + crate::tables::RMMCollections, + gkr_iop::utils::lk_multiplicity::Multiplicity, + ), + crate::error::ZKVMError, + > { + use crate::instructions::riscv::gpu::witgen_gpu; + if I::INST_KIND == InsnKind::LW { + if let Some(result) = witgen_gpu::try_gpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + witgen_gpu::GpuWitgenKind::Lw, + )? { + return Ok(result); + } + } + crate::instructions::cpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + ) + } } diff --git a/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs b/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs index b5e4ba807..efe5b8a3b 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs @@ -251,4 +251,43 @@ impl Instruction for LoadInstruction Result< + ( + crate::tables::RMMCollections, + gkr_iop::utils::lk_multiplicity::Multiplicity, + ), + crate::error::ZKVMError, + > { + use crate::instructions::riscv::gpu::witgen_gpu; + if I::INST_KIND == InsnKind::LW { + if let Some(result) = witgen_gpu::try_gpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + witgen_gpu::GpuWitgenKind::Lw, + )? { + return Ok(result); + } + } + crate::instructions::cpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + ) + } } From 9b673ca7c0ae5066c84dadd90ee36b97e75e7d66 Mon Sep 17 00:00:00 2001 From: Velaciela Date: Tue, 3 Mar 2026 09:40:40 +0800 Subject: [PATCH 06/73] minor --- ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs index 0d419ed56..ad007e9e6 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs @@ -53,11 +53,19 @@ pub fn try_gpu_assign_instances>( return Ok(Some(([raw_witin, raw_structural], lk.into_finalize_result()))); } + // GPU only supports BabyBear field + if std::any::TypeId::of::() + != std::any::TypeId::of::<::BaseField>() + { + return Ok(None); + } + let hal = match get_cuda_hal() { Ok(hal) => hal, Err(_) => return Ok(None), // GPU not available, fallback to CPU }; + tracing::debug!("[GPU witgen] {:?} with {} instances", kind, total_instances); info_span!("gpu_witgen", kind = ?kind, n = total_instances).in_scope(|| { gpu_assign_instances_inner::( config, From 35154b11b1719f9bb65a8a8df58c873c08c51cd9 Mon Sep 17 00:00:00 2001 From: Velaciela Date: Tue, 3 Mar 2026 09:43:22 +0800 Subject: [PATCH 07/73] fmt --- ceno_zkvm/benches/witgen_add_gpu.rs | 7 +- ceno_zkvm/src/instructions.rs | 18 ++--- ceno_zkvm/src/instructions/riscv/gpu/add.rs | 70 +++++++++++++------ ceno_zkvm/src/instructions/riscv/gpu/lw.rs | 30 ++++---- .../src/instructions/riscv/gpu/witgen_gpu.rs | 66 +++++++++++------ 5 files changed, 118 insertions(+), 73 deletions(-) diff --git a/ceno_zkvm/benches/witgen_add_gpu.rs b/ceno_zkvm/benches/witgen_add_gpu.rs index 360582d11..811d69998 100644 --- a/ceno_zkvm/benches/witgen_add_gpu.rs +++ b/ceno_zkvm/benches/witgen_add_gpu.rs @@ -1,15 +1,12 @@ use std::time::Duration; +use ceno_emul::{ByteAddr, Change, InsnKind, StepRecord, encode_rv32}; use ceno_zkvm::{ circuit_builder::{CircuitBuilder, ConstraintSystem}, e2e::ShardContext, - instructions::{ - Instruction, - riscv::arith::AddInstruction, - }, + instructions::{Instruction, riscv::arith::AddInstruction}, structs::ProgramParams, }; -use ceno_emul::{ByteAddr, Change, InsnKind, StepRecord, encode_rv32}; use criterion::*; use ff_ext::BabyBearExt4; diff --git a/ceno_zkvm/src/instructions.rs b/ceno_zkvm/src/instructions.rs index 521879386..89deb6fbc 100644 --- a/ceno_zkvm/src/instructions.rs +++ b/ceno_zkvm/src/instructions.rs @@ -199,7 +199,13 @@ pub fn cpu_assign_instances>( num_structural_witin: usize, shard_steps: &[StepRecord], step_indices: &[StepIndex], -) -> Result<(crate::tables::RMMCollections, gkr_iop::utils::lk_multiplicity::Multiplicity), ZKVMError> { +) -> Result< + ( + crate::tables::RMMCollections, + gkr_iop::utils::lk_multiplicity::Multiplicity, + ), + ZKVMError, +> { assert!(num_structural_witin == 0 || num_structural_witin == 1); let num_structural_witin = num_structural_witin.max(1); @@ -212,19 +218,15 @@ pub fn cpu_assign_instances>( } .max(1); let lk_multiplicity = crate::witness::LkMultiplicity::default(); - let mut raw_witin = RowMajorMatrix::::new( - total_instances, - num_witin, - I::padding_strategy(), - ); + let mut raw_witin = + RowMajorMatrix::::new(total_instances, num_witin, I::padding_strategy()); let mut raw_structual_witin = RowMajorMatrix::::new( total_instances, num_structural_witin, I::padding_strategy(), ); let raw_witin_iter = raw_witin.par_batch_iter_mut(num_instance_per_batch); - let raw_structual_witin_iter = - raw_structual_witin.par_batch_iter_mut(num_instance_per_batch); + let raw_structual_witin_iter = raw_structual_witin.par_batch_iter_mut(num_instance_per_batch); let shard_ctx_vec = shard_ctx.get_forked(); raw_witin_iter diff --git a/ceno_zkvm/src/instructions/riscv/gpu/add.rs b/ceno_zkvm/src/instructions/riscv/gpu/add.rs index 8eefc05f7..a8718e870 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/add.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/add.rs @@ -2,8 +2,7 @@ use ceno_emul::StepIndex; use ceno_gpu::common::witgen_types::{AddColumnMap, AddStepRecordSOA}; use ff_ext::ExtensionField; -use crate::instructions::riscv::arith::ArithConfig; -use crate::e2e::ShardContext; +use crate::{e2e::ShardContext, instructions::riscv::arith::ArithConfig}; use ceno_emul::StepRecord; /// Extract column map from a constructed ArithConfig (ADD variant). @@ -40,7 +39,11 @@ pub fn extract_add_column_map( let rd_id = config.r_insn.rd.id.id as u32; let rd_prev_ts = config.r_insn.rd.prev_ts.id as u32; let rd_prev_val: [u32; 2] = { - let limbs = config.r_insn.rd.prev_value.wits_in() + let limbs = config + .r_insn + .rd + .prev_value + .wits_in() .expect("WriteRD prev_value should have WitIn limbs"); assert_eq!(limbs.len(), 2, "Expected 2 prev_value limbs"); [limbs[0].id as u32, limbs[1].id as u32] @@ -53,13 +56,17 @@ pub fn extract_add_column_map( // Arithmetic: rs1/rs2 u16 limbs let rs1_limbs: [u32; 2] = { - let limbs = config.rs1_read.wits_in() + let limbs = config + .rs1_read + .wits_in() .expect("rs1_read should have WitIn limbs"); assert_eq!(limbs.len(), 2, "Expected 2 rs1_read limbs"); [limbs[0].id as u32, limbs[1].id as u32] }; let rs2_limbs: [u32; 2] = { - let limbs = config.rs2_read.wits_in() + let limbs = config + .rs2_read + .wits_in() .expect("rs2_read should have WitIn limbs"); assert_eq!(limbs.len(), 2, "Expected 2 rs2_read limbs"); [limbs[0].id as u32, limbs[1].id as u32] @@ -67,7 +74,10 @@ pub fn extract_add_column_map( // rd carries let rd_carries: [u32; 2] = { - let carries = config.rd_written.carries.as_ref() + let carries = config + .rd_written + .carries + .as_ref() .expect("rd_written should have carries"); assert_eq!(carries.len(), 2, "Expected 2 rd_written carries"); [carries[0].id as u32, carries[1].id as u32] @@ -117,13 +127,16 @@ pub fn pack_add_soa( soa.cycle.push(step.cycle() - offset); soa.rs1_reg.push(rs1.register_index() as u32); soa.rs1_val.push(rs1.value); - soa.rs1_prev_cycle.push(aligned_prev_ts(rs1.previous_cycle, offset)); + soa.rs1_prev_cycle + .push(aligned_prev_ts(rs1.previous_cycle, offset)); soa.rs2_reg.push(rs2.register_index() as u32); soa.rs2_val.push(rs2.value); - soa.rs2_prev_cycle.push(aligned_prev_ts(rs2.previous_cycle, offset)); + soa.rs2_prev_cycle + .push(aligned_prev_ts(rs2.previous_cycle, offset)); soa.rd_reg.push(rd.register_index() as u32); soa.rd_val_before.push(rd.value.before); - soa.rd_prev_cycle.push(aligned_prev_ts(rd.previous_cycle, offset)); + soa.rd_prev_cycle + .push(aligned_prev_ts(rd.previous_cycle, offset)); } soa @@ -144,13 +157,11 @@ mod tests { use crate::{ circuit_builder::{CircuitBuilder, ConstraintSystem}, e2e::ShardContext, - instructions::Instruction, - instructions::riscv::arith::AddInstruction, + instructions::{Instruction, riscv::arith::AddInstruction}, structs::ProgramParams, }; - use ceno_emul::{Change, encode_rv32, InsnKind, ByteAddr}; - use ceno_gpu::bb31::CudaHalBB31; - use ceno_gpu::Buffer; + use ceno_emul::{ByteAddr, Change, InsnKind, encode_rv32}; + use ceno_gpu::{Buffer, bb31::CudaHalBB31}; use ff_ext::BabyBearExt4; type E = BabyBearExt4; @@ -184,8 +195,8 @@ mod tests { fn test_extract_add_column_map() { let mut cs = ConstraintSystem::::new(|| "test"); let mut cb = CircuitBuilder::new(&mut cs); - let config = AddInstruction::::construct_circuit(&mut cb, &ProgramParams::default()) - .unwrap(); + let config = + AddInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); let col_map = extract_add_column_map(&config, cb.cs.num_witin as usize); let flat = col_map.to_flat(); @@ -195,7 +206,10 @@ mod tests { assert!( (col as usize) < col_map.num_cols as usize, "Column {} (index {}) out of range: {} >= {}", - i, col, col, col_map.num_cols + i, + col, + col, + col_map.num_cols ); } // Check uniqueness @@ -226,8 +240,8 @@ mod tests { // Construct circuit let mut cs = ConstraintSystem::::new(|| "test_gpu"); let mut cb = CircuitBuilder::new(&mut cs); - let config = AddInstruction::::construct_circuit(&mut cb, &ProgramParams::default()) - .unwrap(); + let config = + AddInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); let num_witin = cb.cs.num_witin as usize; let num_structural_witin = cb.cs.num_structural_witin as usize; @@ -261,8 +275,13 @@ mod tests { // Compare element by element let cpu_data = cpu_witness.values(); - assert_eq!(gpu_data.len(), cpu_data.len(), - "Size mismatch: GPU {} vs CPU {}", gpu_data.len(), cpu_data.len()); + assert_eq!( + gpu_data.len(), + cpu_data.len(), + "Size mismatch: GPU {} vs CPU {}", + gpu_data.len(), + cpu_data.len() + ); let mut mismatches = 0; for row in 0..n { @@ -280,7 +299,12 @@ mod tests { } } } - assert_eq!(mismatches, 0, "Found {} mismatches out of {} elements", - mismatches, n * num_witin); + assert_eq!( + mismatches, + 0, + "Found {} mismatches out of {} elements", + mismatches, + n * num_witin + ); } } diff --git a/ceno_zkvm/src/instructions/riscv/gpu/lw.rs b/ceno_zkvm/src/instructions/riscv/gpu/lw.rs index 65accacd2..6ab14a86c 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/lw.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/lw.rs @@ -2,12 +2,11 @@ use ceno_emul::StepIndex; use ceno_gpu::common::witgen_types::{LwColumnMap, LwStepRecordSOA}; use ff_ext::ExtensionField; -use crate::e2e::ShardContext; #[cfg(not(feature = "u16limb_circuit"))] use crate::instructions::riscv::memory::load::LoadConfig; #[cfg(feature = "u16limb_circuit")] use crate::instructions::riscv::memory::load_v2::LoadConfig; -use crate::tables::InsnRecord; +use crate::{e2e::ShardContext, tables::InsnRecord}; use ceno_emul::{ByteAddr, StepRecord}; /// Extract column map from a constructed LoadConfig (LW variant). @@ -64,7 +63,11 @@ pub fn extract_lw_column_map( #[cfg(not(feature = "u16limb_circuit"))] let imm_sign = None; let mem_addr_limbs = { - let l = config.memory_addr.addr.wits_in().expect("memory_addr WitIns"); + let l = config + .memory_addr + .addr + .wits_in() + .expect("memory_addr WitIns"); assert_eq!(l.len(), 2); [l[0].id as u32, l[1].id as u32] }; @@ -119,8 +122,7 @@ pub fn pack_lw_soa( let imm_field_val: B = imm_pair.1; // Compute unaligned address - let unaligned_addr = - ByteAddr::from(rs1_op.value.wrapping_add_signed(imm_pair.0 as i32)); + let unaligned_addr = ByteAddr::from(rs1_op.value.wrapping_add_signed(imm_pair.0 as i32)); soa.pc_before.push(step.pc().before.0); soa.cycle.push(step.cycle() - offset); @@ -141,8 +143,7 @@ pub fn pack_lw_soa( // imm_sign for v2 variant #[cfg(feature = "u16limb_circuit")] { - let imm_sign_extend = - crate::utils::imm_sign_extend(true, step.insn().imm as i16); + let imm_sign_extend = crate::utils::imm_sign_extend(true, step.insn().imm as i16); let is_neg = if imm_sign_extend[1] > 0 { 1u32 } else { 0u32 }; if soa.imm_sign_field.is_none() { soa.imm_sign_field = Some(Vec::with_capacity(n)); @@ -171,9 +172,7 @@ mod tests { instructions::Instruction, structs::ProgramParams, }; - use ceno_emul::{ - ByteAddr, Change, InsnKind, ReadOp, StepRecord, WordAddr, encode_rv32, - }; + use ceno_emul::{ByteAddr, Change, InsnKind, ReadOp, StepRecord, WordAddr, encode_rv32}; use ceno_gpu::{Buffer, bb31::CudaHalBB31}; use ff_ext::BabyBearExt4; @@ -216,8 +215,7 @@ mod tests { fn test_extract_lw_column_map() { let mut cs = ConstraintSystem::::new(|| "test_lw"); let mut cb = CircuitBuilder::new(&mut cs); - let config = - LwInstruction::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let config = LwInstruction::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); let col_map = extract_lw_column_map(&config, cb.cs.num_witin as usize); let (n_entries, flat) = col_map.to_flat(); @@ -227,7 +225,10 @@ mod tests { assert!( (col as usize) < col_map.num_cols as usize, "Column {} (index {}) out of range: {} >= {}", - i, col, col, col_map.num_cols + i, + col, + col, + col_map.num_cols ); } // Check uniqueness @@ -243,8 +244,7 @@ mod tests { let mut cs = ConstraintSystem::::new(|| "test_lw_gpu"); let mut cb = CircuitBuilder::new(&mut cs); - let config = - LwInstruction::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let config = LwInstruction::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); let num_witin = cb.cs.num_witin as usize; let num_structural_witin = cb.cs.num_structural_witin as usize; diff --git a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs index ad007e9e6..6f47d9914 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs @@ -5,22 +5,22 @@ /// 2. Runs a CPU loop to collect side effects (shard_ctx.send, lk_multiplicity) /// 3. Returns the GPU-generated witness + CPU-collected side effects use ceno_emul::{StepIndex, StepRecord}; -use ceno_gpu::bb31::CudaHalBB31; -use ceno_gpu::Buffer; +use ceno_gpu::{Buffer, bb31::CudaHalBB31}; use ff_ext::ExtensionField; use gkr_iop::utils::lk_multiplicity::Multiplicity; use multilinear_extensions::util::max_usable_threads; use p3::field::FieldAlgebra; -use rayon::iter::{IndexedParallelIterator, ParallelIterator}; -use rayon::slice::ParallelSlice; +use rayon::{ + iter::{IndexedParallelIterator, ParallelIterator}, + slice::ParallelSlice, +}; use tracing::info_span; use witness::{InstancePaddingStrategy, RowMajorMatrix}; -use crate::e2e::ShardContext; -use crate::error::ZKVMError; -use crate::instructions::Instruction; -use crate::tables::RMMCollections; -use crate::witness::LkMultiplicity; +use crate::{ + e2e::ShardContext, error::ZKVMError, instructions::Instruction, tables::RMMCollections, + witness::LkMultiplicity, +}; #[derive(Debug, Clone, Copy)] pub enum GpuWitgenKind { @@ -45,12 +45,14 @@ pub fn try_gpu_assign_instances>( if total_instances == 0 { // Empty: just return empty matrices let num_structural_witin = num_structural_witin.max(1); - let raw_witin = - RowMajorMatrix::::new(0, num_witin, I::padding_strategy()); + let raw_witin = RowMajorMatrix::::new(0, num_witin, I::padding_strategy()); let raw_structural = RowMajorMatrix::::new(0, num_structural_witin, I::padding_strategy()); let lk = LkMultiplicity::default(); - return Ok(Some(([raw_witin, raw_structural], lk.into_finalize_result()))); + return Ok(Some(( + [raw_witin, raw_structural], + lk.into_finalize_result(), + ))); } // GPU only supports BabyBear field @@ -96,7 +98,15 @@ fn gpu_assign_instances_inner>( // Step 1: GPU fills witness matrix let gpu_witness = info_span!("gpu_kernel").in_scope(|| { - gpu_fill_witness::(hal, config, shard_ctx, num_witin, shard_steps, step_indices, kind) + gpu_fill_witness::( + hal, + config, + shard_ctx, + num_witin, + shard_steps, + step_indices, + kind, + ) })?; // Step 2: CPU collects side effects (shard_ctx.send, lk_multiplicity) @@ -118,11 +128,19 @@ fn gpu_assign_instances_inner>( // Step 4: Convert GPU witness to RowMajorMatrix let mut raw_witin = info_span!("d2h_copy").in_scope(|| { - gpu_witness_to_rmm::(gpu_witness, total_instances, num_witin, I::padding_strategy()) + gpu_witness_to_rmm::( + gpu_witness, + total_instances, + num_witin, + I::padding_strategy(), + ) })?; raw_witin.padding_by_strategy(); - Ok(([raw_witin, raw_structural], lk_multiplicity.into_finalize_result())) + Ok(( + [raw_witin, raw_structural], + lk_multiplicity.into_finalize_result(), + )) } /// GPU kernel dispatch based on instruction kind. @@ -134,7 +152,12 @@ fn gpu_fill_witness>( shard_steps: &[StepRecord], step_indices: &[StepIndex], kind: GpuWitgenKind, -) -> Result::BaseField>>, ZKVMError> { +) -> Result< + ceno_gpu::common::witgen_types::GpuWitnessResult< + ceno_gpu::common::BufferImpl<'static, ::BaseField>, + >, + ZKVMError, +> { match kind { GpuWitgenKind::Add => { // Safety: we know config is ArithConfig when kind == Add @@ -142,11 +165,11 @@ fn gpu_fill_witness>( &*(config as *const I::InstructionConfig as *const crate::instructions::riscv::arith::ArithConfig) }; - let col_map = - super::add::extract_add_column_map(arith_config, num_witin); + let col_map = super::add::extract_add_column_map(arith_config, num_witin); let soa = super::add::pack_add_soa(shard_ctx, shard_steps, step_indices); - hal.witgen_add(&col_map, &soa, None) - .map_err(|e| ZKVMError::InvalidWitness(format!("GPU witgen_add failed: {e}").into())) + hal.witgen_add(&col_map, &soa, None).map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU witgen_add failed: {e}").into()) + }) } GpuWitgenKind::Lw => { // LoadConfig location depends on the u16limb_circuit feature @@ -160,8 +183,7 @@ fn gpu_fill_witness>( &*(config as *const I::InstructionConfig as *const crate::instructions::riscv::memory::load::LoadConfig) }; - let col_map = - super::lw::extract_lw_column_map(load_config, num_witin); + let col_map = super::lw::extract_lw_column_map(load_config, num_witin); let soa = super::lw::pack_lw_soa::(shard_ctx, shard_steps, step_indices); hal.witgen_lw(&col_map, &soa, None) .map_err(|e| ZKVMError::InvalidWitness(format!("GPU witgen_lw failed: {e}").into())) From 2d823063d3f2419df422bc2869607fea97db90b1 Mon Sep 17 00:00:00 2001 From: Velaciela Date: Tue, 3 Mar 2026 10:13:52 +0800 Subject: [PATCH 08/73] minor --- ceno_zkvm/src/instructions/riscv/gpu/add.rs | 5 +++-- ceno_zkvm/src/instructions/riscv/gpu/lw.rs | 5 +++-- ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs | 10 +++++++++- 3 files changed, 15 insertions(+), 5 deletions(-) diff --git a/ceno_zkvm/src/instructions/riscv/gpu/add.rs b/ceno_zkvm/src/instructions/riscv/gpu/add.rs index a8718e870..539ed2cf3 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/add.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/add.rs @@ -250,9 +250,10 @@ mod tests { let steps = make_test_steps(n); let indices: Vec = (0..n).collect(); - // CPU path + // CPU path — use cpu_assign_instances directly to avoid going through + // the GPU override in assign_instances (which would make this GPU vs GPU). let mut shard_ctx = ShardContext::default(); - let (cpu_rmms, _lkm) = AddInstruction::::assign_instances( + let (cpu_rmms, _lkm) = crate::instructions::cpu_assign_instances::>( &config, &mut shard_ctx, num_witin, diff --git a/ceno_zkvm/src/instructions/riscv/gpu/lw.rs b/ceno_zkvm/src/instructions/riscv/gpu/lw.rs index 6ab14a86c..639a04db0 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/lw.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/lw.rs @@ -252,9 +252,10 @@ mod tests { let steps = make_lw_test_steps(n); let indices: Vec = (0..n).collect(); - // CPU path + // CPU path — use cpu_assign_instances directly to avoid going through + // the GPU override in assign_instances (which would make this GPU vs GPU). let mut shard_ctx = ShardContext::default(); - let (cpu_rmms, _lkm) = LwInstruction::assign_instances( + let (cpu_rmms, _lkm) = crate::instructions::cpu_assign_instances::( &config, &mut shard_ctx, num_witin, diff --git a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs index 6f47d9914..842a5f41e 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs @@ -30,7 +30,15 @@ pub enum GpuWitgenKind { /// Try to run GPU witness generation for the given instruction. /// Returns `Ok(Some(...))` if GPU was used, `Ok(None)` if GPU is unavailable (caller should fallback to CPU). -pub fn try_gpu_assign_instances>( +/// +/// # Safety invariant +/// +/// The caller **must** ensure that `I::InstructionConfig` matches `kind`: +/// - `GpuWitgenKind::Add` requires `I` to be `ArithInstruction` (config = `ArithConfig`) +/// - `GpuWitgenKind::Lw` requires `I` to be `LoadInstruction` (config = `LoadConfig`) +/// +/// Violating this will cause undefined behavior via pointer cast in [`gpu_fill_witness`]. +pub(crate) fn try_gpu_assign_instances>( config: &I::InstructionConfig, shard_ctx: &mut ShardContext, num_witin: usize, From 65985ff0c0ad970bb35281e6c00f6ca96a219d1a Mon Sep 17 00:00:00 2001 From: Velaciela Date: Tue, 3 Mar 2026 11:06:23 +0800 Subject: [PATCH 09/73] dev-local --- Cargo.lock | 112 +++++++++++++++++++++++++++++++++++++++++++++++++++-- Cargo.toml | 4 +- 2 files changed, 111 insertions(+), 5 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 2d88e5a73..f47740ca1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1598,10 +1598,48 @@ version = "0.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e2931af7e13dc045d8e9d26afccc6fa115d64e115c9c84b1166288b46f6782c2" +[[package]] +name = "cuda-config" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ee74643f7430213a1a78320f88649de309b20b80818325575e393f848f79f5d" +dependencies = [ + "glob", +] + +[[package]] +name = "cuda-runtime-sys" +version = "0.3.0-alpha.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d070b301187fee3c611e75a425cf12247b7c75c09729dbdef95cb9cb64e8c39" +dependencies = [ + "cuda-config", +] + [[package]] name = "cuda_hal" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno-gpu-mock.git?branch=main#fe8f7923b7d3a3823c27949fab0aab8e31011aa9" +dependencies = [ + "anyhow", + "cuda-runtime-sys", + "cudarc", + "downcast-rs", + "ff_ext", + "itertools 0.13.0", + "mpcs", + "multilinear_extensions", + "p3", + "rand 0.8.5", + "rayon", + "sha2", + "sppark", + "sppark_plug", + "sumcheck", + "thiserror 1.0.69", + "tracing", + "transcript", + "witness", +] [[package]] name = "cudarc" @@ -2668,6 +2706,15 @@ dependencies = [ "digest 0.10.7", ] +[[package]] +name = "home" +version = "0.5.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc627f471c528ff0c4a49e1d5e60450c8f6461dd6d10ba9dcd3a61d3dff7728d" +dependencies = [ + "windows-sys 0.61.1", +] + [[package]] name = "iana-time-zone" version = "0.1.64" @@ -3099,6 +3146,12 @@ dependencies = [ "vcpkg", ] +[[package]] +name = "linux-raw-sys" +version = "0.4.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d26c52dbd32dccf2d10cac7725f8eae5296885fb5703b261f7d0a0739ec807ab" + [[package]] name = "linux-raw-sys" version = "0.9.4" @@ -5721,6 +5774,19 @@ dependencies = [ "semver 1.0.26", ] +[[package]] +name = "rustix" +version = "0.38.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fdb5bc1ae2baa591800df16c9ca78619bf65c0488b41b96ccec5d11220d8c154" +dependencies = [ + "bitflags", + "errno", + "libc", + "linux-raw-sys 0.4.15", + "windows-sys 0.59.0", +] + [[package]] name = "rustix" version = "1.0.7" @@ -5730,7 +5796,7 @@ dependencies = [ "bitflags", "errno", "libc", - "linux-raw-sys", + "linux-raw-sys 0.9.4", "windows-sys 0.59.0", ] @@ -6115,6 +6181,25 @@ dependencies = [ "der", ] +[[package]] +name = "sppark" +version = "0.1.11" +dependencies = [ + "cc", + "which", +] + +[[package]] +name = "sppark_plug" +version = "0.1.0" +dependencies = [ + "cc", + "ff_ext", + "itertools 0.13.0", + "p3", + "sppark", +] + [[package]] name = "stable_deref_trait" version = "1.2.0" @@ -6304,7 +6389,7 @@ dependencies = [ "fastrand", "getrandom 0.3.2", "once_cell", - "rustix", + "rustix 1.0.7", "windows-sys 0.59.0", ] @@ -6921,6 +7006,18 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "which" +version = "4.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87ba24419a2078cd2b0f2ede2691b6c66d8e47836da3b6db8265ebad47afbfc7" +dependencies = [ + "either", + "home", + "once_cell", + "rustix 0.38.44", +] + [[package]] name = "whir" version = "0.1.0" @@ -7052,6 +7149,15 @@ dependencies = [ "windows-targets 0.53.4", ] +[[package]] +name = "windows-sys" +version = "0.61.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6f109e41dd4a3c848907eb83d5a42ea98b3769495597450cf6d153507b166f0f" +dependencies = [ + "windows-link", +] + [[package]] name = "windows-targets" version = "0.52.6" diff --git a/Cargo.toml b/Cargo.toml index b20888473..ab7009cfe 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -127,8 +127,8 @@ lto = "thin" #ceno_crypto_primitives = { path = "../ceno-patch/crypto-primitives", package = "ceno_crypto_primitives" } #ceno_syscall = { path = "../ceno-patch/syscall", package = "ceno_syscall" } -#[patch."https://github.com/scroll-tech/ceno-gpu-mock.git"] -#ceno_gpu = { path = "../ceno-gpu/cuda_hal", package = "cuda_hal", default-features = false, features = ["bb31"] } +[patch."https://github.com/scroll-tech/ceno-gpu-mock.git"] +ceno_gpu = { path = "../ceno-gpu/cuda_hal", package = "cuda_hal", default-features = false, features = ["bb31"] } #[patch."https://github.com/scroll-tech/gkr-backend"] #ff_ext = { path = "../gkr-backend/crates/ff_ext", package = "ff_ext" } From 0f9603324443dfaaad6cc0dc882657cebf1b0a59 Mon Sep 17 00:00:00 2001 From: Velaciela Date: Wed, 4 Mar 2026 10:29:26 +0800 Subject: [PATCH 10/73] GPU: AOS StepRecord --- ceno_zkvm/benches/witgen_add_gpu.rs | 32 ++++- ceno_zkvm/src/instructions/riscv/gpu/add.rs | 115 ++++------------- ceno_zkvm/src/instructions/riscv/gpu/lw.rs | 121 ++++-------------- .../src/instructions/riscv/gpu/witgen_gpu.rs | 35 +++-- 4 files changed, 102 insertions(+), 201 deletions(-) diff --git a/ceno_zkvm/benches/witgen_add_gpu.rs b/ceno_zkvm/benches/witgen_add_gpu.rs index 811d69998..f1583606a 100644 --- a/ceno_zkvm/benches/witgen_add_gpu.rs +++ b/ceno_zkvm/benches/witgen_add_gpu.rs @@ -13,7 +13,7 @@ use ff_ext::BabyBearExt4; #[cfg(feature = "gpu")] use ceno_gpu::bb31::CudaHalBB31; #[cfg(feature = "gpu")] -use ceno_zkvm::instructions::riscv::gpu::add::{extract_add_column_map, pack_add_soa}; +use ceno_zkvm::instructions::riscv::gpu::add::extract_add_column_map; mod alloc; @@ -51,6 +51,16 @@ fn make_test_steps(n: usize) -> Vec { .collect() } +#[cfg(feature = "gpu")] +fn step_records_to_bytes(records: &[StepRecord]) -> &[u8] { + unsafe { + std::slice::from_raw_parts( + records.as_ptr() as *const u8, + records.len() * std::mem::size_of::(), + ) + } +} + fn bench_witgen_add(c: &mut Criterion) { let mut cs = ConstraintSystem::::new(|| "bench"); let mut cb = CircuitBuilder::new(&mut cs); @@ -88,24 +98,32 @@ fn bench_witgen_add(c: &mut Criterion) { }) }); - // GPU benchmark (total: SOA pack + H2D + kernel + synchronize) + // GPU benchmark (total: H2D + kernel + synchronize) #[cfg(feature = "gpu")] group.bench_function("gpu_total", |b| { + let steps_bytes = step_records_to_bytes(&steps); + let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); b.iter(|| { let shard_ctx = ShardContext::default(); - let soa = pack_add_soa(&shard_ctx, &steps, &indices); - hal.witgen_add(&col_map, &soa, None).unwrap() + let shard_offset = shard_ctx.current_shard_offset_cycle(); + hal.witgen_add(&col_map, steps_bytes, &indices_u32, shard_offset, None) + .unwrap() }) }); - // GPU benchmark (kernel only: pre-upload SOA, measure only kernel) + // GPU benchmark (kernel only: same as total since H2D is inside HAL) #[cfg(feature = "gpu")] { + let steps_bytes = step_records_to_bytes(&steps); + let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); let shard_ctx = ShardContext::default(); - let soa = pack_add_soa(&shard_ctx, &steps, &indices); + let shard_offset = shard_ctx.current_shard_offset_cycle(); group.bench_function("gpu_kernel_only", |b| { - b.iter(|| hal.witgen_add(&col_map, &soa, None).unwrap()) + b.iter(|| { + hal.witgen_add(&col_map, steps_bytes, &indices_u32, shard_offset, None) + .unwrap() + }) }); } diff --git a/ceno_zkvm/src/instructions/riscv/gpu/add.rs b/ceno_zkvm/src/instructions/riscv/gpu/add.rs index 539ed2cf3..281a311dc 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/add.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/add.rs @@ -1,9 +1,7 @@ -use ceno_emul::StepIndex; -use ceno_gpu::common::witgen_types::{AddColumnMap, AddStepRecordSOA}; +use ceno_gpu::common::witgen_types::AddColumnMap; use ff_ext::ExtensionField; -use crate::{e2e::ShardContext, instructions::riscv::arith::ArithConfig}; -use ceno_emul::StepRecord; +use crate::instructions::riscv::arith::ArithConfig; /// Extract column map from a constructed ArithConfig (ADD variant). /// @@ -103,71 +101,20 @@ pub fn extract_add_column_map( } } -/// Pack step records into SOA format for GPU transfer. -/// -/// Pre-computes shard-adjusted timing values on CPU so the GPU kernel -/// only needs to do witness filling. -pub fn pack_add_soa( - shard_ctx: &ShardContext, - shard_steps: &[StepRecord], - step_indices: &[StepIndex], -) -> AddStepRecordSOA { - let n = step_indices.len(); - let mut soa = AddStepRecordSOA::with_capacity(n); - - let offset = shard_ctx.current_shard_offset_cycle(); - - for &idx in step_indices { - let step = &shard_steps[idx]; - let rs1 = step.rs1().expect("ADD requires rs1"); - let rs2 = step.rs2().expect("ADD requires rs2"); - let rd = step.rd().expect("ADD requires rd"); - - soa.pc_before.push(step.pc().before.0); - soa.cycle.push(step.cycle() - offset); - soa.rs1_reg.push(rs1.register_index() as u32); - soa.rs1_val.push(rs1.value); - soa.rs1_prev_cycle - .push(aligned_prev_ts(rs1.previous_cycle, offset)); - soa.rs2_reg.push(rs2.register_index() as u32); - soa.rs2_val.push(rs2.value); - soa.rs2_prev_cycle - .push(aligned_prev_ts(rs2.previous_cycle, offset)); - soa.rd_reg.push(rd.register_index() as u32); - soa.rd_val_before.push(rd.value.before); - soa.rd_prev_cycle - .push(aligned_prev_ts(rd.previous_cycle, offset)); - } - - soa -} - -/// Inline version of ShardContext::aligned_prev_ts for SOA packing. -fn aligned_prev_ts(prev_cycle: u64, shard_offset: u64) -> u64 { - let mut ts = prev_cycle.saturating_sub(shard_offset); - if ts < ceno_emul::FullTracer::SUBCYCLES_PER_INSN { - ts = 0; - } - ts -} - #[cfg(test)] mod tests { use super::*; use crate::{ circuit_builder::{CircuitBuilder, ConstraintSystem}, - e2e::ShardContext, instructions::{Instruction, riscv::arith::AddInstruction}, structs::ProgramParams, }; - use ceno_emul::{ByteAddr, Change, InsnKind, encode_rv32}; - use ceno_gpu::{Buffer, bb31::CudaHalBB31}; + use ceno_emul::{ByteAddr, Change, InsnKind, StepRecord, encode_rv32}; use ff_ext::BabyBearExt4; type E = BabyBearExt4; fn make_test_steps(n: usize) -> Vec { - // Use small PC values that fit within BabyBear field (P ≈ 2×10^9) let pc_start = 0x1000u32; (0..n) .map(|i| { @@ -175,7 +122,7 @@ mod tests { let rs2 = (i as u32) % 500 + 3; let rd_before = (i as u32) % 200; let rd_after = rs1.wrapping_add(rs2); - let cycle = 4 + (i as u64) * 4; // cycles start at 4 (SUBCYCLES_PER_INSN) + let cycle = 4 + (i as u64) * 4; let pc = ByteAddr(pc_start + (i as u32) * 4); let insn_code = encode_rv32(InsnKind::ADD, 2, 3, 4, 0); StepRecord::new_r_instruction( @@ -185,7 +132,7 @@ mod tests { rs1, rs2, Change::new(rd_before, rd_after), - 0, // prev_cycle + 0, ) }) .collect() @@ -219,22 +166,12 @@ mod tests { } } - #[test] - fn test_pack_add_soa() { - let steps = make_test_steps(4); - let indices: Vec = (0..steps.len()).collect(); - let shard_ctx = ShardContext::default(); - let soa = pack_add_soa(&shard_ctx, &steps, &indices); - - assert_eq!(soa.len(), 4); - // Check first step - assert_eq!(soa.rs1_val[0], 1); // 0 * 7 + 1 - assert_eq!(soa.rs2_val[0], 3); // 0 * 13 + 3 - } - #[test] #[cfg(feature = "gpu")] fn test_gpu_witgen_add_correctness() { + use crate::e2e::ShardContext; + use ceno_gpu::{Buffer, bb31::CudaHalBB31}; + let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); // Construct circuit @@ -250,8 +187,7 @@ mod tests { let steps = make_test_steps(n); let indices: Vec = (0..n).collect(); - // CPU path — use cpu_assign_instances directly to avoid going through - // the GPU override in assign_instances (which would make this GPU vs GPU). + // CPU path let mut shard_ctx = ShardContext::default(); let (cpu_rmms, _lkm) = crate::instructions::cpu_assign_instances::>( &config, @@ -262,13 +198,22 @@ mod tests { &indices, ) .unwrap(); - let cpu_witness = &cpu_rmms[0]; // witness matrix (not structural) + let cpu_witness = &cpu_rmms[0]; - // GPU path + // GPU path (AOS with indirect indexing) let col_map = extract_add_column_map(&config, num_witin); let shard_ctx_gpu = ShardContext::default(); - let soa = pack_add_soa(&shard_ctx_gpu, &steps, &indices); - let gpu_result = hal.witgen_add(&col_map, &soa, None).unwrap(); + let shard_offset = shard_ctx_gpu.current_shard_offset_cycle(); + let steps_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + steps.as_ptr() as *const u8, + steps.len() * std::mem::size_of::(), + ) + }; + let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); + let gpu_result = hal + .witgen_add(&col_map, steps_bytes, &indices_u32, shard_offset, None) + .unwrap(); // D2H copy let gpu_data: Vec<::BaseField> = @@ -276,13 +221,7 @@ mod tests { // Compare element by element let cpu_data = cpu_witness.values(); - assert_eq!( - gpu_data.len(), - cpu_data.len(), - "Size mismatch: GPU {} vs CPU {}", - gpu_data.len(), - cpu_data.len() - ); + assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); let mut mismatches = 0; for row in 0..n { @@ -300,12 +239,6 @@ mod tests { } } } - assert_eq!( - mismatches, - 0, - "Found {} mismatches out of {} elements", - mismatches, - n * num_witin - ); + assert_eq!(mismatches, 0, "Found {} mismatches", mismatches); } } diff --git a/ceno_zkvm/src/instructions/riscv/gpu/lw.rs b/ceno_zkvm/src/instructions/riscv/gpu/lw.rs index 639a04db0..36db9d0eb 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/lw.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/lw.rs @@ -1,13 +1,10 @@ -use ceno_emul::StepIndex; -use ceno_gpu::common::witgen_types::{LwColumnMap, LwStepRecordSOA}; +use ceno_gpu::common::witgen_types::LwColumnMap; use ff_ext::ExtensionField; #[cfg(not(feature = "u16limb_circuit"))] use crate::instructions::riscv::memory::load::LoadConfig; #[cfg(feature = "u16limb_circuit")] use crate::instructions::riscv::memory::load_v2::LoadConfig; -use crate::{e2e::ShardContext, tables::InsnRecord}; -use ceno_emul::{ByteAddr, StepRecord}; /// Extract column map from a constructed LoadConfig (LW variant). pub fn extract_lw_column_map( @@ -98,82 +95,15 @@ pub fn extract_lw_column_map( } } -/// Pack step records into SOA format for LW GPU transfer. -pub fn pack_lw_soa( - shard_ctx: &ShardContext, - shard_steps: &[StepRecord], - step_indices: &[StepIndex], -) -> LwStepRecordSOA { - use p3::field::PrimeField32; - type B = ::BaseField; - - let n = step_indices.len(); - let mut soa = LwStepRecordSOA::with_capacity(n); - let offset = shard_ctx.current_shard_offset_cycle(); - - for &idx in step_indices { - let step = &shard_steps[idx]; - let rs1_op = step.rs1().expect("LW requires rs1"); - let rd_op = step.rd().expect("LW requires rd"); - let mem_op = step.memory_op().expect("LW requires memory_op"); - - // Compute imm field value (signed immediate as BabyBear) - let imm_pair = InsnRecord::::imm_internal(&step.insn()); - let imm_field_val: B = imm_pair.1; - - // Compute unaligned address - let unaligned_addr = ByteAddr::from(rs1_op.value.wrapping_add_signed(imm_pair.0 as i32)); - - soa.pc_before.push(step.pc().before.0); - soa.cycle.push(step.cycle() - offset); - soa.rs1_reg.push(rs1_op.register_index() as u32); - soa.rs1_val.push(rs1_op.value); - soa.rs1_prev_cycle - .push(aligned_prev_ts(rs1_op.previous_cycle, offset)); - soa.rd_reg.push(rd_op.register_index() as u32); - soa.rd_val_before.push(rd_op.value.before); - soa.rd_prev_cycle - .push(aligned_prev_ts(rd_op.previous_cycle, offset)); - soa.mem_prev_cycle - .push(aligned_prev_ts(mem_op.previous_cycle, offset)); - soa.mem_val.push(mem_op.value.before); - soa.imm_field.push(imm_field_val.as_canonical_u32()); - soa.unaligned_addr.push(unaligned_addr.0); - - // imm_sign for v2 variant - #[cfg(feature = "u16limb_circuit")] - { - let imm_sign_extend = crate::utils::imm_sign_extend(true, step.insn().imm as i16); - let is_neg = if imm_sign_extend[1] > 0 { 1u32 } else { 0u32 }; - if soa.imm_sign_field.is_none() { - soa.imm_sign_field = Some(Vec::with_capacity(n)); - } - soa.imm_sign_field.as_mut().unwrap().push(is_neg); - } - } - - soa -} - -fn aligned_prev_ts(prev_cycle: u64, shard_offset: u64) -> u64 { - let mut ts = prev_cycle.saturating_sub(shard_offset); - if ts < ceno_emul::FullTracer::SUBCYCLES_PER_INSN { - ts = 0; - } - ts -} - #[cfg(test)] mod tests { use super::*; use crate::{ circuit_builder::{CircuitBuilder, ConstraintSystem}, - e2e::ShardContext, instructions::Instruction, structs::ProgramParams, }; use ceno_emul::{ByteAddr, Change, InsnKind, ReadOp, StepRecord, WordAddr, encode_rv32}; - use ceno_gpu::{Buffer, bb31::CudaHalBB31}; use ff_ext::BabyBearExt4; type E = BabyBearExt4; @@ -181,12 +111,14 @@ mod tests { fn make_lw_test_steps(n: usize) -> Vec { let pc_start = 0x1000u32; + // Use varying immediates including negative values to test imm_field encoding + let imm_values: [i32; 4] = [0, 4, -4, -8]; (0..n) .map(|i| { - let rs1_val = 0x100u32 + (i as u32) * 4; // base address, 4-byte aligned - let imm: i32 = 0; // zero offset for simplicity + let rs1_val = 0x1000u32 + (i as u32) * 16; // 16-byte aligned base + let imm: i32 = imm_values[i % imm_values.len()]; let mem_addr = rs1_val.wrapping_add_signed(imm); - let mem_val = (i as u32) * 111 % 1000000; // some value < P + let mem_val = (i as u32) * 111 % 1000000; let rd_before = (i as u32) % 200; let cycle = 4 + (i as u64) * 4; let pc = ByteAddr(pc_start + (i as u32) * 4); @@ -205,7 +137,7 @@ mod tests { rs1_val, Change::new(rd_before, mem_val), mem_read_op, - 0, // prev_cycle + 0, ) }) .collect() @@ -220,7 +152,6 @@ mod tests { let col_map = extract_lw_column_map(&config, cb.cs.num_witin as usize); let (n_entries, flat) = col_map.to_flat(); - // All column IDs should be within range for (i, &col) in flat[..n_entries].iter().enumerate() { assert!( (col as usize) < col_map.num_cols as usize, @@ -231,7 +162,6 @@ mod tests { col_map.num_cols ); } - // Check uniqueness let mut seen = std::collections::HashSet::new(); for &col in &flat[..n_entries] { assert!(seen.insert(col), "Duplicate column ID: {}", col); @@ -239,7 +169,11 @@ mod tests { } #[test] + #[cfg(feature = "gpu")] fn test_gpu_witgen_lw_correctness() { + use crate::e2e::ShardContext; + use ceno_gpu::{Buffer, bb31::CudaHalBB31}; + let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); let mut cs = ConstraintSystem::::new(|| "test_lw_gpu"); @@ -252,8 +186,7 @@ mod tests { let steps = make_lw_test_steps(n); let indices: Vec = (0..n).collect(); - // CPU path — use cpu_assign_instances directly to avoid going through - // the GPU override in assign_instances (which would make this GPU vs GPU). + // CPU path let mut shard_ctx = ShardContext::default(); let (cpu_rmms, _lkm) = crate::instructions::cpu_assign_instances::( &config, @@ -266,25 +199,27 @@ mod tests { .unwrap(); let cpu_witness = &cpu_rmms[0]; - // GPU path + // GPU path (AOS with indirect indexing) let col_map = extract_lw_column_map(&config, num_witin); let shard_ctx_gpu = ShardContext::default(); - let soa = pack_lw_soa::(&shard_ctx_gpu, &steps, &indices); - let gpu_result = hal.witgen_lw(&col_map, &soa, None).unwrap(); + let shard_offset = shard_ctx_gpu.current_shard_offset_cycle(); + let steps_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + steps.as_ptr() as *const u8, + steps.len() * std::mem::size_of::(), + ) + }; + let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); + let gpu_result = hal + .witgen_lw(&col_map, steps_bytes, &indices_u32, shard_offset, None) + .unwrap(); let gpu_data: Vec<::BaseField> = gpu_result.device_buffer.to_vec().unwrap(); let cpu_data = cpu_witness.values(); - assert_eq!( - gpu_data.len(), - cpu_data.len(), - "Size mismatch: GPU {} vs CPU {}", - gpu_data.len(), - cpu_data.len() - ); + assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); - // Only compare columns that the GPU fills (the col_map columns) let (n_entries, flat) = col_map.to_flat(); let mut mismatches = 0; for row in 0..n { @@ -303,10 +238,6 @@ mod tests { } } } - assert_eq!( - mismatches, 0, - "Found {} mismatches in GPU-filled columns", - mismatches - ); + assert_eq!(mismatches, 0, "Found {} mismatches", mismatches); } } diff --git a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs index 842a5f41e..a4431c34c 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs @@ -166,6 +166,18 @@ fn gpu_fill_witness>( >, ZKVMError, > { + // Cast shard_steps to bytes for bulk H2D (no gather — GPU does indirect access). + let shard_steps_bytes: &[u8] = info_span!("shard_steps_bytes").in_scope(|| unsafe { + std::slice::from_raw_parts( + shard_steps.as_ptr() as *const u8, + shard_steps.len() * std::mem::size_of::(), + ) + }); + // Convert step_indices from usize to u32 for GPU. + let indices_u32: Vec = info_span!("indices_u32", n = step_indices.len()) + .in_scope(|| step_indices.iter().map(|&i| i as u32).collect()); + let shard_offset = shard_ctx.current_shard_offset_cycle(); + match kind { GpuWitgenKind::Add => { // Safety: we know config is ArithConfig when kind == Add @@ -173,10 +185,13 @@ fn gpu_fill_witness>( &*(config as *const I::InstructionConfig as *const crate::instructions::riscv::arith::ArithConfig) }; - let col_map = super::add::extract_add_column_map(arith_config, num_witin); - let soa = super::add::pack_add_soa(shard_ctx, shard_steps, step_indices); - hal.witgen_add(&col_map, &soa, None).map_err(|e| { - ZKVMError::InvalidWitness(format!("GPU witgen_add failed: {e}").into()) + let col_map = info_span!("col_map") + .in_scope(|| super::add::extract_add_column_map(arith_config, num_witin)); + info_span!("hal_witgen_add").in_scope(|| { + hal.witgen_add(&col_map, shard_steps_bytes, &indices_u32, shard_offset, None) + .map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU witgen_add failed: {e}").into()) + }) }) } GpuWitgenKind::Lw => { @@ -191,10 +206,14 @@ fn gpu_fill_witness>( &*(config as *const I::InstructionConfig as *const crate::instructions::riscv::memory::load::LoadConfig) }; - let col_map = super::lw::extract_lw_column_map(load_config, num_witin); - let soa = super::lw::pack_lw_soa::(shard_ctx, shard_steps, step_indices); - hal.witgen_lw(&col_map, &soa, None) - .map_err(|e| ZKVMError::InvalidWitness(format!("GPU witgen_lw failed: {e}").into())) + let col_map = info_span!("col_map") + .in_scope(|| super::lw::extract_lw_column_map(load_config, num_witin)); + info_span!("hal_witgen_lw").in_scope(|| { + hal.witgen_lw(&col_map, shard_steps_bytes, &indices_u32, shard_offset, None) + .map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU witgen_lw failed: {e}").into()) + }) + }) } } } From 4fc7368ff126bb7dce89442e1bdb665e9177b752 Mon Sep 17 00:00:00 2001 From: Velaciela Date: Wed, 4 Mar 2026 10:31:05 +0800 Subject: [PATCH 11/73] SHARD_STEPS_DEVICE --- ceno_zkvm/benches/witgen_add_gpu.rs | 10 +- ceno_zkvm/src/e2e.rs | 5 + ceno_zkvm/src/instructions/riscv/gpu/add.rs | 3 +- ceno_zkvm/src/instructions/riscv/gpu/lw.rs | 3 +- .../src/instructions/riscv/gpu/witgen_gpu.rs | 123 +++++++++++++++--- 5 files changed, 117 insertions(+), 27 deletions(-) diff --git a/ceno_zkvm/benches/witgen_add_gpu.rs b/ceno_zkvm/benches/witgen_add_gpu.rs index f1583606a..7b68d93bb 100644 --- a/ceno_zkvm/benches/witgen_add_gpu.rs +++ b/ceno_zkvm/benches/witgen_add_gpu.rs @@ -98,30 +98,32 @@ fn bench_witgen_add(c: &mut Criterion) { }) }); - // GPU benchmark (total: H2D + kernel + synchronize) + // GPU benchmark (total: H2D records + indices + kernel + synchronize) #[cfg(feature = "gpu")] group.bench_function("gpu_total", |b| { let steps_bytes = step_records_to_bytes(&steps); let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); b.iter(|| { + let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); let shard_ctx = ShardContext::default(); let shard_offset = shard_ctx.current_shard_offset_cycle(); - hal.witgen_add(&col_map, steps_bytes, &indices_u32, shard_offset, None) + hal.witgen_add(&col_map, &gpu_records, &indices_u32, shard_offset, None) .unwrap() }) }); - // GPU benchmark (kernel only: same as total since H2D is inside HAL) + // GPU benchmark (kernel only: records pre-uploaded) #[cfg(feature = "gpu")] { let steps_bytes = step_records_to_bytes(&steps); + let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); let shard_ctx = ShardContext::default(); let shard_offset = shard_ctx.current_shard_offset_cycle(); group.bench_function("gpu_kernel_only", |b| { b.iter(|| { - hal.witgen_add(&col_map, steps_bytes, &indices_u32, shard_offset, None) + hal.witgen_add(&col_map, &gpu_records, &indices_u32, shard_offset, None) .unwrap() }) }); diff --git a/ceno_zkvm/src/e2e.rs b/ceno_zkvm/src/e2e.rs index 7a3c4710c..4fc66df5a 100644 --- a/ceno_zkvm/src/e2e.rs +++ b/ceno_zkvm/src/e2e.rs @@ -1352,6 +1352,11 @@ pub fn generate_witness<'a, E: ExtensionField>( ) .unwrap(); tracing::debug!("assign_opcode_circuit finish in {:?}", time.elapsed()); + + // Free GPU shard_steps cache after all opcode circuits are done. + #[cfg(feature = "gpu")] + crate::instructions::riscv::gpu::witgen_gpu::invalidate_shard_steps_cache(); + let time = std::time::Instant::now(); system_config .dummy_config diff --git a/ceno_zkvm/src/instructions/riscv/gpu/add.rs b/ceno_zkvm/src/instructions/riscv/gpu/add.rs index 281a311dc..2ee07edf7 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/add.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/add.rs @@ -210,9 +210,10 @@ mod tests { steps.len() * std::mem::size_of::(), ) }; + let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); let gpu_result = hal - .witgen_add(&col_map, steps_bytes, &indices_u32, shard_offset, None) + .witgen_add(&col_map, &gpu_records, &indices_u32, shard_offset, None) .unwrap(); // D2H copy diff --git a/ceno_zkvm/src/instructions/riscv/gpu/lw.rs b/ceno_zkvm/src/instructions/riscv/gpu/lw.rs index 36db9d0eb..fdef5a686 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/lw.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/lw.rs @@ -209,9 +209,10 @@ mod tests { steps.len() * std::mem::size_of::(), ) }; + let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); let gpu_result = hal - .witgen_lw(&col_map, steps_bytes, &indices_u32, shard_offset, None) + .witgen_lw(&col_map, &gpu_records, &indices_u32, shard_offset, None) .unwrap(); let gpu_data: Vec<::BaseField> = diff --git a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs index a4431c34c..f2a170991 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs @@ -5,7 +5,7 @@ /// 2. Runs a CPU loop to collect side effects (shard_ctx.send, lk_multiplicity) /// 3. Returns the GPU-generated witness + CPU-collected side effects use ceno_emul::{StepIndex, StepRecord}; -use ceno_gpu::{Buffer, bb31::CudaHalBB31}; +use ceno_gpu::{Buffer, CudaHal, CudaSlice, bb31::CudaHalBB31}; use ff_ext::ExtensionField; use gkr_iop::utils::lk_multiplicity::Multiplicity; use multilinear_extensions::util::max_usable_threads; @@ -14,6 +14,7 @@ use rayon::{ iter::{IndexedParallelIterator, ParallelIterator}, slice::ParallelSlice, }; +use std::cell::RefCell; use tracing::info_span; use witness::{InstancePaddingStrategy, RowMajorMatrix}; @@ -28,6 +29,89 @@ pub enum GpuWitgenKind { Lw, } +/// Cached shard_steps device buffer with metadata for logging. +struct ShardStepsCache { + host_ptr: usize, + byte_len: usize, + shard_id: usize, + n_steps: usize, + device_buf: CudaSlice, +} + +// Thread-local cache for shard_steps device buffer. Invalidated when shard changes. +thread_local! { + static SHARD_STEPS_DEVICE: RefCell> = + const { RefCell::new(None) }; +} + +/// Upload shard_steps to GPU, reusing cached device buffer if the same data. +fn upload_shard_steps_cached( + hal: &CudaHalBB31, + shard_steps: &[StepRecord], + shard_id: usize, +) -> Result<(), ZKVMError> { + let ptr = shard_steps.as_ptr() as usize; + let byte_len = shard_steps.len() * std::mem::size_of::(); + + SHARD_STEPS_DEVICE.with(|cache| { + let mut cache = cache.borrow_mut(); + if let Some(c) = cache.as_ref() { + if c.host_ptr == ptr && c.byte_len == byte_len { + return Ok(()); // cache hit + } + } + // Cache miss: upload + let mb = byte_len as f64 / (1024.0 * 1024.0); + tracing::info!( + "[GPU witgen] uploading shard_steps: shard_id={}, n_steps={}, {:.2} MB", + shard_id, + shard_steps.len(), + mb, + ); + let bytes: &[u8] = + unsafe { std::slice::from_raw_parts(shard_steps.as_ptr() as *const u8, byte_len) }; + let device_buf = hal.inner.htod_copy_stream(None, bytes).map_err(|e| { + ZKVMError::InvalidWitness(format!("shard_steps H2D failed: {e}").into()) + })?; + *cache = Some(ShardStepsCache { + host_ptr: ptr, + byte_len, + shard_id, + n_steps: shard_steps.len(), + device_buf, + }); + Ok(()) + }) +} + +/// Borrow the cached device buffer for kernel launch. +/// Panics if `upload_shard_steps_cached` was not called first. +fn with_cached_shard_steps(f: impl FnOnce(&CudaSlice) -> R) -> R { + SHARD_STEPS_DEVICE.with(|cache| { + let cache = cache.borrow(); + let c = cache.as_ref().expect("shard_steps not uploaded"); + f(&c.device_buf) + }) +} + +/// Invalidate the cached shard_steps device buffer. +/// Call this when shard processing is complete to free GPU memory. +pub fn invalidate_shard_steps_cache() { + SHARD_STEPS_DEVICE.with(|cache| { + let mut cache = cache.borrow_mut(); + if let Some(c) = cache.as_ref() { + let mb = c.byte_len as f64 / (1024.0 * 1024.0); + tracing::info!( + "[GPU witgen] releasing shard_steps cache: shard_id={}, n_steps={}, {:.2} MB", + c.shard_id, + c.n_steps, + mb, + ); + } + *cache = None; + }); +} + /// Try to run GPU witness generation for the given instruction. /// Returns `Ok(Some(...))` if GPU was used, `Ok(None)` if GPU is unavailable (caller should fallback to CPU). /// @@ -118,7 +202,6 @@ fn gpu_assign_instances_inner>( })?; // Step 2: CPU collects side effects (shard_ctx.send, lk_multiplicity) - // We run assign_instance with a scratch buffer per thread and discard the witness data. let lk_multiplicity = info_span!("cpu_side_effects").in_scope(|| { collect_side_effects::(config, shard_ctx, num_witin, shard_steps, step_indices) })?; @@ -166,13 +249,11 @@ fn gpu_fill_witness>( >, ZKVMError, > { - // Cast shard_steps to bytes for bulk H2D (no gather — GPU does indirect access). - let shard_steps_bytes: &[u8] = info_span!("shard_steps_bytes").in_scope(|| unsafe { - std::slice::from_raw_parts( - shard_steps.as_ptr() as *const u8, - shard_steps.len() * std::mem::size_of::(), - ) - }); + // Upload shard_steps to GPU once (cached across ADD/LW calls within same shard). + let shard_id = shard_ctx.shard_id; + info_span!("upload_shard_steps") + .in_scope(|| upload_shard_steps_cached(hal, shard_steps, shard_id))?; + // Convert step_indices from usize to u32 for GPU. let indices_u32: Vec = info_span!("indices_u32", n = step_indices.len()) .in_scope(|| step_indices.iter().map(|&i| i as u32).collect()); @@ -180,7 +261,6 @@ fn gpu_fill_witness>( match kind { GpuWitgenKind::Add => { - // Safety: we know config is ArithConfig when kind == Add let arith_config = unsafe { &*(config as *const I::InstructionConfig as *const crate::instructions::riscv::arith::ArithConfig) @@ -188,14 +268,15 @@ fn gpu_fill_witness>( let col_map = info_span!("col_map") .in_scope(|| super::add::extract_add_column_map(arith_config, num_witin)); info_span!("hal_witgen_add").in_scope(|| { - hal.witgen_add(&col_map, shard_steps_bytes, &indices_u32, shard_offset, None) - .map_err(|e| { - ZKVMError::InvalidWitness(format!("GPU witgen_add failed: {e}").into()) - }) + with_cached_shard_steps(|gpu_records| { + hal.witgen_add(&col_map, gpu_records, &indices_u32, shard_offset, None) + .map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU witgen_add failed: {e}").into()) + }) + }) }) } GpuWitgenKind::Lw => { - // LoadConfig location depends on the u16limb_circuit feature #[cfg(feature = "u16limb_circuit")] let load_config = unsafe { &*(config as *const I::InstructionConfig @@ -209,10 +290,12 @@ fn gpu_fill_witness>( let col_map = info_span!("col_map") .in_scope(|| super::lw::extract_lw_column_map(load_config, num_witin)); info_span!("hal_witgen_lw").in_scope(|| { - hal.witgen_lw(&col_map, shard_steps_bytes, &indices_u32, shard_offset, None) - .map_err(|e| { - ZKVMError::InvalidWitness(format!("GPU witgen_lw failed: {e}").into()) - }) + with_cached_shard_steps(|gpu_records| { + hal.witgen_lw(&col_map, gpu_records, &indices_u32, shard_offset, None) + .map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU witgen_lw failed: {e}").into()) + }) + }) }) } } @@ -244,13 +327,11 @@ fn collect_side_effects>( .zip(shard_ctx_vec) .flat_map(|(indices, mut shard_ctx)| { let mut lk_multiplicity = lk_multiplicity.clone(); - // Reusable scratch buffer for this thread's assign_instance calls let mut scratch = vec![E::BaseField::ZERO; num_witin]; indices .iter() .copied() .map(|step_idx| { - // Zero out scratch for each step scratch.fill(E::BaseField::ZERO); I::assign_instance( config, From 72dd155458023764d3348bec38892edd9790c619 Mon Sep 17 00:00:00 2001 From: Velaciela Date: Fri, 6 Mar 2026 00:40:36 +0800 Subject: [PATCH 12/73] batch-1234 --- ceno_zkvm/src/instructions/riscv/arith.rs | 10 +- ceno_zkvm/src/instructions/riscv/arith_imm.rs | 2 +- .../riscv/arith_imm/arith_imm_circuit_v2.rs | 48 +++- ceno_zkvm/src/instructions/riscv/gpu/addi.rs | 114 +++++++++ .../src/instructions/riscv/gpu/logic_i.rs | 120 ++++++++++ .../src/instructions/riscv/gpu/logic_r.rs | 201 ++++++++++++++++ ceno_zkvm/src/instructions/riscv/gpu/mod.rs | 8 + ceno_zkvm/src/instructions/riscv/gpu/sub.rs | 223 ++++++++++++++++++ .../src/instructions/riscv/gpu/witgen_gpu.rs | 78 ++++++ ceno_zkvm/src/instructions/riscv/logic.rs | 2 +- .../instructions/riscv/logic/logic_circuit.rs | 44 +++- ceno_zkvm/src/instructions/riscv/logic_imm.rs | 2 +- .../riscv/logic_imm/logic_imm_circuit_v2.rs | 46 +++- 13 files changed, 880 insertions(+), 18 deletions(-) create mode 100644 ceno_zkvm/src/instructions/riscv/gpu/addi.rs create mode 100644 ceno_zkvm/src/instructions/riscv/gpu/logic_i.rs create mode 100644 ceno_zkvm/src/instructions/riscv/gpu/logic_r.rs create mode 100644 ceno_zkvm/src/instructions/riscv/gpu/sub.rs diff --git a/ceno_zkvm/src/instructions/riscv/arith.rs b/ceno_zkvm/src/instructions/riscv/arith.rs index aa8a05093..7ce605971 100644 --- a/ceno_zkvm/src/instructions/riscv/arith.rs +++ b/ceno_zkvm/src/instructions/riscv/arith.rs @@ -150,8 +150,12 @@ impl Instruction for ArithInstruction Result<(RMMCollections, Multiplicity), ZKVMError> { use crate::instructions::riscv::gpu::witgen_gpu; - // Only ADD gets GPU path; SUB and others fall through to CPU - if I::INST_KIND == InsnKind::ADD { + let gpu_kind = match I::INST_KIND { + InsnKind::ADD => Some(witgen_gpu::GpuWitgenKind::Add), + InsnKind::SUB => Some(witgen_gpu::GpuWitgenKind::Sub), + _ => None, + }; + if let Some(kind) = gpu_kind { if let Some(result) = witgen_gpu::try_gpu_assign_instances::( config, shard_ctx, @@ -159,7 +163,7 @@ impl Instruction for ArithInstruction(PhantomData); pub struct InstructionConfig { - i_insn: IInstructionConfig, + pub(crate) i_insn: IInstructionConfig, - rs1_read: UInt, - imm: WitIn, + pub(crate) rs1_read: UInt, + pub(crate) imm: WitIn, // 0 positive, 1 negative - imm_sign: WitIn, - rd_written: UInt, + pub(crate) imm_sign: WitIn, + pub(crate) rd_written: UInt, } impl Instruction for AddiInstruction { @@ -104,4 +111,35 @@ impl Instruction for AddiInstruction { Ok(()) } + + #[cfg(feature = "gpu")] + fn assign_instances( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + num_witin: usize, + num_structural_witin: usize, + shard_steps: &[StepRecord], + step_indices: &[StepIndex], + ) -> Result<(RMMCollections, Multiplicity), ZKVMError> { + use crate::instructions::riscv::gpu::witgen_gpu; + if let Some(result) = witgen_gpu::try_gpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + witgen_gpu::GpuWitgenKind::Addi, + )? { + return Ok(result); + } + crate::instructions::cpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + ) + } } diff --git a/ceno_zkvm/src/instructions/riscv/gpu/addi.rs b/ceno_zkvm/src/instructions/riscv/gpu/addi.rs new file mode 100644 index 000000000..59019bf18 --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/gpu/addi.rs @@ -0,0 +1,114 @@ +use ceno_gpu::common::witgen_types::AddiColumnMap; +use ff_ext::ExtensionField; + +use crate::instructions::riscv::arith_imm::arith_imm_circuit_v2::InstructionConfig; + +/// Extract column map from a constructed InstructionConfig (ADDI v2). +pub fn extract_addi_column_map( + config: &InstructionConfig, + num_witin: usize, +) -> AddiColumnMap { + let im = &config.i_insn; + + // StateInOut + let pc = im.vm_state.pc.id as u32; + let ts = im.vm_state.ts.id as u32; + + // ReadRS1 + let rs1_id = im.rs1.id.id as u32; + let rs1_prev_ts = im.rs1.prev_ts.id as u32; + let rs1_lt_diff: [u32; 2] = { + let d = &im.rs1.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + // WriteRD + let rd_id = im.rd.id.id as u32; + let rd_prev_ts = im.rd.prev_ts.id as u32; + let rd_prev_val: [u32; 2] = { + let l = im.rd.prev_value.wits_in().expect("rd prev_value WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + let rd_lt_diff: [u32; 2] = { + let d = &im.rd.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + // rs1 u16 limbs + let rs1_limbs: [u32; 2] = { + let l = config.rs1_read.wits_in().expect("rs1_read WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + + // imm and imm_sign + let imm = config.imm.id as u32; + let imm_sign = config.imm_sign.id as u32; + + // rd carries (from the add operation: rs1 + sign_extend(imm)) + let rd_carries: [u32; 2] = { + let carries = config + .rd_written + .carries + .as_ref() + .expect("rd_written should have carries for ADDI"); + assert_eq!(carries.len(), 2); + [carries[0].id as u32, carries[1].id as u32] + }; + + AddiColumnMap { + pc, + ts, + rs1_id, + rs1_prev_ts, + rs1_lt_diff, + rd_id, + rd_prev_ts, + rd_prev_val, + rd_lt_diff, + rs1_limbs, + imm, + imm_sign, + rd_carries, + num_cols: num_witin as u32, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::{Instruction, riscv::arith_imm::AddiInstruction}, + structs::ProgramParams, + }; + use ff_ext::BabyBearExt4; + + type E = BabyBearExt4; + + #[test] + fn test_extract_addi_column_map() { + let mut cs = ConstraintSystem::::new(|| "test_addi"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + AddiInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + + let col_map = extract_addi_column_map(&config, cb.cs.num_witin as usize); + let flat = col_map.to_flat(); + + for (i, &col) in flat.iter().enumerate() { + assert!( + (col as usize) < col_map.num_cols as usize, + "Column {} (index {}) out of range: {} >= {}", + i, col, col, col_map.num_cols + ); + } + let mut seen = std::collections::HashSet::new(); + for &col in &flat { + assert!(seen.insert(col), "Duplicate column ID: {}", col); + } + } +} diff --git a/ceno_zkvm/src/instructions/riscv/gpu/logic_i.rs b/ceno_zkvm/src/instructions/riscv/gpu/logic_i.rs new file mode 100644 index 000000000..235a25fd1 --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/gpu/logic_i.rs @@ -0,0 +1,120 @@ +use ceno_gpu::common::witgen_types::LogicIColumnMap; +use ff_ext::ExtensionField; + +use crate::instructions::riscv::logic_imm::logic_imm_circuit_v2::LogicConfig; + +/// Extract column map from a constructed LogicConfig (I-type v2: ANDI/ORI/XORI). +pub fn extract_logic_i_column_map( + config: &LogicConfig, + num_witin: usize, +) -> LogicIColumnMap { + let im = &config.i_insn; + + // StateInOut + let pc = im.vm_state.pc.id as u32; + let ts = im.vm_state.ts.id as u32; + + // ReadRS1 + let rs1_id = im.rs1.id.id as u32; + let rs1_prev_ts = im.rs1.prev_ts.id as u32; + let rs1_lt_diff: [u32; 2] = { + let d = &im.rs1.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + // WriteRD + let rd_id = im.rd.id.id as u32; + let rd_prev_ts = im.rd.prev_ts.id as u32; + let rd_prev_val: [u32; 2] = { + let l = im.rd.prev_value.wits_in().expect("rd prev_value WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + let rd_lt_diff: [u32; 2] = { + let d = &im.rd.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + // rs1 u8 bytes + let rs1_bytes: [u32; 4] = { + let l = config.rs1_read.wits_in().expect("rs1_read WitIns"); + assert_eq!(l.len(), 4); + [l[0].id as u32, l[1].id as u32, l[2].id as u32, l[3].id as u32] + }; + + // rd u8 bytes + let rd_bytes: [u32; 4] = { + let l = config.rd_written.wits_in().expect("rd_written WitIns"); + assert_eq!(l.len(), 4); + [l[0].id as u32, l[1].id as u32, l[2].id as u32, l[3].id as u32] + }; + + // imm_lo u8 bytes (UIntLimbs<16,8> = 2 x u8) + let imm_lo_bytes: [u32; 2] = { + let l = config.imm_lo.wits_in().expect("imm_lo WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + + // imm_hi u8 bytes (UIntLimbs<16,8> = 2 x u8) + let imm_hi_bytes: [u32; 2] = { + let l = config.imm_hi.wits_in().expect("imm_hi WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + + LogicIColumnMap { + pc, + ts, + rs1_id, + rs1_prev_ts, + rs1_lt_diff, + rd_id, + rd_prev_ts, + rd_prev_val, + rd_lt_diff, + rs1_bytes, + rd_bytes, + imm_lo_bytes, + imm_hi_bytes, + num_cols: num_witin as u32, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::{Instruction, riscv::logic_imm::AndiInstruction}, + structs::ProgramParams, + }; + use ff_ext::BabyBearExt4; + + type E = BabyBearExt4; + + #[test] + fn test_extract_logic_i_column_map() { + let mut cs = ConstraintSystem::::new(|| "test_logic_i"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + AndiInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + + let col_map = extract_logic_i_column_map(&config, cb.cs.num_witin as usize); + let flat = col_map.to_flat(); + + for (i, &col) in flat.iter().enumerate() { + assert!( + (col as usize) < col_map.num_cols as usize, + "Column {} (index {}) out of range: {} >= {}", + i, col, col, col_map.num_cols + ); + } + let mut seen = std::collections::HashSet::new(); + for &col in &flat { + assert!(seen.insert(col), "Duplicate column ID: {}", col); + } + } +} diff --git a/ceno_zkvm/src/instructions/riscv/gpu/logic_r.rs b/ceno_zkvm/src/instructions/riscv/gpu/logic_r.rs new file mode 100644 index 000000000..1abc851fa --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/gpu/logic_r.rs @@ -0,0 +1,201 @@ +use ceno_gpu::common::witgen_types::LogicRColumnMap; +use ff_ext::ExtensionField; + +use crate::instructions::riscv::logic::logic_circuit::LogicConfig; + +/// Extract column map from a constructed LogicConfig (R-type: AND/OR/XOR). +pub fn extract_logic_r_column_map( + config: &LogicConfig, + num_witin: usize, +) -> LogicRColumnMap { + // StateInOut + let pc = config.r_insn.vm_state.pc.id as u32; + let ts = config.r_insn.vm_state.ts.id as u32; + + // ReadRS1 + let rs1_id = config.r_insn.rs1.id.id as u32; + let rs1_prev_ts = config.r_insn.rs1.prev_ts.id as u32; + let rs1_lt_diff: [u32; 2] = { + let diffs = &config.r_insn.rs1.lt_cfg.0.diff; + assert_eq!(diffs.len(), 2); + [diffs[0].id as u32, diffs[1].id as u32] + }; + + // ReadRS2 + let rs2_id = config.r_insn.rs2.id.id as u32; + let rs2_prev_ts = config.r_insn.rs2.prev_ts.id as u32; + let rs2_lt_diff: [u32; 2] = { + let diffs = &config.r_insn.rs2.lt_cfg.0.diff; + assert_eq!(diffs.len(), 2); + [diffs[0].id as u32, diffs[1].id as u32] + }; + + // WriteRD + let rd_id = config.r_insn.rd.id.id as u32; + let rd_prev_ts = config.r_insn.rd.prev_ts.id as u32; + let rd_prev_val: [u32; 2] = { + let limbs = config.r_insn.rd.prev_value.wits_in().expect("rd prev_value WitIns"); + assert_eq!(limbs.len(), 2); + [limbs[0].id as u32, limbs[1].id as u32] + }; + let rd_lt_diff: [u32; 2] = { + let diffs = &config.r_insn.rd.lt_cfg.0.diff; + assert_eq!(diffs.len(), 2); + [diffs[0].id as u32, diffs[1].id as u32] + }; + + // UInt8 byte limbs + let rs1_bytes: [u32; 4] = { + let l = config.rs1_read.wits_in().expect("rs1_read WitIns"); + assert_eq!(l.len(), 4); + [l[0].id as u32, l[1].id as u32, l[2].id as u32, l[3].id as u32] + }; + let rs2_bytes: [u32; 4] = { + let l = config.rs2_read.wits_in().expect("rs2_read WitIns"); + assert_eq!(l.len(), 4); + [l[0].id as u32, l[1].id as u32, l[2].id as u32, l[3].id as u32] + }; + let rd_bytes: [u32; 4] = { + let l = config.rd_written.wits_in().expect("rd_written WitIns"); + assert_eq!(l.len(), 4); + [l[0].id as u32, l[1].id as u32, l[2].id as u32, l[3].id as u32] + }; + + LogicRColumnMap { + pc, + ts, + rs1_id, + rs1_prev_ts, + rs1_lt_diff, + rs2_id, + rs2_prev_ts, + rs2_lt_diff, + rd_id, + rd_prev_ts, + rd_prev_val, + rd_lt_diff, + rs1_bytes, + rs2_bytes, + rd_bytes, + num_cols: num_witin as u32, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::{Instruction, riscv::logic::AndInstruction}, + structs::ProgramParams, + }; + use ff_ext::BabyBearExt4; + + type E = BabyBearExt4; + + #[test] + fn test_extract_logic_r_column_map() { + let mut cs = ConstraintSystem::::new(|| "test_logic_r"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + AndInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + + let col_map = extract_logic_r_column_map(&config, cb.cs.num_witin as usize); + let flat = col_map.to_flat(); + + for (i, &col) in flat.iter().enumerate() { + assert!( + (col as usize) < col_map.num_cols as usize, + "Column {} (index {}) out of range: {} >= {}", + i, col, col, col_map.num_cols + ); + } + let mut seen = std::collections::HashSet::new(); + for &col in &flat { + assert!(seen.insert(col), "Duplicate column ID: {}", col); + } + } + + #[test] + #[cfg(feature = "gpu")] + fn test_gpu_witgen_logic_r_correctness() { + use crate::e2e::ShardContext; + use ceno_emul::{ByteAddr, Change, InsnKind, StepRecord, encode_rv32}; + use ceno_gpu::{Buffer, bb31::CudaHalBB31}; + + let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); + + let mut cs = ConstraintSystem::::new(|| "test_and_gpu"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + AndInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let num_witin = cb.cs.num_witin as usize; + let num_structural_witin = cb.cs.num_structural_witin as usize; + + let n = 1024; + let steps: Vec = (0..n) + .map(|i| { + let rs1 = 0xDEAD_0000u32 | (i as u32); + let rs2 = 0x00FF_FF00u32 | ((i as u32) << 8); + let rd_after = rs1 & rs2; // AND + let cycle = 4 + (i as u64) * 4; + let pc = ByteAddr(0x1000 + (i as u32) * 4); + let insn_code = encode_rv32(InsnKind::AND, 2, 3, 4, 0); + StepRecord::new_r_instruction( + cycle, pc, insn_code, rs1, rs2, + Change::new((i as u32) % 200, rd_after), 0, + ) + }) + .collect(); + let indices: Vec = (0..n).collect(); + + // CPU path + let mut shard_ctx = ShardContext::default(); + let (cpu_rmms, _lkm) = crate::instructions::cpu_assign_instances::>( + &config, &mut shard_ctx, num_witin, num_structural_witin, &steps, &indices, + ) + .unwrap(); + let cpu_witness = &cpu_rmms[0]; + + // GPU path + let col_map = extract_logic_r_column_map(&config, num_witin); + let shard_ctx_gpu = ShardContext::default(); + let shard_offset = shard_ctx_gpu.current_shard_offset_cycle(); + let steps_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + steps.as_ptr() as *const u8, + steps.len() * std::mem::size_of::(), + ) + }; + let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); + let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); + let gpu_result = hal + .witgen_logic_r(&col_map, &gpu_records, &indices_u32, shard_offset, None) + .unwrap(); + + let gpu_data: Vec<::BaseField> = + gpu_result.device_buffer.to_vec().unwrap(); + let cpu_data = cpu_witness.values(); + assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); + + let flat = col_map.to_flat(); + let mut mismatches = 0; + for row in 0..n { + for &col in &flat { + let c = col as usize; + let gpu_val = gpu_data[row * num_witin + c]; + let cpu_val = cpu_data[row * num_witin + c]; + if gpu_val != cpu_val { + if mismatches < 10 { + eprintln!( + "Mismatch at row={}, col={}: GPU={:?}, CPU={:?}", + row, c, gpu_val, cpu_val + ); + } + mismatches += 1; + } + } + } + assert_eq!(mismatches, 0, "Found {} mismatches", mismatches); + } +} diff --git a/ceno_zkvm/src/instructions/riscv/gpu/mod.rs b/ceno_zkvm/src/instructions/riscv/gpu/mod.rs index 5ebf0d50b..0e0d082e3 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/mod.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/mod.rs @@ -1,6 +1,14 @@ #[cfg(feature = "gpu")] pub mod add; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod addi; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod logic_i; +#[cfg(feature = "gpu")] +pub mod logic_r; #[cfg(feature = "gpu")] pub mod lw; #[cfg(feature = "gpu")] +pub mod sub; +#[cfg(feature = "gpu")] pub mod witgen_gpu; diff --git a/ceno_zkvm/src/instructions/riscv/gpu/sub.rs b/ceno_zkvm/src/instructions/riscv/gpu/sub.rs new file mode 100644 index 000000000..a3aaaa9a1 --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/gpu/sub.rs @@ -0,0 +1,223 @@ +use ceno_gpu::common::witgen_types::SubColumnMap; +use ff_ext::ExtensionField; + +use crate::instructions::riscv::arith::ArithConfig; + +/// Extract column map from a constructed ArithConfig (SUB variant). +/// +/// SUB proves: rs1 = rs2 + rd. The carries come from the (rs2 + rd) addition, +/// stored in rs1_read.carries (since rs1_read = rs2.add(rd) in construct_circuit). +pub fn extract_sub_column_map( + config: &ArithConfig, + num_witin: usize, +) -> SubColumnMap { + // StateInOut + let pc = config.r_insn.vm_state.pc.id as u32; + let ts = config.r_insn.vm_state.ts.id as u32; + + // ReadRS1 + let rs1_id = config.r_insn.rs1.id.id as u32; + let rs1_prev_ts = config.r_insn.rs1.prev_ts.id as u32; + let rs1_lt_diff: [u32; 2] = { + let diffs = &config.r_insn.rs1.lt_cfg.0.diff; + assert_eq!(diffs.len(), 2, "Expected 2 AssertLt diff limbs for RS1"); + [diffs[0].id as u32, diffs[1].id as u32] + }; + + // ReadRS2 + let rs2_id = config.r_insn.rs2.id.id as u32; + let rs2_prev_ts = config.r_insn.rs2.prev_ts.id as u32; + let rs2_lt_diff: [u32; 2] = { + let diffs = &config.r_insn.rs2.lt_cfg.0.diff; + assert_eq!(diffs.len(), 2, "Expected 2 AssertLt diff limbs for RS2"); + [diffs[0].id as u32, diffs[1].id as u32] + }; + + // WriteRD + let rd_id = config.r_insn.rd.id.id as u32; + let rd_prev_ts = config.r_insn.rd.prev_ts.id as u32; + let rd_prev_val: [u32; 2] = { + let limbs = config + .r_insn + .rd + .prev_value + .wits_in() + .expect("WriteRD prev_value should have WitIn limbs"); + assert_eq!(limbs.len(), 2, "Expected 2 prev_value limbs"); + [limbs[0].id as u32, limbs[1].id as u32] + }; + let rd_lt_diff: [u32; 2] = { + let diffs = &config.r_insn.rd.lt_cfg.0.diff; + assert_eq!(diffs.len(), 2, "Expected 2 AssertLt diff limbs for RD"); + [diffs[0].id as u32, diffs[1].id as u32] + }; + + // SUB: rs2_read limbs (rs2 value u16 decomposition) + let rs2_limbs: [u32; 2] = { + let limbs = config + .rs2_read + .wits_in() + .expect("rs2_read should have WitIn limbs"); + assert_eq!(limbs.len(), 2, "Expected 2 rs2_read limbs"); + [limbs[0].id as u32, limbs[1].id as u32] + }; + + // SUB: rd_written limbs (rd.value.after u16 decomposition) + let rd_limbs: [u32; 2] = { + let limbs = config + .rd_written + .wits_in() + .expect("rd_written should have WitIn limbs for SUB"); + assert_eq!(limbs.len(), 2, "Expected 2 rd_written limbs"); + [limbs[0].id as u32, limbs[1].id as u32] + }; + + // SUB: carries from rs1_read (= rs2 + rd) + let carries: [u32; 2] = { + let carries = config + .rs1_read + .carries + .as_ref() + .expect("rs1_read should have carries for SUB"); + assert_eq!(carries.len(), 2, "Expected 2 rs1_read carries"); + [carries[0].id as u32, carries[1].id as u32] + }; + + SubColumnMap { + pc, + ts, + rs1_id, + rs1_prev_ts, + rs1_lt_diff, + rs2_id, + rs2_prev_ts, + rs2_lt_diff, + rd_id, + rd_prev_ts, + rd_prev_val, + rd_lt_diff, + rs2_limbs, + rd_limbs, + carries, + num_cols: num_witin as u32, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::{Instruction, riscv::arith::SubInstruction}, + structs::ProgramParams, + }; + use ff_ext::BabyBearExt4; + + type E = BabyBearExt4; + + #[test] + fn test_extract_sub_column_map() { + let mut cs = ConstraintSystem::::new(|| "test_sub"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + SubInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + + let col_map = extract_sub_column_map(&config, cb.cs.num_witin as usize); + let flat = col_map.to_flat(); + + for (i, &col) in flat.iter().enumerate() { + assert!( + (col as usize) < col_map.num_cols as usize, + "Column {} (index {}) out of range: {} >= {}", + i, col, col, col_map.num_cols + ); + } + let mut seen = std::collections::HashSet::new(); + for &col in &flat { + assert!(seen.insert(col), "Duplicate column ID: {}", col); + } + } + + #[test] + #[cfg(feature = "gpu")] + fn test_gpu_witgen_sub_correctness() { + use crate::e2e::ShardContext; + use ceno_emul::{ByteAddr, Change, InsnKind, StepRecord, encode_rv32}; + use ceno_gpu::{Buffer, bb31::CudaHalBB31}; + + let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); + + let mut cs = ConstraintSystem::::new(|| "test_sub_gpu"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + SubInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let num_witin = cb.cs.num_witin as usize; + let num_structural_witin = cb.cs.num_structural_witin as usize; + + let n = 1024; + let steps: Vec = (0..n) + .map(|i| { + let rs1 = (i as u32) % 1000 + 500; + let rs2 = (i as u32) % 300 + 1; + let rd_after = rs1.wrapping_sub(rs2); + let cycle = 4 + (i as u64) * 4; + let pc = ByteAddr(0x1000 + (i as u32) * 4); + let insn_code = encode_rv32(InsnKind::SUB, 2, 3, 4, 0); + StepRecord::new_r_instruction( + cycle, pc, insn_code, rs1, rs2, + Change::new((i as u32) % 200, rd_after), 0, + ) + }) + .collect(); + let indices: Vec = (0..n).collect(); + + // CPU path + let mut shard_ctx = ShardContext::default(); + let (cpu_rmms, _lkm) = crate::instructions::cpu_assign_instances::>( + &config, &mut shard_ctx, num_witin, num_structural_witin, &steps, &indices, + ) + .unwrap(); + let cpu_witness = &cpu_rmms[0]; + + // GPU path + let col_map = extract_sub_column_map(&config, num_witin); + let shard_ctx_gpu = ShardContext::default(); + let shard_offset = shard_ctx_gpu.current_shard_offset_cycle(); + let steps_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + steps.as_ptr() as *const u8, + steps.len() * std::mem::size_of::(), + ) + }; + let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); + let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); + let gpu_result = hal + .witgen_sub(&col_map, &gpu_records, &indices_u32, shard_offset, None) + .unwrap(); + + let gpu_data: Vec<::BaseField> = + gpu_result.device_buffer.to_vec().unwrap(); + let cpu_data = cpu_witness.values(); + assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); + + let flat = col_map.to_flat(); + let mut mismatches = 0; + for row in 0..n { + for &col in &flat { + let c = col as usize; + let gpu_val = gpu_data[row * num_witin + c]; + let cpu_val = cpu_data[row * num_witin + c]; + if gpu_val != cpu_val { + if mismatches < 10 { + eprintln!( + "Mismatch at row={}, col={}: GPU={:?}, CPU={:?}", + row, c, gpu_val, cpu_val + ); + } + mismatches += 1; + } + } + } + assert_eq!(mismatches, 0, "Found {} mismatches", mismatches); + } +} diff --git a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs index f2a170991..79d3c40d0 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs @@ -26,6 +26,12 @@ use crate::{ #[derive(Debug, Clone, Copy)] pub enum GpuWitgenKind { Add, + Sub, + LogicR, + #[cfg(feature = "u16limb_circuit")] + LogicI, + #[cfg(feature = "u16limb_circuit")] + Addi, Lw, } @@ -276,6 +282,78 @@ fn gpu_fill_witness>( }) }) } + GpuWitgenKind::Sub => { + let arith_config = unsafe { + &*(config as *const I::InstructionConfig + as *const crate::instructions::riscv::arith::ArithConfig) + }; + let col_map = info_span!("col_map") + .in_scope(|| super::sub::extract_sub_column_map(arith_config, num_witin)); + info_span!("hal_witgen_sub").in_scope(|| { + with_cached_shard_steps(|gpu_records| { + hal.witgen_sub(&col_map, gpu_records, &indices_u32, shard_offset, None) + .map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU witgen_sub failed: {e}").into()) + }) + }) + }) + } + GpuWitgenKind::LogicR => { + let logic_config = unsafe { + &*(config as *const I::InstructionConfig + as *const crate::instructions::riscv::logic::logic_circuit::LogicConfig) + }; + let col_map = info_span!("col_map") + .in_scope(|| super::logic_r::extract_logic_r_column_map(logic_config, num_witin)); + info_span!("hal_witgen_logic_r").in_scope(|| { + with_cached_shard_steps(|gpu_records| { + hal.witgen_logic_r(&col_map, gpu_records, &indices_u32, shard_offset, None) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_logic_r failed: {e}").into(), + ) + }) + }) + }) + } + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::LogicI => { + let logic_config = unsafe { + &*(config as *const I::InstructionConfig + as *const crate::instructions::riscv::logic_imm::logic_imm_circuit_v2::LogicConfig) + }; + let col_map = info_span!("col_map") + .in_scope(|| super::logic_i::extract_logic_i_column_map(logic_config, num_witin)); + info_span!("hal_witgen_logic_i").in_scope(|| { + with_cached_shard_steps(|gpu_records| { + hal.witgen_logic_i(&col_map, gpu_records, &indices_u32, shard_offset, None) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_logic_i failed: {e}").into(), + ) + }) + }) + }) + } + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Addi => { + let addi_config = unsafe { + &*(config as *const I::InstructionConfig + as *const crate::instructions::riscv::arith_imm::arith_imm_circuit_v2::InstructionConfig) + }; + let col_map = info_span!("col_map") + .in_scope(|| super::addi::extract_addi_column_map(addi_config, num_witin)); + info_span!("hal_witgen_addi").in_scope(|| { + with_cached_shard_steps(|gpu_records| { + hal.witgen_addi(&col_map, gpu_records, &indices_u32, shard_offset, None) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_addi failed: {e}").into(), + ) + }) + }) + }) + } GpuWitgenKind::Lw => { #[cfg(feature = "u16limb_circuit")] let load_config = unsafe { diff --git a/ceno_zkvm/src/instructions/riscv/logic.rs b/ceno_zkvm/src/instructions/riscv/logic.rs index 9ac2cd4c1..9684c36bf 100644 --- a/ceno_zkvm/src/instructions/riscv/logic.rs +++ b/ceno_zkvm/src/instructions/riscv/logic.rs @@ -1,4 +1,4 @@ -mod logic_circuit; +pub(crate) mod logic_circuit; use gkr_iop::tables::ops::{AndTable, OrTable, XorTable}; use logic_circuit::{LogicInstruction, LogicOp}; diff --git a/ceno_zkvm/src/instructions/riscv/logic/logic_circuit.rs b/ceno_zkvm/src/instructions/riscv/logic/logic_circuit.rs index 4d2cf6db8..f6ce31288 100644 --- a/ceno_zkvm/src/instructions/riscv/logic/logic_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/logic/logic_circuit.rs @@ -18,6 +18,13 @@ use crate::{ }; use ceno_emul::{InsnKind, StepRecord}; +#[cfg(feature = "gpu")] +use crate::tables::RMMCollections; +#[cfg(feature = "gpu")] +use ceno_emul::StepIndex; +#[cfg(feature = "gpu")] +use gkr_iop::utils::lk_multiplicity::Multiplicity; + /// This trait defines a logic instruction, connecting an instruction type to a lookup table. pub trait LogicOp { const INST_KIND: InsnKind; @@ -72,16 +79,47 @@ impl Instruction for LogicInstruction { config.assign_instance(instance, shard_ctx, lk_multiplicity, step) } + + #[cfg(feature = "gpu")] + fn assign_instances( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + num_witin: usize, + num_structural_witin: usize, + shard_steps: &[StepRecord], + step_indices: &[StepIndex], + ) -> Result<(RMMCollections, Multiplicity), ZKVMError> { + use crate::instructions::riscv::gpu::witgen_gpu; + if let Some(result) = witgen_gpu::try_gpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + witgen_gpu::GpuWitgenKind::LogicR, + )? { + return Ok(result); + } + crate::instructions::cpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + ) + } } /// This config implements R-Instructions that represent registers values as 4 * u8. /// Non-generic code shared by several circuits. #[derive(Debug)] pub struct LogicConfig { - r_insn: RInstructionConfig, + pub(crate) r_insn: RInstructionConfig, - rs1_read: UInt8, - rs2_read: UInt8, + pub(crate) rs1_read: UInt8, + pub(crate) rs2_read: UInt8, pub(crate) rd_written: UInt8, } diff --git a/ceno_zkvm/src/instructions/riscv/logic_imm.rs b/ceno_zkvm/src/instructions/riscv/logic_imm.rs index a4b46edcc..44a51233b 100644 --- a/ceno_zkvm/src/instructions/riscv/logic_imm.rs +++ b/ceno_zkvm/src/instructions/riscv/logic_imm.rs @@ -2,7 +2,7 @@ mod logic_imm_circuit; #[cfg(feature = "u16limb_circuit")] -mod logic_imm_circuit_v2; +pub(crate) mod logic_imm_circuit_v2; #[cfg(not(feature = "u16limb_circuit"))] pub use crate::instructions::riscv::logic_imm::logic_imm_circuit::LogicInstruction; diff --git a/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit_v2.rs index 14c2adeb0..6f20710e3 100644 --- a/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit_v2.rs @@ -24,6 +24,13 @@ use crate::{ witness::LkMultiplicity, }; use ceno_emul::{InsnKind, StepRecord}; + +#[cfg(feature = "gpu")] +use crate::tables::RMMCollections; +#[cfg(feature = "gpu")] +use ceno_emul::StepIndex; +#[cfg(feature = "gpu")] +use gkr_iop::utils::lk_multiplicity::Multiplicity; use multilinear_extensions::ToExpr; /// The Instruction circuit for a given LogicOp. @@ -124,18 +131,49 @@ impl Instruction for LogicInstruction { config.assign_instance(instance, shard_ctx, lkm, step) } + + #[cfg(feature = "gpu")] + fn assign_instances( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + num_witin: usize, + num_structural_witin: usize, + shard_steps: &[StepRecord], + step_indices: &[StepIndex], + ) -> Result<(RMMCollections, Multiplicity), ZKVMError> { + use crate::instructions::riscv::gpu::witgen_gpu; + if let Some(result) = witgen_gpu::try_gpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + witgen_gpu::GpuWitgenKind::LogicI, + )? { + return Ok(result); + } + crate::instructions::cpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + ) + } } /// This config implements I-Instructions that represent registers values as 4 * u8. /// Non-generic code shared by several circuits. #[derive(Debug)] pub struct LogicConfig { - i_insn: IInstructionConfig, + pub(crate) i_insn: IInstructionConfig, - rs1_read: UInt8, + pub(crate) rs1_read: UInt8, pub(crate) rd_written: UInt8, - imm_lo: UIntLimbs<{ LIMB_BITS }, 8, E>, - imm_hi: UIntLimbs<{ LIMB_BITS }, 8, E>, + pub(crate) imm_lo: UIntLimbs<{ LIMB_BITS }, 8, E>, + pub(crate) imm_hi: UIntLimbs<{ LIMB_BITS }, 8, E>, } impl LogicConfig { From 273cf7c7db68d39f28e24a2cbc0db96aa5959278 Mon Sep 17 00:00:00 2001 From: Velaciela Date: Fri, 6 Mar 2026 00:57:07 +0800 Subject: [PATCH 13/73] batch-5,12 --- ceno_zkvm/src/instructions/riscv/auipc.rs | 38 +++++++ ceno_zkvm/src/instructions/riscv/gpu/auipc.rs | 107 ++++++++++++++++++ ceno_zkvm/src/instructions/riscv/gpu/jal.rs | 86 ++++++++++++++ ceno_zkvm/src/instructions/riscv/gpu/lui.rs | 98 ++++++++++++++++ ceno_zkvm/src/instructions/riscv/gpu/mod.rs | 6 + .../src/instructions/riscv/gpu/witgen_gpu.rs | 63 +++++++++++ ceno_zkvm/src/instructions/riscv/jump.rs | 2 +- .../src/instructions/riscv/jump/jal_v2.rs | 38 +++++++ ceno_zkvm/src/instructions/riscv/lui.rs | 38 +++++++ 9 files changed, 475 insertions(+), 1 deletion(-) create mode 100644 ceno_zkvm/src/instructions/riscv/gpu/auipc.rs create mode 100644 ceno_zkvm/src/instructions/riscv/gpu/jal.rs create mode 100644 ceno_zkvm/src/instructions/riscv/gpu/lui.rs diff --git a/ceno_zkvm/src/instructions/riscv/auipc.rs b/ceno_zkvm/src/instructions/riscv/auipc.rs index 6311fc2aa..46573f09d 100644 --- a/ceno_zkvm/src/instructions/riscv/auipc.rs +++ b/ceno_zkvm/src/instructions/riscv/auipc.rs @@ -24,6 +24,13 @@ use multilinear_extensions::{Expression, ToExpr, WitIn}; use p3::field::{Field, FieldAlgebra}; use witness::set_val; +#[cfg(feature = "gpu")] +use crate::tables::RMMCollections; +#[cfg(feature = "gpu")] +use ceno_emul::StepIndex; +#[cfg(feature = "gpu")] +use gkr_iop::utils::lk_multiplicity::Multiplicity; + pub struct AuipcConfig { pub i_insn: IInstructionConfig, // The limbs of the immediate except the least significant limb since it is always 0 @@ -185,6 +192,37 @@ impl Instruction for AuipcInstruction { Ok(()) } + + #[cfg(feature = "gpu")] + fn assign_instances( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + num_witin: usize, + num_structural_witin: usize, + shard_steps: &[ceno_emul::StepRecord], + step_indices: &[StepIndex], + ) -> Result<(RMMCollections, Multiplicity), ZKVMError> { + use crate::instructions::riscv::gpu::witgen_gpu; + if let Some(result) = witgen_gpu::try_gpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + witgen_gpu::GpuWitgenKind::Auipc, + )? { + return Ok(result); + } + crate::instructions::cpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + ) + } } #[cfg(test)] diff --git a/ceno_zkvm/src/instructions/riscv/gpu/auipc.rs b/ceno_zkvm/src/instructions/riscv/gpu/auipc.rs new file mode 100644 index 000000000..198bf178b --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/gpu/auipc.rs @@ -0,0 +1,107 @@ +use ceno_gpu::common::witgen_types::AuipcColumnMap; +use ff_ext::ExtensionField; + +use crate::instructions::riscv::auipc::AuipcConfig; + +/// Extract column map from a constructed AuipcConfig. +pub fn extract_auipc_column_map( + config: &AuipcConfig, + num_witin: usize, +) -> AuipcColumnMap { + let im = &config.i_insn; + + // StateInOut + let pc = im.vm_state.pc.id as u32; + let ts = im.vm_state.ts.id as u32; + + // ReadRS1 + let rs1_id = im.rs1.id.id as u32; + let rs1_prev_ts = im.rs1.prev_ts.id as u32; + let rs1_lt_diff: [u32; 2] = { + let d = &im.rs1.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + // WriteRD + let rd_id = im.rd.id.id as u32; + let rd_prev_ts = im.rd.prev_ts.id as u32; + let rd_prev_val: [u32; 2] = { + let l = im.rd.prev_value.wits_in().expect("rd prev_value WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + let rd_lt_diff: [u32; 2] = { + let d = &im.rd.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + // AUIPC-specific + let rd_bytes: [u32; 4] = { + let l = config.rd_written.wits_in().expect("rd_written UInt8 WitIns"); + assert_eq!(l.len(), 4); + [l[0].id as u32, l[1].id as u32, l[2].id as u32, l[3].id as u32] + }; + let pc_limbs: [u32; 2] = [ + config.pc_limbs[0].id as u32, + config.pc_limbs[1].id as u32, + ]; + let imm_limbs: [u32; 3] = [ + config.imm_limbs[0].id as u32, + config.imm_limbs[1].id as u32, + config.imm_limbs[2].id as u32, + ]; + + AuipcColumnMap { + pc, + ts, + rs1_id, + rs1_prev_ts, + rs1_lt_diff, + rd_id, + rd_prev_ts, + rd_prev_val, + rd_lt_diff, + rd_bytes, + pc_limbs, + imm_limbs, + num_cols: num_witin as u32, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::{Instruction, riscv::auipc::AuipcInstruction}, + structs::ProgramParams, + }; + use ff_ext::BabyBearExt4; + + type E = BabyBearExt4; + + #[test] + fn test_extract_auipc_column_map() { + let mut cs = ConstraintSystem::::new(|| "test_auipc"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + AuipcInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + + let col_map = extract_auipc_column_map(&config, cb.cs.num_witin as usize); + let flat = col_map.to_flat(); + + for (i, &col) in flat.iter().enumerate() { + assert!( + (col as usize) < col_map.num_cols as usize, + "Column {} (index {}) out of range: {} >= {}", + i, col, col, col_map.num_cols + ); + } + let mut seen = std::collections::HashSet::new(); + for &col in &flat { + assert!(seen.insert(col), "Duplicate column ID: {}", col); + } + } +} diff --git a/ceno_zkvm/src/instructions/riscv/gpu/jal.rs b/ceno_zkvm/src/instructions/riscv/gpu/jal.rs new file mode 100644 index 000000000..07fd072aa --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/gpu/jal.rs @@ -0,0 +1,86 @@ +use ceno_gpu::common::witgen_types::JalColumnMap; +use ff_ext::ExtensionField; + +use crate::instructions::riscv::jump::jal_v2::JalConfig; + +/// Extract column map from a constructed JalConfig. +pub fn extract_jal_column_map( + config: &JalConfig, + num_witin: usize, +) -> JalColumnMap { + let jm = &config.j_insn; + + // StateInOut (J-type: has next_pc) + let pc = jm.vm_state.pc.id as u32; + let next_pc = jm.vm_state.next_pc.expect("JAL must have next_pc").id as u32; + let ts = jm.vm_state.ts.id as u32; + + // WriteRD + let rd_id = jm.rd.id.id as u32; + let rd_prev_ts = jm.rd.prev_ts.id as u32; + let rd_prev_val: [u32; 2] = { + let l = jm.rd.prev_value.wits_in().expect("rd prev_value WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + let rd_lt_diff: [u32; 2] = { + let d = &jm.rd.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + // JAL-specific: rd u8 bytes + let rd_bytes: [u32; 4] = { + let l = config.rd_written.wits_in().expect("rd_written UInt8 WitIns"); + assert_eq!(l.len(), 4); + [l[0].id as u32, l[1].id as u32, l[2].id as u32, l[3].id as u32] + }; + + JalColumnMap { + pc, + next_pc, + ts, + rd_id, + rd_prev_ts, + rd_prev_val, + rd_lt_diff, + rd_bytes, + num_cols: num_witin as u32, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::{Instruction, riscv::jump::JalInstruction}, + structs::ProgramParams, + }; + use ff_ext::BabyBearExt4; + + type E = BabyBearExt4; + + #[test] + fn test_extract_jal_column_map() { + let mut cs = ConstraintSystem::::new(|| "test_jal"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + JalInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + + let col_map = extract_jal_column_map(&config, cb.cs.num_witin as usize); + let flat = col_map.to_flat(); + + for (i, &col) in flat.iter().enumerate() { + assert!( + (col as usize) < col_map.num_cols as usize, + "Column {} (index {}) out of range: {} >= {}", + i, col, col, col_map.num_cols + ); + } + let mut seen = std::collections::HashSet::new(); + for &col in &flat { + assert!(seen.insert(col), "Duplicate column ID: {}", col); + } + } +} diff --git a/ceno_zkvm/src/instructions/riscv/gpu/lui.rs b/ceno_zkvm/src/instructions/riscv/gpu/lui.rs new file mode 100644 index 000000000..2f8bd21bc --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/gpu/lui.rs @@ -0,0 +1,98 @@ +use ceno_gpu::common::witgen_types::LuiColumnMap; +use ff_ext::ExtensionField; + +use crate::instructions::riscv::lui::LuiConfig; + +/// Extract column map from a constructed LuiConfig. +pub fn extract_lui_column_map( + config: &LuiConfig, + num_witin: usize, +) -> LuiColumnMap { + let im = &config.i_insn; + + // StateInOut + let pc = im.vm_state.pc.id as u32; + let ts = im.vm_state.ts.id as u32; + + // ReadRS1 + let rs1_id = im.rs1.id.id as u32; + let rs1_prev_ts = im.rs1.prev_ts.id as u32; + let rs1_lt_diff: [u32; 2] = { + let d = &im.rs1.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + // WriteRD + let rd_id = im.rd.id.id as u32; + let rd_prev_ts = im.rd.prev_ts.id as u32; + let rd_prev_val: [u32; 2] = { + let l = im.rd.prev_value.wits_in().expect("rd prev_value WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + let rd_lt_diff: [u32; 2] = { + let d = &im.rd.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + // LUI-specific: rd bytes (skip byte 0) + imm + let rd_bytes: [u32; 3] = [ + config.rd_written[0].id as u32, + config.rd_written[1].id as u32, + config.rd_written[2].id as u32, + ]; + let imm = config.imm.id as u32; + + LuiColumnMap { + pc, + ts, + rs1_id, + rs1_prev_ts, + rs1_lt_diff, + rd_id, + rd_prev_ts, + rd_prev_val, + rd_lt_diff, + rd_bytes, + imm, + num_cols: num_witin as u32, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::{Instruction, riscv::lui::LuiInstruction}, + structs::ProgramParams, + }; + use ff_ext::BabyBearExt4; + + type E = BabyBearExt4; + + #[test] + fn test_extract_lui_column_map() { + let mut cs = ConstraintSystem::::new(|| "test_lui"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + LuiInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + + let col_map = extract_lui_column_map(&config, cb.cs.num_witin as usize); + let flat = col_map.to_flat(); + + for (i, &col) in flat.iter().enumerate() { + assert!( + (col as usize) < col_map.num_cols as usize, + "Column {} (index {}) out of range: {} >= {}", + i, col, col, col_map.num_cols + ); + } + let mut seen = std::collections::HashSet::new(); + for &col in &flat { + assert!(seen.insert(col), "Duplicate column ID: {}", col); + } + } +} diff --git a/ceno_zkvm/src/instructions/riscv/gpu/mod.rs b/ceno_zkvm/src/instructions/riscv/gpu/mod.rs index 0e0d082e3..79068cd06 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/mod.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/mod.rs @@ -3,9 +3,15 @@ pub mod add; #[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] pub mod addi; #[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod auipc; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod jal; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] pub mod logic_i; #[cfg(feature = "gpu")] pub mod logic_r; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod lui; #[cfg(feature = "gpu")] pub mod lw; #[cfg(feature = "gpu")] diff --git a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs index 79d3c40d0..d26fc9244 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs @@ -32,6 +32,12 @@ pub enum GpuWitgenKind { LogicI, #[cfg(feature = "u16limb_circuit")] Addi, + #[cfg(feature = "u16limb_circuit")] + Lui, + #[cfg(feature = "u16limb_circuit")] + Auipc, + #[cfg(feature = "u16limb_circuit")] + Jal, Lw, } @@ -354,6 +360,63 @@ fn gpu_fill_witness>( }) }) } + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Lui => { + let lui_config = unsafe { + &*(config as *const I::InstructionConfig + as *const crate::instructions::riscv::lui::LuiConfig) + }; + let col_map = info_span!("col_map") + .in_scope(|| super::lui::extract_lui_column_map(lui_config, num_witin)); + info_span!("hal_witgen_lui").in_scope(|| { + with_cached_shard_steps(|gpu_records| { + hal.witgen_lui(&col_map, gpu_records, &indices_u32, shard_offset, None) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_lui failed: {e}").into(), + ) + }) + }) + }) + } + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Auipc => { + let auipc_config = unsafe { + &*(config as *const I::InstructionConfig + as *const crate::instructions::riscv::auipc::AuipcConfig) + }; + let col_map = info_span!("col_map") + .in_scope(|| super::auipc::extract_auipc_column_map(auipc_config, num_witin)); + info_span!("hal_witgen_auipc").in_scope(|| { + with_cached_shard_steps(|gpu_records| { + hal.witgen_auipc(&col_map, gpu_records, &indices_u32, shard_offset, None) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_auipc failed: {e}").into(), + ) + }) + }) + }) + } + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Jal => { + let jal_config = unsafe { + &*(config as *const I::InstructionConfig + as *const crate::instructions::riscv::jump::jal_v2::JalConfig) + }; + let col_map = info_span!("col_map") + .in_scope(|| super::jal::extract_jal_column_map(jal_config, num_witin)); + info_span!("hal_witgen_jal").in_scope(|| { + with_cached_shard_steps(|gpu_records| { + hal.witgen_jal(&col_map, gpu_records, &indices_u32, shard_offset, None) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_jal failed: {e}").into(), + ) + }) + }) + }) + } GpuWitgenKind::Lw => { #[cfg(feature = "u16limb_circuit")] let load_config = unsafe { diff --git a/ceno_zkvm/src/instructions/riscv/jump.rs b/ceno_zkvm/src/instructions/riscv/jump.rs index 7bf1a41f6..8a1d82ea4 100644 --- a/ceno_zkvm/src/instructions/riscv/jump.rs +++ b/ceno_zkvm/src/instructions/riscv/jump.rs @@ -1,7 +1,7 @@ #[cfg(not(feature = "u16limb_circuit"))] mod jal; #[cfg(feature = "u16limb_circuit")] -mod jal_v2; +pub(crate) mod jal_v2; #[cfg(not(feature = "u16limb_circuit"))] mod jalr; diff --git a/ceno_zkvm/src/instructions/riscv/jump/jal_v2.rs b/ceno_zkvm/src/instructions/riscv/jump/jal_v2.rs index a766ea795..85db8e91f 100644 --- a/ceno_zkvm/src/instructions/riscv/jump/jal_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/jump/jal_v2.rs @@ -22,6 +22,13 @@ use gkr_iop::tables::{LookupTable, ops::XorTable}; use multilinear_extensions::{Expression, ToExpr}; use p3::field::FieldAlgebra; +#[cfg(feature = "gpu")] +use crate::tables::RMMCollections; +#[cfg(feature = "gpu")] +use ceno_emul::StepIndex; +#[cfg(feature = "gpu")] +use gkr_iop::utils::lk_multiplicity::Multiplicity; + pub struct JalConfig { pub j_insn: JInstructionConfig, pub rd_written: UInt8, @@ -121,4 +128,35 @@ impl Instruction for JalInstruction { Ok(()) } + + #[cfg(feature = "gpu")] + fn assign_instances( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + num_witin: usize, + num_structural_witin: usize, + shard_steps: &[ceno_emul::StepRecord], + step_indices: &[StepIndex], + ) -> Result<(RMMCollections, Multiplicity), ZKVMError> { + use crate::instructions::riscv::gpu::witgen_gpu; + if let Some(result) = witgen_gpu::try_gpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + witgen_gpu::GpuWitgenKind::Jal, + )? { + return Ok(result); + } + crate::instructions::cpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + ) + } } diff --git a/ceno_zkvm/src/instructions/riscv/lui.rs b/ceno_zkvm/src/instructions/riscv/lui.rs index deb7b5736..38882b78e 100644 --- a/ceno_zkvm/src/instructions/riscv/lui.rs +++ b/ceno_zkvm/src/instructions/riscv/lui.rs @@ -23,6 +23,13 @@ use multilinear_extensions::{Expression, ToExpr, WitIn}; use p3::field::FieldAlgebra; use witness::set_val; +#[cfg(feature = "gpu")] +use crate::tables::RMMCollections; +#[cfg(feature = "gpu")] +use ceno_emul::StepIndex; +#[cfg(feature = "gpu")] +use gkr_iop::utils::lk_multiplicity::Multiplicity; + pub struct LuiConfig { pub i_insn: IInstructionConfig, pub imm: WitIn, @@ -113,6 +120,37 @@ impl Instruction for LuiInstruction { Ok(()) } + + #[cfg(feature = "gpu")] + fn assign_instances( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + num_witin: usize, + num_structural_witin: usize, + shard_steps: &[ceno_emul::StepRecord], + step_indices: &[StepIndex], + ) -> Result<(RMMCollections, Multiplicity), ZKVMError> { + use crate::instructions::riscv::gpu::witgen_gpu; + if let Some(result) = witgen_gpu::try_gpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + witgen_gpu::GpuWitgenKind::Lui, + )? { + return Ok(result); + } + crate::instructions::cpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + ) + } } #[cfg(test)] From 31697bff0c964190facb676a52a54582766c0012 Mon Sep 17 00:00:00 2001 From: Velaciela Date: Fri, 6 Mar 2026 01:12:55 +0800 Subject: [PATCH 14/73] batch-6-shift --- ceno_zkvm/src/instructions/riscv/gpu/mod.rs | 4 + .../src/instructions/riscv/gpu/shift_i.rs | 124 ++++++++++++++++ .../src/instructions/riscv/gpu/shift_r.rs | 138 ++++++++++++++++++ .../src/instructions/riscv/gpu/witgen_gpu.rs | 42 ++++++ .../riscv/shift/shift_circuit_v2.rs | 96 +++++++++++- 5 files changed, 396 insertions(+), 8 deletions(-) create mode 100644 ceno_zkvm/src/instructions/riscv/gpu/shift_i.rs create mode 100644 ceno_zkvm/src/instructions/riscv/gpu/shift_r.rs diff --git a/ceno_zkvm/src/instructions/riscv/gpu/mod.rs b/ceno_zkvm/src/instructions/riscv/gpu/mod.rs index 79068cd06..4f07624b6 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/mod.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/mod.rs @@ -14,6 +14,10 @@ pub mod logic_r; pub mod lui; #[cfg(feature = "gpu")] pub mod lw; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod shift_i; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod shift_r; #[cfg(feature = "gpu")] pub mod sub; #[cfg(feature = "gpu")] diff --git a/ceno_zkvm/src/instructions/riscv/gpu/shift_i.rs b/ceno_zkvm/src/instructions/riscv/gpu/shift_i.rs new file mode 100644 index 000000000..8eacae84a --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/gpu/shift_i.rs @@ -0,0 +1,124 @@ +use ceno_gpu::common::witgen_types::ShiftIColumnMap; +use ff_ext::ExtensionField; + +use crate::instructions::riscv::shift::shift_circuit_v2::ShiftImmConfig; + +/// Extract column map from a constructed ShiftImmConfig (I-type: SLLI/SRLI/SRAI). +pub fn extract_shift_i_column_map( + config: &ShiftImmConfig, + num_witin: usize, +) -> ShiftIColumnMap { + // StateInOut + let pc = config.i_insn.vm_state.pc.id as u32; + let ts = config.i_insn.vm_state.ts.id as u32; + + // ReadRS1 + let rs1_id = config.i_insn.rs1.id.id as u32; + let rs1_prev_ts = config.i_insn.rs1.prev_ts.id as u32; + let rs1_lt_diff: [u32; 2] = { + let diffs = &config.i_insn.rs1.lt_cfg.0.diff; + assert_eq!(diffs.len(), 2); + [diffs[0].id as u32, diffs[1].id as u32] + }; + + // WriteRD + let rd_id = config.i_insn.rd.id.id as u32; + let rd_prev_ts = config.i_insn.rd.prev_ts.id as u32; + let rd_prev_val: [u32; 2] = { + let limbs = config.i_insn.rd.prev_value.wits_in().expect("rd prev_value WitIns"); + assert_eq!(limbs.len(), 2); + [limbs[0].id as u32, limbs[1].id as u32] + }; + let rd_lt_diff: [u32; 2] = { + let diffs = &config.i_insn.rd.lt_cfg.0.diff; + assert_eq!(diffs.len(), 2); + [diffs[0].id as u32, diffs[1].id as u32] + }; + + // UInt8 byte limbs + let rs1_bytes: [u32; 4] = { + let l = config.rs1_read.wits_in().expect("rs1_read WitIns"); + assert_eq!(l.len(), 4); + [l[0].id as u32, l[1].id as u32, l[2].id as u32, l[3].id as u32] + }; + let rd_bytes: [u32; 4] = { + let l = config.rd_written.wits_in().expect("rd_written WitIns"); + assert_eq!(l.len(), 4); + [l[0].id as u32, l[1].id as u32, l[2].id as u32, l[3].id as u32] + }; + + // Immediate + let imm = config.imm.id as u32; + + // ShiftBase + let bit_shift_marker: [u32; 8] = std::array::from_fn(|i| { + config.shift_base_config.bit_shift_marker[i].id as u32 + }); + let limb_shift_marker: [u32; 4] = std::array::from_fn(|i| { + config.shift_base_config.limb_shift_marker[i].id as u32 + }); + let bit_multiplier_left = config.shift_base_config.bit_multiplier_left.id as u32; + let bit_multiplier_right = config.shift_base_config.bit_multiplier_right.id as u32; + let b_sign = config.shift_base_config.b_sign.id as u32; + let bit_shift_carry: [u32; 4] = std::array::from_fn(|i| { + config.shift_base_config.bit_shift_carry[i].id as u32 + }); + + ShiftIColumnMap { + pc, + ts, + rs1_id, + rs1_prev_ts, + rs1_lt_diff, + rd_id, + rd_prev_ts, + rd_prev_val, + rd_lt_diff, + rs1_bytes, + rd_bytes, + imm, + bit_shift_marker, + limb_shift_marker, + bit_multiplier_left, + bit_multiplier_right, + b_sign, + bit_shift_carry, + num_cols: num_witin as u32, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::{Instruction, riscv::shift_imm::SlliInstruction}, + structs::ProgramParams, + }; + use ff_ext::BabyBearExt4; + + type E = BabyBearExt4; + + #[test] + fn test_extract_shift_i_column_map() { + let mut cs = ConstraintSystem::::new(|| "test_shift_i"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + SlliInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + + let col_map = extract_shift_i_column_map(&config, cb.cs.num_witin as usize); + let flat = col_map.to_flat(); + + for (i, &col) in flat.iter().enumerate() { + assert!( + (col as usize) < col_map.num_cols as usize, + "Column {} (index {}) out of range: {} >= {}", + i, col, col, col_map.num_cols + ); + } + let mut seen = std::collections::HashSet::new(); + for &col in &flat { + assert!(seen.insert(col), "Duplicate column ID: {}", col); + } + } +} diff --git a/ceno_zkvm/src/instructions/riscv/gpu/shift_r.rs b/ceno_zkvm/src/instructions/riscv/gpu/shift_r.rs new file mode 100644 index 000000000..81840a553 --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/gpu/shift_r.rs @@ -0,0 +1,138 @@ +use ceno_gpu::common::witgen_types::ShiftRColumnMap; +use ff_ext::ExtensionField; + +use crate::instructions::riscv::shift::shift_circuit_v2::ShiftRTypeConfig; + +/// Extract column map from a constructed ShiftRTypeConfig (R-type: SLL/SRL/SRA). +pub fn extract_shift_r_column_map( + config: &ShiftRTypeConfig, + num_witin: usize, +) -> ShiftRColumnMap { + // StateInOut + let pc = config.r_insn.vm_state.pc.id as u32; + let ts = config.r_insn.vm_state.ts.id as u32; + + // ReadRS1 + let rs1_id = config.r_insn.rs1.id.id as u32; + let rs1_prev_ts = config.r_insn.rs1.prev_ts.id as u32; + let rs1_lt_diff: [u32; 2] = { + let diffs = &config.r_insn.rs1.lt_cfg.0.diff; + assert_eq!(diffs.len(), 2); + [diffs[0].id as u32, diffs[1].id as u32] + }; + + // ReadRS2 + let rs2_id = config.r_insn.rs2.id.id as u32; + let rs2_prev_ts = config.r_insn.rs2.prev_ts.id as u32; + let rs2_lt_diff: [u32; 2] = { + let diffs = &config.r_insn.rs2.lt_cfg.0.diff; + assert_eq!(diffs.len(), 2); + [diffs[0].id as u32, diffs[1].id as u32] + }; + + // WriteRD + let rd_id = config.r_insn.rd.id.id as u32; + let rd_prev_ts = config.r_insn.rd.prev_ts.id as u32; + let rd_prev_val: [u32; 2] = { + let limbs = config.r_insn.rd.prev_value.wits_in().expect("rd prev_value WitIns"); + assert_eq!(limbs.len(), 2); + [limbs[0].id as u32, limbs[1].id as u32] + }; + let rd_lt_diff: [u32; 2] = { + let diffs = &config.r_insn.rd.lt_cfg.0.diff; + assert_eq!(diffs.len(), 2); + [diffs[0].id as u32, diffs[1].id as u32] + }; + + // UInt8 byte limbs + let rs1_bytes: [u32; 4] = { + let l = config.rs1_read.wits_in().expect("rs1_read WitIns"); + assert_eq!(l.len(), 4); + [l[0].id as u32, l[1].id as u32, l[2].id as u32, l[3].id as u32] + }; + let rs2_bytes: [u32; 4] = { + let l = config.rs2_read.wits_in().expect("rs2_read WitIns"); + assert_eq!(l.len(), 4); + [l[0].id as u32, l[1].id as u32, l[2].id as u32, l[3].id as u32] + }; + let rd_bytes: [u32; 4] = { + let l = config.rd_written.wits_in().expect("rd_written WitIns"); + assert_eq!(l.len(), 4); + [l[0].id as u32, l[1].id as u32, l[2].id as u32, l[3].id as u32] + }; + + // ShiftBase + let bit_shift_marker: [u32; 8] = std::array::from_fn(|i| { + config.shift_base_config.bit_shift_marker[i].id as u32 + }); + let limb_shift_marker: [u32; 4] = std::array::from_fn(|i| { + config.shift_base_config.limb_shift_marker[i].id as u32 + }); + let bit_multiplier_left = config.shift_base_config.bit_multiplier_left.id as u32; + let bit_multiplier_right = config.shift_base_config.bit_multiplier_right.id as u32; + let b_sign = config.shift_base_config.b_sign.id as u32; + let bit_shift_carry: [u32; 4] = std::array::from_fn(|i| { + config.shift_base_config.bit_shift_carry[i].id as u32 + }); + + ShiftRColumnMap { + pc, + ts, + rs1_id, + rs1_prev_ts, + rs1_lt_diff, + rs2_id, + rs2_prev_ts, + rs2_lt_diff, + rd_id, + rd_prev_ts, + rd_prev_val, + rd_lt_diff, + rs1_bytes, + rs2_bytes, + rd_bytes, + bit_shift_marker, + limb_shift_marker, + bit_multiplier_left, + bit_multiplier_right, + b_sign, + bit_shift_carry, + num_cols: num_witin as u32, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::{Instruction, riscv::shift::SllInstruction}, + structs::ProgramParams, + }; + use ff_ext::BabyBearExt4; + + type E = BabyBearExt4; + + #[test] + fn test_extract_shift_r_column_map() { + let mut cs = ConstraintSystem::::new(|| "test_shift_r"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + SllInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + + let col_map = extract_shift_r_column_map(&config, cb.cs.num_witin as usize); + let flat = col_map.to_flat(); + + for (i, &col) in flat.iter().enumerate() { + assert!( + (col as usize) < col_map.num_cols as usize, + "Column {} (index {}) out of range: {} >= {}", + i, col, col, col_map.num_cols + ); + } + let mut seen = std::collections::HashSet::new(); + for &col in &flat { + assert!(seen.insert(col), "Duplicate column ID: {}", col); + } + } +} diff --git a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs index d26fc9244..dbb2179d5 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs @@ -38,6 +38,10 @@ pub enum GpuWitgenKind { Auipc, #[cfg(feature = "u16limb_circuit")] Jal, + #[cfg(feature = "u16limb_circuit")] + ShiftR(u32), // 0=SLL, 1=SRL, 2=SRA + #[cfg(feature = "u16limb_circuit")] + ShiftI(u32), // 0=SLLI, 1=SRLI, 2=SRAI Lw, } @@ -417,6 +421,44 @@ fn gpu_fill_witness>( }) }) } + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::ShiftR(shift_kind) => { + let shift_config = unsafe { + &*(config as *const I::InstructionConfig + as *const crate::instructions::riscv::shift::shift_circuit_v2::ShiftRTypeConfig) + }; + let col_map = info_span!("col_map") + .in_scope(|| super::shift_r::extract_shift_r_column_map(shift_config, num_witin)); + info_span!("hal_witgen_shift_r").in_scope(|| { + with_cached_shard_steps(|gpu_records| { + hal.witgen_shift_r(&col_map, gpu_records, &indices_u32, shard_offset, shift_kind, None) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_shift_r failed: {e}").into(), + ) + }) + }) + }) + } + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::ShiftI(shift_kind) => { + let shift_config = unsafe { + &*(config as *const I::InstructionConfig + as *const crate::instructions::riscv::shift::shift_circuit_v2::ShiftImmConfig) + }; + let col_map = info_span!("col_map") + .in_scope(|| super::shift_i::extract_shift_i_column_map(shift_config, num_witin)); + info_span!("hal_witgen_shift_i").in_scope(|| { + with_cached_shard_steps(|gpu_records| { + hal.witgen_shift_i(&col_map, gpu_records, &indices_u32, shard_offset, shift_kind, None) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_shift_i failed: {e}").into(), + ) + }) + }) + }) + } GpuWitgenKind::Lw => { #[cfg(feature = "u16limb_circuit")] let load_config = unsafe { diff --git a/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs index 310d17491..3093943d7 100644 --- a/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs @@ -1,4 +1,10 @@ use crate::e2e::ShardContext; +#[cfg(feature = "gpu")] +use crate::tables::RMMCollections; +#[cfg(feature = "gpu")] +use ceno_emul::StepIndex; +#[cfg(feature = "gpu")] +use gkr_iop::utils::lk_multiplicity::Multiplicity; /// constrain implementation follow from https://github.com/openvm-org/openvm/blob/main/extensions/rv32im/circuit/src/shift/core.rs use crate::{ instructions::{ @@ -265,11 +271,11 @@ impl } pub struct ShiftRTypeConfig { - shift_base_config: ShiftBaseConfig, - rs1_read: UInt8, - rs2_read: UInt8, + pub(crate) shift_base_config: ShiftBaseConfig, + pub(crate) rs1_read: UInt8, + pub(crate) rs2_read: UInt8, pub rd_written: UInt8, - r_insn: RInstructionConfig, + pub(crate) r_insn: RInstructionConfig, } pub struct ShiftLogicalInstruction(PhantomData<(E, I)>); @@ -363,14 +369,51 @@ impl Instruction for ShiftLogicalInstru Ok(()) } + + #[cfg(feature = "gpu")] + fn assign_instances( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + num_witin: usize, + num_structural_witin: usize, + shard_steps: &[ceno_emul::StepRecord], + step_indices: &[StepIndex], + ) -> Result<(RMMCollections, Multiplicity), crate::error::ZKVMError> { + use crate::instructions::riscv::gpu::witgen_gpu; + let shift_kind = match I::INST_KIND { + InsnKind::SLL => 0u32, + InsnKind::SRL => 1u32, + InsnKind::SRA => 2u32, + _ => unreachable!(), + }; + if let Some(result) = witgen_gpu::try_gpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + witgen_gpu::GpuWitgenKind::ShiftR(shift_kind), + )? { + return Ok(result); + } + crate::instructions::cpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + ) + } } pub struct ShiftImmConfig { - shift_base_config: ShiftBaseConfig, - rs1_read: UInt8, + pub(crate) shift_base_config: ShiftBaseConfig, + pub(crate) rs1_read: UInt8, pub rd_written: UInt8, - i_insn: IInstructionConfig, - imm: WitIn, + pub(crate) i_insn: IInstructionConfig, + pub(crate) imm: WitIn, } pub struct ShiftImmInstruction(PhantomData<(E, I)>); @@ -466,6 +509,43 @@ impl Instruction for ShiftImmInstructio Ok(()) } + + #[cfg(feature = "gpu")] + fn assign_instances( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + num_witin: usize, + num_structural_witin: usize, + shard_steps: &[ceno_emul::StepRecord], + step_indices: &[StepIndex], + ) -> Result<(RMMCollections, Multiplicity), crate::error::ZKVMError> { + use crate::instructions::riscv::gpu::witgen_gpu; + let shift_kind = match I::INST_KIND { + InsnKind::SLLI => 0u32, + InsnKind::SRLI => 1u32, + InsnKind::SRAI => 2u32, + _ => unreachable!(), + }; + if let Some(result) = witgen_gpu::try_gpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + witgen_gpu::GpuWitgenKind::ShiftI(shift_kind), + )? { + return Ok(result); + } + crate::instructions::cpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + ) + } } fn run_shift( From 707ea1d26c5d1677afeb6178e66ae92f464a26e2 Mon Sep 17 00:00:00 2001 From: Velaciela Date: Fri, 6 Mar 2026 01:39:04 +0800 Subject: [PATCH 15/73] batch-8,9-slt --- ceno_zkvm/src/instructions/riscv/gpu/mod.rs | 4 + ceno_zkvm/src/instructions/riscv/gpu/slt.rs | 137 ++++++++++++++++++ ceno_zkvm/src/instructions/riscv/gpu/slti.rs | 120 +++++++++++++++ .../src/instructions/riscv/gpu/witgen_gpu.rs | 42 ++++++ ceno_zkvm/src/instructions/riscv/slt.rs | 2 +- .../instructions/riscv/slt/slt_circuit_v2.rs | 51 ++++++- ceno_zkvm/src/instructions/riscv/slti.rs | 2 +- .../riscv/slti/slti_circuit_v2.rs | 53 ++++++- 8 files changed, 400 insertions(+), 11 deletions(-) create mode 100644 ceno_zkvm/src/instructions/riscv/gpu/slt.rs create mode 100644 ceno_zkvm/src/instructions/riscv/gpu/slti.rs diff --git a/ceno_zkvm/src/instructions/riscv/gpu/mod.rs b/ceno_zkvm/src/instructions/riscv/gpu/mod.rs index 4f07624b6..848cdaf71 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/mod.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/mod.rs @@ -18,6 +18,10 @@ pub mod lw; pub mod shift_i; #[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] pub mod shift_r; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod slt; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod slti; #[cfg(feature = "gpu")] pub mod sub; #[cfg(feature = "gpu")] diff --git a/ceno_zkvm/src/instructions/riscv/gpu/slt.rs b/ceno_zkvm/src/instructions/riscv/gpu/slt.rs new file mode 100644 index 000000000..7cd25266e --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/gpu/slt.rs @@ -0,0 +1,137 @@ +use ceno_gpu::common::witgen_types::SltColumnMap; +use ff_ext::ExtensionField; + +use crate::instructions::riscv::slt::slt_circuit_v2::SetLessThanConfig; + +/// Extract column map from a constructed SetLessThanConfig (SLT/SLTU). +pub fn extract_slt_column_map( + config: &SetLessThanConfig, + num_witin: usize, +) -> SltColumnMap { + // rs1_read: UInt (2 u16 limbs) + let rs1_limbs: [u32; 2] = { + let limbs = config + .rs1_read + .wits_in() + .expect("rs1_read should have WitIn limbs"); + assert_eq!(limbs.len(), 2); + [limbs[0].id as u32, limbs[1].id as u32] + }; + + // rs2_read: UInt (2 u16 limbs) + let rs2_limbs: [u32; 2] = { + let limbs = config + .rs2_read + .wits_in() + .expect("rs2_read should have WitIn limbs"); + assert_eq!(limbs.len(), 2); + [limbs[0].id as u32, limbs[1].id as u32] + }; + + // UIntLimbsLT comparison gadget + let cmp_lt = config.uint_lt_config.cmp_lt.id as u32; + let a_msb_f = config.uint_lt_config.a_msb_f.id as u32; + let b_msb_f = config.uint_lt_config.b_msb_f.id as u32; + let diff_marker: [u32; 2] = [ + config.uint_lt_config.diff_marker[0].id as u32, + config.uint_lt_config.diff_marker[1].id as u32, + ]; + let diff_val = config.uint_lt_config.diff_val.id as u32; + + // R-type base: StateInOut + ReadRS1 + ReadRS2 + WriteRD + let pc = config.r_insn.vm_state.pc.id as u32; + let ts = config.r_insn.vm_state.ts.id as u32; + + let rs1_id = config.r_insn.rs1.id.id as u32; + let rs1_prev_ts = config.r_insn.rs1.prev_ts.id as u32; + let rs1_lt_diff: [u32; 2] = { + let diffs = &config.r_insn.rs1.lt_cfg.0.diff; + assert_eq!(diffs.len(), 2); + [diffs[0].id as u32, diffs[1].id as u32] + }; + + let rs2_id = config.r_insn.rs2.id.id as u32; + let rs2_prev_ts = config.r_insn.rs2.prev_ts.id as u32; + let rs2_lt_diff: [u32; 2] = { + let diffs = &config.r_insn.rs2.lt_cfg.0.diff; + assert_eq!(diffs.len(), 2); + [diffs[0].id as u32, diffs[1].id as u32] + }; + + let rd_id = config.r_insn.rd.id.id as u32; + let rd_prev_ts = config.r_insn.rd.prev_ts.id as u32; + let rd_prev_val: [u32; 2] = { + let limbs = config + .r_insn + .rd + .prev_value + .wits_in() + .expect("WriteRD prev_value should have WitIn limbs"); + assert_eq!(limbs.len(), 2); + [limbs[0].id as u32, limbs[1].id as u32] + }; + let rd_lt_diff: [u32; 2] = { + let diffs = &config.r_insn.rd.lt_cfg.0.diff; + assert_eq!(diffs.len(), 2); + [diffs[0].id as u32, diffs[1].id as u32] + }; + + SltColumnMap { + rs1_limbs, + rs2_limbs, + cmp_lt, + a_msb_f, + b_msb_f, + diff_marker, + diff_val, + pc, + ts, + rs1_id, + rs1_prev_ts, + rs1_lt_diff, + rs2_id, + rs2_prev_ts, + rs2_lt_diff, + rd_id, + rd_prev_ts, + rd_prev_val, + rd_lt_diff, + num_cols: num_witin as u32, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::{Instruction, riscv::slt::SltInstruction}, + structs::ProgramParams, + }; + use ff_ext::BabyBearExt4; + + type E = BabyBearExt4; + + #[test] + fn test_extract_slt_column_map() { + let mut cs = ConstraintSystem::::new(|| "test"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + SltInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + + let col_map = extract_slt_column_map(&config, cb.cs.num_witin as usize); + let flat = col_map.to_flat(); + + for (i, &col) in flat.iter().enumerate() { + assert!( + (col as usize) < col_map.num_cols as usize, + "Column {} (index {}) out of range: {} >= {}", + i, col, col, col_map.num_cols + ); + } + let mut seen = std::collections::HashSet::new(); + for &col in &flat { + assert!(seen.insert(col), "Duplicate column ID: {}", col); + } + } +} diff --git a/ceno_zkvm/src/instructions/riscv/gpu/slti.rs b/ceno_zkvm/src/instructions/riscv/gpu/slti.rs new file mode 100644 index 000000000..0b38c7a8e --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/gpu/slti.rs @@ -0,0 +1,120 @@ +use ceno_gpu::common::witgen_types::SltiColumnMap; +use ff_ext::ExtensionField; + +use crate::instructions::riscv::slti::slti_circuit_v2::SetLessThanImmConfig; + +/// Extract column map from a constructed SetLessThanImmConfig (SLTI/SLTIU). +pub fn extract_slti_column_map( + config: &SetLessThanImmConfig, + num_witin: usize, +) -> SltiColumnMap { + // rs1_read: UInt (2 u16 limbs) + let rs1_limbs: [u32; 2] = { + let limbs = config + .rs1_read + .wits_in() + .expect("rs1_read should have WitIn limbs"); + assert_eq!(limbs.len(), 2); + [limbs[0].id as u32, limbs[1].id as u32] + }; + + let imm = config.imm.id as u32; + let imm_sign = config.imm_sign.id as u32; + + // UIntLimbsLT comparison gadget + let cmp_lt = config.uint_lt_config.cmp_lt.id as u32; + let a_msb_f = config.uint_lt_config.a_msb_f.id as u32; + let b_msb_f = config.uint_lt_config.b_msb_f.id as u32; + let diff_marker: [u32; 2] = [ + config.uint_lt_config.diff_marker[0].id as u32, + config.uint_lt_config.diff_marker[1].id as u32, + ]; + let diff_val = config.uint_lt_config.diff_val.id as u32; + + // I-type base: StateInOut + ReadRS1 + WriteRD + let pc = config.i_insn.vm_state.pc.id as u32; + let ts = config.i_insn.vm_state.ts.id as u32; + + let rs1_id = config.i_insn.rs1.id.id as u32; + let rs1_prev_ts = config.i_insn.rs1.prev_ts.id as u32; + let rs1_lt_diff: [u32; 2] = { + let diffs = &config.i_insn.rs1.lt_cfg.0.diff; + assert_eq!(diffs.len(), 2); + [diffs[0].id as u32, diffs[1].id as u32] + }; + + let rd_id = config.i_insn.rd.id.id as u32; + let rd_prev_ts = config.i_insn.rd.prev_ts.id as u32; + let rd_prev_val: [u32; 2] = { + let limbs = config + .i_insn + .rd + .prev_value + .wits_in() + .expect("WriteRD prev_value should have WitIn limbs"); + assert_eq!(limbs.len(), 2); + [limbs[0].id as u32, limbs[1].id as u32] + }; + let rd_lt_diff: [u32; 2] = { + let diffs = &config.i_insn.rd.lt_cfg.0.diff; + assert_eq!(diffs.len(), 2); + [diffs[0].id as u32, diffs[1].id as u32] + }; + + SltiColumnMap { + rs1_limbs, + imm, + imm_sign, + cmp_lt, + a_msb_f, + b_msb_f, + diff_marker, + diff_val, + pc, + ts, + rs1_id, + rs1_prev_ts, + rs1_lt_diff, + rd_id, + rd_prev_ts, + rd_prev_val, + rd_lt_diff, + num_cols: num_witin as u32, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::{Instruction, riscv::slti::SltiInstruction}, + structs::ProgramParams, + }; + use ff_ext::BabyBearExt4; + + type E = BabyBearExt4; + + #[test] + fn test_extract_slti_column_map() { + let mut cs = ConstraintSystem::::new(|| "test"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + SltiInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + + let col_map = extract_slti_column_map(&config, cb.cs.num_witin as usize); + let flat = col_map.to_flat(); + + for (i, &col) in flat.iter().enumerate() { + assert!( + (col as usize) < col_map.num_cols as usize, + "Column {} (index {}) out of range: {} >= {}", + i, col, col, col_map.num_cols + ); + } + let mut seen = std::collections::HashSet::new(); + for &col in &flat { + assert!(seen.insert(col), "Duplicate column ID: {}", col); + } + } +} diff --git a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs index dbb2179d5..01b5e35a5 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs @@ -42,6 +42,10 @@ pub enum GpuWitgenKind { ShiftR(u32), // 0=SLL, 1=SRL, 2=SRA #[cfg(feature = "u16limb_circuit")] ShiftI(u32), // 0=SLLI, 1=SRLI, 2=SRAI + #[cfg(feature = "u16limb_circuit")] + Slt(u32), // 1=SLT(signed), 0=SLTU(unsigned) + #[cfg(feature = "u16limb_circuit")] + Slti(u32), // 1=SLTI(signed), 0=SLTIU(unsigned) Lw, } @@ -459,6 +463,44 @@ fn gpu_fill_witness>( }) }) } + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Slt(is_signed) => { + let slt_config = unsafe { + &*(config as *const I::InstructionConfig + as *const crate::instructions::riscv::slt::slt_circuit_v2::SetLessThanConfig) + }; + let col_map = info_span!("col_map") + .in_scope(|| super::slt::extract_slt_column_map(slt_config, num_witin)); + info_span!("hal_witgen_slt").in_scope(|| { + with_cached_shard_steps(|gpu_records| { + hal.witgen_slt(&col_map, gpu_records, &indices_u32, shard_offset, is_signed, None) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_slt failed: {e}").into(), + ) + }) + }) + }) + } + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Slti(is_signed) => { + let slti_config = unsafe { + &*(config as *const I::InstructionConfig + as *const crate::instructions::riscv::slti::slti_circuit_v2::SetLessThanImmConfig) + }; + let col_map = info_span!("col_map") + .in_scope(|| super::slti::extract_slti_column_map(slti_config, num_witin)); + info_span!("hal_witgen_slti").in_scope(|| { + with_cached_shard_steps(|gpu_records| { + hal.witgen_slti(&col_map, gpu_records, &indices_u32, shard_offset, is_signed, None) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_slti failed: {e}").into(), + ) + }) + }) + }) + } GpuWitgenKind::Lw => { #[cfg(feature = "u16limb_circuit")] let load_config = unsafe { diff --git a/ceno_zkvm/src/instructions/riscv/slt.rs b/ceno_zkvm/src/instructions/riscv/slt.rs index a0dc51bd6..01d39a49c 100644 --- a/ceno_zkvm/src/instructions/riscv/slt.rs +++ b/ceno_zkvm/src/instructions/riscv/slt.rs @@ -1,7 +1,7 @@ #[cfg(not(feature = "u16limb_circuit"))] mod slt_circuit; #[cfg(feature = "u16limb_circuit")] -mod slt_circuit_v2; +pub(crate) mod slt_circuit_v2; use ceno_emul::InsnKind; diff --git a/ceno_zkvm/src/instructions/riscv/slt/slt_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/slt/slt_circuit_v2.rs index d57aeb2cd..b5e41f6ac 100644 --- a/ceno_zkvm/src/instructions/riscv/slt/slt_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/slt/slt_circuit_v2.rs @@ -15,18 +15,25 @@ use ceno_emul::{InsnKind, StepRecord}; use ff_ext::ExtensionField; use std::marker::PhantomData; +#[cfg(feature = "gpu")] +use crate::tables::RMMCollections; +#[cfg(feature = "gpu")] +use ceno_emul::StepIndex; +#[cfg(feature = "gpu")] +use gkr_iop::utils::lk_multiplicity::Multiplicity; + pub struct SetLessThanInstruction(PhantomData<(E, I)>); /// This config handles R-Instructions that represent registers values as 2 * u16. pub struct SetLessThanConfig { - r_insn: RInstructionConfig, + pub(crate) r_insn: RInstructionConfig, - rs1_read: UInt, - rs2_read: UInt, + pub(crate) rs1_read: UInt, + pub(crate) rs2_read: UInt, #[allow(dead_code)] pub(crate) rd_written: UInt, - uint_lt_config: UIntLimbsLTConfig, + pub(crate) uint_lt_config: UIntLimbsLTConfig, } impl Instruction for SetLessThanInstruction { type InstructionConfig = SetLessThanConfig; @@ -113,4 +120,40 @@ impl Instruction for SetLessThanInstruc )?; Ok(()) } + + #[cfg(feature = "gpu")] + fn assign_instances( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + num_witin: usize, + num_structural_witin: usize, + shard_steps: &[StepRecord], + step_indices: &[StepIndex], + ) -> Result<(RMMCollections, Multiplicity), crate::error::ZKVMError> { + use crate::instructions::riscv::gpu::witgen_gpu; + let is_signed = match I::INST_KIND { + InsnKind::SLT => 1u32, + InsnKind::SLTU => 0u32, + _ => unreachable!(), + }; + if let Some(result) = witgen_gpu::try_gpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + witgen_gpu::GpuWitgenKind::Slt(is_signed), + )? { + return Ok(result); + } + crate::instructions::cpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + ) + } } diff --git a/ceno_zkvm/src/instructions/riscv/slti.rs b/ceno_zkvm/src/instructions/riscv/slti.rs index 90dcb8448..474b664ee 100644 --- a/ceno_zkvm/src/instructions/riscv/slti.rs +++ b/ceno_zkvm/src/instructions/riscv/slti.rs @@ -1,5 +1,5 @@ #[cfg(feature = "u16limb_circuit")] -mod slti_circuit_v2; +pub(crate) mod slti_circuit_v2; #[cfg(not(feature = "u16limb_circuit"))] mod slti_circuit; diff --git a/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs index b2449614e..471a70866 100644 --- a/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs @@ -23,18 +23,25 @@ use p3::field::FieldAlgebra; use std::marker::PhantomData; use witness::set_val; +#[cfg(feature = "gpu")] +use crate::tables::RMMCollections; +#[cfg(feature = "gpu")] +use ceno_emul::StepIndex; +#[cfg(feature = "gpu")] +use gkr_iop::utils::lk_multiplicity::Multiplicity; + #[derive(Debug)] pub struct SetLessThanImmConfig { - i_insn: IInstructionConfig, + pub(crate) i_insn: IInstructionConfig, - rs1_read: UInt, - imm: WitIn, + pub(crate) rs1_read: UInt, + pub(crate) imm: WitIn, // 0 positive, 1 negative - imm_sign: WitIn, + pub(crate) imm_sign: WitIn, #[allow(dead_code)] pub(crate) rd_written: UInt, - uint_lt_config: UIntLimbsLTConfig, + pub(crate) uint_lt_config: UIntLimbsLTConfig, } pub struct SetLessThanImmInstruction(PhantomData<(E, I)>); @@ -133,4 +140,40 @@ impl Instruction for SetLessThanImmInst )?; Ok(()) } + + #[cfg(feature = "gpu")] + fn assign_instances( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + num_witin: usize, + num_structural_witin: usize, + shard_steps: &[StepRecord], + step_indices: &[StepIndex], + ) -> Result<(RMMCollections, Multiplicity), crate::error::ZKVMError> { + use crate::instructions::riscv::gpu::witgen_gpu; + let is_signed = match I::INST_KIND { + InsnKind::SLTI => 1u32, + InsnKind::SLTIU => 0u32, + _ => unreachable!(), + }; + if let Some(result) = witgen_gpu::try_gpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + witgen_gpu::GpuWitgenKind::Slti(is_signed), + )? { + return Ok(result); + } + crate::instructions::cpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + ) + } } From 3fcc70e54bbcd7dba139f184cc7ea7af866e6ef4 Mon Sep 17 00:00:00 2001 From: Velaciela Date: Fri, 6 Mar 2026 01:54:14 +0800 Subject: [PATCH 16/73] test: orrectness --- ceno_zkvm/src/instructions/riscv/gpu/addi.rs | 85 ++++++++++++++++++ ceno_zkvm/src/instructions/riscv/gpu/auipc.rs | 86 +++++++++++++++++++ ceno_zkvm/src/instructions/riscv/gpu/jal.rs | 85 ++++++++++++++++++ .../src/instructions/riscv/gpu/logic_i.rs | 85 ++++++++++++++++++ ceno_zkvm/src/instructions/riscv/gpu/lui.rs | 85 ++++++++++++++++++ .../src/instructions/riscv/gpu/shift_i.rs | 85 ++++++++++++++++++ .../src/instructions/riscv/gpu/shift_r.rs | 81 +++++++++++++++++ ceno_zkvm/src/instructions/riscv/gpu/slt.rs | 82 ++++++++++++++++++ ceno_zkvm/src/instructions/riscv/gpu/slti.rs | 85 ++++++++++++++++++ 9 files changed, 759 insertions(+) diff --git a/ceno_zkvm/src/instructions/riscv/gpu/addi.rs b/ceno_zkvm/src/instructions/riscv/gpu/addi.rs index 59019bf18..ef3339d82 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/addi.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/addi.rs @@ -111,4 +111,89 @@ mod tests { assert!(seen.insert(col), "Duplicate column ID: {}", col); } } + + #[test] + #[cfg(feature = "gpu")] + fn test_gpu_witgen_addi_correctness() { + use crate::e2e::ShardContext; + use ceno_emul::{ByteAddr, Change, InsnKind, PC_STEP_SIZE, StepRecord, encode_rv32}; + use ceno_gpu::{Buffer, bb31::CudaHalBB31}; + + let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); + + let mut cs = ConstraintSystem::::new(|| "test_addi_gpu"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + AddiInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let num_witin = cb.cs.num_witin as usize; + let num_structural_witin = cb.cs.num_structural_witin as usize; + + let n = 1024; + let steps: Vec = (0..n) + .map(|i| { + let rs1 = (i as u32) * 137 + 1; + let imm = ((i as i32) % 2048 - 1024) as i32; + let rd_after = rs1.wrapping_add(imm as u32); + let cycle = 4 + (i as u64) * 4; + let pc = ByteAddr(0x1000 + (i as u32) * 4); + let insn_code = encode_rv32(InsnKind::ADDI, 2, 0, 4, imm); + StepRecord::new_i_instruction( + cycle, + Change::new(pc, pc + PC_STEP_SIZE), + insn_code, + rs1, + Change::new((i as u32) % 200, rd_after), + 0, + ) + }) + .collect(); + let indices: Vec = (0..n).collect(); + + let mut shard_ctx = ShardContext::default(); + let (cpu_rmms, _lkm) = crate::instructions::cpu_assign_instances::>( + &config, &mut shard_ctx, num_witin, num_structural_witin, &steps, &indices, + ) + .unwrap(); + let cpu_witness = &cpu_rmms[0]; + + let col_map = extract_addi_column_map(&config, num_witin); + let shard_ctx_gpu = ShardContext::default(); + let shard_offset = shard_ctx_gpu.current_shard_offset_cycle(); + let steps_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + steps.as_ptr() as *const u8, + steps.len() * std::mem::size_of::(), + ) + }; + let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); + let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); + let gpu_result = hal + .witgen_addi(&col_map, &gpu_records, &indices_u32, shard_offset, None) + .unwrap(); + + let gpu_data: Vec<::BaseField> = + gpu_result.device_buffer.to_vec().unwrap(); + let cpu_data = cpu_witness.values(); + assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); + + let flat = col_map.to_flat(); + let mut mismatches = 0; + for row in 0..n { + for &col in &flat { + let c = col as usize; + let gpu_val = gpu_data[row * num_witin + c]; + let cpu_val = cpu_data[row * num_witin + c]; + if gpu_val != cpu_val { + if mismatches < 10 { + eprintln!( + "Mismatch at row={}, col={}: GPU={:?}, CPU={:?}", + row, c, gpu_val, cpu_val + ); + } + mismatches += 1; + } + } + } + assert_eq!(mismatches, 0, "Found {} mismatches", mismatches); + } } diff --git a/ceno_zkvm/src/instructions/riscv/gpu/auipc.rs b/ceno_zkvm/src/instructions/riscv/gpu/auipc.rs index 198bf178b..9ba8fe9b4 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/auipc.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/auipc.rs @@ -104,4 +104,90 @@ mod tests { assert!(seen.insert(col), "Duplicate column ID: {}", col); } } + + #[test] + #[cfg(feature = "gpu")] + fn test_gpu_witgen_auipc_correctness() { + use crate::e2e::ShardContext; + use ceno_emul::{ByteAddr, Change, InsnKind, PC_STEP_SIZE, StepRecord, encode_rv32}; + use ceno_gpu::{Buffer, bb31::CudaHalBB31}; + + let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); + + let mut cs = ConstraintSystem::::new(|| "test_auipc_gpu"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + AuipcInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let num_witin = cb.cs.num_witin as usize; + let num_structural_witin = cb.cs.num_structural_witin as usize; + + let n = 1024; + let steps: Vec = (0..n) + .map(|i| { + let imm_20bit = (i as i32) % 0x100000; // 0..0xfffff (20-bit) + let imm = imm_20bit << 12; // AUIPC immediate is upper 20 bits + let pc = ByteAddr(0x1000 + (i as u32) * 4); + let rd_after = pc.0.wrapping_add(imm as u32); + let cycle = 4 + (i as u64) * 4; + let insn_code = encode_rv32(InsnKind::AUIPC, 0, 0, 4, imm); + StepRecord::new_i_instruction( + cycle, + Change::new(pc, pc + PC_STEP_SIZE), + insn_code, + 0, + Change::new((i as u32) % 200, rd_after), + 0, + ) + }) + .collect(); + let indices: Vec = (0..n).collect(); + + let mut shard_ctx = ShardContext::default(); + let (cpu_rmms, _lkm) = + crate::instructions::cpu_assign_instances::>( + &config, &mut shard_ctx, num_witin, num_structural_witin, &steps, &indices, + ) + .unwrap(); + let cpu_witness = &cpu_rmms[0]; + + let col_map = extract_auipc_column_map(&config, num_witin); + let shard_ctx_gpu = ShardContext::default(); + let shard_offset = shard_ctx_gpu.current_shard_offset_cycle(); + let steps_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + steps.as_ptr() as *const u8, + steps.len() * std::mem::size_of::(), + ) + }; + let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); + let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); + let gpu_result = hal + .witgen_auipc(&col_map, &gpu_records, &indices_u32, shard_offset, None) + .unwrap(); + + let gpu_data: Vec<::BaseField> = + gpu_result.device_buffer.to_vec().unwrap(); + let cpu_data = cpu_witness.values(); + assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); + + let flat = col_map.to_flat(); + let mut mismatches = 0; + for row in 0..n { + for &col in &flat { + let c = col as usize; + let gpu_val = gpu_data[row * num_witin + c]; + let cpu_val = cpu_data[row * num_witin + c]; + if gpu_val != cpu_val { + if mismatches < 10 { + eprintln!( + "Mismatch at row={}, col={}: GPU={:?}, CPU={:?}", + row, c, gpu_val, cpu_val + ); + } + mismatches += 1; + } + } + } + assert_eq!(mismatches, 0, "Found {} mismatches", mismatches); + } } diff --git a/ceno_zkvm/src/instructions/riscv/gpu/jal.rs b/ceno_zkvm/src/instructions/riscv/gpu/jal.rs index 07fd072aa..815f7b809 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/jal.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/jal.rs @@ -83,4 +83,89 @@ mod tests { assert!(seen.insert(col), "Duplicate column ID: {}", col); } } + + #[test] + #[cfg(feature = "gpu")] + fn test_gpu_witgen_jal_correctness() { + use crate::e2e::ShardContext; + use ceno_emul::{ByteAddr, Change, InsnKind, PC_STEP_SIZE, StepRecord, encode_rv32}; + use ceno_gpu::{Buffer, bb31::CudaHalBB31}; + + let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); + + let mut cs = ConstraintSystem::::new(|| "test_jal_gpu"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + JalInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let num_witin = cb.cs.num_witin as usize; + let num_structural_witin = cb.cs.num_structural_witin as usize; + + let n = 1024; + let steps: Vec = (0..n) + .map(|i| { + let pc = ByteAddr(0x1000 + (i as u32) * 4); + // JAL offset must be even; use small positive/negative offsets + let offset = (((i as i32) % 256) - 128) * 2; // even offsets + let new_pc = ByteAddr(pc.0.wrapping_add_signed(offset)); + let rd_after: u32 = (pc + PC_STEP_SIZE).into(); + let cycle = 4 + (i as u64) * 4; + let insn_code = encode_rv32(InsnKind::JAL, 0, 0, 4, offset); + StepRecord::new_j_instruction( + cycle, + Change::new(pc, new_pc), + insn_code, + Change::new((i as u32) % 200, rd_after), + 0, + ) + }) + .collect(); + let indices: Vec = (0..n).collect(); + + let mut shard_ctx = ShardContext::default(); + let (cpu_rmms, _lkm) = crate::instructions::cpu_assign_instances::>( + &config, &mut shard_ctx, num_witin, num_structural_witin, &steps, &indices, + ) + .unwrap(); + let cpu_witness = &cpu_rmms[0]; + + let col_map = extract_jal_column_map(&config, num_witin); + let shard_ctx_gpu = ShardContext::default(); + let shard_offset = shard_ctx_gpu.current_shard_offset_cycle(); + let steps_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + steps.as_ptr() as *const u8, + steps.len() * std::mem::size_of::(), + ) + }; + let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); + let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); + let gpu_result = hal + .witgen_jal(&col_map, &gpu_records, &indices_u32, shard_offset, None) + .unwrap(); + + let gpu_data: Vec<::BaseField> = + gpu_result.device_buffer.to_vec().unwrap(); + let cpu_data = cpu_witness.values(); + assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); + + let flat = col_map.to_flat(); + let mut mismatches = 0; + for row in 0..n { + for &col in &flat { + let c = col as usize; + let gpu_val = gpu_data[row * num_witin + c]; + let cpu_val = cpu_data[row * num_witin + c]; + if gpu_val != cpu_val { + if mismatches < 10 { + eprintln!( + "Mismatch at row={}, col={}: GPU={:?}, CPU={:?}", + row, c, gpu_val, cpu_val + ); + } + mismatches += 1; + } + } + } + assert_eq!(mismatches, 0, "Found {} mismatches", mismatches); + } } diff --git a/ceno_zkvm/src/instructions/riscv/gpu/logic_i.rs b/ceno_zkvm/src/instructions/riscv/gpu/logic_i.rs index 235a25fd1..ce4999c10 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/logic_i.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/logic_i.rs @@ -117,4 +117,89 @@ mod tests { assert!(seen.insert(col), "Duplicate column ID: {}", col); } } + + #[test] + #[cfg(feature = "gpu")] + fn test_gpu_witgen_logic_i_correctness() { + use crate::e2e::ShardContext; + use ceno_emul::{ByteAddr, Change, InsnKind, PC_STEP_SIZE, StepRecord, encode_rv32u}; + use ceno_gpu::{Buffer, bb31::CudaHalBB31}; + + let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); + + let mut cs = ConstraintSystem::::new(|| "test_logic_i_gpu"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + AndiInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let num_witin = cb.cs.num_witin as usize; + let num_structural_witin = cb.cs.num_structural_witin as usize; + + let n = 1024; + let steps: Vec = (0..n) + .map(|i| { + let rs1 = (i as u32).wrapping_mul(0x01010101) ^ 0xabed_5eff; + let imm = (i as u32) % 4096; // 0..4095 (12-bit unsigned imm) + let rd_after = rs1 & imm; // ANDI + let cycle = 4 + (i as u64) * 4; + let pc = ByteAddr(0x1000 + (i as u32) * 4); + let insn_code = encode_rv32u(InsnKind::ANDI, 2, 0, 4, imm); + StepRecord::new_i_instruction( + cycle, + Change::new(pc, pc + PC_STEP_SIZE), + insn_code, + rs1, + Change::new((i as u32) % 200, rd_after), + 0, + ) + }) + .collect(); + let indices: Vec = (0..n).collect(); + + let mut shard_ctx = ShardContext::default(); + let (cpu_rmms, _lkm) = crate::instructions::cpu_assign_instances::>( + &config, &mut shard_ctx, num_witin, num_structural_witin, &steps, &indices, + ) + .unwrap(); + let cpu_witness = &cpu_rmms[0]; + + let col_map = extract_logic_i_column_map(&config, num_witin); + let shard_ctx_gpu = ShardContext::default(); + let shard_offset = shard_ctx_gpu.current_shard_offset_cycle(); + let steps_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + steps.as_ptr() as *const u8, + steps.len() * std::mem::size_of::(), + ) + }; + let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); + let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); + let gpu_result = hal + .witgen_logic_i(&col_map, &gpu_records, &indices_u32, shard_offset, None) + .unwrap(); + + let gpu_data: Vec<::BaseField> = + gpu_result.device_buffer.to_vec().unwrap(); + let cpu_data = cpu_witness.values(); + assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); + + let flat = col_map.to_flat(); + let mut mismatches = 0; + for row in 0..n { + for &col in &flat { + let c = col as usize; + let gpu_val = gpu_data[row * num_witin + c]; + let cpu_val = cpu_data[row * num_witin + c]; + if gpu_val != cpu_val { + if mismatches < 10 { + eprintln!( + "Mismatch at row={}, col={}: GPU={:?}, CPU={:?}", + row, c, gpu_val, cpu_val + ); + } + mismatches += 1; + } + } + } + assert_eq!(mismatches, 0, "Found {} mismatches", mismatches); + } } diff --git a/ceno_zkvm/src/instructions/riscv/gpu/lui.rs b/ceno_zkvm/src/instructions/riscv/gpu/lui.rs index 2f8bd21bc..5cd798073 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/lui.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/lui.rs @@ -95,4 +95,89 @@ mod tests { assert!(seen.insert(col), "Duplicate column ID: {}", col); } } + + #[test] + #[cfg(feature = "gpu")] + fn test_gpu_witgen_lui_correctness() { + use crate::e2e::ShardContext; + use ceno_emul::{ByteAddr, Change, InsnKind, PC_STEP_SIZE, StepRecord, encode_rv32}; + use ceno_gpu::{Buffer, bb31::CudaHalBB31}; + + let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); + + let mut cs = ConstraintSystem::::new(|| "test_lui_gpu"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + LuiInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let num_witin = cb.cs.num_witin as usize; + let num_structural_witin = cb.cs.num_structural_witin as usize; + + let n = 1024; + let steps: Vec = (0..n) + .map(|i| { + let imm_20bit = (i as i32) % 0x100000; // 0..0xfffff (20-bit) + let imm = imm_20bit << 12; // LUI immediate is upper 20 bits + let rd_after = imm as u32; + let cycle = 4 + (i as u64) * 4; + let pc = ByteAddr(0x1000 + (i as u32) * 4); + let insn_code = encode_rv32(InsnKind::LUI, 0, 0, 4, imm); + StepRecord::new_i_instruction( + cycle, + Change::new(pc, pc + PC_STEP_SIZE), + insn_code, + 0, + Change::new((i as u32) % 200, rd_after), + 0, + ) + }) + .collect(); + let indices: Vec = (0..n).collect(); + + let mut shard_ctx = ShardContext::default(); + let (cpu_rmms, _lkm) = crate::instructions::cpu_assign_instances::>( + &config, &mut shard_ctx, num_witin, num_structural_witin, &steps, &indices, + ) + .unwrap(); + let cpu_witness = &cpu_rmms[0]; + + let col_map = extract_lui_column_map(&config, num_witin); + let shard_ctx_gpu = ShardContext::default(); + let shard_offset = shard_ctx_gpu.current_shard_offset_cycle(); + let steps_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + steps.as_ptr() as *const u8, + steps.len() * std::mem::size_of::(), + ) + }; + let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); + let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); + let gpu_result = hal + .witgen_lui(&col_map, &gpu_records, &indices_u32, shard_offset, None) + .unwrap(); + + let gpu_data: Vec<::BaseField> = + gpu_result.device_buffer.to_vec().unwrap(); + let cpu_data = cpu_witness.values(); + assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); + + let flat = col_map.to_flat(); + let mut mismatches = 0; + for row in 0..n { + for &col in &flat { + let c = col as usize; + let gpu_val = gpu_data[row * num_witin + c]; + let cpu_val = cpu_data[row * num_witin + c]; + if gpu_val != cpu_val { + if mismatches < 10 { + eprintln!( + "Mismatch at row={}, col={}: GPU={:?}, CPU={:?}", + row, c, gpu_val, cpu_val + ); + } + mismatches += 1; + } + } + } + assert_eq!(mismatches, 0, "Found {} mismatches", mismatches); + } } diff --git a/ceno_zkvm/src/instructions/riscv/gpu/shift_i.rs b/ceno_zkvm/src/instructions/riscv/gpu/shift_i.rs index 8eacae84a..f58bcf497 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/shift_i.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/shift_i.rs @@ -121,4 +121,89 @@ mod tests { assert!(seen.insert(col), "Duplicate column ID: {}", col); } } + + #[test] + #[cfg(feature = "gpu")] + fn test_gpu_witgen_shift_i_correctness() { + use crate::e2e::ShardContext; + use ceno_emul::{ByteAddr, Change, InsnKind, PC_STEP_SIZE, StepRecord, encode_rv32}; + use ceno_gpu::{Buffer, bb31::CudaHalBB31}; + + let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); + + let mut cs = ConstraintSystem::::new(|| "test_shift_i_gpu"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + SlliInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let num_witin = cb.cs.num_witin as usize; + let num_structural_witin = cb.cs.num_structural_witin as usize; + + let n = 1024; + let steps: Vec = (0..n) + .map(|i| { + let rs1 = (i as u32).wrapping_mul(0x01010101); + let shamt = (i as i32) % 32; // 0..31 + let rd_after = rs1 << (shamt as u32); + let cycle = 4 + (i as u64) * 4; + let pc = ByteAddr(0x1000 + (i as u32) * 4); + let insn_code = encode_rv32(InsnKind::SLLI, 2, 0, 4, shamt); + StepRecord::new_i_instruction( + cycle, + Change::new(pc, pc + PC_STEP_SIZE), + insn_code, + rs1, + Change::new((i as u32) % 200, rd_after), + 0, + ) + }) + .collect(); + let indices: Vec = (0..n).collect(); + + let mut shard_ctx = ShardContext::default(); + let (cpu_rmms, _lkm) = crate::instructions::cpu_assign_instances::>( + &config, &mut shard_ctx, num_witin, num_structural_witin, &steps, &indices, + ) + .unwrap(); + let cpu_witness = &cpu_rmms[0]; + + let col_map = extract_shift_i_column_map(&config, num_witin); + let shard_ctx_gpu = ShardContext::default(); + let shard_offset = shard_ctx_gpu.current_shard_offset_cycle(); + let steps_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + steps.as_ptr() as *const u8, + steps.len() * std::mem::size_of::(), + ) + }; + let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); + let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); + let gpu_result = hal + .witgen_shift_i(&col_map, &gpu_records, &indices_u32, shard_offset, 0, None) + .unwrap(); + + let gpu_data: Vec<::BaseField> = + gpu_result.device_buffer.to_vec().unwrap(); + let cpu_data = cpu_witness.values(); + assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); + + let flat = col_map.to_flat(); + let mut mismatches = 0; + for row in 0..n { + for &col in &flat { + let c = col as usize; + let gpu_val = gpu_data[row * num_witin + c]; + let cpu_val = cpu_data[row * num_witin + c]; + if gpu_val != cpu_val { + if mismatches < 10 { + eprintln!( + "Mismatch at row={}, col={}: GPU={:?}, CPU={:?}", + row, c, gpu_val, cpu_val + ); + } + mismatches += 1; + } + } + } + assert_eq!(mismatches, 0, "Found {} mismatches", mismatches); + } } diff --git a/ceno_zkvm/src/instructions/riscv/gpu/shift_r.rs b/ceno_zkvm/src/instructions/riscv/gpu/shift_r.rs index 81840a553..e286936e8 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/shift_r.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/shift_r.rs @@ -135,4 +135,85 @@ mod tests { assert!(seen.insert(col), "Duplicate column ID: {}", col); } } + + #[test] + #[cfg(feature = "gpu")] + fn test_gpu_witgen_shift_r_correctness() { + use crate::e2e::ShardContext; + use ceno_emul::{ByteAddr, Change, InsnKind, StepRecord, encode_rv32}; + use ceno_gpu::{Buffer, bb31::CudaHalBB31}; + + let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); + + let mut cs = ConstraintSystem::::new(|| "test_shift_r_gpu"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + SllInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let num_witin = cb.cs.num_witin as usize; + let num_structural_witin = cb.cs.num_structural_witin as usize; + + let n = 1024; + let steps: Vec = (0..n) + .map(|i| { + let rs1 = (i as u32).wrapping_mul(0x01010101); + let rs2 = (i as u32) % 32; + let rd_after = rs1 << rs2; + let cycle = 4 + (i as u64) * 4; + let pc = ByteAddr(0x1000 + (i as u32) * 4); + let insn_code = encode_rv32(InsnKind::SLL, 2, 3, 4, 0); + StepRecord::new_r_instruction( + cycle, pc, insn_code, rs1, rs2, + Change::new((i as u32) % 200, rd_after), 0, + ) + }) + .collect(); + let indices: Vec = (0..n).collect(); + + let mut shard_ctx = ShardContext::default(); + let (cpu_rmms, _lkm) = crate::instructions::cpu_assign_instances::>( + &config, &mut shard_ctx, num_witin, num_structural_witin, &steps, &indices, + ) + .unwrap(); + let cpu_witness = &cpu_rmms[0]; + + let col_map = extract_shift_r_column_map(&config, num_witin); + let shard_ctx_gpu = ShardContext::default(); + let shard_offset = shard_ctx_gpu.current_shard_offset_cycle(); + let steps_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + steps.as_ptr() as *const u8, + steps.len() * std::mem::size_of::(), + ) + }; + let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); + let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); + let gpu_result = hal + .witgen_shift_r(&col_map, &gpu_records, &indices_u32, shard_offset, 0, None) + .unwrap(); + + let gpu_data: Vec<::BaseField> = + gpu_result.device_buffer.to_vec().unwrap(); + let cpu_data = cpu_witness.values(); + assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); + + let flat = col_map.to_flat(); + let mut mismatches = 0; + for row in 0..n { + for &col in &flat { + let c = col as usize; + let gpu_val = gpu_data[row * num_witin + c]; + let cpu_val = cpu_data[row * num_witin + c]; + if gpu_val != cpu_val { + if mismatches < 10 { + eprintln!( + "Mismatch at row={}, col={}: GPU={:?}, CPU={:?}", + row, c, gpu_val, cpu_val + ); + } + mismatches += 1; + } + } + } + assert_eq!(mismatches, 0, "Found {} mismatches", mismatches); + } } diff --git a/ceno_zkvm/src/instructions/riscv/gpu/slt.rs b/ceno_zkvm/src/instructions/riscv/gpu/slt.rs index 7cd25266e..985f5dc8b 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/slt.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/slt.rs @@ -134,4 +134,86 @@ mod tests { assert!(seen.insert(col), "Duplicate column ID: {}", col); } } + + #[test] + #[cfg(feature = "gpu")] + fn test_gpu_witgen_slt_correctness() { + use crate::e2e::ShardContext; + use ceno_emul::{ByteAddr, Change, InsnKind, StepRecord, encode_rv32}; + use ceno_gpu::{Buffer, bb31::CudaHalBB31}; + + let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); + + let mut cs = ConstraintSystem::::new(|| "test_slt_gpu"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + SltInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let num_witin = cb.cs.num_witin as usize; + let num_structural_witin = cb.cs.num_structural_witin as usize; + + let n = 1024; + let steps: Vec = (0..n) + .map(|i| { + // Mix positive, negative, equal cases + let rs1 = ((i as i32) * 137 - 500) as u32; + let rs2 = ((i as i32) * 89 - 300) as u32; + let rd_after = if (rs1 as i32) < (rs2 as i32) { 1u32 } else { 0u32 }; + let cycle = 4 + (i as u64) * 4; + let pc = ByteAddr(0x1000 + (i as u32) * 4); + let insn_code = encode_rv32(InsnKind::SLT, 2, 3, 4, 0); + StepRecord::new_r_instruction( + cycle, pc, insn_code, rs1, rs2, + Change::new((i as u32) % 200, rd_after), 0, + ) + }) + .collect(); + let indices: Vec = (0..n).collect(); + + let mut shard_ctx = ShardContext::default(); + let (cpu_rmms, _lkm) = crate::instructions::cpu_assign_instances::>( + &config, &mut shard_ctx, num_witin, num_structural_witin, &steps, &indices, + ) + .unwrap(); + let cpu_witness = &cpu_rmms[0]; + + let col_map = extract_slt_column_map(&config, num_witin); + let shard_ctx_gpu = ShardContext::default(); + let shard_offset = shard_ctx_gpu.current_shard_offset_cycle(); + let steps_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + steps.as_ptr() as *const u8, + steps.len() * std::mem::size_of::(), + ) + }; + let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); + let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); + let gpu_result = hal + .witgen_slt(&col_map, &gpu_records, &indices_u32, shard_offset, 1, None) + .unwrap(); + + let gpu_data: Vec<::BaseField> = + gpu_result.device_buffer.to_vec().unwrap(); + let cpu_data = cpu_witness.values(); + assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); + + let flat = col_map.to_flat(); + let mut mismatches = 0; + for row in 0..n { + for &col in &flat { + let c = col as usize; + let gpu_val = gpu_data[row * num_witin + c]; + let cpu_val = cpu_data[row * num_witin + c]; + if gpu_val != cpu_val { + if mismatches < 10 { + eprintln!( + "Mismatch at row={}, col={}: GPU={:?}, CPU={:?}", + row, c, gpu_val, cpu_val + ); + } + mismatches += 1; + } + } + } + assert_eq!(mismatches, 0, "Found {} mismatches", mismatches); + } } diff --git a/ceno_zkvm/src/instructions/riscv/gpu/slti.rs b/ceno_zkvm/src/instructions/riscv/gpu/slti.rs index 0b38c7a8e..713f9c19e 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/slti.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/slti.rs @@ -117,4 +117,89 @@ mod tests { assert!(seen.insert(col), "Duplicate column ID: {}", col); } } + + #[test] + #[cfg(feature = "gpu")] + fn test_gpu_witgen_slti_correctness() { + use crate::e2e::ShardContext; + use ceno_emul::{ByteAddr, Change, InsnKind, PC_STEP_SIZE, StepRecord, encode_rv32}; + use ceno_gpu::{Buffer, bb31::CudaHalBB31}; + + let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); + + let mut cs = ConstraintSystem::::new(|| "test_slti_gpu"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + SltiInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let num_witin = cb.cs.num_witin as usize; + let num_structural_witin = cb.cs.num_structural_witin as usize; + + let n = 1024; + let steps: Vec = (0..n) + .map(|i| { + let rs1 = ((i as i32) * 137 - 500) as u32; + let imm = ((i as i32) % 2048 - 1024) as i32; // -1024..1023 + let rd_after = if (rs1 as i32) < (imm as i32) { 1u32 } else { 0u32 }; + let cycle = 4 + (i as u64) * 4; + let pc = ByteAddr(0x1000 + (i as u32) * 4); + let insn_code = encode_rv32(InsnKind::SLTI, 2, 0, 4, imm); + StepRecord::new_i_instruction( + cycle, + Change::new(pc, pc + PC_STEP_SIZE), + insn_code, + rs1, + Change::new((i as u32) % 200, rd_after), + 0, + ) + }) + .collect(); + let indices: Vec = (0..n).collect(); + + let mut shard_ctx = ShardContext::default(); + let (cpu_rmms, _lkm) = crate::instructions::cpu_assign_instances::>( + &config, &mut shard_ctx, num_witin, num_structural_witin, &steps, &indices, + ) + .unwrap(); + let cpu_witness = &cpu_rmms[0]; + + let col_map = extract_slti_column_map(&config, num_witin); + let shard_ctx_gpu = ShardContext::default(); + let shard_offset = shard_ctx_gpu.current_shard_offset_cycle(); + let steps_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + steps.as_ptr() as *const u8, + steps.len() * std::mem::size_of::(), + ) + }; + let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); + let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); + let gpu_result = hal + .witgen_slti(&col_map, &gpu_records, &indices_u32, shard_offset, 1, None) + .unwrap(); + + let gpu_data: Vec<::BaseField> = + gpu_result.device_buffer.to_vec().unwrap(); + let cpu_data = cpu_witness.values(); + assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); + + let flat = col_map.to_flat(); + let mut mismatches = 0; + for row in 0..n { + for &col in &flat { + let c = col as usize; + let gpu_val = gpu_data[row * num_witin + c]; + let cpu_val = cpu_data[row * num_witin + c]; + if gpu_val != cpu_val { + if mismatches < 10 { + eprintln!( + "Mismatch at row={}, col={}: GPU={:?}, CPU={:?}", + row, c, gpu_val, cpu_val + ); + } + mismatches += 1; + } + } + } + assert_eq!(mismatches, 0, "Found {} mismatches", mismatches); + } } From 7191624efe1df0ba07c8b57bd55edf6a16c73eaf Mon Sep 17 00:00:00 2001 From: Velaciela Date: Fri, 6 Mar 2026 02:12:08 +0800 Subject: [PATCH 17/73] batch-10,11-branch --- ceno_zkvm/src/instructions/riscv/branch.rs | 2 +- .../riscv/branch/branch_circuit_v2.rs | 46 ++++ .../src/instructions/riscv/gpu/branch_cmp.rs | 219 ++++++++++++++++++ .../src/instructions/riscv/gpu/branch_eq.rs | 212 +++++++++++++++++ ceno_zkvm/src/instructions/riscv/gpu/mod.rs | 4 + .../src/instructions/riscv/gpu/witgen_gpu.rs | 42 ++++ 6 files changed, 524 insertions(+), 1 deletion(-) create mode 100644 ceno_zkvm/src/instructions/riscv/gpu/branch_cmp.rs create mode 100644 ceno_zkvm/src/instructions/riscv/gpu/branch_eq.rs diff --git a/ceno_zkvm/src/instructions/riscv/branch.rs b/ceno_zkvm/src/instructions/riscv/branch.rs index ab080ac0d..e1e5f40df 100644 --- a/ceno_zkvm/src/instructions/riscv/branch.rs +++ b/ceno_zkvm/src/instructions/riscv/branch.rs @@ -4,7 +4,7 @@ use ceno_emul::InsnKind; #[cfg(not(feature = "u16limb_circuit"))] mod branch_circuit; #[cfg(feature = "u16limb_circuit")] -mod branch_circuit_v2; +pub(crate) mod branch_circuit_v2; #[cfg(test)] mod test; diff --git a/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs index 85ef6914b..8bec503bb 100644 --- a/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs @@ -204,4 +204,50 @@ impl Instruction for BranchCircuit Result< + ( + crate::tables::RMMCollections, + gkr_iop::utils::lk_multiplicity::Multiplicity, + ), + crate::error::ZKVMError, + > { + use crate::instructions::riscv::gpu::witgen_gpu; + let kind = match I::INST_KIND { + InsnKind::BEQ => witgen_gpu::GpuWitgenKind::BranchEq(1), + InsnKind::BNE => witgen_gpu::GpuWitgenKind::BranchEq(0), + InsnKind::BLT => witgen_gpu::GpuWitgenKind::BranchCmp(1), + InsnKind::BGE => witgen_gpu::GpuWitgenKind::BranchCmp(1), + InsnKind::BLTU => witgen_gpu::GpuWitgenKind::BranchCmp(0), + InsnKind::BGEU => witgen_gpu::GpuWitgenKind::BranchCmp(0), + _ => unreachable!(), + }; + if let Some(result) = witgen_gpu::try_gpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + kind, + )? { + return Ok(result); + } + crate::instructions::cpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + ) + } } diff --git a/ceno_zkvm/src/instructions/riscv/gpu/branch_cmp.rs b/ceno_zkvm/src/instructions/riscv/gpu/branch_cmp.rs new file mode 100644 index 000000000..acd311265 --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/gpu/branch_cmp.rs @@ -0,0 +1,219 @@ +use ceno_gpu::common::witgen_types::BranchCmpColumnMap; +use ff_ext::ExtensionField; + +use crate::instructions::riscv::branch::branch_circuit_v2::BranchConfig; + +/// Extract column map from a constructed BranchConfig (BLT/BGE/BLTU/BGEU variant). +pub fn extract_branch_cmp_column_map( + config: &BranchConfig, + num_witin: usize, +) -> BranchCmpColumnMap { + let rs1_limbs: [u32; 2] = { + let limbs = config + .read_rs1 + .wits_in() + .expect("rs1 WitIn"); + assert_eq!(limbs.len(), 2); + [limbs[0].id as u32, limbs[1].id as u32] + }; + let rs2_limbs: [u32; 2] = { + let limbs = config + .read_rs2 + .wits_in() + .expect("rs2 WitIn"); + assert_eq!(limbs.len(), 2); + [limbs[0].id as u32, limbs[1].id as u32] + }; + + let lt_config = config.uint_lt_config.as_ref().unwrap(); + let cmp_lt = lt_config.cmp_lt.id as u32; + let a_msb_f = lt_config.a_msb_f.id as u32; + let b_msb_f = lt_config.b_msb_f.id as u32; + let diff_marker: [u32; 2] = [ + lt_config.diff_marker[0].id as u32, + lt_config.diff_marker[1].id as u32, + ]; + let diff_val = lt_config.diff_val.id as u32; + + let pc = config.b_insn.vm_state.pc.id as u32; + let next_pc = config.b_insn.vm_state.next_pc.unwrap().id as u32; + let ts = config.b_insn.vm_state.ts.id as u32; + + let rs1_id = config.b_insn.rs1.id.id as u32; + let rs1_prev_ts = config.b_insn.rs1.prev_ts.id as u32; + let rs1_lt_diff: [u32; 2] = { + let d = &config.b_insn.rs1.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + let rs2_id = config.b_insn.rs2.id.id as u32; + let rs2_prev_ts = config.b_insn.rs2.prev_ts.id as u32; + let rs2_lt_diff: [u32; 2] = { + let d = &config.b_insn.rs2.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + let imm = config.b_insn.imm.id as u32; + + BranchCmpColumnMap { + rs1_limbs, + rs2_limbs, + cmp_lt, + a_msb_f, + b_msb_f, + diff_marker, + diff_val, + pc, + next_pc, + ts, + rs1_id, + rs1_prev_ts, + rs1_lt_diff, + rs2_id, + rs2_prev_ts, + rs2_lt_diff, + imm, + num_cols: num_witin as u32, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::{Instruction, riscv::branch::BltInstruction}, + structs::ProgramParams, + }; + use ff_ext::BabyBearExt4; + + type E = BabyBearExt4; + + #[test] + fn test_extract_branch_cmp_column_map() { + let mut cs = ConstraintSystem::::new(|| "test"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + BltInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + + let col_map = extract_branch_cmp_column_map(&config, cb.cs.num_witin as usize); + let flat = col_map.to_flat(); + + for (i, &col) in flat.iter().enumerate() { + assert!( + (col as usize) < col_map.num_cols as usize, + "Column {} (index {}) out of range: {} >= {}", + i, col, col, col_map.num_cols + ); + } + let mut seen = std::collections::HashSet::new(); + for &col in &flat { + assert!(seen.insert(col), "Duplicate column ID: {}", col); + } + } + + #[test] + #[cfg(feature = "gpu")] + fn test_gpu_witgen_branch_cmp_correctness() { + use crate::e2e::ShardContext; + use ceno_emul::{ByteAddr, Change, InsnKind, PC_STEP_SIZE, StepRecord, encode_rv32}; + use ceno_gpu::{Buffer, bb31::CudaHalBB31}; + + let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); + + let mut cs = ConstraintSystem::::new(|| "test_blt_gpu"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + BltInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let num_witin = cb.cs.num_witin as usize; + let num_structural_witin = cb.cs.num_structural_witin as usize; + + let n = 1024; + let insn_code = encode_rv32(InsnKind::BLT, 2, 3, 0, -8); + let steps: Vec = (0..n) + .map(|i| { + let rs1 = ((i as i32) * 137 - 500) as u32; + let rs2 = ((i as i32) * 89 - 300) as u32; + let taken = (rs1 as i32) < (rs2 as i32); + let pc = ByteAddr(0x2000 + (i as u32) * 4); + let pc_after = if taken { + ByteAddr(pc.0.wrapping_sub(8)) + } else { + pc + PC_STEP_SIZE + }; + let cycle = 4 + (i as u64) * 4; + StepRecord::new_b_instruction( + cycle, + Change::new(pc, pc_after), + insn_code, + rs1, + rs2, + 0, + ) + }) + .collect(); + let indices: Vec = (0..n).collect(); + + let mut shard_ctx = ShardContext::default(); + let (cpu_rmms, _lkm) = + crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(); + let cpu_witness = &cpu_rmms[0]; + + let col_map = extract_branch_cmp_column_map(&config, num_witin); + let shard_ctx_gpu = ShardContext::default(); + let shard_offset = shard_ctx_gpu.current_shard_offset_cycle(); + let steps_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + steps.as_ptr() as *const u8, + steps.len() * std::mem::size_of::(), + ) + }; + let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); + let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); + let gpu_result = hal + .witgen_branch_cmp( + &col_map, + &gpu_records, + &indices_u32, + shard_offset, + 1, + None, + ) + .unwrap(); + + let gpu_data: Vec<::BaseField> = + gpu_result.device_buffer.to_vec().unwrap(); + let cpu_data = cpu_witness.values(); + assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); + + let flat = col_map.to_flat(); + let mut mismatches = 0; + for row in 0..n { + for &col in &flat { + let c = col as usize; + let gpu_val = gpu_data[row * num_witin + c]; + let cpu_val = cpu_data[row * num_witin + c]; + if gpu_val != cpu_val { + if mismatches < 10 { + eprintln!( + "Mismatch at row={}, col={}: GPU={:?}, CPU={:?}", + row, c, gpu_val, cpu_val + ); + } + mismatches += 1; + } + } + } + assert_eq!(mismatches, 0, "Found {} mismatches", mismatches); + } +} diff --git a/ceno_zkvm/src/instructions/riscv/gpu/branch_eq.rs b/ceno_zkvm/src/instructions/riscv/gpu/branch_eq.rs new file mode 100644 index 000000000..6d1621633 --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/gpu/branch_eq.rs @@ -0,0 +1,212 @@ +use ceno_gpu::common::witgen_types::BranchEqColumnMap; +use ff_ext::ExtensionField; + +use crate::instructions::riscv::branch::branch_circuit_v2::BranchConfig; + +/// Extract column map from a constructed BranchConfig (BEQ/BNE variant). +pub fn extract_branch_eq_column_map( + config: &BranchConfig, + num_witin: usize, +) -> BranchEqColumnMap { + let rs1_limbs: [u32; 2] = { + let limbs = config + .read_rs1 + .wits_in() + .expect("rs1 WitIn"); + assert_eq!(limbs.len(), 2); + [limbs[0].id as u32, limbs[1].id as u32] + }; + let rs2_limbs: [u32; 2] = { + let limbs = config + .read_rs2 + .wits_in() + .expect("rs2 WitIn"); + assert_eq!(limbs.len(), 2); + [limbs[0].id as u32, limbs[1].id as u32] + }; + + let branch_taken = config.eq_branch_taken_bit.as_ref().unwrap().id as u32; + let diff_inv_marker: [u32; 2] = { + let markers = config.eq_diff_inv_marker.as_ref().unwrap(); + [markers[0].id as u32, markers[1].id as u32] + }; + + let pc = config.b_insn.vm_state.pc.id as u32; + let next_pc = config.b_insn.vm_state.next_pc.unwrap().id as u32; + let ts = config.b_insn.vm_state.ts.id as u32; + + let rs1_id = config.b_insn.rs1.id.id as u32; + let rs1_prev_ts = config.b_insn.rs1.prev_ts.id as u32; + let rs1_lt_diff: [u32; 2] = { + let d = &config.b_insn.rs1.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + let rs2_id = config.b_insn.rs2.id.id as u32; + let rs2_prev_ts = config.b_insn.rs2.prev_ts.id as u32; + let rs2_lt_diff: [u32; 2] = { + let d = &config.b_insn.rs2.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + let imm = config.b_insn.imm.id as u32; + + BranchEqColumnMap { + rs1_limbs, + rs2_limbs, + branch_taken, + diff_inv_marker, + pc, + next_pc, + ts, + rs1_id, + rs1_prev_ts, + rs1_lt_diff, + rs2_id, + rs2_prev_ts, + rs2_lt_diff, + imm, + num_cols: num_witin as u32, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::{Instruction, riscv::branch::BeqInstruction}, + structs::ProgramParams, + }; + use ff_ext::BabyBearExt4; + + type E = BabyBearExt4; + + #[test] + fn test_extract_branch_eq_column_map() { + let mut cs = ConstraintSystem::::new(|| "test"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + BeqInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + + let col_map = extract_branch_eq_column_map(&config, cb.cs.num_witin as usize); + let flat = col_map.to_flat(); + + for (i, &col) in flat.iter().enumerate() { + assert!( + (col as usize) < col_map.num_cols as usize, + "Column {} (index {}) out of range: {} >= {}", + i, col, col, col_map.num_cols + ); + } + let mut seen = std::collections::HashSet::new(); + for &col in &flat { + assert!(seen.insert(col), "Duplicate column ID: {}", col); + } + } + + #[test] + #[cfg(feature = "gpu")] + fn test_gpu_witgen_branch_eq_correctness() { + use crate::e2e::ShardContext; + use ceno_emul::{ByteAddr, Change, InsnKind, PC_STEP_SIZE, StepRecord, encode_rv32}; + use ceno_gpu::{Buffer, bb31::CudaHalBB31}; + + let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); + + let mut cs = ConstraintSystem::::new(|| "test_beq_gpu"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + BeqInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let num_witin = cb.cs.num_witin as usize; + let num_structural_witin = cb.cs.num_structural_witin as usize; + + let n = 1024; + let insn_code = encode_rv32(InsnKind::BEQ, 2, 3, 0, 8); + let steps: Vec = (0..n) + .map(|i| { + let rs1 = ((i as u32) * 137) ^ 0xABCD; + let rs2 = if i % 3 == 0 { rs1 } else { ((i as u32) * 89) ^ 0x1234 }; + let taken = rs1 == rs2; + let pc = ByteAddr(0x1000 + (i as u32) * 4); + let pc_after = if taken { + ByteAddr(pc.0 + 8) + } else { + pc + PC_STEP_SIZE + }; + let cycle = 4 + (i as u64) * 4; + StepRecord::new_b_instruction( + cycle, + Change::new(pc, pc_after), + insn_code, + rs1, + rs2, + 0, + ) + }) + .collect(); + let indices: Vec = (0..n).collect(); + + let mut shard_ctx = ShardContext::default(); + let (cpu_rmms, _lkm) = + crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(); + let cpu_witness = &cpu_rmms[0]; + + let col_map = extract_branch_eq_column_map(&config, num_witin); + let shard_ctx_gpu = ShardContext::default(); + let shard_offset = shard_ctx_gpu.current_shard_offset_cycle(); + let steps_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + steps.as_ptr() as *const u8, + steps.len() * std::mem::size_of::(), + ) + }; + let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); + let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); + let gpu_result = hal + .witgen_branch_eq( + &col_map, + &gpu_records, + &indices_u32, + shard_offset, + 1, + None, + ) + .unwrap(); + + let gpu_data: Vec<::BaseField> = + gpu_result.device_buffer.to_vec().unwrap(); + let cpu_data = cpu_witness.values(); + assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); + + let flat = col_map.to_flat(); + let mut mismatches = 0; + for row in 0..n { + for &col in &flat { + let c = col as usize; + let gpu_val = gpu_data[row * num_witin + c]; + let cpu_val = cpu_data[row * num_witin + c]; + if gpu_val != cpu_val { + if mismatches < 10 { + eprintln!( + "Mismatch at row={}, col={}: GPU={:?}, CPU={:?}", + row, c, gpu_val, cpu_val + ); + } + mismatches += 1; + } + } + } + assert_eq!(mismatches, 0, "Found {} mismatches", mismatches); + } +} diff --git a/ceno_zkvm/src/instructions/riscv/gpu/mod.rs b/ceno_zkvm/src/instructions/riscv/gpu/mod.rs index 848cdaf71..2913d1824 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/mod.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/mod.rs @@ -24,5 +24,9 @@ pub mod slt; pub mod slti; #[cfg(feature = "gpu")] pub mod sub; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod branch_cmp; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod branch_eq; #[cfg(feature = "gpu")] pub mod witgen_gpu; diff --git a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs index 01b5e35a5..1d812a7c4 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs @@ -46,6 +46,10 @@ pub enum GpuWitgenKind { Slt(u32), // 1=SLT(signed), 0=SLTU(unsigned) #[cfg(feature = "u16limb_circuit")] Slti(u32), // 1=SLTI(signed), 0=SLTIU(unsigned) + #[cfg(feature = "u16limb_circuit")] + BranchEq(u32), // 1=BEQ, 0=BNE + #[cfg(feature = "u16limb_circuit")] + BranchCmp(u32), // 1=signed (BLT/BGE), 0=unsigned (BLTU/BGEU) Lw, } @@ -501,6 +505,44 @@ fn gpu_fill_witness>( }) }) } + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::BranchEq(is_beq) => { + let branch_config = unsafe { + &*(config as *const I::InstructionConfig + as *const crate::instructions::riscv::branch::branch_circuit_v2::BranchConfig) + }; + let col_map = info_span!("col_map") + .in_scope(|| super::branch_eq::extract_branch_eq_column_map(branch_config, num_witin)); + info_span!("hal_witgen_branch_eq").in_scope(|| { + with_cached_shard_steps(|gpu_records| { + hal.witgen_branch_eq(&col_map, gpu_records, &indices_u32, shard_offset, is_beq, None) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_branch_eq failed: {e}").into(), + ) + }) + }) + }) + } + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::BranchCmp(is_signed) => { + let branch_config = unsafe { + &*(config as *const I::InstructionConfig + as *const crate::instructions::riscv::branch::branch_circuit_v2::BranchConfig) + }; + let col_map = info_span!("col_map") + .in_scope(|| super::branch_cmp::extract_branch_cmp_column_map(branch_config, num_witin)); + info_span!("hal_witgen_branch_cmp").in_scope(|| { + with_cached_shard_steps(|gpu_records| { + hal.witgen_branch_cmp(&col_map, gpu_records, &indices_u32, shard_offset, is_signed, None) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_branch_cmp failed: {e}").into(), + ) + }) + }) + }) + } GpuWitgenKind::Lw => { #[cfg(feature = "u16limb_circuit")] let load_config = unsafe { From 1b6f346f83cbf8f65fece7f6c4747cc221a9031f Mon Sep 17 00:00:00 2001 From: Velaciela Date: Fri, 6 Mar 2026 09:16:19 +0800 Subject: [PATCH 18/73] batch-13-JALR --- ceno_zkvm/src/instructions/riscv/gpu/jalr.rs | 215 ++++++++++++++++++ ceno_zkvm/src/instructions/riscv/gpu/mod.rs | 2 + .../src/instructions/riscv/gpu/witgen_gpu.rs | 21 ++ ceno_zkvm/src/instructions/riscv/jump.rs | 2 +- .../src/instructions/riscv/jump/jalr_v2.rs | 38 ++++ 5 files changed, 277 insertions(+), 1 deletion(-) create mode 100644 ceno_zkvm/src/instructions/riscv/gpu/jalr.rs diff --git a/ceno_zkvm/src/instructions/riscv/gpu/jalr.rs b/ceno_zkvm/src/instructions/riscv/gpu/jalr.rs new file mode 100644 index 000000000..8da4e6ba4 --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/gpu/jalr.rs @@ -0,0 +1,215 @@ +use ceno_gpu::common::witgen_types::JalrColumnMap; +use ff_ext::ExtensionField; + +use crate::instructions::riscv::jump::jalr_v2::JalrConfig; + +/// Extract column map from a constructed JalrConfig. +pub fn extract_jalr_column_map( + config: &JalrConfig, + num_witin: usize, +) -> JalrColumnMap { + let im = &config.i_insn; + + // StateInOut (branching=true → has next_pc) + let pc = im.vm_state.pc.id as u32; + let next_pc = im.vm_state.next_pc.expect("JALR must have next_pc").id as u32; + let ts = im.vm_state.ts.id as u32; + + // ReadRS1 + let rs1_id = im.rs1.id.id as u32; + let rs1_prev_ts = im.rs1.prev_ts.id as u32; + let rs1_lt_diff: [u32; 2] = { + let d = &im.rs1.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + // WriteRD + let rd_id = im.rd.id.id as u32; + let rd_prev_ts = im.rd.prev_ts.id as u32; + let rd_prev_val: [u32; 2] = { + let l = im.rd.prev_value.wits_in().expect("rd prev_value WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + let rd_lt_diff: [u32; 2] = { + let d = &im.rd.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + // JALR-specific: rs1 u16 limbs + let rs1_limbs: [u32; 2] = { + let l = config.rs1_read.wits_in().expect("rs1_read WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + + // imm, imm_sign + let imm = config.imm.id as u32; + let imm_sign = config.imm_sign.id as u32; + + // jump_pc_addr: MemAddr has addr (UInt = 2 limbs) + low_bits (Vec) + let jump_pc_addr: [u32; 2] = { + let l = config.jump_pc_addr.addr.wits_in().expect("jump_pc_addr WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + let jump_pc_addr_bit: [u32; 2] = { + let bits = &config.jump_pc_addr.low_bits; + assert_eq!(bits.len(), 2, "JALR MemAddr with n_zeros=0 must have 2 low_bits"); + [bits[0].id as u32, bits[1].id as u32] + }; + + // rd_high + let rd_high = config.rd_high.id as u32; + + JalrColumnMap { + pc, + next_pc, + ts, + rs1_id, + rs1_prev_ts, + rs1_lt_diff, + rd_id, + rd_prev_ts, + rd_prev_val, + rd_lt_diff, + rs1_limbs, + imm, + imm_sign, + jump_pc_addr, + jump_pc_addr_bit, + rd_high, + num_cols: num_witin as u32, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::{Instruction, riscv::jump::JalrInstruction}, + structs::ProgramParams, + }; + use ff_ext::BabyBearExt4; + + type E = BabyBearExt4; + + #[test] + fn test_extract_jalr_column_map() { + let mut cs = ConstraintSystem::::new(|| "test_jalr"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + JalrInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + + let col_map = extract_jalr_column_map(&config, cb.cs.num_witin as usize); + let flat = col_map.to_flat(); + + for (i, &col) in flat.iter().enumerate() { + assert!( + (col as usize) < col_map.num_cols as usize, + "Column {} (index {}) out of range: {} >= {}", + i, col, col, col_map.num_cols + ); + } + let mut seen = std::collections::HashSet::new(); + for &col in &flat { + assert!(seen.insert(col), "Duplicate column ID: {}", col); + } + } + + #[test] + #[cfg(feature = "gpu")] + fn test_gpu_witgen_jalr_correctness() { + use crate::e2e::ShardContext; + use ceno_emul::{ByteAddr, Change, InsnKind, PC_STEP_SIZE, StepRecord, encode_rv32}; + use ceno_gpu::{Buffer, bb31::CudaHalBB31}; + + let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); + + let mut cs = ConstraintSystem::::new(|| "test_jalr_gpu"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + JalrInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let num_witin = cb.cs.num_witin as usize; + let num_structural_witin = cb.cs.num_structural_witin as usize; + + let n = 1024; + let steps: Vec = (0..n) + .map(|i| { + let pc = ByteAddr(0x1000 + (i as u32) * 4); + let rs1_val: u32 = 0x0010_0000u32.wrapping_add(i as u32 * 137); + let imm: i32 = ((i as i32) % 2048) - 1024; // range [-1024, 1023] + let jump_raw = rs1_val.wrapping_add(imm as u32); + let new_pc = ByteAddr(jump_raw & !1u32); // aligned to 2 bytes + let rd_after: u32 = (pc + PC_STEP_SIZE).into(); + let cycle = 4 + (i as u64) * 4; + let insn_code = encode_rv32(InsnKind::JALR, 1, 0, 4, imm); + StepRecord::new_i_instruction( + cycle, + Change::new(pc, new_pc), + insn_code, + rs1_val, + Change::new((i as u32) % 200, rd_after), + 0, + ) + }) + .collect(); + let indices: Vec = (0..n).collect(); + + let mut shard_ctx = ShardContext::default(); + let (cpu_rmms, _lkm) = + crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(); + let cpu_witness = &cpu_rmms[0]; + + let col_map = extract_jalr_column_map(&config, num_witin); + let shard_ctx_gpu = ShardContext::default(); + let shard_offset = shard_ctx_gpu.current_shard_offset_cycle(); + let steps_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + steps.as_ptr() as *const u8, + steps.len() * std::mem::size_of::(), + ) + }; + let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); + let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); + let gpu_result = hal + .witgen_jalr(&col_map, &gpu_records, &indices_u32, shard_offset, None) + .unwrap(); + + let gpu_data: Vec<::BaseField> = + gpu_result.device_buffer.to_vec().unwrap(); + let cpu_data = cpu_witness.values(); + assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); + + let flat = col_map.to_flat(); + let mut mismatches = 0; + for row in 0..n { + for &col in &flat { + let c = col as usize; + let gpu_val = gpu_data[row * num_witin + c]; + let cpu_val = cpu_data[row * num_witin + c]; + if gpu_val != cpu_val { + if mismatches < 10 { + eprintln!( + "Mismatch at row={}, col={}: GPU={:?}, CPU={:?}", + row, c, gpu_val, cpu_val + ); + } + mismatches += 1; + } + } + } + assert_eq!(mismatches, 0, "Found {} mismatches", mismatches); + } +} diff --git a/ceno_zkvm/src/instructions/riscv/gpu/mod.rs b/ceno_zkvm/src/instructions/riscv/gpu/mod.rs index 2913d1824..a80b533c4 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/mod.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/mod.rs @@ -27,6 +27,8 @@ pub mod sub; #[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] pub mod branch_cmp; #[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod jalr; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] pub mod branch_eq; #[cfg(feature = "gpu")] pub mod witgen_gpu; diff --git a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs index 1d812a7c4..d9908f619 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs @@ -50,6 +50,8 @@ pub enum GpuWitgenKind { BranchEq(u32), // 1=BEQ, 0=BNE #[cfg(feature = "u16limb_circuit")] BranchCmp(u32), // 1=signed (BLT/BGE), 0=unsigned (BLTU/BGEU) + #[cfg(feature = "u16limb_circuit")] + Jalr, Lw, } @@ -543,6 +545,25 @@ fn gpu_fill_witness>( }) }) } + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Jalr => { + let jalr_config = unsafe { + &*(config as *const I::InstructionConfig + as *const crate::instructions::riscv::jump::jalr_v2::JalrConfig) + }; + let col_map = info_span!("col_map") + .in_scope(|| super::jalr::extract_jalr_column_map(jalr_config, num_witin)); + info_span!("hal_witgen_jalr").in_scope(|| { + with_cached_shard_steps(|gpu_records| { + hal.witgen_jalr(&col_map, gpu_records, &indices_u32, shard_offset, None) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_jalr failed: {e}").into(), + ) + }) + }) + }) + } GpuWitgenKind::Lw => { #[cfg(feature = "u16limb_circuit")] let load_config = unsafe { diff --git a/ceno_zkvm/src/instructions/riscv/jump.rs b/ceno_zkvm/src/instructions/riscv/jump.rs index 8a1d82ea4..c0b121827 100644 --- a/ceno_zkvm/src/instructions/riscv/jump.rs +++ b/ceno_zkvm/src/instructions/riscv/jump.rs @@ -6,7 +6,7 @@ pub(crate) mod jal_v2; #[cfg(not(feature = "u16limb_circuit"))] mod jalr; #[cfg(feature = "u16limb_circuit")] -mod jalr_v2; +pub(crate) mod jalr_v2; #[cfg(not(feature = "u16limb_circuit"))] pub use jal::JalInstruction; diff --git a/ceno_zkvm/src/instructions/riscv/jump/jalr_v2.rs b/ceno_zkvm/src/instructions/riscv/jump/jalr_v2.rs index 7c51728ac..e4c838253 100644 --- a/ceno_zkvm/src/instructions/riscv/jump/jalr_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/jump/jalr_v2.rs @@ -25,6 +25,13 @@ use ff_ext::FieldInto; use multilinear_extensions::{Expression, ToExpr, WitIn}; use p3::field::{Field, FieldAlgebra}; +#[cfg(feature = "gpu")] +use crate::tables::RMMCollections; +#[cfg(feature = "gpu")] +use ceno_emul::StepIndex; +#[cfg(feature = "gpu")] +use gkr_iop::utils::lk_multiplicity::Multiplicity; + pub struct JalrConfig { pub i_insn: IInstructionConfig, pub rs1_read: UInt, @@ -188,4 +195,35 @@ impl Instruction for JalrInstruction { Ok(()) } + + #[cfg(feature = "gpu")] + fn assign_instances( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + num_witin: usize, + num_structural_witin: usize, + shard_steps: &[ceno_emul::StepRecord], + step_indices: &[StepIndex], + ) -> Result<(RMMCollections, Multiplicity), ZKVMError> { + use crate::instructions::riscv::gpu::witgen_gpu; + if let Some(result) = witgen_gpu::try_gpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + witgen_gpu::GpuWitgenKind::Jalr, + )? { + return Ok(result); + } + crate::instructions::cpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + ) + } } From 107f72f1e6bb0919ef7e70c593e1bec236f60a3f Mon Sep 17 00:00:00 2001 From: Velaciela Date: Fri, 6 Mar 2026 09:24:21 +0800 Subject: [PATCH 19/73] batch-14-SW --- ceno_zkvm/src/instructions/riscv/gpu/mod.rs | 2 + ceno_zkvm/src/instructions/riscv/gpu/sw.rs | 233 ++++++++++++++++++ .../src/instructions/riscv/gpu/witgen_gpu.rs | 21 ++ ceno_zkvm/src/instructions/riscv/memory.rs | 2 +- .../src/instructions/riscv/memory/store_v2.rs | 57 ++++- ceno_zkvm/src/instructions/riscv/s_insn.rs | 8 +- 6 files changed, 310 insertions(+), 13 deletions(-) create mode 100644 ceno_zkvm/src/instructions/riscv/gpu/sw.rs diff --git a/ceno_zkvm/src/instructions/riscv/gpu/mod.rs b/ceno_zkvm/src/instructions/riscv/gpu/mod.rs index a80b533c4..9782869c9 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/mod.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/mod.rs @@ -29,6 +29,8 @@ pub mod branch_cmp; #[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] pub mod jalr; #[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod sw; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] pub mod branch_eq; #[cfg(feature = "gpu")] pub mod witgen_gpu; diff --git a/ceno_zkvm/src/instructions/riscv/gpu/sw.rs b/ceno_zkvm/src/instructions/riscv/gpu/sw.rs new file mode 100644 index 000000000..81143a803 --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/gpu/sw.rs @@ -0,0 +1,233 @@ +use ceno_gpu::common::witgen_types::SwColumnMap; +use ff_ext::ExtensionField; + +use crate::instructions::riscv::memory::store_v2::StoreConfig; + +/// Extract column map from a constructed StoreConfig (SW variant, N_ZEROS=2). +pub fn extract_sw_column_map( + config: &StoreConfig, + num_witin: usize, +) -> SwColumnMap { + let sm = &config.s_insn; + + // StateInOut (not branching) + let pc = sm.vm_state.pc.id as u32; + let ts = sm.vm_state.ts.id as u32; + + // ReadRS1 + let rs1_id = sm.rs1.id.id as u32; + let rs1_prev_ts = sm.rs1.prev_ts.id as u32; + let rs1_lt_diff: [u32; 2] = { + let d = &sm.rs1.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + // ReadRS2 + let rs2_id = sm.rs2.id.id as u32; + let rs2_prev_ts = sm.rs2.prev_ts.id as u32; + let rs2_lt_diff: [u32; 2] = { + let d = &sm.rs2.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + // WriteMEM + let mem_prev_ts = sm.mem_write.prev_ts.id as u32; + let mem_lt_diff: [u32; 2] = { + let d = &sm.mem_write.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + // SW-specific + let rs1_limbs: [u32; 2] = { + let l = config.rs1_read.wits_in().expect("rs1_read WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + let rs2_limbs: [u32; 2] = { + let l = config.rs2_read.wits_in().expect("rs2_read WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + let imm = config.imm.id as u32; + let imm_sign = config.imm_sign.id as u32; + let prev_mem_val: [u32; 2] = { + let l = config + .prev_memory_value + .wits_in() + .expect("prev_memory_value WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + let mem_addr: [u32; 2] = { + let l = config + .memory_addr + .addr + .wits_in() + .expect("memory_addr WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + + SwColumnMap { + pc, + ts, + rs1_id, + rs1_prev_ts, + rs1_lt_diff, + rs2_id, + rs2_prev_ts, + rs2_lt_diff, + mem_prev_ts, + mem_lt_diff, + rs1_limbs, + rs2_limbs, + imm, + imm_sign, + prev_mem_val, + mem_addr, + num_cols: num_witin as u32, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::{Instruction, riscv::memory::SwInstruction}, + structs::ProgramParams, + }; + use ff_ext::BabyBearExt4; + + type E = BabyBearExt4; + + #[test] + fn test_extract_sw_column_map() { + let mut cs = ConstraintSystem::::new(|| "test_sw"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + SwInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + + let col_map = extract_sw_column_map(&config, cb.cs.num_witin as usize); + let flat = col_map.to_flat(); + + for (i, &col) in flat.iter().enumerate() { + assert!( + (col as usize) < col_map.num_cols as usize, + "Column {} (index {}) out of range: {} >= {}", + i, col, col, col_map.num_cols + ); + } + let mut seen = std::collections::HashSet::new(); + for &col in &flat { + assert!(seen.insert(col), "Duplicate column ID: {}", col); + } + } + + #[test] + #[cfg(feature = "gpu")] + fn test_gpu_witgen_sw_correctness() { + use crate::e2e::ShardContext; + use ceno_emul::{ + ByteAddr, Change, InsnKind, StepRecord, WordAddr, WriteOp, encode_rv32, + }; + use ceno_gpu::{Buffer, bb31::CudaHalBB31}; + + let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); + + let mut cs = ConstraintSystem::::new(|| "test_sw_gpu"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + SwInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let num_witin = cb.cs.num_witin as usize; + let num_structural_witin = cb.cs.num_structural_witin as usize; + + let n = 1024; + let imm_values: [i32; 4] = [0, 4, -4, -8]; + let steps: Vec = (0..n) + .map(|i| { + let pc = ByteAddr(0x1000 + (i as u32) * 4); + let rs1_val = 0x1000u32 + (i as u32) * 16; // 16-byte aligned base + let rs2_val = (i as u32) * 111 % 1000000; // value to store + let imm: i32 = imm_values[i % imm_values.len()]; + let mem_addr = rs1_val.wrapping_add_signed(imm); + let prev_mem_val = (i as u32) * 77 % 500000; + let cycle = 4 + (i as u64) * 4; + let insn_code = encode_rv32(InsnKind::SW, 2, 3, 0, imm); + + let mem_write_op = WriteOp { + addr: WordAddr::from(ByteAddr(mem_addr)), + value: Change::new(prev_mem_val, rs2_val), + previous_cycle: 0, + }; + + StepRecord::new_s_instruction( + cycle, + pc, + insn_code, + rs1_val, + rs2_val, + mem_write_op, + 0, + ) + }) + .collect(); + let indices: Vec = (0..n).collect(); + + let mut shard_ctx = ShardContext::default(); + let (cpu_rmms, _lkm) = + crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(); + let cpu_witness = &cpu_rmms[0]; + + let col_map = extract_sw_column_map(&config, num_witin); + let shard_ctx_gpu = ShardContext::default(); + let shard_offset = shard_ctx_gpu.current_shard_offset_cycle(); + let steps_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + steps.as_ptr() as *const u8, + steps.len() * std::mem::size_of::(), + ) + }; + let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); + let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); + let gpu_result = hal + .witgen_sw(&col_map, &gpu_records, &indices_u32, shard_offset, None) + .unwrap(); + + let gpu_data: Vec<::BaseField> = + gpu_result.device_buffer.to_vec().unwrap(); + let cpu_data = cpu_witness.values(); + assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); + + let flat = col_map.to_flat(); + let mut mismatches = 0; + for row in 0..n { + for &col in &flat { + let c = col as usize; + let gpu_val = gpu_data[row * num_witin + c]; + let cpu_val = cpu_data[row * num_witin + c]; + if gpu_val != cpu_val { + if mismatches < 10 { + eprintln!( + "Mismatch at row={}, col={}: GPU={:?}, CPU={:?}", + row, c, gpu_val, cpu_val + ); + } + mismatches += 1; + } + } + } + assert_eq!(mismatches, 0, "Found {} mismatches", mismatches); + } +} diff --git a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs index d9908f619..ccf32c88b 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs @@ -52,6 +52,8 @@ pub enum GpuWitgenKind { BranchCmp(u32), // 1=signed (BLT/BGE), 0=unsigned (BLTU/BGEU) #[cfg(feature = "u16limb_circuit")] Jalr, + #[cfg(feature = "u16limb_circuit")] + Sw, Lw, } @@ -564,6 +566,25 @@ fn gpu_fill_witness>( }) }) } + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Sw => { + let sw_config = unsafe { + &*(config as *const I::InstructionConfig + as *const crate::instructions::riscv::memory::store_v2::StoreConfig) + }; + let col_map = info_span!("col_map") + .in_scope(|| super::sw::extract_sw_column_map(sw_config, num_witin)); + info_span!("hal_witgen_sw").in_scope(|| { + with_cached_shard_steps(|gpu_records| { + hal.witgen_sw(&col_map, gpu_records, &indices_u32, shard_offset, None) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_sw failed: {e}").into(), + ) + }) + }) + }) + } GpuWitgenKind::Lw => { #[cfg(feature = "u16limb_circuit")] let load_config = unsafe { diff --git a/ceno_zkvm/src/instructions/riscv/memory.rs b/ceno_zkvm/src/instructions/riscv/memory.rs index ca432360b..294d7fd44 100644 --- a/ceno_zkvm/src/instructions/riscv/memory.rs +++ b/ceno_zkvm/src/instructions/riscv/memory.rs @@ -8,7 +8,7 @@ pub mod store; #[cfg(feature = "u16limb_circuit")] pub mod load_v2; #[cfg(feature = "u16limb_circuit")] -mod store_v2; +pub(crate) mod store_v2; #[cfg(test)] mod test; diff --git a/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs b/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs index a1bd7a812..d4b7c00af 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs @@ -23,17 +23,24 @@ use multilinear_extensions::{ToExpr, WitIn}; use p3::field::{Field, FieldAlgebra}; use std::marker::PhantomData; +#[cfg(feature = "gpu")] +use crate::tables::RMMCollections; +#[cfg(feature = "gpu")] +use ceno_emul::StepIndex; +#[cfg(feature = "gpu")] +use gkr_iop::utils::lk_multiplicity::Multiplicity; + pub struct StoreConfig { - s_insn: SInstructionConfig, + pub(crate) s_insn: SInstructionConfig, - rs1_read: UInt, - rs2_read: UInt, - imm: WitIn, - imm_sign: WitIn, - prev_memory_value: UInt, + pub(crate) rs1_read: UInt, + pub(crate) rs2_read: UInt, + pub(crate) imm: WitIn, + pub(crate) imm_sign: WitIn, + pub(crate) prev_memory_value: UInt, - memory_addr: MemAddr, - next_memory_value: Option>, + pub(crate) memory_addr: MemAddr, + pub(crate) next_memory_value: Option>, } pub struct StoreInstruction(PhantomData<(E, I)>); @@ -171,4 +178,38 @@ impl Instruction Ok(()) } + + #[cfg(feature = "gpu")] + fn assign_instances( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + num_witin: usize, + num_structural_witin: usize, + shard_steps: &[StepRecord], + step_indices: &[StepIndex], + ) -> Result<(RMMCollections, Multiplicity), ZKVMError> { + use crate::instructions::riscv::gpu::witgen_gpu; + // Only SW (N_ZEROS=2) has GPU support currently + if I::INST_KIND == InsnKind::SW { + if let Some(result) = witgen_gpu::try_gpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + witgen_gpu::GpuWitgenKind::Sw, + )? { + return Ok(result); + } + } + crate::instructions::cpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + ) + } } diff --git a/ceno_zkvm/src/instructions/riscv/s_insn.rs b/ceno_zkvm/src/instructions/riscv/s_insn.rs index f252a7c60..3ffa77f3f 100644 --- a/ceno_zkvm/src/instructions/riscv/s_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/s_insn.rs @@ -16,10 +16,10 @@ use multilinear_extensions::{Expression, ToExpr}; /// - Registers reads. /// - Memory write pub struct SInstructionConfig { - vm_state: StateInOut, - rs1: ReadRS1, - rs2: ReadRS2, - mem_write: WriteMEM, + pub(crate) vm_state: StateInOut, + pub(crate) rs1: ReadRS1, + pub(crate) rs2: ReadRS2, + pub(crate) mem_write: WriteMEM, } impl SInstructionConfig { From 4ac98ab86c5d801b3a7a3307321aa06e2eb7a7fa Mon Sep 17 00:00:00 2001 From: Velaciela Date: Fri, 6 Mar 2026 09:34:42 +0800 Subject: [PATCH 20/73] batch-15-SH,SB --- ceno_zkvm/src/instructions/riscv/gpu/mod.rs | 4 + ceno_zkvm/src/instructions/riscv/gpu/sb.rs | 269 ++++++++++++++++++ ceno_zkvm/src/instructions/riscv/gpu/sh.rs | 246 ++++++++++++++++ .../src/instructions/riscv/gpu/witgen_gpu.rs | 42 +++ .../src/instructions/riscv/memory/gadget.rs | 6 +- .../src/instructions/riscv/memory/store_v2.rs | 11 +- 6 files changed, 572 insertions(+), 6 deletions(-) create mode 100644 ceno_zkvm/src/instructions/riscv/gpu/sb.rs create mode 100644 ceno_zkvm/src/instructions/riscv/gpu/sh.rs diff --git a/ceno_zkvm/src/instructions/riscv/gpu/mod.rs b/ceno_zkvm/src/instructions/riscv/gpu/mod.rs index 9782869c9..79badd87a 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/mod.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/mod.rs @@ -31,6 +31,10 @@ pub mod jalr; #[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] pub mod sw; #[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod sh; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod sb; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] pub mod branch_eq; #[cfg(feature = "gpu")] pub mod witgen_gpu; diff --git a/ceno_zkvm/src/instructions/riscv/gpu/sb.rs b/ceno_zkvm/src/instructions/riscv/gpu/sb.rs new file mode 100644 index 000000000..8b5c8689e --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/gpu/sb.rs @@ -0,0 +1,269 @@ +use ceno_gpu::common::witgen_types::SbColumnMap; +use ff_ext::ExtensionField; + +use crate::instructions::riscv::memory::store_v2::StoreConfig; + +/// Extract column map from a constructed StoreConfig (SB variant, N_ZEROS=0). +pub fn extract_sb_column_map( + config: &StoreConfig, + num_witin: usize, +) -> SbColumnMap { + let sm = &config.s_insn; + + // StateInOut (not branching) + let pc = sm.vm_state.pc.id as u32; + let ts = sm.vm_state.ts.id as u32; + + // ReadRS1 + let rs1_id = sm.rs1.id.id as u32; + let rs1_prev_ts = sm.rs1.prev_ts.id as u32; + let rs1_lt_diff: [u32; 2] = { + let d = &sm.rs1.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + // ReadRS2 + let rs2_id = sm.rs2.id.id as u32; + let rs2_prev_ts = sm.rs2.prev_ts.id as u32; + let rs2_lt_diff: [u32; 2] = { + let d = &sm.rs2.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + // WriteMEM + let mem_prev_ts = sm.mem_write.prev_ts.id as u32; + let mem_lt_diff: [u32; 2] = { + let d = &sm.mem_write.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + // Store-specific + let rs1_limbs: [u32; 2] = { + let l = config.rs1_read.wits_in().expect("rs1_read WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + let rs2_limbs: [u32; 2] = { + let l = config.rs2_read.wits_in().expect("rs2_read WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + let imm = config.imm.id as u32; + let imm_sign = config.imm_sign.id as u32; + let prev_mem_val: [u32; 2] = { + let l = config + .prev_memory_value + .wits_in() + .expect("prev_memory_value WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + let mem_addr: [u32; 2] = { + let l = config + .memory_addr + .addr + .wits_in() + .expect("memory_addr WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + + // SB-specific: 2 low_bits (bit_0, bit_1) + assert_eq!(config.memory_addr.low_bits.len(), 2, "SB should have 2 low_bits"); + let mem_addr_bit_0 = config.memory_addr.low_bits[0].id as u32; + let mem_addr_bit_1 = config.memory_addr.low_bits[1].id as u32; + + // MemWordUtil fields (SB has N_ZEROS=0 so these exist) + let mem_word_util = config + .next_memory_value + .as_ref() + .expect("SB must have next_memory_value (MemWordUtil)"); + assert_eq!(mem_word_util.prev_limb_bytes.len(), 2); + let prev_limb_bytes: [u32; 2] = [ + mem_word_util.prev_limb_bytes[0].id as u32, + mem_word_util.prev_limb_bytes[1].id as u32, + ]; + assert_eq!(mem_word_util.rs2_limb_bytes.len(), 1); + let rs2_limb_byte = mem_word_util.rs2_limb_bytes[0].id as u32; + let expected_limb = mem_word_util + .expected_limb + .as_ref() + .expect("SB must have expected_limb") + .id as u32; + + SbColumnMap { + pc, + ts, + rs1_id, + rs1_prev_ts, + rs1_lt_diff, + rs2_id, + rs2_prev_ts, + rs2_lt_diff, + mem_prev_ts, + mem_lt_diff, + rs1_limbs, + rs2_limbs, + imm, + imm_sign, + prev_mem_val, + mem_addr, + mem_addr_bit_0, + mem_addr_bit_1, + prev_limb_bytes, + rs2_limb_byte, + expected_limb, + num_cols: num_witin as u32, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::{Instruction, riscv::memory::SbInstruction}, + structs::ProgramParams, + }; + use ff_ext::BabyBearExt4; + + type E = BabyBearExt4; + + #[test] + fn test_extract_sb_column_map() { + let mut cs = ConstraintSystem::::new(|| "test_sb"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + SbInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + + let col_map = extract_sb_column_map(&config, cb.cs.num_witin as usize); + let flat = col_map.to_flat(); + + for (i, &col) in flat.iter().enumerate() { + assert!( + (col as usize) < col_map.num_cols as usize, + "Column {} (index {}) out of range: {} >= {}", + i, col, col, col_map.num_cols + ); + } + let mut seen = std::collections::HashSet::new(); + for &col in &flat { + assert!(seen.insert(col), "Duplicate column ID: {}", col); + } + } + + #[test] + #[cfg(feature = "gpu")] + fn test_gpu_witgen_sb_correctness() { + use crate::e2e::ShardContext; + use ceno_emul::{ + ByteAddr, Change, InsnKind, StepRecord, WordAddr, WriteOp, encode_rv32, + }; + use ceno_gpu::{Buffer, bb31::CudaHalBB31}; + + let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); + + let mut cs = ConstraintSystem::::new(|| "test_sb_gpu"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + SbInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let num_witin = cb.cs.num_witin as usize; + let num_structural_witin = cb.cs.num_structural_witin as usize; + + let n = 1024; + let imm_values: [i32; 4] = [0, 1, -1, -3]; + let steps: Vec = (0..n) + .map(|i| { + let pc = ByteAddr(0x1000 + (i as u32) * 4); + let rs1_val = 0x1000u32 + (i as u32) * 16; + let rs2_val = (i as u32) * 111 % 1000000; + let imm: i32 = imm_values[i % imm_values.len()]; + let mem_addr = rs1_val.wrapping_add_signed(imm); + let prev_mem_val = (i as u32) * 77 % 500000; + // SB stores the low byte of rs2 into the selected byte + let bit_0 = mem_addr & 1; + let bit_1 = (mem_addr >> 1) & 1; + let rs2_byte = (rs2_val & 0xFF) as u8; + let byte_idx = (bit_1 * 2 + bit_0) as usize; + let mut bytes = prev_mem_val.to_le_bytes(); + bytes[byte_idx] = rs2_byte; + let new_mem_val = u32::from_le_bytes(bytes); + let cycle = 4 + (i as u64) * 4; + let insn_code = encode_rv32(InsnKind::SB, 2, 3, 0, imm); + + let mem_write_op = WriteOp { + addr: WordAddr::from(ByteAddr(mem_addr & !3)), + value: Change::new(prev_mem_val, new_mem_val), + previous_cycle: 0, + }; + + StepRecord::new_s_instruction( + cycle, + pc, + insn_code, + rs1_val, + rs2_val, + mem_write_op, + 0, + ) + }) + .collect(); + let indices: Vec = (0..n).collect(); + + let mut shard_ctx = ShardContext::default(); + let (cpu_rmms, _lkm) = + crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(); + let cpu_witness = &cpu_rmms[0]; + + let col_map = extract_sb_column_map(&config, num_witin); + let shard_ctx_gpu = ShardContext::default(); + let shard_offset = shard_ctx_gpu.current_shard_offset_cycle(); + let steps_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + steps.as_ptr() as *const u8, + steps.len() * std::mem::size_of::(), + ) + }; + let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); + let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); + let gpu_result = hal + .witgen_sb(&col_map, &gpu_records, &indices_u32, shard_offset, None) + .unwrap(); + + let gpu_data: Vec<::BaseField> = + gpu_result.device_buffer.to_vec().unwrap(); + let cpu_data = cpu_witness.values(); + assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); + + let flat = col_map.to_flat(); + let mut mismatches = 0; + for row in 0..n { + for &col in &flat { + let c = col as usize; + let gpu_val = gpu_data[row * num_witin + c]; + let cpu_val = cpu_data[row * num_witin + c]; + if gpu_val != cpu_val { + if mismatches < 10 { + eprintln!( + "Mismatch at row={}, col={}: GPU={:?}, CPU={:?}", + row, c, gpu_val, cpu_val + ); + } + mismatches += 1; + } + } + } + assert_eq!(mismatches, 0, "Found {} mismatches", mismatches); + } +} diff --git a/ceno_zkvm/src/instructions/riscv/gpu/sh.rs b/ceno_zkvm/src/instructions/riscv/gpu/sh.rs new file mode 100644 index 000000000..1e4c103f5 --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/gpu/sh.rs @@ -0,0 +1,246 @@ +use ceno_gpu::common::witgen_types::ShColumnMap; +use ff_ext::ExtensionField; + +use crate::instructions::riscv::memory::store_v2::StoreConfig; + +/// Extract column map from a constructed StoreConfig (SH variant, N_ZEROS=1). +pub fn extract_sh_column_map( + config: &StoreConfig, + num_witin: usize, +) -> ShColumnMap { + let sm = &config.s_insn; + + // StateInOut (not branching) + let pc = sm.vm_state.pc.id as u32; + let ts = sm.vm_state.ts.id as u32; + + // ReadRS1 + let rs1_id = sm.rs1.id.id as u32; + let rs1_prev_ts = sm.rs1.prev_ts.id as u32; + let rs1_lt_diff: [u32; 2] = { + let d = &sm.rs1.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + // ReadRS2 + let rs2_id = sm.rs2.id.id as u32; + let rs2_prev_ts = sm.rs2.prev_ts.id as u32; + let rs2_lt_diff: [u32; 2] = { + let d = &sm.rs2.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + // WriteMEM + let mem_prev_ts = sm.mem_write.prev_ts.id as u32; + let mem_lt_diff: [u32; 2] = { + let d = &sm.mem_write.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + // Store-specific + let rs1_limbs: [u32; 2] = { + let l = config.rs1_read.wits_in().expect("rs1_read WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + let rs2_limbs: [u32; 2] = { + let l = config.rs2_read.wits_in().expect("rs2_read WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + let imm = config.imm.id as u32; + let imm_sign = config.imm_sign.id as u32; + let prev_mem_val: [u32; 2] = { + let l = config + .prev_memory_value + .wits_in() + .expect("prev_memory_value WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + let mem_addr: [u32; 2] = { + let l = config + .memory_addr + .addr + .wits_in() + .expect("memory_addr WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + + // SH-specific: 1 low_bit (bit_1 for halfword select) + assert_eq!(config.memory_addr.low_bits.len(), 1, "SH should have 1 low_bit"); + let mem_addr_bit_1 = config.memory_addr.low_bits[0].id as u32; + + ShColumnMap { + pc, + ts, + rs1_id, + rs1_prev_ts, + rs1_lt_diff, + rs2_id, + rs2_prev_ts, + rs2_lt_diff, + mem_prev_ts, + mem_lt_diff, + rs1_limbs, + rs2_limbs, + imm, + imm_sign, + prev_mem_val, + mem_addr, + mem_addr_bit_1, + num_cols: num_witin as u32, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::{Instruction, riscv::memory::ShInstruction}, + structs::ProgramParams, + }; + use ff_ext::BabyBearExt4; + + type E = BabyBearExt4; + + #[test] + fn test_extract_sh_column_map() { + let mut cs = ConstraintSystem::::new(|| "test_sh"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + ShInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + + let col_map = extract_sh_column_map(&config, cb.cs.num_witin as usize); + let flat = col_map.to_flat(); + + for (i, &col) in flat.iter().enumerate() { + assert!( + (col as usize) < col_map.num_cols as usize, + "Column {} (index {}) out of range: {} >= {}", + i, col, col, col_map.num_cols + ); + } + let mut seen = std::collections::HashSet::new(); + for &col in &flat { + assert!(seen.insert(col), "Duplicate column ID: {}", col); + } + } + + #[test] + #[cfg(feature = "gpu")] + fn test_gpu_witgen_sh_correctness() { + use crate::e2e::ShardContext; + use ceno_emul::{ + ByteAddr, Change, InsnKind, StepRecord, WordAddr, WriteOp, encode_rv32, + }; + use ceno_gpu::{Buffer, bb31::CudaHalBB31}; + + let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); + + let mut cs = ConstraintSystem::::new(|| "test_sh_gpu"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + ShInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let num_witin = cb.cs.num_witin as usize; + let num_structural_witin = cb.cs.num_structural_witin as usize; + + let n = 1024; + let imm_values: [i32; 4] = [0, 2, -2, -6]; + let steps: Vec = (0..n) + .map(|i| { + let pc = ByteAddr(0x1000 + (i as u32) * 4); + let rs1_val = 0x1000u32 + (i as u32) * 16; + let rs2_val = (i as u32) * 111 % 1000000; + let imm: i32 = imm_values[i % imm_values.len()]; + let mem_addr = rs1_val.wrapping_add_signed(imm); + // SH stores the low halfword of rs2 into the selected halfword + let prev_mem_val = (i as u32) * 77 % 500000; + let bit_1 = (mem_addr >> 1) & 1; + let rs2_hw = rs2_val & 0xFFFF; + let new_mem_val = if bit_1 == 0 { + (prev_mem_val & 0xFFFF0000) | rs2_hw + } else { + (prev_mem_val & 0x0000FFFF) | (rs2_hw << 16) + }; + let cycle = 4 + (i as u64) * 4; + let insn_code = encode_rv32(InsnKind::SH, 2, 3, 0, imm); + + let mem_write_op = WriteOp { + addr: WordAddr::from(ByteAddr(mem_addr & !3)), + value: Change::new(prev_mem_val, new_mem_val), + previous_cycle: 0, + }; + + StepRecord::new_s_instruction( + cycle, + pc, + insn_code, + rs1_val, + rs2_val, + mem_write_op, + 0, + ) + }) + .collect(); + let indices: Vec = (0..n).collect(); + + let mut shard_ctx = ShardContext::default(); + let (cpu_rmms, _lkm) = + crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(); + let cpu_witness = &cpu_rmms[0]; + + let col_map = extract_sh_column_map(&config, num_witin); + let shard_ctx_gpu = ShardContext::default(); + let shard_offset = shard_ctx_gpu.current_shard_offset_cycle(); + let steps_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + steps.as_ptr() as *const u8, + steps.len() * std::mem::size_of::(), + ) + }; + let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); + let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); + let gpu_result = hal + .witgen_sh(&col_map, &gpu_records, &indices_u32, shard_offset, None) + .unwrap(); + + let gpu_data: Vec<::BaseField> = + gpu_result.device_buffer.to_vec().unwrap(); + let cpu_data = cpu_witness.values(); + assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); + + let flat = col_map.to_flat(); + let mut mismatches = 0; + for row in 0..n { + for &col in &flat { + let c = col as usize; + let gpu_val = gpu_data[row * num_witin + c]; + let cpu_val = cpu_data[row * num_witin + c]; + if gpu_val != cpu_val { + if mismatches < 10 { + eprintln!( + "Mismatch at row={}, col={}: GPU={:?}, CPU={:?}", + row, c, gpu_val, cpu_val + ); + } + mismatches += 1; + } + } + } + assert_eq!(mismatches, 0, "Found {} mismatches", mismatches); + } +} diff --git a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs index ccf32c88b..e918b60a9 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs @@ -54,6 +54,10 @@ pub enum GpuWitgenKind { Jalr, #[cfg(feature = "u16limb_circuit")] Sw, + #[cfg(feature = "u16limb_circuit")] + Sh, + #[cfg(feature = "u16limb_circuit")] + Sb, Lw, } @@ -585,6 +589,44 @@ fn gpu_fill_witness>( }) }) } + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Sh => { + let sh_config = unsafe { + &*(config as *const I::InstructionConfig + as *const crate::instructions::riscv::memory::store_v2::StoreConfig) + }; + let col_map = info_span!("col_map") + .in_scope(|| super::sh::extract_sh_column_map(sh_config, num_witin)); + info_span!("hal_witgen_sh").in_scope(|| { + with_cached_shard_steps(|gpu_records| { + hal.witgen_sh(&col_map, gpu_records, &indices_u32, shard_offset, None) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_sh failed: {e}").into(), + ) + }) + }) + }) + } + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Sb => { + let sb_config = unsafe { + &*(config as *const I::InstructionConfig + as *const crate::instructions::riscv::memory::store_v2::StoreConfig) + }; + let col_map = info_span!("col_map") + .in_scope(|| super::sb::extract_sb_column_map(sb_config, num_witin)); + info_span!("hal_witgen_sb").in_scope(|| { + with_cached_shard_steps(|gpu_records| { + hal.witgen_sb(&col_map, gpu_records, &indices_u32, shard_offset, None) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_sb failed: {e}").into(), + ) + }) + }) + }) + } GpuWitgenKind::Lw => { #[cfg(feature = "u16limb_circuit")] let load_config = unsafe { diff --git a/ceno_zkvm/src/instructions/riscv/memory/gadget.rs b/ceno_zkvm/src/instructions/riscv/memory/gadget.rs index a37be1f61..5b95ddb05 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/gadget.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/gadget.rs @@ -14,10 +14,10 @@ use p3::field::{Field, FieldAlgebra}; use witness::set_val; pub struct MemWordUtil { - prev_limb_bytes: Vec, - rs2_limb_bytes: Vec, + pub(crate) prev_limb_bytes: Vec, + pub(crate) rs2_limb_bytes: Vec, - expected_limb: Option, + pub(crate) expected_limb: Option, expect_limbs_expr: [Expression; 2], } diff --git a/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs b/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs index d4b7c00af..84f6a87ce 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs @@ -189,8 +189,13 @@ impl Instruction step_indices: &[StepIndex], ) -> Result<(RMMCollections, Multiplicity), ZKVMError> { use crate::instructions::riscv::gpu::witgen_gpu; - // Only SW (N_ZEROS=2) has GPU support currently - if I::INST_KIND == InsnKind::SW { + let gpu_kind = match I::INST_KIND { + InsnKind::SW => Some(witgen_gpu::GpuWitgenKind::Sw), + InsnKind::SH => Some(witgen_gpu::GpuWitgenKind::Sh), + InsnKind::SB => Some(witgen_gpu::GpuWitgenKind::Sb), + _ => None, + }; + if let Some(kind) = gpu_kind { if let Some(result) = witgen_gpu::try_gpu_assign_instances::( config, shard_ctx, @@ -198,7 +203,7 @@ impl Instruction num_structural_witin, shard_steps, step_indices, - witgen_gpu::GpuWitgenKind::Sw, + kind, )? { return Ok(result); } From 3b9e4b53a31ccd724dc1a34c7fd08ab795a60b1b Mon Sep 17 00:00:00 2001 From: Velaciela Date: Fri, 6 Mar 2026 09:44:47 +0800 Subject: [PATCH 21/73] batch-16-LH,LB --- ceno_zkvm/src/gadgets/signed_ext.rs | 4 + .../src/instructions/riscv/gpu/load_sub.rs | 391 ++++++++++++++++++ ceno_zkvm/src/instructions/riscv/gpu/mod.rs | 2 + .../src/instructions/riscv/gpu/witgen_gpu.rs | 23 ++ .../src/instructions/riscv/memory/load_v2.rs | 12 +- 5 files changed, 430 insertions(+), 2 deletions(-) create mode 100644 ceno_zkvm/src/instructions/riscv/gpu/load_sub.rs diff --git a/ceno_zkvm/src/gadgets/signed_ext.rs b/ceno_zkvm/src/gadgets/signed_ext.rs index 4be082386..40683274d 100644 --- a/ceno_zkvm/src/gadgets/signed_ext.rs +++ b/ceno_zkvm/src/gadgets/signed_ext.rs @@ -44,6 +44,10 @@ impl SignedExtendConfig { self.msb.expr() } + pub(crate) fn msb(&self) -> WitIn { + self.msb + } + fn construct_circuit( cb: &mut CircuitBuilder, n_bits: usize, diff --git a/ceno_zkvm/src/instructions/riscv/gpu/load_sub.rs b/ceno_zkvm/src/instructions/riscv/gpu/load_sub.rs new file mode 100644 index 000000000..9c42e909b --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/gpu/load_sub.rs @@ -0,0 +1,391 @@ +use ceno_gpu::common::witgen_types::LoadSubColumnMap; +use ff_ext::ExtensionField; + +use crate::instructions::riscv::memory::load_v2::LoadConfig; + +/// Extract column map from a constructed LoadConfig for sub-word loads (LH/LHU/LB/LBU). +pub fn extract_load_sub_column_map( + config: &LoadConfig, + num_witin: usize, + is_byte: bool, // true for LB/LBU + is_signed: bool, // true for LH/LB +) -> LoadSubColumnMap { + let im = &config.im_insn; + + // StateInOut + let pc = im.vm_state.pc.id as u32; + let ts = im.vm_state.ts.id as u32; + + // ReadRS1 + let rs1_id = im.rs1.id.id as u32; + let rs1_prev_ts = im.rs1.prev_ts.id as u32; + let rs1_lt_diff: [u32; 2] = { + let d = &im.rs1.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + // WriteRD + let rd_id = im.rd.id.id as u32; + let rd_prev_ts = im.rd.prev_ts.id as u32; + let rd_prev_val: [u32; 2] = { + let l = im.rd.prev_value.wits_in().expect("rd prev_value WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + let rd_lt_diff: [u32; 2] = { + let d = &im.rd.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + // ReadMEM + let mem_prev_ts = im.mem_read.prev_ts.id as u32; + let mem_lt_diff: [u32; 2] = { + let d = &im.mem_read.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + // Load-specific + let rs1_limbs: [u32; 2] = { + let l = config.rs1_read.wits_in().expect("rs1_read WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + let imm = config.imm.id as u32; + let imm_sign = config.imm_sign.id as u32; + let mem_addr: [u32; 2] = { + let l = config + .memory_addr + .addr + .wits_in() + .expect("memory_addr WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + let mem_read: [u32; 2] = { + let l = config.memory_read.wits_in().expect("memory_read WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + + // Sub-word specific: addr_bit_1 (all sub-word loads have at least 1 low_bit) + let low_bits = &config.memory_addr.low_bits; + let addr_bit_1 = if is_byte { + // LB/LBU: 2 low_bits, [0]=bit_0, [1]=bit_1 + assert_eq!(low_bits.len(), 2, "LB/LBU should have 2 low_bits"); + low_bits[1].id as u32 + } else { + // LH/LHU: 1 low_bit, [0]=bit_1 + assert_eq!(low_bits.len(), 1, "LH/LHU should have 1 low_bit"); + low_bits[0].id as u32 + }; + + let target_limb = config + .target_limb + .expect("sub-word loads must have target_limb") + .id as u32; + + // LB/LBU: addr_bit_0, target_byte, dummy_byte + let (addr_bit_0, target_byte, dummy_byte) = if is_byte { + let bytes = config + .target_limb_bytes + .as_ref() + .expect("LB/LBU must have target_limb_bytes"); + assert_eq!(bytes.len(), 2); + ( + Some(low_bits[0].id as u32), + Some(bytes[0].id as u32), + Some(bytes[1].id as u32), + ) + } else { + (None, None, None) + }; + + // Signed: msb + let msb = if is_signed { + let sec = config + .signed_extend_config + .as_ref() + .expect("signed loads must have signed_extend_config"); + Some(sec.msb().id as u32) + } else { + None + }; + + LoadSubColumnMap { + pc, + ts, + rs1_id, + rs1_prev_ts, + rs1_lt_diff, + rd_id, + rd_prev_ts, + rd_prev_val, + rd_lt_diff, + mem_prev_ts, + mem_lt_diff, + rs1_limbs, + imm, + imm_sign, + mem_addr, + mem_read, + addr_bit_1, + target_limb, + addr_bit_0, + target_byte, + dummy_byte, + msb, + num_cols: num_witin as u32, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::{Instruction, riscv::memory::{LhInstruction, LhuInstruction, LbInstruction, LbuInstruction}}, + structs::ProgramParams, + }; + use ff_ext::BabyBearExt4; + + type E = BabyBearExt4; + + fn test_column_map_validity(col_map: &LoadSubColumnMap) { + let (n_entries, flat) = col_map.to_flat(); + for (i, &col) in flat[..n_entries].iter().enumerate() { + assert!( + (col as usize) < col_map.num_cols as usize, + "Column {} (index {}) out of range: {} >= {}", + i, col, col, col_map.num_cols + ); + } + let mut seen = std::collections::HashSet::new(); + for &col in &flat[..n_entries] { + assert!(seen.insert(col), "Duplicate column ID: {}", col); + } + } + + #[test] + fn test_extract_lh_column_map() { + let mut cs = ConstraintSystem::::new(|| "test_lh"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + LhInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let col_map = extract_load_sub_column_map(&config, cb.cs.num_witin as usize, false, true); + test_column_map_validity(&col_map); + assert!(col_map.msb.is_some()); + assert!(col_map.addr_bit_0.is_none()); + } + + #[test] + fn test_extract_lhu_column_map() { + let mut cs = ConstraintSystem::::new(|| "test_lhu"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + LhuInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let col_map = extract_load_sub_column_map(&config, cb.cs.num_witin as usize, false, false); + test_column_map_validity(&col_map); + assert!(col_map.msb.is_none()); + assert!(col_map.addr_bit_0.is_none()); + } + + #[test] + fn test_extract_lb_column_map() { + let mut cs = ConstraintSystem::::new(|| "test_lb"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + LbInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let col_map = extract_load_sub_column_map(&config, cb.cs.num_witin as usize, true, true); + test_column_map_validity(&col_map); + assert!(col_map.msb.is_some()); + assert!(col_map.addr_bit_0.is_some()); + assert!(col_map.target_byte.is_some()); + } + + #[test] + fn test_extract_lbu_column_map() { + let mut cs = ConstraintSystem::::new(|| "test_lbu"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + LbuInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let col_map = extract_load_sub_column_map(&config, cb.cs.num_witin as usize, true, false); + test_column_map_validity(&col_map); + assert!(col_map.msb.is_none()); + assert!(col_map.addr_bit_0.is_some()); + assert!(col_map.target_byte.is_some()); + } + + #[test] + #[cfg(feature = "gpu")] + fn test_gpu_witgen_load_sub_correctness() { + use crate::e2e::ShardContext; + use ceno_emul::{ + ByteAddr, Change, InsnKind, ReadOp, StepRecord, WordAddr, encode_rv32, + }; + use ceno_gpu::{Buffer, bb31::CudaHalBB31}; + + let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); + + // Test all 4 variants + let variants: &[(InsnKind, bool, bool, &str)] = &[ + (InsnKind::LH, false, true, "LH"), + (InsnKind::LHU, false, false, "LHU"), + (InsnKind::LB, true, true, "LB"), + (InsnKind::LBU, true, false, "LBU"), + ]; + + for &(insn_kind, is_byte, is_signed, name) in variants { + eprintln!("Testing {} GPU vs CPU correctness...", name); + + let mut cs = ConstraintSystem::::new(|| format!("test_{}", name.to_lowercase())); + let mut cb = CircuitBuilder::new(&mut cs); + + // We need to construct the right instruction type + let config = match insn_kind { + InsnKind::LH => LhInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(), + InsnKind::LHU => LhuInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(), + InsnKind::LB => LbInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(), + InsnKind::LBU => LbuInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(), + _ => unreachable!(), + }; + let num_witin = cb.cs.num_witin as usize; + let num_structural_witin = cb.cs.num_structural_witin as usize; + + let n = 1024; + let imm_values: [i32; 4] = if is_byte { + [0, 1, -1, -3] + } else { + [0, 2, -2, -6] + }; + + let steps: Vec = (0..n) + .map(|i| { + let pc = ByteAddr(0x1000 + (i as u32) * 4); + let rs1_val = 0x1000u32 + (i as u32) * 16; + let imm: i32 = imm_values[i % imm_values.len()]; + let mem_addr = rs1_val.wrapping_add_signed(imm); + let mem_val = (i as u32) * 111 % 500000; + + // Compute rd_after based on load type + let shift = mem_addr & 3; + let bit_1 = (shift >> 1) & 1; + let bit_0 = shift & 1; + let target_limb: u16 = if bit_1 == 0 { + (mem_val & 0xFFFF) as u16 + } else { + (mem_val >> 16) as u16 + }; + let rd_after = match insn_kind { + InsnKind::LH => { + (target_limb as i16) as i32 as u32 + } + InsnKind::LHU => target_limb as u32, + InsnKind::LB => { + let byte = if bit_0 == 0 { + (target_limb & 0xFF) as u8 + } else { + ((target_limb >> 8) & 0xFF) as u8 + }; + (byte as i8) as i32 as u32 + } + InsnKind::LBU => { + let byte = if bit_0 == 0 { + (target_limb & 0xFF) as u8 + } else { + ((target_limb >> 8) & 0xFF) as u8 + }; + byte as u32 + } + _ => unreachable!(), + }; + let rd_before = (i as u32) % 200; + let cycle = 4 + (i as u64) * 4; + let insn_code = encode_rv32(insn_kind, 2, 0, 4, imm); + + let mem_read_op = ReadOp { + addr: WordAddr::from(ByteAddr(mem_addr & !3)), + value: mem_val, + previous_cycle: 0, + }; + + StepRecord::new_im_instruction( + cycle, + pc, + insn_code, + rs1_val, + Change::new(rd_before, rd_after), + mem_read_op, + 0, + ) + }) + .collect(); + let indices: Vec = (0..n).collect(); + + // CPU path + let mut shard_ctx = ShardContext::default(); + let (cpu_rmms, _lkm) = match insn_kind { + InsnKind::LH => crate::instructions::cpu_assign_instances::>( + &config, &mut shard_ctx, num_witin, num_structural_witin, &steps, &indices, + ).unwrap(), + InsnKind::LHU => crate::instructions::cpu_assign_instances::>( + &config, &mut shard_ctx, num_witin, num_structural_witin, &steps, &indices, + ).unwrap(), + InsnKind::LB => crate::instructions::cpu_assign_instances::>( + &config, &mut shard_ctx, num_witin, num_structural_witin, &steps, &indices, + ).unwrap(), + InsnKind::LBU => crate::instructions::cpu_assign_instances::>( + &config, &mut shard_ctx, num_witin, num_structural_witin, &steps, &indices, + ).unwrap(), + _ => unreachable!(), + }; + let cpu_witness = &cpu_rmms[0]; + + // GPU path + let col_map = extract_load_sub_column_map(&config, num_witin, is_byte, is_signed); + let shard_ctx_gpu = ShardContext::default(); + let shard_offset = shard_ctx_gpu.current_shard_offset_cycle(); + let steps_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + steps.as_ptr() as *const u8, + steps.len() * std::mem::size_of::(), + ) + }; + let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); + let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); + let load_width: u32 = if is_byte { 8 } else { 16 }; + let is_signed_u32: u32 = if is_signed { 1 } else { 0 }; + let gpu_result = hal + .witgen_load_sub(&col_map, &gpu_records, &indices_u32, shard_offset, load_width, is_signed_u32, None) + .unwrap(); + + let gpu_data: Vec<::BaseField> = + gpu_result.device_buffer.to_vec().unwrap(); + let cpu_data = cpu_witness.values(); + assert_eq!(gpu_data.len(), cpu_data.len(), "{}: Size mismatch", name); + + let (n_entries, flat) = col_map.to_flat(); + let mut mismatches = 0; + for row in 0..n { + for &col in &flat[..n_entries] { + let c = col as usize; + let gpu_val = gpu_data[row * num_witin + c]; + let cpu_val = cpu_data[row * num_witin + c]; + if gpu_val != cpu_val { + if mismatches < 10 { + eprintln!( + "{}: Mismatch at row={}, col={}: GPU={:?}, CPU={:?}", + name, row, c, gpu_val, cpu_val + ); + } + mismatches += 1; + } + } + } + assert_eq!(mismatches, 0, "{}: Found {} mismatches", name, mismatches); + eprintln!("{} GPU vs CPU: PASS ({} instances)", name, n); + } + } +} diff --git a/ceno_zkvm/src/instructions/riscv/gpu/mod.rs b/ceno_zkvm/src/instructions/riscv/gpu/mod.rs index 79badd87a..2d0277524 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/mod.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/mod.rs @@ -35,6 +35,8 @@ pub mod sh; #[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] pub mod sb; #[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod load_sub; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] pub mod branch_eq; #[cfg(feature = "gpu")] pub mod witgen_gpu; diff --git a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs index e918b60a9..f4f9b58a9 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs @@ -58,6 +58,8 @@ pub enum GpuWitgenKind { Sh, #[cfg(feature = "u16limb_circuit")] Sb, + #[cfg(feature = "u16limb_circuit")] + LoadSub { load_width: u32, is_signed: u32 }, Lw, } @@ -627,6 +629,27 @@ fn gpu_fill_witness>( }) }) } + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::LoadSub { load_width, is_signed } => { + let load_config = unsafe { + &*(config as *const I::InstructionConfig + as *const crate::instructions::riscv::memory::load_v2::LoadConfig) + }; + let is_byte = load_width == 8; + let is_signed_bool = is_signed != 0; + let col_map = info_span!("col_map") + .in_scope(|| super::load_sub::extract_load_sub_column_map(load_config, num_witin, is_byte, is_signed_bool)); + info_span!("hal_witgen_load_sub").in_scope(|| { + with_cached_shard_steps(|gpu_records| { + hal.witgen_load_sub(&col_map, gpu_records, &indices_u32, shard_offset, load_width, is_signed, None) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_load_sub failed: {e}").into(), + ) + }) + }) + }) + } GpuWitgenKind::Lw => { #[cfg(feature = "u16limb_circuit")] let load_config = unsafe { diff --git a/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs b/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs index efe5b8a3b..28193e4f4 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs @@ -268,7 +268,15 @@ impl Instruction for LoadInstruction { use crate::instructions::riscv::gpu::witgen_gpu; - if I::INST_KIND == InsnKind::LW { + let gpu_kind = match I::INST_KIND { + InsnKind::LW => Some(witgen_gpu::GpuWitgenKind::Lw), + InsnKind::LH => Some(witgen_gpu::GpuWitgenKind::LoadSub { load_width: 16, is_signed: 1 }), + InsnKind::LHU => Some(witgen_gpu::GpuWitgenKind::LoadSub { load_width: 16, is_signed: 0 }), + InsnKind::LB => Some(witgen_gpu::GpuWitgenKind::LoadSub { load_width: 8, is_signed: 1 }), + InsnKind::LBU => Some(witgen_gpu::GpuWitgenKind::LoadSub { load_width: 8, is_signed: 0 }), + _ => None, + }; + if let Some(kind) = gpu_kind { if let Some(result) = witgen_gpu::try_gpu_assign_instances::( config, shard_ctx, @@ -276,7 +284,7 @@ impl Instruction for LoadInstruction Date: Fri, 6 Mar 2026 09:50:28 +0800 Subject: [PATCH 22/73] batch-17-MUL --- ceno_zkvm/src/instructions/riscv/gpu/mod.rs | 2 + ceno_zkvm/src/instructions/riscv/gpu/mul.rs | 301 ++++++++++++++++++ .../src/instructions/riscv/gpu/witgen_gpu.rs | 21 ++ ceno_zkvm/src/instructions/riscv/mulh.rs | 2 +- .../riscv/mulh/mulh_circuit_v2.rs | 58 +++- 5 files changed, 376 insertions(+), 8 deletions(-) create mode 100644 ceno_zkvm/src/instructions/riscv/gpu/mul.rs diff --git a/ceno_zkvm/src/instructions/riscv/gpu/mod.rs b/ceno_zkvm/src/instructions/riscv/gpu/mod.rs index 2d0277524..4b4c177b2 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/mod.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/mod.rs @@ -37,6 +37,8 @@ pub mod sb; #[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] pub mod load_sub; #[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod mul; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] pub mod branch_eq; #[cfg(feature = "gpu")] pub mod witgen_gpu; diff --git a/ceno_zkvm/src/instructions/riscv/gpu/mul.rs b/ceno_zkvm/src/instructions/riscv/gpu/mul.rs new file mode 100644 index 000000000..283e3ed23 --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/gpu/mul.rs @@ -0,0 +1,301 @@ +use ceno_gpu::common::witgen_types::MulColumnMap; +use ff_ext::ExtensionField; + +use crate::instructions::riscv::mulh::mulh_circuit_v2::MulhConfig; + +/// Extract column map from a constructed MulhConfig. +/// mul_kind: 0=MUL, 1=MULH, 2=MULHU, 3=MULHSU +pub fn extract_mul_column_map( + config: &MulhConfig, + num_witin: usize, + mul_kind: u32, +) -> MulColumnMap { + let r = &config.r_insn; + + // R-type base + let pc = r.vm_state.pc.id as u32; + let ts = r.vm_state.ts.id as u32; + + let rs1_id = r.rs1.id.id as u32; + let rs1_prev_ts = r.rs1.prev_ts.id as u32; + let rs1_lt_diff: [u32; 2] = { + let d = &r.rs1.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + let rs2_id = r.rs2.id.id as u32; + let rs2_prev_ts = r.rs2.prev_ts.id as u32; + let rs2_lt_diff: [u32; 2] = { + let d = &r.rs2.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + let rd_id = r.rd.id.id as u32; + let rd_prev_ts = r.rd.prev_ts.id as u32; + let rd_prev_val: [u32; 2] = { + let l = r.rd.prev_value.wits_in().expect("rd prev_value WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + let rd_lt_diff: [u32; 2] = { + let d = &r.rd.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + // Mul-specific + let rs1_limbs: [u32; 2] = { + let l = config.rs1_read.wits_in().expect("rs1_read WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + let rs2_limbs: [u32; 2] = { + let l = config.rs2_read.wits_in().expect("rs2_read WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + let rd_low: [u32; 2] = [config.rd_low[0].id as u32, config.rd_low[1].id as u32]; + + // MULH/MULHU/MULHSU have rd_high + extensions + let (rd_high, rs1_ext, rs2_ext) = if mul_kind != 0 { + let h = config.rd_high.as_ref().expect("MULH variants must have rd_high"); + ( + Some([h[0].id as u32, h[1].id as u32]), + Some(config.rs1_ext.expect("MULH variants must have rs1_ext").id as u32), + Some(config.rs2_ext.expect("MULH variants must have rs2_ext").id as u32), + ) + } else { + (None, None, None) + }; + + MulColumnMap { + pc, + ts, + rs1_id, + rs1_prev_ts, + rs1_lt_diff, + rs2_id, + rs2_prev_ts, + rs2_lt_diff, + rd_id, + rd_prev_ts, + rd_prev_val, + rd_lt_diff, + rs1_limbs, + rs2_limbs, + rd_low, + rd_high, + rs1_ext, + rs2_ext, + num_cols: num_witin as u32, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::{ + Instruction, + riscv::mulh::{MulInstruction, MulhInstruction, MulhsuInstruction, MulhuInstruction}, + }, + structs::ProgramParams, + }; + use ff_ext::BabyBearExt4; + + type E = BabyBearExt4; + + fn test_column_map_validity(col_map: &MulColumnMap) { + let (n_entries, flat) = col_map.to_flat(); + for (i, &col) in flat[..n_entries].iter().enumerate() { + assert!( + (col as usize) < col_map.num_cols as usize, + "Column {} (index {}) out of range: {} >= {}", + i, col, col, col_map.num_cols + ); + } + let mut seen = std::collections::HashSet::new(); + for &col in &flat[..n_entries] { + assert!(seen.insert(col), "Duplicate column ID: {}", col); + } + } + + #[test] + fn test_extract_mul_column_map() { + let mut cs = ConstraintSystem::::new(|| "test_mul"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + MulInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let col_map = extract_mul_column_map(&config, cb.cs.num_witin as usize, 0); + test_column_map_validity(&col_map); + assert!(col_map.rd_high.is_none()); + } + + #[test] + fn test_extract_mulh_column_map() { + let mut cs = ConstraintSystem::::new(|| "test_mulh"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + MulhInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let col_map = extract_mul_column_map(&config, cb.cs.num_witin as usize, 1); + test_column_map_validity(&col_map); + assert!(col_map.rd_high.is_some()); + } + + #[test] + fn test_extract_mulhu_column_map() { + let mut cs = ConstraintSystem::::new(|| "test_mulhu"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + MulhuInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let col_map = extract_mul_column_map(&config, cb.cs.num_witin as usize, 2); + test_column_map_validity(&col_map); + assert!(col_map.rd_high.is_some()); + } + + #[test] + fn test_extract_mulhsu_column_map() { + let mut cs = ConstraintSystem::::new(|| "test_mulhsu"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + MulhsuInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let col_map = extract_mul_column_map(&config, cb.cs.num_witin as usize, 3); + test_column_map_validity(&col_map); + assert!(col_map.rd_high.is_some()); + } + + #[test] + #[cfg(feature = "gpu")] + fn test_gpu_witgen_mul_correctness() { + use crate::e2e::ShardContext; + use ceno_emul::{ByteAddr, Change, InsnKind, StepRecord, encode_rv32}; + use ceno_gpu::{Buffer, bb31::CudaHalBB31}; + + let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); + + let variants: &[(InsnKind, u32, &str)] = &[ + (InsnKind::MUL, 0, "MUL"), + (InsnKind::MULH, 1, "MULH"), + (InsnKind::MULHU, 2, "MULHU"), + (InsnKind::MULHSU, 3, "MULHSU"), + ]; + + for &(insn_kind, mul_kind, name) in variants { + eprintln!("Testing {} GPU vs CPU correctness...", name); + + let mut cs = ConstraintSystem::::new(|| format!("test_{}", name.to_lowercase())); + let mut cb = CircuitBuilder::new(&mut cs); + + let config = match insn_kind { + InsnKind::MUL => MulInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(), + InsnKind::MULH => MulhInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(), + InsnKind::MULHU => MulhuInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(), + InsnKind::MULHSU => MulhsuInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(), + _ => unreachable!(), + }; + let num_witin = cb.cs.num_witin as usize; + let num_structural_witin = cb.cs.num_structural_witin as usize; + + let n = 1024; + let steps: Vec = (0..n) + .map(|i| { + let pc = ByteAddr(0x1000 + (i as u32) * 4); + // Use varied values including negative interpretations + let rs1_val = (i as u32).wrapping_mul(12345).wrapping_add(7); + let rs2_val = (i as u32).wrapping_mul(54321).wrapping_add(13); + let rd_after = match insn_kind { + InsnKind::MUL => rs1_val.wrapping_mul(rs2_val), + InsnKind::MULH => { + ((rs1_val as i32 as i64).wrapping_mul(rs2_val as i32 as i64) >> 32) as u32 + } + InsnKind::MULHU => { + ((rs1_val as u64).wrapping_mul(rs2_val as u64) >> 32) as u32 + } + InsnKind::MULHSU => { + ((rs1_val as i32 as i64).wrapping_mul(rs2_val as i64) >> 32) as u32 + } + _ => unreachable!(), + }; + let rd_before = (i as u32) % 200; + let cycle = 4 + (i as u64) * 4; + let insn_code = encode_rv32(insn_kind, 2, 3, 4, 0); + + StepRecord::new_r_instruction( + cycle, + pc, + insn_code, + rs1_val, + rs2_val, + Change::new(rd_before, rd_after), + 0, + ) + }) + .collect(); + let indices: Vec = (0..n).collect(); + + // CPU path + let mut shard_ctx = ShardContext::default(); + let (cpu_rmms, _lkm) = match insn_kind { + InsnKind::MUL => crate::instructions::cpu_assign_instances::>( + &config, &mut shard_ctx, num_witin, num_structural_witin, &steps, &indices, + ).unwrap(), + InsnKind::MULH => crate::instructions::cpu_assign_instances::>( + &config, &mut shard_ctx, num_witin, num_structural_witin, &steps, &indices, + ).unwrap(), + InsnKind::MULHU => crate::instructions::cpu_assign_instances::>( + &config, &mut shard_ctx, num_witin, num_structural_witin, &steps, &indices, + ).unwrap(), + InsnKind::MULHSU => crate::instructions::cpu_assign_instances::>( + &config, &mut shard_ctx, num_witin, num_structural_witin, &steps, &indices, + ).unwrap(), + _ => unreachable!(), + }; + let cpu_witness = &cpu_rmms[0]; + + // GPU path + let col_map = extract_mul_column_map(&config, num_witin, mul_kind); + let shard_ctx_gpu = ShardContext::default(); + let shard_offset = shard_ctx_gpu.current_shard_offset_cycle(); + let steps_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + steps.as_ptr() as *const u8, + steps.len() * std::mem::size_of::(), + ) + }; + let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); + let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); + let gpu_result = hal + .witgen_mul(&col_map, &gpu_records, &indices_u32, shard_offset, mul_kind, None) + .unwrap(); + + let gpu_data: Vec<::BaseField> = + gpu_result.device_buffer.to_vec().unwrap(); + let cpu_data = cpu_witness.values(); + assert_eq!(gpu_data.len(), cpu_data.len(), "{}: Size mismatch", name); + + let (n_entries, flat) = col_map.to_flat(); + let mut mismatches = 0; + for row in 0..n { + for &col in &flat[..n_entries] { + let c = col as usize; + let gpu_val = gpu_data[row * num_witin + c]; + let cpu_val = cpu_data[row * num_witin + c]; + if gpu_val != cpu_val { + if mismatches < 10 { + eprintln!( + "{}: Mismatch at row={}, col={}: GPU={:?}, CPU={:?}", + name, row, c, gpu_val, cpu_val + ); + } + mismatches += 1; + } + } + } + assert_eq!(mismatches, 0, "{}: Found {} mismatches", name, mismatches); + eprintln!("{} GPU vs CPU: PASS ({} instances)", name, n); + } + } +} diff --git a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs index f4f9b58a9..c5fc70258 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs @@ -60,6 +60,8 @@ pub enum GpuWitgenKind { Sb, #[cfg(feature = "u16limb_circuit")] LoadSub { load_width: u32, is_signed: u32 }, + #[cfg(feature = "u16limb_circuit")] + Mul(u32), // 0=MUL, 1=MULH, 2=MULHU, 3=MULHSU Lw, } @@ -650,6 +652,25 @@ fn gpu_fill_witness>( }) }) } + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Mul(mul_kind) => { + let mul_config = unsafe { + &*(config as *const I::InstructionConfig + as *const crate::instructions::riscv::mulh::mulh_circuit_v2::MulhConfig) + }; + let col_map = info_span!("col_map") + .in_scope(|| super::mul::extract_mul_column_map(mul_config, num_witin, mul_kind)); + info_span!("hal_witgen_mul").in_scope(|| { + with_cached_shard_steps(|gpu_records| { + hal.witgen_mul(&col_map, gpu_records, &indices_u32, shard_offset, mul_kind, None) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_mul failed: {e}").into(), + ) + }) + }) + }) + } GpuWitgenKind::Lw => { #[cfg(feature = "u16limb_circuit")] let load_config = unsafe { diff --git a/ceno_zkvm/src/instructions/riscv/mulh.rs b/ceno_zkvm/src/instructions/riscv/mulh.rs index 4a4c8065b..61279fbb6 100644 --- a/ceno_zkvm/src/instructions/riscv/mulh.rs +++ b/ceno_zkvm/src/instructions/riscv/mulh.rs @@ -4,7 +4,7 @@ use ceno_emul::InsnKind; #[cfg(not(feature = "u16limb_circuit"))] mod mulh_circuit; #[cfg(feature = "u16limb_circuit")] -mod mulh_circuit_v2; +pub(crate) mod mulh_circuit_v2; #[cfg(not(feature = "u16limb_circuit"))] use mulh_circuit::MulhInstructionBase; diff --git a/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit_v2.rs index f3bddff1b..d42f9c7d8 100644 --- a/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit_v2.rs @@ -23,16 +23,23 @@ use crate::e2e::ShardContext; use itertools::Itertools; use std::{array, marker::PhantomData}; +#[cfg(feature = "gpu")] +use crate::tables::RMMCollections; +#[cfg(feature = "gpu")] +use ceno_emul::StepIndex; +#[cfg(feature = "gpu")] +use gkr_iop::utils::lk_multiplicity::Multiplicity; + pub struct MulhInstructionBase(PhantomData<(E, I)>); pub struct MulhConfig { - rs1_read: UInt, - rs2_read: UInt, - r_insn: RInstructionConfig, - rd_low: [WitIn; UINT_LIMBS], - rd_high: Option<[WitIn; UINT_LIMBS]>, - rs1_ext: Option, - rs2_ext: Option, + pub(crate) rs1_read: UInt, + pub(crate) rs2_read: UInt, + pub(crate) r_insn: RInstructionConfig, + pub(crate) rd_low: [WitIn; UINT_LIMBS], + pub(crate) rd_high: Option<[WitIn; UINT_LIMBS]>, + pub(crate) rs1_ext: Option, + pub(crate) rs2_ext: Option, phantom: PhantomData, } @@ -327,6 +334,43 @@ impl Instruction for MulhInstructionBas Ok(()) } + + #[cfg(feature = "gpu")] + fn assign_instances( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + num_witin: usize, + num_structural_witin: usize, + shard_steps: &[StepRecord], + step_indices: &[StepIndex], + ) -> Result<(RMMCollections, Multiplicity), ZKVMError> { + use crate::instructions::riscv::gpu::witgen_gpu; + let mul_kind = match I::INST_KIND { + InsnKind::MUL => 0u32, + InsnKind::MULH => 1u32, + InsnKind::MULHU => 2u32, + InsnKind::MULHSU => 3u32, + _ => { + return crate::instructions::cpu_assign_instances::( + config, shard_ctx, num_witin, num_structural_witin, shard_steps, step_indices, + ); + } + }; + if let Some(result) = witgen_gpu::try_gpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + witgen_gpu::GpuWitgenKind::Mul(mul_kind), + )? { + return Ok(result); + } + crate::instructions::cpu_assign_instances::( + config, shard_ctx, num_witin, num_structural_witin, shard_steps, step_indices, + ) + } } fn run_mulh( From cb48e7ac20eae84867de865dec13095651a31189 Mon Sep 17 00:00:00 2001 From: Velaciela Date: Fri, 6 Mar 2026 09:58:42 +0800 Subject: [PATCH 23/73] batch-18-DIV --- ceno_zkvm/src/instructions/riscv/div.rs | 2 +- .../instructions/riscv/div/div_circuit_v2.rs | 72 +++- ceno_zkvm/src/instructions/riscv/gpu/div.rs | 359 ++++++++++++++++++ ceno_zkvm/src/instructions/riscv/gpu/mod.rs | 2 + .../src/instructions/riscv/gpu/witgen_gpu.rs | 21 + 5 files changed, 443 insertions(+), 13 deletions(-) create mode 100644 ceno_zkvm/src/instructions/riscv/gpu/div.rs diff --git a/ceno_zkvm/src/instructions/riscv/div.rs b/ceno_zkvm/src/instructions/riscv/div.rs index 981995452..829f6140c 100644 --- a/ceno_zkvm/src/instructions/riscv/div.rs +++ b/ceno_zkvm/src/instructions/riscv/div.rs @@ -3,7 +3,7 @@ use ceno_emul::InsnKind; #[cfg(not(feature = "u16limb_circuit"))] mod div_circuit; #[cfg(feature = "u16limb_circuit")] -mod div_circuit_v2; +pub(crate) mod div_circuit_v2; use super::RIVInstruction; diff --git a/ceno_zkvm/src/instructions/riscv/div/div_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/div/div_circuit_v2.rs index eb1a5d0f9..a124c6768 100644 --- a/ceno_zkvm/src/instructions/riscv/div/div_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/div/div_circuit_v2.rs @@ -30,18 +30,18 @@ pub struct DivRemConfig { pub(crate) remainder: UInt, pub(crate) r_insn: RInstructionConfig, - dividend_sign: WitIn, - divisor_sign: WitIn, - quotient_sign: WitIn, - remainder_zero: WitIn, - divisor_zero: WitIn, - divisor_sum_inv: WitIn, - remainder_sum_inv: WitIn, - remainder_inv: [WitIn; UINT_LIMBS], - sign_xor: WitIn, - remainder_prime: UInt, // r' - lt_marker: [WitIn; UINT_LIMBS], - lt_diff: WitIn, + pub(crate) dividend_sign: WitIn, + pub(crate) divisor_sign: WitIn, + pub(crate) quotient_sign: WitIn, + pub(crate) remainder_zero: WitIn, + pub(crate) divisor_zero: WitIn, + pub(crate) divisor_sum_inv: WitIn, + pub(crate) remainder_sum_inv: WitIn, + pub(crate) remainder_inv: [WitIn; UINT_LIMBS], + pub(crate) sign_xor: WitIn, + pub(crate) remainder_prime: UInt, // r' + pub(crate) lt_marker: [WitIn; UINT_LIMBS], + pub(crate) lt_diff: WitIn, } pub struct ArithInstruction(PhantomData<(E, I)>); @@ -376,6 +376,54 @@ impl Instruction for ArithInstruction Result< + ( + crate::tables::RMMCollections, + gkr_iop::utils::lk_multiplicity::Multiplicity, + ), + ZKVMError, + > { + use crate::instructions::riscv::gpu::witgen_gpu; + let div_kind = match I::INST_KIND { + InsnKind::DIV => 0u32, + InsnKind::DIVU => 1u32, + InsnKind::REM => 2u32, + InsnKind::REMU => 3u32, + _ => { + return crate::instructions::cpu_assign_instances::( + config, shard_ctx, num_witin, num_structural_witin, shard_steps, step_indices, + ); + } + }; + if let Some(result) = witgen_gpu::try_gpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + witgen_gpu::GpuWitgenKind::Div(div_kind), + )? { + return Ok(result); + } + crate::instructions::cpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + ) + } + fn assign_instance( config: &Self::InstructionConfig, shard_ctx: &mut ShardContext, diff --git a/ceno_zkvm/src/instructions/riscv/gpu/div.rs b/ceno_zkvm/src/instructions/riscv/gpu/div.rs new file mode 100644 index 000000000..6e1aca75f --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/gpu/div.rs @@ -0,0 +1,359 @@ +use ceno_gpu::common::witgen_types::DivColumnMap; +use ff_ext::ExtensionField; + +use crate::instructions::riscv::div::div_circuit_v2::DivRemConfig; + +/// Extract column map from a constructed DivRemConfig. +/// div_kind: 0=DIV, 1=DIVU, 2=REM, 3=REMU +pub fn extract_div_column_map( + config: &DivRemConfig, + num_witin: usize, +) -> DivColumnMap { + let r = &config.r_insn; + + // R-type base + let pc = r.vm_state.pc.id as u32; + let ts = r.vm_state.ts.id as u32; + + let rs1_id = r.rs1.id.id as u32; + let rs1_prev_ts = r.rs1.prev_ts.id as u32; + let rs1_lt_diff: [u32; 2] = { + let d = &r.rs1.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + let rs2_id = r.rs2.id.id as u32; + let rs2_prev_ts = r.rs2.prev_ts.id as u32; + let rs2_lt_diff: [u32; 2] = { + let d = &r.rs2.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + let rd_id = r.rd.id.id as u32; + let rd_prev_ts = r.rd.prev_ts.id as u32; + let rd_prev_val: [u32; 2] = { + let l = r.rd.prev_value.wits_in().expect("rd prev_value WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + let rd_lt_diff: [u32; 2] = { + let d = &r.rd.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + // Div-specific: operand limbs + let dividend: [u32; 2] = { + let l = config.dividend.wits_in().expect("dividend WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + let divisor: [u32; 2] = { + let l = config.divisor.wits_in().expect("divisor WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + let quotient: [u32; 2] = { + let l = config.quotient.wits_in().expect("quotient WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + let remainder: [u32; 2] = { + let l = config.remainder.wits_in().expect("remainder WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + + // Sign/control bits + let dividend_sign = config.dividend_sign.id as u32; + let divisor_sign = config.divisor_sign.id as u32; + let quotient_sign = config.quotient_sign.id as u32; + let remainder_zero = config.remainder_zero.id as u32; + let divisor_zero = config.divisor_zero.id as u32; + + // Inverse witnesses + let divisor_sum_inv = config.divisor_sum_inv.id as u32; + let remainder_sum_inv = config.remainder_sum_inv.id as u32; + let remainder_inv: [u32; 2] = [ + config.remainder_inv[0].id as u32, + config.remainder_inv[1].id as u32, + ]; + + // sign_xor + let sign_xor = config.sign_xor.id as u32; + + // remainder_prime + let remainder_prime: [u32; 2] = { + let l = config.remainder_prime.wits_in().expect("remainder_prime WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + + // lt_marker + let lt_marker: [u32; 2] = [ + config.lt_marker[0].id as u32, + config.lt_marker[1].id as u32, + ]; + + // lt_diff + let lt_diff = config.lt_diff.id as u32; + + DivColumnMap { + pc, + ts, + rs1_id, + rs1_prev_ts, + rs1_lt_diff, + rs2_id, + rs2_prev_ts, + rs2_lt_diff, + rd_id, + rd_prev_ts, + rd_prev_val, + rd_lt_diff, + dividend, + divisor, + quotient, + remainder, + dividend_sign, + divisor_sign, + quotient_sign, + remainder_zero, + divisor_zero, + divisor_sum_inv, + remainder_sum_inv, + remainder_inv, + sign_xor, + remainder_prime, + lt_marker, + lt_diff, + num_cols: num_witin as u32, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::{ + Instruction, + riscv::div::{DivInstruction, DivuInstruction, RemInstruction, RemuInstruction}, + }, + structs::ProgramParams, + }; + use ff_ext::BabyBearExt4; + + type E = BabyBearExt4; + + fn test_column_map_validity(col_map: &DivColumnMap) { + let (n_entries, flat) = col_map.to_flat(); + for (i, &col) in flat[..n_entries].iter().enumerate() { + assert!( + (col as usize) < col_map.num_cols as usize, + "Column {} (index {}) out of range: {} >= {}", + i, col, col, col_map.num_cols + ); + } + let mut seen = std::collections::HashSet::new(); + for &col in &flat[..n_entries] { + assert!(seen.insert(col), "Duplicate column ID: {}", col); + } + } + + #[test] + fn test_extract_div_column_map() { + let mut cs = ConstraintSystem::::new(|| "test_div"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + DivInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let col_map = extract_div_column_map(&config, cb.cs.num_witin as usize); + test_column_map_validity(&col_map); + } + + #[test] + fn test_extract_divu_column_map() { + let mut cs = ConstraintSystem::::new(|| "test_divu"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + DivuInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let col_map = extract_div_column_map(&config, cb.cs.num_witin as usize); + test_column_map_validity(&col_map); + } + + #[test] + fn test_extract_rem_column_map() { + let mut cs = ConstraintSystem::::new(|| "test_rem"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + RemInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let col_map = extract_div_column_map(&config, cb.cs.num_witin as usize); + test_column_map_validity(&col_map); + } + + #[test] + fn test_extract_remu_column_map() { + let mut cs = ConstraintSystem::::new(|| "test_remu"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + RemuInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let col_map = extract_div_column_map(&config, cb.cs.num_witin as usize); + test_column_map_validity(&col_map); + } + + #[test] + #[cfg(feature = "gpu")] + fn test_gpu_witgen_div_correctness() { + use crate::e2e::ShardContext; + use ceno_emul::{ByteAddr, Change, InsnKind, StepRecord, encode_rv32}; + use ceno_gpu::{Buffer, bb31::CudaHalBB31}; + + let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); + + let variants: &[(InsnKind, u32, &str)] = &[ + (InsnKind::DIV, 0, "DIV"), + (InsnKind::DIVU, 1, "DIVU"), + (InsnKind::REM, 2, "REM"), + (InsnKind::REMU, 3, "REMU"), + ]; + + for &(insn_kind, div_kind, name) in variants { + eprintln!("Testing {} GPU vs CPU correctness...", name); + + let mut cs = ConstraintSystem::::new(|| format!("test_{}", name.to_lowercase())); + let mut cb = CircuitBuilder::new(&mut cs); + + let config = match insn_kind { + InsnKind::DIV => DivInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(), + InsnKind::DIVU => DivuInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(), + InsnKind::REM => RemInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(), + InsnKind::REMU => RemuInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(), + _ => unreachable!(), + }; + let num_witin = cb.cs.num_witin as usize; + let num_structural_witin = cb.cs.num_structural_witin as usize; + + let n = 1024; + let steps: Vec = (0..n) + .map(|i| { + let pc = ByteAddr(0x1000 + (i as u32) * 4); + // Use varied values; include zero divisor and edge cases + let rs1_val = (i as u32).wrapping_mul(12345).wrapping_add(7); + let rs2_val = if i % 50 == 0 { + 0 // test zero divisor + } else { + (i as u32).wrapping_mul(54321).wrapping_add(13) + }; + let rd_after = match insn_kind { + InsnKind::DIV => { + if rs2_val == 0 { + u32::MAX // -1 as u32 + } else { + (rs1_val as i32).wrapping_div(rs2_val as i32) as u32 + } + } + InsnKind::DIVU => { + if rs2_val == 0 { + u32::MAX + } else { + rs1_val / rs2_val + } + } + InsnKind::REM => { + if rs2_val == 0 { + rs1_val + } else { + (rs1_val as i32).wrapping_rem(rs2_val as i32) as u32 + } + } + InsnKind::REMU => { + if rs2_val == 0 { + rs1_val + } else { + rs1_val % rs2_val + } + } + _ => unreachable!(), + }; + let rd_before = (i as u32) % 200; + let cycle = 4 + (i as u64) * 4; + let insn_code = encode_rv32(insn_kind, 2, 3, 4, 0); + + StepRecord::new_r_instruction( + cycle, + pc, + insn_code, + rs1_val, + rs2_val, + Change::new(rd_before, rd_after), + 0, + ) + }) + .collect(); + let indices: Vec = (0..n).collect(); + + // CPU path + let mut shard_ctx = ShardContext::default(); + let (cpu_rmms, _lkm) = match insn_kind { + InsnKind::DIV => crate::instructions::cpu_assign_instances::>( + &config, &mut shard_ctx, num_witin, num_structural_witin, &steps, &indices, + ).unwrap(), + InsnKind::DIVU => crate::instructions::cpu_assign_instances::>( + &config, &mut shard_ctx, num_witin, num_structural_witin, &steps, &indices, + ).unwrap(), + InsnKind::REM => crate::instructions::cpu_assign_instances::>( + &config, &mut shard_ctx, num_witin, num_structural_witin, &steps, &indices, + ).unwrap(), + InsnKind::REMU => crate::instructions::cpu_assign_instances::>( + &config, &mut shard_ctx, num_witin, num_structural_witin, &steps, &indices, + ).unwrap(), + _ => unreachable!(), + }; + let cpu_witness = &cpu_rmms[0]; + + // GPU path + let col_map = extract_div_column_map(&config, num_witin); + let shard_ctx_gpu = ShardContext::default(); + let shard_offset = shard_ctx_gpu.current_shard_offset_cycle(); + let steps_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + steps.as_ptr() as *const u8, + steps.len() * std::mem::size_of::(), + ) + }; + let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); + let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); + let gpu_result = hal + .witgen_div(&col_map, &gpu_records, &indices_u32, shard_offset, div_kind, None) + .unwrap(); + + let gpu_data: Vec<::BaseField> = + gpu_result.device_buffer.to_vec().unwrap(); + let cpu_data = cpu_witness.values(); + assert_eq!(gpu_data.len(), cpu_data.len(), "{}: Size mismatch", name); + + let (n_entries, flat) = col_map.to_flat(); + let mut mismatches = 0; + for row in 0..n { + for &col in &flat[..n_entries] { + let c = col as usize; + let gpu_val = gpu_data[row * num_witin + c]; + let cpu_val = cpu_data[row * num_witin + c]; + if gpu_val != cpu_val { + if mismatches < 10 { + eprintln!( + "{}: Mismatch at row={}, col={}: GPU={:?}, CPU={:?}", + name, row, c, gpu_val, cpu_val + ); + } + mismatches += 1; + } + } + } + assert_eq!(mismatches, 0, "{}: Found {} mismatches", name, mismatches); + eprintln!("{} GPU vs CPU: PASS ({} instances)", name, n); + } + } +} diff --git a/ceno_zkvm/src/instructions/riscv/gpu/mod.rs b/ceno_zkvm/src/instructions/riscv/gpu/mod.rs index 4b4c177b2..ad093a423 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/mod.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/mod.rs @@ -39,6 +39,8 @@ pub mod load_sub; #[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] pub mod mul; #[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod div; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] pub mod branch_eq; #[cfg(feature = "gpu")] pub mod witgen_gpu; diff --git a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs index c5fc70258..6f1ea430b 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs @@ -62,6 +62,8 @@ pub enum GpuWitgenKind { LoadSub { load_width: u32, is_signed: u32 }, #[cfg(feature = "u16limb_circuit")] Mul(u32), // 0=MUL, 1=MULH, 2=MULHU, 3=MULHSU + #[cfg(feature = "u16limb_circuit")] + Div(u32), // 0=DIV, 1=DIVU, 2=REM, 3=REMU Lw, } @@ -671,6 +673,25 @@ fn gpu_fill_witness>( }) }) } + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Div(div_kind) => { + let div_config = unsafe { + &*(config as *const I::InstructionConfig + as *const crate::instructions::riscv::div::div_circuit_v2::DivRemConfig) + }; + let col_map = info_span!("col_map") + .in_scope(|| super::div::extract_div_column_map(div_config, num_witin)); + info_span!("hal_witgen_div").in_scope(|| { + with_cached_shard_steps(|gpu_records| { + hal.witgen_div(&col_map, gpu_records, &indices_u32, shard_offset, div_kind, None) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_div failed: {e}").into(), + ) + }) + }) + }) + } GpuWitgenKind::Lw => { #[cfg(feature = "u16limb_circuit")] let load_config = unsafe { From 6c43c6ae71353351453372d2f5d8e687562c06d6 Mon Sep 17 00:00:00 2001 From: Velaciela Date: Fri, 6 Mar 2026 10:34:49 +0800 Subject: [PATCH 24/73] dev: non-witgen-overlap --- ceno_zkvm/src/e2e.rs | 160 +++++++++++++++++++++---------------------- 1 file changed, 80 insertions(+), 80 deletions(-) diff --git a/ceno_zkvm/src/e2e.rs b/ceno_zkvm/src/e2e.rs index 4fc66df5a..0d7676074 100644 --- a/ceno_zkvm/src/e2e.rs +++ b/ceno_zkvm/src/e2e.rs @@ -1888,86 +1888,86 @@ fn create_proofs_streaming< ) -> Vec> { let ctx = prover.pk.program_ctx.as_ref().unwrap(); let proofs = info_span!("[ceno] app_prove.inner").in_scope(|| { - #[cfg(feature = "gpu")] - { - use crossbeam::channel; - let (tx, rx) = channel::bounded(0); - std::thread::scope(|s| { - // pipeline cpu/gpu workload - // cpu producer - s.spawn({ - move || { - let wit_iter = generate_witness( - &ctx.system_config, - emulation_result, - ctx.program.clone(), - &ctx.platform, - init_mem_state, - target_shard_id, - ); - - let wit_iter = if let Some(target_shard_id) = target_shard_id { - Box::new(wit_iter.skip(target_shard_id)) as Box> - } else { - Box::new(wit_iter) - }; - - for proof_input in wit_iter { - if tx.send(proof_input).is_err() { - tracing::warn!( - "witness consumer dropped; stopping witness generation early" - ); - break; - } - } - } - }); - - // gpu consumer - { - let mut proofs = Vec::new(); - let mut proof_err = None; - let rx = rx; - while let Ok((zkvm_witness, shard_ctx, pi)) = rx.recv() { - if is_mock_proving { - MockProver::assert_satisfied_full( - &shard_ctx, - &ctx.system_config.zkvm_cs, - ctx.zkvm_fixed_traces.clone(), - &zkvm_witness, - &pi, - &ctx.program, - ); - tracing::info!("Mock proving passed"); - } - - let transcript = Transcript::new(b"riscv"); - let start = std::time::Instant::now(); - match prover.create_proof(&shard_ctx, zkvm_witness, pi, transcript) { - Ok(zkvm_proof) => { - tracing::debug!( - "{}th shard proof created in {:?}", - shard_ctx.shard_id, - start.elapsed() - ); - proofs.push(zkvm_proof); - } - Err(err) => { - proof_err = Some(err); - break; - } - } - } - drop(rx); - if let Some(err) = proof_err { - panic!("create_proof failed: {err:?}"); - } - proofs - } - }) - } - - #[cfg(not(feature = "gpu"))] + // #[cfg(feature = "gpu")] + // { + // use crossbeam::channel; + // let (tx, rx) = channel::bounded(0); + // std::thread::scope(|s| { + // // pipeline cpu/gpu workload + // // cpu producer + // s.spawn({ + // move || { + // let wit_iter = generate_witness( + // &ctx.system_config, + // emulation_result, + // ctx.program.clone(), + // &ctx.platform, + // init_mem_state, + // target_shard_id, + // ); + + // let wit_iter = if let Some(target_shard_id) = target_shard_id { + // Box::new(wit_iter.skip(target_shard_id)) as Box> + // } else { + // Box::new(wit_iter) + // }; + + // for proof_input in wit_iter { + // if tx.send(proof_input).is_err() { + // tracing::warn!( + // "witness consumer dropped; stopping witness generation early" + // ); + // break; + // } + // } + // } + // }); + + // // gpu consumer + // { + // let mut proofs = Vec::new(); + // let mut proof_err = None; + // let rx = rx; + // while let Ok((zkvm_witness, shard_ctx, pi)) = rx.recv() { + // if is_mock_proving { + // MockProver::assert_satisfied_full( + // &shard_ctx, + // &ctx.system_config.zkvm_cs, + // ctx.zkvm_fixed_traces.clone(), + // &zkvm_witness, + // &pi, + // &ctx.program, + // ); + // tracing::info!("Mock proving passed"); + // } + + // let transcript = Transcript::new(b"riscv"); + // let start = std::time::Instant::now(); + // match prover.create_proof(&shard_ctx, zkvm_witness, pi, transcript) { + // Ok(zkvm_proof) => { + // tracing::debug!( + // "{}th shard proof created in {:?}", + // shard_ctx.shard_id, + // start.elapsed() + // ); + // proofs.push(zkvm_proof); + // } + // Err(err) => { + // proof_err = Some(err); + // break; + // } + // } + // } + // drop(rx); + // if let Some(err) = proof_err { + // panic!("create_proof failed: {err:?}"); + // } + // proofs + // } + // }) + // } + + // #[cfg(not(feature = "gpu"))] { // Generate witness let wit_iter = generate_witness( From 4ba08c034ad4a351102aca5af24dd0dac1b74bff Mon Sep 17 00:00:00 2001 From: Velaciela Date: Fri, 6 Mar 2026 10:40:57 +0800 Subject: [PATCH 25/73] test coverage: compare all column --- ceno_zkvm/src/instructions/riscv/gpu/addi.rs | 4 +--- ceno_zkvm/src/instructions/riscv/gpu/auipc.rs | 4 +--- ceno_zkvm/src/instructions/riscv/gpu/branch_cmp.rs | 4 +--- ceno_zkvm/src/instructions/riscv/gpu/branch_eq.rs | 4 +--- ceno_zkvm/src/instructions/riscv/gpu/div.rs | 4 +--- ceno_zkvm/src/instructions/riscv/gpu/jal.rs | 4 +--- ceno_zkvm/src/instructions/riscv/gpu/jalr.rs | 4 +--- ceno_zkvm/src/instructions/riscv/gpu/load_sub.rs | 4 +--- ceno_zkvm/src/instructions/riscv/gpu/logic_i.rs | 4 +--- ceno_zkvm/src/instructions/riscv/gpu/logic_r.rs | 4 +--- ceno_zkvm/src/instructions/riscv/gpu/lui.rs | 4 +--- ceno_zkvm/src/instructions/riscv/gpu/lw.rs | 4 +--- ceno_zkvm/src/instructions/riscv/gpu/mul.rs | 4 +--- ceno_zkvm/src/instructions/riscv/gpu/sb.rs | 4 +--- ceno_zkvm/src/instructions/riscv/gpu/sh.rs | 4 +--- ceno_zkvm/src/instructions/riscv/gpu/shift_i.rs | 4 +--- ceno_zkvm/src/instructions/riscv/gpu/shift_r.rs | 4 +--- ceno_zkvm/src/instructions/riscv/gpu/slt.rs | 4 +--- ceno_zkvm/src/instructions/riscv/gpu/slti.rs | 4 +--- ceno_zkvm/src/instructions/riscv/gpu/sub.rs | 4 +--- ceno_zkvm/src/instructions/riscv/gpu/sw.rs | 4 +--- 21 files changed, 21 insertions(+), 63 deletions(-) diff --git a/ceno_zkvm/src/instructions/riscv/gpu/addi.rs b/ceno_zkvm/src/instructions/riscv/gpu/addi.rs index ef3339d82..80e8b94d8 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/addi.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/addi.rs @@ -176,11 +176,9 @@ mod tests { let cpu_data = cpu_witness.values(); assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); - let flat = col_map.to_flat(); let mut mismatches = 0; for row in 0..n { - for &col in &flat { - let c = col as usize; + for c in 0..num_witin { let gpu_val = gpu_data[row * num_witin + c]; let cpu_val = cpu_data[row * num_witin + c]; if gpu_val != cpu_val { diff --git a/ceno_zkvm/src/instructions/riscv/gpu/auipc.rs b/ceno_zkvm/src/instructions/riscv/gpu/auipc.rs index 9ba8fe9b4..32146fcb5 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/auipc.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/auipc.rs @@ -170,11 +170,9 @@ mod tests { let cpu_data = cpu_witness.values(); assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); - let flat = col_map.to_flat(); let mut mismatches = 0; for row in 0..n { - for &col in &flat { - let c = col as usize; + for c in 0..num_witin { let gpu_val = gpu_data[row * num_witin + c]; let cpu_val = cpu_data[row * num_witin + c]; if gpu_val != cpu_val { diff --git a/ceno_zkvm/src/instructions/riscv/gpu/branch_cmp.rs b/ceno_zkvm/src/instructions/riscv/gpu/branch_cmp.rs index acd311265..6aebfc0e4 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/branch_cmp.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/branch_cmp.rs @@ -196,11 +196,9 @@ mod tests { let cpu_data = cpu_witness.values(); assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); - let flat = col_map.to_flat(); let mut mismatches = 0; for row in 0..n { - for &col in &flat { - let c = col as usize; + for c in 0..num_witin { let gpu_val = gpu_data[row * num_witin + c]; let cpu_val = cpu_data[row * num_witin + c]; if gpu_val != cpu_val { diff --git a/ceno_zkvm/src/instructions/riscv/gpu/branch_eq.rs b/ceno_zkvm/src/instructions/riscv/gpu/branch_eq.rs index 6d1621633..3cd13663e 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/branch_eq.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/branch_eq.rs @@ -189,11 +189,9 @@ mod tests { let cpu_data = cpu_witness.values(); assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); - let flat = col_map.to_flat(); let mut mismatches = 0; for row in 0..n { - for &col in &flat { - let c = col as usize; + for c in 0..num_witin { let gpu_val = gpu_data[row * num_witin + c]; let cpu_val = cpu_data[row * num_witin + c]; if gpu_val != cpu_val { diff --git a/ceno_zkvm/src/instructions/riscv/gpu/div.rs b/ceno_zkvm/src/instructions/riscv/gpu/div.rs index 6e1aca75f..cfcf183c8 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/div.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/div.rs @@ -334,11 +334,9 @@ mod tests { let cpu_data = cpu_witness.values(); assert_eq!(gpu_data.len(), cpu_data.len(), "{}: Size mismatch", name); - let (n_entries, flat) = col_map.to_flat(); let mut mismatches = 0; for row in 0..n { - for &col in &flat[..n_entries] { - let c = col as usize; + for c in 0..num_witin { let gpu_val = gpu_data[row * num_witin + c]; let cpu_val = cpu_data[row * num_witin + c]; if gpu_val != cpu_val { diff --git a/ceno_zkvm/src/instructions/riscv/gpu/jal.rs b/ceno_zkvm/src/instructions/riscv/gpu/jal.rs index 815f7b809..96e64b548 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/jal.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/jal.rs @@ -148,11 +148,9 @@ mod tests { let cpu_data = cpu_witness.values(); assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); - let flat = col_map.to_flat(); let mut mismatches = 0; for row in 0..n { - for &col in &flat { - let c = col as usize; + for c in 0..num_witin { let gpu_val = gpu_data[row * num_witin + c]; let cpu_val = cpu_data[row * num_witin + c]; if gpu_val != cpu_val { diff --git a/ceno_zkvm/src/instructions/riscv/gpu/jalr.rs b/ceno_zkvm/src/instructions/riscv/gpu/jalr.rs index 8da4e6ba4..83b749abd 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/jalr.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/jalr.rs @@ -192,11 +192,9 @@ mod tests { let cpu_data = cpu_witness.values(); assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); - let flat = col_map.to_flat(); let mut mismatches = 0; for row in 0..n { - for &col in &flat { - let c = col as usize; + for c in 0..num_witin { let gpu_val = gpu_data[row * num_witin + c]; let cpu_val = cpu_data[row * num_witin + c]; if gpu_val != cpu_val { diff --git a/ceno_zkvm/src/instructions/riscv/gpu/load_sub.rs b/ceno_zkvm/src/instructions/riscv/gpu/load_sub.rs index 9c42e909b..8544d6159 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/load_sub.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/load_sub.rs @@ -366,11 +366,9 @@ mod tests { let cpu_data = cpu_witness.values(); assert_eq!(gpu_data.len(), cpu_data.len(), "{}: Size mismatch", name); - let (n_entries, flat) = col_map.to_flat(); let mut mismatches = 0; for row in 0..n { - for &col in &flat[..n_entries] { - let c = col as usize; + for c in 0..num_witin { let gpu_val = gpu_data[row * num_witin + c]; let cpu_val = cpu_data[row * num_witin + c]; if gpu_val != cpu_val { diff --git a/ceno_zkvm/src/instructions/riscv/gpu/logic_i.rs b/ceno_zkvm/src/instructions/riscv/gpu/logic_i.rs index ce4999c10..c47cf219a 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/logic_i.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/logic_i.rs @@ -182,11 +182,9 @@ mod tests { let cpu_data = cpu_witness.values(); assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); - let flat = col_map.to_flat(); let mut mismatches = 0; for row in 0..n { - for &col in &flat { - let c = col as usize; + for c in 0..num_witin { let gpu_val = gpu_data[row * num_witin + c]; let cpu_val = cpu_data[row * num_witin + c]; if gpu_val != cpu_val { diff --git a/ceno_zkvm/src/instructions/riscv/gpu/logic_r.rs b/ceno_zkvm/src/instructions/riscv/gpu/logic_r.rs index 1abc851fa..91b65b23c 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/logic_r.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/logic_r.rs @@ -178,11 +178,9 @@ mod tests { let cpu_data = cpu_witness.values(); assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); - let flat = col_map.to_flat(); let mut mismatches = 0; for row in 0..n { - for &col in &flat { - let c = col as usize; + for c in 0..num_witin { let gpu_val = gpu_data[row * num_witin + c]; let cpu_val = cpu_data[row * num_witin + c]; if gpu_val != cpu_val { diff --git a/ceno_zkvm/src/instructions/riscv/gpu/lui.rs b/ceno_zkvm/src/instructions/riscv/gpu/lui.rs index 5cd798073..e953ff5f3 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/lui.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/lui.rs @@ -160,11 +160,9 @@ mod tests { let cpu_data = cpu_witness.values(); assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); - let flat = col_map.to_flat(); let mut mismatches = 0; for row in 0..n { - for &col in &flat { - let c = col as usize; + for c in 0..num_witin { let gpu_val = gpu_data[row * num_witin + c]; let cpu_val = cpu_data[row * num_witin + c]; if gpu_val != cpu_val { diff --git a/ceno_zkvm/src/instructions/riscv/gpu/lw.rs b/ceno_zkvm/src/instructions/riscv/gpu/lw.rs index fdef5a686..b2d4ccd67 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/lw.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/lw.rs @@ -221,11 +221,9 @@ mod tests { let cpu_data = cpu_witness.values(); assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); - let (n_entries, flat) = col_map.to_flat(); let mut mismatches = 0; for row in 0..n { - for &col in &flat[..n_entries] { - let c = col as usize; + for c in 0..num_witin { let gpu_val = gpu_data[row * num_witin + c]; let cpu_val = cpu_data[row * num_witin + c]; if gpu_val != cpu_val { diff --git a/ceno_zkvm/src/instructions/riscv/gpu/mul.rs b/ceno_zkvm/src/instructions/riscv/gpu/mul.rs index 283e3ed23..dd5a61aec 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/mul.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/mul.rs @@ -276,11 +276,9 @@ mod tests { let cpu_data = cpu_witness.values(); assert_eq!(gpu_data.len(), cpu_data.len(), "{}: Size mismatch", name); - let (n_entries, flat) = col_map.to_flat(); let mut mismatches = 0; for row in 0..n { - for &col in &flat[..n_entries] { - let c = col as usize; + for c in 0..num_witin { let gpu_val = gpu_data[row * num_witin + c]; let cpu_val = cpu_data[row * num_witin + c]; if gpu_val != cpu_val { diff --git a/ceno_zkvm/src/instructions/riscv/gpu/sb.rs b/ceno_zkvm/src/instructions/riscv/gpu/sb.rs index 8b5c8689e..3cbdebb9b 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/sb.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/sb.rs @@ -246,11 +246,9 @@ mod tests { let cpu_data = cpu_witness.values(); assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); - let flat = col_map.to_flat(); let mut mismatches = 0; for row in 0..n { - for &col in &flat { - let c = col as usize; + for c in 0..num_witin { let gpu_val = gpu_data[row * num_witin + c]; let cpu_val = cpu_data[row * num_witin + c]; if gpu_val != cpu_val { diff --git a/ceno_zkvm/src/instructions/riscv/gpu/sh.rs b/ceno_zkvm/src/instructions/riscv/gpu/sh.rs index 1e4c103f5..721aea2e4 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/sh.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/sh.rs @@ -223,11 +223,9 @@ mod tests { let cpu_data = cpu_witness.values(); assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); - let flat = col_map.to_flat(); let mut mismatches = 0; for row in 0..n { - for &col in &flat { - let c = col as usize; + for c in 0..num_witin { let gpu_val = gpu_data[row * num_witin + c]; let cpu_val = cpu_data[row * num_witin + c]; if gpu_val != cpu_val { diff --git a/ceno_zkvm/src/instructions/riscv/gpu/shift_i.rs b/ceno_zkvm/src/instructions/riscv/gpu/shift_i.rs index f58bcf497..aaf1cf4bf 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/shift_i.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/shift_i.rs @@ -186,11 +186,9 @@ mod tests { let cpu_data = cpu_witness.values(); assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); - let flat = col_map.to_flat(); let mut mismatches = 0; for row in 0..n { - for &col in &flat { - let c = col as usize; + for c in 0..num_witin { let gpu_val = gpu_data[row * num_witin + c]; let cpu_val = cpu_data[row * num_witin + c]; if gpu_val != cpu_val { diff --git a/ceno_zkvm/src/instructions/riscv/gpu/shift_r.rs b/ceno_zkvm/src/instructions/riscv/gpu/shift_r.rs index e286936e8..183fc3571 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/shift_r.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/shift_r.rs @@ -196,11 +196,9 @@ mod tests { let cpu_data = cpu_witness.values(); assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); - let flat = col_map.to_flat(); let mut mismatches = 0; for row in 0..n { - for &col in &flat { - let c = col as usize; + for c in 0..num_witin { let gpu_val = gpu_data[row * num_witin + c]; let cpu_val = cpu_data[row * num_witin + c]; if gpu_val != cpu_val { diff --git a/ceno_zkvm/src/instructions/riscv/gpu/slt.rs b/ceno_zkvm/src/instructions/riscv/gpu/slt.rs index 985f5dc8b..f4f31b46f 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/slt.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/slt.rs @@ -196,11 +196,9 @@ mod tests { let cpu_data = cpu_witness.values(); assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); - let flat = col_map.to_flat(); let mut mismatches = 0; for row in 0..n { - for &col in &flat { - let c = col as usize; + for c in 0..num_witin { let gpu_val = gpu_data[row * num_witin + c]; let cpu_val = cpu_data[row * num_witin + c]; if gpu_val != cpu_val { diff --git a/ceno_zkvm/src/instructions/riscv/gpu/slti.rs b/ceno_zkvm/src/instructions/riscv/gpu/slti.rs index 713f9c19e..7bce1a84a 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/slti.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/slti.rs @@ -182,11 +182,9 @@ mod tests { let cpu_data = cpu_witness.values(); assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); - let flat = col_map.to_flat(); let mut mismatches = 0; for row in 0..n { - for &col in &flat { - let c = col as usize; + for c in 0..num_witin { let gpu_val = gpu_data[row * num_witin + c]; let cpu_val = cpu_data[row * num_witin + c]; if gpu_val != cpu_val { diff --git a/ceno_zkvm/src/instructions/riscv/gpu/sub.rs b/ceno_zkvm/src/instructions/riscv/gpu/sub.rs index a3aaaa9a1..6218b308f 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/sub.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/sub.rs @@ -200,11 +200,9 @@ mod tests { let cpu_data = cpu_witness.values(); assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); - let flat = col_map.to_flat(); let mut mismatches = 0; for row in 0..n { - for &col in &flat { - let c = col as usize; + for c in 0..num_witin { let gpu_val = gpu_data[row * num_witin + c]; let cpu_val = cpu_data[row * num_witin + c]; if gpu_val != cpu_val { diff --git a/ceno_zkvm/src/instructions/riscv/gpu/sw.rs b/ceno_zkvm/src/instructions/riscv/gpu/sw.rs index 81143a803..f6b88721b 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/sw.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/sw.rs @@ -210,11 +210,9 @@ mod tests { let cpu_data = cpu_witness.values(); assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); - let flat = col_map.to_flat(); let mut mismatches = 0; for row in 0..n { - for &col in &flat { - let c = col as usize; + for c in 0..num_witin { let gpu_val = gpu_data[row * num_witin + c]; let cpu_val = cpu_data[row * num_witin + c]; if gpu_val != cpu_val { From e943735108d9c4615eb9609ca3f2b0e7c0465125 Mon Sep 17 00:00:00 2001 From: Velaciela Date: Fri, 6 Mar 2026 10:43:54 +0800 Subject: [PATCH 26/73] test coverage: edge cases --- ceno_zkvm/src/instructions/riscv/gpu/add.rs | 18 ++++++++-- ceno_zkvm/src/instructions/riscv/gpu/div.rs | 33 ++++++++++++++++--- .../src/instructions/riscv/gpu/logic_i.rs | 18 ++++++++-- .../src/instructions/riscv/gpu/logic_r.rs | 18 ++++++++-- ceno_zkvm/src/instructions/riscv/gpu/mul.rs | 26 +++++++++++++-- .../src/instructions/riscv/gpu/shift_i.rs | 19 +++++++++-- .../src/instructions/riscv/gpu/shift_r.rs | 20 +++++++++-- ceno_zkvm/src/instructions/riscv/gpu/sub.rs | 18 ++++++++-- 8 files changed, 149 insertions(+), 21 deletions(-) diff --git a/ceno_zkvm/src/instructions/riscv/gpu/add.rs b/ceno_zkvm/src/instructions/riscv/gpu/add.rs index 2ee07edf7..e4e6e8f77 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/add.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/add.rs @@ -115,11 +115,25 @@ mod tests { type E = BabyBearExt4; fn make_test_steps(n: usize) -> Vec { + const EDGE_CASES: &[(u32, u32)] = &[ + (0, 0), + (0, 1), + (1, 0), + (u32::MAX, 1), // overflow + (u32::MAX, u32::MAX), // double overflow + (0x80000000, 0x80000000), // INT_MIN + INT_MIN + (0x7FFFFFFF, 1), // INT_MAX + 1 + (0xFFFF0000, 0x0000FFFF), // limb carry + ]; + let pc_start = 0x1000u32; (0..n) .map(|i| { - let rs1 = (i as u32) % 1000 + 1; - let rs2 = (i as u32) % 500 + 3; + let (rs1, rs2) = if i < EDGE_CASES.len() { + EDGE_CASES[i] + } else { + ((i as u32) % 1000 + 1, (i as u32) % 500 + 3) + }; let rd_before = (i as u32) % 200; let rd_after = rs1.wrapping_add(rs2); let cycle = 4 + (i as u64) * 4; diff --git a/ceno_zkvm/src/instructions/riscv/gpu/div.rs b/ceno_zkvm/src/instructions/riscv/gpu/div.rs index cfcf183c8..9c867aa60 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/div.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/div.rs @@ -236,15 +236,38 @@ mod tests { let num_structural_witin = cb.cs.num_structural_witin as usize; let n = 1024; + + const EDGE_CASES: &[(u32, u32)] = &[ + (0, 1), // 0 / 1 + (1, 1), // 1 / 1 + (0, 0), // 0 / 0 (zero divisor) + (12345, 0), // non-zero / 0 (zero divisor) + (u32::MAX, 0), // max / 0 (zero divisor) + (0x80000000, 0), // INT_MIN / 0 (zero divisor) + (0x80000000, 0xFFFFFFFF), // INT_MIN / -1 (signed overflow!) + (0x7FFFFFFF, 0xFFFFFFFF), // INT_MAX / -1 + (0xFFFFFFFF, 0xFFFFFFFF), // -1 / -1 + (0x80000000, 1), // INT_MIN / 1 + (0x80000000, 2), // INT_MIN / 2 + (u32::MAX, u32::MAX), // max / max + (u32::MAX, 1), // max / 1 + (1, u32::MAX), // 1 / max + ]; + let steps: Vec = (0..n) .map(|i| { let pc = ByteAddr(0x1000 + (i as u32) * 4); - // Use varied values; include zero divisor and edge cases - let rs1_val = (i as u32).wrapping_mul(12345).wrapping_add(7); - let rs2_val = if i % 50 == 0 { - 0 // test zero divisor + // Use edge cases first, then varied values with zero divisor + let (rs1_val, rs2_val) = if i < EDGE_CASES.len() { + EDGE_CASES[i] } else { - (i as u32).wrapping_mul(54321).wrapping_add(13) + let rs1 = (i as u32).wrapping_mul(12345).wrapping_add(7); + let rs2 = if i % 50 == 0 { + 0 // test zero divisor + } else { + (i as u32).wrapping_mul(54321).wrapping_add(13) + }; + (rs1, rs2) }; let rd_after = match insn_kind { InsnKind::DIV => { diff --git a/ceno_zkvm/src/instructions/riscv/gpu/logic_i.rs b/ceno_zkvm/src/instructions/riscv/gpu/logic_i.rs index c47cf219a..bf3e2b791 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/logic_i.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/logic_i.rs @@ -134,11 +134,25 @@ mod tests { let num_witin = cb.cs.num_witin as usize; let num_structural_witin = cb.cs.num_structural_witin as usize; + const EDGE_CASES: &[(u32, u32)] = &[ + (0, 0), + (u32::MAX, 0xFFF), // all bits AND max imm + (u32::MAX, 0), + (0, 0xFFF), + (0xAAAAAAAA, 0x555), // alternating + (0xFFFF0000, 0xFFF), + (0x12345678, 0x000), + (0xDEADBEEF, 0xABC), + ]; + let n = 1024; let steps: Vec = (0..n) .map(|i| { - let rs1 = (i as u32).wrapping_mul(0x01010101) ^ 0xabed_5eff; - let imm = (i as u32) % 4096; // 0..4095 (12-bit unsigned imm) + let (rs1, imm) = if i < EDGE_CASES.len() { + EDGE_CASES[i] + } else { + ((i as u32).wrapping_mul(0x01010101) ^ 0xabed_5eff, (i as u32) % 4096) + }; let rd_after = rs1 & imm; // ANDI let cycle = 4 + (i as u64) * 4; let pc = ByteAddr(0x1000 + (i as u32) * 4); diff --git a/ceno_zkvm/src/instructions/riscv/gpu/logic_r.rs b/ceno_zkvm/src/instructions/riscv/gpu/logic_r.rs index 91b65b23c..661f3b85a 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/logic_r.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/logic_r.rs @@ -132,11 +132,25 @@ mod tests { let num_witin = cb.cs.num_witin as usize; let num_structural_witin = cb.cs.num_structural_witin as usize; + const EDGE_CASES: &[(u32, u32)] = &[ + (0, 0), + (u32::MAX, u32::MAX), + (u32::MAX, 0), + (0, u32::MAX), + (0xAAAAAAAA, 0x55555555), // alternating bits + (0xFFFF0000, 0x0000FFFF), // no overlap + (0xDEADBEEF, 0xFFFFFFFF), // identity + (0x12345678, 0x00000000), // zero + ]; + let n = 1024; let steps: Vec = (0..n) .map(|i| { - let rs1 = 0xDEAD_0000u32 | (i as u32); - let rs2 = 0x00FF_FF00u32 | ((i as u32) << 8); + let (rs1, rs2) = if i < EDGE_CASES.len() { + EDGE_CASES[i] + } else { + (0xDEAD_0000u32 | (i as u32), 0x00FF_FF00u32 | ((i as u32) << 8)) + }; let rd_after = rs1 & rs2; // AND let cycle = 4 + (i as u64) * 4; let pc = ByteAddr(0x1000 + (i as u32) * 4); diff --git a/ceno_zkvm/src/instructions/riscv/gpu/mul.rs b/ceno_zkvm/src/instructions/riscv/gpu/mul.rs index dd5a61aec..fb845777c 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/mul.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/mul.rs @@ -199,13 +199,33 @@ mod tests { let num_witin = cb.cs.num_witin as usize; let num_structural_witin = cb.cs.num_structural_witin as usize; + const EDGE_CASES: &[(u32, u32)] = &[ + (0, 0), // zero * zero + (0, 12345), // zero * non-zero + (12345, 0), // non-zero * zero + (1, 1), // identity + (u32::MAX, 1), // max * 1 + (1, u32::MAX), // 1 * max + (u32::MAX, u32::MAX), // max * max + (0x80000000, 2), // INT_MIN * 2 (for MULH) + (2, 0x80000000), // 2 * INT_MIN + (0xFFFFFFFF, 0xFFFFFFFF), // (-1) * (-1) for signed + (0x80000000, 0xFFFFFFFF), // INT_MIN * (-1) + (0x7FFFFFFF, 0x7FFFFFFF), // INT_MAX * INT_MAX + ]; + let n = 1024; let steps: Vec = (0..n) .map(|i| { let pc = ByteAddr(0x1000 + (i as u32) * 4); - // Use varied values including negative interpretations - let rs1_val = (i as u32).wrapping_mul(12345).wrapping_add(7); - let rs2_val = (i as u32).wrapping_mul(54321).wrapping_add(13); + let (rs1_val, rs2_val) = if i < EDGE_CASES.len() { + EDGE_CASES[i] + } else { + ( + (i as u32).wrapping_mul(12345).wrapping_add(7), + (i as u32).wrapping_mul(54321).wrapping_add(13), + ) + }; let rd_after = match insn_kind { InsnKind::MUL => rs1_val.wrapping_mul(rs2_val), InsnKind::MULH => { diff --git a/ceno_zkvm/src/instructions/riscv/gpu/shift_i.rs b/ceno_zkvm/src/instructions/riscv/gpu/shift_i.rs index aaf1cf4bf..c413bc518 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/shift_i.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/shift_i.rs @@ -138,11 +138,26 @@ mod tests { let num_witin = cb.cs.num_witin as usize; let num_structural_witin = cb.cs.num_structural_witin as usize; + const EDGE_CASES: &[(u32, u32)] = &[ + (0, 0), + (1, 0), // shift by 0 + (1, 31), // shift to MSB + (u32::MAX, 0), // no shift + (u32::MAX, 16), // shift half + (u32::MAX, 31), // shift max + (0x80000000, 1), // INT_MIN << 1 + (0xDEADBEEF, 4), // nibble shift + ]; + let n = 1024; let steps: Vec = (0..n) .map(|i| { - let rs1 = (i as u32).wrapping_mul(0x01010101); - let shamt = (i as i32) % 32; // 0..31 + let (rs1, shamt) = if i < EDGE_CASES.len() { + let (r, s) = EDGE_CASES[i]; + (r, s as i32) + } else { + ((i as u32).wrapping_mul(0x01010101), (i as i32) % 32) + }; let rd_after = rs1 << (shamt as u32); let cycle = 4 + (i as u64) * 4; let pc = ByteAddr(0x1000 + (i as u32) * 4); diff --git a/ceno_zkvm/src/instructions/riscv/gpu/shift_r.rs b/ceno_zkvm/src/instructions/riscv/gpu/shift_r.rs index 183fc3571..ea3707e3e 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/shift_r.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/shift_r.rs @@ -152,12 +152,26 @@ mod tests { let num_witin = cb.cs.num_witin as usize; let num_structural_witin = cb.cs.num_structural_witin as usize; + const EDGE_CASES: &[(u32, u32)] = &[ + (0, 0), + (1, 0), // shift by 0 + (1, 31), // shift to MSB + (u32::MAX, 0), // no shift + (u32::MAX, 16), // shift half + (u32::MAX, 31), // shift max + (0x80000000, 1), // INT_MIN << 1 + (0xDEADBEEF, 4), // nibble shift + ]; + let n = 1024; let steps: Vec = (0..n) .map(|i| { - let rs1 = (i as u32).wrapping_mul(0x01010101); - let rs2 = (i as u32) % 32; - let rd_after = rs1 << rs2; + let (rs1, rs2) = if i < EDGE_CASES.len() { + EDGE_CASES[i] + } else { + ((i as u32).wrapping_mul(0x01010101), (i as u32) % 32) + }; + let rd_after = rs1 << (rs2 & 0x1F); let cycle = 4 + (i as u64) * 4; let pc = ByteAddr(0x1000 + (i as u32) * 4); let insn_code = encode_rv32(InsnKind::SLL, 2, 3, 4, 0); diff --git a/ceno_zkvm/src/instructions/riscv/gpu/sub.rs b/ceno_zkvm/src/instructions/riscv/gpu/sub.rs index 6218b308f..5f2424060 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/sub.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/sub.rs @@ -154,11 +154,25 @@ mod tests { let num_witin = cb.cs.num_witin as usize; let num_structural_witin = cb.cs.num_structural_witin as usize; + const EDGE_CASES: &[(u32, u32)] = &[ + (0, 0), + (0, 1), // underflow + (1, 0), + (0, u32::MAX), // underflow + (u32::MAX, u32::MAX), + (0x80000000, 1), // INT_MIN - 1 + (0, 0x80000000), // 0 - INT_MIN + (0x7FFFFFFF, 0xFFFFFFFF), // INT_MAX - (-1) + ]; + let n = 1024; let steps: Vec = (0..n) .map(|i| { - let rs1 = (i as u32) % 1000 + 500; - let rs2 = (i as u32) % 300 + 1; + let (rs1, rs2) = if i < EDGE_CASES.len() { + EDGE_CASES[i] + } else { + ((i as u32) % 1000 + 500, (i as u32) % 300 + 1) + }; let rd_after = rs1.wrapping_sub(rs2); let cycle = 4 + (i as u64) * 4; let pc = ByteAddr(0x1000 + (i as u32) * 4); From bcdc2a3d90cbb21112731e20aaa3cfd792ac49c3 Mon Sep 17 00:00:00 2001 From: Velaciela Date: Fri, 6 Mar 2026 11:17:01 +0800 Subject: [PATCH 27/73] gpu witgen: col-major --- ceno_zkvm/src/instructions/riscv/gpu/add.rs | 8 ++--- ceno_zkvm/src/instructions/riscv/gpu/addi.rs | 2 +- ceno_zkvm/src/instructions/riscv/gpu/auipc.rs | 2 +- .../src/instructions/riscv/gpu/branch_cmp.rs | 2 +- .../src/instructions/riscv/gpu/branch_eq.rs | 2 +- ceno_zkvm/src/instructions/riscv/gpu/div.rs | 2 +- ceno_zkvm/src/instructions/riscv/gpu/jal.rs | 2 +- ceno_zkvm/src/instructions/riscv/gpu/jalr.rs | 2 +- .../src/instructions/riscv/gpu/load_sub.rs | 2 +- .../src/instructions/riscv/gpu/logic_i.rs | 2 +- .../src/instructions/riscv/gpu/logic_r.rs | 2 +- ceno_zkvm/src/instructions/riscv/gpu/lui.rs | 2 +- ceno_zkvm/src/instructions/riscv/gpu/lw.rs | 2 +- ceno_zkvm/src/instructions/riscv/gpu/mul.rs | 2 +- ceno_zkvm/src/instructions/riscv/gpu/sb.rs | 2 +- ceno_zkvm/src/instructions/riscv/gpu/sh.rs | 2 +- .../src/instructions/riscv/gpu/shift_i.rs | 2 +- .../src/instructions/riscv/gpu/shift_r.rs | 2 +- ceno_zkvm/src/instructions/riscv/gpu/slt.rs | 2 +- ceno_zkvm/src/instructions/riscv/gpu/slti.rs | 2 +- ceno_zkvm/src/instructions/riscv/gpu/sub.rs | 2 +- ceno_zkvm/src/instructions/riscv/gpu/sw.rs | 2 +- .../src/instructions/riscv/gpu/witgen_gpu.rs | 36 +++++++++++++++---- 23 files changed, 55 insertions(+), 31 deletions(-) diff --git a/ceno_zkvm/src/instructions/riscv/gpu/add.rs b/ceno_zkvm/src/instructions/riscv/gpu/add.rs index e4e6e8f77..e6478a7a4 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/add.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/add.rs @@ -230,19 +230,19 @@ mod tests { .witgen_add(&col_map, &gpu_records, &indices_u32, shard_offset, None) .unwrap(); - // D2H copy + // D2H copy (GPU output is column-major) let gpu_data: Vec<::BaseField> = gpu_result.device_buffer.to_vec().unwrap(); - // Compare element by element + // Compare element by element (GPU is column-major, CPU is row-major) let cpu_data = cpu_witness.values(); assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); let mut mismatches = 0; for row in 0..n { for col in 0..num_witin { - let gpu_val = gpu_data[row * num_witin + col]; - let cpu_val = cpu_data[row * num_witin + col]; + let gpu_val = gpu_data[col * n + row]; // column-major + let cpu_val = cpu_data[row * num_witin + col]; // row-major if gpu_val != cpu_val { if mismatches < 10 { eprintln!( diff --git a/ceno_zkvm/src/instructions/riscv/gpu/addi.rs b/ceno_zkvm/src/instructions/riscv/gpu/addi.rs index 80e8b94d8..07df75aae 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/addi.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/addi.rs @@ -179,7 +179,7 @@ mod tests { let mut mismatches = 0; for row in 0..n { for c in 0..num_witin { - let gpu_val = gpu_data[row * num_witin + c]; + let gpu_val = gpu_data[c * n + row]; let cpu_val = cpu_data[row * num_witin + c]; if gpu_val != cpu_val { if mismatches < 10 { diff --git a/ceno_zkvm/src/instructions/riscv/gpu/auipc.rs b/ceno_zkvm/src/instructions/riscv/gpu/auipc.rs index 32146fcb5..80d2b67cc 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/auipc.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/auipc.rs @@ -173,7 +173,7 @@ mod tests { let mut mismatches = 0; for row in 0..n { for c in 0..num_witin { - let gpu_val = gpu_data[row * num_witin + c]; + let gpu_val = gpu_data[c * n + row]; let cpu_val = cpu_data[row * num_witin + c]; if gpu_val != cpu_val { if mismatches < 10 { diff --git a/ceno_zkvm/src/instructions/riscv/gpu/branch_cmp.rs b/ceno_zkvm/src/instructions/riscv/gpu/branch_cmp.rs index 6aebfc0e4..39674a9eb 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/branch_cmp.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/branch_cmp.rs @@ -199,7 +199,7 @@ mod tests { let mut mismatches = 0; for row in 0..n { for c in 0..num_witin { - let gpu_val = gpu_data[row * num_witin + c]; + let gpu_val = gpu_data[c * n + row]; let cpu_val = cpu_data[row * num_witin + c]; if gpu_val != cpu_val { if mismatches < 10 { diff --git a/ceno_zkvm/src/instructions/riscv/gpu/branch_eq.rs b/ceno_zkvm/src/instructions/riscv/gpu/branch_eq.rs index 3cd13663e..2b312cfc8 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/branch_eq.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/branch_eq.rs @@ -192,7 +192,7 @@ mod tests { let mut mismatches = 0; for row in 0..n { for c in 0..num_witin { - let gpu_val = gpu_data[row * num_witin + c]; + let gpu_val = gpu_data[c * n + row]; let cpu_val = cpu_data[row * num_witin + c]; if gpu_val != cpu_val { if mismatches < 10 { diff --git a/ceno_zkvm/src/instructions/riscv/gpu/div.rs b/ceno_zkvm/src/instructions/riscv/gpu/div.rs index 9c867aa60..7d89d4ec6 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/div.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/div.rs @@ -360,7 +360,7 @@ mod tests { let mut mismatches = 0; for row in 0..n { for c in 0..num_witin { - let gpu_val = gpu_data[row * num_witin + c]; + let gpu_val = gpu_data[c * n + row]; let cpu_val = cpu_data[row * num_witin + c]; if gpu_val != cpu_val { if mismatches < 10 { diff --git a/ceno_zkvm/src/instructions/riscv/gpu/jal.rs b/ceno_zkvm/src/instructions/riscv/gpu/jal.rs index 96e64b548..8a9be12de 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/jal.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/jal.rs @@ -151,7 +151,7 @@ mod tests { let mut mismatches = 0; for row in 0..n { for c in 0..num_witin { - let gpu_val = gpu_data[row * num_witin + c]; + let gpu_val = gpu_data[c * n + row]; let cpu_val = cpu_data[row * num_witin + c]; if gpu_val != cpu_val { if mismatches < 10 { diff --git a/ceno_zkvm/src/instructions/riscv/gpu/jalr.rs b/ceno_zkvm/src/instructions/riscv/gpu/jalr.rs index 83b749abd..11bce8623 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/jalr.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/jalr.rs @@ -195,7 +195,7 @@ mod tests { let mut mismatches = 0; for row in 0..n { for c in 0..num_witin { - let gpu_val = gpu_data[row * num_witin + c]; + let gpu_val = gpu_data[c * n + row]; let cpu_val = cpu_data[row * num_witin + c]; if gpu_val != cpu_val { if mismatches < 10 { diff --git a/ceno_zkvm/src/instructions/riscv/gpu/load_sub.rs b/ceno_zkvm/src/instructions/riscv/gpu/load_sub.rs index 8544d6159..163efa681 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/load_sub.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/load_sub.rs @@ -369,7 +369,7 @@ mod tests { let mut mismatches = 0; for row in 0..n { for c in 0..num_witin { - let gpu_val = gpu_data[row * num_witin + c]; + let gpu_val = gpu_data[c * n + row]; let cpu_val = cpu_data[row * num_witin + c]; if gpu_val != cpu_val { if mismatches < 10 { diff --git a/ceno_zkvm/src/instructions/riscv/gpu/logic_i.rs b/ceno_zkvm/src/instructions/riscv/gpu/logic_i.rs index bf3e2b791..3a54fe8e1 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/logic_i.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/logic_i.rs @@ -199,7 +199,7 @@ mod tests { let mut mismatches = 0; for row in 0..n { for c in 0..num_witin { - let gpu_val = gpu_data[row * num_witin + c]; + let gpu_val = gpu_data[c * n + row]; let cpu_val = cpu_data[row * num_witin + c]; if gpu_val != cpu_val { if mismatches < 10 { diff --git a/ceno_zkvm/src/instructions/riscv/gpu/logic_r.rs b/ceno_zkvm/src/instructions/riscv/gpu/logic_r.rs index 661f3b85a..b798b1115 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/logic_r.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/logic_r.rs @@ -195,7 +195,7 @@ mod tests { let mut mismatches = 0; for row in 0..n { for c in 0..num_witin { - let gpu_val = gpu_data[row * num_witin + c]; + let gpu_val = gpu_data[c * n + row]; let cpu_val = cpu_data[row * num_witin + c]; if gpu_val != cpu_val { if mismatches < 10 { diff --git a/ceno_zkvm/src/instructions/riscv/gpu/lui.rs b/ceno_zkvm/src/instructions/riscv/gpu/lui.rs index e953ff5f3..fa52596a7 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/lui.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/lui.rs @@ -163,7 +163,7 @@ mod tests { let mut mismatches = 0; for row in 0..n { for c in 0..num_witin { - let gpu_val = gpu_data[row * num_witin + c]; + let gpu_val = gpu_data[c * n + row]; let cpu_val = cpu_data[row * num_witin + c]; if gpu_val != cpu_val { if mismatches < 10 { diff --git a/ceno_zkvm/src/instructions/riscv/gpu/lw.rs b/ceno_zkvm/src/instructions/riscv/gpu/lw.rs index b2d4ccd67..5a4dd1b91 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/lw.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/lw.rs @@ -224,7 +224,7 @@ mod tests { let mut mismatches = 0; for row in 0..n { for c in 0..num_witin { - let gpu_val = gpu_data[row * num_witin + c]; + let gpu_val = gpu_data[c * n + row]; let cpu_val = cpu_data[row * num_witin + c]; if gpu_val != cpu_val { if mismatches < 10 { diff --git a/ceno_zkvm/src/instructions/riscv/gpu/mul.rs b/ceno_zkvm/src/instructions/riscv/gpu/mul.rs index fb845777c..9cf8ec04a 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/mul.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/mul.rs @@ -299,7 +299,7 @@ mod tests { let mut mismatches = 0; for row in 0..n { for c in 0..num_witin { - let gpu_val = gpu_data[row * num_witin + c]; + let gpu_val = gpu_data[c * n + row]; let cpu_val = cpu_data[row * num_witin + c]; if gpu_val != cpu_val { if mismatches < 10 { diff --git a/ceno_zkvm/src/instructions/riscv/gpu/sb.rs b/ceno_zkvm/src/instructions/riscv/gpu/sb.rs index 3cbdebb9b..3638b56c4 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/sb.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/sb.rs @@ -249,7 +249,7 @@ mod tests { let mut mismatches = 0; for row in 0..n { for c in 0..num_witin { - let gpu_val = gpu_data[row * num_witin + c]; + let gpu_val = gpu_data[c * n + row]; let cpu_val = cpu_data[row * num_witin + c]; if gpu_val != cpu_val { if mismatches < 10 { diff --git a/ceno_zkvm/src/instructions/riscv/gpu/sh.rs b/ceno_zkvm/src/instructions/riscv/gpu/sh.rs index 721aea2e4..225ca91d3 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/sh.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/sh.rs @@ -226,7 +226,7 @@ mod tests { let mut mismatches = 0; for row in 0..n { for c in 0..num_witin { - let gpu_val = gpu_data[row * num_witin + c]; + let gpu_val = gpu_data[c * n + row]; let cpu_val = cpu_data[row * num_witin + c]; if gpu_val != cpu_val { if mismatches < 10 { diff --git a/ceno_zkvm/src/instructions/riscv/gpu/shift_i.rs b/ceno_zkvm/src/instructions/riscv/gpu/shift_i.rs index c413bc518..61342d270 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/shift_i.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/shift_i.rs @@ -204,7 +204,7 @@ mod tests { let mut mismatches = 0; for row in 0..n { for c in 0..num_witin { - let gpu_val = gpu_data[row * num_witin + c]; + let gpu_val = gpu_data[c * n + row]; let cpu_val = cpu_data[row * num_witin + c]; if gpu_val != cpu_val { if mismatches < 10 { diff --git a/ceno_zkvm/src/instructions/riscv/gpu/shift_r.rs b/ceno_zkvm/src/instructions/riscv/gpu/shift_r.rs index ea3707e3e..447a519a7 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/shift_r.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/shift_r.rs @@ -213,7 +213,7 @@ mod tests { let mut mismatches = 0; for row in 0..n { for c in 0..num_witin { - let gpu_val = gpu_data[row * num_witin + c]; + let gpu_val = gpu_data[c * n + row]; let cpu_val = cpu_data[row * num_witin + c]; if gpu_val != cpu_val { if mismatches < 10 { diff --git a/ceno_zkvm/src/instructions/riscv/gpu/slt.rs b/ceno_zkvm/src/instructions/riscv/gpu/slt.rs index f4f31b46f..2412a5dc8 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/slt.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/slt.rs @@ -199,7 +199,7 @@ mod tests { let mut mismatches = 0; for row in 0..n { for c in 0..num_witin { - let gpu_val = gpu_data[row * num_witin + c]; + let gpu_val = gpu_data[c * n + row]; let cpu_val = cpu_data[row * num_witin + c]; if gpu_val != cpu_val { if mismatches < 10 { diff --git a/ceno_zkvm/src/instructions/riscv/gpu/slti.rs b/ceno_zkvm/src/instructions/riscv/gpu/slti.rs index 7bce1a84a..9c7cf262f 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/slti.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/slti.rs @@ -185,7 +185,7 @@ mod tests { let mut mismatches = 0; for row in 0..n { for c in 0..num_witin { - let gpu_val = gpu_data[row * num_witin + c]; + let gpu_val = gpu_data[c * n + row]; let cpu_val = cpu_data[row * num_witin + c]; if gpu_val != cpu_val { if mismatches < 10 { diff --git a/ceno_zkvm/src/instructions/riscv/gpu/sub.rs b/ceno_zkvm/src/instructions/riscv/gpu/sub.rs index 5f2424060..4531cfe38 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/sub.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/sub.rs @@ -217,7 +217,7 @@ mod tests { let mut mismatches = 0; for row in 0..n { for c in 0..num_witin { - let gpu_val = gpu_data[row * num_witin + c]; + let gpu_val = gpu_data[c * n + row]; let cpu_val = cpu_data[row * num_witin + c]; if gpu_val != cpu_val { if mismatches < 10 { diff --git a/ceno_zkvm/src/instructions/riscv/gpu/sw.rs b/ceno_zkvm/src/instructions/riscv/gpu/sw.rs index f6b88721b..5fb0a799b 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/sw.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/sw.rs @@ -213,7 +213,7 @@ mod tests { let mut mismatches = 0; for row in 0..n { for c in 0..num_witin { - let gpu_val = gpu_data[row * num_witin + c]; + let gpu_val = gpu_data[c * n + row]; let cpu_val = cpu_data[row * num_witin + c]; if gpu_val != cpu_val { if mismatches < 10 { diff --git a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs index 6f1ea430b..1b62697d0 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs @@ -5,7 +5,7 @@ /// 2. Runs a CPU loop to collect side effects (shard_ctx.send, lk_multiplicity) /// 3. Returns the GPU-generated witness + CPU-collected side effects use ceno_emul::{StepIndex, StepRecord}; -use ceno_gpu::{Buffer, CudaHal, CudaSlice, bb31::CudaHalBB31}; +use ceno_gpu::{Buffer, CudaHal, CudaSlice, bb31::CudaHalBB31, common::transpose::matrix_transpose}; use ff_ext::ExtensionField; use gkr_iop::utils::lk_multiplicity::Multiplicity; use multilinear_extensions::util::max_usable_threads; @@ -255,9 +255,10 @@ fn gpu_assign_instances_inner>( } raw_structural.padding_by_strategy(); - // Step 4: Convert GPU witness to RowMajorMatrix - let mut raw_witin = info_span!("d2h_copy").in_scope(|| { + // Step 4: Transpose (column-major → row-major) on GPU, then D2H copy to RowMajorMatrix + let mut raw_witin = info_span!("transpose_d2h").in_scope(|| { gpu_witness_to_rmm::( + hal, gpu_witness, total_instances, num_witin, @@ -764,8 +765,12 @@ fn collect_side_effects>( Ok(lk_multiplicity) } -/// Convert GPU device buffer to RowMajorMatrix via D2H copy. +/// Convert GPU device buffer (column-major) to RowMajorMatrix via GPU transpose + D2H copy. +/// +/// GPU witgen kernels output column-major layout for better memory coalescing. +/// This function transposes to row-major on GPU before copying to host. fn gpu_witness_to_rmm( + hal: &CudaHalBB31, gpu_result: ceno_gpu::common::witgen_types::GpuWitnessResult< ceno_gpu::common::BufferImpl<'static, ::BaseField>, >, @@ -773,8 +778,27 @@ fn gpu_witness_to_rmm( num_cols: usize, padding: InstancePaddingStrategy, ) -> Result, ZKVMError> { - let gpu_data: Vec<::BaseField> = gpu_result - .device_buffer + // Transpose from column-major to row-major on GPU. + // Column-major (num_rows x num_cols) is stored as num_cols groups of num_rows elements, + // which is equivalent to a (num_cols x num_rows) row-major matrix. + // Transposing with cols=num_rows, rows=num_cols produces (num_rows x num_cols) row-major. + let mut rmm_buffer = hal + .alloc_elems_on_device(num_rows * num_cols, false, None) + .map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU alloc for transpose failed: {e}").into()) + })?; + matrix_transpose::( + &hal.inner, + &mut rmm_buffer, + &gpu_result.device_buffer, + num_rows, + num_cols, + ) + .map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU transpose failed: {e}").into()) + })?; + + let gpu_data: Vec<::BaseField> = rmm_buffer .to_vec() .map_err(|e| ZKVMError::InvalidWitness(format!("GPU D2H copy failed: {e}").into()))?; From 1e21d37fd5e231533ff6f6f2f92e8dbf5559bf11 Mon Sep 17 00:00:00 2001 From: Velaciela Date: Mon, 9 Mar 2026 00:47:28 +0800 Subject: [PATCH 28/73] phase5 --- ceno_zkvm/src/e2e.rs | 359 ++++- ceno_zkvm/src/instructions.rs | 110 ++ ceno_zkvm/src/instructions/riscv/arith.rs | 53 +- .../riscv/arith_imm/arith_imm_circuit_v2.rs | 31 + ceno_zkvm/src/instructions/riscv/auipc.rs | 60 + ceno_zkvm/src/instructions/riscv/b_insn.rs | 29 +- .../riscv/branch/branch_circuit_v2.rs | 44 + .../instructions/riscv/div/div_circuit_v2.rs | 120 +- ceno_zkvm/src/instructions/riscv/gpu/add.rs | 86 +- ceno_zkvm/src/instructions/riscv/gpu/addi.rs | 12 +- ceno_zkvm/src/instructions/riscv/gpu/auipc.rs | 36 +- .../src/instructions/riscv/gpu/branch_cmp.rs | 43 +- .../src/instructions/riscv/gpu/branch_eq.rs | 49 +- ceno_zkvm/src/instructions/riscv/gpu/div.rs | 120 +- ceno_zkvm/src/instructions/riscv/gpu/jal.rs | 24 +- ceno_zkvm/src/instructions/riscv/gpu/jalr.rs | 36 +- .../src/instructions/riscv/gpu/load_sub.rs | 92 +- .../src/instructions/riscv/gpu/logic_i.rs | 33 +- .../src/instructions/riscv/gpu/logic_r.rs | 122 +- ceno_zkvm/src/instructions/riscv/gpu/lui.rs | 12 +- ceno_zkvm/src/instructions/riscv/gpu/lw.rs | 63 +- ceno_zkvm/src/instructions/riscv/gpu/mod.rs | 32 +- ceno_zkvm/src/instructions/riscv/gpu/mul.rs | 118 +- ceno_zkvm/src/instructions/riscv/gpu/sb.rs | 34 +- ceno_zkvm/src/instructions/riscv/gpu/sh.rs | 34 +- .../src/instructions/riscv/gpu/shift_i.rs | 62 +- .../src/instructions/riscv/gpu/shift_r.rs | 78 +- ceno_zkvm/src/instructions/riscv/gpu/slt.rs | 27 +- ceno_zkvm/src/instructions/riscv/gpu/slti.rs | 18 +- ceno_zkvm/src/instructions/riscv/gpu/sub.rs | 29 +- ceno_zkvm/src/instructions/riscv/gpu/sw.rs | 28 +- .../src/instructions/riscv/gpu/witgen_gpu.rs | 1078 ++++++++++++--- ceno_zkvm/src/instructions/riscv/i_insn.rs | 29 +- ceno_zkvm/src/instructions/riscv/im_insn.rs | 31 +- ceno_zkvm/src/instructions/riscv/insn_base.rs | 260 ++++ ceno_zkvm/src/instructions/riscv/j_insn.rs | 27 +- .../src/instructions/riscv/jump/jal_v2.rs | 42 + .../src/instructions/riscv/jump/jalr_v2.rs | 40 + .../instructions/riscv/logic/logic_circuit.rs | 50 +- .../riscv/logic_imm/logic_imm_circuit_v2.rs | 47 +- ceno_zkvm/src/instructions/riscv/lui.rs | 38 +- .../src/instructions/riscv/memory/load.rs | 60 + .../src/instructions/riscv/memory/load_v2.rs | 80 +- .../src/instructions/riscv/memory/store_v2.rs | 54 + .../riscv/mulh/mulh_circuit_v2.rs | 124 +- ceno_zkvm/src/instructions/riscv/r_insn.rs | 31 +- ceno_zkvm/src/instructions/riscv/s_insn.rs | 32 +- .../riscv/shift/shift_circuit_v2.rs | 129 +- .../instructions/riscv/slt/slt_circuit_v2.rs | 42 + .../riscv/slti/slti_circuit_v2.rs | 41 + ceno_zkvm/src/instructions/side_effects.rs | 1157 +++++++++++++++++ ceno_zkvm/src/structs.rs | 8 + 52 files changed, 4866 insertions(+), 528 deletions(-) create mode 100644 ceno_zkvm/src/instructions/side_effects.rs diff --git a/ceno_zkvm/src/e2e.rs b/ceno_zkvm/src/e2e.rs index 0d7676074..63be3240b 100644 --- a/ceno_zkvm/src/e2e.rs +++ b/ceno_zkvm/src/e2e.rs @@ -28,6 +28,8 @@ use ceno_emul::{ StepCellExtractor, StepIndex, StepRecord, SyscallWitness, Tracer, VM_REG_COUNT, VMState, WORD_SIZE, Word, WordAddr, host_utils::read_all_messages, }; +#[cfg(feature = "gpu")] +use ceno_gpu::CudaHal; use clap::ValueEnum; use either::Either; use ff_ext::{ExtensionField, SmallField}; @@ -251,6 +253,39 @@ impl<'a> Default for ShardContext<'a> { /// for example, if there are 10 shards and 3 provers, /// the shard counts will be distributed as 3, 3, and 4, ensuring an even workload across all provers. impl<'a> ShardContext<'a> { + /// Create a new ShardContext with the same shard metadata but empty record storage. + /// Useful for debug comparisons against the actual shard context. + pub fn new_empty_like(&self) -> ShardContext<'static> { + let max_threads = max_usable_threads(); + ShardContext { + shard_id: self.shard_id, + num_shards: self.num_shards, + max_cycle: self.max_cycle, + addr_future_accesses: self.addr_future_accesses.clone(), + addr_accessed_tbs: Either::Left(vec![Vec::new(); max_threads]), + read_records_tbs: Either::Left( + (0..max_threads) + .map(|_| BTreeMap::new()) + .collect::>(), + ), + write_records_tbs: Either::Left( + (0..max_threads) + .map(|_| BTreeMap::new()) + .collect::>(), + ), + cur_shard_cycle_range: self.cur_shard_cycle_range.clone(), + expected_inst_per_shard: self.expected_inst_per_shard, + max_num_cross_shard_accesses: self.max_num_cross_shard_accesses, + prev_shard_cycle_range: self.prev_shard_cycle_range.clone(), + prev_shard_heap_range: self.prev_shard_heap_range.clone(), + prev_shard_hint_range: self.prev_shard_hint_range.clone(), + platform: self.platform.clone(), + shard_heap_addr_range: self.shard_heap_addr_range.clone(), + shard_hint_addr_range: self.shard_hint_addr_range.clone(), + syscall_witnesses: self.syscall_witnesses.clone(), + } + } + pub fn get_forked(&mut self) -> Vec> { match ( &mut self.read_records_tbs, @@ -391,9 +426,39 @@ impl<'a> ShardContext<'a> { }) } + #[inline(always)] + pub fn insert_read_record(&mut self, addr: WordAddr, record: RAMRecord) { + let ram_record = self + .read_records_tbs + .as_mut() + .right() + .expect("illegal type"); + ram_record.insert(addr, record); + } + + #[inline(always)] + pub fn insert_write_record(&mut self, addr: WordAddr, record: RAMRecord) { + let ram_record = self + .write_records_tbs + .as_mut() + .right() + .expect("illegal type"); + ram_record.insert(addr, record); + } + + #[inline(always)] + pub fn push_addr_accessed(&mut self, addr: WordAddr) { + let addr_accessed = self + .addr_accessed_tbs + .as_mut() + .right() + .expect("illegal type"); + addr_accessed.push(addr); + } + #[inline(always)] #[allow(clippy::too_many_arguments)] - pub fn send( + pub fn record_send_without_touch( &mut self, ram_type: crate::structs::RAMType, addr: WordAddr, @@ -410,15 +475,9 @@ impl<'a> ShardContext<'a> { let addr_raw = addr.baddr().0; let is_heap = self.platform.heap.contains(&addr_raw); let is_hint = self.platform.hints.contains(&addr_raw); - // 1. checking reads from the external bus if prev_cycle > 0 || (prev_cycle == 0 && (!is_heap && !is_hint)) { let prev_shard_id = self.extract_shard_id_by_cycle(prev_cycle); - let ram_record = self - .read_records_tbs - .as_mut() - .right() - .expect("illegal type"); - ram_record.insert( + self.insert_read_record( addr, RAMRecord { ram_type, @@ -437,22 +496,15 @@ impl<'a> ShardContext<'a> { prev_cycle == 0 && (is_heap || is_hint), "addr {addr_raw:x} prev_cycle {prev_cycle}, is_heap {is_heap}, is_hint {is_hint}", ); - // 2. handle heap/hint initial reads outside the shard range. let prev_shard_id = if is_heap && !self.shard_heap_addr_range.contains(&addr_raw) { Some(self.extract_shard_id_by_heap_addr(addr_raw)) } else if is_hint && !self.shard_hint_addr_range.contains(&addr_raw) { Some(self.extract_shard_id_by_hint_addr(addr_raw)) } else { - // dynamic init in current shard, skip and do nothing None }; if let Some(prev_shard_id) = prev_shard_id { - let ram_record = self - .read_records_tbs - .as_mut() - .right() - .expect("illegal type"); - ram_record.insert( + self.insert_read_record( addr, RAMRecord { ram_type, @@ -470,18 +522,12 @@ impl<'a> ShardContext<'a> { } } - // check write to external mem bus if let Some(future_touch_cycle) = self.find_future_next_access(cycle, addr) && self.after_current_shard_cycle(future_touch_cycle) && self.is_in_current_shard(cycle) { let shard_cycle = self.aligned_current_ts(cycle); - let ram_record = self - .write_records_tbs - .as_mut() - .right() - .expect("illegal type"); - ram_record.insert( + self.insert_write_record( addr, RAMRecord { ram_type, @@ -496,13 +542,22 @@ impl<'a> ShardContext<'a> { }, ); } + } - let addr_accessed = self - .addr_accessed_tbs - .as_mut() - .right() - .expect("illegal type"); - addr_accessed.push(addr); + #[inline(always)] + #[allow(clippy::too_many_arguments)] + pub fn send( + &mut self, + ram_type: crate::structs::RAMType, + addr: WordAddr, + id: u64, + cycle: Cycle, + prev_cycle: Cycle, + value: Word, + prev_value: Option, + ) { + self.record_send_without_touch(ram_type, addr, id, cycle, prev_cycle, value, prev_value); + self.push_addr_accessed(addr); } /// merge addr accessed in different threads @@ -1341,6 +1396,9 @@ pub fn generate_witness<'a, E: ExtensionField>( } let time = std::time::Instant::now(); + let debug_compare_e2e_shard = + std::env::var_os("CENO_GPU_DEBUG_COMPARE_E2E_SHARD").is_some(); + let debug_shard_ctx_template = debug_compare_e2e_shard.then(|| clone_debug_shard_ctx(&shard_ctx)); system_config .config .assign_opcode_circuit( @@ -1355,7 +1413,16 @@ pub fn generate_witness<'a, E: ExtensionField>( // Free GPU shard_steps cache after all opcode circuits are done. #[cfg(feature = "gpu")] - crate::instructions::riscv::gpu::witgen_gpu::invalidate_shard_steps_cache(); + { + crate::instructions::riscv::gpu::witgen_gpu::invalidate_shard_steps_cache(); + if std::env::var_os("CENO_GPU_TRIM_AFTER_WITGEN").is_some() { + use gkr_iop::gpu::gpu_prover::get_cuda_hal; + + let cuda_hal = get_cuda_hal().unwrap(); + cuda_hal.inner().trim_mem_pool().unwrap(); + cuda_hal.inner().synchronize().unwrap(); + } + } let time = std::time::Instant::now(); system_config @@ -1371,6 +1438,50 @@ pub fn generate_witness<'a, E: ExtensionField>( tracing::debug!("assign_dummy_config finish in {:?}", time.elapsed()); zkvm_witness.finalize_lk_multiplicities(); + if let Some(mut cpu_shard_ctx) = debug_shard_ctx_template { + let mut cpu_witness = ZKVMWitnesses::default(); + let mut cpu_dispatch_ctx = system_config.inst_dispatch_builder.to_dispatch_ctx(); + cpu_dispatch_ctx.begin_shard(); + for (step_idx, step) in shard_steps.iter().enumerate() { + cpu_dispatch_ctx.ingest_step(step_idx, step); + } + + // Force CPU path for the debug comparison (thread-local, no env var races). + #[cfg(feature = "gpu")] + crate::instructions::riscv::gpu::witgen_gpu::set_force_cpu_path(true); + + system_config + .config + .assign_opcode_circuit( + &system_config.zkvm_cs, + &mut cpu_shard_ctx, + &mut cpu_dispatch_ctx, + shard_steps, + &mut cpu_witness, + ) + .unwrap(); + system_config + .dummy_config + .assign_opcode_circuit( + &system_config.zkvm_cs, + &mut cpu_shard_ctx, + &cpu_dispatch_ctx, + shard_steps, + &mut cpu_witness, + ) + .unwrap(); + cpu_witness.finalize_lk_multiplicities(); + + #[cfg(feature = "gpu")] + crate::instructions::riscv::gpu::witgen_gpu::set_force_cpu_path(false); + + log_shard_ctx_diff("post_opcode_assignment", &cpu_shard_ctx, &shard_ctx); + + // Compare combined_lk_mlt (the merged LK after finalize_lk_multiplicities). + // This catches issues where per-chip LK appears correct but the merge differs. + log_combined_lk_diff(&cpu_witness, &zkvm_witness); + } + // Memory record routing (per address / waddr) // // Legend: @@ -2057,6 +2168,194 @@ pub fn run_e2e_verify>( } } +fn clone_debug_shard_ctx(src: &ShardContext) -> ShardContext<'static> { + let mut cloned = ShardContext::default(); + cloned.shard_id = src.shard_id; + cloned.num_shards = src.num_shards; + cloned.max_cycle = src.max_cycle; + cloned.addr_future_accesses = src.addr_future_accesses.clone(); + cloned.cur_shard_cycle_range = src.cur_shard_cycle_range.clone(); + cloned.expected_inst_per_shard = src.expected_inst_per_shard; + cloned.max_num_cross_shard_accesses = src.max_num_cross_shard_accesses; + cloned.prev_shard_cycle_range = src.prev_shard_cycle_range.clone(); + cloned.prev_shard_heap_range = src.prev_shard_heap_range.clone(); + cloned.prev_shard_hint_range = src.prev_shard_hint_range.clone(); + cloned.platform = src.platform.clone(); + cloned.shard_heap_addr_range = src.shard_heap_addr_range.clone(); + cloned.shard_hint_addr_range = src.shard_hint_addr_range.clone(); + cloned.syscall_witnesses = src.syscall_witnesses.clone(); + cloned +} + +fn flatten_ram_records( + records: &[BTreeMap], +) -> Vec<(u32, u64, u64, u64, u64, Option, u32, usize)> { + let mut flat = Vec::new(); + for table in records { + for (addr, record) in table { + flat.push(( + addr.0, + record.reg_id, + record.prev_cycle, + record.cycle, + record.shard_cycle, + record.prev_value, + record.value, + record.shard_id, + )); + } + } + flat +} + +fn log_shard_ctx_diff(kind: &str, cpu: &ShardContext, gpu: &ShardContext) { + let cpu_addr = cpu.get_addr_accessed(); + let gpu_addr = gpu.get_addr_accessed(); + if cpu_addr != gpu_addr { + tracing::error!( + "[GPU e2e debug] {} addr_accessed cpu={} gpu={}", + kind, + cpu_addr.len(), + gpu_addr.len() + ); + } + + let cpu_reads = flatten_ram_records(cpu.read_records()); + let gpu_reads = flatten_ram_records(gpu.read_records()); + if cpu_reads != gpu_reads { + tracing::error!( + "[GPU e2e debug] {} read_records cpu={} gpu={}", + kind, + cpu_reads.len(), + gpu_reads.len() + ); + } + + let cpu_writes = flatten_ram_records(cpu.write_records()); + let gpu_writes = flatten_ram_records(gpu.write_records()); + if cpu_writes != gpu_writes { + tracing::error!( + "[GPU e2e debug] {} write_records cpu={} gpu={}", + kind, + cpu_writes.len(), + gpu_writes.len() + ); + } +} + +fn log_combined_lk_diff( + cpu_witness: &ZKVMWitnesses, + gpu_witness: &ZKVMWitnesses, +) { + let cpu_combined = cpu_witness.combined_lk_mlt().expect("cpu combined_lk_mlt"); + let gpu_combined = gpu_witness.combined_lk_mlt().expect("gpu combined_lk_mlt"); + + let table_names = [ + "Dynamic", "DoubleU8", "And", "Or", "Xor", "Ltu", "Pow", "Instruction", + ]; + + let mut total_diffs = 0usize; + for (table_idx, (cpu_table, gpu_table)) in + cpu_combined.iter().zip(gpu_combined.iter()).enumerate() + { + let mut keys: Vec = cpu_table + .keys() + .chain(gpu_table.keys()) + .copied() + .collect(); + keys.sort_unstable(); + keys.dedup(); + + let mut table_diffs = 0usize; + for &key in &keys { + let cpu_count = cpu_table.get(&key).copied().unwrap_or(0); + let gpu_count = gpu_table.get(&key).copied().unwrap_or(0); + if cpu_count != gpu_count { + table_diffs += 1; + if table_diffs <= 8 { + let name = table_names.get(table_idx).unwrap_or(&"Unknown"); + tracing::error!( + "[GPU e2e debug] combined_lk table={} key={} cpu={} gpu={}", + name, + key, + cpu_count, + gpu_count + ); + } + } + } + total_diffs += table_diffs; + if table_diffs > 8 { + let name = table_names.get(table_idx).unwrap_or(&"Unknown"); + tracing::error!( + "[GPU e2e debug] combined_lk table={} total_diffs={} (showing first 8)", + name, + table_diffs + ); + } + } + + // Also compare per-chip LK multiplicities + let cpu_lk_keys: std::collections::BTreeSet<_> = cpu_witness.lk_mlts().keys().collect(); + let gpu_lk_keys: std::collections::BTreeSet<_> = gpu_witness.lk_mlts().keys().collect(); + if cpu_lk_keys != gpu_lk_keys { + tracing::error!( + "[GPU e2e debug] lk_mlts key mismatch: cpu_only={:?} gpu_only={:?}", + cpu_lk_keys.difference(&gpu_lk_keys).collect::>(), + gpu_lk_keys.difference(&cpu_lk_keys).collect::>(), + ); + } + for name in cpu_lk_keys.intersection(&gpu_lk_keys) { + let cpu_lk = cpu_witness.lk_mlts().get(*name).unwrap(); + let gpu_lk = gpu_witness.lk_mlts().get(*name).unwrap(); + let mut chip_diffs = 0usize; + for (t_idx, (ct, gt)) in cpu_lk.iter().zip(gpu_lk.iter()).enumerate() { + let mut ks: Vec = ct.keys().chain(gt.keys()).copied().collect(); + ks.sort_unstable(); + ks.dedup(); + for &k in &ks { + let cv = ct.get(&k).copied().unwrap_or(0); + let gv = gt.get(&k).copied().unwrap_or(0); + if cv != gv { + chip_diffs += 1; + if chip_diffs <= 4 { + let tname = table_names.get(t_idx).unwrap_or(&"Unknown"); + tracing::error!( + "[GPU e2e debug] per_chip_lk chip={} table={} key={} cpu={} gpu={}", + name, + tname, + k, + cv, + gv + ); + } + } + } + } + if chip_diffs > 0 { + total_diffs += chip_diffs; + tracing::error!( + "[GPU e2e debug] per_chip_lk chip={} total_diffs={}", + name, + chip_diffs + ); + } + } + + if total_diffs == 0 { + tracing::info!( + "[GPU e2e debug] combined_lk_mlt + per_chip_lk: CPU/GPU match (tables={}, chips={})", + cpu_combined.len(), + cpu_lk_keys.len() + ); + } else { + tracing::error!( + "[GPU e2e debug] TOTAL LK DIFFS = {} (combined + per-chip)", + total_diffs + ); + } +} + #[cfg(debug_assertions)] fn debug_memory_ranges<'a, T: Tracer, I: Iterator>( vm: &VMState, diff --git a/ceno_zkvm/src/instructions.rs b/ceno_zkvm/src/instructions.rs index 89deb6fbc..71467d370 100644 --- a/ceno_zkvm/src/instructions.rs +++ b/ceno_zkvm/src/instructions.rs @@ -20,11 +20,14 @@ use rayon::{ use witness::{InstancePaddingStrategy, RowMajorMatrix}; pub mod riscv; +pub mod side_effects; pub trait Instruction { type InstructionConfig: Send + Sync; type InsnType: Clone + Copy; + const GPU_SIDE_EFFECTS: bool = false; + fn padding_strategy() -> InstancePaddingStrategy { InstancePaddingStrategy::Default } @@ -96,6 +99,36 @@ pub trait Instruction { step: &StepRecord, ) -> Result<(), ZKVMError>; + fn collect_side_effects_instance( + _config: &Self::InstructionConfig, + _shard_ctx: &mut ShardContext, + _lk_multiplicity: &mut LkMultiplicity, + _step: &StepRecord, + ) -> Result<(), ZKVMError> { + Err(ZKVMError::InvalidWitness( + format!( + "{} does not implement lightweight side effects collection", + Self::name() + ) + .into(), + )) + } + + fn collect_shard_side_effects_instance( + _config: &Self::InstructionConfig, + _shard_ctx: &mut ShardContext, + _lk_multiplicity: &mut LkMultiplicity, + _step: &StepRecord, + ) -> Result<(), ZKVMError> { + Err(ZKVMError::InvalidWitness( + format!( + "{} does not implement shard-only side effects collection", + Self::name() + ) + .into(), + )) + } + fn assign_instances( config: &Self::InstructionConfig, shard_ctx: &mut ShardContext, @@ -262,3 +295,80 @@ pub fn cpu_assign_instances>( lk_multiplicity.into_finalize_result(), )) } + +/// CPU-only side-effect collection for GPU-enabled instructions. +/// +/// This path deliberately avoids scratch witness buffers and calls only the +/// instruction-specific side-effect collector. +pub fn cpu_collect_side_effects>( + config: &I::InstructionConfig, + shard_ctx: &mut ShardContext, + shard_steps: &[StepRecord], + step_indices: &[StepIndex], +) -> Result, ZKVMError> { + cpu_collect_side_effects_inner::(config, shard_ctx, shard_steps, step_indices, false) +} + +/// CPU-side `send()` / `addr_accessed` collection for GPU-assisted lk paths. +/// +/// Implementations may still increment fetch multiplicity on CPU, but all other +/// lookup multiplicities are expected to come from the GPU path. +pub fn cpu_collect_shard_side_effects>( + config: &I::InstructionConfig, + shard_ctx: &mut ShardContext, + shard_steps: &[StepRecord], + step_indices: &[StepIndex], +) -> Result, ZKVMError> { + cpu_collect_side_effects_inner::(config, shard_ctx, shard_steps, step_indices, true) +} + +fn cpu_collect_side_effects_inner>( + config: &I::InstructionConfig, + shard_ctx: &mut ShardContext, + shard_steps: &[StepRecord], + step_indices: &[StepIndex], + shard_only: bool, +) -> Result, ZKVMError> { + let nthreads = max_usable_threads(); + let total = step_indices.len(); + let batch_size = if total > 256 { + total.div_ceil(nthreads) + } else { + total + } + .max(1); + + let lk_multiplicity = LkMultiplicity::default(); + let shard_ctx_vec = shard_ctx.get_forked(); + + step_indices + .par_chunks(batch_size) + .zip(shard_ctx_vec) + .flat_map(|(indices, mut shard_ctx)| { + let mut lk_multiplicity = lk_multiplicity.clone(); + indices + .iter() + .copied() + .map(|step_idx| { + if shard_only { + I::collect_shard_side_effects_instance( + config, + &mut shard_ctx, + &mut lk_multiplicity, + &shard_steps[step_idx], + ) + } else { + I::collect_side_effects_instance( + config, + &mut shard_ctx, + &mut lk_multiplicity, + &shard_steps[step_idx], + ) + } + }) + .collect::>() + }) + .collect::>()?; + + Ok(lk_multiplicity.into_finalize_result()) +} diff --git a/ceno_zkvm/src/instructions/riscv/arith.rs b/ceno_zkvm/src/instructions/riscv/arith.rs index 7ce605971..5e9291499 100644 --- a/ceno_zkvm/src/instructions/riscv/arith.rs +++ b/ceno_zkvm/src/instructions/riscv/arith.rs @@ -2,8 +2,16 @@ use std::marker::PhantomData; use super::{RIVInstruction, constants::UInt, r_insn::RInstructionConfig}; use crate::{ - circuit_builder::CircuitBuilder, e2e::ShardContext, error::ZKVMError, - instructions::Instruction, structs::ProgramParams, uint::Value, witness::LkMultiplicity, + circuit_builder::CircuitBuilder, + e2e::ShardContext, + error::ZKVMError, + instructions::{ + Instruction, + side_effects::{CpuSideEffectSink, emit_u16_limbs}, + }, + structs::ProgramParams, + uint::Value, + witness::LkMultiplicity, }; use ceno_emul::{InsnKind, StepRecord}; use ff_ext::ExtensionField; @@ -43,6 +51,8 @@ impl Instruction for ArithInstruction; type InsnType = InsnKind; + const GPU_SIDE_EFFECTS: bool = matches!(I::INST_KIND, InsnKind::ADD | InsnKind::SUB); + fn inst_kinds() -> &'static [Self::InsnType] { &[I::INST_KIND] } @@ -140,6 +150,45 @@ impl Instruction for ArithInstruction Result<(), ZKVMError> { + let shard_ctx_ptr = shard_ctx as *mut ShardContext; + let shard_ctx_view = unsafe { &*shard_ctx_ptr }; + let mut sink = unsafe { CpuSideEffectSink::from_raw(shard_ctx_ptr, lk_multiplicity) }; + config + .r_insn + .collect_side_effects(&mut sink, shard_ctx_view, step); + + match I::INST_KIND { + InsnKind::ADD => { + emit_u16_limbs(&mut sink, step.rd().unwrap().value.after); + } + InsnKind::SUB => { + emit_u16_limbs(&mut sink, step.rd().unwrap().value.after); + emit_u16_limbs(&mut sink, step.rs1().unwrap().value); + } + _ => unreachable!("Unsupported instruction kind"), + } + + Ok(()) + } + + fn collect_shard_side_effects_instance( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) -> Result<(), ZKVMError> { + config + .r_insn + .collect_shard_effects(shard_ctx, lk_multiplicity, step); + Ok(()) + } + #[cfg(feature = "gpu")] fn assign_instances( config: &Self::InstructionConfig, diff --git a/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit_v2.rs index b29edfb25..53219490c 100644 --- a/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit_v2.rs @@ -6,6 +6,7 @@ use crate::{ instructions::{ Instruction, riscv::{RIVInstruction, constants::UInt, i_insn::IInstructionConfig}, + side_effects::{CpuSideEffectSink, emit_u16_limbs}, }, structs::ProgramParams, utils::{imm_sign_extend, imm_sign_extend_circuit}, @@ -41,6 +42,8 @@ impl Instruction for AddiInstruction { type InstructionConfig = InstructionConfig; type InsnType = InsnKind; + const GPU_SIDE_EFFECTS: bool = true; + fn inst_kinds() -> &'static [Self::InsnType] { &[InsnKind::ADDI] } @@ -112,6 +115,34 @@ impl Instruction for AddiInstruction { Ok(()) } + fn collect_side_effects_instance( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) -> Result<(), ZKVMError> { + let shard_ctx_ptr = shard_ctx as *mut ShardContext; + let shard_ctx_view = unsafe { &*shard_ctx_ptr }; + let mut sink = unsafe { CpuSideEffectSink::from_raw(shard_ctx_ptr, lk_multiplicity) }; + config + .i_insn + .collect_side_effects(&mut sink, shard_ctx_view, step); + emit_u16_limbs(&mut sink, step.rd().unwrap().value.after); + Ok(()) + } + + fn collect_shard_side_effects_instance( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) -> Result<(), ZKVMError> { + config + .i_insn + .collect_shard_effects(shard_ctx, lk_multiplicity, step); + Ok(()) + } + #[cfg(feature = "gpu")] fn assign_instances( config: &Self::InstructionConfig, diff --git a/ceno_zkvm/src/instructions/riscv/auipc.rs b/ceno_zkvm/src/instructions/riscv/auipc.rs index 46573f09d..d7ada6ff6 100644 --- a/ceno_zkvm/src/instructions/riscv/auipc.rs +++ b/ceno_zkvm/src/instructions/riscv/auipc.rs @@ -12,6 +12,10 @@ use crate::{ constants::{PC_BITS, UINT_BYTE_LIMBS, UInt8}, i_insn::IInstructionConfig, }, + side_effects::{ + CpuSideEffectSink, LkOp, SideEffectSink, emit_byte_decomposition_ops, + emit_const_range_op, + }, }, structs::ProgramParams, tables::InsnRecord, @@ -46,6 +50,8 @@ impl Instruction for AuipcInstruction { type InstructionConfig = AuipcConfig; type InsnType = InsnKind; + const GPU_SIDE_EFFECTS: bool = true; + fn inst_kinds() -> &'static [Self::InsnType] { &[InsnKind::AUIPC] } @@ -193,6 +199,60 @@ impl Instruction for AuipcInstruction { Ok(()) } + fn collect_side_effects_instance( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &ceno_emul::StepRecord, + ) -> Result<(), ZKVMError> { + let shard_ctx_ptr = shard_ctx as *mut ShardContext; + let shard_ctx_view = unsafe { &*shard_ctx_ptr }; + let mut sink = unsafe { CpuSideEffectSink::from_raw(shard_ctx_ptr, lk_multiplicity) }; + config + .i_insn + .collect_side_effects(&mut sink, shard_ctx_view, step); + + let rd_written = split_to_u8(step.rd().unwrap().value.after); + emit_byte_decomposition_ops(&mut sink, &rd_written); + + let pc = split_to_u8(step.pc().before.0); + // Only iterate over the middle limbs that have witness columns (pc_limbs has UINT_BYTE_LIMBS-2 elements). + // The MSB limb is range-checked via XOR below, the LSB is shared with rd_written[0]. + for val in pc.iter().skip(1).take(config.pc_limbs.len()) { + emit_const_range_op(&mut sink, *val as u64, 8); + } + + let imm = InsnRecord::::imm_internal(&step.insn()).0 as u32; + for val in split_to_u8::(imm) + .into_iter() + .take(config.imm_limbs.len()) + { + emit_const_range_op(&mut sink, val as u64, 8); + } + + let last_limb_bits = PC_BITS - UInt8::::LIMB_BITS * (UINT_BYTE_LIMBS - 1); + let additional_bits = + (last_limb_bits..UInt8::::LIMB_BITS).fold(0, |acc, x| acc + (1 << x)); + sink.emit_lk(LkOp::Xor { + a: pc[3], + b: additional_bits as u8, + }); + + Ok(()) + } + + fn collect_shard_side_effects_instance( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &ceno_emul::StepRecord, + ) -> Result<(), ZKVMError> { + config + .i_insn + .collect_shard_effects(shard_ctx, lk_multiplicity, step); + Ok(()) + } + #[cfg(feature = "gpu")] fn assign_instances( config: &Self::InstructionConfig, diff --git a/ceno_zkvm/src/instructions/riscv/b_insn.rs b/ceno_zkvm/src/instructions/riscv/b_insn.rs index cdc1db56d..c33ca6037 100644 --- a/ceno_zkvm/src/instructions/riscv/b_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/b_insn.rs @@ -7,7 +7,10 @@ use crate::{ circuit_builder::CircuitBuilder, e2e::ShardContext, error::ZKVMError, - instructions::riscv::insn_base::{ReadRS1, ReadRS2, StateInOut}, + instructions::{ + riscv::insn_base::{ReadRS1, ReadRS2, StateInOut}, + side_effects::{LkOp, SideEffectSink}, + }, tables::InsnRecord, witness::{LkMultiplicity, set_val}, }; @@ -111,4 +114,28 @@ impl BInstructionConfig { Ok(()) } + + pub fn collect_shard_effects( + &self, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) { + lk_multiplicity.fetch(step.pc().before.0); + self.rs1.collect_shard_effects(shard_ctx, step); + self.rs2.collect_shard_effects(shard_ctx, step); + } + + pub fn collect_side_effects( + &self, + sink: &mut impl SideEffectSink, + shard_ctx: &ShardContext, + step: &StepRecord, + ) { + sink.emit_lk(LkOp::Fetch { + pc: step.pc().before.0, + }); + self.rs1.collect_side_effects(sink, shard_ctx, step); + self.rs2.collect_side_effects(sink, shard_ctx, step); + } } diff --git a/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs index 8bec503bb..0951ab7ba 100644 --- a/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs @@ -11,6 +11,7 @@ use crate::{ b_insn::BInstructionConfig, constants::{UINT_LIMBS, UInt}, }, + side_effects::{CpuSideEffectSink, emit_uint_limbs_lt_ops}, }, structs::ProgramParams, witness::LkMultiplicity, @@ -41,6 +42,8 @@ impl Instruction for BranchCircuit; type InsnType = InsnKind; + const GPU_SIDE_EFFECTS: bool = true; + fn inst_kinds() -> &'static [Self::InsnType] { &[I::INST_KIND] } @@ -205,6 +208,47 @@ impl Instruction for BranchCircuit Result<(), ZKVMError> { + let shard_ctx_ptr = shard_ctx as *mut ShardContext; + let shard_ctx_view = unsafe { &*shard_ctx_ptr }; + let mut sink = unsafe { CpuSideEffectSink::from_raw(shard_ctx_ptr, lk_multiplicity) }; + config + .b_insn + .collect_side_effects(&mut sink, shard_ctx_view, step); + + if !matches!(I::INST_KIND, InsnKind::BEQ | InsnKind::BNE) { + let rs1_value = Value::new_unchecked(step.rs1().unwrap().value); + let rs2_value = Value::new_unchecked(step.rs2().unwrap().value); + let rs1_limbs = rs1_value.as_u16_limbs(); + let rs2_limbs = rs2_value.as_u16_limbs(); + emit_uint_limbs_lt_ops( + &mut sink, + matches!(I::INST_KIND, InsnKind::BLT | InsnKind::BGE), + &rs1_limbs, + &rs2_limbs, + ); + } + + Ok(()) + } + + fn collect_shard_side_effects_instance( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) -> Result<(), ZKVMError> { + config + .b_insn + .collect_shard_effects(shard_ctx, lk_multiplicity, step); + Ok(()) + } + #[cfg(feature = "gpu")] fn assign_instances( config: &Self::InstructionConfig, diff --git a/ceno_zkvm/src/instructions/riscv/div/div_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/div/div_circuit_v2.rs index a124c6768..474174ed9 100644 --- a/ceno_zkvm/src/instructions/riscv/div/div_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/div/div_circuit_v2.rs @@ -14,7 +14,11 @@ use crate::{ circuit_builder::CircuitBuilder, e2e::ShardContext, error::ZKVMError, - instructions::{Instruction, riscv::constants::LIMB_BITS}, + instructions::{ + Instruction, + riscv::constants::LIMB_BITS, + side_effects::{CpuSideEffectSink, LkOp, SideEffectSink, emit_u16_limbs}, + }, structs::ProgramParams, uint::Value, witness::{LkMultiplicity, set_val}, @@ -50,6 +54,8 @@ impl Instruction for ArithInstruction; type InsnType = InsnKind; + const GPU_SIDE_EFFECTS: bool = true; + fn inst_kinds() -> &'static [Self::InsnType] { &[I::INST_KIND] } @@ -399,7 +405,12 @@ impl Instruction for ArithInstruction 3u32, _ => { return crate::instructions::cpu_assign_instances::( - config, shard_ctx, num_witin, num_structural_witin, shard_steps, step_indices, + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, ); } }; @@ -570,6 +581,111 @@ impl Instruction for ArithInstruction Result<(), ZKVMError> { + let shard_ctx_ptr = shard_ctx as *mut ShardContext; + let shard_ctx_view = unsafe { &*shard_ctx_ptr }; + let mut sink = unsafe { CpuSideEffectSink::from_raw(shard_ctx_ptr, lkm) }; + config + .r_insn + .collect_side_effects(&mut sink, shard_ctx_view, step); + + let dividend = step.rs1().unwrap().value; + let divisor = step.rs2().unwrap().value; + let dividend_value = Value::new_unchecked(dividend); + let divisor_value = Value::new_unchecked(divisor); + let dividend_limbs = dividend_value.as_u16_limbs(); + let divisor_limbs = divisor_value.as_u16_limbs(); + + let signed = matches!(I::INST_KIND, InsnKind::DIV | InsnKind::REM); + let (quotient, remainder, dividend_sign, divisor_sign, quotient_sign, case) = + run_divrem(signed, &u32_to_limbs(÷nd), &u32_to_limbs(&divisor)); + + emit_u16_limbs(&mut sink, limbs_to_u32("ient)); + emit_u16_limbs(&mut sink, limbs_to_u32(&remainder)); + + let carries = run_mul_carries( + signed, + &u32_to_limbs(&divisor), + "ient, + &remainder, + quotient_sign, + ); + for i in 0..UINT_LIMBS { + sink.emit_lk(LkOp::DynamicRange { + value: carries[i] as u64, + bits: (LIMB_BITS + 2) as u32, + }); + sink.emit_lk(LkOp::DynamicRange { + value: carries[i + UINT_LIMBS] as u64, + bits: (LIMB_BITS + 2) as u32, + }); + } + + let sign_xor = dividend_sign ^ divisor_sign; + let remainder_prime = if sign_xor { + negate(&remainder) + } else { + remainder + }; + let remainder_zero = + remainder.iter().all(|&v| v == 0) && case != DivRemCoreSpecialCase::ZeroDivisor; + + if signed { + let dividend_sign_mask = if dividend_sign { + 1 << (LIMB_BITS - 1) + } else { + 0 + }; + let divisor_sign_mask = if divisor_sign { + 1 << (LIMB_BITS - 1) + } else { + 0 + }; + sink.emit_lk(LkOp::DynamicRange { + value: ((dividend_limbs[UINT_LIMBS - 1] as u64 - dividend_sign_mask) << 1), + bits: 16, + }); + sink.emit_lk(LkOp::DynamicRange { + value: ((divisor_limbs[UINT_LIMBS - 1] as u64 - divisor_sign_mask) << 1), + bits: 16, + }); + } + + if case == DivRemCoreSpecialCase::None && !remainder_zero { + let idx = run_sltu_diff_idx(&u32_to_limbs(&divisor), &remainder_prime, divisor_sign); + let val = if divisor_sign { + remainder_prime[idx] - divisor_limbs[idx] as u32 + } else { + divisor_limbs[idx] as u32 - remainder_prime[idx] + }; + sink.emit_lk(LkOp::DynamicRange { + value: val as u64 - 1, + bits: 16, + }); + } else { + sink.emit_lk(LkOp::DynamicRange { value: 0, bits: 16 }); + } + + Ok(()) + } + + fn collect_shard_side_effects_instance( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) -> Result<(), ZKVMError> { + config + .r_insn + .collect_shard_effects(shard_ctx, lk_multiplicity, step); + Ok(()) + } } #[derive(Debug, Eq, PartialEq)] diff --git a/ceno_zkvm/src/instructions/riscv/gpu/add.rs b/ceno_zkvm/src/instructions/riscv/gpu/add.rs index e6478a7a4..1e2f30ad6 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/add.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/add.rs @@ -114,15 +114,44 @@ mod tests { type E = BabyBearExt4; + fn flatten_records( + records: &[std::collections::BTreeMap], + ) -> Vec<(ceno_emul::WordAddr, u64, u64, usize)> { + records + .iter() + .flat_map(|table| { + table + .iter() + .map(|(addr, record)| (*addr, record.prev_cycle, record.cycle, record.shard_id)) + }) + .collect() + } + + fn flatten_lk( + multiplicity: &gkr_iop::utils::lk_multiplicity::Multiplicity, + ) -> Vec> { + multiplicity + .iter() + .map(|table| { + let mut entries = table + .iter() + .map(|(key, count)| (*key, *count)) + .collect::>(); + entries.sort_unstable(); + entries + }) + .collect() + } + fn make_test_steps(n: usize) -> Vec { const EDGE_CASES: &[(u32, u32)] = &[ (0, 0), (0, 1), (1, 0), - (u32::MAX, 1), // overflow - (u32::MAX, u32::MAX), // double overflow + (u32::MAX, 1), // overflow + (u32::MAX, u32::MAX), // double overflow (0x80000000, 0x80000000), // INT_MIN + INT_MIN - (0x7FFFFFFF, 1), // INT_MAX + 1 + (0x7FFFFFFF, 1), // INT_MAX + 1 (0xFFFF0000, 0x0000FFFF), // limb carry ]; @@ -203,15 +232,16 @@ mod tests { // CPU path let mut shard_ctx = ShardContext::default(); - let (cpu_rmms, _lkm) = crate::instructions::cpu_assign_instances::>( - &config, - &mut shard_ctx, - num_witin, - num_structural_witin, - &steps, - &indices, - ) - .unwrap(); + let (cpu_rmms, cpu_lkm) = + crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(); let cpu_witness = &cpu_rmms[0]; // GPU path (AOS with indirect indexing) @@ -255,5 +285,37 @@ mod tests { } } assert_eq!(mismatches, 0, "Found {} mismatches", mismatches); + + let mut shard_ctx_full_gpu = ShardContext::default(); + let (gpu_rmms, gpu_lkm) = + crate::instructions::riscv::gpu::witgen_gpu::try_gpu_assign_instances::< + E, + AddInstruction, + >( + &config, + &mut shard_ctx_full_gpu, + num_witin, + num_structural_witin, + &steps, + &indices, + crate::instructions::riscv::gpu::witgen_gpu::GpuWitgenKind::Add, + ) + .unwrap() + .expect("GPU path should be available"); + + assert_eq!(gpu_rmms[0].values(), cpu_rmms[0].values()); + assert_eq!(flatten_lk(&gpu_lkm), flatten_lk(&cpu_lkm)); + assert_eq!( + shard_ctx_full_gpu.get_addr_accessed(), + shard_ctx.get_addr_accessed() + ); + assert_eq!( + flatten_records(shard_ctx_full_gpu.read_records()), + flatten_records(shard_ctx.read_records()) + ); + assert_eq!( + flatten_records(shard_ctx_full_gpu.write_records()), + flatten_records(shard_ctx.write_records()) + ); } } diff --git a/ceno_zkvm/src/instructions/riscv/gpu/addi.rs b/ceno_zkvm/src/instructions/riscv/gpu/addi.rs index 07df75aae..5b61d38ee 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/addi.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/addi.rs @@ -103,7 +103,10 @@ mod tests { assert!( (col as usize) < col_map.num_cols as usize, "Column {} (index {}) out of range: {} >= {}", - i, col, col, col_map.num_cols + i, + col, + col, + col_map.num_cols ); } let mut seen = std::collections::HashSet::new(); @@ -151,7 +154,12 @@ mod tests { let mut shard_ctx = ShardContext::default(); let (cpu_rmms, _lkm) = crate::instructions::cpu_assign_instances::>( - &config, &mut shard_ctx, num_witin, num_structural_witin, &steps, &indices, + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, ) .unwrap(); let cpu_witness = &cpu_rmms[0]; diff --git a/ceno_zkvm/src/instructions/riscv/gpu/auipc.rs b/ceno_zkvm/src/instructions/riscv/gpu/auipc.rs index 80d2b67cc..c0663d880 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/auipc.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/auipc.rs @@ -39,14 +39,19 @@ pub fn extract_auipc_column_map( // AUIPC-specific let rd_bytes: [u32; 4] = { - let l = config.rd_written.wits_in().expect("rd_written UInt8 WitIns"); + let l = config + .rd_written + .wits_in() + .expect("rd_written UInt8 WitIns"); assert_eq!(l.len(), 4); - [l[0].id as u32, l[1].id as u32, l[2].id as u32, l[3].id as u32] + [ + l[0].id as u32, + l[1].id as u32, + l[2].id as u32, + l[3].id as u32, + ] }; - let pc_limbs: [u32; 2] = [ - config.pc_limbs[0].id as u32, - config.pc_limbs[1].id as u32, - ]; + let pc_limbs: [u32; 2] = [config.pc_limbs[0].id as u32, config.pc_limbs[1].id as u32]; let imm_limbs: [u32; 3] = [ config.imm_limbs[0].id as u32, config.imm_limbs[1].id as u32, @@ -96,7 +101,10 @@ mod tests { assert!( (col as usize) < col_map.num_cols as usize, "Column {} (index {}) out of range: {} >= {}", - i, col, col, col_map.num_cols + i, + col, + col, + col_map.num_cols ); } let mut seen = std::collections::HashSet::new(); @@ -143,11 +151,15 @@ mod tests { let indices: Vec = (0..n).collect(); let mut shard_ctx = ShardContext::default(); - let (cpu_rmms, _lkm) = - crate::instructions::cpu_assign_instances::>( - &config, &mut shard_ctx, num_witin, num_structural_witin, &steps, &indices, - ) - .unwrap(); + let (cpu_rmms, _lkm) = crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(); let cpu_witness = &cpu_rmms[0]; let col_map = extract_auipc_column_map(&config, num_witin); diff --git a/ceno_zkvm/src/instructions/riscv/gpu/branch_cmp.rs b/ceno_zkvm/src/instructions/riscv/gpu/branch_cmp.rs index 39674a9eb..dfb9cd775 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/branch_cmp.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/branch_cmp.rs @@ -9,18 +9,12 @@ pub fn extract_branch_cmp_column_map( num_witin: usize, ) -> BranchCmpColumnMap { let rs1_limbs: [u32; 2] = { - let limbs = config - .read_rs1 - .wits_in() - .expect("rs1 WitIn"); + let limbs = config.read_rs1.wits_in().expect("rs1 WitIn"); assert_eq!(limbs.len(), 2); [limbs[0].id as u32, limbs[1].id as u32] }; let rs2_limbs: [u32; 2] = { - let limbs = config - .read_rs2 - .wits_in() - .expect("rs2 WitIn"); + let limbs = config.read_rs2.wits_in().expect("rs2 WitIn"); assert_eq!(limbs.len(), 2); [limbs[0].id as u32, limbs[1].id as u32] }; @@ -105,7 +99,10 @@ mod tests { assert!( (col as usize) < col_map.num_cols as usize, "Column {} (index {}) out of range: {} >= {}", - i, col, col, col_map.num_cols + i, + col, + col, + col_map.num_cols ); } let mut seen = std::collections::HashSet::new(); @@ -157,16 +154,15 @@ mod tests { let indices: Vec = (0..n).collect(); let mut shard_ctx = ShardContext::default(); - let (cpu_rmms, _lkm) = - crate::instructions::cpu_assign_instances::>( - &config, - &mut shard_ctx, - num_witin, - num_structural_witin, - &steps, - &indices, - ) - .unwrap(); + let (cpu_rmms, _lkm) = crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(); let cpu_witness = &cpu_rmms[0]; let col_map = extract_branch_cmp_column_map(&config, num_witin); @@ -181,14 +177,7 @@ mod tests { let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); let gpu_result = hal - .witgen_branch_cmp( - &col_map, - &gpu_records, - &indices_u32, - shard_offset, - 1, - None, - ) + .witgen_branch_cmp(&col_map, &gpu_records, &indices_u32, shard_offset, 1, None) .unwrap(); let gpu_data: Vec<::BaseField> = diff --git a/ceno_zkvm/src/instructions/riscv/gpu/branch_eq.rs b/ceno_zkvm/src/instructions/riscv/gpu/branch_eq.rs index 2b312cfc8..a44eaafa0 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/branch_eq.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/branch_eq.rs @@ -9,18 +9,12 @@ pub fn extract_branch_eq_column_map( num_witin: usize, ) -> BranchEqColumnMap { let rs1_limbs: [u32; 2] = { - let limbs = config - .read_rs1 - .wits_in() - .expect("rs1 WitIn"); + let limbs = config.read_rs1.wits_in().expect("rs1 WitIn"); assert_eq!(limbs.len(), 2); [limbs[0].id as u32, limbs[1].id as u32] }; let rs2_limbs: [u32; 2] = { - let limbs = config - .read_rs2 - .wits_in() - .expect("rs2 WitIn"); + let limbs = config.read_rs2.wits_in().expect("rs2 WitIn"); assert_eq!(limbs.len(), 2); [limbs[0].id as u32, limbs[1].id as u32] }; @@ -98,7 +92,10 @@ mod tests { assert!( (col as usize) < col_map.num_cols as usize, "Column {} (index {}) out of range: {} >= {}", - i, col, col, col_map.num_cols + i, + col, + col, + col_map.num_cols ); } let mut seen = std::collections::HashSet::new(); @@ -128,7 +125,11 @@ mod tests { let steps: Vec = (0..n) .map(|i| { let rs1 = ((i as u32) * 137) ^ 0xABCD; - let rs2 = if i % 3 == 0 { rs1 } else { ((i as u32) * 89) ^ 0x1234 }; + let rs2 = if i % 3 == 0 { + rs1 + } else { + ((i as u32) * 89) ^ 0x1234 + }; let taken = rs1 == rs2; let pc = ByteAddr(0x1000 + (i as u32) * 4); let pc_after = if taken { @@ -150,16 +151,15 @@ mod tests { let indices: Vec = (0..n).collect(); let mut shard_ctx = ShardContext::default(); - let (cpu_rmms, _lkm) = - crate::instructions::cpu_assign_instances::>( - &config, - &mut shard_ctx, - num_witin, - num_structural_witin, - &steps, - &indices, - ) - .unwrap(); + let (cpu_rmms, _lkm) = crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(); let cpu_witness = &cpu_rmms[0]; let col_map = extract_branch_eq_column_map(&config, num_witin); @@ -174,14 +174,7 @@ mod tests { let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); let gpu_result = hal - .witgen_branch_eq( - &col_map, - &gpu_records, - &indices_u32, - shard_offset, - 1, - None, - ) + .witgen_branch_eq(&col_map, &gpu_records, &indices_u32, shard_offset, 1, None) .unwrap(); let gpu_data: Vec<::BaseField> = diff --git a/ceno_zkvm/src/instructions/riscv/gpu/div.rs b/ceno_zkvm/src/instructions/riscv/gpu/div.rs index 7d89d4ec6..dd708cb59 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/div.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/div.rs @@ -86,16 +86,16 @@ pub fn extract_div_column_map( // remainder_prime let remainder_prime: [u32; 2] = { - let l = config.remainder_prime.wits_in().expect("remainder_prime WitIns"); + let l = config + .remainder_prime + .wits_in() + .expect("remainder_prime WitIns"); assert_eq!(l.len(), 2); [l[0].id as u32, l[1].id as u32] }; // lt_marker - let lt_marker: [u32; 2] = [ - config.lt_marker[0].id as u32, - config.lt_marker[1].id as u32, - ]; + let lt_marker: [u32; 2] = [config.lt_marker[0].id as u32, config.lt_marker[1].id as u32]; // lt_diff let lt_diff = config.lt_diff.id as u32; @@ -154,7 +154,10 @@ mod tests { assert!( (col as usize) < col_map.num_cols as usize, "Column {} (index {}) out of range: {} >= {}", - i, col, col, col_map.num_cols + i, + col, + col, + col_map.num_cols ); } let mut seen = std::collections::HashSet::new(); @@ -226,10 +229,22 @@ mod tests { let mut cb = CircuitBuilder::new(&mut cs); let config = match insn_kind { - InsnKind::DIV => DivInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(), - InsnKind::DIVU => DivuInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(), - InsnKind::REM => RemInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(), - InsnKind::REMU => RemuInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(), + InsnKind::DIV => { + DivInstruction::::construct_circuit(&mut cb, &ProgramParams::default()) + .unwrap() + } + InsnKind::DIVU => { + DivuInstruction::::construct_circuit(&mut cb, &ProgramParams::default()) + .unwrap() + } + InsnKind::REM => { + RemInstruction::::construct_circuit(&mut cb, &ProgramParams::default()) + .unwrap() + } + InsnKind::REMU => { + RemuInstruction::::construct_circuit(&mut cb, &ProgramParams::default()) + .unwrap() + } _ => unreachable!(), }; let num_witin = cb.cs.num_witin as usize; @@ -238,20 +253,20 @@ mod tests { let n = 1024; const EDGE_CASES: &[(u32, u32)] = &[ - (0, 1), // 0 / 1 - (1, 1), // 1 / 1 - (0, 0), // 0 / 0 (zero divisor) - (12345, 0), // non-zero / 0 (zero divisor) - (u32::MAX, 0), // max / 0 (zero divisor) - (0x80000000, 0), // INT_MIN / 0 (zero divisor) - (0x80000000, 0xFFFFFFFF), // INT_MIN / -1 (signed overflow!) - (0x7FFFFFFF, 0xFFFFFFFF), // INT_MAX / -1 - (0xFFFFFFFF, 0xFFFFFFFF), // -1 / -1 - (0x80000000, 1), // INT_MIN / 1 - (0x80000000, 2), // INT_MIN / 2 - (u32::MAX, u32::MAX), // max / max - (u32::MAX, 1), // max / 1 - (1, u32::MAX), // 1 / max + (0, 1), // 0 / 1 + (1, 1), // 1 / 1 + (0, 0), // 0 / 0 (zero divisor) + (12345, 0), // non-zero / 0 (zero divisor) + (u32::MAX, 0), // max / 0 (zero divisor) + (0x80000000, 0), // INT_MIN / 0 (zero divisor) + (0x80000000, 0xFFFFFFFF), // INT_MIN / -1 (signed overflow!) + (0x7FFFFFFF, 0xFFFFFFFF), // INT_MAX / -1 + (0xFFFFFFFF, 0xFFFFFFFF), // -1 / -1 + (0x80000000, 1), // INT_MIN / 1 + (0x80000000, 2), // INT_MIN / 2 + (u32::MAX, u32::MAX), // max / max + (u32::MAX, 1), // max / 1 + (1, u32::MAX), // 1 / max ]; let steps: Vec = (0..n) @@ -321,17 +336,45 @@ mod tests { let mut shard_ctx = ShardContext::default(); let (cpu_rmms, _lkm) = match insn_kind { InsnKind::DIV => crate::instructions::cpu_assign_instances::>( - &config, &mut shard_ctx, num_witin, num_structural_witin, &steps, &indices, - ).unwrap(), - InsnKind::DIVU => crate::instructions::cpu_assign_instances::>( - &config, &mut shard_ctx, num_witin, num_structural_witin, &steps, &indices, - ).unwrap(), + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(), + InsnKind::DIVU => { + crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap() + } InsnKind::REM => crate::instructions::cpu_assign_instances::>( - &config, &mut shard_ctx, num_witin, num_structural_witin, &steps, &indices, - ).unwrap(), - InsnKind::REMU => crate::instructions::cpu_assign_instances::>( - &config, &mut shard_ctx, num_witin, num_structural_witin, &steps, &indices, - ).unwrap(), + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(), + InsnKind::REMU => { + crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap() + } _ => unreachable!(), }; let cpu_witness = &cpu_rmms[0]; @@ -349,7 +392,14 @@ mod tests { let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); let gpu_result = hal - .witgen_div(&col_map, &gpu_records, &indices_u32, shard_offset, div_kind, None) + .witgen_div( + &col_map, + &gpu_records, + &indices_u32, + shard_offset, + div_kind, + None, + ) .unwrap(); let gpu_data: Vec<::BaseField> = diff --git a/ceno_zkvm/src/instructions/riscv/gpu/jal.rs b/ceno_zkvm/src/instructions/riscv/gpu/jal.rs index 8a9be12de..d33b575e6 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/jal.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/jal.rs @@ -31,9 +31,17 @@ pub fn extract_jal_column_map( // JAL-specific: rd u8 bytes let rd_bytes: [u32; 4] = { - let l = config.rd_written.wits_in().expect("rd_written UInt8 WitIns"); + let l = config + .rd_written + .wits_in() + .expect("rd_written UInt8 WitIns"); assert_eq!(l.len(), 4); - [l[0].id as u32, l[1].id as u32, l[2].id as u32, l[3].id as u32] + [ + l[0].id as u32, + l[1].id as u32, + l[2].id as u32, + l[3].id as u32, + ] }; JalColumnMap { @@ -75,7 +83,10 @@ mod tests { assert!( (col as usize) < col_map.num_cols as usize, "Column {} (index {}) out of range: {} >= {}", - i, col, col, col_map.num_cols + i, + col, + col, + col_map.num_cols ); } let mut seen = std::collections::HashSet::new(); @@ -123,7 +134,12 @@ mod tests { let mut shard_ctx = ShardContext::default(); let (cpu_rmms, _lkm) = crate::instructions::cpu_assign_instances::>( - &config, &mut shard_ctx, num_witin, num_structural_witin, &steps, &indices, + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, ) .unwrap(); let cpu_witness = &cpu_rmms[0]; diff --git a/ceno_zkvm/src/instructions/riscv/gpu/jalr.rs b/ceno_zkvm/src/instructions/riscv/gpu/jalr.rs index 11bce8623..804218293 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/jalr.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/jalr.rs @@ -51,13 +51,21 @@ pub fn extract_jalr_column_map( // jump_pc_addr: MemAddr has addr (UInt = 2 limbs) + low_bits (Vec) let jump_pc_addr: [u32; 2] = { - let l = config.jump_pc_addr.addr.wits_in().expect("jump_pc_addr WitIns"); + let l = config + .jump_pc_addr + .addr + .wits_in() + .expect("jump_pc_addr WitIns"); assert_eq!(l.len(), 2); [l[0].id as u32, l[1].id as u32] }; let jump_pc_addr_bit: [u32; 2] = { let bits = &config.jump_pc_addr.low_bits; - assert_eq!(bits.len(), 2, "JALR MemAddr with n_zeros=0 must have 2 low_bits"); + assert_eq!( + bits.len(), + 2, + "JALR MemAddr with n_zeros=0 must have 2 low_bits" + ); [bits[0].id as u32, bits[1].id as u32] }; @@ -111,7 +119,10 @@ mod tests { assert!( (col as usize) < col_map.num_cols as usize, "Column {} (index {}) out of range: {} >= {}", - i, col, col, col_map.num_cols + i, + col, + col, + col_map.num_cols ); } let mut seen = std::collections::HashSet::new(); @@ -160,16 +171,15 @@ mod tests { let indices: Vec = (0..n).collect(); let mut shard_ctx = ShardContext::default(); - let (cpu_rmms, _lkm) = - crate::instructions::cpu_assign_instances::>( - &config, - &mut shard_ctx, - num_witin, - num_structural_witin, - &steps, - &indices, - ) - .unwrap(); + let (cpu_rmms, _lkm) = crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(); let cpu_witness = &cpu_rmms[0]; let col_map = extract_jalr_column_map(&config, num_witin); diff --git a/ceno_zkvm/src/instructions/riscv/gpu/load_sub.rs b/ceno_zkvm/src/instructions/riscv/gpu/load_sub.rs index 163efa681..f7d48c772 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/load_sub.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/load_sub.rs @@ -7,8 +7,8 @@ use crate::instructions::riscv::memory::load_v2::LoadConfig; pub fn extract_load_sub_column_map( config: &LoadConfig, num_witin: usize, - is_byte: bool, // true for LB/LBU - is_signed: bool, // true for LH/LB + is_byte: bool, // true for LB/LBU + is_signed: bool, // true for LH/LB ) -> LoadSubColumnMap { let im = &config.im_insn; @@ -146,7 +146,10 @@ mod tests { use super::*; use crate::{ circuit_builder::{CircuitBuilder, ConstraintSystem}, - instructions::{Instruction, riscv::memory::{LhInstruction, LhuInstruction, LbInstruction, LbuInstruction}}, + instructions::{ + Instruction, + riscv::memory::{LbInstruction, LbuInstruction, LhInstruction, LhuInstruction}, + }, structs::ProgramParams, }; use ff_ext::BabyBearExt4; @@ -159,7 +162,10 @@ mod tests { assert!( (col as usize) < col_map.num_cols as usize, "Column {} (index {}) out of range: {} >= {}", - i, col, col, col_map.num_cols + i, + col, + col, + col_map.num_cols ); } let mut seen = std::collections::HashSet::new(); @@ -222,9 +228,7 @@ mod tests { #[cfg(feature = "gpu")] fn test_gpu_witgen_load_sub_correctness() { use crate::e2e::ShardContext; - use ceno_emul::{ - ByteAddr, Change, InsnKind, ReadOp, StepRecord, WordAddr, encode_rv32, - }; + use ceno_emul::{ByteAddr, Change, InsnKind, ReadOp, StepRecord, WordAddr, encode_rv32}; use ceno_gpu::{Buffer, bb31::CudaHalBB31}; let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); @@ -245,10 +249,22 @@ mod tests { // We need to construct the right instruction type let config = match insn_kind { - InsnKind::LH => LhInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(), - InsnKind::LHU => LhuInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(), - InsnKind::LB => LbInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(), - InsnKind::LBU => LbuInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(), + InsnKind::LH => { + LhInstruction::::construct_circuit(&mut cb, &ProgramParams::default()) + .unwrap() + } + InsnKind::LHU => { + LhuInstruction::::construct_circuit(&mut cb, &ProgramParams::default()) + .unwrap() + } + InsnKind::LB => { + LbInstruction::::construct_circuit(&mut cb, &ProgramParams::default()) + .unwrap() + } + InsnKind::LBU => { + LbuInstruction::::construct_circuit(&mut cb, &ProgramParams::default()) + .unwrap() + } _ => unreachable!(), }; let num_witin = cb.cs.num_witin as usize; @@ -279,9 +295,7 @@ mod tests { (mem_val >> 16) as u16 }; let rd_after = match insn_kind { - InsnKind::LH => { - (target_limb as i16) as i32 as u32 - } + InsnKind::LH => (target_limb as i16) as i32 as u32, InsnKind::LHU => target_limb as u32, InsnKind::LB => { let byte = if bit_0 == 0 { @@ -328,17 +342,41 @@ mod tests { let mut shard_ctx = ShardContext::default(); let (cpu_rmms, _lkm) = match insn_kind { InsnKind::LH => crate::instructions::cpu_assign_instances::>( - &config, &mut shard_ctx, num_witin, num_structural_witin, &steps, &indices, - ).unwrap(), + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(), InsnKind::LHU => crate::instructions::cpu_assign_instances::>( - &config, &mut shard_ctx, num_witin, num_structural_witin, &steps, &indices, - ).unwrap(), + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(), InsnKind::LB => crate::instructions::cpu_assign_instances::>( - &config, &mut shard_ctx, num_witin, num_structural_witin, &steps, &indices, - ).unwrap(), + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(), InsnKind::LBU => crate::instructions::cpu_assign_instances::>( - &config, &mut shard_ctx, num_witin, num_structural_witin, &steps, &indices, - ).unwrap(), + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(), _ => unreachable!(), }; let cpu_witness = &cpu_rmms[0]; @@ -358,7 +396,15 @@ mod tests { let load_width: u32 = if is_byte { 8 } else { 16 }; let is_signed_u32: u32 = if is_signed { 1 } else { 0 }; let gpu_result = hal - .witgen_load_sub(&col_map, &gpu_records, &indices_u32, shard_offset, load_width, is_signed_u32, None) + .witgen_load_sub( + &col_map, + &gpu_records, + &indices_u32, + shard_offset, + load_width, + is_signed_u32, + None, + ) .unwrap(); let gpu_data: Vec<::BaseField> = diff --git a/ceno_zkvm/src/instructions/riscv/gpu/logic_i.rs b/ceno_zkvm/src/instructions/riscv/gpu/logic_i.rs index 3a54fe8e1..c16e4f95f 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/logic_i.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/logic_i.rs @@ -41,14 +41,24 @@ pub fn extract_logic_i_column_map( let rs1_bytes: [u32; 4] = { let l = config.rs1_read.wits_in().expect("rs1_read WitIns"); assert_eq!(l.len(), 4); - [l[0].id as u32, l[1].id as u32, l[2].id as u32, l[3].id as u32] + [ + l[0].id as u32, + l[1].id as u32, + l[2].id as u32, + l[3].id as u32, + ] }; // rd u8 bytes let rd_bytes: [u32; 4] = { let l = config.rd_written.wits_in().expect("rd_written WitIns"); assert_eq!(l.len(), 4); - [l[0].id as u32, l[1].id as u32, l[2].id as u32, l[3].id as u32] + [ + l[0].id as u32, + l[1].id as u32, + l[2].id as u32, + l[3].id as u32, + ] }; // imm_lo u8 bytes (UIntLimbs<16,8> = 2 x u8) @@ -109,7 +119,10 @@ mod tests { assert!( (col as usize) < col_map.num_cols as usize, "Column {} (index {}) out of range: {} >= {}", - i, col, col, col_map.num_cols + i, + col, + col, + col_map.num_cols ); } let mut seen = std::collections::HashSet::new(); @@ -136,7 +149,7 @@ mod tests { const EDGE_CASES: &[(u32, u32)] = &[ (0, 0), - (u32::MAX, 0xFFF), // all bits AND max imm + (u32::MAX, 0xFFF), // all bits AND max imm (u32::MAX, 0), (0, 0xFFF), (0xAAAAAAAA, 0x555), // alternating @@ -151,7 +164,10 @@ mod tests { let (rs1, imm) = if i < EDGE_CASES.len() { EDGE_CASES[i] } else { - ((i as u32).wrapping_mul(0x01010101) ^ 0xabed_5eff, (i as u32) % 4096) + ( + (i as u32).wrapping_mul(0x01010101) ^ 0xabed_5eff, + (i as u32) % 4096, + ) }; let rd_after = rs1 & imm; // ANDI let cycle = 4 + (i as u64) * 4; @@ -171,7 +187,12 @@ mod tests { let mut shard_ctx = ShardContext::default(); let (cpu_rmms, _lkm) = crate::instructions::cpu_assign_instances::>( - &config, &mut shard_ctx, num_witin, num_structural_witin, &steps, &indices, + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, ) .unwrap(); let cpu_witness = &cpu_rmms[0]; diff --git a/ceno_zkvm/src/instructions/riscv/gpu/logic_r.rs b/ceno_zkvm/src/instructions/riscv/gpu/logic_r.rs index b798b1115..17933915d 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/logic_r.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/logic_r.rs @@ -34,7 +34,12 @@ pub fn extract_logic_r_column_map( let rd_id = config.r_insn.rd.id.id as u32; let rd_prev_ts = config.r_insn.rd.prev_ts.id as u32; let rd_prev_val: [u32; 2] = { - let limbs = config.r_insn.rd.prev_value.wits_in().expect("rd prev_value WitIns"); + let limbs = config + .r_insn + .rd + .prev_value + .wits_in() + .expect("rd prev_value WitIns"); assert_eq!(limbs.len(), 2); [limbs[0].id as u32, limbs[1].id as u32] }; @@ -48,17 +53,32 @@ pub fn extract_logic_r_column_map( let rs1_bytes: [u32; 4] = { let l = config.rs1_read.wits_in().expect("rs1_read WitIns"); assert_eq!(l.len(), 4); - [l[0].id as u32, l[1].id as u32, l[2].id as u32, l[3].id as u32] + [ + l[0].id as u32, + l[1].id as u32, + l[2].id as u32, + l[3].id as u32, + ] }; let rs2_bytes: [u32; 4] = { let l = config.rs2_read.wits_in().expect("rs2_read WitIns"); assert_eq!(l.len(), 4); - [l[0].id as u32, l[1].id as u32, l[2].id as u32, l[3].id as u32] + [ + l[0].id as u32, + l[1].id as u32, + l[2].id as u32, + l[3].id as u32, + ] }; let rd_bytes: [u32; 4] = { let l = config.rd_written.wits_in().expect("rd_written WitIns"); assert_eq!(l.len(), 4); - [l[0].id as u32, l[1].id as u32, l[2].id as u32, l[3].id as u32] + [ + l[0].id as u32, + l[1].id as u32, + l[2].id as u32, + l[3].id as u32, + ] }; LogicRColumnMap { @@ -93,6 +113,35 @@ mod tests { type E = BabyBearExt4; + fn flatten_records( + records: &[std::collections::BTreeMap], + ) -> Vec<(ceno_emul::WordAddr, u64, u64, usize)> { + records + .iter() + .flat_map(|table| { + table + .iter() + .map(|(addr, record)| (*addr, record.prev_cycle, record.cycle, record.shard_id)) + }) + .collect() + } + + fn flatten_lk( + multiplicity: &gkr_iop::utils::lk_multiplicity::Multiplicity, + ) -> Vec> { + multiplicity + .iter() + .map(|table| { + let mut entries = table + .iter() + .map(|(key, count)| (*key, *count)) + .collect::>(); + entries.sort_unstable(); + entries + }) + .collect() + } + #[test] fn test_extract_logic_r_column_map() { let mut cs = ConstraintSystem::::new(|| "test_logic_r"); @@ -107,7 +156,10 @@ mod tests { assert!( (col as usize) < col_map.num_cols as usize, "Column {} (index {}) out of range: {} >= {}", - i, col, col, col_map.num_cols + i, + col, + col, + col_map.num_cols ); } let mut seen = std::collections::HashSet::new(); @@ -149,15 +201,23 @@ mod tests { let (rs1, rs2) = if i < EDGE_CASES.len() { EDGE_CASES[i] } else { - (0xDEAD_0000u32 | (i as u32), 0x00FF_FF00u32 | ((i as u32) << 8)) + ( + 0xDEAD_0000u32 | (i as u32), + 0x00FF_FF00u32 | ((i as u32) << 8), + ) }; let rd_after = rs1 & rs2; // AND let cycle = 4 + (i as u64) * 4; let pc = ByteAddr(0x1000 + (i as u32) * 4); let insn_code = encode_rv32(InsnKind::AND, 2, 3, 4, 0); StepRecord::new_r_instruction( - cycle, pc, insn_code, rs1, rs2, - Change::new((i as u32) % 200, rd_after), 0, + cycle, + pc, + insn_code, + rs1, + rs2, + Change::new((i as u32) % 200, rd_after), + 0, ) }) .collect(); @@ -165,10 +225,16 @@ mod tests { // CPU path let mut shard_ctx = ShardContext::default(); - let (cpu_rmms, _lkm) = crate::instructions::cpu_assign_instances::>( - &config, &mut shard_ctx, num_witin, num_structural_witin, &steps, &indices, - ) - .unwrap(); + let (cpu_rmms, cpu_lkm) = + crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(); let cpu_witness = &cpu_rmms[0]; // GPU path @@ -209,5 +275,37 @@ mod tests { } } assert_eq!(mismatches, 0, "Found {} mismatches", mismatches); + + let mut shard_ctx_full_gpu = ShardContext::default(); + let (gpu_rmms, gpu_lkm) = + crate::instructions::riscv::gpu::witgen_gpu::try_gpu_assign_instances::< + E, + AndInstruction, + >( + &config, + &mut shard_ctx_full_gpu, + num_witin, + num_structural_witin, + &steps, + &indices, + crate::instructions::riscv::gpu::witgen_gpu::GpuWitgenKind::LogicR(0), + ) + .unwrap() + .expect("GPU path should be available"); + + assert_eq!(gpu_rmms[0].values(), cpu_rmms[0].values()); + assert_eq!(flatten_lk(&gpu_lkm), flatten_lk(&cpu_lkm)); + assert_eq!( + shard_ctx_full_gpu.get_addr_accessed(), + shard_ctx.get_addr_accessed() + ); + assert_eq!( + flatten_records(shard_ctx_full_gpu.read_records()), + flatten_records(shard_ctx.read_records()) + ); + assert_eq!( + flatten_records(shard_ctx_full_gpu.write_records()), + flatten_records(shard_ctx.write_records()) + ); } } diff --git a/ceno_zkvm/src/instructions/riscv/gpu/lui.rs b/ceno_zkvm/src/instructions/riscv/gpu/lui.rs index fa52596a7..348b5b8b4 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/lui.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/lui.rs @@ -87,7 +87,10 @@ mod tests { assert!( (col as usize) < col_map.num_cols as usize, "Column {} (index {}) out of range: {} >= {}", - i, col, col, col_map.num_cols + i, + col, + col, + col_map.num_cols ); } let mut seen = std::collections::HashSet::new(); @@ -135,7 +138,12 @@ mod tests { let mut shard_ctx = ShardContext::default(); let (cpu_rmms, _lkm) = crate::instructions::cpu_assign_instances::>( - &config, &mut shard_ctx, num_witin, num_structural_witin, &steps, &indices, + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, ) .unwrap(); let cpu_witness = &cpu_rmms[0]; diff --git a/ceno_zkvm/src/instructions/riscv/gpu/lw.rs b/ceno_zkvm/src/instructions/riscv/gpu/lw.rs index 5a4dd1b91..19f38a7a6 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/lw.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/lw.rs @@ -109,6 +109,35 @@ mod tests { type E = BabyBearExt4; type LwInstruction = crate::instructions::riscv::LwInstruction; + fn flatten_records( + records: &[std::collections::BTreeMap], + ) -> Vec<(ceno_emul::WordAddr, u64, u64, usize)> { + records + .iter() + .flat_map(|table| { + table + .iter() + .map(|(addr, record)| (*addr, record.prev_cycle, record.cycle, record.shard_id)) + }) + .collect() + } + + fn flatten_lk( + multiplicity: &gkr_iop::utils::lk_multiplicity::Multiplicity, + ) -> Vec> { + multiplicity + .iter() + .map(|table| { + let mut entries = table + .iter() + .map(|(key, count)| (*key, *count)) + .collect::>(); + entries.sort_unstable(); + entries + }) + .collect() + } + fn make_lw_test_steps(n: usize) -> Vec { let pc_start = 0x1000u32; // Use varying immediates including negative values to test imm_field encoding @@ -188,7 +217,7 @@ mod tests { // CPU path let mut shard_ctx = ShardContext::default(); - let (cpu_rmms, _lkm) = crate::instructions::cpu_assign_instances::( + let (cpu_rmms, cpu_lkm) = crate::instructions::cpu_assign_instances::( &config, &mut shard_ctx, num_witin, @@ -238,5 +267,37 @@ mod tests { } } assert_eq!(mismatches, 0, "Found {} mismatches", mismatches); + + let mut shard_ctx_full_gpu = ShardContext::default(); + let (gpu_rmms, gpu_lkm) = + crate::instructions::riscv::gpu::witgen_gpu::try_gpu_assign_instances::< + E, + LwInstruction, + >( + &config, + &mut shard_ctx_full_gpu, + num_witin, + num_structural_witin, + &steps, + &indices, + crate::instructions::riscv::gpu::witgen_gpu::GpuWitgenKind::Lw, + ) + .unwrap() + .expect("GPU path should be available"); + + assert_eq!(gpu_rmms[0].values(), cpu_rmms[0].values()); + assert_eq!(flatten_lk(&gpu_lkm), flatten_lk(&cpu_lkm)); + assert_eq!( + shard_ctx_full_gpu.get_addr_accessed(), + shard_ctx.get_addr_accessed() + ); + assert_eq!( + flatten_records(shard_ctx_full_gpu.read_records()), + flatten_records(shard_ctx.read_records()) + ); + assert_eq!( + flatten_records(shard_ctx_full_gpu.write_records()), + flatten_records(shard_ctx.write_records()) + ); } } diff --git a/ceno_zkvm/src/instructions/riscv/gpu/mod.rs b/ceno_zkvm/src/instructions/riscv/gpu/mod.rs index ad093a423..51c4ba33f 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/mod.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/mod.rs @@ -5,8 +5,18 @@ pub mod addi; #[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] pub mod auipc; #[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod branch_cmp; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod branch_eq; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod div; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] pub mod jal; #[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod jalr; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod load_sub; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] pub mod logic_i; #[cfg(feature = "gpu")] pub mod logic_r; @@ -15,6 +25,12 @@ pub mod lui; #[cfg(feature = "gpu")] pub mod lw; #[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod mul; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod sb; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod sh; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] pub mod shift_i; #[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] pub mod shift_r; @@ -25,22 +41,6 @@ pub mod slti; #[cfg(feature = "gpu")] pub mod sub; #[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] -pub mod branch_cmp; -#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] -pub mod jalr; -#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] pub mod sw; -#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] -pub mod sh; -#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] -pub mod sb; -#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] -pub mod load_sub; -#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] -pub mod mul; -#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] -pub mod div; -#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] -pub mod branch_eq; #[cfg(feature = "gpu")] pub mod witgen_gpu; diff --git a/ceno_zkvm/src/instructions/riscv/gpu/mul.rs b/ceno_zkvm/src/instructions/riscv/gpu/mul.rs index 9cf8ec04a..1a9b8f902 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/mul.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/mul.rs @@ -60,7 +60,10 @@ pub fn extract_mul_column_map( // MULH/MULHU/MULHSU have rd_high + extensions let (rd_high, rs1_ext, rs2_ext) = if mul_kind != 0 { - let h = config.rd_high.as_ref().expect("MULH variants must have rd_high"); + let h = config + .rd_high + .as_ref() + .expect("MULH variants must have rd_high"); ( Some([h[0].id as u32, h[1].id as u32]), Some(config.rs1_ext.expect("MULH variants must have rs1_ext").id as u32), @@ -114,7 +117,10 @@ mod tests { assert!( (col as usize) < col_map.num_cols as usize, "Column {} (index {}) out of range: {} >= {}", - i, col, col, col_map.num_cols + i, + col, + col, + col_map.num_cols ); } let mut seen = std::collections::HashSet::new(); @@ -190,28 +196,40 @@ mod tests { let mut cb = CircuitBuilder::new(&mut cs); let config = match insn_kind { - InsnKind::MUL => MulInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(), - InsnKind::MULH => MulhInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(), - InsnKind::MULHU => MulhuInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(), - InsnKind::MULHSU => MulhsuInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(), + InsnKind::MUL => { + MulInstruction::::construct_circuit(&mut cb, &ProgramParams::default()) + .unwrap() + } + InsnKind::MULH => { + MulhInstruction::::construct_circuit(&mut cb, &ProgramParams::default()) + .unwrap() + } + InsnKind::MULHU => { + MulhuInstruction::::construct_circuit(&mut cb, &ProgramParams::default()) + .unwrap() + } + InsnKind::MULHSU => { + MulhsuInstruction::::construct_circuit(&mut cb, &ProgramParams::default()) + .unwrap() + } _ => unreachable!(), }; let num_witin = cb.cs.num_witin as usize; let num_structural_witin = cb.cs.num_structural_witin as usize; const EDGE_CASES: &[(u32, u32)] = &[ - (0, 0), // zero * zero - (0, 12345), // zero * non-zero - (12345, 0), // non-zero * zero - (1, 1), // identity - (u32::MAX, 1), // max * 1 - (1, u32::MAX), // 1 * max - (u32::MAX, u32::MAX), // max * max - (0x80000000, 2), // INT_MIN * 2 (for MULH) - (2, 0x80000000), // 2 * INT_MIN - (0xFFFFFFFF, 0xFFFFFFFF), // (-1) * (-1) for signed - (0x80000000, 0xFFFFFFFF), // INT_MIN * (-1) - (0x7FFFFFFF, 0x7FFFFFFF), // INT_MAX * INT_MAX + (0, 0), // zero * zero + (0, 12345), // zero * non-zero + (12345, 0), // non-zero * zero + (1, 1), // identity + (u32::MAX, 1), // max * 1 + (1, u32::MAX), // 1 * max + (u32::MAX, u32::MAX), // max * max + (0x80000000, 2), // INT_MIN * 2 (for MULH) + (2, 0x80000000), // 2 * INT_MIN + (0xFFFFFFFF, 0xFFFFFFFF), // (-1) * (-1) for signed + (0x80000000, 0xFFFFFFFF), // INT_MIN * (-1) + (0x7FFFFFFF, 0x7FFFFFFF), // INT_MAX * INT_MAX ]; let n = 1024; @@ -229,7 +247,8 @@ mod tests { let rd_after = match insn_kind { InsnKind::MUL => rs1_val.wrapping_mul(rs2_val), InsnKind::MULH => { - ((rs1_val as i32 as i64).wrapping_mul(rs2_val as i32 as i64) >> 32) as u32 + ((rs1_val as i32 as i64).wrapping_mul(rs2_val as i32 as i64) >> 32) + as u32 } InsnKind::MULHU => { ((rs1_val as u64).wrapping_mul(rs2_val as u64) >> 32) as u32 @@ -260,17 +279,47 @@ mod tests { let mut shard_ctx = ShardContext::default(); let (cpu_rmms, _lkm) = match insn_kind { InsnKind::MUL => crate::instructions::cpu_assign_instances::>( - &config, &mut shard_ctx, num_witin, num_structural_witin, &steps, &indices, - ).unwrap(), - InsnKind::MULH => crate::instructions::cpu_assign_instances::>( - &config, &mut shard_ctx, num_witin, num_structural_witin, &steps, &indices, - ).unwrap(), - InsnKind::MULHU => crate::instructions::cpu_assign_instances::>( - &config, &mut shard_ctx, num_witin, num_structural_witin, &steps, &indices, - ).unwrap(), - InsnKind::MULHSU => crate::instructions::cpu_assign_instances::>( - &config, &mut shard_ctx, num_witin, num_structural_witin, &steps, &indices, - ).unwrap(), + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(), + InsnKind::MULH => { + crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap() + } + InsnKind::MULHU => { + crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap() + } + InsnKind::MULHSU => { + crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap() + } _ => unreachable!(), }; let cpu_witness = &cpu_rmms[0]; @@ -288,7 +337,14 @@ mod tests { let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); let gpu_result = hal - .witgen_mul(&col_map, &gpu_records, &indices_u32, shard_offset, mul_kind, None) + .witgen_mul( + &col_map, + &gpu_records, + &indices_u32, + shard_offset, + mul_kind, + None, + ) .unwrap(); let gpu_data: Vec<::BaseField> = diff --git a/ceno_zkvm/src/instructions/riscv/gpu/sb.rs b/ceno_zkvm/src/instructions/riscv/gpu/sb.rs index 3638b56c4..346be925e 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/sb.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/sb.rs @@ -72,7 +72,11 @@ pub fn extract_sb_column_map( }; // SB-specific: 2 low_bits (bit_0, bit_1) - assert_eq!(config.memory_addr.low_bits.len(), 2, "SB should have 2 low_bits"); + assert_eq!( + config.memory_addr.low_bits.len(), + 2, + "SB should have 2 low_bits" + ); let mem_addr_bit_0 = config.memory_addr.low_bits[0].id as u32; let mem_addr_bit_1 = config.memory_addr.low_bits[1].id as u32; @@ -146,7 +150,10 @@ mod tests { assert!( (col as usize) < col_map.num_cols as usize, "Column {} (index {}) out of range: {} >= {}", - i, col, col, col_map.num_cols + i, + col, + col, + col_map.num_cols ); } let mut seen = std::collections::HashSet::new(); @@ -159,9 +166,7 @@ mod tests { #[cfg(feature = "gpu")] fn test_gpu_witgen_sb_correctness() { use crate::e2e::ShardContext; - use ceno_emul::{ - ByteAddr, Change, InsnKind, StepRecord, WordAddr, WriteOp, encode_rv32, - }; + use ceno_emul::{ByteAddr, Change, InsnKind, StepRecord, WordAddr, WriteOp, encode_rv32}; use ceno_gpu::{Buffer, bb31::CudaHalBB31}; let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); @@ -214,16 +219,15 @@ mod tests { let indices: Vec = (0..n).collect(); let mut shard_ctx = ShardContext::default(); - let (cpu_rmms, _lkm) = - crate::instructions::cpu_assign_instances::>( - &config, - &mut shard_ctx, - num_witin, - num_structural_witin, - &steps, - &indices, - ) - .unwrap(); + let (cpu_rmms, _lkm) = crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(); let cpu_witness = &cpu_rmms[0]; let col_map = extract_sb_column_map(&config, num_witin); diff --git a/ceno_zkvm/src/instructions/riscv/gpu/sh.rs b/ceno_zkvm/src/instructions/riscv/gpu/sh.rs index 225ca91d3..e35d94bf0 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/sh.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/sh.rs @@ -72,7 +72,11 @@ pub fn extract_sh_column_map( }; // SH-specific: 1 low_bit (bit_1 for halfword select) - assert_eq!(config.memory_addr.low_bits.len(), 1, "SH should have 1 low_bit"); + assert_eq!( + config.memory_addr.low_bits.len(), + 1, + "SH should have 1 low_bit" + ); let mem_addr_bit_1 = config.memory_addr.low_bits[0].id as u32; ShColumnMap { @@ -123,7 +127,10 @@ mod tests { assert!( (col as usize) < col_map.num_cols as usize, "Column {} (index {}) out of range: {} >= {}", - i, col, col, col_map.num_cols + i, + col, + col, + col_map.num_cols ); } let mut seen = std::collections::HashSet::new(); @@ -136,9 +143,7 @@ mod tests { #[cfg(feature = "gpu")] fn test_gpu_witgen_sh_correctness() { use crate::e2e::ShardContext; - use ceno_emul::{ - ByteAddr, Change, InsnKind, StepRecord, WordAddr, WriteOp, encode_rv32, - }; + use ceno_emul::{ByteAddr, Change, InsnKind, StepRecord, WordAddr, WriteOp, encode_rv32}; use ceno_gpu::{Buffer, bb31::CudaHalBB31}; let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); @@ -191,16 +196,15 @@ mod tests { let indices: Vec = (0..n).collect(); let mut shard_ctx = ShardContext::default(); - let (cpu_rmms, _lkm) = - crate::instructions::cpu_assign_instances::>( - &config, - &mut shard_ctx, - num_witin, - num_structural_witin, - &steps, - &indices, - ) - .unwrap(); + let (cpu_rmms, _lkm) = crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(); let cpu_witness = &cpu_rmms[0]; let col_map = extract_sh_column_map(&config, num_witin); diff --git a/ceno_zkvm/src/instructions/riscv/gpu/shift_i.rs b/ceno_zkvm/src/instructions/riscv/gpu/shift_i.rs index 61342d270..e1555fcaf 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/shift_i.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/shift_i.rs @@ -25,7 +25,12 @@ pub fn extract_shift_i_column_map( let rd_id = config.i_insn.rd.id.id as u32; let rd_prev_ts = config.i_insn.rd.prev_ts.id as u32; let rd_prev_val: [u32; 2] = { - let limbs = config.i_insn.rd.prev_value.wits_in().expect("rd prev_value WitIns"); + let limbs = config + .i_insn + .rd + .prev_value + .wits_in() + .expect("rd prev_value WitIns"); assert_eq!(limbs.len(), 2); [limbs[0].id as u32, limbs[1].id as u32] }; @@ -39,30 +44,37 @@ pub fn extract_shift_i_column_map( let rs1_bytes: [u32; 4] = { let l = config.rs1_read.wits_in().expect("rs1_read WitIns"); assert_eq!(l.len(), 4); - [l[0].id as u32, l[1].id as u32, l[2].id as u32, l[3].id as u32] + [ + l[0].id as u32, + l[1].id as u32, + l[2].id as u32, + l[3].id as u32, + ] }; let rd_bytes: [u32; 4] = { let l = config.rd_written.wits_in().expect("rd_written WitIns"); assert_eq!(l.len(), 4); - [l[0].id as u32, l[1].id as u32, l[2].id as u32, l[3].id as u32] + [ + l[0].id as u32, + l[1].id as u32, + l[2].id as u32, + l[3].id as u32, + ] }; // Immediate let imm = config.imm.id as u32; // ShiftBase - let bit_shift_marker: [u32; 8] = std::array::from_fn(|i| { - config.shift_base_config.bit_shift_marker[i].id as u32 - }); - let limb_shift_marker: [u32; 4] = std::array::from_fn(|i| { - config.shift_base_config.limb_shift_marker[i].id as u32 - }); + let bit_shift_marker: [u32; 8] = + std::array::from_fn(|i| config.shift_base_config.bit_shift_marker[i].id as u32); + let limb_shift_marker: [u32; 4] = + std::array::from_fn(|i| config.shift_base_config.limb_shift_marker[i].id as u32); let bit_multiplier_left = config.shift_base_config.bit_multiplier_left.id as u32; let bit_multiplier_right = config.shift_base_config.bit_multiplier_right.id as u32; let b_sign = config.shift_base_config.b_sign.id as u32; - let bit_shift_carry: [u32; 4] = std::array::from_fn(|i| { - config.shift_base_config.bit_shift_carry[i].id as u32 - }); + let bit_shift_carry: [u32; 4] = + std::array::from_fn(|i| config.shift_base_config.bit_shift_carry[i].id as u32); ShiftIColumnMap { pc, @@ -113,7 +125,10 @@ mod tests { assert!( (col as usize) < col_map.num_cols as usize, "Column {} (index {}) out of range: {} >= {}", - i, col, col, col_map.num_cols + i, + col, + col, + col_map.num_cols ); } let mut seen = std::collections::HashSet::new(); @@ -140,13 +155,13 @@ mod tests { const EDGE_CASES: &[(u32, u32)] = &[ (0, 0), - (1, 0), // shift by 0 - (1, 31), // shift to MSB - (u32::MAX, 0), // no shift - (u32::MAX, 16), // shift half - (u32::MAX, 31), // shift max - (0x80000000, 1), // INT_MIN << 1 - (0xDEADBEEF, 4), // nibble shift + (1, 0), // shift by 0 + (1, 31), // shift to MSB + (u32::MAX, 0), // no shift + (u32::MAX, 16), // shift half + (u32::MAX, 31), // shift max + (0x80000000, 1), // INT_MIN << 1 + (0xDEADBEEF, 4), // nibble shift ]; let n = 1024; @@ -176,7 +191,12 @@ mod tests { let mut shard_ctx = ShardContext::default(); let (cpu_rmms, _lkm) = crate::instructions::cpu_assign_instances::>( - &config, &mut shard_ctx, num_witin, num_structural_witin, &steps, &indices, + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, ) .unwrap(); let cpu_witness = &cpu_rmms[0]; diff --git a/ceno_zkvm/src/instructions/riscv/gpu/shift_r.rs b/ceno_zkvm/src/instructions/riscv/gpu/shift_r.rs index 447a519a7..d6efa771c 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/shift_r.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/shift_r.rs @@ -34,7 +34,12 @@ pub fn extract_shift_r_column_map( let rd_id = config.r_insn.rd.id.id as u32; let rd_prev_ts = config.r_insn.rd.prev_ts.id as u32; let rd_prev_val: [u32; 2] = { - let limbs = config.r_insn.rd.prev_value.wits_in().expect("rd prev_value WitIns"); + let limbs = config + .r_insn + .rd + .prev_value + .wits_in() + .expect("rd prev_value WitIns"); assert_eq!(limbs.len(), 2); [limbs[0].id as u32, limbs[1].id as u32] }; @@ -48,32 +53,44 @@ pub fn extract_shift_r_column_map( let rs1_bytes: [u32; 4] = { let l = config.rs1_read.wits_in().expect("rs1_read WitIns"); assert_eq!(l.len(), 4); - [l[0].id as u32, l[1].id as u32, l[2].id as u32, l[3].id as u32] + [ + l[0].id as u32, + l[1].id as u32, + l[2].id as u32, + l[3].id as u32, + ] }; let rs2_bytes: [u32; 4] = { let l = config.rs2_read.wits_in().expect("rs2_read WitIns"); assert_eq!(l.len(), 4); - [l[0].id as u32, l[1].id as u32, l[2].id as u32, l[3].id as u32] + [ + l[0].id as u32, + l[1].id as u32, + l[2].id as u32, + l[3].id as u32, + ] }; let rd_bytes: [u32; 4] = { let l = config.rd_written.wits_in().expect("rd_written WitIns"); assert_eq!(l.len(), 4); - [l[0].id as u32, l[1].id as u32, l[2].id as u32, l[3].id as u32] + [ + l[0].id as u32, + l[1].id as u32, + l[2].id as u32, + l[3].id as u32, + ] }; // ShiftBase - let bit_shift_marker: [u32; 8] = std::array::from_fn(|i| { - config.shift_base_config.bit_shift_marker[i].id as u32 - }); - let limb_shift_marker: [u32; 4] = std::array::from_fn(|i| { - config.shift_base_config.limb_shift_marker[i].id as u32 - }); + let bit_shift_marker: [u32; 8] = + std::array::from_fn(|i| config.shift_base_config.bit_shift_marker[i].id as u32); + let limb_shift_marker: [u32; 4] = + std::array::from_fn(|i| config.shift_base_config.limb_shift_marker[i].id as u32); let bit_multiplier_left = config.shift_base_config.bit_multiplier_left.id as u32; let bit_multiplier_right = config.shift_base_config.bit_multiplier_right.id as u32; let b_sign = config.shift_base_config.b_sign.id as u32; - let bit_shift_carry: [u32; 4] = std::array::from_fn(|i| { - config.shift_base_config.bit_shift_carry[i].id as u32 - }); + let bit_shift_carry: [u32; 4] = + std::array::from_fn(|i| config.shift_base_config.bit_shift_carry[i].id as u32); ShiftRColumnMap { pc, @@ -127,7 +144,10 @@ mod tests { assert!( (col as usize) < col_map.num_cols as usize, "Column {} (index {}) out of range: {} >= {}", - i, col, col, col_map.num_cols + i, + col, + col, + col_map.num_cols ); } let mut seen = std::collections::HashSet::new(); @@ -154,13 +174,13 @@ mod tests { const EDGE_CASES: &[(u32, u32)] = &[ (0, 0), - (1, 0), // shift by 0 - (1, 31), // shift to MSB - (u32::MAX, 0), // no shift - (u32::MAX, 16), // shift half - (u32::MAX, 31), // shift max - (0x80000000, 1), // INT_MIN << 1 - (0xDEADBEEF, 4), // nibble shift + (1, 0), // shift by 0 + (1, 31), // shift to MSB + (u32::MAX, 0), // no shift + (u32::MAX, 16), // shift half + (u32::MAX, 31), // shift max + (0x80000000, 1), // INT_MIN << 1 + (0xDEADBEEF, 4), // nibble shift ]; let n = 1024; @@ -176,8 +196,13 @@ mod tests { let pc = ByteAddr(0x1000 + (i as u32) * 4); let insn_code = encode_rv32(InsnKind::SLL, 2, 3, 4, 0); StepRecord::new_r_instruction( - cycle, pc, insn_code, rs1, rs2, - Change::new((i as u32) % 200, rd_after), 0, + cycle, + pc, + insn_code, + rs1, + rs2, + Change::new((i as u32) % 200, rd_after), + 0, ) }) .collect(); @@ -185,7 +210,12 @@ mod tests { let mut shard_ctx = ShardContext::default(); let (cpu_rmms, _lkm) = crate::instructions::cpu_assign_instances::>( - &config, &mut shard_ctx, num_witin, num_structural_witin, &steps, &indices, + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, ) .unwrap(); let cpu_witness = &cpu_rmms[0]; diff --git a/ceno_zkvm/src/instructions/riscv/gpu/slt.rs b/ceno_zkvm/src/instructions/riscv/gpu/slt.rs index 2412a5dc8..e39e8acab 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/slt.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/slt.rs @@ -126,7 +126,10 @@ mod tests { assert!( (col as usize) < col_map.num_cols as usize, "Column {} (index {}) out of range: {} >= {}", - i, col, col, col_map.num_cols + i, + col, + col, + col_map.num_cols ); } let mut seen = std::collections::HashSet::new(); @@ -157,13 +160,22 @@ mod tests { // Mix positive, negative, equal cases let rs1 = ((i as i32) * 137 - 500) as u32; let rs2 = ((i as i32) * 89 - 300) as u32; - let rd_after = if (rs1 as i32) < (rs2 as i32) { 1u32 } else { 0u32 }; + let rd_after = if (rs1 as i32) < (rs2 as i32) { + 1u32 + } else { + 0u32 + }; let cycle = 4 + (i as u64) * 4; let pc = ByteAddr(0x1000 + (i as u32) * 4); let insn_code = encode_rv32(InsnKind::SLT, 2, 3, 4, 0); StepRecord::new_r_instruction( - cycle, pc, insn_code, rs1, rs2, - Change::new((i as u32) % 200, rd_after), 0, + cycle, + pc, + insn_code, + rs1, + rs2, + Change::new((i as u32) % 200, rd_after), + 0, ) }) .collect(); @@ -171,7 +183,12 @@ mod tests { let mut shard_ctx = ShardContext::default(); let (cpu_rmms, _lkm) = crate::instructions::cpu_assign_instances::>( - &config, &mut shard_ctx, num_witin, num_structural_witin, &steps, &indices, + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, ) .unwrap(); let cpu_witness = &cpu_rmms[0]; diff --git a/ceno_zkvm/src/instructions/riscv/gpu/slti.rs b/ceno_zkvm/src/instructions/riscv/gpu/slti.rs index 9c7cf262f..42d454507 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/slti.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/slti.rs @@ -109,7 +109,10 @@ mod tests { assert!( (col as usize) < col_map.num_cols as usize, "Column {} (index {}) out of range: {} >= {}", - i, col, col, col_map.num_cols + i, + col, + col, + col_map.num_cols ); } let mut seen = std::collections::HashSet::new(); @@ -139,7 +142,11 @@ mod tests { .map(|i| { let rs1 = ((i as i32) * 137 - 500) as u32; let imm = ((i as i32) % 2048 - 1024) as i32; // -1024..1023 - let rd_after = if (rs1 as i32) < (imm as i32) { 1u32 } else { 0u32 }; + let rd_after = if (rs1 as i32) < (imm as i32) { + 1u32 + } else { + 0u32 + }; let cycle = 4 + (i as u64) * 4; let pc = ByteAddr(0x1000 + (i as u32) * 4); let insn_code = encode_rv32(InsnKind::SLTI, 2, 0, 4, imm); @@ -157,7 +164,12 @@ mod tests { let mut shard_ctx = ShardContext::default(); let (cpu_rmms, _lkm) = crate::instructions::cpu_assign_instances::>( - &config, &mut shard_ctx, num_witin, num_structural_witin, &steps, &indices, + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, ) .unwrap(); let cpu_witness = &cpu_rmms[0]; diff --git a/ceno_zkvm/src/instructions/riscv/gpu/sub.rs b/ceno_zkvm/src/instructions/riscv/gpu/sub.rs index 4531cfe38..80bc9b0ad 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/sub.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/sub.rs @@ -129,7 +129,10 @@ mod tests { assert!( (col as usize) < col_map.num_cols as usize, "Column {} (index {}) out of range: {} >= {}", - i, col, col, col_map.num_cols + i, + col, + col, + col_map.num_cols ); } let mut seen = std::collections::HashSet::new(); @@ -156,12 +159,12 @@ mod tests { const EDGE_CASES: &[(u32, u32)] = &[ (0, 0), - (0, 1), // underflow + (0, 1), // underflow (1, 0), - (0, u32::MAX), // underflow + (0, u32::MAX), // underflow (u32::MAX, u32::MAX), - (0x80000000, 1), // INT_MIN - 1 - (0, 0x80000000), // 0 - INT_MIN + (0x80000000, 1), // INT_MIN - 1 + (0, 0x80000000), // 0 - INT_MIN (0x7FFFFFFF, 0xFFFFFFFF), // INT_MAX - (-1) ]; @@ -178,8 +181,13 @@ mod tests { let pc = ByteAddr(0x1000 + (i as u32) * 4); let insn_code = encode_rv32(InsnKind::SUB, 2, 3, 4, 0); StepRecord::new_r_instruction( - cycle, pc, insn_code, rs1, rs2, - Change::new((i as u32) % 200, rd_after), 0, + cycle, + pc, + insn_code, + rs1, + rs2, + Change::new((i as u32) % 200, rd_after), + 0, ) }) .collect(); @@ -188,7 +196,12 @@ mod tests { // CPU path let mut shard_ctx = ShardContext::default(); let (cpu_rmms, _lkm) = crate::instructions::cpu_assign_instances::>( - &config, &mut shard_ctx, num_witin, num_structural_witin, &steps, &indices, + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, ) .unwrap(); let cpu_witness = &cpu_rmms[0]; diff --git a/ceno_zkvm/src/instructions/riscv/gpu/sw.rs b/ceno_zkvm/src/instructions/riscv/gpu/sw.rs index 5fb0a799b..4bdfafa5c 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/sw.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/sw.rs @@ -118,7 +118,10 @@ mod tests { assert!( (col as usize) < col_map.num_cols as usize, "Column {} (index {}) out of range: {} >= {}", - i, col, col, col_map.num_cols + i, + col, + col, + col_map.num_cols ); } let mut seen = std::collections::HashSet::new(); @@ -131,9 +134,7 @@ mod tests { #[cfg(feature = "gpu")] fn test_gpu_witgen_sw_correctness() { use crate::e2e::ShardContext; - use ceno_emul::{ - ByteAddr, Change, InsnKind, StepRecord, WordAddr, WriteOp, encode_rv32, - }; + use ceno_emul::{ByteAddr, Change, InsnKind, StepRecord, WordAddr, WriteOp, encode_rv32}; use ceno_gpu::{Buffer, bb31::CudaHalBB31}; let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); @@ -178,16 +179,15 @@ mod tests { let indices: Vec = (0..n).collect(); let mut shard_ctx = ShardContext::default(); - let (cpu_rmms, _lkm) = - crate::instructions::cpu_assign_instances::>( - &config, - &mut shard_ctx, - num_witin, - num_structural_witin, - &steps, - &indices, - ) - .unwrap(); + let (cpu_rmms, _lkm) = crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(); let cpu_witness = &cpu_rmms[0]; let col_map = extract_sw_column_map(&config, num_witin); diff --git a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs index 1b62697d0..42a48fb1c 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs @@ -2,24 +2,24 @@ /// /// This module provides `try_gpu_assign_instances` which: /// 1. Runs the GPU kernel to fill the witness matrix (fast) -/// 2. Runs a CPU loop to collect side effects (shard_ctx.send, lk_multiplicity) +/// 2. Runs a lightweight CPU loop to collect side effects without witness replay /// 3. Returns the GPU-generated witness + CPU-collected side effects use ceno_emul::{StepIndex, StepRecord}; -use ceno_gpu::{Buffer, CudaHal, CudaSlice, bb31::CudaHalBB31, common::transpose::matrix_transpose}; +use ceno_gpu::{ + Buffer, CudaHal, CudaSlice, bb31::CudaHalBB31, common::transpose::matrix_transpose, +}; use ff_ext::ExtensionField; -use gkr_iop::utils::lk_multiplicity::Multiplicity; -use multilinear_extensions::util::max_usable_threads; +use gkr_iop::{tables::LookupTable, utils::lk_multiplicity::Multiplicity}; use p3::field::FieldAlgebra; -use rayon::{ - iter::{IndexedParallelIterator, ParallelIterator}, - slice::ParallelSlice, -}; -use std::cell::RefCell; +use std::cell::{Cell, RefCell}; use tracing::info_span; use witness::{InstancePaddingStrategy, RowMajorMatrix}; use crate::{ - e2e::ShardContext, error::ZKVMError, instructions::Instruction, tables::RMMCollections, + e2e::ShardContext, + error::ZKVMError, + instructions::{Instruction, cpu_collect_shard_side_effects, cpu_collect_side_effects}, + tables::RMMCollections, witness::LkMultiplicity, }; @@ -27,9 +27,9 @@ use crate::{ pub enum GpuWitgenKind { Add, Sub, - LogicR, + LogicR(u32), // 0=AND, 1=OR, 2=XOR #[cfg(feature = "u16limb_circuit")] - LogicI, + LogicI(u32), // 0=AND, 1=OR, 2=XOR #[cfg(feature = "u16limb_circuit")] Addi, #[cfg(feature = "u16limb_circuit")] @@ -43,7 +43,7 @@ pub enum GpuWitgenKind { #[cfg(feature = "u16limb_circuit")] ShiftI(u32), // 0=SLLI, 1=SRLI, 2=SRAI #[cfg(feature = "u16limb_circuit")] - Slt(u32), // 1=SLT(signed), 0=SLTU(unsigned) + Slt(u32), // 1=SLT(signed), 0=SLTU(unsigned) #[cfg(feature = "u16limb_circuit")] Slti(u32), // 1=SLTI(signed), 0=SLTIU(unsigned) #[cfg(feature = "u16limb_circuit")] @@ -59,7 +59,10 @@ pub enum GpuWitgenKind { #[cfg(feature = "u16limb_circuit")] Sb, #[cfg(feature = "u16limb_circuit")] - LoadSub { load_width: u32, is_signed: u32 }, + LoadSub { + load_width: u32, + is_signed: u32, + }, #[cfg(feature = "u16limb_circuit")] Mul(u32), // 0=MUL, 1=MULH, 2=MULHU, 3=MULHSU #[cfg(feature = "u16limb_circuit")] @@ -80,6 +83,18 @@ struct ShardStepsCache { thread_local! { static SHARD_STEPS_DEVICE: RefCell> = const { RefCell::new(None) }; + /// Thread-local flag to force CPU path (used by debug comparison code). + static FORCE_CPU_PATH: Cell = const { Cell::new(false) }; +} + +/// Force the current thread to use CPU path for all GPU witgen calls. +/// Used by debug comparison code in e2e.rs to run a CPU-only reference. +pub fn set_force_cpu_path(force: bool) { + FORCE_CPU_PATH.with(|f| f.set(force)); +} + +fn is_force_cpu_path() -> bool { + FORCE_CPU_PATH.with(|f| f.get()) } /// Upload shard_steps to GPU, reusing cached device buffer if the same data. @@ -150,6 +165,23 @@ pub fn invalidate_shard_steps_cache() { }); } +/// Returns true if GPU witgen is globally disabled via CENO_GPU_DISABLE_WITGEN env var. +/// The value is cached at first access so it's immune to runtime env var manipulation. +fn is_gpu_witgen_disabled() -> bool { + use std::sync::OnceLock; + static DISABLED: OnceLock = OnceLock::new(); + *DISABLED.get_or_init(|| { + let val = std::env::var_os("CENO_GPU_DISABLE_WITGEN"); + let disabled = val.is_some(); + // Use eprintln to bypass tracing filters — always visible on stderr + eprintln!( + "[GPU witgen] CENO_GPU_DISABLE_WITGEN={:?} → disabled={}", + val, disabled + ); + disabled + }) +} + /// Try to run GPU witness generation for the given instruction. /// Returns `Ok(Some(...))` if GPU was used, `Ok(None)` if GPU is unavailable (caller should fallback to CPU). /// @@ -171,6 +203,18 @@ pub(crate) fn try_gpu_assign_instances>( ) -> Result, Multiplicity)>, ZKVMError> { use gkr_iop::gpu::get_cuda_hal; + if is_gpu_witgen_disabled() || is_force_cpu_path() { + return Ok(None); + } + + if !I::GPU_SIDE_EFFECTS { + return Ok(None); + } + + if is_kind_disabled(kind) { + return Ok(None); + } + let total_instances = step_indices.len(); if total_instances == 0 { // Empty: just return empty matrices @@ -226,8 +270,8 @@ fn gpu_assign_instances_inner>( let num_structural_witin = num_structural_witin.max(1); let total_instances = step_indices.len(); - // Step 1: GPU fills witness matrix - let gpu_witness = info_span!("gpu_kernel").in_scope(|| { + // Step 1: GPU fills witness matrix (+ LK counters for merged kinds) + let (gpu_witness, gpu_lk_counters) = info_span!("gpu_kernel").in_scope(|| { gpu_fill_witness::( hal, config, @@ -239,10 +283,29 @@ fn gpu_assign_instances_inner>( ) })?; - // Step 2: CPU collects side effects (shard_ctx.send, lk_multiplicity) - let lk_multiplicity = info_span!("cpu_side_effects").in_scope(|| { - collect_side_effects::(config, shard_ctx, num_witin, shard_steps, step_indices) - })?; + // Step 2: Collect side effects + // For verified GPU kinds: LK from GPU, shard records from CPU + // For unverified kinds: full CPU side effects (GPU witness still used) + let lk_multiplicity = if gpu_lk_counters.is_some() && kind_has_verified_lk(kind) { + let lk_multiplicity = info_span!("gpu_lk_d2h").in_scope(|| { + gpu_lk_counters_to_multiplicity(gpu_lk_counters.unwrap()) + })?; + // CPU: collect shard records only (send/addr_accessed). + // We call collect_shard_side_effects which also computes fetch, but we + // discard its returned Multiplicity since GPU already has all LK + fetch. + info_span!("cpu_shard_records").in_scope(|| { + let _ = collect_shard_side_effects::(config, shard_ctx, shard_steps, step_indices)?; + Ok::<(), ZKVMError>(()) + })?; + lk_multiplicity + } else { + // GPU LK counters missing or unverified — fall back to full CPU side effects + info_span!("cpu_side_effects").in_scope(|| { + collect_side_effects::(config, shard_ctx, shard_steps, step_indices) + })? + }; + debug_compare_final_lk::(config, shard_ctx, num_witin, num_structural_witin, shard_steps, step_indices, kind, &lk_multiplicity)?; + debug_compare_shard_side_effects::(config, shard_ctx, shard_steps, step_indices, kind)?; // Step 3: Build structural witness (just selector = ONE) let mut raw_structural = RowMajorMatrix::::new( @@ -266,14 +329,50 @@ fn gpu_assign_instances_inner>( ) })?; raw_witin.padding_by_strategy(); + debug_compare_witness::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + kind, + &raw_witin, + )?; - Ok(( - [raw_witin, raw_structural], - lk_multiplicity.into_finalize_result(), - )) + Ok(([raw_witin, raw_structural], lk_multiplicity)) +} + +type WitBuf = ceno_gpu::common::BufferImpl< + 'static, + ::BaseField, +>; +type LkBuf = ceno_gpu::common::BufferImpl<'static, u32>; +type WitResult = ceno_gpu::common::witgen_types::GpuWitnessResult; +type LkResult = ceno_gpu::common::witgen_types::GpuLookupCountersResult; + +/// Compute fetch counter parameters from step data. +fn compute_fetch_params( + shard_steps: &[StepRecord], + step_indices: &[StepIndex], +) -> (u32, usize) { + let mut min_pc = u32::MAX; + let mut max_pc = 0u32; + for &idx in step_indices { + let pc = shard_steps[idx].pc().before.0; + min_pc = min_pc.min(pc); + max_pc = max_pc.max(pc); + } + if min_pc > max_pc { + return (0, 0); + } + let fetch_base_pc = min_pc; + let fetch_num_slots = ((max_pc - min_pc) / 4 + 1) as usize; + (fetch_base_pc, fetch_num_slots) } /// GPU kernel dispatch based on instruction kind. +/// All kinds return witness + LK counters (merged into single GPU kernel). fn gpu_fill_witness>( hal: &CudaHalBB31, config: &I::InstructionConfig, @@ -282,12 +381,7 @@ fn gpu_fill_witness>( shard_steps: &[StepRecord], step_indices: &[StepIndex], kind: GpuWitgenKind, -) -> Result< - ceno_gpu::common::witgen_types::GpuWitnessResult< - ceno_gpu::common::BufferImpl<'static, ::BaseField>, - >, - ZKVMError, -> { +) -> Result<(WitResult, Option), ZKVMError> { // Upload shard_steps to GPU once (cached across ADD/LW calls within same shard). let shard_id = shard_ctx.shard_id; info_span!("upload_shard_steps") @@ -298,6 +392,17 @@ fn gpu_fill_witness>( .in_scope(|| step_indices.iter().map(|&i| i as u32).collect()); let shard_offset = shard_ctx.current_shard_offset_cycle(); + // Helper to split GpuWitgenFullResult into (witness, Some(lk_counters)) + macro_rules! split_full { + ($result:expr) => {{ + let full = $result?; + Ok((full.witness, Some(full.lk_counters))) + }}; + } + + // Compute fetch params for all GPU kinds (LK counters are merged into all kernels) + let (fetch_base_pc, fetch_num_slots) = compute_fetch_params(shard_steps, step_indices); + match kind { GpuWitgenKind::Add => { let arith_config = unsafe { @@ -308,10 +413,19 @@ fn gpu_fill_witness>( .in_scope(|| super::add::extract_add_column_map(arith_config, num_witin)); info_span!("hal_witgen_add").in_scope(|| { with_cached_shard_steps(|gpu_records| { - hal.witgen_add(&col_map, gpu_records, &indices_u32, shard_offset, None) + split_full!(hal + .witgen_add( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + fetch_base_pc, + fetch_num_slots, + None, + ) .map_err(|e| { ZKVMError::InvalidWitness(format!("GPU witgen_add failed: {e}").into()) - }) + })) }) }) } @@ -324,14 +438,23 @@ fn gpu_fill_witness>( .in_scope(|| super::sub::extract_sub_column_map(arith_config, num_witin)); info_span!("hal_witgen_sub").in_scope(|| { with_cached_shard_steps(|gpu_records| { - hal.witgen_sub(&col_map, gpu_records, &indices_u32, shard_offset, None) + split_full!(hal + .witgen_sub( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + fetch_base_pc, + fetch_num_slots, + None, + ) .map_err(|e| { ZKVMError::InvalidWitness(format!("GPU witgen_sub failed: {e}").into()) - }) + })) }) }) } - GpuWitgenKind::LogicR => { + GpuWitgenKind::LogicR(logic_kind) => { let logic_config = unsafe { &*(config as *const I::InstructionConfig as *const crate::instructions::riscv::logic::logic_circuit::LogicConfig) @@ -340,17 +463,27 @@ fn gpu_fill_witness>( .in_scope(|| super::logic_r::extract_logic_r_column_map(logic_config, num_witin)); info_span!("hal_witgen_logic_r").in_scope(|| { with_cached_shard_steps(|gpu_records| { - hal.witgen_logic_r(&col_map, gpu_records, &indices_u32, shard_offset, None) + split_full!(hal + .witgen_logic_r( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + logic_kind, + fetch_base_pc, + fetch_num_slots, + None, + ) .map_err(|e| { ZKVMError::InvalidWitness( format!("GPU witgen_logic_r failed: {e}").into(), ) - }) + })) }) }) } #[cfg(feature = "u16limb_circuit")] - GpuWitgenKind::LogicI => { + GpuWitgenKind::LogicI(logic_kind) => { let logic_config = unsafe { &*(config as *const I::InstructionConfig as *const crate::instructions::riscv::logic_imm::logic_imm_circuit_v2::LogicConfig) @@ -359,12 +492,22 @@ fn gpu_fill_witness>( .in_scope(|| super::logic_i::extract_logic_i_column_map(logic_config, num_witin)); info_span!("hal_witgen_logic_i").in_scope(|| { with_cached_shard_steps(|gpu_records| { - hal.witgen_logic_i(&col_map, gpu_records, &indices_u32, shard_offset, None) + split_full!(hal + .witgen_logic_i( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + logic_kind, + fetch_base_pc, + fetch_num_slots, + None, + ) .map_err(|e| { ZKVMError::InvalidWitness( format!("GPU witgen_logic_i failed: {e}").into(), ) - }) + })) }) }) } @@ -378,12 +521,19 @@ fn gpu_fill_witness>( .in_scope(|| super::addi::extract_addi_column_map(addi_config, num_witin)); info_span!("hal_witgen_addi").in_scope(|| { with_cached_shard_steps(|gpu_records| { - hal.witgen_addi(&col_map, gpu_records, &indices_u32, shard_offset, None) + split_full!(hal + .witgen_addi( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + fetch_base_pc, + fetch_num_slots, + None, + ) .map_err(|e| { - ZKVMError::InvalidWitness( - format!("GPU witgen_addi failed: {e}").into(), - ) - }) + ZKVMError::InvalidWitness(format!("GPU witgen_addi failed: {e}").into()) + })) }) }) } @@ -397,12 +547,19 @@ fn gpu_fill_witness>( .in_scope(|| super::lui::extract_lui_column_map(lui_config, num_witin)); info_span!("hal_witgen_lui").in_scope(|| { with_cached_shard_steps(|gpu_records| { - hal.witgen_lui(&col_map, gpu_records, &indices_u32, shard_offset, None) + split_full!(hal + .witgen_lui( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + fetch_base_pc, + fetch_num_slots, + None, + ) .map_err(|e| { - ZKVMError::InvalidWitness( - format!("GPU witgen_lui failed: {e}").into(), - ) - }) + ZKVMError::InvalidWitness(format!("GPU witgen_lui failed: {e}").into()) + })) }) }) } @@ -416,12 +573,21 @@ fn gpu_fill_witness>( .in_scope(|| super::auipc::extract_auipc_column_map(auipc_config, num_witin)); info_span!("hal_witgen_auipc").in_scope(|| { with_cached_shard_steps(|gpu_records| { - hal.witgen_auipc(&col_map, gpu_records, &indices_u32, shard_offset, None) + split_full!(hal + .witgen_auipc( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + fetch_base_pc, + fetch_num_slots, + None, + ) .map_err(|e| { ZKVMError::InvalidWitness( format!("GPU witgen_auipc failed: {e}").into(), ) - }) + })) }) }) } @@ -435,12 +601,19 @@ fn gpu_fill_witness>( .in_scope(|| super::jal::extract_jal_column_map(jal_config, num_witin)); info_span!("hal_witgen_jal").in_scope(|| { with_cached_shard_steps(|gpu_records| { - hal.witgen_jal(&col_map, gpu_records, &indices_u32, shard_offset, None) + split_full!(hal + .witgen_jal( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + fetch_base_pc, + fetch_num_slots, + None, + ) .map_err(|e| { - ZKVMError::InvalidWitness( - format!("GPU witgen_jal failed: {e}").into(), - ) - }) + ZKVMError::InvalidWitness(format!("GPU witgen_jal failed: {e}").into()) + })) }) }) } @@ -448,18 +621,30 @@ fn gpu_fill_witness>( GpuWitgenKind::ShiftR(shift_kind) => { let shift_config = unsafe { &*(config as *const I::InstructionConfig - as *const crate::instructions::riscv::shift::shift_circuit_v2::ShiftRTypeConfig) + as *const crate::instructions::riscv::shift::shift_circuit_v2::ShiftRTypeConfig< + E, + >) }; let col_map = info_span!("col_map") .in_scope(|| super::shift_r::extract_shift_r_column_map(shift_config, num_witin)); info_span!("hal_witgen_shift_r").in_scope(|| { with_cached_shard_steps(|gpu_records| { - hal.witgen_shift_r(&col_map, gpu_records, &indices_u32, shard_offset, shift_kind, None) + split_full!(hal + .witgen_shift_r( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + shift_kind, + fetch_base_pc, + fetch_num_slots, + None, + ) .map_err(|e| { ZKVMError::InvalidWitness( format!("GPU witgen_shift_r failed: {e}").into(), ) - }) + })) }) }) } @@ -467,18 +652,30 @@ fn gpu_fill_witness>( GpuWitgenKind::ShiftI(shift_kind) => { let shift_config = unsafe { &*(config as *const I::InstructionConfig - as *const crate::instructions::riscv::shift::shift_circuit_v2::ShiftImmConfig) + as *const crate::instructions::riscv::shift::shift_circuit_v2::ShiftImmConfig< + E, + >) }; let col_map = info_span!("col_map") .in_scope(|| super::shift_i::extract_shift_i_column_map(shift_config, num_witin)); info_span!("hal_witgen_shift_i").in_scope(|| { with_cached_shard_steps(|gpu_records| { - hal.witgen_shift_i(&col_map, gpu_records, &indices_u32, shard_offset, shift_kind, None) + split_full!(hal + .witgen_shift_i( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + shift_kind, + fetch_base_pc, + fetch_num_slots, + None, + ) .map_err(|e| { ZKVMError::InvalidWitness( format!("GPU witgen_shift_i failed: {e}").into(), ) - }) + })) }) }) } @@ -492,12 +689,22 @@ fn gpu_fill_witness>( .in_scope(|| super::slt::extract_slt_column_map(slt_config, num_witin)); info_span!("hal_witgen_slt").in_scope(|| { with_cached_shard_steps(|gpu_records| { - hal.witgen_slt(&col_map, gpu_records, &indices_u32, shard_offset, is_signed, None) + split_full!(hal + .witgen_slt( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + is_signed, + fetch_base_pc, + fetch_num_slots, + None, + ) .map_err(|e| { ZKVMError::InvalidWitness( format!("GPU witgen_slt failed: {e}").into(), ) - }) + })) }) }) } @@ -511,12 +718,22 @@ fn gpu_fill_witness>( .in_scope(|| super::slti::extract_slti_column_map(slti_config, num_witin)); info_span!("hal_witgen_slti").in_scope(|| { with_cached_shard_steps(|gpu_records| { - hal.witgen_slti(&col_map, gpu_records, &indices_u32, shard_offset, is_signed, None) + split_full!(hal + .witgen_slti( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + is_signed, + fetch_base_pc, + fetch_num_slots, + None, + ) .map_err(|e| { ZKVMError::InvalidWitness( format!("GPU witgen_slti failed: {e}").into(), ) - }) + })) }) }) } @@ -524,18 +741,31 @@ fn gpu_fill_witness>( GpuWitgenKind::BranchEq(is_beq) => { let branch_config = unsafe { &*(config as *const I::InstructionConfig - as *const crate::instructions::riscv::branch::branch_circuit_v2::BranchConfig) + as *const crate::instructions::riscv::branch::branch_circuit_v2::BranchConfig< + E, + >) }; - let col_map = info_span!("col_map") - .in_scope(|| super::branch_eq::extract_branch_eq_column_map(branch_config, num_witin)); + let col_map = info_span!("col_map").in_scope(|| { + super::branch_eq::extract_branch_eq_column_map(branch_config, num_witin) + }); info_span!("hal_witgen_branch_eq").in_scope(|| { with_cached_shard_steps(|gpu_records| { - hal.witgen_branch_eq(&col_map, gpu_records, &indices_u32, shard_offset, is_beq, None) + split_full!(hal + .witgen_branch_eq( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + is_beq, + fetch_base_pc, + fetch_num_slots, + None, + ) .map_err(|e| { ZKVMError::InvalidWitness( format!("GPU witgen_branch_eq failed: {e}").into(), ) - }) + })) }) }) } @@ -543,18 +773,31 @@ fn gpu_fill_witness>( GpuWitgenKind::BranchCmp(is_signed) => { let branch_config = unsafe { &*(config as *const I::InstructionConfig - as *const crate::instructions::riscv::branch::branch_circuit_v2::BranchConfig) + as *const crate::instructions::riscv::branch::branch_circuit_v2::BranchConfig< + E, + >) }; - let col_map = info_span!("col_map") - .in_scope(|| super::branch_cmp::extract_branch_cmp_column_map(branch_config, num_witin)); + let col_map = info_span!("col_map").in_scope(|| { + super::branch_cmp::extract_branch_cmp_column_map(branch_config, num_witin) + }); info_span!("hal_witgen_branch_cmp").in_scope(|| { with_cached_shard_steps(|gpu_records| { - hal.witgen_branch_cmp(&col_map, gpu_records, &indices_u32, shard_offset, is_signed, None) + split_full!(hal + .witgen_branch_cmp( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + is_signed, + fetch_base_pc, + fetch_num_slots, + None, + ) .map_err(|e| { ZKVMError::InvalidWitness( format!("GPU witgen_branch_cmp failed: {e}").into(), ) - }) + })) }) }) } @@ -568,12 +811,19 @@ fn gpu_fill_witness>( .in_scope(|| super::jalr::extract_jalr_column_map(jalr_config, num_witin)); info_span!("hal_witgen_jalr").in_scope(|| { with_cached_shard_steps(|gpu_records| { - hal.witgen_jalr(&col_map, gpu_records, &indices_u32, shard_offset, None) + split_full!(hal + .witgen_jalr( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + fetch_base_pc, + fetch_num_slots, + None, + ) .map_err(|e| { - ZKVMError::InvalidWitness( - format!("GPU witgen_jalr failed: {e}").into(), - ) - }) + ZKVMError::InvalidWitness(format!("GPU witgen_jalr failed: {e}").into()) + })) }) }) } @@ -583,16 +833,25 @@ fn gpu_fill_witness>( &*(config as *const I::InstructionConfig as *const crate::instructions::riscv::memory::store_v2::StoreConfig) }; + let mem_max_bits = sw_config.memory_addr.max_bits as u32; let col_map = info_span!("col_map") .in_scope(|| super::sw::extract_sw_column_map(sw_config, num_witin)); info_span!("hal_witgen_sw").in_scope(|| { with_cached_shard_steps(|gpu_records| { - hal.witgen_sw(&col_map, gpu_records, &indices_u32, shard_offset, None) + split_full!(hal + .witgen_sw( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + mem_max_bits, + fetch_base_pc, + fetch_num_slots, + None, + ) .map_err(|e| { - ZKVMError::InvalidWitness( - format!("GPU witgen_sw failed: {e}").into(), - ) - }) + ZKVMError::InvalidWitness(format!("GPU witgen_sw failed: {e}").into()) + })) }) }) } @@ -602,16 +861,25 @@ fn gpu_fill_witness>( &*(config as *const I::InstructionConfig as *const crate::instructions::riscv::memory::store_v2::StoreConfig) }; + let mem_max_bits = sh_config.memory_addr.max_bits as u32; let col_map = info_span!("col_map") .in_scope(|| super::sh::extract_sh_column_map(sh_config, num_witin)); info_span!("hal_witgen_sh").in_scope(|| { with_cached_shard_steps(|gpu_records| { - hal.witgen_sh(&col_map, gpu_records, &indices_u32, shard_offset, None) + split_full!(hal + .witgen_sh( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + mem_max_bits, + fetch_base_pc, + fetch_num_slots, + None, + ) .map_err(|e| { - ZKVMError::InvalidWitness( - format!("GPU witgen_sh failed: {e}").into(), - ) - }) + ZKVMError::InvalidWitness(format!("GPU witgen_sh failed: {e}").into()) + })) }) }) } @@ -621,37 +889,68 @@ fn gpu_fill_witness>( &*(config as *const I::InstructionConfig as *const crate::instructions::riscv::memory::store_v2::StoreConfig) }; + let mem_max_bits = sb_config.memory_addr.max_bits as u32; let col_map = info_span!("col_map") .in_scope(|| super::sb::extract_sb_column_map(sb_config, num_witin)); info_span!("hal_witgen_sb").in_scope(|| { with_cached_shard_steps(|gpu_records| { - hal.witgen_sb(&col_map, gpu_records, &indices_u32, shard_offset, None) + split_full!(hal + .witgen_sb( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + mem_max_bits, + fetch_base_pc, + fetch_num_slots, + None, + ) .map_err(|e| { - ZKVMError::InvalidWitness( - format!("GPU witgen_sb failed: {e}").into(), - ) - }) + ZKVMError::InvalidWitness(format!("GPU witgen_sb failed: {e}").into()) + })) }) }) } #[cfg(feature = "u16limb_circuit")] - GpuWitgenKind::LoadSub { load_width, is_signed } => { + GpuWitgenKind::LoadSub { + load_width, + is_signed, + } => { let load_config = unsafe { &*(config as *const I::InstructionConfig as *const crate::instructions::riscv::memory::load_v2::LoadConfig) }; let is_byte = load_width == 8; let is_signed_bool = is_signed != 0; - let col_map = info_span!("col_map") - .in_scope(|| super::load_sub::extract_load_sub_column_map(load_config, num_witin, is_byte, is_signed_bool)); + let col_map = info_span!("col_map").in_scope(|| { + super::load_sub::extract_load_sub_column_map( + load_config, + num_witin, + is_byte, + is_signed_bool, + ) + }); + let mem_max_bits = load_config.memory_addr.max_bits as u32; info_span!("hal_witgen_load_sub").in_scope(|| { with_cached_shard_steps(|gpu_records| { - hal.witgen_load_sub(&col_map, gpu_records, &indices_u32, shard_offset, load_width, is_signed, None) + split_full!(hal + .witgen_load_sub( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + load_width, + is_signed, + mem_max_bits, + fetch_base_pc, + fetch_num_slots, + None, + ) .map_err(|e| { ZKVMError::InvalidWitness( format!("GPU witgen_load_sub failed: {e}").into(), ) - }) + })) }) }) } @@ -665,12 +964,22 @@ fn gpu_fill_witness>( .in_scope(|| super::mul::extract_mul_column_map(mul_config, num_witin, mul_kind)); info_span!("hal_witgen_mul").in_scope(|| { with_cached_shard_steps(|gpu_records| { - hal.witgen_mul(&col_map, gpu_records, &indices_u32, shard_offset, mul_kind, None) + split_full!(hal + .witgen_mul( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + mul_kind, + fetch_base_pc, + fetch_num_slots, + None, + ) .map_err(|e| { ZKVMError::InvalidWitness( format!("GPU witgen_mul failed: {e}").into(), ) - }) + })) }) }) } @@ -684,12 +993,22 @@ fn gpu_fill_witness>( .in_scope(|| super::div::extract_div_column_map(div_config, num_witin)); info_span!("hal_witgen_div").in_scope(|| { with_cached_shard_steps(|gpu_records| { - hal.witgen_div(&col_map, gpu_records, &indices_u32, shard_offset, div_kind, None) + split_full!(hal + .witgen_div( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + div_kind, + fetch_base_pc, + fetch_num_slots, + None, + ) .map_err(|e| { ZKVMError::InvalidWitness( format!("GPU witgen_div failed: {e}").into(), ) - }) + })) }) }) } @@ -704,14 +1023,25 @@ fn gpu_fill_witness>( &*(config as *const I::InstructionConfig as *const crate::instructions::riscv::memory::load::LoadConfig) }; + let mem_max_bits = load_config.memory_addr.max_bits as u32; let col_map = info_span!("col_map") .in_scope(|| super::lw::extract_lw_column_map(load_config, num_witin)); info_span!("hal_witgen_lw").in_scope(|| { with_cached_shard_steps(|gpu_records| { - hal.witgen_lw(&col_map, gpu_records, &indices_u32, shard_offset, None) + split_full!(hal + .witgen_lw( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + mem_max_bits, + fetch_base_pc, + fetch_num_slots, + None, + ) .map_err(|e| { ZKVMError::InvalidWitness(format!("GPU witgen_lw failed: {e}").into()) - }) + })) }) }) } @@ -723,46 +1053,504 @@ fn gpu_fill_witness>( fn collect_side_effects>( config: &I::InstructionConfig, shard_ctx: &mut ShardContext, + shard_steps: &[StepRecord], + step_indices: &[StepIndex], +) -> Result, ZKVMError> { + cpu_collect_side_effects::(config, shard_ctx, shard_steps, step_indices) +} + +fn collect_shard_side_effects>( + config: &I::InstructionConfig, + shard_ctx: &mut ShardContext, + shard_steps: &[StepRecord], + step_indices: &[StepIndex], +) -> Result, ZKVMError> { + cpu_collect_shard_side_effects::(config, shard_ctx, shard_steps, step_indices) +} + +fn kind_tag(kind: GpuWitgenKind) -> &'static str { + match kind { + GpuWitgenKind::Add => "add", + GpuWitgenKind::Sub => "sub", + GpuWitgenKind::LogicR(_) => "logic_r", + GpuWitgenKind::Lw => "lw", + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::LogicI(_) => "logic_i", + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Addi => "addi", + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Lui => "lui", + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Auipc => "auipc", + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Jal => "jal", + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::ShiftR(_) => "shift_r", + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::ShiftI(_) => "shift_i", + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Slt(_) => "slt", + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Slti(_) => "slti", + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::BranchEq(_) => "branch_eq", + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::BranchCmp(_) => "branch_cmp", + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Jalr => "jalr", + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Sw => "sw", + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Sh => "sh", + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Sb => "sb", + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::LoadSub { .. } => "load_sub", + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Mul(_) => "mul", + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Div(_) => "div", + } +} + +/// Returns true if the GPU CUDA kernel for this kind has been verified to produce +/// correct LK multiplicity counters matching the CPU baseline. +/// Unverified kinds fall back to CPU full side effects (GPU still handles witness). +/// +/// Override with `CENO_GPU_DISABLE_LK_KINDS=add,sub,...` to force specific kinds +/// back to CPU LK (for binary-search debugging). +/// Set `CENO_GPU_DISABLE_LK_KINDS=all` to disable GPU LK for ALL kinds. +fn kind_has_verified_lk(kind: GpuWitgenKind) -> bool { + if is_lk_kind_disabled(kind) { + return false; + } + match kind { + // Phase B verified (Add/Sub/LogicR/Lw) + GpuWitgenKind::Add => true, + GpuWitgenKind::Sub => true, + GpuWitgenKind::LogicR(_) => true, + GpuWitgenKind::Lw => true, + // Phase C verified via debug_compare_final_lk + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Addi => true, + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::LogicI(_) => true, + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Lui => true, + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Slti(_) => true, + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::BranchEq(_) => true, + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::BranchCmp(_) => true, + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Sw => true, + // Phase C CUDA kernel fixes applied + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::ShiftI(_) => true, + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Auipc => true, + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Jal => true, + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Jalr => true, + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Sb => true, + // Remaining kinds enabled + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::ShiftR(_) => true, + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Slt(_) => true, + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Sh => true, + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::LoadSub { .. } => true, + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Mul(_) => true, + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Div(_) => true, + _ => false, + } +} + +/// Check if GPU LK is disabled for a specific kind via CENO_GPU_DISABLE_LK_KINDS env var. +/// Format: CENO_GPU_DISABLE_LK_KINDS=add,sub,lw (comma-separated kind tags) +/// Special value: CENO_GPU_DISABLE_LK_KINDS=all (disables GPU LK for ALL kinds) +fn is_lk_kind_disabled(kind: GpuWitgenKind) -> bool { + thread_local! { + static DISABLED: std::cell::OnceCell> = const { std::cell::OnceCell::new() }; + } + DISABLED.with(|cell| { + let disabled = cell.get_or_init(|| { + std::env::var("CENO_GPU_DISABLE_LK_KINDS") + .ok() + .map(|s| s.split(',').map(|t| t.trim().to_lowercase()).collect()) + .unwrap_or_default() + }); + if disabled.is_empty() { + return false; + } + if disabled.iter().any(|d| d == "all") { + return true; + } + let tag = kind_tag(kind); + disabled.iter().any(|d| d == tag) + }) +} + +/// Check if a specific GPU witgen kind is disabled via CENO_GPU_DISABLE_KINDS env var. +/// Format: CENO_GPU_DISABLE_KINDS=add,sub,lw (comma-separated kind tags) +fn is_kind_disabled(kind: GpuWitgenKind) -> bool { + thread_local! { + static DISABLED: std::cell::OnceCell> = const { std::cell::OnceCell::new() }; + } + DISABLED.with(|cell| { + let disabled = cell.get_or_init(|| { + std::env::var("CENO_GPU_DISABLE_KINDS") + .ok() + .map(|s| s.split(',').map(|t| t.trim().to_lowercase()).collect()) + .unwrap_or_default() + }); + if disabled.is_empty() { + return false; + } + let tag = kind_tag(kind); + disabled.iter().any(|d| d == tag) + }) +} + +fn debug_compare_final_lk>( + config: &I::InstructionConfig, + shard_ctx: &ShardContext, num_witin: usize, + num_structural_witin: usize, shard_steps: &[StepRecord], step_indices: &[StepIndex], -) -> Result { - let nthreads = max_usable_threads(); - let total = step_indices.len(); - let batch_size = if total > 256 { - total.div_ceil(nthreads) - } else { - total + kind: GpuWitgenKind, + mixed_lk: &Multiplicity, +) -> Result<(), ZKVMError> { + if std::env::var_os("CENO_GPU_DEBUG_COMPARE_LK").is_none() { + return Ok(()); } - .max(1); - - let lk_multiplicity = LkMultiplicity::default(); - let shard_ctx_vec = shard_ctx.get_forked(); - - step_indices - .par_chunks(batch_size) - .zip(shard_ctx_vec) - .flat_map(|(indices, mut shard_ctx)| { - let mut lk_multiplicity = lk_multiplicity.clone(); - let mut scratch = vec![E::BaseField::ZERO; num_witin]; - indices - .iter() - .copied() - .map(|step_idx| { - scratch.fill(E::BaseField::ZERO); - I::assign_instance( - config, - &mut shard_ctx, - &mut scratch, - &mut lk_multiplicity, - &shard_steps[step_idx], - ) - }) - .collect::>() - }) - .collect::>()?; - Ok(lk_multiplicity) + // Compare against cpu_assign_instances (the true baseline using assign_instance) + let mut cpu_ctx = shard_ctx.new_empty_like(); + let (_, cpu_assign_lk) = crate::instructions::cpu_assign_instances::( + config, + &mut cpu_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + )?; + tracing::info!("[GPU lk debug] kind={kind:?} comparing mixed_lk vs cpu_assign_instances lk"); + log_lk_diff(kind, &cpu_assign_lk, mixed_lk); + Ok(()) +} + +fn log_lk_diff(kind: GpuWitgenKind, cpu_lk: &Multiplicity, actual_lk: &Multiplicity) { + let limit = std::env::var("CENO_GPU_DEBUG_COMPARE_LK_LIMIT") + .ok() + .and_then(|v| v.parse::().ok()) + .unwrap_or(32); + + let mut total_diffs = 0usize; + for (table_idx, (cpu_table, actual_table)) in cpu_lk.iter().zip(actual_lk.iter()).enumerate() { + let mut keys = cpu_table + .keys() + .chain(actual_table.keys()) + .copied() + .collect::>(); + keys.sort_unstable(); + keys.dedup(); + + let mut table_diffs = Vec::new(); + for key in keys { + let cpu_count = cpu_table.get(&key).copied().unwrap_or(0); + let actual_count = actual_table.get(&key).copied().unwrap_or(0); + if cpu_count != actual_count { + table_diffs.push((key, cpu_count, actual_count)); + } + } + + if !table_diffs.is_empty() { + total_diffs += table_diffs.len(); + tracing::error!( + "[GPU lk debug] kind={kind:?} table={} diff_count={}", + lookup_table_name(table_idx), + table_diffs.len() + ); + for (key, cpu_count, actual_count) in table_diffs.into_iter().take(limit) { + tracing::error!( + "[GPU lk debug] kind={kind:?} table={} key={} cpu={} gpu={}", + lookup_table_name(table_idx), + key, + cpu_count, + actual_count + ); + } + } + } + + if total_diffs == 0 { + tracing::info!("[GPU lk debug] kind={kind:?} CPU/GPU lookup multiplicities match"); + } +} + +fn debug_compare_witness>( + config: &I::InstructionConfig, + shard_ctx: &ShardContext, + num_witin: usize, + num_structural_witin: usize, + shard_steps: &[StepRecord], + step_indices: &[StepIndex], + kind: GpuWitgenKind, + gpu_witness: &RowMajorMatrix, +) -> Result<(), ZKVMError> { + if std::env::var_os("CENO_GPU_DEBUG_COMPARE_WITNESS").is_none() { + return Ok(()); + } + + let mut cpu_ctx = shard_ctx.new_empty_like(); + let (cpu_rmms, _) = crate::instructions::cpu_assign_instances::( + config, + &mut cpu_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + )?; + let cpu_witness = &cpu_rmms[0]; + let cpu_vals = cpu_witness.values(); + let gpu_vals = gpu_witness.values(); + if cpu_vals == gpu_vals { + return Ok(()); + } + + let limit = std::env::var("CENO_GPU_DEBUG_COMPARE_WITNESS_LIMIT") + .ok() + .and_then(|v| v.parse::().ok()) + .unwrap_or(16); + let cpu_num_cols = cpu_witness.n_col(); + let cpu_num_rows = cpu_vals.len() / cpu_num_cols; + let mut mismatches = 0usize; + for row in 0..cpu_num_rows { + for col in 0..cpu_num_cols { + let idx = row * cpu_num_cols + col; + if cpu_vals[idx] != gpu_vals[idx] { + mismatches += 1; + if mismatches <= limit { + tracing::error!( + "[GPU witness debug] kind={kind:?} row={} col={} cpu={:?} gpu={:?}", + row, + col, + cpu_vals[idx], + gpu_vals[idx] + ); + } + } + } + } + tracing::error!( + "[GPU witness debug] kind={kind:?} total_mismatches={}", + mismatches + ); + Ok(()) +} + +fn debug_compare_shard_side_effects>( + config: &I::InstructionConfig, + shard_ctx: &ShardContext, + shard_steps: &[StepRecord], + step_indices: &[StepIndex], + kind: GpuWitgenKind, +) -> Result<(), ZKVMError> { + if std::env::var_os("CENO_GPU_DEBUG_COMPARE_SHARD").is_none() { + return Ok(()); + } + + let mut cpu_ctx = shard_ctx.new_empty_like(); + let _ = cpu_collect_side_effects::(config, &mut cpu_ctx, shard_steps, step_indices)?; + + let mut mixed_ctx = shard_ctx.new_empty_like(); + let _ = + cpu_collect_shard_side_effects::(config, &mut mixed_ctx, shard_steps, step_indices)?; + + let cpu_addr = cpu_ctx.get_addr_accessed(); + let mixed_addr = mixed_ctx.get_addr_accessed(); + if cpu_addr != mixed_addr { + tracing::error!( + "[GPU shard debug] kind={kind:?} addr_accessed cpu={} gpu={}", + cpu_addr.len(), + mixed_addr.len() + ); + } + + let cpu_reads = flatten_ram_records(cpu_ctx.read_records()); + let mixed_reads = flatten_ram_records(mixed_ctx.read_records()); + if cpu_reads != mixed_reads { + log_ram_record_diff(kind, "read_records", &cpu_reads, &mixed_reads); + } + + let cpu_writes = flatten_ram_records(cpu_ctx.write_records()); + let mixed_writes = flatten_ram_records(mixed_ctx.write_records()); + if cpu_writes != mixed_writes { + log_ram_record_diff(kind, "write_records", &cpu_writes, &mixed_writes); + } + + Ok(()) +} + +fn flatten_ram_records( + records: &[std::collections::BTreeMap], +) -> Vec<(u32, u64, u64, u64, u64, Option, u32, usize)> { + let mut flat = Vec::new(); + for table in records { + for (addr, record) in table { + flat.push(( + addr.0, + record.reg_id, + record.prev_cycle, + record.cycle, + record.shard_cycle, + record.prev_value, + record.value, + record.shard_id, + )); + } + } + flat +} + +fn log_ram_record_diff( + kind: GpuWitgenKind, + label: &str, + cpu_records: &[(u32, u64, u64, u64, u64, Option, u32, usize)], + mixed_records: &[(u32, u64, u64, u64, u64, Option, u32, usize)], +) { + let limit = std::env::var("CENO_GPU_DEBUG_COMPARE_SHARD_LIMIT") + .ok() + .and_then(|v| v.parse::().ok()) + .unwrap_or(16); + tracing::error!( + "[GPU shard debug] kind={kind:?} {} cpu={} gpu={}", + label, + cpu_records.len(), + mixed_records.len() + ); + let max_len = cpu_records.len().max(mixed_records.len()); + let mut logged = 0usize; + for idx in 0..max_len { + let cpu = cpu_records.get(idx); + let gpu = mixed_records.get(idx); + if cpu != gpu { + tracing::error!( + "[GPU shard debug] kind={kind:?} {} idx={} cpu={:?} gpu={:?}", + label, + idx, + cpu, + gpu + ); + logged += 1; + if logged >= limit { + break; + } + } + } +} + +fn lookup_table_name(table_idx: usize) -> &'static str { + match table_idx { + x if x == LookupTable::Dynamic as usize => "Dynamic", + x if x == LookupTable::DoubleU8 as usize => "DoubleU8", + x if x == LookupTable::And as usize => "And", + x if x == LookupTable::Or as usize => "Or", + x if x == LookupTable::Xor as usize => "Xor", + x if x == LookupTable::Ltu as usize => "Ltu", + x if x == LookupTable::Pow as usize => "Pow", + x if x == LookupTable::Instruction as usize => "Instruction", + _ => "Unknown", + } +} + +fn gpu_lk_counters_to_multiplicity(counters: LkResult) -> Result, ZKVMError> { + let mut lk = LkMultiplicity::default(); + merge_dense_counter_table( + &mut lk, + LookupTable::Dynamic, + &counters.dynamic.to_vec().map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU dynamic lk D2H failed: {e}").into()) + })?, + ); + merge_dense_counter_table( + &mut lk, + LookupTable::DoubleU8, + &counters.double_u8.to_vec().map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU double_u8 lk D2H failed: {e}").into()) + })?, + ); + merge_dense_counter_table( + &mut lk, + LookupTable::And, + &counters + .and_table + .to_vec() + .map_err(|e| ZKVMError::InvalidWitness(format!("GPU and lk D2H failed: {e}").into()))?, + ); + merge_dense_counter_table( + &mut lk, + LookupTable::Or, + &counters + .or_table + .to_vec() + .map_err(|e| ZKVMError::InvalidWitness(format!("GPU or lk D2H failed: {e}").into()))?, + ); + merge_dense_counter_table( + &mut lk, + LookupTable::Xor, + &counters + .xor_table + .to_vec() + .map_err(|e| ZKVMError::InvalidWitness(format!("GPU xor lk D2H failed: {e}").into()))?, + ); + merge_dense_counter_table( + &mut lk, + LookupTable::Ltu, + &counters + .ltu_table + .to_vec() + .map_err(|e| ZKVMError::InvalidWitness(format!("GPU ltu lk D2H failed: {e}").into()))?, + ); + merge_dense_counter_table( + &mut lk, + LookupTable::Pow, + &counters + .pow_table + .to_vec() + .map_err(|e| ZKVMError::InvalidWitness(format!("GPU pow lk D2H failed: {e}").into()))?, + ); + // Merge fetch (Instruction) table if present + if let Some(fetch_buf) = counters.fetch { + let base_pc = counters.fetch_base_pc; + let fetch_counts = fetch_buf.to_vec().map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU fetch lk D2H failed: {e}").into()) + })?; + for (slot_idx, &count) in fetch_counts.iter().enumerate() { + if count != 0 { + let pc = base_pc as u64 + (slot_idx as u64) * 4; + lk.set_count(LookupTable::Instruction, pc, count as usize); + } + } + } + Ok(lk.into_finalize_result()) +} + +fn merge_dense_counter_table(lk: &mut LkMultiplicity, table: LookupTable, counts: &[u32]) { + for (key, &count) in counts.iter().enumerate() { + if count != 0 { + lk.set_count(table, key as u64, count as usize); + } + } } /// Convert GPU device buffer (column-major) to RowMajorMatrix via GPU transpose + D2H copy. @@ -794,9 +1582,7 @@ fn gpu_witness_to_rmm( num_rows, num_cols, ) - .map_err(|e| { - ZKVMError::InvalidWitness(format!("GPU transpose failed: {e}").into()) - })?; + .map_err(|e| ZKVMError::InvalidWitness(format!("GPU transpose failed: {e}").into()))?; let gpu_data: Vec<::BaseField> = rmm_buffer .to_vec() diff --git a/ceno_zkvm/src/instructions/riscv/i_insn.rs b/ceno_zkvm/src/instructions/riscv/i_insn.rs index c726f8a88..ff3fc72a3 100644 --- a/ceno_zkvm/src/instructions/riscv/i_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/i_insn.rs @@ -6,7 +6,10 @@ use crate::{ circuit_builder::CircuitBuilder, e2e::ShardContext, error::ZKVMError, - instructions::riscv::insn_base::{ReadRS1, StateInOut, WriteRD}, + instructions::{ + riscv::insn_base::{ReadRS1, StateInOut, WriteRD}, + side_effects::{LkOp, SideEffectSink}, + }, tables::InsnRecord, witness::LkMultiplicity, }; @@ -76,4 +79,28 @@ impl IInstructionConfig { Ok(()) } + + pub fn collect_side_effects( + &self, + sink: &mut impl SideEffectSink, + shard_ctx: &ShardContext, + step: &StepRecord, + ) { + sink.emit_lk(LkOp::Fetch { + pc: step.pc().before.0, + }); + self.rs1.collect_side_effects(sink, shard_ctx, step); + self.rd.collect_side_effects(sink, shard_ctx, step); + } + + pub fn collect_shard_effects( + &self, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) { + lk_multiplicity.fetch(step.pc().before.0); + self.rs1.collect_shard_effects(shard_ctx, step); + self.rd.collect_shard_effects(shard_ctx, step); + } } diff --git a/ceno_zkvm/src/instructions/riscv/im_insn.rs b/ceno_zkvm/src/instructions/riscv/im_insn.rs index 26b8ce7b9..cafd104aa 100644 --- a/ceno_zkvm/src/instructions/riscv/im_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/im_insn.rs @@ -2,7 +2,10 @@ use crate::{ chip_handler::{AddressExpr, MemoryExpr, RegisterExpr, general::InstFetch}, circuit_builder::CircuitBuilder, error::ZKVMError, - instructions::riscv::insn_base::{ReadMEM, ReadRS1, StateInOut, WriteRD}, + instructions::{ + riscv::insn_base::{ReadMEM, ReadRS1, StateInOut, WriteRD}, + side_effects::{LkOp, SideEffectSink}, + }, tables::InsnRecord, witness::LkMultiplicity, }; @@ -85,4 +88,30 @@ impl IMInstructionConfig { Ok(()) } + + pub fn collect_side_effects( + &self, + sink: &mut impl SideEffectSink, + shard_ctx: &ShardContext, + step: &StepRecord, + ) { + sink.emit_lk(LkOp::Fetch { + pc: step.pc().before.0, + }); + self.rs1.collect_side_effects(sink, shard_ctx, step); + self.rd.collect_side_effects(sink, shard_ctx, step); + self.mem_read.collect_side_effects(sink, shard_ctx, step); + } + + pub fn collect_shard_effects( + &self, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) { + lk_multiplicity.fetch(step.pc().before.0); + self.rs1.collect_shard_effects(shard_ctx, step); + self.rd.collect_shard_effects(shard_ctx, step); + self.mem_read.collect_shard_effects(shard_ctx, step); + } } diff --git a/ceno_zkvm/src/instructions/riscv/insn_base.rs b/ceno_zkvm/src/instructions/riscv/insn_base.rs index 69ea105b7..51e84be17 100644 --- a/ceno_zkvm/src/instructions/riscv/insn_base.rs +++ b/ceno_zkvm/src/instructions/riscv/insn_base.rs @@ -13,6 +13,7 @@ use crate::{ e2e::ShardContext, error::ZKVMError, gadgets::AssertLtConfig, + instructions::side_effects::{LkOp, SendEvent, SideEffectSink, emit_assert_lt_ops}, structs::RAMType, uint::Value, witness::{LkMultiplicity, set_val}, @@ -141,6 +142,47 @@ impl ReadRS1 { Ok(()) } + + pub fn collect_side_effects( + &self, + sink: &mut impl SideEffectSink, + shard_ctx: &ShardContext, + step: &StepRecord, + ) { + let op = step.rs1().expect("rs1 op"); + let shard_prev_cycle = shard_ctx.aligned_prev_ts(op.previous_cycle); + let shard_cycle = step.cycle() - shard_ctx.current_shard_offset_cycle(); + emit_assert_lt_ops( + sink, + &self.lt_cfg, + shard_prev_cycle, + shard_cycle + Tracer::SUBCYCLE_RS1, + ); + sink.emit_send(SendEvent { + ram_type: RAMType::Register, + addr: op.addr, + id: op.register_index() as u64, + cycle: step.cycle() + Tracer::SUBCYCLE_RS1, + prev_cycle: op.previous_cycle, + value: op.value, + prev_value: None, + }); + sink.touch_addr(op.addr); + } + + pub fn collect_shard_effects(&self, shard_ctx: &mut ShardContext, step: &StepRecord) { + let op = step.rs1().expect("rs1 op"); + shard_ctx.record_send_without_touch( + RAMType::Register, + op.addr, + op.register_index() as u64, + step.cycle() + Tracer::SUBCYCLE_RS1, + op.previous_cycle, + op.value, + None, + ); + shard_ctx.push_addr_accessed(op.addr); + } } #[derive(Debug)] @@ -209,6 +251,47 @@ impl ReadRS2 { Ok(()) } + + pub fn collect_side_effects( + &self, + sink: &mut impl SideEffectSink, + shard_ctx: &ShardContext, + step: &StepRecord, + ) { + let op = step.rs2().expect("rs2 op"); + let shard_prev_cycle = shard_ctx.aligned_prev_ts(op.previous_cycle); + let shard_cycle = step.cycle() - shard_ctx.current_shard_offset_cycle(); + emit_assert_lt_ops( + sink, + &self.lt_cfg, + shard_prev_cycle, + shard_cycle + Tracer::SUBCYCLE_RS2, + ); + sink.emit_send(SendEvent { + ram_type: RAMType::Register, + addr: op.addr, + id: op.register_index() as u64, + cycle: step.cycle() + Tracer::SUBCYCLE_RS2, + prev_cycle: op.previous_cycle, + value: op.value, + prev_value: None, + }); + sink.touch_addr(op.addr); + } + + pub fn collect_shard_effects(&self, shard_ctx: &mut ShardContext, step: &StepRecord) { + let op = step.rs2().expect("rs2 op"); + shard_ctx.record_send_without_touch( + RAMType::Register, + op.addr, + op.register_index() as u64, + step.cycle() + Tracer::SUBCYCLE_RS2, + op.previous_cycle, + op.value, + None, + ); + shard_ctx.push_addr_accessed(op.addr); + } } #[derive(Debug)] @@ -295,6 +378,66 @@ impl WriteRD { Ok(()) } + + pub fn collect_op_side_effects( + &self, + sink: &mut impl SideEffectSink, + shard_ctx: &ShardContext, + cycle: Cycle, + op: &WriteOp, + ) { + let shard_prev_cycle = shard_ctx.aligned_prev_ts(op.previous_cycle); + let shard_cycle = cycle - shard_ctx.current_shard_offset_cycle(); + emit_assert_lt_ops( + sink, + &self.lt_cfg, + shard_prev_cycle, + shard_cycle + Tracer::SUBCYCLE_RD, + ); + sink.emit_send(SendEvent { + ram_type: RAMType::Register, + addr: op.addr, + id: op.register_index() as u64, + cycle: cycle + Tracer::SUBCYCLE_RD, + prev_cycle: op.previous_cycle, + value: op.value.after, + prev_value: Some(op.value.before), + }); + sink.touch_addr(op.addr); + } + + pub fn collect_side_effects( + &self, + sink: &mut impl SideEffectSink, + shard_ctx: &ShardContext, + step: &StepRecord, + ) { + let op = step.rd().expect("rd op"); + self.collect_op_side_effects(sink, shard_ctx, step.cycle(), &op) + } + + pub fn collect_shard_effects(&self, shard_ctx: &mut ShardContext, step: &StepRecord) { + let op = step.rd().expect("rd op"); + self.collect_op_shard_effects(shard_ctx, step.cycle(), &op) + } + + pub fn collect_op_shard_effects( + &self, + shard_ctx: &mut ShardContext, + cycle: Cycle, + op: &WriteOp, + ) { + shard_ctx.record_send_without_touch( + RAMType::Register, + op.addr, + op.register_index() as u64, + cycle + Tracer::SUBCYCLE_RD, + op.previous_cycle, + op.value.after, + Some(op.value.before), + ); + shard_ctx.push_addr_accessed(op.addr); + } } #[derive(Debug)] @@ -361,6 +504,47 @@ impl ReadMEM { Ok(()) } + + pub fn collect_side_effects( + &self, + sink: &mut impl SideEffectSink, + shard_ctx: &ShardContext, + step: &StepRecord, + ) { + let op = step.memory_op().expect("memory op"); + let shard_prev_cycle = shard_ctx.aligned_prev_ts(op.previous_cycle); + let shard_cycle = step.cycle() - shard_ctx.current_shard_offset_cycle(); + emit_assert_lt_ops( + sink, + &self.lt_cfg, + shard_prev_cycle, + shard_cycle + Tracer::SUBCYCLE_MEM, + ); + sink.emit_send(SendEvent { + ram_type: RAMType::Memory, + addr: op.addr, + id: op.addr.baddr().0 as u64, + cycle: step.cycle() + Tracer::SUBCYCLE_MEM, + prev_cycle: op.previous_cycle, + value: op.value.after, + prev_value: None, + }); + sink.touch_addr(op.addr); + } + + pub fn collect_shard_effects(&self, shard_ctx: &mut ShardContext, step: &StepRecord) { + let op = step.memory_op().expect("memory op"); + shard_ctx.record_send_without_touch( + RAMType::Memory, + op.addr, + op.addr.baddr().0 as u64, + step.cycle() + Tracer::SUBCYCLE_MEM, + op.previous_cycle, + op.value.after, + None, + ); + shard_ctx.push_addr_accessed(op.addr); + } } #[derive(Debug)] @@ -434,6 +618,66 @@ impl WriteMEM { Ok(()) } + + pub fn collect_op_side_effects( + &self, + sink: &mut impl SideEffectSink, + shard_ctx: &ShardContext, + cycle: Cycle, + op: &WriteOp, + ) { + let shard_prev_cycle = shard_ctx.aligned_prev_ts(op.previous_cycle); + let shard_cycle = cycle - shard_ctx.current_shard_offset_cycle(); + emit_assert_lt_ops( + sink, + &self.lt_cfg, + shard_prev_cycle, + shard_cycle + Tracer::SUBCYCLE_MEM, + ); + sink.emit_send(SendEvent { + ram_type: RAMType::Memory, + addr: op.addr, + id: op.addr.baddr().0 as u64, + cycle: cycle + Tracer::SUBCYCLE_MEM, + prev_cycle: op.previous_cycle, + value: op.value.after, + prev_value: Some(op.value.before), + }); + sink.touch_addr(op.addr); + } + + pub fn collect_side_effects( + &self, + sink: &mut impl SideEffectSink, + shard_ctx: &ShardContext, + step: &StepRecord, + ) { + let op = step.memory_op().expect("memory op"); + self.collect_op_side_effects(sink, shard_ctx, step.cycle(), &op) + } + + pub fn collect_shard_effects(&self, shard_ctx: &mut ShardContext, step: &StepRecord) { + let op = step.memory_op().expect("memory op"); + self.collect_op_shard_effects(shard_ctx, step.cycle(), &op) + } + + pub fn collect_op_shard_effects( + &self, + shard_ctx: &mut ShardContext, + cycle: Cycle, + op: &WriteOp, + ) { + shard_ctx.record_send_without_touch( + RAMType::Memory, + op.addr, + op.addr.baddr().0 as u64, + cycle + Tracer::SUBCYCLE_MEM, + op.previous_cycle, + op.value.after, + Some(op.value.before), + ); + shard_ctx.push_addr_accessed(op.addr); + } } #[derive(Debug)] @@ -584,6 +828,22 @@ impl MemAddr { Ok(()) } + pub fn collect_side_effects(&self, sink: &mut impl SideEffectSink, addr: Word) { + let mid_u14 = ((addr & 0xffff) >> Self::N_LOW_BITS) as u16; + sink.emit_lk(LkOp::AssertU14 { value: mid_u14 }); + + for i in 1..UINT_LIMBS { + let high_u16 = ((addr >> (i * 16)) & 0xffff) as u64; + let bits = (self.max_bits - i * 16).min(16); + if bits > 1 { + sink.emit_lk(LkOp::DynamicRange { + value: high_u16, + bits: bits as u32, + }); + } + } + } + fn n_zeros(&self) -> usize { Self::N_LOW_BITS - self.low_bits.len() } diff --git a/ceno_zkvm/src/instructions/riscv/j_insn.rs b/ceno_zkvm/src/instructions/riscv/j_insn.rs index 84cb84679..eb3f7b693 100644 --- a/ceno_zkvm/src/instructions/riscv/j_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/j_insn.rs @@ -6,7 +6,10 @@ use crate::{ circuit_builder::CircuitBuilder, e2e::ShardContext, error::ZKVMError, - instructions::riscv::insn_base::{StateInOut, WriteRD}, + instructions::{ + riscv::insn_base::{StateInOut, WriteRD}, + side_effects::{LkOp, SideEffectSink}, + }, tables::InsnRecord, witness::LkMultiplicity, }; @@ -68,4 +71,26 @@ impl JInstructionConfig { Ok(()) } + + pub fn collect_shard_effects( + &self, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) { + lk_multiplicity.fetch(step.pc().before.0); + self.rd.collect_shard_effects(shard_ctx, step); + } + + pub fn collect_side_effects( + &self, + sink: &mut impl SideEffectSink, + shard_ctx: &ShardContext, + step: &StepRecord, + ) { + sink.emit_lk(LkOp::Fetch { + pc: step.pc().before.0, + }); + self.rd.collect_side_effects(sink, shard_ctx, step); + } } diff --git a/ceno_zkvm/src/instructions/riscv/jump/jal_v2.rs b/ceno_zkvm/src/instructions/riscv/jump/jal_v2.rs index 85db8e91f..130dac8fc 100644 --- a/ceno_zkvm/src/instructions/riscv/jump/jal_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/jump/jal_v2.rs @@ -12,6 +12,7 @@ use crate::{ constants::{PC_BITS, UINT_BYTE_LIMBS, UInt8}, j_insn::JInstructionConfig, }, + side_effects::{CpuSideEffectSink, LkOp, SideEffectSink, emit_byte_decomposition_ops}, }, structs::ProgramParams, utils::split_to_u8, @@ -51,6 +52,8 @@ impl Instruction for JalInstruction { type InstructionConfig = JalConfig; type InsnType = InsnKind; + const GPU_SIDE_EFFECTS: bool = true; + fn inst_kinds() -> &'static [Self::InsnType] { &[InsnKind::JAL] } @@ -129,6 +132,45 @@ impl Instruction for JalInstruction { Ok(()) } + fn collect_side_effects_instance( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &ceno_emul::StepRecord, + ) -> Result<(), ZKVMError> { + let shard_ctx_ptr = shard_ctx as *mut ShardContext; + let shard_ctx_view = unsafe { &*shard_ctx_ptr }; + let mut sink = unsafe { CpuSideEffectSink::from_raw(shard_ctx_ptr, lk_multiplicity) }; + config + .j_insn + .collect_side_effects(&mut sink, shard_ctx_view, step); + + let rd_written = split_to_u8(step.rd().unwrap().value.after); + emit_byte_decomposition_ops(&mut sink, &rd_written); + + let last_limb_bits = PC_BITS - UInt8::::LIMB_BITS * (UINT_BYTE_LIMBS - 1); + let additional_bits = + (last_limb_bits..UInt8::::LIMB_BITS).fold(0, |acc, x| acc + (1 << x)); + sink.emit_lk(LkOp::Xor { + a: rd_written[3], + b: additional_bits as u8, + }); + + Ok(()) + } + + fn collect_shard_side_effects_instance( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &ceno_emul::StepRecord, + ) -> Result<(), ZKVMError> { + config + .j_insn + .collect_shard_effects(shard_ctx, lk_multiplicity, step); + Ok(()) + } + #[cfg(feature = "gpu")] fn assign_instances( config: &Self::InstructionConfig, diff --git a/ceno_zkvm/src/instructions/riscv/jump/jalr_v2.rs b/ceno_zkvm/src/instructions/riscv/jump/jalr_v2.rs index e4c838253..644ba2a45 100644 --- a/ceno_zkvm/src/instructions/riscv/jump/jalr_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/jump/jalr_v2.rs @@ -14,6 +14,7 @@ use crate::{ i_insn::IInstructionConfig, insn_base::{MemAddr, ReadRS1, StateInOut, WriteRD}, }, + side_effects::{CpuSideEffectSink, emit_const_range_op}, }, structs::ProgramParams, tables::InsnRecord, @@ -51,6 +52,8 @@ impl Instruction for JalrInstruction { type InstructionConfig = JalrConfig; type InsnType = InsnKind; + const GPU_SIDE_EFFECTS: bool = true; + fn inst_kinds() -> &'static [Self::InsnType] { &[InsnKind::JALR] } @@ -196,6 +199,43 @@ impl Instruction for JalrInstruction { Ok(()) } + fn collect_side_effects_instance( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &ceno_emul::StepRecord, + ) -> Result<(), ZKVMError> { + let shard_ctx_ptr = shard_ctx as *mut ShardContext; + let shard_ctx_view = unsafe { &*shard_ctx_ptr }; + let mut sink = unsafe { CpuSideEffectSink::from_raw(shard_ctx_ptr, lk_multiplicity) }; + config + .i_insn + .collect_side_effects(&mut sink, shard_ctx_view, step); + + let rd_value = Value::new_unchecked(step.rd().unwrap().value.after); + let rd_limb = rd_value.as_u16_limbs(); + emit_const_range_op(&mut sink, rd_limb[0] as u64, 16); + emit_const_range_op(&mut sink, rd_limb[1] as u64, PC_BITS - 16); + + let imm = InsnRecord::::imm_internal(&step.insn()); + let jump_pc = step.rs1().unwrap().value.wrapping_add_signed(imm.0 as i32); + config.jump_pc_addr.collect_side_effects(&mut sink, jump_pc); + + Ok(()) + } + + fn collect_shard_side_effects_instance( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &ceno_emul::StepRecord, + ) -> Result<(), ZKVMError> { + config + .i_insn + .collect_shard_effects(shard_ctx, lk_multiplicity, step); + Ok(()) + } + #[cfg(feature = "gpu")] fn assign_instances( config: &Self::InstructionConfig, diff --git a/ceno_zkvm/src/instructions/riscv/logic/logic_circuit.rs b/ceno_zkvm/src/instructions/riscv/logic/logic_circuit.rs index f6ce31288..aae61aef5 100644 --- a/ceno_zkvm/src/instructions/riscv/logic/logic_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/logic/logic_circuit.rs @@ -11,6 +11,7 @@ use crate::{ instructions::{ Instruction, riscv::{constants::UInt8, r_insn::RInstructionConfig}, + side_effects::{CpuSideEffectSink, emit_logic_u8_ops}, }, structs::ProgramParams, utils::split_to_u8, @@ -38,6 +39,8 @@ impl Instruction for LogicInstruction { type InstructionConfig = LogicConfig; type InsnType = InsnKind; + const GPU_SIDE_EFFECTS: bool = true; + fn inst_kinds() -> &'static [Self::InsnType] { &[I::INST_KIND] } @@ -80,6 +83,37 @@ impl Instruction for LogicInstruction { config.assign_instance(instance, shard_ctx, lk_multiplicity, step) } + fn collect_side_effects_instance( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) -> Result<(), ZKVMError> { + let shard_ctx_ptr = shard_ctx as *mut ShardContext; + let shard_ctx_view = unsafe { &*shard_ctx_ptr }; + let mut sink = unsafe { CpuSideEffectSink::from_raw(shard_ctx_ptr, lk_multiplicity) }; + config.collect_side_effects(&mut sink, shard_ctx_view, step); + emit_logic_u8_ops::( + &mut sink, + step.rs1().unwrap().value as u64, + step.rs2().unwrap().value as u64, + 4, + ); + Ok(()) + } + + fn collect_shard_side_effects_instance( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) -> Result<(), ZKVMError> { + config + .r_insn + .collect_shard_effects(shard_ctx, lk_multiplicity, step); + Ok(()) + } + #[cfg(feature = "gpu")] fn assign_instances( config: &Self::InstructionConfig, @@ -97,7 +131,12 @@ impl Instruction for LogicInstruction { num_structural_witin, shard_steps, step_indices, - witgen_gpu::GpuWitgenKind::LogicR, + witgen_gpu::GpuWitgenKind::LogicR(match I::INST_KIND { + InsnKind::AND => 0, + InsnKind::OR => 1, + InsnKind::XOR => 2, + kind => unreachable!("unsupported logic GPU kind: {kind:?}"), + }), )? { return Ok(result); } @@ -169,4 +208,13 @@ impl LogicConfig { Ok(()) } + + fn collect_side_effects( + &self, + sink: &mut impl crate::instructions::side_effects::SideEffectSink, + shard_ctx: &ShardContext, + step: &StepRecord, + ) { + self.r_insn.collect_side_effects(sink, shard_ctx, step); + } } diff --git a/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit_v2.rs index 6f20710e3..2eb89036c 100644 --- a/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit_v2.rs @@ -16,6 +16,7 @@ use crate::{ i_insn::IInstructionConfig, logic_imm::LogicOp, }, + side_effects::{CpuSideEffectSink, emit_logic_u8_ops}, }, structs::ProgramParams, tables::InsnRecord, @@ -40,6 +41,8 @@ impl Instruction for LogicInstruction { type InstructionConfig = LogicConfig; type InsnType = InsnKind; + const GPU_SIDE_EFFECTS: bool = true; + fn inst_kinds() -> &'static [Self::InsnType] { &[I::INST_KIND] } @@ -132,6 +135,43 @@ impl Instruction for LogicInstruction { config.assign_instance(instance, shard_ctx, lkm, step) } + fn collect_side_effects_instance( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) -> Result<(), ZKVMError> { + let shard_ctx_ptr = shard_ctx as *mut ShardContext; + let shard_ctx_view = unsafe { &*shard_ctx_ptr }; + let mut sink = unsafe { CpuSideEffectSink::from_raw(shard_ctx_ptr, lk_multiplicity) }; + config + .i_insn + .collect_side_effects(&mut sink, shard_ctx_view, step); + + let rs1_lo = step.rs1().unwrap().value & LIMB_MASK; + let rs1_hi = (step.rs1().unwrap().value >> LIMB_BITS) & LIMB_MASK; + let imm_lo = InsnRecord::::imm_internal(&step.insn()).0 as u32 & LIMB_MASK; + let imm_hi = (InsnRecord::::imm_signed_internal(&step.insn()).0 as u32 + >> LIMB_BITS) + & LIMB_MASK; + + emit_logic_u8_ops::(&mut sink, rs1_lo.into(), imm_lo.into(), 2); + emit_logic_u8_ops::(&mut sink, rs1_hi.into(), imm_hi.into(), 2); + Ok(()) + } + + fn collect_shard_side_effects_instance( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) -> Result<(), ZKVMError> { + config + .i_insn + .collect_shard_effects(shard_ctx, lk_multiplicity, step); + Ok(()) + } + #[cfg(feature = "gpu")] fn assign_instances( config: &Self::InstructionConfig, @@ -149,7 +189,12 @@ impl Instruction for LogicInstruction { num_structural_witin, shard_steps, step_indices, - witgen_gpu::GpuWitgenKind::LogicI, + witgen_gpu::GpuWitgenKind::LogicI(match I::INST_KIND { + InsnKind::ANDI => 0, + InsnKind::ORI => 1, + InsnKind::XORI => 2, + kind => unreachable!("unsupported logic_imm GPU kind: {kind:?}"), + }), )? { return Ok(result); } diff --git a/ceno_zkvm/src/instructions/riscv/lui.rs b/ceno_zkvm/src/instructions/riscv/lui.rs index 38882b78e..bc661d7a7 100644 --- a/ceno_zkvm/src/instructions/riscv/lui.rs +++ b/ceno_zkvm/src/instructions/riscv/lui.rs @@ -12,6 +12,7 @@ use crate::{ constants::{UINT_BYTE_LIMBS, UInt8}, i_insn::IInstructionConfig, }, + side_effects::{CpuSideEffectSink, emit_const_range_op}, }, structs::ProgramParams, tables::InsnRecord, @@ -43,6 +44,8 @@ impl Instruction for LuiInstruction { type InstructionConfig = LuiConfig; type InsnType = InsnKind; + const GPU_SIDE_EFFECTS: bool = true; + fn inst_kinds() -> &'static [Self::InsnType] { &[InsnKind::LUI] } @@ -110,7 +113,7 @@ impl Instruction for LuiInstruction { .i_insn .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; - let rd_written = split_to_u8(step.rd().unwrap().value.after); + let rd_written = split_to_u8::(step.rd().unwrap().value.after); for (val, witin) in izip!(rd_written.iter().skip(1), config.rd_written) { lk_multiplicity.assert_ux::<8>(*val as u64); set_val!(instance, witin, E::BaseField::from_canonical_u8(*val)); @@ -121,6 +124,39 @@ impl Instruction for LuiInstruction { Ok(()) } + fn collect_side_effects_instance( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &ceno_emul::StepRecord, + ) -> Result<(), ZKVMError> { + let shard_ctx_ptr = shard_ctx as *mut ShardContext; + let shard_ctx_view = unsafe { &*shard_ctx_ptr }; + let mut sink = unsafe { CpuSideEffectSink::from_raw(shard_ctx_ptr, lk_multiplicity) }; + config + .i_insn + .collect_side_effects(&mut sink, shard_ctx_view, step); + + let rd_written = split_to_u8::(step.rd().unwrap().value.after); + for val in rd_written.iter().skip(1) { + emit_const_range_op(&mut sink, *val as u64, 8); + } + + Ok(()) + } + + fn collect_shard_side_effects_instance( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &ceno_emul::StepRecord, + ) -> Result<(), ZKVMError> { + config + .i_insn + .collect_shard_effects(shard_ctx, lk_multiplicity, step); + Ok(()) + } + #[cfg(feature = "gpu")] fn assign_instances( config: &Self::InstructionConfig, diff --git a/ceno_zkvm/src/instructions/riscv/memory/load.rs b/ceno_zkvm/src/instructions/riscv/memory/load.rs index 08ca6c878..fe8d3b2e2 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/load.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/load.rs @@ -9,6 +9,7 @@ use crate::{ riscv::{ RIVInstruction, constants::UInt, im_insn::IMInstructionConfig, insn_base::MemAddr, }, + side_effects::CpuSideEffectSink, }, structs::ProgramParams, tables::InsnRecord, @@ -40,6 +41,8 @@ impl Instruction for LoadInstruction; type InsnType = InsnKind; + const GPU_SIDE_EFFECTS: bool = matches!(I::INST_KIND, InsnKind::LW); + fn inst_kinds() -> &'static [Self::InsnType] { &[I::INST_KIND] } @@ -227,6 +230,63 @@ impl Instruction for LoadInstruction Result<(), ZKVMError> { + match I::INST_KIND { + InsnKind::LW => { + let shard_ctx_ptr = shard_ctx as *mut ShardContext; + let shard_ctx_view = unsafe { &*shard_ctx_ptr }; + let mut sink = + unsafe { CpuSideEffectSink::from_raw(shard_ctx_ptr, lk_multiplicity) }; + config + .im_insn + .collect_side_effects(&mut sink, shard_ctx_view, step); + + let imm = InsnRecord::::imm_internal(&step.insn()); + let unaligned_addr = + ByteAddr::from(step.rs1().unwrap().value.wrapping_add_signed(imm.0 as i32)); + config + .memory_addr + .collect_side_effects(&mut sink, unaligned_addr.into()); + Ok(()) + } + _ => Err(ZKVMError::InvalidWitness( + format!( + "lightweight side effects not implemented for {:?}", + I::INST_KIND + ) + .into(), + )), + } + } + + fn collect_shard_side_effects_instance( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) -> Result<(), ZKVMError> { + match I::INST_KIND { + InsnKind::LW => { + config + .im_insn + .collect_shard_effects(shard_ctx, lk_multiplicity, step); + Ok(()) + } + _ => Err(ZKVMError::InvalidWitness( + format!( + "shard-only side effects not implemented for {:?}", + I::INST_KIND + ) + .into(), + )), + } + } + #[cfg(feature = "gpu")] fn assign_instances( config: &Self::InstructionConfig, diff --git a/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs b/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs index 28193e4f4..850d2dffc 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs @@ -12,6 +12,7 @@ use crate::{ im_insn::IMInstructionConfig, insn_base::MemAddr, }, + side_effects::CpuSideEffectSink, }, structs::ProgramParams, tables::InsnRecord, @@ -44,6 +45,8 @@ impl Instruction for LoadInstruction; type InsnType = InsnKind; + const GPU_SIDE_EFFECTS: bool = matches!(I::INST_KIND, InsnKind::LW); + fn inst_kinds() -> &'static [Self::InsnType] { &[I::INST_KIND] } @@ -252,6 +255,63 @@ impl Instruction for LoadInstruction Result<(), ZKVMError> { + match I::INST_KIND { + InsnKind::LW => { + let shard_ctx_ptr = shard_ctx as *mut ShardContext; + let shard_ctx_view = unsafe { &*shard_ctx_ptr }; + let mut sink = + unsafe { CpuSideEffectSink::from_raw(shard_ctx_ptr, lk_multiplicity) }; + config + .im_insn + .collect_side_effects(&mut sink, shard_ctx_view, step); + + let imm = InsnRecord::::imm_internal(&step.insn()); + let unaligned_addr = + ByteAddr::from(step.rs1().unwrap().value.wrapping_add_signed(imm.0 as i32)); + config + .memory_addr + .collect_side_effects(&mut sink, unaligned_addr.into()); + Ok(()) + } + _ => Err(ZKVMError::InvalidWitness( + format!( + "lightweight side effects not implemented for {:?}", + I::INST_KIND + ) + .into(), + )), + } + } + + fn collect_shard_side_effects_instance( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) -> Result<(), ZKVMError> { + match I::INST_KIND { + InsnKind::LW => { + config + .im_insn + .collect_shard_effects(shard_ctx, lk_multiplicity, step); + Ok(()) + } + _ => Err(ZKVMError::InvalidWitness( + format!( + "shard-only side effects not implemented for {:?}", + I::INST_KIND + ) + .into(), + )), + } + } + #[cfg(feature = "gpu")] fn assign_instances( config: &Self::InstructionConfig, @@ -270,10 +330,22 @@ impl Instruction for LoadInstruction Some(witgen_gpu::GpuWitgenKind::Lw), - InsnKind::LH => Some(witgen_gpu::GpuWitgenKind::LoadSub { load_width: 16, is_signed: 1 }), - InsnKind::LHU => Some(witgen_gpu::GpuWitgenKind::LoadSub { load_width: 16, is_signed: 0 }), - InsnKind::LB => Some(witgen_gpu::GpuWitgenKind::LoadSub { load_width: 8, is_signed: 1 }), - InsnKind::LBU => Some(witgen_gpu::GpuWitgenKind::LoadSub { load_width: 8, is_signed: 0 }), + InsnKind::LH => Some(witgen_gpu::GpuWitgenKind::LoadSub { + load_width: 16, + is_signed: 1, + }), + InsnKind::LHU => Some(witgen_gpu::GpuWitgenKind::LoadSub { + load_width: 16, + is_signed: 0, + }), + InsnKind::LB => Some(witgen_gpu::GpuWitgenKind::LoadSub { + load_width: 8, + is_signed: 1, + }), + InsnKind::LBU => Some(witgen_gpu::GpuWitgenKind::LoadSub { + load_width: 8, + is_signed: 0, + }), _ => None, }; if let Some(kind) = gpu_kind { diff --git a/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs b/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs index 84f6a87ce..ddb1dffb7 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs @@ -12,6 +12,7 @@ use crate::{ memory::gadget::MemWordUtil, s_insn::SInstructionConfig, }, + side_effects::{CpuSideEffectSink, emit_const_range_op, emit_u16_limbs}, }, structs::ProgramParams, tables::InsnRecord, @@ -51,6 +52,8 @@ impl Instruction type InstructionConfig = StoreConfig; type InsnType = InsnKind; + const GPU_SIDE_EFFECTS: bool = true; + fn inst_kinds() -> &'static [Self::InsnType] { &[I::INST_KIND] } @@ -179,6 +182,57 @@ impl Instruction Ok(()) } + fn collect_side_effects_instance( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) -> Result<(), ZKVMError> { + let shard_ctx_ptr = shard_ctx as *mut ShardContext; + let shard_ctx_view = unsafe { &*shard_ctx_ptr }; + let mut sink = unsafe { CpuSideEffectSink::from_raw(shard_ctx_ptr, lk_multiplicity) }; + config + .s_insn + .collect_side_effects(&mut sink, shard_ctx_view, step); + + emit_u16_limbs(&mut sink, step.memory_op().unwrap().value.before); + + let imm = InsnRecord::::imm_internal(&step.insn()); + let addr = ByteAddr::from(step.rs1().unwrap().value.wrapping_add_signed(imm.0 as i32)); + config + .memory_addr + .collect_side_effects(&mut sink, addr.into()); + + if N_ZEROS == 0 { + let memory_op = step.memory_op().unwrap(); + let prev_value = Value::new_unchecked(memory_op.value.before); + let rs2_value = Value::new_unchecked(step.rs2().unwrap().value); + let prev_limb = prev_value.as_u16_limbs()[((addr.shift() >> 1) & 1) as usize]; + let rs2_limb = rs2_value.as_u16_limbs()[0]; + + for byte in prev_limb.to_le_bytes() { + emit_const_range_op(&mut sink, byte as u64, 8); + } + for byte in rs2_limb.to_le_bytes() { + emit_const_range_op(&mut sink, byte as u64, 8); + } + } + + Ok(()) + } + + fn collect_shard_side_effects_instance( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) -> Result<(), ZKVMError> { + config + .s_insn + .collect_shard_effects(shard_ctx, lk_multiplicity, step); + Ok(()) + } + #[cfg(feature = "gpu")] fn assign_instances( config: &Self::InstructionConfig, diff --git a/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit_v2.rs index d42f9c7d8..a04359256 100644 --- a/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit_v2.rs @@ -8,6 +8,7 @@ use crate::{ constants::{LIMB_BITS, UINT_LIMBS, UInt}, r_insn::RInstructionConfig, }, + side_effects::{CpuSideEffectSink, LkOp, SideEffectSink}, }, structs::ProgramParams, uint::Value, @@ -47,6 +48,8 @@ impl Instruction for MulhInstructionBas type InstructionConfig = MulhConfig; type InsnType = InsnKind; + const GPU_SIDE_EFFECTS: bool = true; + fn inst_kinds() -> &'static [Self::InsnType] { &[I::INST_KIND] } @@ -335,6 +338,113 @@ impl Instruction for MulhInstructionBas Ok(()) } + fn collect_side_effects_instance( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) -> Result<(), ZKVMError> { + let rs1 = step.rs1().unwrap().value; + let rs1_val = Value::new_unchecked(rs1); + let rs2 = step.rs2().unwrap().value; + let rs2_val = Value::new_unchecked(rs2); + + let shard_ctx_ptr = shard_ctx as *mut ShardContext; + let shard_ctx_view = unsafe { &*shard_ctx_ptr }; + let mut sink = unsafe { CpuSideEffectSink::from_raw(shard_ctx_ptr, lk_multiplicity) }; + config + .r_insn + .collect_side_effects(&mut sink, shard_ctx_view, step); + + let (rd_high, rd_low, carry, rs1_ext, rs2_ext) = run_mulh::( + I::INST_KIND, + rs1_val + .as_u16_limbs() + .iter() + .map(|x| *x as u32) + .collect::>() + .as_slice(), + rs2_val + .as_u16_limbs() + .iter() + .map(|x| *x as u32) + .collect::>() + .as_slice(), + ); + + for (rd_low, carry_low) in rd_low.iter().zip(carry[0..UINT_LIMBS].iter()) { + sink.emit_lk(LkOp::DynamicRange { + value: *rd_low as u64, + bits: 16, + }); + sink.emit_lk(LkOp::DynamicRange { + value: *carry_low as u64, + bits: 18, + }); + } + + match I::INST_KIND { + InsnKind::MULH | InsnKind::MULHU | InsnKind::MULHSU => { + for (rd_high, carry_high) in rd_high.iter().zip(carry[UINT_LIMBS..].iter()) { + sink.emit_lk(LkOp::DynamicRange { + value: *rd_high as u64, + bits: 16, + }); + sink.emit_lk(LkOp::DynamicRange { + value: *carry_high as u64, + bits: 18, + }); + } + } + _ => {} + } + + let sign_mask = 1 << (LIMB_BITS - 1); + let ext = (1 << LIMB_BITS) - 1; + let rs1_sign = rs1_ext / ext; + let rs2_sign = rs2_ext / ext; + let rs1_limbs = rs1_val.as_u16_limbs(); + let rs2_limbs = rs2_val.as_u16_limbs(); + + match I::INST_KIND { + InsnKind::MULH => { + sink.emit_lk(LkOp::DynamicRange { + value: (2 * (rs1_limbs[UINT_LIMBS - 1] as u32 - rs1_sign * sign_mask)) as u64, + bits: 16, + }); + sink.emit_lk(LkOp::DynamicRange { + value: (2 * (rs2_limbs[UINT_LIMBS - 1] as u32 - rs2_sign * sign_mask)) as u64, + bits: 16, + }); + } + InsnKind::MULHSU => { + sink.emit_lk(LkOp::DynamicRange { + value: (2 * (rs1_limbs[UINT_LIMBS - 1] as u32 - rs1_sign * sign_mask)) as u64, + bits: 16, + }); + sink.emit_lk(LkOp::DynamicRange { + value: (rs2_limbs[UINT_LIMBS - 1] as u32 - rs2_sign * sign_mask) as u64, + bits: 16, + }); + } + _ => {} + } + + Ok(()) + } + + fn collect_shard_side_effects_instance( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) -> Result<(), ZKVMError> { + config + .r_insn + .collect_shard_effects(shard_ctx, lk_multiplicity, step); + Ok(()) + } + #[cfg(feature = "gpu")] fn assign_instances( config: &Self::InstructionConfig, @@ -352,7 +462,12 @@ impl Instruction for MulhInstructionBas InsnKind::MULHSU => 3u32, _ => { return crate::instructions::cpu_assign_instances::( - config, shard_ctx, num_witin, num_structural_witin, shard_steps, step_indices, + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, ); } }; @@ -368,7 +483,12 @@ impl Instruction for MulhInstructionBas return Ok(result); } crate::instructions::cpu_assign_instances::( - config, shard_ctx, num_witin, num_structural_witin, shard_steps, step_indices, + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, ) } } diff --git a/ceno_zkvm/src/instructions/riscv/r_insn.rs b/ceno_zkvm/src/instructions/riscv/r_insn.rs index a4b9bb128..b0e8089d0 100644 --- a/ceno_zkvm/src/instructions/riscv/r_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/r_insn.rs @@ -6,7 +6,10 @@ use crate::{ circuit_builder::CircuitBuilder, e2e::ShardContext, error::ZKVMError, - instructions::riscv::insn_base::{ReadRS1, ReadRS2, StateInOut, WriteRD}, + instructions::{ + riscv::insn_base::{ReadRS1, ReadRS2, StateInOut, WriteRD}, + side_effects::{LkOp, SideEffectSink}, + }, tables::InsnRecord, witness::LkMultiplicity, }; @@ -81,4 +84,30 @@ impl RInstructionConfig { Ok(()) } + + pub fn collect_side_effects( + &self, + sink: &mut impl SideEffectSink, + shard_ctx: &ShardContext, + step: &StepRecord, + ) { + sink.emit_lk(LkOp::Fetch { + pc: step.pc().before.0, + }); + self.rs1.collect_side_effects(sink, shard_ctx, step); + self.rs2.collect_side_effects(sink, shard_ctx, step); + self.rd.collect_side_effects(sink, shard_ctx, step); + } + + pub fn collect_shard_effects( + &self, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) { + lk_multiplicity.fetch(step.pc().before.0); + self.rs1.collect_shard_effects(shard_ctx, step); + self.rs2.collect_shard_effects(shard_ctx, step); + self.rd.collect_shard_effects(shard_ctx, step); + } } diff --git a/ceno_zkvm/src/instructions/riscv/s_insn.rs b/ceno_zkvm/src/instructions/riscv/s_insn.rs index 3ffa77f3f..9b5f8d88e 100644 --- a/ceno_zkvm/src/instructions/riscv/s_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/s_insn.rs @@ -3,7 +3,10 @@ use crate::{ circuit_builder::CircuitBuilder, e2e::ShardContext, error::ZKVMError, - instructions::riscv::insn_base::{ReadRS1, ReadRS2, StateInOut, WriteMEM}, + instructions::{ + riscv::insn_base::{ReadRS1, ReadRS2, StateInOut, WriteMEM}, + side_effects::{LkOp, SideEffectSink}, + }, tables::InsnRecord, witness::LkMultiplicity, }; @@ -91,4 +94,31 @@ impl SInstructionConfig { Ok(()) } + + pub fn collect_shard_effects( + &self, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) { + lk_multiplicity.fetch(step.pc().before.0); + self.rs1.collect_shard_effects(shard_ctx, step); + self.rs2.collect_shard_effects(shard_ctx, step); + self.mem_write.collect_shard_effects(shard_ctx, step); + } + + #[allow(dead_code)] + pub fn collect_side_effects( + &self, + sink: &mut impl SideEffectSink, + shard_ctx: &ShardContext, + step: &StepRecord, + ) { + sink.emit_lk(LkOp::Fetch { + pc: step.pc().before.0, + }); + self.rs1.collect_side_effects(sink, shard_ctx, step); + self.rs2.collect_side_effects(sink, shard_ctx, step); + self.mem_write.collect_side_effects(sink, shard_ctx, step); + } } diff --git a/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs index 3093943d7..38f6b7758 100644 --- a/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs @@ -1,10 +1,6 @@ use crate::e2e::ShardContext; #[cfg(feature = "gpu")] use crate::tables::RMMCollections; -#[cfg(feature = "gpu")] -use ceno_emul::StepIndex; -#[cfg(feature = "gpu")] -use gkr_iop::utils::lk_multiplicity::Multiplicity; /// constrain implementation follow from https://github.com/openvm-org/openvm/blob/main/extensions/rv32im/circuit/src/shift/core.rs use crate::{ instructions::{ @@ -15,12 +11,20 @@ use crate::{ i_insn::IInstructionConfig, r_insn::RInstructionConfig, }, + side_effects::{ + CpuSideEffectSink, LkOp, SideEffectSink, emit_byte_decomposition_ops, + emit_const_range_op, + }, }, structs::ProgramParams, utils::{split_to_limb, split_to_u8}, }; use ceno_emul::InsnKind; +#[cfg(feature = "gpu")] +use ceno_emul::StepIndex; use ff_ext::{ExtensionField, FieldInto}; +#[cfg(feature = "gpu")] +use gkr_iop::utils::lk_multiplicity::Multiplicity; use itertools::Itertools; use multilinear_extensions::{Expression, ToExpr, WitIn}; use p3::field::{Field, FieldAlgebra}; @@ -212,6 +216,45 @@ impl }) } + pub fn collect_side_effects( + &self, + sink: &mut impl SideEffectSink, + kind: InsnKind, + b: u32, + c: u32, + ) { + let b = split_to_limb::(b); + let c = split_to_limb::(c); + let (_, limb_shift, bit_shift) = run_shift::( + kind, + &b.clone().try_into().unwrap(), + &c.clone().try_into().unwrap(), + ); + + let bit_shift_carry: [u32; NUM_LIMBS] = array::from_fn(|i| match kind { + InsnKind::SLL | InsnKind::SLLI => b[i] >> (LIMB_BITS - bit_shift), + _ => b[i] % (1 << bit_shift), + }); + for val in bit_shift_carry { + sink.emit_lk(LkOp::DynamicRange { + value: val as u64, + bits: bit_shift as u32, + }); + } + + let num_bits_log = (NUM_LIMBS * LIMB_BITS).ilog2(); + let carry_quotient = + (((c[0] as usize) - bit_shift - limb_shift * LIMB_BITS) >> num_bits_log) as u64; + emit_const_range_op(sink, carry_quotient, LIMB_BITS - num_bits_log as usize); + + if matches!(kind, InsnKind::SRA | InsnKind::SRAI) { + sink.emit_lk(LkOp::Xor { + a: b[NUM_LIMBS - 1] as u8, + b: (1 << (LIMB_BITS - 1)) as u8, + }); + } + } + pub fn assign_instances( &self, instance: &mut [::BaseField], @@ -284,6 +327,8 @@ impl Instruction for ShiftLogicalInstru type InstructionConfig = ShiftRTypeConfig; type InsnType = InsnKind; + const GPU_SIDE_EFFECTS: bool = true; + fn inst_kinds() -> &'static [Self::InsnType] { &[I::INST_KIND] } @@ -370,6 +415,43 @@ impl Instruction for ShiftLogicalInstru Ok(()) } + fn collect_side_effects_instance( + config: &ShiftRTypeConfig, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut crate::witness::LkMultiplicity, + step: &ceno_emul::StepRecord, + ) -> Result<(), crate::error::ZKVMError> { + let shard_ctx_ptr = shard_ctx as *mut ShardContext; + let shard_ctx_view = unsafe { &*shard_ctx_ptr }; + let mut sink = unsafe { CpuSideEffectSink::from_raw(shard_ctx_ptr, lk_multiplicity) }; + config + .r_insn + .collect_side_effects(&mut sink, shard_ctx_view, step); + + let rd_written = split_to_u8::(step.rd().unwrap().value.after); + emit_byte_decomposition_ops(&mut sink, &rd_written); + config.shift_base_config.collect_side_effects( + &mut sink, + I::INST_KIND, + step.rs1().unwrap().value, + step.rs2().unwrap().value, + ); + + Ok(()) + } + + fn collect_shard_side_effects_instance( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut crate::witness::LkMultiplicity, + step: &ceno_emul::StepRecord, + ) -> Result<(), crate::error::ZKVMError> { + config + .r_insn + .collect_shard_effects(shard_ctx, lk_multiplicity, step); + Ok(()) + } + #[cfg(feature = "gpu")] fn assign_instances( config: &Self::InstructionConfig, @@ -422,6 +504,8 @@ impl Instruction for ShiftImmInstructio type InstructionConfig = ShiftImmConfig; type InsnType = InsnKind; + const GPU_SIDE_EFFECTS: bool = true; + fn inst_kinds() -> &'static [Self::InsnType] { &[I::INST_KIND] } @@ -510,6 +594,43 @@ impl Instruction for ShiftImmInstructio Ok(()) } + fn collect_side_effects_instance( + config: &ShiftImmConfig, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut crate::witness::LkMultiplicity, + step: &ceno_emul::StepRecord, + ) -> Result<(), crate::error::ZKVMError> { + let shard_ctx_ptr = shard_ctx as *mut ShardContext; + let shard_ctx_view = unsafe { &*shard_ctx_ptr }; + let mut sink = unsafe { CpuSideEffectSink::from_raw(shard_ctx_ptr, lk_multiplicity) }; + config + .i_insn + .collect_side_effects(&mut sink, shard_ctx_view, step); + + let rd_written = split_to_u8::(step.rd().unwrap().value.after); + emit_byte_decomposition_ops(&mut sink, &rd_written); + config.shift_base_config.collect_side_effects( + &mut sink, + I::INST_KIND, + step.rs1().unwrap().value, + step.insn().imm as i16 as u16 as u32, + ); + + Ok(()) + } + + fn collect_shard_side_effects_instance( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut crate::witness::LkMultiplicity, + step: &ceno_emul::StepRecord, + ) -> Result<(), crate::error::ZKVMError> { + config + .i_insn + .collect_shard_effects(shard_ctx, lk_multiplicity, step); + Ok(()) + } + #[cfg(feature = "gpu")] fn assign_instances( config: &Self::InstructionConfig, diff --git a/ceno_zkvm/src/instructions/riscv/slt/slt_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/slt/slt_circuit_v2.rs index b5e41f6ac..15e5c104b 100644 --- a/ceno_zkvm/src/instructions/riscv/slt/slt_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/slt/slt_circuit_v2.rs @@ -7,6 +7,7 @@ use crate::{ instructions::{ Instruction, riscv::{RIVInstruction, constants::UInt, r_insn::RInstructionConfig}, + side_effects::{CpuSideEffectSink, emit_uint_limbs_lt_ops}, }, structs::ProgramParams, witness::LkMultiplicity, @@ -39,6 +40,8 @@ impl Instruction for SetLessThanInstruc type InstructionConfig = SetLessThanConfig; type InsnType = InsnKind; + const GPU_SIDE_EFFECTS: bool = true; + fn inst_kinds() -> &'static [Self::InsnType] { &[I::INST_KIND] } @@ -121,6 +124,45 @@ impl Instruction for SetLessThanInstruc Ok(()) } + fn collect_side_effects_instance( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + lkm: &mut LkMultiplicity, + step: &StepRecord, + ) -> Result<(), ZKVMError> { + let shard_ctx_ptr = shard_ctx as *mut ShardContext; + let shard_ctx_view = unsafe { &*shard_ctx_ptr }; + let mut sink = unsafe { CpuSideEffectSink::from_raw(shard_ctx_ptr, lkm) }; + config + .r_insn + .collect_side_effects(&mut sink, shard_ctx_view, step); + + let rs1_value = Value::new_unchecked(step.rs1().unwrap().value); + let rs2_value = Value::new_unchecked(step.rs2().unwrap().value); + let rs1_limbs = rs1_value.as_u16_limbs(); + let rs2_limbs = rs2_value.as_u16_limbs(); + emit_uint_limbs_lt_ops( + &mut sink, + matches!(I::INST_KIND, InsnKind::SLT), + &rs1_limbs, + &rs2_limbs, + ); + + Ok(()) + } + + fn collect_shard_side_effects_instance( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) -> Result<(), ZKVMError> { + config + .r_insn + .collect_shard_effects(shard_ctx, lk_multiplicity, step); + Ok(()) + } + #[cfg(feature = "gpu")] fn assign_instances( config: &Self::InstructionConfig, diff --git a/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs index 471a70866..da60ca953 100644 --- a/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs @@ -11,6 +11,7 @@ use crate::{ constants::{UINT_LIMBS, UInt}, i_insn::IInstructionConfig, }, + side_effects::{CpuSideEffectSink, emit_uint_limbs_lt_ops}, }, structs::ProgramParams, utils::{imm_sign_extend, imm_sign_extend_circuit}, @@ -50,6 +51,8 @@ impl Instruction for SetLessThanImmInst type InstructionConfig = SetLessThanImmConfig; type InsnType = InsnKind; + const GPU_SIDE_EFFECTS: bool = true; + fn inst_kinds() -> &'static [Self::InsnType] { &[I::INST_KIND] } @@ -141,6 +144,44 @@ impl Instruction for SetLessThanImmInst Ok(()) } + fn collect_side_effects_instance( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + lkm: &mut LkMultiplicity, + step: &StepRecord, + ) -> Result<(), ZKVMError> { + let shard_ctx_ptr = shard_ctx as *mut ShardContext; + let shard_ctx_view = unsafe { &*shard_ctx_ptr }; + let mut sink = unsafe { CpuSideEffectSink::from_raw(shard_ctx_ptr, lkm) }; + config + .i_insn + .collect_side_effects(&mut sink, shard_ctx_view, step); + + let rs1_value = Value::new_unchecked(step.rs1().unwrap().value); + let rs1_limbs = rs1_value.as_u16_limbs(); + let imm_sign_extend = imm_sign_extend(true, step.insn().imm as i16); + emit_uint_limbs_lt_ops( + &mut sink, + matches!(I::INST_KIND, InsnKind::SLTI), + &rs1_limbs, + &imm_sign_extend, + ); + + Ok(()) + } + + fn collect_shard_side_effects_instance( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) -> Result<(), ZKVMError> { + config + .i_insn + .collect_shard_effects(shard_ctx, lk_multiplicity, step); + Ok(()) + } + #[cfg(feature = "gpu")] fn assign_instances( config: &Self::InstructionConfig, diff --git a/ceno_zkvm/src/instructions/side_effects.rs b/ceno_zkvm/src/instructions/side_effects.rs new file mode 100644 index 000000000..97695c526 --- /dev/null +++ b/ceno_zkvm/src/instructions/side_effects.rs @@ -0,0 +1,1157 @@ +use ceno_emul::{Cycle, Word, WordAddr}; +use gkr_iop::{ + gadgets::{AssertLtConfig, cal_lt_diff}, + tables::{LookupTable, OpsTable}, +}; +use smallvec::SmallVec; +use std::marker::PhantomData; + +use crate::{ + e2e::ShardContext, + instructions::riscv::constants::{LIMB_BITS, UINT_LIMBS}, + structs::RAMType, + witness::LkMultiplicity, +}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum LkOp { + AssertU16 { value: u16 }, + DynamicRange { value: u64, bits: u32 }, + AssertU14 { value: u16 }, + Fetch { pc: u32 }, + DoubleU8 { a: u8, b: u8 }, + And { a: u8, b: u8 }, + Or { a: u8, b: u8 }, + Xor { a: u8, b: u8 }, + Ltu { a: u8, b: u8 }, + Pow2 { value: u8 }, + ShrByte { shift: u8, carry: u8, bits: u8 }, +} + +impl LkOp { + pub fn encode_all(&self) -> SmallVec<[(LookupTable, u64); 2]> { + match *self { + LkOp::AssertU16 { value } => { + SmallVec::from_slice(&[(LookupTable::Dynamic, (1u64 << 16) + value as u64)]) + } + LkOp::DynamicRange { value, bits } => { + SmallVec::from_slice(&[(LookupTable::Dynamic, (1u64 << bits) + value)]) + } + LkOp::AssertU14 { value } => { + SmallVec::from_slice(&[(LookupTable::Dynamic, (1u64 << 14) + value as u64)]) + } + LkOp::Fetch { pc } => SmallVec::from_slice(&[(LookupTable::Instruction, pc as u64)]), + LkOp::DoubleU8 { a, b } => { + SmallVec::from_slice(&[(LookupTable::DoubleU8, ((a as u64) << 8) + b as u64)]) + } + LkOp::And { a, b } => { + SmallVec::from_slice(&[(LookupTable::And, (a as u64) | ((b as u64) << 8))]) + } + LkOp::Or { a, b } => { + SmallVec::from_slice(&[(LookupTable::Or, (a as u64) | ((b as u64) << 8))]) + } + LkOp::Xor { a, b } => { + SmallVec::from_slice(&[(LookupTable::Xor, (a as u64) | ((b as u64) << 8))]) + } + LkOp::Ltu { a, b } => { + SmallVec::from_slice(&[(LookupTable::Ltu, (a as u64) | ((b as u64) << 8))]) + } + LkOp::Pow2 { value } => { + SmallVec::from_slice(&[(LookupTable::Pow, 2u64 | ((value as u64) << 8))]) + } + LkOp::ShrByte { shift, carry, bits } => SmallVec::from_slice(&[ + ( + LookupTable::DoubleU8, + ((shift as u64) << 8) + ((shift as u64) << bits), + ), + ( + LookupTable::DoubleU8, + ((carry as u64) << 8) + ((carry as u64) << (8 - bits)), + ), + ]), + } + } +} + +#[derive(Debug, Clone, Copy)] +pub struct SendEvent { + pub ram_type: RAMType, + pub addr: WordAddr, + pub id: u64, + pub cycle: Cycle, + pub prev_cycle: Cycle, + pub value: Word, + pub prev_value: Option, +} + +pub trait SideEffectSink { + fn emit_lk(&mut self, op: LkOp); + fn emit_send(&mut self, event: SendEvent); + fn touch_addr(&mut self, addr: WordAddr); +} + +pub struct CpuSideEffectSink<'ctx, 'shard, 'lk> { + shard_ctx: *mut ShardContext<'shard>, + lk: &'lk mut LkMultiplicity, + _marker: PhantomData<&'ctx mut ShardContext<'shard>>, +} + +impl<'ctx, 'shard, 'lk> CpuSideEffectSink<'ctx, 'shard, 'lk> { + pub unsafe fn from_raw( + shard_ctx: *mut ShardContext<'shard>, + lk: &'lk mut LkMultiplicity, + ) -> Self { + Self { + shard_ctx, + lk, + _marker: PhantomData, + } + } + + fn shard_ctx(&mut self) -> &mut ShardContext<'shard> { + // Safety: `from_raw` is only constructed from a live `&mut ShardContext` + // for the duration of side-effect collection. + unsafe { &mut *self.shard_ctx } + } +} + +impl SideEffectSink for CpuSideEffectSink<'_, '_, '_> { + fn emit_lk(&mut self, op: LkOp) { + for (table, key) in op.encode_all() { + self.lk.increment(table, key); + } + } + + fn emit_send(&mut self, event: SendEvent) { + self.shard_ctx().record_send_without_touch( + event.ram_type, + event.addr, + event.id, + event.cycle, + event.prev_cycle, + event.value, + event.prev_value, + ); + } + + fn touch_addr(&mut self, addr: WordAddr) { + self.shard_ctx().push_addr_accessed(addr); + } +} + +pub fn emit_assert_lt_ops( + sink: &mut impl SideEffectSink, + lt_cfg: &AssertLtConfig, + lhs: u64, + rhs: u64, +) { + let max_bits = lt_cfg.0.max_bits; + let diff = cal_lt_diff(lhs < rhs, max_bits, lhs, rhs); + for i in 0..(max_bits / u16::BITS as usize) { + let value = ((diff >> (i * u16::BITS as usize)) & 0xffff) as u16; + sink.emit_lk(LkOp::AssertU16 { value }); + } + let remain_bits = max_bits % u16::BITS as usize; + if remain_bits > 1 { + let value = (diff >> ((lt_cfg.0.diff.len() - 1) * u16::BITS as usize)) & 0xffff; + sink.emit_lk(LkOp::DynamicRange { + value, + bits: remain_bits as u32, + }); + } +} + +pub fn emit_u16_limbs(sink: &mut impl SideEffectSink, value: u32) { + sink.emit_lk(LkOp::AssertU16 { + value: (value & 0xffff) as u16, + }); + sink.emit_lk(LkOp::AssertU16 { + value: (value >> 16) as u16, + }); +} + +pub fn emit_const_range_op(sink: &mut impl SideEffectSink, value: u64, bits: usize) { + match bits { + 0 | 1 => {} + 14 => sink.emit_lk(LkOp::AssertU14 { + value: value as u16, + }), + 16 => sink.emit_lk(LkOp::AssertU16 { + value: value as u16, + }), + _ => sink.emit_lk(LkOp::DynamicRange { + value, + bits: bits as u32, + }), + } +} + +pub fn emit_byte_decomposition_ops(sink: &mut impl SideEffectSink, bytes: &[u8]) { + for chunk in bytes.chunks(2) { + match chunk { + [a, b] => sink.emit_lk(LkOp::DoubleU8 { a: *a, b: *b }), + [a] => emit_const_range_op(sink, *a as u64, 8), + _ => unreachable!(), + } + } +} + +pub fn emit_signed_extend_op(sink: &mut impl SideEffectSink, n_bits: usize, value: u64) { + let msb = value >> (n_bits - 1); + sink.emit_lk(LkOp::DynamicRange { + value: 2 * value - (msb << n_bits), + bits: n_bits as u32, + }); +} + +pub fn emit_logic_u8_ops( + sink: &mut impl SideEffectSink, + lhs: u64, + rhs: u64, + num_bytes: usize, +) { + for i in 0..num_bytes { + let a = ((lhs >> (i * 8)) & 0xff) as u8; + let b = ((rhs >> (i * 8)) & 0xff) as u8; + let op = match OP::ROM_TYPE { + LookupTable::And => LkOp::And { a, b }, + LookupTable::Or => LkOp::Or { a, b }, + LookupTable::Xor => LkOp::Xor { a, b }, + LookupTable::Ltu => LkOp::Ltu { a, b }, + rom_type => unreachable!("unsupported logic table: {rom_type:?}"), + }; + sink.emit_lk(op); + } +} + +pub fn emit_uint_limbs_lt_ops( + sink: &mut impl SideEffectSink, + is_sign_comparison: bool, + a: &[u16], + b: &[u16], +) { + assert_eq!(a.len(), UINT_LIMBS); + assert_eq!(b.len(), UINT_LIMBS); + + let last = UINT_LIMBS - 1; + let sign_mask = 1 << (LIMB_BITS - 1); + let is_a_neg = is_sign_comparison && (a[last] & sign_mask) != 0; + let is_b_neg = is_sign_comparison && (b[last] & sign_mask) != 0; + + let (cmp_lt, diff_idx) = (0..UINT_LIMBS) + .rev() + .find(|&i| a[i] != b[i]) + .map(|i| ((a[i] < b[i]) ^ is_a_neg ^ is_b_neg, i)) + .unwrap_or((false, UINT_LIMBS)); + + let a_msb_range = if is_a_neg { + a[last] - sign_mask + } else { + a[last] + ((is_sign_comparison as u16) << (LIMB_BITS - 1)) + }; + let b_msb_range = if is_b_neg { + b[last] - sign_mask + } else { + b[last] + ((is_sign_comparison as u16) << (LIMB_BITS - 1)) + }; + + let to_signed = |value: u16, is_neg: bool| -> i32 { + if is_neg { + value as i32 - (1 << LIMB_BITS) + } else { + value as i32 + } + }; + let diff_val = if diff_idx == UINT_LIMBS { + 0 + } else if diff_idx == last { + let a_signed = to_signed(a[last], is_a_neg); + let b_signed = to_signed(b[last], is_b_neg); + if cmp_lt { + (b_signed - a_signed) as u16 + } else { + (a_signed - b_signed) as u16 + } + } else if cmp_lt { + b[diff_idx] - a[diff_idx] + } else { + a[diff_idx] - b[diff_idx] + }; + + emit_const_range_op( + sink, + if diff_idx == UINT_LIMBS { + 0 + } else { + (diff_val - 1) as u64 + }, + LIMB_BITS, + ); + emit_const_range_op(sink, a_msb_range as u64, LIMB_BITS); + emit_const_range_op(sink, b_msb_range as u64, LIMB_BITS); +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + e2e::ShardContext, + instructions::{ + Instruction, cpu_assign_instances, cpu_collect_shard_side_effects, + cpu_collect_side_effects, + riscv::{ + AddInstruction, JalInstruction, JalrInstruction, LwInstruction, SbInstruction, + branch::{BeqInstruction, BltInstruction}, + div::{DivInstruction, RemuInstruction}, + logic::AndInstruction, + mulh::{MulInstruction, MulhInstruction}, + shift::SraInstruction, + shift_imm::SlliInstruction, + slt::SltInstruction, + slti::SltiInstruction, + }, + }, + structs::ProgramParams, + }; + use ceno_emul::{ + ByteAddr, Change, InsnKind, PC_STEP_SIZE, ReadOp, StepRecord, WordAddr, WriteOp, + encode_rv32, + }; + use ff_ext::GoldilocksExt2; + use gkr_iop::tables::LookupTable; + + type E = GoldilocksExt2; + + fn assert_side_effects_match>( + config: &I::InstructionConfig, + num_witin: usize, + num_structural_witin: usize, + steps: &[StepRecord], + ) { + let indices: Vec = (0..steps.len()).collect(); + + let mut assign_ctx = ShardContext::default(); + let (_, expected_lk) = cpu_assign_instances::( + config, + &mut assign_ctx, + num_witin, + num_structural_witin, + steps, + &indices, + ) + .unwrap(); + + let mut collect_ctx = ShardContext::default(); + let actual_lk = + cpu_collect_side_effects::(config, &mut collect_ctx, steps, &indices).unwrap(); + + assert_eq!(flatten_lk(&expected_lk), flatten_lk(&actual_lk)); + assert_eq!( + assign_ctx.get_addr_accessed(), + collect_ctx.get_addr_accessed() + ); + assert_eq!( + flatten_records(assign_ctx.read_records()), + flatten_records(collect_ctx.read_records()) + ); + assert_eq!( + flatten_records(assign_ctx.write_records()), + flatten_records(collect_ctx.write_records()) + ); + } + + fn assert_shard_side_effects_match>( + config: &I::InstructionConfig, + num_witin: usize, + num_structural_witin: usize, + steps: &[StepRecord], + ) { + let indices: Vec = (0..steps.len()).collect(); + + let mut assign_ctx = ShardContext::default(); + let (_, expected_lk) = cpu_assign_instances::( + config, + &mut assign_ctx, + num_witin, + num_structural_witin, + steps, + &indices, + ) + .unwrap(); + + let mut collect_ctx = ShardContext::default(); + let actual_lk = + cpu_collect_shard_side_effects::(config, &mut collect_ctx, steps, &indices) + .unwrap(); + + assert_eq!( + expected_lk[LookupTable::Instruction as usize], + actual_lk[LookupTable::Instruction as usize] + ); + for (table_idx, table) in actual_lk.iter().enumerate() { + if table_idx != LookupTable::Instruction as usize { + assert!( + table.is_empty(), + "unexpected non-fetch shard-only multiplicity in table {table_idx}: {table:?}" + ); + } + } + assert_eq!( + assign_ctx.get_addr_accessed(), + collect_ctx.get_addr_accessed() + ); + assert_eq!( + flatten_records(assign_ctx.read_records()), + flatten_records(collect_ctx.read_records()) + ); + assert_eq!( + flatten_records(assign_ctx.write_records()), + flatten_records(collect_ctx.write_records()) + ); + } + + fn flatten_records( + records: &[std::collections::BTreeMap], + ) -> Vec<(WordAddr, u64, u64, usize)> { + records + .iter() + .flat_map(|table| { + table + .iter() + .map(|(addr, record)| (*addr, record.prev_cycle, record.cycle, record.shard_id)) + }) + .collect() + } + + fn flatten_lk( + multiplicity: &gkr_iop::utils::lk_multiplicity::Multiplicity, + ) -> Vec> { + multiplicity + .iter() + .map(|table| { + let mut entries = table + .iter() + .map(|(key, count)| (*key, *count)) + .collect::>(); + entries.sort_unstable(); + entries + }) + .collect() + } + + #[test] + fn test_add_side_effects_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "add_side_effects"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + AddInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let steps: Vec<_> = (0..4) + .map(|i| { + let rd = 2 + i; + let rs1 = 8 + i; + let rs2 = 16 + i; + let lhs = 10 + i as u32; + let rhs = 100 + i as u32; + let insn = encode_rv32(InsnKind::ADD, rs1, rs2, rd, 0); + StepRecord::new_r_instruction( + 4 + (i as u64) * 4, + ByteAddr(0x1000 + i as u32 * 4), + insn, + lhs, + rhs, + Change::new(0, lhs.wrapping_add(rhs)), + 0, + ) + }) + .collect(); + + assert_side_effects_match::>( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &steps, + ); + } + + #[test] + fn test_and_side_effects_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "and_side_effects"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + AndInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let steps: Vec<_> = (0..4) + .map(|i| { + let rd = 2 + i; + let rs1 = 8 + i; + let rs2 = 16 + i; + let lhs = 0xdead_0000 | i as u32; + let rhs = 0x00ff_ff00 | ((i as u32) << 8); + let insn = encode_rv32(InsnKind::AND, rs1, rs2, rd, 0); + StepRecord::new_r_instruction( + 4 + (i as u64) * 4, + ByteAddr(0x2000 + i as u32 * 4), + insn, + lhs, + rhs, + Change::new(0, lhs & rhs), + 0, + ) + }) + .collect(); + + assert_side_effects_match::>( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &steps, + ); + } + + #[test] + fn test_add_shard_side_effects_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "add_shard_side_effects"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + AddInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let steps: Vec<_> = (0..4) + .map(|i| { + let rd = 2 + i; + let rs1 = 8 + i; + let rs2 = 16 + i; + let lhs = 10 + i as u32; + let rhs = 100 + i as u32; + let insn = encode_rv32(InsnKind::ADD, rs1, rs2, rd, 0); + StepRecord::new_r_instruction( + 84 + (i as u64) * 4, + ByteAddr(0x5000 + i as u32 * 4), + insn, + lhs, + rhs, + Change::new(0, lhs.wrapping_add(rhs)), + 0, + ) + }) + .collect(); + + assert_shard_side_effects_match::>( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &steps, + ); + } + + #[test] + fn test_and_shard_side_effects_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "and_shard_side_effects"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + AndInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let steps: Vec<_> = (0..4) + .map(|i| { + let rd = 2 + i; + let rs1 = 8 + i; + let rs2 = 16 + i; + let lhs = 0xdead_0000 | i as u32; + let rhs = 0x00ff_ff00 | ((i as u32) << 8); + let insn = encode_rv32(InsnKind::AND, rs1, rs2, rd, 0); + StepRecord::new_r_instruction( + 100 + (i as u64) * 4, + ByteAddr(0x5100 + i as u32 * 4), + insn, + lhs, + rhs, + Change::new(0, lhs & rhs), + 0, + ) + }) + .collect(); + + assert_shard_side_effects_match::>( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &steps, + ); + } + + #[test] + fn test_lw_side_effects_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "lw_side_effects"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + LwInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let steps: Vec<_> = (0..4) + .map(|i| { + let rd = 2 + i; + let rs1 = 8 + i; + let rs1_val = 0x1000u32 + (i as u32) * 16; + let imm = (i as i32) * 4 - 4; + let mem_addr = rs1_val.wrapping_add_signed(imm); + let mem_val = 0xabc0_0000 | i as u32; + let insn = encode_rv32(InsnKind::LW, rs1, 0, rd, imm); + let mem_read = ReadOp { + addr: WordAddr::from(ByteAddr(mem_addr)), + value: mem_val, + previous_cycle: 0, + }; + StepRecord::new_im_instruction( + 4 + (i as u64) * 4, + ByteAddr(0x3000 + i as u32 * 4), + insn, + rs1_val, + Change::new(0, mem_val), + mem_read, + 0, + ) + }) + .collect(); + + assert_side_effects_match::>( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &steps, + ); + } + + #[test] + fn test_lw_shard_side_effects_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "lw_shard_side_effects"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + LwInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let steps: Vec<_> = (0..4) + .map(|i| { + let rd = 2 + i; + let rs1 = 8 + i; + let rs1_val = 0x1400u32 + (i as u32) * 16; + let imm = (i as i32) * 4 - 4; + let mem_addr = rs1_val.wrapping_add_signed(imm); + let mem_val = 0xabd0_0000 | i as u32; + let insn = encode_rv32(InsnKind::LW, rs1, 0, rd, imm); + let mem_read = ReadOp { + addr: WordAddr::from(ByteAddr(mem_addr)), + value: mem_val, + previous_cycle: 0, + }; + StepRecord::new_im_instruction( + 116 + (i as u64) * 4, + ByteAddr(0x5200 + i as u32 * 4), + insn, + rs1_val, + Change::new(0, mem_val), + mem_read, + 0, + ) + }) + .collect(); + + assert_shard_side_effects_match::>( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &steps, + ); + } + + #[test] + fn test_beq_side_effects_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "beq_side_effects"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + BeqInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let cases = [ + (true, 0x1122_3344, 0x1122_3344), + (false, 0x5566_7788, 0x99aa_bbcc), + ]; + let steps: Vec<_> = cases + .into_iter() + .enumerate() + .map(|(i, (taken, lhs, rhs))| { + let pc = ByteAddr(0x4000 + i as u32 * 4); + let next_pc = if taken { + ByteAddr(pc.0 + 8) + } else { + pc + PC_STEP_SIZE + }; + StepRecord::new_b_instruction( + 4 + i as u64 * 4, + Change::new(pc, next_pc), + encode_rv32(InsnKind::BEQ, 8 + i as u32, 16 + i as u32, 0, 8), + lhs, + rhs, + 0, + ) + }) + .collect(); + + assert_side_effects_match::>( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &steps, + ); + } + + #[test] + fn test_blt_side_effects_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "blt_side_effects"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + BltInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let cases = [(true, (-2i32) as u32, 1u32), (false, 7u32, (-3i32) as u32)]; + let steps: Vec<_> = cases + .into_iter() + .enumerate() + .map(|(i, (taken, lhs, rhs))| { + let pc = ByteAddr(0x4100 + i as u32 * 4); + let next_pc = if taken { + ByteAddr(pc.0.wrapping_sub(8)) + } else { + pc + PC_STEP_SIZE + }; + StepRecord::new_b_instruction( + 12 + i as u64 * 4, + Change::new(pc, next_pc), + encode_rv32(InsnKind::BLT, 4 + i as u32, 5 + i as u32, 0, -8), + lhs, + rhs, + 10, + ) + }) + .collect(); + + assert_side_effects_match::>( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &steps, + ); + } + + #[test] + fn test_jal_side_effects_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "jal_side_effects"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + JalInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let offsets = [8, -8]; + let steps: Vec<_> = offsets + .into_iter() + .enumerate() + .map(|(i, offset)| { + let pc = ByteAddr(0x4200 + i as u32 * 4); + StepRecord::new_j_instruction( + 20 + i as u64 * 4, + Change::new(pc, ByteAddr(pc.0.wrapping_add_signed(offset))), + encode_rv32(InsnKind::JAL, 0, 0, 3 + i as u32, offset), + Change::new(0, (pc + PC_STEP_SIZE).into()), + 0, + ) + }) + .collect(); + + assert_side_effects_match::>( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &steps, + ); + } + + #[test] + fn test_jalr_side_effects_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "jalr_side_effects"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + JalrInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let cases = [(100u32, 3), (0x4010u32, -5)]; + let steps: Vec<_> = cases + .into_iter() + .enumerate() + .map(|(i, (rs1, imm))| { + let pc = ByteAddr(0x4300 + i as u32 * 4); + let next_pc = ByteAddr(rs1.wrapping_add_signed(imm) & !1); + StepRecord::new_i_instruction( + 28 + i as u64 * 4, + Change::new(pc, next_pc), + encode_rv32(InsnKind::JALR, 2 + i as u32, 0, 6 + i as u32, imm), + rs1, + Change::new(0, (pc + PC_STEP_SIZE).into()), + 0, + ) + }) + .collect(); + + assert_side_effects_match::>( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &steps, + ); + } + + #[test] + fn test_slt_side_effects_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "slt_side_effects"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + SltInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let cases = [((-1i32) as u32, 0u32), (5u32, (-2i32) as u32)]; + let steps: Vec<_> = cases + .into_iter() + .enumerate() + .map(|(i, (lhs, rhs))| { + let insn = + encode_rv32(InsnKind::SLT, 9 + i as u32, 10 + i as u32, 11 + i as u32, 0); + StepRecord::new_r_instruction( + 36 + i as u64 * 4, + ByteAddr(0x4400 + i as u32 * 4), + insn, + lhs, + rhs, + Change::new(0, ((lhs as i32) < (rhs as i32)) as u32), + 0, + ) + }) + .collect(); + + assert_side_effects_match::>( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &steps, + ); + } + + #[test] + fn test_slti_side_effects_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "slti_side_effects"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + SltiInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let cases = [(0u32, -1), ((-2i32) as u32, 1)]; + let steps: Vec<_> = cases + .into_iter() + .enumerate() + .map(|(i, (rs1, imm))| { + let insn = encode_rv32(InsnKind::SLTI, 12 + i as u32, 0, 13 + i as u32, imm); + let pc = ByteAddr(0x4500 + i as u32 * 4); + StepRecord::new_i_instruction( + 44 + i as u64 * 4, + Change::new(pc, pc + PC_STEP_SIZE), + insn, + rs1, + Change::new(0, ((rs1 as i32) < imm) as u32), + 0, + ) + }) + .collect(); + + assert_side_effects_match::>( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &steps, + ); + } + + #[test] + fn test_sra_side_effects_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "sra_side_effects"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + SraInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let cases = [(0x8765_4321u32, 4u32), (0xf000_0000u32, 31u32)]; + let steps: Vec<_> = cases + .into_iter() + .enumerate() + .map(|(i, (lhs, rhs))| { + let shift = rhs & 31; + let rd = ((lhs as i32) >> shift) as u32; + StepRecord::new_r_instruction( + 52 + i as u64 * 4, + ByteAddr(0x4600 + i as u32 * 4), + encode_rv32(InsnKind::SRA, 6 + i as u32, 7 + i as u32, 8 + i as u32, 0), + lhs, + rhs, + Change::new(0, rd), + 0, + ) + }) + .collect(); + + assert_side_effects_match::>( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &steps, + ); + } + + #[test] + fn test_slli_side_effects_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "slli_side_effects"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + SlliInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let cases = [(0x1234_5678u32, 3), (0x0000_0001u32, 31)]; + let steps: Vec<_> = cases + .into_iter() + .enumerate() + .map(|(i, (rs1, imm))| { + let pc = ByteAddr(0x4700 + i as u32 * 4); + StepRecord::new_i_instruction( + 60 + i as u64 * 4, + Change::new(pc, pc + PC_STEP_SIZE), + encode_rv32(InsnKind::SLLI, 9 + i as u32, 0, 10 + i as u32, imm), + rs1, + Change::new(0, rs1.wrapping_shl((imm & 31) as u32)), + 0, + ) + }) + .collect(); + + assert_side_effects_match::>( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &steps, + ); + } + + #[test] + fn test_sb_side_effects_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "sb_side_effects"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + SbInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let steps: Vec<_> = (0..2) + .map(|i| { + let rs1 = 0x4800u32 + i * 16; + let rs2 = 0x1234_5600u32 | i; + let imm = i as i32 - 1; + let addr = ByteAddr::from(rs1.wrapping_add_signed(imm)); + let prev = 0x4030_2010u32 + i; + let shift = (addr.shift() * 8) as usize; + let mut next = prev & !(0xff << shift); + next |= (rs2 & 0xff) << shift; + StepRecord::new_s_instruction( + 68 + i as u64 * 4, + ByteAddr(0x4800 + i * 4), + encode_rv32(InsnKind::SB, 11 + i, 12 + i, 0, imm), + rs1, + rs2, + WriteOp { + addr: addr.waddr(), + value: Change::new(prev, next), + previous_cycle: 4, + }, + 8, + ) + }) + .collect(); + + assert_side_effects_match::>( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &steps, + ); + } + + #[test] + fn test_mul_side_effects_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "mul_side_effects"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + MulInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let cases = [(2u32, 11u32), (u32::MAX, 17u32)]; + let steps: Vec<_> = cases + .into_iter() + .enumerate() + .map(|(i, (lhs, rhs))| { + StepRecord::new_r_instruction( + 76 + i as u64 * 4, + ByteAddr(0x4900 + i as u32 * 4), + encode_rv32( + InsnKind::MUL, + 13 + i as u32, + 14 + i as u32, + 15 + i as u32, + 0, + ), + lhs, + rhs, + Change::new(0, lhs.wrapping_mul(rhs)), + 0, + ) + }) + .collect(); + + assert_side_effects_match::>( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &steps, + ); + } + + #[test] + fn test_mulh_side_effects_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "mulh_side_effects"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + MulhInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let cases = [(2i32, -11i32), (i32::MIN, -1i32)]; + let steps: Vec<_> = cases + .into_iter() + .enumerate() + .map(|(i, (lhs, rhs))| { + let outcome = ((lhs as i64).wrapping_mul(rhs as i64) >> 32) as u32; + StepRecord::new_r_instruction( + 84 + i as u64 * 4, + ByteAddr(0x4a00 + i as u32 * 4), + encode_rv32( + InsnKind::MULH, + 16 + i as u32, + 17 + i as u32, + 18 + i as u32, + 0, + ), + lhs as u32, + rhs as u32, + Change::new(0, outcome), + 0, + ) + }) + .collect(); + + assert_side_effects_match::>( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &steps, + ); + } + + #[test] + fn test_div_side_effects_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "div_side_effects"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + DivInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let cases = [(17i32, -3i32), (i32::MIN, -1i32)]; + let steps: Vec<_> = cases + .into_iter() + .enumerate() + .map(|(i, (lhs, rhs))| { + let out = if rhs == 0 { + -1i32 + } else { + lhs.wrapping_div(rhs) + } as u32; + StepRecord::new_r_instruction( + 92 + i as u64 * 4, + ByteAddr(0x4b00 + i as u32 * 4), + encode_rv32( + InsnKind::DIV, + 19 + i as u32, + 20 + i as u32, + 21 + i as u32, + 0, + ), + lhs as u32, + rhs as u32, + Change::new(0, out), + 0, + ) + }) + .collect(); + + assert_side_effects_match::>( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &steps, + ); + } + + #[test] + fn test_remu_side_effects_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "remu_side_effects"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + RemuInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let cases = [(17u32, 3u32), (0x8000_0001u32, 0u32)]; + let steps: Vec<_> = cases + .into_iter() + .enumerate() + .map(|(i, (lhs, rhs))| { + let out = if rhs == 0 { lhs } else { lhs % rhs }; + StepRecord::new_r_instruction( + 100 + i as u64 * 4, + ByteAddr(0x4c00 + i as u32 * 4), + encode_rv32( + InsnKind::REMU, + 22 + i as u32, + 23 + i as u32, + 24 + i as u32, + 0, + ), + lhs, + rhs, + Change::new(0, out), + 0, + ) + }) + .collect(); + + assert_side_effects_match::>( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &steps, + ); + } + + #[test] + fn test_lk_op_encodings_match_cpu_multiplicity() { + let ops = [ + LkOp::AssertU16 { value: 7 }, + LkOp::DynamicRange { value: 11, bits: 8 }, + LkOp::AssertU14 { value: 5 }, + LkOp::Fetch { pc: 0x1234 }, + LkOp::DoubleU8 { a: 1, b: 2 }, + LkOp::And { a: 3, b: 4 }, + LkOp::Or { a: 5, b: 6 }, + LkOp::Xor { a: 7, b: 8 }, + LkOp::Ltu { a: 9, b: 10 }, + LkOp::Pow2 { value: 12 }, + LkOp::ShrByte { + shift: 3, + carry: 17, + bits: 2, + }, + ]; + + let mut lk = LkMultiplicity::default(); + for op in ops { + for (table, key) in op.encode_all() { + lk.increment(table, key); + } + } + + let finalized = lk.into_finalize_result(); + assert_eq!(finalized[LookupTable::Dynamic as usize].len(), 3); + assert_eq!(finalized[LookupTable::Instruction as usize].len(), 1); + assert_eq!(finalized[LookupTable::DoubleU8 as usize].len(), 3); + assert_eq!(finalized[LookupTable::And as usize].len(), 1); + assert_eq!(finalized[LookupTable::Or as usize].len(), 1); + assert_eq!(finalized[LookupTable::Xor as usize].len(), 1); + assert_eq!(finalized[LookupTable::Ltu as usize].len(), 1); + assert_eq!(finalized[LookupTable::Pow as usize].len(), 1); + } +} diff --git a/ceno_zkvm/src/structs.rs b/ceno_zkvm/src/structs.rs index 1f6847140..1f433d0e2 100644 --- a/ceno_zkvm/src/structs.rs +++ b/ceno_zkvm/src/structs.rs @@ -363,6 +363,14 @@ impl ZKVMWitnesses { self.lk_mlts.get(name) } + pub fn combined_lk_mlt(&self) -> Option<&Vec>> { + self.combined_lk_mlt.as_ref() + } + + pub fn lk_mlts(&self) -> &BTreeMap> { + &self.lk_mlts + } + pub fn assign_opcode_circuit>( &mut self, cs: &ZKVMConstraintSystem, From 8307ba10d207a35528d256a76251e19101357f27 Mon Sep 17 00:00:00 2001 From: Velaciela Date: Thu, 12 Mar 2026 10:17:25 +0800 Subject: [PATCH 29/73] shard-1 --- .../src/instructions/riscv/gpu/witgen_gpu.rs | 329 ++++++++++++++++-- 1 file changed, 300 insertions(+), 29 deletions(-) diff --git a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs index 42a48fb1c..56bdd8690 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs @@ -4,19 +4,21 @@ /// 1. Runs the GPU kernel to fill the witness matrix (fast) /// 2. Runs a lightweight CPU loop to collect side effects without witness replay /// 3. Returns the GPU-generated witness + CPU-collected side effects -use ceno_emul::{StepIndex, StepRecord}; +use ceno_emul::{StepIndex, StepRecord, WordAddr}; use ceno_gpu::{ Buffer, CudaHal, CudaSlice, bb31::CudaHalBB31, common::transpose::matrix_transpose, }; +use ceno_gpu::bb31::ShardDeviceBuffers; +use ceno_gpu::common::witgen_types::{GpuRamRecordSlot, GpuShardScalars, RAM_SLOTS_PER_INST}; use ff_ext::ExtensionField; -use gkr_iop::{tables::LookupTable, utils::lk_multiplicity::Multiplicity}; +use gkr_iop::{RAMType, tables::LookupTable, utils::lk_multiplicity::Multiplicity}; use p3::field::FieldAlgebra; use std::cell::{Cell, RefCell}; use tracing::info_span; use witness::{InstancePaddingStrategy, RowMajorMatrix}; use crate::{ - e2e::ShardContext, + e2e::{RAMRecord, ShardContext}, error::ZKVMError, instructions::{Instruction, cpu_collect_shard_side_effects, cpu_collect_side_effects}, tables::RMMCollections, @@ -165,6 +167,250 @@ pub fn invalidate_shard_steps_cache() { }); } +/// Cached shard metadata device buffers for GPU shard records. +/// Invalidated when shard_id changes; shared across all kernel invocations in one shard. +struct ShardMetadataCache { + shard_id: usize, + device_bufs: ShardDeviceBuffers, +} + +thread_local! { + static SHARD_META_CACHE: RefCell> = + const { RefCell::new(None) }; +} + +/// Build and cache shard metadata device buffers for GPU shard records. +/// Returns a reference to the cached `ShardDeviceBuffers`. +fn ensure_shard_metadata_cached( + hal: &CudaHalBB31, + shard_ctx: &ShardContext, +) -> Result<(), ZKVMError> { + let shard_id = shard_ctx.shard_id; + SHARD_META_CACHE.with(|cache| { + let mut cache = cache.borrow_mut(); + if let Some(c) = cache.as_ref() { + if c.shard_id == shard_id { + return Ok(()); // cache hit + } + } + + // Build sorted future-access arrays from HashMap + let (fa_cycles_vec, fa_addrs_vec, fa_next_vec) = { + let mut entries: Vec<(u64, u32, u64)> = Vec::new(); + for (cycle, pairs) in shard_ctx.addr_future_accesses.iter() { + for &(addr, next_cycle) in pairs.iter() { + entries.push((*cycle, addr.0, next_cycle)); + } + } + entries.sort_unstable(); + let mut cycles = Vec::with_capacity(entries.len()); + let mut addrs = Vec::with_capacity(entries.len()); + let mut nexts = Vec::with_capacity(entries.len()); + for (c, a, n) in entries { + cycles.push(c); + addrs.push(a); + nexts.push(n); + } + (cycles, addrs, nexts) + }; + + // Build GpuShardScalars + let scalars = GpuShardScalars { + shard_cycle_start: shard_ctx.cur_shard_cycle_range.start as u64, + shard_cycle_end: shard_ctx.cur_shard_cycle_range.end as u64, + shard_offset_cycle: shard_ctx.current_shard_offset_cycle(), + shard_id: shard_id as u32, + heap_start: shard_ctx.platform.heap.start, + heap_end: shard_ctx.platform.heap.end, + hint_start: shard_ctx.platform.hints.start, + hint_end: shard_ctx.platform.hints.end, + shard_heap_start: shard_ctx.shard_heap_addr_range.start, + shard_heap_end: shard_ctx.shard_heap_addr_range.end, + shard_hint_start: shard_ctx.shard_hint_addr_range.start, + shard_hint_end: shard_ctx.shard_hint_addr_range.end, + fa_count: fa_cycles_vec.len() as u32, + num_prev_shards: shard_ctx.prev_shard_cycle_range.len() as u32, + num_prev_heap_ranges: shard_ctx.prev_shard_heap_range.len() as u32, + num_prev_hint_ranges: shard_ctx.prev_shard_hint_range.len() as u32, + }; + + // H2D copy scalar struct + let scalars_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + &scalars as *const GpuShardScalars as *const u8, + std::mem::size_of::(), + ) + }; + let scalars_device = hal.inner.htod_copy_stream(None, scalars_bytes).map_err(|e| { + ZKVMError::InvalidWitness(format!("shard scalars H2D failed: {e}").into()) + })?; + + // H2D copy arrays (use empty slice [0] sentinel for empty arrays) + let fa_cycles_device = hal + .alloc_u64_from_host(if fa_cycles_vec.is_empty() { &[0u64] } else { &fa_cycles_vec }, None) + .map_err(|e| ZKVMError::InvalidWitness(format!("fa_cycles H2D failed: {e}").into()))?; + let fa_addrs_device = hal + .alloc_u32_from_host(if fa_addrs_vec.is_empty() { &[0u32] } else { &fa_addrs_vec }, None) + .map_err(|e| ZKVMError::InvalidWitness(format!("fa_addrs H2D failed: {e}").into()))?; + let fa_next_device = hal + .alloc_u64_from_host(if fa_next_vec.is_empty() { &[0u64] } else { &fa_next_vec }, None) + .map_err(|e| ZKVMError::InvalidWitness(format!("fa_next H2D failed: {e}").into()))?; + + let pscr = &shard_ctx.prev_shard_cycle_range; + let pscr_device = hal + .alloc_u64_from_host(if pscr.is_empty() { &[0u64] } else { pscr }, None) + .map_err(|e| ZKVMError::InvalidWitness(format!("pscr H2D failed: {e}").into()))?; + + let pshr = &shard_ctx.prev_shard_heap_range; + let pshr_device = hal + .alloc_u32_from_host(if pshr.is_empty() { &[0u32] } else { pshr }, None) + .map_err(|e| ZKVMError::InvalidWitness(format!("pshr H2D failed: {e}").into()))?; + + let pshi = &shard_ctx.prev_shard_hint_range; + let pshi_device = hal + .alloc_u32_from_host(if pshi.is_empty() { &[0u32] } else { pshi }, None) + .map_err(|e| ZKVMError::InvalidWitness(format!("pshi H2D failed: {e}").into()))?; + + let mb = (fa_cycles_vec.len() * 8 * 2 + fa_addrs_vec.len() * 4) as f64 / (1024.0 * 1024.0); + tracing::info!( + "[GPU shard] built ShardMetadataCache: shard_id={}, fa_entries={}, {:.2} MB", + shard_id, fa_cycles_vec.len(), mb, + ); + + *cache = Some(ShardMetadataCache { + shard_id, + device_bufs: ShardDeviceBuffers { + scalars: scalars_device, + fa_cycles: fa_cycles_device, + fa_addrs: fa_addrs_device, + fa_next_cycles: fa_next_device, + prev_shard_cycle_range: pscr_device, + prev_shard_heap_range: pshr_device, + prev_shard_hint_range: pshi_device, + }, + }); + Ok(()) + }) +} + +/// Borrow the cached shard device buffers for kernel launch. +fn with_cached_shard_meta(f: impl FnOnce(&ShardDeviceBuffers) -> R) -> R { + SHARD_META_CACHE.with(|cache| { + let cache = cache.borrow(); + let c = cache.as_ref().expect("shard metadata not uploaded"); + f(&c.device_bufs) + }) +} + +/// Invalidate the shard metadata cache (call when shard processing is complete). +pub fn invalidate_shard_meta_cache() { + SHARD_META_CACHE.with(|cache| { + *cache.borrow_mut() = None; + }); +} + +/// CPU-side lightweight scan of GPU-produced RAM record slots. +/// +/// Reconstructs BTreeMap read/write records and addr_accessed from the GPU output, +/// replacing the previous `collect_shard_side_effects()` CPU loop. +fn gpu_collect_shard_records( + shard_ctx: &mut ShardContext, + slots: &[GpuRamRecordSlot], +) { + let current_shard_id = shard_ctx.shard_id; + + for slot in slots { + // Check was_sent flag (bit 4): this slot corresponds to a send() call + if slot.flags & (1 << 4) != 0 { + shard_ctx.push_addr_accessed(WordAddr(slot.addr)); + } + + // Check active flag (bit 0): this slot has a read or write record + if slot.flags & 1 == 0 { + continue; + } + + let ram_type = match (slot.flags >> 5) & 0x7 { + 1 => RAMType::Register, + 2 => RAMType::Memory, + _ => continue, + }; + let has_prev_value = slot.flags & (1 << 3) != 0; + let prev_value = if has_prev_value { Some(slot.prev_value) } else { None }; + let addr = WordAddr(slot.addr); + + // Insert read record (bit 1) + if slot.flags & (1 << 1) != 0 { + shard_ctx.insert_read_record( + addr, + RAMRecord { + ram_type, + reg_id: slot.reg_id as u64, + addr, + prev_cycle: slot.prev_cycle, + cycle: slot.cycle, + shard_cycle: 0, + prev_value, + value: slot.value, + shard_id: slot.read_shard_id as usize, + }, + ); + } + + // Insert write record (bit 2) + if slot.flags & (1 << 2) != 0 { + shard_ctx.insert_write_record( + addr, + RAMRecord { + ram_type, + reg_id: slot.reg_id as u64, + addr, + prev_cycle: slot.prev_cycle, + cycle: slot.cycle, + shard_cycle: slot.shard_cycle, + prev_value, + value: slot.value, + shard_id: current_shard_id, + }, + ); + } + } +} + +/// Returns true if GPU shard records are verified for this kind. +fn kind_has_verified_shard(kind: GpuWitgenKind) -> bool { + if is_shard_kind_disabled(kind) { + return false; + } + match kind { + GpuWitgenKind::Add => true, + _ => false, + } +} + +/// Check if GPU shard records are disabled for a specific kind via env var. +fn is_shard_kind_disabled(kind: GpuWitgenKind) -> bool { + thread_local! { + static DISABLED: std::cell::OnceCell> = const { std::cell::OnceCell::new() }; + } + DISABLED.with(|cell| { + let disabled = cell.get_or_init(|| { + std::env::var("CENO_GPU_DISABLE_SHARD_KINDS") + .ok() + .map(|s| s.split(',').map(|t| t.trim().to_lowercase()).collect()) + .unwrap_or_default() + }); + if disabled.is_empty() { + return false; + } + if disabled.iter().any(|d| d == "all") { + return true; + } + let tag = kind_tag(kind); + disabled.iter().any(|d| d == tag) + }) +} + /// Returns true if GPU witgen is globally disabled via CENO_GPU_DISABLE_WITGEN env var. /// The value is cached at first access so it's immune to runtime env var manipulation. fn is_gpu_witgen_disabled() -> bool { @@ -270,8 +516,8 @@ fn gpu_assign_instances_inner>( let num_structural_witin = num_structural_witin.max(1); let total_instances = step_indices.len(); - // Step 1: GPU fills witness matrix (+ LK counters for merged kinds) - let (gpu_witness, gpu_lk_counters) = info_span!("gpu_kernel").in_scope(|| { + // Step 1: GPU fills witness matrix (+ LK counters + shard records for merged kinds) + let (gpu_witness, gpu_lk_counters, gpu_ram_slots) = info_span!("gpu_kernel").in_scope(|| { gpu_fill_witness::( hal, config, @@ -284,19 +530,36 @@ fn gpu_assign_instances_inner>( })?; // Step 2: Collect side effects - // For verified GPU kinds: LK from GPU, shard records from CPU - // For unverified kinds: full CPU side effects (GPU witness still used) + // Priority: GPU shard records > CPU shard records > full CPU side effects let lk_multiplicity = if gpu_lk_counters.is_some() && kind_has_verified_lk(kind) { let lk_multiplicity = info_span!("gpu_lk_d2h").in_scope(|| { gpu_lk_counters_to_multiplicity(gpu_lk_counters.unwrap()) })?; - // CPU: collect shard records only (send/addr_accessed). - // We call collect_shard_side_effects which also computes fetch, but we - // discard its returned Multiplicity since GPU already has all LK + fetch. - info_span!("cpu_shard_records").in_scope(|| { - let _ = collect_shard_side_effects::(config, shard_ctx, shard_steps, step_indices)?; - Ok::<(), ZKVMError>(()) - })?; + + if gpu_ram_slots.is_some() && kind_has_verified_shard(kind) { + // GPU shard records path: D2H + lightweight CPU scan + info_span!("gpu_shard_records").in_scope(|| { + let ram_buf = gpu_ram_slots.unwrap(); + let slot_bytes: Vec = ram_buf.to_vec().map_err(|e| { + ZKVMError::InvalidWitness(format!("ram_slots D2H failed: {e}").into()) + })?; + // Reinterpret u32 buffer as GpuRamRecordSlot slice + let slots: &[GpuRamRecordSlot] = unsafe { + std::slice::from_raw_parts( + slot_bytes.as_ptr() as *const GpuRamRecordSlot, + slot_bytes.len() * 4 / std::mem::size_of::(), + ) + }; + gpu_collect_shard_records(shard_ctx, slots); + Ok::<(), ZKVMError>(()) + })?; + } else { + // CPU: collect shard records only (send/addr_accessed). + info_span!("cpu_shard_records").in_scope(|| { + let _ = collect_shard_side_effects::(config, shard_ctx, shard_steps, step_indices)?; + Ok::<(), ZKVMError>(()) + })?; + } lk_multiplicity } else { // GPU LK counters missing or unverified — fall back to full CPU side effects @@ -348,6 +611,7 @@ type WitBuf = ceno_gpu::common::BufferImpl< ::BaseField, >; type LkBuf = ceno_gpu::common::BufferImpl<'static, u32>; +type RamBuf = ceno_gpu::common::BufferImpl<'static, u32>; type WitResult = ceno_gpu::common::witgen_types::GpuWitnessResult; type LkResult = ceno_gpu::common::witgen_types::GpuLookupCountersResult; @@ -381,7 +645,7 @@ fn gpu_fill_witness>( shard_steps: &[StepRecord], step_indices: &[StepIndex], kind: GpuWitgenKind, -) -> Result<(WitResult, Option), ZKVMError> { +) -> Result<(WitResult, Option, Option), ZKVMError> { // Upload shard_steps to GPU once (cached across ADD/LW calls within same shard). let shard_id = shard_ctx.shard_id; info_span!("upload_shard_steps") @@ -396,7 +660,7 @@ fn gpu_fill_witness>( macro_rules! split_full { ($result:expr) => {{ let full = $result?; - Ok((full.witness, Some(full.lk_counters))) + Ok((full.witness, Some(full.lk_counters), None)) }}; } @@ -411,21 +675,28 @@ fn gpu_fill_witness>( }; let col_map = info_span!("col_map") .in_scope(|| super::add::extract_add_column_map(arith_config, num_witin)); + ensure_shard_metadata_cached(hal, shard_ctx)?; info_span!("hal_witgen_add").in_scope(|| { with_cached_shard_steps(|gpu_records| { - split_full!(hal - .witgen_add( - &col_map, - gpu_records, - &indices_u32, - shard_offset, - fetch_base_pc, - fetch_num_slots, - None, - ) - .map_err(|e| { - ZKVMError::InvalidWitness(format!("GPU witgen_add failed: {e}").into()) - })) + with_cached_shard_meta(|shard_bufs| { + let full = hal + .witgen_add( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), + ) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_add failed: {e}").into(), + ) + })?; + Ok((full.witness, Some(full.lk_counters), full.ram_slots)) + }) }) }) } From a24c51c367118c4b08db73b36958210c95d26dd6 Mon Sep 17 00:00:00 2001 From: Velaciela Date: Thu, 12 Mar 2026 10:18:00 +0800 Subject: [PATCH 30/73] phase6-2: dispatch all 22 GPU kinds with shard metadata + enable all verified shard kinds --- .../src/instructions/riscv/gpu/witgen_gpu.rs | 718 ++++++++++-------- 1 file changed, 407 insertions(+), 311 deletions(-) diff --git a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs index 56bdd8690..acaca0947 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs @@ -9,7 +9,7 @@ use ceno_gpu::{ Buffer, CudaHal, CudaSlice, bb31::CudaHalBB31, common::transpose::matrix_transpose, }; use ceno_gpu::bb31::ShardDeviceBuffers; -use ceno_gpu::common::witgen_types::{GpuRamRecordSlot, GpuShardScalars, RAM_SLOTS_PER_INST}; +use ceno_gpu::common::witgen_types::{GpuRamRecordSlot, GpuShardScalars}; use ff_ext::ExtensionField; use gkr_iop::{RAMType, tables::LookupTable, utils::lk_multiplicity::Multiplicity}; use p3::field::FieldAlgebra; @@ -383,7 +383,30 @@ fn kind_has_verified_shard(kind: GpuWitgenKind) -> bool { return false; } match kind { - GpuWitgenKind::Add => true, + GpuWitgenKind::Add + | GpuWitgenKind::Sub + | GpuWitgenKind::LogicR(_) + | GpuWitgenKind::Lw => true, + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::LogicI(_) + | GpuWitgenKind::Addi + | GpuWitgenKind::Lui + | GpuWitgenKind::Auipc + | GpuWitgenKind::Jal + | GpuWitgenKind::ShiftR(_) + | GpuWitgenKind::ShiftI(_) + | GpuWitgenKind::Slt(_) + | GpuWitgenKind::Slti(_) + | GpuWitgenKind::BranchEq(_) + | GpuWitgenKind::BranchCmp(_) + | GpuWitgenKind::Jalr + | GpuWitgenKind::Sw + | GpuWitgenKind::Sh + | GpuWitgenKind::Sb + | GpuWitgenKind::LoadSub { .. } + | GpuWitgenKind::Mul(_) + | GpuWitgenKind::Div(_) => true, + #[cfg(not(feature = "u16limb_circuit"))] _ => false, } } @@ -550,7 +573,12 @@ fn gpu_assign_instances_inner>( slot_bytes.len() * 4 / std::mem::size_of::(), ) }; - gpu_collect_shard_records(shard_ctx, slots); + // Use a forked sub-context (Right variant) since + // insert_read_record/insert_write_record/push_addr_accessed + // require per-thread mutable references. + let mut forked = shard_ctx.get_forked(); + let thread_ctx = &mut forked[0]; + gpu_collect_shard_records(thread_ctx, slots); Ok::<(), ZKVMError>(()) })?; } else { @@ -656,17 +684,20 @@ fn gpu_fill_witness>( .in_scope(|| step_indices.iter().map(|&i| i as u32).collect()); let shard_offset = shard_ctx.current_shard_offset_cycle(); - // Helper to split GpuWitgenFullResult into (witness, Some(lk_counters)) + // Helper to split GpuWitgenFullResult into (witness, Some(lk_counters), ram_slots) macro_rules! split_full { ($result:expr) => {{ let full = $result?; - Ok((full.witness, Some(full.lk_counters), None)) + Ok((full.witness, Some(full.lk_counters), full.ram_slots)) }}; } // Compute fetch params for all GPU kinds (LK counters are merged into all kernels) let (fetch_base_pc, fetch_num_slots) = compute_fetch_params(shard_steps, step_indices); + // Ensure shard metadata is cached for GPU shard records (shared across all kernel kinds) + ensure_shard_metadata_cached(hal, shard_ctx)?; + match kind { GpuWitgenKind::Add => { let arith_config = unsafe { @@ -675,11 +706,10 @@ fn gpu_fill_witness>( }; let col_map = info_span!("col_map") .in_scope(|| super::add::extract_add_column_map(arith_config, num_witin)); - ensure_shard_metadata_cached(hal, shard_ctx)?; info_span!("hal_witgen_add").in_scope(|| { with_cached_shard_steps(|gpu_records| { with_cached_shard_meta(|shard_bufs| { - let full = hal + split_full!(hal .witgen_add( &col_map, gpu_records, @@ -694,8 +724,7 @@ fn gpu_fill_witness>( ZKVMError::InvalidWitness( format!("GPU witgen_add failed: {e}").into(), ) - })?; - Ok((full.witness, Some(full.lk_counters), full.ram_slots)) + })) }) }) }) @@ -709,19 +738,24 @@ fn gpu_fill_witness>( .in_scope(|| super::sub::extract_sub_column_map(arith_config, num_witin)); info_span!("hal_witgen_sub").in_scope(|| { with_cached_shard_steps(|gpu_records| { - split_full!(hal - .witgen_sub( - &col_map, - gpu_records, - &indices_u32, - shard_offset, - fetch_base_pc, - fetch_num_slots, - None, - ) - .map_err(|e| { - ZKVMError::InvalidWitness(format!("GPU witgen_sub failed: {e}").into()) - })) + with_cached_shard_meta(|shard_bufs| { + split_full!(hal + .witgen_sub( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), + ) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_sub failed: {e}").into(), + ) + })) + }) }) }) } @@ -734,22 +768,25 @@ fn gpu_fill_witness>( .in_scope(|| super::logic_r::extract_logic_r_column_map(logic_config, num_witin)); info_span!("hal_witgen_logic_r").in_scope(|| { with_cached_shard_steps(|gpu_records| { - split_full!(hal - .witgen_logic_r( - &col_map, - gpu_records, - &indices_u32, - shard_offset, - logic_kind, - fetch_base_pc, - fetch_num_slots, - None, - ) - .map_err(|e| { - ZKVMError::InvalidWitness( - format!("GPU witgen_logic_r failed: {e}").into(), + with_cached_shard_meta(|shard_bufs| { + split_full!(hal + .witgen_logic_r( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + logic_kind, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), ) - })) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_logic_r failed: {e}").into(), + ) + })) + }) }) }) } @@ -763,22 +800,25 @@ fn gpu_fill_witness>( .in_scope(|| super::logic_i::extract_logic_i_column_map(logic_config, num_witin)); info_span!("hal_witgen_logic_i").in_scope(|| { with_cached_shard_steps(|gpu_records| { - split_full!(hal - .witgen_logic_i( - &col_map, - gpu_records, - &indices_u32, - shard_offset, - logic_kind, - fetch_base_pc, - fetch_num_slots, - None, - ) - .map_err(|e| { - ZKVMError::InvalidWitness( - format!("GPU witgen_logic_i failed: {e}").into(), + with_cached_shard_meta(|shard_bufs| { + split_full!(hal + .witgen_logic_i( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + logic_kind, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), ) - })) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_logic_i failed: {e}").into(), + ) + })) + }) }) }) } @@ -792,19 +832,24 @@ fn gpu_fill_witness>( .in_scope(|| super::addi::extract_addi_column_map(addi_config, num_witin)); info_span!("hal_witgen_addi").in_scope(|| { with_cached_shard_steps(|gpu_records| { - split_full!(hal - .witgen_addi( - &col_map, - gpu_records, - &indices_u32, - shard_offset, - fetch_base_pc, - fetch_num_slots, - None, - ) - .map_err(|e| { - ZKVMError::InvalidWitness(format!("GPU witgen_addi failed: {e}").into()) - })) + with_cached_shard_meta(|shard_bufs| { + split_full!(hal + .witgen_addi( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), + ) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_addi failed: {e}").into(), + ) + })) + }) }) }) } @@ -818,19 +863,22 @@ fn gpu_fill_witness>( .in_scope(|| super::lui::extract_lui_column_map(lui_config, num_witin)); info_span!("hal_witgen_lui").in_scope(|| { with_cached_shard_steps(|gpu_records| { - split_full!(hal - .witgen_lui( - &col_map, - gpu_records, - &indices_u32, - shard_offset, - fetch_base_pc, - fetch_num_slots, - None, - ) - .map_err(|e| { - ZKVMError::InvalidWitness(format!("GPU witgen_lui failed: {e}").into()) - })) + with_cached_shard_meta(|shard_bufs| { + split_full!(hal + .witgen_lui( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), + ) + .map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU witgen_lui failed: {e}").into()) + })) + }) }) }) } @@ -844,21 +892,24 @@ fn gpu_fill_witness>( .in_scope(|| super::auipc::extract_auipc_column_map(auipc_config, num_witin)); info_span!("hal_witgen_auipc").in_scope(|| { with_cached_shard_steps(|gpu_records| { - split_full!(hal - .witgen_auipc( - &col_map, - gpu_records, - &indices_u32, - shard_offset, - fetch_base_pc, - fetch_num_slots, - None, - ) - .map_err(|e| { - ZKVMError::InvalidWitness( - format!("GPU witgen_auipc failed: {e}").into(), + with_cached_shard_meta(|shard_bufs| { + split_full!(hal + .witgen_auipc( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), ) - })) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_auipc failed: {e}").into(), + ) + })) + }) }) }) } @@ -872,19 +923,22 @@ fn gpu_fill_witness>( .in_scope(|| super::jal::extract_jal_column_map(jal_config, num_witin)); info_span!("hal_witgen_jal").in_scope(|| { with_cached_shard_steps(|gpu_records| { - split_full!(hal - .witgen_jal( - &col_map, - gpu_records, - &indices_u32, - shard_offset, - fetch_base_pc, - fetch_num_slots, - None, - ) - .map_err(|e| { - ZKVMError::InvalidWitness(format!("GPU witgen_jal failed: {e}").into()) - })) + with_cached_shard_meta(|shard_bufs| { + split_full!(hal + .witgen_jal( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), + ) + .map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU witgen_jal failed: {e}").into()) + })) + }) }) }) } @@ -900,22 +954,25 @@ fn gpu_fill_witness>( .in_scope(|| super::shift_r::extract_shift_r_column_map(shift_config, num_witin)); info_span!("hal_witgen_shift_r").in_scope(|| { with_cached_shard_steps(|gpu_records| { - split_full!(hal - .witgen_shift_r( - &col_map, - gpu_records, - &indices_u32, - shard_offset, - shift_kind, - fetch_base_pc, - fetch_num_slots, - None, - ) - .map_err(|e| { - ZKVMError::InvalidWitness( - format!("GPU witgen_shift_r failed: {e}").into(), + with_cached_shard_meta(|shard_bufs| { + split_full!(hal + .witgen_shift_r( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + shift_kind, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), ) - })) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_shift_r failed: {e}").into(), + ) + })) + }) }) }) } @@ -931,22 +988,25 @@ fn gpu_fill_witness>( .in_scope(|| super::shift_i::extract_shift_i_column_map(shift_config, num_witin)); info_span!("hal_witgen_shift_i").in_scope(|| { with_cached_shard_steps(|gpu_records| { - split_full!(hal - .witgen_shift_i( - &col_map, - gpu_records, - &indices_u32, - shard_offset, - shift_kind, - fetch_base_pc, - fetch_num_slots, - None, - ) - .map_err(|e| { - ZKVMError::InvalidWitness( - format!("GPU witgen_shift_i failed: {e}").into(), + with_cached_shard_meta(|shard_bufs| { + split_full!(hal + .witgen_shift_i( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + shift_kind, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), ) - })) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_shift_i failed: {e}").into(), + ) + })) + }) }) }) } @@ -960,22 +1020,25 @@ fn gpu_fill_witness>( .in_scope(|| super::slt::extract_slt_column_map(slt_config, num_witin)); info_span!("hal_witgen_slt").in_scope(|| { with_cached_shard_steps(|gpu_records| { - split_full!(hal - .witgen_slt( - &col_map, - gpu_records, - &indices_u32, - shard_offset, - is_signed, - fetch_base_pc, - fetch_num_slots, - None, - ) - .map_err(|e| { - ZKVMError::InvalidWitness( - format!("GPU witgen_slt failed: {e}").into(), + with_cached_shard_meta(|shard_bufs| { + split_full!(hal + .witgen_slt( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + is_signed, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), ) - })) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_slt failed: {e}").into(), + ) + })) + }) }) }) } @@ -989,22 +1052,25 @@ fn gpu_fill_witness>( .in_scope(|| super::slti::extract_slti_column_map(slti_config, num_witin)); info_span!("hal_witgen_slti").in_scope(|| { with_cached_shard_steps(|gpu_records| { - split_full!(hal - .witgen_slti( - &col_map, - gpu_records, - &indices_u32, - shard_offset, - is_signed, - fetch_base_pc, - fetch_num_slots, - None, - ) - .map_err(|e| { - ZKVMError::InvalidWitness( - format!("GPU witgen_slti failed: {e}").into(), + with_cached_shard_meta(|shard_bufs| { + split_full!(hal + .witgen_slti( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + is_signed, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), ) - })) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_slti failed: {e}").into(), + ) + })) + }) }) }) } @@ -1021,22 +1087,25 @@ fn gpu_fill_witness>( }); info_span!("hal_witgen_branch_eq").in_scope(|| { with_cached_shard_steps(|gpu_records| { - split_full!(hal - .witgen_branch_eq( - &col_map, - gpu_records, - &indices_u32, - shard_offset, - is_beq, - fetch_base_pc, - fetch_num_slots, - None, - ) - .map_err(|e| { - ZKVMError::InvalidWitness( - format!("GPU witgen_branch_eq failed: {e}").into(), + with_cached_shard_meta(|shard_bufs| { + split_full!(hal + .witgen_branch_eq( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + is_beq, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), ) - })) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_branch_eq failed: {e}").into(), + ) + })) + }) }) }) } @@ -1053,22 +1122,25 @@ fn gpu_fill_witness>( }); info_span!("hal_witgen_branch_cmp").in_scope(|| { with_cached_shard_steps(|gpu_records| { - split_full!(hal - .witgen_branch_cmp( - &col_map, - gpu_records, - &indices_u32, - shard_offset, - is_signed, - fetch_base_pc, - fetch_num_slots, - None, - ) - .map_err(|e| { - ZKVMError::InvalidWitness( - format!("GPU witgen_branch_cmp failed: {e}").into(), + with_cached_shard_meta(|shard_bufs| { + split_full!(hal + .witgen_branch_cmp( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + is_signed, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), ) - })) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_branch_cmp failed: {e}").into(), + ) + })) + }) }) }) } @@ -1082,19 +1154,22 @@ fn gpu_fill_witness>( .in_scope(|| super::jalr::extract_jalr_column_map(jalr_config, num_witin)); info_span!("hal_witgen_jalr").in_scope(|| { with_cached_shard_steps(|gpu_records| { - split_full!(hal - .witgen_jalr( - &col_map, - gpu_records, - &indices_u32, - shard_offset, - fetch_base_pc, - fetch_num_slots, - None, - ) - .map_err(|e| { - ZKVMError::InvalidWitness(format!("GPU witgen_jalr failed: {e}").into()) - })) + with_cached_shard_meta(|shard_bufs| { + split_full!(hal + .witgen_jalr( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), + ) + .map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU witgen_jalr failed: {e}").into()) + })) + }) }) }) } @@ -1109,20 +1184,23 @@ fn gpu_fill_witness>( .in_scope(|| super::sw::extract_sw_column_map(sw_config, num_witin)); info_span!("hal_witgen_sw").in_scope(|| { with_cached_shard_steps(|gpu_records| { - split_full!(hal - .witgen_sw( - &col_map, - gpu_records, - &indices_u32, - shard_offset, - mem_max_bits, - fetch_base_pc, - fetch_num_slots, - None, - ) - .map_err(|e| { - ZKVMError::InvalidWitness(format!("GPU witgen_sw failed: {e}").into()) - })) + with_cached_shard_meta(|shard_bufs| { + split_full!(hal + .witgen_sw( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + mem_max_bits, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), + ) + .map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU witgen_sw failed: {e}").into()) + })) + }) }) }) } @@ -1137,20 +1215,23 @@ fn gpu_fill_witness>( .in_scope(|| super::sh::extract_sh_column_map(sh_config, num_witin)); info_span!("hal_witgen_sh").in_scope(|| { with_cached_shard_steps(|gpu_records| { - split_full!(hal - .witgen_sh( - &col_map, - gpu_records, - &indices_u32, - shard_offset, - mem_max_bits, - fetch_base_pc, - fetch_num_slots, - None, - ) - .map_err(|e| { - ZKVMError::InvalidWitness(format!("GPU witgen_sh failed: {e}").into()) - })) + with_cached_shard_meta(|shard_bufs| { + split_full!(hal + .witgen_sh( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + mem_max_bits, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), + ) + .map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU witgen_sh failed: {e}").into()) + })) + }) }) }) } @@ -1165,20 +1246,23 @@ fn gpu_fill_witness>( .in_scope(|| super::sb::extract_sb_column_map(sb_config, num_witin)); info_span!("hal_witgen_sb").in_scope(|| { with_cached_shard_steps(|gpu_records| { - split_full!(hal - .witgen_sb( - &col_map, - gpu_records, - &indices_u32, - shard_offset, - mem_max_bits, - fetch_base_pc, - fetch_num_slots, - None, - ) - .map_err(|e| { - ZKVMError::InvalidWitness(format!("GPU witgen_sb failed: {e}").into()) - })) + with_cached_shard_meta(|shard_bufs| { + split_full!(hal + .witgen_sb( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + mem_max_bits, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), + ) + .map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU witgen_sb failed: {e}").into()) + })) + }) }) }) } @@ -1204,24 +1288,27 @@ fn gpu_fill_witness>( let mem_max_bits = load_config.memory_addr.max_bits as u32; info_span!("hal_witgen_load_sub").in_scope(|| { with_cached_shard_steps(|gpu_records| { - split_full!(hal - .witgen_load_sub( - &col_map, - gpu_records, - &indices_u32, - shard_offset, - load_width, - is_signed, - mem_max_bits, - fetch_base_pc, - fetch_num_slots, - None, - ) - .map_err(|e| { - ZKVMError::InvalidWitness( - format!("GPU witgen_load_sub failed: {e}").into(), + with_cached_shard_meta(|shard_bufs| { + split_full!(hal + .witgen_load_sub( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + load_width, + is_signed, + mem_max_bits, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), ) - })) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_load_sub failed: {e}").into(), + ) + })) + }) }) }) } @@ -1235,22 +1322,25 @@ fn gpu_fill_witness>( .in_scope(|| super::mul::extract_mul_column_map(mul_config, num_witin, mul_kind)); info_span!("hal_witgen_mul").in_scope(|| { with_cached_shard_steps(|gpu_records| { - split_full!(hal - .witgen_mul( - &col_map, - gpu_records, - &indices_u32, - shard_offset, - mul_kind, - fetch_base_pc, - fetch_num_slots, - None, - ) - .map_err(|e| { - ZKVMError::InvalidWitness( - format!("GPU witgen_mul failed: {e}").into(), + with_cached_shard_meta(|shard_bufs| { + split_full!(hal + .witgen_mul( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + mul_kind, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), ) - })) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_mul failed: {e}").into(), + ) + })) + }) }) }) } @@ -1264,22 +1354,25 @@ fn gpu_fill_witness>( .in_scope(|| super::div::extract_div_column_map(div_config, num_witin)); info_span!("hal_witgen_div").in_scope(|| { with_cached_shard_steps(|gpu_records| { - split_full!(hal - .witgen_div( - &col_map, - gpu_records, - &indices_u32, - shard_offset, - div_kind, - fetch_base_pc, - fetch_num_slots, - None, - ) - .map_err(|e| { - ZKVMError::InvalidWitness( - format!("GPU witgen_div failed: {e}").into(), + with_cached_shard_meta(|shard_bufs| { + split_full!(hal + .witgen_div( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + div_kind, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), ) - })) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_div failed: {e}").into(), + ) + })) + }) }) }) } @@ -1299,20 +1392,23 @@ fn gpu_fill_witness>( .in_scope(|| super::lw::extract_lw_column_map(load_config, num_witin)); info_span!("hal_witgen_lw").in_scope(|| { with_cached_shard_steps(|gpu_records| { - split_full!(hal - .witgen_lw( - &col_map, - gpu_records, - &indices_u32, - shard_offset, - mem_max_bits, - fetch_base_pc, - fetch_num_slots, - None, - ) - .map_err(|e| { - ZKVMError::InvalidWitness(format!("GPU witgen_lw failed: {e}").into()) - })) + with_cached_shard_meta(|shard_bufs| { + split_full!(hal + .witgen_lw( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + mem_max_bits, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), + ) + .map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU witgen_lw failed: {e}").into()) + })) + }) }) }) } From 45a359e87c04d8b57716a4494d19c3230e219ca0 Mon Sep 17 00:00:00 2001 From: Velaciela Date: Thu, 12 Mar 2026 10:16:11 +0800 Subject: [PATCH 31/73] fa_sort --- .../src/instructions/riscv/gpu/witgen_gpu.rs | 114 ++++++++++-------- 1 file changed, 61 insertions(+), 53 deletions(-) diff --git a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs index acaca0947..a91a41165 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs @@ -11,6 +11,7 @@ use ceno_gpu::{ use ceno_gpu::bb31::ShardDeviceBuffers; use ceno_gpu::common::witgen_types::{GpuRamRecordSlot, GpuShardScalars}; use ff_ext::ExtensionField; +use rayon::prelude::*; use gkr_iop::{RAMType, tables::LookupTable, utils::lk_multiplicity::Multiplicity}; use p3::field::FieldAlgebra; use std::cell::{Cell, RefCell}; @@ -195,24 +196,25 @@ fn ensure_shard_metadata_cached( } // Build sorted future-access arrays from HashMap - let (fa_cycles_vec, fa_addrs_vec, fa_next_vec) = { - let mut entries: Vec<(u64, u32, u64)> = Vec::new(); - for (cycle, pairs) in shard_ctx.addr_future_accesses.iter() { - for &(addr, next_cycle) in pairs.iter() { - entries.push((*cycle, addr.0, next_cycle)); + let (fa_cycles_vec, fa_addrs_vec, fa_next_vec) = + tracing::info_span!("fa_sort").in_scope(|| { + let mut entries: Vec<(u64, u32, u64)> = Vec::new(); + for (cycle, pairs) in shard_ctx.addr_future_accesses.iter() { + for &(addr, next_cycle) in pairs.iter() { + entries.push((*cycle, addr.0, next_cycle)); + } } - } - entries.sort_unstable(); - let mut cycles = Vec::with_capacity(entries.len()); - let mut addrs = Vec::with_capacity(entries.len()); - let mut nexts = Vec::with_capacity(entries.len()); - for (c, a, n) in entries { - cycles.push(c); - addrs.push(a); - nexts.push(n); - } - (cycles, addrs, nexts) - }; + entries.par_sort_unstable(); + let mut cycles = Vec::with_capacity(entries.len()); + let mut addrs = Vec::with_capacity(entries.len()); + let mut nexts = Vec::with_capacity(entries.len()); + for (c, a, n) in entries { + cycles.push(c); + addrs.push(a); + nexts.push(n); + } + (cycles, addrs, nexts) + }); // Build GpuShardScalars let scalars = GpuShardScalars { @@ -234,42 +236,48 @@ fn ensure_shard_metadata_cached( num_prev_hint_ranges: shard_ctx.prev_shard_hint_range.len() as u32, }; - // H2D copy scalar struct - let scalars_bytes: &[u8] = unsafe { - std::slice::from_raw_parts( - &scalars as *const GpuShardScalars as *const u8, - std::mem::size_of::(), - ) - }; - let scalars_device = hal.inner.htod_copy_stream(None, scalars_bytes).map_err(|e| { - ZKVMError::InvalidWitness(format!("shard scalars H2D failed: {e}").into()) - })?; + // H2D uploads + let (scalars_device, fa_cycles_device, fa_addrs_device, fa_next_device, + pscr_device, pshr_device, pshi_device) = + tracing::info_span!("shard_meta_h2d").in_scope(|| -> Result<_, ZKVMError> { + let scalars_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + &scalars as *const GpuShardScalars as *const u8, + std::mem::size_of::(), + ) + }; + let scalars_device = hal.inner.htod_copy_stream(None, scalars_bytes).map_err(|e| { + ZKVMError::InvalidWitness(format!("shard scalars H2D failed: {e}").into()) + })?; - // H2D copy arrays (use empty slice [0] sentinel for empty arrays) - let fa_cycles_device = hal - .alloc_u64_from_host(if fa_cycles_vec.is_empty() { &[0u64] } else { &fa_cycles_vec }, None) - .map_err(|e| ZKVMError::InvalidWitness(format!("fa_cycles H2D failed: {e}").into()))?; - let fa_addrs_device = hal - .alloc_u32_from_host(if fa_addrs_vec.is_empty() { &[0u32] } else { &fa_addrs_vec }, None) - .map_err(|e| ZKVMError::InvalidWitness(format!("fa_addrs H2D failed: {e}").into()))?; - let fa_next_device = hal - .alloc_u64_from_host(if fa_next_vec.is_empty() { &[0u64] } else { &fa_next_vec }, None) - .map_err(|e| ZKVMError::InvalidWitness(format!("fa_next H2D failed: {e}").into()))?; - - let pscr = &shard_ctx.prev_shard_cycle_range; - let pscr_device = hal - .alloc_u64_from_host(if pscr.is_empty() { &[0u64] } else { pscr }, None) - .map_err(|e| ZKVMError::InvalidWitness(format!("pscr H2D failed: {e}").into()))?; - - let pshr = &shard_ctx.prev_shard_heap_range; - let pshr_device = hal - .alloc_u32_from_host(if pshr.is_empty() { &[0u32] } else { pshr }, None) - .map_err(|e| ZKVMError::InvalidWitness(format!("pshr H2D failed: {e}").into()))?; - - let pshi = &shard_ctx.prev_shard_hint_range; - let pshi_device = hal - .alloc_u32_from_host(if pshi.is_empty() { &[0u32] } else { pshi }, None) - .map_err(|e| ZKVMError::InvalidWitness(format!("pshi H2D failed: {e}").into()))?; + let fa_cycles_device = hal + .alloc_u64_from_host(if fa_cycles_vec.is_empty() { &[0u64] } else { &fa_cycles_vec }, None) + .map_err(|e| ZKVMError::InvalidWitness(format!("fa_cycles H2D failed: {e}").into()))?; + let fa_addrs_device = hal + .alloc_u32_from_host(if fa_addrs_vec.is_empty() { &[0u32] } else { &fa_addrs_vec }, None) + .map_err(|e| ZKVMError::InvalidWitness(format!("fa_addrs H2D failed: {e}").into()))?; + let fa_next_device = hal + .alloc_u64_from_host(if fa_next_vec.is_empty() { &[0u64] } else { &fa_next_vec }, None) + .map_err(|e| ZKVMError::InvalidWitness(format!("fa_next H2D failed: {e}").into()))?; + + let pscr = &shard_ctx.prev_shard_cycle_range; + let pscr_device = hal + .alloc_u64_from_host(if pscr.is_empty() { &[0u64] } else { pscr }, None) + .map_err(|e| ZKVMError::InvalidWitness(format!("pscr H2D failed: {e}").into()))?; + + let pshr = &shard_ctx.prev_shard_heap_range; + let pshr_device = hal + .alloc_u32_from_host(if pshr.is_empty() { &[0u32] } else { pshr }, None) + .map_err(|e| ZKVMError::InvalidWitness(format!("pshr H2D failed: {e}").into()))?; + + let pshi = &shard_ctx.prev_shard_hint_range; + let pshi_device = hal + .alloc_u32_from_host(if pshi.is_empty() { &[0u32] } else { pshi }, None) + .map_err(|e| ZKVMError::InvalidWitness(format!("pshi H2D failed: {e}").into()))?; + + Ok((scalars_device, fa_cycles_device, fa_addrs_device, fa_next_device, + pscr_device, pshr_device, pshi_device)) + })?; let mb = (fa_cycles_vec.len() * 8 * 2 + fa_addrs_vec.len() * 4) as f64 / (1024.0 * 1024.0); tracing::info!( @@ -696,7 +704,7 @@ fn gpu_fill_witness>( let (fetch_base_pc, fetch_num_slots) = compute_fetch_params(shard_steps, step_indices); // Ensure shard metadata is cached for GPU shard records (shared across all kernel kinds) - ensure_shard_metadata_cached(hal, shard_ctx)?; + info_span!("ensure_shard_meta").in_scope(|| ensure_shard_metadata_cached(hal, shard_ctx))?; match kind { GpuWitgenKind::Add => { From e5395057e36c4fc60fa7d9e6b0680a596d5559e5 Mon Sep 17 00:00:00 2001 From: Velaciela Date: Fri, 13 Mar 2026 09:54:38 +0800 Subject: [PATCH 32/73] perf-preflight --- ceno_emul/src/lib.rs | 3 +- ceno_emul/src/tracer.rs | 63 ++++++- ceno_zkvm/src/e2e.rs | 59 ++++++- .../src/instructions/riscv/gpu/witgen_gpu.rs | 159 ++++++++++-------- 4 files changed, 209 insertions(+), 75 deletions(-) diff --git a/ceno_emul/src/lib.rs b/ceno_emul/src/lib.rs index 915edd18f..268fe78e8 100644 --- a/ceno_emul/src/lib.rs +++ b/ceno_emul/src/lib.rs @@ -11,7 +11,8 @@ pub use platform::{CENO_PLATFORM, Platform}; mod tracer; pub use tracer::{ Change, FullTracer, FullTracerConfig, LatestAccesses, MemOp, NextAccessPair, NextCycleAccess, - PreflightTracer, PreflightTracerConfig, ReadOp, ShardPlanBuilder, StepCellExtractor, StepIndex, + PackedNextAccessEntry, PreflightTracer, PreflightTracerConfig, ReadOp, ShardPlanBuilder, + StepCellExtractor, StepIndex, StepRecord, Tracer, WriteOp, }; diff --git a/ceno_emul/src/tracer.rs b/ceno_emul/src/tracer.rs index 4aa7b3080..164f1c6c4 100644 --- a/ceno_emul/src/tracer.rs +++ b/ceno_emul/src/tracer.rs @@ -86,6 +86,60 @@ pub trait StepCellExtractor { pub type NextAccessPair = SmallVec<[(WordAddr, Cycle); 1]>; pub type NextCycleAccess = FxHashMap; +/// Packed next-access entry (16 bytes, u128-aligned). +/// Stores (cycle, addr, next_cycle) with 40-bit cycles for GPU bulk H2D upload. +/// Must be layout-compatible with CUDA `PackedNextAccessEntry` in shard_helpers.cuh. +#[repr(C, align(16))] +#[derive(Debug, Clone, Copy, Default)] +pub struct PackedNextAccessEntry { + pub cycles_lo: u32, + pub addr: u32, + pub nexts_lo: u32, + pub cycles_hi: u8, + pub nexts_hi: u8, + pub _reserved: u16, +} + +impl PackedNextAccessEntry { + #[inline] + pub fn new(cycle: u64, addr: u32, next_cycle: u64) -> Self { + Self { + cycles_lo: cycle as u32, + addr, + nexts_lo: next_cycle as u32, + cycles_hi: (cycle >> 32) as u8, + nexts_hi: (next_cycle >> 32) as u8, + _reserved: 0, + } + } +} + +impl Eq for PackedNextAccessEntry {} + +impl PartialEq for PackedNextAccessEntry { + #[inline] + fn eq(&self, other: &Self) -> bool { + self.cycles_hi == other.cycles_hi + && self.cycles_lo == other.cycles_lo + && self.addr == other.addr + } +} + +impl Ord for PackedNextAccessEntry { + #[inline] + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + (self.cycles_hi, self.cycles_lo, self.addr) + .cmp(&(other.cycles_hi, other.cycles_lo, other.addr)) + } +} + +impl PartialOrd for PackedNextAccessEntry { + #[inline] + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + fn init_mmio_min_max_access( platform: &Platform, ) -> BTreeMap { @@ -1051,6 +1105,7 @@ pub struct PreflightTracer { mmio_min_max_access: Option>, latest_accesses: LatestAccesses, next_accesses: NextCycleAccess, + next_accesses_vec: Vec, register_reads_tracked: u8, planner: Option, current_shard_start_cycle: Cycle, @@ -1075,6 +1130,7 @@ impl fmt::Debug for PreflightTracer { .field("mmio_min_max_access", &self.mmio_min_max_access) .field("latest_accesses", &self.latest_accesses) .field("next_accesses", &self.next_accesses) + .field("next_accesses_vec_len", &self.next_accesses_vec.len()) .field("register_reads_tracked", &self.register_reads_tracked) .field("planner", &self.planner) .field("current_shard_start_cycle", &self.current_shard_start_cycle) @@ -1172,6 +1228,7 @@ impl PreflightTracer { mmio_min_max_access: Some(init_mmio_min_max_access(platform)), latest_accesses: LatestAccesses::new(platform), next_accesses: FxHashMap::default(), + next_accesses_vec: Vec::new(), register_reads_tracked: 0, planner: Some(ShardPlanBuilder::new( max_cell_per_shard, @@ -1184,14 +1241,14 @@ impl PreflightTracer { tracer } - pub fn into_shard_plan(self) -> (ShardPlanBuilder, NextCycleAccess) { + pub fn into_shard_plan(self) -> (ShardPlanBuilder, NextCycleAccess, Vec) { let Some(mut planner) = self.planner else { panic!("shard planner missing") }; if !planner.finalized { planner.finalize(self.cycle); } - (planner, self.next_accesses) + (planner, self.next_accesses, self.next_accesses_vec) } #[inline(always)] @@ -1312,6 +1369,8 @@ impl Tracer for PreflightTracer { .entry(prev_cycle) .or_default() .push((addr, cur_cycle)); + self.next_accesses_vec + .push(PackedNextAccessEntry::new(prev_cycle, addr.0, cur_cycle)); } prev_cycle } diff --git a/ceno_zkvm/src/e2e.rs b/ceno_zkvm/src/e2e.rs index 63be3240b..642d982de 100644 --- a/ceno_zkvm/src/e2e.rs +++ b/ceno_zkvm/src/e2e.rs @@ -24,7 +24,8 @@ use crate::{ }; use ceno_emul::{ Addr, ByteAddr, CENO_PLATFORM, Cycle, EmuContext, FullTracer, FullTracerConfig, IterAddresses, - NextCycleAccess, Platform, PreflightTracer, PreflightTracerConfig, Program, RegIdx, + NextCycleAccess, PackedNextAccessEntry, Platform, PreflightTracer, PreflightTracerConfig, Program, + RegIdx, StepCellExtractor, StepIndex, StepRecord, SyscallWitness, Tracer, VM_REG_COUNT, VMState, WORD_SIZE, Word, WordAddr, host_utils::read_all_messages, }; @@ -41,6 +42,7 @@ use itertools::MinMaxResult; use itertools::{Itertools, chain}; use mpcs::{PolynomialCommitmentScheme, SecurityLevel}; use multilinear_extensions::util::max_usable_threads; +use rayon::prelude::*; use rustc_hash::FxHashSet; use serde::Serialize; #[cfg(debug_assertions)] @@ -181,11 +183,18 @@ impl Default for MultiProver { } } +/// Pre-sorted packed future access entries for GPU bulk H2D upload. +/// Sorted by (cycle, addr) composite key. +pub struct SortedNextAccesses { + pub packed: Vec, +} + pub struct ShardContext<'a> { pub shard_id: usize, num_shards: usize, max_cycle: Cycle, pub addr_future_accesses: Arc, + pub sorted_next_accesses: Arc, addr_accessed_tbs: Either>, &'a mut Vec>, read_records_tbs: Either>, &'a mut BTreeMap>, @@ -217,6 +226,9 @@ impl<'a> Default for ShardContext<'a> { num_shards: 1, max_cycle: Cycle::MAX, addr_future_accesses: Arc::new(Default::default()), + sorted_next_accesses: Arc::new(SortedNextAccesses { + packed: vec![], + }), addr_accessed_tbs: Either::Left(vec![Vec::new(); max_threads]), read_records_tbs: Either::Left( (0..max_threads) @@ -262,6 +274,7 @@ impl<'a> ShardContext<'a> { num_shards: self.num_shards, max_cycle: self.max_cycle, addr_future_accesses: self.addr_future_accesses.clone(), + sorted_next_accesses: self.sorted_next_accesses.clone(), addr_accessed_tbs: Either::Left(vec![Vec::new(); max_threads]), read_records_tbs: Either::Left( (0..max_threads) @@ -305,6 +318,7 @@ impl<'a> ShardContext<'a> { num_shards: self.num_shards, max_cycle: self.max_cycle, addr_future_accesses: self.addr_future_accesses.clone(), + sorted_next_accesses: self.sorted_next_accesses.clone(), addr_accessed_tbs: Either::Right(addr_accessed_tbs), read_records_tbs: Either::Right(read), write_records_tbs: Either::Right(write), @@ -663,6 +677,7 @@ impl ShardStepSummary { pub struct ShardContextBuilder { pub cur_shard_id: usize, addr_future_accesses: Arc, + sorted_next_accesses: Arc, prev_shard_cycle_range: Vec, prev_shard_heap_range: Vec, prev_shard_hint_range: Vec, @@ -676,6 +691,9 @@ impl Default for ShardContextBuilder { ShardContextBuilder { cur_shard_id: 0, addr_future_accesses: Arc::new(Default::default()), + sorted_next_accesses: Arc::new(SortedNextAccesses { + packed: vec![], + }), prev_shard_cycle_range: vec![], prev_shard_heap_range: vec![], prev_shard_hint_range: vec![], @@ -693,12 +711,45 @@ impl ShardContextBuilder { shard_cycle_boundaries: Arc>, max_cycle: Cycle, addr_future_accesses: NextCycleAccess, + next_accesses_vec: Vec, ) -> Self { assert_eq!(multi_prover.max_provers, 1); assert_eq!(multi_prover.prover_id, 0); + + let sorted_next_accesses = + info_span!("next_access_presort").in_scope(|| { + let source = std::env::var("CENO_NEXT_ACCESS_SOURCE").unwrap_or_default(); + let mut entries = if source == "hashmap" { + tracing::info!("[next-access presort] converting from HashMap"); + info_span!("next_access_from_hashmap").in_scope(|| { + let mut entries = Vec::new(); + for (cycle, pairs) in addr_future_accesses.iter() { + for &(addr, next_cycle) in pairs.iter() { + entries.push(PackedNextAccessEntry::new(*cycle, addr.0, next_cycle)); + } + } + entries + }) + } else { + tracing::info!( + "[next-access presort] using preflight-appended vec ({} entries)", + next_accesses_vec.len() + ); + next_accesses_vec + }; + let len = entries.len(); + info_span!("next_access_par_sort", n = len).in_scope(|| { + entries.par_sort_unstable(); + }); + tracing::info!("[next-access presort] sorted {} entries ({:.2} MB)", + len, len * 16 / (1024 * 1024)); + Arc::new(SortedNextAccesses { packed: entries }) + }); + ShardContextBuilder { cur_shard_id: 0, addr_future_accesses: Arc::new(addr_future_accesses), + sorted_next_accesses, prev_shard_cycle_range: vec![0], prev_shard_heap_range: vec![0], prev_shard_hint_range: vec![0], @@ -785,6 +836,7 @@ impl ShardContextBuilder { cur_shard_cycle_range: summary.first_cycle as usize ..(summary.last_cycle + FullTracer::SUBCYCLES_PER_INSN) as usize, addr_future_accesses: self.addr_future_accesses.clone(), + sorted_next_accesses: self.sorted_next_accesses.clone(), prev_shard_cycle_range: self.prev_shard_cycle_range.clone(), prev_shard_heap_range: self.prev_shard_heap_range.clone(), prev_shard_hint_range: self.prev_shard_hint_range.clone(), @@ -1103,7 +1155,7 @@ pub fn emulate_program<'a>( } let tracer = vm.take_tracer(); - let (plan_builder, next_accesses) = tracer.into_shard_plan(); + let (plan_builder, next_accesses, next_accesses_vec) = tracer.into_shard_plan(); let max_step_shard = plan_builder.max_step_shard(); let shard_cycle_boundaries = Arc::new(plan_builder.into_cycle_boundaries()); let shard_ctx_builder = ShardContextBuilder::from_plan( @@ -1112,6 +1164,7 @@ pub fn emulate_program<'a>( shard_cycle_boundaries.clone(), max_cycle, next_accesses, + next_accesses_vec, ); tracing::info!( "num_shards: {}, max_cycle {}, shard_cycle_boundaries {:?}", @@ -2174,6 +2227,7 @@ fn clone_debug_shard_ctx(src: &ShardContext) -> ShardContext<'static> { cloned.num_shards = src.num_shards; cloned.max_cycle = src.max_cycle; cloned.addr_future_accesses = src.addr_future_accesses.clone(); + cloned.sorted_next_accesses = src.sorted_next_accesses.clone(); cloned.cur_shard_cycle_range = src.cur_shard_cycle_range.clone(); cloned.expected_inst_per_shard = src.expected_inst_per_shard; cloned.max_num_cross_shard_accesses = src.max_num_cross_shard_accesses; @@ -2498,6 +2552,7 @@ mod tests { shard_cycle_boundaries, max_cycle, NextCycleAccess::default(), + Vec::new(), ); struct TestReplay { steps: Vec, diff --git a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs index a91a41165..c03f3b816 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs @@ -11,7 +11,6 @@ use ceno_gpu::{ use ceno_gpu::bb31::ShardDeviceBuffers; use ceno_gpu::common::witgen_types::{GpuRamRecordSlot, GpuShardScalars}; use ff_ext::ExtensionField; -use rayon::prelude::*; use gkr_iop::{RAMType, tables::LookupTable, utils::lk_multiplicity::Multiplicity}; use p3::field::FieldAlgebra; use std::cell::{Cell, RefCell}; @@ -181,7 +180,10 @@ thread_local! { } /// Build and cache shard metadata device buffers for GPU shard records. -/// Returns a reference to the cached `ShardDeviceBuffers`. +/// +/// FA (future access) device buffers are global and identical across all shards, +/// so they are uploaded once and reused via move. Only per-shard data (scalars + +/// prev_shard_ranges) is re-uploaded when the shard changes. fn ensure_shard_metadata_cached( hal: &CudaHalBB31, shard_ctx: &ShardContext, @@ -195,28 +197,50 @@ fn ensure_shard_metadata_cached( } } - // Build sorted future-access arrays from HashMap - let (fa_cycles_vec, fa_addrs_vec, fa_next_vec) = - tracing::info_span!("fa_sort").in_scope(|| { - let mut entries: Vec<(u64, u32, u64)> = Vec::new(); - for (cycle, pairs) in shard_ctx.addr_future_accesses.iter() { - for &(addr, next_cycle) in pairs.iter() { - entries.push((*cycle, addr.0, next_cycle)); + // Move FA device buffer from previous cache (reuse across shards). + // FA data is global — identical across all shards — so we reuse, not re-upload. + let existing_fa = cache.take().map(|c| { + let ShardDeviceBuffers { + next_access_packed, + scalars: _, + prev_shard_cycle_range: _, + prev_shard_heap_range: _, + prev_shard_hint_range: _, + } = c.device_bufs; + next_access_packed + }); + + let next_access_packed_device = if let Some(fa) = existing_fa { + fa // Reuse existing GPU memory — zero cost pointer move + } else { + // First shard: bulk H2D upload packed FA entries (no sort here) + let sorted = &shard_ctx.sorted_next_accesses; + tracing::info_span!("next_access_h2d").in_scope(|| -> Result<_, ZKVMError> { + let packed_bytes: &[u8] = if sorted.packed.is_empty() { + &[0u8; 16] // sentinel for empty + } else { + unsafe { + std::slice::from_raw_parts( + sorted.packed.as_ptr() as *const u8, + sorted.packed.len() * std::mem::size_of::(), + ) } - } - entries.par_sort_unstable(); - let mut cycles = Vec::with_capacity(entries.len()); - let mut addrs = Vec::with_capacity(entries.len()); - let mut nexts = Vec::with_capacity(entries.len()); - for (c, a, n) in entries { - cycles.push(c); - addrs.push(a); - nexts.push(n); - } - (cycles, addrs, nexts) - }); + }; + let buf = hal.inner.htod_copy_stream(None, packed_bytes).map_err(|e| { + ZKVMError::InvalidWitness(format!("next_access_packed H2D: {e}").into()) + })?; + let next_access_device = ceno_gpu::common::buffer::BufferImpl::new(buf); + let mb = packed_bytes.len() as f64 / (1024.0 * 1024.0); + tracing::info!( + "[GPU shard] FA uploaded once: {} entries, {:.2} MB (packed)", + sorted.packed.len(), + mb, + ); + Ok(next_access_device) + })? + }; - // Build GpuShardScalars + // Per-shard: always re-upload scalars + prev_shard_ranges let scalars = GpuShardScalars { shard_cycle_start: shard_ctx.cur_shard_cycle_range.start as u64, shard_cycle_end: shard_ctx.cur_shard_cycle_range.end as u64, @@ -230,68 +254,63 @@ fn ensure_shard_metadata_cached( shard_heap_end: shard_ctx.shard_heap_addr_range.end, shard_hint_start: shard_ctx.shard_hint_addr_range.start, shard_hint_end: shard_ctx.shard_hint_addr_range.end, - fa_count: fa_cycles_vec.len() as u32, + next_access_count: shard_ctx.sorted_next_accesses.packed.len() as u32, num_prev_shards: shard_ctx.prev_shard_cycle_range.len() as u32, num_prev_heap_ranges: shard_ctx.prev_shard_heap_range.len() as u32, num_prev_hint_ranges: shard_ctx.prev_shard_hint_range.len() as u32, }; - // H2D uploads - let (scalars_device, fa_cycles_device, fa_addrs_device, fa_next_device, - pscr_device, pshr_device, pshi_device) = - tracing::info_span!("shard_meta_h2d").in_scope(|| -> Result<_, ZKVMError> { - let scalars_bytes: &[u8] = unsafe { - std::slice::from_raw_parts( - &scalars as *const GpuShardScalars as *const u8, - std::mem::size_of::(), - ) - }; - let scalars_device = hal.inner.htod_copy_stream(None, scalars_bytes).map_err(|e| { - ZKVMError::InvalidWitness(format!("shard scalars H2D failed: {e}").into()) + let (scalars_device, pscr_device, pshr_device, pshi_device) = + tracing::info_span!("shard_scalars_h2d").in_scope(|| -> Result<_, ZKVMError> { + let scalars_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + &scalars as *const GpuShardScalars as *const u8, + std::mem::size_of::(), + ) + }; + let scalars_device = + hal.inner + .htod_copy_stream(None, scalars_bytes) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("shard scalars H2D failed: {e}").into(), + ) + })?; + + let pscr = &shard_ctx.prev_shard_cycle_range; + let pscr_device = hal + .alloc_u64_from_host(if pscr.is_empty() { &[0u64] } else { pscr }, None) + .map_err(|e| { + ZKVMError::InvalidWitness(format!("pscr H2D failed: {e}").into()) + })?; + + let pshr = &shard_ctx.prev_shard_heap_range; + let pshr_device = hal + .alloc_u32_from_host(if pshr.is_empty() { &[0u32] } else { pshr }, None) + .map_err(|e| { + ZKVMError::InvalidWitness(format!("pshr H2D failed: {e}").into()) + })?; + + let pshi = &shard_ctx.prev_shard_hint_range; + let pshi_device = hal + .alloc_u32_from_host(if pshi.is_empty() { &[0u32] } else { pshi }, None) + .map_err(|e| { + ZKVMError::InvalidWitness(format!("pshi H2D failed: {e}").into()) + })?; + + Ok((scalars_device, pscr_device, pshr_device, pshi_device)) })?; - let fa_cycles_device = hal - .alloc_u64_from_host(if fa_cycles_vec.is_empty() { &[0u64] } else { &fa_cycles_vec }, None) - .map_err(|e| ZKVMError::InvalidWitness(format!("fa_cycles H2D failed: {e}").into()))?; - let fa_addrs_device = hal - .alloc_u32_from_host(if fa_addrs_vec.is_empty() { &[0u32] } else { &fa_addrs_vec }, None) - .map_err(|e| ZKVMError::InvalidWitness(format!("fa_addrs H2D failed: {e}").into()))?; - let fa_next_device = hal - .alloc_u64_from_host(if fa_next_vec.is_empty() { &[0u64] } else { &fa_next_vec }, None) - .map_err(|e| ZKVMError::InvalidWitness(format!("fa_next H2D failed: {e}").into()))?; - - let pscr = &shard_ctx.prev_shard_cycle_range; - let pscr_device = hal - .alloc_u64_from_host(if pscr.is_empty() { &[0u64] } else { pscr }, None) - .map_err(|e| ZKVMError::InvalidWitness(format!("pscr H2D failed: {e}").into()))?; - - let pshr = &shard_ctx.prev_shard_heap_range; - let pshr_device = hal - .alloc_u32_from_host(if pshr.is_empty() { &[0u32] } else { pshr }, None) - .map_err(|e| ZKVMError::InvalidWitness(format!("pshr H2D failed: {e}").into()))?; - - let pshi = &shard_ctx.prev_shard_hint_range; - let pshi_device = hal - .alloc_u32_from_host(if pshi.is_empty() { &[0u32] } else { pshi }, None) - .map_err(|e| ZKVMError::InvalidWitness(format!("pshi H2D failed: {e}").into()))?; - - Ok((scalars_device, fa_cycles_device, fa_addrs_device, fa_next_device, - pscr_device, pshr_device, pshi_device)) - })?; - - let mb = (fa_cycles_vec.len() * 8 * 2 + fa_addrs_vec.len() * 4) as f64 / (1024.0 * 1024.0); tracing::info!( - "[GPU shard] built ShardMetadataCache: shard_id={}, fa_entries={}, {:.2} MB", - shard_id, fa_cycles_vec.len(), mb, + "[GPU shard] shard_id={}: per-shard scalars updated", + shard_id, ); *cache = Some(ShardMetadataCache { shard_id, device_bufs: ShardDeviceBuffers { scalars: scalars_device, - fa_cycles: fa_cycles_device, - fa_addrs: fa_addrs_device, - fa_next_cycles: fa_next_device, + next_access_packed: next_access_packed_device, prev_shard_cycle_range: pscr_device, prev_shard_heap_range: pshr_device, prev_shard_hint_range: pshi_device, From 12cd5eeb177a8c06bc26a79c04c5f14c4d1c4bc0 Mon Sep 17 00:00:00 2001 From: Velaciela Date: Fri, 13 Mar 2026 09:54:38 +0800 Subject: [PATCH 33/73] shardram: ec --- ceno_zkvm/src/e2e.rs | 26 +++ .../src/instructions/riscv/gpu/witgen_gpu.rs | 218 +++++++++++++++++- ceno_zkvm/src/scheme/septic_curve.rs | 78 +++++++ ceno_zkvm/src/structs.rs | 82 ++++++- 4 files changed, 391 insertions(+), 13 deletions(-) diff --git a/ceno_zkvm/src/e2e.rs b/ceno_zkvm/src/e2e.rs index 642d982de..ea92ff6f7 100644 --- a/ceno_zkvm/src/e2e.rs +++ b/ceno_zkvm/src/e2e.rs @@ -212,6 +212,10 @@ pub struct ShardContext<'a> { pub shard_hint_addr_range: Range, /// Syscall witnesses for StepRecord::syscall() lookups. pub syscall_witnesses: Arc>, + /// GPU-produced compact EC shard records (raw bytes of GpuShardRamRecord). + /// Each record is GPU_SHARD_RAM_RECORD_SIZE bytes. These bypass BTreeMap and + /// are converted to ShardRamInput in assign_shared_circuit. + pub gpu_ec_records: Vec, } impl<'a> Default for ShardContext<'a> { @@ -250,10 +254,14 @@ impl<'a> Default for ShardContext<'a> { shard_heap_addr_range: CENO_PLATFORM.heap.clone(), shard_hint_addr_range: CENO_PLATFORM.hints.clone(), syscall_witnesses: Arc::new(Vec::new()), + gpu_ec_records: vec![], } } } +/// Size of a single GpuShardRamRecord in bytes (must match CUDA struct). +pub const GPU_SHARD_RAM_RECORD_SIZE: usize = 104; + /// `prover_id` and `num_provers` in MultiProver are exposed as arguments /// to specify the number of physical provers in a cluster, /// each mark with a prover_id. @@ -296,6 +304,7 @@ impl<'a> ShardContext<'a> { shard_heap_addr_range: self.shard_heap_addr_range.clone(), shard_hint_addr_range: self.shard_hint_addr_range.clone(), syscall_witnesses: self.syscall_witnesses.clone(), + gpu_ec_records: vec![], } } @@ -332,6 +341,7 @@ impl<'a> ShardContext<'a> { shard_heap_addr_range: self.shard_heap_addr_range.clone(), shard_hint_addr_range: self.shard_hint_addr_range.clone(), syscall_witnesses: self.syscall_witnesses.clone(), + gpu_ec_records: vec![], }) .collect_vec(), _ => panic!("invalid type"), @@ -470,6 +480,22 @@ impl<'a> ShardContext<'a> { addr_accessed.push(addr); } + /// Extend GPU EC records with raw bytes from GpuShardRamRecord slice. + /// Called from the GPU EC path to accumulate records across kernel invocations. + pub fn extend_gpu_ec_records_raw(&mut self, raw_bytes: &[u8]) { + self.gpu_ec_records.extend_from_slice(raw_bytes); + } + + /// Returns true if GPU EC records have been collected. + pub fn has_gpu_ec_records(&self) -> bool { + !self.gpu_ec_records.is_empty() + } + + /// Take GPU EC records, leaving the field empty. + pub fn take_gpu_ec_records(&mut self) -> Vec { + std::mem::take(&mut self.gpu_ec_records) + } + #[inline(always)] #[allow(clippy::too_many_arguments)] pub fn record_send_without_touch( diff --git a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs index c03f3b816..821d2a320 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs @@ -9,7 +9,7 @@ use ceno_gpu::{ Buffer, CudaHal, CudaSlice, bb31::CudaHalBB31, common::transpose::matrix_transpose, }; use ceno_gpu::bb31::ShardDeviceBuffers; -use ceno_gpu::common::witgen_types::{GpuRamRecordSlot, GpuShardScalars}; +use ceno_gpu::common::witgen_types::{CompactEcResult, GpuRamRecordSlot, GpuShardRamRecord, GpuShardScalars}; use ff_ext::ExtensionField; use gkr_iop::{RAMType, tables::LookupTable, utils::lk_multiplicity::Multiplicity}; use p3::field::FieldAlgebra; @@ -206,6 +206,7 @@ fn ensure_shard_metadata_cached( prev_shard_cycle_range: _, prev_shard_heap_range: _, prev_shard_hint_range: _, + gpu_ec_shard_id: _, } = c.device_bufs; next_access_packed }); @@ -314,6 +315,7 @@ fn ensure_shard_metadata_cached( prev_shard_cycle_range: pscr_device, prev_shard_heap_range: pshr_device, prev_shard_hint_range: pshi_device, + gpu_ec_shard_id: Some(shard_id as u64), }, }); Ok(()) @@ -404,6 +406,34 @@ fn gpu_collect_shard_records( } } +/// D2H the compact EC result: read count, then partial-D2H only that many records. +fn gpu_compact_ec_d2h( + compact: &CompactEcResult, +) -> Result, ZKVMError> { + // D2H the count (1 u32) + let count_vec: Vec = compact.count_buf.to_vec().map_err(|e| { + ZKVMError::InvalidWitness(format!("compact_count D2H failed: {e}").into()) + })?; + let count = count_vec[0] as usize; + if count == 0 { + return Ok(vec![]); + } + + // D2H the buffer (all u32s), then reinterpret as GpuShardRamRecord + let buf_vec: Vec = compact.buffer.to_vec().map_err(|e| { + ZKVMError::InvalidWitness(format!("compact_out D2H failed: {e}").into()) + })?; + + let record_u32s = std::mem::size_of::() / 4; // 26 + let total_u32s = count * record_u32s; + let records: Vec = unsafe { + let ptr = buf_vec.as_ptr() as *const GpuShardRamRecord; + std::slice::from_raw_parts(ptr, count).to_vec() + }; + tracing::debug!("GPU EC compact D2H: {} active records ({} bytes)", count, total_u32s * 4); + Ok(records) +} + /// Returns true if GPU shard records are verified for this kind. fn kind_has_verified_shard(kind: GpuWitgenKind) -> bool { if is_shard_kind_disabled(kind) { @@ -567,7 +597,7 @@ fn gpu_assign_instances_inner>( let total_instances = step_indices.len(); // Step 1: GPU fills witness matrix (+ LK counters + shard records for merged kinds) - let (gpu_witness, gpu_lk_counters, gpu_ram_slots) = info_span!("gpu_kernel").in_scope(|| { + let (gpu_witness, gpu_lk_counters, gpu_ram_slots, gpu_compact_ec) = info_span!("gpu_kernel").in_scope(|| { gpu_fill_witness::( hal, config, @@ -586,23 +616,59 @@ fn gpu_assign_instances_inner>( gpu_lk_counters_to_multiplicity(gpu_lk_counters.unwrap()) })?; - if gpu_ram_slots.is_some() && kind_has_verified_shard(kind) { - // GPU shard records path: D2H + lightweight CPU scan + if gpu_compact_ec.is_some() && kind_has_verified_shard(kind) { + // GPU EC path: compact records already have EC points computed on device. + // D2H only the active records (much smaller than full N*3 slot buffer). + info_span!("gpu_ec_shard").in_scope(|| { + let compact = gpu_compact_ec.unwrap(); + let compact_records = gpu_compact_ec_d2h(&compact)?; + debug_compare_ec_points(&compact_records, kind); + + // Still need addr_accessed from the old ram_slots path + // (WAS_SENT flag indicates send() calls for addr_accessed tracking). + if gpu_ram_slots.is_some() { + let ram_buf = gpu_ram_slots.unwrap(); + let slot_bytes: Vec = ram_buf.to_vec().map_err(|e| { + ZKVMError::InvalidWitness(format!("ram_slots D2H failed: {e}").into()) + })?; + let slots: &[GpuRamRecordSlot] = unsafe { + std::slice::from_raw_parts( + slot_bytes.as_ptr() as *const GpuRamRecordSlot, + slot_bytes.len() * 4 / std::mem::size_of::(), + ) + }; + let mut forked = shard_ctx.get_forked(); + let thread_ctx = &mut forked[0]; + // Only collect addr_accessed (WAS_SENT) and BTreeMap records + // from slot-based path, for compatibility. + gpu_collect_shard_records(thread_ctx, slots); + } + + // Store raw GPU EC records for downstream assign_shared_circuit. + // Records are stored as raw bytes and converted to ShardRamInput later. + let raw_bytes = unsafe { + std::slice::from_raw_parts( + compact_records.as_ptr() as *const u8, + compact_records.len() * std::mem::size_of::(), + ) + }; + shard_ctx.extend_gpu_ec_records_raw(raw_bytes); + + Ok::<(), ZKVMError>(()) + })?; + } else if gpu_ram_slots.is_some() && kind_has_verified_shard(kind) { + // GPU shard records path (no EC): D2H + lightweight CPU scan info_span!("gpu_shard_records").in_scope(|| { let ram_buf = gpu_ram_slots.unwrap(); let slot_bytes: Vec = ram_buf.to_vec().map_err(|e| { ZKVMError::InvalidWitness(format!("ram_slots D2H failed: {e}").into()) })?; - // Reinterpret u32 buffer as GpuRamRecordSlot slice let slots: &[GpuRamRecordSlot] = unsafe { std::slice::from_raw_parts( slot_bytes.as_ptr() as *const GpuRamRecordSlot, slot_bytes.len() * 4 / std::mem::size_of::(), ) }; - // Use a forked sub-context (Right variant) since - // insert_read_record/insert_write_record/push_addr_accessed - // require per-thread mutable references. let mut forked = shard_ctx.get_forked(); let thread_ctx = &mut forked[0]; gpu_collect_shard_records(thread_ctx, slots); @@ -669,6 +735,7 @@ type LkBuf = ceno_gpu::common::BufferImpl<'static, u32>; type RamBuf = ceno_gpu::common::BufferImpl<'static, u32>; type WitResult = ceno_gpu::common::witgen_types::GpuWitnessResult; type LkResult = ceno_gpu::common::witgen_types::GpuLookupCountersResult; +type CompactEcBuf = ceno_gpu::common::witgen_types::CompactEcResult; /// Compute fetch counter parameters from step data. fn compute_fetch_params( @@ -700,7 +767,7 @@ fn gpu_fill_witness>( shard_steps: &[StepRecord], step_indices: &[StepIndex], kind: GpuWitgenKind, -) -> Result<(WitResult, Option, Option), ZKVMError> { +) -> Result<(WitResult, Option, Option, Option), ZKVMError> { // Upload shard_steps to GPU once (cached across ADD/LW calls within same shard). let shard_id = shard_ctx.shard_id; info_span!("upload_shard_steps") @@ -711,11 +778,11 @@ fn gpu_fill_witness>( .in_scope(|| step_indices.iter().map(|&i| i as u32).collect()); let shard_offset = shard_ctx.current_shard_offset_cycle(); - // Helper to split GpuWitgenFullResult into (witness, Some(lk_counters), ram_slots) + // Helper to split GpuWitgenFullResult into (witness, Some(lk_counters), ram_slots, compact_ec) macro_rules! split_full { ($result:expr) => {{ let full = $result?; - Ok((full.witness, Some(full.lk_counters), full.ram_slots)) + Ok((full.witness, Some(full.lk_counters), full.ram_slots, full.compact_ec)) }}; } @@ -1563,6 +1630,7 @@ fn kind_has_verified_lk(kind: GpuWitgenKind) -> bool { GpuWitgenKind::Mul(_) => true, #[cfg(feature = "u16limb_circuit")] GpuWitgenKind::Div(_) => true, + #[cfg(not(feature = "u16limb_circuit"))] _ => false, } } @@ -1795,6 +1863,134 @@ fn debug_compare_shard_side_effects>( Ok(()) } +/// Compare GPU-produced EC points against CPU to_ec_point() for correctness. +/// Activated by CENO_GPU_DEBUG_COMPARE_EC=1. +/// Limit output with CENO_GPU_DEBUG_COMPARE_EC_LIMIT (default: 16). +fn debug_compare_ec_points( + compact_records: &[GpuShardRamRecord], + kind: GpuWitgenKind, +) { + if std::env::var_os("CENO_GPU_DEBUG_COMPARE_EC").is_none() { + return; + } + + println!("debug_compare_ec_points"); + + use crate::scheme::septic_curve::{SepticExtension, SepticPoint}; + use crate::tables::{ECPoint, ShardRamRecord}; + use ff_ext::{BabyBearExt4 as E, PoseidonField, SmallField}; + use p3::babybear::BabyBear; + let limit = std::env::var("CENO_GPU_DEBUG_COMPARE_EC_LIMIT") + .ok() + .and_then(|v| v.parse::().ok()) + .unwrap_or(16); + + let perm = BabyBear::get_default_perm(); + + let mut mismatches = 0usize; + let mut field_mismatches = 0usize; + let mut nonce_mismatches = 0usize; + + for (i, gpu_rec) in compact_records.iter().enumerate() { + // Reconstruct ShardRamRecord from GPU record fields + let cpu_record = ShardRamRecord { + addr: gpu_rec.addr, + ram_type: if gpu_rec.ram_type == 1 { + RAMType::Register + } else { + RAMType::Memory + }, + value: gpu_rec.value, + shard: gpu_rec.shard, + local_clk: gpu_rec.local_clk, + global_clk: gpu_rec.global_clk, + is_to_write_set: gpu_rec.is_to_write_set != 0, + }; + + // CPU computes EC point + let cpu_ec: ECPoint = cpu_record.to_ec_point(&perm); + + // GPU EC point (from canonical u32) + let gpu_x = SepticExtension( + gpu_rec.point_x.map(|v| BabyBear::from_canonical_u32(v)), + ); + let gpu_y = SepticExtension( + gpu_rec.point_y.map(|v| BabyBear::from_canonical_u32(v)), + ); + // Verify point is on curve (optional sanity check) + let _gpu_point = SepticPoint::from_affine(gpu_x, gpu_y); + + let mut has_diff = false; + + // Compare nonce + if gpu_rec.nonce != cpu_ec.nonce { + nonce_mismatches += 1; + has_diff = true; + if mismatches < limit { + tracing::error!( + "[GPU EC debug] kind={kind:?} rec[{i}] nonce mismatch: gpu={} cpu={}", + gpu_rec.nonce, + cpu_ec.nonce + ); + } + } + + // Compare x coordinates + for j in 0..7 { + let gpu_v = gpu_rec.point_x[j]; + let cpu_v = cpu_ec.point.x.0[j].to_canonical_u64() as u32; + if gpu_v != cpu_v { + field_mismatches += 1; + has_diff = true; + if mismatches < limit { + tracing::error!( + "[GPU EC debug] kind={kind:?} rec[{i}] x[{j}] mismatch: gpu={gpu_v} cpu={cpu_v}" + ); + } + } + } + + // Compare y coordinates + for j in 0..7 { + let gpu_v = gpu_rec.point_y[j]; + let cpu_v = cpu_ec.point.y.0[j].to_canonical_u64() as u32; + if gpu_v != cpu_v { + field_mismatches += 1; + has_diff = true; + if mismatches < limit { + tracing::error!( + "[GPU EC debug] kind={kind:?} rec[{i}] y[{j}] mismatch: gpu={gpu_v} cpu={cpu_v} \ + (addr={} ram_type={} value={} shard={} clk={} is_write={})", + gpu_rec.addr, + gpu_rec.ram_type, + gpu_rec.value, + gpu_rec.shard, + gpu_rec.global_clk, + gpu_rec.is_to_write_set + ); + } + } + } + + if has_diff { + mismatches += 1; + } + } + + if mismatches == 0 { + tracing::info!( + "[GPU EC debug] kind={kind:?} ALL {} EC points match CPU", + compact_records.len() + ); + } else { + tracing::error!( + "[GPU EC debug] kind={kind:?} {mismatches}/{} records have mismatches \ + (nonce_diffs={nonce_mismatches} field_diffs={field_mismatches})", + compact_records.len() + ); + } +} + fn flatten_ram_records( records: &[std::collections::BTreeMap], ) -> Vec<(u32, u64, u64, u64, u64, Option, u32, usize)> { diff --git a/ceno_zkvm/src/scheme/septic_curve.rs b/ceno_zkvm/src/scheme/septic_curve.rs index f9b6b4f76..4bacfe011 100644 --- a/ceno_zkvm/src/scheme/septic_curve.rs +++ b/ceno_zkvm/src/scheme/septic_curve.rs @@ -1171,4 +1171,82 @@ mod tests { assert!(j4.is_on_curve()); assert_eq!(j4.into_affine(), p4); } + + /// GPU vs CPU EC point computation test. + /// Launches the `test_septic_ec_point` CUDA kernel with known inputs, + /// then compares GPU outputs against CPU `ShardRamRecord::to_ec_point()`. + #[test] + #[cfg(feature = "gpu")] + fn test_gpu_ec_point_matches_cpu() { + use crate::tables::{ECPoint, ShardRamRecord}; + use ceno_gpu::bb31::test_impl::{TestEcInput, run_gpu_ec_test}; + use ceno_gpu::bb31::CudaHalBB31; + use ff_ext::{PoseidonField, SmallField}; + use gkr_iop::RAMType; + + let hal = CudaHalBB31::new(0).unwrap(); + let perm = F::get_default_perm(); + + // Test cases: various write/read, register/memory, edge cases + let test_inputs = vec![ + TestEcInput { addr: 5, ram_type: 1, value: 0x12345678, is_write: 1, shard: 1, global_clk: 100 }, + TestEcInput { addr: 5, ram_type: 1, value: 0x12345678, is_write: 0, shard: 0, global_clk: 50 }, + TestEcInput { addr: 0x80000, ram_type: 2, value: 0xDEADBEEF, is_write: 1, shard: 2, global_clk: 200 }, + TestEcInput { addr: 0x80000, ram_type: 2, value: 0xDEADBEEF, is_write: 0, shard: 1, global_clk: 150 }, + TestEcInput { addr: 0, ram_type: 1, value: 0, is_write: 1, shard: 0, global_clk: 1 }, + TestEcInput { addr: 31, ram_type: 1, value: 0xFFFFFFFF, is_write: 0, shard: 3, global_clk: 999 }, + TestEcInput { addr: 0x40000000, ram_type: 2, value: 42, is_write: 1, shard: 5, global_clk: 500000 }, + TestEcInput { addr: 10, ram_type: 1, value: 1, is_write: 0, shard: 100, global_clk: 1000000 }, + ]; + + let gpu_results = run_gpu_ec_test(&hal, &test_inputs); + + let mut mismatches = 0; + for (i, (input, gpu_rec)) in test_inputs.iter().zip(gpu_results.iter()).enumerate() { + // Build CPU ShardRamRecord and compute EC point + let cpu_record = ShardRamRecord { + addr: input.addr, + ram_type: if input.ram_type == 1 { RAMType::Register } else { RAMType::Memory }, + value: input.value, + shard: input.shard, + local_clk: if input.is_write != 0 { input.global_clk } else { 0 }, + global_clk: input.global_clk, + is_to_write_set: input.is_write != 0, + }; + let cpu_ec: ECPoint = cpu_record.to_ec_point(&perm); + + let mut has_diff = false; + + if gpu_rec.nonce != cpu_ec.nonce { + eprintln!("[{i}] nonce: gpu={} cpu={}", gpu_rec.nonce, cpu_ec.nonce); + has_diff = true; + } + + for j in 0..7 { + let gpu_x = gpu_rec.point_x[j]; + let cpu_x = cpu_ec.point.x.0[j].to_canonical_u64() as u32; + if gpu_x != cpu_x { + eprintln!("[{i}] x[{j}]: gpu={gpu_x} cpu={cpu_x}"); + has_diff = true; + } + let gpu_y = gpu_rec.point_y[j]; + let cpu_y = cpu_ec.point.y.0[j].to_canonical_u64() as u32; + if gpu_y != cpu_y { + eprintln!("[{i}] y[{j}]: gpu={gpu_y} cpu={cpu_y}"); + has_diff = true; + } + } + + if has_diff { + eprintln!("MISMATCH [{i}]: addr={} ram_type={} value={:#x} is_write={} shard={} clk={}", + input.addr, input.ram_type, input.value, input.is_write, input.shard, input.global_clk); + mismatches += 1; + } + } + + assert_eq!(mismatches, 0, + "{mismatches}/{} test cases had GPU/CPU EC point mismatches", + test_inputs.len()); + eprintln!("All {} GPU EC point test cases match CPU!", test_inputs.len()); + } } diff --git a/ceno_zkvm/src/structs.rs b/ceno_zkvm/src/structs.rs index 1f433d0e2..8b24e20bd 100644 --- a/ceno_zkvm/src/structs.rs +++ b/ceno_zkvm/src/structs.rs @@ -1,9 +1,9 @@ use crate::{ circuit_builder::{CircuitBuilder, ConstraintSystem}, - e2e::{E2EProgramCtx, ShardContext}, + e2e::{E2EProgramCtx, GPU_SHARD_RAM_RECORD_SIZE, ShardContext}, error::ZKVMError, instructions::Instruction, - scheme::septic_curve::SepticPoint, + scheme::septic_curve::{SepticExtension, SepticPoint}, state::StateCircuit, tables::{ ECPoint, MemFinalRecord, RMMCollections, ShardRamCircuit, ShardRamInput, ShardRamRecord, @@ -541,6 +541,13 @@ impl ZKVMWitnesses { }) .collect::>(); + // GPU EC records: convert raw bytes to ShardRamInput (EC points already computed on GPU) + let gpu_ec_inputs = if shard_ctx.has_gpu_ec_records() { + gpu_ec_records_to_shard_ram_inputs::(&shard_ctx.gpu_ec_records) + } else { + vec![] + }; + let global_input = shard_ctx .write_records() .par_iter() @@ -570,6 +577,7 @@ impl ZKVMWitnesses { } }) })) + .chain(gpu_ec_inputs.into_par_iter()) .collect::>(); if tracing::enabled!(Level::DEBUG) { @@ -848,3 +856,73 @@ where // mainly used for debugging pub circuit_index_to_name: BTreeMap, } + +/// Convert raw GPU EC record bytes to ShardRamInput. +/// The raw bytes are from `GpuShardRamRecord` structs (104 bytes each). +/// EC points are already computed on GPU — no Poseidon2/SepticCurve needed. +fn gpu_ec_records_to_shard_ram_inputs( + raw: &[u8], +) -> Vec> { + use gkr_iop::RAMType; + use p3::field::FieldAlgebra; + + // GpuShardRamRecord layout (104 bytes, 8-byte aligned): + // addr: u32 (0), ram_type: u32 (4), value: u32 (8), _pad: u32 (12), + // shard: u64 (16), local_clk: u64 (24), global_clk: u64 (32), + // is_to_write_set: u32 (40), nonce: u32 (44), + // point_x: [u32;7] (48..76), point_y: [u32;7] (76..104) + + assert!(raw.len() % GPU_SHARD_RAM_RECORD_SIZE == 0); + let count = raw.len() / GPU_SHARD_RAM_RECORD_SIZE; + + (0..count).map(|i| { + let base = i * GPU_SHARD_RAM_RECORD_SIZE; + let r = &raw[base..base + GPU_SHARD_RAM_RECORD_SIZE]; + + let addr = u32::from_le_bytes(r[0..4].try_into().unwrap()); + let ram_type_val = u32::from_le_bytes(r[4..8].try_into().unwrap()); + let value = u32::from_le_bytes(r[8..12].try_into().unwrap()); + let shard = u64::from_le_bytes(r[16..24].try_into().unwrap()); + let local_clk = u64::from_le_bytes(r[24..32].try_into().unwrap()); + let global_clk = u64::from_le_bytes(r[32..40].try_into().unwrap()); + let is_to_write_set = u32::from_le_bytes(r[40..44].try_into().unwrap()) != 0; + let nonce = u32::from_le_bytes(r[44..48].try_into().unwrap()); + + let mut point_x_arr = [E::BaseField::ZERO; 7]; + let mut point_y_arr = [E::BaseField::ZERO; 7]; + for j in 0..7 { + let xoff = 48 + j * 4; + let yoff = 76 + j * 4; + point_x_arr[j] = E::BaseField::from_canonical_u32( + u32::from_le_bytes(r[xoff..xoff+4].try_into().unwrap()) + ); + point_y_arr[j] = E::BaseField::from_canonical_u32( + u32::from_le_bytes(r[yoff..yoff+4].try_into().unwrap()) + ); + } + + let record = ShardRamRecord { + addr, + ram_type: if ram_type_val == 1 { RAMType::Register } else { RAMType::Memory }, + value, + shard, + local_clk, + global_clk, + is_to_write_set, + }; + + let x = SepticExtension(point_x_arr); + let y = SepticExtension(point_y_arr); + let point = SepticPoint::from_affine(x, y); + + ShardRamInput { + name: if is_to_write_set { + "current_shard_external_write" + } else { + "current_shard_external_read" + }, + record, + ec_point: ECPoint { nonce, point }, + } + }).collect() +} From 4cd72812d78694a3478d5ef89f81e4abe538e430 Mon Sep 17 00:00:00 2001 From: Velaciela Date: Fri, 13 Mar 2026 09:54:38 +0800 Subject: [PATCH 34/73] api --- ceno_zkvm/src/instructions/riscv/gpu/add.rs | 4 +- ceno_zkvm/src/instructions/riscv/gpu/addi.rs | 4 +- ceno_zkvm/src/instructions/riscv/gpu/auipc.rs | 4 +- .../src/instructions/riscv/gpu/branch_cmp.rs | 4 +- .../src/instructions/riscv/gpu/branch_eq.rs | 4 +- ceno_zkvm/src/instructions/riscv/gpu/div.rs | 5 +- ceno_zkvm/src/instructions/riscv/gpu/jal.rs | 4 +- ceno_zkvm/src/instructions/riscv/gpu/jalr.rs | 4 +- .../src/instructions/riscv/gpu/load_sub.rs | 6 +- .../src/instructions/riscv/gpu/logic_i.rs | 4 +- .../src/instructions/riscv/gpu/logic_r.rs | 4 +- ceno_zkvm/src/instructions/riscv/gpu/lui.rs | 4 +- ceno_zkvm/src/instructions/riscv/gpu/lw.rs | 4 +- ceno_zkvm/src/instructions/riscv/gpu/mul.rs | 5 +- ceno_zkvm/src/instructions/riscv/gpu/sb.rs | 4 +- ceno_zkvm/src/instructions/riscv/gpu/sh.rs | 4 +- .../src/instructions/riscv/gpu/shift_i.rs | 4 +- .../src/instructions/riscv/gpu/shift_r.rs | 4 +- ceno_zkvm/src/instructions/riscv/gpu/slt.rs | 4 +- ceno_zkvm/src/instructions/riscv/gpu/slti.rs | 4 +- ceno_zkvm/src/instructions/riscv/gpu/sub.rs | 4 +- ceno_zkvm/src/instructions/riscv/gpu/sw.rs | 4 +- .../src/instructions/riscv/gpu/witgen_gpu.rs | 2 - ceno_zkvm/src/scheme/septic_curve.rs | 134 +++++++++++++++++- 24 files changed, 184 insertions(+), 44 deletions(-) diff --git a/ceno_zkvm/src/instructions/riscv/gpu/add.rs b/ceno_zkvm/src/instructions/riscv/gpu/add.rs index 1e2f30ad6..4df630312 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/add.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/add.rs @@ -257,12 +257,12 @@ mod tests { let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); let gpu_result = hal - .witgen_add(&col_map, &gpu_records, &indices_u32, shard_offset, None) + .witgen_add(&col_map, &gpu_records, &indices_u32, shard_offset, 0, 0, None, None) .unwrap(); // D2H copy (GPU output is column-major) let gpu_data: Vec<::BaseField> = - gpu_result.device_buffer.to_vec().unwrap(); + gpu_result.witness.device_buffer.to_vec().unwrap(); // Compare element by element (GPU is column-major, CPU is row-major) let cpu_data = cpu_witness.values(); diff --git a/ceno_zkvm/src/instructions/riscv/gpu/addi.rs b/ceno_zkvm/src/instructions/riscv/gpu/addi.rs index 5b61d38ee..485eee423 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/addi.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/addi.rs @@ -176,11 +176,11 @@ mod tests { let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); let gpu_result = hal - .witgen_addi(&col_map, &gpu_records, &indices_u32, shard_offset, None) + .witgen_addi(&col_map, &gpu_records, &indices_u32, shard_offset, 0, 0, None, None) .unwrap(); let gpu_data: Vec<::BaseField> = - gpu_result.device_buffer.to_vec().unwrap(); + gpu_result.witness.device_buffer.to_vec().unwrap(); let cpu_data = cpu_witness.values(); assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); diff --git a/ceno_zkvm/src/instructions/riscv/gpu/auipc.rs b/ceno_zkvm/src/instructions/riscv/gpu/auipc.rs index c0663d880..431d1b257 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/auipc.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/auipc.rs @@ -174,11 +174,11 @@ mod tests { let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); let gpu_result = hal - .witgen_auipc(&col_map, &gpu_records, &indices_u32, shard_offset, None) + .witgen_auipc(&col_map, &gpu_records, &indices_u32, shard_offset, 0, 0, None, None) .unwrap(); let gpu_data: Vec<::BaseField> = - gpu_result.device_buffer.to_vec().unwrap(); + gpu_result.witness.device_buffer.to_vec().unwrap(); let cpu_data = cpu_witness.values(); assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); diff --git a/ceno_zkvm/src/instructions/riscv/gpu/branch_cmp.rs b/ceno_zkvm/src/instructions/riscv/gpu/branch_cmp.rs index dfb9cd775..572e5bdbb 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/branch_cmp.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/branch_cmp.rs @@ -177,11 +177,11 @@ mod tests { let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); let gpu_result = hal - .witgen_branch_cmp(&col_map, &gpu_records, &indices_u32, shard_offset, 1, None) + .witgen_branch_cmp(&col_map, &gpu_records, &indices_u32, shard_offset, 1, 0, 0, None, None) .unwrap(); let gpu_data: Vec<::BaseField> = - gpu_result.device_buffer.to_vec().unwrap(); + gpu_result.witness.device_buffer.to_vec().unwrap(); let cpu_data = cpu_witness.values(); assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); diff --git a/ceno_zkvm/src/instructions/riscv/gpu/branch_eq.rs b/ceno_zkvm/src/instructions/riscv/gpu/branch_eq.rs index a44eaafa0..178b16fab 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/branch_eq.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/branch_eq.rs @@ -174,11 +174,11 @@ mod tests { let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); let gpu_result = hal - .witgen_branch_eq(&col_map, &gpu_records, &indices_u32, shard_offset, 1, None) + .witgen_branch_eq(&col_map, &gpu_records, &indices_u32, shard_offset, 1, 0, 0, None, None) .unwrap(); let gpu_data: Vec<::BaseField> = - gpu_result.device_buffer.to_vec().unwrap(); + gpu_result.witness.device_buffer.to_vec().unwrap(); let cpu_data = cpu_witness.values(); assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); diff --git a/ceno_zkvm/src/instructions/riscv/gpu/div.rs b/ceno_zkvm/src/instructions/riscv/gpu/div.rs index dd708cb59..f7420445c 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/div.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/div.rs @@ -398,12 +398,15 @@ mod tests { &indices_u32, shard_offset, div_kind, + 0, + 0, + None, None, ) .unwrap(); let gpu_data: Vec<::BaseField> = - gpu_result.device_buffer.to_vec().unwrap(); + gpu_result.witness.device_buffer.to_vec().unwrap(); let cpu_data = cpu_witness.values(); assert_eq!(gpu_data.len(), cpu_data.len(), "{}: Size mismatch", name); diff --git a/ceno_zkvm/src/instructions/riscv/gpu/jal.rs b/ceno_zkvm/src/instructions/riscv/gpu/jal.rs index d33b575e6..61710ef80 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/jal.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/jal.rs @@ -156,11 +156,11 @@ mod tests { let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); let gpu_result = hal - .witgen_jal(&col_map, &gpu_records, &indices_u32, shard_offset, None) + .witgen_jal(&col_map, &gpu_records, &indices_u32, shard_offset, 0, 0, None, None) .unwrap(); let gpu_data: Vec<::BaseField> = - gpu_result.device_buffer.to_vec().unwrap(); + gpu_result.witness.device_buffer.to_vec().unwrap(); let cpu_data = cpu_witness.values(); assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); diff --git a/ceno_zkvm/src/instructions/riscv/gpu/jalr.rs b/ceno_zkvm/src/instructions/riscv/gpu/jalr.rs index 804218293..03f6c510c 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/jalr.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/jalr.rs @@ -194,11 +194,11 @@ mod tests { let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); let gpu_result = hal - .witgen_jalr(&col_map, &gpu_records, &indices_u32, shard_offset, None) + .witgen_jalr(&col_map, &gpu_records, &indices_u32, shard_offset, 0, 0, None, None) .unwrap(); let gpu_data: Vec<::BaseField> = - gpu_result.device_buffer.to_vec().unwrap(); + gpu_result.witness.device_buffer.to_vec().unwrap(); let cpu_data = cpu_witness.values(); assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); diff --git a/ceno_zkvm/src/instructions/riscv/gpu/load_sub.rs b/ceno_zkvm/src/instructions/riscv/gpu/load_sub.rs index f7d48c772..787f091c2 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/load_sub.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/load_sub.rs @@ -403,12 +403,16 @@ mod tests { shard_offset, load_width, is_signed_u32, + 0, + 0, + 0, + None, None, ) .unwrap(); let gpu_data: Vec<::BaseField> = - gpu_result.device_buffer.to_vec().unwrap(); + gpu_result.witness.device_buffer.to_vec().unwrap(); let cpu_data = cpu_witness.values(); assert_eq!(gpu_data.len(), cpu_data.len(), "{}: Size mismatch", name); diff --git a/ceno_zkvm/src/instructions/riscv/gpu/logic_i.rs b/ceno_zkvm/src/instructions/riscv/gpu/logic_i.rs index c16e4f95f..36e33f4e2 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/logic_i.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/logic_i.rs @@ -209,11 +209,11 @@ mod tests { let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); let gpu_result = hal - .witgen_logic_i(&col_map, &gpu_records, &indices_u32, shard_offset, None) + .witgen_logic_i(&col_map, &gpu_records, &indices_u32, shard_offset, 0, 0, 0, None, None) .unwrap(); let gpu_data: Vec<::BaseField> = - gpu_result.device_buffer.to_vec().unwrap(); + gpu_result.witness.device_buffer.to_vec().unwrap(); let cpu_data = cpu_witness.values(); assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); diff --git a/ceno_zkvm/src/instructions/riscv/gpu/logic_r.rs b/ceno_zkvm/src/instructions/riscv/gpu/logic_r.rs index 17933915d..cd8c52375 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/logic_r.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/logic_r.rs @@ -250,11 +250,11 @@ mod tests { let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); let gpu_result = hal - .witgen_logic_r(&col_map, &gpu_records, &indices_u32, shard_offset, None) + .witgen_logic_r(&col_map, &gpu_records, &indices_u32, shard_offset, 0, 0, 0, None, None) .unwrap(); let gpu_data: Vec<::BaseField> = - gpu_result.device_buffer.to_vec().unwrap(); + gpu_result.witness.device_buffer.to_vec().unwrap(); let cpu_data = cpu_witness.values(); assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); diff --git a/ceno_zkvm/src/instructions/riscv/gpu/lui.rs b/ceno_zkvm/src/instructions/riscv/gpu/lui.rs index 348b5b8b4..0c644a808 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/lui.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/lui.rs @@ -160,11 +160,11 @@ mod tests { let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); let gpu_result = hal - .witgen_lui(&col_map, &gpu_records, &indices_u32, shard_offset, None) + .witgen_lui(&col_map, &gpu_records, &indices_u32, shard_offset, 0, 0, None, None) .unwrap(); let gpu_data: Vec<::BaseField> = - gpu_result.device_buffer.to_vec().unwrap(); + gpu_result.witness.device_buffer.to_vec().unwrap(); let cpu_data = cpu_witness.values(); assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); diff --git a/ceno_zkvm/src/instructions/riscv/gpu/lw.rs b/ceno_zkvm/src/instructions/riscv/gpu/lw.rs index 19f38a7a6..8e686d0cb 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/lw.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/lw.rs @@ -241,11 +241,11 @@ mod tests { let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); let gpu_result = hal - .witgen_lw(&col_map, &gpu_records, &indices_u32, shard_offset, None) + .witgen_lw(&col_map, &gpu_records, &indices_u32, shard_offset, 0, 0, 0, None, None) .unwrap(); let gpu_data: Vec<::BaseField> = - gpu_result.device_buffer.to_vec().unwrap(); + gpu_result.witness.device_buffer.to_vec().unwrap(); let cpu_data = cpu_witness.values(); assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); diff --git a/ceno_zkvm/src/instructions/riscv/gpu/mul.rs b/ceno_zkvm/src/instructions/riscv/gpu/mul.rs index 1a9b8f902..efafd6bd1 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/mul.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/mul.rs @@ -343,12 +343,15 @@ mod tests { &indices_u32, shard_offset, mul_kind, + 0, + 0, + None, None, ) .unwrap(); let gpu_data: Vec<::BaseField> = - gpu_result.device_buffer.to_vec().unwrap(); + gpu_result.witness.device_buffer.to_vec().unwrap(); let cpu_data = cpu_witness.values(); assert_eq!(gpu_data.len(), cpu_data.len(), "{}: Size mismatch", name); diff --git a/ceno_zkvm/src/instructions/riscv/gpu/sb.rs b/ceno_zkvm/src/instructions/riscv/gpu/sb.rs index 346be925e..10775d984 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/sb.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/sb.rs @@ -242,11 +242,11 @@ mod tests { let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); let gpu_result = hal - .witgen_sb(&col_map, &gpu_records, &indices_u32, shard_offset, None) + .witgen_sb(&col_map, &gpu_records, &indices_u32, shard_offset, 0, 0, 0, None, None) .unwrap(); let gpu_data: Vec<::BaseField> = - gpu_result.device_buffer.to_vec().unwrap(); + gpu_result.witness.device_buffer.to_vec().unwrap(); let cpu_data = cpu_witness.values(); assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); diff --git a/ceno_zkvm/src/instructions/riscv/gpu/sh.rs b/ceno_zkvm/src/instructions/riscv/gpu/sh.rs index e35d94bf0..72ea316f6 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/sh.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/sh.rs @@ -219,11 +219,11 @@ mod tests { let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); let gpu_result = hal - .witgen_sh(&col_map, &gpu_records, &indices_u32, shard_offset, None) + .witgen_sh(&col_map, &gpu_records, &indices_u32, shard_offset, 0, 0, 0, None, None) .unwrap(); let gpu_data: Vec<::BaseField> = - gpu_result.device_buffer.to_vec().unwrap(); + gpu_result.witness.device_buffer.to_vec().unwrap(); let cpu_data = cpu_witness.values(); assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); diff --git a/ceno_zkvm/src/instructions/riscv/gpu/shift_i.rs b/ceno_zkvm/src/instructions/riscv/gpu/shift_i.rs index e1555fcaf..22dee5dab 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/shift_i.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/shift_i.rs @@ -213,11 +213,11 @@ mod tests { let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); let gpu_result = hal - .witgen_shift_i(&col_map, &gpu_records, &indices_u32, shard_offset, 0, None) + .witgen_shift_i(&col_map, &gpu_records, &indices_u32, shard_offset, 0, 0, 0, None, None) .unwrap(); let gpu_data: Vec<::BaseField> = - gpu_result.device_buffer.to_vec().unwrap(); + gpu_result.witness.device_buffer.to_vec().unwrap(); let cpu_data = cpu_witness.values(); assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); diff --git a/ceno_zkvm/src/instructions/riscv/gpu/shift_r.rs b/ceno_zkvm/src/instructions/riscv/gpu/shift_r.rs index d6efa771c..7498b84a8 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/shift_r.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/shift_r.rs @@ -232,11 +232,11 @@ mod tests { let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); let gpu_result = hal - .witgen_shift_r(&col_map, &gpu_records, &indices_u32, shard_offset, 0, None) + .witgen_shift_r(&col_map, &gpu_records, &indices_u32, shard_offset, 0, 0, 0, None, None) .unwrap(); let gpu_data: Vec<::BaseField> = - gpu_result.device_buffer.to_vec().unwrap(); + gpu_result.witness.device_buffer.to_vec().unwrap(); let cpu_data = cpu_witness.values(); assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); diff --git a/ceno_zkvm/src/instructions/riscv/gpu/slt.rs b/ceno_zkvm/src/instructions/riscv/gpu/slt.rs index e39e8acab..a8023edbd 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/slt.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/slt.rs @@ -205,11 +205,11 @@ mod tests { let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); let gpu_result = hal - .witgen_slt(&col_map, &gpu_records, &indices_u32, shard_offset, 1, None) + .witgen_slt(&col_map, &gpu_records, &indices_u32, shard_offset, 1, 0, 0, None, None) .unwrap(); let gpu_data: Vec<::BaseField> = - gpu_result.device_buffer.to_vec().unwrap(); + gpu_result.witness.device_buffer.to_vec().unwrap(); let cpu_data = cpu_witness.values(); assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); diff --git a/ceno_zkvm/src/instructions/riscv/gpu/slti.rs b/ceno_zkvm/src/instructions/riscv/gpu/slti.rs index 42d454507..d0fcbca32 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/slti.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/slti.rs @@ -186,11 +186,11 @@ mod tests { let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); let gpu_result = hal - .witgen_slti(&col_map, &gpu_records, &indices_u32, shard_offset, 1, None) + .witgen_slti(&col_map, &gpu_records, &indices_u32, shard_offset, 1, 0, 0, None, None) .unwrap(); let gpu_data: Vec<::BaseField> = - gpu_result.device_buffer.to_vec().unwrap(); + gpu_result.witness.device_buffer.to_vec().unwrap(); let cpu_data = cpu_witness.values(); assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); diff --git a/ceno_zkvm/src/instructions/riscv/gpu/sub.rs b/ceno_zkvm/src/instructions/riscv/gpu/sub.rs index 80bc9b0ad..fd729b996 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/sub.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/sub.rs @@ -219,11 +219,11 @@ mod tests { let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); let gpu_result = hal - .witgen_sub(&col_map, &gpu_records, &indices_u32, shard_offset, None) + .witgen_sub(&col_map, &gpu_records, &indices_u32, shard_offset, 0, 0, None, None) .unwrap(); let gpu_data: Vec<::BaseField> = - gpu_result.device_buffer.to_vec().unwrap(); + gpu_result.witness.device_buffer.to_vec().unwrap(); let cpu_data = cpu_witness.values(); assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); diff --git a/ceno_zkvm/src/instructions/riscv/gpu/sw.rs b/ceno_zkvm/src/instructions/riscv/gpu/sw.rs index 4bdfafa5c..2142af2f1 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/sw.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/sw.rs @@ -202,11 +202,11 @@ mod tests { let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); let gpu_result = hal - .witgen_sw(&col_map, &gpu_records, &indices_u32, shard_offset, None) + .witgen_sw(&col_map, &gpu_records, &indices_u32, shard_offset, 0, 0, 0, None, None) .unwrap(); let gpu_data: Vec<::BaseField> = - gpu_result.device_buffer.to_vec().unwrap(); + gpu_result.witness.device_buffer.to_vec().unwrap(); let cpu_data = cpu_witness.values(); assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); diff --git a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs index 821d2a320..67a841a83 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs @@ -1874,8 +1874,6 @@ fn debug_compare_ec_points( return; } - println!("debug_compare_ec_points"); - use crate::scheme::septic_curve::{SepticExtension, SepticPoint}; use crate::tables::{ECPoint, ShardRamRecord}; use ff_ext::{BabyBearExt4 as E, PoseidonField, SmallField}; diff --git a/ceno_zkvm/src/scheme/septic_curve.rs b/ceno_zkvm/src/scheme/septic_curve.rs index 4bacfe011..23a7ad0e3 100644 --- a/ceno_zkvm/src/scheme/septic_curve.rs +++ b/ceno_zkvm/src/scheme/septic_curve.rs @@ -1112,7 +1112,7 @@ impl SepticJacobianPoint { mod tests { use super::SepticExtension; use crate::scheme::septic_curve::{SepticJacobianPoint, SepticPoint}; - use p3::{babybear::BabyBear, field::Field}; + use p3::{babybear::BabyBear, field::{Field, FieldAlgebra}}; use rand::thread_rng; type F = BabyBear; @@ -1249,4 +1249,136 @@ mod tests { test_inputs.len()); eprintln!("All {} GPU EC point test cases match CPU!", test_inputs.len()); } + + /// Verify GPU Poseidon2 permutation matches CPU on the exact sponge packing + /// used by to_ec_point(). This isolates Montgomery encoding correctness. + #[test] + #[cfg(feature = "gpu")] + fn test_gpu_poseidon2_sponge_matches_cpu() { + use ceno_gpu::bb31::test_impl::{run_gpu_poseidon2_sponge, SPONGE_WIDTH}; + use ceno_gpu::bb31::CudaHalBB31; + use ff_ext::{PoseidonField, SmallField}; + use p3::symmetric::Permutation; + + let hal = CudaHalBB31::new(0).unwrap(); + let perm = F::get_default_perm(); + + // Build sponge inputs matching to_ec_point packing: + // [addr, ram_type, value_lo16, value_hi16, shard, global_clk, nonce, 0..0] + let test_cases: Vec<[u32; SPONGE_WIDTH]> = vec![ + // Case 1: typical write record + { + let mut s = [0u32; SPONGE_WIDTH]; + s[0] = 5; // addr + s[1] = 1; // ram_type (Register) + s[2] = 0x5678; // value lo16 + s[3] = 0x1234; // value hi16 + s[4] = 1; // shard + s[5] = 100; // global_clk + s[6] = 0; // nonce + s + }, + // Case 2: memory read, different values + { + let mut s = [0u32; SPONGE_WIDTH]; + s[0] = 0x80000; + s[1] = 2; // Memory + s[2] = 0xBEEF; + s[3] = 0xDEAD; + s[4] = 2; + s[5] = 200; + s[6] = 3; // nonce=3 + s + }, + // Case 3: all zeros (edge case) + [0u32; SPONGE_WIDTH], + ]; + + let count = test_cases.len(); + let flat_input: Vec = test_cases.iter().flat_map(|s| s.iter().copied()).collect(); + let gpu_output = run_gpu_poseidon2_sponge(&hal, &flat_input, count); + + let mut mismatches = 0; + for (i, input) in test_cases.iter().enumerate() { + // CPU Poseidon2 + let cpu_input: Vec = input.iter().map(|&v| F::from_canonical_u32(v)).collect(); + let cpu_output = perm.permute(cpu_input); + + for j in 0..SPONGE_WIDTH { + let gpu_v = gpu_output[i * SPONGE_WIDTH + j]; + let cpu_v = cpu_output[j].to_canonical_u64() as u32; + if gpu_v != cpu_v { + eprintln!("[case {i}] sponge[{j}]: gpu={gpu_v} cpu={cpu_v}"); + mismatches += 1; + } + } + } + + assert_eq!(mismatches, 0, "{mismatches} Poseidon2 output elements differ between GPU and CPU"); + eprintln!("All {} Poseidon2 sponge test cases match!", count); + } + + /// Verify GPU septic_point_from_x matches CPU SepticPoint::from_x. + /// Tests the full GF(p^7) sqrt (Cipolla) + curve equation. + #[test] + #[cfg(feature = "gpu")] + fn test_gpu_septic_from_x_matches_cpu() { + use ceno_gpu::bb31::test_impl::run_gpu_septic_from_x; + use ceno_gpu::bb31::CudaHalBB31; + use ff_ext::SmallField; + + let hal = CudaHalBB31::new(0).unwrap(); + + // Generate test x-coordinates by hashing known inputs + // (ensures we get a mix of points-exist and points-don't-exist cases) + let test_xs: Vec<[u32; 7]> = vec![ + // x from Poseidon2([5,1,0x5678,0x1234,1,100,0, 0..]) — known to have a point + [1594766074, 868528894, 1733778006, 1242721508, 1690833816, 1437202757, 1753525271], + // Simple: x = [1,0,0,0,0,0,0] + [1, 0, 0, 0, 0, 0, 0], + // x = [0,0,0,0,0,0,0] (zero) + [0, 0, 0, 0, 0, 0, 0], + // x = [42, 17, 999, 0, 0, 0, 0] + [42, 17, 999, 0, 0, 0, 0], + // Random-ish values + [1000000007, 123456789, 987654321, 111111111, 222222222, 333333333, 444444444], + ]; + + let count = test_xs.len(); + let flat_x: Vec = test_xs.iter().flat_map(|x| x.iter().copied()).collect(); + let (gpu_y, gpu_flags) = run_gpu_septic_from_x(&hal, &flat_x, count); + + let mut mismatches = 0; + for (i, x_arr) in test_xs.iter().enumerate() { + // CPU: SepticPoint::from_x + let x = SepticExtension(x_arr.map(|v| F::from_canonical_u32(v))); + let cpu_result = SepticPoint::::from_x(x); + + let gpu_found = gpu_flags[i] != 0; + let cpu_found = cpu_result.is_some(); + + if gpu_found != cpu_found { + eprintln!("[{i}] from_x existence: gpu={gpu_found} cpu={cpu_found}"); + mismatches += 1; + continue; + } + + if let Some(cpu_pt) = cpu_result { + // Compare y coordinates (GPU returns canonical, before any negation) + for j in 0..7 { + let gpu_v = gpu_y[i * 7 + j]; + let cpu_v = cpu_pt.y.0[j].to_canonical_u64() as u32; + // from_x returns the "natural" sqrt; they should match exactly + if gpu_v != cpu_v { + eprintln!("[{i}] y[{j}]: gpu={gpu_v} cpu={cpu_v}"); + mismatches += 1; + } + } + } + } + + assert_eq!(mismatches, 0, + "{mismatches} septic_from_x results differ between GPU and CPU"); + eprintln!("All {} septic_from_x test cases match!", count); + } } From 2974a8ae9ad4945d06e9e85416bdccc0720b2aee Mon Sep 17 00:00:00 2001 From: Velaciela Date: Fri, 13 Mar 2026 09:54:38 +0800 Subject: [PATCH 35/73] debug --- .../src/instructions/riscv/gpu/witgen_gpu.rs | 383 +++++++++++++----- ceno_zkvm/src/structs.rs | 46 ++- 2 files changed, 332 insertions(+), 97 deletions(-) diff --git a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs index 67a841a83..49c7fe9e6 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs @@ -435,7 +435,12 @@ fn gpu_compact_ec_d2h( } /// Returns true if GPU shard records are verified for this kind. +/// Set CENO_GPU_DISABLE_SHARD_KINDS=all to force ALL kinds back to CPU shard path. fn kind_has_verified_shard(kind: GpuWitgenKind) -> bool { + // Global kill switch: force pure CPU shard path for baseline testing + if std::env::var_os("CENO_GPU_CPU_SHARD").is_some() { + return false; + } if is_shard_kind_disabled(kind) { return false; } @@ -622,30 +627,46 @@ fn gpu_assign_instances_inner>( info_span!("gpu_ec_shard").in_scope(|| { let compact = gpu_compact_ec.unwrap(); let compact_records = gpu_compact_ec_d2h(&compact)?; - debug_compare_ec_points(&compact_records, kind); - // Still need addr_accessed from the old ram_slots path - // (WAS_SENT flag indicates send() calls for addr_accessed tracking). - if gpu_ram_slots.is_some() { + // D2H ram_slots for addr_accessed (WAS_SENT flags only). + // Do NOT insert into BTreeMap — gpu_ec_records replace BTreeMap records. + let slots_vec: Option> = if gpu_ram_slots.is_some() { let ram_buf = gpu_ram_slots.unwrap(); - let slot_bytes: Vec = ram_buf.to_vec().map_err(|e| { + Some(ram_buf.to_vec().map_err(|e| { ZKVMError::InvalidWitness(format!("ram_slots D2H failed: {e}").into()) - })?; - let slots: &[GpuRamRecordSlot] = unsafe { + })?) + } else { + None + }; + let slots: &[GpuRamRecordSlot] = if let Some(ref sv) = slots_vec { + unsafe { std::slice::from_raw_parts( - slot_bytes.as_ptr() as *const GpuRamRecordSlot, - slot_bytes.len() * 4 / std::mem::size_of::(), + sv.as_ptr() as *const GpuRamRecordSlot, + sv.len() * 4 / std::mem::size_of::(), ) - }; + } + } else { + &[] + }; + + // Debug: compare GPU shard_ctx vs CPU shard_ctx independently + debug_compare_shard_ec::( + &compact_records, slots, config, shard_ctx, + shard_steps, step_indices, kind, + ); + + // Populate shard_ctx: addr_accessed from ram_slots + if !slots.is_empty() { let mut forked = shard_ctx.get_forked(); let thread_ctx = &mut forked[0]; - // Only collect addr_accessed (WAS_SENT) and BTreeMap records - // from slot-based path, for compatibility. - gpu_collect_shard_records(thread_ctx, slots); + for slot in slots { + if slot.flags & (1 << 4) != 0 { + thread_ctx.push_addr_accessed(WordAddr(slot.addr)); + } + } } - // Store raw GPU EC records for downstream assign_shared_circuit. - // Records are stored as raw bytes and converted to ShardRamInput later. + // Populate shard_ctx: gpu_ec_records (raw bytes for assign_shared_circuit) let raw_bytes = unsafe { std::slice::from_raw_parts( compact_records.as_ptr() as *const u8, @@ -1863,11 +1884,27 @@ fn debug_compare_shard_side_effects>( Ok(()) } -/// Compare GPU-produced EC points against CPU to_ec_point() for correctness. +/// Compare GPU shard context vs CPU shard context, field by field. +/// +/// Both paths are independent and produce equivalent ShardContext state: +/// CPU path: cpu_collect_shard_side_effects → addr_accessed + write_records + read_records +/// GPU path: compact_records → shard records (④gpu_ec_records) +/// ram_slots WAS_SENT → addr_accessed (①) +/// (②write_records and ③read_records stay empty for GPU EC kernels) +/// +/// This function builds both independently and compares: +/// A. addr_accessed sets +/// B. shard records (sorted, normalized to ShardRamRecord) +/// C. EC points (nonce + SepticPoint x,y) +/// /// Activated by CENO_GPU_DEBUG_COMPARE_EC=1. -/// Limit output with CENO_GPU_DEBUG_COMPARE_EC_LIMIT (default: 16). -fn debug_compare_ec_points( +fn debug_compare_shard_ec>( compact_records: &[GpuShardRamRecord], + ram_slots: &[GpuRamRecordSlot], + config: &I::InstructionConfig, + shard_ctx: &ShardContext, + shard_steps: &[StepRecord], + step_indices: &[StepIndex], kind: GpuWitgenKind, ) { if std::env::var_os("CENO_GPU_DEBUG_COMPARE_EC").is_none() { @@ -1876,115 +1913,279 @@ fn debug_compare_ec_points( use crate::scheme::septic_curve::{SepticExtension, SepticPoint}; use crate::tables::{ECPoint, ShardRamRecord}; - use ff_ext::{BabyBearExt4 as E, PoseidonField, SmallField}; - use p3::babybear::BabyBear; + use ff_ext::{PoseidonField, SmallField}; + let limit = std::env::var("CENO_GPU_DEBUG_COMPARE_EC_LIMIT") .ok() .and_then(|v| v.parse::().ok()) .unwrap_or(16); - let perm = BabyBear::get_default_perm(); + // ========== Build CPU shard context (independent, isolated) ========== + let mut cpu_ctx = shard_ctx.new_empty_like(); + if let Err(e) = cpu_collect_shard_side_effects::( + config, &mut cpu_ctx, shard_steps, step_indices, + ) { + tracing::error!("[GPU EC debug] kind={kind:?} CPU shard side effects failed: {e:?}"); + return; + } - let mut mismatches = 0usize; - let mut field_mismatches = 0usize; - let mut nonce_mismatches = 0usize; - - for (i, gpu_rec) in compact_records.iter().enumerate() { - // Reconstruct ShardRamRecord from GPU record fields - let cpu_record = ShardRamRecord { - addr: gpu_rec.addr, - ram_type: if gpu_rec.ram_type == 1 { - RAMType::Register - } else { - RAMType::Memory - }, - value: gpu_rec.value, - shard: gpu_rec.shard, - local_clk: gpu_rec.local_clk, - global_clk: gpu_rec.global_clk, - is_to_write_set: gpu_rec.is_to_write_set != 0, - }; + let perm = ::get_default_perm(); - // CPU computes EC point - let cpu_ec: ECPoint = cpu_record.to_ec_point(&perm); + // CPU: addr_accessed + let cpu_addr = cpu_ctx.get_addr_accessed(); - // GPU EC point (from canonical u32) - let gpu_x = SepticExtension( - gpu_rec.point_x.map(|v| BabyBear::from_canonical_u32(v)), - ); - let gpu_y = SepticExtension( - gpu_rec.point_y.map(|v| BabyBear::from_canonical_u32(v)), + // CPU: shard records (BTreeMap → ShardRamRecord + ECPoint) + let mut cpu_entries: Vec<(ShardRamRecord, ECPoint)> = Vec::new(); + for records in cpu_ctx.write_records() { + for (vma, record) in records { + let rec: ShardRamRecord = (vma, record, true).into(); + let ec = rec.to_ec_point::(&perm); + cpu_entries.push((rec, ec)); + } + } + for records in cpu_ctx.read_records() { + for (vma, record) in records { + let rec: ShardRamRecord = (vma, record, false).into(); + let ec = rec.to_ec_point::(&perm); + cpu_entries.push((rec, ec)); + } + } + cpu_entries.sort_by_key(|(r, _)| (r.addr, r.is_to_write_set as u8, r.ram_type as u8)); + + // ========== Build GPU shard context (independent, from D2H data only) ========== + + // GPU: addr_accessed (from ram_slots WAS_SENT flags) + let gpu_addr: rustc_hash::FxHashSet = ram_slots + .iter() + .filter(|s| s.flags & (1 << 4) != 0) + .map(|s| WordAddr(s.addr)) + .collect(); + + // GPU: shard records (compact_records → ShardRamRecord + ECPoint) + let mut gpu_entries: Vec<(ShardRamRecord, ECPoint)> = compact_records + .iter() + .map(|g| { + let rec = ShardRamRecord { + addr: g.addr, + ram_type: if g.ram_type == 1 { RAMType::Register } else { RAMType::Memory }, + value: g.value, + shard: g.shard, + local_clk: g.local_clk, + global_clk: g.global_clk, + is_to_write_set: g.is_to_write_set != 0, + }; + let x = SepticExtension(g.point_x.map(|v| E::BaseField::from_canonical_u32(v))); + let y = SepticExtension(g.point_y.map(|v| E::BaseField::from_canonical_u32(v))); + let point = SepticPoint::from_affine(x, y); + let ec = ECPoint:: { nonce: g.nonce, point }; + (rec, ec) + }) + .collect(); + gpu_entries.sort_by_key(|(r, _)| (r.addr, r.is_to_write_set as u8, r.ram_type as u8)); + + // ========== Compare A: addr_accessed ========== + if cpu_addr != gpu_addr { + let cpu_only: Vec<_> = cpu_addr.difference(&gpu_addr).collect(); + let gpu_only: Vec<_> = gpu_addr.difference(&cpu_addr).collect(); + tracing::error!( + "[GPU EC debug] kind={kind:?} ADDR_ACCESSED MISMATCH: cpu={} gpu={} \ + cpu_only={} gpu_only={}", + cpu_addr.len(), gpu_addr.len(), cpu_only.len(), gpu_only.len() ); - // Verify point is on curve (optional sanity check) - let _gpu_point = SepticPoint::from_affine(gpu_x, gpu_y); + for (i, addr) in cpu_only.iter().enumerate() { + if i >= limit { break; } + tracing::error!("[GPU EC debug] kind={kind:?} addr_accessed CPU-only: {}", addr.0); + } + for (i, addr) in gpu_only.iter().enumerate() { + if i >= limit { break; } + tracing::error!("[GPU EC debug] kind={kind:?} addr_accessed GPU-only: {}", addr.0); + } + } + + // ========== Compare B+C: shard records + EC points ========== - let mut has_diff = false; + // Check counts + if cpu_entries.len() != gpu_entries.len() { + tracing::error!( + "[GPU EC debug] kind={kind:?} RECORD COUNT MISMATCH: cpu={} gpu={}", + cpu_entries.len(), gpu_entries.len() + ); + let cpu_keys: std::collections::BTreeSet<_> = cpu_entries + .iter().map(|(r, _)| (r.addr, r.is_to_write_set)).collect(); + let gpu_keys: std::collections::BTreeSet<_> = gpu_entries + .iter().map(|(r, _)| (r.addr, r.is_to_write_set)).collect(); + let mut logged = 0usize; + for key in cpu_keys.difference(&gpu_keys) { + if logged >= limit { break; } + tracing::error!("[GPU EC debug] kind={kind:?} CPU-only: addr={} is_write={}", key.0, key.1); + logged += 1; + } + for key in gpu_keys.difference(&cpu_keys) { + if logged >= limit { break; } + tracing::error!("[GPU EC debug] kind={kind:?} GPU-only: addr={} is_write={}", key.0, key.1); + logged += 1; + } + } - // Compare nonce - if gpu_rec.nonce != cpu_ec.nonce { - nonce_mismatches += 1; - has_diff = true; - if mismatches < limit { + // Check GPU duplicates (BTreeMap deduplicates, atomicAdd doesn't) + let mut gpu_dup_count = 0usize; + for w in gpu_entries.windows(2) { + if w[0].0.addr == w[1].0.addr + && w[0].0.is_to_write_set == w[1].0.is_to_write_set + && w[0].0.ram_type == w[1].0.ram_type + { + gpu_dup_count += 1; + if gpu_dup_count <= limit { tracing::error!( - "[GPU EC debug] kind={kind:?} rec[{i}] nonce mismatch: gpu={} cpu={}", - gpu_rec.nonce, - cpu_ec.nonce + "[GPU EC debug] kind={kind:?} GPU DUPLICATE: addr={} is_write={} ram_type={:?}", + w[0].0.addr, w[0].0.is_to_write_set, w[0].0.ram_type ); } } + } - // Compare x coordinates - for j in 0..7 { - let gpu_v = gpu_rec.point_x[j]; - let cpu_v = cpu_ec.point.x.0[j].to_canonical_u64() as u32; - if gpu_v != cpu_v { - field_mismatches += 1; - has_diff = true; - if mismatches < limit { + // Merge-walk sorted lists + let mut ci = 0usize; + let mut gi = 0usize; + let mut record_mismatches = 0usize; + let mut ec_mismatches = 0usize; + let mut matched = 0usize; + + while ci < cpu_entries.len() && gi < gpu_entries.len() { + let (cr, ce) = &cpu_entries[ci]; + let (gr, ge) = &gpu_entries[gi]; + let ck = (cr.addr, cr.is_to_write_set as u8, cr.ram_type as u8); + let gk = (gr.addr, gr.is_to_write_set as u8, gr.ram_type as u8); + + match ck.cmp(&gk) { + std::cmp::Ordering::Less => { + if record_mismatches < limit { + tracing::error!( + "[GPU EC debug] kind={kind:?} MISSING in GPU: addr={} is_write={} ram={:?} val={} shard={} clk={}", + cr.addr, cr.is_to_write_set, cr.ram_type, cr.value, cr.shard, cr.global_clk + ); + } + record_mismatches += 1; + ci += 1; + continue; + } + std::cmp::Ordering::Greater => { + if record_mismatches < limit { + tracing::error!( + "[GPU EC debug] kind={kind:?} EXTRA in GPU: addr={} is_write={} ram={:?} val={} shard={} clk={}", + gr.addr, gr.is_to_write_set, gr.ram_type, gr.value, gr.shard, gr.global_clk + ); + } + record_mismatches += 1; + gi += 1; + continue; + } + std::cmp::Ordering::Equal => {} + } + + // Keys match — compare record fields + let mut field_diff = false; + for (name, cv, gv) in [ + ("value", cr.value as u64, gr.value as u64), + ("shard", cr.shard, gr.shard), + ("local_clk", cr.local_clk, gr.local_clk), + ("global_clk", cr.global_clk, gr.global_clk), + ] { + if cv != gv { + field_diff = true; + if record_mismatches < limit { tracing::error!( - "[GPU EC debug] kind={kind:?} rec[{i}] x[{j}] mismatch: gpu={gpu_v} cpu={cpu_v}" + "[GPU EC debug] kind={kind:?} addr={} {name}: cpu={cv} gpu={gv}", + cr.addr ); } } } + if field_diff { + record_mismatches += 1; + } - // Compare y coordinates + // Compare EC points + let mut ec_diff = false; + if ce.nonce != ge.nonce { + ec_diff = true; + if ec_mismatches < limit { + tracing::error!( + "[GPU EC debug] kind={kind:?} addr={} nonce: cpu={} gpu={}", + cr.addr, ce.nonce, ge.nonce + ); + } + } + for j in 0..7 { + let cv = ce.point.x.0[j].to_canonical_u64() as u32; + let gv = ge.point.x.0[j].to_canonical_u64() as u32; + if cv != gv { + ec_diff = true; + if ec_mismatches < limit { + tracing::error!( + "[GPU EC debug] kind={kind:?} addr={} x[{j}]: cpu={cv} gpu={gv}", cr.addr + ); + } + } + } for j in 0..7 { - let gpu_v = gpu_rec.point_y[j]; - let cpu_v = cpu_ec.point.y.0[j].to_canonical_u64() as u32; - if gpu_v != cpu_v { - field_mismatches += 1; - has_diff = true; - if mismatches < limit { + let cv = ce.point.y.0[j].to_canonical_u64() as u32; + let gv = ge.point.y.0[j].to_canonical_u64() as u32; + if cv != gv { + ec_diff = true; + if ec_mismatches < limit { tracing::error!( - "[GPU EC debug] kind={kind:?} rec[{i}] y[{j}] mismatch: gpu={gpu_v} cpu={cpu_v} \ - (addr={} ram_type={} value={} shard={} clk={} is_write={})", - gpu_rec.addr, - gpu_rec.ram_type, - gpu_rec.value, - gpu_rec.shard, - gpu_rec.global_clk, - gpu_rec.is_to_write_set + "[GPU EC debug] kind={kind:?} addr={} y[{j}]: cpu={cv} gpu={gv}", cr.addr ); } } } + if ec_diff { + ec_mismatches += 1; + } + + matched += 1; + ci += 1; + gi += 1; + } - if has_diff { - mismatches += 1; + // Remaining unmatched + while ci < cpu_entries.len() { + if record_mismatches < limit { + let (cr, _) = &cpu_entries[ci]; + tracing::error!( + "[GPU EC debug] kind={kind:?} MISSING in GPU (tail): addr={} is_write={} val={}", + cr.addr, cr.is_to_write_set, cr.value + ); + } + record_mismatches += 1; + ci += 1; + } + while gi < gpu_entries.len() { + if record_mismatches < limit { + let (gr, _) = &gpu_entries[gi]; + tracing::error!( + "[GPU EC debug] kind={kind:?} EXTRA in GPU (tail): addr={} is_write={} val={}", + gr.addr, gr.is_to_write_set, gr.value + ); } + record_mismatches += 1; + gi += 1; } - if mismatches == 0 { + // ========== Summary ========== + let addr_ok = cpu_addr == gpu_addr; + if addr_ok && record_mismatches == 0 && ec_mismatches == 0 && gpu_dup_count == 0 { tracing::info!( - "[GPU EC debug] kind={kind:?} ALL {} EC points match CPU", - compact_records.len() + "[GPU EC debug] kind={kind:?} ALL MATCH: {} records, {} addr_accessed, EC points OK", + matched, cpu_addr.len() ); } else { tracing::error!( - "[GPU EC debug] kind={kind:?} {mismatches}/{} records have mismatches \ - (nonce_diffs={nonce_mismatches} field_diffs={field_mismatches})", - compact_records.len() + "[GPU EC debug] kind={kind:?} MISMATCH: matched={matched} record_diffs={record_mismatches} \ + ec_diffs={ec_mismatches} gpu_dups={gpu_dup_count} addr_ok={addr_ok} \ + (cpu_records={} gpu_records={} cpu_addrs={} gpu_addrs={})", + cpu_entries.len(), gpu_entries.len(), cpu_addr.len(), gpu_addr.len() ); } } diff --git a/ceno_zkvm/src/structs.rs b/ceno_zkvm/src/structs.rs index 8b24e20bd..ee2087aa8 100644 --- a/ceno_zkvm/src/structs.rs +++ b/ceno_zkvm/src/structs.rs @@ -542,11 +542,16 @@ impl ZKVMWitnesses { .collect::>(); // GPU EC records: convert raw bytes to ShardRamInput (EC points already computed on GPU) - let gpu_ec_inputs = if shard_ctx.has_gpu_ec_records() { - gpu_ec_records_to_shard_ram_inputs::(&shard_ctx.gpu_ec_records) - } else { - vec![] - }; + // Partition into writes and reads to maintain the ordering invariant required by + // ShardRamCircuit::assign_instances (writes first, reads after). + let (gpu_ec_writes, gpu_ec_reads): (Vec<_>, Vec<_>) = + if shard_ctx.has_gpu_ec_records() { + gpu_ec_records_to_shard_ram_inputs::(&shard_ctx.gpu_ec_records) + .into_iter() + .partition(|input| input.record.is_to_write_set) + } else { + (vec![], vec![]) + }; let global_input = shard_ctx .write_records() @@ -565,6 +570,7 @@ impl ZKVMWitnesses { }) .chain(first_shard_access_later_records.into_par_iter()) .chain(current_shard_access_later.into_par_iter()) + .chain(gpu_ec_writes.into_par_iter()) .chain(shard_ctx.read_records().par_iter().flat_map(|records| { // global read -> local write records.par_iter().map(|(vma, record)| { @@ -577,7 +583,7 @@ impl ZKVMWitnesses { } }) })) - .chain(gpu_ec_inputs.into_par_iter()) + .chain(gpu_ec_reads.into_par_iter()) .collect::>(); if tracing::enabled!(Level::DEBUG) { @@ -608,6 +614,34 @@ impl ZKVMWitnesses { } } + // Invariant: all writes (is_to_write_set=true) must precede all reads. + // ShardRamCircuit::assign_instances uses take_while to count writes. + // Activate with CENO_DEBUG_SHARD_RAM_ORDER=1. + if std::env::var_os("CENO_DEBUG_SHARD_RAM_ORDER").is_some() { + let mut seen_read = false; + for (i, input) in global_input.iter().enumerate() { + if input.record.is_to_write_set { + if seen_read { + tracing::error!( + "[SHARD_RAM_ORDER] BUG: write after read at index={i} \ + addr={} ram_type={:?} shard={} global_clk={} \ + (total={} writes={} reads={})", + input.record.addr, + input.record.ram_type, + shard_ctx.shard_id, + input.record.global_clk, + global_input.len(), + global_input.iter().filter(|x| x.record.is_to_write_set).count(), + global_input.iter().filter(|x| !x.record.is_to_write_set).count(), + ); + break; + } + } else { + seen_read = true; + } + } + } + assert!(self.combined_lk_mlt.is_some()); let cs = cs.get_cs(&ShardRamCircuit::::name()).unwrap(); let circuit_inputs = global_input From 018ad73df40cb6bd04a62e5c7a651547884fd9f3 Mon Sep 17 00:00:00 2001 From: Velaciela Date: Fri, 13 Mar 2026 09:54:38 +0800 Subject: [PATCH 36/73] perf --- Cargo.lock | 1 + .../src/instructions/riscv/gpu/witgen_gpu.rs | 240 ++++++++++-------- ceno_zkvm/src/structs.rs | 6 +- ceno_zkvm/src/tables/mod.rs | 4 +- ceno_zkvm/src/tables/ops/ops_circuit.rs | 5 +- ceno_zkvm/src/tables/ops/ops_impl.rs | 4 +- ceno_zkvm/src/tables/program.rs | 5 +- ceno_zkvm/src/tables/ram/ram_circuit.rs | 11 +- ceno_zkvm/src/tables/range/range_circuit.rs | 7 +- ceno_zkvm/src/tables/range/range_impl.rs | 6 +- ceno_zkvm/src/tables/shard_ram.rs | 5 +- gkr_iop/Cargo.toml | 1 + gkr_iop/src/utils/lk_multiplicity.rs | 4 +- 13 files changed, 163 insertions(+), 136 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index f47740ca1..79f9fb07b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2451,6 +2451,7 @@ dependencies = [ "p3", "rand 0.8.5", "rayon", + "rustc-hash", "serde", "smallvec", "strum", diff --git a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs index 49c7fe9e6..5048eddd1 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs @@ -13,6 +13,7 @@ use ceno_gpu::common::witgen_types::{CompactEcResult, GpuRamRecordSlot, GpuShard use ff_ext::ExtensionField; use gkr_iop::{RAMType, tables::LookupTable, utils::lk_multiplicity::Multiplicity}; use p3::field::FieldAlgebra; +use rustc_hash::FxHashMap; use std::cell::{Cell, RefCell}; use tracing::info_span; use witness::{InstancePaddingStrategy, RowMajorMatrix}; @@ -419,13 +420,13 @@ fn gpu_compact_ec_d2h( return Ok(vec![]); } - // D2H the buffer (all u32s), then reinterpret as GpuShardRamRecord - let buf_vec: Vec = compact.buffer.to_vec().map_err(|e| { + // Partial D2H: only transfer the first `count` records (not the full allocation) + let record_u32s = std::mem::size_of::() / 4; // 26 + let total_u32s = count * record_u32s; + let buf_vec: Vec = compact.buffer.to_vec_n(total_u32s).map_err(|e| { ZKVMError::InvalidWitness(format!("compact_out D2H failed: {e}").into()) })?; - let record_u32s = std::mem::size_of::() / 4; // 26 - let total_u32s = count * record_u32s; let records: Vec = unsafe { let ptr = buf_vec.as_ptr() as *const GpuShardRamRecord; std::slice::from_raw_parts(ptr, count).to_vec() @@ -602,7 +603,7 @@ fn gpu_assign_instances_inner>( let total_instances = step_indices.len(); // Step 1: GPU fills witness matrix (+ LK counters + shard records for merged kinds) - let (gpu_witness, gpu_lk_counters, gpu_ram_slots, gpu_compact_ec) = info_span!("gpu_kernel").in_scope(|| { + let (gpu_witness, gpu_lk_counters, gpu_ram_slots, gpu_compact_ec, gpu_compact_addr) = info_span!("gpu_kernel").in_scope(|| { gpu_fill_witness::( hal, config, @@ -626,44 +627,71 @@ fn gpu_assign_instances_inner>( // D2H only the active records (much smaller than full N*3 slot buffer). info_span!("gpu_ec_shard").in_scope(|| { let compact = gpu_compact_ec.unwrap(); - let compact_records = gpu_compact_ec_d2h(&compact)?; - - // D2H ram_slots for addr_accessed (WAS_SENT flags only). - // Do NOT insert into BTreeMap — gpu_ec_records replace BTreeMap records. - let slots_vec: Option> = if gpu_ram_slots.is_some() { - let ram_buf = gpu_ram_slots.unwrap(); - Some(ram_buf.to_vec().map_err(|e| { - ZKVMError::InvalidWitness(format!("ram_slots D2H failed: {e}").into()) - })?) - } else { - None - }; - let slots: &[GpuRamRecordSlot] = if let Some(ref sv) = slots_vec { - unsafe { - std::slice::from_raw_parts( - sv.as_ptr() as *const GpuRamRecordSlot, - sv.len() * 4 / std::mem::size_of::(), - ) + let compact_records = info_span!("compact_d2h") + .in_scope(|| gpu_compact_ec_d2h(&compact))?; + + // D2H ram_slots lazily (only for debug or fallback). + // Avoid the 68 MB D2H in the common case. + let ram_slots_d2h = || -> Result, ZKVMError> { + if let Some(ref ram_buf) = gpu_ram_slots { + let sv: Vec = ram_buf.to_vec().map_err(|e| { + ZKVMError::InvalidWitness( + format!("ram_slots D2H failed: {e}").into(), + ) + })?; + Ok(unsafe { + let ptr = sv.as_ptr() as *const GpuRamRecordSlot; + let len = sv.len() * 4 / std::mem::size_of::(); + std::slice::from_raw_parts(ptr, len).to_vec() + }) + } else { + Ok(vec![]) } - } else { - &[] }; - // Debug: compare GPU shard_ctx vs CPU shard_ctx independently - debug_compare_shard_ec::( - &compact_records, slots, config, shard_ctx, - shard_steps, step_indices, kind, - ); - - // Populate shard_ctx: addr_accessed from ram_slots - if !slots.is_empty() { - let mut forked = shard_ctx.get_forked(); - let thread_ctx = &mut forked[0]; - for slot in slots { - if slot.flags & (1 << 4) != 0 { - thread_ctx.push_addr_accessed(WordAddr(slot.addr)); + // D2H compact addr_accessed (GPU-side compaction via atomicAdd). + // Much smaller than full ram_slots D2H (4 bytes/addr vs 48 bytes/slot). + info_span!("compact_addr_d2h").in_scope(|| -> Result<(), ZKVMError> { + if let Some(ref ca) = gpu_compact_addr { + let count_vec: Vec = ca.count_buf.to_vec().map_err(|e| { + ZKVMError::InvalidWitness( + format!("compact_addr_count D2H failed: {e}").into(), + ) + })?; + let n = count_vec[0] as usize; + if n > 0 { + let addrs: Vec = ca.buffer.to_vec_n(n).map_err(|e| { + ZKVMError::InvalidWitness( + format!("compact_addr D2H failed: {e}").into(), + ) + })?; + let mut forked = shard_ctx.get_forked(); + let thread_ctx = &mut forked[0]; + for &addr in &addrs { + thread_ctx.push_addr_accessed(WordAddr(addr)); + } + } + } else { + // Fallback: D2H full ram_slots for addr_accessed + let slots = ram_slots_d2h()?; + let mut forked = shard_ctx.get_forked(); + let thread_ctx = &mut forked[0]; + for slot in &slots { + if slot.flags & (1 << 4) != 0 { + thread_ctx.push_addr_accessed(WordAddr(slot.addr)); + } } } + Ok(()) + })?; + + // Debug: compare GPU shard_ctx vs CPU shard_ctx independently + if std::env::var_os("CENO_GPU_DEBUG_COMPARE_EC").is_some() { + let slots = ram_slots_d2h()?; + debug_compare_shard_ec::( + &compact_records, &slots, config, shard_ctx, + shard_steps, step_indices, kind, + ); } // Populate shard_ctx: gpu_ec_records (raw bytes for assign_shared_circuit) @@ -788,7 +816,7 @@ fn gpu_fill_witness>( shard_steps: &[StepRecord], step_indices: &[StepIndex], kind: GpuWitgenKind, -) -> Result<(WitResult, Option, Option, Option), ZKVMError> { +) -> Result<(WitResult, Option, Option, Option, Option), ZKVMError> { // Upload shard_steps to GPU once (cached across ADD/LW calls within same shard). let shard_id = shard_ctx.shard_id; info_span!("upload_shard_steps") @@ -799,11 +827,11 @@ fn gpu_fill_witness>( .in_scope(|| step_indices.iter().map(|&i| i as u32).collect()); let shard_offset = shard_ctx.current_shard_offset_cycle(); - // Helper to split GpuWitgenFullResult into (witness, Some(lk_counters), ram_slots, compact_ec) + // Helper to split GpuWitgenFullResult into (witness, Some(lk_counters), ram_slots, compact_ec, compact_addr) macro_rules! split_full { ($result:expr) => {{ let full = $result?; - Ok((full.witness, Some(full.lk_counters), full.ram_slots, full.compact_ec)) + Ok((full.witness, Some(full.lk_counters), full.ram_slots, full.compact_ec, full.compact_addr)) }}; } @@ -2263,83 +2291,75 @@ fn lookup_table_name(table_idx: usize) -> &'static str { } fn gpu_lk_counters_to_multiplicity(counters: LkResult) -> Result, ZKVMError> { - let mut lk = LkMultiplicity::default(); - merge_dense_counter_table( - &mut lk, - LookupTable::Dynamic, - &counters.dynamic.to_vec().map_err(|e| { + let mut tables: [FxHashMap; 8] = Default::default(); + + // Dynamic: D2H + direct FxHashMap construction (no LkMultiplicity) + info_span!("lk_dynamic_d2h").in_scope(|| { + let counts: Vec = counters.dynamic.to_vec().map_err(|e| { ZKVMError::InvalidWitness(format!("GPU dynamic lk D2H failed: {e}").into()) - })?, - ); - merge_dense_counter_table( - &mut lk, - LookupTable::DoubleU8, - &counters.double_u8.to_vec().map_err(|e| { - ZKVMError::InvalidWitness(format!("GPU double_u8 lk D2H failed: {e}").into()) - })?, - ); - merge_dense_counter_table( - &mut lk, - LookupTable::And, - &counters - .and_table - .to_vec() - .map_err(|e| ZKVMError::InvalidWitness(format!("GPU and lk D2H failed: {e}").into()))?, - ); - merge_dense_counter_table( - &mut lk, - LookupTable::Or, - &counters - .or_table - .to_vec() - .map_err(|e| ZKVMError::InvalidWitness(format!("GPU or lk D2H failed: {e}").into()))?, - ); - merge_dense_counter_table( - &mut lk, - LookupTable::Xor, - &counters - .xor_table - .to_vec() - .map_err(|e| ZKVMError::InvalidWitness(format!("GPU xor lk D2H failed: {e}").into()))?, - ); - merge_dense_counter_table( - &mut lk, - LookupTable::Ltu, - &counters - .ltu_table - .to_vec() - .map_err(|e| ZKVMError::InvalidWitness(format!("GPU ltu lk D2H failed: {e}").into()))?, - ); - merge_dense_counter_table( - &mut lk, - LookupTable::Pow, - &counters - .pow_table - .to_vec() - .map_err(|e| ZKVMError::InvalidWitness(format!("GPU pow lk D2H failed: {e}").into()))?, - ); - // Merge fetch (Instruction) table if present - if let Some(fetch_buf) = counters.fetch { - let base_pc = counters.fetch_base_pc; - let fetch_counts = fetch_buf.to_vec().map_err(|e| { - ZKVMError::InvalidWitness(format!("GPU fetch lk D2H failed: {e}").into()) })?; - for (slot_idx, &count) in fetch_counts.iter().enumerate() { + let nnz = counts.iter().filter(|&&c| c != 0).count(); + let map = &mut tables[LookupTable::Dynamic as usize]; + map.reserve(nnz); + for (key, &count) in counts.iter().enumerate() { if count != 0 { - let pc = base_pc as u64 + (slot_idx as u64) * 4; - lk.set_count(LookupTable::Instruction, pc, count as usize); + map.insert(key as u64, count as usize); } } - } - Ok(lk.into_finalize_result()) -} + Ok::<(), ZKVMError>(()) + })?; -fn merge_dense_counter_table(lk: &mut LkMultiplicity, table: LookupTable, counts: &[u32]) { - for (key, &count) in counts.iter().enumerate() { - if count != 0 { - lk.set_count(table, key as u64, count as usize); + // Dense tables: same pattern, skip None + info_span!("lk_dense_d2h").in_scope(|| { + let dense: &[(LookupTable, &Option)] = &[ + (LookupTable::DoubleU8, &counters.double_u8), + (LookupTable::And, &counters.and_table), + (LookupTable::Or, &counters.or_table), + (LookupTable::Xor, &counters.xor_table), + (LookupTable::Ltu, &counters.ltu_table), + (LookupTable::Pow, &counters.pow_table), + ]; + for &(table, ref buf_opt) in dense { + if let Some(buf) = buf_opt { + let counts: Vec = buf.to_vec().map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU {:?} lk D2H failed: {e}", table).into(), + ) + })?; + let nnz = counts.iter().filter(|&&c| c != 0).count(); + let map = &mut tables[table as usize]; + map.reserve(nnz); + for (key, &count) in counts.iter().enumerate() { + if count != 0 { + map.insert(key as u64, count as usize); + } + } + } } + Ok::<(), ZKVMError>(()) + })?; + + // Fetch (Instruction table) + if let Some(fetch_buf) = counters.fetch { + info_span!("lk_fetch_d2h").in_scope(|| { + let base_pc = counters.fetch_base_pc; + let counts = fetch_buf.to_vec().map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU fetch lk D2H failed: {e}").into()) + })?; + let nnz = counts.iter().filter(|&&c| c != 0).count(); + let map = &mut tables[LookupTable::Instruction as usize]; + map.reserve(nnz); + for (slot_idx, &count) in counts.iter().enumerate() { + if count != 0 { + let pc = base_pc as u64 + (slot_idx as u64) * 4; + map.insert(pc, count as usize); + } + } + Ok::<(), ZKVMError>(()) + })?; } + + Ok(Multiplicity(tables)) } /// Convert GPU device buffer (column-major) to RowMajorMatrix via GPU transpose + D2H copy. diff --git a/ceno_zkvm/src/structs.rs b/ceno_zkvm/src/structs.rs index ee2087aa8..4fcf19fa4 100644 --- a/ceno_zkvm/src/structs.rs +++ b/ceno_zkvm/src/structs.rs @@ -20,7 +20,7 @@ use rayon::{ iter::{IntoParallelIterator, IntoParallelRefIterator, ParallelIterator}, prelude::ParallelSlice, }; -use rustc_hash::FxHashSet; +use rustc_hash::{FxHashMap, FxHashSet}; use serde::{Deserialize, Serialize, de::DeserializeOwned}; use std::{ collections::{BTreeMap, HashMap}, @@ -351,7 +351,7 @@ impl ChipInput { pub struct ZKVMWitnesses { pub witnesses: BTreeMap>>, lk_mlts: BTreeMap>, - combined_lk_mlt: Option>>, + combined_lk_mlt: Option>>, } impl ZKVMWitnesses { @@ -363,7 +363,7 @@ impl ZKVMWitnesses { self.lk_mlts.get(name) } - pub fn combined_lk_mlt(&self) -> Option<&Vec>> { + pub fn combined_lk_mlt(&self) -> Option<&Vec>> { self.combined_lk_mlt.as_ref() } diff --git a/ceno_zkvm/src/tables/mod.rs b/ceno_zkvm/src/tables/mod.rs index 0acfed059..db32dfe66 100644 --- a/ceno_zkvm/src/tables/mod.rs +++ b/ceno_zkvm/src/tables/mod.rs @@ -7,7 +7,7 @@ use gkr_iop::{ }; use itertools::Itertools; use multilinear_extensions::ToExpr; -use std::collections::HashMap; +use rustc_hash::FxHashMap; use witness::RowMajorMatrix; mod shard_ram; @@ -94,7 +94,7 @@ pub trait TableCircuit { config: &Self::TableConfig, num_witin: usize, num_structural_witin: usize, - multiplicity: &[HashMap], + multiplicity: &[FxHashMap], input: &Self::WitnessInput<'_>, ) -> Result, ZKVMError>; } diff --git a/ceno_zkvm/src/tables/ops/ops_circuit.rs b/ceno_zkvm/src/tables/ops/ops_circuit.rs index b1216f5ae..05948c00c 100644 --- a/ceno_zkvm/src/tables/ops/ops_circuit.rs +++ b/ceno_zkvm/src/tables/ops/ops_circuit.rs @@ -2,7 +2,8 @@ use super::ops_impl::OpTableConfig; -use std::{collections::HashMap, marker::PhantomData}; +use rustc_hash::FxHashMap; +use std::marker::PhantomData; use crate::{ circuit_builder::CircuitBuilder, @@ -47,7 +48,7 @@ impl TableCircuit for OpsTableCircuit config: &Self::TableConfig, num_witin: usize, num_structural_witin: usize, - multiplicity: &[HashMap], + multiplicity: &[FxHashMap], _input: &(), ) -> Result, ZKVMError> { let multiplicity = &multiplicity[OP::ROM_TYPE as usize]; diff --git a/ceno_zkvm/src/tables/ops/ops_impl.rs b/ceno_zkvm/src/tables/ops/ops_impl.rs index 72b80a548..7046de304 100644 --- a/ceno_zkvm/src/tables/ops/ops_impl.rs +++ b/ceno_zkvm/src/tables/ops/ops_impl.rs @@ -4,7 +4,7 @@ use ff_ext::{ExtensionField, SmallField}; use gkr_iop::error::CircuitBuilderError; use itertools::Itertools; use rayon::iter::{IndexedParallelIterator, ParallelIterator}; -use std::collections::HashMap; +use rustc_hash::FxHashMap; use witness::{InstancePaddingStrategy, RowMajorMatrix, set_fixed_val, set_val}; use crate::{ @@ -70,7 +70,7 @@ impl OpTableConfig { &self, num_witin: usize, num_structural_witin: usize, - multiplicity: &HashMap, + multiplicity: &FxHashMap, length: usize, ) -> Result, CircuitBuilderError> { assert_eq!(num_structural_witin, 1); diff --git a/ceno_zkvm/src/tables/program.rs b/ceno_zkvm/src/tables/program.rs index 3894828d9..96c4356a0 100644 --- a/ceno_zkvm/src/tables/program.rs +++ b/ceno_zkvm/src/tables/program.rs @@ -15,7 +15,8 @@ use itertools::Itertools; use multilinear_extensions::{Expression, Fixed, ToExpr, WitIn}; use p3::field::FieldAlgebra; use rayon::iter::{IndexedParallelIterator, ParallelIterator}; -use std::{collections::HashMap, marker::PhantomData}; +use rustc_hash::FxHashMap; +use std::marker::PhantomData; use witness::{ InstancePaddingStrategy, RowMajorMatrix, next_pow2_instance_padding, set_fixed_val, set_val, }; @@ -268,7 +269,7 @@ impl TableCircuit for ProgramTableCircuit { config: &Self::TableConfig, num_witin: usize, num_structural_witin: usize, - multiplicity: &[HashMap], + multiplicity: &[FxHashMap], program: &Program, ) -> Result, ZKVMError> { assert!(!program.instructions.is_empty()); diff --git a/ceno_zkvm/src/tables/ram/ram_circuit.rs b/ceno_zkvm/src/tables/ram/ram_circuit.rs index 249f70125..dcf2142fa 100644 --- a/ceno_zkvm/src/tables/ram/ram_circuit.rs +++ b/ceno_zkvm/src/tables/ram/ram_circuit.rs @@ -19,7 +19,8 @@ use gkr_iop::{ }; use itertools::Itertools; use multilinear_extensions::{Expression, StructuralWitIn, StructuralWitInType, ToExpr}; -use std::{collections::HashMap, marker::PhantomData, ops::Range}; +use rustc_hash::FxHashMap; +use std::{marker::PhantomData, ops::Range}; use witness::{InstancePaddingStrategy, RowMajorMatrix}; #[derive(Clone, Debug)] @@ -110,7 +111,7 @@ impl< config: &Self::TableConfig, num_witin: usize, num_structural_witin: usize, - _multiplicity: &[HashMap], + _multiplicity: &[FxHashMap], final_v: &Self::WitnessInput<'_>, ) -> Result, ZKVMError> { // assume returned table is well-formed include padding @@ -167,7 +168,7 @@ impl TableCirc config: &Self::TableConfig, num_witin: usize, num_structural_witin: usize, - _multiplicity: &[HashMap], + _multiplicity: &[FxHashMap], final_mem: &[MemFinalRecord], ) -> Result, ZKVMError> { // assume returned table is well-formed including padding @@ -294,7 +295,7 @@ impl< config: &Self::TableConfig, num_witin: usize, num_structural_witin: usize, - _multiplicity: &[HashMap], + _multiplicity: &[FxHashMap], data: &Self::WitnessInput<'_>, ) -> Result, ZKVMError> { // assume returned table is well-formed include padding @@ -380,7 +381,7 @@ impl TableCircuit for LocalFinalRamC config: &Self::TableConfig, num_witin: usize, num_structural_witin: usize, - _multiplicity: &[HashMap], + _multiplicity: &[FxHashMap], (shard_ctx, final_mem): &Self::WitnessInput<'_>, ) -> Result, ZKVMError> { // assume returned table is well-formed include padding diff --git a/ceno_zkvm/src/tables/range/range_circuit.rs b/ceno_zkvm/src/tables/range/range_circuit.rs index d98161fea..7bdec456b 100644 --- a/ceno_zkvm/src/tables/range/range_circuit.rs +++ b/ceno_zkvm/src/tables/range/range_circuit.rs @@ -1,6 +1,7 @@ //! Range tables as circuits with trait TableCircuit. -use std::{collections::HashMap, marker::PhantomData}; +use rustc_hash::FxHashMap; +use std::marker::PhantomData; use crate::{ circuit_builder::CircuitBuilder, @@ -68,7 +69,7 @@ impl TableCircuit config: &Self::TableConfig, num_witin: usize, num_structural_witin: usize, - multiplicity: &[HashMap], + multiplicity: &[FxHashMap], _input: &(), ) -> Result, ZKVMError> { let multiplicity = &multiplicity[LookupTable::Dynamic as usize]; @@ -149,7 +150,7 @@ impl], + multiplicity: &[FxHashMap], _input: &(), ) -> Result, ZKVMError> { let multiplicity = &multiplicity[R::ROM_TYPE as usize]; diff --git a/ceno_zkvm/src/tables/range/range_impl.rs b/ceno_zkvm/src/tables/range/range_impl.rs index a95664085..fa3901a6b 100644 --- a/ceno_zkvm/src/tables/range/range_impl.rs +++ b/ceno_zkvm/src/tables/range/range_impl.rs @@ -3,7 +3,7 @@ use ff_ext::{ExtensionField, SmallField}; use gkr_iop::{error::CircuitBuilderError, tables::LookupTable}; use rayon::iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator}; -use std::collections::HashMap; +use rustc_hash::FxHashMap; use witness::{InstancePaddingStrategy, RowMajorMatrix, set_val}; use crate::{ @@ -56,7 +56,7 @@ impl DynamicRangeTableConfig { &self, num_witin: usize, num_structural_witin: usize, - multiplicity: &HashMap, + multiplicity: &FxHashMap, max_bits: usize, ) -> Result<[RowMajorMatrix; 2], CircuitBuilderError> { let length = 1 << (max_bits + 1); @@ -158,7 +158,7 @@ impl DoubleRangeTableConfig { &self, num_witin: usize, num_structural_witin: usize, - multiplicity: &HashMap, + multiplicity: &FxHashMap, ) -> Result<[RowMajorMatrix; 2], CircuitBuilderError> { let length = 1 << (self.range_a_bits + self.range_b_bits); let mut witness: RowMajorMatrix = diff --git a/ceno_zkvm/src/tables/shard_ram.rs b/ceno_zkvm/src/tables/shard_ram.rs index 23897fce8..f2748a4f9 100644 --- a/ceno_zkvm/src/tables/shard_ram.rs +++ b/ceno_zkvm/src/tables/shard_ram.rs @@ -1,4 +1,5 @@ -use std::{collections::HashMap, iter::repeat_n, marker::PhantomData}; +use rustc_hash::FxHashMap; +use std::{iter::repeat_n, marker::PhantomData}; use crate::{ Value, @@ -474,7 +475,7 @@ impl TableCircuit for ShardRamCircuit { config: &Self::TableConfig, num_witin: usize, num_structural_witin: usize, - _multiplicity: &[HashMap], + _multiplicity: &[FxHashMap], steps: &Self::WitnessInput<'_>, ) -> Result, ZKVMError> { if steps.is_empty() { diff --git a/gkr_iop/Cargo.toml b/gkr_iop/Cargo.toml index f93ef4335..cffea08eb 100644 --- a/gkr_iop/Cargo.toml +++ b/gkr_iop/Cargo.toml @@ -22,6 +22,7 @@ once_cell.workspace = true p3.workspace = true rand.workspace = true rayon.workspace = true +rustc-hash.workspace = true serde.workspace = true smallvec.workspace = true strum.workspace = true diff --git a/gkr_iop/src/utils/lk_multiplicity.rs b/gkr_iop/src/utils/lk_multiplicity.rs index 7dded4e70..62de189aa 100644 --- a/gkr_iop/src/utils/lk_multiplicity.rs +++ b/gkr_iop/src/utils/lk_multiplicity.rs @@ -1,8 +1,8 @@ use ff_ext::SmallField; use itertools::izip; +use rustc_hash::FxHashMap; use std::{ cell::RefCell, - collections::HashMap, fmt::Debug, hash::Hash, mem::{self}, @@ -16,7 +16,7 @@ use crate::tables::{ ops::{AndTable, LtuTable, OrTable, PowTable, XorTable}, }; -pub type MultiplicityRaw = [HashMap; mem::variant_count::()]; +pub type MultiplicityRaw = [FxHashMap; mem::variant_count::()]; #[derive(Clone, Default, Debug)] pub struct Multiplicity(pub MultiplicityRaw); From 39078ce077171d6e4170e15b3a0678f89ada4213 Mon Sep 17 00:00:00 2001 From: Velaciela Date: Fri, 13 Mar 2026 09:54:38 +0800 Subject: [PATCH 37/73] profile --- ceno_zkvm/src/e2e.rs | 226 ++++++++++----------- ceno_zkvm/src/instructions/riscv/rv32im.rs | 83 ++++---- 2 files changed, 156 insertions(+), 153 deletions(-) diff --git a/ceno_zkvm/src/e2e.rs b/ceno_zkvm/src/e2e.rs index ea92ff6f7..404a28b14 100644 --- a/ceno_zkvm/src/e2e.rs +++ b/ceno_zkvm/src/e2e.rs @@ -1413,17 +1413,17 @@ pub fn generate_witness<'a, E: ExtensionField>( shard_id = shard_ctx_builder.cur_shard_id ) .in_scope(|| { - let time = std::time::Instant::now(); instrunction_dispatch_ctx.begin_shard(); let (mut shard_ctx, shard_summary) = - match shard_ctx_builder.position_next_shard( - &mut step_iter, - |idx, record| instrunction_dispatch_ctx.ingest_step(idx, record), - ) { + match info_span!("position_next_shard").in_scope(|| { + shard_ctx_builder.position_next_shard( + &mut step_iter, + |idx, record| instrunction_dispatch_ctx.ingest_step(idx, record), + ) + }) { Some(result) => result, None => return None, }; - tracing::debug!("position_next_shard finish in {:?}", time.elapsed()); let shard_steps = step_iter.shard_steps(); shard_ctx.syscall_witnesses = Arc::new(step_iter.syscall_witnesses().to_vec()); @@ -1474,21 +1474,20 @@ pub fn generate_witness<'a, E: ExtensionField>( } } - let time = std::time::Instant::now(); let debug_compare_e2e_shard = std::env::var_os("CENO_GPU_DEBUG_COMPARE_E2E_SHARD").is_some(); let debug_shard_ctx_template = debug_compare_e2e_shard.then(|| clone_debug_shard_ctx(&shard_ctx)); - system_config - .config - .assign_opcode_circuit( - &system_config.zkvm_cs, - &mut shard_ctx, - &mut instrunction_dispatch_ctx, - shard_steps, - &mut zkvm_witness, - ) - .unwrap(); - tracing::debug!("assign_opcode_circuit finish in {:?}", time.elapsed()); + info_span!("assign_opcode_circuits").in_scope(|| { + system_config + .config + .assign_opcode_circuit( + &system_config.zkvm_cs, + &mut shard_ctx, + &mut instrunction_dispatch_ctx, + shard_steps, + &mut zkvm_witness, + ) + }).unwrap(); // Free GPU shard_steps cache after all opcode circuits are done. #[cfg(feature = "gpu")] @@ -1503,19 +1502,20 @@ pub fn generate_witness<'a, E: ExtensionField>( } } - let time = std::time::Instant::now(); - system_config - .dummy_config - .assign_opcode_circuit( - &system_config.zkvm_cs, - &mut shard_ctx, - &instrunction_dispatch_ctx, - shard_steps, - &mut zkvm_witness, - ) - .unwrap(); - tracing::debug!("assign_dummy_config finish in {:?}", time.elapsed()); - zkvm_witness.finalize_lk_multiplicities(); + info_span!("assign_dummy_circuits").in_scope(|| { + system_config + .dummy_config + .assign_opcode_circuit( + &system_config.zkvm_cs, + &mut shard_ctx, + &instrunction_dispatch_ctx, + shard_steps, + &mut zkvm_witness, + ) + }).unwrap(); + info_span!("finalize_lk_multiplicities").in_scope(|| { + zkvm_witness.finalize_lk_multiplicities(); + }); if let Some(mut cpu_shard_ctx) = debug_shard_ctx_template { let mut cpu_witness = ZKVMWitnesses::default(); @@ -1594,110 +1594,102 @@ pub fn generate_witness<'a, E: ExtensionField>( // ├─ later rw? NO -> ShardRAM + LocalFinalize // └─ later rw? YES -> ShardRAM - let time = std::time::Instant::now(); - system_config - .config - .assign_table_circuit(&system_config.zkvm_cs, &mut zkvm_witness) - .unwrap(); - tracing::debug!("assign_table_circuit finish in {:?}", time.elapsed()); + info_span!("assign_table_circuits").in_scope(|| { + system_config + .config + .assign_table_circuit(&system_config.zkvm_cs, &mut zkvm_witness) + }).unwrap(); + + info_span!("assign_init_table").in_scope(|| { + if shard_ctx.is_first_shard() { + system_config + .mmu_config + .assign_init_table_circuit( + &system_config.zkvm_cs, + &mut zkvm_witness, + &pi, + &emul_result.final_mem_state.reg, + &emul_result.final_mem_state.mem, + &emul_result.final_mem_state.io, + &emul_result.final_mem_state.stack, + ) + } else { + system_config + .mmu_config + .assign_init_table_circuit( + &system_config.zkvm_cs, + &mut zkvm_witness, + &pi, + &[], + &[], + &[], + &[], + ) + } + }).unwrap(); - if shard_ctx.is_first_shard() { - let time = std::time::Instant::now(); + info_span!("assign_dynamic_init_table").in_scope(|| { system_config .mmu_config - .assign_init_table_circuit( + .assign_dynamic_init_table_circuit( &system_config.zkvm_cs, &mut zkvm_witness, &pi, - &emul_result.final_mem_state.reg, - &emul_result.final_mem_state.mem, - &emul_result.final_mem_state.io, - &emul_result.final_mem_state.stack, + &emul_result.final_mem_state.hints, + &emul_result.final_mem_state.heap, ) - .unwrap(); - tracing::debug!("assign_init_table_circuit finish in {:?}", time.elapsed()); - } else { + }).unwrap(); + + info_span!("assign_continuation").in_scope(|| { system_config .mmu_config - .assign_init_table_circuit( + .assign_continuation_circuit( &system_config.zkvm_cs, + &shard_ctx, &mut zkvm_witness, &pi, - &[], - &[], - &[], - &[], + &emul_result.final_mem_state.reg, + &emul_result.final_mem_state.mem, + &emul_result.final_mem_state.io, + &emul_result.final_mem_state.hints, + &emul_result.final_mem_state.stack, + &emul_result.final_mem_state.heap, ) - .unwrap(); - } + }).unwrap(); - let time = std::time::Instant::now(); - system_config - .mmu_config - .assign_dynamic_init_table_circuit( - &system_config.zkvm_cs, - &mut zkvm_witness, - &pi, - &emul_result.final_mem_state.hints, - &emul_result.final_mem_state.heap, - ) - .unwrap(); - tracing::debug!( - "assign_dynamic_init_table_circuit finish in {:?}", - time.elapsed() - ); - let time = std::time::Instant::now(); - system_config - .mmu_config - .assign_continuation_circuit( - &system_config.zkvm_cs, - &shard_ctx, - &mut zkvm_witness, - &pi, - &emul_result.final_mem_state.reg, - &emul_result.final_mem_state.mem, - &emul_result.final_mem_state.io, - &emul_result.final_mem_state.hints, - &emul_result.final_mem_state.stack, - &emul_result.final_mem_state.heap, - ) - .unwrap(); - tracing::debug!("assign_continuation_circuit finish in {:?}", time.elapsed()); - - let time = std::time::Instant::now(); - zkvm_witness - .assign_table_circuit::>( - &system_config.zkvm_cs, - &system_config.prog_config, - &program, - ) - .unwrap(); - tracing::debug!("assign_table_circuit finish in {:?}", time.elapsed()); + info_span!("assign_program_table").in_scope(|| { + zkvm_witness + .assign_table_circuit::>( + &system_config.zkvm_cs, + &system_config.prog_config, + &program, + ) + }).unwrap(); if let Some(shard_ram_witnesses) = zkvm_witness.get_witness(&ShardRamCircuit::::name()) { - let time = std::time::Instant::now(); - let shard_ram_ec_sum: SepticPoint = shard_ram_witnesses - .iter() - .filter(|shard_ram_witness| shard_ram_witness.num_instances[0] > 0) - .map(|shard_ram_witness| { - ShardRamCircuit::::extract_ec_sum( - &system_config.mmu_config.ram_bus_circuit, - &shard_ram_witness.witness_rmms[0], - ) - }) - .sum(); - - let xy = shard_ram_ec_sum - .x - .0 - .iter() - .chain(shard_ram_ec_sum.y.0.iter()); - for (f, v) in xy.zip_eq(pi.shard_rw_sum.as_mut_slice()) { - *v = f.to_canonical_u64() as u32; - } - tracing::debug!("update pi shard_rw_sum finish in {:?}", time.elapsed()); + info_span!("shard_ram_ec_sum").in_scope(|| { + let shard_ram_ec_sum: SepticPoint = shard_ram_witnesses + .iter() + .filter(|shard_ram_witness| shard_ram_witness.num_instances[0] > 0) + .map(|shard_ram_witness| { + ShardRamCircuit::::extract_ec_sum( + &system_config.mmu_config.ram_bus_circuit, + &shard_ram_witness.witness_rmms[0], + ) + }) + .sum(); + + let xy = shard_ram_ec_sum + .x + .0 + .iter() + .chain(shard_ram_ec_sum.y.0.iter()); + for (f, v) in xy.zip_eq(pi.shard_rw_sum.as_mut_slice()) { + *v = f.to_canonical_u64() as u32; + } + }); } Some((zkvm_witness, shard_ctx, pi)) diff --git a/ceno_zkvm/src/instructions/riscv/rv32im.rs b/ceno_zkvm/src/instructions/riscv/rv32im.rs index 091ce3000..1227032f0 100644 --- a/ceno_zkvm/src/instructions/riscv/rv32im.rs +++ b/ceno_zkvm/src/instructions/riscv/rv32im.rs @@ -69,6 +69,7 @@ use std::{ collections::{BTreeMap, HashMap}, }; use strum::{EnumCount, IntoEnumIterator}; +use tracing::info_span; pub mod mmu; @@ -681,13 +682,17 @@ impl Rv32imConfig { let records = instrunction_dispatch_ctx .records_for_kinds::() .unwrap_or(&[]); - witness.assign_opcode_circuit::<$instruction>( - cs, - shard_ctx, - &self.$config, - shard_steps, - records, - )?; + let n = records.len(); + info_span!("assign_chip", chip = %<$instruction>::name(), n) + .in_scope(|| { + witness.assign_opcode_circuit::<$instruction>( + cs, + shard_ctx, + &self.$config, + shard_steps, + records, + ) + })?; }}; } @@ -696,13 +701,17 @@ impl Rv32imConfig { let records = instrunction_dispatch_ctx .records_for_ecall_code($code) .unwrap_or(&[]); - witness.assign_opcode_circuit::<$instruction>( - cs, - shard_ctx, - &self.$config, - shard_steps, - records, - )?; + let n = records.len(); + info_span!("assign_chip", chip = %<$instruction>::name(), n) + .in_scope(|| { + witness.assign_opcode_circuit::<$instruction>( + cs, + shard_ctx, + &self.$config, + shard_steps, + records, + ) + })?; }}; } @@ -846,22 +855,20 @@ impl Rv32imConfig { cs: &ZKVMConstraintSystem, witness: &mut ZKVMWitnesses, ) -> Result<(), ZKVMError> { - witness.assign_table_circuit::>( - cs, - &self.dynamic_range_config, - &(), - )?; - witness.assign_table_circuit::>( - cs, - &self.double_u8_range_config, - &(), - )?; - witness.assign_table_circuit::>(cs, &self.and_table_config, &())?; - witness.assign_table_circuit::>(cs, &self.or_table_config, &())?; - witness.assign_table_circuit::>(cs, &self.xor_table_config, &())?; - witness.assign_table_circuit::>(cs, &self.ltu_config, &())?; + macro_rules! assign_table { + ($table:ty, $config:expr) => { + info_span!("assign_table", table = %<$table>::name()) + .in_scope(|| witness.assign_table_circuit::<$table>(cs, $config, &()))?; + }; + } + assign_table!(DynamicRangeTableCircuit, &self.dynamic_range_config); + assign_table!(DoubleU8TableCircuit, &self.double_u8_range_config); + assign_table!(AndTableCircuit, &self.and_table_config); + assign_table!(OrTableCircuit, &self.or_table_config); + assign_table!(XorTableCircuit, &self.xor_table_config); + assign_table!(LtuTableCircuit, &self.ltu_config); #[cfg(not(feature = "u16limb_circuit"))] - witness.assign_table_circuit::>(cs, &self.pow_config, &())?; + assign_table!(PowTableCircuit, &self.pow_config); Ok(()) } @@ -1016,13 +1023,17 @@ impl DummyExtraConfig { let phantom_log_pc_cycle_records = instrunction_dispatch_ctx .records_for_ecall_code(LogPcCycleSpec::CODE) .unwrap_or(&[]); - witness.assign_opcode_circuit::>( - cs, - shard_ctx, - &self.phantom_log_pc_cycle, - shard_steps, - phantom_log_pc_cycle_records, - )?; + let n = phantom_log_pc_cycle_records.len(); + info_span!("assign_chip", chip = %LargeEcallDummy::::name(), n) + .in_scope(|| { + witness.assign_opcode_circuit::>( + cs, + shard_ctx, + &self.phantom_log_pc_cycle, + shard_steps, + phantom_log_pc_cycle_records, + ) + })?; Ok(()) } } From a8cc5a30b4b67bfbbfa03a5a45fbb67d6679c3f6 Mon Sep 17 00:00:00 2001 From: Velaciela Date: Tue, 17 Mar 2026 09:56:37 +0800 Subject: [PATCH 38/73] batch_continuation_ec --- .../src/instructions/riscv/gpu/witgen_gpu.rs | 116 +++++++++++ ceno_zkvm/src/structs.rs | 191 ++++++++++-------- 2 files changed, 219 insertions(+), 88 deletions(-) diff --git a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs index 5048eddd1..c5b2e5502 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs @@ -2290,6 +2290,122 @@ fn lookup_table_name(table_idx: usize) -> &'static str { } } +/// Batch compute EC points for continuation circuit ShardRamRecords on GPU. +/// +/// Converts ShardRamRecords to GPU format, launches the `batch_continuation_ec` +/// kernel to compute Poseidon2 + SepticCurve on device, and converts results +/// back to ShardRamInput (with EC points). +/// +/// Returns (write_inputs, read_inputs) maintaining the write-before-read ordering +/// invariant required by ShardRamCircuit::assign_instances. +pub fn gpu_batch_continuation_ec( + write_records: &[(crate::tables::ShardRamRecord, &'static str)], + read_records: &[(crate::tables::ShardRamRecord, &'static str)], +) -> Result< + ( + Vec>, + Vec>, + ), + ZKVMError, +> { + use crate::tables::ShardRamInput; + use gkr_iop::gpu::get_cuda_hal; + + let hal = get_cuda_hal().map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU not available for batch EC: {e}").into()) + })?; + + let total = write_records.len() + read_records.len(); + if total == 0 { + return Ok((vec![], vec![])); + } + + // Convert ShardRamRecords to GpuShardRamRecord format + let mut gpu_records: Vec = Vec::with_capacity(total); + for (rec, _name) in write_records.iter().chain(read_records.iter()) { + gpu_records.push(shard_ram_record_to_gpu(rec)); + } + + // GPU batch EC computation + let result = info_span!("gpu_batch_ec", n = total).in_scope(|| { + hal.batch_continuation_ec(&gpu_records) + }).map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU batch EC failed: {e}").into()) + })?; + + // Convert back to ShardRamInput, split into writes and reads + let mut write_inputs = Vec::with_capacity(write_records.len()); + let mut read_inputs = Vec::with_capacity(read_records.len()); + + for (i, gpu_rec) in result.iter().enumerate() { + let (rec, name) = if i < write_records.len() { + (&write_records[i].0, write_records[i].1) + } else { + let ri = i - write_records.len(); + (&read_records[ri].0, read_records[ri].1) + }; + + let ec_point = gpu_shard_ram_record_to_ec_point::(gpu_rec); + let input = ShardRamInput { + name, + record: rec.clone(), + ec_point, + }; + + if i < write_records.len() { + write_inputs.push(input); + } else { + read_inputs.push(input); + } + } + + Ok((write_inputs, read_inputs)) +} + +/// Convert a ShardRamRecord to GpuShardRamRecord (metadata only, EC fields zeroed). +fn shard_ram_record_to_gpu(rec: &crate::tables::ShardRamRecord) -> GpuShardRamRecord { + GpuShardRamRecord { + addr: rec.addr, + ram_type: match rec.ram_type { + RAMType::Register => 1, + RAMType::Memory => 2, + _ => 0, + }, + value: rec.value, + _pad0: 0, + shard: rec.shard, + local_clk: rec.local_clk, + global_clk: rec.global_clk, + is_to_write_set: if rec.is_to_write_set { 1 } else { 0 }, + nonce: 0, + point_x: [0; 7], + point_y: [0; 7], + } +} + +/// Convert a GPU-computed GpuShardRamRecord to ECPoint. +fn gpu_shard_ram_record_to_ec_point( + gpu_rec: &GpuShardRamRecord, +) -> crate::tables::ECPoint { + use crate::scheme::septic_curve::{SepticExtension, SepticPoint}; + + let mut point_x_arr = [E::BaseField::ZERO; 7]; + let mut point_y_arr = [E::BaseField::ZERO; 7]; + for j in 0..7 { + point_x_arr[j] = E::BaseField::from_canonical_u32(gpu_rec.point_x[j]); + point_y_arr[j] = E::BaseField::from_canonical_u32(gpu_rec.point_y[j]); + } + + let x = SepticExtension(point_x_arr); + let y = SepticExtension(point_y_arr); + let point = SepticPoint::from_affine(x, y); + + crate::tables::ECPoint { + nonce: gpu_rec.nonce, + point, + } +} + fn gpu_lk_counters_to_multiplicity(counters: LkResult) -> Result, ZKVMError> { let mut tables: [FxHashMap; 8] = Default::default(); diff --git a/ceno_zkvm/src/structs.rs b/ceno_zkvm/src/structs.rs index 4fcf19fa4..c3d21c19b 100644 --- a/ceno_zkvm/src/structs.rs +++ b/ceno_zkvm/src/structs.rs @@ -477,49 +477,46 @@ impl ZKVMWitnesses { ), config: & as TableCircuit>::TableConfig, ) -> Result<(), ZKVMError> { - let perm = ::get_default_perm(); let addr_accessed = shard_ctx.get_addr_accessed(); - // future shard needed records := shard_ctx.write_records ∪ // - // (shard_ctx.after_current_shard_cycle(mem_record.cycle) && !addr_accessed.contains(&waddr)) + // GPU EC records: convert raw bytes to ShardRamInput (EC points already computed on GPU) + // Partition into writes and reads to maintain the ordering invariant required by + // ShardRamCircuit::assign_instances (writes first, reads after). + let (gpu_ec_writes, gpu_ec_reads): (Vec<_>, Vec<_>) = + if shard_ctx.has_gpu_ec_records() { + gpu_ec_records_to_shard_ram_inputs::(&shard_ctx.gpu_ec_records) + .into_iter() + .partition(|input| input.record.is_to_write_set) + } else { + (vec![], vec![]) + }; - // 1. process final mem which - // 1.1 init in first shard - // 1.2 not accessed in first shard - // 1.3 accessed in future shard - let first_shard_access_later_records = if shard_ctx.is_first_shard() { - final_mem - .par_iter() - // only process no range restriction memory record - // for range specified it means dynamic init across different shard - .filter(|(_, range, _)| range.is_none()) - .flat_map(|(mem_name, _, final_mem)| { - final_mem.par_iter().filter_map(|mem_record| { - let (waddr, addr) = Self::mem_addresses(mem_record); - Self::make_cross_shard_input( - mem_name, - mem_record, - waddr, - addr, - shard_ctx, - &addr_accessed, - &perm, - ) + // Collect cross-shard records (filter only, no EC computation yet) + let first_shard_access_later_recs: Vec<(ShardRamRecord, &'static str)> = + if shard_ctx.is_first_shard() { + final_mem + .par_iter() + .filter(|(_, range, _)| range.is_none()) + .flat_map(|(mem_name, _, final_mem)| { + final_mem.par_iter().filter_map(|mem_record| { + let (waddr, addr) = Self::mem_addresses(mem_record); + Self::make_cross_shard_record( + mem_name, + mem_record, + waddr, + addr, + shard_ctx, + &addr_accessed, + ) + }) }) - }) - .collect() - } else { - vec![] - }; + .collect() + } else { + vec![] + }; - // 2. process records which - // 2.1 init within current shard - // 2.2 not accessed in current shard - // 2.3 access by later shards. - let current_shard_access_later = final_mem + let current_shard_access_later_recs: Vec<(ShardRamRecord, &'static str)> = final_mem .par_iter() - // only process range-restricted memory record - // for range specified it means dynamic init across different shard .filter(|(_, range, _)| range.is_some()) .flat_map(|(mem_name, range, final_mem)| { let range = range.as_ref().unwrap(); @@ -528,63 +525,85 @@ impl ZKVMWitnesses { if !range.contains(&addr) { return None; } - Self::make_cross_shard_input( + Self::make_cross_shard_record( mem_name, mem_record, waddr, addr, shard_ctx, &addr_accessed, - &perm, ) }) }) - .collect::>(); - - // GPU EC records: convert raw bytes to ShardRamInput (EC points already computed on GPU) - // Partition into writes and reads to maintain the ordering invariant required by - // ShardRamCircuit::assign_instances (writes first, reads after). - let (gpu_ec_writes, gpu_ec_reads): (Vec<_>, Vec<_>) = - if shard_ctx.has_gpu_ec_records() { - gpu_ec_records_to_shard_ram_inputs::(&shard_ctx.gpu_ec_records) - .into_iter() - .partition(|input| input.record.is_to_write_set) - } else { - (vec![], vec![]) - }; + .collect(); - let global_input = shard_ctx + // Collect write_records and read_records as (ShardRamRecord, name) pairs + let write_record_pairs: Vec<(ShardRamRecord, &'static str)> = shard_ctx .write_records() - .par_iter() + .iter() .flat_map(|records| { - // global write -> local reads - records.par_iter().map(|(vma, record)| { - let global_write: ShardRamRecord = (vma, record, true).into(); - let ec_point: ECPoint = global_write.to_ec_point(&perm); - ShardRamInput { - name: "current_shard_external_write", - record: global_write, - ec_point, - } + records.iter().map(|(vma, record)| { + ((vma, record, true).into(), "current_shard_external_write") }) }) - .chain(first_shard_access_later_records.into_par_iter()) - .chain(current_shard_access_later.into_par_iter()) - .chain(gpu_ec_writes.into_par_iter()) - .chain(shard_ctx.read_records().par_iter().flat_map(|records| { - // global read -> local write - records.par_iter().map(|(vma, record)| { - let global_read: ShardRamRecord = (vma, record, false).into(); - let ec_point: ECPoint = global_read.to_ec_point(&perm); - ShardRamInput { - name: "current_shard_external_read", - record: global_read, - ec_point, - } + .chain(first_shard_access_later_recs) + .chain(current_shard_access_later_recs) + .collect(); + + let read_record_pairs: Vec<(ShardRamRecord, &'static str)> = shard_ctx + .read_records() + .iter() + .flat_map(|records| { + records.iter().map(|(vma, record)| { + ((vma, record, false).into(), "current_shard_external_read") }) - })) - .chain(gpu_ec_reads.into_par_iter()) - .collect::>(); + }) + .collect(); + + // Compute EC points: GPU path (fast) or CPU fallback + let global_input = { + #[cfg(feature = "gpu")] + let ec_result = { + use crate::instructions::riscv::gpu::witgen_gpu::gpu_batch_continuation_ec; + gpu_batch_continuation_ec::(&write_record_pairs, &read_record_pairs) + .ok() + }; + #[cfg(not(feature = "gpu"))] + let ec_result: Option<(Vec>, Vec>)> = None; + + if let Some((computed_writes, computed_reads)) = ec_result { + // GPU path: chain computed EC with pre-computed GPU EC records + computed_writes + .into_iter() + .chain(gpu_ec_writes) + .chain(computed_reads) + .chain(gpu_ec_reads) + .collect::>() + } else { + // CPU fallback: compute EC points with Poseidon2 permutation + let perm = ::get_default_perm(); + let cpu_writes: Vec> = write_record_pairs + .into_par_iter() + .map(|(record, name)| { + let ec_point = record.to_ec_point(&perm); + ShardRamInput { name, record, ec_point } + }) + .collect(); + let cpu_reads: Vec> = read_record_pairs + .into_par_iter() + .map(|(record, name)| { + let ec_point = record.to_ec_point(&perm); + ShardRamInput { name, record, ec_point } + }) + .collect(); + cpu_writes + .into_iter() + .chain(gpu_ec_writes) + .chain(cpu_reads) + .chain(gpu_ec_reads) + .collect::>() + } + }; if tracing::enabled!(Level::DEBUG) { let total = global_input.len() as f64; @@ -707,16 +726,17 @@ impl ZKVMWitnesses { } } + /// Filter and construct a cross-shard ShardRamRecord without EC computation. + /// Used by the GPU path where EC is computed in batch on device. #[inline(always)] - fn make_cross_shard_input( + fn make_cross_shard_record( mem_name: &'static str, mem_record: &MemFinalRecord, waddr: WordAddr, addr: u32, shard_ctx: &ShardContext, addr_accessed: &FxHashSet, - perm: &<::BaseField as PoseidonField>::P, - ) -> Option> { + ) -> Option<(ShardRamRecord, &'static str)> { if addr_accessed.contains(&waddr) || !shard_ctx.after_current_shard_cycle(mem_record.cycle) { return None; @@ -735,12 +755,7 @@ impl ZKVMWitnesses { global_clk: 0, is_to_write_set: true, }; - let ec_point: ECPoint = global_write.to_ec_point(perm); - Some(ShardRamInput { - name: mem_name, - record: global_write, - ec_point, - }) + Some((global_write, mem_name)) } } From 1fe145bee426b96fd50e6edf8db7adad0a4afab2 Mon Sep 17 00:00:00 2001 From: Velaciela Date: Tue, 17 Mar 2026 09:56:37 +0800 Subject: [PATCH 39/73] try-perf --- ceno_zkvm/src/e2e.rs | 5 +- .../src/instructions/riscv/rv32im/mmu.rs | 24 +- ceno_zkvm/src/structs.rs | 256 ++++++++++-------- 3 files changed, 158 insertions(+), 127 deletions(-) diff --git a/ceno_zkvm/src/e2e.rs b/ceno_zkvm/src/e2e.rs index 404a28b14..bfb4805ba 100644 --- a/ceno_zkvm/src/e2e.rs +++ b/ceno_zkvm/src/e2e.rs @@ -602,15 +602,16 @@ impl<'a> ShardContext<'a> { /// merge addr accessed in different threads pub fn get_addr_accessed(&self) -> FxHashSet { - let mut merged = FxHashSet::default(); if let Either::Left(addr_accessed_tbs) = &self.addr_accessed_tbs { + let total: usize = addr_accessed_tbs.iter().map(|v| v.len()).sum(); + let mut merged = FxHashSet::with_capacity_and_hasher(total, Default::default()); for addrs in addr_accessed_tbs { merged.extend(addrs.iter().copied()); } + merged } else { panic!("invalid type"); } - merged } /// Splits a total count `num_shards` into up to `num_provers` non-empty parts, distributing as evenly as possible. diff --git a/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs b/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs index b3f96feb0..3d4f1efc0 100644 --- a/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs +++ b/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs @@ -199,16 +199,20 @@ impl MmuConfig { .filter(|(_, _, record)| !record.is_empty()) .collect_vec(); - witness.assign_table_circuit::>( - cs, - &self.local_final_circuit, - &(shard_ctx, all_records.as_slice()), - )?; - witness.assign_shared_circuit( - cs, - &(shard_ctx, all_records.as_slice()), - &self.ram_bus_circuit, - )?; + tracing::info_span!("local_final_circuit").in_scope(|| { + witness.assign_table_circuit::>( + cs, + &self.local_final_circuit, + &(shard_ctx, all_records.as_slice()), + ) + })?; + tracing::info_span!("shared_circuit").in_scope(|| { + witness.assign_shared_circuit( + cs, + &(shard_ctx, all_records.as_slice()), + &self.ram_bus_circuit, + ) + })?; Ok(()) } diff --git a/ceno_zkvm/src/structs.rs b/ceno_zkvm/src/structs.rs index c3d21c19b..26ce927f3 100644 --- a/ceno_zkvm/src/structs.rs +++ b/ceno_zkvm/src/structs.rs @@ -477,88 +477,97 @@ impl ZKVMWitnesses { ), config: & as TableCircuit>::TableConfig, ) -> Result<(), ZKVMError> { - let addr_accessed = shard_ctx.get_addr_accessed(); + use tracing::info_span; + + let addr_accessed = info_span!("get_addr_accessed").in_scope(|| { + shard_ctx.get_addr_accessed() + }); // GPU EC records: convert raw bytes to ShardRamInput (EC points already computed on GPU) // Partition into writes and reads to maintain the ordering invariant required by // ShardRamCircuit::assign_instances (writes first, reads after). let (gpu_ec_writes, gpu_ec_reads): (Vec<_>, Vec<_>) = - if shard_ctx.has_gpu_ec_records() { - gpu_ec_records_to_shard_ram_inputs::(&shard_ctx.gpu_ec_records) - .into_iter() - .partition(|input| input.record.is_to_write_set) - } else { - (vec![], vec![]) - }; + info_span!("gpu_ec_convert", n = shard_ctx.gpu_ec_records.len() / 104).in_scope(|| { + if shard_ctx.has_gpu_ec_records() { + gpu_ec_records_to_shard_ram_inputs::(&shard_ctx.gpu_ec_records) + .into_iter() + .partition(|input| input.record.is_to_write_set) + } else { + (vec![], vec![]) + } + }); // Collect cross-shard records (filter only, no EC computation yet) - let first_shard_access_later_recs: Vec<(ShardRamRecord, &'static str)> = - if shard_ctx.is_first_shard() { - final_mem - .par_iter() - .filter(|(_, range, _)| range.is_none()) - .flat_map(|(mem_name, _, final_mem)| { - final_mem.par_iter().filter_map(|mem_record| { - let (waddr, addr) = Self::mem_addresses(mem_record); - Self::make_cross_shard_record( - mem_name, - mem_record, - waddr, - addr, - shard_ctx, - &addr_accessed, - ) + let (write_record_pairs, read_record_pairs) = info_span!("collect_records").in_scope(|| { + let first_shard_access_later_recs: Vec<(ShardRamRecord, &'static str)> = + if shard_ctx.is_first_shard() { + final_mem + .par_iter() + .filter(|(_, range, _)| range.is_none()) + .flat_map(|(mem_name, _, final_mem)| { + final_mem.par_iter().filter_map(|mem_record| { + let (waddr, addr) = Self::mem_addresses(mem_record); + Self::make_cross_shard_record( + mem_name, + mem_record, + waddr, + addr, + shard_ctx, + &addr_accessed, + ) + }) }) - }) - .collect() - } else { - vec![] - }; + .collect() + } else { + vec![] + }; - let current_shard_access_later_recs: Vec<(ShardRamRecord, &'static str)> = final_mem - .par_iter() - .filter(|(_, range, _)| range.is_some()) - .flat_map(|(mem_name, range, final_mem)| { - let range = range.as_ref().unwrap(); - final_mem.par_iter().filter_map(|mem_record| { - let (waddr, addr) = Self::mem_addresses(mem_record); - if !range.contains(&addr) { - return None; - } - Self::make_cross_shard_record( - mem_name, - mem_record, - waddr, - addr, - shard_ctx, - &addr_accessed, - ) + let current_shard_access_later_recs: Vec<(ShardRamRecord, &'static str)> = final_mem + .par_iter() + .filter(|(_, range, _)| range.is_some()) + .flat_map(|(mem_name, range, final_mem)| { + let range = range.as_ref().unwrap(); + final_mem.par_iter().filter_map(|mem_record| { + let (waddr, addr) = Self::mem_addresses(mem_record); + if !range.contains(&addr) { + return None; + } + Self::make_cross_shard_record( + mem_name, + mem_record, + waddr, + addr, + shard_ctx, + &addr_accessed, + ) + }) }) - }) - .collect(); + .collect(); - // Collect write_records and read_records as (ShardRamRecord, name) pairs - let write_record_pairs: Vec<(ShardRamRecord, &'static str)> = shard_ctx - .write_records() - .iter() - .flat_map(|records| { - records.iter().map(|(vma, record)| { - ((vma, record, true).into(), "current_shard_external_write") + let write_record_pairs: Vec<(ShardRamRecord, &'static str)> = shard_ctx + .write_records() + .iter() + .flat_map(|records| { + records.iter().map(|(vma, record)| { + ((vma, record, true).into(), "current_shard_external_write") + }) }) - }) - .chain(first_shard_access_later_recs) - .chain(current_shard_access_later_recs) - .collect(); + .chain(first_shard_access_later_recs) + .chain(current_shard_access_later_recs) + .collect(); - let read_record_pairs: Vec<(ShardRamRecord, &'static str)> = shard_ctx - .read_records() - .iter() - .flat_map(|records| { - records.iter().map(|(vma, record)| { - ((vma, record, false).into(), "current_shard_external_read") + let read_record_pairs: Vec<(ShardRamRecord, &'static str)> = shard_ctx + .read_records() + .iter() + .flat_map(|records| { + records.iter().map(|(vma, record)| { + ((vma, record, false).into(), "current_shard_external_read") + }) }) - }) - .collect(); + .collect(); + + (write_record_pairs, read_record_pairs) + }); // Compute EC points: GPU path (fast) or CPU fallback let global_input = { @@ -663,29 +672,32 @@ impl ZKVMWitnesses { assert!(self.combined_lk_mlt.is_some()); let cs = cs.get_cs(&ShardRamCircuit::::name()).unwrap(); - let circuit_inputs = global_input - .par_chunks(shard_ctx.max_num_cross_shard_accesses) - .map(|shard_accesses| { - let witness = ShardRamCircuit::assign_instances( - config, - cs.zkvm_v1_css.num_witin as usize, - cs.zkvm_v1_css.num_structural_witin as usize, - self.combined_lk_mlt.as_ref().unwrap(), - shard_accesses, - )?; - let num_reads = shard_accesses - .par_iter() - .filter(|access| access.record.is_to_write_set) - .count(); - let num_writes = shard_accesses.len() - num_reads; - - Ok(ChipInput::new( - ShardRamCircuit::::name(), - witness, - vec![num_reads, num_writes], - )) - }) - .collect::, ZKVMError>>()?; + let n_global = global_input.len(); + let circuit_inputs = info_span!("shard_ram_assign_instances", n = n_global).in_scope(|| { + global_input + .par_chunks(shard_ctx.max_num_cross_shard_accesses) + .map(|shard_accesses| { + let witness = ShardRamCircuit::assign_instances( + config, + cs.zkvm_v1_css.num_witin as usize, + cs.zkvm_v1_css.num_structural_witin as usize, + self.combined_lk_mlt.as_ref().unwrap(), + shard_accesses, + )?; + let num_reads = shard_accesses + .par_iter() + .filter(|access| access.record.is_to_write_set) + .count(); + let num_writes = shard_accesses.len() - num_reads; + + Ok(ChipInput::new( + ShardRamCircuit::::name(), + witness, + vec![num_reads, num_writes], + )) + }) + .collect::, ZKVMError>>() + })?; // set num_read, num_write as separate instance assert!( self.witnesses @@ -915,39 +927,53 @@ fn gpu_ec_records_to_shard_ram_inputs( use gkr_iop::RAMType; use p3::field::FieldAlgebra; - // GpuShardRamRecord layout (104 bytes, 8-byte aligned): - // addr: u32 (0), ram_type: u32 (4), value: u32 (8), _pad: u32 (12), - // shard: u64 (16), local_clk: u64 (24), global_clk: u64 (32), - // is_to_write_set: u32 (40), nonce: u32 (44), - // point_x: [u32;7] (48..76), point_y: [u32;7] (76..104) - assert!(raw.len() % GPU_SHARD_RAM_RECORD_SIZE == 0); let count = raw.len() / GPU_SHARD_RAM_RECORD_SIZE; - (0..count).map(|i| { + // Reinterpret raw bytes as GpuShardRamRecord slice (zero-copy). + // GpuShardRamRecord is #[repr(C)], 104 bytes, 8-byte aligned. + #[cfg(feature = "gpu")] + let as_gpu_record = |i: usize| -> &ceno_gpu::common::witgen_types::GpuShardRamRecord { + use ceno_gpu::common::witgen_types::GpuShardRamRecord; let base = i * GPU_SHARD_RAM_RECORD_SIZE; - let r = &raw[base..base + GPU_SHARD_RAM_RECORD_SIZE]; + let ptr = raw[base..].as_ptr() as *const GpuShardRamRecord; + unsafe { &*ptr } + }; + + (0..count).map(|i| { + #[cfg(feature = "gpu")] + let (addr, ram_type_val, value, shard, local_clk, global_clk, is_to_write_set, nonce, point_x, point_y) = { + let g = as_gpu_record(i); + (g.addr, g.ram_type, g.value, g.shard, g.local_clk, g.global_clk, + g.is_to_write_set != 0, g.nonce, g.point_x, g.point_y) + }; - let addr = u32::from_le_bytes(r[0..4].try_into().unwrap()); - let ram_type_val = u32::from_le_bytes(r[4..8].try_into().unwrap()); - let value = u32::from_le_bytes(r[8..12].try_into().unwrap()); - let shard = u64::from_le_bytes(r[16..24].try_into().unwrap()); - let local_clk = u64::from_le_bytes(r[24..32].try_into().unwrap()); - let global_clk = u64::from_le_bytes(r[32..40].try_into().unwrap()); - let is_to_write_set = u32::from_le_bytes(r[40..44].try_into().unwrap()) != 0; - let nonce = u32::from_le_bytes(r[44..48].try_into().unwrap()); + #[cfg(not(feature = "gpu"))] + let (addr, ram_type_val, value, shard, local_clk, global_clk, is_to_write_set, nonce, point_x, point_y) = { + let base = i * GPU_SHARD_RAM_RECORD_SIZE; + let r = &raw[base..base + GPU_SHARD_RAM_RECORD_SIZE]; + let addr = u32::from_le_bytes(r[0..4].try_into().unwrap()); + let ram_type_val = u32::from_le_bytes(r[4..8].try_into().unwrap()); + let value = u32::from_le_bytes(r[8..12].try_into().unwrap()); + let shard = u64::from_le_bytes(r[16..24].try_into().unwrap()); + let local_clk = u64::from_le_bytes(r[24..32].try_into().unwrap()); + let global_clk = u64::from_le_bytes(r[32..40].try_into().unwrap()); + let is_to_write_set = u32::from_le_bytes(r[40..44].try_into().unwrap()) != 0; + let nonce = u32::from_le_bytes(r[44..48].try_into().unwrap()); + let mut px = [0u32; 7]; + let mut py = [0u32; 7]; + for j in 0..7 { + px[j] = u32::from_le_bytes(r[48 + j*4..52 + j*4].try_into().unwrap()); + py[j] = u32::from_le_bytes(r[76 + j*4..80 + j*4].try_into().unwrap()); + } + (addr, ram_type_val, value, shard, local_clk, global_clk, is_to_write_set, nonce, px, py) + }; let mut point_x_arr = [E::BaseField::ZERO; 7]; let mut point_y_arr = [E::BaseField::ZERO; 7]; for j in 0..7 { - let xoff = 48 + j * 4; - let yoff = 76 + j * 4; - point_x_arr[j] = E::BaseField::from_canonical_u32( - u32::from_le_bytes(r[xoff..xoff+4].try_into().unwrap()) - ); - point_y_arr[j] = E::BaseField::from_canonical_u32( - u32::from_le_bytes(r[yoff..yoff+4].try_into().unwrap()) - ); + point_x_arr[j] = E::BaseField::from_canonical_u32(point_x[j]); + point_y_arr[j] = E::BaseField::from_canonical_u32(point_y[j]); } let record = ShardRamRecord { From 2d7f3e9e7fd1986b4894b71982855f87f18366dc Mon Sep 17 00:00:00 2001 From: Velaciela Date: Tue, 17 Mar 2026 09:56:37 +0800 Subject: [PATCH 40/73] perf-minor --- ceno_zkvm/src/e2e.rs | 17 +++++++ ceno_zkvm/src/structs.rs | 105 ++++++++++++++++++--------------------- 2 files changed, 66 insertions(+), 56 deletions(-) diff --git a/ceno_zkvm/src/e2e.rs b/ceno_zkvm/src/e2e.rs index bfb4805ba..5f1d1de50 100644 --- a/ceno_zkvm/src/e2e.rs +++ b/ceno_zkvm/src/e2e.rs @@ -614,6 +614,23 @@ impl<'a> ShardContext<'a> { } } + /// merge addr accessed into a sorted Vec for fast binary search lookups. + /// Much faster than FxHashSet for large sets (avoids hashing overhead). + pub fn get_addr_accessed_sorted(&self) -> Vec { + if let Either::Left(addr_accessed_tbs) = &self.addr_accessed_tbs { + let total: usize = addr_accessed_tbs.iter().map(|v| v.len()).sum(); + let mut merged = Vec::with_capacity(total); + for addrs in addr_accessed_tbs { + merged.extend_from_slice(addrs); + } + merged.par_sort_unstable(); + merged.dedup(); + merged + } else { + panic!("invalid type"); + } + } + /// Splits a total count `num_shards` into up to `num_provers` non-empty parts, distributing as evenly as possible. /// /// # Behavior diff --git a/ceno_zkvm/src/structs.rs b/ceno_zkvm/src/structs.rs index 26ce927f3..97856155e 100644 --- a/ceno_zkvm/src/structs.rs +++ b/ceno_zkvm/src/structs.rs @@ -20,7 +20,7 @@ use rayon::{ iter::{IntoParallelIterator, IntoParallelRefIterator, ParallelIterator}, prelude::ParallelSlice, }; -use rustc_hash::{FxHashMap, FxHashSet}; +use rustc_hash::FxHashMap; use serde::{Deserialize, Serialize, de::DeserializeOwned}; use std::{ collections::{BTreeMap, HashMap}, @@ -480,18 +480,16 @@ impl ZKVMWitnesses { use tracing::info_span; let addr_accessed = info_span!("get_addr_accessed").in_scope(|| { - shard_ctx.get_addr_accessed() + shard_ctx.get_addr_accessed_sorted() }); // GPU EC records: convert raw bytes to ShardRamInput (EC points already computed on GPU) // Partition into writes and reads to maintain the ordering invariant required by // ShardRamCircuit::assign_instances (writes first, reads after). - let (gpu_ec_writes, gpu_ec_reads): (Vec<_>, Vec<_>) = + let (gpu_ec_writes, gpu_ec_reads) = info_span!("gpu_ec_convert", n = shard_ctx.gpu_ec_records.len() / 104).in_scope(|| { if shard_ctx.has_gpu_ec_records() { gpu_ec_records_to_shard_ram_inputs::(&shard_ctx.gpu_ec_records) - .into_iter() - .partition(|input| input.record.is_to_write_set) } else { (vec![], vec![]) } @@ -747,9 +745,10 @@ impl ZKVMWitnesses { waddr: WordAddr, addr: u32, shard_ctx: &ShardContext, - addr_accessed: &FxHashSet, + addr_accessed: &[WordAddr], ) -> Option<(ShardRamRecord, &'static str)> { - if addr_accessed.contains(&waddr) || !shard_ctx.after_current_shard_cycle(mem_record.cycle) + if addr_accessed.binary_search(&waddr).is_ok() + || !shard_ctx.after_current_shard_cycle(mem_record.cycle) { return None; } @@ -921,59 +920,45 @@ where /// Convert raw GPU EC record bytes to ShardRamInput. /// The raw bytes are from `GpuShardRamRecord` structs (104 bytes each). /// EC points are already computed on GPU — no Poseidon2/SepticCurve needed. +/// Returns (writes, reads) pre-partitioned using parallel iteration. fn gpu_ec_records_to_shard_ram_inputs( raw: &[u8], -) -> Vec> { - use gkr_iop::RAMType; - use p3::field::FieldAlgebra; - +) -> (Vec>, Vec>) { assert!(raw.len() % GPU_SHARD_RAM_RECORD_SIZE == 0); let count = raw.len() / GPU_SHARD_RAM_RECORD_SIZE; - // Reinterpret raw bytes as GpuShardRamRecord slice (zero-copy). - // GpuShardRamRecord is #[repr(C)], 104 bytes, 8-byte aligned. - #[cfg(feature = "gpu")] - let as_gpu_record = |i: usize| -> &ceno_gpu::common::witgen_types::GpuShardRamRecord { - use ceno_gpu::common::witgen_types::GpuShardRamRecord; - let base = i * GPU_SHARD_RAM_RECORD_SIZE; - let ptr = raw[base..].as_ptr() as *const GpuShardRamRecord; - unsafe { &*ptr } - }; - - (0..count).map(|i| { - #[cfg(feature = "gpu")] - let (addr, ram_type_val, value, shard, local_clk, global_clk, is_to_write_set, nonce, point_x, point_y) = { - let g = as_gpu_record(i); - (g.addr, g.ram_type, g.value, g.shard, g.local_clk, g.global_clk, - g.is_to_write_set != 0, g.nonce, g.point_x, g.point_y) - }; + #[inline(always)] + fn convert_record(raw: &[u8], i: usize) -> ShardRamInput { + use gkr_iop::RAMType; + use p3::field::FieldAlgebra; - #[cfg(not(feature = "gpu"))] - let (addr, ram_type_val, value, shard, local_clk, global_clk, is_to_write_set, nonce, point_x, point_y) = { - let base = i * GPU_SHARD_RAM_RECORD_SIZE; - let r = &raw[base..base + GPU_SHARD_RAM_RECORD_SIZE]; - let addr = u32::from_le_bytes(r[0..4].try_into().unwrap()); - let ram_type_val = u32::from_le_bytes(r[4..8].try_into().unwrap()); - let value = u32::from_le_bytes(r[8..12].try_into().unwrap()); - let shard = u64::from_le_bytes(r[16..24].try_into().unwrap()); - let local_clk = u64::from_le_bytes(r[24..32].try_into().unwrap()); - let global_clk = u64::from_le_bytes(r[32..40].try_into().unwrap()); - let is_to_write_set = u32::from_le_bytes(r[40..44].try_into().unwrap()) != 0; - let nonce = u32::from_le_bytes(r[44..48].try_into().unwrap()); - let mut px = [0u32; 7]; - let mut py = [0u32; 7]; - for j in 0..7 { - px[j] = u32::from_le_bytes(r[48 + j*4..52 + j*4].try_into().unwrap()); - py[j] = u32::from_le_bytes(r[76 + j*4..80 + j*4].try_into().unwrap()); - } - (addr, ram_type_val, value, shard, local_clk, global_clk, is_to_write_set, nonce, px, py) - }; + let base = i * GPU_SHARD_RAM_RECORD_SIZE; + let r = &raw[base..base + GPU_SHARD_RAM_RECORD_SIZE]; + + // Read fields directly from the byte buffer. + // Layout matches GpuShardRamRecord (104 bytes, #[repr(C)]): + // 0: addr(u32), 4: ram_type(u32), 8: value(u32), 12: _pad(u32), + // 16: shard(u64), 24: local_clk(u64), 32: global_clk(u64), + // 40: is_to_write_set(u32), 44: nonce(u32), + // 48: point_x[7](u32×7), 76: point_y[7](u32×7) + let addr = u32::from_le_bytes(r[0..4].try_into().unwrap()); + let ram_type_val = u32::from_le_bytes(r[4..8].try_into().unwrap()); + let value = u32::from_le_bytes(r[8..12].try_into().unwrap()); + let shard = u64::from_le_bytes(r[16..24].try_into().unwrap()); + let local_clk = u64::from_le_bytes(r[24..32].try_into().unwrap()); + let global_clk = u64::from_le_bytes(r[32..40].try_into().unwrap()); + let is_to_write_set = u32::from_le_bytes(r[40..44].try_into().unwrap()) != 0; + let nonce = u32::from_le_bytes(r[44..48].try_into().unwrap()); let mut point_x_arr = [E::BaseField::ZERO; 7]; let mut point_y_arr = [E::BaseField::ZERO; 7]; for j in 0..7 { - point_x_arr[j] = E::BaseField::from_canonical_u32(point_x[j]); - point_y_arr[j] = E::BaseField::from_canonical_u32(point_y[j]); + point_x_arr[j] = E::BaseField::from_canonical_u32( + u32::from_le_bytes(r[48 + j*4..52 + j*4].try_into().unwrap()), + ); + point_y_arr[j] = E::BaseField::from_canonical_u32( + u32::from_le_bytes(r[76 + j*4..80 + j*4].try_into().unwrap()), + ); } let record = ShardRamRecord { @@ -986,10 +971,6 @@ fn gpu_ec_records_to_shard_ram_inputs( is_to_write_set, }; - let x = SepticExtension(point_x_arr); - let y = SepticExtension(point_y_arr); - let point = SepticPoint::from_affine(x, y); - ShardRamInput { name: if is_to_write_set { "current_shard_external_write" @@ -997,7 +978,19 @@ fn gpu_ec_records_to_shard_ram_inputs( "current_shard_external_read" }, record, - ec_point: ECPoint { nonce, point }, + ec_point: ECPoint { + nonce, + point: SepticPoint::from_affine( + SepticExtension(point_x_arr), + SepticExtension(point_y_arr), + ), + }, } - }).collect() + } + + // Parallel convert + partition in one pass + (0..count) + .into_par_iter() + .map(|i| convert_record::(raw, i)) + .partition(|input| input.record.is_to_write_set) } From 6a21816c29eb2162a5352be92d72eaf32390c335 Mon Sep 17 00:00:00 2001 From: Velaciela Date: Tue, 17 Mar 2026 09:56:37 +0800 Subject: [PATCH 41/73] shardram-1 --- ceno_zkvm/src/gadgets/poseidon2.rs | 4 +- ceno_zkvm/src/instructions/riscv/gpu/mod.rs | 2 + .../src/instructions/riscv/gpu/shard_ram.rs | 134 +++++++++ ceno_zkvm/src/tables/shard_ram.rs | 258 +++++++++++++++++- 4 files changed, 384 insertions(+), 14 deletions(-) create mode 100644 ceno_zkvm/src/instructions/riscv/gpu/shard_ram.rs diff --git a/ceno_zkvm/src/gadgets/poseidon2.rs b/ceno_zkvm/src/gadgets/poseidon2.rs index 1dbfcccb5..fdb6be1a8 100644 --- a/ceno_zkvm/src/gadgets/poseidon2.rs +++ b/ceno_zkvm/src/gadgets/poseidon2.rs @@ -78,8 +78,8 @@ pub struct Poseidon2Config< const HALF_FULL_ROUNDS: usize, const PARTIAL_ROUNDS: usize, > { - p3_cols: Vec, // columns in the plonky3-air - post_linear_layer_cols: Vec, /* additional columns to hold the state after linear layers */ + pub(crate) p3_cols: Vec, // columns in the plonky3-air + pub(crate) post_linear_layer_cols: Vec, /* additional columns to hold the state after linear layers */ constants: RoundConstants, } diff --git a/ceno_zkvm/src/instructions/riscv/gpu/mod.rs b/ceno_zkvm/src/instructions/riscv/gpu/mod.rs index 51c4ba33f..03696d4af 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/mod.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/mod.rs @@ -43,4 +43,6 @@ pub mod sub; #[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] pub mod sw; #[cfg(feature = "gpu")] +pub mod shard_ram; +#[cfg(feature = "gpu")] pub mod witgen_gpu; diff --git a/ceno_zkvm/src/instructions/riscv/gpu/shard_ram.rs b/ceno_zkvm/src/instructions/riscv/gpu/shard_ram.rs new file mode 100644 index 000000000..1e08b96e9 --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/gpu/shard_ram.rs @@ -0,0 +1,134 @@ +use ceno_gpu::common::witgen_types::ShardRamColumnMap; +use ff_ext::ExtensionField; + +use crate::tables::ShardRamConfig; + +/// Extract column map from a constructed ShardRamConfig. +/// +/// This reads all WitIn.id values from the config and packs them +/// into a ShardRamColumnMap suitable for GPU kernel dispatch. +pub fn extract_shard_ram_column_map( + config: &ShardRamConfig, + num_witin: usize, +) -> ShardRamColumnMap { + let addr = config.addr.id as u32; + let is_ram_register = config.is_ram_register.id as u32; + + let value_limbs = config + .value + .wits_in() + .expect("value should have WitIn limbs"); + assert_eq!(value_limbs.len(), 2, "Expected 2 value limbs"); + let value = [value_limbs[0].id as u32, value_limbs[1].id as u32]; + + let shard = config.shard.id as u32; + let global_clk = config.global_clk.id as u32; + let local_clk = config.local_clk.id as u32; + let nonce = config.nonce.id as u32; + let is_global_write = config.is_global_write.id as u32; + + let mut x = [0u32; 7]; + let mut y = [0u32; 7]; + let mut slope = [0u32; 7]; + for i in 0..7 { + x[i] = config.x[i].id as u32; + y[i] = config.y[i].id as u32; + slope[i] = config.slope[i].id as u32; + } + + // Poseidon2 columns: p3_cols are contiguous, followed by post_linear_layer_cols + let poseidon2_base_col = config.perm_config.p3_cols[0].id as u32; + let num_p3_cols = config.perm_config.p3_cols.len() as u32; + let num_post_linear = config.perm_config.post_linear_layer_cols.len() as u32; + let num_poseidon2_cols = num_p3_cols + num_post_linear; + + // Verify contiguity: p3_cols should be contiguous + for (i, col) in config.perm_config.p3_cols.iter().enumerate() { + debug_assert_eq!( + col.id as u32, + poseidon2_base_col + i as u32, + "p3_cols not contiguous at index {}", + i + ); + } + // post_linear_layer_cols should be contiguous after p3_cols + let post_base = poseidon2_base_col + num_p3_cols; + for (i, col) in config.perm_config.post_linear_layer_cols.iter().enumerate() { + debug_assert_eq!( + col.id as u32, + post_base + i as u32, + "post_linear_layer_cols not contiguous at index {}", + i + ); + } + + ShardRamColumnMap { + addr, + is_ram_register, + value, + shard, + global_clk, + local_clk, + nonce, + is_global_write, + x, + y, + slope, + poseidon2_base_col, + num_poseidon2_cols, + num_p3_cols, + num_cols: num_witin as u32, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + structs::ProgramParams, + tables::{ShardRamCircuit, TableCircuit}, + }; + use ff_ext::BabyBearExt4; + + type E = BabyBearExt4; + + #[test] + fn test_extract_shard_ram_column_map() { + let mut cs = ConstraintSystem::::new(|| "test"); + let mut cb = CircuitBuilder::new(&mut cs); + let (config, _gkr_circuit) = + ShardRamCircuit::::build_gkr_iop_circuit(&mut cb, &ProgramParams::default()) + .unwrap(); + + let col_map = + extract_shard_ram_column_map(&config, cb.cs.num_witin as usize); + let flat = col_map.to_flat(); + + // Basic columns should be in range + // (excluding poseidon2 meta entries which are counts, not column IDs) + for (i, &col) in flat[..30].iter().enumerate() { + assert!( + (col as usize) < col_map.num_cols as usize, + "Column {} (flat index {}) out of range: {} >= {}", + col, + i, + col, + col_map.num_cols + ); + } + + // Check uniqueness of actual column IDs (first 30 entries) + let mut seen = std::collections::HashSet::new(); + for &col in &flat[..30] { + assert!(seen.insert(col), "Duplicate column ID: {}", col); + } + + // Verify Poseidon2 column counts are reasonable + assert_eq!(col_map.num_p3_cols, 299, "Expected 299 p3 cols"); + assert_eq!( + col_map.num_poseidon2_cols, 344, + "Expected 344 total Poseidon2 cols" + ); + } +} diff --git a/ceno_zkvm/src/tables/shard_ram.rs b/ceno_zkvm/src/tables/shard_ram.rs index f2748a4f9..3fc8db2d2 100644 --- a/ceno_zkvm/src/tables/shard_ram.rs +++ b/ceno_zkvm/src/tables/shard_ram.rs @@ -166,21 +166,21 @@ impl ShardRamRecord { /// 3. For a local memory write record which will be read in the future, /// the shard ram circuit will insert a local read record and write it to the **global set**. pub struct ShardRamConfig { - addr: WitIn, - is_ram_register: WitIn, - value: UInt, - shard: WitIn, - global_clk: WitIn, - local_clk: WitIn, - nonce: WitIn, + pub(crate) addr: WitIn, + pub(crate) is_ram_register: WitIn, + pub(crate) value: UInt, + pub(crate) shard: WitIn, + pub(crate) global_clk: WitIn, + pub(crate) local_clk: WitIn, + pub(crate) nonce: WitIn, // if it's write to global set, then insert a local read record // s.t. local offline memory checking can cancel out // serves as propagating local write to global. - is_global_write: WitIn, - x: Vec, - y: Vec, - slope: Vec, - perm_config: Poseidon2Config, + pub(crate) is_global_write: WitIn, + pub(crate) x: Vec, + pub(crate) y: Vec, + pub(crate) slope: Vec, + pub(crate) perm_config: Poseidon2Config, } impl ShardRamConfig { @@ -484,6 +484,13 @@ impl TableCircuit for ShardRamCircuit { witness::RowMajorMatrix::empty(), ]); } + + #[cfg(feature = "gpu")] + { + if let Some(result) = Self::try_gpu_assign_instances(config, num_witin, num_structural_witin, steps)? { + return Ok(result); + } + } // FIXME selector is the only structural witness // this is workaround, as call `construct_circuit` will not initialized selector // we can remove this one all opcode unittest migrate to call `build_gkr_iop_circuit` @@ -644,6 +651,233 @@ impl TableCircuit for ShardRamCircuit { } } +#[cfg(feature = "gpu")] +impl ShardRamCircuit { + /// Try to run assign_instances on GPU. Returns None if GPU is unavailable. + fn try_gpu_assign_instances( + config: &ShardRamConfig, + num_witin: usize, + num_structural_witin: usize, + steps: &[ShardRamInput], + ) -> Result>, ZKVMError> { + use ceno_gpu::{ + Buffer, CudaHal, + bb31::CudaHalBB31, + common::{ + transpose::matrix_transpose, + witgen_types::GpuShardRamRecord, + }, + }; + use gkr_iop::gpu::gpu_prover::get_cuda_hal; + use p3::field::PrimeField32; + + type BB = ::BaseField; + + // GPU only supports BabyBear + if std::any::TypeId::of::() != std::any::TypeId::of::() { + return Ok(None); + } + + let hal = match get_cuda_hal() { + Ok(h) => h, + Err(_) => return Ok(None), + }; + + let num_local_reads = steps + .iter() + .take_while(|s| s.record.is_to_write_set) + .count(); + + let n = next_pow2_instance_padding(steps.len()); + let num_rows_padded = 2 * n; + + // 1. Convert ShardRamInput → GpuShardRamRecord + let gpu_records: Vec = steps + .iter() + .map(|step| { + let r = &step.record; + let ec = &step.ec_point; + let mut rec = GpuShardRamRecord::default(); + rec.addr = r.addr; + rec.ram_type = r.ram_type as u32; + rec.value = r.value; + rec.shard = r.shard; + rec.local_clk = r.local_clk; + rec.global_clk = r.global_clk; + rec.is_to_write_set = if r.is_to_write_set { 1 } else { 0 }; + rec.nonce = ec.nonce; + for i in 0..7 { + // Safe: TypeId check above guarantees E::BaseField == BB + let px: BB = unsafe { *(&ec.point.x.0[i] as *const E::BaseField as *const BB) }; + let py: BB = unsafe { *(&ec.point.y.0[i] as *const E::BaseField as *const BB) }; + rec.point_x[i] = px.as_canonical_u32(); + rec.point_y[i] = py.as_canonical_u32(); + } + rec + }) + .collect(); + + // 2. Extract column map + let col_map = crate::instructions::riscv::gpu::shard_ram::extract_shard_ram_column_map( + config, num_witin, + ); + + // 3. GPU Phase 1: per-row assignment + let (gpu_witness, gpu_structural) = hal + .witgen_shard_ram_per_row( + &col_map, + &gpu_records, + num_local_reads as u32, + num_witin as u32, + num_structural_witin as u32, + num_rows_padded as u32, + None, + ) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU shard_ram per-row kernel failed: {e}").into(), + ) + })?; + + // 4. GPU Phase 2: EC binary tree + let col_offsets = col_map.to_flat(); + let gpu_cols = hal.alloc_u32_from_host(&col_offsets, None).map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU alloc col offsets failed: {e}").into()) + })?; + + // Build initial layer points (padded to n) using BB (BabyBear) directly + let mut init_x = vec![BB::ZERO; n * 7]; + let mut init_y = vec![BB::ZERO; n * 7]; + for (i, step) in steps.iter().enumerate() { + for j in 0..7 { + // E::BaseField == BB at runtime (checked above), safe to transmute + init_x[i * 7 + j] = + unsafe { *(&step.ec_point.point.x.0[j] as *const E::BaseField as *const BB) }; + init_y[i * 7 + j] = + unsafe { *(&step.ec_point.point.y.0[j] as *const E::BaseField as *const BB) }; + } + } + + let mut cur_x = hal.alloc_elems_from_host(&init_x, None).map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU alloc init_x failed: {e}").into()) + })?; + let mut cur_y = hal.alloc_elems_from_host(&init_y, None).map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU alloc init_y failed: {e}").into()) + })?; + + let mut witness_buf = gpu_witness.device_buffer; + let mut offset = num_rows_padded / 2; // n + let mut current_layer_len = n; + + loop { + if current_layer_len <= 1 { + break; + } + + let (next_x, next_y) = hal + .shard_ram_ec_tree_layer( + &gpu_cols, + &cur_x, + &cur_y, + &mut witness_buf, + current_layer_len, + offset, + num_rows_padded, + None, + ) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU EC tree layer failed: {e}").into(), + ) + })?; + + current_layer_len /= 2; + offset += current_layer_len; + cur_x = next_x; + cur_y = next_y; + } + + // 5. GPU transpose: column-major → row-major + D2H + let wit_num_rows = num_rows_padded; + let wit_num_cols = num_witin; + let mut rmm_buf = hal + .alloc_elems_on_device(wit_num_rows * wit_num_cols, false, None) + .map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU alloc for transpose failed: {e}").into()) + })?; + matrix_transpose::( + &hal.inner, + &mut rmm_buf, + &witness_buf, + wit_num_rows, + wit_num_cols, + ) + .map_err(|e| ZKVMError::InvalidWitness(format!("GPU transpose failed: {e}").into()))?; + + let gpu_wit_data: Vec = rmm_buf.to_vec().map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU D2H wit failed: {e}").into()) + })?; + let wit_data: Vec = unsafe { + let mut data = std::mem::ManuallyDrop::new(gpu_wit_data); + Vec::from_raw_parts( + data.as_mut_ptr() as *mut E::BaseField, + data.len(), + data.capacity(), + ) + }; + + let struct_num_cols = num_structural_witin; + let mut struct_rmm_buf = hal + .alloc_elems_on_device(wit_num_rows * struct_num_cols, false, None) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU alloc for struct transpose failed: {e}").into(), + ) + })?; + matrix_transpose::( + &hal.inner, + &mut struct_rmm_buf, + &gpu_structural.device_buffer, + wit_num_rows, + struct_num_cols, + ) + .map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU struct transpose failed: {e}").into()) + })?; + + let gpu_struct_data: Vec = struct_rmm_buf.to_vec().map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU D2H struct failed: {e}").into()) + })?; + let struct_data: Vec = unsafe { + let mut data = std::mem::ManuallyDrop::new(gpu_struct_data); + Vec::from_raw_parts( + data.as_mut_ptr() as *mut E::BaseField, + data.len(), + data.capacity(), + ) + }; + + let raw_witin = witness::RowMajorMatrix::new_by_values( + wit_data, + wit_num_cols, + InstancePaddingStrategy::Default, + ); + let raw_structural_witin = witness::RowMajorMatrix::new_by_values( + struct_data, + struct_num_cols, + InstancePaddingStrategy::Default, + ); + + tracing::debug!( + "GPU shard_ram assign_instances: {} records, {} padded rows", + steps.len(), + num_rows_padded + ); + + Ok(Some([raw_witin, raw_structural_witin])) + } +} + #[cfg(test)] mod tests { use either::Either; From 1d27ae1b0316e0427e9c91803f4649e3c9f43573 Mon Sep 17 00:00:00 2001 From: Velaciela Date: Wed, 18 Mar 2026 09:39:10 +0800 Subject: [PATCH 42/73] profile --- ceno_zkvm/src/tables/shard_ram.rs | 231 ++++++++++++++++-------------- 1 file changed, 125 insertions(+), 106 deletions(-) diff --git a/ceno_zkvm/src/tables/shard_ram.rs b/ceno_zkvm/src/tables/shard_ram.rs index 3fc8db2d2..edddd7928 100644 --- a/ceno_zkvm/src/tables/shard_ram.rs +++ b/ceno_zkvm/src/tables/shard_ram.rs @@ -692,7 +692,9 @@ impl ShardRamCircuit { let num_rows_padded = 2 * n; // 1. Convert ShardRamInput → GpuShardRamRecord - let gpu_records: Vec = steps + let gpu_records: Vec = tracing::info_span!( + "gpu_shard_ram_pack_records", n = steps.len() + ).in_scope(|| steps .iter() .map(|step| { let r = &step.record; @@ -715,7 +717,7 @@ impl ShardRamCircuit { } rec }) - .collect(); + .collect()); // 2. Extract column map let col_map = crate::instructions::riscv::gpu::shard_ram::extract_shard_ram_column_map( @@ -723,7 +725,12 @@ impl ShardRamCircuit { ); // 3. GPU Phase 1: per-row assignment - let (gpu_witness, gpu_structural) = hal + let (gpu_witness, gpu_structural) = tracing::info_span!( + "gpu_shard_ram_per_row", + n = steps.len(), + num_rows_padded, + num_witin, + ).in_scope(|| hal .witgen_shard_ram_per_row( &col_map, &gpu_records, @@ -737,139 +744,151 @@ impl ShardRamCircuit { ZKVMError::InvalidWitness( format!("GPU shard_ram per-row kernel failed: {e}").into(), ) - })?; + }))?; // 4. GPU Phase 2: EC binary tree - let col_offsets = col_map.to_flat(); - let gpu_cols = hal.alloc_u32_from_host(&col_offsets, None).map_err(|e| { - ZKVMError::InvalidWitness(format!("GPU alloc col offsets failed: {e}").into()) - })?; + let witness_buf = tracing::info_span!( + "gpu_shard_ram_ec_tree", n + ).in_scope(|| -> Result<_, ZKVMError> { + let col_offsets = col_map.to_flat(); + let gpu_cols = hal.alloc_u32_from_host(&col_offsets, None).map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU alloc col offsets failed: {e}").into()) + })?; - // Build initial layer points (padded to n) using BB (BabyBear) directly - let mut init_x = vec![BB::ZERO; n * 7]; - let mut init_y = vec![BB::ZERO; n * 7]; - for (i, step) in steps.iter().enumerate() { - for j in 0..7 { - // E::BaseField == BB at runtime (checked above), safe to transmute - init_x[i * 7 + j] = - unsafe { *(&step.ec_point.point.x.0[j] as *const E::BaseField as *const BB) }; - init_y[i * 7 + j] = - unsafe { *(&step.ec_point.point.y.0[j] as *const E::BaseField as *const BB) }; + // Build initial layer points (padded to n) using BB (BabyBear) directly + let mut init_x = vec![BB::ZERO; n * 7]; + let mut init_y = vec![BB::ZERO; n * 7]; + for (i, step) in steps.iter().enumerate() { + for j in 0..7 { + // E::BaseField == BB at runtime (checked above), safe to transmute + init_x[i * 7 + j] = + unsafe { *(&step.ec_point.point.x.0[j] as *const E::BaseField as *const BB) }; + init_y[i * 7 + j] = + unsafe { *(&step.ec_point.point.y.0[j] as *const E::BaseField as *const BB) }; + } } - } - let mut cur_x = hal.alloc_elems_from_host(&init_x, None).map_err(|e| { - ZKVMError::InvalidWitness(format!("GPU alloc init_x failed: {e}").into()) - })?; - let mut cur_y = hal.alloc_elems_from_host(&init_y, None).map_err(|e| { - ZKVMError::InvalidWitness(format!("GPU alloc init_y failed: {e}").into()) - })?; + let mut cur_x = hal.alloc_elems_from_host(&init_x, None).map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU alloc init_x failed: {e}").into()) + })?; + let mut cur_y = hal.alloc_elems_from_host(&init_y, None).map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU alloc init_y failed: {e}").into()) + })?; - let mut witness_buf = gpu_witness.device_buffer; - let mut offset = num_rows_padded / 2; // n - let mut current_layer_len = n; + let mut witness_buf = gpu_witness.device_buffer; + let mut offset = num_rows_padded / 2; // n + let mut current_layer_len = n; - loop { - if current_layer_len <= 1 { - break; + loop { + if current_layer_len <= 1 { + break; + } + + let (next_x, next_y) = hal + .shard_ram_ec_tree_layer( + &gpu_cols, + &cur_x, + &cur_y, + &mut witness_buf, + current_layer_len, + offset, + num_rows_padded, + None, + ) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU EC tree layer failed: {e}").into(), + ) + })?; + + current_layer_len /= 2; + offset += current_layer_len; + cur_x = next_x; + cur_y = next_y; } - let (next_x, next_y) = hal - .shard_ram_ec_tree_layer( - &gpu_cols, - &cur_x, - &cur_y, - &mut witness_buf, - current_layer_len, - offset, - num_rows_padded, - None, + Ok(witness_buf) + })?; + + // 5. GPU transpose: column-major → row-major + D2H + let (wit_data, struct_data) = tracing::info_span!( + "gpu_shard_ram_transpose_d2h", num_rows_padded, num_witin, + ).in_scope(|| -> Result<_, ZKVMError> { + let wit_num_rows = num_rows_padded; + let wit_num_cols = num_witin; + let mut rmm_buf = hal + .alloc_elems_on_device(wit_num_rows * wit_num_cols, false, None) + .map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU alloc for transpose failed: {e}").into()) + })?; + matrix_transpose::( + &hal.inner, + &mut rmm_buf, + &witness_buf, + wit_num_rows, + wit_num_cols, + ) + .map_err(|e| ZKVMError::InvalidWitness(format!("GPU transpose failed: {e}").into()))?; + + let gpu_wit_data: Vec = rmm_buf.to_vec().map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU D2H wit failed: {e}").into()) + })?; + let wit_data: Vec = unsafe { + let mut data = std::mem::ManuallyDrop::new(gpu_wit_data); + Vec::from_raw_parts( + data.as_mut_ptr() as *mut E::BaseField, + data.len(), + data.capacity(), ) + }; + + let struct_num_cols = num_structural_witin; + let mut struct_rmm_buf = hal + .alloc_elems_on_device(wit_num_rows * struct_num_cols, false, None) .map_err(|e| { ZKVMError::InvalidWitness( - format!("GPU EC tree layer failed: {e}").into(), + format!("GPU alloc for struct transpose failed: {e}").into(), ) })?; - - current_layer_len /= 2; - offset += current_layer_len; - cur_x = next_x; - cur_y = next_y; - } - - // 5. GPU transpose: column-major → row-major + D2H - let wit_num_rows = num_rows_padded; - let wit_num_cols = num_witin; - let mut rmm_buf = hal - .alloc_elems_on_device(wit_num_rows * wit_num_cols, false, None) + matrix_transpose::( + &hal.inner, + &mut struct_rmm_buf, + &gpu_structural.device_buffer, + wit_num_rows, + struct_num_cols, + ) .map_err(|e| { - ZKVMError::InvalidWitness(format!("GPU alloc for transpose failed: {e}").into()) + ZKVMError::InvalidWitness(format!("GPU struct transpose failed: {e}").into()) })?; - matrix_transpose::( - &hal.inner, - &mut rmm_buf, - &witness_buf, - wit_num_rows, - wit_num_cols, - ) - .map_err(|e| ZKVMError::InvalidWitness(format!("GPU transpose failed: {e}").into()))?; - let gpu_wit_data: Vec = rmm_buf.to_vec().map_err(|e| { - ZKVMError::InvalidWitness(format!("GPU D2H wit failed: {e}").into()) - })?; - let wit_data: Vec = unsafe { - let mut data = std::mem::ManuallyDrop::new(gpu_wit_data); - Vec::from_raw_parts( - data.as_mut_ptr() as *mut E::BaseField, - data.len(), - data.capacity(), - ) - }; - - let struct_num_cols = num_structural_witin; - let mut struct_rmm_buf = hal - .alloc_elems_on_device(wit_num_rows * struct_num_cols, false, None) - .map_err(|e| { - ZKVMError::InvalidWitness( - format!("GPU alloc for struct transpose failed: {e}").into(), - ) + let gpu_struct_data: Vec = struct_rmm_buf.to_vec().map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU D2H struct failed: {e}").into()) })?; - matrix_transpose::( - &hal.inner, - &mut struct_rmm_buf, - &gpu_structural.device_buffer, - wit_num_rows, - struct_num_cols, - ) - .map_err(|e| { - ZKVMError::InvalidWitness(format!("GPU struct transpose failed: {e}").into()) - })?; + let struct_data: Vec = unsafe { + let mut data = std::mem::ManuallyDrop::new(gpu_struct_data); + Vec::from_raw_parts( + data.as_mut_ptr() as *mut E::BaseField, + data.len(), + data.capacity(), + ) + }; - let gpu_struct_data: Vec = struct_rmm_buf.to_vec().map_err(|e| { - ZKVMError::InvalidWitness(format!("GPU D2H struct failed: {e}").into()) + Ok((wit_data, struct_data)) })?; - let struct_data: Vec = unsafe { - let mut data = std::mem::ManuallyDrop::new(gpu_struct_data); - Vec::from_raw_parts( - data.as_mut_ptr() as *mut E::BaseField, - data.len(), - data.capacity(), - ) - }; let raw_witin = witness::RowMajorMatrix::new_by_values( wit_data, - wit_num_cols, + num_witin, InstancePaddingStrategy::Default, ); let raw_structural_witin = witness::RowMajorMatrix::new_by_values( struct_data, - struct_num_cols, + num_structural_witin, InstancePaddingStrategy::Default, ); - tracing::debug!( - "GPU shard_ram assign_instances: {} records, {} padded rows", + tracing::info!( + "GPU shard_ram assign_instances done: {} records, {} padded rows", steps.len(), num_rows_padded ); From c4ab9ff450c9f2508597c78c430a24840f8050be Mon Sep 17 00:00:00 2001 From: Velaciela Date: Wed, 18 Mar 2026 09:39:10 +0800 Subject: [PATCH 43/73] shard-1 --- ceno_zkvm/src/e2e.rs | 9 ++ .../src/instructions/riscv/gpu/witgen_gpu.rs | 137 +++++++++++++++++- 2 files changed, 144 insertions(+), 2 deletions(-) diff --git a/ceno_zkvm/src/e2e.rs b/ceno_zkvm/src/e2e.rs index 5f1d1de50..01b4b0f92 100644 --- a/ceno_zkvm/src/e2e.rs +++ b/ceno_zkvm/src/e2e.rs @@ -1507,6 +1507,15 @@ pub fn generate_witness<'a, E: ExtensionField>( ) }).unwrap(); + // Flush shared EC/addr buffers from GPU after all opcode circuits are done. + // This batch-D2Hs accumulated EC records and addr_accessed into shard_ctx. + #[cfg(feature = "gpu")] + info_span!("flush_shared_ec").in_scope(|| { + crate::instructions::riscv::gpu::witgen_gpu::flush_shared_ec_buffers( + &mut shard_ctx, + ) + }).unwrap(); + // Free GPU shard_steps cache after all opcode circuits are done. #[cfg(feature = "gpu")] { diff --git a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs index c5b2e5502..f18c1457d 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs @@ -173,6 +173,14 @@ pub fn invalidate_shard_steps_cache() { struct ShardMetadataCache { shard_id: usize, device_bufs: ShardDeviceBuffers, + /// Shared EC record buffer (owns the GPU memory, pointer stored in device_bufs). + shared_ec_buf: Option>, + /// Shared EC record count buffer (single u32 counter). + shared_ec_count: Option>, + /// Shared addr_accessed buffer (u32 word addresses). + shared_addr_buf: Option>, + /// Shared addr_accessed count buffer (single u32 counter). + shared_addr_count: Option>, } thread_local! { @@ -188,6 +196,7 @@ thread_local! { fn ensure_shard_metadata_cached( hal: &CudaHalBB31, shard_ctx: &ShardContext, + n_total_steps: usize, ) -> Result<(), ZKVMError> { let shard_id = shard_ctx.shard_id; SHARD_META_CACHE.with(|cache| { @@ -208,6 +217,10 @@ fn ensure_shard_metadata_cached( prev_shard_heap_range: _, prev_shard_hint_range: _, gpu_ec_shard_id: _, + shared_ec_out_ptr: _, + shared_ec_count_ptr: _, + shared_addr_out_ptr: _, + shared_addr_count_ptr: _, } = c.device_bufs; next_access_packed }); @@ -308,6 +321,42 @@ fn ensure_shard_metadata_cached( shard_id, ); + // Allocate shared EC/addr compact buffers for this shard. + // Each step produces up to 3 EC records and 3 addr_accessed entries. + // These buffers persist across all kernel invocations within the shard. + let ec_capacity = n_total_steps * 4; // extra headroom + let ec_u32s = ec_capacity * 26; // 26 u32s per GpuShardRamRecord (104 bytes) + let addr_capacity = n_total_steps * 4; + + let shared_ec_buf = hal + .alloc_u32_zeroed(ec_u32s, None) + .map_err(|e| ZKVMError::InvalidWitness(format!("shared_ec_buf alloc: {e}").into()))?; + let shared_ec_count = hal + .alloc_u32_zeroed(1, None) + .map_err(|e| { + ZKVMError::InvalidWitness(format!("shared_ec_count alloc: {e}").into()) + })?; + let shared_addr_buf = hal + .alloc_u32_zeroed(addr_capacity, None) + .map_err(|e| { + ZKVMError::InvalidWitness(format!("shared_addr_buf alloc: {e}").into()) + })?; + let shared_addr_count = hal + .alloc_u32_zeroed(1, None) + .map_err(|e| { + ZKVMError::InvalidWitness(format!("shared_addr_count alloc: {e}").into()) + })?; + + let shared_ec_out_ptr = shared_ec_buf.device_ptr() as u64; + let shared_ec_count_ptr = shared_ec_count.device_ptr() as u64; + let shared_addr_out_ptr = shared_addr_buf.device_ptr() as u64; + let shared_addr_count_ptr = shared_addr_count.device_ptr() as u64; + + tracing::info!( + "[GPU shard] shard_id={}: shared buffers allocated: ec_capacity={}, addr_capacity={}", + shard_id, ec_capacity, addr_capacity, + ); + *cache = Some(ShardMetadataCache { shard_id, device_bufs: ShardDeviceBuffers { @@ -317,7 +366,15 @@ fn ensure_shard_metadata_cached( prev_shard_heap_range: pshr_device, prev_shard_hint_range: pshi_device, gpu_ec_shard_id: Some(shard_id as u64), + shared_ec_out_ptr, + shared_ec_count_ptr, + shared_addr_out_ptr, + shared_addr_count_ptr, }, + shared_ec_buf: Some(shared_ec_buf), + shared_ec_count: Some(shared_ec_count), + shared_addr_buf: Some(shared_addr_buf), + shared_addr_count: Some(shared_addr_count), }); Ok(()) }) @@ -339,6 +396,77 @@ pub fn invalidate_shard_meta_cache() { }); } +/// Batch D2H of shared EC records and addr_accessed buffers after all kernel invocations. +/// +/// Called once per shard after all opcode `gpu_assign_instances_inner` calls complete. +/// Transfers accumulated EC records and addresses from shared GPU buffers into `shard_ctx`. +pub fn flush_shared_ec_buffers(shard_ctx: &mut ShardContext) -> Result<(), ZKVMError> { + SHARD_META_CACHE.with(|cache| { + let cache = cache.borrow(); + let c = cache.as_ref().ok_or_else(|| { + ZKVMError::InvalidWitness("shard metadata not cached".into()) + })?; + + // D2H EC record count + let ec_count_buf = c.shared_ec_count.as_ref().ok_or_else(|| { + ZKVMError::InvalidWitness("shared_ec_count not allocated".into()) + })?; + let ec_count_vec: Vec = ec_count_buf.to_vec().map_err(|e| { + ZKVMError::InvalidWitness(format!("shared_ec_count D2H: {e}").into()) + })?; + let ec_count = ec_count_vec[0] as usize; + + if ec_count > 0 { + // D2H EC records (only the active portion) + let ec_buf = c.shared_ec_buf.as_ref().unwrap(); + let ec_u32s = ec_count * 26; // 26 u32s per GpuShardRamRecord + let raw_u32: Vec = ec_buf.to_vec_n(ec_u32s).map_err(|e| { + ZKVMError::InvalidWitness(format!("shared_ec_buf D2H: {e}").into()) + })?; + let raw_bytes = unsafe { + std::slice::from_raw_parts( + raw_u32.as_ptr() as *const u8, + raw_u32.len() * 4, + ) + }; + tracing::info!( + "[GPU shard] flush_shared_ec_buffers: {} EC records, {:.2} MB", + ec_count, + raw_bytes.len() as f64 / (1024.0 * 1024.0), + ); + shard_ctx.extend_gpu_ec_records_raw(raw_bytes); + } + + // D2H addr_accessed count + let addr_count_buf = c.shared_addr_count.as_ref().ok_or_else(|| { + ZKVMError::InvalidWitness("shared_addr_count not allocated".into()) + })?; + let addr_count_vec: Vec = addr_count_buf.to_vec().map_err(|e| { + ZKVMError::InvalidWitness(format!("shared_addr_count D2H: {e}").into()) + })?; + let addr_count = addr_count_vec[0] as usize; + + if addr_count > 0 { + let addr_buf = c.shared_addr_buf.as_ref().unwrap(); + let addrs: Vec = addr_buf.to_vec_n(addr_count).map_err(|e| { + ZKVMError::InvalidWitness(format!("shared_addr_buf D2H: {e}").into()) + })?; + tracing::info!( + "[GPU shard] flush_shared_ec_buffers: {} addr_accessed, {:.2} MB", + addr_count, + addr_count as f64 * 4.0 / (1024.0 * 1024.0), + ); + let mut forked = shard_ctx.get_forked(); + let thread_ctx = &mut forked[0]; + for &addr in &addrs { + thread_ctx.push_addr_accessed(WordAddr(addr)); + } + } + + Ok(()) + }) +} + /// CPU-side lightweight scan of GPU-produced RAM record slots. /// /// Reconstructs BTreeMap read/write records and addr_accessed from the GPU output, @@ -622,7 +750,11 @@ fn gpu_assign_instances_inner>( gpu_lk_counters_to_multiplicity(gpu_lk_counters.unwrap()) })?; - if gpu_compact_ec.is_some() && kind_has_verified_shard(kind) { + if gpu_compact_ec.is_none() && gpu_compact_addr.is_none() && kind_has_verified_shard(kind) { + // Shared buffer path: EC records + addr_accessed accumulated on device + // in shared buffers across all kernel invocations. Skip per-kernel D2H. + // Data will be consumed in batch by assign_shared_circuit. + } else if gpu_compact_ec.is_some() && kind_has_verified_shard(kind) { // GPU EC path: compact records already have EC points computed on device. // D2H only the active records (much smaller than full N*3 slot buffer). info_span!("gpu_ec_shard").in_scope(|| { @@ -839,7 +971,8 @@ fn gpu_fill_witness>( let (fetch_base_pc, fetch_num_slots) = compute_fetch_params(shard_steps, step_indices); // Ensure shard metadata is cached for GPU shard records (shared across all kernel kinds) - info_span!("ensure_shard_meta").in_scope(|| ensure_shard_metadata_cached(hal, shard_ctx))?; + info_span!("ensure_shard_meta") + .in_scope(|| ensure_shard_metadata_cached(hal, shard_ctx, shard_steps.len()))?; match kind { GpuWitgenKind::Add => { From be58c5e1f19ba80da97fb09a0e746ff5bc394f5f Mon Sep 17 00:00:00 2001 From: Velaciela Date: Wed, 18 Mar 2026 09:39:10 +0800 Subject: [PATCH 44/73] shard-2 --- .../src/instructions/riscv/gpu/witgen_gpu.rs | 93 +++++- ceno_zkvm/src/structs.rs | 265 ++++++++++++++++++ ceno_zkvm/src/tables/shard_ram.rs | 201 +++++++++++++ 3 files changed, 552 insertions(+), 7 deletions(-) diff --git a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs index f18c1457d..0703c0c88 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs @@ -396,21 +396,100 @@ pub fn invalidate_shard_meta_cache() { }); } +/// Take ownership of shared EC and addr_accessed device buffers from the cache. +/// +/// Returns (shared_ec_buf, ec_count, shared_addr_buf, addr_count) or None if unavailable. +/// The cache is invalidated after this call — must be called at most once per shard. +pub fn take_shared_device_buffers() -> Option { + SHARD_META_CACHE.with(|cache| { + let mut cache = cache.borrow_mut(); + let c = cache.as_mut()?; + + let ec_buf = c.shared_ec_buf.take()?; + let ec_count = c.shared_ec_count.take()?; + let addr_buf = c.shared_addr_buf.take()?; + let addr_count = c.shared_addr_count.take()?; + + Some(SharedDeviceBufferSet { + ec_buf, + ec_count, + addr_buf, + addr_count, + }) + }) +} + +/// Shared device buffers taken from the shard metadata cache. +pub struct SharedDeviceBufferSet { + pub ec_buf: ceno_gpu::common::buffer::BufferImpl<'static, u32>, + pub ec_count: ceno_gpu::common::buffer::BufferImpl<'static, u32>, + pub addr_buf: ceno_gpu::common::buffer::BufferImpl<'static, u32>, + pub addr_count: ceno_gpu::common::buffer::BufferImpl<'static, u32>, +} + +/// Batch compute EC points for continuation records, keeping results on device. +/// +/// Returns (device_buf_as_u32, num_records) where the device buffer contains +/// GpuShardRamRecord entries with EC points computed. +pub fn gpu_batch_continuation_ec_on_device( + write_records: &[(crate::tables::ShardRamRecord, &'static str)], + read_records: &[(crate::tables::ShardRamRecord, &'static str)], +) -> Result<(ceno_gpu::common::buffer::BufferImpl<'static, u32>, usize, usize), ZKVMError> { + use gkr_iop::gpu::get_cuda_hal; + + let hal = get_cuda_hal().map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU not available for batch EC: {e}").into()) + })?; + + let n_writes = write_records.len(); + let n_reads = read_records.len(); + let total = n_writes + n_reads; + if total == 0 { + let empty = hal.alloc_u32_zeroed(1, None).map_err(|e| { + ZKVMError::InvalidWitness(format!("alloc: {e}").into()) + })?; + return Ok((empty, 0, 0)); + } + + // Convert to GpuShardRamRecord format (writes first, reads after) + let mut gpu_records: Vec = Vec::with_capacity(total); + for (rec, _name) in write_records.iter().chain(read_records.iter()) { + gpu_records.push(shard_ram_record_to_gpu(rec)); + } + + // GPU batch EC, results stay on device + let (device_buf, _count) = info_span!("gpu_batch_ec_on_device", n = total).in_scope(|| { + hal.batch_continuation_ec_on_device(&gpu_records, None) + }).map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU batch EC on device failed: {e}").into()) + })?; + + Ok((device_buf, n_writes, n_reads)) +} + /// Batch D2H of shared EC records and addr_accessed buffers after all kernel invocations. /// /// Called once per shard after all opcode `gpu_assign_instances_inner` calls complete. /// Transfers accumulated EC records and addresses from shared GPU buffers into `shard_ctx`. +/// +/// If the shared buffers have already been taken by `take_shared_device_buffers` +/// (for the full GPU pipeline), this is a no-op. pub fn flush_shared_ec_buffers(shard_ctx: &mut ShardContext) -> Result<(), ZKVMError> { SHARD_META_CACHE.with(|cache| { let cache = cache.borrow(); - let c = cache.as_ref().ok_or_else(|| { - ZKVMError::InvalidWitness("shard metadata not cached".into()) - })?; + let c = match cache.as_ref() { + Some(c) => c, + None => return Ok(()), // cache already invalidated — no-op + }; - // D2H EC record count - let ec_count_buf = c.shared_ec_count.as_ref().ok_or_else(|| { - ZKVMError::InvalidWitness("shared_ec_count not allocated".into()) - })?; + // If buffers have been taken by take_shared_device_buffers, skip D2H + let ec_count_buf = match c.shared_ec_count.as_ref() { + Some(b) => b, + None => { + tracing::debug!("[GPU shard] flush_shared_ec_buffers: buffers already taken, no-op"); + return Ok(()); + } + }; let ec_count_vec: Vec = ec_count_buf.to_vec().map_err(|e| { ZKVMError::InvalidWitness(format!("shared_ec_count D2H: {e}").into()) })?; diff --git a/ceno_zkvm/src/structs.rs b/ceno_zkvm/src/structs.rs index 97856155e..495aa41b8 100644 --- a/ceno_zkvm/src/structs.rs +++ b/ceno_zkvm/src/structs.rs @@ -479,6 +479,22 @@ impl ZKVMWitnesses { ) -> Result<(), ZKVMError> { use tracing::info_span; + // Try the full GPU pipeline: keep data on device, minimal CPU roundtrips. + // Falls back to the traditional path on failure. + #[cfg(feature = "gpu")] + { + let gpu_result = self.try_assign_shared_circuit_gpu( + cs, shard_ctx, final_mem, config, + ); + match gpu_result { + Ok(true) => return Ok(()), // GPU pipeline succeeded + Ok(false) => {} // GPU pipeline unavailable, fall through + Err(e) => { + tracing::warn!("GPU full pipeline failed, falling back: {e:?}"); + } + } + } + let addr_accessed = info_span!("get_addr_accessed").in_scope(|| { shard_ctx.get_addr_accessed_sorted() }); @@ -706,6 +722,255 @@ impl ZKVMWitnesses { Ok(()) } + /// Full GPU pipeline for assign_shared_circuit: keep data on device, minimal CPU roundtrips. + /// + /// Returns Ok(true) if successful, Ok(false) if unavailable (no shared device buffers). + #[cfg(feature = "gpu")] + fn try_assign_shared_circuit_gpu( + &mut self, + cs: &ZKVMConstraintSystem, + shard_ctx: &ShardContext, + final_mem: &[(&'static str, Option>, &[MemFinalRecord])], + config: & as TableCircuit>::TableConfig, + ) -> Result { + use crate::instructions::riscv::gpu::witgen_gpu::{ + gpu_batch_continuation_ec_on_device, take_shared_device_buffers, + }; + use ceno_gpu::Buffer; + use gkr_iop::gpu::get_cuda_hal; + use tracing::info_span; + + // 1. Take shared device buffers (if available) + let mut shared = match take_shared_device_buffers() { + Some(s) => s, + None => return Ok(false), + }; + + let hal = match get_cuda_hal() { + Ok(h) => h, + Err(_) => return Ok(false), + }; + + tracing::info!("[GPU full pipeline] starting device-resident assign_shared_circuit"); + + // 2. D2H the EC count and addr count + let ec_count = { + let cv: Vec = shared.ec_count.to_vec().map_err(|e| { + ZKVMError::InvalidWitness(format!("shared_ec_count D2H: {e}").into()) + })?; + cv[0] as usize + }; + let addr_count = { + let cv: Vec = shared.addr_count.to_vec().map_err(|e| { + ZKVMError::InvalidWitness(format!("shared_addr_count D2H: {e}").into()) + })?; + cv[0] as usize + }; + + tracing::info!( + "[GPU full pipeline] shared buffers: {} EC records, {} addr_accessed", + ec_count, addr_count, + ); + + // 3. GPU sort addr_accessed + dedup, then D2H sorted unique addrs + let addr_accessed: Vec = if addr_count > 0 { + info_span!("gpu_sort_addr").in_scope(|| { + let (deduped, unique_count) = hal + .sort_and_dedup_u32(&mut shared.addr_buf, addr_count, None) + .map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU sort addr: {e}").into()) + })?; + if unique_count == 0 { + return Ok::, ZKVMError>(vec![]); + } + // GPU-sorted + CPU-deduped; convert to WordAddr + let addrs: Vec = deduped.into_iter().map(WordAddr).collect(); + tracing::info!( + "[GPU full pipeline] sorted {} addrs → {} unique", + addr_count, unique_count, + ); + Ok(addrs) + })? + } else { + vec![] + }; + + // 4. CPU collect_records (24ms, uses sorted unique addrs) + let (write_record_pairs, read_record_pairs) = info_span!("collect_records").in_scope(|| { + // This is the same logic as the existing path + let first_shard_access_later_recs: Vec<(ShardRamRecord, &'static str)> = + if shard_ctx.is_first_shard() { + final_mem + .par_iter() + .filter(|(_, range, _)| range.is_none()) + .flat_map(|(mem_name, _, final_mem)| { + final_mem.par_iter().filter_map(|mem_record| { + let (waddr, addr) = Self::mem_addresses(mem_record); + Self::make_cross_shard_record( + mem_name, + mem_record, + waddr, + addr, + shard_ctx, + &addr_accessed, + ) + }) + }) + .collect() + } else { + vec![] + }; + + let current_shard_access_later_recs: Vec<(ShardRamRecord, &'static str)> = final_mem + .par_iter() + .filter(|(_, range, _)| range.is_some()) + .flat_map(|(mem_name, range, final_mem)| { + let range = range.as_ref().unwrap(); + final_mem.par_iter().filter_map(|mem_record| { + let (waddr, addr) = Self::mem_addresses(mem_record); + if !range.contains(&addr) { + return None; + } + Self::make_cross_shard_record( + mem_name, + mem_record, + waddr, + addr, + shard_ctx, + &addr_accessed, + ) + }) + }) + .collect(); + + let write_record_pairs: Vec<(ShardRamRecord, &'static str)> = shard_ctx + .write_records() + .iter() + .flat_map(|records| { + records.iter().map(|(vma, record)| { + ((vma, record, true).into(), "current_shard_external_write") + }) + }) + .chain(first_shard_access_later_recs) + .chain(current_shard_access_later_recs) + .collect(); + + let read_record_pairs: Vec<(ShardRamRecord, &'static str)> = shard_ctx + .read_records() + .iter() + .flat_map(|records| { + records.iter().map(|(vma, record)| { + ((vma, record, false).into(), "current_shard_external_read") + }) + }) + .collect(); + + (write_record_pairs, read_record_pairs) + }); + + // 5. GPU batch EC on device for continuation records (25ms, results stay on GPU) + let (cont_ec_buf, cont_n_writes, cont_n_reads) = + info_span!("gpu_batch_ec_on_device").in_scope(|| { + gpu_batch_continuation_ec_on_device(&write_record_pairs, &read_record_pairs) + })?; + let cont_total = cont_n_writes + cont_n_reads; + + tracing::info!( + "[GPU full pipeline] batch EC on device: {} writes + {} reads = {} continuation records", + cont_n_writes, cont_n_reads, cont_total, + ); + + // 6. GPU merge shared_ec + batch_ec, then partition by is_to_write_set + let (partitioned_buf, num_writes, total_records) = + info_span!("gpu_merge_partition").in_scope(|| { + hal.merge_and_partition_records( + &shared.ec_buf, + ec_count, + &cont_ec_buf, + cont_total, + None, + ) + .map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU merge+partition: {e}").into()) + }) + })?; + + tracing::info!( + "[GPU full pipeline] merged+partitioned: {} total ({} writes, {} reads)", + total_records, num_writes, total_records - num_writes, + ); + + // 7. GPU assign_instances from device buffer (chunked by max_cross_shard) + assert!(self.combined_lk_mlt.is_some()); + let cs_inner = cs.get_cs(&ShardRamCircuit::::name()).unwrap(); + let num_witin = cs_inner.zkvm_v1_css.num_witin as usize; + let num_structural_witin = cs_inner.zkvm_v1_css.num_structural_witin as usize; + let max_chunk = shard_ctx.max_num_cross_shard_accesses; + + // Record sizes needed for chunking + let record_u32s = std::mem::size_of::() / 4; + + let circuit_inputs = info_span!("shard_ram_assign_from_device", n = total_records) + .in_scope(|| { + // Process chunks sequentially (each chunk uses GPU exclusively) + let mut inputs = Vec::new(); + let mut records_offset = 0usize; + let mut writes_remaining = num_writes; + + while records_offset < total_records { + let chunk_size = max_chunk.min(total_records - records_offset); + let chunk_writes = writes_remaining.min(chunk_size); + writes_remaining = writes_remaining.saturating_sub(chunk_size); + + // Create a view into the partitioned buffer for this chunk. + // SAFETY: chunk_buf borrows from partitioned_buf and is dropped + // at the end of each loop iteration, before partitioned_buf goes + // out of scope. The 'static lifetime is required by the HAL API. + let chunk_byte_start = records_offset * record_u32s * 4; + let chunk_byte_end = (records_offset + chunk_size) * record_u32s * 4; + let chunk_view = partitioned_buf.as_slice_range(chunk_byte_start..chunk_byte_end); + let chunk_buf: ceno_gpu::common::buffer::BufferImpl<'static, u32> = + unsafe { std::mem::transmute(ceno_gpu::common::buffer::BufferImpl::::new_from_view(chunk_view)) }; + + let witness = ShardRamCircuit::::try_gpu_assign_instances_from_device( + config, + num_witin, + num_structural_witin, + &chunk_buf, + chunk_size, + chunk_writes, + )?; + + let witness = witness.ok_or_else(|| { + ZKVMError::InvalidWitness("GPU shard_ram from_device returned None".into()) + })?; + + let num_reads = chunk_size - chunk_writes; + inputs.push(ChipInput::new( + ShardRamCircuit::::name(), + witness, + vec![chunk_writes, num_reads], + )); + + records_offset += chunk_size; + } + Ok::<_, ZKVMError>(inputs) + })?; + + assert!( + self.witnesses + .insert(ShardRamCircuit::::name(), circuit_inputs) + .is_none() + ); + + tracing::info!( + "[GPU full pipeline] assign_shared_circuit complete: {} total records", + total_records, + ); + + Ok(true) + } + pub fn get_witnesses_name_instance(&self) -> Vec<(String, Vec)> { self.witnesses .iter() diff --git a/ceno_zkvm/src/tables/shard_ram.rs b/ceno_zkvm/src/tables/shard_ram.rs index edddd7928..77d121148 100644 --- a/ceno_zkvm/src/tables/shard_ram.rs +++ b/ceno_zkvm/src/tables/shard_ram.rs @@ -895,6 +895,207 @@ impl ShardRamCircuit { Ok(Some([raw_witin, raw_structural_witin])) } + + /// GPU assign_instances from a device buffer of GpuShardRamRecord. + /// + /// Avoids the ShardRamInput → GpuShardRamRecord conversion and H2D transfer + /// by accepting records that are already on device (from opcode witgen + batch EC). + #[cfg(feature = "gpu")] + pub fn try_gpu_assign_instances_from_device( + config: &ShardRamConfig, + num_witin: usize, + num_structural_witin: usize, + device_records: &ceno_gpu::common::buffer::BufferImpl<'static, u32>, + num_records: usize, + num_local_writes: usize, + ) -> Result>, ZKVMError> { + use ceno_gpu::{ + Buffer, CudaHal, + bb31::CudaHalBB31, + common::transpose::matrix_transpose, + }; + use gkr_iop::gpu::gpu_prover::get_cuda_hal; + + type BB = ::BaseField; + + if std::any::TypeId::of::() != std::any::TypeId::of::() { + return Ok(None); + } + + let hal = match get_cuda_hal() { + Ok(h) => h, + Err(_) => return Ok(None), + }; + + let n = next_pow2_instance_padding(num_records); + let num_rows_padded = 2 * n; + + // 1. Extract column map (same as regular path) + let col_map = crate::instructions::riscv::gpu::shard_ram::extract_shard_ram_column_map( + config, num_witin, + ); + + // 2. GPU Phase 1: per-row assignment (records already on device) + let (gpu_witness, gpu_structural) = tracing::info_span!( + "gpu_shard_ram_per_row_from_device", + n = num_records, + num_rows_padded, + num_witin, + ).in_scope(|| hal + .witgen_shard_ram_per_row_from_device( + &col_map, + device_records, + num_records, + num_local_writes as u32, + num_witin as u32, + num_structural_witin as u32, + num_rows_padded as u32, + None, + ) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU shard_ram per-row (from_device) kernel failed: {e}").into(), + ) + }))?; + + // 3. GPU: extract EC points from device records (replaces CPU loop) + let witness_buf = tracing::info_span!( + "gpu_shard_ram_ec_tree_from_device", n + ).in_scope(|| -> Result<_, ZKVMError> { + let col_offsets = col_map.to_flat(); + let gpu_cols = hal.alloc_u32_from_host(&col_offsets, None).map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU alloc col offsets failed: {e}").into()) + })?; + + // Extract point_x/y from device records into flat arrays + let (mut cur_x, mut cur_y) = hal + .extract_ec_points_from_device(device_records, num_records, n, None) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU extract_ec_points failed: {e}").into(), + ) + })?; + + let mut witness_buf = gpu_witness.device_buffer; + let mut offset = num_rows_padded / 2; // n + let mut current_layer_len = n; + + loop { + if current_layer_len <= 1 { + break; + } + + let (next_x, next_y) = hal + .shard_ram_ec_tree_layer( + &gpu_cols, + &cur_x, + &cur_y, + &mut witness_buf, + current_layer_len, + offset, + num_rows_padded, + None, + ) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU EC tree layer failed: {e}").into(), + ) + })?; + + current_layer_len /= 2; + offset += current_layer_len; + cur_x = next_x; + cur_y = next_y; + } + + Ok(witness_buf) + })?; + + // 4. GPU transpose + D2H (same as regular path) + let (wit_data, struct_data) = tracing::info_span!( + "gpu_shard_ram_transpose_d2h_from_device", num_rows_padded, num_witin, + ).in_scope(|| -> Result<_, ZKVMError> { + let wit_num_rows = num_rows_padded; + let wit_num_cols = num_witin; + let mut rmm_buf = hal + .alloc_elems_on_device(wit_num_rows * wit_num_cols, false, None) + .map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU alloc for transpose failed: {e}").into()) + })?; + matrix_transpose::( + &hal.inner, + &mut rmm_buf, + &witness_buf, + wit_num_rows, + wit_num_cols, + ) + .map_err(|e| ZKVMError::InvalidWitness(format!("GPU transpose failed: {e}").into()))?; + + let gpu_wit_data: Vec = rmm_buf.to_vec().map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU D2H wit failed: {e}").into()) + })?; + let wit_data: Vec = unsafe { + let mut data = std::mem::ManuallyDrop::new(gpu_wit_data); + Vec::from_raw_parts( + data.as_mut_ptr() as *mut E::BaseField, + data.len(), + data.capacity(), + ) + }; + + let struct_num_cols = num_structural_witin; + let mut struct_rmm_buf = hal + .alloc_elems_on_device(wit_num_rows * struct_num_cols, false, None) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU alloc for struct transpose failed: {e}").into(), + ) + })?; + matrix_transpose::( + &hal.inner, + &mut struct_rmm_buf, + &gpu_structural.device_buffer, + wit_num_rows, + struct_num_cols, + ) + .map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU struct transpose failed: {e}").into()) + })?; + + let gpu_struct_data: Vec = struct_rmm_buf.to_vec().map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU D2H struct failed: {e}").into()) + })?; + let struct_data: Vec = unsafe { + let mut data = std::mem::ManuallyDrop::new(gpu_struct_data); + Vec::from_raw_parts( + data.as_mut_ptr() as *mut E::BaseField, + data.len(), + data.capacity(), + ) + }; + + Ok((wit_data, struct_data)) + })?; + + let raw_witin = witness::RowMajorMatrix::new_by_values( + wit_data, + num_witin, + InstancePaddingStrategy::Default, + ); + let raw_structural_witin = witness::RowMajorMatrix::new_by_values( + struct_data, + num_structural_witin, + InstancePaddingStrategy::Default, + ); + + tracing::info!( + "GPU shard_ram assign_instances (from_device) done: {} records, {} padded rows", + num_records, + num_rows_padded + ); + + Ok(Some([raw_witin, raw_structural_witin])) + } } #[cfg(test)] From 8fe9d32b2ce5f50cbad6e21d87fa11915645ab45 Mon Sep 17 00:00:00 2001 From: Velaciela Date: Mon, 23 Mar 2026 15:31:49 +0800 Subject: [PATCH 45/73] minor: LB LH --- ceno_zkvm/src/instructions/riscv/memory/load_v2.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs b/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs index 850d2dffc..8bc1eb1f2 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs @@ -45,7 +45,10 @@ impl Instruction for LoadInstruction; type InsnType = InsnKind; - const GPU_SIDE_EFFECTS: bool = matches!(I::INST_KIND, InsnKind::LW); + const GPU_SIDE_EFFECTS: bool = matches!( + I::INST_KIND, + InsnKind::LW | InsnKind::LB | InsnKind::LBU | InsnKind::LH | InsnKind::LHU + ); fn inst_kinds() -> &'static [Self::InsnType] { &[I::INST_KIND] From 23f04f06ffd61206141a509633c7a5ae76572e95 Mon Sep 17 00:00:00 2001 From: Velaciela Date: Mon, 23 Mar 2026 15:31:49 +0800 Subject: [PATCH 46/73] keccak --- Cargo.lock | 11 - Cargo.toml | 22 +- ceno_zkvm/src/instructions/riscv/ecall.rs | 2 +- .../src/instructions/riscv/ecall/keccak.rs | 23 +- .../src/instructions/riscv/gpu/keccak.rs | 206 ++++++++++++ ceno_zkvm/src/instructions/riscv/gpu/mod.rs | 2 + .../src/instructions/riscv/gpu/witgen_gpu.rs | 293 ++++++++++++++++++ ceno_zkvm/src/precompiles/mod.rs | 2 +- 8 files changed, 533 insertions(+), 28 deletions(-) create mode 100644 ceno_zkvm/src/instructions/riscv/gpu/keccak.rs diff --git a/Cargo.lock b/Cargo.lock index 79f9fb07b..883f5dc37 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2273,7 +2273,6 @@ dependencies = [ [[package]] name = "ff_ext" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.22#a140b93cf80109ef86c2327b9a940d4cace83628" dependencies = [ "once_cell", "p3", @@ -3294,7 +3293,6 @@ dependencies = [ [[package]] name = "mpcs" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.22#a140b93cf80109ef86c2327b9a940d4cace83628" dependencies = [ "bincode 1.3.3", "clap", @@ -3318,7 +3316,6 @@ dependencies = [ [[package]] name = "multilinear_extensions" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.22#a140b93cf80109ef86c2327b9a940d4cace83628" dependencies = [ "either", "ff_ext", @@ -4609,7 +4606,6 @@ dependencies = [ [[package]] name = "p3" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.22#a140b93cf80109ef86c2327b9a940d4cace83628" dependencies = [ "p3-air", "p3-baby-bear", @@ -5177,7 +5173,6 @@ dependencies = [ [[package]] name = "poseidon" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.22#a140b93cf80109ef86c2327b9a940d4cace83628" dependencies = [ "ff_ext", "p3", @@ -6147,7 +6142,6 @@ dependencies = [ [[package]] name = "sp1-curves" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.22#a140b93cf80109ef86c2327b9a940d4cace83628" dependencies = [ "cfg-if", "dashu", @@ -6291,7 +6285,6 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "sumcheck" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.22#a140b93cf80109ef86c2327b9a940d4cace83628" dependencies = [ "either", "ff_ext", @@ -6309,7 +6302,6 @@ dependencies = [ [[package]] name = "sumcheck_macro" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.22#a140b93cf80109ef86c2327b9a940d4cace83628" dependencies = [ "itertools 0.13.0", "p3", @@ -6716,7 +6708,6 @@ dependencies = [ [[package]] name = "transcript" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.22#a140b93cf80109ef86c2327b9a940d4cace83628" dependencies = [ "ff_ext", "itertools 0.13.0", @@ -7022,7 +7013,6 @@ dependencies = [ [[package]] name = "whir" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.22#a140b93cf80109ef86c2327b9a940d4cace83628" dependencies = [ "bincode 1.3.3", "clap", @@ -7318,7 +7308,6 @@ dependencies = [ [[package]] name = "witness" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.22#a140b93cf80109ef86c2327b9a940d4cace83628" dependencies = [ "ff_ext", "multilinear_extensions", diff --git a/Cargo.toml b/Cargo.toml index ab7009cfe..09e299842 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -130,17 +130,17 @@ lto = "thin" [patch."https://github.com/scroll-tech/ceno-gpu-mock.git"] ceno_gpu = { path = "../ceno-gpu/cuda_hal", package = "cuda_hal", default-features = false, features = ["bb31"] } -#[patch."https://github.com/scroll-tech/gkr-backend"] -#ff_ext = { path = "../gkr-backend/crates/ff_ext", package = "ff_ext" } -#mpcs = { path = "../gkr-backend/crates/mpcs", package = "mpcs" } -#multilinear_extensions = { path = "../gkr-backend/crates/multilinear_extensions", package = "multilinear_extensions" } -#p3 = { path = "../gkr-backend/crates/p3", package = "p3" } -#poseidon = { path = "../gkr-backend/crates/poseidon", package = "poseidon" } -#sp1-curves = { path = "../gkr-backend/crates/curves", package = "sp1-curves" } -#sumcheck = { path = "../gkr-backend/crates/sumcheck", package = "sumcheck" } -#transcript = { path = "../gkr-backend/crates/transcript", package = "transcript" } -#whir = { path = "../gkr-backend/crates/whir", package = "whir" } -#witness = { path = "../gkr-backend/crates/witness", package = "witness" } +[patch."https://github.com/scroll-tech/gkr-backend"] +ff_ext = { path = "../gkr-backend/crates/ff_ext", package = "ff_ext" } +mpcs = { path = "../gkr-backend/crates/mpcs", package = "mpcs" } +multilinear_extensions = { path = "../gkr-backend/crates/multilinear_extensions", package = "multilinear_extensions" } +p3 = { path = "../gkr-backend/crates/p3", package = "p3" } +poseidon = { path = "../gkr-backend/crates/poseidon", package = "poseidon" } +sp1-curves = { path = "../gkr-backend/crates/curves", package = "sp1-curves" } +sumcheck = { path = "../gkr-backend/crates/sumcheck", package = "sumcheck" } +transcript = { path = "../gkr-backend/crates/transcript", package = "transcript" } +whir = { path = "../gkr-backend/crates/whir", package = "whir" } +witness = { path = "../gkr-backend/crates/witness", package = "witness" } # [patch."https://github.com/scroll-tech/openvm.git"] # openvm = { path = "../openvm-scroll-tech/crates/toolchain/openvm", default-features = false } diff --git a/ceno_zkvm/src/instructions/riscv/ecall.rs b/ceno_zkvm/src/instructions/riscv/ecall.rs index 84a836bf3..b743e8bb3 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall.rs @@ -2,7 +2,7 @@ mod fptower_fp; mod fptower_fp2_add; mod fptower_fp2_mul; mod halt; -mod keccak; +pub(crate) mod keccak; mod sha_extend; mod uint256; mod weierstrass_add; diff --git a/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs b/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs index f9c9f1712..ac164b942 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs @@ -44,10 +44,10 @@ use crate::{ #[derive(Debug)] pub struct EcallKeccakConfig { pub layout: KeccakLayout, - vm_state: StateInOut, - ecall_id: OpFixedRS, - state_ptr: (OpFixedRS, MemAddr), - mem_rw: Vec, + pub(crate) vm_state: StateInOut, + pub(crate) ecall_id: OpFixedRS, + pub(crate) state_ptr: (OpFixedRS, MemAddr), + pub(crate) mem_rw: Vec, } /// KeccakInstruction can handle any instruction and produce its side-effects. @@ -178,6 +178,21 @@ impl Instruction for KeccakInstruction { steps: &[StepRecord], step_indices: &[StepIndex], ) -> Result<(RMMCollections, Multiplicity), ZKVMError> { + #[cfg(feature = "gpu")] + { + use crate::instructions::riscv::gpu::witgen_gpu::gpu_assign_keccak_instances; + if let Some(result) = gpu_assign_keccak_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + steps, + step_indices, + )? { + return Ok(result); + } + } + let mut lk_multiplicity = LkMultiplicity::default(); if step_indices.is_empty() { return Ok(( diff --git a/ceno_zkvm/src/instructions/riscv/gpu/keccak.rs b/ceno_zkvm/src/instructions/riscv/gpu/keccak.rs new file mode 100644 index 000000000..88af07403 --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/gpu/keccak.rs @@ -0,0 +1,206 @@ +use ceno_emul::{StepIndex, StepRecord}; +use ceno_gpu::common::witgen_types::{GpuKeccakInstance, GpuKeccakWriteOp, KeccakColumnMap}; +use ff_ext::ExtensionField; +use std::sync::Arc; + +use crate::{ + instructions::riscv::ecall::keccak::EcallKeccakConfig, + precompiles::lookup_keccakf::KECCAK_INPUT32_SIZE, +}; + +use ceno_emul::SyscallWitness; + +/// Extract column map from a constructed EcallKeccakConfig. +/// +/// VM state columns are listed individually. Keccak math columns use +/// a single base offset since they're allocated contiguously via transmute. +pub fn extract_keccak_column_map( + config: &EcallKeccakConfig, + num_witin: usize, +) -> KeccakColumnMap { + // StateInOut + let pc = config.vm_state.pc.id as u32; + let ts = config.vm_state.ts.id as u32; + + // OpFixedRS - ecall_id + let ecall_prev_ts = config.ecall_id.prev_ts.id as u32; + let ecall_lt_diff = { + let diffs = &config.ecall_id.lt_cfg.0.diff; + assert_eq!(diffs.len(), 2, "Expected 2 AssertLt diff limbs for ecall_id"); + [diffs[0].id as u32, diffs[1].id as u32] + }; + + // MemAddr - state_ptr address limbs + let addr_limbs = { + let limbs = config + .state_ptr + .1 + .addr + .wits_in() + .expect("MemAddr should have WitIn limbs"); + assert_eq!(limbs.len(), 2, "Expected 2 addr limbs"); + [limbs[0].id as u32, limbs[1].id as u32] + }; + + // OpFixedRS - state_ptr register write + let sptr_prev_ts = config.state_ptr.0.prev_ts.id as u32; + let sptr_prev_val = { + let limbs = config + .state_ptr + .0 + .prev_value + .as_ref() + .expect("state_ptr should have prev_value") + .wits_in() + .expect("prev_value should have WitIn limbs"); + assert_eq!(limbs.len(), 2, "Expected 2 prev_value limbs"); + [limbs[0].id as u32, limbs[1].id as u32] + }; + let sptr_lt_diff = { + let diffs = &config.state_ptr.0.lt_cfg.0.diff; + assert_eq!(diffs.len(), 2, "Expected 2 AssertLt diff limbs for state_ptr"); + [diffs[0].id as u32, diffs[1].id as u32] + }; + + // WriteMEM x50: prev_ts + lt_diff[2] each + let mut mem_prev_ts = [0u32; 50]; + let mut mem_lt_diff_0 = [0u32; 50]; + let mut mem_lt_diff_1 = [0u32; 50]; + for (i, writer) in config.mem_rw.iter().enumerate() { + mem_prev_ts[i] = writer.prev_ts.id as u32; + let diffs = &writer.lt_cfg.0.diff; + assert_eq!( + diffs.len(), + 2, + "Expected 2 AssertLt diff limbs for mem_rw[{}]", + i + ); + mem_lt_diff_0[i] = diffs[0].id as u32; + mem_lt_diff_1[i] = diffs[1].id as u32; + } + + // Keccak math columns base offset (contiguous block) + let keccak_base_col = config.layout.layer_exprs.wits.input8[0].id as u32; + + // Verify contiguity of keccak math columns + #[cfg(debug_assertions)] + { + let base = keccak_base_col as usize; + let expected_size = std::mem::size_of::>(); + // Check that the last keccak column is at base + expected_size - 1 + let last_rc = config.layout.layer_exprs.wits.rc.last().unwrap(); + assert_eq!( + last_rc.id as usize, + base + expected_size - 1, + "Keccak math columns not contiguous: last rc id {} != expected {}", + last_rc.id, + base + expected_size - 1 + ); + } + + KeccakColumnMap { + pc, + ts, + ecall_prev_ts, + ecall_lt_diff, + addr_limbs, + sptr_prev_ts, + sptr_prev_val, + sptr_lt_diff, + mem_prev_ts, + mem_lt_diff_0, + mem_lt_diff_1, + keccak_base_col, + num_cols: num_witin as u32, + } +} + +/// Pack step records + syscall witnesses into flat GPU-transferable instances. +pub fn pack_keccak_instances( + steps: &[StepRecord], + step_indices: &[StepIndex], + syscall_witnesses: &Arc>, +) -> Vec { + step_indices + .iter() + .map(|&idx| { + let step = &steps[idx]; + let sw = step + .syscall(syscall_witnesses) + .expect("keccak step must have syscall witness"); + + // Register op (state_ptr) + let reg_op = &sw.reg_ops[0]; + let gpu_reg_op = GpuKeccakWriteOp { + addr: reg_op.addr.0, + value_before: reg_op.value.before, + value_after: reg_op.value.after, + _pad: 0, + previous_cycle: reg_op.previous_cycle, + }; + + // Memory ops (50 read-writes) + let mut mem_ops = [GpuKeccakWriteOp::default(); 50]; + for (i, op) in sw.mem_ops.iter().enumerate() { + mem_ops[i] = GpuKeccakWriteOp { + addr: op.addr.0, + value_before: op.value.before, + value_after: op.value.after, + _pad: 0, + previous_cycle: op.previous_cycle, + }; + } + + GpuKeccakInstance { + pc: step.pc().before.0, + _pad0: 0, + cycle: step.cycle(), + ecall_prev_cycle: step.rs1().unwrap().previous_cycle, + reg_op: gpu_reg_op, + mem_ops, + } + }) + .collect() +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::{Instruction, riscv::ecall::keccak::KeccakInstruction}, + structs::ProgramParams, + }; + use ff_ext::BabyBearExt4; + + type E = BabyBearExt4; + + #[test] + fn test_extract_keccak_column_map() { + let mut cs = ConstraintSystem::::new(|| "test"); + let mut cb = CircuitBuilder::new(&mut cs); + let (config, _gkr_circuit) = + KeccakInstruction::::build_gkr_iop_circuit(&mut cb, &ProgramParams::default()) + .unwrap(); + + let col_map = extract_keccak_column_map(&config, cb.cs.num_witin as usize); + let flat = col_map.to_flat(); + + // All column IDs should be within range + // Note: keccak_base_col and num_cols are metadata, not column indices + let metadata_indices = [flat.len() - 1, flat.len() - 2]; // num_cols, keccak_base_col + for (i, &col) in flat.iter().enumerate() { + if metadata_indices.contains(&i) { + continue; + } + assert!( + (col as usize) < col_map.num_cols as usize, + "Column {} (index {}) out of range: {} >= {}", + i, + col, + col, + col_map.num_cols + ); + } + } +} diff --git a/ceno_zkvm/src/instructions/riscv/gpu/mod.rs b/ceno_zkvm/src/instructions/riscv/gpu/mod.rs index 03696d4af..c488455c6 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/mod.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/mod.rs @@ -43,6 +43,8 @@ pub mod sub; #[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] pub mod sw; #[cfg(feature = "gpu")] +pub mod keccak; +#[cfg(feature = "gpu")] pub mod shard_ram; #[cfg(feature = "gpu")] pub mod witgen_gpu; diff --git a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs index 0703c0c88..c29988d3f 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs @@ -2690,6 +2690,299 @@ fn gpu_lk_counters_to_multiplicity(counters: LkResult) -> Result( + config: &crate::instructions::riscv::ecall::keccak::EcallKeccakConfig, + shard_ctx: &mut ShardContext, + num_witin: usize, + num_structural_witin: usize, + steps: &[StepRecord], + step_indices: &[StepIndex], +) -> Result, Multiplicity)>, ZKVMError> { + use crate::precompiles::KECCAK_ROUNDS_CEIL_LOG2; + use gkr_iop::gpu::get_cuda_hal; + + // Guard: disabled or force-CPU + if is_gpu_witgen_disabled() || is_force_cpu_path() { + return Ok(None); + } + + // GPU only supports BabyBear field + if std::any::TypeId::of::() + != std::any::TypeId::of::<::BaseField>() + { + return Ok(None); + } + + let hal = match get_cuda_hal() { + Ok(hal) => hal, + Err(_) => return Ok(None), + }; + + // Empty step_indices: return empty matrices + if step_indices.is_empty() { + let rotation = KECCAK_ROUNDS_CEIL_LOG2; + let raw_witin = RowMajorMatrix::::new_by_rotation( + 0, + rotation, + num_witin, + InstancePaddingStrategy::Default, + ); + let raw_structural = RowMajorMatrix::::new_by_rotation( + 0, + rotation, + num_structural_witin, + InstancePaddingStrategy::Default, + ); + let lk = LkMultiplicity::default(); + return Ok(Some(( + [raw_witin, raw_structural], + lk.into_finalize_result(), + ))); + } + + let num_instances = step_indices.len(); + tracing::debug!( + "[GPU witgen] keccak with {} instances", + num_instances + ); + + info_span!("gpu_witgen_keccak", n = num_instances).in_scope(|| { + gpu_assign_keccak_inner::( + config, + shard_ctx, + num_witin, + num_structural_witin, + steps, + step_indices, + &hal, + ) + .map(Some) + }) +} + +#[cfg(feature = "gpu")] +fn gpu_assign_keccak_inner( + config: &crate::instructions::riscv::ecall::keccak::EcallKeccakConfig, + shard_ctx: &mut ShardContext, + num_witin: usize, + num_structural_witin: usize, + steps: &[StepRecord], + step_indices: &[StepIndex], + hal: &CudaHalBB31, +) -> Result<(RMMCollections, Multiplicity), ZKVMError> { + use crate::precompiles::KECCAK_ROUNDS_CEIL_LOG2; + + let num_instances = step_indices.len(); + let num_padded_instances = num_instances.next_power_of_two().max(2); + let num_padded_rows = num_padded_instances * 32; // 2^5 = 32 rows per instance + let rotation = KECCAK_ROUNDS_CEIL_LOG2; // = 5 + + // Step 1: Extract column map + let col_map = info_span!("col_map") + .in_scope(|| super::keccak::extract_keccak_column_map(config, num_witin)); + + // Step 2: Pack instances + let packed_instances = info_span!("pack_instances") + .in_scope(|| { + super::keccak::pack_keccak_instances( + steps, + step_indices, + &shard_ctx.syscall_witnesses, + ) + }); + + // Step 3: Compute fetch params + let (fetch_base_pc, fetch_num_slots) = compute_fetch_params(steps, step_indices); + + // Step 4: Ensure shard metadata cached + info_span!("ensure_shard_meta") + .in_scope(|| ensure_shard_metadata_cached(hal, shard_ctx, steps.len()))?; + + // Step 5: Launch GPU kernel + let gpu_result = info_span!("gpu_kernel").in_scope(|| { + with_cached_shard_meta(|shard_bufs| { + hal.witgen_keccak( + &col_map, + &packed_instances, + num_padded_rows, + shard_ctx.current_shard_offset_cycle(), + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), + ) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_keccak failed: {e}").into(), + ) + }) + }) + })?; + + // Step 6: Collect LK multiplicity + let lk_multiplicity = info_span!("gpu_lk_d2h") + .in_scope(|| gpu_lk_counters_to_multiplicity(gpu_result.lk_counters))?; + + // Step 7: Handle compact EC records (shared buffer path) + if gpu_result.compact_ec.is_none() && gpu_result.compact_addr.is_none() { + // Shared buffer path: EC records + addr_accessed accumulated on device + // in shared buffers across all kernel invocations. Skip per-kernel D2H. + } else if let Some(compact) = gpu_result.compact_ec { + info_span!("gpu_ec_shard").in_scope(|| { + let compact_records = info_span!("compact_d2h") + .in_scope(|| gpu_compact_ec_d2h(&compact))?; + + // D2H compact addr_accessed + info_span!("compact_addr_d2h").in_scope(|| -> Result<(), ZKVMError> { + if let Some(ref ca) = gpu_result.compact_addr { + let count_vec: Vec = ca.count_buf.to_vec().map_err(|e| { + ZKVMError::InvalidWitness( + format!("compact_addr_count D2H failed: {e}").into(), + ) + })?; + let n = count_vec[0] as usize; + if n > 0 { + let addrs: Vec = ca.buffer.to_vec_n(n).map_err(|e| { + ZKVMError::InvalidWitness( + format!("compact_addr D2H failed: {e}").into(), + ) + })?; + let mut forked = shard_ctx.get_forked(); + let thread_ctx = &mut forked[0]; + for &addr in &addrs { + thread_ctx.push_addr_accessed(WordAddr(addr)); + } + } + } + Ok(()) + })?; + + // Populate shard_ctx with GPU EC records + let raw_bytes = unsafe { + std::slice::from_raw_parts( + compact_records.as_ptr() as *const u8, + compact_records.len() * std::mem::size_of::(), + ) + }; + shard_ctx.extend_gpu_ec_records_raw(raw_bytes); + + Ok::<(), ZKVMError>(()) + })?; + } + + // Step 8: Transpose GPU witness (column-major -> row-major) + D2H + let raw_witin = info_span!("transpose_d2h").in_scope(|| { + let mut rmm_buffer = hal + .alloc_elems_on_device(num_padded_rows * num_witin, false, None) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU alloc for transpose failed: {e}").into(), + ) + })?; + matrix_transpose::( + &hal.inner, + &mut rmm_buffer, + &gpu_result.witness.device_buffer, + num_padded_rows, + num_witin, + ) + .map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU transpose failed: {e}").into()) + })?; + + let gpu_data: Vec<::BaseField> = + rmm_buffer.to_vec().map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU D2H copy failed: {e}").into()) + })?; + + // Safety: BabyBear is the only supported GPU field, and E::BaseField must match + let data: Vec = unsafe { + let mut data = std::mem::ManuallyDrop::new(gpu_data); + Vec::from_raw_parts( + data.as_mut_ptr() as *mut E::BaseField, + data.len(), + data.capacity(), + ) + }; + + Ok::<_, ZKVMError>(RowMajorMatrix::::from_values_with_rotation( + data, + num_witin, + rotation, + num_instances, + InstancePaddingStrategy::Default, + )) + })?; + + // Step 9: Build structural witness on CPU with selector indices + let raw_structural = info_span!("structural_witness").in_scope(|| { + let mut raw_structural = RowMajorMatrix::::new_by_rotation( + num_instances, + rotation, + num_structural_witin, + InstancePaddingStrategy::Default, + ); + + // Get selector column IDs from config + let sel_first = config + .layout + .selector_type_layout + .sel_first + .as_ref() + .expect("sel_first must be Some"); + let sel_last = config + .layout + .selector_type_layout + .sel_last + .as_ref() + .expect("sel_last must be Some"); + + let sel_first_id = sel_first.selector_expr().id(); + let sel_last_id = sel_last.selector_expr().id(); + let sel_all_id = config + .layout + .selector_type_layout + .sel_all + .selector_expr() + .id(); + + let sel_first_indices = sel_first.sparse_indices(); + let sel_last_indices = sel_last.sparse_indices(); + let sel_all_indices = config + .layout + .selector_type_layout + .sel_all + .sparse_indices(); + + for instance_chunk in raw_structural.iter_mut() { + // instance_chunk is a &mut [F] of size 32 * num_structural_witin + for &idx in sel_first_indices { + instance_chunk[idx * num_structural_witin + sel_first_id] = + E::BaseField::ONE; + } + for &idx in sel_last_indices { + instance_chunk[idx * num_structural_witin + sel_last_id] = + E::BaseField::ONE; + } + for &idx in sel_all_indices { + instance_chunk[idx * num_structural_witin + sel_all_id] = + E::BaseField::ONE; + } + } + raw_structural.padding_by_strategy(); + + raw_structural + }); + + Ok(([raw_witin, raw_structural], lk_multiplicity)) +} + /// Convert GPU device buffer (column-major) to RowMajorMatrix via GPU transpose + D2H copy. /// /// GPU witgen kernels output column-major layout for better memory coalescing. diff --git a/ceno_zkvm/src/precompiles/mod.rs b/ceno_zkvm/src/precompiles/mod.rs index 3d9a6e545..7e0e6d049 100644 --- a/ceno_zkvm/src/precompiles/mod.rs +++ b/ceno_zkvm/src/precompiles/mod.rs @@ -1,6 +1,6 @@ mod bitwise_keccakf; mod fptower; -mod lookup_keccakf; +pub(crate) mod lookup_keccakf; mod sha256; mod uint256; mod utils; From 58684909668b184ba9c51867964ed5019b02a428 Mon Sep 17 00:00:00 2001 From: Velaciela Date: Mon, 23 Mar 2026 15:31:49 +0800 Subject: [PATCH 47/73] fix-1 --- ceno_emul/src/test_utils.rs | 9 +- .../src/instructions/riscv/gpu/keccak.rs | 145 +++++++++++++++++- .../src/instructions/riscv/gpu/witgen_gpu.rs | 43 +++++- 3 files changed, 185 insertions(+), 12 deletions(-) diff --git a/ceno_emul/src/test_utils.rs b/ceno_emul/src/test_utils.rs index 625ed52c5..2c906c7b4 100644 --- a/ceno_emul/src/test_utils.rs +++ b/ceno_emul/src/test_utils.rs @@ -2,6 +2,7 @@ use crate::{ CENO_PLATFORM, InsnKind, Instruction, Platform, Program, StepRecord, VMState, encode_rv32, encode_rv32u, syscalls::{KECCAK_PERMUTE, SyscallWitness}, + tracer::FullTracerConfig, }; use anyhow::Result; @@ -24,7 +25,13 @@ pub fn keccak_step() -> (StepRecord, Vec, Vec) { instructions.clone(), Default::default(), ); - let mut vm = VMState::new(CENO_PLATFORM.clone(), program.into()); + let mut vm: VMState = VMState::new_with_tracer_config( + CENO_PLATFORM.clone(), + program.into(), + FullTracerConfig { + max_step_shard: 10, + }, + ); vm.iter_until_halt().collect::>>().unwrap(); let steps = vm.tracer().recorded_steps(); let syscall_witnesses = vm.tracer().syscall_witnesses().to_vec(); diff --git a/ceno_zkvm/src/instructions/riscv/gpu/keccak.rs b/ceno_zkvm/src/instructions/riscv/gpu/keccak.rs index 88af07403..31b2f206c 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/keccak.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/keccak.rs @@ -3,10 +3,7 @@ use ceno_gpu::common::witgen_types::{GpuKeccakInstance, GpuKeccakWriteOp, Keccak use ff_ext::ExtensionField; use std::sync::Arc; -use crate::{ - instructions::riscv::ecall::keccak::EcallKeccakConfig, - precompiles::lookup_keccakf::KECCAK_INPUT32_SIZE, -}; +use crate::instructions::riscv::ecall::keccak::EcallKeccakConfig; use ceno_emul::SyscallWitness; @@ -203,4 +200,144 @@ mod tests { ); } } + + #[test] + fn test_gpu_witgen_keccak_correctness() { + use crate::e2e::ShardContext; + + let mut cs = ConstraintSystem::::new(|| "test_keccak_gpu"); + let mut cb = CircuitBuilder::new(&mut cs); + let (config, _gkr_circuit) = + KeccakInstruction::::build_gkr_iop_circuit(&mut cb, &ProgramParams::default()) + .unwrap(); + let num_witin = cb.cs.num_witin as usize; + let num_structural_witin = cb.cs.num_structural_witin as usize; + + // Get test data from emulator + let (step, _program, syscall_witnesses) = ceno_emul::test_utils::keccak_step(); + let steps = vec![step]; + let step_indices: Vec = vec![0]; + + // --- CPU path (force CPU via thread-local flag) --- + use super::super::witgen_gpu::set_force_cpu_path; + set_force_cpu_path(true); + let mut shard_ctx = ShardContext::default(); + shard_ctx.syscall_witnesses = std::sync::Arc::new(syscall_witnesses.clone()); + let (cpu_rmms, _cpu_lkm) = KeccakInstruction::::assign_instances( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &step_indices, + ) + .unwrap(); + set_force_cpu_path(false); + let cpu_witness = &cpu_rmms[0]; + let cpu_structural = &cpu_rmms[1]; + + // --- GPU path (full pipeline via gpu_assign_keccak_instances) --- + use super::super::witgen_gpu::gpu_assign_keccak_instances; + let mut shard_ctx_gpu = ShardContext::default(); + shard_ctx_gpu.syscall_witnesses = std::sync::Arc::new(syscall_witnesses); + let (gpu_rmms, gpu_lk) = gpu_assign_keccak_instances::( + &config, + &mut shard_ctx_gpu, + num_witin, + num_structural_witin, + &steps, + &step_indices, + ) + .unwrap() + .expect("GPU path should not return None"); + let gpu_witness = &gpu_rmms[0]; + let gpu_structural = &gpu_rmms[1]; + + // --- Compare witness (raw_witin) --- + let gpu_data = gpu_witness.values(); + let cpu_data = cpu_witness.values(); + assert_eq!(gpu_data.len(), cpu_data.len(), "witness size mismatch"); + + let mut mismatches = 0; + for (i, (g, c)) in gpu_data.iter().zip(cpu_data.iter()).enumerate() { + if g != c { + if mismatches < 20 { + let row = i / num_witin; + let col = i % num_witin; + eprintln!( + "Witness mismatch row={}, col={}: GPU={:?}, CPU={:?}", + row, col, g, c + ); + } + mismatches += 1; + } + } + eprintln!( + "Keccak witness: {} mismatches out of {} cells", + mismatches, + gpu_data.len() + ); + + // --- Compare structural witness --- + let gpu_struct_data = gpu_structural.values(); + let cpu_struct_data = cpu_structural.values(); + assert_eq!( + gpu_struct_data.len(), + cpu_struct_data.len(), + "structural witness size mismatch" + ); + + let mut struct_mismatches = 0; + for (i, (g, c)) in gpu_struct_data.iter().zip(cpu_struct_data.iter()).enumerate() { + if g != c { + if struct_mismatches < 20 { + let row = i / num_structural_witin; + let col = i % num_structural_witin; + eprintln!( + "Structural mismatch row={}, col={}: GPU={:?}, CPU={:?}", + row, col, g, c + ); + } + struct_mismatches += 1; + } + } + eprintln!( + "Keccak structural: {} mismatches out of {} cells", + struct_mismatches, + gpu_struct_data.len() + ); + + // --- Compare LK multiplicity --- + let mut lk_mismatches = 0; + for (table_idx, (gpu_map, cpu_map)) in gpu_lk.0.iter().zip(_cpu_lkm.0.iter()).enumerate() { + for (&k, &gpu_v) in gpu_map.iter() { + let cpu_v = cpu_map.get(&k).copied().unwrap_or(0); + if gpu_v != cpu_v { + if lk_mismatches < 30 { + eprintln!( + "LK mismatch table={}, key={:#x}: GPU={}, CPU={}", + table_idx, k, gpu_v, cpu_v, + ); + } + lk_mismatches += 1; + } + } + for (&k, &cpu_v) in cpu_map.iter() { + if !gpu_map.contains_key(&k) { + if lk_mismatches < 30 { + eprintln!( + "LK mismatch table={}, key={:#x}: GPU=missing, CPU={}", + table_idx, k, cpu_v, + ); + } + lk_mismatches += 1; + } + } + } + eprintln!("Keccak LK: {} mismatches", lk_mismatches); + + assert_eq!(mismatches, 0, "GPU vs CPU witness mismatch"); + assert_eq!(struct_mismatches, 0, "GPU vs CPU structural witness mismatch"); + assert_eq!(lk_mismatches, 0, "GPU vs CPU LK multiplicity mismatch"); + } } diff --git a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs index c29988d3f..59e8f792a 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs @@ -221,6 +221,8 @@ fn ensure_shard_metadata_cached( shared_ec_count_ptr: _, shared_addr_out_ptr: _, shared_addr_count_ptr: _, + shared_ec_capacity: _, + shared_addr_capacity: _, } = c.device_bufs; next_access_packed }); @@ -322,11 +324,16 @@ fn ensure_shard_metadata_cached( ); // Allocate shared EC/addr compact buffers for this shard. - // Each step produces up to 3 EC records and 3 addr_accessed entries. - // These buffers persist across all kernel invocations within the shard. - let ec_capacity = n_total_steps * 4; // extra headroom + // + // EC records: cross-shard only (sparse subset of RAM ops). + // 104 bytes each (26 u32s). Cap at 16M entries ≈ 1.6 GB. + // Addr records: every gpu_send() emits one (dense). + // 4 bytes each (1 u32). Cap at 256M entries ≈ 1 GB. + let max_ops_per_step = 52u64; // keccak worst case + let total_ops_estimate = n_total_steps as u64 * max_ops_per_step; + let ec_capacity = total_ops_estimate.min(16 * 1024 * 1024) as usize; let ec_u32s = ec_capacity * 26; // 26 u32s per GpuShardRamRecord (104 bytes) - let addr_capacity = n_total_steps * 4; + let addr_capacity = total_ops_estimate.min(256 * 1024 * 1024) as usize; let shared_ec_buf = hal .alloc_u32_zeroed(ec_u32s, None) @@ -370,6 +377,8 @@ fn ensure_shard_metadata_cached( shared_ec_count_ptr, shared_addr_out_ptr, shared_addr_count_ptr, + shared_ec_capacity: ec_capacity as u32, + shared_addr_capacity: addr_capacity as u32, }, shared_ec_buf: Some(shared_ec_buf), shared_ec_count: Some(shared_ec_count), @@ -494,6 +503,15 @@ pub fn flush_shared_ec_buffers(shard_ctx: &mut ShardContext) -> Result<(), ZKVME ZKVMError::InvalidWitness(format!("shared_ec_count D2H: {e}").into()) })?; let ec_count = ec_count_vec[0] as usize; + let ec_capacity = c.device_bufs.shared_ec_capacity as usize; + + assert!( + ec_count <= ec_capacity, + "GPU shared EC buffer overflow: count={} > capacity={}. \ + Increase ec_capacity in ensure_shard_metadata_cached.", + ec_count, + ec_capacity, + ); if ec_count > 0 { // D2H EC records (only the active portion) @@ -524,6 +542,14 @@ pub fn flush_shared_ec_buffers(shard_ctx: &mut ShardContext) -> Result<(), ZKVME ZKVMError::InvalidWitness(format!("shared_addr_count D2H: {e}").into()) })?; let addr_count = addr_count_vec[0] as usize; + let addr_capacity = c.device_bufs.shared_addr_capacity as usize; + + assert!( + addr_count <= addr_capacity, + "GPU shared addr buffer overflow: count={} > capacity={}", + addr_count, + addr_capacity, + ); if addr_count > 0 { let addr_buf = c.shared_addr_buf.as_ref().unwrap(); @@ -998,7 +1024,7 @@ type LkResult = ceno_gpu::common::witgen_types::GpuLookupCountersResult; type CompactEcBuf = ceno_gpu::common::witgen_types::CompactEcResult; /// Compute fetch counter parameters from step data. -fn compute_fetch_params( +pub(crate) fn compute_fetch_params( shard_steps: &[StepRecord], step_indices: &[StepIndex], ) -> (u32, usize) { @@ -2618,7 +2644,7 @@ fn gpu_shard_ram_record_to_ec_point( } } -fn gpu_lk_counters_to_multiplicity(counters: LkResult) -> Result, ZKVMError> { +pub(crate) fn gpu_lk_counters_to_multiplicity(counters: LkResult) -> Result, ZKVMError> { let mut tables: [FxHashMap; 8] = Default::default(); // Dynamic: D2H + direct FxHashMap construction (no LkMultiplicity) @@ -2829,6 +2855,8 @@ fn gpu_assign_keccak_inner( let lk_multiplicity = info_span!("gpu_lk_d2h") .in_scope(|| gpu_lk_counters_to_multiplicity(gpu_result.lk_counters))?; + // Debug LK comparison is done in the unit test instead. + // Step 7: Handle compact EC records (shared buffer path) if gpu_result.compact_ec.is_none() && gpu_result.compact_addr.is_none() { // Shared buffer path: EC records + addr_accessed accumulated on device @@ -2960,7 +2988,8 @@ fn gpu_assign_keccak_inner( .sel_all .sparse_indices(); - for instance_chunk in raw_structural.iter_mut() { + // Only set selectors for real instances, not padding ones. + for instance_chunk in raw_structural.iter_mut().take(num_instances) { // instance_chunk is a &mut [F] of size 32 * num_structural_witin for &idx in sel_first_indices { instance_chunk[idx * num_structural_witin + sel_first_id] = From 1deb849762e0a1654295e54221ce25144c9ee5b0 Mon Sep 17 00:00:00 2001 From: Velaciela Date: Mon, 23 Mar 2026 15:31:49 +0800 Subject: [PATCH 48/73] fix-2 --- .../src/instructions/riscv/gpu/witgen_gpu.rs | 133 ++++++++++++++++++ .../src/instructions/riscv/memory/load.rs | 58 +++----- .../src/instructions/riscv/memory/load_v2.rs | 60 +++----- 3 files changed, 171 insertions(+), 80 deletions(-) diff --git a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs index 59e8f792a..92bdb0eb6 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs @@ -71,6 +71,7 @@ pub enum GpuWitgenKind { #[cfg(feature = "u16limb_circuit")] Div(u32), // 0=DIV, 1=DIVU, 2=REM, 3=REMU Lw, + Keccak, } /// Cached shard_steps device buffer with metadata for logging. @@ -702,6 +703,8 @@ fn kind_has_verified_shard(kind: GpuWitgenKind) -> bool { | GpuWitgenKind::LoadSub { .. } | GpuWitgenKind::Mul(_) | GpuWitgenKind::Div(_) => true, + // Keccak has its own dispatch path, never enters try_gpu_assign_instances. + GpuWitgenKind::Keccak => false, #[cfg(not(feature = "u16limb_circuit"))] _ => false, } @@ -1793,6 +1796,9 @@ fn gpu_fill_witness>( }) }) } + GpuWitgenKind::Keccak => { + unreachable!("keccak uses gpu_assign_keccak_instances, not try_gpu_assign_instances") + } } } @@ -1858,6 +1864,7 @@ fn kind_tag(kind: GpuWitgenKind) -> &'static str { GpuWitgenKind::Mul(_) => "mul", #[cfg(feature = "u16limb_circuit")] GpuWitgenKind::Div(_) => "div", + GpuWitgenKind::Keccak => "keccak", } } @@ -1917,6 +1924,8 @@ fn kind_has_verified_lk(kind: GpuWitgenKind) -> bool { GpuWitgenKind::Mul(_) => true, #[cfg(feature = "u16limb_circuit")] GpuWitgenKind::Div(_) => true, + // Keccak has its own dispatch path with its own LK handling. + GpuWitgenKind::Keccak => false, #[cfg(not(feature = "u16limb_circuit"))] _ => false, } @@ -3009,9 +3018,133 @@ fn gpu_assign_keccak_inner( raw_structural }); + // Debug comparisons (activated by env vars) + debug_compare_keccak::( + config, + shard_ctx, + num_witin, + num_structural_witin, + steps, + step_indices, + &lk_multiplicity, + &raw_witin, + )?; + Ok(([raw_witin, raw_structural], lk_multiplicity)) } +/// Debug comparison for keccak GPU witgen. +/// Runs the CPU path and compares LK / witness / shard side effects. +/// +/// Activated by CENO_GPU_DEBUG_COMPARE_LK, CENO_GPU_DEBUG_COMPARE_WITNESS, +/// or CENO_GPU_DEBUG_COMPARE_SHARD environment variables. +#[cfg(feature = "gpu")] +fn debug_compare_keccak( + config: &crate::instructions::riscv::ecall::keccak::EcallKeccakConfig, + shard_ctx: &ShardContext, + num_witin: usize, + num_structural_witin: usize, + steps: &[StepRecord], + step_indices: &[StepIndex], + gpu_lk: &Multiplicity, + gpu_witin: &RowMajorMatrix, +) -> Result<(), ZKVMError> { + let want_lk = std::env::var_os("CENO_GPU_DEBUG_COMPARE_LK").is_some(); + let want_witness = std::env::var_os("CENO_GPU_DEBUG_COMPARE_WITNESS").is_some(); + let want_shard = std::env::var_os("CENO_GPU_DEBUG_COMPARE_SHARD").is_some(); + + if !want_lk && !want_witness && !want_shard { + return Ok(()); + } + + // Guard against recursion: is_gpu_witgen_disabled() uses OnceLock so env var + // manipulation doesn't work. Use a thread-local flag instead. + thread_local! { + static IN_DEBUG_COMPARE: Cell = const { Cell::new(false) }; + } + if IN_DEBUG_COMPARE.with(|f| f.get()) { + return Ok(()); + } + IN_DEBUG_COMPARE.with(|f| f.set(true)); + + tracing::info!("[GPU keccak debug] running CPU baseline for comparison"); + + // Run CPU path via assign_instances. The IN_DEBUG_COMPARE guard prevents + // gpu_assign_keccak_instances from calling debug_compare_keccak again, + // so it will produce the GPU result, which is then returned without + // re-entering this function. We need assign_instances (not cpu_assign_instances) + // because keccak has rotation matrices and 3 structural columns. + // + // To force the CPU path, we use is_force_cpu_path() by setting the env var. + let mut cpu_ctx = shard_ctx.new_empty_like(); + let (cpu_rmms, cpu_lk) = { + use crate::instructions::riscv::ecall::keccak::KeccakInstruction; + // Set force-CPU flag so gpu_assign_keccak_instances returns None + set_force_cpu_path(true); + let result = as crate::instructions::Instruction>::assign_instances( + config, + &mut cpu_ctx, + num_witin, + num_structural_witin, + steps, + step_indices, + ); + set_force_cpu_path(false); + IN_DEBUG_COMPARE.with(|f| f.set(false)); + result? + }; + + let kind = GpuWitgenKind::Keccak; + + if want_lk { + tracing::info!("[GPU keccak debug] comparing LK multiplicities"); + log_lk_diff(kind, &cpu_lk, gpu_lk); + } + + if want_witness { + let limit = std::env::var("CENO_GPU_DEBUG_COMPARE_WITNESS_LIMIT") + .ok() + .and_then(|v| v.parse::().ok()) + .unwrap_or(32); + let cpu_witin = &cpu_rmms[0]; + let gpu_vals = gpu_witin.values(); + let cpu_vals = cpu_witin.values(); + let mut diffs = 0usize; + for (i, (g, c)) in gpu_vals.iter().zip(cpu_vals.iter()).enumerate() { + if g != c { + if diffs < limit { + let row = i / num_witin; + let col = i % num_witin; + tracing::error!( + "[GPU keccak witness] row={} col={} gpu={:?} cpu={:?}", + row, col, g, c + ); + } + diffs += 1; + } + } + if diffs == 0 { + tracing::info!("[GPU keccak debug] witness matrices match ({} elements)", gpu_vals.len()); + } else { + tracing::error!("[GPU keccak debug] witness mismatch: {} diffs out of {}", diffs, gpu_vals.len()); + } + } + + if want_shard { + // Note: keccak uses shared EC/addr buffers that are accumulated on-device + // and only flushed to shard_ctx in flush_shared_ec_buffers() after ALL + // opcode circuits complete. At this point shard_ctx.get_addr_accessed() + // is expected to be empty — the data is still on GPU. + let cpu_addr = cpu_ctx.get_addr_accessed(); + tracing::info!( + "[GPU keccak shard] CPU addr_accessed count={} (GPU uses shared buffer, flushed later)", + cpu_addr.len() + ); + } + + Ok(()) +} + /// Convert GPU device buffer (column-major) to RowMajorMatrix via GPU transpose + D2H copy. /// /// GPU witgen kernels output column-major layout for better memory coalescing. diff --git a/ceno_zkvm/src/instructions/riscv/memory/load.rs b/ceno_zkvm/src/instructions/riscv/memory/load.rs index fe8d3b2e2..a42b53e06 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/load.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/load.rs @@ -236,32 +236,21 @@ impl Instruction for LoadInstruction Result<(), ZKVMError> { - match I::INST_KIND { - InsnKind::LW => { - let shard_ctx_ptr = shard_ctx as *mut ShardContext; - let shard_ctx_view = unsafe { &*shard_ctx_ptr }; - let mut sink = - unsafe { CpuSideEffectSink::from_raw(shard_ctx_ptr, lk_multiplicity) }; - config - .im_insn - .collect_side_effects(&mut sink, shard_ctx_view, step); + let shard_ctx_ptr = shard_ctx as *mut ShardContext; + let shard_ctx_view = unsafe { &*shard_ctx_ptr }; + let mut sink = + unsafe { CpuSideEffectSink::from_raw(shard_ctx_ptr, lk_multiplicity) }; + config + .im_insn + .collect_side_effects(&mut sink, shard_ctx_view, step); - let imm = InsnRecord::::imm_internal(&step.insn()); - let unaligned_addr = - ByteAddr::from(step.rs1().unwrap().value.wrapping_add_signed(imm.0 as i32)); - config - .memory_addr - .collect_side_effects(&mut sink, unaligned_addr.into()); - Ok(()) - } - _ => Err(ZKVMError::InvalidWitness( - format!( - "lightweight side effects not implemented for {:?}", - I::INST_KIND - ) - .into(), - )), - } + let imm = InsnRecord::::imm_internal(&step.insn()); + let unaligned_addr = + ByteAddr::from(step.rs1().unwrap().value.wrapping_add_signed(imm.0 as i32)); + config + .memory_addr + .collect_side_effects(&mut sink, unaligned_addr.into()); + Ok(()) } fn collect_shard_side_effects_instance( @@ -270,21 +259,10 @@ impl Instruction for LoadInstruction Result<(), ZKVMError> { - match I::INST_KIND { - InsnKind::LW => { - config - .im_insn - .collect_shard_effects(shard_ctx, lk_multiplicity, step); - Ok(()) - } - _ => Err(ZKVMError::InvalidWitness( - format!( - "shard-only side effects not implemented for {:?}", - I::INST_KIND - ) - .into(), - )), - } + config + .im_insn + .collect_shard_effects(shard_ctx, lk_multiplicity, step); + Ok(()) } #[cfg(feature = "gpu")] diff --git a/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs b/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs index 8bc1eb1f2..7878d9994 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs @@ -264,32 +264,23 @@ impl Instruction for LoadInstruction Result<(), ZKVMError> { - match I::INST_KIND { - InsnKind::LW => { - let shard_ctx_ptr = shard_ctx as *mut ShardContext; - let shard_ctx_view = unsafe { &*shard_ctx_ptr }; - let mut sink = - unsafe { CpuSideEffectSink::from_raw(shard_ctx_ptr, lk_multiplicity) }; - config - .im_insn - .collect_side_effects(&mut sink, shard_ctx_view, step); + // Side effects (shard send/addr) are identical for all load types (LW/LH/LB/LHU/LBU). + // Sub-word extraction only affects LK emissions, handled separately by GPU kernel. + let shard_ctx_ptr = shard_ctx as *mut ShardContext; + let shard_ctx_view = unsafe { &*shard_ctx_ptr }; + let mut sink = + unsafe { CpuSideEffectSink::from_raw(shard_ctx_ptr, lk_multiplicity) }; + config + .im_insn + .collect_side_effects(&mut sink, shard_ctx_view, step); - let imm = InsnRecord::::imm_internal(&step.insn()); - let unaligned_addr = - ByteAddr::from(step.rs1().unwrap().value.wrapping_add_signed(imm.0 as i32)); - config - .memory_addr - .collect_side_effects(&mut sink, unaligned_addr.into()); - Ok(()) - } - _ => Err(ZKVMError::InvalidWitness( - format!( - "lightweight side effects not implemented for {:?}", - I::INST_KIND - ) - .into(), - )), - } + let imm = InsnRecord::::imm_internal(&step.insn()); + let unaligned_addr = + ByteAddr::from(step.rs1().unwrap().value.wrapping_add_signed(imm.0 as i32)); + config + .memory_addr + .collect_side_effects(&mut sink, unaligned_addr.into()); + Ok(()) } fn collect_shard_side_effects_instance( @@ -298,21 +289,10 @@ impl Instruction for LoadInstruction Result<(), ZKVMError> { - match I::INST_KIND { - InsnKind::LW => { - config - .im_insn - .collect_shard_effects(shard_ctx, lk_multiplicity, step); - Ok(()) - } - _ => Err(ZKVMError::InvalidWitness( - format!( - "shard-only side effects not implemented for {:?}", - I::INST_KIND - ) - .into(), - )), - } + config + .im_insn + .collect_shard_effects(shard_ctx, lk_multiplicity, step); + Ok(()) } #[cfg(feature = "gpu")] From 7c8d16a56401d7f28b116eaaa7aa022f6144e52e Mon Sep 17 00:00:00 2001 From: Velaciela Date: Mon, 23 Mar 2026 15:31:49 +0800 Subject: [PATCH 49/73] fix-3 --- .../src/instructions/riscv/gpu/witgen_gpu.rs | 102 ++++++++++++++++-- 1 file changed, 94 insertions(+), 8 deletions(-) diff --git a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs index 92bdb0eb6..6e5a8b03f 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs @@ -477,6 +477,34 @@ pub fn gpu_batch_continuation_ec_on_device( Ok((device_buf, n_writes, n_reads)) } +/// Read the current shared addr count from device (single u32 D2H). +/// Used by debug comparison to snapshot count before/after a kernel. +#[cfg(feature = "gpu")] +fn read_shared_addr_count() -> usize { + SHARD_META_CACHE.with(|cache| { + let cache = cache.borrow(); + let c = cache.as_ref().expect("shard metadata not cached"); + let buf = c.shared_addr_count.as_ref().expect("shared_addr_count not allocated"); + let v: Vec = buf.to_vec().expect("shared_addr_count D2H failed"); + v[0] as usize + }) +} + +/// Read a range of addr entries [start..end) from the shared addr buffer. +#[cfg(feature = "gpu")] +fn read_shared_addr_range(start: usize, end: usize) -> Vec { + if start >= end { + return Vec::new(); + } + SHARD_META_CACHE.with(|cache| { + let cache = cache.borrow(); + let c = cache.as_ref().expect("shard metadata not cached"); + let buf = c.shared_addr_buf.as_ref().expect("shared_addr_buf not allocated"); + let all: Vec = buf.to_vec_n(end).expect("shared_addr_buf D2H failed"); + all[start..end].to_vec() + }) +} + /// Batch D2H of shared EC records and addr_accessed buffers after all kernel invocations. /// /// Called once per shard after all opcode `gpu_assign_instances_inner` calls complete. @@ -2746,6 +2774,10 @@ pub fn gpu_assign_keccak_instances( if is_gpu_witgen_disabled() || is_force_cpu_path() { return Ok(None); } + // CENO_GPU_DISABLE_KECCAK=1 → fall back to CPU keccak witgen + if std::env::var_os("CENO_GPU_DISABLE_KECCAK").is_some() { + return Ok(None); + } // GPU only supports BabyBear field if std::any::TypeId::of::() @@ -2839,6 +2871,13 @@ fn gpu_assign_keccak_inner( info_span!("ensure_shard_meta") .in_scope(|| ensure_shard_metadata_cached(hal, shard_ctx, steps.len()))?; + // Snapshot shared addr count before kernel (for debug comparison) + let addr_count_before = if std::env::var_os("CENO_GPU_DEBUG_COMPARE_SHARD").is_some() { + read_shared_addr_count() + } else { + 0 + }; + // Step 5: Launch GPU kernel let gpu_result = info_span!("gpu_kernel").in_scope(|| { with_cached_shard_meta(|shard_bufs| { @@ -2860,6 +2899,18 @@ fn gpu_assign_keccak_inner( }) })?; + // D2H keccak's addr entries from shared buffer (delta since before kernel) + let gpu_keccak_addrs = if std::env::var_os("CENO_GPU_DEBUG_COMPARE_SHARD").is_some() { + let addr_count_after = read_shared_addr_count(); + if addr_count_after > addr_count_before { + read_shared_addr_range(addr_count_before, addr_count_after) + } else { + Vec::new() + } + } else { + Vec::new() + }; + // Step 6: Collect LK multiplicity let lk_multiplicity = info_span!("gpu_lk_d2h") .in_scope(|| gpu_lk_counters_to_multiplicity(gpu_result.lk_counters))?; @@ -3028,6 +3079,7 @@ fn gpu_assign_keccak_inner( step_indices, &lk_multiplicity, &raw_witin, + &gpu_keccak_addrs, )?; Ok(([raw_witin, raw_structural], lk_multiplicity)) @@ -3048,6 +3100,7 @@ fn debug_compare_keccak( step_indices: &[StepIndex], gpu_lk: &Multiplicity, gpu_witin: &RowMajorMatrix, + gpu_addrs: &[u32], ) -> Result<(), ZKVMError> { let want_lk = std::env::var_os("CENO_GPU_DEBUG_COMPARE_LK").is_some(); let want_witness = std::env::var_os("CENO_GPU_DEBUG_COMPARE_WITNESS").is_some(); @@ -3131,15 +3184,48 @@ fn debug_compare_keccak( } if want_shard { - // Note: keccak uses shared EC/addr buffers that are accumulated on-device - // and only flushed to shard_ctx in flush_shared_ec_buffers() after ALL - // opcode circuits complete. At this point shard_ctx.get_addr_accessed() - // is expected to be empty — the data is still on GPU. + // Compare addr_accessed: GPU entries were D2H'd from the shared buffer + // delta (before/after kernel launch) and passed in as gpu_addrs. let cpu_addr = cpu_ctx.get_addr_accessed(); - tracing::info!( - "[GPU keccak shard] CPU addr_accessed count={} (GPU uses shared buffer, flushed later)", - cpu_addr.len() - ); + let gpu_addr_set: rustc_hash::FxHashSet = + gpu_addrs.iter().map(|&a| WordAddr(a)).collect(); + + if cpu_addr.len() != gpu_addr_set.len() { + tracing::error!( + "[GPU keccak shard] addr_accessed count mismatch: cpu={} gpu={}", + cpu_addr.len(), gpu_addr_set.len() + ); + } + let mut missing_from_gpu = 0usize; + let mut extra_in_gpu = 0usize; + let limit = 16usize; + for addr in &cpu_addr { + if !gpu_addr_set.contains(addr) { + if missing_from_gpu < limit { + tracing::error!("[GPU keccak shard] addr {} in CPU but not GPU", addr.0); + } + missing_from_gpu += 1; + } + } + for &addr in gpu_addrs { + if !cpu_addr.contains(&WordAddr(addr)) { + if extra_in_gpu < limit { + tracing::error!("[GPU keccak shard] addr {} in GPU but not CPU", addr); + } + extra_in_gpu += 1; + } + } + if missing_from_gpu == 0 && extra_in_gpu == 0 { + tracing::info!( + "[GPU keccak shard] addr_accessed matches: {} entries", + cpu_addr.len() + ); + } else { + tracing::error!( + "[GPU keccak shard] addr_accessed diff: missing_from_gpu={} extra_in_gpu={}", + missing_from_gpu, extra_in_gpu + ); + } } Ok(()) From ed9eb6e7ac55db5b1a2f546d164507e926336d1b Mon Sep 17 00:00:00 2001 From: Velaciela Date: Mon, 23 Mar 2026 15:31:49 +0800 Subject: [PATCH 50/73] minor --- ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs index 6e5a8b03f..682798fdc 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs @@ -1020,7 +1020,7 @@ fn gpu_assign_instances_inner>( raw_structural.padding_by_strategy(); // Step 4: Transpose (column-major → row-major) on GPU, then D2H copy to RowMajorMatrix - let mut raw_witin = info_span!("transpose_d2h").in_scope(|| { + let mut raw_witin = info_span!("transpose_d2h", rows = total_instances, cols = num_witin).in_scope(|| { gpu_witness_to_rmm::( hal, gpu_witness, @@ -2965,7 +2965,7 @@ fn gpu_assign_keccak_inner( } // Step 8: Transpose GPU witness (column-major -> row-major) + D2H - let raw_witin = info_span!("transpose_d2h").in_scope(|| { + let raw_witin = info_span!("transpose_d2h", rows = num_padded_rows, cols = num_witin).in_scope(|| { let mut rmm_buffer = hal .alloc_elems_on_device(num_padded_rows * num_witin, false, None) .map_err(|e| { From cc0d7eeb157077ea58c87cf7adacd86b36c666e8 Mon Sep 17 00:00:00 2001 From: Velaciela Date: Mon, 23 Mar 2026 17:40:47 +0800 Subject: [PATCH 51/73] part-a --- ceno_zkvm/src/instructions/riscv/gpu/add.rs | 93 +---------- ceno_zkvm/src/instructions/riscv/gpu/addi.rs | 67 +------- ceno_zkvm/src/instructions/riscv/gpu/auipc.rs | 61 +------ .../src/instructions/riscv/gpu/branch_cmp.rs | 52 +----- .../src/instructions/riscv/gpu/branch_eq.rs | 52 +----- .../src/instructions/riscv/gpu/colmap_base.rs | 155 ++++++++++++++++++ ceno_zkvm/src/instructions/riscv/gpu/div.rs | 74 ++------- ceno_zkvm/src/instructions/riscv/gpu/jal.rs | 53 +----- ceno_zkvm/src/instructions/riscv/gpu/jalr.rs | 80 ++------- .../src/instructions/riscv/gpu/load_sub.rs | 62 +------ .../src/instructions/riscv/gpu/logic_i.rs | 87 +--------- .../src/instructions/riscv/gpu/logic_r.rs | 94 +---------- ceno_zkvm/src/instructions/riscv/gpu/lui.rs | 46 +----- ceno_zkvm/src/instructions/riscv/gpu/lw.rs | 63 +------ ceno_zkvm/src/instructions/riscv/gpu/mod.rs | 2 + ceno_zkvm/src/instructions/riscv/gpu/mul.rs | 52 +----- ceno_zkvm/src/instructions/riscv/gpu/sb.rs | 84 ++-------- ceno_zkvm/src/instructions/riscv/gpu/sh.rs | 84 ++-------- .../src/instructions/riscv/gpu/shift_i.rs | 76 +-------- .../src/instructions/riscv/gpu/shift_r.rs | 94 +---------- ceno_zkvm/src/instructions/riscv/gpu/slt.rs | 78 +-------- ceno_zkvm/src/instructions/riscv/gpu/slti.rs | 60 +------ ceno_zkvm/src/instructions/riscv/gpu/sub.rs | 93 +---------- ceno_zkvm/src/instructions/riscv/gpu/sw.rs | 84 ++-------- 24 files changed, 340 insertions(+), 1406 deletions(-) create mode 100644 ceno_zkvm/src/instructions/riscv/gpu/colmap_base.rs diff --git a/ceno_zkvm/src/instructions/riscv/gpu/add.rs b/ceno_zkvm/src/instructions/riscv/gpu/add.rs index 4df630312..beaf9693f 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/add.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/add.rs @@ -1,6 +1,7 @@ use ceno_gpu::common::witgen_types::AddColumnMap; use ff_ext::ExtensionField; +use super::colmap_base::{extract_carries, extract_rd, extract_rs1, extract_rs2, extract_state, extract_uint_limbs}; use crate::instructions::riscv::arith::ArithConfig; /// Extract column map from a constructed ArithConfig (ADD variant). @@ -11,75 +12,14 @@ pub fn extract_add_column_map( config: &ArithConfig, num_witin: usize, ) -> AddColumnMap { - // StateInOut - let pc = config.r_insn.vm_state.pc.id as u32; - let ts = config.r_insn.vm_state.ts.id as u32; + let (pc, ts) = extract_state(&config.r_insn.vm_state); + let (rs1_id, rs1_prev_ts, rs1_lt_diff) = extract_rs1(&config.r_insn.rs1); + let (rs2_id, rs2_prev_ts, rs2_lt_diff) = extract_rs2(&config.r_insn.rs2); + let (rd_id, rd_prev_ts, rd_prev_val, rd_lt_diff) = extract_rd(&config.r_insn.rd); - // ReadRS1 - let rs1_id = config.r_insn.rs1.id.id as u32; - let rs1_prev_ts = config.r_insn.rs1.prev_ts.id as u32; - let rs1_lt_diff: [u32; 2] = { - let diffs = &config.r_insn.rs1.lt_cfg.0.diff; - assert_eq!(diffs.len(), 2, "Expected 2 AssertLt diff limbs for RS1"); - [diffs[0].id as u32, diffs[1].id as u32] - }; - - // ReadRS2 - let rs2_id = config.r_insn.rs2.id.id as u32; - let rs2_prev_ts = config.r_insn.rs2.prev_ts.id as u32; - let rs2_lt_diff: [u32; 2] = { - let diffs = &config.r_insn.rs2.lt_cfg.0.diff; - assert_eq!(diffs.len(), 2, "Expected 2 AssertLt diff limbs for RS2"); - [diffs[0].id as u32, diffs[1].id as u32] - }; - - // WriteRD - let rd_id = config.r_insn.rd.id.id as u32; - let rd_prev_ts = config.r_insn.rd.prev_ts.id as u32; - let rd_prev_val: [u32; 2] = { - let limbs = config - .r_insn - .rd - .prev_value - .wits_in() - .expect("WriteRD prev_value should have WitIn limbs"); - assert_eq!(limbs.len(), 2, "Expected 2 prev_value limbs"); - [limbs[0].id as u32, limbs[1].id as u32] - }; - let rd_lt_diff: [u32; 2] = { - let diffs = &config.r_insn.rd.lt_cfg.0.diff; - assert_eq!(diffs.len(), 2, "Expected 2 AssertLt diff limbs for RD"); - [diffs[0].id as u32, diffs[1].id as u32] - }; - - // Arithmetic: rs1/rs2 u16 limbs - let rs1_limbs: [u32; 2] = { - let limbs = config - .rs1_read - .wits_in() - .expect("rs1_read should have WitIn limbs"); - assert_eq!(limbs.len(), 2, "Expected 2 rs1_read limbs"); - [limbs[0].id as u32, limbs[1].id as u32] - }; - let rs2_limbs: [u32; 2] = { - let limbs = config - .rs2_read - .wits_in() - .expect("rs2_read should have WitIn limbs"); - assert_eq!(limbs.len(), 2, "Expected 2 rs2_read limbs"); - [limbs[0].id as u32, limbs[1].id as u32] - }; - - // rd carries - let rd_carries: [u32; 2] = { - let carries = config - .rd_written - .carries - .as_ref() - .expect("rd_written should have carries"); - assert_eq!(carries.len(), 2, "Expected 2 rd_written carries"); - [carries[0].id as u32, carries[1].id as u32] - }; + let rs1_limbs = extract_uint_limbs::(&config.rs1_read, "rs1_read"); + let rs2_limbs = extract_uint_limbs::(&config.rs2_read, "rs2_read"); + let rd_carries = extract_carries::(&config.rd_written, "rd_written"); AddColumnMap { pc, @@ -191,22 +131,7 @@ mod tests { let col_map = extract_add_column_map(&config, cb.cs.num_witin as usize); let flat = col_map.to_flat(); - // All column IDs should be unique and within range - for (i, &col) in flat.iter().enumerate() { - assert!( - (col as usize) < col_map.num_cols as usize, - "Column {} (index {}) out of range: {} >= {}", - i, - col, - col, - col_map.num_cols - ); - } - // Check uniqueness - let mut seen = std::collections::HashSet::new(); - for &col in &flat { - assert!(seen.insert(col), "Duplicate column ID: {}", col); - } + crate::instructions::riscv::gpu::colmap_base::validate_column_map(&flat, col_map.num_cols); } #[test] diff --git a/ceno_zkvm/src/instructions/riscv/gpu/addi.rs b/ceno_zkvm/src/instructions/riscv/gpu/addi.rs index 485eee423..9996633d1 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/addi.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/addi.rs @@ -1,6 +1,7 @@ use ceno_gpu::common::witgen_types::AddiColumnMap; use ff_ext::ExtensionField; +use super::colmap_base::{extract_carries, extract_rd, extract_rs1, extract_state, extract_uint_limbs}; use crate::instructions::riscv::arith_imm::arith_imm_circuit_v2::InstructionConfig; /// Extract column map from a constructed InstructionConfig (ADDI v2). @@ -10,54 +11,14 @@ pub fn extract_addi_column_map( ) -> AddiColumnMap { let im = &config.i_insn; - // StateInOut - let pc = im.vm_state.pc.id as u32; - let ts = im.vm_state.ts.id as u32; - - // ReadRS1 - let rs1_id = im.rs1.id.id as u32; - let rs1_prev_ts = im.rs1.prev_ts.id as u32; - let rs1_lt_diff: [u32; 2] = { - let d = &im.rs1.lt_cfg.0.diff; - assert_eq!(d.len(), 2); - [d[0].id as u32, d[1].id as u32] - }; - - // WriteRD - let rd_id = im.rd.id.id as u32; - let rd_prev_ts = im.rd.prev_ts.id as u32; - let rd_prev_val: [u32; 2] = { - let l = im.rd.prev_value.wits_in().expect("rd prev_value WitIns"); - assert_eq!(l.len(), 2); - [l[0].id as u32, l[1].id as u32] - }; - let rd_lt_diff: [u32; 2] = { - let d = &im.rd.lt_cfg.0.diff; - assert_eq!(d.len(), 2); - [d[0].id as u32, d[1].id as u32] - }; - - // rs1 u16 limbs - let rs1_limbs: [u32; 2] = { - let l = config.rs1_read.wits_in().expect("rs1_read WitIns"); - assert_eq!(l.len(), 2); - [l[0].id as u32, l[1].id as u32] - }; + let (pc, ts) = extract_state(&im.vm_state); + let (rs1_id, rs1_prev_ts, rs1_lt_diff) = extract_rs1(&im.rs1); + let (rd_id, rd_prev_ts, rd_prev_val, rd_lt_diff) = extract_rd(&im.rd); - // imm and imm_sign + let rs1_limbs = extract_uint_limbs::(&config.rs1_read, "rs1_read"); let imm = config.imm.id as u32; let imm_sign = config.imm_sign.id as u32; - - // rd carries (from the add operation: rs1 + sign_extend(imm)) - let rd_carries: [u32; 2] = { - let carries = config - .rd_written - .carries - .as_ref() - .expect("rd_written should have carries for ADDI"); - assert_eq!(carries.len(), 2); - [carries[0].id as u32, carries[1].id as u32] - }; + let rd_carries = extract_carries::(&config.rd_written, "rd_written"); AddiColumnMap { pc, @@ -98,21 +59,7 @@ mod tests { let col_map = extract_addi_column_map(&config, cb.cs.num_witin as usize); let flat = col_map.to_flat(); - - for (i, &col) in flat.iter().enumerate() { - assert!( - (col as usize) < col_map.num_cols as usize, - "Column {} (index {}) out of range: {} >= {}", - i, - col, - col, - col_map.num_cols - ); - } - let mut seen = std::collections::HashSet::new(); - for &col in &flat { - assert!(seen.insert(col), "Duplicate column ID: {}", col); - } + crate::instructions::riscv::gpu::colmap_base::validate_column_map(&flat, col_map.num_cols); } #[test] diff --git a/ceno_zkvm/src/instructions/riscv/gpu/auipc.rs b/ceno_zkvm/src/instructions/riscv/gpu/auipc.rs index 431d1b257..9ece38a58 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/auipc.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/auipc.rs @@ -1,6 +1,7 @@ use ceno_gpu::common::witgen_types::AuipcColumnMap; use ff_ext::ExtensionField; +use super::colmap_base::{extract_rd, extract_rs1, extract_state, extract_uint_limbs}; use crate::instructions::riscv::auipc::AuipcConfig; /// Extract column map from a constructed AuipcConfig. @@ -10,47 +11,11 @@ pub fn extract_auipc_column_map( ) -> AuipcColumnMap { let im = &config.i_insn; - // StateInOut - let pc = im.vm_state.pc.id as u32; - let ts = im.vm_state.ts.id as u32; - - // ReadRS1 - let rs1_id = im.rs1.id.id as u32; - let rs1_prev_ts = im.rs1.prev_ts.id as u32; - let rs1_lt_diff: [u32; 2] = { - let d = &im.rs1.lt_cfg.0.diff; - assert_eq!(d.len(), 2); - [d[0].id as u32, d[1].id as u32] - }; - - // WriteRD - let rd_id = im.rd.id.id as u32; - let rd_prev_ts = im.rd.prev_ts.id as u32; - let rd_prev_val: [u32; 2] = { - let l = im.rd.prev_value.wits_in().expect("rd prev_value WitIns"); - assert_eq!(l.len(), 2); - [l[0].id as u32, l[1].id as u32] - }; - let rd_lt_diff: [u32; 2] = { - let d = &im.rd.lt_cfg.0.diff; - assert_eq!(d.len(), 2); - [d[0].id as u32, d[1].id as u32] - }; + let (pc, ts) = extract_state(&im.vm_state); + let (rs1_id, rs1_prev_ts, rs1_lt_diff) = extract_rs1(&im.rs1); + let (rd_id, rd_prev_ts, rd_prev_val, rd_lt_diff) = extract_rd(&im.rd); - // AUIPC-specific - let rd_bytes: [u32; 4] = { - let l = config - .rd_written - .wits_in() - .expect("rd_written UInt8 WitIns"); - assert_eq!(l.len(), 4); - [ - l[0].id as u32, - l[1].id as u32, - l[2].id as u32, - l[3].id as u32, - ] - }; + let rd_bytes = extract_uint_limbs::(&config.rd_written, "rd_written"); let pc_limbs: [u32; 2] = [config.pc_limbs[0].id as u32, config.pc_limbs[1].id as u32]; let imm_limbs: [u32; 3] = [ config.imm_limbs[0].id as u32, @@ -96,21 +61,7 @@ mod tests { let col_map = extract_auipc_column_map(&config, cb.cs.num_witin as usize); let flat = col_map.to_flat(); - - for (i, &col) in flat.iter().enumerate() { - assert!( - (col as usize) < col_map.num_cols as usize, - "Column {} (index {}) out of range: {} >= {}", - i, - col, - col, - col_map.num_cols - ); - } - let mut seen = std::collections::HashSet::new(); - for &col in &flat { - assert!(seen.insert(col), "Duplicate column ID: {}", col); - } + crate::instructions::riscv::gpu::colmap_base::validate_column_map(&flat, col_map.num_cols); } #[test] diff --git a/ceno_zkvm/src/instructions/riscv/gpu/branch_cmp.rs b/ceno_zkvm/src/instructions/riscv/gpu/branch_cmp.rs index 572e5bdbb..6edaf174d 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/branch_cmp.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/branch_cmp.rs @@ -1,6 +1,7 @@ use ceno_gpu::common::witgen_types::BranchCmpColumnMap; use ff_ext::ExtensionField; +use super::colmap_base::{extract_rs1, extract_rs2, extract_state_branching, extract_uint_limbs}; use crate::instructions::riscv::branch::branch_circuit_v2::BranchConfig; /// Extract column map from a constructed BranchConfig (BLT/BGE/BLTU/BGEU variant). @@ -8,16 +9,8 @@ pub fn extract_branch_cmp_column_map( config: &BranchConfig, num_witin: usize, ) -> BranchCmpColumnMap { - let rs1_limbs: [u32; 2] = { - let limbs = config.read_rs1.wits_in().expect("rs1 WitIn"); - assert_eq!(limbs.len(), 2); - [limbs[0].id as u32, limbs[1].id as u32] - }; - let rs2_limbs: [u32; 2] = { - let limbs = config.read_rs2.wits_in().expect("rs2 WitIn"); - assert_eq!(limbs.len(), 2); - [limbs[0].id as u32, limbs[1].id as u32] - }; + let rs1_limbs = extract_uint_limbs::(&config.read_rs1, "read_rs1"); + let rs2_limbs = extract_uint_limbs::(&config.read_rs2, "read_rs2"); let lt_config = config.uint_lt_config.as_ref().unwrap(); let cmp_lt = lt_config.cmp_lt.id as u32; @@ -29,26 +22,9 @@ pub fn extract_branch_cmp_column_map( ]; let diff_val = lt_config.diff_val.id as u32; - let pc = config.b_insn.vm_state.pc.id as u32; - let next_pc = config.b_insn.vm_state.next_pc.unwrap().id as u32; - let ts = config.b_insn.vm_state.ts.id as u32; - - let rs1_id = config.b_insn.rs1.id.id as u32; - let rs1_prev_ts = config.b_insn.rs1.prev_ts.id as u32; - let rs1_lt_diff: [u32; 2] = { - let d = &config.b_insn.rs1.lt_cfg.0.diff; - assert_eq!(d.len(), 2); - [d[0].id as u32, d[1].id as u32] - }; - - let rs2_id = config.b_insn.rs2.id.id as u32; - let rs2_prev_ts = config.b_insn.rs2.prev_ts.id as u32; - let rs2_lt_diff: [u32; 2] = { - let d = &config.b_insn.rs2.lt_cfg.0.diff; - assert_eq!(d.len(), 2); - [d[0].id as u32, d[1].id as u32] - }; - + let (pc, next_pc, ts) = extract_state_branching(&config.b_insn.vm_state); + let (rs1_id, rs1_prev_ts, rs1_lt_diff) = extract_rs1(&config.b_insn.rs1); + let (rs2_id, rs2_prev_ts, rs2_lt_diff) = extract_rs2(&config.b_insn.rs2); let imm = config.b_insn.imm.id as u32; BranchCmpColumnMap { @@ -94,21 +70,7 @@ mod tests { let col_map = extract_branch_cmp_column_map(&config, cb.cs.num_witin as usize); let flat = col_map.to_flat(); - - for (i, &col) in flat.iter().enumerate() { - assert!( - (col as usize) < col_map.num_cols as usize, - "Column {} (index {}) out of range: {} >= {}", - i, - col, - col, - col_map.num_cols - ); - } - let mut seen = std::collections::HashSet::new(); - for &col in &flat { - assert!(seen.insert(col), "Duplicate column ID: {}", col); - } + crate::instructions::riscv::gpu::colmap_base::validate_column_map(&flat, col_map.num_cols); } #[test] diff --git a/ceno_zkvm/src/instructions/riscv/gpu/branch_eq.rs b/ceno_zkvm/src/instructions/riscv/gpu/branch_eq.rs index 178b16fab..4a51655bd 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/branch_eq.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/branch_eq.rs @@ -1,6 +1,7 @@ use ceno_gpu::common::witgen_types::BranchEqColumnMap; use ff_ext::ExtensionField; +use super::colmap_base::{extract_rs1, extract_rs2, extract_state_branching, extract_uint_limbs}; use crate::instructions::riscv::branch::branch_circuit_v2::BranchConfig; /// Extract column map from a constructed BranchConfig (BEQ/BNE variant). @@ -8,16 +9,8 @@ pub fn extract_branch_eq_column_map( config: &BranchConfig, num_witin: usize, ) -> BranchEqColumnMap { - let rs1_limbs: [u32; 2] = { - let limbs = config.read_rs1.wits_in().expect("rs1 WitIn"); - assert_eq!(limbs.len(), 2); - [limbs[0].id as u32, limbs[1].id as u32] - }; - let rs2_limbs: [u32; 2] = { - let limbs = config.read_rs2.wits_in().expect("rs2 WitIn"); - assert_eq!(limbs.len(), 2); - [limbs[0].id as u32, limbs[1].id as u32] - }; + let rs1_limbs = extract_uint_limbs::(&config.read_rs1, "read_rs1"); + let rs2_limbs = extract_uint_limbs::(&config.read_rs2, "read_rs2"); let branch_taken = config.eq_branch_taken_bit.as_ref().unwrap().id as u32; let diff_inv_marker: [u32; 2] = { @@ -25,26 +18,9 @@ pub fn extract_branch_eq_column_map( [markers[0].id as u32, markers[1].id as u32] }; - let pc = config.b_insn.vm_state.pc.id as u32; - let next_pc = config.b_insn.vm_state.next_pc.unwrap().id as u32; - let ts = config.b_insn.vm_state.ts.id as u32; - - let rs1_id = config.b_insn.rs1.id.id as u32; - let rs1_prev_ts = config.b_insn.rs1.prev_ts.id as u32; - let rs1_lt_diff: [u32; 2] = { - let d = &config.b_insn.rs1.lt_cfg.0.diff; - assert_eq!(d.len(), 2); - [d[0].id as u32, d[1].id as u32] - }; - - let rs2_id = config.b_insn.rs2.id.id as u32; - let rs2_prev_ts = config.b_insn.rs2.prev_ts.id as u32; - let rs2_lt_diff: [u32; 2] = { - let d = &config.b_insn.rs2.lt_cfg.0.diff; - assert_eq!(d.len(), 2); - [d[0].id as u32, d[1].id as u32] - }; - + let (pc, next_pc, ts) = extract_state_branching(&config.b_insn.vm_state); + let (rs1_id, rs1_prev_ts, rs1_lt_diff) = extract_rs1(&config.b_insn.rs1); + let (rs2_id, rs2_prev_ts, rs2_lt_diff) = extract_rs2(&config.b_insn.rs2); let imm = config.b_insn.imm.id as u32; BranchEqColumnMap { @@ -87,21 +63,7 @@ mod tests { let col_map = extract_branch_eq_column_map(&config, cb.cs.num_witin as usize); let flat = col_map.to_flat(); - - for (i, &col) in flat.iter().enumerate() { - assert!( - (col as usize) < col_map.num_cols as usize, - "Column {} (index {}) out of range: {} >= {}", - i, - col, - col, - col_map.num_cols - ); - } - let mut seen = std::collections::HashSet::new(); - for &col in &flat { - assert!(seen.insert(col), "Duplicate column ID: {}", col); - } + crate::instructions::riscv::gpu::colmap_base::validate_column_map(&flat, col_map.num_cols); } #[test] diff --git a/ceno_zkvm/src/instructions/riscv/gpu/colmap_base.rs b/ceno_zkvm/src/instructions/riscv/gpu/colmap_base.rs new file mode 100644 index 000000000..3d48e6c08 --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/gpu/colmap_base.rs @@ -0,0 +1,155 @@ +//! Shared column-map extraction helpers for GPU witness generation. +//! +//! Each GPU column-map extractor (`add.rs`, `sub.rs`, …) reads `WitIn.id` +//! fields from a circuit config and packs them into a `#[repr(C)]` struct +//! for the CUDA kernel. The base instruction fields (pc, ts, rs1, rs2, rd, +//! mem timestamps) are identical across instruction formats — these helpers +//! eliminate the duplication. + +use ff_ext::ExtensionField; +use gkr_iop::gadgets::AssertLtConfig; +use multilinear_extensions::WitIn; + +use crate::{ + instructions::riscv::insn_base::{ReadMEM, ReadRS1, ReadRS2, StateInOut, WriteMEM, WriteRD}, + uint::UIntLimbs, +}; + +// --------------------------------------------------------------------------- +// StateInOut +// --------------------------------------------------------------------------- + +/// Extract `(pc, ts)` from a non-branching `StateInOut`. +#[inline] +pub fn extract_state(vm: &StateInOut) -> (u32, u32) { + (vm.pc.id as u32, vm.ts.id as u32) +} + +/// Extract `(pc, next_pc, ts)` from a branching `StateInOut`. +#[inline] +pub fn extract_state_branching(vm: &StateInOut) -> (u32, u32, u32) { + ( + vm.pc.id as u32, + vm.next_pc.expect("branching StateInOut must have next_pc").id as u32, + vm.ts.id as u32, + ) +} + +// --------------------------------------------------------------------------- +// Register reads / writes +// --------------------------------------------------------------------------- + +/// Extract `(id, prev_ts, lt_diff[2])` from a `ReadRS1`. +#[inline] +pub fn extract_rs1(rs1: &ReadRS1) -> (u32, u32, [u32; 2]) { + ( + rs1.id.id as u32, + rs1.prev_ts.id as u32, + extract_lt_diff(&rs1.lt_cfg), + ) +} + +/// Extract `(id, prev_ts, lt_diff[2])` from a `ReadRS2`. +#[inline] +pub fn extract_rs2(rs2: &ReadRS2) -> (u32, u32, [u32; 2]) { + ( + rs2.id.id as u32, + rs2.prev_ts.id as u32, + extract_lt_diff(&rs2.lt_cfg), + ) +} + +/// Extract `(id, prev_ts, prev_val[2], lt_diff[2])` from a `WriteRD`. +#[inline] +pub fn extract_rd(rd: &WriteRD) -> (u32, u32, [u32; 2], [u32; 2]) { + let prev_val = extract_uint_limbs::(&rd.prev_value, "WriteRD prev_value"); + ( + rd.id.id as u32, + rd.prev_ts.id as u32, + prev_val, + extract_lt_diff(&rd.lt_cfg), + ) +} + +// --------------------------------------------------------------------------- +// Memory reads / writes +// --------------------------------------------------------------------------- + +/// Extract `(prev_ts, lt_diff[2])` from a `ReadMEM`. +#[inline] +pub fn extract_read_mem(mem: &ReadMEM) -> (u32, [u32; 2]) { + (mem.prev_ts.id as u32, extract_lt_diff(&mem.lt_cfg)) +} + +/// Extract `(prev_ts, lt_diff[2])` from a `WriteMEM`. +#[inline] +pub fn extract_write_mem(mem: &WriteMEM) -> (u32, [u32; 2]) { + (mem.prev_ts.id as u32, extract_lt_diff(&mem.lt_cfg)) +} + +// --------------------------------------------------------------------------- +// Primitive helpers +// --------------------------------------------------------------------------- + +/// Extract the two diff-limb column IDs from an `AssertLtConfig`. +#[inline] +pub fn extract_lt_diff(lt: &AssertLtConfig) -> [u32; 2] { + let d = <.0.diff; + assert_eq!(d.len(), 2, "Expected 2 AssertLt diff limbs, got {}", d.len()); + [d[0].id as u32, d[1].id as u32] +} + +/// Extract `N` limb column IDs from a `UIntLimbs` via `wits_in()`. +#[inline] +pub fn extract_uint_limbs( + u: &UIntLimbs, + label: &str, +) -> [u32; N] { + let limbs = u.wits_in().unwrap_or_else(|| panic!("{label} should have WitIn limbs")); + assert_eq!(limbs.len(), N, "Expected {N} limbs for {label}, got {}", limbs.len()); + std::array::from_fn(|i| limbs[i].id as u32) +} + +/// Extract `N` carry column IDs from a `UIntLimbs`'s carries. +#[inline] +pub fn extract_carries( + u: &UIntLimbs, + label: &str, +) -> [u32; N] { + let carries = u + .carries + .as_ref() + .unwrap_or_else(|| panic!("{label} should have carries")); + assert_eq!(carries.len(), N, "Expected {N} carries for {label}, got {}", carries.len()); + std::array::from_fn(|i| carries[i].id as u32) +} + +/// Extract `N` column IDs from a `&[WitIn]` slice (e.g. byte decomposition). +#[inline] +pub fn extract_wit_ids(wits: &[WitIn], label: &str) -> [u32; N] { + assert_eq!( + wits.len(), + N, + "Expected {N} WitIn entries for {label}, got {}", + wits.len() + ); + std::array::from_fn(|i| wits[i].id as u32) +} + +// --------------------------------------------------------------------------- +// Test helper +// --------------------------------------------------------------------------- + +#[cfg(test)] +pub fn validate_column_map(flat: &[u32], num_cols: u32) { + for (i, &col) in flat.iter().enumerate() { + assert!( + col < num_cols, + "Column index {i}: value {col} out of range (num_cols = {num_cols})" + ); + } + let mut seen = std::collections::HashSet::new(); + for &col in flat { + assert!(seen.insert(col), "Duplicate column ID: {col}"); + } +} diff --git a/ceno_zkvm/src/instructions/riscv/gpu/div.rs b/ceno_zkvm/src/instructions/riscv/gpu/div.rs index f7420445c..817495fec 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/div.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/div.rs @@ -1,6 +1,7 @@ use ceno_gpu::common::witgen_types::DivColumnMap; use ff_ext::ExtensionField; +use super::colmap_base::{extract_rd, extract_rs1, extract_rs2, extract_state, extract_uint_limbs}; use crate::instructions::riscv::div::div_circuit_v2::DivRemConfig; /// Extract column map from a constructed DivRemConfig. @@ -9,62 +10,15 @@ pub fn extract_div_column_map( config: &DivRemConfig, num_witin: usize, ) -> DivColumnMap { - let r = &config.r_insn; + let (pc, ts) = extract_state(&config.r_insn.vm_state); + let (rs1_id, rs1_prev_ts, rs1_lt_diff) = extract_rs1(&config.r_insn.rs1); + let (rs2_id, rs2_prev_ts, rs2_lt_diff) = extract_rs2(&config.r_insn.rs2); + let (rd_id, rd_prev_ts, rd_prev_val, rd_lt_diff) = extract_rd(&config.r_insn.rd); - // R-type base - let pc = r.vm_state.pc.id as u32; - let ts = r.vm_state.ts.id as u32; - - let rs1_id = r.rs1.id.id as u32; - let rs1_prev_ts = r.rs1.prev_ts.id as u32; - let rs1_lt_diff: [u32; 2] = { - let d = &r.rs1.lt_cfg.0.diff; - assert_eq!(d.len(), 2); - [d[0].id as u32, d[1].id as u32] - }; - - let rs2_id = r.rs2.id.id as u32; - let rs2_prev_ts = r.rs2.prev_ts.id as u32; - let rs2_lt_diff: [u32; 2] = { - let d = &r.rs2.lt_cfg.0.diff; - assert_eq!(d.len(), 2); - [d[0].id as u32, d[1].id as u32] - }; - - let rd_id = r.rd.id.id as u32; - let rd_prev_ts = r.rd.prev_ts.id as u32; - let rd_prev_val: [u32; 2] = { - let l = r.rd.prev_value.wits_in().expect("rd prev_value WitIns"); - assert_eq!(l.len(), 2); - [l[0].id as u32, l[1].id as u32] - }; - let rd_lt_diff: [u32; 2] = { - let d = &r.rd.lt_cfg.0.diff; - assert_eq!(d.len(), 2); - [d[0].id as u32, d[1].id as u32] - }; - - // Div-specific: operand limbs - let dividend: [u32; 2] = { - let l = config.dividend.wits_in().expect("dividend WitIns"); - assert_eq!(l.len(), 2); - [l[0].id as u32, l[1].id as u32] - }; - let divisor: [u32; 2] = { - let l = config.divisor.wits_in().expect("divisor WitIns"); - assert_eq!(l.len(), 2); - [l[0].id as u32, l[1].id as u32] - }; - let quotient: [u32; 2] = { - let l = config.quotient.wits_in().expect("quotient WitIns"); - assert_eq!(l.len(), 2); - [l[0].id as u32, l[1].id as u32] - }; - let remainder: [u32; 2] = { - let l = config.remainder.wits_in().expect("remainder WitIns"); - assert_eq!(l.len(), 2); - [l[0].id as u32, l[1].id as u32] - }; + let dividend = extract_uint_limbs::(&config.dividend, "dividend"); + let divisor = extract_uint_limbs::(&config.divisor, "divisor"); + let quotient = extract_uint_limbs::(&config.quotient, "quotient"); + let remainder = extract_uint_limbs::(&config.remainder, "remainder"); // Sign/control bits let dividend_sign = config.dividend_sign.id as u32; @@ -84,15 +38,7 @@ pub fn extract_div_column_map( // sign_xor let sign_xor = config.sign_xor.id as u32; - // remainder_prime - let remainder_prime: [u32; 2] = { - let l = config - .remainder_prime - .wits_in() - .expect("remainder_prime WitIns"); - assert_eq!(l.len(), 2); - [l[0].id as u32, l[1].id as u32] - }; + let remainder_prime = extract_uint_limbs::(&config.remainder_prime, "remainder_prime"); // lt_marker let lt_marker: [u32; 2] = [config.lt_marker[0].id as u32, config.lt_marker[1].id as u32]; diff --git a/ceno_zkvm/src/instructions/riscv/gpu/jal.rs b/ceno_zkvm/src/instructions/riscv/gpu/jal.rs index 61710ef80..fd446f185 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/jal.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/jal.rs @@ -1,6 +1,7 @@ use ceno_gpu::common::witgen_types::JalColumnMap; use ff_ext::ExtensionField; +use super::colmap_base::{extract_rd, extract_state_branching, extract_uint_limbs}; use crate::instructions::riscv::jump::jal_v2::JalConfig; /// Extract column map from a constructed JalConfig. @@ -10,39 +11,9 @@ pub fn extract_jal_column_map( ) -> JalColumnMap { let jm = &config.j_insn; - // StateInOut (J-type: has next_pc) - let pc = jm.vm_state.pc.id as u32; - let next_pc = jm.vm_state.next_pc.expect("JAL must have next_pc").id as u32; - let ts = jm.vm_state.ts.id as u32; - - // WriteRD - let rd_id = jm.rd.id.id as u32; - let rd_prev_ts = jm.rd.prev_ts.id as u32; - let rd_prev_val: [u32; 2] = { - let l = jm.rd.prev_value.wits_in().expect("rd prev_value WitIns"); - assert_eq!(l.len(), 2); - [l[0].id as u32, l[1].id as u32] - }; - let rd_lt_diff: [u32; 2] = { - let d = &jm.rd.lt_cfg.0.diff; - assert_eq!(d.len(), 2); - [d[0].id as u32, d[1].id as u32] - }; - - // JAL-specific: rd u8 bytes - let rd_bytes: [u32; 4] = { - let l = config - .rd_written - .wits_in() - .expect("rd_written UInt8 WitIns"); - assert_eq!(l.len(), 4); - [ - l[0].id as u32, - l[1].id as u32, - l[2].id as u32, - l[3].id as u32, - ] - }; + let (pc, next_pc, ts) = extract_state_branching(&jm.vm_state); + let (rd_id, rd_prev_ts, rd_prev_val, rd_lt_diff) = extract_rd(&jm.rd); + let rd_bytes = extract_uint_limbs::(&config.rd_written, "rd_written"); JalColumnMap { pc, @@ -78,21 +49,7 @@ mod tests { let col_map = extract_jal_column_map(&config, cb.cs.num_witin as usize); let flat = col_map.to_flat(); - - for (i, &col) in flat.iter().enumerate() { - assert!( - (col as usize) < col_map.num_cols as usize, - "Column {} (index {}) out of range: {} >= {}", - i, - col, - col, - col_map.num_cols - ); - } - let mut seen = std::collections::HashSet::new(); - for &col in &flat { - assert!(seen.insert(col), "Duplicate column ID: {}", col); - } + crate::instructions::riscv::gpu::colmap_base::validate_column_map(&flat, col_map.num_cols); } #[test] diff --git a/ceno_zkvm/src/instructions/riscv/gpu/jalr.rs b/ceno_zkvm/src/instructions/riscv/gpu/jalr.rs index 03f6c510c..804b14934 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/jalr.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/jalr.rs @@ -1,6 +1,9 @@ use ceno_gpu::common::witgen_types::JalrColumnMap; use ff_ext::ExtensionField; +use super::colmap_base::{ + extract_rd, extract_rs1, extract_state_branching, extract_uint_limbs, extract_wit_ids, +}; use crate::instructions::riscv::jump::jalr_v2::JalrConfig; /// Extract column map from a constructed JalrConfig. @@ -10,64 +13,15 @@ pub fn extract_jalr_column_map( ) -> JalrColumnMap { let im = &config.i_insn; - // StateInOut (branching=true → has next_pc) - let pc = im.vm_state.pc.id as u32; - let next_pc = im.vm_state.next_pc.expect("JALR must have next_pc").id as u32; - let ts = im.vm_state.ts.id as u32; - - // ReadRS1 - let rs1_id = im.rs1.id.id as u32; - let rs1_prev_ts = im.rs1.prev_ts.id as u32; - let rs1_lt_diff: [u32; 2] = { - let d = &im.rs1.lt_cfg.0.diff; - assert_eq!(d.len(), 2); - [d[0].id as u32, d[1].id as u32] - }; - - // WriteRD - let rd_id = im.rd.id.id as u32; - let rd_prev_ts = im.rd.prev_ts.id as u32; - let rd_prev_val: [u32; 2] = { - let l = im.rd.prev_value.wits_in().expect("rd prev_value WitIns"); - assert_eq!(l.len(), 2); - [l[0].id as u32, l[1].id as u32] - }; - let rd_lt_diff: [u32; 2] = { - let d = &im.rd.lt_cfg.0.diff; - assert_eq!(d.len(), 2); - [d[0].id as u32, d[1].id as u32] - }; - - // JALR-specific: rs1 u16 limbs - let rs1_limbs: [u32; 2] = { - let l = config.rs1_read.wits_in().expect("rs1_read WitIns"); - assert_eq!(l.len(), 2); - [l[0].id as u32, l[1].id as u32] - }; + let (pc, next_pc, ts) = extract_state_branching(&im.vm_state); + let (rs1_id, rs1_prev_ts, rs1_lt_diff) = extract_rs1(&im.rs1); + let (rd_id, rd_prev_ts, rd_prev_val, rd_lt_diff) = extract_rd(&im.rd); - // imm, imm_sign + let rs1_limbs = extract_uint_limbs::(&config.rs1_read, "rs1_read"); let imm = config.imm.id as u32; let imm_sign = config.imm_sign.id as u32; - - // jump_pc_addr: MemAddr has addr (UInt = 2 limbs) + low_bits (Vec) - let jump_pc_addr: [u32; 2] = { - let l = config - .jump_pc_addr - .addr - .wits_in() - .expect("jump_pc_addr WitIns"); - assert_eq!(l.len(), 2); - [l[0].id as u32, l[1].id as u32] - }; - let jump_pc_addr_bit: [u32; 2] = { - let bits = &config.jump_pc_addr.low_bits; - assert_eq!( - bits.len(), - 2, - "JALR MemAddr with n_zeros=0 must have 2 low_bits" - ); - [bits[0].id as u32, bits[1].id as u32] - }; + let jump_pc_addr = extract_uint_limbs::(&config.jump_pc_addr.addr, "jump_pc_addr"); + let jump_pc_addr_bit = extract_wit_ids::<2>(&config.jump_pc_addr.low_bits, "jump_pc_addr low_bits"); // rd_high let rd_high = config.rd_high.id as u32; @@ -114,21 +68,7 @@ mod tests { let col_map = extract_jalr_column_map(&config, cb.cs.num_witin as usize); let flat = col_map.to_flat(); - - for (i, &col) in flat.iter().enumerate() { - assert!( - (col as usize) < col_map.num_cols as usize, - "Column {} (index {}) out of range: {} >= {}", - i, - col, - col, - col_map.num_cols - ); - } - let mut seen = std::collections::HashSet::new(); - for &col in &flat { - assert!(seen.insert(col), "Duplicate column ID: {}", col); - } + crate::instructions::riscv::gpu::colmap_base::validate_column_map(&flat, col_map.num_cols); } #[test] diff --git a/ceno_zkvm/src/instructions/riscv/gpu/load_sub.rs b/ceno_zkvm/src/instructions/riscv/gpu/load_sub.rs index 787f091c2..4efb17c12 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/load_sub.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/load_sub.rs @@ -1,6 +1,7 @@ use ceno_gpu::common::witgen_types::LoadSubColumnMap; use ff_ext::ExtensionField; +use super::colmap_base::{extract_rd, extract_read_mem, extract_rs1, extract_state, extract_uint_limbs}; use crate::instructions::riscv::memory::load_v2::LoadConfig; /// Extract column map from a constructed LoadConfig for sub-word loads (LH/LHU/LB/LBU). @@ -12,63 +13,16 @@ pub fn extract_load_sub_column_map( ) -> LoadSubColumnMap { let im = &config.im_insn; - // StateInOut - let pc = im.vm_state.pc.id as u32; - let ts = im.vm_state.ts.id as u32; + let (pc, ts) = extract_state(&im.vm_state); + let (rs1_id, rs1_prev_ts, rs1_lt_diff) = extract_rs1(&im.rs1); + let (rd_id, rd_prev_ts, rd_prev_val, rd_lt_diff) = extract_rd(&im.rd); + let (mem_prev_ts, mem_lt_diff) = extract_read_mem(&im.mem_read); - // ReadRS1 - let rs1_id = im.rs1.id.id as u32; - let rs1_prev_ts = im.rs1.prev_ts.id as u32; - let rs1_lt_diff: [u32; 2] = { - let d = &im.rs1.lt_cfg.0.diff; - assert_eq!(d.len(), 2); - [d[0].id as u32, d[1].id as u32] - }; - - // WriteRD - let rd_id = im.rd.id.id as u32; - let rd_prev_ts = im.rd.prev_ts.id as u32; - let rd_prev_val: [u32; 2] = { - let l = im.rd.prev_value.wits_in().expect("rd prev_value WitIns"); - assert_eq!(l.len(), 2); - [l[0].id as u32, l[1].id as u32] - }; - let rd_lt_diff: [u32; 2] = { - let d = &im.rd.lt_cfg.0.diff; - assert_eq!(d.len(), 2); - [d[0].id as u32, d[1].id as u32] - }; - - // ReadMEM - let mem_prev_ts = im.mem_read.prev_ts.id as u32; - let mem_lt_diff: [u32; 2] = { - let d = &im.mem_read.lt_cfg.0.diff; - assert_eq!(d.len(), 2); - [d[0].id as u32, d[1].id as u32] - }; - - // Load-specific - let rs1_limbs: [u32; 2] = { - let l = config.rs1_read.wits_in().expect("rs1_read WitIns"); - assert_eq!(l.len(), 2); - [l[0].id as u32, l[1].id as u32] - }; + let rs1_limbs = extract_uint_limbs::(&config.rs1_read, "rs1_read"); let imm = config.imm.id as u32; let imm_sign = config.imm_sign.id as u32; - let mem_addr: [u32; 2] = { - let l = config - .memory_addr - .addr - .wits_in() - .expect("memory_addr WitIns"); - assert_eq!(l.len(), 2); - [l[0].id as u32, l[1].id as u32] - }; - let mem_read: [u32; 2] = { - let l = config.memory_read.wits_in().expect("memory_read WitIns"); - assert_eq!(l.len(), 2); - [l[0].id as u32, l[1].id as u32] - }; + let mem_addr = extract_uint_limbs::(&config.memory_addr.addr, "memory_addr"); + let mem_read = extract_uint_limbs::(&config.memory_read, "memory_read"); // Sub-word specific: addr_bit_1 (all sub-word loads have at least 1 low_bit) let low_bits = &config.memory_addr.low_bits; diff --git a/ceno_zkvm/src/instructions/riscv/gpu/logic_i.rs b/ceno_zkvm/src/instructions/riscv/gpu/logic_i.rs index 36e33f4e2..ffb6b5507 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/logic_i.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/logic_i.rs @@ -1,6 +1,7 @@ use ceno_gpu::common::witgen_types::LogicIColumnMap; use ff_ext::ExtensionField; +use super::colmap_base::{extract_rd, extract_rs1, extract_state, extract_uint_limbs}; use crate::instructions::riscv::logic_imm::logic_imm_circuit_v2::LogicConfig; /// Extract column map from a constructed LogicConfig (I-type v2: ANDI/ORI/XORI). @@ -10,70 +11,14 @@ pub fn extract_logic_i_column_map( ) -> LogicIColumnMap { let im = &config.i_insn; - // StateInOut - let pc = im.vm_state.pc.id as u32; - let ts = im.vm_state.ts.id as u32; - - // ReadRS1 - let rs1_id = im.rs1.id.id as u32; - let rs1_prev_ts = im.rs1.prev_ts.id as u32; - let rs1_lt_diff: [u32; 2] = { - let d = &im.rs1.lt_cfg.0.diff; - assert_eq!(d.len(), 2); - [d[0].id as u32, d[1].id as u32] - }; - - // WriteRD - let rd_id = im.rd.id.id as u32; - let rd_prev_ts = im.rd.prev_ts.id as u32; - let rd_prev_val: [u32; 2] = { - let l = im.rd.prev_value.wits_in().expect("rd prev_value WitIns"); - assert_eq!(l.len(), 2); - [l[0].id as u32, l[1].id as u32] - }; - let rd_lt_diff: [u32; 2] = { - let d = &im.rd.lt_cfg.0.diff; - assert_eq!(d.len(), 2); - [d[0].id as u32, d[1].id as u32] - }; - - // rs1 u8 bytes - let rs1_bytes: [u32; 4] = { - let l = config.rs1_read.wits_in().expect("rs1_read WitIns"); - assert_eq!(l.len(), 4); - [ - l[0].id as u32, - l[1].id as u32, - l[2].id as u32, - l[3].id as u32, - ] - }; - - // rd u8 bytes - let rd_bytes: [u32; 4] = { - let l = config.rd_written.wits_in().expect("rd_written WitIns"); - assert_eq!(l.len(), 4); - [ - l[0].id as u32, - l[1].id as u32, - l[2].id as u32, - l[3].id as u32, - ] - }; - - // imm_lo u8 bytes (UIntLimbs<16,8> = 2 x u8) - let imm_lo_bytes: [u32; 2] = { - let l = config.imm_lo.wits_in().expect("imm_lo WitIns"); - assert_eq!(l.len(), 2); - [l[0].id as u32, l[1].id as u32] - }; + let (pc, ts) = extract_state(&im.vm_state); + let (rs1_id, rs1_prev_ts, rs1_lt_diff) = extract_rs1(&im.rs1); + let (rd_id, rd_prev_ts, rd_prev_val, rd_lt_diff) = extract_rd(&im.rd); - // imm_hi u8 bytes (UIntLimbs<16,8> = 2 x u8) - let imm_hi_bytes: [u32; 2] = { - let l = config.imm_hi.wits_in().expect("imm_hi WitIns"); - assert_eq!(l.len(), 2); - [l[0].id as u32, l[1].id as u32] - }; + let rs1_bytes = extract_uint_limbs::(&config.rs1_read, "rs1_read"); + let rd_bytes = extract_uint_limbs::(&config.rd_written, "rd_written"); + let imm_lo_bytes = extract_uint_limbs::(&config.imm_lo, "imm_lo"); + let imm_hi_bytes = extract_uint_limbs::(&config.imm_hi, "imm_hi"); LogicIColumnMap { pc, @@ -114,21 +59,7 @@ mod tests { let col_map = extract_logic_i_column_map(&config, cb.cs.num_witin as usize); let flat = col_map.to_flat(); - - for (i, &col) in flat.iter().enumerate() { - assert!( - (col as usize) < col_map.num_cols as usize, - "Column {} (index {}) out of range: {} >= {}", - i, - col, - col, - col_map.num_cols - ); - } - let mut seen = std::collections::HashSet::new(); - for &col in &flat { - assert!(seen.insert(col), "Duplicate column ID: {}", col); - } + crate::instructions::riscv::gpu::colmap_base::validate_column_map(&flat, col_map.num_cols); } #[test] diff --git a/ceno_zkvm/src/instructions/riscv/gpu/logic_r.rs b/ceno_zkvm/src/instructions/riscv/gpu/logic_r.rs index cd8c52375..68ba03060 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/logic_r.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/logic_r.rs @@ -1,6 +1,7 @@ use ceno_gpu::common::witgen_types::LogicRColumnMap; use ff_ext::ExtensionField; +use super::colmap_base::{extract_rd, extract_rs1, extract_rs2, extract_state, extract_uint_limbs}; use crate::instructions::riscv::logic::logic_circuit::LogicConfig; /// Extract column map from a constructed LogicConfig (R-type: AND/OR/XOR). @@ -8,78 +9,14 @@ pub fn extract_logic_r_column_map( config: &LogicConfig, num_witin: usize, ) -> LogicRColumnMap { - // StateInOut - let pc = config.r_insn.vm_state.pc.id as u32; - let ts = config.r_insn.vm_state.ts.id as u32; + let (pc, ts) = extract_state(&config.r_insn.vm_state); + let (rs1_id, rs1_prev_ts, rs1_lt_diff) = extract_rs1(&config.r_insn.rs1); + let (rs2_id, rs2_prev_ts, rs2_lt_diff) = extract_rs2(&config.r_insn.rs2); + let (rd_id, rd_prev_ts, rd_prev_val, rd_lt_diff) = extract_rd(&config.r_insn.rd); - // ReadRS1 - let rs1_id = config.r_insn.rs1.id.id as u32; - let rs1_prev_ts = config.r_insn.rs1.prev_ts.id as u32; - let rs1_lt_diff: [u32; 2] = { - let diffs = &config.r_insn.rs1.lt_cfg.0.diff; - assert_eq!(diffs.len(), 2); - [diffs[0].id as u32, diffs[1].id as u32] - }; - - // ReadRS2 - let rs2_id = config.r_insn.rs2.id.id as u32; - let rs2_prev_ts = config.r_insn.rs2.prev_ts.id as u32; - let rs2_lt_diff: [u32; 2] = { - let diffs = &config.r_insn.rs2.lt_cfg.0.diff; - assert_eq!(diffs.len(), 2); - [diffs[0].id as u32, diffs[1].id as u32] - }; - - // WriteRD - let rd_id = config.r_insn.rd.id.id as u32; - let rd_prev_ts = config.r_insn.rd.prev_ts.id as u32; - let rd_prev_val: [u32; 2] = { - let limbs = config - .r_insn - .rd - .prev_value - .wits_in() - .expect("rd prev_value WitIns"); - assert_eq!(limbs.len(), 2); - [limbs[0].id as u32, limbs[1].id as u32] - }; - let rd_lt_diff: [u32; 2] = { - let diffs = &config.r_insn.rd.lt_cfg.0.diff; - assert_eq!(diffs.len(), 2); - [diffs[0].id as u32, diffs[1].id as u32] - }; - - // UInt8 byte limbs - let rs1_bytes: [u32; 4] = { - let l = config.rs1_read.wits_in().expect("rs1_read WitIns"); - assert_eq!(l.len(), 4); - [ - l[0].id as u32, - l[1].id as u32, - l[2].id as u32, - l[3].id as u32, - ] - }; - let rs2_bytes: [u32; 4] = { - let l = config.rs2_read.wits_in().expect("rs2_read WitIns"); - assert_eq!(l.len(), 4); - [ - l[0].id as u32, - l[1].id as u32, - l[2].id as u32, - l[3].id as u32, - ] - }; - let rd_bytes: [u32; 4] = { - let l = config.rd_written.wits_in().expect("rd_written WitIns"); - assert_eq!(l.len(), 4); - [ - l[0].id as u32, - l[1].id as u32, - l[2].id as u32, - l[3].id as u32, - ] - }; + let rs1_bytes = extract_uint_limbs::(&config.rs1_read, "rs1_read"); + let rs2_bytes = extract_uint_limbs::(&config.rs2_read, "rs2_read"); + let rd_bytes = extract_uint_limbs::(&config.rd_written, "rd_written"); LogicRColumnMap { pc, @@ -152,20 +89,7 @@ mod tests { let col_map = extract_logic_r_column_map(&config, cb.cs.num_witin as usize); let flat = col_map.to_flat(); - for (i, &col) in flat.iter().enumerate() { - assert!( - (col as usize) < col_map.num_cols as usize, - "Column {} (index {}) out of range: {} >= {}", - i, - col, - col, - col_map.num_cols - ); - } - let mut seen = std::collections::HashSet::new(); - for &col in &flat { - assert!(seen.insert(col), "Duplicate column ID: {}", col); - } + crate::instructions::riscv::gpu::colmap_base::validate_column_map(&flat, col_map.num_cols); } #[test] diff --git a/ceno_zkvm/src/instructions/riscv/gpu/lui.rs b/ceno_zkvm/src/instructions/riscv/gpu/lui.rs index 0c644a808..7ad254ed1 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/lui.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/lui.rs @@ -1,6 +1,7 @@ use ceno_gpu::common::witgen_types::LuiColumnMap; use ff_ext::ExtensionField; +use super::colmap_base::{extract_rd, extract_rs1, extract_state}; use crate::instructions::riscv::lui::LuiConfig; /// Extract column map from a constructed LuiConfig. @@ -10,32 +11,9 @@ pub fn extract_lui_column_map( ) -> LuiColumnMap { let im = &config.i_insn; - // StateInOut - let pc = im.vm_state.pc.id as u32; - let ts = im.vm_state.ts.id as u32; - - // ReadRS1 - let rs1_id = im.rs1.id.id as u32; - let rs1_prev_ts = im.rs1.prev_ts.id as u32; - let rs1_lt_diff: [u32; 2] = { - let d = &im.rs1.lt_cfg.0.diff; - assert_eq!(d.len(), 2); - [d[0].id as u32, d[1].id as u32] - }; - - // WriteRD - let rd_id = im.rd.id.id as u32; - let rd_prev_ts = im.rd.prev_ts.id as u32; - let rd_prev_val: [u32; 2] = { - let l = im.rd.prev_value.wits_in().expect("rd prev_value WitIns"); - assert_eq!(l.len(), 2); - [l[0].id as u32, l[1].id as u32] - }; - let rd_lt_diff: [u32; 2] = { - let d = &im.rd.lt_cfg.0.diff; - assert_eq!(d.len(), 2); - [d[0].id as u32, d[1].id as u32] - }; + let (pc, ts) = extract_state(&im.vm_state); + let (rs1_id, rs1_prev_ts, rs1_lt_diff) = extract_rs1(&im.rs1); + let (rd_id, rd_prev_ts, rd_prev_val, rd_lt_diff) = extract_rd(&im.rd); // LUI-specific: rd bytes (skip byte 0) + imm let rd_bytes: [u32; 3] = [ @@ -82,21 +60,7 @@ mod tests { let col_map = extract_lui_column_map(&config, cb.cs.num_witin as usize); let flat = col_map.to_flat(); - - for (i, &col) in flat.iter().enumerate() { - assert!( - (col as usize) < col_map.num_cols as usize, - "Column {} (index {}) out of range: {} >= {}", - i, - col, - col, - col_map.num_cols - ); - } - let mut seen = std::collections::HashSet::new(); - for &col in &flat { - assert!(seen.insert(col), "Duplicate column ID: {}", col); - } + crate::instructions::riscv::gpu::colmap_base::validate_column_map(&flat, col_map.num_cols); } #[test] diff --git a/ceno_zkvm/src/instructions/riscv/gpu/lw.rs b/ceno_zkvm/src/instructions/riscv/gpu/lw.rs index 8e686d0cb..b518d3361 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/lw.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/lw.rs @@ -1,6 +1,8 @@ use ceno_gpu::common::witgen_types::LwColumnMap; use ff_ext::ExtensionField; +use super::colmap_base::{extract_rd, extract_read_mem, extract_rs1, extract_state, extract_uint_limbs}; + #[cfg(not(feature = "u16limb_circuit"))] use crate::instructions::riscv::memory::load::LoadConfig; #[cfg(feature = "u16limb_circuit")] @@ -13,66 +15,19 @@ pub fn extract_lw_column_map( ) -> LwColumnMap { let im = &config.im_insn; - // StateInOut - let pc = im.vm_state.pc.id as u32; - let ts = im.vm_state.ts.id as u32; - - // ReadRS1 - let rs1_id = im.rs1.id.id as u32; - let rs1_prev_ts = im.rs1.prev_ts.id as u32; - let rs1_lt_diff = { - let d = &im.rs1.lt_cfg.0.diff; - assert_eq!(d.len(), 2); - [d[0].id as u32, d[1].id as u32] - }; - - // WriteRD - let rd_id = im.rd.id.id as u32; - let rd_prev_ts = im.rd.prev_ts.id as u32; - let rd_prev_val = { - let l = im.rd.prev_value.wits_in().expect("rd prev_value WitIns"); - assert_eq!(l.len(), 2); - [l[0].id as u32, l[1].id as u32] - }; - let rd_lt_diff = { - let d = &im.rd.lt_cfg.0.diff; - assert_eq!(d.len(), 2); - [d[0].id as u32, d[1].id as u32] - }; + let (pc, ts) = extract_state(&im.vm_state); + let (rs1_id, rs1_prev_ts, rs1_lt_diff) = extract_rs1(&im.rs1); + let (rd_id, rd_prev_ts, rd_prev_val, rd_lt_diff) = extract_rd(&im.rd); + let (mem_prev_ts, mem_lt_diff) = extract_read_mem(&im.mem_read); - // ReadMEM - let mem_prev_ts = im.mem_read.prev_ts.id as u32; - let mem_lt_diff = { - let d = &im.mem_read.lt_cfg.0.diff; - assert_eq!(d.len(), 2); - [d[0].id as u32, d[1].id as u32] - }; - - // Load-specific - let rs1_limbs = { - let l = config.rs1_read.wits_in().expect("rs1_read WitIns"); - assert_eq!(l.len(), 2); - [l[0].id as u32, l[1].id as u32] - }; + let rs1_limbs = extract_uint_limbs::(&config.rs1_read, "rs1_read"); let imm = config.imm.id as u32; #[cfg(feature = "u16limb_circuit")] let imm_sign = Some(config.imm_sign.id as u32); #[cfg(not(feature = "u16limb_circuit"))] let imm_sign = None; - let mem_addr_limbs = { - let l = config - .memory_addr - .addr - .wits_in() - .expect("memory_addr WitIns"); - assert_eq!(l.len(), 2); - [l[0].id as u32, l[1].id as u32] - }; - let mem_read_limbs = { - let l = config.memory_read.wits_in().expect("memory_read WitIns"); - assert_eq!(l.len(), 2); - [l[0].id as u32, l[1].id as u32] - }; + let mem_addr_limbs = extract_uint_limbs::(&config.memory_addr.addr, "memory_addr"); + let mem_read_limbs = extract_uint_limbs::(&config.memory_read, "memory_read"); LwColumnMap { pc, diff --git a/ceno_zkvm/src/instructions/riscv/gpu/mod.rs b/ceno_zkvm/src/instructions/riscv/gpu/mod.rs index c488455c6..21e3fd8d2 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/mod.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/mod.rs @@ -47,4 +47,6 @@ pub mod keccak; #[cfg(feature = "gpu")] pub mod shard_ram; #[cfg(feature = "gpu")] +pub mod colmap_base; +#[cfg(feature = "gpu")] pub mod witgen_gpu; diff --git a/ceno_zkvm/src/instructions/riscv/gpu/mul.rs b/ceno_zkvm/src/instructions/riscv/gpu/mul.rs index efafd6bd1..2ea5ddbf5 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/mul.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/mul.rs @@ -1,6 +1,7 @@ use ceno_gpu::common::witgen_types::MulColumnMap; use ff_ext::ExtensionField; +use super::colmap_base::{extract_rd, extract_rs1, extract_rs2, extract_state, extract_uint_limbs}; use crate::instructions::riscv::mulh::mulh_circuit_v2::MulhConfig; /// Extract column map from a constructed MulhConfig. @@ -10,52 +11,13 @@ pub fn extract_mul_column_map( num_witin: usize, mul_kind: u32, ) -> MulColumnMap { - let r = &config.r_insn; + let (pc, ts) = extract_state(&config.r_insn.vm_state); + let (rs1_id, rs1_prev_ts, rs1_lt_diff) = extract_rs1(&config.r_insn.rs1); + let (rs2_id, rs2_prev_ts, rs2_lt_diff) = extract_rs2(&config.r_insn.rs2); + let (rd_id, rd_prev_ts, rd_prev_val, rd_lt_diff) = extract_rd(&config.r_insn.rd); - // R-type base - let pc = r.vm_state.pc.id as u32; - let ts = r.vm_state.ts.id as u32; - - let rs1_id = r.rs1.id.id as u32; - let rs1_prev_ts = r.rs1.prev_ts.id as u32; - let rs1_lt_diff: [u32; 2] = { - let d = &r.rs1.lt_cfg.0.diff; - assert_eq!(d.len(), 2); - [d[0].id as u32, d[1].id as u32] - }; - - let rs2_id = r.rs2.id.id as u32; - let rs2_prev_ts = r.rs2.prev_ts.id as u32; - let rs2_lt_diff: [u32; 2] = { - let d = &r.rs2.lt_cfg.0.diff; - assert_eq!(d.len(), 2); - [d[0].id as u32, d[1].id as u32] - }; - - let rd_id = r.rd.id.id as u32; - let rd_prev_ts = r.rd.prev_ts.id as u32; - let rd_prev_val: [u32; 2] = { - let l = r.rd.prev_value.wits_in().expect("rd prev_value WitIns"); - assert_eq!(l.len(), 2); - [l[0].id as u32, l[1].id as u32] - }; - let rd_lt_diff: [u32; 2] = { - let d = &r.rd.lt_cfg.0.diff; - assert_eq!(d.len(), 2); - [d[0].id as u32, d[1].id as u32] - }; - - // Mul-specific - let rs1_limbs: [u32; 2] = { - let l = config.rs1_read.wits_in().expect("rs1_read WitIns"); - assert_eq!(l.len(), 2); - [l[0].id as u32, l[1].id as u32] - }; - let rs2_limbs: [u32; 2] = { - let l = config.rs2_read.wits_in().expect("rs2_read WitIns"); - assert_eq!(l.len(), 2); - [l[0].id as u32, l[1].id as u32] - }; + let rs1_limbs = extract_uint_limbs::(&config.rs1_read, "rs1_read"); + let rs2_limbs = extract_uint_limbs::(&config.rs2_read, "rs2_read"); let rd_low: [u32; 2] = [config.rd_low[0].id as u32, config.rd_low[1].id as u32]; // MULH/MULHU/MULHSU have rd_high + extensions diff --git a/ceno_zkvm/src/instructions/riscv/gpu/sb.rs b/ceno_zkvm/src/instructions/riscv/gpu/sb.rs index 10775d984..01ff6c50c 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/sb.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/sb.rs @@ -1,6 +1,9 @@ use ceno_gpu::common::witgen_types::SbColumnMap; use ff_ext::ExtensionField; +use super::colmap_base::{ + extract_rs1, extract_rs2, extract_state, extract_uint_limbs, extract_write_mem, +}; use crate::instructions::riscv::memory::store_v2::StoreConfig; /// Extract column map from a constructed StoreConfig (SB variant, N_ZEROS=0). @@ -10,66 +13,17 @@ pub fn extract_sb_column_map( ) -> SbColumnMap { let sm = &config.s_insn; - // StateInOut (not branching) - let pc = sm.vm_state.pc.id as u32; - let ts = sm.vm_state.ts.id as u32; + let (pc, ts) = extract_state(&sm.vm_state); + let (rs1_id, rs1_prev_ts, rs1_lt_diff) = extract_rs1(&sm.rs1); + let (rs2_id, rs2_prev_ts, rs2_lt_diff) = extract_rs2(&sm.rs2); + let (mem_prev_ts, mem_lt_diff) = extract_write_mem(&sm.mem_write); - // ReadRS1 - let rs1_id = sm.rs1.id.id as u32; - let rs1_prev_ts = sm.rs1.prev_ts.id as u32; - let rs1_lt_diff: [u32; 2] = { - let d = &sm.rs1.lt_cfg.0.diff; - assert_eq!(d.len(), 2); - [d[0].id as u32, d[1].id as u32] - }; - - // ReadRS2 - let rs2_id = sm.rs2.id.id as u32; - let rs2_prev_ts = sm.rs2.prev_ts.id as u32; - let rs2_lt_diff: [u32; 2] = { - let d = &sm.rs2.lt_cfg.0.diff; - assert_eq!(d.len(), 2); - [d[0].id as u32, d[1].id as u32] - }; - - // WriteMEM - let mem_prev_ts = sm.mem_write.prev_ts.id as u32; - let mem_lt_diff: [u32; 2] = { - let d = &sm.mem_write.lt_cfg.0.diff; - assert_eq!(d.len(), 2); - [d[0].id as u32, d[1].id as u32] - }; - - // Store-specific - let rs1_limbs: [u32; 2] = { - let l = config.rs1_read.wits_in().expect("rs1_read WitIns"); - assert_eq!(l.len(), 2); - [l[0].id as u32, l[1].id as u32] - }; - let rs2_limbs: [u32; 2] = { - let l = config.rs2_read.wits_in().expect("rs2_read WitIns"); - assert_eq!(l.len(), 2); - [l[0].id as u32, l[1].id as u32] - }; + let rs1_limbs = extract_uint_limbs::(&config.rs1_read, "rs1_read"); + let rs2_limbs = extract_uint_limbs::(&config.rs2_read, "rs2_read"); let imm = config.imm.id as u32; let imm_sign = config.imm_sign.id as u32; - let prev_mem_val: [u32; 2] = { - let l = config - .prev_memory_value - .wits_in() - .expect("prev_memory_value WitIns"); - assert_eq!(l.len(), 2); - [l[0].id as u32, l[1].id as u32] - }; - let mem_addr: [u32; 2] = { - let l = config - .memory_addr - .addr - .wits_in() - .expect("memory_addr WitIns"); - assert_eq!(l.len(), 2); - [l[0].id as u32, l[1].id as u32] - }; + let prev_mem_val = extract_uint_limbs::(&config.prev_memory_value, "prev_memory_value"); + let mem_addr = extract_uint_limbs::(&config.memory_addr.addr, "memory_addr"); // SB-specific: 2 low_bits (bit_0, bit_1) assert_eq!( @@ -145,21 +99,7 @@ mod tests { let col_map = extract_sb_column_map(&config, cb.cs.num_witin as usize); let flat = col_map.to_flat(); - - for (i, &col) in flat.iter().enumerate() { - assert!( - (col as usize) < col_map.num_cols as usize, - "Column {} (index {}) out of range: {} >= {}", - i, - col, - col, - col_map.num_cols - ); - } - let mut seen = std::collections::HashSet::new(); - for &col in &flat { - assert!(seen.insert(col), "Duplicate column ID: {}", col); - } + crate::instructions::riscv::gpu::colmap_base::validate_column_map(&flat, col_map.num_cols); } #[test] diff --git a/ceno_zkvm/src/instructions/riscv/gpu/sh.rs b/ceno_zkvm/src/instructions/riscv/gpu/sh.rs index 72ea316f6..f8814b35c 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/sh.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/sh.rs @@ -1,6 +1,9 @@ use ceno_gpu::common::witgen_types::ShColumnMap; use ff_ext::ExtensionField; +use super::colmap_base::{ + extract_rs1, extract_rs2, extract_state, extract_uint_limbs, extract_write_mem, +}; use crate::instructions::riscv::memory::store_v2::StoreConfig; /// Extract column map from a constructed StoreConfig (SH variant, N_ZEROS=1). @@ -10,66 +13,17 @@ pub fn extract_sh_column_map( ) -> ShColumnMap { let sm = &config.s_insn; - // StateInOut (not branching) - let pc = sm.vm_state.pc.id as u32; - let ts = sm.vm_state.ts.id as u32; - - // ReadRS1 - let rs1_id = sm.rs1.id.id as u32; - let rs1_prev_ts = sm.rs1.prev_ts.id as u32; - let rs1_lt_diff: [u32; 2] = { - let d = &sm.rs1.lt_cfg.0.diff; - assert_eq!(d.len(), 2); - [d[0].id as u32, d[1].id as u32] - }; - - // ReadRS2 - let rs2_id = sm.rs2.id.id as u32; - let rs2_prev_ts = sm.rs2.prev_ts.id as u32; - let rs2_lt_diff: [u32; 2] = { - let d = &sm.rs2.lt_cfg.0.diff; - assert_eq!(d.len(), 2); - [d[0].id as u32, d[1].id as u32] - }; + let (pc, ts) = extract_state(&sm.vm_state); + let (rs1_id, rs1_prev_ts, rs1_lt_diff) = extract_rs1(&sm.rs1); + let (rs2_id, rs2_prev_ts, rs2_lt_diff) = extract_rs2(&sm.rs2); + let (mem_prev_ts, mem_lt_diff) = extract_write_mem(&sm.mem_write); - // WriteMEM - let mem_prev_ts = sm.mem_write.prev_ts.id as u32; - let mem_lt_diff: [u32; 2] = { - let d = &sm.mem_write.lt_cfg.0.diff; - assert_eq!(d.len(), 2); - [d[0].id as u32, d[1].id as u32] - }; - - // Store-specific - let rs1_limbs: [u32; 2] = { - let l = config.rs1_read.wits_in().expect("rs1_read WitIns"); - assert_eq!(l.len(), 2); - [l[0].id as u32, l[1].id as u32] - }; - let rs2_limbs: [u32; 2] = { - let l = config.rs2_read.wits_in().expect("rs2_read WitIns"); - assert_eq!(l.len(), 2); - [l[0].id as u32, l[1].id as u32] - }; + let rs1_limbs = extract_uint_limbs::(&config.rs1_read, "rs1_read"); + let rs2_limbs = extract_uint_limbs::(&config.rs2_read, "rs2_read"); let imm = config.imm.id as u32; let imm_sign = config.imm_sign.id as u32; - let prev_mem_val: [u32; 2] = { - let l = config - .prev_memory_value - .wits_in() - .expect("prev_memory_value WitIns"); - assert_eq!(l.len(), 2); - [l[0].id as u32, l[1].id as u32] - }; - let mem_addr: [u32; 2] = { - let l = config - .memory_addr - .addr - .wits_in() - .expect("memory_addr WitIns"); - assert_eq!(l.len(), 2); - [l[0].id as u32, l[1].id as u32] - }; + let prev_mem_val = extract_uint_limbs::(&config.prev_memory_value, "prev_memory_value"); + let mem_addr = extract_uint_limbs::(&config.memory_addr.addr, "memory_addr"); // SH-specific: 1 low_bit (bit_1 for halfword select) assert_eq!( @@ -122,21 +76,7 @@ mod tests { let col_map = extract_sh_column_map(&config, cb.cs.num_witin as usize); let flat = col_map.to_flat(); - - for (i, &col) in flat.iter().enumerate() { - assert!( - (col as usize) < col_map.num_cols as usize, - "Column {} (index {}) out of range: {} >= {}", - i, - col, - col, - col_map.num_cols - ); - } - let mut seen = std::collections::HashSet::new(); - for &col in &flat { - assert!(seen.insert(col), "Duplicate column ID: {}", col); - } + crate::instructions::riscv::gpu::colmap_base::validate_column_map(&flat, col_map.num_cols); } #[test] diff --git a/ceno_zkvm/src/instructions/riscv/gpu/shift_i.rs b/ceno_zkvm/src/instructions/riscv/gpu/shift_i.rs index 22dee5dab..4cf969e0c 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/shift_i.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/shift_i.rs @@ -1,6 +1,7 @@ use ceno_gpu::common::witgen_types::ShiftIColumnMap; use ff_ext::ExtensionField; +use super::colmap_base::{extract_rd, extract_rs1, extract_state, extract_uint_limbs}; use crate::instructions::riscv::shift::shift_circuit_v2::ShiftImmConfig; /// Extract column map from a constructed ShiftImmConfig (I-type: SLLI/SRLI/SRAI). @@ -8,61 +9,12 @@ pub fn extract_shift_i_column_map( config: &ShiftImmConfig, num_witin: usize, ) -> ShiftIColumnMap { - // StateInOut - let pc = config.i_insn.vm_state.pc.id as u32; - let ts = config.i_insn.vm_state.ts.id as u32; - - // ReadRS1 - let rs1_id = config.i_insn.rs1.id.id as u32; - let rs1_prev_ts = config.i_insn.rs1.prev_ts.id as u32; - let rs1_lt_diff: [u32; 2] = { - let diffs = &config.i_insn.rs1.lt_cfg.0.diff; - assert_eq!(diffs.len(), 2); - [diffs[0].id as u32, diffs[1].id as u32] - }; - - // WriteRD - let rd_id = config.i_insn.rd.id.id as u32; - let rd_prev_ts = config.i_insn.rd.prev_ts.id as u32; - let rd_prev_val: [u32; 2] = { - let limbs = config - .i_insn - .rd - .prev_value - .wits_in() - .expect("rd prev_value WitIns"); - assert_eq!(limbs.len(), 2); - [limbs[0].id as u32, limbs[1].id as u32] - }; - let rd_lt_diff: [u32; 2] = { - let diffs = &config.i_insn.rd.lt_cfg.0.diff; - assert_eq!(diffs.len(), 2); - [diffs[0].id as u32, diffs[1].id as u32] - }; - - // UInt8 byte limbs - let rs1_bytes: [u32; 4] = { - let l = config.rs1_read.wits_in().expect("rs1_read WitIns"); - assert_eq!(l.len(), 4); - [ - l[0].id as u32, - l[1].id as u32, - l[2].id as u32, - l[3].id as u32, - ] - }; - let rd_bytes: [u32; 4] = { - let l = config.rd_written.wits_in().expect("rd_written WitIns"); - assert_eq!(l.len(), 4); - [ - l[0].id as u32, - l[1].id as u32, - l[2].id as u32, - l[3].id as u32, - ] - }; + let (pc, ts) = extract_state(&config.i_insn.vm_state); + let (rs1_id, rs1_prev_ts, rs1_lt_diff) = extract_rs1(&config.i_insn.rs1); + let (rd_id, rd_prev_ts, rd_prev_val, rd_lt_diff) = extract_rd(&config.i_insn.rd); - // Immediate + let rs1_bytes = extract_uint_limbs::(&config.rs1_read, "rs1_read"); + let rd_bytes = extract_uint_limbs::(&config.rd_written, "rd_written"); let imm = config.imm.id as u32; // ShiftBase @@ -120,21 +72,7 @@ mod tests { let col_map = extract_shift_i_column_map(&config, cb.cs.num_witin as usize); let flat = col_map.to_flat(); - - for (i, &col) in flat.iter().enumerate() { - assert!( - (col as usize) < col_map.num_cols as usize, - "Column {} (index {}) out of range: {} >= {}", - i, - col, - col, - col_map.num_cols - ); - } - let mut seen = std::collections::HashSet::new(); - for &col in &flat { - assert!(seen.insert(col), "Duplicate column ID: {}", col); - } + crate::instructions::riscv::gpu::colmap_base::validate_column_map(&flat, col_map.num_cols); } #[test] diff --git a/ceno_zkvm/src/instructions/riscv/gpu/shift_r.rs b/ceno_zkvm/src/instructions/riscv/gpu/shift_r.rs index 7498b84a8..72ef64f50 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/shift_r.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/shift_r.rs @@ -1,6 +1,7 @@ use ceno_gpu::common::witgen_types::ShiftRColumnMap; use ff_ext::ExtensionField; +use super::colmap_base::{extract_rd, extract_rs1, extract_rs2, extract_state, extract_uint_limbs}; use crate::instructions::riscv::shift::shift_circuit_v2::ShiftRTypeConfig; /// Extract column map from a constructed ShiftRTypeConfig (R-type: SLL/SRL/SRA). @@ -8,78 +9,14 @@ pub fn extract_shift_r_column_map( config: &ShiftRTypeConfig, num_witin: usize, ) -> ShiftRColumnMap { - // StateInOut - let pc = config.r_insn.vm_state.pc.id as u32; - let ts = config.r_insn.vm_state.ts.id as u32; + let (pc, ts) = extract_state(&config.r_insn.vm_state); + let (rs1_id, rs1_prev_ts, rs1_lt_diff) = extract_rs1(&config.r_insn.rs1); + let (rs2_id, rs2_prev_ts, rs2_lt_diff) = extract_rs2(&config.r_insn.rs2); + let (rd_id, rd_prev_ts, rd_prev_val, rd_lt_diff) = extract_rd(&config.r_insn.rd); - // ReadRS1 - let rs1_id = config.r_insn.rs1.id.id as u32; - let rs1_prev_ts = config.r_insn.rs1.prev_ts.id as u32; - let rs1_lt_diff: [u32; 2] = { - let diffs = &config.r_insn.rs1.lt_cfg.0.diff; - assert_eq!(diffs.len(), 2); - [diffs[0].id as u32, diffs[1].id as u32] - }; - - // ReadRS2 - let rs2_id = config.r_insn.rs2.id.id as u32; - let rs2_prev_ts = config.r_insn.rs2.prev_ts.id as u32; - let rs2_lt_diff: [u32; 2] = { - let diffs = &config.r_insn.rs2.lt_cfg.0.diff; - assert_eq!(diffs.len(), 2); - [diffs[0].id as u32, diffs[1].id as u32] - }; - - // WriteRD - let rd_id = config.r_insn.rd.id.id as u32; - let rd_prev_ts = config.r_insn.rd.prev_ts.id as u32; - let rd_prev_val: [u32; 2] = { - let limbs = config - .r_insn - .rd - .prev_value - .wits_in() - .expect("rd prev_value WitIns"); - assert_eq!(limbs.len(), 2); - [limbs[0].id as u32, limbs[1].id as u32] - }; - let rd_lt_diff: [u32; 2] = { - let diffs = &config.r_insn.rd.lt_cfg.0.diff; - assert_eq!(diffs.len(), 2); - [diffs[0].id as u32, diffs[1].id as u32] - }; - - // UInt8 byte limbs - let rs1_bytes: [u32; 4] = { - let l = config.rs1_read.wits_in().expect("rs1_read WitIns"); - assert_eq!(l.len(), 4); - [ - l[0].id as u32, - l[1].id as u32, - l[2].id as u32, - l[3].id as u32, - ] - }; - let rs2_bytes: [u32; 4] = { - let l = config.rs2_read.wits_in().expect("rs2_read WitIns"); - assert_eq!(l.len(), 4); - [ - l[0].id as u32, - l[1].id as u32, - l[2].id as u32, - l[3].id as u32, - ] - }; - let rd_bytes: [u32; 4] = { - let l = config.rd_written.wits_in().expect("rd_written WitIns"); - assert_eq!(l.len(), 4); - [ - l[0].id as u32, - l[1].id as u32, - l[2].id as u32, - l[3].id as u32, - ] - }; + let rs1_bytes = extract_uint_limbs::(&config.rs1_read, "rs1_read"); + let rs2_bytes = extract_uint_limbs::(&config.rs2_read, "rs2_read"); + let rd_bytes = extract_uint_limbs::(&config.rd_written, "rd_written"); // ShiftBase let bit_shift_marker: [u32; 8] = @@ -140,20 +77,7 @@ mod tests { let col_map = extract_shift_r_column_map(&config, cb.cs.num_witin as usize); let flat = col_map.to_flat(); - for (i, &col) in flat.iter().enumerate() { - assert!( - (col as usize) < col_map.num_cols as usize, - "Column {} (index {}) out of range: {} >= {}", - i, - col, - col, - col_map.num_cols - ); - } - let mut seen = std::collections::HashSet::new(); - for &col in &flat { - assert!(seen.insert(col), "Duplicate column ID: {}", col); - } + crate::instructions::riscv::gpu::colmap_base::validate_column_map(&flat, col_map.num_cols); } #[test] diff --git a/ceno_zkvm/src/instructions/riscv/gpu/slt.rs b/ceno_zkvm/src/instructions/riscv/gpu/slt.rs index a8023edbd..c46c63d4e 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/slt.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/slt.rs @@ -1,6 +1,7 @@ use ceno_gpu::common::witgen_types::SltColumnMap; use ff_ext::ExtensionField; +use super::colmap_base::{extract_rd, extract_rs1, extract_rs2, extract_state, extract_uint_limbs}; use crate::instructions::riscv::slt::slt_circuit_v2::SetLessThanConfig; /// Extract column map from a constructed SetLessThanConfig (SLT/SLTU). @@ -8,25 +9,8 @@ pub fn extract_slt_column_map( config: &SetLessThanConfig, num_witin: usize, ) -> SltColumnMap { - // rs1_read: UInt (2 u16 limbs) - let rs1_limbs: [u32; 2] = { - let limbs = config - .rs1_read - .wits_in() - .expect("rs1_read should have WitIn limbs"); - assert_eq!(limbs.len(), 2); - [limbs[0].id as u32, limbs[1].id as u32] - }; - - // rs2_read: UInt (2 u16 limbs) - let rs2_limbs: [u32; 2] = { - let limbs = config - .rs2_read - .wits_in() - .expect("rs2_read should have WitIn limbs"); - assert_eq!(limbs.len(), 2); - [limbs[0].id as u32, limbs[1].id as u32] - }; + let rs1_limbs = extract_uint_limbs::(&config.rs1_read, "rs1_read"); + let rs2_limbs = extract_uint_limbs::(&config.rs2_read, "rs2_read"); // UIntLimbsLT comparison gadget let cmp_lt = config.uint_lt_config.cmp_lt.id as u32; @@ -38,43 +22,10 @@ pub fn extract_slt_column_map( ]; let diff_val = config.uint_lt_config.diff_val.id as u32; - // R-type base: StateInOut + ReadRS1 + ReadRS2 + WriteRD - let pc = config.r_insn.vm_state.pc.id as u32; - let ts = config.r_insn.vm_state.ts.id as u32; - - let rs1_id = config.r_insn.rs1.id.id as u32; - let rs1_prev_ts = config.r_insn.rs1.prev_ts.id as u32; - let rs1_lt_diff: [u32; 2] = { - let diffs = &config.r_insn.rs1.lt_cfg.0.diff; - assert_eq!(diffs.len(), 2); - [diffs[0].id as u32, diffs[1].id as u32] - }; - - let rs2_id = config.r_insn.rs2.id.id as u32; - let rs2_prev_ts = config.r_insn.rs2.prev_ts.id as u32; - let rs2_lt_diff: [u32; 2] = { - let diffs = &config.r_insn.rs2.lt_cfg.0.diff; - assert_eq!(diffs.len(), 2); - [diffs[0].id as u32, diffs[1].id as u32] - }; - - let rd_id = config.r_insn.rd.id.id as u32; - let rd_prev_ts = config.r_insn.rd.prev_ts.id as u32; - let rd_prev_val: [u32; 2] = { - let limbs = config - .r_insn - .rd - .prev_value - .wits_in() - .expect("WriteRD prev_value should have WitIn limbs"); - assert_eq!(limbs.len(), 2); - [limbs[0].id as u32, limbs[1].id as u32] - }; - let rd_lt_diff: [u32; 2] = { - let diffs = &config.r_insn.rd.lt_cfg.0.diff; - assert_eq!(diffs.len(), 2); - [diffs[0].id as u32, diffs[1].id as u32] - }; + let (pc, ts) = extract_state(&config.r_insn.vm_state); + let (rs1_id, rs1_prev_ts, rs1_lt_diff) = extract_rs1(&config.r_insn.rs1); + let (rs2_id, rs2_prev_ts, rs2_lt_diff) = extract_rs2(&config.r_insn.rs2); + let (rd_id, rd_prev_ts, rd_prev_val, rd_lt_diff) = extract_rd(&config.r_insn.rd); SltColumnMap { rs1_limbs, @@ -122,20 +73,7 @@ mod tests { let col_map = extract_slt_column_map(&config, cb.cs.num_witin as usize); let flat = col_map.to_flat(); - for (i, &col) in flat.iter().enumerate() { - assert!( - (col as usize) < col_map.num_cols as usize, - "Column {} (index {}) out of range: {} >= {}", - i, - col, - col, - col_map.num_cols - ); - } - let mut seen = std::collections::HashSet::new(); - for &col in &flat { - assert!(seen.insert(col), "Duplicate column ID: {}", col); - } + crate::instructions::riscv::gpu::colmap_base::validate_column_map(&flat, col_map.num_cols); } #[test] diff --git a/ceno_zkvm/src/instructions/riscv/gpu/slti.rs b/ceno_zkvm/src/instructions/riscv/gpu/slti.rs index d0fcbca32..33f2c9247 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/slti.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/slti.rs @@ -1,6 +1,7 @@ use ceno_gpu::common::witgen_types::SltiColumnMap; use ff_ext::ExtensionField; +use super::colmap_base::{extract_rd, extract_rs1, extract_state, extract_uint_limbs}; use crate::instructions::riscv::slti::slti_circuit_v2::SetLessThanImmConfig; /// Extract column map from a constructed SetLessThanImmConfig (SLTI/SLTIU). @@ -8,16 +9,7 @@ pub fn extract_slti_column_map( config: &SetLessThanImmConfig, num_witin: usize, ) -> SltiColumnMap { - // rs1_read: UInt (2 u16 limbs) - let rs1_limbs: [u32; 2] = { - let limbs = config - .rs1_read - .wits_in() - .expect("rs1_read should have WitIn limbs"); - assert_eq!(limbs.len(), 2); - [limbs[0].id as u32, limbs[1].id as u32] - }; - + let rs1_limbs = extract_uint_limbs::(&config.rs1_read, "rs1_read"); let imm = config.imm.id as u32; let imm_sign = config.imm_sign.id as u32; @@ -31,35 +23,9 @@ pub fn extract_slti_column_map( ]; let diff_val = config.uint_lt_config.diff_val.id as u32; - // I-type base: StateInOut + ReadRS1 + WriteRD - let pc = config.i_insn.vm_state.pc.id as u32; - let ts = config.i_insn.vm_state.ts.id as u32; - - let rs1_id = config.i_insn.rs1.id.id as u32; - let rs1_prev_ts = config.i_insn.rs1.prev_ts.id as u32; - let rs1_lt_diff: [u32; 2] = { - let diffs = &config.i_insn.rs1.lt_cfg.0.diff; - assert_eq!(diffs.len(), 2); - [diffs[0].id as u32, diffs[1].id as u32] - }; - - let rd_id = config.i_insn.rd.id.id as u32; - let rd_prev_ts = config.i_insn.rd.prev_ts.id as u32; - let rd_prev_val: [u32; 2] = { - let limbs = config - .i_insn - .rd - .prev_value - .wits_in() - .expect("WriteRD prev_value should have WitIn limbs"); - assert_eq!(limbs.len(), 2); - [limbs[0].id as u32, limbs[1].id as u32] - }; - let rd_lt_diff: [u32; 2] = { - let diffs = &config.i_insn.rd.lt_cfg.0.diff; - assert_eq!(diffs.len(), 2); - [diffs[0].id as u32, diffs[1].id as u32] - }; + let (pc, ts) = extract_state(&config.i_insn.vm_state); + let (rs1_id, rs1_prev_ts, rs1_lt_diff) = extract_rs1(&config.i_insn.rs1); + let (rd_id, rd_prev_ts, rd_prev_val, rd_lt_diff) = extract_rd(&config.i_insn.rd); SltiColumnMap { rs1_limbs, @@ -104,21 +70,7 @@ mod tests { let col_map = extract_slti_column_map(&config, cb.cs.num_witin as usize); let flat = col_map.to_flat(); - - for (i, &col) in flat.iter().enumerate() { - assert!( - (col as usize) < col_map.num_cols as usize, - "Column {} (index {}) out of range: {} >= {}", - i, - col, - col, - col_map.num_cols - ); - } - let mut seen = std::collections::HashSet::new(); - for &col in &flat { - assert!(seen.insert(col), "Duplicate column ID: {}", col); - } + crate::instructions::riscv::gpu::colmap_base::validate_column_map(&flat, col_map.num_cols); } #[test] diff --git a/ceno_zkvm/src/instructions/riscv/gpu/sub.rs b/ceno_zkvm/src/instructions/riscv/gpu/sub.rs index fd729b996..31c6a562e 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/sub.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/sub.rs @@ -1,6 +1,7 @@ use ceno_gpu::common::witgen_types::SubColumnMap; use ff_ext::ExtensionField; +use super::colmap_base::{extract_carries, extract_rd, extract_rs1, extract_rs2, extract_state, extract_uint_limbs}; use crate::instructions::riscv::arith::ArithConfig; /// Extract column map from a constructed ArithConfig (SUB variant). @@ -11,77 +12,14 @@ pub fn extract_sub_column_map( config: &ArithConfig, num_witin: usize, ) -> SubColumnMap { - // StateInOut - let pc = config.r_insn.vm_state.pc.id as u32; - let ts = config.r_insn.vm_state.ts.id as u32; - - // ReadRS1 - let rs1_id = config.r_insn.rs1.id.id as u32; - let rs1_prev_ts = config.r_insn.rs1.prev_ts.id as u32; - let rs1_lt_diff: [u32; 2] = { - let diffs = &config.r_insn.rs1.lt_cfg.0.diff; - assert_eq!(diffs.len(), 2, "Expected 2 AssertLt diff limbs for RS1"); - [diffs[0].id as u32, diffs[1].id as u32] - }; - - // ReadRS2 - let rs2_id = config.r_insn.rs2.id.id as u32; - let rs2_prev_ts = config.r_insn.rs2.prev_ts.id as u32; - let rs2_lt_diff: [u32; 2] = { - let diffs = &config.r_insn.rs2.lt_cfg.0.diff; - assert_eq!(diffs.len(), 2, "Expected 2 AssertLt diff limbs for RS2"); - [diffs[0].id as u32, diffs[1].id as u32] - }; + let (pc, ts) = extract_state(&config.r_insn.vm_state); + let (rs1_id, rs1_prev_ts, rs1_lt_diff) = extract_rs1(&config.r_insn.rs1); + let (rs2_id, rs2_prev_ts, rs2_lt_diff) = extract_rs2(&config.r_insn.rs2); + let (rd_id, rd_prev_ts, rd_prev_val, rd_lt_diff) = extract_rd(&config.r_insn.rd); - // WriteRD - let rd_id = config.r_insn.rd.id.id as u32; - let rd_prev_ts = config.r_insn.rd.prev_ts.id as u32; - let rd_prev_val: [u32; 2] = { - let limbs = config - .r_insn - .rd - .prev_value - .wits_in() - .expect("WriteRD prev_value should have WitIn limbs"); - assert_eq!(limbs.len(), 2, "Expected 2 prev_value limbs"); - [limbs[0].id as u32, limbs[1].id as u32] - }; - let rd_lt_diff: [u32; 2] = { - let diffs = &config.r_insn.rd.lt_cfg.0.diff; - assert_eq!(diffs.len(), 2, "Expected 2 AssertLt diff limbs for RD"); - [diffs[0].id as u32, diffs[1].id as u32] - }; - - // SUB: rs2_read limbs (rs2 value u16 decomposition) - let rs2_limbs: [u32; 2] = { - let limbs = config - .rs2_read - .wits_in() - .expect("rs2_read should have WitIn limbs"); - assert_eq!(limbs.len(), 2, "Expected 2 rs2_read limbs"); - [limbs[0].id as u32, limbs[1].id as u32] - }; - - // SUB: rd_written limbs (rd.value.after u16 decomposition) - let rd_limbs: [u32; 2] = { - let limbs = config - .rd_written - .wits_in() - .expect("rd_written should have WitIn limbs for SUB"); - assert_eq!(limbs.len(), 2, "Expected 2 rd_written limbs"); - [limbs[0].id as u32, limbs[1].id as u32] - }; - - // SUB: carries from rs1_read (= rs2 + rd) - let carries: [u32; 2] = { - let carries = config - .rs1_read - .carries - .as_ref() - .expect("rs1_read should have carries for SUB"); - assert_eq!(carries.len(), 2, "Expected 2 rs1_read carries"); - [carries[0].id as u32, carries[1].id as u32] - }; + let rs2_limbs = extract_uint_limbs::(&config.rs2_read, "rs2_read"); + let rd_limbs = extract_uint_limbs::(&config.rd_written, "rd_written"); + let carries = extract_carries::(&config.rs1_read, "rs1_read"); SubColumnMap { pc, @@ -125,20 +63,7 @@ mod tests { let col_map = extract_sub_column_map(&config, cb.cs.num_witin as usize); let flat = col_map.to_flat(); - for (i, &col) in flat.iter().enumerate() { - assert!( - (col as usize) < col_map.num_cols as usize, - "Column {} (index {}) out of range: {} >= {}", - i, - col, - col, - col_map.num_cols - ); - } - let mut seen = std::collections::HashSet::new(); - for &col in &flat { - assert!(seen.insert(col), "Duplicate column ID: {}", col); - } + crate::instructions::riscv::gpu::colmap_base::validate_column_map(&flat, col_map.num_cols); } #[test] diff --git a/ceno_zkvm/src/instructions/riscv/gpu/sw.rs b/ceno_zkvm/src/instructions/riscv/gpu/sw.rs index 2142af2f1..267d5879d 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/sw.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/sw.rs @@ -1,6 +1,9 @@ use ceno_gpu::common::witgen_types::SwColumnMap; use ff_ext::ExtensionField; +use super::colmap_base::{ + extract_rs1, extract_rs2, extract_state, extract_uint_limbs, extract_write_mem, +}; use crate::instructions::riscv::memory::store_v2::StoreConfig; /// Extract column map from a constructed StoreConfig (SW variant, N_ZEROS=2). @@ -10,66 +13,17 @@ pub fn extract_sw_column_map( ) -> SwColumnMap { let sm = &config.s_insn; - // StateInOut (not branching) - let pc = sm.vm_state.pc.id as u32; - let ts = sm.vm_state.ts.id as u32; - - // ReadRS1 - let rs1_id = sm.rs1.id.id as u32; - let rs1_prev_ts = sm.rs1.prev_ts.id as u32; - let rs1_lt_diff: [u32; 2] = { - let d = &sm.rs1.lt_cfg.0.diff; - assert_eq!(d.len(), 2); - [d[0].id as u32, d[1].id as u32] - }; - - // ReadRS2 - let rs2_id = sm.rs2.id.id as u32; - let rs2_prev_ts = sm.rs2.prev_ts.id as u32; - let rs2_lt_diff: [u32; 2] = { - let d = &sm.rs2.lt_cfg.0.diff; - assert_eq!(d.len(), 2); - [d[0].id as u32, d[1].id as u32] - }; + let (pc, ts) = extract_state(&sm.vm_state); + let (rs1_id, rs1_prev_ts, rs1_lt_diff) = extract_rs1(&sm.rs1); + let (rs2_id, rs2_prev_ts, rs2_lt_diff) = extract_rs2(&sm.rs2); + let (mem_prev_ts, mem_lt_diff) = extract_write_mem(&sm.mem_write); - // WriteMEM - let mem_prev_ts = sm.mem_write.prev_ts.id as u32; - let mem_lt_diff: [u32; 2] = { - let d = &sm.mem_write.lt_cfg.0.diff; - assert_eq!(d.len(), 2); - [d[0].id as u32, d[1].id as u32] - }; - - // SW-specific - let rs1_limbs: [u32; 2] = { - let l = config.rs1_read.wits_in().expect("rs1_read WitIns"); - assert_eq!(l.len(), 2); - [l[0].id as u32, l[1].id as u32] - }; - let rs2_limbs: [u32; 2] = { - let l = config.rs2_read.wits_in().expect("rs2_read WitIns"); - assert_eq!(l.len(), 2); - [l[0].id as u32, l[1].id as u32] - }; + let rs1_limbs = extract_uint_limbs::(&config.rs1_read, "rs1_read"); + let rs2_limbs = extract_uint_limbs::(&config.rs2_read, "rs2_read"); let imm = config.imm.id as u32; let imm_sign = config.imm_sign.id as u32; - let prev_mem_val: [u32; 2] = { - let l = config - .prev_memory_value - .wits_in() - .expect("prev_memory_value WitIns"); - assert_eq!(l.len(), 2); - [l[0].id as u32, l[1].id as u32] - }; - let mem_addr: [u32; 2] = { - let l = config - .memory_addr - .addr - .wits_in() - .expect("memory_addr WitIns"); - assert_eq!(l.len(), 2); - [l[0].id as u32, l[1].id as u32] - }; + let prev_mem_val = extract_uint_limbs::(&config.prev_memory_value, "prev_memory_value"); + let mem_addr = extract_uint_limbs::(&config.memory_addr.addr, "memory_addr"); SwColumnMap { pc, @@ -113,21 +67,7 @@ mod tests { let col_map = extract_sw_column_map(&config, cb.cs.num_witin as usize); let flat = col_map.to_flat(); - - for (i, &col) in flat.iter().enumerate() { - assert!( - (col as usize) < col_map.num_cols as usize, - "Column {} (index {}) out of range: {} >= {}", - i, - col, - col, - col_map.num_cols - ); - } - let mut seen = std::collections::HashSet::new(); - for &col in &flat { - assert!(seen.insert(col), "Duplicate column ID: {}", col); - } + crate::instructions::riscv::gpu::colmap_base::validate_column_map(&flat, col_map.num_cols); } #[test] From aa34af5c1e63da3989241f17767b50a5c90d97c6 Mon Sep 17 00:00:00 2001 From: Velaciela Date: Mon, 23 Mar 2026 18:23:23 +0800 Subject: [PATCH 52/73] part-b --- ceno_zkvm/src/instructions/riscv/gpu/d2h.rs | 369 ++++ .../instructions/riscv/gpu/debug_compare.rs | 729 +++++++ .../instructions/riscv/gpu/device_cache.rs | 532 +++++ .../src/instructions/riscv/gpu/gpu_config.rs | 241 +++ ceno_zkvm/src/instructions/riscv/gpu/mod.rs | 8 + .../src/instructions/riscv/gpu/witgen_gpu.rs | 1831 +---------------- 6 files changed, 1907 insertions(+), 1803 deletions(-) create mode 100644 ceno_zkvm/src/instructions/riscv/gpu/d2h.rs create mode 100644 ceno_zkvm/src/instructions/riscv/gpu/debug_compare.rs create mode 100644 ceno_zkvm/src/instructions/riscv/gpu/device_cache.rs create mode 100644 ceno_zkvm/src/instructions/riscv/gpu/gpu_config.rs diff --git a/ceno_zkvm/src/instructions/riscv/gpu/d2h.rs b/ceno_zkvm/src/instructions/riscv/gpu/d2h.rs new file mode 100644 index 000000000..ac4175f5d --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/gpu/d2h.rs @@ -0,0 +1,369 @@ +/// Device-to-host conversion functions for GPU witness generation. +/// +/// This module handles: +/// - Type aliases for GPU buffer types +/// - D2H transfer of witness matrices (transpose + copy) +/// - D2H transfer of lookup counter buffers +/// - D2H transfer of compact EC records +/// - Conversion between host and GPU shard RAM record formats +/// - Batch EC point computation on GPU for continuation circuits +use ceno_emul::WordAddr; +use ceno_gpu::{ + Buffer, CudaHal, bb31::CudaHalBB31, common::transpose::matrix_transpose, +}; +use ceno_gpu::common::witgen_types::{CompactEcResult, GpuRamRecordSlot, GpuShardRamRecord}; +use ff_ext::ExtensionField; +use gkr_iop::{RAMType, tables::LookupTable, utils::lk_multiplicity::Multiplicity}; +use p3::field::FieldAlgebra; +use rustc_hash::FxHashMap; +use tracing::info_span; +use witness::{InstancePaddingStrategy, RowMajorMatrix}; + +use crate::{ + e2e::{RAMRecord, ShardContext}, + error::ZKVMError, +}; + +pub(crate) type WitBuf = ceno_gpu::common::BufferImpl< + 'static, + ::BaseField, +>; +pub(crate) type LkBuf = ceno_gpu::common::BufferImpl<'static, u32>; +pub(crate) type RamBuf = ceno_gpu::common::BufferImpl<'static, u32>; +pub(crate) type WitResult = ceno_gpu::common::witgen_types::GpuWitnessResult; +pub(crate) type LkResult = ceno_gpu::common::witgen_types::GpuLookupCountersResult; +pub(crate) type CompactEcBuf = ceno_gpu::common::witgen_types::CompactEcResult; + +/// CPU-side lightweight scan of GPU-produced RAM record slots. +/// +/// Reconstructs BTreeMap read/write records and addr_accessed from the GPU output, +/// replacing the previous `collect_shard_side_effects()` CPU loop. +pub(crate) fn gpu_collect_shard_records( + shard_ctx: &mut ShardContext, + slots: &[GpuRamRecordSlot], +) { + let current_shard_id = shard_ctx.shard_id; + + for slot in slots { + // Check was_sent flag (bit 4): this slot corresponds to a send() call + if slot.flags & (1 << 4) != 0 { + shard_ctx.push_addr_accessed(WordAddr(slot.addr)); + } + + // Check active flag (bit 0): this slot has a read or write record + if slot.flags & 1 == 0 { + continue; + } + + let ram_type = match (slot.flags >> 5) & 0x7 { + 1 => RAMType::Register, + 2 => RAMType::Memory, + _ => continue, + }; + let has_prev_value = slot.flags & (1 << 3) != 0; + let prev_value = if has_prev_value { Some(slot.prev_value) } else { None }; + let addr = WordAddr(slot.addr); + + // Insert read record (bit 1) + if slot.flags & (1 << 1) != 0 { + shard_ctx.insert_read_record( + addr, + RAMRecord { + ram_type, + reg_id: slot.reg_id as u64, + addr, + prev_cycle: slot.prev_cycle, + cycle: slot.cycle, + shard_cycle: 0, + prev_value, + value: slot.value, + shard_id: slot.read_shard_id as usize, + }, + ); + } + + // Insert write record (bit 2) + if slot.flags & (1 << 2) != 0 { + shard_ctx.insert_write_record( + addr, + RAMRecord { + ram_type, + reg_id: slot.reg_id as u64, + addr, + prev_cycle: slot.prev_cycle, + cycle: slot.cycle, + shard_cycle: slot.shard_cycle, + prev_value, + value: slot.value, + shard_id: current_shard_id, + }, + ); + } + } +} + +/// D2H the compact EC result: read count, then partial-D2H only that many records. +pub(crate) fn gpu_compact_ec_d2h( + compact: &CompactEcResult, +) -> Result, ZKVMError> { + // D2H the count (1 u32) + let count_vec: Vec = compact.count_buf.to_vec().map_err(|e| { + ZKVMError::InvalidWitness(format!("compact_count D2H failed: {e}").into()) + })?; + let count = count_vec[0] as usize; + if count == 0 { + return Ok(vec![]); + } + + // Partial D2H: only transfer the first `count` records (not the full allocation) + let record_u32s = std::mem::size_of::() / 4; // 26 + let total_u32s = count * record_u32s; + let buf_vec: Vec = compact.buffer.to_vec_n(total_u32s).map_err(|e| { + ZKVMError::InvalidWitness(format!("compact_out D2H failed: {e}").into()) + })?; + + let records: Vec = unsafe { + let ptr = buf_vec.as_ptr() as *const GpuShardRamRecord; + std::slice::from_raw_parts(ptr, count).to_vec() + }; + tracing::debug!("GPU EC compact D2H: {} active records ({} bytes)", count, total_u32s * 4); + Ok(records) +} + +/// Batch compute EC points for continuation circuit ShardRamRecords on GPU. +/// +/// Converts ShardRamRecords to GPU format, launches the `batch_continuation_ec` +/// kernel to compute Poseidon2 + SepticCurve on device, and converts results +/// back to ShardRamInput (with EC points). +/// +/// Returns (write_inputs, read_inputs) maintaining the write-before-read ordering +/// invariant required by ShardRamCircuit::assign_instances. +pub fn gpu_batch_continuation_ec( + write_records: &[(crate::tables::ShardRamRecord, &'static str)], + read_records: &[(crate::tables::ShardRamRecord, &'static str)], +) -> Result< + ( + Vec>, + Vec>, + ), + ZKVMError, +> { + use crate::tables::ShardRamInput; + use gkr_iop::gpu::get_cuda_hal; + + let hal = get_cuda_hal().map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU not available for batch EC: {e}").into()) + })?; + + let total = write_records.len() + read_records.len(); + if total == 0 { + return Ok((vec![], vec![])); + } + + // Convert ShardRamRecords to GpuShardRamRecord format + let mut gpu_records: Vec = Vec::with_capacity(total); + for (rec, _name) in write_records.iter().chain(read_records.iter()) { + gpu_records.push(shard_ram_record_to_gpu(rec)); + } + + // GPU batch EC computation + let result = info_span!("gpu_batch_ec", n = total).in_scope(|| { + hal.batch_continuation_ec(&gpu_records) + }).map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU batch EC failed: {e}").into()) + })?; + + // Convert back to ShardRamInput, split into writes and reads + let mut write_inputs = Vec::with_capacity(write_records.len()); + let mut read_inputs = Vec::with_capacity(read_records.len()); + + for (i, gpu_rec) in result.iter().enumerate() { + let (rec, name) = if i < write_records.len() { + (&write_records[i].0, write_records[i].1) + } else { + let ri = i - write_records.len(); + (&read_records[ri].0, read_records[ri].1) + }; + + let ec_point = gpu_shard_ram_record_to_ec_point::(gpu_rec); + let input = ShardRamInput { + name, + record: rec.clone(), + ec_point, + }; + + if i < write_records.len() { + write_inputs.push(input); + } else { + read_inputs.push(input); + } + } + + Ok((write_inputs, read_inputs)) +} + +/// Convert a ShardRamRecord to GpuShardRamRecord (metadata only, EC fields zeroed). +pub(crate) fn shard_ram_record_to_gpu(rec: &crate::tables::ShardRamRecord) -> GpuShardRamRecord { + GpuShardRamRecord { + addr: rec.addr, + ram_type: match rec.ram_type { + RAMType::Register => 1, + RAMType::Memory => 2, + _ => 0, + }, + value: rec.value, + _pad0: 0, + shard: rec.shard, + local_clk: rec.local_clk, + global_clk: rec.global_clk, + is_to_write_set: if rec.is_to_write_set { 1 } else { 0 }, + nonce: 0, + point_x: [0; 7], + point_y: [0; 7], + } +} + +/// Convert a GPU-computed GpuShardRamRecord to ECPoint. +pub(crate) fn gpu_shard_ram_record_to_ec_point( + gpu_rec: &GpuShardRamRecord, +) -> crate::tables::ECPoint { + use crate::scheme::septic_curve::{SepticExtension, SepticPoint}; + + let mut point_x_arr = [E::BaseField::ZERO; 7]; + let mut point_y_arr = [E::BaseField::ZERO; 7]; + for j in 0..7 { + point_x_arr[j] = E::BaseField::from_canonical_u32(gpu_rec.point_x[j]); + point_y_arr[j] = E::BaseField::from_canonical_u32(gpu_rec.point_y[j]); + } + + let x = SepticExtension(point_x_arr); + let y = SepticExtension(point_y_arr); + let point = SepticPoint::from_affine(x, y); + + crate::tables::ECPoint { + nonce: gpu_rec.nonce, + point, + } +} + +pub(crate) fn gpu_lk_counters_to_multiplicity(counters: LkResult) -> Result, ZKVMError> { + let mut tables: [FxHashMap; 8] = Default::default(); + + // Dynamic: D2H + direct FxHashMap construction (no LkMultiplicity) + info_span!("lk_dynamic_d2h").in_scope(|| { + let counts: Vec = counters.dynamic.to_vec().map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU dynamic lk D2H failed: {e}").into()) + })?; + let nnz = counts.iter().filter(|&&c| c != 0).count(); + let map = &mut tables[LookupTable::Dynamic as usize]; + map.reserve(nnz); + for (key, &count) in counts.iter().enumerate() { + if count != 0 { + map.insert(key as u64, count as usize); + } + } + Ok::<(), ZKVMError>(()) + })?; + + // Dense tables: same pattern, skip None + info_span!("lk_dense_d2h").in_scope(|| { + let dense: &[(LookupTable, &Option)] = &[ + (LookupTable::DoubleU8, &counters.double_u8), + (LookupTable::And, &counters.and_table), + (LookupTable::Or, &counters.or_table), + (LookupTable::Xor, &counters.xor_table), + (LookupTable::Ltu, &counters.ltu_table), + (LookupTable::Pow, &counters.pow_table), + ]; + for &(table, ref buf_opt) in dense { + if let Some(buf) = buf_opt { + let counts: Vec = buf.to_vec().map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU {:?} lk D2H failed: {e}", table).into(), + ) + })?; + let nnz = counts.iter().filter(|&&c| c != 0).count(); + let map = &mut tables[table as usize]; + map.reserve(nnz); + for (key, &count) in counts.iter().enumerate() { + if count != 0 { + map.insert(key as u64, count as usize); + } + } + } + } + Ok::<(), ZKVMError>(()) + })?; + + // Fetch (Instruction table) + if let Some(fetch_buf) = counters.fetch { + info_span!("lk_fetch_d2h").in_scope(|| { + let base_pc = counters.fetch_base_pc; + let counts = fetch_buf.to_vec().map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU fetch lk D2H failed: {e}").into()) + })?; + let nnz = counts.iter().filter(|&&c| c != 0).count(); + let map = &mut tables[LookupTable::Instruction as usize]; + map.reserve(nnz); + for (slot_idx, &count) in counts.iter().enumerate() { + if count != 0 { + let pc = base_pc as u64 + (slot_idx as u64) * 4; + map.insert(pc, count as usize); + } + } + Ok::<(), ZKVMError>(()) + })?; + } + + Ok(Multiplicity(tables)) +} + +/// Convert GPU device buffer (column-major) to RowMajorMatrix via GPU transpose + D2H copy. +/// +/// GPU witgen kernels output column-major layout for better memory coalescing. +/// This function transposes to row-major on GPU before copying to host. +pub(crate) fn gpu_witness_to_rmm( + hal: &CudaHalBB31, + gpu_result: ceno_gpu::common::witgen_types::GpuWitnessResult< + ceno_gpu::common::BufferImpl<'static, ::BaseField>, + >, + num_rows: usize, + num_cols: usize, + padding: InstancePaddingStrategy, +) -> Result, ZKVMError> { + // Transpose from column-major to row-major on GPU. + // Column-major (num_rows x num_cols) is stored as num_cols groups of num_rows elements, + // which is equivalent to a (num_cols x num_rows) row-major matrix. + // Transposing with cols=num_rows, rows=num_cols produces (num_rows x num_cols) row-major. + let mut rmm_buffer = hal + .alloc_elems_on_device(num_rows * num_cols, false, None) + .map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU alloc for transpose failed: {e}").into()) + })?; + matrix_transpose::( + &hal.inner, + &mut rmm_buffer, + &gpu_result.device_buffer, + num_rows, + num_cols, + ) + .map_err(|e| ZKVMError::InvalidWitness(format!("GPU transpose failed: {e}").into()))?; + + let gpu_data: Vec<::BaseField> = rmm_buffer + .to_vec() + .map_err(|e| ZKVMError::InvalidWitness(format!("GPU D2H copy failed: {e}").into()))?; + + // Safety: BabyBear is the only supported GPU field, and E::BaseField must match + let data: Vec = unsafe { + let mut data = std::mem::ManuallyDrop::new(gpu_data); + Vec::from_raw_parts( + data.as_mut_ptr() as *mut E::BaseField, + data.len(), + data.capacity(), + ) + }; + + Ok(RowMajorMatrix::::new_by_values( + data, num_cols, padding, + )) +} diff --git a/ceno_zkvm/src/instructions/riscv/gpu/debug_compare.rs b/ceno_zkvm/src/instructions/riscv/gpu/debug_compare.rs new file mode 100644 index 000000000..be5af1af0 --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/gpu/debug_compare.rs @@ -0,0 +1,729 @@ +/// Debug comparison functions for GPU witness generation. +/// +/// These functions compare GPU-produced results against CPU baselines +/// to validate correctness. Activated by environment variables: +/// - CENO_GPU_DEBUG_COMPARE_LK: compare lookup multiplicities +/// - CENO_GPU_DEBUG_COMPARE_WITNESS: compare witness matrices +/// - CENO_GPU_DEBUG_COMPARE_SHARD: compare shard side effects +/// - CENO_GPU_DEBUG_COMPARE_EC: compare EC points +use ceno_emul::{StepIndex, StepRecord, WordAddr}; +use ceno_gpu::common::witgen_types::{GpuRamRecordSlot, GpuShardRamRecord}; +use ff_ext::ExtensionField; +use gkr_iop::{RAMType, tables::LookupTable, utils::lk_multiplicity::Multiplicity}; +use p3::field::FieldAlgebra; +use std::cell::Cell; +use witness::RowMajorMatrix; + +use crate::{ + e2e::ShardContext, + error::ZKVMError, + instructions::{Instruction, cpu_collect_shard_side_effects, cpu_collect_side_effects}, +}; + +use super::witgen_gpu::{GpuWitgenKind, set_force_cpu_path}; + +pub(crate) fn debug_compare_final_lk>( + config: &I::InstructionConfig, + shard_ctx: &ShardContext, + num_witin: usize, + num_structural_witin: usize, + shard_steps: &[StepRecord], + step_indices: &[StepIndex], + kind: GpuWitgenKind, + mixed_lk: &Multiplicity, +) -> Result<(), ZKVMError> { + if std::env::var_os("CENO_GPU_DEBUG_COMPARE_LK").is_none() { + return Ok(()); + } + + // Compare against cpu_assign_instances (the true baseline using assign_instance) + let mut cpu_ctx = shard_ctx.new_empty_like(); + let (_, cpu_assign_lk) = crate::instructions::cpu_assign_instances::( + config, + &mut cpu_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + )?; + tracing::info!("[GPU lk debug] kind={kind:?} comparing mixed_lk vs cpu_assign_instances lk"); + log_lk_diff(kind, &cpu_assign_lk, mixed_lk); + Ok(()) +} + +pub(crate) fn log_lk_diff(kind: GpuWitgenKind, cpu_lk: &Multiplicity, actual_lk: &Multiplicity) { + let limit = std::env::var("CENO_GPU_DEBUG_COMPARE_LK_LIMIT") + .ok() + .and_then(|v| v.parse::().ok()) + .unwrap_or(32); + + let mut total_diffs = 0usize; + for (table_idx, (cpu_table, actual_table)) in cpu_lk.iter().zip(actual_lk.iter()).enumerate() { + let mut keys = cpu_table + .keys() + .chain(actual_table.keys()) + .copied() + .collect::>(); + keys.sort_unstable(); + keys.dedup(); + + let mut table_diffs = Vec::new(); + for key in keys { + let cpu_count = cpu_table.get(&key).copied().unwrap_or(0); + let actual_count = actual_table.get(&key).copied().unwrap_or(0); + if cpu_count != actual_count { + table_diffs.push((key, cpu_count, actual_count)); + } + } + + if !table_diffs.is_empty() { + total_diffs += table_diffs.len(); + tracing::error!( + "[GPU lk debug] kind={kind:?} table={} diff_count={}", + lookup_table_name(table_idx), + table_diffs.len() + ); + for (key, cpu_count, actual_count) in table_diffs.into_iter().take(limit) { + tracing::error!( + "[GPU lk debug] kind={kind:?} table={} key={} cpu={} gpu={}", + lookup_table_name(table_idx), + key, + cpu_count, + actual_count + ); + } + } + } + + if total_diffs == 0 { + tracing::info!("[GPU lk debug] kind={kind:?} CPU/GPU lookup multiplicities match"); + } +} + +pub(crate) fn debug_compare_witness>( + config: &I::InstructionConfig, + shard_ctx: &ShardContext, + num_witin: usize, + num_structural_witin: usize, + shard_steps: &[StepRecord], + step_indices: &[StepIndex], + kind: GpuWitgenKind, + gpu_witness: &RowMajorMatrix, +) -> Result<(), ZKVMError> { + if std::env::var_os("CENO_GPU_DEBUG_COMPARE_WITNESS").is_none() { + return Ok(()); + } + + let mut cpu_ctx = shard_ctx.new_empty_like(); + let (cpu_rmms, _) = crate::instructions::cpu_assign_instances::( + config, + &mut cpu_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + )?; + let cpu_witness = &cpu_rmms[0]; + let cpu_vals = cpu_witness.values(); + let gpu_vals = gpu_witness.values(); + if cpu_vals == gpu_vals { + return Ok(()); + } + + let limit = std::env::var("CENO_GPU_DEBUG_COMPARE_WITNESS_LIMIT") + .ok() + .and_then(|v| v.parse::().ok()) + .unwrap_or(16); + let cpu_num_cols = cpu_witness.n_col(); + let cpu_num_rows = cpu_vals.len() / cpu_num_cols; + let mut mismatches = 0usize; + for row in 0..cpu_num_rows { + for col in 0..cpu_num_cols { + let idx = row * cpu_num_cols + col; + if cpu_vals[idx] != gpu_vals[idx] { + mismatches += 1; + if mismatches <= limit { + tracing::error!( + "[GPU witness debug] kind={kind:?} row={} col={} cpu={:?} gpu={:?}", + row, + col, + cpu_vals[idx], + gpu_vals[idx] + ); + } + } + } + } + tracing::error!( + "[GPU witness debug] kind={kind:?} total_mismatches={}", + mismatches + ); + Ok(()) +} + +pub(crate) fn debug_compare_shard_side_effects>( + config: &I::InstructionConfig, + shard_ctx: &ShardContext, + shard_steps: &[StepRecord], + step_indices: &[StepIndex], + kind: GpuWitgenKind, +) -> Result<(), ZKVMError> { + if std::env::var_os("CENO_GPU_DEBUG_COMPARE_SHARD").is_none() { + return Ok(()); + } + + let mut cpu_ctx = shard_ctx.new_empty_like(); + let _ = cpu_collect_side_effects::(config, &mut cpu_ctx, shard_steps, step_indices)?; + + let mut mixed_ctx = shard_ctx.new_empty_like(); + let _ = + cpu_collect_shard_side_effects::(config, &mut mixed_ctx, shard_steps, step_indices)?; + + let cpu_addr = cpu_ctx.get_addr_accessed(); + let mixed_addr = mixed_ctx.get_addr_accessed(); + if cpu_addr != mixed_addr { + tracing::error!( + "[GPU shard debug] kind={kind:?} addr_accessed cpu={} gpu={}", + cpu_addr.len(), + mixed_addr.len() + ); + } + + let cpu_reads = flatten_ram_records(cpu_ctx.read_records()); + let mixed_reads = flatten_ram_records(mixed_ctx.read_records()); + if cpu_reads != mixed_reads { + log_ram_record_diff(kind, "read_records", &cpu_reads, &mixed_reads); + } + + let cpu_writes = flatten_ram_records(cpu_ctx.write_records()); + let mixed_writes = flatten_ram_records(mixed_ctx.write_records()); + if cpu_writes != mixed_writes { + log_ram_record_diff(kind, "write_records", &cpu_writes, &mixed_writes); + } + + Ok(()) +} + +/// Compare GPU shard context vs CPU shard context, field by field. +/// +/// Both paths are independent and produce equivalent ShardContext state: +/// CPU path: cpu_collect_shard_side_effects -> addr_accessed + write_records + read_records +/// GPU path: compact_records -> shard records (gpu_ec_records) +/// ram_slots WAS_SENT -> addr_accessed +/// (write_records and read_records stay empty for GPU EC kernels) +/// +/// This function builds both independently and compares: +/// A. addr_accessed sets +/// B. shard records (sorted, normalized to ShardRamRecord) +/// C. EC points (nonce + SepticPoint x,y) +/// +/// Activated by CENO_GPU_DEBUG_COMPARE_EC=1. +pub(crate) fn debug_compare_shard_ec>( + compact_records: &[GpuShardRamRecord], + ram_slots: &[GpuRamRecordSlot], + config: &I::InstructionConfig, + shard_ctx: &ShardContext, + shard_steps: &[StepRecord], + step_indices: &[StepIndex], + kind: GpuWitgenKind, +) { + if std::env::var_os("CENO_GPU_DEBUG_COMPARE_EC").is_none() { + return; + } + + use crate::scheme::septic_curve::{SepticExtension, SepticPoint}; + use crate::tables::{ECPoint, ShardRamRecord}; + use ff_ext::{PoseidonField, SmallField}; + + let limit = std::env::var("CENO_GPU_DEBUG_COMPARE_EC_LIMIT") + .ok() + .and_then(|v| v.parse::().ok()) + .unwrap_or(16); + + // ========== Build CPU shard context (independent, isolated) ========== + let mut cpu_ctx = shard_ctx.new_empty_like(); + if let Err(e) = cpu_collect_shard_side_effects::( + config, &mut cpu_ctx, shard_steps, step_indices, + ) { + tracing::error!("[GPU EC debug] kind={kind:?} CPU shard side effects failed: {e:?}"); + return; + } + + let perm = ::get_default_perm(); + + // CPU: addr_accessed + let cpu_addr = cpu_ctx.get_addr_accessed(); + + // CPU: shard records (BTreeMap -> ShardRamRecord + ECPoint) + let mut cpu_entries: Vec<(ShardRamRecord, ECPoint)> = Vec::new(); + for records in cpu_ctx.write_records() { + for (vma, record) in records { + let rec: ShardRamRecord = (vma, record, true).into(); + let ec = rec.to_ec_point::(&perm); + cpu_entries.push((rec, ec)); + } + } + for records in cpu_ctx.read_records() { + for (vma, record) in records { + let rec: ShardRamRecord = (vma, record, false).into(); + let ec = rec.to_ec_point::(&perm); + cpu_entries.push((rec, ec)); + } + } + cpu_entries.sort_by_key(|(r, _)| (r.addr, r.is_to_write_set as u8, r.ram_type as u8)); + + // ========== Build GPU shard context (independent, from D2H data only) ========== + + // GPU: addr_accessed (from ram_slots WAS_SENT flags) + let gpu_addr: rustc_hash::FxHashSet = ram_slots + .iter() + .filter(|s| s.flags & (1 << 4) != 0) + .map(|s| WordAddr(s.addr)) + .collect(); + + // GPU: shard records (compact_records -> ShardRamRecord + ECPoint) + let mut gpu_entries: Vec<(ShardRamRecord, ECPoint)> = compact_records + .iter() + .map(|g| { + let rec = ShardRamRecord { + addr: g.addr, + ram_type: if g.ram_type == 1 { RAMType::Register } else { RAMType::Memory }, + value: g.value, + shard: g.shard, + local_clk: g.local_clk, + global_clk: g.global_clk, + is_to_write_set: g.is_to_write_set != 0, + }; + let x = SepticExtension(g.point_x.map(|v| E::BaseField::from_canonical_u32(v))); + let y = SepticExtension(g.point_y.map(|v| E::BaseField::from_canonical_u32(v))); + let point = SepticPoint::from_affine(x, y); + let ec = ECPoint:: { nonce: g.nonce, point }; + (rec, ec) + }) + .collect(); + gpu_entries.sort_by_key(|(r, _)| (r.addr, r.is_to_write_set as u8, r.ram_type as u8)); + + // ========== Compare A: addr_accessed ========== + if cpu_addr != gpu_addr { + let cpu_only: Vec<_> = cpu_addr.difference(&gpu_addr).collect(); + let gpu_only: Vec<_> = gpu_addr.difference(&cpu_addr).collect(); + tracing::error!( + "[GPU EC debug] kind={kind:?} ADDR_ACCESSED MISMATCH: cpu={} gpu={} \ + cpu_only={} gpu_only={}", + cpu_addr.len(), gpu_addr.len(), cpu_only.len(), gpu_only.len() + ); + for (i, addr) in cpu_only.iter().enumerate() { + if i >= limit { break; } + tracing::error!("[GPU EC debug] kind={kind:?} addr_accessed CPU-only: {}", addr.0); + } + for (i, addr) in gpu_only.iter().enumerate() { + if i >= limit { break; } + tracing::error!("[GPU EC debug] kind={kind:?} addr_accessed GPU-only: {}", addr.0); + } + } + + // ========== Compare B+C: shard records + EC points ========== + + // Check counts + if cpu_entries.len() != gpu_entries.len() { + tracing::error!( + "[GPU EC debug] kind={kind:?} RECORD COUNT MISMATCH: cpu={} gpu={}", + cpu_entries.len(), gpu_entries.len() + ); + let cpu_keys: std::collections::BTreeSet<_> = cpu_entries + .iter().map(|(r, _)| (r.addr, r.is_to_write_set)).collect(); + let gpu_keys: std::collections::BTreeSet<_> = gpu_entries + .iter().map(|(r, _)| (r.addr, r.is_to_write_set)).collect(); + let mut logged = 0usize; + for key in cpu_keys.difference(&gpu_keys) { + if logged >= limit { break; } + tracing::error!("[GPU EC debug] kind={kind:?} CPU-only: addr={} is_write={}", key.0, key.1); + logged += 1; + } + for key in gpu_keys.difference(&cpu_keys) { + if logged >= limit { break; } + tracing::error!("[GPU EC debug] kind={kind:?} GPU-only: addr={} is_write={}", key.0, key.1); + logged += 1; + } + } + + // Check GPU duplicates (BTreeMap deduplicates, atomicAdd doesn't) + let mut gpu_dup_count = 0usize; + for w in gpu_entries.windows(2) { + if w[0].0.addr == w[1].0.addr + && w[0].0.is_to_write_set == w[1].0.is_to_write_set + && w[0].0.ram_type == w[1].0.ram_type + { + gpu_dup_count += 1; + if gpu_dup_count <= limit { + tracing::error!( + "[GPU EC debug] kind={kind:?} GPU DUPLICATE: addr={} is_write={} ram_type={:?}", + w[0].0.addr, w[0].0.is_to_write_set, w[0].0.ram_type + ); + } + } + } + + // Merge-walk sorted lists + let mut ci = 0usize; + let mut gi = 0usize; + let mut record_mismatches = 0usize; + let mut ec_mismatches = 0usize; + let mut matched = 0usize; + + while ci < cpu_entries.len() && gi < gpu_entries.len() { + let (cr, ce) = &cpu_entries[ci]; + let (gr, ge) = &gpu_entries[gi]; + let ck = (cr.addr, cr.is_to_write_set as u8, cr.ram_type as u8); + let gk = (gr.addr, gr.is_to_write_set as u8, gr.ram_type as u8); + + match ck.cmp(&gk) { + std::cmp::Ordering::Less => { + if record_mismatches < limit { + tracing::error!( + "[GPU EC debug] kind={kind:?} MISSING in GPU: addr={} is_write={} ram={:?} val={} shard={} clk={}", + cr.addr, cr.is_to_write_set, cr.ram_type, cr.value, cr.shard, cr.global_clk + ); + } + record_mismatches += 1; + ci += 1; + continue; + } + std::cmp::Ordering::Greater => { + if record_mismatches < limit { + tracing::error!( + "[GPU EC debug] kind={kind:?} EXTRA in GPU: addr={} is_write={} ram={:?} val={} shard={} clk={}", + gr.addr, gr.is_to_write_set, gr.ram_type, gr.value, gr.shard, gr.global_clk + ); + } + record_mismatches += 1; + gi += 1; + continue; + } + std::cmp::Ordering::Equal => {} + } + + // Keys match -- compare record fields + let mut field_diff = false; + for (name, cv, gv) in [ + ("value", cr.value as u64, gr.value as u64), + ("shard", cr.shard, gr.shard), + ("local_clk", cr.local_clk, gr.local_clk), + ("global_clk", cr.global_clk, gr.global_clk), + ] { + if cv != gv { + field_diff = true; + if record_mismatches < limit { + tracing::error!( + "[GPU EC debug] kind={kind:?} addr={} {name}: cpu={cv} gpu={gv}", + cr.addr + ); + } + } + } + if field_diff { + record_mismatches += 1; + } + + // Compare EC points + let mut ec_diff = false; + if ce.nonce != ge.nonce { + ec_diff = true; + if ec_mismatches < limit { + tracing::error!( + "[GPU EC debug] kind={kind:?} addr={} nonce: cpu={} gpu={}", + cr.addr, ce.nonce, ge.nonce + ); + } + } + for j in 0..7 { + let cv = ce.point.x.0[j].to_canonical_u64() as u32; + let gv = ge.point.x.0[j].to_canonical_u64() as u32; + if cv != gv { + ec_diff = true; + if ec_mismatches < limit { + tracing::error!( + "[GPU EC debug] kind={kind:?} addr={} x[{j}]: cpu={cv} gpu={gv}", cr.addr + ); + } + } + } + for j in 0..7 { + let cv = ce.point.y.0[j].to_canonical_u64() as u32; + let gv = ge.point.y.0[j].to_canonical_u64() as u32; + if cv != gv { + ec_diff = true; + if ec_mismatches < limit { + tracing::error!( + "[GPU EC debug] kind={kind:?} addr={} y[{j}]: cpu={cv} gpu={gv}", cr.addr + ); + } + } + } + if ec_diff { + ec_mismatches += 1; + } + + matched += 1; + ci += 1; + gi += 1; + } + + // Remaining unmatched + while ci < cpu_entries.len() { + if record_mismatches < limit { + let (cr, _) = &cpu_entries[ci]; + tracing::error!( + "[GPU EC debug] kind={kind:?} MISSING in GPU (tail): addr={} is_write={} val={}", + cr.addr, cr.is_to_write_set, cr.value + ); + } + record_mismatches += 1; + ci += 1; + } + while gi < gpu_entries.len() { + if record_mismatches < limit { + let (gr, _) = &gpu_entries[gi]; + tracing::error!( + "[GPU EC debug] kind={kind:?} EXTRA in GPU (tail): addr={} is_write={} val={}", + gr.addr, gr.is_to_write_set, gr.value + ); + } + record_mismatches += 1; + gi += 1; + } + + // ========== Summary ========== + let addr_ok = cpu_addr == gpu_addr; + if addr_ok && record_mismatches == 0 && ec_mismatches == 0 && gpu_dup_count == 0 { + tracing::info!( + "[GPU EC debug] kind={kind:?} ALL MATCH: {} records, {} addr_accessed, EC points OK", + matched, cpu_addr.len() + ); + } else { + tracing::error!( + "[GPU EC debug] kind={kind:?} MISMATCH: matched={matched} record_diffs={record_mismatches} \ + ec_diffs={ec_mismatches} gpu_dups={gpu_dup_count} addr_ok={addr_ok} \ + (cpu_records={} gpu_records={} cpu_addrs={} gpu_addrs={})", + cpu_entries.len(), gpu_entries.len(), cpu_addr.len(), gpu_addr.len() + ); + } +} + +pub(crate) fn flatten_ram_records( + records: &[std::collections::BTreeMap], +) -> Vec<(u32, u64, u64, u64, u64, Option, u32, usize)> { + let mut flat = Vec::new(); + for table in records { + for (addr, record) in table { + flat.push(( + addr.0, + record.reg_id, + record.prev_cycle, + record.cycle, + record.shard_cycle, + record.prev_value, + record.value, + record.shard_id, + )); + } + } + flat +} + +pub(crate) fn log_ram_record_diff( + kind: GpuWitgenKind, + label: &str, + cpu_records: &[(u32, u64, u64, u64, u64, Option, u32, usize)], + mixed_records: &[(u32, u64, u64, u64, u64, Option, u32, usize)], +) { + let limit = std::env::var("CENO_GPU_DEBUG_COMPARE_SHARD_LIMIT") + .ok() + .and_then(|v| v.parse::().ok()) + .unwrap_or(16); + tracing::error!( + "[GPU shard debug] kind={kind:?} {} cpu={} gpu={}", + label, + cpu_records.len(), + mixed_records.len() + ); + let max_len = cpu_records.len().max(mixed_records.len()); + let mut logged = 0usize; + for idx in 0..max_len { + let cpu = cpu_records.get(idx); + let gpu = mixed_records.get(idx); + if cpu != gpu { + tracing::error!( + "[GPU shard debug] kind={kind:?} {} idx={} cpu={:?} gpu={:?}", + label, + idx, + cpu, + gpu + ); + logged += 1; + if logged >= limit { + break; + } + } + } +} + +pub(crate) fn lookup_table_name(table_idx: usize) -> &'static str { + match table_idx { + x if x == LookupTable::Dynamic as usize => "Dynamic", + x if x == LookupTable::DoubleU8 as usize => "DoubleU8", + x if x == LookupTable::And as usize => "And", + x if x == LookupTable::Or as usize => "Or", + x if x == LookupTable::Xor as usize => "Xor", + x if x == LookupTable::Ltu as usize => "Ltu", + x if x == LookupTable::Pow as usize => "Pow", + x if x == LookupTable::Instruction as usize => "Instruction", + _ => "Unknown", + } +} + +/// Debug comparison for keccak GPU witgen. +/// Runs the CPU path and compares LK / witness / shard side effects. +/// +/// Activated by CENO_GPU_DEBUG_COMPARE_LK, CENO_GPU_DEBUG_COMPARE_WITNESS, +/// or CENO_GPU_DEBUG_COMPARE_SHARD environment variables. +#[cfg(feature = "gpu")] +pub(crate) fn debug_compare_keccak( + config: &crate::instructions::riscv::ecall::keccak::EcallKeccakConfig, + shard_ctx: &ShardContext, + num_witin: usize, + num_structural_witin: usize, + steps: &[StepRecord], + step_indices: &[StepIndex], + gpu_lk: &Multiplicity, + gpu_witin: &RowMajorMatrix, + gpu_addrs: &[u32], +) -> Result<(), ZKVMError> { + let want_lk = std::env::var_os("CENO_GPU_DEBUG_COMPARE_LK").is_some(); + let want_witness = std::env::var_os("CENO_GPU_DEBUG_COMPARE_WITNESS").is_some(); + let want_shard = std::env::var_os("CENO_GPU_DEBUG_COMPARE_SHARD").is_some(); + + if !want_lk && !want_witness && !want_shard { + return Ok(()); + } + + // Guard against recursion: is_gpu_witgen_disabled() uses OnceLock so env var + // manipulation doesn't work. Use a thread-local flag instead. + thread_local! { + static IN_DEBUG_COMPARE: Cell = const { Cell::new(false) }; + } + if IN_DEBUG_COMPARE.with(|f| f.get()) { + return Ok(()); + } + IN_DEBUG_COMPARE.with(|f| f.set(true)); + + tracing::info!("[GPU keccak debug] running CPU baseline for comparison"); + + // Run CPU path via assign_instances. The IN_DEBUG_COMPARE guard prevents + // gpu_assign_keccak_instances from calling debug_compare_keccak again, + // so it will produce the GPU result, which is then returned without + // re-entering this function. We need assign_instances (not cpu_assign_instances) + // because keccak has rotation matrices and 3 structural columns. + // + // To force the CPU path, we use is_force_cpu_path() by setting the env var. + let mut cpu_ctx = shard_ctx.new_empty_like(); + let (cpu_rmms, cpu_lk) = { + use crate::instructions::riscv::ecall::keccak::KeccakInstruction; + // Set force-CPU flag so gpu_assign_keccak_instances returns None + set_force_cpu_path(true); + let result = as crate::instructions::Instruction>::assign_instances( + config, + &mut cpu_ctx, + num_witin, + num_structural_witin, + steps, + step_indices, + ); + set_force_cpu_path(false); + IN_DEBUG_COMPARE.with(|f| f.set(false)); + result? + }; + + let kind = GpuWitgenKind::Keccak; + + if want_lk { + tracing::info!("[GPU keccak debug] comparing LK multiplicities"); + log_lk_diff(kind, &cpu_lk, gpu_lk); + } + + if want_witness { + let limit = std::env::var("CENO_GPU_DEBUG_COMPARE_WITNESS_LIMIT") + .ok() + .and_then(|v| v.parse::().ok()) + .unwrap_or(32); + let cpu_witin = &cpu_rmms[0]; + let gpu_vals = gpu_witin.values(); + let cpu_vals = cpu_witin.values(); + let mut diffs = 0usize; + for (i, (g, c)) in gpu_vals.iter().zip(cpu_vals.iter()).enumerate() { + if g != c { + if diffs < limit { + let row = i / num_witin; + let col = i % num_witin; + tracing::error!( + "[GPU keccak witness] row={} col={} gpu={:?} cpu={:?}", + row, col, g, c + ); + } + diffs += 1; + } + } + if diffs == 0 { + tracing::info!("[GPU keccak debug] witness matrices match ({} elements)", gpu_vals.len()); + } else { + tracing::error!("[GPU keccak debug] witness mismatch: {} diffs out of {}", diffs, gpu_vals.len()); + } + } + + if want_shard { + // Compare addr_accessed: GPU entries were D2H'd from the shared buffer + // delta (before/after kernel launch) and passed in as gpu_addrs. + let cpu_addr = cpu_ctx.get_addr_accessed(); + let gpu_addr_set: rustc_hash::FxHashSet = + gpu_addrs.iter().map(|&a| WordAddr(a)).collect(); + + if cpu_addr.len() != gpu_addr_set.len() { + tracing::error!( + "[GPU keccak shard] addr_accessed count mismatch: cpu={} gpu={}", + cpu_addr.len(), gpu_addr_set.len() + ); + } + let mut missing_from_gpu = 0usize; + let mut extra_in_gpu = 0usize; + let limit = 16usize; + for addr in &cpu_addr { + if !gpu_addr_set.contains(addr) { + if missing_from_gpu < limit { + tracing::error!("[GPU keccak shard] addr {} in CPU but not GPU", addr.0); + } + missing_from_gpu += 1; + } + } + for &addr in gpu_addrs { + if !cpu_addr.contains(&WordAddr(addr)) { + if extra_in_gpu < limit { + tracing::error!("[GPU keccak shard] addr {} in GPU but not CPU", addr); + } + extra_in_gpu += 1; + } + } + if missing_from_gpu == 0 && extra_in_gpu == 0 { + tracing::info!( + "[GPU keccak shard] addr_accessed matches: {} entries", + cpu_addr.len() + ); + } else { + tracing::error!( + "[GPU keccak shard] addr_accessed diff: missing_from_gpu={} extra_in_gpu={}", + missing_from_gpu, extra_in_gpu + ); + } + } + + Ok(()) +} diff --git a/ceno_zkvm/src/instructions/riscv/gpu/device_cache.rs b/ceno_zkvm/src/instructions/riscv/gpu/device_cache.rs new file mode 100644 index 000000000..086f1c983 --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/gpu/device_cache.rs @@ -0,0 +1,532 @@ +/// Device buffer caching for GPU witness generation. +/// +/// Manages thread-local caches for shard step data and shard metadata +/// device buffers, avoiding redundant host-to-device transfers within +/// the same shard. +use ceno_emul::{StepRecord, WordAddr}; +use ceno_gpu::{ + Buffer, CudaHal, CudaSlice, bb31::CudaHalBB31, bb31::ShardDeviceBuffers, + common::witgen_types::GpuShardRamRecord, common::witgen_types::GpuShardScalars, +}; +use std::cell::RefCell; +use tracing::info_span; + +use crate::{ + e2e::ShardContext, + error::ZKVMError, +}; + +/// Cached shard_steps device buffer with metadata for logging. +struct ShardStepsCache { + host_ptr: usize, + byte_len: usize, + shard_id: usize, + n_steps: usize, + device_buf: CudaSlice, +} + +// Thread-local cache for shard_steps device buffer. Invalidated when shard changes. +thread_local! { + static SHARD_STEPS_DEVICE: RefCell> = + const { RefCell::new(None) }; +} + +/// Upload shard_steps to GPU, reusing cached device buffer if the same data. +pub(crate) fn upload_shard_steps_cached( + hal: &CudaHalBB31, + shard_steps: &[StepRecord], + shard_id: usize, +) -> Result<(), ZKVMError> { + let ptr = shard_steps.as_ptr() as usize; + let byte_len = shard_steps.len() * std::mem::size_of::(); + + SHARD_STEPS_DEVICE.with(|cache| { + let mut cache = cache.borrow_mut(); + if let Some(c) = cache.as_ref() { + if c.host_ptr == ptr && c.byte_len == byte_len { + return Ok(()); // cache hit + } + } + // Cache miss: upload + let mb = byte_len as f64 / (1024.0 * 1024.0); + tracing::info!( + "[GPU witgen] uploading shard_steps: shard_id={}, n_steps={}, {:.2} MB", + shard_id, + shard_steps.len(), + mb, + ); + let bytes: &[u8] = + unsafe { std::slice::from_raw_parts(shard_steps.as_ptr() as *const u8, byte_len) }; + let device_buf = hal.inner.htod_copy_stream(None, bytes).map_err(|e| { + ZKVMError::InvalidWitness(format!("shard_steps H2D failed: {e}").into()) + })?; + *cache = Some(ShardStepsCache { + host_ptr: ptr, + byte_len, + shard_id, + n_steps: shard_steps.len(), + device_buf, + }); + Ok(()) + }) +} + +/// Borrow the cached device buffer for kernel launch. +/// Panics if `upload_shard_steps_cached` was not called first. +pub(crate) fn with_cached_shard_steps(f: impl FnOnce(&CudaSlice) -> R) -> R { + SHARD_STEPS_DEVICE.with(|cache| { + let cache = cache.borrow(); + let c = cache.as_ref().expect("shard_steps not uploaded"); + f(&c.device_buf) + }) +} + +/// Invalidate the cached shard_steps device buffer. +/// Call this when shard processing is complete to free GPU memory. +pub fn invalidate_shard_steps_cache() { + SHARD_STEPS_DEVICE.with(|cache| { + let mut cache = cache.borrow_mut(); + if let Some(c) = cache.as_ref() { + let mb = c.byte_len as f64 / (1024.0 * 1024.0); + tracing::info!( + "[GPU witgen] releasing shard_steps cache: shard_id={}, n_steps={}, {:.2} MB", + c.shard_id, + c.n_steps, + mb, + ); + } + *cache = None; + }); +} + +/// Cached shard metadata device buffers for GPU shard records. +/// Invalidated when shard_id changes; shared across all kernel invocations in one shard. +struct ShardMetadataCache { + shard_id: usize, + device_bufs: ShardDeviceBuffers, + /// Shared EC record buffer (owns the GPU memory, pointer stored in device_bufs). + shared_ec_buf: Option>, + /// Shared EC record count buffer (single u32 counter). + shared_ec_count: Option>, + /// Shared addr_accessed buffer (u32 word addresses). + shared_addr_buf: Option>, + /// Shared addr_accessed count buffer (single u32 counter). + shared_addr_count: Option>, +} + +thread_local! { + static SHARD_META_CACHE: RefCell> = + const { RefCell::new(None) }; +} + +/// Build and cache shard metadata device buffers for GPU shard records. +/// +/// FA (future access) device buffers are global and identical across all shards, +/// so they are uploaded once and reused via move. Only per-shard data (scalars + +/// prev_shard_ranges) is re-uploaded when the shard changes. +pub(crate) fn ensure_shard_metadata_cached( + hal: &CudaHalBB31, + shard_ctx: &ShardContext, + n_total_steps: usize, +) -> Result<(), ZKVMError> { + let shard_id = shard_ctx.shard_id; + SHARD_META_CACHE.with(|cache| { + let mut cache = cache.borrow_mut(); + if let Some(c) = cache.as_ref() { + if c.shard_id == shard_id { + return Ok(()); // cache hit + } + } + + // Move FA device buffer from previous cache (reuse across shards). + // FA data is global — identical across all shards — so we reuse, not re-upload. + let existing_fa = cache.take().map(|c| { + let ShardDeviceBuffers { + next_access_packed, + scalars: _, + prev_shard_cycle_range: _, + prev_shard_heap_range: _, + prev_shard_hint_range: _, + gpu_ec_shard_id: _, + shared_ec_out_ptr: _, + shared_ec_count_ptr: _, + shared_addr_out_ptr: _, + shared_addr_count_ptr: _, + shared_ec_capacity: _, + shared_addr_capacity: _, + } = c.device_bufs; + next_access_packed + }); + + let next_access_packed_device = if let Some(fa) = existing_fa { + fa // Reuse existing GPU memory — zero cost pointer move + } else { + // First shard: bulk H2D upload packed FA entries (no sort here) + let sorted = &shard_ctx.sorted_next_accesses; + tracing::info_span!("next_access_h2d").in_scope(|| -> Result<_, ZKVMError> { + let packed_bytes: &[u8] = if sorted.packed.is_empty() { + &[0u8; 16] // sentinel for empty + } else { + unsafe { + std::slice::from_raw_parts( + sorted.packed.as_ptr() as *const u8, + sorted.packed.len() * std::mem::size_of::(), + ) + } + }; + let buf = hal.inner.htod_copy_stream(None, packed_bytes).map_err(|e| { + ZKVMError::InvalidWitness(format!("next_access_packed H2D: {e}").into()) + })?; + let next_access_device = ceno_gpu::common::buffer::BufferImpl::new(buf); + let mb = packed_bytes.len() as f64 / (1024.0 * 1024.0); + tracing::info!( + "[GPU shard] FA uploaded once: {} entries, {:.2} MB (packed)", + sorted.packed.len(), + mb, + ); + Ok(next_access_device) + })? + }; + + // Per-shard: always re-upload scalars + prev_shard_ranges + let scalars = GpuShardScalars { + shard_cycle_start: shard_ctx.cur_shard_cycle_range.start as u64, + shard_cycle_end: shard_ctx.cur_shard_cycle_range.end as u64, + shard_offset_cycle: shard_ctx.current_shard_offset_cycle(), + shard_id: shard_id as u32, + heap_start: shard_ctx.platform.heap.start, + heap_end: shard_ctx.platform.heap.end, + hint_start: shard_ctx.platform.hints.start, + hint_end: shard_ctx.platform.hints.end, + shard_heap_start: shard_ctx.shard_heap_addr_range.start, + shard_heap_end: shard_ctx.shard_heap_addr_range.end, + shard_hint_start: shard_ctx.shard_hint_addr_range.start, + shard_hint_end: shard_ctx.shard_hint_addr_range.end, + next_access_count: shard_ctx.sorted_next_accesses.packed.len() as u32, + num_prev_shards: shard_ctx.prev_shard_cycle_range.len() as u32, + num_prev_heap_ranges: shard_ctx.prev_shard_heap_range.len() as u32, + num_prev_hint_ranges: shard_ctx.prev_shard_hint_range.len() as u32, + }; + + let (scalars_device, pscr_device, pshr_device, pshi_device) = + tracing::info_span!("shard_scalars_h2d").in_scope(|| -> Result<_, ZKVMError> { + let scalars_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + &scalars as *const GpuShardScalars as *const u8, + std::mem::size_of::(), + ) + }; + let scalars_device = + hal.inner + .htod_copy_stream(None, scalars_bytes) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("shard scalars H2D failed: {e}").into(), + ) + })?; + + let pscr = &shard_ctx.prev_shard_cycle_range; + let pscr_device = hal + .alloc_u64_from_host(if pscr.is_empty() { &[0u64] } else { pscr }, None) + .map_err(|e| { + ZKVMError::InvalidWitness(format!("pscr H2D failed: {e}").into()) + })?; + + let pshr = &shard_ctx.prev_shard_heap_range; + let pshr_device = hal + .alloc_u32_from_host(if pshr.is_empty() { &[0u32] } else { pshr }, None) + .map_err(|e| { + ZKVMError::InvalidWitness(format!("pshr H2D failed: {e}").into()) + })?; + + let pshi = &shard_ctx.prev_shard_hint_range; + let pshi_device = hal + .alloc_u32_from_host(if pshi.is_empty() { &[0u32] } else { pshi }, None) + .map_err(|e| { + ZKVMError::InvalidWitness(format!("pshi H2D failed: {e}").into()) + })?; + + Ok((scalars_device, pscr_device, pshr_device, pshi_device)) + })?; + + tracing::info!( + "[GPU shard] shard_id={}: per-shard scalars updated", + shard_id, + ); + + // Allocate shared EC/addr compact buffers for this shard. + // + // EC records: cross-shard only (sparse subset of RAM ops). + // 104 bytes each (26 u32s). Cap at 16M entries ≈ 1.6 GB. + // Addr records: every gpu_send() emits one (dense). + // 4 bytes each (1 u32). Cap at 256M entries ≈ 1 GB. + let max_ops_per_step = 52u64; // keccak worst case + let total_ops_estimate = n_total_steps as u64 * max_ops_per_step; + let ec_capacity = total_ops_estimate.min(16 * 1024 * 1024) as usize; + let ec_u32s = ec_capacity * 26; // 26 u32s per GpuShardRamRecord (104 bytes) + let addr_capacity = total_ops_estimate.min(256 * 1024 * 1024) as usize; + + let shared_ec_buf = hal + .alloc_u32_zeroed(ec_u32s, None) + .map_err(|e| ZKVMError::InvalidWitness(format!("shared_ec_buf alloc: {e}").into()))?; + let shared_ec_count = hal + .alloc_u32_zeroed(1, None) + .map_err(|e| { + ZKVMError::InvalidWitness(format!("shared_ec_count alloc: {e}").into()) + })?; + let shared_addr_buf = hal + .alloc_u32_zeroed(addr_capacity, None) + .map_err(|e| { + ZKVMError::InvalidWitness(format!("shared_addr_buf alloc: {e}").into()) + })?; + let shared_addr_count = hal + .alloc_u32_zeroed(1, None) + .map_err(|e| { + ZKVMError::InvalidWitness(format!("shared_addr_count alloc: {e}").into()) + })?; + + let shared_ec_out_ptr = shared_ec_buf.device_ptr() as u64; + let shared_ec_count_ptr = shared_ec_count.device_ptr() as u64; + let shared_addr_out_ptr = shared_addr_buf.device_ptr() as u64; + let shared_addr_count_ptr = shared_addr_count.device_ptr() as u64; + + tracing::info!( + "[GPU shard] shard_id={}: shared buffers allocated: ec_capacity={}, addr_capacity={}", + shard_id, ec_capacity, addr_capacity, + ); + + *cache = Some(ShardMetadataCache { + shard_id, + device_bufs: ShardDeviceBuffers { + scalars: scalars_device, + next_access_packed: next_access_packed_device, + prev_shard_cycle_range: pscr_device, + prev_shard_heap_range: pshr_device, + prev_shard_hint_range: pshi_device, + gpu_ec_shard_id: Some(shard_id as u64), + shared_ec_out_ptr, + shared_ec_count_ptr, + shared_addr_out_ptr, + shared_addr_count_ptr, + shared_ec_capacity: ec_capacity as u32, + shared_addr_capacity: addr_capacity as u32, + }, + shared_ec_buf: Some(shared_ec_buf), + shared_ec_count: Some(shared_ec_count), + shared_addr_buf: Some(shared_addr_buf), + shared_addr_count: Some(shared_addr_count), + }); + Ok(()) + }) +} + +/// Borrow the cached shard device buffers for kernel launch. +pub(crate) fn with_cached_shard_meta(f: impl FnOnce(&ShardDeviceBuffers) -> R) -> R { + SHARD_META_CACHE.with(|cache| { + let cache = cache.borrow(); + let c = cache.as_ref().expect("shard metadata not uploaded"); + f(&c.device_bufs) + }) +} + +/// Invalidate the shard metadata cache (call when shard processing is complete). +pub fn invalidate_shard_meta_cache() { + SHARD_META_CACHE.with(|cache| { + *cache.borrow_mut() = None; + }); +} + +/// Take ownership of shared EC and addr_accessed device buffers from the cache. +/// +/// Returns (shared_ec_buf, ec_count, shared_addr_buf, addr_count) or None if unavailable. +/// The cache is invalidated after this call — must be called at most once per shard. +pub fn take_shared_device_buffers() -> Option { + SHARD_META_CACHE.with(|cache| { + let mut cache = cache.borrow_mut(); + let c = cache.as_mut()?; + + let ec_buf = c.shared_ec_buf.take()?; + let ec_count = c.shared_ec_count.take()?; + let addr_buf = c.shared_addr_buf.take()?; + let addr_count = c.shared_addr_count.take()?; + + Some(SharedDeviceBufferSet { + ec_buf, + ec_count, + addr_buf, + addr_count, + }) + }) +} + +/// Shared device buffers taken from the shard metadata cache. +pub struct SharedDeviceBufferSet { + pub ec_buf: ceno_gpu::common::buffer::BufferImpl<'static, u32>, + pub ec_count: ceno_gpu::common::buffer::BufferImpl<'static, u32>, + pub addr_buf: ceno_gpu::common::buffer::BufferImpl<'static, u32>, + pub addr_count: ceno_gpu::common::buffer::BufferImpl<'static, u32>, +} + +/// Batch compute EC points for continuation records, keeping results on device. +/// +/// Returns (device_buf_as_u32, num_records) where the device buffer contains +/// GpuShardRamRecord entries with EC points computed. +pub fn gpu_batch_continuation_ec_on_device( + write_records: &[(crate::tables::ShardRamRecord, &'static str)], + read_records: &[(crate::tables::ShardRamRecord, &'static str)], +) -> Result<(ceno_gpu::common::buffer::BufferImpl<'static, u32>, usize, usize), ZKVMError> { + use gkr_iop::gpu::get_cuda_hal; + + let hal = get_cuda_hal().map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU not available for batch EC: {e}").into()) + })?; + + let n_writes = write_records.len(); + let n_reads = read_records.len(); + let total = n_writes + n_reads; + if total == 0 { + let empty = hal.alloc_u32_zeroed(1, None).map_err(|e| { + ZKVMError::InvalidWitness(format!("alloc: {e}").into()) + })?; + return Ok((empty, 0, 0)); + } + + // Convert to GpuShardRamRecord format (writes first, reads after) + let mut gpu_records: Vec = Vec::with_capacity(total); + for (rec, _name) in write_records.iter().chain(read_records.iter()) { + gpu_records.push(super::d2h::shard_ram_record_to_gpu(rec)); + } + + // GPU batch EC, results stay on device + let (device_buf, _count) = info_span!("gpu_batch_ec_on_device", n = total).in_scope(|| { + hal.batch_continuation_ec_on_device(&gpu_records, None) + }).map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU batch EC on device failed: {e}").into()) + })?; + + Ok((device_buf, n_writes, n_reads)) +} + +/// Read the current shared addr count from device (single u32 D2H). +/// Used by debug comparison to snapshot count before/after a kernel. +#[cfg(feature = "gpu")] +pub(crate) fn read_shared_addr_count() -> usize { + SHARD_META_CACHE.with(|cache| { + let cache = cache.borrow(); + let c = cache.as_ref().expect("shard metadata not cached"); + let buf = c.shared_addr_count.as_ref().expect("shared_addr_count not allocated"); + let v: Vec = buf.to_vec().expect("shared_addr_count D2H failed"); + v[0] as usize + }) +} + +/// Read a range of addr entries [start..end) from the shared addr buffer. +#[cfg(feature = "gpu")] +pub(crate) fn read_shared_addr_range(start: usize, end: usize) -> Vec { + if start >= end { + return Vec::new(); + } + SHARD_META_CACHE.with(|cache| { + let cache = cache.borrow(); + let c = cache.as_ref().expect("shard metadata not cached"); + let buf = c.shared_addr_buf.as_ref().expect("shared_addr_buf not allocated"); + let all: Vec = buf.to_vec_n(end).expect("shared_addr_buf D2H failed"); + all[start..end].to_vec() + }) +} + +/// Batch D2H of shared EC records and addr_accessed buffers after all kernel invocations. +/// +/// Called once per shard after all opcode `gpu_assign_instances_inner` calls complete. +/// Transfers accumulated EC records and addresses from shared GPU buffers into `shard_ctx`. +/// +/// If the shared buffers have already been taken by `take_shared_device_buffers` +/// (for the full GPU pipeline), this is a no-op. +pub fn flush_shared_ec_buffers(shard_ctx: &mut ShardContext) -> Result<(), ZKVMError> { + SHARD_META_CACHE.with(|cache| { + let cache = cache.borrow(); + let c = match cache.as_ref() { + Some(c) => c, + None => return Ok(()), // cache already invalidated — no-op + }; + + // If buffers have been taken by take_shared_device_buffers, skip D2H + let ec_count_buf = match c.shared_ec_count.as_ref() { + Some(b) => b, + None => { + tracing::debug!("[GPU shard] flush_shared_ec_buffers: buffers already taken, no-op"); + return Ok(()); + } + }; + let ec_count_vec: Vec = ec_count_buf.to_vec().map_err(|e| { + ZKVMError::InvalidWitness(format!("shared_ec_count D2H: {e}").into()) + })?; + let ec_count = ec_count_vec[0] as usize; + let ec_capacity = c.device_bufs.shared_ec_capacity as usize; + + assert!( + ec_count <= ec_capacity, + "GPU shared EC buffer overflow: count={} > capacity={}. \ + Increase ec_capacity in ensure_shard_metadata_cached.", + ec_count, + ec_capacity, + ); + + if ec_count > 0 { + // D2H EC records (only the active portion) + let ec_buf = c.shared_ec_buf.as_ref().unwrap(); + let ec_u32s = ec_count * 26; // 26 u32s per GpuShardRamRecord + let raw_u32: Vec = ec_buf.to_vec_n(ec_u32s).map_err(|e| { + ZKVMError::InvalidWitness(format!("shared_ec_buf D2H: {e}").into()) + })?; + let raw_bytes = unsafe { + std::slice::from_raw_parts( + raw_u32.as_ptr() as *const u8, + raw_u32.len() * 4, + ) + }; + tracing::info!( + "[GPU shard] flush_shared_ec_buffers: {} EC records, {:.2} MB", + ec_count, + raw_bytes.len() as f64 / (1024.0 * 1024.0), + ); + shard_ctx.extend_gpu_ec_records_raw(raw_bytes); + } + + // D2H addr_accessed count + let addr_count_buf = c.shared_addr_count.as_ref().ok_or_else(|| { + ZKVMError::InvalidWitness("shared_addr_count not allocated".into()) + })?; + let addr_count_vec: Vec = addr_count_buf.to_vec().map_err(|e| { + ZKVMError::InvalidWitness(format!("shared_addr_count D2H: {e}").into()) + })?; + let addr_count = addr_count_vec[0] as usize; + let addr_capacity = c.device_bufs.shared_addr_capacity as usize; + + assert!( + addr_count <= addr_capacity, + "GPU shared addr buffer overflow: count={} > capacity={}", + addr_count, + addr_capacity, + ); + + if addr_count > 0 { + let addr_buf = c.shared_addr_buf.as_ref().unwrap(); + let addrs: Vec = addr_buf.to_vec_n(addr_count).map_err(|e| { + ZKVMError::InvalidWitness(format!("shared_addr_buf D2H: {e}").into()) + })?; + tracing::info!( + "[GPU shard] flush_shared_ec_buffers: {} addr_accessed, {:.2} MB", + addr_count, + addr_count as f64 * 4.0 / (1024.0 * 1024.0), + ); + let mut forked = shard_ctx.get_forked(); + let thread_ctx = &mut forked[0]; + for &addr in &addrs { + thread_ctx.push_addr_accessed(WordAddr(addr)); + } + } + + Ok(()) + }) +} diff --git a/ceno_zkvm/src/instructions/riscv/gpu/gpu_config.rs b/ceno_zkvm/src/instructions/riscv/gpu/gpu_config.rs new file mode 100644 index 000000000..3b2bee2d0 --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/gpu/gpu_config.rs @@ -0,0 +1,241 @@ +/// GPU witgen path-control helpers: kind tags, verified-kind queries, and +/// environment-variable disable switches. +/// +/// Extracted from `witgen_gpu.rs` — pure code move, no behavioural changes. +use super::witgen_gpu::GpuWitgenKind; + +pub(crate) fn kind_tag(kind: GpuWitgenKind) -> &'static str { + match kind { + GpuWitgenKind::Add => "add", + GpuWitgenKind::Sub => "sub", + GpuWitgenKind::LogicR(_) => "logic_r", + GpuWitgenKind::Lw => "lw", + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::LogicI(_) => "logic_i", + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Addi => "addi", + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Lui => "lui", + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Auipc => "auipc", + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Jal => "jal", + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::ShiftR(_) => "shift_r", + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::ShiftI(_) => "shift_i", + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Slt(_) => "slt", + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Slti(_) => "slti", + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::BranchEq(_) => "branch_eq", + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::BranchCmp(_) => "branch_cmp", + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Jalr => "jalr", + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Sw => "sw", + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Sh => "sh", + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Sb => "sb", + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::LoadSub { .. } => "load_sub", + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Mul(_) => "mul", + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Div(_) => "div", + GpuWitgenKind::Keccak => "keccak", + } +} + +/// Returns true if the GPU CUDA kernel for this kind has been verified to produce +/// correct LK multiplicity counters matching the CPU baseline. +/// Unverified kinds fall back to CPU full side effects (GPU still handles witness). +/// +/// Override with `CENO_GPU_DISABLE_LK_KINDS=add,sub,...` to force specific kinds +/// back to CPU LK (for binary-search debugging). +/// Set `CENO_GPU_DISABLE_LK_KINDS=all` to disable GPU LK for ALL kinds. +pub(crate) fn kind_has_verified_lk(kind: GpuWitgenKind) -> bool { + if is_lk_kind_disabled(kind) { + return false; + } + match kind { + // Phase B verified (Add/Sub/LogicR/Lw) + GpuWitgenKind::Add => true, + GpuWitgenKind::Sub => true, + GpuWitgenKind::LogicR(_) => true, + GpuWitgenKind::Lw => true, + // Phase C verified via debug_compare_final_lk + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Addi => true, + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::LogicI(_) => true, + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Lui => true, + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Slti(_) => true, + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::BranchEq(_) => true, + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::BranchCmp(_) => true, + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Sw => true, + // Phase C CUDA kernel fixes applied + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::ShiftI(_) => true, + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Auipc => true, + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Jal => true, + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Jalr => true, + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Sb => true, + // Remaining kinds enabled + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::ShiftR(_) => true, + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Slt(_) => true, + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Sh => true, + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::LoadSub { .. } => true, + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Mul(_) => true, + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Div(_) => true, + // Keccak has its own dispatch path with its own LK handling. + GpuWitgenKind::Keccak => false, + #[cfg(not(feature = "u16limb_circuit"))] + _ => false, + } +} + +/// Returns true if GPU shard records are verified for this kind. +/// Set CENO_GPU_DISABLE_SHARD_KINDS=all to force ALL kinds back to CPU shard path. +pub(crate) fn kind_has_verified_shard(kind: GpuWitgenKind) -> bool { + // Global kill switch: force pure CPU shard path for baseline testing + if std::env::var_os("CENO_GPU_CPU_SHARD").is_some() { + return false; + } + if is_shard_kind_disabled(kind) { + return false; + } + match kind { + GpuWitgenKind::Add + | GpuWitgenKind::Sub + | GpuWitgenKind::LogicR(_) + | GpuWitgenKind::Lw => true, + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::LogicI(_) + | GpuWitgenKind::Addi + | GpuWitgenKind::Lui + | GpuWitgenKind::Auipc + | GpuWitgenKind::Jal + | GpuWitgenKind::ShiftR(_) + | GpuWitgenKind::ShiftI(_) + | GpuWitgenKind::Slt(_) + | GpuWitgenKind::Slti(_) + | GpuWitgenKind::BranchEq(_) + | GpuWitgenKind::BranchCmp(_) + | GpuWitgenKind::Jalr + | GpuWitgenKind::Sw + | GpuWitgenKind::Sh + | GpuWitgenKind::Sb + | GpuWitgenKind::LoadSub { .. } + | GpuWitgenKind::Mul(_) + | GpuWitgenKind::Div(_) => true, + // Keccak has its own dispatch path, never enters try_gpu_assign_instances. + GpuWitgenKind::Keccak => false, + #[cfg(not(feature = "u16limb_circuit"))] + _ => false, + } +} + +/// Check if GPU LK is disabled for a specific kind via CENO_GPU_DISABLE_LK_KINDS env var. +/// Format: CENO_GPU_DISABLE_LK_KINDS=add,sub,lw (comma-separated kind tags) +/// Special value: CENO_GPU_DISABLE_LK_KINDS=all (disables GPU LK for ALL kinds) +pub(crate) fn is_lk_kind_disabled(kind: GpuWitgenKind) -> bool { + thread_local! { + static DISABLED: std::cell::OnceCell> = const { std::cell::OnceCell::new() }; + } + DISABLED.with(|cell| { + let disabled = cell.get_or_init(|| { + std::env::var("CENO_GPU_DISABLE_LK_KINDS") + .ok() + .map(|s| s.split(',').map(|t| t.trim().to_lowercase()).collect()) + .unwrap_or_default() + }); + if disabled.is_empty() { + return false; + } + if disabled.iter().any(|d| d == "all") { + return true; + } + let tag = kind_tag(kind); + disabled.iter().any(|d| d == tag) + }) +} + +/// Check if GPU shard records are disabled for a specific kind via env var. +pub(crate) fn is_shard_kind_disabled(kind: GpuWitgenKind) -> bool { + thread_local! { + static DISABLED: std::cell::OnceCell> = const { std::cell::OnceCell::new() }; + } + DISABLED.with(|cell| { + let disabled = cell.get_or_init(|| { + std::env::var("CENO_GPU_DISABLE_SHARD_KINDS") + .ok() + .map(|s| s.split(',').map(|t| t.trim().to_lowercase()).collect()) + .unwrap_or_default() + }); + if disabled.is_empty() { + return false; + } + if disabled.iter().any(|d| d == "all") { + return true; + } + let tag = kind_tag(kind); + disabled.iter().any(|d| d == tag) + }) +} + +/// Check if a specific GPU witgen kind is disabled via CENO_GPU_DISABLE_KINDS env var. +/// Format: CENO_GPU_DISABLE_KINDS=add,sub,lw (comma-separated kind tags) +pub(crate) fn is_kind_disabled(kind: GpuWitgenKind) -> bool { + thread_local! { + static DISABLED: std::cell::OnceCell> = const { std::cell::OnceCell::new() }; + } + DISABLED.with(|cell| { + let disabled = cell.get_or_init(|| { + std::env::var("CENO_GPU_DISABLE_KINDS") + .ok() + .map(|s| s.split(',').map(|t| t.trim().to_lowercase()).collect()) + .unwrap_or_default() + }); + if disabled.is_empty() { + return false; + } + let tag = kind_tag(kind); + disabled.iter().any(|d| d == tag) + }) +} + +/// Returns true if GPU witgen is globally disabled via CENO_GPU_DISABLE_WITGEN env var. +/// The value is cached at first access so it's immune to runtime env var manipulation. +pub(crate) fn is_gpu_witgen_disabled() -> bool { + use std::sync::OnceLock; + static DISABLED: OnceLock = OnceLock::new(); + *DISABLED.get_or_init(|| { + let val = std::env::var_os("CENO_GPU_DISABLE_WITGEN"); + let disabled = val.is_some(); + // Use eprintln to bypass tracing filters — always visible on stderr + eprintln!( + "[GPU witgen] CENO_GPU_DISABLE_WITGEN={:?} → disabled={}", + val, disabled + ); + disabled + }) +} diff --git a/ceno_zkvm/src/instructions/riscv/gpu/mod.rs b/ceno_zkvm/src/instructions/riscv/gpu/mod.rs index 21e3fd8d2..1559187ca 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/mod.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/mod.rs @@ -49,4 +49,12 @@ pub mod shard_ram; #[cfg(feature = "gpu")] pub mod colmap_base; #[cfg(feature = "gpu")] +pub mod debug_compare; +#[cfg(feature = "gpu")] +pub mod gpu_config; +#[cfg(feature = "gpu")] +pub mod d2h; +#[cfg(feature = "gpu")] +pub mod device_cache; +#[cfg(feature = "gpu")] pub mod witgen_gpu; diff --git a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs index 682798fdc..496c68ff9 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs @@ -6,20 +6,25 @@ /// 3. Returns the GPU-generated witness + CPU-collected side effects use ceno_emul::{StepIndex, StepRecord, WordAddr}; use ceno_gpu::{ - Buffer, CudaHal, CudaSlice, bb31::CudaHalBB31, common::transpose::matrix_transpose, + Buffer, CudaHal, bb31::CudaHalBB31, common::transpose::matrix_transpose, }; -use ceno_gpu::bb31::ShardDeviceBuffers; -use ceno_gpu::common::witgen_types::{CompactEcResult, GpuRamRecordSlot, GpuShardRamRecord, GpuShardScalars}; +use ceno_gpu::common::witgen_types::{GpuRamRecordSlot, GpuShardRamRecord}; use ff_ext::ExtensionField; -use gkr_iop::{RAMType, tables::LookupTable, utils::lk_multiplicity::Multiplicity}; +use gkr_iop::utils::lk_multiplicity::Multiplicity; use p3::field::FieldAlgebra; -use rustc_hash::FxHashMap; -use std::cell::{Cell, RefCell}; +use std::cell::Cell; use tracing::info_span; use witness::{InstancePaddingStrategy, RowMajorMatrix}; +use super::debug_compare::{ + debug_compare_final_lk, debug_compare_keccak, debug_compare_shard_ec, + debug_compare_shard_side_effects, debug_compare_witness, +}; +use super::gpu_config::{ + is_gpu_witgen_disabled, is_kind_disabled, kind_has_verified_lk, kind_has_verified_shard, +}; use crate::{ - e2e::{RAMRecord, ShardContext}, + e2e::ShardContext, error::ZKVMError, instructions::{Instruction, cpu_collect_shard_side_effects, cpu_collect_side_effects}, tables::RMMCollections, @@ -74,19 +79,23 @@ pub enum GpuWitgenKind { Keccak, } -/// Cached shard_steps device buffer with metadata for logging. -struct ShardStepsCache { - host_ptr: usize, - byte_len: usize, - shard_id: usize, - n_steps: usize, - device_buf: CudaSlice, -} +// Re-exports from device_cache module for external callers (e2e.rs, structs.rs). +pub use super::device_cache::{ + SharedDeviceBufferSet, flush_shared_ec_buffers, gpu_batch_continuation_ec_on_device, + invalidate_shard_meta_cache, invalidate_shard_steps_cache, take_shared_device_buffers, +}; +// Re-export for external callers (structs.rs). +pub use super::d2h::gpu_batch_continuation_ec; +use super::d2h::{ + CompactEcBuf, LkResult, RamBuf, WitResult, gpu_collect_shard_records, gpu_compact_ec_d2h, + gpu_lk_counters_to_multiplicity, gpu_witness_to_rmm, +}; +use super::device_cache::{ + ensure_shard_metadata_cached, read_shared_addr_count, read_shared_addr_range, + upload_shard_steps_cached, with_cached_shard_meta, with_cached_shard_steps, +}; -// Thread-local cache for shard_steps device buffer. Invalidated when shard changes. thread_local! { - static SHARD_STEPS_DEVICE: RefCell> = - const { RefCell::new(None) }; /// Thread-local flag to force CPU path (used by debug comparison code). static FORCE_CPU_PATH: Cell = const { Cell::new(false) }; } @@ -101,683 +110,6 @@ fn is_force_cpu_path() -> bool { FORCE_CPU_PATH.with(|f| f.get()) } -/// Upload shard_steps to GPU, reusing cached device buffer if the same data. -fn upload_shard_steps_cached( - hal: &CudaHalBB31, - shard_steps: &[StepRecord], - shard_id: usize, -) -> Result<(), ZKVMError> { - let ptr = shard_steps.as_ptr() as usize; - let byte_len = shard_steps.len() * std::mem::size_of::(); - - SHARD_STEPS_DEVICE.with(|cache| { - let mut cache = cache.borrow_mut(); - if let Some(c) = cache.as_ref() { - if c.host_ptr == ptr && c.byte_len == byte_len { - return Ok(()); // cache hit - } - } - // Cache miss: upload - let mb = byte_len as f64 / (1024.0 * 1024.0); - tracing::info!( - "[GPU witgen] uploading shard_steps: shard_id={}, n_steps={}, {:.2} MB", - shard_id, - shard_steps.len(), - mb, - ); - let bytes: &[u8] = - unsafe { std::slice::from_raw_parts(shard_steps.as_ptr() as *const u8, byte_len) }; - let device_buf = hal.inner.htod_copy_stream(None, bytes).map_err(|e| { - ZKVMError::InvalidWitness(format!("shard_steps H2D failed: {e}").into()) - })?; - *cache = Some(ShardStepsCache { - host_ptr: ptr, - byte_len, - shard_id, - n_steps: shard_steps.len(), - device_buf, - }); - Ok(()) - }) -} - -/// Borrow the cached device buffer for kernel launch. -/// Panics if `upload_shard_steps_cached` was not called first. -fn with_cached_shard_steps(f: impl FnOnce(&CudaSlice) -> R) -> R { - SHARD_STEPS_DEVICE.with(|cache| { - let cache = cache.borrow(); - let c = cache.as_ref().expect("shard_steps not uploaded"); - f(&c.device_buf) - }) -} - -/// Invalidate the cached shard_steps device buffer. -/// Call this when shard processing is complete to free GPU memory. -pub fn invalidate_shard_steps_cache() { - SHARD_STEPS_DEVICE.with(|cache| { - let mut cache = cache.borrow_mut(); - if let Some(c) = cache.as_ref() { - let mb = c.byte_len as f64 / (1024.0 * 1024.0); - tracing::info!( - "[GPU witgen] releasing shard_steps cache: shard_id={}, n_steps={}, {:.2} MB", - c.shard_id, - c.n_steps, - mb, - ); - } - *cache = None; - }); -} - -/// Cached shard metadata device buffers for GPU shard records. -/// Invalidated when shard_id changes; shared across all kernel invocations in one shard. -struct ShardMetadataCache { - shard_id: usize, - device_bufs: ShardDeviceBuffers, - /// Shared EC record buffer (owns the GPU memory, pointer stored in device_bufs). - shared_ec_buf: Option>, - /// Shared EC record count buffer (single u32 counter). - shared_ec_count: Option>, - /// Shared addr_accessed buffer (u32 word addresses). - shared_addr_buf: Option>, - /// Shared addr_accessed count buffer (single u32 counter). - shared_addr_count: Option>, -} - -thread_local! { - static SHARD_META_CACHE: RefCell> = - const { RefCell::new(None) }; -} - -/// Build and cache shard metadata device buffers for GPU shard records. -/// -/// FA (future access) device buffers are global and identical across all shards, -/// so they are uploaded once and reused via move. Only per-shard data (scalars + -/// prev_shard_ranges) is re-uploaded when the shard changes. -fn ensure_shard_metadata_cached( - hal: &CudaHalBB31, - shard_ctx: &ShardContext, - n_total_steps: usize, -) -> Result<(), ZKVMError> { - let shard_id = shard_ctx.shard_id; - SHARD_META_CACHE.with(|cache| { - let mut cache = cache.borrow_mut(); - if let Some(c) = cache.as_ref() { - if c.shard_id == shard_id { - return Ok(()); // cache hit - } - } - - // Move FA device buffer from previous cache (reuse across shards). - // FA data is global — identical across all shards — so we reuse, not re-upload. - let existing_fa = cache.take().map(|c| { - let ShardDeviceBuffers { - next_access_packed, - scalars: _, - prev_shard_cycle_range: _, - prev_shard_heap_range: _, - prev_shard_hint_range: _, - gpu_ec_shard_id: _, - shared_ec_out_ptr: _, - shared_ec_count_ptr: _, - shared_addr_out_ptr: _, - shared_addr_count_ptr: _, - shared_ec_capacity: _, - shared_addr_capacity: _, - } = c.device_bufs; - next_access_packed - }); - - let next_access_packed_device = if let Some(fa) = existing_fa { - fa // Reuse existing GPU memory — zero cost pointer move - } else { - // First shard: bulk H2D upload packed FA entries (no sort here) - let sorted = &shard_ctx.sorted_next_accesses; - tracing::info_span!("next_access_h2d").in_scope(|| -> Result<_, ZKVMError> { - let packed_bytes: &[u8] = if sorted.packed.is_empty() { - &[0u8; 16] // sentinel for empty - } else { - unsafe { - std::slice::from_raw_parts( - sorted.packed.as_ptr() as *const u8, - sorted.packed.len() * std::mem::size_of::(), - ) - } - }; - let buf = hal.inner.htod_copy_stream(None, packed_bytes).map_err(|e| { - ZKVMError::InvalidWitness(format!("next_access_packed H2D: {e}").into()) - })?; - let next_access_device = ceno_gpu::common::buffer::BufferImpl::new(buf); - let mb = packed_bytes.len() as f64 / (1024.0 * 1024.0); - tracing::info!( - "[GPU shard] FA uploaded once: {} entries, {:.2} MB (packed)", - sorted.packed.len(), - mb, - ); - Ok(next_access_device) - })? - }; - - // Per-shard: always re-upload scalars + prev_shard_ranges - let scalars = GpuShardScalars { - shard_cycle_start: shard_ctx.cur_shard_cycle_range.start as u64, - shard_cycle_end: shard_ctx.cur_shard_cycle_range.end as u64, - shard_offset_cycle: shard_ctx.current_shard_offset_cycle(), - shard_id: shard_id as u32, - heap_start: shard_ctx.platform.heap.start, - heap_end: shard_ctx.platform.heap.end, - hint_start: shard_ctx.platform.hints.start, - hint_end: shard_ctx.platform.hints.end, - shard_heap_start: shard_ctx.shard_heap_addr_range.start, - shard_heap_end: shard_ctx.shard_heap_addr_range.end, - shard_hint_start: shard_ctx.shard_hint_addr_range.start, - shard_hint_end: shard_ctx.shard_hint_addr_range.end, - next_access_count: shard_ctx.sorted_next_accesses.packed.len() as u32, - num_prev_shards: shard_ctx.prev_shard_cycle_range.len() as u32, - num_prev_heap_ranges: shard_ctx.prev_shard_heap_range.len() as u32, - num_prev_hint_ranges: shard_ctx.prev_shard_hint_range.len() as u32, - }; - - let (scalars_device, pscr_device, pshr_device, pshi_device) = - tracing::info_span!("shard_scalars_h2d").in_scope(|| -> Result<_, ZKVMError> { - let scalars_bytes: &[u8] = unsafe { - std::slice::from_raw_parts( - &scalars as *const GpuShardScalars as *const u8, - std::mem::size_of::(), - ) - }; - let scalars_device = - hal.inner - .htod_copy_stream(None, scalars_bytes) - .map_err(|e| { - ZKVMError::InvalidWitness( - format!("shard scalars H2D failed: {e}").into(), - ) - })?; - - let pscr = &shard_ctx.prev_shard_cycle_range; - let pscr_device = hal - .alloc_u64_from_host(if pscr.is_empty() { &[0u64] } else { pscr }, None) - .map_err(|e| { - ZKVMError::InvalidWitness(format!("pscr H2D failed: {e}").into()) - })?; - - let pshr = &shard_ctx.prev_shard_heap_range; - let pshr_device = hal - .alloc_u32_from_host(if pshr.is_empty() { &[0u32] } else { pshr }, None) - .map_err(|e| { - ZKVMError::InvalidWitness(format!("pshr H2D failed: {e}").into()) - })?; - - let pshi = &shard_ctx.prev_shard_hint_range; - let pshi_device = hal - .alloc_u32_from_host(if pshi.is_empty() { &[0u32] } else { pshi }, None) - .map_err(|e| { - ZKVMError::InvalidWitness(format!("pshi H2D failed: {e}").into()) - })?; - - Ok((scalars_device, pscr_device, pshr_device, pshi_device)) - })?; - - tracing::info!( - "[GPU shard] shard_id={}: per-shard scalars updated", - shard_id, - ); - - // Allocate shared EC/addr compact buffers for this shard. - // - // EC records: cross-shard only (sparse subset of RAM ops). - // 104 bytes each (26 u32s). Cap at 16M entries ≈ 1.6 GB. - // Addr records: every gpu_send() emits one (dense). - // 4 bytes each (1 u32). Cap at 256M entries ≈ 1 GB. - let max_ops_per_step = 52u64; // keccak worst case - let total_ops_estimate = n_total_steps as u64 * max_ops_per_step; - let ec_capacity = total_ops_estimate.min(16 * 1024 * 1024) as usize; - let ec_u32s = ec_capacity * 26; // 26 u32s per GpuShardRamRecord (104 bytes) - let addr_capacity = total_ops_estimate.min(256 * 1024 * 1024) as usize; - - let shared_ec_buf = hal - .alloc_u32_zeroed(ec_u32s, None) - .map_err(|e| ZKVMError::InvalidWitness(format!("shared_ec_buf alloc: {e}").into()))?; - let shared_ec_count = hal - .alloc_u32_zeroed(1, None) - .map_err(|e| { - ZKVMError::InvalidWitness(format!("shared_ec_count alloc: {e}").into()) - })?; - let shared_addr_buf = hal - .alloc_u32_zeroed(addr_capacity, None) - .map_err(|e| { - ZKVMError::InvalidWitness(format!("shared_addr_buf alloc: {e}").into()) - })?; - let shared_addr_count = hal - .alloc_u32_zeroed(1, None) - .map_err(|e| { - ZKVMError::InvalidWitness(format!("shared_addr_count alloc: {e}").into()) - })?; - - let shared_ec_out_ptr = shared_ec_buf.device_ptr() as u64; - let shared_ec_count_ptr = shared_ec_count.device_ptr() as u64; - let shared_addr_out_ptr = shared_addr_buf.device_ptr() as u64; - let shared_addr_count_ptr = shared_addr_count.device_ptr() as u64; - - tracing::info!( - "[GPU shard] shard_id={}: shared buffers allocated: ec_capacity={}, addr_capacity={}", - shard_id, ec_capacity, addr_capacity, - ); - - *cache = Some(ShardMetadataCache { - shard_id, - device_bufs: ShardDeviceBuffers { - scalars: scalars_device, - next_access_packed: next_access_packed_device, - prev_shard_cycle_range: pscr_device, - prev_shard_heap_range: pshr_device, - prev_shard_hint_range: pshi_device, - gpu_ec_shard_id: Some(shard_id as u64), - shared_ec_out_ptr, - shared_ec_count_ptr, - shared_addr_out_ptr, - shared_addr_count_ptr, - shared_ec_capacity: ec_capacity as u32, - shared_addr_capacity: addr_capacity as u32, - }, - shared_ec_buf: Some(shared_ec_buf), - shared_ec_count: Some(shared_ec_count), - shared_addr_buf: Some(shared_addr_buf), - shared_addr_count: Some(shared_addr_count), - }); - Ok(()) - }) -} - -/// Borrow the cached shard device buffers for kernel launch. -fn with_cached_shard_meta(f: impl FnOnce(&ShardDeviceBuffers) -> R) -> R { - SHARD_META_CACHE.with(|cache| { - let cache = cache.borrow(); - let c = cache.as_ref().expect("shard metadata not uploaded"); - f(&c.device_bufs) - }) -} - -/// Invalidate the shard metadata cache (call when shard processing is complete). -pub fn invalidate_shard_meta_cache() { - SHARD_META_CACHE.with(|cache| { - *cache.borrow_mut() = None; - }); -} - -/// Take ownership of shared EC and addr_accessed device buffers from the cache. -/// -/// Returns (shared_ec_buf, ec_count, shared_addr_buf, addr_count) or None if unavailable. -/// The cache is invalidated after this call — must be called at most once per shard. -pub fn take_shared_device_buffers() -> Option { - SHARD_META_CACHE.with(|cache| { - let mut cache = cache.borrow_mut(); - let c = cache.as_mut()?; - - let ec_buf = c.shared_ec_buf.take()?; - let ec_count = c.shared_ec_count.take()?; - let addr_buf = c.shared_addr_buf.take()?; - let addr_count = c.shared_addr_count.take()?; - - Some(SharedDeviceBufferSet { - ec_buf, - ec_count, - addr_buf, - addr_count, - }) - }) -} - -/// Shared device buffers taken from the shard metadata cache. -pub struct SharedDeviceBufferSet { - pub ec_buf: ceno_gpu::common::buffer::BufferImpl<'static, u32>, - pub ec_count: ceno_gpu::common::buffer::BufferImpl<'static, u32>, - pub addr_buf: ceno_gpu::common::buffer::BufferImpl<'static, u32>, - pub addr_count: ceno_gpu::common::buffer::BufferImpl<'static, u32>, -} - -/// Batch compute EC points for continuation records, keeping results on device. -/// -/// Returns (device_buf_as_u32, num_records) where the device buffer contains -/// GpuShardRamRecord entries with EC points computed. -pub fn gpu_batch_continuation_ec_on_device( - write_records: &[(crate::tables::ShardRamRecord, &'static str)], - read_records: &[(crate::tables::ShardRamRecord, &'static str)], -) -> Result<(ceno_gpu::common::buffer::BufferImpl<'static, u32>, usize, usize), ZKVMError> { - use gkr_iop::gpu::get_cuda_hal; - - let hal = get_cuda_hal().map_err(|e| { - ZKVMError::InvalidWitness(format!("GPU not available for batch EC: {e}").into()) - })?; - - let n_writes = write_records.len(); - let n_reads = read_records.len(); - let total = n_writes + n_reads; - if total == 0 { - let empty = hal.alloc_u32_zeroed(1, None).map_err(|e| { - ZKVMError::InvalidWitness(format!("alloc: {e}").into()) - })?; - return Ok((empty, 0, 0)); - } - - // Convert to GpuShardRamRecord format (writes first, reads after) - let mut gpu_records: Vec = Vec::with_capacity(total); - for (rec, _name) in write_records.iter().chain(read_records.iter()) { - gpu_records.push(shard_ram_record_to_gpu(rec)); - } - - // GPU batch EC, results stay on device - let (device_buf, _count) = info_span!("gpu_batch_ec_on_device", n = total).in_scope(|| { - hal.batch_continuation_ec_on_device(&gpu_records, None) - }).map_err(|e| { - ZKVMError::InvalidWitness(format!("GPU batch EC on device failed: {e}").into()) - })?; - - Ok((device_buf, n_writes, n_reads)) -} - -/// Read the current shared addr count from device (single u32 D2H). -/// Used by debug comparison to snapshot count before/after a kernel. -#[cfg(feature = "gpu")] -fn read_shared_addr_count() -> usize { - SHARD_META_CACHE.with(|cache| { - let cache = cache.borrow(); - let c = cache.as_ref().expect("shard metadata not cached"); - let buf = c.shared_addr_count.as_ref().expect("shared_addr_count not allocated"); - let v: Vec = buf.to_vec().expect("shared_addr_count D2H failed"); - v[0] as usize - }) -} - -/// Read a range of addr entries [start..end) from the shared addr buffer. -#[cfg(feature = "gpu")] -fn read_shared_addr_range(start: usize, end: usize) -> Vec { - if start >= end { - return Vec::new(); - } - SHARD_META_CACHE.with(|cache| { - let cache = cache.borrow(); - let c = cache.as_ref().expect("shard metadata not cached"); - let buf = c.shared_addr_buf.as_ref().expect("shared_addr_buf not allocated"); - let all: Vec = buf.to_vec_n(end).expect("shared_addr_buf D2H failed"); - all[start..end].to_vec() - }) -} - -/// Batch D2H of shared EC records and addr_accessed buffers after all kernel invocations. -/// -/// Called once per shard after all opcode `gpu_assign_instances_inner` calls complete. -/// Transfers accumulated EC records and addresses from shared GPU buffers into `shard_ctx`. -/// -/// If the shared buffers have already been taken by `take_shared_device_buffers` -/// (for the full GPU pipeline), this is a no-op. -pub fn flush_shared_ec_buffers(shard_ctx: &mut ShardContext) -> Result<(), ZKVMError> { - SHARD_META_CACHE.with(|cache| { - let cache = cache.borrow(); - let c = match cache.as_ref() { - Some(c) => c, - None => return Ok(()), // cache already invalidated — no-op - }; - - // If buffers have been taken by take_shared_device_buffers, skip D2H - let ec_count_buf = match c.shared_ec_count.as_ref() { - Some(b) => b, - None => { - tracing::debug!("[GPU shard] flush_shared_ec_buffers: buffers already taken, no-op"); - return Ok(()); - } - }; - let ec_count_vec: Vec = ec_count_buf.to_vec().map_err(|e| { - ZKVMError::InvalidWitness(format!("shared_ec_count D2H: {e}").into()) - })?; - let ec_count = ec_count_vec[0] as usize; - let ec_capacity = c.device_bufs.shared_ec_capacity as usize; - - assert!( - ec_count <= ec_capacity, - "GPU shared EC buffer overflow: count={} > capacity={}. \ - Increase ec_capacity in ensure_shard_metadata_cached.", - ec_count, - ec_capacity, - ); - - if ec_count > 0 { - // D2H EC records (only the active portion) - let ec_buf = c.shared_ec_buf.as_ref().unwrap(); - let ec_u32s = ec_count * 26; // 26 u32s per GpuShardRamRecord - let raw_u32: Vec = ec_buf.to_vec_n(ec_u32s).map_err(|e| { - ZKVMError::InvalidWitness(format!("shared_ec_buf D2H: {e}").into()) - })?; - let raw_bytes = unsafe { - std::slice::from_raw_parts( - raw_u32.as_ptr() as *const u8, - raw_u32.len() * 4, - ) - }; - tracing::info!( - "[GPU shard] flush_shared_ec_buffers: {} EC records, {:.2} MB", - ec_count, - raw_bytes.len() as f64 / (1024.0 * 1024.0), - ); - shard_ctx.extend_gpu_ec_records_raw(raw_bytes); - } - - // D2H addr_accessed count - let addr_count_buf = c.shared_addr_count.as_ref().ok_or_else(|| { - ZKVMError::InvalidWitness("shared_addr_count not allocated".into()) - })?; - let addr_count_vec: Vec = addr_count_buf.to_vec().map_err(|e| { - ZKVMError::InvalidWitness(format!("shared_addr_count D2H: {e}").into()) - })?; - let addr_count = addr_count_vec[0] as usize; - let addr_capacity = c.device_bufs.shared_addr_capacity as usize; - - assert!( - addr_count <= addr_capacity, - "GPU shared addr buffer overflow: count={} > capacity={}", - addr_count, - addr_capacity, - ); - - if addr_count > 0 { - let addr_buf = c.shared_addr_buf.as_ref().unwrap(); - let addrs: Vec = addr_buf.to_vec_n(addr_count).map_err(|e| { - ZKVMError::InvalidWitness(format!("shared_addr_buf D2H: {e}").into()) - })?; - tracing::info!( - "[GPU shard] flush_shared_ec_buffers: {} addr_accessed, {:.2} MB", - addr_count, - addr_count as f64 * 4.0 / (1024.0 * 1024.0), - ); - let mut forked = shard_ctx.get_forked(); - let thread_ctx = &mut forked[0]; - for &addr in &addrs { - thread_ctx.push_addr_accessed(WordAddr(addr)); - } - } - - Ok(()) - }) -} - -/// CPU-side lightweight scan of GPU-produced RAM record slots. -/// -/// Reconstructs BTreeMap read/write records and addr_accessed from the GPU output, -/// replacing the previous `collect_shard_side_effects()` CPU loop. -fn gpu_collect_shard_records( - shard_ctx: &mut ShardContext, - slots: &[GpuRamRecordSlot], -) { - let current_shard_id = shard_ctx.shard_id; - - for slot in slots { - // Check was_sent flag (bit 4): this slot corresponds to a send() call - if slot.flags & (1 << 4) != 0 { - shard_ctx.push_addr_accessed(WordAddr(slot.addr)); - } - - // Check active flag (bit 0): this slot has a read or write record - if slot.flags & 1 == 0 { - continue; - } - - let ram_type = match (slot.flags >> 5) & 0x7 { - 1 => RAMType::Register, - 2 => RAMType::Memory, - _ => continue, - }; - let has_prev_value = slot.flags & (1 << 3) != 0; - let prev_value = if has_prev_value { Some(slot.prev_value) } else { None }; - let addr = WordAddr(slot.addr); - - // Insert read record (bit 1) - if slot.flags & (1 << 1) != 0 { - shard_ctx.insert_read_record( - addr, - RAMRecord { - ram_type, - reg_id: slot.reg_id as u64, - addr, - prev_cycle: slot.prev_cycle, - cycle: slot.cycle, - shard_cycle: 0, - prev_value, - value: slot.value, - shard_id: slot.read_shard_id as usize, - }, - ); - } - - // Insert write record (bit 2) - if slot.flags & (1 << 2) != 0 { - shard_ctx.insert_write_record( - addr, - RAMRecord { - ram_type, - reg_id: slot.reg_id as u64, - addr, - prev_cycle: slot.prev_cycle, - cycle: slot.cycle, - shard_cycle: slot.shard_cycle, - prev_value, - value: slot.value, - shard_id: current_shard_id, - }, - ); - } - } -} - -/// D2H the compact EC result: read count, then partial-D2H only that many records. -fn gpu_compact_ec_d2h( - compact: &CompactEcResult, -) -> Result, ZKVMError> { - // D2H the count (1 u32) - let count_vec: Vec = compact.count_buf.to_vec().map_err(|e| { - ZKVMError::InvalidWitness(format!("compact_count D2H failed: {e}").into()) - })?; - let count = count_vec[0] as usize; - if count == 0 { - return Ok(vec![]); - } - - // Partial D2H: only transfer the first `count` records (not the full allocation) - let record_u32s = std::mem::size_of::() / 4; // 26 - let total_u32s = count * record_u32s; - let buf_vec: Vec = compact.buffer.to_vec_n(total_u32s).map_err(|e| { - ZKVMError::InvalidWitness(format!("compact_out D2H failed: {e}").into()) - })?; - - let records: Vec = unsafe { - let ptr = buf_vec.as_ptr() as *const GpuShardRamRecord; - std::slice::from_raw_parts(ptr, count).to_vec() - }; - tracing::debug!("GPU EC compact D2H: {} active records ({} bytes)", count, total_u32s * 4); - Ok(records) -} - -/// Returns true if GPU shard records are verified for this kind. -/// Set CENO_GPU_DISABLE_SHARD_KINDS=all to force ALL kinds back to CPU shard path. -fn kind_has_verified_shard(kind: GpuWitgenKind) -> bool { - // Global kill switch: force pure CPU shard path for baseline testing - if std::env::var_os("CENO_GPU_CPU_SHARD").is_some() { - return false; - } - if is_shard_kind_disabled(kind) { - return false; - } - match kind { - GpuWitgenKind::Add - | GpuWitgenKind::Sub - | GpuWitgenKind::LogicR(_) - | GpuWitgenKind::Lw => true, - #[cfg(feature = "u16limb_circuit")] - GpuWitgenKind::LogicI(_) - | GpuWitgenKind::Addi - | GpuWitgenKind::Lui - | GpuWitgenKind::Auipc - | GpuWitgenKind::Jal - | GpuWitgenKind::ShiftR(_) - | GpuWitgenKind::ShiftI(_) - | GpuWitgenKind::Slt(_) - | GpuWitgenKind::Slti(_) - | GpuWitgenKind::BranchEq(_) - | GpuWitgenKind::BranchCmp(_) - | GpuWitgenKind::Jalr - | GpuWitgenKind::Sw - | GpuWitgenKind::Sh - | GpuWitgenKind::Sb - | GpuWitgenKind::LoadSub { .. } - | GpuWitgenKind::Mul(_) - | GpuWitgenKind::Div(_) => true, - // Keccak has its own dispatch path, never enters try_gpu_assign_instances. - GpuWitgenKind::Keccak => false, - #[cfg(not(feature = "u16limb_circuit"))] - _ => false, - } -} - -/// Check if GPU shard records are disabled for a specific kind via env var. -fn is_shard_kind_disabled(kind: GpuWitgenKind) -> bool { - thread_local! { - static DISABLED: std::cell::OnceCell> = const { std::cell::OnceCell::new() }; - } - DISABLED.with(|cell| { - let disabled = cell.get_or_init(|| { - std::env::var("CENO_GPU_DISABLE_SHARD_KINDS") - .ok() - .map(|s| s.split(',').map(|t| t.trim().to_lowercase()).collect()) - .unwrap_or_default() - }); - if disabled.is_empty() { - return false; - } - if disabled.iter().any(|d| d == "all") { - return true; - } - let tag = kind_tag(kind); - disabled.iter().any(|d| d == tag) - }) -} - -/// Returns true if GPU witgen is globally disabled via CENO_GPU_DISABLE_WITGEN env var. -/// The value is cached at first access so it's immune to runtime env var manipulation. -fn is_gpu_witgen_disabled() -> bool { - use std::sync::OnceLock; - static DISABLED: OnceLock = OnceLock::new(); - *DISABLED.get_or_init(|| { - let val = std::env::var_os("CENO_GPU_DISABLE_WITGEN"); - let disabled = val.is_some(); - // Use eprintln to bypass tracing filters — always visible on stderr - eprintln!( - "[GPU witgen] CENO_GPU_DISABLE_WITGEN={:?} → disabled={}", - val, disabled - ); - disabled - }) -} - /// Try to run GPU witness generation for the given instruction. /// Returns `Ok(Some(...))` if GPU was used, `Ok(None)` if GPU is unavailable (caller should fallback to CPU). /// @@ -1044,15 +376,7 @@ fn gpu_assign_instances_inner>( Ok(([raw_witin, raw_structural], lk_multiplicity)) } -type WitBuf = ceno_gpu::common::BufferImpl< - 'static, - ::BaseField, ->; -type LkBuf = ceno_gpu::common::BufferImpl<'static, u32>; -type RamBuf = ceno_gpu::common::BufferImpl<'static, u32>; -type WitResult = ceno_gpu::common::witgen_types::GpuWitnessResult; -type LkResult = ceno_gpu::common::witgen_types::GpuLookupCountersResult; -type CompactEcBuf = ceno_gpu::common::witgen_types::CompactEcResult; +// Type aliases and D2H conversion functions live in super::d2h. /// Compute fetch counter parameters from step data. pub(crate) fn compute_fetch_params( @@ -1850,909 +1174,6 @@ fn collect_shard_side_effects>( cpu_collect_shard_side_effects::(config, shard_ctx, shard_steps, step_indices) } -fn kind_tag(kind: GpuWitgenKind) -> &'static str { - match kind { - GpuWitgenKind::Add => "add", - GpuWitgenKind::Sub => "sub", - GpuWitgenKind::LogicR(_) => "logic_r", - GpuWitgenKind::Lw => "lw", - #[cfg(feature = "u16limb_circuit")] - GpuWitgenKind::LogicI(_) => "logic_i", - #[cfg(feature = "u16limb_circuit")] - GpuWitgenKind::Addi => "addi", - #[cfg(feature = "u16limb_circuit")] - GpuWitgenKind::Lui => "lui", - #[cfg(feature = "u16limb_circuit")] - GpuWitgenKind::Auipc => "auipc", - #[cfg(feature = "u16limb_circuit")] - GpuWitgenKind::Jal => "jal", - #[cfg(feature = "u16limb_circuit")] - GpuWitgenKind::ShiftR(_) => "shift_r", - #[cfg(feature = "u16limb_circuit")] - GpuWitgenKind::ShiftI(_) => "shift_i", - #[cfg(feature = "u16limb_circuit")] - GpuWitgenKind::Slt(_) => "slt", - #[cfg(feature = "u16limb_circuit")] - GpuWitgenKind::Slti(_) => "slti", - #[cfg(feature = "u16limb_circuit")] - GpuWitgenKind::BranchEq(_) => "branch_eq", - #[cfg(feature = "u16limb_circuit")] - GpuWitgenKind::BranchCmp(_) => "branch_cmp", - #[cfg(feature = "u16limb_circuit")] - GpuWitgenKind::Jalr => "jalr", - #[cfg(feature = "u16limb_circuit")] - GpuWitgenKind::Sw => "sw", - #[cfg(feature = "u16limb_circuit")] - GpuWitgenKind::Sh => "sh", - #[cfg(feature = "u16limb_circuit")] - GpuWitgenKind::Sb => "sb", - #[cfg(feature = "u16limb_circuit")] - GpuWitgenKind::LoadSub { .. } => "load_sub", - #[cfg(feature = "u16limb_circuit")] - GpuWitgenKind::Mul(_) => "mul", - #[cfg(feature = "u16limb_circuit")] - GpuWitgenKind::Div(_) => "div", - GpuWitgenKind::Keccak => "keccak", - } -} - -/// Returns true if the GPU CUDA kernel for this kind has been verified to produce -/// correct LK multiplicity counters matching the CPU baseline. -/// Unverified kinds fall back to CPU full side effects (GPU still handles witness). -/// -/// Override with `CENO_GPU_DISABLE_LK_KINDS=add,sub,...` to force specific kinds -/// back to CPU LK (for binary-search debugging). -/// Set `CENO_GPU_DISABLE_LK_KINDS=all` to disable GPU LK for ALL kinds. -fn kind_has_verified_lk(kind: GpuWitgenKind) -> bool { - if is_lk_kind_disabled(kind) { - return false; - } - match kind { - // Phase B verified (Add/Sub/LogicR/Lw) - GpuWitgenKind::Add => true, - GpuWitgenKind::Sub => true, - GpuWitgenKind::LogicR(_) => true, - GpuWitgenKind::Lw => true, - // Phase C verified via debug_compare_final_lk - #[cfg(feature = "u16limb_circuit")] - GpuWitgenKind::Addi => true, - #[cfg(feature = "u16limb_circuit")] - GpuWitgenKind::LogicI(_) => true, - #[cfg(feature = "u16limb_circuit")] - GpuWitgenKind::Lui => true, - #[cfg(feature = "u16limb_circuit")] - GpuWitgenKind::Slti(_) => true, - #[cfg(feature = "u16limb_circuit")] - GpuWitgenKind::BranchEq(_) => true, - #[cfg(feature = "u16limb_circuit")] - GpuWitgenKind::BranchCmp(_) => true, - #[cfg(feature = "u16limb_circuit")] - GpuWitgenKind::Sw => true, - // Phase C CUDA kernel fixes applied - #[cfg(feature = "u16limb_circuit")] - GpuWitgenKind::ShiftI(_) => true, - #[cfg(feature = "u16limb_circuit")] - GpuWitgenKind::Auipc => true, - #[cfg(feature = "u16limb_circuit")] - GpuWitgenKind::Jal => true, - #[cfg(feature = "u16limb_circuit")] - GpuWitgenKind::Jalr => true, - #[cfg(feature = "u16limb_circuit")] - GpuWitgenKind::Sb => true, - // Remaining kinds enabled - #[cfg(feature = "u16limb_circuit")] - GpuWitgenKind::ShiftR(_) => true, - #[cfg(feature = "u16limb_circuit")] - GpuWitgenKind::Slt(_) => true, - #[cfg(feature = "u16limb_circuit")] - GpuWitgenKind::Sh => true, - #[cfg(feature = "u16limb_circuit")] - GpuWitgenKind::LoadSub { .. } => true, - #[cfg(feature = "u16limb_circuit")] - GpuWitgenKind::Mul(_) => true, - #[cfg(feature = "u16limb_circuit")] - GpuWitgenKind::Div(_) => true, - // Keccak has its own dispatch path with its own LK handling. - GpuWitgenKind::Keccak => false, - #[cfg(not(feature = "u16limb_circuit"))] - _ => false, - } -} - -/// Check if GPU LK is disabled for a specific kind via CENO_GPU_DISABLE_LK_KINDS env var. -/// Format: CENO_GPU_DISABLE_LK_KINDS=add,sub,lw (comma-separated kind tags) -/// Special value: CENO_GPU_DISABLE_LK_KINDS=all (disables GPU LK for ALL kinds) -fn is_lk_kind_disabled(kind: GpuWitgenKind) -> bool { - thread_local! { - static DISABLED: std::cell::OnceCell> = const { std::cell::OnceCell::new() }; - } - DISABLED.with(|cell| { - let disabled = cell.get_or_init(|| { - std::env::var("CENO_GPU_DISABLE_LK_KINDS") - .ok() - .map(|s| s.split(',').map(|t| t.trim().to_lowercase()).collect()) - .unwrap_or_default() - }); - if disabled.is_empty() { - return false; - } - if disabled.iter().any(|d| d == "all") { - return true; - } - let tag = kind_tag(kind); - disabled.iter().any(|d| d == tag) - }) -} - -/// Check if a specific GPU witgen kind is disabled via CENO_GPU_DISABLE_KINDS env var. -/// Format: CENO_GPU_DISABLE_KINDS=add,sub,lw (comma-separated kind tags) -fn is_kind_disabled(kind: GpuWitgenKind) -> bool { - thread_local! { - static DISABLED: std::cell::OnceCell> = const { std::cell::OnceCell::new() }; - } - DISABLED.with(|cell| { - let disabled = cell.get_or_init(|| { - std::env::var("CENO_GPU_DISABLE_KINDS") - .ok() - .map(|s| s.split(',').map(|t| t.trim().to_lowercase()).collect()) - .unwrap_or_default() - }); - if disabled.is_empty() { - return false; - } - let tag = kind_tag(kind); - disabled.iter().any(|d| d == tag) - }) -} - -fn debug_compare_final_lk>( - config: &I::InstructionConfig, - shard_ctx: &ShardContext, - num_witin: usize, - num_structural_witin: usize, - shard_steps: &[StepRecord], - step_indices: &[StepIndex], - kind: GpuWitgenKind, - mixed_lk: &Multiplicity, -) -> Result<(), ZKVMError> { - if std::env::var_os("CENO_GPU_DEBUG_COMPARE_LK").is_none() { - return Ok(()); - } - - // Compare against cpu_assign_instances (the true baseline using assign_instance) - let mut cpu_ctx = shard_ctx.new_empty_like(); - let (_, cpu_assign_lk) = crate::instructions::cpu_assign_instances::( - config, - &mut cpu_ctx, - num_witin, - num_structural_witin, - shard_steps, - step_indices, - )?; - tracing::info!("[GPU lk debug] kind={kind:?} comparing mixed_lk vs cpu_assign_instances lk"); - log_lk_diff(kind, &cpu_assign_lk, mixed_lk); - Ok(()) -} - -fn log_lk_diff(kind: GpuWitgenKind, cpu_lk: &Multiplicity, actual_lk: &Multiplicity) { - let limit = std::env::var("CENO_GPU_DEBUG_COMPARE_LK_LIMIT") - .ok() - .and_then(|v| v.parse::().ok()) - .unwrap_or(32); - - let mut total_diffs = 0usize; - for (table_idx, (cpu_table, actual_table)) in cpu_lk.iter().zip(actual_lk.iter()).enumerate() { - let mut keys = cpu_table - .keys() - .chain(actual_table.keys()) - .copied() - .collect::>(); - keys.sort_unstable(); - keys.dedup(); - - let mut table_diffs = Vec::new(); - for key in keys { - let cpu_count = cpu_table.get(&key).copied().unwrap_or(0); - let actual_count = actual_table.get(&key).copied().unwrap_or(0); - if cpu_count != actual_count { - table_diffs.push((key, cpu_count, actual_count)); - } - } - - if !table_diffs.is_empty() { - total_diffs += table_diffs.len(); - tracing::error!( - "[GPU lk debug] kind={kind:?} table={} diff_count={}", - lookup_table_name(table_idx), - table_diffs.len() - ); - for (key, cpu_count, actual_count) in table_diffs.into_iter().take(limit) { - tracing::error!( - "[GPU lk debug] kind={kind:?} table={} key={} cpu={} gpu={}", - lookup_table_name(table_idx), - key, - cpu_count, - actual_count - ); - } - } - } - - if total_diffs == 0 { - tracing::info!("[GPU lk debug] kind={kind:?} CPU/GPU lookup multiplicities match"); - } -} - -fn debug_compare_witness>( - config: &I::InstructionConfig, - shard_ctx: &ShardContext, - num_witin: usize, - num_structural_witin: usize, - shard_steps: &[StepRecord], - step_indices: &[StepIndex], - kind: GpuWitgenKind, - gpu_witness: &RowMajorMatrix, -) -> Result<(), ZKVMError> { - if std::env::var_os("CENO_GPU_DEBUG_COMPARE_WITNESS").is_none() { - return Ok(()); - } - - let mut cpu_ctx = shard_ctx.new_empty_like(); - let (cpu_rmms, _) = crate::instructions::cpu_assign_instances::( - config, - &mut cpu_ctx, - num_witin, - num_structural_witin, - shard_steps, - step_indices, - )?; - let cpu_witness = &cpu_rmms[0]; - let cpu_vals = cpu_witness.values(); - let gpu_vals = gpu_witness.values(); - if cpu_vals == gpu_vals { - return Ok(()); - } - - let limit = std::env::var("CENO_GPU_DEBUG_COMPARE_WITNESS_LIMIT") - .ok() - .and_then(|v| v.parse::().ok()) - .unwrap_or(16); - let cpu_num_cols = cpu_witness.n_col(); - let cpu_num_rows = cpu_vals.len() / cpu_num_cols; - let mut mismatches = 0usize; - for row in 0..cpu_num_rows { - for col in 0..cpu_num_cols { - let idx = row * cpu_num_cols + col; - if cpu_vals[idx] != gpu_vals[idx] { - mismatches += 1; - if mismatches <= limit { - tracing::error!( - "[GPU witness debug] kind={kind:?} row={} col={} cpu={:?} gpu={:?}", - row, - col, - cpu_vals[idx], - gpu_vals[idx] - ); - } - } - } - } - tracing::error!( - "[GPU witness debug] kind={kind:?} total_mismatches={}", - mismatches - ); - Ok(()) -} - -fn debug_compare_shard_side_effects>( - config: &I::InstructionConfig, - shard_ctx: &ShardContext, - shard_steps: &[StepRecord], - step_indices: &[StepIndex], - kind: GpuWitgenKind, -) -> Result<(), ZKVMError> { - if std::env::var_os("CENO_GPU_DEBUG_COMPARE_SHARD").is_none() { - return Ok(()); - } - - let mut cpu_ctx = shard_ctx.new_empty_like(); - let _ = cpu_collect_side_effects::(config, &mut cpu_ctx, shard_steps, step_indices)?; - - let mut mixed_ctx = shard_ctx.new_empty_like(); - let _ = - cpu_collect_shard_side_effects::(config, &mut mixed_ctx, shard_steps, step_indices)?; - - let cpu_addr = cpu_ctx.get_addr_accessed(); - let mixed_addr = mixed_ctx.get_addr_accessed(); - if cpu_addr != mixed_addr { - tracing::error!( - "[GPU shard debug] kind={kind:?} addr_accessed cpu={} gpu={}", - cpu_addr.len(), - mixed_addr.len() - ); - } - - let cpu_reads = flatten_ram_records(cpu_ctx.read_records()); - let mixed_reads = flatten_ram_records(mixed_ctx.read_records()); - if cpu_reads != mixed_reads { - log_ram_record_diff(kind, "read_records", &cpu_reads, &mixed_reads); - } - - let cpu_writes = flatten_ram_records(cpu_ctx.write_records()); - let mixed_writes = flatten_ram_records(mixed_ctx.write_records()); - if cpu_writes != mixed_writes { - log_ram_record_diff(kind, "write_records", &cpu_writes, &mixed_writes); - } - - Ok(()) -} - -/// Compare GPU shard context vs CPU shard context, field by field. -/// -/// Both paths are independent and produce equivalent ShardContext state: -/// CPU path: cpu_collect_shard_side_effects → addr_accessed + write_records + read_records -/// GPU path: compact_records → shard records (④gpu_ec_records) -/// ram_slots WAS_SENT → addr_accessed (①) -/// (②write_records and ③read_records stay empty for GPU EC kernels) -/// -/// This function builds both independently and compares: -/// A. addr_accessed sets -/// B. shard records (sorted, normalized to ShardRamRecord) -/// C. EC points (nonce + SepticPoint x,y) -/// -/// Activated by CENO_GPU_DEBUG_COMPARE_EC=1. -fn debug_compare_shard_ec>( - compact_records: &[GpuShardRamRecord], - ram_slots: &[GpuRamRecordSlot], - config: &I::InstructionConfig, - shard_ctx: &ShardContext, - shard_steps: &[StepRecord], - step_indices: &[StepIndex], - kind: GpuWitgenKind, -) { - if std::env::var_os("CENO_GPU_DEBUG_COMPARE_EC").is_none() { - return; - } - - use crate::scheme::septic_curve::{SepticExtension, SepticPoint}; - use crate::tables::{ECPoint, ShardRamRecord}; - use ff_ext::{PoseidonField, SmallField}; - - let limit = std::env::var("CENO_GPU_DEBUG_COMPARE_EC_LIMIT") - .ok() - .and_then(|v| v.parse::().ok()) - .unwrap_or(16); - - // ========== Build CPU shard context (independent, isolated) ========== - let mut cpu_ctx = shard_ctx.new_empty_like(); - if let Err(e) = cpu_collect_shard_side_effects::( - config, &mut cpu_ctx, shard_steps, step_indices, - ) { - tracing::error!("[GPU EC debug] kind={kind:?} CPU shard side effects failed: {e:?}"); - return; - } - - let perm = ::get_default_perm(); - - // CPU: addr_accessed - let cpu_addr = cpu_ctx.get_addr_accessed(); - - // CPU: shard records (BTreeMap → ShardRamRecord + ECPoint) - let mut cpu_entries: Vec<(ShardRamRecord, ECPoint)> = Vec::new(); - for records in cpu_ctx.write_records() { - for (vma, record) in records { - let rec: ShardRamRecord = (vma, record, true).into(); - let ec = rec.to_ec_point::(&perm); - cpu_entries.push((rec, ec)); - } - } - for records in cpu_ctx.read_records() { - for (vma, record) in records { - let rec: ShardRamRecord = (vma, record, false).into(); - let ec = rec.to_ec_point::(&perm); - cpu_entries.push((rec, ec)); - } - } - cpu_entries.sort_by_key(|(r, _)| (r.addr, r.is_to_write_set as u8, r.ram_type as u8)); - - // ========== Build GPU shard context (independent, from D2H data only) ========== - - // GPU: addr_accessed (from ram_slots WAS_SENT flags) - let gpu_addr: rustc_hash::FxHashSet = ram_slots - .iter() - .filter(|s| s.flags & (1 << 4) != 0) - .map(|s| WordAddr(s.addr)) - .collect(); - - // GPU: shard records (compact_records → ShardRamRecord + ECPoint) - let mut gpu_entries: Vec<(ShardRamRecord, ECPoint)> = compact_records - .iter() - .map(|g| { - let rec = ShardRamRecord { - addr: g.addr, - ram_type: if g.ram_type == 1 { RAMType::Register } else { RAMType::Memory }, - value: g.value, - shard: g.shard, - local_clk: g.local_clk, - global_clk: g.global_clk, - is_to_write_set: g.is_to_write_set != 0, - }; - let x = SepticExtension(g.point_x.map(|v| E::BaseField::from_canonical_u32(v))); - let y = SepticExtension(g.point_y.map(|v| E::BaseField::from_canonical_u32(v))); - let point = SepticPoint::from_affine(x, y); - let ec = ECPoint:: { nonce: g.nonce, point }; - (rec, ec) - }) - .collect(); - gpu_entries.sort_by_key(|(r, _)| (r.addr, r.is_to_write_set as u8, r.ram_type as u8)); - - // ========== Compare A: addr_accessed ========== - if cpu_addr != gpu_addr { - let cpu_only: Vec<_> = cpu_addr.difference(&gpu_addr).collect(); - let gpu_only: Vec<_> = gpu_addr.difference(&cpu_addr).collect(); - tracing::error!( - "[GPU EC debug] kind={kind:?} ADDR_ACCESSED MISMATCH: cpu={} gpu={} \ - cpu_only={} gpu_only={}", - cpu_addr.len(), gpu_addr.len(), cpu_only.len(), gpu_only.len() - ); - for (i, addr) in cpu_only.iter().enumerate() { - if i >= limit { break; } - tracing::error!("[GPU EC debug] kind={kind:?} addr_accessed CPU-only: {}", addr.0); - } - for (i, addr) in gpu_only.iter().enumerate() { - if i >= limit { break; } - tracing::error!("[GPU EC debug] kind={kind:?} addr_accessed GPU-only: {}", addr.0); - } - } - - // ========== Compare B+C: shard records + EC points ========== - - // Check counts - if cpu_entries.len() != gpu_entries.len() { - tracing::error!( - "[GPU EC debug] kind={kind:?} RECORD COUNT MISMATCH: cpu={} gpu={}", - cpu_entries.len(), gpu_entries.len() - ); - let cpu_keys: std::collections::BTreeSet<_> = cpu_entries - .iter().map(|(r, _)| (r.addr, r.is_to_write_set)).collect(); - let gpu_keys: std::collections::BTreeSet<_> = gpu_entries - .iter().map(|(r, _)| (r.addr, r.is_to_write_set)).collect(); - let mut logged = 0usize; - for key in cpu_keys.difference(&gpu_keys) { - if logged >= limit { break; } - tracing::error!("[GPU EC debug] kind={kind:?} CPU-only: addr={} is_write={}", key.0, key.1); - logged += 1; - } - for key in gpu_keys.difference(&cpu_keys) { - if logged >= limit { break; } - tracing::error!("[GPU EC debug] kind={kind:?} GPU-only: addr={} is_write={}", key.0, key.1); - logged += 1; - } - } - - // Check GPU duplicates (BTreeMap deduplicates, atomicAdd doesn't) - let mut gpu_dup_count = 0usize; - for w in gpu_entries.windows(2) { - if w[0].0.addr == w[1].0.addr - && w[0].0.is_to_write_set == w[1].0.is_to_write_set - && w[0].0.ram_type == w[1].0.ram_type - { - gpu_dup_count += 1; - if gpu_dup_count <= limit { - tracing::error!( - "[GPU EC debug] kind={kind:?} GPU DUPLICATE: addr={} is_write={} ram_type={:?}", - w[0].0.addr, w[0].0.is_to_write_set, w[0].0.ram_type - ); - } - } - } - - // Merge-walk sorted lists - let mut ci = 0usize; - let mut gi = 0usize; - let mut record_mismatches = 0usize; - let mut ec_mismatches = 0usize; - let mut matched = 0usize; - - while ci < cpu_entries.len() && gi < gpu_entries.len() { - let (cr, ce) = &cpu_entries[ci]; - let (gr, ge) = &gpu_entries[gi]; - let ck = (cr.addr, cr.is_to_write_set as u8, cr.ram_type as u8); - let gk = (gr.addr, gr.is_to_write_set as u8, gr.ram_type as u8); - - match ck.cmp(&gk) { - std::cmp::Ordering::Less => { - if record_mismatches < limit { - tracing::error!( - "[GPU EC debug] kind={kind:?} MISSING in GPU: addr={} is_write={} ram={:?} val={} shard={} clk={}", - cr.addr, cr.is_to_write_set, cr.ram_type, cr.value, cr.shard, cr.global_clk - ); - } - record_mismatches += 1; - ci += 1; - continue; - } - std::cmp::Ordering::Greater => { - if record_mismatches < limit { - tracing::error!( - "[GPU EC debug] kind={kind:?} EXTRA in GPU: addr={} is_write={} ram={:?} val={} shard={} clk={}", - gr.addr, gr.is_to_write_set, gr.ram_type, gr.value, gr.shard, gr.global_clk - ); - } - record_mismatches += 1; - gi += 1; - continue; - } - std::cmp::Ordering::Equal => {} - } - - // Keys match — compare record fields - let mut field_diff = false; - for (name, cv, gv) in [ - ("value", cr.value as u64, gr.value as u64), - ("shard", cr.shard, gr.shard), - ("local_clk", cr.local_clk, gr.local_clk), - ("global_clk", cr.global_clk, gr.global_clk), - ] { - if cv != gv { - field_diff = true; - if record_mismatches < limit { - tracing::error!( - "[GPU EC debug] kind={kind:?} addr={} {name}: cpu={cv} gpu={gv}", - cr.addr - ); - } - } - } - if field_diff { - record_mismatches += 1; - } - - // Compare EC points - let mut ec_diff = false; - if ce.nonce != ge.nonce { - ec_diff = true; - if ec_mismatches < limit { - tracing::error!( - "[GPU EC debug] kind={kind:?} addr={} nonce: cpu={} gpu={}", - cr.addr, ce.nonce, ge.nonce - ); - } - } - for j in 0..7 { - let cv = ce.point.x.0[j].to_canonical_u64() as u32; - let gv = ge.point.x.0[j].to_canonical_u64() as u32; - if cv != gv { - ec_diff = true; - if ec_mismatches < limit { - tracing::error!( - "[GPU EC debug] kind={kind:?} addr={} x[{j}]: cpu={cv} gpu={gv}", cr.addr - ); - } - } - } - for j in 0..7 { - let cv = ce.point.y.0[j].to_canonical_u64() as u32; - let gv = ge.point.y.0[j].to_canonical_u64() as u32; - if cv != gv { - ec_diff = true; - if ec_mismatches < limit { - tracing::error!( - "[GPU EC debug] kind={kind:?} addr={} y[{j}]: cpu={cv} gpu={gv}", cr.addr - ); - } - } - } - if ec_diff { - ec_mismatches += 1; - } - - matched += 1; - ci += 1; - gi += 1; - } - - // Remaining unmatched - while ci < cpu_entries.len() { - if record_mismatches < limit { - let (cr, _) = &cpu_entries[ci]; - tracing::error!( - "[GPU EC debug] kind={kind:?} MISSING in GPU (tail): addr={} is_write={} val={}", - cr.addr, cr.is_to_write_set, cr.value - ); - } - record_mismatches += 1; - ci += 1; - } - while gi < gpu_entries.len() { - if record_mismatches < limit { - let (gr, _) = &gpu_entries[gi]; - tracing::error!( - "[GPU EC debug] kind={kind:?} EXTRA in GPU (tail): addr={} is_write={} val={}", - gr.addr, gr.is_to_write_set, gr.value - ); - } - record_mismatches += 1; - gi += 1; - } - - // ========== Summary ========== - let addr_ok = cpu_addr == gpu_addr; - if addr_ok && record_mismatches == 0 && ec_mismatches == 0 && gpu_dup_count == 0 { - tracing::info!( - "[GPU EC debug] kind={kind:?} ALL MATCH: {} records, {} addr_accessed, EC points OK", - matched, cpu_addr.len() - ); - } else { - tracing::error!( - "[GPU EC debug] kind={kind:?} MISMATCH: matched={matched} record_diffs={record_mismatches} \ - ec_diffs={ec_mismatches} gpu_dups={gpu_dup_count} addr_ok={addr_ok} \ - (cpu_records={} gpu_records={} cpu_addrs={} gpu_addrs={})", - cpu_entries.len(), gpu_entries.len(), cpu_addr.len(), gpu_addr.len() - ); - } -} - -fn flatten_ram_records( - records: &[std::collections::BTreeMap], -) -> Vec<(u32, u64, u64, u64, u64, Option, u32, usize)> { - let mut flat = Vec::new(); - for table in records { - for (addr, record) in table { - flat.push(( - addr.0, - record.reg_id, - record.prev_cycle, - record.cycle, - record.shard_cycle, - record.prev_value, - record.value, - record.shard_id, - )); - } - } - flat -} - -fn log_ram_record_diff( - kind: GpuWitgenKind, - label: &str, - cpu_records: &[(u32, u64, u64, u64, u64, Option, u32, usize)], - mixed_records: &[(u32, u64, u64, u64, u64, Option, u32, usize)], -) { - let limit = std::env::var("CENO_GPU_DEBUG_COMPARE_SHARD_LIMIT") - .ok() - .and_then(|v| v.parse::().ok()) - .unwrap_or(16); - tracing::error!( - "[GPU shard debug] kind={kind:?} {} cpu={} gpu={}", - label, - cpu_records.len(), - mixed_records.len() - ); - let max_len = cpu_records.len().max(mixed_records.len()); - let mut logged = 0usize; - for idx in 0..max_len { - let cpu = cpu_records.get(idx); - let gpu = mixed_records.get(idx); - if cpu != gpu { - tracing::error!( - "[GPU shard debug] kind={kind:?} {} idx={} cpu={:?} gpu={:?}", - label, - idx, - cpu, - gpu - ); - logged += 1; - if logged >= limit { - break; - } - } - } -} - -fn lookup_table_name(table_idx: usize) -> &'static str { - match table_idx { - x if x == LookupTable::Dynamic as usize => "Dynamic", - x if x == LookupTable::DoubleU8 as usize => "DoubleU8", - x if x == LookupTable::And as usize => "And", - x if x == LookupTable::Or as usize => "Or", - x if x == LookupTable::Xor as usize => "Xor", - x if x == LookupTable::Ltu as usize => "Ltu", - x if x == LookupTable::Pow as usize => "Pow", - x if x == LookupTable::Instruction as usize => "Instruction", - _ => "Unknown", - } -} - -/// Batch compute EC points for continuation circuit ShardRamRecords on GPU. -/// -/// Converts ShardRamRecords to GPU format, launches the `batch_continuation_ec` -/// kernel to compute Poseidon2 + SepticCurve on device, and converts results -/// back to ShardRamInput (with EC points). -/// -/// Returns (write_inputs, read_inputs) maintaining the write-before-read ordering -/// invariant required by ShardRamCircuit::assign_instances. -pub fn gpu_batch_continuation_ec( - write_records: &[(crate::tables::ShardRamRecord, &'static str)], - read_records: &[(crate::tables::ShardRamRecord, &'static str)], -) -> Result< - ( - Vec>, - Vec>, - ), - ZKVMError, -> { - use crate::tables::ShardRamInput; - use gkr_iop::gpu::get_cuda_hal; - - let hal = get_cuda_hal().map_err(|e| { - ZKVMError::InvalidWitness(format!("GPU not available for batch EC: {e}").into()) - })?; - - let total = write_records.len() + read_records.len(); - if total == 0 { - return Ok((vec![], vec![])); - } - - // Convert ShardRamRecords to GpuShardRamRecord format - let mut gpu_records: Vec = Vec::with_capacity(total); - for (rec, _name) in write_records.iter().chain(read_records.iter()) { - gpu_records.push(shard_ram_record_to_gpu(rec)); - } - - // GPU batch EC computation - let result = info_span!("gpu_batch_ec", n = total).in_scope(|| { - hal.batch_continuation_ec(&gpu_records) - }).map_err(|e| { - ZKVMError::InvalidWitness(format!("GPU batch EC failed: {e}").into()) - })?; - - // Convert back to ShardRamInput, split into writes and reads - let mut write_inputs = Vec::with_capacity(write_records.len()); - let mut read_inputs = Vec::with_capacity(read_records.len()); - - for (i, gpu_rec) in result.iter().enumerate() { - let (rec, name) = if i < write_records.len() { - (&write_records[i].0, write_records[i].1) - } else { - let ri = i - write_records.len(); - (&read_records[ri].0, read_records[ri].1) - }; - - let ec_point = gpu_shard_ram_record_to_ec_point::(gpu_rec); - let input = ShardRamInput { - name, - record: rec.clone(), - ec_point, - }; - - if i < write_records.len() { - write_inputs.push(input); - } else { - read_inputs.push(input); - } - } - - Ok((write_inputs, read_inputs)) -} - -/// Convert a ShardRamRecord to GpuShardRamRecord (metadata only, EC fields zeroed). -fn shard_ram_record_to_gpu(rec: &crate::tables::ShardRamRecord) -> GpuShardRamRecord { - GpuShardRamRecord { - addr: rec.addr, - ram_type: match rec.ram_type { - RAMType::Register => 1, - RAMType::Memory => 2, - _ => 0, - }, - value: rec.value, - _pad0: 0, - shard: rec.shard, - local_clk: rec.local_clk, - global_clk: rec.global_clk, - is_to_write_set: if rec.is_to_write_set { 1 } else { 0 }, - nonce: 0, - point_x: [0; 7], - point_y: [0; 7], - } -} - -/// Convert a GPU-computed GpuShardRamRecord to ECPoint. -fn gpu_shard_ram_record_to_ec_point( - gpu_rec: &GpuShardRamRecord, -) -> crate::tables::ECPoint { - use crate::scheme::septic_curve::{SepticExtension, SepticPoint}; - - let mut point_x_arr = [E::BaseField::ZERO; 7]; - let mut point_y_arr = [E::BaseField::ZERO; 7]; - for j in 0..7 { - point_x_arr[j] = E::BaseField::from_canonical_u32(gpu_rec.point_x[j]); - point_y_arr[j] = E::BaseField::from_canonical_u32(gpu_rec.point_y[j]); - } - - let x = SepticExtension(point_x_arr); - let y = SepticExtension(point_y_arr); - let point = SepticPoint::from_affine(x, y); - - crate::tables::ECPoint { - nonce: gpu_rec.nonce, - point, - } -} - -pub(crate) fn gpu_lk_counters_to_multiplicity(counters: LkResult) -> Result, ZKVMError> { - let mut tables: [FxHashMap; 8] = Default::default(); - - // Dynamic: D2H + direct FxHashMap construction (no LkMultiplicity) - info_span!("lk_dynamic_d2h").in_scope(|| { - let counts: Vec = counters.dynamic.to_vec().map_err(|e| { - ZKVMError::InvalidWitness(format!("GPU dynamic lk D2H failed: {e}").into()) - })?; - let nnz = counts.iter().filter(|&&c| c != 0).count(); - let map = &mut tables[LookupTable::Dynamic as usize]; - map.reserve(nnz); - for (key, &count) in counts.iter().enumerate() { - if count != 0 { - map.insert(key as u64, count as usize); - } - } - Ok::<(), ZKVMError>(()) - })?; - - // Dense tables: same pattern, skip None - info_span!("lk_dense_d2h").in_scope(|| { - let dense: &[(LookupTable, &Option)] = &[ - (LookupTable::DoubleU8, &counters.double_u8), - (LookupTable::And, &counters.and_table), - (LookupTable::Or, &counters.or_table), - (LookupTable::Xor, &counters.xor_table), - (LookupTable::Ltu, &counters.ltu_table), - (LookupTable::Pow, &counters.pow_table), - ]; - for &(table, ref buf_opt) in dense { - if let Some(buf) = buf_opt { - let counts: Vec = buf.to_vec().map_err(|e| { - ZKVMError::InvalidWitness( - format!("GPU {:?} lk D2H failed: {e}", table).into(), - ) - })?; - let nnz = counts.iter().filter(|&&c| c != 0).count(); - let map = &mut tables[table as usize]; - map.reserve(nnz); - for (key, &count) in counts.iter().enumerate() { - if count != 0 { - map.insert(key as u64, count as usize); - } - } - } - } - Ok::<(), ZKVMError>(()) - })?; - - // Fetch (Instruction table) - if let Some(fetch_buf) = counters.fetch { - info_span!("lk_fetch_d2h").in_scope(|| { - let base_pc = counters.fetch_base_pc; - let counts = fetch_buf.to_vec().map_err(|e| { - ZKVMError::InvalidWitness(format!("GPU fetch lk D2H failed: {e}").into()) - })?; - let nnz = counts.iter().filter(|&&c| c != 0).count(); - let map = &mut tables[LookupTable::Instruction as usize]; - map.reserve(nnz); - for (slot_idx, &count) in counts.iter().enumerate() { - if count != 0 { - let pc = base_pc as u64 + (slot_idx as u64) * 4; - map.insert(pc, count as usize); - } - } - Ok::<(), ZKVMError>(()) - })?; - } - - Ok(Multiplicity(tables)) -} - /// GPU dispatch entry point for keccak ecall witness generation. /// /// Unlike `try_gpu_assign_instances`, keccak has a rotation-aware matrix layout @@ -3084,199 +1505,3 @@ fn gpu_assign_keccak_inner( Ok(([raw_witin, raw_structural], lk_multiplicity)) } - -/// Debug comparison for keccak GPU witgen. -/// Runs the CPU path and compares LK / witness / shard side effects. -/// -/// Activated by CENO_GPU_DEBUG_COMPARE_LK, CENO_GPU_DEBUG_COMPARE_WITNESS, -/// or CENO_GPU_DEBUG_COMPARE_SHARD environment variables. -#[cfg(feature = "gpu")] -fn debug_compare_keccak( - config: &crate::instructions::riscv::ecall::keccak::EcallKeccakConfig, - shard_ctx: &ShardContext, - num_witin: usize, - num_structural_witin: usize, - steps: &[StepRecord], - step_indices: &[StepIndex], - gpu_lk: &Multiplicity, - gpu_witin: &RowMajorMatrix, - gpu_addrs: &[u32], -) -> Result<(), ZKVMError> { - let want_lk = std::env::var_os("CENO_GPU_DEBUG_COMPARE_LK").is_some(); - let want_witness = std::env::var_os("CENO_GPU_DEBUG_COMPARE_WITNESS").is_some(); - let want_shard = std::env::var_os("CENO_GPU_DEBUG_COMPARE_SHARD").is_some(); - - if !want_lk && !want_witness && !want_shard { - return Ok(()); - } - - // Guard against recursion: is_gpu_witgen_disabled() uses OnceLock so env var - // manipulation doesn't work. Use a thread-local flag instead. - thread_local! { - static IN_DEBUG_COMPARE: Cell = const { Cell::new(false) }; - } - if IN_DEBUG_COMPARE.with(|f| f.get()) { - return Ok(()); - } - IN_DEBUG_COMPARE.with(|f| f.set(true)); - - tracing::info!("[GPU keccak debug] running CPU baseline for comparison"); - - // Run CPU path via assign_instances. The IN_DEBUG_COMPARE guard prevents - // gpu_assign_keccak_instances from calling debug_compare_keccak again, - // so it will produce the GPU result, which is then returned without - // re-entering this function. We need assign_instances (not cpu_assign_instances) - // because keccak has rotation matrices and 3 structural columns. - // - // To force the CPU path, we use is_force_cpu_path() by setting the env var. - let mut cpu_ctx = shard_ctx.new_empty_like(); - let (cpu_rmms, cpu_lk) = { - use crate::instructions::riscv::ecall::keccak::KeccakInstruction; - // Set force-CPU flag so gpu_assign_keccak_instances returns None - set_force_cpu_path(true); - let result = as crate::instructions::Instruction>::assign_instances( - config, - &mut cpu_ctx, - num_witin, - num_structural_witin, - steps, - step_indices, - ); - set_force_cpu_path(false); - IN_DEBUG_COMPARE.with(|f| f.set(false)); - result? - }; - - let kind = GpuWitgenKind::Keccak; - - if want_lk { - tracing::info!("[GPU keccak debug] comparing LK multiplicities"); - log_lk_diff(kind, &cpu_lk, gpu_lk); - } - - if want_witness { - let limit = std::env::var("CENO_GPU_DEBUG_COMPARE_WITNESS_LIMIT") - .ok() - .and_then(|v| v.parse::().ok()) - .unwrap_or(32); - let cpu_witin = &cpu_rmms[0]; - let gpu_vals = gpu_witin.values(); - let cpu_vals = cpu_witin.values(); - let mut diffs = 0usize; - for (i, (g, c)) in gpu_vals.iter().zip(cpu_vals.iter()).enumerate() { - if g != c { - if diffs < limit { - let row = i / num_witin; - let col = i % num_witin; - tracing::error!( - "[GPU keccak witness] row={} col={} gpu={:?} cpu={:?}", - row, col, g, c - ); - } - diffs += 1; - } - } - if diffs == 0 { - tracing::info!("[GPU keccak debug] witness matrices match ({} elements)", gpu_vals.len()); - } else { - tracing::error!("[GPU keccak debug] witness mismatch: {} diffs out of {}", diffs, gpu_vals.len()); - } - } - - if want_shard { - // Compare addr_accessed: GPU entries were D2H'd from the shared buffer - // delta (before/after kernel launch) and passed in as gpu_addrs. - let cpu_addr = cpu_ctx.get_addr_accessed(); - let gpu_addr_set: rustc_hash::FxHashSet = - gpu_addrs.iter().map(|&a| WordAddr(a)).collect(); - - if cpu_addr.len() != gpu_addr_set.len() { - tracing::error!( - "[GPU keccak shard] addr_accessed count mismatch: cpu={} gpu={}", - cpu_addr.len(), gpu_addr_set.len() - ); - } - let mut missing_from_gpu = 0usize; - let mut extra_in_gpu = 0usize; - let limit = 16usize; - for addr in &cpu_addr { - if !gpu_addr_set.contains(addr) { - if missing_from_gpu < limit { - tracing::error!("[GPU keccak shard] addr {} in CPU but not GPU", addr.0); - } - missing_from_gpu += 1; - } - } - for &addr in gpu_addrs { - if !cpu_addr.contains(&WordAddr(addr)) { - if extra_in_gpu < limit { - tracing::error!("[GPU keccak shard] addr {} in GPU but not CPU", addr); - } - extra_in_gpu += 1; - } - } - if missing_from_gpu == 0 && extra_in_gpu == 0 { - tracing::info!( - "[GPU keccak shard] addr_accessed matches: {} entries", - cpu_addr.len() - ); - } else { - tracing::error!( - "[GPU keccak shard] addr_accessed diff: missing_from_gpu={} extra_in_gpu={}", - missing_from_gpu, extra_in_gpu - ); - } - } - - Ok(()) -} - -/// Convert GPU device buffer (column-major) to RowMajorMatrix via GPU transpose + D2H copy. -/// -/// GPU witgen kernels output column-major layout for better memory coalescing. -/// This function transposes to row-major on GPU before copying to host. -fn gpu_witness_to_rmm( - hal: &CudaHalBB31, - gpu_result: ceno_gpu::common::witgen_types::GpuWitnessResult< - ceno_gpu::common::BufferImpl<'static, ::BaseField>, - >, - num_rows: usize, - num_cols: usize, - padding: InstancePaddingStrategy, -) -> Result, ZKVMError> { - // Transpose from column-major to row-major on GPU. - // Column-major (num_rows x num_cols) is stored as num_cols groups of num_rows elements, - // which is equivalent to a (num_cols x num_rows) row-major matrix. - // Transposing with cols=num_rows, rows=num_cols produces (num_rows x num_cols) row-major. - let mut rmm_buffer = hal - .alloc_elems_on_device(num_rows * num_cols, false, None) - .map_err(|e| { - ZKVMError::InvalidWitness(format!("GPU alloc for transpose failed: {e}").into()) - })?; - matrix_transpose::( - &hal.inner, - &mut rmm_buffer, - &gpu_result.device_buffer, - num_rows, - num_cols, - ) - .map_err(|e| ZKVMError::InvalidWitness(format!("GPU transpose failed: {e}").into()))?; - - let gpu_data: Vec<::BaseField> = rmm_buffer - .to_vec() - .map_err(|e| ZKVMError::InvalidWitness(format!("GPU D2H copy failed: {e}").into()))?; - - // Safety: BabyBear is the only supported GPU field, and E::BaseField must match - let data: Vec = unsafe { - let mut data = std::mem::ManuallyDrop::new(gpu_data); - Vec::from_raw_parts( - data.as_mut_ptr() as *mut E::BaseField, - data.len(), - data.capacity(), - ) - }; - - Ok(RowMajorMatrix::::new_by_values( - data, num_cols, padding, - )) -} From e29851175cad4153ce32a3205edf4c988c6ef1cc Mon Sep 17 00:00:00 2001 From: Velaciela Date: Mon, 23 Mar 2026 18:41:52 +0800 Subject: [PATCH 53/73] part-b: fix --- ceno_zkvm/src/instructions/riscv/gpu/add.rs | 7 +++++++ ceno_zkvm/src/instructions/riscv/gpu/logic_r.rs | 5 +++++ ceno_zkvm/src/instructions/riscv/gpu/lw.rs | 5 +++++ 3 files changed, 17 insertions(+) diff --git a/ceno_zkvm/src/instructions/riscv/gpu/add.rs b/ceno_zkvm/src/instructions/riscv/gpu/add.rs index beaf9693f..223b0a4fe 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/add.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/add.rs @@ -228,6 +228,13 @@ mod tests { .unwrap() .expect("GPU path should be available"); + // Flush shared EC/addr buffers from GPU device to shard_ctx + // (in the e2e pipeline this is called once per shard after all opcode circuits) + crate::instructions::riscv::gpu::device_cache::flush_shared_ec_buffers( + &mut shard_ctx_full_gpu, + ) + .unwrap(); + assert_eq!(gpu_rmms[0].values(), cpu_rmms[0].values()); assert_eq!(flatten_lk(&gpu_lkm), flatten_lk(&cpu_lkm)); assert_eq!( diff --git a/ceno_zkvm/src/instructions/riscv/gpu/logic_r.rs b/ceno_zkvm/src/instructions/riscv/gpu/logic_r.rs index 68ba03060..d084fe324 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/logic_r.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/logic_r.rs @@ -217,6 +217,11 @@ mod tests { .unwrap() .expect("GPU path should be available"); + crate::instructions::riscv::gpu::device_cache::flush_shared_ec_buffers( + &mut shard_ctx_full_gpu, + ) + .unwrap(); + assert_eq!(gpu_rmms[0].values(), cpu_rmms[0].values()); assert_eq!(flatten_lk(&gpu_lkm), flatten_lk(&cpu_lkm)); assert_eq!( diff --git a/ceno_zkvm/src/instructions/riscv/gpu/lw.rs b/ceno_zkvm/src/instructions/riscv/gpu/lw.rs index b518d3361..111010f12 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/lw.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/lw.rs @@ -240,6 +240,11 @@ mod tests { .unwrap() .expect("GPU path should be available"); + crate::instructions::riscv::gpu::device_cache::flush_shared_ec_buffers( + &mut shard_ctx_full_gpu, + ) + .unwrap(); + assert_eq!(gpu_rmms[0].values(), cpu_rmms[0].values()); assert_eq!(flatten_lk(&gpu_lkm), flatten_lk(&cpu_lkm)); assert_eq!( From 033d729f1a9985ba38977c529aab81fefb1cbb02 Mon Sep 17 00:00:00 2001 From: Velaciela Date: Mon, 23 Mar 2026 18:54:03 +0800 Subject: [PATCH 54/73] part-c --- ceno_zkvm/src/instructions.rs | 155 +-------- .../src/instructions/host_ops/cpu_fallback.rs | 166 ++++++++++ ceno_zkvm/src/instructions/host_ops/emit.rs | 160 +++++++++ ceno_zkvm/src/instructions/host_ops/lk_ops.rs | 76 +++++ .../{side_effects.rs => host_ops/mod.rs} | 303 +----------------- ceno_zkvm/src/instructions/host_ops/sink.rs | 61 ++++ 6 files changed, 482 insertions(+), 439 deletions(-) create mode 100644 ceno_zkvm/src/instructions/host_ops/cpu_fallback.rs create mode 100644 ceno_zkvm/src/instructions/host_ops/emit.rs create mode 100644 ceno_zkvm/src/instructions/host_ops/lk_ops.rs rename ceno_zkvm/src/instructions/{side_effects.rs => host_ops/mod.rs} (77%) create mode 100644 ceno_zkvm/src/instructions/host_ops/sink.rs diff --git a/ceno_zkvm/src/instructions.rs b/ceno_zkvm/src/instructions.rs index 71467d370..548e99a9f 100644 --- a/ceno_zkvm/src/instructions.rs +++ b/ceno_zkvm/src/instructions.rs @@ -20,7 +20,12 @@ use rayon::{ use witness::{InstancePaddingStrategy, RowMajorMatrix}; pub mod riscv; -pub mod side_effects; +pub mod host_ops; + +/// Backward-compatible re-export: old `side_effects` path still works. +pub use host_ops as side_effects; + +pub use host_ops::{cpu_assign_instances, cpu_collect_shard_side_effects, cpu_collect_side_effects}; pub trait Instruction { type InstructionConfig: Send + Sync; @@ -224,151 +229,3 @@ pub fn full_step_indices(steps: &[StepRecord]) -> Vec { (0..steps.len()).collect() } -/// CPU-only assign_instances. Extracted so GPU-enabled instructions can call this as fallback. -pub fn cpu_assign_instances>( - config: &I::InstructionConfig, - shard_ctx: &mut ShardContext, - num_witin: usize, - num_structural_witin: usize, - shard_steps: &[StepRecord], - step_indices: &[StepIndex], -) -> Result< - ( - crate::tables::RMMCollections, - gkr_iop::utils::lk_multiplicity::Multiplicity, - ), - ZKVMError, -> { - assert!(num_structural_witin == 0 || num_structural_witin == 1); - let num_structural_witin = num_structural_witin.max(1); - - let nthreads = multilinear_extensions::util::max_usable_threads(); - let total_instances = step_indices.len(); - let num_instance_per_batch = if total_instances > 256 { - total_instances.div_ceil(nthreads) - } else { - total_instances - } - .max(1); - let lk_multiplicity = crate::witness::LkMultiplicity::default(); - let mut raw_witin = - RowMajorMatrix::::new(total_instances, num_witin, I::padding_strategy()); - let mut raw_structual_witin = RowMajorMatrix::::new( - total_instances, - num_structural_witin, - I::padding_strategy(), - ); - let raw_witin_iter = raw_witin.par_batch_iter_mut(num_instance_per_batch); - let raw_structual_witin_iter = raw_structual_witin.par_batch_iter_mut(num_instance_per_batch); - let shard_ctx_vec = shard_ctx.get_forked(); - - raw_witin_iter - .zip_eq(raw_structual_witin_iter) - .zip_eq(step_indices.par_chunks(num_instance_per_batch)) - .zip(shard_ctx_vec) - .flat_map( - |(((instances, structural_instance), indices), mut shard_ctx)| { - let mut lk_multiplicity = lk_multiplicity.clone(); - instances - .chunks_mut(num_witin) - .zip_eq(structural_instance.chunks_mut(num_structural_witin)) - .zip_eq(indices.iter().copied()) - .map(|((instance, structural_instance), step_idx)| { - *structural_instance.last_mut().unwrap() = E::BaseField::ONE; - I::assign_instance( - config, - &mut shard_ctx, - instance, - &mut lk_multiplicity, - &shard_steps[step_idx], - ) - }) - .collect::>() - }, - ) - .collect::>()?; - - raw_witin.padding_by_strategy(); - raw_structual_witin.padding_by_strategy(); - Ok(( - [raw_witin, raw_structual_witin], - lk_multiplicity.into_finalize_result(), - )) -} - -/// CPU-only side-effect collection for GPU-enabled instructions. -/// -/// This path deliberately avoids scratch witness buffers and calls only the -/// instruction-specific side-effect collector. -pub fn cpu_collect_side_effects>( - config: &I::InstructionConfig, - shard_ctx: &mut ShardContext, - shard_steps: &[StepRecord], - step_indices: &[StepIndex], -) -> Result, ZKVMError> { - cpu_collect_side_effects_inner::(config, shard_ctx, shard_steps, step_indices, false) -} - -/// CPU-side `send()` / `addr_accessed` collection for GPU-assisted lk paths. -/// -/// Implementations may still increment fetch multiplicity on CPU, but all other -/// lookup multiplicities are expected to come from the GPU path. -pub fn cpu_collect_shard_side_effects>( - config: &I::InstructionConfig, - shard_ctx: &mut ShardContext, - shard_steps: &[StepRecord], - step_indices: &[StepIndex], -) -> Result, ZKVMError> { - cpu_collect_side_effects_inner::(config, shard_ctx, shard_steps, step_indices, true) -} - -fn cpu_collect_side_effects_inner>( - config: &I::InstructionConfig, - shard_ctx: &mut ShardContext, - shard_steps: &[StepRecord], - step_indices: &[StepIndex], - shard_only: bool, -) -> Result, ZKVMError> { - let nthreads = max_usable_threads(); - let total = step_indices.len(); - let batch_size = if total > 256 { - total.div_ceil(nthreads) - } else { - total - } - .max(1); - - let lk_multiplicity = LkMultiplicity::default(); - let shard_ctx_vec = shard_ctx.get_forked(); - - step_indices - .par_chunks(batch_size) - .zip(shard_ctx_vec) - .flat_map(|(indices, mut shard_ctx)| { - let mut lk_multiplicity = lk_multiplicity.clone(); - indices - .iter() - .copied() - .map(|step_idx| { - if shard_only { - I::collect_shard_side_effects_instance( - config, - &mut shard_ctx, - &mut lk_multiplicity, - &shard_steps[step_idx], - ) - } else { - I::collect_side_effects_instance( - config, - &mut shard_ctx, - &mut lk_multiplicity, - &shard_steps[step_idx], - ) - } - }) - .collect::>() - }) - .collect::>()?; - - Ok(lk_multiplicity.into_finalize_result()) -} diff --git a/ceno_zkvm/src/instructions/host_ops/cpu_fallback.rs b/ceno_zkvm/src/instructions/host_ops/cpu_fallback.rs new file mode 100644 index 000000000..e3a63c943 --- /dev/null +++ b/ceno_zkvm/src/instructions/host_ops/cpu_fallback.rs @@ -0,0 +1,166 @@ +use ceno_emul::StepIndex; +use ff_ext::ExtensionField; +use gkr_iop::utils::lk_multiplicity::Multiplicity; +use itertools::Itertools; +use multilinear_extensions::util::max_usable_threads; +use p3::field::FieldAlgebra; +use rayon::{ + iter::{IndexedParallelIterator, ParallelIterator}, + slice::ParallelSlice, +}; +use witness::RowMajorMatrix; + +use crate::{ + e2e::ShardContext, error::ZKVMError, tables::RMMCollections, witness::LkMultiplicity, +}; + +use super::super::Instruction; + +/// CPU-only assign_instances. Extracted so GPU-enabled instructions can call this as fallback. +pub fn cpu_assign_instances>( + config: &I::InstructionConfig, + shard_ctx: &mut ShardContext, + num_witin: usize, + num_structural_witin: usize, + shard_steps: &[ceno_emul::StepRecord], + step_indices: &[StepIndex], +) -> Result< + ( + RMMCollections, + Multiplicity, + ), + ZKVMError, +> { + assert!(num_structural_witin == 0 || num_structural_witin == 1); + let num_structural_witin = num_structural_witin.max(1); + + let nthreads = max_usable_threads(); + let total_instances = step_indices.len(); + let num_instance_per_batch = if total_instances > 256 { + total_instances.div_ceil(nthreads) + } else { + total_instances + } + .max(1); + let lk_multiplicity = LkMultiplicity::default(); + let mut raw_witin = + RowMajorMatrix::::new(total_instances, num_witin, I::padding_strategy()); + let mut raw_structual_witin = RowMajorMatrix::::new( + total_instances, + num_structural_witin, + I::padding_strategy(), + ); + let raw_witin_iter = raw_witin.par_batch_iter_mut(num_instance_per_batch); + let raw_structual_witin_iter = raw_structual_witin.par_batch_iter_mut(num_instance_per_batch); + let shard_ctx_vec = shard_ctx.get_forked(); + + raw_witin_iter + .zip_eq(raw_structual_witin_iter) + .zip_eq(step_indices.par_chunks(num_instance_per_batch)) + .zip(shard_ctx_vec) + .flat_map( + |(((instances, structural_instance), indices), mut shard_ctx)| { + let mut lk_multiplicity = lk_multiplicity.clone(); + instances + .chunks_mut(num_witin) + .zip_eq(structural_instance.chunks_mut(num_structural_witin)) + .zip_eq(indices.iter().copied()) + .map(|((instance, structural_instance), step_idx)| { + *structural_instance.last_mut().unwrap() = E::BaseField::ONE; + I::assign_instance( + config, + &mut shard_ctx, + instance, + &mut lk_multiplicity, + &shard_steps[step_idx], + ) + }) + .collect::>() + }, + ) + .collect::>()?; + + raw_witin.padding_by_strategy(); + raw_structual_witin.padding_by_strategy(); + Ok(( + [raw_witin, raw_structual_witin], + lk_multiplicity.into_finalize_result(), + )) +} + +/// CPU-only side-effect collection for GPU-enabled instructions. +/// +/// This path deliberately avoids scratch witness buffers and calls only the +/// instruction-specific side-effect collector. +pub fn cpu_collect_side_effects>( + config: &I::InstructionConfig, + shard_ctx: &mut ShardContext, + shard_steps: &[ceno_emul::StepRecord], + step_indices: &[StepIndex], +) -> Result, ZKVMError> { + cpu_collect_side_effects_inner::(config, shard_ctx, shard_steps, step_indices, false) +} + +/// CPU-side `send()` / `addr_accessed` collection for GPU-assisted lk paths. +/// +/// Implementations may still increment fetch multiplicity on CPU, but all other +/// lookup multiplicities are expected to come from the GPU path. +pub fn cpu_collect_shard_side_effects>( + config: &I::InstructionConfig, + shard_ctx: &mut ShardContext, + shard_steps: &[ceno_emul::StepRecord], + step_indices: &[StepIndex], +) -> Result, ZKVMError> { + cpu_collect_side_effects_inner::(config, shard_ctx, shard_steps, step_indices, true) +} + +fn cpu_collect_side_effects_inner>( + config: &I::InstructionConfig, + shard_ctx: &mut ShardContext, + shard_steps: &[ceno_emul::StepRecord], + step_indices: &[StepIndex], + shard_only: bool, +) -> Result, ZKVMError> { + let nthreads = max_usable_threads(); + let total = step_indices.len(); + let batch_size = if total > 256 { + total.div_ceil(nthreads) + } else { + total + } + .max(1); + + let lk_multiplicity = LkMultiplicity::default(); + let shard_ctx_vec = shard_ctx.get_forked(); + + step_indices + .par_chunks(batch_size) + .zip(shard_ctx_vec) + .flat_map(|(indices, mut shard_ctx)| { + let mut lk_multiplicity = lk_multiplicity.clone(); + indices + .iter() + .copied() + .map(|step_idx| { + if shard_only { + I::collect_shard_side_effects_instance( + config, + &mut shard_ctx, + &mut lk_multiplicity, + &shard_steps[step_idx], + ) + } else { + I::collect_side_effects_instance( + config, + &mut shard_ctx, + &mut lk_multiplicity, + &shard_steps[step_idx], + ) + } + }) + .collect::>() + }) + .collect::>()?; + + Ok(lk_multiplicity.into_finalize_result()) +} diff --git a/ceno_zkvm/src/instructions/host_ops/emit.rs b/ceno_zkvm/src/instructions/host_ops/emit.rs new file mode 100644 index 000000000..4f84d3e05 --- /dev/null +++ b/ceno_zkvm/src/instructions/host_ops/emit.rs @@ -0,0 +1,160 @@ +use gkr_iop::{ + gadgets::{AssertLtConfig, cal_lt_diff}, + tables::{LookupTable, OpsTable}, +}; + +use crate::instructions::riscv::constants::{LIMB_BITS, UINT_LIMBS}; + +use super::{LkOp, SideEffectSink}; + +pub fn emit_assert_lt_ops( + sink: &mut impl SideEffectSink, + lt_cfg: &AssertLtConfig, + lhs: u64, + rhs: u64, +) { + let max_bits = lt_cfg.0.max_bits; + let diff = cal_lt_diff(lhs < rhs, max_bits, lhs, rhs); + for i in 0..(max_bits / u16::BITS as usize) { + let value = ((diff >> (i * u16::BITS as usize)) & 0xffff) as u16; + sink.emit_lk(LkOp::AssertU16 { value }); + } + let remain_bits = max_bits % u16::BITS as usize; + if remain_bits > 1 { + let value = (diff >> ((lt_cfg.0.diff.len() - 1) * u16::BITS as usize)) & 0xffff; + sink.emit_lk(LkOp::DynamicRange { + value, + bits: remain_bits as u32, + }); + } +} + +pub fn emit_u16_limbs(sink: &mut impl SideEffectSink, value: u32) { + sink.emit_lk(LkOp::AssertU16 { + value: (value & 0xffff) as u16, + }); + sink.emit_lk(LkOp::AssertU16 { + value: (value >> 16) as u16, + }); +} + +pub fn emit_const_range_op(sink: &mut impl SideEffectSink, value: u64, bits: usize) { + match bits { + 0 | 1 => {} + 14 => sink.emit_lk(LkOp::AssertU14 { + value: value as u16, + }), + 16 => sink.emit_lk(LkOp::AssertU16 { + value: value as u16, + }), + _ => sink.emit_lk(LkOp::DynamicRange { + value, + bits: bits as u32, + }), + } +} + +pub fn emit_byte_decomposition_ops(sink: &mut impl SideEffectSink, bytes: &[u8]) { + for chunk in bytes.chunks(2) { + match chunk { + [a, b] => sink.emit_lk(LkOp::DoubleU8 { a: *a, b: *b }), + [a] => emit_const_range_op(sink, *a as u64, 8), + _ => unreachable!(), + } + } +} + +pub fn emit_signed_extend_op(sink: &mut impl SideEffectSink, n_bits: usize, value: u64) { + let msb = value >> (n_bits - 1); + sink.emit_lk(LkOp::DynamicRange { + value: 2 * value - (msb << n_bits), + bits: n_bits as u32, + }); +} + +pub fn emit_logic_u8_ops( + sink: &mut impl SideEffectSink, + lhs: u64, + rhs: u64, + num_bytes: usize, +) { + for i in 0..num_bytes { + let a = ((lhs >> (i * 8)) & 0xff) as u8; + let b = ((rhs >> (i * 8)) & 0xff) as u8; + let op = match OP::ROM_TYPE { + LookupTable::And => LkOp::And { a, b }, + LookupTable::Or => LkOp::Or { a, b }, + LookupTable::Xor => LkOp::Xor { a, b }, + LookupTable::Ltu => LkOp::Ltu { a, b }, + rom_type => unreachable!("unsupported logic table: {rom_type:?}"), + }; + sink.emit_lk(op); + } +} + +pub fn emit_uint_limbs_lt_ops( + sink: &mut impl SideEffectSink, + is_sign_comparison: bool, + a: &[u16], + b: &[u16], +) { + assert_eq!(a.len(), UINT_LIMBS); + assert_eq!(b.len(), UINT_LIMBS); + + let last = UINT_LIMBS - 1; + let sign_mask = 1 << (LIMB_BITS - 1); + let is_a_neg = is_sign_comparison && (a[last] & sign_mask) != 0; + let is_b_neg = is_sign_comparison && (b[last] & sign_mask) != 0; + + let (cmp_lt, diff_idx) = (0..UINT_LIMBS) + .rev() + .find(|&i| a[i] != b[i]) + .map(|i| ((a[i] < b[i]) ^ is_a_neg ^ is_b_neg, i)) + .unwrap_or((false, UINT_LIMBS)); + + let a_msb_range = if is_a_neg { + a[last] - sign_mask + } else { + a[last] + ((is_sign_comparison as u16) << (LIMB_BITS - 1)) + }; + let b_msb_range = if is_b_neg { + b[last] - sign_mask + } else { + b[last] + ((is_sign_comparison as u16) << (LIMB_BITS - 1)) + }; + + let to_signed = |value: u16, is_neg: bool| -> i32 { + if is_neg { + value as i32 - (1 << LIMB_BITS) + } else { + value as i32 + } + }; + let diff_val = if diff_idx == UINT_LIMBS { + 0 + } else if diff_idx == last { + let a_signed = to_signed(a[last], is_a_neg); + let b_signed = to_signed(b[last], is_b_neg); + if cmp_lt { + (b_signed - a_signed) as u16 + } else { + (a_signed - b_signed) as u16 + } + } else if cmp_lt { + b[diff_idx] - a[diff_idx] + } else { + a[diff_idx] - b[diff_idx] + }; + + emit_const_range_op( + sink, + if diff_idx == UINT_LIMBS { + 0 + } else { + (diff_val - 1) as u64 + }, + LIMB_BITS, + ); + emit_const_range_op(sink, a_msb_range as u64, LIMB_BITS); + emit_const_range_op(sink, b_msb_range as u64, LIMB_BITS); +} diff --git a/ceno_zkvm/src/instructions/host_ops/lk_ops.rs b/ceno_zkvm/src/instructions/host_ops/lk_ops.rs new file mode 100644 index 000000000..ef409f845 --- /dev/null +++ b/ceno_zkvm/src/instructions/host_ops/lk_ops.rs @@ -0,0 +1,76 @@ +use ceno_emul::{Cycle, Word, WordAddr}; +use gkr_iop::tables::LookupTable; +use smallvec::SmallVec; + +use crate::structs::RAMType; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum LkOp { + AssertU16 { value: u16 }, + DynamicRange { value: u64, bits: u32 }, + AssertU14 { value: u16 }, + Fetch { pc: u32 }, + DoubleU8 { a: u8, b: u8 }, + And { a: u8, b: u8 }, + Or { a: u8, b: u8 }, + Xor { a: u8, b: u8 }, + Ltu { a: u8, b: u8 }, + Pow2 { value: u8 }, + ShrByte { shift: u8, carry: u8, bits: u8 }, +} + +impl LkOp { + pub fn encode_all(&self) -> SmallVec<[(LookupTable, u64); 2]> { + match *self { + LkOp::AssertU16 { value } => { + SmallVec::from_slice(&[(LookupTable::Dynamic, (1u64 << 16) + value as u64)]) + } + LkOp::DynamicRange { value, bits } => { + SmallVec::from_slice(&[(LookupTable::Dynamic, (1u64 << bits) + value)]) + } + LkOp::AssertU14 { value } => { + SmallVec::from_slice(&[(LookupTable::Dynamic, (1u64 << 14) + value as u64)]) + } + LkOp::Fetch { pc } => SmallVec::from_slice(&[(LookupTable::Instruction, pc as u64)]), + LkOp::DoubleU8 { a, b } => { + SmallVec::from_slice(&[(LookupTable::DoubleU8, ((a as u64) << 8) + b as u64)]) + } + LkOp::And { a, b } => { + SmallVec::from_slice(&[(LookupTable::And, (a as u64) | ((b as u64) << 8))]) + } + LkOp::Or { a, b } => { + SmallVec::from_slice(&[(LookupTable::Or, (a as u64) | ((b as u64) << 8))]) + } + LkOp::Xor { a, b } => { + SmallVec::from_slice(&[(LookupTable::Xor, (a as u64) | ((b as u64) << 8))]) + } + LkOp::Ltu { a, b } => { + SmallVec::from_slice(&[(LookupTable::Ltu, (a as u64) | ((b as u64) << 8))]) + } + LkOp::Pow2 { value } => { + SmallVec::from_slice(&[(LookupTable::Pow, 2u64 | ((value as u64) << 8))]) + } + LkOp::ShrByte { shift, carry, bits } => SmallVec::from_slice(&[ + ( + LookupTable::DoubleU8, + ((shift as u64) << 8) + ((shift as u64) << bits), + ), + ( + LookupTable::DoubleU8, + ((carry as u64) << 8) + ((carry as u64) << (8 - bits)), + ), + ]), + } + } +} + +#[derive(Debug, Clone, Copy)] +pub struct SendEvent { + pub ram_type: RAMType, + pub addr: WordAddr, + pub id: u64, + pub cycle: Cycle, + pub prev_cycle: Cycle, + pub value: Word, + pub prev_value: Option, +} diff --git a/ceno_zkvm/src/instructions/side_effects.rs b/ceno_zkvm/src/instructions/host_ops/mod.rs similarity index 77% rename from ceno_zkvm/src/instructions/side_effects.rs rename to ceno_zkvm/src/instructions/host_ops/mod.rs index 97695c526..d18ed10b8 100644 --- a/ceno_zkvm/src/instructions/side_effects.rs +++ b/ceno_zkvm/src/instructions/host_ops/mod.rs @@ -1,295 +1,17 @@ -use ceno_emul::{Cycle, Word, WordAddr}; -use gkr_iop::{ - gadgets::{AssertLtConfig, cal_lt_diff}, - tables::{LookupTable, OpsTable}, -}; -use smallvec::SmallVec; -use std::marker::PhantomData; +//! Host-side operations for GPU-CPU hybrid witness generation. +//! +//! Contains lookup/shard side-effect collection abstractions and CPU fallback paths. -use crate::{ - e2e::ShardContext, - instructions::riscv::constants::{LIMB_BITS, UINT_LIMBS}, - structs::RAMType, - witness::LkMultiplicity, -}; +mod lk_ops; +mod sink; +mod emit; +mod cpu_fallback; -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum LkOp { - AssertU16 { value: u16 }, - DynamicRange { value: u64, bits: u32 }, - AssertU14 { value: u16 }, - Fetch { pc: u32 }, - DoubleU8 { a: u8, b: u8 }, - And { a: u8, b: u8 }, - Or { a: u8, b: u8 }, - Xor { a: u8, b: u8 }, - Ltu { a: u8, b: u8 }, - Pow2 { value: u8 }, - ShrByte { shift: u8, carry: u8, bits: u8 }, -} - -impl LkOp { - pub fn encode_all(&self) -> SmallVec<[(LookupTable, u64); 2]> { - match *self { - LkOp::AssertU16 { value } => { - SmallVec::from_slice(&[(LookupTable::Dynamic, (1u64 << 16) + value as u64)]) - } - LkOp::DynamicRange { value, bits } => { - SmallVec::from_slice(&[(LookupTable::Dynamic, (1u64 << bits) + value)]) - } - LkOp::AssertU14 { value } => { - SmallVec::from_slice(&[(LookupTable::Dynamic, (1u64 << 14) + value as u64)]) - } - LkOp::Fetch { pc } => SmallVec::from_slice(&[(LookupTable::Instruction, pc as u64)]), - LkOp::DoubleU8 { a, b } => { - SmallVec::from_slice(&[(LookupTable::DoubleU8, ((a as u64) << 8) + b as u64)]) - } - LkOp::And { a, b } => { - SmallVec::from_slice(&[(LookupTable::And, (a as u64) | ((b as u64) << 8))]) - } - LkOp::Or { a, b } => { - SmallVec::from_slice(&[(LookupTable::Or, (a as u64) | ((b as u64) << 8))]) - } - LkOp::Xor { a, b } => { - SmallVec::from_slice(&[(LookupTable::Xor, (a as u64) | ((b as u64) << 8))]) - } - LkOp::Ltu { a, b } => { - SmallVec::from_slice(&[(LookupTable::Ltu, (a as u64) | ((b as u64) << 8))]) - } - LkOp::Pow2 { value } => { - SmallVec::from_slice(&[(LookupTable::Pow, 2u64 | ((value as u64) << 8))]) - } - LkOp::ShrByte { shift, carry, bits } => SmallVec::from_slice(&[ - ( - LookupTable::DoubleU8, - ((shift as u64) << 8) + ((shift as u64) << bits), - ), - ( - LookupTable::DoubleU8, - ((carry as u64) << 8) + ((carry as u64) << (8 - bits)), - ), - ]), - } - } -} - -#[derive(Debug, Clone, Copy)] -pub struct SendEvent { - pub ram_type: RAMType, - pub addr: WordAddr, - pub id: u64, - pub cycle: Cycle, - pub prev_cycle: Cycle, - pub value: Word, - pub prev_value: Option, -} - -pub trait SideEffectSink { - fn emit_lk(&mut self, op: LkOp); - fn emit_send(&mut self, event: SendEvent); - fn touch_addr(&mut self, addr: WordAddr); -} - -pub struct CpuSideEffectSink<'ctx, 'shard, 'lk> { - shard_ctx: *mut ShardContext<'shard>, - lk: &'lk mut LkMultiplicity, - _marker: PhantomData<&'ctx mut ShardContext<'shard>>, -} - -impl<'ctx, 'shard, 'lk> CpuSideEffectSink<'ctx, 'shard, 'lk> { - pub unsafe fn from_raw( - shard_ctx: *mut ShardContext<'shard>, - lk: &'lk mut LkMultiplicity, - ) -> Self { - Self { - shard_ctx, - lk, - _marker: PhantomData, - } - } - - fn shard_ctx(&mut self) -> &mut ShardContext<'shard> { - // Safety: `from_raw` is only constructed from a live `&mut ShardContext` - // for the duration of side-effect collection. - unsafe { &mut *self.shard_ctx } - } -} - -impl SideEffectSink for CpuSideEffectSink<'_, '_, '_> { - fn emit_lk(&mut self, op: LkOp) { - for (table, key) in op.encode_all() { - self.lk.increment(table, key); - } - } - - fn emit_send(&mut self, event: SendEvent) { - self.shard_ctx().record_send_without_touch( - event.ram_type, - event.addr, - event.id, - event.cycle, - event.prev_cycle, - event.value, - event.prev_value, - ); - } - - fn touch_addr(&mut self, addr: WordAddr) { - self.shard_ctx().push_addr_accessed(addr); - } -} - -pub fn emit_assert_lt_ops( - sink: &mut impl SideEffectSink, - lt_cfg: &AssertLtConfig, - lhs: u64, - rhs: u64, -) { - let max_bits = lt_cfg.0.max_bits; - let diff = cal_lt_diff(lhs < rhs, max_bits, lhs, rhs); - for i in 0..(max_bits / u16::BITS as usize) { - let value = ((diff >> (i * u16::BITS as usize)) & 0xffff) as u16; - sink.emit_lk(LkOp::AssertU16 { value }); - } - let remain_bits = max_bits % u16::BITS as usize; - if remain_bits > 1 { - let value = (diff >> ((lt_cfg.0.diff.len() - 1) * u16::BITS as usize)) & 0xffff; - sink.emit_lk(LkOp::DynamicRange { - value, - bits: remain_bits as u32, - }); - } -} - -pub fn emit_u16_limbs(sink: &mut impl SideEffectSink, value: u32) { - sink.emit_lk(LkOp::AssertU16 { - value: (value & 0xffff) as u16, - }); - sink.emit_lk(LkOp::AssertU16 { - value: (value >> 16) as u16, - }); -} - -pub fn emit_const_range_op(sink: &mut impl SideEffectSink, value: u64, bits: usize) { - match bits { - 0 | 1 => {} - 14 => sink.emit_lk(LkOp::AssertU14 { - value: value as u16, - }), - 16 => sink.emit_lk(LkOp::AssertU16 { - value: value as u16, - }), - _ => sink.emit_lk(LkOp::DynamicRange { - value, - bits: bits as u32, - }), - } -} - -pub fn emit_byte_decomposition_ops(sink: &mut impl SideEffectSink, bytes: &[u8]) { - for chunk in bytes.chunks(2) { - match chunk { - [a, b] => sink.emit_lk(LkOp::DoubleU8 { a: *a, b: *b }), - [a] => emit_const_range_op(sink, *a as u64, 8), - _ => unreachable!(), - } - } -} - -pub fn emit_signed_extend_op(sink: &mut impl SideEffectSink, n_bits: usize, value: u64) { - let msb = value >> (n_bits - 1); - sink.emit_lk(LkOp::DynamicRange { - value: 2 * value - (msb << n_bits), - bits: n_bits as u32, - }); -} - -pub fn emit_logic_u8_ops( - sink: &mut impl SideEffectSink, - lhs: u64, - rhs: u64, - num_bytes: usize, -) { - for i in 0..num_bytes { - let a = ((lhs >> (i * 8)) & 0xff) as u8; - let b = ((rhs >> (i * 8)) & 0xff) as u8; - let op = match OP::ROM_TYPE { - LookupTable::And => LkOp::And { a, b }, - LookupTable::Or => LkOp::Or { a, b }, - LookupTable::Xor => LkOp::Xor { a, b }, - LookupTable::Ltu => LkOp::Ltu { a, b }, - rom_type => unreachable!("unsupported logic table: {rom_type:?}"), - }; - sink.emit_lk(op); - } -} - -pub fn emit_uint_limbs_lt_ops( - sink: &mut impl SideEffectSink, - is_sign_comparison: bool, - a: &[u16], - b: &[u16], -) { - assert_eq!(a.len(), UINT_LIMBS); - assert_eq!(b.len(), UINT_LIMBS); - - let last = UINT_LIMBS - 1; - let sign_mask = 1 << (LIMB_BITS - 1); - let is_a_neg = is_sign_comparison && (a[last] & sign_mask) != 0; - let is_b_neg = is_sign_comparison && (b[last] & sign_mask) != 0; - - let (cmp_lt, diff_idx) = (0..UINT_LIMBS) - .rev() - .find(|&i| a[i] != b[i]) - .map(|i| ((a[i] < b[i]) ^ is_a_neg ^ is_b_neg, i)) - .unwrap_or((false, UINT_LIMBS)); - - let a_msb_range = if is_a_neg { - a[last] - sign_mask - } else { - a[last] + ((is_sign_comparison as u16) << (LIMB_BITS - 1)) - }; - let b_msb_range = if is_b_neg { - b[last] - sign_mask - } else { - b[last] + ((is_sign_comparison as u16) << (LIMB_BITS - 1)) - }; - - let to_signed = |value: u16, is_neg: bool| -> i32 { - if is_neg { - value as i32 - (1 << LIMB_BITS) - } else { - value as i32 - } - }; - let diff_val = if diff_idx == UINT_LIMBS { - 0 - } else if diff_idx == last { - let a_signed = to_signed(a[last], is_a_neg); - let b_signed = to_signed(b[last], is_b_neg); - if cmp_lt { - (b_signed - a_signed) as u16 - } else { - (a_signed - b_signed) as u16 - } - } else if cmp_lt { - b[diff_idx] - a[diff_idx] - } else { - a[diff_idx] - b[diff_idx] - }; - - emit_const_range_op( - sink, - if diff_idx == UINT_LIMBS { - 0 - } else { - (diff_val - 1) as u64 - }, - LIMB_BITS, - ); - emit_const_range_op(sink, a_msb_range as u64, LIMB_BITS); - emit_const_range_op(sink, b_msb_range as u64, LIMB_BITS); -} +// Re-export all public types for convenience +pub use lk_ops::*; +pub use sink::*; +pub use emit::*; +pub use cpu_fallback::*; #[cfg(test)] mod tests { @@ -313,6 +35,7 @@ mod tests { }, }, structs::ProgramParams, + witness::LkMultiplicity, }; use ceno_emul::{ ByteAddr, Change, InsnKind, PC_STEP_SIZE, ReadOp, StepRecord, WordAddr, WriteOp, diff --git a/ceno_zkvm/src/instructions/host_ops/sink.rs b/ceno_zkvm/src/instructions/host_ops/sink.rs new file mode 100644 index 000000000..258269332 --- /dev/null +++ b/ceno_zkvm/src/instructions/host_ops/sink.rs @@ -0,0 +1,61 @@ +use ceno_emul::WordAddr; +use std::marker::PhantomData; + +use crate::{e2e::ShardContext, witness::LkMultiplicity}; + +use super::{LkOp, SendEvent}; + +pub trait SideEffectSink { + fn emit_lk(&mut self, op: LkOp); + fn emit_send(&mut self, event: SendEvent); + fn touch_addr(&mut self, addr: WordAddr); +} + +pub struct CpuSideEffectSink<'ctx, 'shard, 'lk> { + shard_ctx: *mut ShardContext<'shard>, + lk: &'lk mut LkMultiplicity, + _marker: PhantomData<&'ctx mut ShardContext<'shard>>, +} + +impl<'ctx, 'shard, 'lk> CpuSideEffectSink<'ctx, 'shard, 'lk> { + pub unsafe fn from_raw( + shard_ctx: *mut ShardContext<'shard>, + lk: &'lk mut LkMultiplicity, + ) -> Self { + Self { + shard_ctx, + lk, + _marker: PhantomData, + } + } + + fn shard_ctx(&mut self) -> &mut ShardContext<'shard> { + // Safety: `from_raw` is only constructed from a live `&mut ShardContext` + // for the duration of side-effect collection. + unsafe { &mut *self.shard_ctx } + } +} + +impl SideEffectSink for CpuSideEffectSink<'_, '_, '_> { + fn emit_lk(&mut self, op: LkOp) { + for (table, key) in op.encode_all() { + self.lk.increment(table, key); + } + } + + fn emit_send(&mut self, event: SendEvent) { + self.shard_ctx().record_send_without_touch( + event.ram_type, + event.addr, + event.id, + event.cycle, + event.prev_cycle, + event.value, + event.prev_value, + ); + } + + fn touch_addr(&mut self, addr: WordAddr) { + self.shard_ctx().push_addr_accessed(addr); + } +} From b54af5041536226cf0337538bdbca9b3edd2e4a6 Mon Sep 17 00:00:00 2001 From: Velaciela Date: Mon, 23 Mar 2026 19:09:43 +0800 Subject: [PATCH 55/73] part-c: fix --- ceno_zkvm/src/scheme/septic_curve.rs | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/ceno_zkvm/src/scheme/septic_curve.rs b/ceno_zkvm/src/scheme/septic_curve.rs index 23a7ad0e3..98edd8615 100644 --- a/ceno_zkvm/src/scheme/septic_curve.rs +++ b/ceno_zkvm/src/scheme/septic_curve.rs @@ -1203,9 +1203,12 @@ mod tests { let mut mismatches = 0; for (i, (input, gpu_rec)) in test_inputs.iter().zip(gpu_results.iter()).enumerate() { - // Build CPU ShardRamRecord and compute EC point + // Build CPU ShardRamRecord and compute EC point. + // GPU kernel treats slot.addr as word address and converts to byte address + // via `<< 2` for Memory type; Register type uses reg_id directly. + let addr = if input.ram_type == 2 { input.addr << 2 } else { input.addr }; let cpu_record = ShardRamRecord { - addr: input.addr, + addr, ram_type: if input.ram_type == 1 { RAMType::Register } else { RAMType::Memory }, value: input.value, shard: input.shard, From 7ee32ed84d7a313fbbd24c41b68b863dce6b1332 Mon Sep 17 00:00:00 2001 From: Velaciela Date: Mon, 23 Mar 2026 20:13:40 +0800 Subject: [PATCH 56/73] simplify: macro-1,2 --- ceno_zkvm/src/instructions.rs | 101 +++++++++++++++ ceno_zkvm/src/instructions/riscv/arith.rs | 63 ++-------- .../riscv/arith_imm/arith_imm_circuit_v2.rs | 51 +------- ceno_zkvm/src/instructions/riscv/auipc.rs | 51 +------- .../riscv/branch/branch_circuit_v2.rs | 69 ++--------- .../instructions/riscv/div/div_circuit_v2.rs | 72 ++--------- .../src/instructions/riscv/jump/jal_v2.rs | 51 +------- .../src/instructions/riscv/jump/jalr_v2.rs | 51 +------- .../instructions/riscv/logic/logic_circuit.rs | 61 ++-------- .../riscv/logic_imm/logic_imm_circuit_v2.rs | 61 ++-------- ceno_zkvm/src/instructions/riscv/lui.rs | 51 +------- .../src/instructions/riscv/memory/load.rs | 55 +-------- .../src/instructions/riscv/memory/load_v2.rs | 91 ++++---------- .../src/instructions/riscv/memory/store_v2.rs | 64 ++-------- .../riscv/mulh/mulh_circuit_v2.rs | 73 ++--------- .../riscv/shift/shift_circuit_v2.rs | 115 +++--------------- .../instructions/riscv/slt/slt_circuit_v2.rs | 60 ++------- .../riscv/slti/slti_circuit_v2.rs | 60 ++------- 18 files changed, 234 insertions(+), 966 deletions(-) diff --git a/ceno_zkvm/src/instructions.rs b/ceno_zkvm/src/instructions.rs index 548e99a9f..ce2febe7f 100644 --- a/ceno_zkvm/src/instructions.rs +++ b/ceno_zkvm/src/instructions.rs @@ -229,3 +229,104 @@ pub fn full_step_indices(steps: &[StepRecord]) -> Vec { (0..steps.len()).collect() } +// --------------------------------------------------------------------------- +// Macros to reduce per-chip boilerplate +// --------------------------------------------------------------------------- + +/// Implement `collect_shard_side_effects_instance` by delegating to +/// `config.$field.collect_shard_effects(shard_ctx, lk_multiplicity, step)`. +/// +/// Every chip's implementation is identical except for the config field name +/// (`r_insn`, `i_insn`, `b_insn`, `s_insn`, `j_insn`, `im_insn`). +/// +/// Usage inside `impl Instruction for MyChip`: +/// ```ignore +/// impl_collect_shard!(r_insn); +/// ``` +#[macro_export] +macro_rules! impl_collect_shard { + ($field:ident) => { + fn collect_shard_side_effects_instance( + config: &Self::InstructionConfig, + shard_ctx: &mut $crate::e2e::ShardContext, + lk_multiplicity: &mut $crate::witness::LkMultiplicity, + step: &ceno_emul::StepRecord, + ) -> Result<(), $crate::error::ZKVMError> { + config + .$field + .collect_shard_effects(shard_ctx, lk_multiplicity, step); + Ok(()) + } + }; +} + +/// Implement the `#[cfg(feature = "gpu")] fn assign_instances` override that: +/// 1. Computes `Option` from `$kind_expr` +/// 2. Tries `try_gpu_assign_instances` → returns on success +/// 3. Falls back to `cpu_assign_instances` +/// +/// Usage inside `impl Instruction for MyChip`: +/// ```ignore +/// // Single kind (always GPU): +/// impl_gpu_assign!(GpuWitgenKind::Lui); +/// +/// // Match expression → Option: +/// impl_gpu_assign!(match I::INST_KIND { +/// InsnKind::ADD => Some(GpuWitgenKind::Add), +/// InsnKind::SUB => Some(GpuWitgenKind::Sub), +/// _ => None, +/// }); +/// ``` +#[macro_export] +macro_rules! impl_gpu_assign { + // Match/block → Option + (match $($rest:tt)*) => { + $crate::impl_gpu_assign!(@ match $($rest)*); + }; + // Single kind — always use GPU + ($kind:expr) => { + $crate::impl_gpu_assign!(@ Some($kind)); + }; + (@ $kind_expr:expr) => { + #[cfg(feature = "gpu")] + fn assign_instances( + config: &Self::InstructionConfig, + shard_ctx: &mut $crate::e2e::ShardContext, + num_witin: usize, + num_structural_witin: usize, + shard_steps: &[ceno_emul::StepRecord], + step_indices: &[ceno_emul::StepIndex], + ) -> Result< + ( + $crate::tables::RMMCollections, + gkr_iop::utils::lk_multiplicity::Multiplicity, + ), + $crate::error::ZKVMError, + > { + use $crate::instructions::riscv::gpu::witgen_gpu; + let gpu_kind: Option = $kind_expr; + if let Some(kind) = gpu_kind { + if let Some(result) = witgen_gpu::try_gpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + kind, + )? { + return Ok(result); + } + } + $crate::instructions::cpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + ) + } + }; +} + diff --git a/ceno_zkvm/src/instructions/riscv/arith.rs b/ceno_zkvm/src/instructions/riscv/arith.rs index 5e9291499..523bec45d 100644 --- a/ceno_zkvm/src/instructions/riscv/arith.rs +++ b/ceno_zkvm/src/instructions/riscv/arith.rs @@ -2,6 +2,7 @@ use std::marker::PhantomData; use super::{RIVInstruction, constants::UInt, r_insn::RInstructionConfig}; use crate::{ + impl_collect_shard, impl_gpu_assign, circuit_builder::CircuitBuilder, e2e::ShardContext, error::ZKVMError, @@ -16,13 +17,6 @@ use crate::{ use ceno_emul::{InsnKind, StepRecord}; use ff_ext::ExtensionField; -#[cfg(feature = "gpu")] -use crate::tables::RMMCollections; -#[cfg(feature = "gpu")] -use ceno_emul::StepIndex; -#[cfg(feature = "gpu")] -use gkr_iop::utils::lk_multiplicity::Multiplicity; - /// This config handles R-Instructions that represent registers values as 2 * u16. #[derive(Debug)] pub struct ArithConfig { @@ -177,56 +171,13 @@ impl Instruction for ArithInstruction Result<(), ZKVMError> { - config - .r_insn - .collect_shard_effects(shard_ctx, lk_multiplicity, step); - Ok(()) - } + impl_collect_shard!(r_insn); - #[cfg(feature = "gpu")] - fn assign_instances( - config: &Self::InstructionConfig, - shard_ctx: &mut ShardContext, - num_witin: usize, - num_structural_witin: usize, - shard_steps: &[StepRecord], - step_indices: &[StepIndex], - ) -> Result<(RMMCollections, Multiplicity), ZKVMError> { - use crate::instructions::riscv::gpu::witgen_gpu; - let gpu_kind = match I::INST_KIND { - InsnKind::ADD => Some(witgen_gpu::GpuWitgenKind::Add), - InsnKind::SUB => Some(witgen_gpu::GpuWitgenKind::Sub), - _ => None, - }; - if let Some(kind) = gpu_kind { - if let Some(result) = witgen_gpu::try_gpu_assign_instances::( - config, - shard_ctx, - num_witin, - num_structural_witin, - shard_steps, - step_indices, - kind, - )? { - return Ok(result); - } - } - // Fallback to CPU path - crate::instructions::cpu_assign_instances::( - config, - shard_ctx, - num_witin, - num_structural_witin, - shard_steps, - step_indices, - ) - } + impl_gpu_assign!(match I::INST_KIND { + InsnKind::ADD => Some(witgen_gpu::GpuWitgenKind::Add), + InsnKind::SUB => Some(witgen_gpu::GpuWitgenKind::Sub), + _ => None, + }); } #[cfg(test)] diff --git a/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit_v2.rs index 53219490c..282ec6c4e 100644 --- a/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit_v2.rs @@ -3,6 +3,7 @@ use crate::{ circuit_builder::CircuitBuilder, e2e::ShardContext, error::ZKVMError, + impl_collect_shard, impl_gpu_assign, instructions::{ Instruction, riscv::{RIVInstruction, constants::UInt, i_insn::IInstructionConfig}, @@ -19,13 +20,6 @@ use p3::field::FieldAlgebra; use std::marker::PhantomData; use witness::set_val; -#[cfg(feature = "gpu")] -use crate::tables::RMMCollections; -#[cfg(feature = "gpu")] -use ceno_emul::StepIndex; -#[cfg(feature = "gpu")] -use gkr_iop::utils::lk_multiplicity::Multiplicity; - pub struct AddiInstruction(PhantomData); pub struct InstructionConfig { @@ -131,46 +125,7 @@ impl Instruction for AddiInstruction { Ok(()) } - fn collect_shard_side_effects_instance( - config: &Self::InstructionConfig, - shard_ctx: &mut ShardContext, - lk_multiplicity: &mut LkMultiplicity, - step: &StepRecord, - ) -> Result<(), ZKVMError> { - config - .i_insn - .collect_shard_effects(shard_ctx, lk_multiplicity, step); - Ok(()) - } + impl_collect_shard!(i_insn); - #[cfg(feature = "gpu")] - fn assign_instances( - config: &Self::InstructionConfig, - shard_ctx: &mut ShardContext, - num_witin: usize, - num_structural_witin: usize, - shard_steps: &[StepRecord], - step_indices: &[StepIndex], - ) -> Result<(RMMCollections, Multiplicity), ZKVMError> { - use crate::instructions::riscv::gpu::witgen_gpu; - if let Some(result) = witgen_gpu::try_gpu_assign_instances::( - config, - shard_ctx, - num_witin, - num_structural_witin, - shard_steps, - step_indices, - witgen_gpu::GpuWitgenKind::Addi, - )? { - return Ok(result); - } - crate::instructions::cpu_assign_instances::( - config, - shard_ctx, - num_witin, - num_structural_witin, - shard_steps, - step_indices, - ) - } + impl_gpu_assign!(witgen_gpu::GpuWitgenKind::Addi); } diff --git a/ceno_zkvm/src/instructions/riscv/auipc.rs b/ceno_zkvm/src/instructions/riscv/auipc.rs index d7ada6ff6..b4a7ec6ed 100644 --- a/ceno_zkvm/src/instructions/riscv/auipc.rs +++ b/ceno_zkvm/src/instructions/riscv/auipc.rs @@ -6,6 +6,7 @@ use crate::{ circuit_builder::CircuitBuilder, e2e::ShardContext, error::ZKVMError, + impl_collect_shard, impl_gpu_assign, instructions::{ Instruction, riscv::{ @@ -28,13 +29,6 @@ use multilinear_extensions::{Expression, ToExpr, WitIn}; use p3::field::{Field, FieldAlgebra}; use witness::set_val; -#[cfg(feature = "gpu")] -use crate::tables::RMMCollections; -#[cfg(feature = "gpu")] -use ceno_emul::StepIndex; -#[cfg(feature = "gpu")] -use gkr_iop::utils::lk_multiplicity::Multiplicity; - pub struct AuipcConfig { pub i_insn: IInstructionConfig, // The limbs of the immediate except the least significant limb since it is always 0 @@ -241,48 +235,9 @@ impl Instruction for AuipcInstruction { Ok(()) } - fn collect_shard_side_effects_instance( - config: &Self::InstructionConfig, - shard_ctx: &mut ShardContext, - lk_multiplicity: &mut LkMultiplicity, - step: &ceno_emul::StepRecord, - ) -> Result<(), ZKVMError> { - config - .i_insn - .collect_shard_effects(shard_ctx, lk_multiplicity, step); - Ok(()) - } + impl_collect_shard!(i_insn); - #[cfg(feature = "gpu")] - fn assign_instances( - config: &Self::InstructionConfig, - shard_ctx: &mut ShardContext, - num_witin: usize, - num_structural_witin: usize, - shard_steps: &[ceno_emul::StepRecord], - step_indices: &[StepIndex], - ) -> Result<(RMMCollections, Multiplicity), ZKVMError> { - use crate::instructions::riscv::gpu::witgen_gpu; - if let Some(result) = witgen_gpu::try_gpu_assign_instances::( - config, - shard_ctx, - num_witin, - num_structural_witin, - shard_steps, - step_indices, - witgen_gpu::GpuWitgenKind::Auipc, - )? { - return Ok(result); - } - crate::instructions::cpu_assign_instances::( - config, - shard_ctx, - num_witin, - num_structural_witin, - shard_steps, - step_indices, - ) - } + impl_gpu_assign!(witgen_gpu::GpuWitgenKind::Auipc); } #[cfg(test)] diff --git a/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs index 0951ab7ba..b0452c112 100644 --- a/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs @@ -4,6 +4,7 @@ use crate::{ e2e::ShardContext, error::ZKVMError, gadgets::{UIntLimbsLT, UIntLimbsLTConfig}, + impl_collect_shard, impl_gpu_assign, instructions::{ Instruction, riscv::{ @@ -237,61 +238,15 @@ impl Instruction for BranchCircuit Result<(), ZKVMError> { - config - .b_insn - .collect_shard_effects(shard_ctx, lk_multiplicity, step); - Ok(()) - } - - #[cfg(feature = "gpu")] - fn assign_instances( - config: &Self::InstructionConfig, - shard_ctx: &mut ShardContext, - num_witin: usize, - num_structural_witin: usize, - shard_steps: &[StepRecord], - step_indices: &[ceno_emul::StepIndex], - ) -> Result< - ( - crate::tables::RMMCollections, - gkr_iop::utils::lk_multiplicity::Multiplicity, - ), - crate::error::ZKVMError, - > { - use crate::instructions::riscv::gpu::witgen_gpu; - let kind = match I::INST_KIND { - InsnKind::BEQ => witgen_gpu::GpuWitgenKind::BranchEq(1), - InsnKind::BNE => witgen_gpu::GpuWitgenKind::BranchEq(0), - InsnKind::BLT => witgen_gpu::GpuWitgenKind::BranchCmp(1), - InsnKind::BGE => witgen_gpu::GpuWitgenKind::BranchCmp(1), - InsnKind::BLTU => witgen_gpu::GpuWitgenKind::BranchCmp(0), - InsnKind::BGEU => witgen_gpu::GpuWitgenKind::BranchCmp(0), - _ => unreachable!(), - }; - if let Some(result) = witgen_gpu::try_gpu_assign_instances::( - config, - shard_ctx, - num_witin, - num_structural_witin, - shard_steps, - step_indices, - kind, - )? { - return Ok(result); - } - crate::instructions::cpu_assign_instances::( - config, - shard_ctx, - num_witin, - num_structural_witin, - shard_steps, - step_indices, - ) - } + impl_collect_shard!(b_insn); + + impl_gpu_assign!(match I::INST_KIND { + InsnKind::BEQ => Some(witgen_gpu::GpuWitgenKind::BranchEq(1)), + InsnKind::BNE => Some(witgen_gpu::GpuWitgenKind::BranchEq(0)), + InsnKind::BLT => Some(witgen_gpu::GpuWitgenKind::BranchCmp(1)), + InsnKind::BGE => Some(witgen_gpu::GpuWitgenKind::BranchCmp(1)), + InsnKind::BLTU => Some(witgen_gpu::GpuWitgenKind::BranchCmp(0)), + InsnKind::BGEU => Some(witgen_gpu::GpuWitgenKind::BranchCmp(0)), + _ => unreachable!(), + }); } diff --git a/ceno_zkvm/src/instructions/riscv/div/div_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/div/div_circuit_v2.rs index 474174ed9..302fc8f23 100644 --- a/ceno_zkvm/src/instructions/riscv/div/div_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/div/div_circuit_v2.rs @@ -14,6 +14,7 @@ use crate::{ circuit_builder::CircuitBuilder, e2e::ShardContext, error::ZKVMError, + impl_collect_shard, impl_gpu_assign, instructions::{ Instruction, riscv::constants::LIMB_BITS, @@ -382,58 +383,13 @@ impl Instruction for ArithInstruction Result< - ( - crate::tables::RMMCollections, - gkr_iop::utils::lk_multiplicity::Multiplicity, - ), - ZKVMError, - > { - use crate::instructions::riscv::gpu::witgen_gpu; - let div_kind = match I::INST_KIND { - InsnKind::DIV => 0u32, - InsnKind::DIVU => 1u32, - InsnKind::REM => 2u32, - InsnKind::REMU => 3u32, - _ => { - return crate::instructions::cpu_assign_instances::( - config, - shard_ctx, - num_witin, - num_structural_witin, - shard_steps, - step_indices, - ); - } - }; - if let Some(result) = witgen_gpu::try_gpu_assign_instances::( - config, - shard_ctx, - num_witin, - num_structural_witin, - shard_steps, - step_indices, - witgen_gpu::GpuWitgenKind::Div(div_kind), - )? { - return Ok(result); - } - crate::instructions::cpu_assign_instances::( - config, - shard_ctx, - num_witin, - num_structural_witin, - shard_steps, - step_indices, - ) - } + impl_gpu_assign!(match I::INST_KIND { + InsnKind::DIV => Some(witgen_gpu::GpuWitgenKind::Div(0u32)), + InsnKind::DIVU => Some(witgen_gpu::GpuWitgenKind::Div(1u32)), + InsnKind::REM => Some(witgen_gpu::GpuWitgenKind::Div(2u32)), + InsnKind::REMU => Some(witgen_gpu::GpuWitgenKind::Div(3u32)), + _ => None, + }); fn assign_instance( config: &Self::InstructionConfig, @@ -675,17 +631,7 @@ impl Instruction for ArithInstruction Result<(), ZKVMError> { - config - .r_insn - .collect_shard_effects(shard_ctx, lk_multiplicity, step); - Ok(()) - } + impl_collect_shard!(r_insn); } #[derive(Debug, Eq, PartialEq)] diff --git a/ceno_zkvm/src/instructions/riscv/jump/jal_v2.rs b/ceno_zkvm/src/instructions/riscv/jump/jal_v2.rs index 130dac8fc..23e55c2e7 100644 --- a/ceno_zkvm/src/instructions/riscv/jump/jal_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/jump/jal_v2.rs @@ -6,6 +6,7 @@ use crate::{ circuit_builder::CircuitBuilder, e2e::ShardContext, error::ZKVMError, + impl_collect_shard, impl_gpu_assign, instructions::{ Instruction, riscv::{ @@ -23,13 +24,6 @@ use gkr_iop::tables::{LookupTable, ops::XorTable}; use multilinear_extensions::{Expression, ToExpr}; use p3::field::FieldAlgebra; -#[cfg(feature = "gpu")] -use crate::tables::RMMCollections; -#[cfg(feature = "gpu")] -use ceno_emul::StepIndex; -#[cfg(feature = "gpu")] -use gkr_iop::utils::lk_multiplicity::Multiplicity; - pub struct JalConfig { pub j_insn: JInstructionConfig, pub rd_written: UInt8, @@ -159,46 +153,7 @@ impl Instruction for JalInstruction { Ok(()) } - fn collect_shard_side_effects_instance( - config: &Self::InstructionConfig, - shard_ctx: &mut ShardContext, - lk_multiplicity: &mut LkMultiplicity, - step: &ceno_emul::StepRecord, - ) -> Result<(), ZKVMError> { - config - .j_insn - .collect_shard_effects(shard_ctx, lk_multiplicity, step); - Ok(()) - } + impl_collect_shard!(j_insn); - #[cfg(feature = "gpu")] - fn assign_instances( - config: &Self::InstructionConfig, - shard_ctx: &mut ShardContext, - num_witin: usize, - num_structural_witin: usize, - shard_steps: &[ceno_emul::StepRecord], - step_indices: &[StepIndex], - ) -> Result<(RMMCollections, Multiplicity), ZKVMError> { - use crate::instructions::riscv::gpu::witgen_gpu; - if let Some(result) = witgen_gpu::try_gpu_assign_instances::( - config, - shard_ctx, - num_witin, - num_structural_witin, - shard_steps, - step_indices, - witgen_gpu::GpuWitgenKind::Jal, - )? { - return Ok(result); - } - crate::instructions::cpu_assign_instances::( - config, - shard_ctx, - num_witin, - num_structural_witin, - shard_steps, - step_indices, - ) - } + impl_gpu_assign!(witgen_gpu::GpuWitgenKind::Jal); } diff --git a/ceno_zkvm/src/instructions/riscv/jump/jalr_v2.rs b/ceno_zkvm/src/instructions/riscv/jump/jalr_v2.rs index 644ba2a45..236a672eb 100644 --- a/ceno_zkvm/src/instructions/riscv/jump/jalr_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/jump/jalr_v2.rs @@ -7,6 +7,7 @@ use crate::{ circuit_builder::CircuitBuilder, e2e::ShardContext, error::ZKVMError, + impl_collect_shard, impl_gpu_assign, instructions::{ Instruction, riscv::{ @@ -26,13 +27,6 @@ use ff_ext::FieldInto; use multilinear_extensions::{Expression, ToExpr, WitIn}; use p3::field::{Field, FieldAlgebra}; -#[cfg(feature = "gpu")] -use crate::tables::RMMCollections; -#[cfg(feature = "gpu")] -use ceno_emul::StepIndex; -#[cfg(feature = "gpu")] -use gkr_iop::utils::lk_multiplicity::Multiplicity; - pub struct JalrConfig { pub i_insn: IInstructionConfig, pub rs1_read: UInt, @@ -224,46 +218,7 @@ impl Instruction for JalrInstruction { Ok(()) } - fn collect_shard_side_effects_instance( - config: &Self::InstructionConfig, - shard_ctx: &mut ShardContext, - lk_multiplicity: &mut LkMultiplicity, - step: &ceno_emul::StepRecord, - ) -> Result<(), ZKVMError> { - config - .i_insn - .collect_shard_effects(shard_ctx, lk_multiplicity, step); - Ok(()) - } + impl_collect_shard!(i_insn); - #[cfg(feature = "gpu")] - fn assign_instances( - config: &Self::InstructionConfig, - shard_ctx: &mut ShardContext, - num_witin: usize, - num_structural_witin: usize, - shard_steps: &[ceno_emul::StepRecord], - step_indices: &[StepIndex], - ) -> Result<(RMMCollections, Multiplicity), ZKVMError> { - use crate::instructions::riscv::gpu::witgen_gpu; - if let Some(result) = witgen_gpu::try_gpu_assign_instances::( - config, - shard_ctx, - num_witin, - num_structural_witin, - shard_steps, - step_indices, - witgen_gpu::GpuWitgenKind::Jalr, - )? { - return Ok(result); - } - crate::instructions::cpu_assign_instances::( - config, - shard_ctx, - num_witin, - num_structural_witin, - shard_steps, - step_indices, - ) - } + impl_gpu_assign!(witgen_gpu::GpuWitgenKind::Jalr); } diff --git a/ceno_zkvm/src/instructions/riscv/logic/logic_circuit.rs b/ceno_zkvm/src/instructions/riscv/logic/logic_circuit.rs index aae61aef5..087409ba7 100644 --- a/ceno_zkvm/src/instructions/riscv/logic/logic_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/logic/logic_circuit.rs @@ -8,6 +8,7 @@ use crate::{ circuit_builder::CircuitBuilder, e2e::ShardContext, error::ZKVMError, + impl_collect_shard, impl_gpu_assign, instructions::{ Instruction, riscv::{constants::UInt8, r_insn::RInstructionConfig}, @@ -19,13 +20,6 @@ use crate::{ }; use ceno_emul::{InsnKind, StepRecord}; -#[cfg(feature = "gpu")] -use crate::tables::RMMCollections; -#[cfg(feature = "gpu")] -use ceno_emul::StepIndex; -#[cfg(feature = "gpu")] -use gkr_iop::utils::lk_multiplicity::Multiplicity; - /// This trait defines a logic instruction, connecting an instruction type to a lookup table. pub trait LogicOp { const INST_KIND: InsnKind; @@ -102,53 +96,14 @@ impl Instruction for LogicInstruction { Ok(()) } - fn collect_shard_side_effects_instance( - config: &Self::InstructionConfig, - shard_ctx: &mut ShardContext, - lk_multiplicity: &mut LkMultiplicity, - step: &StepRecord, - ) -> Result<(), ZKVMError> { - config - .r_insn - .collect_shard_effects(shard_ctx, lk_multiplicity, step); - Ok(()) - } + impl_collect_shard!(r_insn); - #[cfg(feature = "gpu")] - fn assign_instances( - config: &Self::InstructionConfig, - shard_ctx: &mut ShardContext, - num_witin: usize, - num_structural_witin: usize, - shard_steps: &[StepRecord], - step_indices: &[StepIndex], - ) -> Result<(RMMCollections, Multiplicity), ZKVMError> { - use crate::instructions::riscv::gpu::witgen_gpu; - if let Some(result) = witgen_gpu::try_gpu_assign_instances::( - config, - shard_ctx, - num_witin, - num_structural_witin, - shard_steps, - step_indices, - witgen_gpu::GpuWitgenKind::LogicR(match I::INST_KIND { - InsnKind::AND => 0, - InsnKind::OR => 1, - InsnKind::XOR => 2, - kind => unreachable!("unsupported logic GPU kind: {kind:?}"), - }), - )? { - return Ok(result); - } - crate::instructions::cpu_assign_instances::( - config, - shard_ctx, - num_witin, - num_structural_witin, - shard_steps, - step_indices, - ) - } + impl_gpu_assign!(witgen_gpu::GpuWitgenKind::LogicR(match I::INST_KIND { + InsnKind::AND => 0, + InsnKind::OR => 1, + InsnKind::XOR => 2, + kind => unreachable!("unsupported logic GPU kind: {kind:?}"), + })); } /// This config implements R-Instructions that represent registers values as 4 * u8. diff --git a/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit_v2.rs index 2eb89036c..a4c258c41 100644 --- a/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit_v2.rs @@ -9,6 +9,7 @@ use crate::{ circuit_builder::CircuitBuilder, e2e::ShardContext, error::ZKVMError, + impl_collect_shard, impl_gpu_assign, instructions::{ Instruction, riscv::{ @@ -25,13 +26,6 @@ use crate::{ witness::LkMultiplicity, }; use ceno_emul::{InsnKind, StepRecord}; - -#[cfg(feature = "gpu")] -use crate::tables::RMMCollections; -#[cfg(feature = "gpu")] -use ceno_emul::StepIndex; -#[cfg(feature = "gpu")] -use gkr_iop::utils::lk_multiplicity::Multiplicity; use multilinear_extensions::ToExpr; /// The Instruction circuit for a given LogicOp. @@ -160,53 +154,14 @@ impl Instruction for LogicInstruction { Ok(()) } - fn collect_shard_side_effects_instance( - config: &Self::InstructionConfig, - shard_ctx: &mut ShardContext, - lk_multiplicity: &mut LkMultiplicity, - step: &StepRecord, - ) -> Result<(), ZKVMError> { - config - .i_insn - .collect_shard_effects(shard_ctx, lk_multiplicity, step); - Ok(()) - } + impl_collect_shard!(i_insn); - #[cfg(feature = "gpu")] - fn assign_instances( - config: &Self::InstructionConfig, - shard_ctx: &mut ShardContext, - num_witin: usize, - num_structural_witin: usize, - shard_steps: &[StepRecord], - step_indices: &[StepIndex], - ) -> Result<(RMMCollections, Multiplicity), ZKVMError> { - use crate::instructions::riscv::gpu::witgen_gpu; - if let Some(result) = witgen_gpu::try_gpu_assign_instances::( - config, - shard_ctx, - num_witin, - num_structural_witin, - shard_steps, - step_indices, - witgen_gpu::GpuWitgenKind::LogicI(match I::INST_KIND { - InsnKind::ANDI => 0, - InsnKind::ORI => 1, - InsnKind::XORI => 2, - kind => unreachable!("unsupported logic_imm GPU kind: {kind:?}"), - }), - )? { - return Ok(result); - } - crate::instructions::cpu_assign_instances::( - config, - shard_ctx, - num_witin, - num_structural_witin, - shard_steps, - step_indices, - ) - } + impl_gpu_assign!(witgen_gpu::GpuWitgenKind::LogicI(match I::INST_KIND { + InsnKind::ANDI => 0, + InsnKind::ORI => 1, + InsnKind::XORI => 2, + kind => unreachable!("unsupported logic_imm GPU kind: {kind:?}"), + })); } /// This config implements I-Instructions that represent registers values as 4 * u8. diff --git a/ceno_zkvm/src/instructions/riscv/lui.rs b/ceno_zkvm/src/instructions/riscv/lui.rs index bc661d7a7..7359a96d5 100644 --- a/ceno_zkvm/src/instructions/riscv/lui.rs +++ b/ceno_zkvm/src/instructions/riscv/lui.rs @@ -6,6 +6,7 @@ use crate::{ circuit_builder::CircuitBuilder, e2e::ShardContext, error::ZKVMError, + impl_collect_shard, impl_gpu_assign, instructions::{ Instruction, riscv::{ @@ -24,13 +25,6 @@ use multilinear_extensions::{Expression, ToExpr, WitIn}; use p3::field::FieldAlgebra; use witness::set_val; -#[cfg(feature = "gpu")] -use crate::tables::RMMCollections; -#[cfg(feature = "gpu")] -use ceno_emul::StepIndex; -#[cfg(feature = "gpu")] -use gkr_iop::utils::lk_multiplicity::Multiplicity; - pub struct LuiConfig { pub i_insn: IInstructionConfig, pub imm: WitIn, @@ -145,48 +139,9 @@ impl Instruction for LuiInstruction { Ok(()) } - fn collect_shard_side_effects_instance( - config: &Self::InstructionConfig, - shard_ctx: &mut ShardContext, - lk_multiplicity: &mut LkMultiplicity, - step: &ceno_emul::StepRecord, - ) -> Result<(), ZKVMError> { - config - .i_insn - .collect_shard_effects(shard_ctx, lk_multiplicity, step); - Ok(()) - } + impl_collect_shard!(i_insn); - #[cfg(feature = "gpu")] - fn assign_instances( - config: &Self::InstructionConfig, - shard_ctx: &mut ShardContext, - num_witin: usize, - num_structural_witin: usize, - shard_steps: &[ceno_emul::StepRecord], - step_indices: &[StepIndex], - ) -> Result<(RMMCollections, Multiplicity), ZKVMError> { - use crate::instructions::riscv::gpu::witgen_gpu; - if let Some(result) = witgen_gpu::try_gpu_assign_instances::( - config, - shard_ctx, - num_witin, - num_structural_witin, - shard_steps, - step_indices, - witgen_gpu::GpuWitgenKind::Lui, - )? { - return Ok(result); - } - crate::instructions::cpu_assign_instances::( - config, - shard_ctx, - num_witin, - num_structural_witin, - shard_steps, - step_indices, - ) - } + impl_gpu_assign!(witgen_gpu::GpuWitgenKind::Lui); } #[cfg(test)] diff --git a/ceno_zkvm/src/instructions/riscv/memory/load.rs b/ceno_zkvm/src/instructions/riscv/memory/load.rs index a42b53e06..75b69ca35 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/load.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/load.rs @@ -4,6 +4,7 @@ use crate::{ e2e::ShardContext, error::ZKVMError, gadgets::SignedExtendConfig, + impl_collect_shard, impl_gpu_assign, instructions::{ Instruction, riscv::{ @@ -253,54 +254,10 @@ impl Instruction for LoadInstruction Result<(), ZKVMError> { - config - .im_insn - .collect_shard_effects(shard_ctx, lk_multiplicity, step); - Ok(()) - } + impl_collect_shard!(im_insn); - #[cfg(feature = "gpu")] - fn assign_instances( - config: &Self::InstructionConfig, - shard_ctx: &mut crate::e2e::ShardContext, - num_witin: usize, - num_structural_witin: usize, - shard_steps: &[StepRecord], - step_indices: &[ceno_emul::StepIndex], - ) -> Result< - ( - crate::tables::RMMCollections, - gkr_iop::utils::lk_multiplicity::Multiplicity, - ), - crate::error::ZKVMError, - > { - use crate::instructions::riscv::gpu::witgen_gpu; - if I::INST_KIND == InsnKind::LW { - if let Some(result) = witgen_gpu::try_gpu_assign_instances::( - config, - shard_ctx, - num_witin, - num_structural_witin, - shard_steps, - step_indices, - witgen_gpu::GpuWitgenKind::Lw, - )? { - return Ok(result); - } - } - crate::instructions::cpu_assign_instances::( - config, - shard_ctx, - num_witin, - num_structural_witin, - shard_steps, - step_indices, - ) - } + impl_gpu_assign!(match I::INST_KIND { + InsnKind::LW => Some(witgen_gpu::GpuWitgenKind::Lw), + _ => None, + }); } diff --git a/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs b/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs index 7878d9994..42a0848d0 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs @@ -4,6 +4,7 @@ use crate::{ e2e::ShardContext, error::ZKVMError, gadgets::SignedExtendConfig, + impl_collect_shard, impl_gpu_assign, instructions::{ Instruction, riscv::{ @@ -283,74 +284,26 @@ impl Instruction for LoadInstruction Result<(), ZKVMError> { - config - .im_insn - .collect_shard_effects(shard_ctx, lk_multiplicity, step); - Ok(()) - } + impl_collect_shard!(im_insn); - #[cfg(feature = "gpu")] - fn assign_instances( - config: &Self::InstructionConfig, - shard_ctx: &mut crate::e2e::ShardContext, - num_witin: usize, - num_structural_witin: usize, - shard_steps: &[ceno_emul::StepRecord], - step_indices: &[ceno_emul::StepIndex], - ) -> Result< - ( - crate::tables::RMMCollections, - gkr_iop::utils::lk_multiplicity::Multiplicity, - ), - crate::error::ZKVMError, - > { - use crate::instructions::riscv::gpu::witgen_gpu; - let gpu_kind = match I::INST_KIND { - InsnKind::LW => Some(witgen_gpu::GpuWitgenKind::Lw), - InsnKind::LH => Some(witgen_gpu::GpuWitgenKind::LoadSub { - load_width: 16, - is_signed: 1, - }), - InsnKind::LHU => Some(witgen_gpu::GpuWitgenKind::LoadSub { - load_width: 16, - is_signed: 0, - }), - InsnKind::LB => Some(witgen_gpu::GpuWitgenKind::LoadSub { - load_width: 8, - is_signed: 1, - }), - InsnKind::LBU => Some(witgen_gpu::GpuWitgenKind::LoadSub { - load_width: 8, - is_signed: 0, - }), - _ => None, - }; - if let Some(kind) = gpu_kind { - if let Some(result) = witgen_gpu::try_gpu_assign_instances::( - config, - shard_ctx, - num_witin, - num_structural_witin, - shard_steps, - step_indices, - kind, - )? { - return Ok(result); - } - } - crate::instructions::cpu_assign_instances::( - config, - shard_ctx, - num_witin, - num_structural_witin, - shard_steps, - step_indices, - ) - } + impl_gpu_assign!(match I::INST_KIND { + InsnKind::LW => Some(witgen_gpu::GpuWitgenKind::Lw), + InsnKind::LH => Some(witgen_gpu::GpuWitgenKind::LoadSub { + load_width: 16, + is_signed: 1, + }), + InsnKind::LHU => Some(witgen_gpu::GpuWitgenKind::LoadSub { + load_width: 16, + is_signed: 0, + }), + InsnKind::LB => Some(witgen_gpu::GpuWitgenKind::LoadSub { + load_width: 8, + is_signed: 1, + }), + InsnKind::LBU => Some(witgen_gpu::GpuWitgenKind::LoadSub { + load_width: 8, + is_signed: 0, + }), + _ => None, + }); } diff --git a/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs b/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs index ddb1dffb7..235cd1f5a 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs @@ -3,6 +3,7 @@ use crate::{ circuit_builder::CircuitBuilder, e2e::ShardContext, error::ZKVMError, + impl_collect_shard, impl_gpu_assign, instructions::{ Instruction, riscv::{ @@ -24,13 +25,6 @@ use multilinear_extensions::{ToExpr, WitIn}; use p3::field::{Field, FieldAlgebra}; use std::marker::PhantomData; -#[cfg(feature = "gpu")] -use crate::tables::RMMCollections; -#[cfg(feature = "gpu")] -use ceno_emul::StepIndex; -#[cfg(feature = "gpu")] -use gkr_iop::utils::lk_multiplicity::Multiplicity; - pub struct StoreConfig { pub(crate) s_insn: SInstructionConfig, @@ -221,54 +215,12 @@ impl Instruction Ok(()) } - fn collect_shard_side_effects_instance( - config: &Self::InstructionConfig, - shard_ctx: &mut ShardContext, - lk_multiplicity: &mut LkMultiplicity, - step: &StepRecord, - ) -> Result<(), ZKVMError> { - config - .s_insn - .collect_shard_effects(shard_ctx, lk_multiplicity, step); - Ok(()) - } + impl_collect_shard!(s_insn); - #[cfg(feature = "gpu")] - fn assign_instances( - config: &Self::InstructionConfig, - shard_ctx: &mut ShardContext, - num_witin: usize, - num_structural_witin: usize, - shard_steps: &[StepRecord], - step_indices: &[StepIndex], - ) -> Result<(RMMCollections, Multiplicity), ZKVMError> { - use crate::instructions::riscv::gpu::witgen_gpu; - let gpu_kind = match I::INST_KIND { - InsnKind::SW => Some(witgen_gpu::GpuWitgenKind::Sw), - InsnKind::SH => Some(witgen_gpu::GpuWitgenKind::Sh), - InsnKind::SB => Some(witgen_gpu::GpuWitgenKind::Sb), - _ => None, - }; - if let Some(kind) = gpu_kind { - if let Some(result) = witgen_gpu::try_gpu_assign_instances::( - config, - shard_ctx, - num_witin, - num_structural_witin, - shard_steps, - step_indices, - kind, - )? { - return Ok(result); - } - } - crate::instructions::cpu_assign_instances::( - config, - shard_ctx, - num_witin, - num_structural_witin, - shard_steps, - step_indices, - ) - } + impl_gpu_assign!(match I::INST_KIND { + InsnKind::SW => Some(witgen_gpu::GpuWitgenKind::Sw), + InsnKind::SH => Some(witgen_gpu::GpuWitgenKind::Sh), + InsnKind::SB => Some(witgen_gpu::GpuWitgenKind::Sb), + _ => None, + }); } diff --git a/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit_v2.rs index a04359256..99643c982 100644 --- a/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit_v2.rs @@ -1,6 +1,7 @@ use crate::{ circuit_builder::CircuitBuilder, error::ZKVMError, + impl_collect_shard, impl_gpu_assign, instructions::{ Instruction, riscv::{ @@ -24,13 +25,6 @@ use crate::e2e::ShardContext; use itertools::Itertools; use std::{array, marker::PhantomData}; -#[cfg(feature = "gpu")] -use crate::tables::RMMCollections; -#[cfg(feature = "gpu")] -use ceno_emul::StepIndex; -#[cfg(feature = "gpu")] -use gkr_iop::utils::lk_multiplicity::Multiplicity; - pub struct MulhInstructionBase(PhantomData<(E, I)>); pub struct MulhConfig { @@ -433,64 +427,15 @@ impl Instruction for MulhInstructionBas Ok(()) } - fn collect_shard_side_effects_instance( - config: &Self::InstructionConfig, - shard_ctx: &mut ShardContext, - lk_multiplicity: &mut LkMultiplicity, - step: &StepRecord, - ) -> Result<(), ZKVMError> { - config - .r_insn - .collect_shard_effects(shard_ctx, lk_multiplicity, step); - Ok(()) - } + impl_collect_shard!(r_insn); - #[cfg(feature = "gpu")] - fn assign_instances( - config: &Self::InstructionConfig, - shard_ctx: &mut ShardContext, - num_witin: usize, - num_structural_witin: usize, - shard_steps: &[StepRecord], - step_indices: &[StepIndex], - ) -> Result<(RMMCollections, Multiplicity), ZKVMError> { - use crate::instructions::riscv::gpu::witgen_gpu; - let mul_kind = match I::INST_KIND { - InsnKind::MUL => 0u32, - InsnKind::MULH => 1u32, - InsnKind::MULHU => 2u32, - InsnKind::MULHSU => 3u32, - _ => { - return crate::instructions::cpu_assign_instances::( - config, - shard_ctx, - num_witin, - num_structural_witin, - shard_steps, - step_indices, - ); - } - }; - if let Some(result) = witgen_gpu::try_gpu_assign_instances::( - config, - shard_ctx, - num_witin, - num_structural_witin, - shard_steps, - step_indices, - witgen_gpu::GpuWitgenKind::Mul(mul_kind), - )? { - return Ok(result); - } - crate::instructions::cpu_assign_instances::( - config, - shard_ctx, - num_witin, - num_structural_witin, - shard_steps, - step_indices, - ) - } + impl_gpu_assign!(match I::INST_KIND { + InsnKind::MUL => Some(witgen_gpu::GpuWitgenKind::Mul(0u32)), + InsnKind::MULH => Some(witgen_gpu::GpuWitgenKind::Mul(1u32)), + InsnKind::MULHU => Some(witgen_gpu::GpuWitgenKind::Mul(2u32)), + InsnKind::MULHSU => Some(witgen_gpu::GpuWitgenKind::Mul(3u32)), + _ => None, + }); } fn run_mulh( diff --git a/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs index 38f6b7758..91c1845bd 100644 --- a/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs @@ -1,8 +1,7 @@ use crate::e2e::ShardContext; -#[cfg(feature = "gpu")] -use crate::tables::RMMCollections; /// constrain implementation follow from https://github.com/openvm-org/openvm/blob/main/extensions/rv32im/circuit/src/shift/core.rs use crate::{ + impl_collect_shard, impl_gpu_assign, instructions::{ Instruction, riscv::{ @@ -20,11 +19,7 @@ use crate::{ utils::{split_to_limb, split_to_u8}, }; use ceno_emul::InsnKind; -#[cfg(feature = "gpu")] -use ceno_emul::StepIndex; use ff_ext::{ExtensionField, FieldInto}; -#[cfg(feature = "gpu")] -use gkr_iop::utils::lk_multiplicity::Multiplicity; use itertools::Itertools; use multilinear_extensions::{Expression, ToExpr, WitIn}; use p3::field::{Field, FieldAlgebra}; @@ -440,54 +435,14 @@ impl Instruction for ShiftLogicalInstru Ok(()) } - fn collect_shard_side_effects_instance( - config: &Self::InstructionConfig, - shard_ctx: &mut ShardContext, - lk_multiplicity: &mut crate::witness::LkMultiplicity, - step: &ceno_emul::StepRecord, - ) -> Result<(), crate::error::ZKVMError> { - config - .r_insn - .collect_shard_effects(shard_ctx, lk_multiplicity, step); - Ok(()) - } + impl_collect_shard!(r_insn); - #[cfg(feature = "gpu")] - fn assign_instances( - config: &Self::InstructionConfig, - shard_ctx: &mut ShardContext, - num_witin: usize, - num_structural_witin: usize, - shard_steps: &[ceno_emul::StepRecord], - step_indices: &[StepIndex], - ) -> Result<(RMMCollections, Multiplicity), crate::error::ZKVMError> { - use crate::instructions::riscv::gpu::witgen_gpu; - let shift_kind = match I::INST_KIND { - InsnKind::SLL => 0u32, - InsnKind::SRL => 1u32, - InsnKind::SRA => 2u32, - _ => unreachable!(), - }; - if let Some(result) = witgen_gpu::try_gpu_assign_instances::( - config, - shard_ctx, - num_witin, - num_structural_witin, - shard_steps, - step_indices, - witgen_gpu::GpuWitgenKind::ShiftR(shift_kind), - )? { - return Ok(result); - } - crate::instructions::cpu_assign_instances::( - config, - shard_ctx, - num_witin, - num_structural_witin, - shard_steps, - step_indices, - ) - } + impl_gpu_assign!(witgen_gpu::GpuWitgenKind::ShiftR(match I::INST_KIND { + InsnKind::SLL => 0u32, + InsnKind::SRL => 1u32, + InsnKind::SRA => 2u32, + _ => unreachable!(), + })); } pub struct ShiftImmConfig { @@ -619,54 +574,14 @@ impl Instruction for ShiftImmInstructio Ok(()) } - fn collect_shard_side_effects_instance( - config: &Self::InstructionConfig, - shard_ctx: &mut ShardContext, - lk_multiplicity: &mut crate::witness::LkMultiplicity, - step: &ceno_emul::StepRecord, - ) -> Result<(), crate::error::ZKVMError> { - config - .i_insn - .collect_shard_effects(shard_ctx, lk_multiplicity, step); - Ok(()) - } + impl_collect_shard!(i_insn); - #[cfg(feature = "gpu")] - fn assign_instances( - config: &Self::InstructionConfig, - shard_ctx: &mut ShardContext, - num_witin: usize, - num_structural_witin: usize, - shard_steps: &[ceno_emul::StepRecord], - step_indices: &[StepIndex], - ) -> Result<(RMMCollections, Multiplicity), crate::error::ZKVMError> { - use crate::instructions::riscv::gpu::witgen_gpu; - let shift_kind = match I::INST_KIND { - InsnKind::SLLI => 0u32, - InsnKind::SRLI => 1u32, - InsnKind::SRAI => 2u32, - _ => unreachable!(), - }; - if let Some(result) = witgen_gpu::try_gpu_assign_instances::( - config, - shard_ctx, - num_witin, - num_structural_witin, - shard_steps, - step_indices, - witgen_gpu::GpuWitgenKind::ShiftI(shift_kind), - )? { - return Ok(result); - } - crate::instructions::cpu_assign_instances::( - config, - shard_ctx, - num_witin, - num_structural_witin, - shard_steps, - step_indices, - ) - } + impl_gpu_assign!(witgen_gpu::GpuWitgenKind::ShiftI(match I::INST_KIND { + InsnKind::SLLI => 0u32, + InsnKind::SRLI => 1u32, + InsnKind::SRAI => 2u32, + _ => unreachable!(), + })); } fn run_shift( diff --git a/ceno_zkvm/src/instructions/riscv/slt/slt_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/slt/slt_circuit_v2.rs index 15e5c104b..3cc7fae23 100644 --- a/ceno_zkvm/src/instructions/riscv/slt/slt_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/slt/slt_circuit_v2.rs @@ -4,6 +4,7 @@ use crate::{ e2e::ShardContext, error::ZKVMError, gadgets::{UIntLimbsLT, UIntLimbsLTConfig}, + impl_collect_shard, impl_gpu_assign, instructions::{ Instruction, riscv::{RIVInstruction, constants::UInt, r_insn::RInstructionConfig}, @@ -16,13 +17,6 @@ use ceno_emul::{InsnKind, StepRecord}; use ff_ext::ExtensionField; use std::marker::PhantomData; -#[cfg(feature = "gpu")] -use crate::tables::RMMCollections; -#[cfg(feature = "gpu")] -use ceno_emul::StepIndex; -#[cfg(feature = "gpu")] -use gkr_iop::utils::lk_multiplicity::Multiplicity; - pub struct SetLessThanInstruction(PhantomData<(E, I)>); /// This config handles R-Instructions that represent registers values as 2 * u16. @@ -151,51 +145,11 @@ impl Instruction for SetLessThanInstruc Ok(()) } - fn collect_shard_side_effects_instance( - config: &Self::InstructionConfig, - shard_ctx: &mut ShardContext, - lk_multiplicity: &mut LkMultiplicity, - step: &StepRecord, - ) -> Result<(), ZKVMError> { - config - .r_insn - .collect_shard_effects(shard_ctx, lk_multiplicity, step); - Ok(()) - } + impl_collect_shard!(r_insn); - #[cfg(feature = "gpu")] - fn assign_instances( - config: &Self::InstructionConfig, - shard_ctx: &mut ShardContext, - num_witin: usize, - num_structural_witin: usize, - shard_steps: &[StepRecord], - step_indices: &[StepIndex], - ) -> Result<(RMMCollections, Multiplicity), crate::error::ZKVMError> { - use crate::instructions::riscv::gpu::witgen_gpu; - let is_signed = match I::INST_KIND { - InsnKind::SLT => 1u32, - InsnKind::SLTU => 0u32, - _ => unreachable!(), - }; - if let Some(result) = witgen_gpu::try_gpu_assign_instances::( - config, - shard_ctx, - num_witin, - num_structural_witin, - shard_steps, - step_indices, - witgen_gpu::GpuWitgenKind::Slt(is_signed), - )? { - return Ok(result); - } - crate::instructions::cpu_assign_instances::( - config, - shard_ctx, - num_witin, - num_structural_witin, - shard_steps, - step_indices, - ) - } + impl_gpu_assign!(witgen_gpu::GpuWitgenKind::Slt(match I::INST_KIND { + InsnKind::SLT => 1u32, + InsnKind::SLTU => 0u32, + _ => unreachable!(), + })); } diff --git a/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs index da60ca953..030613c2b 100644 --- a/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs @@ -4,6 +4,7 @@ use crate::{ e2e::ShardContext, error::ZKVMError, gadgets::{UIntLimbsLT, UIntLimbsLTConfig}, + impl_collect_shard, impl_gpu_assign, instructions::{ Instruction, riscv::{ @@ -24,13 +25,6 @@ use p3::field::FieldAlgebra; use std::marker::PhantomData; use witness::set_val; -#[cfg(feature = "gpu")] -use crate::tables::RMMCollections; -#[cfg(feature = "gpu")] -use ceno_emul::StepIndex; -#[cfg(feature = "gpu")] -use gkr_iop::utils::lk_multiplicity::Multiplicity; - #[derive(Debug)] pub struct SetLessThanImmConfig { pub(crate) i_insn: IInstructionConfig, @@ -170,51 +164,11 @@ impl Instruction for SetLessThanImmInst Ok(()) } - fn collect_shard_side_effects_instance( - config: &Self::InstructionConfig, - shard_ctx: &mut ShardContext, - lk_multiplicity: &mut LkMultiplicity, - step: &StepRecord, - ) -> Result<(), ZKVMError> { - config - .i_insn - .collect_shard_effects(shard_ctx, lk_multiplicity, step); - Ok(()) - } + impl_collect_shard!(i_insn); - #[cfg(feature = "gpu")] - fn assign_instances( - config: &Self::InstructionConfig, - shard_ctx: &mut ShardContext, - num_witin: usize, - num_structural_witin: usize, - shard_steps: &[StepRecord], - step_indices: &[StepIndex], - ) -> Result<(RMMCollections, Multiplicity), crate::error::ZKVMError> { - use crate::instructions::riscv::gpu::witgen_gpu; - let is_signed = match I::INST_KIND { - InsnKind::SLTI => 1u32, - InsnKind::SLTIU => 0u32, - _ => unreachable!(), - }; - if let Some(result) = witgen_gpu::try_gpu_assign_instances::( - config, - shard_ctx, - num_witin, - num_structural_witin, - shard_steps, - step_indices, - witgen_gpu::GpuWitgenKind::Slti(is_signed), - )? { - return Ok(result); - } - crate::instructions::cpu_assign_instances::( - config, - shard_ctx, - num_witin, - num_structural_witin, - shard_steps, - step_indices, - ) - } + impl_gpu_assign!(witgen_gpu::GpuWitgenKind::Slti(match I::INST_KIND { + InsnKind::SLTI => 1u32, + InsnKind::SLTIU => 0u32, + _ => unreachable!(), + })); } From 6daebf49019886a7603cf479c77583906c2bbfa1 Mon Sep 17 00:00:00 2001 From: Velaciela Date: Mon, 23 Mar 2026 20:34:03 +0800 Subject: [PATCH 57/73] simplify: macro-3 --- ceno_zkvm/src/instructions.rs | 44 +++++++++++++++++ ceno_zkvm/src/instructions/host_ops/sink.rs | 20 ++++++++ ceno_zkvm/src/instructions/riscv/arith.rs | 28 +++-------- .../riscv/arith_imm/arith_imm_circuit_v2.rs | 22 ++------- ceno_zkvm/src/instructions/riscv/auipc.rs | 28 +++-------- .../riscv/branch/branch_circuit_v2.rs | 24 ++-------- .../instructions/riscv/div/div_circuit_v2.rs | 26 +++------- .../src/instructions/riscv/jump/jal_v2.rs | 24 ++-------- .../src/instructions/riscv/jump/jalr_v2.rs | 28 +++-------- .../instructions/riscv/logic/logic_circuit.rs | 20 ++------ .../riscv/logic_imm/logic_imm_circuit_v2.rs | 25 +++------- ceno_zkvm/src/instructions/riscv/lui.rs | 24 ++-------- .../src/instructions/riscv/memory/load.rs | 23 ++------- .../src/instructions/riscv/memory/load_v2.rs | 23 ++------- .../src/instructions/riscv/memory/store_v2.rs | 30 ++++-------- .../riscv/mulh/mulh_circuit_v2.rs | 22 ++------- .../riscv/shift/shift_circuit_v2.rs | 48 ++++--------------- .../instructions/riscv/slt/slt_circuit_v2.rs | 24 ++-------- .../riscv/slti/slti_circuit_v2.rs | 24 ++-------- 19 files changed, 162 insertions(+), 345 deletions(-) diff --git a/ceno_zkvm/src/instructions.rs b/ceno_zkvm/src/instructions.rs index ce2febe7f..151481911 100644 --- a/ceno_zkvm/src/instructions.rs +++ b/ceno_zkvm/src/instructions.rs @@ -233,6 +233,50 @@ pub fn full_step_indices(steps: &[StepRecord]) -> Vec { // Macros to reduce per-chip boilerplate // --------------------------------------------------------------------------- +/// Implement `collect_side_effects_instance` with a common prologue +/// (create `CpuSideEffectSink`, dispatch to `config.$field.collect_side_effects`) +/// and a chip-specific body for additional LK ops. +/// +/// The closure receives `(sink, step, config, ctx)`: +/// - `sink: &mut CpuSideEffectSink` — emit LK ops and send events +/// - `step: &StepRecord` — current step +/// - `config: &Self::InstructionConfig` — circuit config (for sub-configs) +/// - `ctx: &ShardContext` — read-only shard context +/// +/// Usage inside `impl Instruction for MyChip`: +/// ```ignore +/// impl_collect_side_effects!(r_insn, |sink, step, _config, _ctx| { +/// emit_u16_limbs(sink, step.rd().unwrap().value.after); +/// }); +/// ``` +#[macro_export] +macro_rules! impl_collect_side_effects { + ($field:ident, |$sink:ident, $step:ident, $config:ident, $ctx:ident| $body:block) => { + fn collect_side_effects_instance( + config: &Self::InstructionConfig, + shard_ctx: &mut $crate::e2e::ShardContext, + lk_multiplicity: &mut $crate::witness::LkMultiplicity, + step: &ceno_emul::StepRecord, + ) -> Result<(), $crate::error::ZKVMError> { + let shard_ctx_ptr = shard_ctx as *mut $crate::e2e::ShardContext; + let _ctx = unsafe { &*shard_ctx_ptr }; + let mut _sink_val = unsafe { + $crate::instructions::side_effects::CpuSideEffectSink::from_raw( + shard_ctx_ptr, + lk_multiplicity, + ) + }; + config.$field.collect_side_effects(&mut _sink_val, _ctx, step); + let $sink = &mut _sink_val; + let $step = step; + let $config = config; + let $ctx = _ctx; + $body + Ok(()) + } + }; +} + /// Implement `collect_shard_side_effects_instance` by delegating to /// `config.$field.collect_shard_effects(shard_ctx, lk_multiplicity, step)`. /// diff --git a/ceno_zkvm/src/instructions/host_ops/sink.rs b/ceno_zkvm/src/instructions/host_ops/sink.rs index 258269332..4a07d7e43 100644 --- a/ceno_zkvm/src/instructions/host_ops/sink.rs +++ b/ceno_zkvm/src/instructions/host_ops/sink.rs @@ -36,6 +36,26 @@ impl<'ctx, 'shard, 'lk> CpuSideEffectSink<'ctx, 'shard, 'lk> { } } +/// Create a `CpuSideEffectSink` and an immutable view of `ShardContext`, +/// then pass both to the closure `f`. +/// +/// This encapsulates the raw-pointer trick needed to hold `&mut ShardContext` +/// (inside the sink, for writes) and `&ShardContext` (for reads like +/// `aligned_prev_ts`) simultaneously. +/// +/// # Safety +/// Safe to call — the unsafety is internal and bounded by the closure lifetime. +pub fn with_cpu_sink<'a, R>( + shard_ctx: &'a mut ShardContext<'a>, + lk: &'a mut LkMultiplicity, + f: impl FnOnce(&mut CpuSideEffectSink<'a, 'a, 'a>, &ShardContext) -> R, +) -> R { + let ptr = shard_ctx as *mut ShardContext; + let view = unsafe { &*ptr }; + let mut sink = unsafe { CpuSideEffectSink::from_raw(ptr, lk) }; + f(&mut sink, view) +} + impl SideEffectSink for CpuSideEffectSink<'_, '_, '_> { fn emit_lk(&mut self, op: LkOp) { for (table, key) in op.encode_all() { diff --git a/ceno_zkvm/src/instructions/riscv/arith.rs b/ceno_zkvm/src/instructions/riscv/arith.rs index 523bec45d..d7c213e32 100644 --- a/ceno_zkvm/src/instructions/riscv/arith.rs +++ b/ceno_zkvm/src/instructions/riscv/arith.rs @@ -2,13 +2,13 @@ use std::marker::PhantomData; use super::{RIVInstruction, constants::UInt, r_insn::RInstructionConfig}; use crate::{ - impl_collect_shard, impl_gpu_assign, + impl_collect_shard, impl_collect_side_effects, impl_gpu_assign, circuit_builder::CircuitBuilder, e2e::ShardContext, error::ZKVMError, instructions::{ Instruction, - side_effects::{CpuSideEffectSink, emit_u16_limbs}, + side_effects::emit_u16_limbs, }, structs::ProgramParams, uint::Value, @@ -144,32 +144,18 @@ impl Instruction for ArithInstruction Result<(), ZKVMError> { - let shard_ctx_ptr = shard_ctx as *mut ShardContext; - let shard_ctx_view = unsafe { &*shard_ctx_ptr }; - let mut sink = unsafe { CpuSideEffectSink::from_raw(shard_ctx_ptr, lk_multiplicity) }; - config - .r_insn - .collect_side_effects(&mut sink, shard_ctx_view, step); - + impl_collect_side_effects!(r_insn, |sink, step, _config, _ctx| { match I::INST_KIND { InsnKind::ADD => { - emit_u16_limbs(&mut sink, step.rd().unwrap().value.after); + emit_u16_limbs(sink, step.rd().unwrap().value.after); } InsnKind::SUB => { - emit_u16_limbs(&mut sink, step.rd().unwrap().value.after); - emit_u16_limbs(&mut sink, step.rs1().unwrap().value); + emit_u16_limbs(sink, step.rd().unwrap().value.after); + emit_u16_limbs(sink, step.rs1().unwrap().value); } _ => unreachable!("Unsupported instruction kind"), } - - Ok(()) - } + }); impl_collect_shard!(r_insn); diff --git a/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit_v2.rs index 282ec6c4e..288fe0a5b 100644 --- a/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit_v2.rs @@ -3,11 +3,11 @@ use crate::{ circuit_builder::CircuitBuilder, e2e::ShardContext, error::ZKVMError, - impl_collect_shard, impl_gpu_assign, + impl_collect_shard, impl_collect_side_effects, impl_gpu_assign, instructions::{ Instruction, riscv::{RIVInstruction, constants::UInt, i_insn::IInstructionConfig}, - side_effects::{CpuSideEffectSink, emit_u16_limbs}, + side_effects::emit_u16_limbs, }, structs::ProgramParams, utils::{imm_sign_extend, imm_sign_extend_circuit}, @@ -109,21 +109,9 @@ impl Instruction for AddiInstruction { Ok(()) } - fn collect_side_effects_instance( - config: &Self::InstructionConfig, - shard_ctx: &mut ShardContext, - lk_multiplicity: &mut LkMultiplicity, - step: &StepRecord, - ) -> Result<(), ZKVMError> { - let shard_ctx_ptr = shard_ctx as *mut ShardContext; - let shard_ctx_view = unsafe { &*shard_ctx_ptr }; - let mut sink = unsafe { CpuSideEffectSink::from_raw(shard_ctx_ptr, lk_multiplicity) }; - config - .i_insn - .collect_side_effects(&mut sink, shard_ctx_view, step); - emit_u16_limbs(&mut sink, step.rd().unwrap().value.after); - Ok(()) - } + impl_collect_side_effects!(i_insn, |sink, step, _config, _ctx| { + emit_u16_limbs(sink, step.rd().unwrap().value.after); + }); impl_collect_shard!(i_insn); diff --git a/ceno_zkvm/src/instructions/riscv/auipc.rs b/ceno_zkvm/src/instructions/riscv/auipc.rs index b4a7ec6ed..15f4a44aa 100644 --- a/ceno_zkvm/src/instructions/riscv/auipc.rs +++ b/ceno_zkvm/src/instructions/riscv/auipc.rs @@ -6,7 +6,7 @@ use crate::{ circuit_builder::CircuitBuilder, e2e::ShardContext, error::ZKVMError, - impl_collect_shard, impl_gpu_assign, + impl_collect_shard, impl_collect_side_effects, impl_gpu_assign, instructions::{ Instruction, riscv::{ @@ -14,7 +14,7 @@ use crate::{ i_insn::IInstructionConfig, }, side_effects::{ - CpuSideEffectSink, LkOp, SideEffectSink, emit_byte_decomposition_ops, + LkOp, SideEffectSink, emit_byte_decomposition_ops, emit_const_range_op, }, }, @@ -193,27 +193,15 @@ impl Instruction for AuipcInstruction { Ok(()) } - fn collect_side_effects_instance( - config: &Self::InstructionConfig, - shard_ctx: &mut ShardContext, - lk_multiplicity: &mut LkMultiplicity, - step: &ceno_emul::StepRecord, - ) -> Result<(), ZKVMError> { - let shard_ctx_ptr = shard_ctx as *mut ShardContext; - let shard_ctx_view = unsafe { &*shard_ctx_ptr }; - let mut sink = unsafe { CpuSideEffectSink::from_raw(shard_ctx_ptr, lk_multiplicity) }; - config - .i_insn - .collect_side_effects(&mut sink, shard_ctx_view, step); - + impl_collect_side_effects!(i_insn, |sink, step, config, _ctx| { let rd_written = split_to_u8(step.rd().unwrap().value.after); - emit_byte_decomposition_ops(&mut sink, &rd_written); + emit_byte_decomposition_ops(sink, &rd_written); let pc = split_to_u8(step.pc().before.0); // Only iterate over the middle limbs that have witness columns (pc_limbs has UINT_BYTE_LIMBS-2 elements). // The MSB limb is range-checked via XOR below, the LSB is shared with rd_written[0]. for val in pc.iter().skip(1).take(config.pc_limbs.len()) { - emit_const_range_op(&mut sink, *val as u64, 8); + emit_const_range_op(sink, *val as u64, 8); } let imm = InsnRecord::::imm_internal(&step.insn()).0 as u32; @@ -221,7 +209,7 @@ impl Instruction for AuipcInstruction { .into_iter() .take(config.imm_limbs.len()) { - emit_const_range_op(&mut sink, val as u64, 8); + emit_const_range_op(sink, val as u64, 8); } let last_limb_bits = PC_BITS - UInt8::::LIMB_BITS * (UINT_BYTE_LIMBS - 1); @@ -231,9 +219,7 @@ impl Instruction for AuipcInstruction { a: pc[3], b: additional_bits as u8, }); - - Ok(()) - } + }); impl_collect_shard!(i_insn); diff --git a/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs index b0452c112..888d7062a 100644 --- a/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs @@ -4,7 +4,7 @@ use crate::{ e2e::ShardContext, error::ZKVMError, gadgets::{UIntLimbsLT, UIntLimbsLTConfig}, - impl_collect_shard, impl_gpu_assign, + impl_collect_shard, impl_collect_side_effects, impl_gpu_assign, instructions::{ Instruction, riscv::{ @@ -12,7 +12,7 @@ use crate::{ b_insn::BInstructionConfig, constants::{UINT_LIMBS, UInt}, }, - side_effects::{CpuSideEffectSink, emit_uint_limbs_lt_ops}, + side_effects::emit_uint_limbs_lt_ops, }, structs::ProgramParams, witness::LkMultiplicity, @@ -209,34 +209,20 @@ impl Instruction for BranchCircuit Result<(), ZKVMError> { - let shard_ctx_ptr = shard_ctx as *mut ShardContext; - let shard_ctx_view = unsafe { &*shard_ctx_ptr }; - let mut sink = unsafe { CpuSideEffectSink::from_raw(shard_ctx_ptr, lk_multiplicity) }; - config - .b_insn - .collect_side_effects(&mut sink, shard_ctx_view, step); - + impl_collect_side_effects!(b_insn, |sink, step, _config, _ctx| { if !matches!(I::INST_KIND, InsnKind::BEQ | InsnKind::BNE) { let rs1_value = Value::new_unchecked(step.rs1().unwrap().value); let rs2_value = Value::new_unchecked(step.rs2().unwrap().value); let rs1_limbs = rs1_value.as_u16_limbs(); let rs2_limbs = rs2_value.as_u16_limbs(); emit_uint_limbs_lt_ops( - &mut sink, + sink, matches!(I::INST_KIND, InsnKind::BLT | InsnKind::BGE), &rs1_limbs, &rs2_limbs, ); } - - Ok(()) - } + }); impl_collect_shard!(b_insn); diff --git a/ceno_zkvm/src/instructions/riscv/div/div_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/div/div_circuit_v2.rs index 302fc8f23..e73298136 100644 --- a/ceno_zkvm/src/instructions/riscv/div/div_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/div/div_circuit_v2.rs @@ -14,11 +14,11 @@ use crate::{ circuit_builder::CircuitBuilder, e2e::ShardContext, error::ZKVMError, - impl_collect_shard, impl_gpu_assign, + impl_collect_shard, impl_collect_side_effects, impl_gpu_assign, instructions::{ Instruction, riscv::constants::LIMB_BITS, - side_effects::{CpuSideEffectSink, LkOp, SideEffectSink, emit_u16_limbs}, + side_effects::{LkOp, SideEffectSink, emit_u16_limbs}, }, structs::ProgramParams, uint::Value, @@ -538,19 +538,7 @@ impl Instruction for ArithInstruction Result<(), ZKVMError> { - let shard_ctx_ptr = shard_ctx as *mut ShardContext; - let shard_ctx_view = unsafe { &*shard_ctx_ptr }; - let mut sink = unsafe { CpuSideEffectSink::from_raw(shard_ctx_ptr, lkm) }; - config - .r_insn - .collect_side_effects(&mut sink, shard_ctx_view, step); - + impl_collect_side_effects!(r_insn, |sink, step, _config, _ctx| { let dividend = step.rs1().unwrap().value; let divisor = step.rs2().unwrap().value; let dividend_value = Value::new_unchecked(dividend); @@ -562,8 +550,8 @@ impl Instruction for ArithInstruction Instruction for ArithInstruction Instruction for JalInstruction { Ok(()) } - fn collect_side_effects_instance( - config: &Self::InstructionConfig, - shard_ctx: &mut ShardContext, - lk_multiplicity: &mut LkMultiplicity, - step: &ceno_emul::StepRecord, - ) -> Result<(), ZKVMError> { - let shard_ctx_ptr = shard_ctx as *mut ShardContext; - let shard_ctx_view = unsafe { &*shard_ctx_ptr }; - let mut sink = unsafe { CpuSideEffectSink::from_raw(shard_ctx_ptr, lk_multiplicity) }; - config - .j_insn - .collect_side_effects(&mut sink, shard_ctx_view, step); - + impl_collect_side_effects!(j_insn, |sink, step, _config, _ctx| { let rd_written = split_to_u8(step.rd().unwrap().value.after); - emit_byte_decomposition_ops(&mut sink, &rd_written); + emit_byte_decomposition_ops(sink, &rd_written); let last_limb_bits = PC_BITS - UInt8::::LIMB_BITS * (UINT_BYTE_LIMBS - 1); let additional_bits = @@ -149,9 +137,7 @@ impl Instruction for JalInstruction { a: rd_written[3], b: additional_bits as u8, }); - - Ok(()) - } + }); impl_collect_shard!(j_insn); diff --git a/ceno_zkvm/src/instructions/riscv/jump/jalr_v2.rs b/ceno_zkvm/src/instructions/riscv/jump/jalr_v2.rs index 236a672eb..7ca91ddbd 100644 --- a/ceno_zkvm/src/instructions/riscv/jump/jalr_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/jump/jalr_v2.rs @@ -7,7 +7,7 @@ use crate::{ circuit_builder::CircuitBuilder, e2e::ShardContext, error::ZKVMError, - impl_collect_shard, impl_gpu_assign, + impl_collect_shard, impl_collect_side_effects, impl_gpu_assign, instructions::{ Instruction, riscv::{ @@ -15,7 +15,7 @@ use crate::{ i_insn::IInstructionConfig, insn_base::{MemAddr, ReadRS1, StateInOut, WriteRD}, }, - side_effects::{CpuSideEffectSink, emit_const_range_op}, + side_effects::emit_const_range_op, }, structs::ProgramParams, tables::InsnRecord, @@ -193,30 +193,16 @@ impl Instruction for JalrInstruction { Ok(()) } - fn collect_side_effects_instance( - config: &Self::InstructionConfig, - shard_ctx: &mut ShardContext, - lk_multiplicity: &mut LkMultiplicity, - step: &ceno_emul::StepRecord, - ) -> Result<(), ZKVMError> { - let shard_ctx_ptr = shard_ctx as *mut ShardContext; - let shard_ctx_view = unsafe { &*shard_ctx_ptr }; - let mut sink = unsafe { CpuSideEffectSink::from_raw(shard_ctx_ptr, lk_multiplicity) }; - config - .i_insn - .collect_side_effects(&mut sink, shard_ctx_view, step); - + impl_collect_side_effects!(i_insn, |sink, step, config, _ctx| { let rd_value = Value::new_unchecked(step.rd().unwrap().value.after); let rd_limb = rd_value.as_u16_limbs(); - emit_const_range_op(&mut sink, rd_limb[0] as u64, 16); - emit_const_range_op(&mut sink, rd_limb[1] as u64, PC_BITS - 16); + emit_const_range_op(sink, rd_limb[0] as u64, 16); + emit_const_range_op(sink, rd_limb[1] as u64, PC_BITS - 16); let imm = InsnRecord::::imm_internal(&step.insn()); let jump_pc = step.rs1().unwrap().value.wrapping_add_signed(imm.0 as i32); - config.jump_pc_addr.collect_side_effects(&mut sink, jump_pc); - - Ok(()) - } + config.jump_pc_addr.collect_side_effects(sink, jump_pc); + }); impl_collect_shard!(i_insn); diff --git a/ceno_zkvm/src/instructions/riscv/logic/logic_circuit.rs b/ceno_zkvm/src/instructions/riscv/logic/logic_circuit.rs index 087409ba7..439701c5b 100644 --- a/ceno_zkvm/src/instructions/riscv/logic/logic_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/logic/logic_circuit.rs @@ -8,11 +8,11 @@ use crate::{ circuit_builder::CircuitBuilder, e2e::ShardContext, error::ZKVMError, - impl_collect_shard, impl_gpu_assign, + impl_collect_shard, impl_collect_side_effects, impl_gpu_assign, instructions::{ Instruction, riscv::{constants::UInt8, r_insn::RInstructionConfig}, - side_effects::{CpuSideEffectSink, emit_logic_u8_ops}, + side_effects::emit_logic_u8_ops, }, structs::ProgramParams, utils::split_to_u8, @@ -77,24 +77,14 @@ impl Instruction for LogicInstruction { config.assign_instance(instance, shard_ctx, lk_multiplicity, step) } - fn collect_side_effects_instance( - config: &Self::InstructionConfig, - shard_ctx: &mut ShardContext, - lk_multiplicity: &mut LkMultiplicity, - step: &StepRecord, - ) -> Result<(), ZKVMError> { - let shard_ctx_ptr = shard_ctx as *mut ShardContext; - let shard_ctx_view = unsafe { &*shard_ctx_ptr }; - let mut sink = unsafe { CpuSideEffectSink::from_raw(shard_ctx_ptr, lk_multiplicity) }; - config.collect_side_effects(&mut sink, shard_ctx_view, step); + impl_collect_side_effects!(r_insn, |sink, step, _config, _ctx| { emit_logic_u8_ops::( - &mut sink, + sink, step.rs1().unwrap().value as u64, step.rs2().unwrap().value as u64, 4, ); - Ok(()) - } + }); impl_collect_shard!(r_insn); diff --git a/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit_v2.rs index a4c258c41..243a9ee6d 100644 --- a/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit_v2.rs @@ -9,7 +9,7 @@ use crate::{ circuit_builder::CircuitBuilder, e2e::ShardContext, error::ZKVMError, - impl_collect_shard, impl_gpu_assign, + impl_collect_shard, impl_collect_side_effects, impl_gpu_assign, instructions::{ Instruction, riscv::{ @@ -17,7 +17,7 @@ use crate::{ i_insn::IInstructionConfig, logic_imm::LogicOp, }, - side_effects::{CpuSideEffectSink, emit_logic_u8_ops}, + side_effects::emit_logic_u8_ops, }, structs::ProgramParams, tables::InsnRecord, @@ -129,19 +129,7 @@ impl Instruction for LogicInstruction { config.assign_instance(instance, shard_ctx, lkm, step) } - fn collect_side_effects_instance( - config: &Self::InstructionConfig, - shard_ctx: &mut ShardContext, - lk_multiplicity: &mut LkMultiplicity, - step: &StepRecord, - ) -> Result<(), ZKVMError> { - let shard_ctx_ptr = shard_ctx as *mut ShardContext; - let shard_ctx_view = unsafe { &*shard_ctx_ptr }; - let mut sink = unsafe { CpuSideEffectSink::from_raw(shard_ctx_ptr, lk_multiplicity) }; - config - .i_insn - .collect_side_effects(&mut sink, shard_ctx_view, step); - + impl_collect_side_effects!(i_insn, |sink, step, _config, _ctx| { let rs1_lo = step.rs1().unwrap().value & LIMB_MASK; let rs1_hi = (step.rs1().unwrap().value >> LIMB_BITS) & LIMB_MASK; let imm_lo = InsnRecord::::imm_internal(&step.insn()).0 as u32 & LIMB_MASK; @@ -149,10 +137,9 @@ impl Instruction for LogicInstruction { >> LIMB_BITS) & LIMB_MASK; - emit_logic_u8_ops::(&mut sink, rs1_lo.into(), imm_lo.into(), 2); - emit_logic_u8_ops::(&mut sink, rs1_hi.into(), imm_hi.into(), 2); - Ok(()) - } + emit_logic_u8_ops::(sink, rs1_lo.into(), imm_lo.into(), 2); + emit_logic_u8_ops::(sink, rs1_hi.into(), imm_hi.into(), 2); + }); impl_collect_shard!(i_insn); diff --git a/ceno_zkvm/src/instructions/riscv/lui.rs b/ceno_zkvm/src/instructions/riscv/lui.rs index 7359a96d5..e3d251727 100644 --- a/ceno_zkvm/src/instructions/riscv/lui.rs +++ b/ceno_zkvm/src/instructions/riscv/lui.rs @@ -6,14 +6,14 @@ use crate::{ circuit_builder::CircuitBuilder, e2e::ShardContext, error::ZKVMError, - impl_collect_shard, impl_gpu_assign, + impl_collect_shard, impl_collect_side_effects, impl_gpu_assign, instructions::{ Instruction, riscv::{ constants::{UINT_BYTE_LIMBS, UInt8}, i_insn::IInstructionConfig, }, - side_effects::{CpuSideEffectSink, emit_const_range_op}, + side_effects::emit_const_range_op, }, structs::ProgramParams, tables::InsnRecord, @@ -118,26 +118,12 @@ impl Instruction for LuiInstruction { Ok(()) } - fn collect_side_effects_instance( - config: &Self::InstructionConfig, - shard_ctx: &mut ShardContext, - lk_multiplicity: &mut LkMultiplicity, - step: &ceno_emul::StepRecord, - ) -> Result<(), ZKVMError> { - let shard_ctx_ptr = shard_ctx as *mut ShardContext; - let shard_ctx_view = unsafe { &*shard_ctx_ptr }; - let mut sink = unsafe { CpuSideEffectSink::from_raw(shard_ctx_ptr, lk_multiplicity) }; - config - .i_insn - .collect_side_effects(&mut sink, shard_ctx_view, step); - + impl_collect_side_effects!(i_insn, |sink, step, _config, _ctx| { let rd_written = split_to_u8::(step.rd().unwrap().value.after); for val in rd_written.iter().skip(1) { - emit_const_range_op(&mut sink, *val as u64, 8); + emit_const_range_op(sink, *val as u64, 8); } - - Ok(()) - } + }); impl_collect_shard!(i_insn); diff --git a/ceno_zkvm/src/instructions/riscv/memory/load.rs b/ceno_zkvm/src/instructions/riscv/memory/load.rs index 75b69ca35..c416e2a48 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/load.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/load.rs @@ -4,13 +4,12 @@ use crate::{ e2e::ShardContext, error::ZKVMError, gadgets::SignedExtendConfig, - impl_collect_shard, impl_gpu_assign, + impl_collect_shard, impl_collect_side_effects, impl_gpu_assign, instructions::{ Instruction, riscv::{ RIVInstruction, constants::UInt, im_insn::IMInstructionConfig, insn_base::MemAddr, }, - side_effects::CpuSideEffectSink, }, structs::ProgramParams, tables::InsnRecord, @@ -231,28 +230,14 @@ impl Instruction for LoadInstruction Result<(), ZKVMError> { - let shard_ctx_ptr = shard_ctx as *mut ShardContext; - let shard_ctx_view = unsafe { &*shard_ctx_ptr }; - let mut sink = - unsafe { CpuSideEffectSink::from_raw(shard_ctx_ptr, lk_multiplicity) }; - config - .im_insn - .collect_side_effects(&mut sink, shard_ctx_view, step); - + impl_collect_side_effects!(im_insn, |sink, step, config, _ctx| { let imm = InsnRecord::::imm_internal(&step.insn()); let unaligned_addr = ByteAddr::from(step.rs1().unwrap().value.wrapping_add_signed(imm.0 as i32)); config .memory_addr - .collect_side_effects(&mut sink, unaligned_addr.into()); - Ok(()) - } + .collect_side_effects(sink, unaligned_addr.into()); + }); impl_collect_shard!(im_insn); diff --git a/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs b/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs index 42a0848d0..8aab7bbef 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs @@ -4,7 +4,7 @@ use crate::{ e2e::ShardContext, error::ZKVMError, gadgets::SignedExtendConfig, - impl_collect_shard, impl_gpu_assign, + impl_collect_shard, impl_collect_side_effects, impl_gpu_assign, instructions::{ Instruction, riscv::{ @@ -13,7 +13,6 @@ use crate::{ im_insn::IMInstructionConfig, insn_base::MemAddr, }, - side_effects::CpuSideEffectSink, }, structs::ProgramParams, tables::InsnRecord, @@ -259,30 +258,16 @@ impl Instruction for LoadInstruction Result<(), ZKVMError> { + impl_collect_side_effects!(im_insn, |sink, step, config, _ctx| { // Side effects (shard send/addr) are identical for all load types (LW/LH/LB/LHU/LBU). // Sub-word extraction only affects LK emissions, handled separately by GPU kernel. - let shard_ctx_ptr = shard_ctx as *mut ShardContext; - let shard_ctx_view = unsafe { &*shard_ctx_ptr }; - let mut sink = - unsafe { CpuSideEffectSink::from_raw(shard_ctx_ptr, lk_multiplicity) }; - config - .im_insn - .collect_side_effects(&mut sink, shard_ctx_view, step); - let imm = InsnRecord::::imm_internal(&step.insn()); let unaligned_addr = ByteAddr::from(step.rs1().unwrap().value.wrapping_add_signed(imm.0 as i32)); config .memory_addr - .collect_side_effects(&mut sink, unaligned_addr.into()); - Ok(()) - } + .collect_side_effects(sink, unaligned_addr.into()); + }); impl_collect_shard!(im_insn); diff --git a/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs b/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs index 235cd1f5a..9f96cfeb7 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs @@ -3,7 +3,7 @@ use crate::{ circuit_builder::CircuitBuilder, e2e::ShardContext, error::ZKVMError, - impl_collect_shard, impl_gpu_assign, + impl_collect_shard, impl_collect_side_effects, impl_gpu_assign, instructions::{ Instruction, riscv::{ @@ -13,7 +13,7 @@ use crate::{ memory::gadget::MemWordUtil, s_insn::SInstructionConfig, }, - side_effects::{CpuSideEffectSink, emit_const_range_op, emit_u16_limbs}, + side_effects::{emit_const_range_op, emit_u16_limbs}, }, structs::ProgramParams, tables::InsnRecord, @@ -176,26 +176,14 @@ impl Instruction Ok(()) } - fn collect_side_effects_instance( - config: &Self::InstructionConfig, - shard_ctx: &mut ShardContext, - lk_multiplicity: &mut LkMultiplicity, - step: &StepRecord, - ) -> Result<(), ZKVMError> { - let shard_ctx_ptr = shard_ctx as *mut ShardContext; - let shard_ctx_view = unsafe { &*shard_ctx_ptr }; - let mut sink = unsafe { CpuSideEffectSink::from_raw(shard_ctx_ptr, lk_multiplicity) }; - config - .s_insn - .collect_side_effects(&mut sink, shard_ctx_view, step); - - emit_u16_limbs(&mut sink, step.memory_op().unwrap().value.before); + impl_collect_side_effects!(s_insn, |sink, step, config, _ctx| { + emit_u16_limbs(sink, step.memory_op().unwrap().value.before); let imm = InsnRecord::::imm_internal(&step.insn()); let addr = ByteAddr::from(step.rs1().unwrap().value.wrapping_add_signed(imm.0 as i32)); config .memory_addr - .collect_side_effects(&mut sink, addr.into()); + .collect_side_effects(sink, addr.into()); if N_ZEROS == 0 { let memory_op = step.memory_op().unwrap(); @@ -205,15 +193,13 @@ impl Instruction let rs2_limb = rs2_value.as_u16_limbs()[0]; for byte in prev_limb.to_le_bytes() { - emit_const_range_op(&mut sink, byte as u64, 8); + emit_const_range_op(sink, byte as u64, 8); } for byte in rs2_limb.to_le_bytes() { - emit_const_range_op(&mut sink, byte as u64, 8); + emit_const_range_op(sink, byte as u64, 8); } } - - Ok(()) - } + }); impl_collect_shard!(s_insn); diff --git a/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit_v2.rs index 99643c982..e0c30b781 100644 --- a/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit_v2.rs @@ -1,7 +1,7 @@ use crate::{ circuit_builder::CircuitBuilder, error::ZKVMError, - impl_collect_shard, impl_gpu_assign, + impl_collect_shard, impl_collect_side_effects, impl_gpu_assign, instructions::{ Instruction, riscv::{ @@ -9,7 +9,7 @@ use crate::{ constants::{LIMB_BITS, UINT_LIMBS, UInt}, r_insn::RInstructionConfig, }, - side_effects::{CpuSideEffectSink, LkOp, SideEffectSink}, + side_effects::{LkOp, SideEffectSink}, }, structs::ProgramParams, uint::Value, @@ -332,24 +332,12 @@ impl Instruction for MulhInstructionBas Ok(()) } - fn collect_side_effects_instance( - config: &Self::InstructionConfig, - shard_ctx: &mut ShardContext, - lk_multiplicity: &mut LkMultiplicity, - step: &StepRecord, - ) -> Result<(), ZKVMError> { + impl_collect_side_effects!(r_insn, |sink, step, _config, _ctx| { let rs1 = step.rs1().unwrap().value; let rs1_val = Value::new_unchecked(rs1); let rs2 = step.rs2().unwrap().value; let rs2_val = Value::new_unchecked(rs2); - let shard_ctx_ptr = shard_ctx as *mut ShardContext; - let shard_ctx_view = unsafe { &*shard_ctx_ptr }; - let mut sink = unsafe { CpuSideEffectSink::from_raw(shard_ctx_ptr, lk_multiplicity) }; - config - .r_insn - .collect_side_effects(&mut sink, shard_ctx_view, step); - let (rd_high, rd_low, carry, rs1_ext, rs2_ext) = run_mulh::( I::INST_KIND, rs1_val @@ -423,9 +411,7 @@ impl Instruction for MulhInstructionBas } _ => {} } - - Ok(()) - } + }); impl_collect_shard!(r_insn); diff --git a/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs index 91c1845bd..3726af270 100644 --- a/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs @@ -1,7 +1,7 @@ use crate::e2e::ShardContext; /// constrain implementation follow from https://github.com/openvm-org/openvm/blob/main/extensions/rv32im/circuit/src/shift/core.rs use crate::{ - impl_collect_shard, impl_gpu_assign, + impl_collect_shard, impl_collect_side_effects, impl_gpu_assign, instructions::{ Instruction, riscv::{ @@ -11,7 +11,7 @@ use crate::{ r_insn::RInstructionConfig, }, side_effects::{ - CpuSideEffectSink, LkOp, SideEffectSink, emit_byte_decomposition_ops, + LkOp, SideEffectSink, emit_byte_decomposition_ops, emit_const_range_op, }, }, @@ -410,30 +410,16 @@ impl Instruction for ShiftLogicalInstru Ok(()) } - fn collect_side_effects_instance( - config: &ShiftRTypeConfig, - shard_ctx: &mut ShardContext, - lk_multiplicity: &mut crate::witness::LkMultiplicity, - step: &ceno_emul::StepRecord, - ) -> Result<(), crate::error::ZKVMError> { - let shard_ctx_ptr = shard_ctx as *mut ShardContext; - let shard_ctx_view = unsafe { &*shard_ctx_ptr }; - let mut sink = unsafe { CpuSideEffectSink::from_raw(shard_ctx_ptr, lk_multiplicity) }; - config - .r_insn - .collect_side_effects(&mut sink, shard_ctx_view, step); - + impl_collect_side_effects!(r_insn, |sink, step, config, _ctx| { let rd_written = split_to_u8::(step.rd().unwrap().value.after); - emit_byte_decomposition_ops(&mut sink, &rd_written); + emit_byte_decomposition_ops(sink, &rd_written); config.shift_base_config.collect_side_effects( - &mut sink, + sink, I::INST_KIND, step.rs1().unwrap().value, step.rs2().unwrap().value, ); - - Ok(()) - } + }); impl_collect_shard!(r_insn); @@ -549,30 +535,16 @@ impl Instruction for ShiftImmInstructio Ok(()) } - fn collect_side_effects_instance( - config: &ShiftImmConfig, - shard_ctx: &mut ShardContext, - lk_multiplicity: &mut crate::witness::LkMultiplicity, - step: &ceno_emul::StepRecord, - ) -> Result<(), crate::error::ZKVMError> { - let shard_ctx_ptr = shard_ctx as *mut ShardContext; - let shard_ctx_view = unsafe { &*shard_ctx_ptr }; - let mut sink = unsafe { CpuSideEffectSink::from_raw(shard_ctx_ptr, lk_multiplicity) }; - config - .i_insn - .collect_side_effects(&mut sink, shard_ctx_view, step); - + impl_collect_side_effects!(i_insn, |sink, step, config, _ctx| { let rd_written = split_to_u8::(step.rd().unwrap().value.after); - emit_byte_decomposition_ops(&mut sink, &rd_written); + emit_byte_decomposition_ops(sink, &rd_written); config.shift_base_config.collect_side_effects( - &mut sink, + sink, I::INST_KIND, step.rs1().unwrap().value, step.insn().imm as i16 as u16 as u32, ); - - Ok(()) - } + }); impl_collect_shard!(i_insn); diff --git a/ceno_zkvm/src/instructions/riscv/slt/slt_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/slt/slt_circuit_v2.rs index 3cc7fae23..557131583 100644 --- a/ceno_zkvm/src/instructions/riscv/slt/slt_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/slt/slt_circuit_v2.rs @@ -4,11 +4,11 @@ use crate::{ e2e::ShardContext, error::ZKVMError, gadgets::{UIntLimbsLT, UIntLimbsLTConfig}, - impl_collect_shard, impl_gpu_assign, + impl_collect_shard, impl_collect_side_effects, impl_gpu_assign, instructions::{ Instruction, riscv::{RIVInstruction, constants::UInt, r_insn::RInstructionConfig}, - side_effects::{CpuSideEffectSink, emit_uint_limbs_lt_ops}, + side_effects::emit_uint_limbs_lt_ops, }, structs::ProgramParams, witness::LkMultiplicity, @@ -118,32 +118,18 @@ impl Instruction for SetLessThanInstruc Ok(()) } - fn collect_side_effects_instance( - config: &Self::InstructionConfig, - shard_ctx: &mut ShardContext, - lkm: &mut LkMultiplicity, - step: &StepRecord, - ) -> Result<(), ZKVMError> { - let shard_ctx_ptr = shard_ctx as *mut ShardContext; - let shard_ctx_view = unsafe { &*shard_ctx_ptr }; - let mut sink = unsafe { CpuSideEffectSink::from_raw(shard_ctx_ptr, lkm) }; - config - .r_insn - .collect_side_effects(&mut sink, shard_ctx_view, step); - + impl_collect_side_effects!(r_insn, |sink, step, _config, _ctx| { let rs1_value = Value::new_unchecked(step.rs1().unwrap().value); let rs2_value = Value::new_unchecked(step.rs2().unwrap().value); let rs1_limbs = rs1_value.as_u16_limbs(); let rs2_limbs = rs2_value.as_u16_limbs(); emit_uint_limbs_lt_ops( - &mut sink, + sink, matches!(I::INST_KIND, InsnKind::SLT), &rs1_limbs, &rs2_limbs, ); - - Ok(()) - } + }); impl_collect_shard!(r_insn); diff --git a/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs index 030613c2b..47c60443a 100644 --- a/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs @@ -4,7 +4,7 @@ use crate::{ e2e::ShardContext, error::ZKVMError, gadgets::{UIntLimbsLT, UIntLimbsLTConfig}, - impl_collect_shard, impl_gpu_assign, + impl_collect_shard, impl_collect_side_effects, impl_gpu_assign, instructions::{ Instruction, riscv::{ @@ -12,7 +12,7 @@ use crate::{ constants::{UINT_LIMBS, UInt}, i_insn::IInstructionConfig, }, - side_effects::{CpuSideEffectSink, emit_uint_limbs_lt_ops}, + side_effects::emit_uint_limbs_lt_ops, }, structs::ProgramParams, utils::{imm_sign_extend, imm_sign_extend_circuit}, @@ -138,31 +138,17 @@ impl Instruction for SetLessThanImmInst Ok(()) } - fn collect_side_effects_instance( - config: &Self::InstructionConfig, - shard_ctx: &mut ShardContext, - lkm: &mut LkMultiplicity, - step: &StepRecord, - ) -> Result<(), ZKVMError> { - let shard_ctx_ptr = shard_ctx as *mut ShardContext; - let shard_ctx_view = unsafe { &*shard_ctx_ptr }; - let mut sink = unsafe { CpuSideEffectSink::from_raw(shard_ctx_ptr, lkm) }; - config - .i_insn - .collect_side_effects(&mut sink, shard_ctx_view, step); - + impl_collect_side_effects!(i_insn, |sink, step, _config, _ctx| { let rs1_value = Value::new_unchecked(step.rs1().unwrap().value); let rs1_limbs = rs1_value.as_u16_limbs(); let imm_sign_extend = imm_sign_extend(true, step.insn().imm as i16); emit_uint_limbs_lt_ops( - &mut sink, + sink, matches!(I::INST_KIND, InsnKind::SLTI), &rs1_limbs, &imm_sign_extend, ); - - Ok(()) - } + }); impl_collect_shard!(i_insn); From 78a51d218f52130df7b8490f23e4c047dfb3fd97 Mon Sep 17 00:00:00 2001 From: Velaciela Date: Mon, 23 Mar 2026 21:17:14 +0800 Subject: [PATCH 58/73] simplify: naming --- ceno_zkvm/src/instructions.rs | 30 +++++++-------- .../src/instructions/host_ops/cpu_fallback.rs | 14 +++---- ceno_zkvm/src/instructions/host_ops/mod.rs | 8 ++-- ceno_zkvm/src/instructions/riscv/arith.rs | 6 +-- .../riscv/arith_imm/arith_imm_circuit_v2.rs | 6 +-- ceno_zkvm/src/instructions/riscv/auipc.rs | 6 +-- ceno_zkvm/src/instructions/riscv/b_insn.rs | 12 +++--- .../riscv/branch/branch_circuit_v2.rs | 6 +-- .../instructions/riscv/div/div_circuit_v2.rs | 6 +-- ceno_zkvm/src/instructions/riscv/gpu/d2h.rs | 2 +- .../instructions/riscv/gpu/debug_compare.rs | 10 ++--- .../src/instructions/riscv/gpu/witgen_gpu.rs | 14 +++---- ceno_zkvm/src/instructions/riscv/i_insn.rs | 12 +++--- ceno_zkvm/src/instructions/riscv/im_insn.rs | 16 ++++---- ceno_zkvm/src/instructions/riscv/insn_base.rs | 38 +++++++++---------- ceno_zkvm/src/instructions/riscv/j_insn.rs | 8 ++-- .../src/instructions/riscv/jump/jal_v2.rs | 6 +-- .../src/instructions/riscv/jump/jalr_v2.rs | 8 ++-- .../instructions/riscv/logic/logic_circuit.rs | 10 ++--- .../riscv/logic_imm/logic_imm_circuit_v2.rs | 6 +-- ceno_zkvm/src/instructions/riscv/lui.rs | 6 +-- .../src/instructions/riscv/memory/load.rs | 8 ++-- .../src/instructions/riscv/memory/load_v2.rs | 8 ++-- .../src/instructions/riscv/memory/store_v2.rs | 8 ++-- .../riscv/mulh/mulh_circuit_v2.rs | 6 +-- ceno_zkvm/src/instructions/riscv/r_insn.rs | 16 ++++---- ceno_zkvm/src/instructions/riscv/s_insn.rs | 16 ++++---- .../riscv/shift/shift_circuit_v2.rs | 16 ++++---- .../instructions/riscv/slt/slt_circuit_v2.rs | 6 +-- .../riscv/slti/slti_circuit_v2.rs | 6 +-- 30 files changed, 160 insertions(+), 160 deletions(-) diff --git a/ceno_zkvm/src/instructions.rs b/ceno_zkvm/src/instructions.rs index 151481911..c351e45ea 100644 --- a/ceno_zkvm/src/instructions.rs +++ b/ceno_zkvm/src/instructions.rs @@ -25,7 +25,7 @@ pub mod host_ops; /// Backward-compatible re-export: old `side_effects` path still works. pub use host_ops as side_effects; -pub use host_ops::{cpu_assign_instances, cpu_collect_shard_side_effects, cpu_collect_side_effects}; +pub use host_ops::{cpu_assign_instances, cpu_collect_lk_and_shardram, cpu_collect_shardram}; pub trait Instruction { type InstructionConfig: Send + Sync; @@ -104,7 +104,7 @@ pub trait Instruction { step: &StepRecord, ) -> Result<(), ZKVMError>; - fn collect_side_effects_instance( + fn collect_lk_and_shardram( _config: &Self::InstructionConfig, _shard_ctx: &mut ShardContext, _lk_multiplicity: &mut LkMultiplicity, @@ -119,7 +119,7 @@ pub trait Instruction { )) } - fn collect_shard_side_effects_instance( + fn collect_shardram( _config: &Self::InstructionConfig, _shard_ctx: &mut ShardContext, _lk_multiplicity: &mut LkMultiplicity, @@ -233,8 +233,8 @@ pub fn full_step_indices(steps: &[StepRecord]) -> Vec { // Macros to reduce per-chip boilerplate // --------------------------------------------------------------------------- -/// Implement `collect_side_effects_instance` with a common prologue -/// (create `CpuSideEffectSink`, dispatch to `config.$field.collect_side_effects`) +/// Implement `collect_lk_and_shardram` with a common prologue +/// (create `CpuSideEffectSink`, dispatch to `config.$field.emit_lk_and_shardram`) /// and a chip-specific body for additional LK ops. /// /// The closure receives `(sink, step, config, ctx)`: @@ -245,14 +245,14 @@ pub fn full_step_indices(steps: &[StepRecord]) -> Vec { /// /// Usage inside `impl Instruction for MyChip`: /// ```ignore -/// impl_collect_side_effects!(r_insn, |sink, step, _config, _ctx| { +/// impl_collect_lk_and_shardram!(r_insn, |sink, step, _config, _ctx| { /// emit_u16_limbs(sink, step.rd().unwrap().value.after); /// }); /// ``` #[macro_export] -macro_rules! impl_collect_side_effects { +macro_rules! impl_collect_lk_and_shardram { ($field:ident, |$sink:ident, $step:ident, $config:ident, $ctx:ident| $body:block) => { - fn collect_side_effects_instance( + fn collect_lk_and_shardram( config: &Self::InstructionConfig, shard_ctx: &mut $crate::e2e::ShardContext, lk_multiplicity: &mut $crate::witness::LkMultiplicity, @@ -266,7 +266,7 @@ macro_rules! impl_collect_side_effects { lk_multiplicity, ) }; - config.$field.collect_side_effects(&mut _sink_val, _ctx, step); + config.$field.emit_lk_and_shardram(&mut _sink_val, _ctx, step); let $sink = &mut _sink_val; let $step = step; let $config = config; @@ -277,20 +277,20 @@ macro_rules! impl_collect_side_effects { }; } -/// Implement `collect_shard_side_effects_instance` by delegating to -/// `config.$field.collect_shard_effects(shard_ctx, lk_multiplicity, step)`. +/// Implement `collect_shardram` by delegating to +/// `config.$field.emit_shardram(shard_ctx, lk_multiplicity, step)`. /// /// Every chip's implementation is identical except for the config field name /// (`r_insn`, `i_insn`, `b_insn`, `s_insn`, `j_insn`, `im_insn`). /// /// Usage inside `impl Instruction for MyChip`: /// ```ignore -/// impl_collect_shard!(r_insn); +/// impl_collect_shardram!(r_insn); /// ``` #[macro_export] -macro_rules! impl_collect_shard { +macro_rules! impl_collect_shardram { ($field:ident) => { - fn collect_shard_side_effects_instance( + fn collect_shardram( config: &Self::InstructionConfig, shard_ctx: &mut $crate::e2e::ShardContext, lk_multiplicity: &mut $crate::witness::LkMultiplicity, @@ -298,7 +298,7 @@ macro_rules! impl_collect_shard { ) -> Result<(), $crate::error::ZKVMError> { config .$field - .collect_shard_effects(shard_ctx, lk_multiplicity, step); + .emit_shardram(shard_ctx, lk_multiplicity, step); Ok(()) } }; diff --git a/ceno_zkvm/src/instructions/host_ops/cpu_fallback.rs b/ceno_zkvm/src/instructions/host_ops/cpu_fallback.rs index e3a63c943..302935ecc 100644 --- a/ceno_zkvm/src/instructions/host_ops/cpu_fallback.rs +++ b/ceno_zkvm/src/instructions/host_ops/cpu_fallback.rs @@ -92,29 +92,29 @@ pub fn cpu_assign_instances>( /// /// This path deliberately avoids scratch witness buffers and calls only the /// instruction-specific side-effect collector. -pub fn cpu_collect_side_effects>( +pub fn cpu_collect_lk_and_shardram>( config: &I::InstructionConfig, shard_ctx: &mut ShardContext, shard_steps: &[ceno_emul::StepRecord], step_indices: &[StepIndex], ) -> Result, ZKVMError> { - cpu_collect_side_effects_inner::(config, shard_ctx, shard_steps, step_indices, false) + cpu_collect_lk_shardram_inner::(config, shard_ctx, shard_steps, step_indices, false) } /// CPU-side `send()` / `addr_accessed` collection for GPU-assisted lk paths. /// /// Implementations may still increment fetch multiplicity on CPU, but all other /// lookup multiplicities are expected to come from the GPU path. -pub fn cpu_collect_shard_side_effects>( +pub fn cpu_collect_shardram>( config: &I::InstructionConfig, shard_ctx: &mut ShardContext, shard_steps: &[ceno_emul::StepRecord], step_indices: &[StepIndex], ) -> Result, ZKVMError> { - cpu_collect_side_effects_inner::(config, shard_ctx, shard_steps, step_indices, true) + cpu_collect_lk_shardram_inner::(config, shard_ctx, shard_steps, step_indices, true) } -fn cpu_collect_side_effects_inner>( +fn cpu_collect_lk_shardram_inner>( config: &I::InstructionConfig, shard_ctx: &mut ShardContext, shard_steps: &[ceno_emul::StepRecord], @@ -143,14 +143,14 @@ fn cpu_collect_side_effects_inner>( .copied() .map(|step_idx| { if shard_only { - I::collect_shard_side_effects_instance( + I::collect_shardram( config, &mut shard_ctx, &mut lk_multiplicity, &shard_steps[step_idx], ) } else { - I::collect_side_effects_instance( + I::collect_lk_and_shardram( config, &mut shard_ctx, &mut lk_multiplicity, diff --git a/ceno_zkvm/src/instructions/host_ops/mod.rs b/ceno_zkvm/src/instructions/host_ops/mod.rs index d18ed10b8..194fd4b64 100644 --- a/ceno_zkvm/src/instructions/host_ops/mod.rs +++ b/ceno_zkvm/src/instructions/host_ops/mod.rs @@ -20,8 +20,8 @@ mod tests { circuit_builder::{CircuitBuilder, ConstraintSystem}, e2e::ShardContext, instructions::{ - Instruction, cpu_assign_instances, cpu_collect_shard_side_effects, - cpu_collect_side_effects, + Instruction, cpu_assign_instances, cpu_collect_shardram, + cpu_collect_lk_and_shardram, riscv::{ AddInstruction, JalInstruction, JalrInstruction, LwInstruction, SbInstruction, branch::{BeqInstruction, BltInstruction}, @@ -67,7 +67,7 @@ mod tests { let mut collect_ctx = ShardContext::default(); let actual_lk = - cpu_collect_side_effects::(config, &mut collect_ctx, steps, &indices).unwrap(); + cpu_collect_lk_and_shardram::(config, &mut collect_ctx, steps, &indices).unwrap(); assert_eq!(flatten_lk(&expected_lk), flatten_lk(&actual_lk)); assert_eq!( @@ -105,7 +105,7 @@ mod tests { let mut collect_ctx = ShardContext::default(); let actual_lk = - cpu_collect_shard_side_effects::(config, &mut collect_ctx, steps, &indices) + cpu_collect_shardram::(config, &mut collect_ctx, steps, &indices) .unwrap(); assert_eq!( diff --git a/ceno_zkvm/src/instructions/riscv/arith.rs b/ceno_zkvm/src/instructions/riscv/arith.rs index d7c213e32..edbd86b6c 100644 --- a/ceno_zkvm/src/instructions/riscv/arith.rs +++ b/ceno_zkvm/src/instructions/riscv/arith.rs @@ -2,7 +2,7 @@ use std::marker::PhantomData; use super::{RIVInstruction, constants::UInt, r_insn::RInstructionConfig}; use crate::{ - impl_collect_shard, impl_collect_side_effects, impl_gpu_assign, + impl_collect_shardram, impl_collect_lk_and_shardram, impl_gpu_assign, circuit_builder::CircuitBuilder, e2e::ShardContext, error::ZKVMError, @@ -144,7 +144,7 @@ impl Instruction for ArithInstruction { emit_u16_limbs(sink, step.rd().unwrap().value.after); @@ -157,7 +157,7 @@ impl Instruction for ArithInstruction Some(witgen_gpu::GpuWitgenKind::Add), diff --git a/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit_v2.rs index 288fe0a5b..442372323 100644 --- a/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit_v2.rs @@ -3,7 +3,7 @@ use crate::{ circuit_builder::CircuitBuilder, e2e::ShardContext, error::ZKVMError, - impl_collect_shard, impl_collect_side_effects, impl_gpu_assign, + impl_collect_shardram, impl_collect_lk_and_shardram, impl_gpu_assign, instructions::{ Instruction, riscv::{RIVInstruction, constants::UInt, i_insn::IInstructionConfig}, @@ -109,11 +109,11 @@ impl Instruction for AddiInstruction { Ok(()) } - impl_collect_side_effects!(i_insn, |sink, step, _config, _ctx| { + impl_collect_lk_and_shardram!(i_insn, |sink, step, _config, _ctx| { emit_u16_limbs(sink, step.rd().unwrap().value.after); }); - impl_collect_shard!(i_insn); + impl_collect_shardram!(i_insn); impl_gpu_assign!(witgen_gpu::GpuWitgenKind::Addi); } diff --git a/ceno_zkvm/src/instructions/riscv/auipc.rs b/ceno_zkvm/src/instructions/riscv/auipc.rs index 15f4a44aa..04055d4b7 100644 --- a/ceno_zkvm/src/instructions/riscv/auipc.rs +++ b/ceno_zkvm/src/instructions/riscv/auipc.rs @@ -6,7 +6,7 @@ use crate::{ circuit_builder::CircuitBuilder, e2e::ShardContext, error::ZKVMError, - impl_collect_shard, impl_collect_side_effects, impl_gpu_assign, + impl_collect_shardram, impl_collect_lk_and_shardram, impl_gpu_assign, instructions::{ Instruction, riscv::{ @@ -193,7 +193,7 @@ impl Instruction for AuipcInstruction { Ok(()) } - impl_collect_side_effects!(i_insn, |sink, step, config, _ctx| { + impl_collect_lk_and_shardram!(i_insn, |sink, step, config, _ctx| { let rd_written = split_to_u8(step.rd().unwrap().value.after); emit_byte_decomposition_ops(sink, &rd_written); @@ -221,7 +221,7 @@ impl Instruction for AuipcInstruction { }); }); - impl_collect_shard!(i_insn); + impl_collect_shardram!(i_insn); impl_gpu_assign!(witgen_gpu::GpuWitgenKind::Auipc); } diff --git a/ceno_zkvm/src/instructions/riscv/b_insn.rs b/ceno_zkvm/src/instructions/riscv/b_insn.rs index c33ca6037..a7e7acb1f 100644 --- a/ceno_zkvm/src/instructions/riscv/b_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/b_insn.rs @@ -115,18 +115,18 @@ impl BInstructionConfig { Ok(()) } - pub fn collect_shard_effects( + pub fn emit_shardram( &self, shard_ctx: &mut ShardContext, lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, ) { lk_multiplicity.fetch(step.pc().before.0); - self.rs1.collect_shard_effects(shard_ctx, step); - self.rs2.collect_shard_effects(shard_ctx, step); + self.rs1.emit_shardram(shard_ctx, step); + self.rs2.emit_shardram(shard_ctx, step); } - pub fn collect_side_effects( + pub fn emit_lk_and_shardram( &self, sink: &mut impl SideEffectSink, shard_ctx: &ShardContext, @@ -135,7 +135,7 @@ impl BInstructionConfig { sink.emit_lk(LkOp::Fetch { pc: step.pc().before.0, }); - self.rs1.collect_side_effects(sink, shard_ctx, step); - self.rs2.collect_side_effects(sink, shard_ctx, step); + self.rs1.emit_lk_and_shardram(sink, shard_ctx, step); + self.rs2.emit_lk_and_shardram(sink, shard_ctx, step); } } diff --git a/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs index 888d7062a..13ea76578 100644 --- a/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs @@ -4,7 +4,7 @@ use crate::{ e2e::ShardContext, error::ZKVMError, gadgets::{UIntLimbsLT, UIntLimbsLTConfig}, - impl_collect_shard, impl_collect_side_effects, impl_gpu_assign, + impl_collect_shardram, impl_collect_lk_and_shardram, impl_gpu_assign, instructions::{ Instruction, riscv::{ @@ -209,7 +209,7 @@ impl Instruction for BranchCircuit Instruction for BranchCircuit Some(witgen_gpu::GpuWitgenKind::BranchEq(1)), diff --git a/ceno_zkvm/src/instructions/riscv/div/div_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/div/div_circuit_v2.rs index e73298136..ce455fcaa 100644 --- a/ceno_zkvm/src/instructions/riscv/div/div_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/div/div_circuit_v2.rs @@ -14,7 +14,7 @@ use crate::{ circuit_builder::CircuitBuilder, e2e::ShardContext, error::ZKVMError, - impl_collect_shard, impl_collect_side_effects, impl_gpu_assign, + impl_collect_shardram, impl_collect_lk_and_shardram, impl_gpu_assign, instructions::{ Instruction, riscv::constants::LIMB_BITS, @@ -538,7 +538,7 @@ impl Instruction for ArithInstruction Instruction for ArithInstruction(config, &mut cpu_ctx, shard_steps, step_indices)?; + let _ = cpu_collect_lk_and_shardram::(config, &mut cpu_ctx, shard_steps, step_indices)?; let mut mixed_ctx = shard_ctx.new_empty_like(); let _ = - cpu_collect_shard_side_effects::(config, &mut mixed_ctx, shard_steps, step_indices)?; + cpu_collect_shardram::(config, &mut mixed_ctx, shard_steps, step_indices)?; let cpu_addr = cpu_ctx.get_addr_accessed(); let mixed_addr = mixed_ctx.get_addr_accessed(); @@ -207,7 +207,7 @@ pub(crate) fn debug_compare_shard_side_effects addr_accessed + write_records + read_records +/// CPU path: cpu_collect_shardram -> addr_accessed + write_records + read_records /// GPU path: compact_records -> shard records (gpu_ec_records) /// ram_slots WAS_SENT -> addr_accessed /// (write_records and read_records stay empty for GPU EC kernels) @@ -242,7 +242,7 @@ pub(crate) fn debug_compare_shard_ec>( // ========== Build CPU shard context (independent, isolated) ========== let mut cpu_ctx = shard_ctx.new_empty_like(); - if let Err(e) = cpu_collect_shard_side_effects::( + if let Err(e) = cpu_collect_shardram::( config, &mut cpu_ctx, shard_steps, step_indices, ) { tracing::error!("[GPU EC debug] kind={kind:?} CPU shard side effects failed: {e:?}"); diff --git a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs index 496c68ff9..2b709bce7 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs @@ -26,7 +26,7 @@ use super::gpu_config::{ use crate::{ e2e::ShardContext, error::ZKVMError, - instructions::{Instruction, cpu_collect_shard_side_effects, cpu_collect_side_effects}, + instructions::{Instruction, cpu_collect_shardram, cpu_collect_lk_and_shardram}, tables::RMMCollections, witness::LkMultiplicity, }; @@ -326,7 +326,7 @@ fn gpu_assign_instances_inner>( } else { // CPU: collect shard records only (send/addr_accessed). info_span!("cpu_shard_records").in_scope(|| { - let _ = collect_shard_side_effects::(config, shard_ctx, shard_steps, step_indices)?; + let _ = collect_shardram::(config, shard_ctx, shard_steps, step_indices)?; Ok::<(), ZKVMError>(()) })?; } @@ -334,7 +334,7 @@ fn gpu_assign_instances_inner>( } else { // GPU LK counters missing or unverified — fall back to full CPU side effects info_span!("cpu_side_effects").in_scope(|| { - collect_side_effects::(config, shard_ctx, shard_steps, step_indices) + collect_lk_and_shardram::(config, shard_ctx, shard_steps, step_indices) })? }; debug_compare_final_lk::(config, shard_ctx, num_witin, num_structural_witin, shard_steps, step_indices, kind, &lk_multiplicity)?; @@ -1156,22 +1156,22 @@ fn gpu_fill_witness>( /// CPU-side loop to collect side effects only (shard_ctx.send, lk_multiplicity). /// Runs assign_instance with a scratch buffer per thread. -fn collect_side_effects>( +fn collect_lk_and_shardram>( config: &I::InstructionConfig, shard_ctx: &mut ShardContext, shard_steps: &[StepRecord], step_indices: &[StepIndex], ) -> Result, ZKVMError> { - cpu_collect_side_effects::(config, shard_ctx, shard_steps, step_indices) + cpu_collect_lk_and_shardram::(config, shard_ctx, shard_steps, step_indices) } -fn collect_shard_side_effects>( +fn collect_shardram>( config: &I::InstructionConfig, shard_ctx: &mut ShardContext, shard_steps: &[StepRecord], step_indices: &[StepIndex], ) -> Result, ZKVMError> { - cpu_collect_shard_side_effects::(config, shard_ctx, shard_steps, step_indices) + cpu_collect_shardram::(config, shard_ctx, shard_steps, step_indices) } /// GPU dispatch entry point for keccak ecall witness generation. diff --git a/ceno_zkvm/src/instructions/riscv/i_insn.rs b/ceno_zkvm/src/instructions/riscv/i_insn.rs index ff3fc72a3..317921307 100644 --- a/ceno_zkvm/src/instructions/riscv/i_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/i_insn.rs @@ -80,7 +80,7 @@ impl IInstructionConfig { Ok(()) } - pub fn collect_side_effects( + pub fn emit_lk_and_shardram( &self, sink: &mut impl SideEffectSink, shard_ctx: &ShardContext, @@ -89,18 +89,18 @@ impl IInstructionConfig { sink.emit_lk(LkOp::Fetch { pc: step.pc().before.0, }); - self.rs1.collect_side_effects(sink, shard_ctx, step); - self.rd.collect_side_effects(sink, shard_ctx, step); + self.rs1.emit_lk_and_shardram(sink, shard_ctx, step); + self.rd.emit_lk_and_shardram(sink, shard_ctx, step); } - pub fn collect_shard_effects( + pub fn emit_shardram( &self, shard_ctx: &mut ShardContext, lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, ) { lk_multiplicity.fetch(step.pc().before.0); - self.rs1.collect_shard_effects(shard_ctx, step); - self.rd.collect_shard_effects(shard_ctx, step); + self.rs1.emit_shardram(shard_ctx, step); + self.rd.emit_shardram(shard_ctx, step); } } diff --git a/ceno_zkvm/src/instructions/riscv/im_insn.rs b/ceno_zkvm/src/instructions/riscv/im_insn.rs index cafd104aa..4414261e5 100644 --- a/ceno_zkvm/src/instructions/riscv/im_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/im_insn.rs @@ -89,7 +89,7 @@ impl IMInstructionConfig { Ok(()) } - pub fn collect_side_effects( + pub fn emit_lk_and_shardram( &self, sink: &mut impl SideEffectSink, shard_ctx: &ShardContext, @@ -98,20 +98,20 @@ impl IMInstructionConfig { sink.emit_lk(LkOp::Fetch { pc: step.pc().before.0, }); - self.rs1.collect_side_effects(sink, shard_ctx, step); - self.rd.collect_side_effects(sink, shard_ctx, step); - self.mem_read.collect_side_effects(sink, shard_ctx, step); + self.rs1.emit_lk_and_shardram(sink, shard_ctx, step); + self.rd.emit_lk_and_shardram(sink, shard_ctx, step); + self.mem_read.emit_lk_and_shardram(sink, shard_ctx, step); } - pub fn collect_shard_effects( + pub fn emit_shardram( &self, shard_ctx: &mut ShardContext, lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, ) { lk_multiplicity.fetch(step.pc().before.0); - self.rs1.collect_shard_effects(shard_ctx, step); - self.rd.collect_shard_effects(shard_ctx, step); - self.mem_read.collect_shard_effects(shard_ctx, step); + self.rs1.emit_shardram(shard_ctx, step); + self.rd.emit_shardram(shard_ctx, step); + self.mem_read.emit_shardram(shard_ctx, step); } } diff --git a/ceno_zkvm/src/instructions/riscv/insn_base.rs b/ceno_zkvm/src/instructions/riscv/insn_base.rs index 51e84be17..84a1e8f07 100644 --- a/ceno_zkvm/src/instructions/riscv/insn_base.rs +++ b/ceno_zkvm/src/instructions/riscv/insn_base.rs @@ -143,7 +143,7 @@ impl ReadRS1 { Ok(()) } - pub fn collect_side_effects( + pub fn emit_lk_and_shardram( &self, sink: &mut impl SideEffectSink, shard_ctx: &ShardContext, @@ -170,7 +170,7 @@ impl ReadRS1 { sink.touch_addr(op.addr); } - pub fn collect_shard_effects(&self, shard_ctx: &mut ShardContext, step: &StepRecord) { + pub fn emit_shardram(&self, shard_ctx: &mut ShardContext, step: &StepRecord) { let op = step.rs1().expect("rs1 op"); shard_ctx.record_send_without_touch( RAMType::Register, @@ -252,7 +252,7 @@ impl ReadRS2 { Ok(()) } - pub fn collect_side_effects( + pub fn emit_lk_and_shardram( &self, sink: &mut impl SideEffectSink, shard_ctx: &ShardContext, @@ -279,7 +279,7 @@ impl ReadRS2 { sink.touch_addr(op.addr); } - pub fn collect_shard_effects(&self, shard_ctx: &mut ShardContext, step: &StepRecord) { + pub fn emit_shardram(&self, shard_ctx: &mut ShardContext, step: &StepRecord) { let op = step.rs2().expect("rs2 op"); shard_ctx.record_send_without_touch( RAMType::Register, @@ -379,7 +379,7 @@ impl WriteRD { Ok(()) } - pub fn collect_op_side_effects( + pub fn emit_op_lk_and_shardram( &self, sink: &mut impl SideEffectSink, shard_ctx: &ShardContext, @@ -406,22 +406,22 @@ impl WriteRD { sink.touch_addr(op.addr); } - pub fn collect_side_effects( + pub fn emit_lk_and_shardram( &self, sink: &mut impl SideEffectSink, shard_ctx: &ShardContext, step: &StepRecord, ) { let op = step.rd().expect("rd op"); - self.collect_op_side_effects(sink, shard_ctx, step.cycle(), &op) + self.emit_op_lk_and_shardram(sink, shard_ctx, step.cycle(), &op) } - pub fn collect_shard_effects(&self, shard_ctx: &mut ShardContext, step: &StepRecord) { + pub fn emit_shardram(&self, shard_ctx: &mut ShardContext, step: &StepRecord) { let op = step.rd().expect("rd op"); - self.collect_op_shard_effects(shard_ctx, step.cycle(), &op) + self.emit_op_shardram(shard_ctx, step.cycle(), &op) } - pub fn collect_op_shard_effects( + pub fn emit_op_shardram( &self, shard_ctx: &mut ShardContext, cycle: Cycle, @@ -505,7 +505,7 @@ impl ReadMEM { Ok(()) } - pub fn collect_side_effects( + pub fn emit_lk_and_shardram( &self, sink: &mut impl SideEffectSink, shard_ctx: &ShardContext, @@ -532,7 +532,7 @@ impl ReadMEM { sink.touch_addr(op.addr); } - pub fn collect_shard_effects(&self, shard_ctx: &mut ShardContext, step: &StepRecord) { + pub fn emit_shardram(&self, shard_ctx: &mut ShardContext, step: &StepRecord) { let op = step.memory_op().expect("memory op"); shard_ctx.record_send_without_touch( RAMType::Memory, @@ -619,7 +619,7 @@ impl WriteMEM { Ok(()) } - pub fn collect_op_side_effects( + pub fn emit_op_lk_and_shardram( &self, sink: &mut impl SideEffectSink, shard_ctx: &ShardContext, @@ -646,22 +646,22 @@ impl WriteMEM { sink.touch_addr(op.addr); } - pub fn collect_side_effects( + pub fn emit_lk_and_shardram( &self, sink: &mut impl SideEffectSink, shard_ctx: &ShardContext, step: &StepRecord, ) { let op = step.memory_op().expect("memory op"); - self.collect_op_side_effects(sink, shard_ctx, step.cycle(), &op) + self.emit_op_lk_and_shardram(sink, shard_ctx, step.cycle(), &op) } - pub fn collect_shard_effects(&self, shard_ctx: &mut ShardContext, step: &StepRecord) { + pub fn emit_shardram(&self, shard_ctx: &mut ShardContext, step: &StepRecord) { let op = step.memory_op().expect("memory op"); - self.collect_op_shard_effects(shard_ctx, step.cycle(), &op) + self.emit_op_shardram(shard_ctx, step.cycle(), &op) } - pub fn collect_op_shard_effects( + pub fn emit_op_shardram( &self, shard_ctx: &mut ShardContext, cycle: Cycle, @@ -828,7 +828,7 @@ impl MemAddr { Ok(()) } - pub fn collect_side_effects(&self, sink: &mut impl SideEffectSink, addr: Word) { + pub fn emit_lk_and_shardram(&self, sink: &mut impl SideEffectSink, addr: Word) { let mid_u14 = ((addr & 0xffff) >> Self::N_LOW_BITS) as u16; sink.emit_lk(LkOp::AssertU14 { value: mid_u14 }); diff --git a/ceno_zkvm/src/instructions/riscv/j_insn.rs b/ceno_zkvm/src/instructions/riscv/j_insn.rs index eb3f7b693..219af3e4c 100644 --- a/ceno_zkvm/src/instructions/riscv/j_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/j_insn.rs @@ -72,17 +72,17 @@ impl JInstructionConfig { Ok(()) } - pub fn collect_shard_effects( + pub fn emit_shardram( &self, shard_ctx: &mut ShardContext, lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, ) { lk_multiplicity.fetch(step.pc().before.0); - self.rd.collect_shard_effects(shard_ctx, step); + self.rd.emit_shardram(shard_ctx, step); } - pub fn collect_side_effects( + pub fn emit_lk_and_shardram( &self, sink: &mut impl SideEffectSink, shard_ctx: &ShardContext, @@ -91,6 +91,6 @@ impl JInstructionConfig { sink.emit_lk(LkOp::Fetch { pc: step.pc().before.0, }); - self.rd.collect_side_effects(sink, shard_ctx, step); + self.rd.emit_lk_and_shardram(sink, shard_ctx, step); } } diff --git a/ceno_zkvm/src/instructions/riscv/jump/jal_v2.rs b/ceno_zkvm/src/instructions/riscv/jump/jal_v2.rs index 4ed671eba..1af4ad390 100644 --- a/ceno_zkvm/src/instructions/riscv/jump/jal_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/jump/jal_v2.rs @@ -6,7 +6,7 @@ use crate::{ circuit_builder::CircuitBuilder, e2e::ShardContext, error::ZKVMError, - impl_collect_shard, impl_collect_side_effects, impl_gpu_assign, + impl_collect_shardram, impl_collect_lk_and_shardram, impl_gpu_assign, instructions::{ Instruction, riscv::{ @@ -126,7 +126,7 @@ impl Instruction for JalInstruction { Ok(()) } - impl_collect_side_effects!(j_insn, |sink, step, _config, _ctx| { + impl_collect_lk_and_shardram!(j_insn, |sink, step, _config, _ctx| { let rd_written = split_to_u8(step.rd().unwrap().value.after); emit_byte_decomposition_ops(sink, &rd_written); @@ -139,7 +139,7 @@ impl Instruction for JalInstruction { }); }); - impl_collect_shard!(j_insn); + impl_collect_shardram!(j_insn); impl_gpu_assign!(witgen_gpu::GpuWitgenKind::Jal); } diff --git a/ceno_zkvm/src/instructions/riscv/jump/jalr_v2.rs b/ceno_zkvm/src/instructions/riscv/jump/jalr_v2.rs index 7ca91ddbd..15b047b10 100644 --- a/ceno_zkvm/src/instructions/riscv/jump/jalr_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/jump/jalr_v2.rs @@ -7,7 +7,7 @@ use crate::{ circuit_builder::CircuitBuilder, e2e::ShardContext, error::ZKVMError, - impl_collect_shard, impl_collect_side_effects, impl_gpu_assign, + impl_collect_shardram, impl_collect_lk_and_shardram, impl_gpu_assign, instructions::{ Instruction, riscv::{ @@ -193,7 +193,7 @@ impl Instruction for JalrInstruction { Ok(()) } - impl_collect_side_effects!(i_insn, |sink, step, config, _ctx| { + impl_collect_lk_and_shardram!(i_insn, |sink, step, config, _ctx| { let rd_value = Value::new_unchecked(step.rd().unwrap().value.after); let rd_limb = rd_value.as_u16_limbs(); emit_const_range_op(sink, rd_limb[0] as u64, 16); @@ -201,10 +201,10 @@ impl Instruction for JalrInstruction { let imm = InsnRecord::::imm_internal(&step.insn()); let jump_pc = step.rs1().unwrap().value.wrapping_add_signed(imm.0 as i32); - config.jump_pc_addr.collect_side_effects(sink, jump_pc); + config.jump_pc_addr.emit_lk_and_shardram(sink, jump_pc); }); - impl_collect_shard!(i_insn); + impl_collect_shardram!(i_insn); impl_gpu_assign!(witgen_gpu::GpuWitgenKind::Jalr); } diff --git a/ceno_zkvm/src/instructions/riscv/logic/logic_circuit.rs b/ceno_zkvm/src/instructions/riscv/logic/logic_circuit.rs index 439701c5b..f7469645a 100644 --- a/ceno_zkvm/src/instructions/riscv/logic/logic_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/logic/logic_circuit.rs @@ -8,7 +8,7 @@ use crate::{ circuit_builder::CircuitBuilder, e2e::ShardContext, error::ZKVMError, - impl_collect_shard, impl_collect_side_effects, impl_gpu_assign, + impl_collect_shardram, impl_collect_lk_and_shardram, impl_gpu_assign, instructions::{ Instruction, riscv::{constants::UInt8, r_insn::RInstructionConfig}, @@ -77,7 +77,7 @@ impl Instruction for LogicInstruction { config.assign_instance(instance, shard_ctx, lk_multiplicity, step) } - impl_collect_side_effects!(r_insn, |sink, step, _config, _ctx| { + impl_collect_lk_and_shardram!(r_insn, |sink, step, _config, _ctx| { emit_logic_u8_ops::( sink, step.rs1().unwrap().value as u64, @@ -86,7 +86,7 @@ impl Instruction for LogicInstruction { ); }); - impl_collect_shard!(r_insn); + impl_collect_shardram!(r_insn); impl_gpu_assign!(witgen_gpu::GpuWitgenKind::LogicR(match I::INST_KIND { InsnKind::AND => 0, @@ -154,12 +154,12 @@ impl LogicConfig { Ok(()) } - fn collect_side_effects( + fn emit_lk_and_shardram( &self, sink: &mut impl crate::instructions::side_effects::SideEffectSink, shard_ctx: &ShardContext, step: &StepRecord, ) { - self.r_insn.collect_side_effects(sink, shard_ctx, step); + self.r_insn.emit_lk_and_shardram(sink, shard_ctx, step); } } diff --git a/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit_v2.rs index 243a9ee6d..5d43380a0 100644 --- a/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit_v2.rs @@ -9,7 +9,7 @@ use crate::{ circuit_builder::CircuitBuilder, e2e::ShardContext, error::ZKVMError, - impl_collect_shard, impl_collect_side_effects, impl_gpu_assign, + impl_collect_shardram, impl_collect_lk_and_shardram, impl_gpu_assign, instructions::{ Instruction, riscv::{ @@ -129,7 +129,7 @@ impl Instruction for LogicInstruction { config.assign_instance(instance, shard_ctx, lkm, step) } - impl_collect_side_effects!(i_insn, |sink, step, _config, _ctx| { + impl_collect_lk_and_shardram!(i_insn, |sink, step, _config, _ctx| { let rs1_lo = step.rs1().unwrap().value & LIMB_MASK; let rs1_hi = (step.rs1().unwrap().value >> LIMB_BITS) & LIMB_MASK; let imm_lo = InsnRecord::::imm_internal(&step.insn()).0 as u32 & LIMB_MASK; @@ -141,7 +141,7 @@ impl Instruction for LogicInstruction { emit_logic_u8_ops::(sink, rs1_hi.into(), imm_hi.into(), 2); }); - impl_collect_shard!(i_insn); + impl_collect_shardram!(i_insn); impl_gpu_assign!(witgen_gpu::GpuWitgenKind::LogicI(match I::INST_KIND { InsnKind::ANDI => 0, diff --git a/ceno_zkvm/src/instructions/riscv/lui.rs b/ceno_zkvm/src/instructions/riscv/lui.rs index e3d251727..52983fe40 100644 --- a/ceno_zkvm/src/instructions/riscv/lui.rs +++ b/ceno_zkvm/src/instructions/riscv/lui.rs @@ -6,7 +6,7 @@ use crate::{ circuit_builder::CircuitBuilder, e2e::ShardContext, error::ZKVMError, - impl_collect_shard, impl_collect_side_effects, impl_gpu_assign, + impl_collect_shardram, impl_collect_lk_and_shardram, impl_gpu_assign, instructions::{ Instruction, riscv::{ @@ -118,14 +118,14 @@ impl Instruction for LuiInstruction { Ok(()) } - impl_collect_side_effects!(i_insn, |sink, step, _config, _ctx| { + impl_collect_lk_and_shardram!(i_insn, |sink, step, _config, _ctx| { let rd_written = split_to_u8::(step.rd().unwrap().value.after); for val in rd_written.iter().skip(1) { emit_const_range_op(sink, *val as u64, 8); } }); - impl_collect_shard!(i_insn); + impl_collect_shardram!(i_insn); impl_gpu_assign!(witgen_gpu::GpuWitgenKind::Lui); } diff --git a/ceno_zkvm/src/instructions/riscv/memory/load.rs b/ceno_zkvm/src/instructions/riscv/memory/load.rs index c416e2a48..d1c93194a 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/load.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/load.rs @@ -4,7 +4,7 @@ use crate::{ e2e::ShardContext, error::ZKVMError, gadgets::SignedExtendConfig, - impl_collect_shard, impl_collect_side_effects, impl_gpu_assign, + impl_collect_shardram, impl_collect_lk_and_shardram, impl_gpu_assign, instructions::{ Instruction, riscv::{ @@ -230,16 +230,16 @@ impl Instruction for LoadInstruction::imm_internal(&step.insn()); let unaligned_addr = ByteAddr::from(step.rs1().unwrap().value.wrapping_add_signed(imm.0 as i32)); config .memory_addr - .collect_side_effects(sink, unaligned_addr.into()); + .emit_lk_and_shardram(sink, unaligned_addr.into()); }); - impl_collect_shard!(im_insn); + impl_collect_shardram!(im_insn); impl_gpu_assign!(match I::INST_KIND { InsnKind::LW => Some(witgen_gpu::GpuWitgenKind::Lw), diff --git a/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs b/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs index 8aab7bbef..d8f6b8128 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs @@ -4,7 +4,7 @@ use crate::{ e2e::ShardContext, error::ZKVMError, gadgets::SignedExtendConfig, - impl_collect_shard, impl_collect_side_effects, impl_gpu_assign, + impl_collect_shardram, impl_collect_lk_and_shardram, impl_gpu_assign, instructions::{ Instruction, riscv::{ @@ -258,7 +258,7 @@ impl Instruction for LoadInstruction::imm_internal(&step.insn()); @@ -266,10 +266,10 @@ impl Instruction for LoadInstruction Some(witgen_gpu::GpuWitgenKind::Lw), diff --git a/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs b/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs index 9f96cfeb7..b6c1e3064 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs @@ -3,7 +3,7 @@ use crate::{ circuit_builder::CircuitBuilder, e2e::ShardContext, error::ZKVMError, - impl_collect_shard, impl_collect_side_effects, impl_gpu_assign, + impl_collect_shardram, impl_collect_lk_and_shardram, impl_gpu_assign, instructions::{ Instruction, riscv::{ @@ -176,14 +176,14 @@ impl Instruction Ok(()) } - impl_collect_side_effects!(s_insn, |sink, step, config, _ctx| { + impl_collect_lk_and_shardram!(s_insn, |sink, step, config, _ctx| { emit_u16_limbs(sink, step.memory_op().unwrap().value.before); let imm = InsnRecord::::imm_internal(&step.insn()); let addr = ByteAddr::from(step.rs1().unwrap().value.wrapping_add_signed(imm.0 as i32)); config .memory_addr - .collect_side_effects(sink, addr.into()); + .emit_lk_and_shardram(sink, addr.into()); if N_ZEROS == 0 { let memory_op = step.memory_op().unwrap(); @@ -201,7 +201,7 @@ impl Instruction } }); - impl_collect_shard!(s_insn); + impl_collect_shardram!(s_insn); impl_gpu_assign!(match I::INST_KIND { InsnKind::SW => Some(witgen_gpu::GpuWitgenKind::Sw), diff --git a/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit_v2.rs index e0c30b781..47180ecfb 100644 --- a/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit_v2.rs @@ -1,7 +1,7 @@ use crate::{ circuit_builder::CircuitBuilder, error::ZKVMError, - impl_collect_shard, impl_collect_side_effects, impl_gpu_assign, + impl_collect_shardram, impl_collect_lk_and_shardram, impl_gpu_assign, instructions::{ Instruction, riscv::{ @@ -332,7 +332,7 @@ impl Instruction for MulhInstructionBas Ok(()) } - impl_collect_side_effects!(r_insn, |sink, step, _config, _ctx| { + impl_collect_lk_and_shardram!(r_insn, |sink, step, _config, _ctx| { let rs1 = step.rs1().unwrap().value; let rs1_val = Value::new_unchecked(rs1); let rs2 = step.rs2().unwrap().value; @@ -413,7 +413,7 @@ impl Instruction for MulhInstructionBas } }); - impl_collect_shard!(r_insn); + impl_collect_shardram!(r_insn); impl_gpu_assign!(match I::INST_KIND { InsnKind::MUL => Some(witgen_gpu::GpuWitgenKind::Mul(0u32)), diff --git a/ceno_zkvm/src/instructions/riscv/r_insn.rs b/ceno_zkvm/src/instructions/riscv/r_insn.rs index b0e8089d0..bf68809de 100644 --- a/ceno_zkvm/src/instructions/riscv/r_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/r_insn.rs @@ -85,7 +85,7 @@ impl RInstructionConfig { Ok(()) } - pub fn collect_side_effects( + pub fn emit_lk_and_shardram( &self, sink: &mut impl SideEffectSink, shard_ctx: &ShardContext, @@ -94,20 +94,20 @@ impl RInstructionConfig { sink.emit_lk(LkOp::Fetch { pc: step.pc().before.0, }); - self.rs1.collect_side_effects(sink, shard_ctx, step); - self.rs2.collect_side_effects(sink, shard_ctx, step); - self.rd.collect_side_effects(sink, shard_ctx, step); + self.rs1.emit_lk_and_shardram(sink, shard_ctx, step); + self.rs2.emit_lk_and_shardram(sink, shard_ctx, step); + self.rd.emit_lk_and_shardram(sink, shard_ctx, step); } - pub fn collect_shard_effects( + pub fn emit_shardram( &self, shard_ctx: &mut ShardContext, lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, ) { lk_multiplicity.fetch(step.pc().before.0); - self.rs1.collect_shard_effects(shard_ctx, step); - self.rs2.collect_shard_effects(shard_ctx, step); - self.rd.collect_shard_effects(shard_ctx, step); + self.rs1.emit_shardram(shard_ctx, step); + self.rs2.emit_shardram(shard_ctx, step); + self.rd.emit_shardram(shard_ctx, step); } } diff --git a/ceno_zkvm/src/instructions/riscv/s_insn.rs b/ceno_zkvm/src/instructions/riscv/s_insn.rs index 9b5f8d88e..332becf61 100644 --- a/ceno_zkvm/src/instructions/riscv/s_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/s_insn.rs @@ -95,20 +95,20 @@ impl SInstructionConfig { Ok(()) } - pub fn collect_shard_effects( + pub fn emit_shardram( &self, shard_ctx: &mut ShardContext, lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, ) { lk_multiplicity.fetch(step.pc().before.0); - self.rs1.collect_shard_effects(shard_ctx, step); - self.rs2.collect_shard_effects(shard_ctx, step); - self.mem_write.collect_shard_effects(shard_ctx, step); + self.rs1.emit_shardram(shard_ctx, step); + self.rs2.emit_shardram(shard_ctx, step); + self.mem_write.emit_shardram(shard_ctx, step); } #[allow(dead_code)] - pub fn collect_side_effects( + pub fn emit_lk_and_shardram( &self, sink: &mut impl SideEffectSink, shard_ctx: &ShardContext, @@ -117,8 +117,8 @@ impl SInstructionConfig { sink.emit_lk(LkOp::Fetch { pc: step.pc().before.0, }); - self.rs1.collect_side_effects(sink, shard_ctx, step); - self.rs2.collect_side_effects(sink, shard_ctx, step); - self.mem_write.collect_side_effects(sink, shard_ctx, step); + self.rs1.emit_lk_and_shardram(sink, shard_ctx, step); + self.rs2.emit_lk_and_shardram(sink, shard_ctx, step); + self.mem_write.emit_lk_and_shardram(sink, shard_ctx, step); } } diff --git a/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs index 3726af270..61411dbd0 100644 --- a/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs @@ -1,7 +1,7 @@ use crate::e2e::ShardContext; /// constrain implementation follow from https://github.com/openvm-org/openvm/blob/main/extensions/rv32im/circuit/src/shift/core.rs use crate::{ - impl_collect_shard, impl_collect_side_effects, impl_gpu_assign, + impl_collect_shardram, impl_collect_lk_and_shardram, impl_gpu_assign, instructions::{ Instruction, riscv::{ @@ -211,7 +211,7 @@ impl }) } - pub fn collect_side_effects( + pub fn emit_lk_and_shardram( &self, sink: &mut impl SideEffectSink, kind: InsnKind, @@ -410,10 +410,10 @@ impl Instruction for ShiftLogicalInstru Ok(()) } - impl_collect_side_effects!(r_insn, |sink, step, config, _ctx| { + impl_collect_lk_and_shardram!(r_insn, |sink, step, config, _ctx| { let rd_written = split_to_u8::(step.rd().unwrap().value.after); emit_byte_decomposition_ops(sink, &rd_written); - config.shift_base_config.collect_side_effects( + config.shift_base_config.emit_lk_and_shardram( sink, I::INST_KIND, step.rs1().unwrap().value, @@ -421,7 +421,7 @@ impl Instruction for ShiftLogicalInstru ); }); - impl_collect_shard!(r_insn); + impl_collect_shardram!(r_insn); impl_gpu_assign!(witgen_gpu::GpuWitgenKind::ShiftR(match I::INST_KIND { InsnKind::SLL => 0u32, @@ -535,10 +535,10 @@ impl Instruction for ShiftImmInstructio Ok(()) } - impl_collect_side_effects!(i_insn, |sink, step, config, _ctx| { + impl_collect_lk_and_shardram!(i_insn, |sink, step, config, _ctx| { let rd_written = split_to_u8::(step.rd().unwrap().value.after); emit_byte_decomposition_ops(sink, &rd_written); - config.shift_base_config.collect_side_effects( + config.shift_base_config.emit_lk_and_shardram( sink, I::INST_KIND, step.rs1().unwrap().value, @@ -546,7 +546,7 @@ impl Instruction for ShiftImmInstructio ); }); - impl_collect_shard!(i_insn); + impl_collect_shardram!(i_insn); impl_gpu_assign!(witgen_gpu::GpuWitgenKind::ShiftI(match I::INST_KIND { InsnKind::SLLI => 0u32, diff --git a/ceno_zkvm/src/instructions/riscv/slt/slt_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/slt/slt_circuit_v2.rs index 557131583..9dda4abe2 100644 --- a/ceno_zkvm/src/instructions/riscv/slt/slt_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/slt/slt_circuit_v2.rs @@ -4,7 +4,7 @@ use crate::{ e2e::ShardContext, error::ZKVMError, gadgets::{UIntLimbsLT, UIntLimbsLTConfig}, - impl_collect_shard, impl_collect_side_effects, impl_gpu_assign, + impl_collect_shardram, impl_collect_lk_and_shardram, impl_gpu_assign, instructions::{ Instruction, riscv::{RIVInstruction, constants::UInt, r_insn::RInstructionConfig}, @@ -118,7 +118,7 @@ impl Instruction for SetLessThanInstruc Ok(()) } - impl_collect_side_effects!(r_insn, |sink, step, _config, _ctx| { + impl_collect_lk_and_shardram!(r_insn, |sink, step, _config, _ctx| { let rs1_value = Value::new_unchecked(step.rs1().unwrap().value); let rs2_value = Value::new_unchecked(step.rs2().unwrap().value); let rs1_limbs = rs1_value.as_u16_limbs(); @@ -131,7 +131,7 @@ impl Instruction for SetLessThanInstruc ); }); - impl_collect_shard!(r_insn); + impl_collect_shardram!(r_insn); impl_gpu_assign!(witgen_gpu::GpuWitgenKind::Slt(match I::INST_KIND { InsnKind::SLT => 1u32, diff --git a/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs index 47c60443a..d56076fd1 100644 --- a/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs @@ -4,7 +4,7 @@ use crate::{ e2e::ShardContext, error::ZKVMError, gadgets::{UIntLimbsLT, UIntLimbsLTConfig}, - impl_collect_shard, impl_collect_side_effects, impl_gpu_assign, + impl_collect_shardram, impl_collect_lk_and_shardram, impl_gpu_assign, instructions::{ Instruction, riscv::{ @@ -138,7 +138,7 @@ impl Instruction for SetLessThanImmInst Ok(()) } - impl_collect_side_effects!(i_insn, |sink, step, _config, _ctx| { + impl_collect_lk_and_shardram!(i_insn, |sink, step, _config, _ctx| { let rs1_value = Value::new_unchecked(step.rs1().unwrap().value); let rs1_limbs = rs1_value.as_u16_limbs(); let imm_sign_extend = imm_sign_extend(true, step.insn().imm as i16); @@ -150,7 +150,7 @@ impl Instruction for SetLessThanImmInst ); }); - impl_collect_shard!(i_insn); + impl_collect_shardram!(i_insn); impl_gpu_assign!(witgen_gpu::GpuWitgenKind::Slti(match I::INST_KIND { InsnKind::SLTI => 1u32, From b690229d4b106bec7a68ef817aac28b5c6b68490 Mon Sep 17 00:00:00 2001 From: Velaciela Date: Tue, 24 Mar 2026 11:29:41 +0800 Subject: [PATCH 59/73] gpu: hal.witgen --- ceno_zkvm/src/instructions/riscv/gpu/add.rs | 4 +- ceno_zkvm/src/instructions/riscv/gpu/addi.rs | 4 +- ceno_zkvm/src/instructions/riscv/gpu/auipc.rs | 4 +- .../src/instructions/riscv/gpu/branch_cmp.rs | 4 +- .../src/instructions/riscv/gpu/branch_eq.rs | 4 +- ceno_zkvm/src/instructions/riscv/gpu/d2h.rs | 12 ++--- .../instructions/riscv/gpu/debug_compare.rs | 2 +- .../instructions/riscv/gpu/device_cache.rs | 14 +++--- ceno_zkvm/src/instructions/riscv/gpu/div.rs | 4 +- ceno_zkvm/src/instructions/riscv/gpu/jal.rs | 4 +- ceno_zkvm/src/instructions/riscv/gpu/jalr.rs | 4 +- .../src/instructions/riscv/gpu/keccak.rs | 2 +- .../src/instructions/riscv/gpu/load_sub.rs | 4 +- .../src/instructions/riscv/gpu/logic_i.rs | 4 +- .../src/instructions/riscv/gpu/logic_r.rs | 4 +- ceno_zkvm/src/instructions/riscv/gpu/lui.rs | 4 +- ceno_zkvm/src/instructions/riscv/gpu/lw.rs | 4 +- ceno_zkvm/src/instructions/riscv/gpu/mul.rs | 4 +- ceno_zkvm/src/instructions/riscv/gpu/sb.rs | 4 +- ceno_zkvm/src/instructions/riscv/gpu/sh.rs | 4 +- .../src/instructions/riscv/gpu/shard_ram.rs | 2 +- .../src/instructions/riscv/gpu/shift_i.rs | 4 +- .../src/instructions/riscv/gpu/shift_r.rs | 4 +- ceno_zkvm/src/instructions/riscv/gpu/slt.rs | 4 +- ceno_zkvm/src/instructions/riscv/gpu/slti.rs | 4 +- ceno_zkvm/src/instructions/riscv/gpu/sub.rs | 4 +- ceno_zkvm/src/instructions/riscv/gpu/sw.rs | 4 +- .../src/instructions/riscv/gpu/witgen_gpu.rs | 48 +++++++++---------- ceno_zkvm/src/structs.rs | 6 +-- ceno_zkvm/src/tables/shard_ram.rs | 20 ++++---- 30 files changed, 97 insertions(+), 97 deletions(-) diff --git a/ceno_zkvm/src/instructions/riscv/gpu/add.rs b/ceno_zkvm/src/instructions/riscv/gpu/add.rs index 223b0a4fe..35702d74a 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/add.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/add.rs @@ -1,4 +1,4 @@ -use ceno_gpu::common::witgen_types::AddColumnMap; +use ceno_gpu::common::witgen::types::AddColumnMap; use ff_ext::ExtensionField; use super::colmap_base::{extract_carries, extract_rd, extract_rs1, extract_rs2, extract_state, extract_uint_limbs}; @@ -181,7 +181,7 @@ mod tests { }; let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); - let gpu_result = hal + let gpu_result = hal.witgen .witgen_add(&col_map, &gpu_records, &indices_u32, shard_offset, 0, 0, None, None) .unwrap(); diff --git a/ceno_zkvm/src/instructions/riscv/gpu/addi.rs b/ceno_zkvm/src/instructions/riscv/gpu/addi.rs index 9996633d1..e263250bf 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/addi.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/addi.rs @@ -1,4 +1,4 @@ -use ceno_gpu::common::witgen_types::AddiColumnMap; +use ceno_gpu::common::witgen::types::AddiColumnMap; use ff_ext::ExtensionField; use super::colmap_base::{extract_carries, extract_rd, extract_rs1, extract_state, extract_uint_limbs}; @@ -122,7 +122,7 @@ mod tests { }; let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); - let gpu_result = hal + let gpu_result = hal.witgen .witgen_addi(&col_map, &gpu_records, &indices_u32, shard_offset, 0, 0, None, None) .unwrap(); diff --git a/ceno_zkvm/src/instructions/riscv/gpu/auipc.rs b/ceno_zkvm/src/instructions/riscv/gpu/auipc.rs index 9ece38a58..07b984295 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/auipc.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/auipc.rs @@ -1,4 +1,4 @@ -use ceno_gpu::common::witgen_types::AuipcColumnMap; +use ceno_gpu::common::witgen::types::AuipcColumnMap; use ff_ext::ExtensionField; use super::colmap_base::{extract_rd, extract_rs1, extract_state, extract_uint_limbs}; @@ -124,7 +124,7 @@ mod tests { }; let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); - let gpu_result = hal + let gpu_result = hal.witgen .witgen_auipc(&col_map, &gpu_records, &indices_u32, shard_offset, 0, 0, None, None) .unwrap(); diff --git a/ceno_zkvm/src/instructions/riscv/gpu/branch_cmp.rs b/ceno_zkvm/src/instructions/riscv/gpu/branch_cmp.rs index 6edaf174d..dcba4eef7 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/branch_cmp.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/branch_cmp.rs @@ -1,4 +1,4 @@ -use ceno_gpu::common::witgen_types::BranchCmpColumnMap; +use ceno_gpu::common::witgen::types::BranchCmpColumnMap; use ff_ext::ExtensionField; use super::colmap_base::{extract_rs1, extract_rs2, extract_state_branching, extract_uint_limbs}; @@ -138,7 +138,7 @@ mod tests { }; let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); - let gpu_result = hal + let gpu_result = hal.witgen .witgen_branch_cmp(&col_map, &gpu_records, &indices_u32, shard_offset, 1, 0, 0, None, None) .unwrap(); diff --git a/ceno_zkvm/src/instructions/riscv/gpu/branch_eq.rs b/ceno_zkvm/src/instructions/riscv/gpu/branch_eq.rs index 4a51655bd..99c8621a3 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/branch_eq.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/branch_eq.rs @@ -1,4 +1,4 @@ -use ceno_gpu::common::witgen_types::BranchEqColumnMap; +use ceno_gpu::common::witgen::types::BranchEqColumnMap; use ff_ext::ExtensionField; use super::colmap_base::{extract_rs1, extract_rs2, extract_state_branching, extract_uint_limbs}; @@ -135,7 +135,7 @@ mod tests { }; let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); - let gpu_result = hal + let gpu_result = hal.witgen .witgen_branch_eq(&col_map, &gpu_records, &indices_u32, shard_offset, 1, 0, 0, None, None) .unwrap(); diff --git a/ceno_zkvm/src/instructions/riscv/gpu/d2h.rs b/ceno_zkvm/src/instructions/riscv/gpu/d2h.rs index 0f80e2c84..5ef9068f5 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/d2h.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/d2h.rs @@ -11,7 +11,7 @@ use ceno_emul::WordAddr; use ceno_gpu::{ Buffer, CudaHal, bb31::CudaHalBB31, common::transpose::matrix_transpose, }; -use ceno_gpu::common::witgen_types::{CompactEcResult, GpuRamRecordSlot, GpuShardRamRecord}; +use ceno_gpu::common::witgen::types::{CompactEcResult, GpuRamRecordSlot, GpuShardRamRecord}; use ff_ext::ExtensionField; use gkr_iop::{RAMType, tables::LookupTable, utils::lk_multiplicity::Multiplicity}; use p3::field::FieldAlgebra; @@ -30,9 +30,9 @@ pub(crate) type WitBuf = ceno_gpu::common::BufferImpl< >; pub(crate) type LkBuf = ceno_gpu::common::BufferImpl<'static, u32>; pub(crate) type RamBuf = ceno_gpu::common::BufferImpl<'static, u32>; -pub(crate) type WitResult = ceno_gpu::common::witgen_types::GpuWitnessResult; -pub(crate) type LkResult = ceno_gpu::common::witgen_types::GpuLookupCountersResult; -pub(crate) type CompactEcBuf = ceno_gpu::common::witgen_types::CompactEcResult; +pub(crate) type WitResult = ceno_gpu::common::witgen::types::GpuWitnessResult; +pub(crate) type LkResult = ceno_gpu::common::witgen::types::GpuLookupCountersResult; +pub(crate) type CompactEcBuf = ceno_gpu::common::witgen::types::CompactEcResult; /// CPU-side lightweight scan of GPU-produced RAM record slots. /// @@ -168,7 +168,7 @@ pub fn gpu_batch_continuation_ec( // GPU batch EC computation let result = info_span!("gpu_batch_ec", n = total).in_scope(|| { - hal.batch_continuation_ec(&gpu_records) + hal.witgen.batch_continuation_ec(&gpu_records) }).map_err(|e| { ZKVMError::InvalidWitness(format!("GPU batch EC failed: {e}").into()) })?; @@ -324,7 +324,7 @@ pub(crate) fn gpu_lk_counters_to_multiplicity(counters: LkResult) -> Result( hal: &CudaHalBB31, - gpu_result: ceno_gpu::common::witgen_types::GpuWitnessResult< + gpu_result: ceno_gpu::common::witgen::types::GpuWitnessResult< ceno_gpu::common::BufferImpl<'static, ::BaseField>, >, num_rows: usize, diff --git a/ceno_zkvm/src/instructions/riscv/gpu/debug_compare.rs b/ceno_zkvm/src/instructions/riscv/gpu/debug_compare.rs index 2ee6b630b..f33ac96bd 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/debug_compare.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/debug_compare.rs @@ -7,7 +7,7 @@ /// - CENO_GPU_DEBUG_COMPARE_SHARD: compare shard side effects /// - CENO_GPU_DEBUG_COMPARE_EC: compare EC points use ceno_emul::{StepIndex, StepRecord, WordAddr}; -use ceno_gpu::common::witgen_types::{GpuRamRecordSlot, GpuShardRamRecord}; +use ceno_gpu::common::witgen::types::{GpuRamRecordSlot, GpuShardRamRecord}; use ff_ext::ExtensionField; use gkr_iop::{RAMType, tables::LookupTable, utils::lk_multiplicity::Multiplicity}; use p3::field::FieldAlgebra; diff --git a/ceno_zkvm/src/instructions/riscv/gpu/device_cache.rs b/ceno_zkvm/src/instructions/riscv/gpu/device_cache.rs index 086f1c983..d2d02c6c0 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/device_cache.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/device_cache.rs @@ -6,7 +6,7 @@ use ceno_emul::{StepRecord, WordAddr}; use ceno_gpu::{ Buffer, CudaHal, CudaSlice, bb31::CudaHalBB31, bb31::ShardDeviceBuffers, - common::witgen_types::GpuShardRamRecord, common::witgen_types::GpuShardScalars, + common::witgen::types::GpuShardRamRecord, common::witgen::types::GpuShardScalars, }; use std::cell::RefCell; use tracing::info_span; @@ -266,20 +266,20 @@ pub(crate) fn ensure_shard_metadata_cached( let ec_u32s = ec_capacity * 26; // 26 u32s per GpuShardRamRecord (104 bytes) let addr_capacity = total_ops_estimate.min(256 * 1024 * 1024) as usize; - let shared_ec_buf = hal + let shared_ec_buf = hal.witgen .alloc_u32_zeroed(ec_u32s, None) .map_err(|e| ZKVMError::InvalidWitness(format!("shared_ec_buf alloc: {e}").into()))?; - let shared_ec_count = hal + let shared_ec_count = hal.witgen .alloc_u32_zeroed(1, None) .map_err(|e| { ZKVMError::InvalidWitness(format!("shared_ec_count alloc: {e}").into()) })?; - let shared_addr_buf = hal + let shared_addr_buf = hal.witgen .alloc_u32_zeroed(addr_capacity, None) .map_err(|e| { ZKVMError::InvalidWitness(format!("shared_addr_buf alloc: {e}").into()) })?; - let shared_addr_count = hal + let shared_addr_count = hal.witgen .alloc_u32_zeroed(1, None) .map_err(|e| { ZKVMError::InvalidWitness(format!("shared_addr_count alloc: {e}").into()) @@ -385,7 +385,7 @@ pub fn gpu_batch_continuation_ec_on_device( let n_reads = read_records.len(); let total = n_writes + n_reads; if total == 0 { - let empty = hal.alloc_u32_zeroed(1, None).map_err(|e| { + let empty = hal.witgen.alloc_u32_zeroed(1, None).map_err(|e| { ZKVMError::InvalidWitness(format!("alloc: {e}").into()) })?; return Ok((empty, 0, 0)); @@ -399,7 +399,7 @@ pub fn gpu_batch_continuation_ec_on_device( // GPU batch EC, results stay on device let (device_buf, _count) = info_span!("gpu_batch_ec_on_device", n = total).in_scope(|| { - hal.batch_continuation_ec_on_device(&gpu_records, None) + hal.witgen.batch_continuation_ec_on_device(&gpu_records, None) }).map_err(|e| { ZKVMError::InvalidWitness(format!("GPU batch EC on device failed: {e}").into()) })?; diff --git a/ceno_zkvm/src/instructions/riscv/gpu/div.rs b/ceno_zkvm/src/instructions/riscv/gpu/div.rs index 817495fec..6d6bd11c9 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/div.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/div.rs @@ -1,4 +1,4 @@ -use ceno_gpu::common::witgen_types::DivColumnMap; +use ceno_gpu::common::witgen::types::DivColumnMap; use ff_ext::ExtensionField; use super::colmap_base::{extract_rd, extract_rs1, extract_rs2, extract_state, extract_uint_limbs}; @@ -337,7 +337,7 @@ mod tests { }; let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); - let gpu_result = hal + let gpu_result = hal.witgen .witgen_div( &col_map, &gpu_records, diff --git a/ceno_zkvm/src/instructions/riscv/gpu/jal.rs b/ceno_zkvm/src/instructions/riscv/gpu/jal.rs index fd446f185..21421ffb5 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/jal.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/jal.rs @@ -1,4 +1,4 @@ -use ceno_gpu::common::witgen_types::JalColumnMap; +use ceno_gpu::common::witgen::types::JalColumnMap; use ff_ext::ExtensionField; use super::colmap_base::{extract_rd, extract_state_branching, extract_uint_limbs}; @@ -112,7 +112,7 @@ mod tests { }; let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); - let gpu_result = hal + let gpu_result = hal.witgen .witgen_jal(&col_map, &gpu_records, &indices_u32, shard_offset, 0, 0, None, None) .unwrap(); diff --git a/ceno_zkvm/src/instructions/riscv/gpu/jalr.rs b/ceno_zkvm/src/instructions/riscv/gpu/jalr.rs index 804b14934..1c9447346 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/jalr.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/jalr.rs @@ -1,4 +1,4 @@ -use ceno_gpu::common::witgen_types::JalrColumnMap; +use ceno_gpu::common::witgen::types::JalrColumnMap; use ff_ext::ExtensionField; use super::colmap_base::{ @@ -133,7 +133,7 @@ mod tests { }; let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); - let gpu_result = hal + let gpu_result = hal.witgen .witgen_jalr(&col_map, &gpu_records, &indices_u32, shard_offset, 0, 0, None, None) .unwrap(); diff --git a/ceno_zkvm/src/instructions/riscv/gpu/keccak.rs b/ceno_zkvm/src/instructions/riscv/gpu/keccak.rs index 31b2f206c..2e0654156 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/keccak.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/keccak.rs @@ -1,5 +1,5 @@ use ceno_emul::{StepIndex, StepRecord}; -use ceno_gpu::common::witgen_types::{GpuKeccakInstance, GpuKeccakWriteOp, KeccakColumnMap}; +use ceno_gpu::common::witgen::types::{GpuKeccakInstance, GpuKeccakWriteOp, KeccakColumnMap}; use ff_ext::ExtensionField; use std::sync::Arc; diff --git a/ceno_zkvm/src/instructions/riscv/gpu/load_sub.rs b/ceno_zkvm/src/instructions/riscv/gpu/load_sub.rs index 4efb17c12..df3d35937 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/load_sub.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/load_sub.rs @@ -1,4 +1,4 @@ -use ceno_gpu::common::witgen_types::LoadSubColumnMap; +use ceno_gpu::common::witgen::types::LoadSubColumnMap; use ff_ext::ExtensionField; use super::colmap_base::{extract_rd, extract_read_mem, extract_rs1, extract_state, extract_uint_limbs}; @@ -349,7 +349,7 @@ mod tests { let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); let load_width: u32 = if is_byte { 8 } else { 16 }; let is_signed_u32: u32 = if is_signed { 1 } else { 0 }; - let gpu_result = hal + let gpu_result = hal.witgen .witgen_load_sub( &col_map, &gpu_records, diff --git a/ceno_zkvm/src/instructions/riscv/gpu/logic_i.rs b/ceno_zkvm/src/instructions/riscv/gpu/logic_i.rs index ffb6b5507..e61e24796 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/logic_i.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/logic_i.rs @@ -1,4 +1,4 @@ -use ceno_gpu::common::witgen_types::LogicIColumnMap; +use ceno_gpu::common::witgen::types::LogicIColumnMap; use ff_ext::ExtensionField; use super::colmap_base::{extract_rd, extract_rs1, extract_state, extract_uint_limbs}; @@ -139,7 +139,7 @@ mod tests { }; let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); - let gpu_result = hal + let gpu_result = hal.witgen .witgen_logic_i(&col_map, &gpu_records, &indices_u32, shard_offset, 0, 0, 0, None, None) .unwrap(); diff --git a/ceno_zkvm/src/instructions/riscv/gpu/logic_r.rs b/ceno_zkvm/src/instructions/riscv/gpu/logic_r.rs index d084fe324..9803777b4 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/logic_r.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/logic_r.rs @@ -1,4 +1,4 @@ -use ceno_gpu::common::witgen_types::LogicRColumnMap; +use ceno_gpu::common::witgen::types::LogicRColumnMap; use ff_ext::ExtensionField; use super::colmap_base::{extract_rd, extract_rs1, extract_rs2, extract_state, extract_uint_limbs}; @@ -173,7 +173,7 @@ mod tests { }; let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); - let gpu_result = hal + let gpu_result = hal.witgen .witgen_logic_r(&col_map, &gpu_records, &indices_u32, shard_offset, 0, 0, 0, None, None) .unwrap(); diff --git a/ceno_zkvm/src/instructions/riscv/gpu/lui.rs b/ceno_zkvm/src/instructions/riscv/gpu/lui.rs index 7ad254ed1..b72832a24 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/lui.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/lui.rs @@ -1,4 +1,4 @@ -use ceno_gpu::common::witgen_types::LuiColumnMap; +use ceno_gpu::common::witgen::types::LuiColumnMap; use ff_ext::ExtensionField; use super::colmap_base::{extract_rd, extract_rs1, extract_state}; @@ -123,7 +123,7 @@ mod tests { }; let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); - let gpu_result = hal + let gpu_result = hal.witgen .witgen_lui(&col_map, &gpu_records, &indices_u32, shard_offset, 0, 0, None, None) .unwrap(); diff --git a/ceno_zkvm/src/instructions/riscv/gpu/lw.rs b/ceno_zkvm/src/instructions/riscv/gpu/lw.rs index 111010f12..98596da99 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/lw.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/lw.rs @@ -1,4 +1,4 @@ -use ceno_gpu::common::witgen_types::LwColumnMap; +use ceno_gpu::common::witgen::types::LwColumnMap; use ff_ext::ExtensionField; use super::colmap_base::{extract_rd, extract_read_mem, extract_rs1, extract_state, extract_uint_limbs}; @@ -195,7 +195,7 @@ mod tests { }; let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); - let gpu_result = hal + let gpu_result = hal.witgen .witgen_lw(&col_map, &gpu_records, &indices_u32, shard_offset, 0, 0, 0, None, None) .unwrap(); diff --git a/ceno_zkvm/src/instructions/riscv/gpu/mul.rs b/ceno_zkvm/src/instructions/riscv/gpu/mul.rs index 2ea5ddbf5..56b3c8b40 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/mul.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/mul.rs @@ -1,4 +1,4 @@ -use ceno_gpu::common::witgen_types::MulColumnMap; +use ceno_gpu::common::witgen::types::MulColumnMap; use ff_ext::ExtensionField; use super::colmap_base::{extract_rd, extract_rs1, extract_rs2, extract_state, extract_uint_limbs}; @@ -298,7 +298,7 @@ mod tests { }; let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); - let gpu_result = hal + let gpu_result = hal.witgen .witgen_mul( &col_map, &gpu_records, diff --git a/ceno_zkvm/src/instructions/riscv/gpu/sb.rs b/ceno_zkvm/src/instructions/riscv/gpu/sb.rs index 01ff6c50c..eca04efd3 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/sb.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/sb.rs @@ -1,4 +1,4 @@ -use ceno_gpu::common::witgen_types::SbColumnMap; +use ceno_gpu::common::witgen::types::SbColumnMap; use ff_ext::ExtensionField; use super::colmap_base::{ @@ -181,7 +181,7 @@ mod tests { }; let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); - let gpu_result = hal + let gpu_result = hal.witgen .witgen_sb(&col_map, &gpu_records, &indices_u32, shard_offset, 0, 0, 0, None, None) .unwrap(); diff --git a/ceno_zkvm/src/instructions/riscv/gpu/sh.rs b/ceno_zkvm/src/instructions/riscv/gpu/sh.rs index f8814b35c..350381e37 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/sh.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/sh.rs @@ -1,4 +1,4 @@ -use ceno_gpu::common::witgen_types::ShColumnMap; +use ceno_gpu::common::witgen::types::ShColumnMap; use ff_ext::ExtensionField; use super::colmap_base::{ @@ -158,7 +158,7 @@ mod tests { }; let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); - let gpu_result = hal + let gpu_result = hal.witgen .witgen_sh(&col_map, &gpu_records, &indices_u32, shard_offset, 0, 0, 0, None, None) .unwrap(); diff --git a/ceno_zkvm/src/instructions/riscv/gpu/shard_ram.rs b/ceno_zkvm/src/instructions/riscv/gpu/shard_ram.rs index 1e08b96e9..7981bef47 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/shard_ram.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/shard_ram.rs @@ -1,4 +1,4 @@ -use ceno_gpu::common::witgen_types::ShardRamColumnMap; +use ceno_gpu::common::witgen::types::ShardRamColumnMap; use ff_ext::ExtensionField; use crate::tables::ShardRamConfig; diff --git a/ceno_zkvm/src/instructions/riscv/gpu/shift_i.rs b/ceno_zkvm/src/instructions/riscv/gpu/shift_i.rs index 4cf969e0c..303a9dc6f 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/shift_i.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/shift_i.rs @@ -1,4 +1,4 @@ -use ceno_gpu::common::witgen_types::ShiftIColumnMap; +use ceno_gpu::common::witgen::types::ShiftIColumnMap; use ff_ext::ExtensionField; use super::colmap_base::{extract_rd, extract_rs1, extract_state, extract_uint_limbs}; @@ -150,7 +150,7 @@ mod tests { }; let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); - let gpu_result = hal + let gpu_result = hal.witgen .witgen_shift_i(&col_map, &gpu_records, &indices_u32, shard_offset, 0, 0, 0, None, None) .unwrap(); diff --git a/ceno_zkvm/src/instructions/riscv/gpu/shift_r.rs b/ceno_zkvm/src/instructions/riscv/gpu/shift_r.rs index 72ef64f50..5fc4a571e 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/shift_r.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/shift_r.rs @@ -1,4 +1,4 @@ -use ceno_gpu::common::witgen_types::ShiftRColumnMap; +use ceno_gpu::common::witgen::types::ShiftRColumnMap; use ff_ext::ExtensionField; use super::colmap_base::{extract_rd, extract_rs1, extract_rs2, extract_state, extract_uint_limbs}; @@ -155,7 +155,7 @@ mod tests { }; let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); - let gpu_result = hal + let gpu_result = hal.witgen .witgen_shift_r(&col_map, &gpu_records, &indices_u32, shard_offset, 0, 0, 0, None, None) .unwrap(); diff --git a/ceno_zkvm/src/instructions/riscv/gpu/slt.rs b/ceno_zkvm/src/instructions/riscv/gpu/slt.rs index c46c63d4e..13af3c3fe 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/slt.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/slt.rs @@ -1,4 +1,4 @@ -use ceno_gpu::common::witgen_types::SltColumnMap; +use ceno_gpu::common::witgen::types::SltColumnMap; use ff_ext::ExtensionField; use super::colmap_base::{extract_rd, extract_rs1, extract_rs2, extract_state, extract_uint_limbs}; @@ -142,7 +142,7 @@ mod tests { }; let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); - let gpu_result = hal + let gpu_result = hal.witgen .witgen_slt(&col_map, &gpu_records, &indices_u32, shard_offset, 1, 0, 0, None, None) .unwrap(); diff --git a/ceno_zkvm/src/instructions/riscv/gpu/slti.rs b/ceno_zkvm/src/instructions/riscv/gpu/slti.rs index 33f2c9247..e651d3e02 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/slti.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/slti.rs @@ -1,4 +1,4 @@ -use ceno_gpu::common::witgen_types::SltiColumnMap; +use ceno_gpu::common::witgen::types::SltiColumnMap; use ff_ext::ExtensionField; use super::colmap_base::{extract_rd, extract_rs1, extract_state, extract_uint_limbs}; @@ -137,7 +137,7 @@ mod tests { }; let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); - let gpu_result = hal + let gpu_result = hal.witgen .witgen_slti(&col_map, &gpu_records, &indices_u32, shard_offset, 1, 0, 0, None, None) .unwrap(); diff --git a/ceno_zkvm/src/instructions/riscv/gpu/sub.rs b/ceno_zkvm/src/instructions/riscv/gpu/sub.rs index 31c6a562e..f8afa30ee 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/sub.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/sub.rs @@ -1,4 +1,4 @@ -use ceno_gpu::common::witgen_types::SubColumnMap; +use ceno_gpu::common::witgen::types::SubColumnMap; use ff_ext::ExtensionField; use super::colmap_base::{extract_carries, extract_rd, extract_rs1, extract_rs2, extract_state, extract_uint_limbs}; @@ -143,7 +143,7 @@ mod tests { }; let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); - let gpu_result = hal + let gpu_result = hal.witgen .witgen_sub(&col_map, &gpu_records, &indices_u32, shard_offset, 0, 0, None, None) .unwrap(); diff --git a/ceno_zkvm/src/instructions/riscv/gpu/sw.rs b/ceno_zkvm/src/instructions/riscv/gpu/sw.rs index 267d5879d..644c87bf4 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/sw.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/sw.rs @@ -1,4 +1,4 @@ -use ceno_gpu::common::witgen_types::SwColumnMap; +use ceno_gpu::common::witgen::types::SwColumnMap; use ff_ext::ExtensionField; use super::colmap_base::{ @@ -141,7 +141,7 @@ mod tests { }; let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); - let gpu_result = hal + let gpu_result = hal.witgen .witgen_sw(&col_map, &gpu_records, &indices_u32, shard_offset, 0, 0, 0, None, None) .unwrap(); diff --git a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs index 2b709bce7..5eb63aba9 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs @@ -8,7 +8,7 @@ use ceno_emul::{StepIndex, StepRecord, WordAddr}; use ceno_gpu::{ Buffer, CudaHal, bb31::CudaHalBB31, common::transpose::matrix_transpose, }; -use ceno_gpu::common::witgen_types::{GpuRamRecordSlot, GpuShardRamRecord}; +use ceno_gpu::common::witgen::types::{GpuRamRecordSlot, GpuShardRamRecord}; use ff_ext::ExtensionField; use gkr_iop::utils::lk_multiplicity::Multiplicity; use p3::field::FieldAlgebra; @@ -445,7 +445,7 @@ fn gpu_fill_witness>( info_span!("hal_witgen_add").in_scope(|| { with_cached_shard_steps(|gpu_records| { with_cached_shard_meta(|shard_bufs| { - split_full!(hal + split_full!(hal.witgen .witgen_add( &col_map, gpu_records, @@ -475,7 +475,7 @@ fn gpu_fill_witness>( info_span!("hal_witgen_sub").in_scope(|| { with_cached_shard_steps(|gpu_records| { with_cached_shard_meta(|shard_bufs| { - split_full!(hal + split_full!(hal.witgen .witgen_sub( &col_map, gpu_records, @@ -505,7 +505,7 @@ fn gpu_fill_witness>( info_span!("hal_witgen_logic_r").in_scope(|| { with_cached_shard_steps(|gpu_records| { with_cached_shard_meta(|shard_bufs| { - split_full!(hal + split_full!(hal.witgen .witgen_logic_r( &col_map, gpu_records, @@ -537,7 +537,7 @@ fn gpu_fill_witness>( info_span!("hal_witgen_logic_i").in_scope(|| { with_cached_shard_steps(|gpu_records| { with_cached_shard_meta(|shard_bufs| { - split_full!(hal + split_full!(hal.witgen .witgen_logic_i( &col_map, gpu_records, @@ -569,7 +569,7 @@ fn gpu_fill_witness>( info_span!("hal_witgen_addi").in_scope(|| { with_cached_shard_steps(|gpu_records| { with_cached_shard_meta(|shard_bufs| { - split_full!(hal + split_full!(hal.witgen .witgen_addi( &col_map, gpu_records, @@ -600,7 +600,7 @@ fn gpu_fill_witness>( info_span!("hal_witgen_lui").in_scope(|| { with_cached_shard_steps(|gpu_records| { with_cached_shard_meta(|shard_bufs| { - split_full!(hal + split_full!(hal.witgen .witgen_lui( &col_map, gpu_records, @@ -629,7 +629,7 @@ fn gpu_fill_witness>( info_span!("hal_witgen_auipc").in_scope(|| { with_cached_shard_steps(|gpu_records| { with_cached_shard_meta(|shard_bufs| { - split_full!(hal + split_full!(hal.witgen .witgen_auipc( &col_map, gpu_records, @@ -660,7 +660,7 @@ fn gpu_fill_witness>( info_span!("hal_witgen_jal").in_scope(|| { with_cached_shard_steps(|gpu_records| { with_cached_shard_meta(|shard_bufs| { - split_full!(hal + split_full!(hal.witgen .witgen_jal( &col_map, gpu_records, @@ -691,7 +691,7 @@ fn gpu_fill_witness>( info_span!("hal_witgen_shift_r").in_scope(|| { with_cached_shard_steps(|gpu_records| { with_cached_shard_meta(|shard_bufs| { - split_full!(hal + split_full!(hal.witgen .witgen_shift_r( &col_map, gpu_records, @@ -725,7 +725,7 @@ fn gpu_fill_witness>( info_span!("hal_witgen_shift_i").in_scope(|| { with_cached_shard_steps(|gpu_records| { with_cached_shard_meta(|shard_bufs| { - split_full!(hal + split_full!(hal.witgen .witgen_shift_i( &col_map, gpu_records, @@ -757,7 +757,7 @@ fn gpu_fill_witness>( info_span!("hal_witgen_slt").in_scope(|| { with_cached_shard_steps(|gpu_records| { with_cached_shard_meta(|shard_bufs| { - split_full!(hal + split_full!(hal.witgen .witgen_slt( &col_map, gpu_records, @@ -789,7 +789,7 @@ fn gpu_fill_witness>( info_span!("hal_witgen_slti").in_scope(|| { with_cached_shard_steps(|gpu_records| { with_cached_shard_meta(|shard_bufs| { - split_full!(hal + split_full!(hal.witgen .witgen_slti( &col_map, gpu_records, @@ -824,7 +824,7 @@ fn gpu_fill_witness>( info_span!("hal_witgen_branch_eq").in_scope(|| { with_cached_shard_steps(|gpu_records| { with_cached_shard_meta(|shard_bufs| { - split_full!(hal + split_full!(hal.witgen .witgen_branch_eq( &col_map, gpu_records, @@ -859,7 +859,7 @@ fn gpu_fill_witness>( info_span!("hal_witgen_branch_cmp").in_scope(|| { with_cached_shard_steps(|gpu_records| { with_cached_shard_meta(|shard_bufs| { - split_full!(hal + split_full!(hal.witgen .witgen_branch_cmp( &col_map, gpu_records, @@ -891,7 +891,7 @@ fn gpu_fill_witness>( info_span!("hal_witgen_jalr").in_scope(|| { with_cached_shard_steps(|gpu_records| { with_cached_shard_meta(|shard_bufs| { - split_full!(hal + split_full!(hal.witgen .witgen_jalr( &col_map, gpu_records, @@ -921,7 +921,7 @@ fn gpu_fill_witness>( info_span!("hal_witgen_sw").in_scope(|| { with_cached_shard_steps(|gpu_records| { with_cached_shard_meta(|shard_bufs| { - split_full!(hal + split_full!(hal.witgen .witgen_sw( &col_map, gpu_records, @@ -952,7 +952,7 @@ fn gpu_fill_witness>( info_span!("hal_witgen_sh").in_scope(|| { with_cached_shard_steps(|gpu_records| { with_cached_shard_meta(|shard_bufs| { - split_full!(hal + split_full!(hal.witgen .witgen_sh( &col_map, gpu_records, @@ -983,7 +983,7 @@ fn gpu_fill_witness>( info_span!("hal_witgen_sb").in_scope(|| { with_cached_shard_steps(|gpu_records| { with_cached_shard_meta(|shard_bufs| { - split_full!(hal + split_full!(hal.witgen .witgen_sb( &col_map, gpu_records, @@ -1025,7 +1025,7 @@ fn gpu_fill_witness>( info_span!("hal_witgen_load_sub").in_scope(|| { with_cached_shard_steps(|gpu_records| { with_cached_shard_meta(|shard_bufs| { - split_full!(hal + split_full!(hal.witgen .witgen_load_sub( &col_map, gpu_records, @@ -1059,7 +1059,7 @@ fn gpu_fill_witness>( info_span!("hal_witgen_mul").in_scope(|| { with_cached_shard_steps(|gpu_records| { with_cached_shard_meta(|shard_bufs| { - split_full!(hal + split_full!(hal.witgen .witgen_mul( &col_map, gpu_records, @@ -1091,7 +1091,7 @@ fn gpu_fill_witness>( info_span!("hal_witgen_div").in_scope(|| { with_cached_shard_steps(|gpu_records| { with_cached_shard_meta(|shard_bufs| { - split_full!(hal + split_full!(hal.witgen .witgen_div( &col_map, gpu_records, @@ -1129,7 +1129,7 @@ fn gpu_fill_witness>( info_span!("hal_witgen_lw").in_scope(|| { with_cached_shard_steps(|gpu_records| { with_cached_shard_meta(|shard_bufs| { - split_full!(hal + split_full!(hal.witgen .witgen_lw( &col_map, gpu_records, @@ -1302,7 +1302,7 @@ fn gpu_assign_keccak_inner( // Step 5: Launch GPU kernel let gpu_result = info_span!("gpu_kernel").in_scope(|| { with_cached_shard_meta(|shard_bufs| { - hal.witgen_keccak( + hal.witgen.witgen_keccak( &col_map, &packed_instances, num_padded_rows, diff --git a/ceno_zkvm/src/structs.rs b/ceno_zkvm/src/structs.rs index 495aa41b8..12c2d9505 100644 --- a/ceno_zkvm/src/structs.rs +++ b/ceno_zkvm/src/structs.rs @@ -775,7 +775,7 @@ impl ZKVMWitnesses { // 3. GPU sort addr_accessed + dedup, then D2H sorted unique addrs let addr_accessed: Vec = if addr_count > 0 { info_span!("gpu_sort_addr").in_scope(|| { - let (deduped, unique_count) = hal + let (deduped, unique_count) = hal.witgen .sort_and_dedup_u32(&mut shared.addr_buf, addr_count, None) .map_err(|e| { ZKVMError::InvalidWitness(format!("GPU sort addr: {e}").into()) @@ -883,7 +883,7 @@ impl ZKVMWitnesses { // 6. GPU merge shared_ec + batch_ec, then partition by is_to_write_set let (partitioned_buf, num_writes, total_records) = info_span!("gpu_merge_partition").in_scope(|| { - hal.merge_and_partition_records( + hal.witgen.merge_and_partition_records( &shared.ec_buf, ec_count, &cont_ec_buf, @@ -908,7 +908,7 @@ impl ZKVMWitnesses { let max_chunk = shard_ctx.max_num_cross_shard_accesses; // Record sizes needed for chunking - let record_u32s = std::mem::size_of::() / 4; + let record_u32s = std::mem::size_of::() / 4; let circuit_inputs = info_span!("shard_ram_assign_from_device", n = total_records) .in_scope(|| { diff --git a/ceno_zkvm/src/tables/shard_ram.rs b/ceno_zkvm/src/tables/shard_ram.rs index 77d121148..54d478c96 100644 --- a/ceno_zkvm/src/tables/shard_ram.rs +++ b/ceno_zkvm/src/tables/shard_ram.rs @@ -665,7 +665,7 @@ impl ShardRamCircuit { bb31::CudaHalBB31, common::{ transpose::matrix_transpose, - witgen_types::GpuShardRamRecord, + witgen::types::GpuShardRamRecord, }, }; use gkr_iop::gpu::gpu_prover::get_cuda_hal; @@ -730,7 +730,7 @@ impl ShardRamCircuit { n = steps.len(), num_rows_padded, num_witin, - ).in_scope(|| hal + ).in_scope(|| hal.witgen .witgen_shard_ram_per_row( &col_map, &gpu_records, @@ -784,7 +784,7 @@ impl ShardRamCircuit { break; } - let (next_x, next_y) = hal + let (next_x, next_y) = hal.witgen .shard_ram_ec_tree_layer( &gpu_cols, &cur_x, @@ -816,7 +816,7 @@ impl ShardRamCircuit { ).in_scope(|| -> Result<_, ZKVMError> { let wit_num_rows = num_rows_padded; let wit_num_cols = num_witin; - let mut rmm_buf = hal + let mut rmm_buf = hal.witgen .alloc_elems_on_device(wit_num_rows * wit_num_cols, false, None) .map_err(|e| { ZKVMError::InvalidWitness(format!("GPU alloc for transpose failed: {e}").into()) @@ -843,7 +843,7 @@ impl ShardRamCircuit { }; let struct_num_cols = num_structural_witin; - let mut struct_rmm_buf = hal + let mut struct_rmm_buf = hal.witgen .alloc_elems_on_device(wit_num_rows * struct_num_cols, false, None) .map_err(|e| { ZKVMError::InvalidWitness( @@ -941,7 +941,7 @@ impl ShardRamCircuit { n = num_records, num_rows_padded, num_witin, - ).in_scope(|| hal + ).in_scope(|| hal.witgen .witgen_shard_ram_per_row_from_device( &col_map, device_records, @@ -968,7 +968,7 @@ impl ShardRamCircuit { })?; // Extract point_x/y from device records into flat arrays - let (mut cur_x, mut cur_y) = hal + let (mut cur_x, mut cur_y) = hal.witgen .extract_ec_points_from_device(device_records, num_records, n, None) .map_err(|e| { ZKVMError::InvalidWitness( @@ -985,7 +985,7 @@ impl ShardRamCircuit { break; } - let (next_x, next_y) = hal + let (next_x, next_y) = hal.witgen .shard_ram_ec_tree_layer( &gpu_cols, &cur_x, @@ -1017,7 +1017,7 @@ impl ShardRamCircuit { ).in_scope(|| -> Result<_, ZKVMError> { let wit_num_rows = num_rows_padded; let wit_num_cols = num_witin; - let mut rmm_buf = hal + let mut rmm_buf = hal.witgen .alloc_elems_on_device(wit_num_rows * wit_num_cols, false, None) .map_err(|e| { ZKVMError::InvalidWitness(format!("GPU alloc for transpose failed: {e}").into()) @@ -1044,7 +1044,7 @@ impl ShardRamCircuit { }; let struct_num_cols = num_structural_witin; - let mut struct_rmm_buf = hal + let mut struct_rmm_buf = hal.witgen .alloc_elems_on_device(wit_num_rows * struct_num_cols, false, None) .map_err(|e| { ZKVMError::InvalidWitness( From 130cc4a452009eafb798bdf47145b83f3966fb10 Mon Sep 17 00:00:00 2001 From: Velaciela Date: Tue, 24 Mar 2026 12:04:23 +0800 Subject: [PATCH 60/73] path: ceno_zkvm/src/instructions/gpu --- ceno_zkvm/src/e2e.rs | 8 ++++---- ceno_zkvm/src/instructions.rs | 10 +++++----- ceno_zkvm/src/instructions/{riscv => }/gpu/add.rs | 8 ++++---- ceno_zkvm/src/instructions/{riscv => }/gpu/addi.rs | 2 +- ceno_zkvm/src/instructions/{riscv => }/gpu/auipc.rs | 2 +- .../src/instructions/{riscv => }/gpu/branch_cmp.rs | 2 +- .../src/instructions/{riscv => }/gpu/branch_eq.rs | 2 +- .../src/instructions/{riscv => }/gpu/colmap_base.rs | 0 ceno_zkvm/src/instructions/{riscv => }/gpu/d2h.rs | 0 .../src/instructions/{riscv => }/gpu/debug_compare.rs | 0 .../src/instructions/{riscv => }/gpu/device_cache.rs | 0 ceno_zkvm/src/instructions/{riscv => }/gpu/div.rs | 0 .../src/instructions/{riscv => }/gpu/gpu_config.rs | 0 .../instructions/{ => gpu}/host_ops/cpu_fallback.rs | 2 +- ceno_zkvm/src/instructions/{ => gpu}/host_ops/emit.rs | 0 .../src/instructions/{ => gpu}/host_ops/lk_ops.rs | 0 ceno_zkvm/src/instructions/{ => gpu}/host_ops/mod.rs | 0 ceno_zkvm/src/instructions/{ => gpu}/host_ops/sink.rs | 0 ceno_zkvm/src/instructions/{riscv => }/gpu/jal.rs | 2 +- ceno_zkvm/src/instructions/{riscv => }/gpu/jalr.rs | 2 +- ceno_zkvm/src/instructions/{riscv => }/gpu/keccak.rs | 0 ceno_zkvm/src/instructions/{riscv => }/gpu/load_sub.rs | 0 ceno_zkvm/src/instructions/{riscv => }/gpu/logic_i.rs | 2 +- ceno_zkvm/src/instructions/{riscv => }/gpu/logic_r.rs | 8 ++++---- ceno_zkvm/src/instructions/{riscv => }/gpu/lui.rs | 2 +- ceno_zkvm/src/instructions/{riscv => }/gpu/lw.rs | 6 +++--- ceno_zkvm/src/instructions/{riscv => }/gpu/mod.rs | 1 + ceno_zkvm/src/instructions/{riscv => }/gpu/mul.rs | 0 ceno_zkvm/src/instructions/{riscv => }/gpu/sb.rs | 2 +- ceno_zkvm/src/instructions/{riscv => }/gpu/sh.rs | 2 +- .../src/instructions/{riscv => }/gpu/shard_ram.rs | 0 ceno_zkvm/src/instructions/{riscv => }/gpu/shift_i.rs | 2 +- ceno_zkvm/src/instructions/{riscv => }/gpu/shift_r.rs | 2 +- ceno_zkvm/src/instructions/{riscv => }/gpu/slt.rs | 2 +- ceno_zkvm/src/instructions/{riscv => }/gpu/slti.rs | 2 +- ceno_zkvm/src/instructions/{riscv => }/gpu/sub.rs | 2 +- ceno_zkvm/src/instructions/{riscv => }/gpu/sw.rs | 2 +- .../src/instructions/{riscv => }/gpu/witgen_gpu.rs | 0 ceno_zkvm/src/instructions/riscv.rs | 7 +++---- ceno_zkvm/src/instructions/riscv/ecall/keccak.rs | 2 +- .../src/instructions/riscv/logic/logic_circuit.rs | 2 +- ceno_zkvm/src/structs.rs | 4 ++-- ceno_zkvm/src/tables/shard_ram.rs | 4 ++-- 43 files changed, 47 insertions(+), 47 deletions(-) rename ceno_zkvm/src/instructions/{riscv => }/gpu/add.rs (96%) rename ceno_zkvm/src/instructions/{riscv => }/gpu/addi.rs (98%) rename ceno_zkvm/src/instructions/{riscv => }/gpu/auipc.rs (98%) rename ceno_zkvm/src/instructions/{riscv => }/gpu/branch_cmp.rs (98%) rename ceno_zkvm/src/instructions/{riscv => }/gpu/branch_eq.rs (98%) rename ceno_zkvm/src/instructions/{riscv => }/gpu/colmap_base.rs (100%) rename ceno_zkvm/src/instructions/{riscv => }/gpu/d2h.rs (100%) rename ceno_zkvm/src/instructions/{riscv => }/gpu/debug_compare.rs (100%) rename ceno_zkvm/src/instructions/{riscv => }/gpu/device_cache.rs (100%) rename ceno_zkvm/src/instructions/{riscv => }/gpu/div.rs (100%) rename ceno_zkvm/src/instructions/{riscv => }/gpu/gpu_config.rs (100%) rename ceno_zkvm/src/instructions/{ => gpu}/host_ops/cpu_fallback.rs (99%) rename ceno_zkvm/src/instructions/{ => gpu}/host_ops/emit.rs (100%) rename ceno_zkvm/src/instructions/{ => gpu}/host_ops/lk_ops.rs (100%) rename ceno_zkvm/src/instructions/{ => gpu}/host_ops/mod.rs (100%) rename ceno_zkvm/src/instructions/{ => gpu}/host_ops/sink.rs (100%) rename ceno_zkvm/src/instructions/{riscv => }/gpu/jal.rs (98%) rename ceno_zkvm/src/instructions/{riscv => }/gpu/jalr.rs (98%) rename ceno_zkvm/src/instructions/{riscv => }/gpu/keccak.rs (100%) rename ceno_zkvm/src/instructions/{riscv => }/gpu/load_sub.rs (100%) rename ceno_zkvm/src/instructions/{riscv => }/gpu/logic_i.rs (98%) rename ceno_zkvm/src/instructions/{riscv => }/gpu/logic_r.rs (95%) rename ceno_zkvm/src/instructions/{riscv => }/gpu/lui.rs (98%) rename ceno_zkvm/src/instructions/{riscv => }/gpu/lw.rs (97%) rename ceno_zkvm/src/instructions/{riscv => }/gpu/mod.rs (99%) rename ceno_zkvm/src/instructions/{riscv => }/gpu/mul.rs (100%) rename ceno_zkvm/src/instructions/{riscv => }/gpu/sb.rs (98%) rename ceno_zkvm/src/instructions/{riscv => }/gpu/sh.rs (98%) rename ceno_zkvm/src/instructions/{riscv => }/gpu/shard_ram.rs (100%) rename ceno_zkvm/src/instructions/{riscv => }/gpu/shift_i.rs (98%) rename ceno_zkvm/src/instructions/{riscv => }/gpu/shift_r.rs (98%) rename ceno_zkvm/src/instructions/{riscv => }/gpu/slt.rs (98%) rename ceno_zkvm/src/instructions/{riscv => }/gpu/slti.rs (98%) rename ceno_zkvm/src/instructions/{riscv => }/gpu/sub.rs (98%) rename ceno_zkvm/src/instructions/{riscv => }/gpu/sw.rs (98%) rename ceno_zkvm/src/instructions/{riscv => }/gpu/witgen_gpu.rs (100%) diff --git a/ceno_zkvm/src/e2e.rs b/ceno_zkvm/src/e2e.rs index 01b4b0f92..612835cfe 100644 --- a/ceno_zkvm/src/e2e.rs +++ b/ceno_zkvm/src/e2e.rs @@ -1511,7 +1511,7 @@ pub fn generate_witness<'a, E: ExtensionField>( // This batch-D2Hs accumulated EC records and addr_accessed into shard_ctx. #[cfg(feature = "gpu")] info_span!("flush_shared_ec").in_scope(|| { - crate::instructions::riscv::gpu::witgen_gpu::flush_shared_ec_buffers( + crate::instructions::gpu::witgen_gpu::flush_shared_ec_buffers( &mut shard_ctx, ) }).unwrap(); @@ -1519,7 +1519,7 @@ pub fn generate_witness<'a, E: ExtensionField>( // Free GPU shard_steps cache after all opcode circuits are done. #[cfg(feature = "gpu")] { - crate::instructions::riscv::gpu::witgen_gpu::invalidate_shard_steps_cache(); + crate::instructions::gpu::witgen_gpu::invalidate_shard_steps_cache(); if std::env::var_os("CENO_GPU_TRIM_AFTER_WITGEN").is_some() { use gkr_iop::gpu::gpu_prover::get_cuda_hal; @@ -1554,7 +1554,7 @@ pub fn generate_witness<'a, E: ExtensionField>( // Force CPU path for the debug comparison (thread-local, no env var races). #[cfg(feature = "gpu")] - crate::instructions::riscv::gpu::witgen_gpu::set_force_cpu_path(true); + crate::instructions::gpu::witgen_gpu::set_force_cpu_path(true); system_config .config @@ -1579,7 +1579,7 @@ pub fn generate_witness<'a, E: ExtensionField>( cpu_witness.finalize_lk_multiplicities(); #[cfg(feature = "gpu")] - crate::instructions::riscv::gpu::witgen_gpu::set_force_cpu_path(false); + crate::instructions::gpu::witgen_gpu::set_force_cpu_path(false); log_shard_ctx_diff("post_opcode_assignment", &cpu_shard_ctx, &shard_ctx); diff --git a/ceno_zkvm/src/instructions.rs b/ceno_zkvm/src/instructions.rs index c351e45ea..99e1da0b1 100644 --- a/ceno_zkvm/src/instructions.rs +++ b/ceno_zkvm/src/instructions.rs @@ -20,12 +20,12 @@ use rayon::{ use witness::{InstancePaddingStrategy, RowMajorMatrix}; pub mod riscv; -pub mod host_ops; +pub mod gpu; /// Backward-compatible re-export: old `side_effects` path still works. -pub use host_ops as side_effects; +pub use gpu::host_ops as side_effects; -pub use host_ops::{cpu_assign_instances, cpu_collect_lk_and_shardram, cpu_collect_shardram}; +pub use gpu::host_ops::{cpu_assign_instances, cpu_collect_lk_and_shardram, cpu_collect_shardram}; pub trait Instruction { type InstructionConfig: Send + Sync; @@ -261,7 +261,7 @@ macro_rules! impl_collect_lk_and_shardram { let shard_ctx_ptr = shard_ctx as *mut $crate::e2e::ShardContext; let _ctx = unsafe { &*shard_ctx_ptr }; let mut _sink_val = unsafe { - $crate::instructions::side_effects::CpuSideEffectSink::from_raw( + $crate::instructions::gpu::host_ops::CpuSideEffectSink::from_raw( shard_ctx_ptr, lk_multiplicity, ) @@ -347,7 +347,7 @@ macro_rules! impl_gpu_assign { ), $crate::error::ZKVMError, > { - use $crate::instructions::riscv::gpu::witgen_gpu; + use $crate::instructions::gpu::witgen_gpu; let gpu_kind: Option = $kind_expr; if let Some(kind) = gpu_kind { if let Some(result) = witgen_gpu::try_gpu_assign_instances::( diff --git a/ceno_zkvm/src/instructions/riscv/gpu/add.rs b/ceno_zkvm/src/instructions/gpu/add.rs similarity index 96% rename from ceno_zkvm/src/instructions/riscv/gpu/add.rs rename to ceno_zkvm/src/instructions/gpu/add.rs index 35702d74a..a83011cae 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/add.rs +++ b/ceno_zkvm/src/instructions/gpu/add.rs @@ -131,7 +131,7 @@ mod tests { let col_map = extract_add_column_map(&config, cb.cs.num_witin as usize); let flat = col_map.to_flat(); - crate::instructions::riscv::gpu::colmap_base::validate_column_map(&flat, col_map.num_cols); + crate::instructions::gpu::colmap_base::validate_column_map(&flat, col_map.num_cols); } #[test] @@ -213,7 +213,7 @@ mod tests { let mut shard_ctx_full_gpu = ShardContext::default(); let (gpu_rmms, gpu_lkm) = - crate::instructions::riscv::gpu::witgen_gpu::try_gpu_assign_instances::< + crate::instructions::gpu::witgen_gpu::try_gpu_assign_instances::< E, AddInstruction, >( @@ -223,14 +223,14 @@ mod tests { num_structural_witin, &steps, &indices, - crate::instructions::riscv::gpu::witgen_gpu::GpuWitgenKind::Add, + crate::instructions::gpu::witgen_gpu::GpuWitgenKind::Add, ) .unwrap() .expect("GPU path should be available"); // Flush shared EC/addr buffers from GPU device to shard_ctx // (in the e2e pipeline this is called once per shard after all opcode circuits) - crate::instructions::riscv::gpu::device_cache::flush_shared_ec_buffers( + crate::instructions::gpu::device_cache::flush_shared_ec_buffers( &mut shard_ctx_full_gpu, ) .unwrap(); diff --git a/ceno_zkvm/src/instructions/riscv/gpu/addi.rs b/ceno_zkvm/src/instructions/gpu/addi.rs similarity index 98% rename from ceno_zkvm/src/instructions/riscv/gpu/addi.rs rename to ceno_zkvm/src/instructions/gpu/addi.rs index e263250bf..d206f1a4e 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/addi.rs +++ b/ceno_zkvm/src/instructions/gpu/addi.rs @@ -59,7 +59,7 @@ mod tests { let col_map = extract_addi_column_map(&config, cb.cs.num_witin as usize); let flat = col_map.to_flat(); - crate::instructions::riscv::gpu::colmap_base::validate_column_map(&flat, col_map.num_cols); + crate::instructions::gpu::colmap_base::validate_column_map(&flat, col_map.num_cols); } #[test] diff --git a/ceno_zkvm/src/instructions/riscv/gpu/auipc.rs b/ceno_zkvm/src/instructions/gpu/auipc.rs similarity index 98% rename from ceno_zkvm/src/instructions/riscv/gpu/auipc.rs rename to ceno_zkvm/src/instructions/gpu/auipc.rs index 07b984295..64d345ec7 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/auipc.rs +++ b/ceno_zkvm/src/instructions/gpu/auipc.rs @@ -61,7 +61,7 @@ mod tests { let col_map = extract_auipc_column_map(&config, cb.cs.num_witin as usize); let flat = col_map.to_flat(); - crate::instructions::riscv::gpu::colmap_base::validate_column_map(&flat, col_map.num_cols); + crate::instructions::gpu::colmap_base::validate_column_map(&flat, col_map.num_cols); } #[test] diff --git a/ceno_zkvm/src/instructions/riscv/gpu/branch_cmp.rs b/ceno_zkvm/src/instructions/gpu/branch_cmp.rs similarity index 98% rename from ceno_zkvm/src/instructions/riscv/gpu/branch_cmp.rs rename to ceno_zkvm/src/instructions/gpu/branch_cmp.rs index dcba4eef7..b5716aa9c 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/branch_cmp.rs +++ b/ceno_zkvm/src/instructions/gpu/branch_cmp.rs @@ -70,7 +70,7 @@ mod tests { let col_map = extract_branch_cmp_column_map(&config, cb.cs.num_witin as usize); let flat = col_map.to_flat(); - crate::instructions::riscv::gpu::colmap_base::validate_column_map(&flat, col_map.num_cols); + crate::instructions::gpu::colmap_base::validate_column_map(&flat, col_map.num_cols); } #[test] diff --git a/ceno_zkvm/src/instructions/riscv/gpu/branch_eq.rs b/ceno_zkvm/src/instructions/gpu/branch_eq.rs similarity index 98% rename from ceno_zkvm/src/instructions/riscv/gpu/branch_eq.rs rename to ceno_zkvm/src/instructions/gpu/branch_eq.rs index 99c8621a3..300fef35d 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/branch_eq.rs +++ b/ceno_zkvm/src/instructions/gpu/branch_eq.rs @@ -63,7 +63,7 @@ mod tests { let col_map = extract_branch_eq_column_map(&config, cb.cs.num_witin as usize); let flat = col_map.to_flat(); - crate::instructions::riscv::gpu::colmap_base::validate_column_map(&flat, col_map.num_cols); + crate::instructions::gpu::colmap_base::validate_column_map(&flat, col_map.num_cols); } #[test] diff --git a/ceno_zkvm/src/instructions/riscv/gpu/colmap_base.rs b/ceno_zkvm/src/instructions/gpu/colmap_base.rs similarity index 100% rename from ceno_zkvm/src/instructions/riscv/gpu/colmap_base.rs rename to ceno_zkvm/src/instructions/gpu/colmap_base.rs diff --git a/ceno_zkvm/src/instructions/riscv/gpu/d2h.rs b/ceno_zkvm/src/instructions/gpu/d2h.rs similarity index 100% rename from ceno_zkvm/src/instructions/riscv/gpu/d2h.rs rename to ceno_zkvm/src/instructions/gpu/d2h.rs diff --git a/ceno_zkvm/src/instructions/riscv/gpu/debug_compare.rs b/ceno_zkvm/src/instructions/gpu/debug_compare.rs similarity index 100% rename from ceno_zkvm/src/instructions/riscv/gpu/debug_compare.rs rename to ceno_zkvm/src/instructions/gpu/debug_compare.rs diff --git a/ceno_zkvm/src/instructions/riscv/gpu/device_cache.rs b/ceno_zkvm/src/instructions/gpu/device_cache.rs similarity index 100% rename from ceno_zkvm/src/instructions/riscv/gpu/device_cache.rs rename to ceno_zkvm/src/instructions/gpu/device_cache.rs diff --git a/ceno_zkvm/src/instructions/riscv/gpu/div.rs b/ceno_zkvm/src/instructions/gpu/div.rs similarity index 100% rename from ceno_zkvm/src/instructions/riscv/gpu/div.rs rename to ceno_zkvm/src/instructions/gpu/div.rs diff --git a/ceno_zkvm/src/instructions/riscv/gpu/gpu_config.rs b/ceno_zkvm/src/instructions/gpu/gpu_config.rs similarity index 100% rename from ceno_zkvm/src/instructions/riscv/gpu/gpu_config.rs rename to ceno_zkvm/src/instructions/gpu/gpu_config.rs diff --git a/ceno_zkvm/src/instructions/host_ops/cpu_fallback.rs b/ceno_zkvm/src/instructions/gpu/host_ops/cpu_fallback.rs similarity index 99% rename from ceno_zkvm/src/instructions/host_ops/cpu_fallback.rs rename to ceno_zkvm/src/instructions/gpu/host_ops/cpu_fallback.rs index 302935ecc..8e3e6b194 100644 --- a/ceno_zkvm/src/instructions/host_ops/cpu_fallback.rs +++ b/ceno_zkvm/src/instructions/gpu/host_ops/cpu_fallback.rs @@ -14,7 +14,7 @@ use crate::{ e2e::ShardContext, error::ZKVMError, tables::RMMCollections, witness::LkMultiplicity, }; -use super::super::Instruction; +use crate::instructions::Instruction; /// CPU-only assign_instances. Extracted so GPU-enabled instructions can call this as fallback. pub fn cpu_assign_instances>( diff --git a/ceno_zkvm/src/instructions/host_ops/emit.rs b/ceno_zkvm/src/instructions/gpu/host_ops/emit.rs similarity index 100% rename from ceno_zkvm/src/instructions/host_ops/emit.rs rename to ceno_zkvm/src/instructions/gpu/host_ops/emit.rs diff --git a/ceno_zkvm/src/instructions/host_ops/lk_ops.rs b/ceno_zkvm/src/instructions/gpu/host_ops/lk_ops.rs similarity index 100% rename from ceno_zkvm/src/instructions/host_ops/lk_ops.rs rename to ceno_zkvm/src/instructions/gpu/host_ops/lk_ops.rs diff --git a/ceno_zkvm/src/instructions/host_ops/mod.rs b/ceno_zkvm/src/instructions/gpu/host_ops/mod.rs similarity index 100% rename from ceno_zkvm/src/instructions/host_ops/mod.rs rename to ceno_zkvm/src/instructions/gpu/host_ops/mod.rs diff --git a/ceno_zkvm/src/instructions/host_ops/sink.rs b/ceno_zkvm/src/instructions/gpu/host_ops/sink.rs similarity index 100% rename from ceno_zkvm/src/instructions/host_ops/sink.rs rename to ceno_zkvm/src/instructions/gpu/host_ops/sink.rs diff --git a/ceno_zkvm/src/instructions/riscv/gpu/jal.rs b/ceno_zkvm/src/instructions/gpu/jal.rs similarity index 98% rename from ceno_zkvm/src/instructions/riscv/gpu/jal.rs rename to ceno_zkvm/src/instructions/gpu/jal.rs index 21421ffb5..dc7da0bb6 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/jal.rs +++ b/ceno_zkvm/src/instructions/gpu/jal.rs @@ -49,7 +49,7 @@ mod tests { let col_map = extract_jal_column_map(&config, cb.cs.num_witin as usize); let flat = col_map.to_flat(); - crate::instructions::riscv::gpu::colmap_base::validate_column_map(&flat, col_map.num_cols); + crate::instructions::gpu::colmap_base::validate_column_map(&flat, col_map.num_cols); } #[test] diff --git a/ceno_zkvm/src/instructions/riscv/gpu/jalr.rs b/ceno_zkvm/src/instructions/gpu/jalr.rs similarity index 98% rename from ceno_zkvm/src/instructions/riscv/gpu/jalr.rs rename to ceno_zkvm/src/instructions/gpu/jalr.rs index 1c9447346..e68d961ca 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/jalr.rs +++ b/ceno_zkvm/src/instructions/gpu/jalr.rs @@ -68,7 +68,7 @@ mod tests { let col_map = extract_jalr_column_map(&config, cb.cs.num_witin as usize); let flat = col_map.to_flat(); - crate::instructions::riscv::gpu::colmap_base::validate_column_map(&flat, col_map.num_cols); + crate::instructions::gpu::colmap_base::validate_column_map(&flat, col_map.num_cols); } #[test] diff --git a/ceno_zkvm/src/instructions/riscv/gpu/keccak.rs b/ceno_zkvm/src/instructions/gpu/keccak.rs similarity index 100% rename from ceno_zkvm/src/instructions/riscv/gpu/keccak.rs rename to ceno_zkvm/src/instructions/gpu/keccak.rs diff --git a/ceno_zkvm/src/instructions/riscv/gpu/load_sub.rs b/ceno_zkvm/src/instructions/gpu/load_sub.rs similarity index 100% rename from ceno_zkvm/src/instructions/riscv/gpu/load_sub.rs rename to ceno_zkvm/src/instructions/gpu/load_sub.rs diff --git a/ceno_zkvm/src/instructions/riscv/gpu/logic_i.rs b/ceno_zkvm/src/instructions/gpu/logic_i.rs similarity index 98% rename from ceno_zkvm/src/instructions/riscv/gpu/logic_i.rs rename to ceno_zkvm/src/instructions/gpu/logic_i.rs index e61e24796..d7bfb068d 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/logic_i.rs +++ b/ceno_zkvm/src/instructions/gpu/logic_i.rs @@ -59,7 +59,7 @@ mod tests { let col_map = extract_logic_i_column_map(&config, cb.cs.num_witin as usize); let flat = col_map.to_flat(); - crate::instructions::riscv::gpu::colmap_base::validate_column_map(&flat, col_map.num_cols); + crate::instructions::gpu::colmap_base::validate_column_map(&flat, col_map.num_cols); } #[test] diff --git a/ceno_zkvm/src/instructions/riscv/gpu/logic_r.rs b/ceno_zkvm/src/instructions/gpu/logic_r.rs similarity index 95% rename from ceno_zkvm/src/instructions/riscv/gpu/logic_r.rs rename to ceno_zkvm/src/instructions/gpu/logic_r.rs index 9803777b4..bbc4b56e8 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/logic_r.rs +++ b/ceno_zkvm/src/instructions/gpu/logic_r.rs @@ -89,7 +89,7 @@ mod tests { let col_map = extract_logic_r_column_map(&config, cb.cs.num_witin as usize); let flat = col_map.to_flat(); - crate::instructions::riscv::gpu::colmap_base::validate_column_map(&flat, col_map.num_cols); + crate::instructions::gpu::colmap_base::validate_column_map(&flat, col_map.num_cols); } #[test] @@ -202,7 +202,7 @@ mod tests { let mut shard_ctx_full_gpu = ShardContext::default(); let (gpu_rmms, gpu_lkm) = - crate::instructions::riscv::gpu::witgen_gpu::try_gpu_assign_instances::< + crate::instructions::gpu::witgen_gpu::try_gpu_assign_instances::< E, AndInstruction, >( @@ -212,12 +212,12 @@ mod tests { num_structural_witin, &steps, &indices, - crate::instructions::riscv::gpu::witgen_gpu::GpuWitgenKind::LogicR(0), + crate::instructions::gpu::witgen_gpu::GpuWitgenKind::LogicR(0), ) .unwrap() .expect("GPU path should be available"); - crate::instructions::riscv::gpu::device_cache::flush_shared_ec_buffers( + crate::instructions::gpu::device_cache::flush_shared_ec_buffers( &mut shard_ctx_full_gpu, ) .unwrap(); diff --git a/ceno_zkvm/src/instructions/riscv/gpu/lui.rs b/ceno_zkvm/src/instructions/gpu/lui.rs similarity index 98% rename from ceno_zkvm/src/instructions/riscv/gpu/lui.rs rename to ceno_zkvm/src/instructions/gpu/lui.rs index b72832a24..3d1d04166 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/lui.rs +++ b/ceno_zkvm/src/instructions/gpu/lui.rs @@ -60,7 +60,7 @@ mod tests { let col_map = extract_lui_column_map(&config, cb.cs.num_witin as usize); let flat = col_map.to_flat(); - crate::instructions::riscv::gpu::colmap_base::validate_column_map(&flat, col_map.num_cols); + crate::instructions::gpu::colmap_base::validate_column_map(&flat, col_map.num_cols); } #[test] diff --git a/ceno_zkvm/src/instructions/riscv/gpu/lw.rs b/ceno_zkvm/src/instructions/gpu/lw.rs similarity index 97% rename from ceno_zkvm/src/instructions/riscv/gpu/lw.rs rename to ceno_zkvm/src/instructions/gpu/lw.rs index 98596da99..56c73a55e 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/lw.rs +++ b/ceno_zkvm/src/instructions/gpu/lw.rs @@ -225,7 +225,7 @@ mod tests { let mut shard_ctx_full_gpu = ShardContext::default(); let (gpu_rmms, gpu_lkm) = - crate::instructions::riscv::gpu::witgen_gpu::try_gpu_assign_instances::< + crate::instructions::gpu::witgen_gpu::try_gpu_assign_instances::< E, LwInstruction, >( @@ -235,12 +235,12 @@ mod tests { num_structural_witin, &steps, &indices, - crate::instructions::riscv::gpu::witgen_gpu::GpuWitgenKind::Lw, + crate::instructions::gpu::witgen_gpu::GpuWitgenKind::Lw, ) .unwrap() .expect("GPU path should be available"); - crate::instructions::riscv::gpu::device_cache::flush_shared_ec_buffers( + crate::instructions::gpu::device_cache::flush_shared_ec_buffers( &mut shard_ctx_full_gpu, ) .unwrap(); diff --git a/ceno_zkvm/src/instructions/riscv/gpu/mod.rs b/ceno_zkvm/src/instructions/gpu/mod.rs similarity index 99% rename from ceno_zkvm/src/instructions/riscv/gpu/mod.rs rename to ceno_zkvm/src/instructions/gpu/mod.rs index 1559187ca..2a7e8c7da 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/mod.rs +++ b/ceno_zkvm/src/instructions/gpu/mod.rs @@ -1,3 +1,4 @@ +pub mod host_ops; #[cfg(feature = "gpu")] pub mod add; #[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] diff --git a/ceno_zkvm/src/instructions/riscv/gpu/mul.rs b/ceno_zkvm/src/instructions/gpu/mul.rs similarity index 100% rename from ceno_zkvm/src/instructions/riscv/gpu/mul.rs rename to ceno_zkvm/src/instructions/gpu/mul.rs diff --git a/ceno_zkvm/src/instructions/riscv/gpu/sb.rs b/ceno_zkvm/src/instructions/gpu/sb.rs similarity index 98% rename from ceno_zkvm/src/instructions/riscv/gpu/sb.rs rename to ceno_zkvm/src/instructions/gpu/sb.rs index eca04efd3..35652a3e1 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/sb.rs +++ b/ceno_zkvm/src/instructions/gpu/sb.rs @@ -99,7 +99,7 @@ mod tests { let col_map = extract_sb_column_map(&config, cb.cs.num_witin as usize); let flat = col_map.to_flat(); - crate::instructions::riscv::gpu::colmap_base::validate_column_map(&flat, col_map.num_cols); + crate::instructions::gpu::colmap_base::validate_column_map(&flat, col_map.num_cols); } #[test] diff --git a/ceno_zkvm/src/instructions/riscv/gpu/sh.rs b/ceno_zkvm/src/instructions/gpu/sh.rs similarity index 98% rename from ceno_zkvm/src/instructions/riscv/gpu/sh.rs rename to ceno_zkvm/src/instructions/gpu/sh.rs index 350381e37..8e259d037 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/sh.rs +++ b/ceno_zkvm/src/instructions/gpu/sh.rs @@ -76,7 +76,7 @@ mod tests { let col_map = extract_sh_column_map(&config, cb.cs.num_witin as usize); let flat = col_map.to_flat(); - crate::instructions::riscv::gpu::colmap_base::validate_column_map(&flat, col_map.num_cols); + crate::instructions::gpu::colmap_base::validate_column_map(&flat, col_map.num_cols); } #[test] diff --git a/ceno_zkvm/src/instructions/riscv/gpu/shard_ram.rs b/ceno_zkvm/src/instructions/gpu/shard_ram.rs similarity index 100% rename from ceno_zkvm/src/instructions/riscv/gpu/shard_ram.rs rename to ceno_zkvm/src/instructions/gpu/shard_ram.rs diff --git a/ceno_zkvm/src/instructions/riscv/gpu/shift_i.rs b/ceno_zkvm/src/instructions/gpu/shift_i.rs similarity index 98% rename from ceno_zkvm/src/instructions/riscv/gpu/shift_i.rs rename to ceno_zkvm/src/instructions/gpu/shift_i.rs index 303a9dc6f..43afca628 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/shift_i.rs +++ b/ceno_zkvm/src/instructions/gpu/shift_i.rs @@ -72,7 +72,7 @@ mod tests { let col_map = extract_shift_i_column_map(&config, cb.cs.num_witin as usize); let flat = col_map.to_flat(); - crate::instructions::riscv::gpu::colmap_base::validate_column_map(&flat, col_map.num_cols); + crate::instructions::gpu::colmap_base::validate_column_map(&flat, col_map.num_cols); } #[test] diff --git a/ceno_zkvm/src/instructions/riscv/gpu/shift_r.rs b/ceno_zkvm/src/instructions/gpu/shift_r.rs similarity index 98% rename from ceno_zkvm/src/instructions/riscv/gpu/shift_r.rs rename to ceno_zkvm/src/instructions/gpu/shift_r.rs index 5fc4a571e..54b3e582f 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/shift_r.rs +++ b/ceno_zkvm/src/instructions/gpu/shift_r.rs @@ -77,7 +77,7 @@ mod tests { let col_map = extract_shift_r_column_map(&config, cb.cs.num_witin as usize); let flat = col_map.to_flat(); - crate::instructions::riscv::gpu::colmap_base::validate_column_map(&flat, col_map.num_cols); + crate::instructions::gpu::colmap_base::validate_column_map(&flat, col_map.num_cols); } #[test] diff --git a/ceno_zkvm/src/instructions/riscv/gpu/slt.rs b/ceno_zkvm/src/instructions/gpu/slt.rs similarity index 98% rename from ceno_zkvm/src/instructions/riscv/gpu/slt.rs rename to ceno_zkvm/src/instructions/gpu/slt.rs index 13af3c3fe..9d90432d7 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/slt.rs +++ b/ceno_zkvm/src/instructions/gpu/slt.rs @@ -73,7 +73,7 @@ mod tests { let col_map = extract_slt_column_map(&config, cb.cs.num_witin as usize); let flat = col_map.to_flat(); - crate::instructions::riscv::gpu::colmap_base::validate_column_map(&flat, col_map.num_cols); + crate::instructions::gpu::colmap_base::validate_column_map(&flat, col_map.num_cols); } #[test] diff --git a/ceno_zkvm/src/instructions/riscv/gpu/slti.rs b/ceno_zkvm/src/instructions/gpu/slti.rs similarity index 98% rename from ceno_zkvm/src/instructions/riscv/gpu/slti.rs rename to ceno_zkvm/src/instructions/gpu/slti.rs index e651d3e02..2c8c8aa93 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/slti.rs +++ b/ceno_zkvm/src/instructions/gpu/slti.rs @@ -70,7 +70,7 @@ mod tests { let col_map = extract_slti_column_map(&config, cb.cs.num_witin as usize); let flat = col_map.to_flat(); - crate::instructions::riscv::gpu::colmap_base::validate_column_map(&flat, col_map.num_cols); + crate::instructions::gpu::colmap_base::validate_column_map(&flat, col_map.num_cols); } #[test] diff --git a/ceno_zkvm/src/instructions/riscv/gpu/sub.rs b/ceno_zkvm/src/instructions/gpu/sub.rs similarity index 98% rename from ceno_zkvm/src/instructions/riscv/gpu/sub.rs rename to ceno_zkvm/src/instructions/gpu/sub.rs index f8afa30ee..6b876e719 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/sub.rs +++ b/ceno_zkvm/src/instructions/gpu/sub.rs @@ -63,7 +63,7 @@ mod tests { let col_map = extract_sub_column_map(&config, cb.cs.num_witin as usize); let flat = col_map.to_flat(); - crate::instructions::riscv::gpu::colmap_base::validate_column_map(&flat, col_map.num_cols); + crate::instructions::gpu::colmap_base::validate_column_map(&flat, col_map.num_cols); } #[test] diff --git a/ceno_zkvm/src/instructions/riscv/gpu/sw.rs b/ceno_zkvm/src/instructions/gpu/sw.rs similarity index 98% rename from ceno_zkvm/src/instructions/riscv/gpu/sw.rs rename to ceno_zkvm/src/instructions/gpu/sw.rs index 644c87bf4..dc7914747 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/sw.rs +++ b/ceno_zkvm/src/instructions/gpu/sw.rs @@ -67,7 +67,7 @@ mod tests { let col_map = extract_sw_column_map(&config, cb.cs.num_witin as usize); let flat = col_map.to_flat(); - crate::instructions::riscv::gpu::colmap_base::validate_column_map(&flat, col_map.num_cols); + crate::instructions::gpu::colmap_base::validate_column_map(&flat, col_map.num_cols); } #[test] diff --git a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs b/ceno_zkvm/src/instructions/gpu/witgen_gpu.rs similarity index 100% rename from ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs rename to ceno_zkvm/src/instructions/gpu/witgen_gpu.rs diff --git a/ceno_zkvm/src/instructions/riscv.rs b/ceno_zkvm/src/instructions/riscv.rs index c70264071..ab15d2295 100644 --- a/ceno_zkvm/src/instructions/riscv.rs +++ b/ceno_zkvm/src/instructions/riscv.rs @@ -32,14 +32,13 @@ mod r_insn; mod ecall_insn; -pub mod gpu; #[cfg(feature = "u16limb_circuit")] -mod auipc; +pub(crate) mod auipc; mod im_insn; #[cfg(feature = "u16limb_circuit")] -mod lui; -mod memory; +pub(crate) mod lui; +pub(crate) mod memory; mod s_insn; #[cfg(test)] mod test; diff --git a/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs b/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs index ac164b942..6490aff11 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs @@ -180,7 +180,7 @@ impl Instruction for KeccakInstruction { ) -> Result<(RMMCollections, Multiplicity), ZKVMError> { #[cfg(feature = "gpu")] { - use crate::instructions::riscv::gpu::witgen_gpu::gpu_assign_keccak_instances; + use crate::instructions::gpu::witgen_gpu::gpu_assign_keccak_instances; if let Some(result) = gpu_assign_keccak_instances::( config, shard_ctx, diff --git a/ceno_zkvm/src/instructions/riscv/logic/logic_circuit.rs b/ceno_zkvm/src/instructions/riscv/logic/logic_circuit.rs index f7469645a..529198b91 100644 --- a/ceno_zkvm/src/instructions/riscv/logic/logic_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/logic/logic_circuit.rs @@ -156,7 +156,7 @@ impl LogicConfig { fn emit_lk_and_shardram( &self, - sink: &mut impl crate::instructions::side_effects::SideEffectSink, + sink: &mut impl crate::instructions::gpu::host_ops::SideEffectSink, shard_ctx: &ShardContext, step: &StepRecord, ) { diff --git a/ceno_zkvm/src/structs.rs b/ceno_zkvm/src/structs.rs index 12c2d9505..80bc04304 100644 --- a/ceno_zkvm/src/structs.rs +++ b/ceno_zkvm/src/structs.rs @@ -587,7 +587,7 @@ impl ZKVMWitnesses { let global_input = { #[cfg(feature = "gpu")] let ec_result = { - use crate::instructions::riscv::gpu::witgen_gpu::gpu_batch_continuation_ec; + use crate::instructions::gpu::witgen_gpu::gpu_batch_continuation_ec; gpu_batch_continuation_ec::(&write_record_pairs, &read_record_pairs) .ok() }; @@ -733,7 +733,7 @@ impl ZKVMWitnesses { final_mem: &[(&'static str, Option>, &[MemFinalRecord])], config: & as TableCircuit>::TableConfig, ) -> Result { - use crate::instructions::riscv::gpu::witgen_gpu::{ + use crate::instructions::gpu::witgen_gpu::{ gpu_batch_continuation_ec_on_device, take_shared_device_buffers, }; use ceno_gpu::Buffer; diff --git a/ceno_zkvm/src/tables/shard_ram.rs b/ceno_zkvm/src/tables/shard_ram.rs index 54d478c96..1c82ee198 100644 --- a/ceno_zkvm/src/tables/shard_ram.rs +++ b/ceno_zkvm/src/tables/shard_ram.rs @@ -720,7 +720,7 @@ impl ShardRamCircuit { .collect()); // 2. Extract column map - let col_map = crate::instructions::riscv::gpu::shard_ram::extract_shard_ram_column_map( + let col_map = crate::instructions::gpu::shard_ram::extract_shard_ram_column_map( config, num_witin, ); @@ -931,7 +931,7 @@ impl ShardRamCircuit { let num_rows_padded = 2 * n; // 1. Extract column map (same as regular path) - let col_map = crate::instructions::riscv::gpu::shard_ram::extract_shard_ram_column_map( + let col_map = crate::instructions::gpu::shard_ram::extract_shard_ram_column_map( config, num_witin, ); From 211c875db0b653c6f60af39684c331d9f6c4ca73 Mon Sep 17 00:00:00 2001 From: Velaciela Date: Tue, 24 Mar 2026 12:23:13 +0800 Subject: [PATCH 61/73] naming: side_effects to lk_shardram --- ceno_zkvm/src/instructions.rs | 14 +- .../src/instructions/gpu/debug_compare.rs | 8 +- ceno_zkvm/src/instructions/gpu/gpu_config.rs | 2 +- .../instructions/gpu/host_ops/cpu_fallback.rs | 2 +- .../src/instructions/gpu/host_ops/emit.rs | 16 +-- .../src/instructions/gpu/host_ops/mod.rs | 120 +++++++++--------- .../src/instructions/gpu/host_ops/sink.rs | 16 +-- ceno_zkvm/src/instructions/gpu/witgen_gpu.rs | 20 +-- ceno_zkvm/src/instructions/riscv/arith.rs | 4 +- .../riscv/arith_imm/arith_imm_circuit_v2.rs | 4 +- ceno_zkvm/src/instructions/riscv/auipc.rs | 6 +- ceno_zkvm/src/instructions/riscv/b_insn.rs | 4 +- .../riscv/branch/branch_circuit_v2.rs | 4 +- .../instructions/riscv/div/div_circuit_v2.rs | 4 +- .../src/instructions/riscv/ecall/keccak.rs | 2 +- .../instructions/riscv/ecall/sha_extend.rs | 2 +- .../src/instructions/riscv/ecall/uint256.rs | 4 +- .../riscv/ecall/weierstrass_add.rs | 2 +- .../riscv/ecall/weierstrass_decompress.rs | 2 +- .../riscv/ecall/weierstrass_double.rs | 2 +- ceno_zkvm/src/instructions/riscv/i_insn.rs | 4 +- ceno_zkvm/src/instructions/riscv/im_insn.rs | 4 +- ceno_zkvm/src/instructions/riscv/insn_base.rs | 18 +-- ceno_zkvm/src/instructions/riscv/j_insn.rs | 4 +- .../src/instructions/riscv/jump/jal_v2.rs | 4 +- .../src/instructions/riscv/jump/jalr_v2.rs | 4 +- .../instructions/riscv/logic/logic_circuit.rs | 6 +- .../riscv/logic_imm/logic_imm_circuit_v2.rs | 4 +- ceno_zkvm/src/instructions/riscv/lui.rs | 4 +- .../src/instructions/riscv/memory/load.rs | 2 +- .../src/instructions/riscv/memory/load_v2.rs | 4 +- .../src/instructions/riscv/memory/store_v2.rs | 4 +- .../riscv/mulh/mulh_circuit_v2.rs | 4 +- ceno_zkvm/src/instructions/riscv/r_insn.rs | 4 +- ceno_zkvm/src/instructions/riscv/s_insn.rs | 4 +- .../riscv/shift/shift_circuit_v2.rs | 10 +- .../instructions/riscv/slt/slt_circuit_v2.rs | 4 +- .../riscv/slti/slti_circuit_v2.rs | 4 +- 38 files changed, 164 insertions(+), 166 deletions(-) diff --git a/ceno_zkvm/src/instructions.rs b/ceno_zkvm/src/instructions.rs index 99e1da0b1..c121f00bc 100644 --- a/ceno_zkvm/src/instructions.rs +++ b/ceno_zkvm/src/instructions.rs @@ -22,8 +22,6 @@ use witness::{InstancePaddingStrategy, RowMajorMatrix}; pub mod riscv; pub mod gpu; -/// Backward-compatible re-export: old `side_effects` path still works. -pub use gpu::host_ops as side_effects; pub use gpu::host_ops::{cpu_assign_instances, cpu_collect_lk_and_shardram, cpu_collect_shardram}; @@ -31,7 +29,7 @@ pub trait Instruction { type InstructionConfig: Send + Sync; type InsnType: Clone + Copy; - const GPU_SIDE_EFFECTS: bool = false; + const GPU_LK_SHARDRAM: bool = false; fn padding_strategy() -> InstancePaddingStrategy { InstancePaddingStrategy::Default @@ -112,7 +110,7 @@ pub trait Instruction { ) -> Result<(), ZKVMError> { Err(ZKVMError::InvalidWitness( format!( - "{} does not implement lightweight side effects collection", + "{} does not implement lk and shardram collection", Self::name() ) .into(), @@ -127,7 +125,7 @@ pub trait Instruction { ) -> Result<(), ZKVMError> { Err(ZKVMError::InvalidWitness( format!( - "{} does not implement shard-only side effects collection", + "{} does not implement shardram-only collection", Self::name() ) .into(), @@ -234,11 +232,11 @@ pub fn full_step_indices(steps: &[StepRecord]) -> Vec { // --------------------------------------------------------------------------- /// Implement `collect_lk_and_shardram` with a common prologue -/// (create `CpuSideEffectSink`, dispatch to `config.$field.emit_lk_and_shardram`) +/// (create `CpuLkShardramSink`, dispatch to `config.$field.emit_lk_and_shardram`) /// and a chip-specific body for additional LK ops. /// /// The closure receives `(sink, step, config, ctx)`: -/// - `sink: &mut CpuSideEffectSink` — emit LK ops and send events +/// - `sink: &mut CpuLkShardramSink` — emit LK ops and send events /// - `step: &StepRecord` — current step /// - `config: &Self::InstructionConfig` — circuit config (for sub-configs) /// - `ctx: &ShardContext` — read-only shard context @@ -261,7 +259,7 @@ macro_rules! impl_collect_lk_and_shardram { let shard_ctx_ptr = shard_ctx as *mut $crate::e2e::ShardContext; let _ctx = unsafe { &*shard_ctx_ptr }; let mut _sink_val = unsafe { - $crate::instructions::gpu::host_ops::CpuSideEffectSink::from_raw( + $crate::instructions::gpu::host_ops::CpuLkShardramSink::from_raw( shard_ctx_ptr, lk_multiplicity, ) diff --git a/ceno_zkvm/src/instructions/gpu/debug_compare.rs b/ceno_zkvm/src/instructions/gpu/debug_compare.rs index f33ac96bd..12b9f3550 100644 --- a/ceno_zkvm/src/instructions/gpu/debug_compare.rs +++ b/ceno_zkvm/src/instructions/gpu/debug_compare.rs @@ -4,7 +4,7 @@ /// to validate correctness. Activated by environment variables: /// - CENO_GPU_DEBUG_COMPARE_LK: compare lookup multiplicities /// - CENO_GPU_DEBUG_COMPARE_WITNESS: compare witness matrices -/// - CENO_GPU_DEBUG_COMPARE_SHARD: compare shard side effects +/// - CENO_GPU_DEBUG_COMPARE_SHARD: compare shardram records /// - CENO_GPU_DEBUG_COMPARE_EC: compare EC points use ceno_emul::{StepIndex, StepRecord, WordAddr}; use ceno_gpu::common::witgen::types::{GpuRamRecordSlot, GpuShardRamRecord}; @@ -161,7 +161,7 @@ pub(crate) fn debug_compare_witness>( Ok(()) } -pub(crate) fn debug_compare_shard_side_effects>( +pub(crate) fn debug_compare_shardram>( config: &I::InstructionConfig, shard_ctx: &ShardContext, shard_steps: &[StepRecord], @@ -245,7 +245,7 @@ pub(crate) fn debug_compare_shard_ec>( if let Err(e) = cpu_collect_shardram::( config, &mut cpu_ctx, shard_steps, step_indices, ) { - tracing::error!("[GPU EC debug] kind={kind:?} CPU shard side effects failed: {e:?}"); + tracing::error!("[GPU EC debug] kind={kind:?} CPU shardram records failed: {e:?}"); return; } @@ -583,7 +583,7 @@ pub(crate) fn lookup_table_name(table_idx: usize) -> &'static str { } /// Debug comparison for keccak GPU witgen. -/// Runs the CPU path and compares LK / witness / shard side effects. +/// Runs the CPU path and compares LK / witness / shardram records. /// /// Activated by CENO_GPU_DEBUG_COMPARE_LK, CENO_GPU_DEBUG_COMPARE_WITNESS, /// or CENO_GPU_DEBUG_COMPARE_SHARD environment variables. diff --git a/ceno_zkvm/src/instructions/gpu/gpu_config.rs b/ceno_zkvm/src/instructions/gpu/gpu_config.rs index 3b2bee2d0..9fe72d4ed 100644 --- a/ceno_zkvm/src/instructions/gpu/gpu_config.rs +++ b/ceno_zkvm/src/instructions/gpu/gpu_config.rs @@ -52,7 +52,7 @@ pub(crate) fn kind_tag(kind: GpuWitgenKind) -> &'static str { /// Returns true if the GPU CUDA kernel for this kind has been verified to produce /// correct LK multiplicity counters matching the CPU baseline. -/// Unverified kinds fall back to CPU full side effects (GPU still handles witness). +/// Unverified kinds fall back to CPU full lk_shardram (GPU still handles witness). /// /// Override with `CENO_GPU_DISABLE_LK_KINDS=add,sub,...` to force specific kinds /// back to CPU LK (for binary-search debugging). diff --git a/ceno_zkvm/src/instructions/gpu/host_ops/cpu_fallback.rs b/ceno_zkvm/src/instructions/gpu/host_ops/cpu_fallback.rs index 8e3e6b194..376e9872e 100644 --- a/ceno_zkvm/src/instructions/gpu/host_ops/cpu_fallback.rs +++ b/ceno_zkvm/src/instructions/gpu/host_ops/cpu_fallback.rs @@ -88,7 +88,7 @@ pub fn cpu_assign_instances>( )) } -/// CPU-only side-effect collection for GPU-enabled instructions. +/// CPU-only lk_shardram collection for GPU-enabled instructions. /// /// This path deliberately avoids scratch witness buffers and calls only the /// instruction-specific side-effect collector. diff --git a/ceno_zkvm/src/instructions/gpu/host_ops/emit.rs b/ceno_zkvm/src/instructions/gpu/host_ops/emit.rs index 4f84d3e05..ae4647dea 100644 --- a/ceno_zkvm/src/instructions/gpu/host_ops/emit.rs +++ b/ceno_zkvm/src/instructions/gpu/host_ops/emit.rs @@ -5,10 +5,10 @@ use gkr_iop::{ use crate::instructions::riscv::constants::{LIMB_BITS, UINT_LIMBS}; -use super::{LkOp, SideEffectSink}; +use super::{LkOp, LkShardramSink}; pub fn emit_assert_lt_ops( - sink: &mut impl SideEffectSink, + sink: &mut impl LkShardramSink, lt_cfg: &AssertLtConfig, lhs: u64, rhs: u64, @@ -29,7 +29,7 @@ pub fn emit_assert_lt_ops( } } -pub fn emit_u16_limbs(sink: &mut impl SideEffectSink, value: u32) { +pub fn emit_u16_limbs(sink: &mut impl LkShardramSink, value: u32) { sink.emit_lk(LkOp::AssertU16 { value: (value & 0xffff) as u16, }); @@ -38,7 +38,7 @@ pub fn emit_u16_limbs(sink: &mut impl SideEffectSink, value: u32) { }); } -pub fn emit_const_range_op(sink: &mut impl SideEffectSink, value: u64, bits: usize) { +pub fn emit_const_range_op(sink: &mut impl LkShardramSink, value: u64, bits: usize) { match bits { 0 | 1 => {} 14 => sink.emit_lk(LkOp::AssertU14 { @@ -54,7 +54,7 @@ pub fn emit_const_range_op(sink: &mut impl SideEffectSink, value: u64, bits: usi } } -pub fn emit_byte_decomposition_ops(sink: &mut impl SideEffectSink, bytes: &[u8]) { +pub fn emit_byte_decomposition_ops(sink: &mut impl LkShardramSink, bytes: &[u8]) { for chunk in bytes.chunks(2) { match chunk { [a, b] => sink.emit_lk(LkOp::DoubleU8 { a: *a, b: *b }), @@ -64,7 +64,7 @@ pub fn emit_byte_decomposition_ops(sink: &mut impl SideEffectSink, bytes: &[u8]) } } -pub fn emit_signed_extend_op(sink: &mut impl SideEffectSink, n_bits: usize, value: u64) { +pub fn emit_signed_extend_op(sink: &mut impl LkShardramSink, n_bits: usize, value: u64) { let msb = value >> (n_bits - 1); sink.emit_lk(LkOp::DynamicRange { value: 2 * value - (msb << n_bits), @@ -73,7 +73,7 @@ pub fn emit_signed_extend_op(sink: &mut impl SideEffectSink, n_bits: usize, valu } pub fn emit_logic_u8_ops( - sink: &mut impl SideEffectSink, + sink: &mut impl LkShardramSink, lhs: u64, rhs: u64, num_bytes: usize, @@ -93,7 +93,7 @@ pub fn emit_logic_u8_ops( } pub fn emit_uint_limbs_lt_ops( - sink: &mut impl SideEffectSink, + sink: &mut impl LkShardramSink, is_sign_comparison: bool, a: &[u16], b: &[u16], diff --git a/ceno_zkvm/src/instructions/gpu/host_ops/mod.rs b/ceno_zkvm/src/instructions/gpu/host_ops/mod.rs index 194fd4b64..23b6986ae 100644 --- a/ceno_zkvm/src/instructions/gpu/host_ops/mod.rs +++ b/ceno_zkvm/src/instructions/gpu/host_ops/mod.rs @@ -1,6 +1,6 @@ //! Host-side operations for GPU-CPU hybrid witness generation. //! -//! Contains lookup/shard side-effect collection abstractions and CPU fallback paths. +//! Contains lookup/shard lk_shardram collection abstractions and CPU fallback paths. mod lk_ops; mod sink; @@ -46,7 +46,7 @@ mod tests { type E = GoldilocksExt2; - fn assert_side_effects_match>( + fn assert_lk_shardram_match>( config: &I::InstructionConfig, num_witin: usize, num_structural_witin: usize, @@ -84,7 +84,7 @@ mod tests { ); } - fn assert_shard_side_effects_match>( + fn assert_shard_lk_shardram_match>( config: &I::InstructionConfig, num_witin: usize, num_structural_witin: usize, @@ -164,8 +164,8 @@ mod tests { } #[test] - fn test_add_side_effects_match_assign_instance() { - let mut cs = ConstraintSystem::::new(|| "add_side_effects"); + fn test_add_lk_shardram_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "add_lk_shardram"); let mut cb = CircuitBuilder::new(&mut cs); let config = AddInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); @@ -189,7 +189,7 @@ mod tests { }) .collect(); - assert_side_effects_match::>( + assert_lk_shardram_match::>( &config, cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, @@ -198,8 +198,8 @@ mod tests { } #[test] - fn test_and_side_effects_match_assign_instance() { - let mut cs = ConstraintSystem::::new(|| "and_side_effects"); + fn test_and_lk_shardram_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "and_lk_shardram"); let mut cb = CircuitBuilder::new(&mut cs); let config = AndInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); @@ -223,7 +223,7 @@ mod tests { }) .collect(); - assert_side_effects_match::>( + assert_lk_shardram_match::>( &config, cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, @@ -232,8 +232,8 @@ mod tests { } #[test] - fn test_add_shard_side_effects_match_assign_instance() { - let mut cs = ConstraintSystem::::new(|| "add_shard_side_effects"); + fn test_add_shard_lk_shardram_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "add_shard_lk_shardram"); let mut cb = CircuitBuilder::new(&mut cs); let config = AddInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); @@ -257,7 +257,7 @@ mod tests { }) .collect(); - assert_shard_side_effects_match::>( + assert_shard_lk_shardram_match::>( &config, cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, @@ -266,8 +266,8 @@ mod tests { } #[test] - fn test_and_shard_side_effects_match_assign_instance() { - let mut cs = ConstraintSystem::::new(|| "and_shard_side_effects"); + fn test_and_shard_lk_shardram_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "and_shard_lk_shardram"); let mut cb = CircuitBuilder::new(&mut cs); let config = AndInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); @@ -291,7 +291,7 @@ mod tests { }) .collect(); - assert_shard_side_effects_match::>( + assert_shard_lk_shardram_match::>( &config, cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, @@ -300,8 +300,8 @@ mod tests { } #[test] - fn test_lw_side_effects_match_assign_instance() { - let mut cs = ConstraintSystem::::new(|| "lw_side_effects"); + fn test_lw_lk_shardram_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "lw_lk_shardram"); let mut cb = CircuitBuilder::new(&mut cs); let config = LwInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); @@ -331,7 +331,7 @@ mod tests { }) .collect(); - assert_side_effects_match::>( + assert_lk_shardram_match::>( &config, cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, @@ -340,8 +340,8 @@ mod tests { } #[test] - fn test_lw_shard_side_effects_match_assign_instance() { - let mut cs = ConstraintSystem::::new(|| "lw_shard_side_effects"); + fn test_lw_shard_lk_shardram_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "lw_shard_lk_shardram"); let mut cb = CircuitBuilder::new(&mut cs); let config = LwInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); @@ -371,7 +371,7 @@ mod tests { }) .collect(); - assert_shard_side_effects_match::>( + assert_shard_lk_shardram_match::>( &config, cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, @@ -380,8 +380,8 @@ mod tests { } #[test] - fn test_beq_side_effects_match_assign_instance() { - let mut cs = ConstraintSystem::::new(|| "beq_side_effects"); + fn test_beq_lk_shardram_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "beq_lk_shardram"); let mut cb = CircuitBuilder::new(&mut cs); let config = BeqInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); @@ -410,7 +410,7 @@ mod tests { }) .collect(); - assert_side_effects_match::>( + assert_lk_shardram_match::>( &config, cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, @@ -419,8 +419,8 @@ mod tests { } #[test] - fn test_blt_side_effects_match_assign_instance() { - let mut cs = ConstraintSystem::::new(|| "blt_side_effects"); + fn test_blt_lk_shardram_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "blt_lk_shardram"); let mut cb = CircuitBuilder::new(&mut cs); let config = BltInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); @@ -446,7 +446,7 @@ mod tests { }) .collect(); - assert_side_effects_match::>( + assert_lk_shardram_match::>( &config, cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, @@ -455,8 +455,8 @@ mod tests { } #[test] - fn test_jal_side_effects_match_assign_instance() { - let mut cs = ConstraintSystem::::new(|| "jal_side_effects"); + fn test_jal_lk_shardram_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "jal_lk_shardram"); let mut cb = CircuitBuilder::new(&mut cs); let config = JalInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); @@ -476,7 +476,7 @@ mod tests { }) .collect(); - assert_side_effects_match::>( + assert_lk_shardram_match::>( &config, cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, @@ -485,8 +485,8 @@ mod tests { } #[test] - fn test_jalr_side_effects_match_assign_instance() { - let mut cs = ConstraintSystem::::new(|| "jalr_side_effects"); + fn test_jalr_lk_shardram_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "jalr_lk_shardram"); let mut cb = CircuitBuilder::new(&mut cs); let config = JalrInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); @@ -508,7 +508,7 @@ mod tests { }) .collect(); - assert_side_effects_match::>( + assert_lk_shardram_match::>( &config, cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, @@ -517,8 +517,8 @@ mod tests { } #[test] - fn test_slt_side_effects_match_assign_instance() { - let mut cs = ConstraintSystem::::new(|| "slt_side_effects"); + fn test_slt_lk_shardram_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "slt_lk_shardram"); let mut cb = CircuitBuilder::new(&mut cs); let config = SltInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); @@ -541,7 +541,7 @@ mod tests { }) .collect(); - assert_side_effects_match::>( + assert_lk_shardram_match::>( &config, cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, @@ -550,8 +550,8 @@ mod tests { } #[test] - fn test_slti_side_effects_match_assign_instance() { - let mut cs = ConstraintSystem::::new(|| "slti_side_effects"); + fn test_slti_lk_shardram_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "slti_lk_shardram"); let mut cb = CircuitBuilder::new(&mut cs); let config = SltiInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); @@ -573,7 +573,7 @@ mod tests { }) .collect(); - assert_side_effects_match::>( + assert_lk_shardram_match::>( &config, cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, @@ -582,8 +582,8 @@ mod tests { } #[test] - fn test_sra_side_effects_match_assign_instance() { - let mut cs = ConstraintSystem::::new(|| "sra_side_effects"); + fn test_sra_lk_shardram_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "sra_lk_shardram"); let mut cb = CircuitBuilder::new(&mut cs); let config = SraInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); @@ -606,7 +606,7 @@ mod tests { }) .collect(); - assert_side_effects_match::>( + assert_lk_shardram_match::>( &config, cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, @@ -615,8 +615,8 @@ mod tests { } #[test] - fn test_slli_side_effects_match_assign_instance() { - let mut cs = ConstraintSystem::::new(|| "slli_side_effects"); + fn test_slli_lk_shardram_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "slli_lk_shardram"); let mut cb = CircuitBuilder::new(&mut cs); let config = SlliInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); @@ -637,7 +637,7 @@ mod tests { }) .collect(); - assert_side_effects_match::>( + assert_lk_shardram_match::>( &config, cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, @@ -646,8 +646,8 @@ mod tests { } #[test] - fn test_sb_side_effects_match_assign_instance() { - let mut cs = ConstraintSystem::::new(|| "sb_side_effects"); + fn test_sb_lk_shardram_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "sb_lk_shardram"); let mut cb = CircuitBuilder::new(&mut cs); let config = SbInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); @@ -677,7 +677,7 @@ mod tests { }) .collect(); - assert_side_effects_match::>( + assert_lk_shardram_match::>( &config, cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, @@ -686,8 +686,8 @@ mod tests { } #[test] - fn test_mul_side_effects_match_assign_instance() { - let mut cs = ConstraintSystem::::new(|| "mul_side_effects"); + fn test_mul_lk_shardram_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "mul_lk_shardram"); let mut cb = CircuitBuilder::new(&mut cs); let config = MulInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); @@ -714,7 +714,7 @@ mod tests { }) .collect(); - assert_side_effects_match::>( + assert_lk_shardram_match::>( &config, cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, @@ -723,8 +723,8 @@ mod tests { } #[test] - fn test_mulh_side_effects_match_assign_instance() { - let mut cs = ConstraintSystem::::new(|| "mulh_side_effects"); + fn test_mulh_lk_shardram_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "mulh_lk_shardram"); let mut cb = CircuitBuilder::new(&mut cs); let config = MulhInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); @@ -752,7 +752,7 @@ mod tests { }) .collect(); - assert_side_effects_match::>( + assert_lk_shardram_match::>( &config, cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, @@ -761,8 +761,8 @@ mod tests { } #[test] - fn test_div_side_effects_match_assign_instance() { - let mut cs = ConstraintSystem::::new(|| "div_side_effects"); + fn test_div_lk_shardram_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "div_lk_shardram"); let mut cb = CircuitBuilder::new(&mut cs); let config = DivInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); @@ -794,7 +794,7 @@ mod tests { }) .collect(); - assert_side_effects_match::>( + assert_lk_shardram_match::>( &config, cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, @@ -803,8 +803,8 @@ mod tests { } #[test] - fn test_remu_side_effects_match_assign_instance() { - let mut cs = ConstraintSystem::::new(|| "remu_side_effects"); + fn test_remu_lk_shardram_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "remu_lk_shardram"); let mut cb = CircuitBuilder::new(&mut cs); let config = RemuInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); @@ -832,7 +832,7 @@ mod tests { }) .collect(); - assert_side_effects_match::>( + assert_lk_shardram_match::>( &config, cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, diff --git a/ceno_zkvm/src/instructions/gpu/host_ops/sink.rs b/ceno_zkvm/src/instructions/gpu/host_ops/sink.rs index 4a07d7e43..3c19f44bc 100644 --- a/ceno_zkvm/src/instructions/gpu/host_ops/sink.rs +++ b/ceno_zkvm/src/instructions/gpu/host_ops/sink.rs @@ -5,19 +5,19 @@ use crate::{e2e::ShardContext, witness::LkMultiplicity}; use super::{LkOp, SendEvent}; -pub trait SideEffectSink { +pub trait LkShardramSink { fn emit_lk(&mut self, op: LkOp); fn emit_send(&mut self, event: SendEvent); fn touch_addr(&mut self, addr: WordAddr); } -pub struct CpuSideEffectSink<'ctx, 'shard, 'lk> { +pub struct CpuLkShardramSink<'ctx, 'shard, 'lk> { shard_ctx: *mut ShardContext<'shard>, lk: &'lk mut LkMultiplicity, _marker: PhantomData<&'ctx mut ShardContext<'shard>>, } -impl<'ctx, 'shard, 'lk> CpuSideEffectSink<'ctx, 'shard, 'lk> { +impl<'ctx, 'shard, 'lk> CpuLkShardramSink<'ctx, 'shard, 'lk> { pub unsafe fn from_raw( shard_ctx: *mut ShardContext<'shard>, lk: &'lk mut LkMultiplicity, @@ -31,12 +31,12 @@ impl<'ctx, 'shard, 'lk> CpuSideEffectSink<'ctx, 'shard, 'lk> { fn shard_ctx(&mut self) -> &mut ShardContext<'shard> { // Safety: `from_raw` is only constructed from a live `&mut ShardContext` - // for the duration of side-effect collection. + // for the duration of lk_shardram collection. unsafe { &mut *self.shard_ctx } } } -/// Create a `CpuSideEffectSink` and an immutable view of `ShardContext`, +/// Create a `CpuLkShardramSink` and an immutable view of `ShardContext`, /// then pass both to the closure `f`. /// /// This encapsulates the raw-pointer trick needed to hold `&mut ShardContext` @@ -48,15 +48,15 @@ impl<'ctx, 'shard, 'lk> CpuSideEffectSink<'ctx, 'shard, 'lk> { pub fn with_cpu_sink<'a, R>( shard_ctx: &'a mut ShardContext<'a>, lk: &'a mut LkMultiplicity, - f: impl FnOnce(&mut CpuSideEffectSink<'a, 'a, 'a>, &ShardContext) -> R, + f: impl FnOnce(&mut CpuLkShardramSink<'a, 'a, 'a>, &ShardContext) -> R, ) -> R { let ptr = shard_ctx as *mut ShardContext; let view = unsafe { &*ptr }; - let mut sink = unsafe { CpuSideEffectSink::from_raw(ptr, lk) }; + let mut sink = unsafe { CpuLkShardramSink::from_raw(ptr, lk) }; f(&mut sink, view) } -impl SideEffectSink for CpuSideEffectSink<'_, '_, '_> { +impl LkShardramSink for CpuLkShardramSink<'_, '_, '_> { fn emit_lk(&mut self, op: LkOp) { for (table, key) in op.encode_all() { self.lk.increment(table, key); diff --git a/ceno_zkvm/src/instructions/gpu/witgen_gpu.rs b/ceno_zkvm/src/instructions/gpu/witgen_gpu.rs index 5eb63aba9..a44b318d9 100644 --- a/ceno_zkvm/src/instructions/gpu/witgen_gpu.rs +++ b/ceno_zkvm/src/instructions/gpu/witgen_gpu.rs @@ -2,8 +2,8 @@ /// /// This module provides `try_gpu_assign_instances` which: /// 1. Runs the GPU kernel to fill the witness matrix (fast) -/// 2. Runs a lightweight CPU loop to collect side effects without witness replay -/// 3. Returns the GPU-generated witness + CPU-collected side effects +/// 2. Runs a lightweight CPU loop to collect lk and shardram without witness replay +/// 3. Returns the GPU-generated witness + CPU-collected lk and shardram use ceno_emul::{StepIndex, StepRecord, WordAddr}; use ceno_gpu::{ Buffer, CudaHal, bb31::CudaHalBB31, common::transpose::matrix_transpose, @@ -18,7 +18,7 @@ use witness::{InstancePaddingStrategy, RowMajorMatrix}; use super::debug_compare::{ debug_compare_final_lk, debug_compare_keccak, debug_compare_shard_ec, - debug_compare_shard_side_effects, debug_compare_witness, + debug_compare_shardram, debug_compare_witness, }; use super::gpu_config::{ is_gpu_witgen_disabled, is_kind_disabled, kind_has_verified_lk, kind_has_verified_shard, @@ -135,7 +135,7 @@ pub(crate) fn try_gpu_assign_instances>( return Ok(None); } - if !I::GPU_SIDE_EFFECTS { + if !I::GPU_LK_SHARDRAM { return Ok(None); } @@ -211,8 +211,8 @@ fn gpu_assign_instances_inner>( ) })?; - // Step 2: Collect side effects - // Priority: GPU shard records > CPU shard records > full CPU side effects + // Step 2: Collect lk and shardram + // Priority: GPU shard records > CPU shard records > full CPU lk and shardram let lk_multiplicity = if gpu_lk_counters.is_some() && kind_has_verified_lk(kind) { let lk_multiplicity = info_span!("gpu_lk_d2h").in_scope(|| { gpu_lk_counters_to_multiplicity(gpu_lk_counters.unwrap()) @@ -332,13 +332,13 @@ fn gpu_assign_instances_inner>( } lk_multiplicity } else { - // GPU LK counters missing or unverified — fall back to full CPU side effects - info_span!("cpu_side_effects").in_scope(|| { + // GPU LK counters missing or unverified — fall back to full CPU lk and shardram + info_span!("cpu_lk_shardram").in_scope(|| { collect_lk_and_shardram::(config, shard_ctx, shard_steps, step_indices) })? }; debug_compare_final_lk::(config, shard_ctx, num_witin, num_structural_witin, shard_steps, step_indices, kind, &lk_multiplicity)?; - debug_compare_shard_side_effects::(config, shard_ctx, shard_steps, step_indices, kind)?; + debug_compare_shardram::(config, shard_ctx, shard_steps, step_indices, kind)?; // Step 3: Build structural witness (just selector = ONE) let mut raw_structural = RowMajorMatrix::::new( @@ -1154,7 +1154,7 @@ fn gpu_fill_witness>( } } -/// CPU-side loop to collect side effects only (shard_ctx.send, lk_multiplicity). +/// CPU-side loop to collect lk and shardram only (shard_ctx.send, lk_multiplicity). /// Runs assign_instance with a scratch buffer per thread. fn collect_lk_and_shardram>( config: &I::InstructionConfig, diff --git a/ceno_zkvm/src/instructions/riscv/arith.rs b/ceno_zkvm/src/instructions/riscv/arith.rs index edbd86b6c..f7dd04305 100644 --- a/ceno_zkvm/src/instructions/riscv/arith.rs +++ b/ceno_zkvm/src/instructions/riscv/arith.rs @@ -8,7 +8,7 @@ use crate::{ error::ZKVMError, instructions::{ Instruction, - side_effects::emit_u16_limbs, + gpu::host_ops::emit_u16_limbs, }, structs::ProgramParams, uint::Value, @@ -45,7 +45,7 @@ impl Instruction for ArithInstruction; type InsnType = InsnKind; - const GPU_SIDE_EFFECTS: bool = matches!(I::INST_KIND, InsnKind::ADD | InsnKind::SUB); + const GPU_LK_SHARDRAM: bool = matches!(I::INST_KIND, InsnKind::ADD | InsnKind::SUB); fn inst_kinds() -> &'static [Self::InsnType] { &[I::INST_KIND] diff --git a/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit_v2.rs index 442372323..0eade7a25 100644 --- a/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit_v2.rs @@ -7,7 +7,7 @@ use crate::{ instructions::{ Instruction, riscv::{RIVInstruction, constants::UInt, i_insn::IInstructionConfig}, - side_effects::emit_u16_limbs, + gpu::host_ops::emit_u16_limbs, }, structs::ProgramParams, utils::{imm_sign_extend, imm_sign_extend_circuit}, @@ -36,7 +36,7 @@ impl Instruction for AddiInstruction { type InstructionConfig = InstructionConfig; type InsnType = InsnKind; - const GPU_SIDE_EFFECTS: bool = true; + const GPU_LK_SHARDRAM: bool = true; fn inst_kinds() -> &'static [Self::InsnType] { &[InsnKind::ADDI] diff --git a/ceno_zkvm/src/instructions/riscv/auipc.rs b/ceno_zkvm/src/instructions/riscv/auipc.rs index 04055d4b7..7e7da705d 100644 --- a/ceno_zkvm/src/instructions/riscv/auipc.rs +++ b/ceno_zkvm/src/instructions/riscv/auipc.rs @@ -13,8 +13,8 @@ use crate::{ constants::{PC_BITS, UINT_BYTE_LIMBS, UInt8}, i_insn::IInstructionConfig, }, - side_effects::{ - LkOp, SideEffectSink, emit_byte_decomposition_ops, + gpu::host_ops::{ + LkOp, LkShardramSink, emit_byte_decomposition_ops, emit_const_range_op, }, }, @@ -44,7 +44,7 @@ impl Instruction for AuipcInstruction { type InstructionConfig = AuipcConfig; type InsnType = InsnKind; - const GPU_SIDE_EFFECTS: bool = true; + const GPU_LK_SHARDRAM: bool = true; fn inst_kinds() -> &'static [Self::InsnType] { &[InsnKind::AUIPC] diff --git a/ceno_zkvm/src/instructions/riscv/b_insn.rs b/ceno_zkvm/src/instructions/riscv/b_insn.rs index a7e7acb1f..b5f7d546a 100644 --- a/ceno_zkvm/src/instructions/riscv/b_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/b_insn.rs @@ -9,7 +9,7 @@ use crate::{ error::ZKVMError, instructions::{ riscv::insn_base::{ReadRS1, ReadRS2, StateInOut}, - side_effects::{LkOp, SideEffectSink}, + gpu::host_ops::{LkOp, LkShardramSink}, }, tables::InsnRecord, witness::{LkMultiplicity, set_val}, @@ -128,7 +128,7 @@ impl BInstructionConfig { pub fn emit_lk_and_shardram( &self, - sink: &mut impl SideEffectSink, + sink: &mut impl LkShardramSink, shard_ctx: &ShardContext, step: &StepRecord, ) { diff --git a/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs index 13ea76578..2f1c42975 100644 --- a/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs @@ -12,7 +12,7 @@ use crate::{ b_insn::BInstructionConfig, constants::{UINT_LIMBS, UInt}, }, - side_effects::emit_uint_limbs_lt_ops, + gpu::host_ops::emit_uint_limbs_lt_ops, }, structs::ProgramParams, witness::LkMultiplicity, @@ -43,7 +43,7 @@ impl Instruction for BranchCircuit; type InsnType = InsnKind; - const GPU_SIDE_EFFECTS: bool = true; + const GPU_LK_SHARDRAM: bool = true; fn inst_kinds() -> &'static [Self::InsnType] { &[I::INST_KIND] diff --git a/ceno_zkvm/src/instructions/riscv/div/div_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/div/div_circuit_v2.rs index ce455fcaa..de7476e9d 100644 --- a/ceno_zkvm/src/instructions/riscv/div/div_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/div/div_circuit_v2.rs @@ -18,7 +18,7 @@ use crate::{ instructions::{ Instruction, riscv::constants::LIMB_BITS, - side_effects::{LkOp, SideEffectSink, emit_u16_limbs}, + gpu::host_ops::{LkOp, LkShardramSink, emit_u16_limbs}, }, structs::ProgramParams, uint::Value, @@ -55,7 +55,7 @@ impl Instruction for ArithInstruction; type InsnType = InsnKind; - const GPU_SIDE_EFFECTS: bool = true; + const GPU_LK_SHARDRAM: bool = true; fn inst_kinds() -> &'static [Self::InsnType] { &[I::INST_KIND] diff --git a/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs b/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs index 6490aff11..f94c274cb 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs @@ -50,7 +50,7 @@ pub struct EcallKeccakConfig { pub(crate) mem_rw: Vec, } -/// KeccakInstruction can handle any instruction and produce its side-effects. +/// KeccakInstruction can handle any instruction and produce its lk and shardram data. pub struct KeccakInstruction(PhantomData); impl Instruction for KeccakInstruction { diff --git a/ceno_zkvm/src/instructions/riscv/ecall/sha_extend.rs b/ceno_zkvm/src/instructions/riscv/ecall/sha_extend.rs index 3a0f42ab2..8044e5d2e 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/sha_extend.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/sha_extend.rs @@ -47,7 +47,7 @@ pub struct EcallShaExtendConfig { mem_rw: Vec, } -/// ShaExtendInstruction can handle any instruction and produce its side-effects. +/// ShaExtendInstruction can handle any instruction and produce its lk and shardram data. pub struct ShaExtendInstruction(PhantomData); impl Instruction for ShaExtendInstruction { diff --git a/ceno_zkvm/src/instructions/riscv/ecall/uint256.rs b/ceno_zkvm/src/instructions/riscv/ecall/uint256.rs index 67a042cd7..92c3778c5 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/uint256.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/uint256.rs @@ -63,7 +63,7 @@ pub struct EcallUint256MulConfig { mem_rw: Vec, } -/// Uint256MulInstruction can handle any instruction and produce its side-effects. +/// Uint256MulInstruction can handle any instruction and produce its lk and shardram data. pub struct Uint256MulInstruction(PhantomData); impl Instruction for Uint256MulInstruction { @@ -372,7 +372,7 @@ impl Instruction for Uint256MulInstruction { } } -/// Uint256InvInstruction can handle any instruction and produce its side-effects. +/// Uint256InvInstruction can handle any instruction and produce its lk and shardram data. pub struct Uint256InvInstruction(PhantomData<(E, P)>); pub struct Secp256K1EcallSpec; diff --git a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_add.rs b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_add.rs index ca4f59ed3..27470aff1 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_add.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_add.rs @@ -52,7 +52,7 @@ pub struct EcallWeierstrassAddAssignConfig mem_rw: Vec, } -/// WeierstrassAddAssignInstruction can handle any instruction and produce its side-effects. +/// WeierstrassAddAssignInstruction can handle any instruction and produce its lk and shardram data. pub struct WeierstrassAddAssignInstruction(PhantomData<(E, EC)>); impl Instruction diff --git a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_decompress.rs b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_decompress.rs index 1d79efd63..682b2b967 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_decompress.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_decompress.rs @@ -59,7 +59,7 @@ pub struct EcallWeierstrassDecompressConfig, } -/// WeierstrassDecompressInstruction can handle any instruction and produce its side-effects. +/// WeierstrassDecompressInstruction can handle any instruction and produce its lk and shardram data. pub struct WeierstrassDecompressInstruction(PhantomData<(E, EC)>); impl Instruction diff --git a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_double.rs b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_double.rs index 206ef143d..a95221b7c 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_double.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_double.rs @@ -54,7 +54,7 @@ pub struct EcallWeierstrassDoubleAssignConfig< mem_rw: Vec, } -/// WeierstrassDoubleAssignInstruction can handle any instruction and produce its side-effects. +/// WeierstrassDoubleAssignInstruction can handle any instruction and produce its lk and shardram data. pub struct WeierstrassDoubleAssignInstruction(PhantomData<(E, EC)>); impl Instruction diff --git a/ceno_zkvm/src/instructions/riscv/i_insn.rs b/ceno_zkvm/src/instructions/riscv/i_insn.rs index 317921307..ddda27c95 100644 --- a/ceno_zkvm/src/instructions/riscv/i_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/i_insn.rs @@ -8,7 +8,7 @@ use crate::{ error::ZKVMError, instructions::{ riscv::insn_base::{ReadRS1, StateInOut, WriteRD}, - side_effects::{LkOp, SideEffectSink}, + gpu::host_ops::{LkOp, LkShardramSink}, }, tables::InsnRecord, witness::LkMultiplicity, @@ -82,7 +82,7 @@ impl IInstructionConfig { pub fn emit_lk_and_shardram( &self, - sink: &mut impl SideEffectSink, + sink: &mut impl LkShardramSink, shard_ctx: &ShardContext, step: &StepRecord, ) { diff --git a/ceno_zkvm/src/instructions/riscv/im_insn.rs b/ceno_zkvm/src/instructions/riscv/im_insn.rs index 4414261e5..ce939a00a 100644 --- a/ceno_zkvm/src/instructions/riscv/im_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/im_insn.rs @@ -4,7 +4,7 @@ use crate::{ error::ZKVMError, instructions::{ riscv::insn_base::{ReadMEM, ReadRS1, StateInOut, WriteRD}, - side_effects::{LkOp, SideEffectSink}, + gpu::host_ops::{LkOp, LkShardramSink}, }, tables::InsnRecord, witness::LkMultiplicity, @@ -91,7 +91,7 @@ impl IMInstructionConfig { pub fn emit_lk_and_shardram( &self, - sink: &mut impl SideEffectSink, + sink: &mut impl LkShardramSink, shard_ctx: &ShardContext, step: &StepRecord, ) { diff --git a/ceno_zkvm/src/instructions/riscv/insn_base.rs b/ceno_zkvm/src/instructions/riscv/insn_base.rs index 84a1e8f07..6bbdbe583 100644 --- a/ceno_zkvm/src/instructions/riscv/insn_base.rs +++ b/ceno_zkvm/src/instructions/riscv/insn_base.rs @@ -13,7 +13,7 @@ use crate::{ e2e::ShardContext, error::ZKVMError, gadgets::AssertLtConfig, - instructions::side_effects::{LkOp, SendEvent, SideEffectSink, emit_assert_lt_ops}, + instructions::gpu::host_ops::{LkOp, SendEvent, LkShardramSink, emit_assert_lt_ops}, structs::RAMType, uint::Value, witness::{LkMultiplicity, set_val}, @@ -145,7 +145,7 @@ impl ReadRS1 { pub fn emit_lk_and_shardram( &self, - sink: &mut impl SideEffectSink, + sink: &mut impl LkShardramSink, shard_ctx: &ShardContext, step: &StepRecord, ) { @@ -254,7 +254,7 @@ impl ReadRS2 { pub fn emit_lk_and_shardram( &self, - sink: &mut impl SideEffectSink, + sink: &mut impl LkShardramSink, shard_ctx: &ShardContext, step: &StepRecord, ) { @@ -381,7 +381,7 @@ impl WriteRD { pub fn emit_op_lk_and_shardram( &self, - sink: &mut impl SideEffectSink, + sink: &mut impl LkShardramSink, shard_ctx: &ShardContext, cycle: Cycle, op: &WriteOp, @@ -408,7 +408,7 @@ impl WriteRD { pub fn emit_lk_and_shardram( &self, - sink: &mut impl SideEffectSink, + sink: &mut impl LkShardramSink, shard_ctx: &ShardContext, step: &StepRecord, ) { @@ -507,7 +507,7 @@ impl ReadMEM { pub fn emit_lk_and_shardram( &self, - sink: &mut impl SideEffectSink, + sink: &mut impl LkShardramSink, shard_ctx: &ShardContext, step: &StepRecord, ) { @@ -621,7 +621,7 @@ impl WriteMEM { pub fn emit_op_lk_and_shardram( &self, - sink: &mut impl SideEffectSink, + sink: &mut impl LkShardramSink, shard_ctx: &ShardContext, cycle: Cycle, op: &WriteOp, @@ -648,7 +648,7 @@ impl WriteMEM { pub fn emit_lk_and_shardram( &self, - sink: &mut impl SideEffectSink, + sink: &mut impl LkShardramSink, shard_ctx: &ShardContext, step: &StepRecord, ) { @@ -828,7 +828,7 @@ impl MemAddr { Ok(()) } - pub fn emit_lk_and_shardram(&self, sink: &mut impl SideEffectSink, addr: Word) { + pub fn emit_lk_and_shardram(&self, sink: &mut impl LkShardramSink, addr: Word) { let mid_u14 = ((addr & 0xffff) >> Self::N_LOW_BITS) as u16; sink.emit_lk(LkOp::AssertU14 { value: mid_u14 }); diff --git a/ceno_zkvm/src/instructions/riscv/j_insn.rs b/ceno_zkvm/src/instructions/riscv/j_insn.rs index 219af3e4c..da95642e0 100644 --- a/ceno_zkvm/src/instructions/riscv/j_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/j_insn.rs @@ -8,7 +8,7 @@ use crate::{ error::ZKVMError, instructions::{ riscv::insn_base::{StateInOut, WriteRD}, - side_effects::{LkOp, SideEffectSink}, + gpu::host_ops::{LkOp, LkShardramSink}, }, tables::InsnRecord, witness::LkMultiplicity, @@ -84,7 +84,7 @@ impl JInstructionConfig { pub fn emit_lk_and_shardram( &self, - sink: &mut impl SideEffectSink, + sink: &mut impl LkShardramSink, shard_ctx: &ShardContext, step: &StepRecord, ) { diff --git a/ceno_zkvm/src/instructions/riscv/jump/jal_v2.rs b/ceno_zkvm/src/instructions/riscv/jump/jal_v2.rs index 1af4ad390..d6555533a 100644 --- a/ceno_zkvm/src/instructions/riscv/jump/jal_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/jump/jal_v2.rs @@ -13,7 +13,7 @@ use crate::{ constants::{PC_BITS, UINT_BYTE_LIMBS, UInt8}, j_insn::JInstructionConfig, }, - side_effects::{LkOp, SideEffectSink, emit_byte_decomposition_ops}, + gpu::host_ops::{LkOp, LkShardramSink, emit_byte_decomposition_ops}, }, structs::ProgramParams, utils::split_to_u8, @@ -46,7 +46,7 @@ impl Instruction for JalInstruction { type InstructionConfig = JalConfig; type InsnType = InsnKind; - const GPU_SIDE_EFFECTS: bool = true; + const GPU_LK_SHARDRAM: bool = true; fn inst_kinds() -> &'static [Self::InsnType] { &[InsnKind::JAL] diff --git a/ceno_zkvm/src/instructions/riscv/jump/jalr_v2.rs b/ceno_zkvm/src/instructions/riscv/jump/jalr_v2.rs index 15b047b10..2e4c89916 100644 --- a/ceno_zkvm/src/instructions/riscv/jump/jalr_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/jump/jalr_v2.rs @@ -15,7 +15,7 @@ use crate::{ i_insn::IInstructionConfig, insn_base::{MemAddr, ReadRS1, StateInOut, WriteRD}, }, - side_effects::emit_const_range_op, + gpu::host_ops::emit_const_range_op, }, structs::ProgramParams, tables::InsnRecord, @@ -46,7 +46,7 @@ impl Instruction for JalrInstruction { type InstructionConfig = JalrConfig; type InsnType = InsnKind; - const GPU_SIDE_EFFECTS: bool = true; + const GPU_LK_SHARDRAM: bool = true; fn inst_kinds() -> &'static [Self::InsnType] { &[InsnKind::JALR] diff --git a/ceno_zkvm/src/instructions/riscv/logic/logic_circuit.rs b/ceno_zkvm/src/instructions/riscv/logic/logic_circuit.rs index 529198b91..8c10683e2 100644 --- a/ceno_zkvm/src/instructions/riscv/logic/logic_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/logic/logic_circuit.rs @@ -12,7 +12,7 @@ use crate::{ instructions::{ Instruction, riscv::{constants::UInt8, r_insn::RInstructionConfig}, - side_effects::emit_logic_u8_ops, + gpu::host_ops::emit_logic_u8_ops, }, structs::ProgramParams, utils::split_to_u8, @@ -33,7 +33,7 @@ impl Instruction for LogicInstruction { type InstructionConfig = LogicConfig; type InsnType = InsnKind; - const GPU_SIDE_EFFECTS: bool = true; + const GPU_LK_SHARDRAM: bool = true; fn inst_kinds() -> &'static [Self::InsnType] { &[I::INST_KIND] @@ -156,7 +156,7 @@ impl LogicConfig { fn emit_lk_and_shardram( &self, - sink: &mut impl crate::instructions::gpu::host_ops::SideEffectSink, + sink: &mut impl crate::instructions::gpu::host_ops::LkShardramSink, shard_ctx: &ShardContext, step: &StepRecord, ) { diff --git a/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit_v2.rs index 5d43380a0..bfe36b3ae 100644 --- a/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit_v2.rs @@ -17,7 +17,7 @@ use crate::{ i_insn::IInstructionConfig, logic_imm::LogicOp, }, - side_effects::emit_logic_u8_ops, + gpu::host_ops::emit_logic_u8_ops, }, structs::ProgramParams, tables::InsnRecord, @@ -35,7 +35,7 @@ impl Instruction for LogicInstruction { type InstructionConfig = LogicConfig; type InsnType = InsnKind; - const GPU_SIDE_EFFECTS: bool = true; + const GPU_LK_SHARDRAM: bool = true; fn inst_kinds() -> &'static [Self::InsnType] { &[I::INST_KIND] diff --git a/ceno_zkvm/src/instructions/riscv/lui.rs b/ceno_zkvm/src/instructions/riscv/lui.rs index 52983fe40..866b90f2d 100644 --- a/ceno_zkvm/src/instructions/riscv/lui.rs +++ b/ceno_zkvm/src/instructions/riscv/lui.rs @@ -13,7 +13,7 @@ use crate::{ constants::{UINT_BYTE_LIMBS, UInt8}, i_insn::IInstructionConfig, }, - side_effects::emit_const_range_op, + gpu::host_ops::emit_const_range_op, }, structs::ProgramParams, tables::InsnRecord, @@ -38,7 +38,7 @@ impl Instruction for LuiInstruction { type InstructionConfig = LuiConfig; type InsnType = InsnKind; - const GPU_SIDE_EFFECTS: bool = true; + const GPU_LK_SHARDRAM: bool = true; fn inst_kinds() -> &'static [Self::InsnType] { &[InsnKind::LUI] diff --git a/ceno_zkvm/src/instructions/riscv/memory/load.rs b/ceno_zkvm/src/instructions/riscv/memory/load.rs index d1c93194a..d3fe72181 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/load.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/load.rs @@ -41,7 +41,7 @@ impl Instruction for LoadInstruction; type InsnType = InsnKind; - const GPU_SIDE_EFFECTS: bool = matches!(I::INST_KIND, InsnKind::LW); + const GPU_LK_SHARDRAM: bool = matches!(I::INST_KIND, InsnKind::LW); fn inst_kinds() -> &'static [Self::InsnType] { &[I::INST_KIND] diff --git a/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs b/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs index d8f6b8128..6ce464ed9 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs @@ -45,7 +45,7 @@ impl Instruction for LoadInstruction; type InsnType = InsnKind; - const GPU_SIDE_EFFECTS: bool = matches!( + const GPU_LK_SHARDRAM: bool = matches!( I::INST_KIND, InsnKind::LW | InsnKind::LB | InsnKind::LBU | InsnKind::LH | InsnKind::LHU ); @@ -259,7 +259,7 @@ impl Instruction for LoadInstruction::imm_internal(&step.insn()); let unaligned_addr = diff --git a/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs b/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs index b6c1e3064..d74214a93 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs @@ -13,7 +13,7 @@ use crate::{ memory::gadget::MemWordUtil, s_insn::SInstructionConfig, }, - side_effects::{emit_const_range_op, emit_u16_limbs}, + gpu::host_ops::{emit_const_range_op, emit_u16_limbs}, }, structs::ProgramParams, tables::InsnRecord, @@ -46,7 +46,7 @@ impl Instruction type InstructionConfig = StoreConfig; type InsnType = InsnKind; - const GPU_SIDE_EFFECTS: bool = true; + const GPU_LK_SHARDRAM: bool = true; fn inst_kinds() -> &'static [Self::InsnType] { &[I::INST_KIND] diff --git a/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit_v2.rs index 47180ecfb..76719b168 100644 --- a/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit_v2.rs @@ -9,7 +9,7 @@ use crate::{ constants::{LIMB_BITS, UINT_LIMBS, UInt}, r_insn::RInstructionConfig, }, - side_effects::{LkOp, SideEffectSink}, + gpu::host_ops::{LkOp, LkShardramSink}, }, structs::ProgramParams, uint::Value, @@ -42,7 +42,7 @@ impl Instruction for MulhInstructionBas type InstructionConfig = MulhConfig; type InsnType = InsnKind; - const GPU_SIDE_EFFECTS: bool = true; + const GPU_LK_SHARDRAM: bool = true; fn inst_kinds() -> &'static [Self::InsnType] { &[I::INST_KIND] diff --git a/ceno_zkvm/src/instructions/riscv/r_insn.rs b/ceno_zkvm/src/instructions/riscv/r_insn.rs index bf68809de..b372a8bb2 100644 --- a/ceno_zkvm/src/instructions/riscv/r_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/r_insn.rs @@ -8,7 +8,7 @@ use crate::{ error::ZKVMError, instructions::{ riscv::insn_base::{ReadRS1, ReadRS2, StateInOut, WriteRD}, - side_effects::{LkOp, SideEffectSink}, + gpu::host_ops::{LkOp, LkShardramSink}, }, tables::InsnRecord, witness::LkMultiplicity, @@ -87,7 +87,7 @@ impl RInstructionConfig { pub fn emit_lk_and_shardram( &self, - sink: &mut impl SideEffectSink, + sink: &mut impl LkShardramSink, shard_ctx: &ShardContext, step: &StepRecord, ) { diff --git a/ceno_zkvm/src/instructions/riscv/s_insn.rs b/ceno_zkvm/src/instructions/riscv/s_insn.rs index 332becf61..d136101b4 100644 --- a/ceno_zkvm/src/instructions/riscv/s_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/s_insn.rs @@ -5,7 +5,7 @@ use crate::{ error::ZKVMError, instructions::{ riscv::insn_base::{ReadRS1, ReadRS2, StateInOut, WriteMEM}, - side_effects::{LkOp, SideEffectSink}, + gpu::host_ops::{LkOp, LkShardramSink}, }, tables::InsnRecord, witness::LkMultiplicity, @@ -110,7 +110,7 @@ impl SInstructionConfig { #[allow(dead_code)] pub fn emit_lk_and_shardram( &self, - sink: &mut impl SideEffectSink, + sink: &mut impl LkShardramSink, shard_ctx: &ShardContext, step: &StepRecord, ) { diff --git a/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs index 61411dbd0..b9496ac20 100644 --- a/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs @@ -10,8 +10,8 @@ use crate::{ i_insn::IInstructionConfig, r_insn::RInstructionConfig, }, - side_effects::{ - LkOp, SideEffectSink, emit_byte_decomposition_ops, + gpu::host_ops::{ + LkOp, LkShardramSink, emit_byte_decomposition_ops, emit_const_range_op, }, }, @@ -213,7 +213,7 @@ impl pub fn emit_lk_and_shardram( &self, - sink: &mut impl SideEffectSink, + sink: &mut impl LkShardramSink, kind: InsnKind, b: u32, c: u32, @@ -322,7 +322,7 @@ impl Instruction for ShiftLogicalInstru type InstructionConfig = ShiftRTypeConfig; type InsnType = InsnKind; - const GPU_SIDE_EFFECTS: bool = true; + const GPU_LK_SHARDRAM: bool = true; fn inst_kinds() -> &'static [Self::InsnType] { &[I::INST_KIND] @@ -445,7 +445,7 @@ impl Instruction for ShiftImmInstructio type InstructionConfig = ShiftImmConfig; type InsnType = InsnKind; - const GPU_SIDE_EFFECTS: bool = true; + const GPU_LK_SHARDRAM: bool = true; fn inst_kinds() -> &'static [Self::InsnType] { &[I::INST_KIND] diff --git a/ceno_zkvm/src/instructions/riscv/slt/slt_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/slt/slt_circuit_v2.rs index 9dda4abe2..562712954 100644 --- a/ceno_zkvm/src/instructions/riscv/slt/slt_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/slt/slt_circuit_v2.rs @@ -8,7 +8,7 @@ use crate::{ instructions::{ Instruction, riscv::{RIVInstruction, constants::UInt, r_insn::RInstructionConfig}, - side_effects::emit_uint_limbs_lt_ops, + gpu::host_ops::emit_uint_limbs_lt_ops, }, structs::ProgramParams, witness::LkMultiplicity, @@ -34,7 +34,7 @@ impl Instruction for SetLessThanInstruc type InstructionConfig = SetLessThanConfig; type InsnType = InsnKind; - const GPU_SIDE_EFFECTS: bool = true; + const GPU_LK_SHARDRAM: bool = true; fn inst_kinds() -> &'static [Self::InsnType] { &[I::INST_KIND] diff --git a/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs index d56076fd1..19a30c647 100644 --- a/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs @@ -12,7 +12,7 @@ use crate::{ constants::{UINT_LIMBS, UInt}, i_insn::IInstructionConfig, }, - side_effects::emit_uint_limbs_lt_ops, + gpu::host_ops::emit_uint_limbs_lt_ops, }, structs::ProgramParams, utils::{imm_sign_extend, imm_sign_extend_circuit}, @@ -45,7 +45,7 @@ impl Instruction for SetLessThanImmInst type InstructionConfig = SetLessThanImmConfig; type InsnType = InsnKind; - const GPU_SIDE_EFFECTS: bool = true; + const GPU_LK_SHARDRAM: bool = true; fn inst_kinds() -> &'static [Self::InsnType] { &[I::INST_KIND] From 659586084910ce3e5d5ccd4128e481a4edd6cd9a Mon Sep 17 00:00:00 2001 From: Velaciela Date: Tue, 24 Mar 2026 12:31:54 +0800 Subject: [PATCH 62/73] path: gpu/host_ops to gpu/utils --- ceno_zkvm/src/instructions.rs | 4 +- .../src/instructions/gpu/host_ops/mod.rs | 880 ------------------ ceno_zkvm/src/instructions/gpu/mod.rs | 830 ++++++++++++++++- .../gpu/{host_ops => utils}/emit.rs | 0 .../cpu_fallback.rs => utils/fallback.rs} | 0 .../gpu/{host_ops => utils}/lk_ops.rs | 0 ceno_zkvm/src/instructions/gpu/utils/mod.rs | 60 ++ .../gpu/{host_ops => utils}/sink.rs | 0 ceno_zkvm/src/instructions/riscv/arith.rs | 2 +- .../riscv/arith_imm/arith_imm_circuit_v2.rs | 2 +- ceno_zkvm/src/instructions/riscv/auipc.rs | 2 +- ceno_zkvm/src/instructions/riscv/b_insn.rs | 2 +- .../riscv/branch/branch_circuit_v2.rs | 2 +- .../instructions/riscv/div/div_circuit_v2.rs | 2 +- ceno_zkvm/src/instructions/riscv/i_insn.rs | 2 +- ceno_zkvm/src/instructions/riscv/im_insn.rs | 2 +- ceno_zkvm/src/instructions/riscv/insn_base.rs | 2 +- ceno_zkvm/src/instructions/riscv/j_insn.rs | 2 +- .../src/instructions/riscv/jump/jal_v2.rs | 2 +- .../src/instructions/riscv/jump/jalr_v2.rs | 2 +- .../instructions/riscv/logic/logic_circuit.rs | 4 +- .../riscv/logic_imm/logic_imm_circuit_v2.rs | 2 +- ceno_zkvm/src/instructions/riscv/lui.rs | 2 +- .../src/instructions/riscv/memory/store_v2.rs | 2 +- .../riscv/mulh/mulh_circuit_v2.rs | 2 +- ceno_zkvm/src/instructions/riscv/r_insn.rs | 2 +- ceno_zkvm/src/instructions/riscv/s_insn.rs | 2 +- .../riscv/shift/shift_circuit_v2.rs | 2 +- .../instructions/riscv/slt/slt_circuit_v2.rs | 2 +- .../riscv/slti/slti_circuit_v2.rs | 2 +- 30 files changed, 914 insertions(+), 906 deletions(-) delete mode 100644 ceno_zkvm/src/instructions/gpu/host_ops/mod.rs rename ceno_zkvm/src/instructions/gpu/{host_ops => utils}/emit.rs (100%) rename ceno_zkvm/src/instructions/gpu/{host_ops/cpu_fallback.rs => utils/fallback.rs} (100%) rename ceno_zkvm/src/instructions/gpu/{host_ops => utils}/lk_ops.rs (100%) create mode 100644 ceno_zkvm/src/instructions/gpu/utils/mod.rs rename ceno_zkvm/src/instructions/gpu/{host_ops => utils}/sink.rs (100%) diff --git a/ceno_zkvm/src/instructions.rs b/ceno_zkvm/src/instructions.rs index c121f00bc..ed95ce846 100644 --- a/ceno_zkvm/src/instructions.rs +++ b/ceno_zkvm/src/instructions.rs @@ -23,7 +23,7 @@ pub mod riscv; pub mod gpu; -pub use gpu::host_ops::{cpu_assign_instances, cpu_collect_lk_and_shardram, cpu_collect_shardram}; +pub use gpu::utils::{cpu_assign_instances, cpu_collect_lk_and_shardram, cpu_collect_shardram}; pub trait Instruction { type InstructionConfig: Send + Sync; @@ -259,7 +259,7 @@ macro_rules! impl_collect_lk_and_shardram { let shard_ctx_ptr = shard_ctx as *mut $crate::e2e::ShardContext; let _ctx = unsafe { &*shard_ctx_ptr }; let mut _sink_val = unsafe { - $crate::instructions::gpu::host_ops::CpuLkShardramSink::from_raw( + $crate::instructions::gpu::utils::CpuLkShardramSink::from_raw( shard_ctx_ptr, lk_multiplicity, ) diff --git a/ceno_zkvm/src/instructions/gpu/host_ops/mod.rs b/ceno_zkvm/src/instructions/gpu/host_ops/mod.rs deleted file mode 100644 index 23b6986ae..000000000 --- a/ceno_zkvm/src/instructions/gpu/host_ops/mod.rs +++ /dev/null @@ -1,880 +0,0 @@ -//! Host-side operations for GPU-CPU hybrid witness generation. -//! -//! Contains lookup/shard lk_shardram collection abstractions and CPU fallback paths. - -mod lk_ops; -mod sink; -mod emit; -mod cpu_fallback; - -// Re-export all public types for convenience -pub use lk_ops::*; -pub use sink::*; -pub use emit::*; -pub use cpu_fallback::*; - -#[cfg(test)] -mod tests { - use super::*; - use crate::{ - circuit_builder::{CircuitBuilder, ConstraintSystem}, - e2e::ShardContext, - instructions::{ - Instruction, cpu_assign_instances, cpu_collect_shardram, - cpu_collect_lk_and_shardram, - riscv::{ - AddInstruction, JalInstruction, JalrInstruction, LwInstruction, SbInstruction, - branch::{BeqInstruction, BltInstruction}, - div::{DivInstruction, RemuInstruction}, - logic::AndInstruction, - mulh::{MulInstruction, MulhInstruction}, - shift::SraInstruction, - shift_imm::SlliInstruction, - slt::SltInstruction, - slti::SltiInstruction, - }, - }, - structs::ProgramParams, - witness::LkMultiplicity, - }; - use ceno_emul::{ - ByteAddr, Change, InsnKind, PC_STEP_SIZE, ReadOp, StepRecord, WordAddr, WriteOp, - encode_rv32, - }; - use ff_ext::GoldilocksExt2; - use gkr_iop::tables::LookupTable; - - type E = GoldilocksExt2; - - fn assert_lk_shardram_match>( - config: &I::InstructionConfig, - num_witin: usize, - num_structural_witin: usize, - steps: &[StepRecord], - ) { - let indices: Vec = (0..steps.len()).collect(); - - let mut assign_ctx = ShardContext::default(); - let (_, expected_lk) = cpu_assign_instances::( - config, - &mut assign_ctx, - num_witin, - num_structural_witin, - steps, - &indices, - ) - .unwrap(); - - let mut collect_ctx = ShardContext::default(); - let actual_lk = - cpu_collect_lk_and_shardram::(config, &mut collect_ctx, steps, &indices).unwrap(); - - assert_eq!(flatten_lk(&expected_lk), flatten_lk(&actual_lk)); - assert_eq!( - assign_ctx.get_addr_accessed(), - collect_ctx.get_addr_accessed() - ); - assert_eq!( - flatten_records(assign_ctx.read_records()), - flatten_records(collect_ctx.read_records()) - ); - assert_eq!( - flatten_records(assign_ctx.write_records()), - flatten_records(collect_ctx.write_records()) - ); - } - - fn assert_shard_lk_shardram_match>( - config: &I::InstructionConfig, - num_witin: usize, - num_structural_witin: usize, - steps: &[StepRecord], - ) { - let indices: Vec = (0..steps.len()).collect(); - - let mut assign_ctx = ShardContext::default(); - let (_, expected_lk) = cpu_assign_instances::( - config, - &mut assign_ctx, - num_witin, - num_structural_witin, - steps, - &indices, - ) - .unwrap(); - - let mut collect_ctx = ShardContext::default(); - let actual_lk = - cpu_collect_shardram::(config, &mut collect_ctx, steps, &indices) - .unwrap(); - - assert_eq!( - expected_lk[LookupTable::Instruction as usize], - actual_lk[LookupTable::Instruction as usize] - ); - for (table_idx, table) in actual_lk.iter().enumerate() { - if table_idx != LookupTable::Instruction as usize { - assert!( - table.is_empty(), - "unexpected non-fetch shard-only multiplicity in table {table_idx}: {table:?}" - ); - } - } - assert_eq!( - assign_ctx.get_addr_accessed(), - collect_ctx.get_addr_accessed() - ); - assert_eq!( - flatten_records(assign_ctx.read_records()), - flatten_records(collect_ctx.read_records()) - ); - assert_eq!( - flatten_records(assign_ctx.write_records()), - flatten_records(collect_ctx.write_records()) - ); - } - - fn flatten_records( - records: &[std::collections::BTreeMap], - ) -> Vec<(WordAddr, u64, u64, usize)> { - records - .iter() - .flat_map(|table| { - table - .iter() - .map(|(addr, record)| (*addr, record.prev_cycle, record.cycle, record.shard_id)) - }) - .collect() - } - - fn flatten_lk( - multiplicity: &gkr_iop::utils::lk_multiplicity::Multiplicity, - ) -> Vec> { - multiplicity - .iter() - .map(|table| { - let mut entries = table - .iter() - .map(|(key, count)| (*key, *count)) - .collect::>(); - entries.sort_unstable(); - entries - }) - .collect() - } - - #[test] - fn test_add_lk_shardram_match_assign_instance() { - let mut cs = ConstraintSystem::::new(|| "add_lk_shardram"); - let mut cb = CircuitBuilder::new(&mut cs); - let config = - AddInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); - let steps: Vec<_> = (0..4) - .map(|i| { - let rd = 2 + i; - let rs1 = 8 + i; - let rs2 = 16 + i; - let lhs = 10 + i as u32; - let rhs = 100 + i as u32; - let insn = encode_rv32(InsnKind::ADD, rs1, rs2, rd, 0); - StepRecord::new_r_instruction( - 4 + (i as u64) * 4, - ByteAddr(0x1000 + i as u32 * 4), - insn, - lhs, - rhs, - Change::new(0, lhs.wrapping_add(rhs)), - 0, - ) - }) - .collect(); - - assert_lk_shardram_match::>( - &config, - cb.cs.num_witin as usize, - cb.cs.num_structural_witin as usize, - &steps, - ); - } - - #[test] - fn test_and_lk_shardram_match_assign_instance() { - let mut cs = ConstraintSystem::::new(|| "and_lk_shardram"); - let mut cb = CircuitBuilder::new(&mut cs); - let config = - AndInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); - let steps: Vec<_> = (0..4) - .map(|i| { - let rd = 2 + i; - let rs1 = 8 + i; - let rs2 = 16 + i; - let lhs = 0xdead_0000 | i as u32; - let rhs = 0x00ff_ff00 | ((i as u32) << 8); - let insn = encode_rv32(InsnKind::AND, rs1, rs2, rd, 0); - StepRecord::new_r_instruction( - 4 + (i as u64) * 4, - ByteAddr(0x2000 + i as u32 * 4), - insn, - lhs, - rhs, - Change::new(0, lhs & rhs), - 0, - ) - }) - .collect(); - - assert_lk_shardram_match::>( - &config, - cb.cs.num_witin as usize, - cb.cs.num_structural_witin as usize, - &steps, - ); - } - - #[test] - fn test_add_shard_lk_shardram_match_assign_instance() { - let mut cs = ConstraintSystem::::new(|| "add_shard_lk_shardram"); - let mut cb = CircuitBuilder::new(&mut cs); - let config = - AddInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); - let steps: Vec<_> = (0..4) - .map(|i| { - let rd = 2 + i; - let rs1 = 8 + i; - let rs2 = 16 + i; - let lhs = 10 + i as u32; - let rhs = 100 + i as u32; - let insn = encode_rv32(InsnKind::ADD, rs1, rs2, rd, 0); - StepRecord::new_r_instruction( - 84 + (i as u64) * 4, - ByteAddr(0x5000 + i as u32 * 4), - insn, - lhs, - rhs, - Change::new(0, lhs.wrapping_add(rhs)), - 0, - ) - }) - .collect(); - - assert_shard_lk_shardram_match::>( - &config, - cb.cs.num_witin as usize, - cb.cs.num_structural_witin as usize, - &steps, - ); - } - - #[test] - fn test_and_shard_lk_shardram_match_assign_instance() { - let mut cs = ConstraintSystem::::new(|| "and_shard_lk_shardram"); - let mut cb = CircuitBuilder::new(&mut cs); - let config = - AndInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); - let steps: Vec<_> = (0..4) - .map(|i| { - let rd = 2 + i; - let rs1 = 8 + i; - let rs2 = 16 + i; - let lhs = 0xdead_0000 | i as u32; - let rhs = 0x00ff_ff00 | ((i as u32) << 8); - let insn = encode_rv32(InsnKind::AND, rs1, rs2, rd, 0); - StepRecord::new_r_instruction( - 100 + (i as u64) * 4, - ByteAddr(0x5100 + i as u32 * 4), - insn, - lhs, - rhs, - Change::new(0, lhs & rhs), - 0, - ) - }) - .collect(); - - assert_shard_lk_shardram_match::>( - &config, - cb.cs.num_witin as usize, - cb.cs.num_structural_witin as usize, - &steps, - ); - } - - #[test] - fn test_lw_lk_shardram_match_assign_instance() { - let mut cs = ConstraintSystem::::new(|| "lw_lk_shardram"); - let mut cb = CircuitBuilder::new(&mut cs); - let config = - LwInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); - let steps: Vec<_> = (0..4) - .map(|i| { - let rd = 2 + i; - let rs1 = 8 + i; - let rs1_val = 0x1000u32 + (i as u32) * 16; - let imm = (i as i32) * 4 - 4; - let mem_addr = rs1_val.wrapping_add_signed(imm); - let mem_val = 0xabc0_0000 | i as u32; - let insn = encode_rv32(InsnKind::LW, rs1, 0, rd, imm); - let mem_read = ReadOp { - addr: WordAddr::from(ByteAddr(mem_addr)), - value: mem_val, - previous_cycle: 0, - }; - StepRecord::new_im_instruction( - 4 + (i as u64) * 4, - ByteAddr(0x3000 + i as u32 * 4), - insn, - rs1_val, - Change::new(0, mem_val), - mem_read, - 0, - ) - }) - .collect(); - - assert_lk_shardram_match::>( - &config, - cb.cs.num_witin as usize, - cb.cs.num_structural_witin as usize, - &steps, - ); - } - - #[test] - fn test_lw_shard_lk_shardram_match_assign_instance() { - let mut cs = ConstraintSystem::::new(|| "lw_shard_lk_shardram"); - let mut cb = CircuitBuilder::new(&mut cs); - let config = - LwInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); - let steps: Vec<_> = (0..4) - .map(|i| { - let rd = 2 + i; - let rs1 = 8 + i; - let rs1_val = 0x1400u32 + (i as u32) * 16; - let imm = (i as i32) * 4 - 4; - let mem_addr = rs1_val.wrapping_add_signed(imm); - let mem_val = 0xabd0_0000 | i as u32; - let insn = encode_rv32(InsnKind::LW, rs1, 0, rd, imm); - let mem_read = ReadOp { - addr: WordAddr::from(ByteAddr(mem_addr)), - value: mem_val, - previous_cycle: 0, - }; - StepRecord::new_im_instruction( - 116 + (i as u64) * 4, - ByteAddr(0x5200 + i as u32 * 4), - insn, - rs1_val, - Change::new(0, mem_val), - mem_read, - 0, - ) - }) - .collect(); - - assert_shard_lk_shardram_match::>( - &config, - cb.cs.num_witin as usize, - cb.cs.num_structural_witin as usize, - &steps, - ); - } - - #[test] - fn test_beq_lk_shardram_match_assign_instance() { - let mut cs = ConstraintSystem::::new(|| "beq_lk_shardram"); - let mut cb = CircuitBuilder::new(&mut cs); - let config = - BeqInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); - let cases = [ - (true, 0x1122_3344, 0x1122_3344), - (false, 0x5566_7788, 0x99aa_bbcc), - ]; - let steps: Vec<_> = cases - .into_iter() - .enumerate() - .map(|(i, (taken, lhs, rhs))| { - let pc = ByteAddr(0x4000 + i as u32 * 4); - let next_pc = if taken { - ByteAddr(pc.0 + 8) - } else { - pc + PC_STEP_SIZE - }; - StepRecord::new_b_instruction( - 4 + i as u64 * 4, - Change::new(pc, next_pc), - encode_rv32(InsnKind::BEQ, 8 + i as u32, 16 + i as u32, 0, 8), - lhs, - rhs, - 0, - ) - }) - .collect(); - - assert_lk_shardram_match::>( - &config, - cb.cs.num_witin as usize, - cb.cs.num_structural_witin as usize, - &steps, - ); - } - - #[test] - fn test_blt_lk_shardram_match_assign_instance() { - let mut cs = ConstraintSystem::::new(|| "blt_lk_shardram"); - let mut cb = CircuitBuilder::new(&mut cs); - let config = - BltInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); - let cases = [(true, (-2i32) as u32, 1u32), (false, 7u32, (-3i32) as u32)]; - let steps: Vec<_> = cases - .into_iter() - .enumerate() - .map(|(i, (taken, lhs, rhs))| { - let pc = ByteAddr(0x4100 + i as u32 * 4); - let next_pc = if taken { - ByteAddr(pc.0.wrapping_sub(8)) - } else { - pc + PC_STEP_SIZE - }; - StepRecord::new_b_instruction( - 12 + i as u64 * 4, - Change::new(pc, next_pc), - encode_rv32(InsnKind::BLT, 4 + i as u32, 5 + i as u32, 0, -8), - lhs, - rhs, - 10, - ) - }) - .collect(); - - assert_lk_shardram_match::>( - &config, - cb.cs.num_witin as usize, - cb.cs.num_structural_witin as usize, - &steps, - ); - } - - #[test] - fn test_jal_lk_shardram_match_assign_instance() { - let mut cs = ConstraintSystem::::new(|| "jal_lk_shardram"); - let mut cb = CircuitBuilder::new(&mut cs); - let config = - JalInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); - let offsets = [8, -8]; - let steps: Vec<_> = offsets - .into_iter() - .enumerate() - .map(|(i, offset)| { - let pc = ByteAddr(0x4200 + i as u32 * 4); - StepRecord::new_j_instruction( - 20 + i as u64 * 4, - Change::new(pc, ByteAddr(pc.0.wrapping_add_signed(offset))), - encode_rv32(InsnKind::JAL, 0, 0, 3 + i as u32, offset), - Change::new(0, (pc + PC_STEP_SIZE).into()), - 0, - ) - }) - .collect(); - - assert_lk_shardram_match::>( - &config, - cb.cs.num_witin as usize, - cb.cs.num_structural_witin as usize, - &steps, - ); - } - - #[test] - fn test_jalr_lk_shardram_match_assign_instance() { - let mut cs = ConstraintSystem::::new(|| "jalr_lk_shardram"); - let mut cb = CircuitBuilder::new(&mut cs); - let config = - JalrInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); - let cases = [(100u32, 3), (0x4010u32, -5)]; - let steps: Vec<_> = cases - .into_iter() - .enumerate() - .map(|(i, (rs1, imm))| { - let pc = ByteAddr(0x4300 + i as u32 * 4); - let next_pc = ByteAddr(rs1.wrapping_add_signed(imm) & !1); - StepRecord::new_i_instruction( - 28 + i as u64 * 4, - Change::new(pc, next_pc), - encode_rv32(InsnKind::JALR, 2 + i as u32, 0, 6 + i as u32, imm), - rs1, - Change::new(0, (pc + PC_STEP_SIZE).into()), - 0, - ) - }) - .collect(); - - assert_lk_shardram_match::>( - &config, - cb.cs.num_witin as usize, - cb.cs.num_structural_witin as usize, - &steps, - ); - } - - #[test] - fn test_slt_lk_shardram_match_assign_instance() { - let mut cs = ConstraintSystem::::new(|| "slt_lk_shardram"); - let mut cb = CircuitBuilder::new(&mut cs); - let config = - SltInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); - let cases = [((-1i32) as u32, 0u32), (5u32, (-2i32) as u32)]; - let steps: Vec<_> = cases - .into_iter() - .enumerate() - .map(|(i, (lhs, rhs))| { - let insn = - encode_rv32(InsnKind::SLT, 9 + i as u32, 10 + i as u32, 11 + i as u32, 0); - StepRecord::new_r_instruction( - 36 + i as u64 * 4, - ByteAddr(0x4400 + i as u32 * 4), - insn, - lhs, - rhs, - Change::new(0, ((lhs as i32) < (rhs as i32)) as u32), - 0, - ) - }) - .collect(); - - assert_lk_shardram_match::>( - &config, - cb.cs.num_witin as usize, - cb.cs.num_structural_witin as usize, - &steps, - ); - } - - #[test] - fn test_slti_lk_shardram_match_assign_instance() { - let mut cs = ConstraintSystem::::new(|| "slti_lk_shardram"); - let mut cb = CircuitBuilder::new(&mut cs); - let config = - SltiInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); - let cases = [(0u32, -1), ((-2i32) as u32, 1)]; - let steps: Vec<_> = cases - .into_iter() - .enumerate() - .map(|(i, (rs1, imm))| { - let insn = encode_rv32(InsnKind::SLTI, 12 + i as u32, 0, 13 + i as u32, imm); - let pc = ByteAddr(0x4500 + i as u32 * 4); - StepRecord::new_i_instruction( - 44 + i as u64 * 4, - Change::new(pc, pc + PC_STEP_SIZE), - insn, - rs1, - Change::new(0, ((rs1 as i32) < imm) as u32), - 0, - ) - }) - .collect(); - - assert_lk_shardram_match::>( - &config, - cb.cs.num_witin as usize, - cb.cs.num_structural_witin as usize, - &steps, - ); - } - - #[test] - fn test_sra_lk_shardram_match_assign_instance() { - let mut cs = ConstraintSystem::::new(|| "sra_lk_shardram"); - let mut cb = CircuitBuilder::new(&mut cs); - let config = - SraInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); - let cases = [(0x8765_4321u32, 4u32), (0xf000_0000u32, 31u32)]; - let steps: Vec<_> = cases - .into_iter() - .enumerate() - .map(|(i, (lhs, rhs))| { - let shift = rhs & 31; - let rd = ((lhs as i32) >> shift) as u32; - StepRecord::new_r_instruction( - 52 + i as u64 * 4, - ByteAddr(0x4600 + i as u32 * 4), - encode_rv32(InsnKind::SRA, 6 + i as u32, 7 + i as u32, 8 + i as u32, 0), - lhs, - rhs, - Change::new(0, rd), - 0, - ) - }) - .collect(); - - assert_lk_shardram_match::>( - &config, - cb.cs.num_witin as usize, - cb.cs.num_structural_witin as usize, - &steps, - ); - } - - #[test] - fn test_slli_lk_shardram_match_assign_instance() { - let mut cs = ConstraintSystem::::new(|| "slli_lk_shardram"); - let mut cb = CircuitBuilder::new(&mut cs); - let config = - SlliInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); - let cases = [(0x1234_5678u32, 3), (0x0000_0001u32, 31)]; - let steps: Vec<_> = cases - .into_iter() - .enumerate() - .map(|(i, (rs1, imm))| { - let pc = ByteAddr(0x4700 + i as u32 * 4); - StepRecord::new_i_instruction( - 60 + i as u64 * 4, - Change::new(pc, pc + PC_STEP_SIZE), - encode_rv32(InsnKind::SLLI, 9 + i as u32, 0, 10 + i as u32, imm), - rs1, - Change::new(0, rs1.wrapping_shl((imm & 31) as u32)), - 0, - ) - }) - .collect(); - - assert_lk_shardram_match::>( - &config, - cb.cs.num_witin as usize, - cb.cs.num_structural_witin as usize, - &steps, - ); - } - - #[test] - fn test_sb_lk_shardram_match_assign_instance() { - let mut cs = ConstraintSystem::::new(|| "sb_lk_shardram"); - let mut cb = CircuitBuilder::new(&mut cs); - let config = - SbInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); - let steps: Vec<_> = (0..2) - .map(|i| { - let rs1 = 0x4800u32 + i * 16; - let rs2 = 0x1234_5600u32 | i; - let imm = i as i32 - 1; - let addr = ByteAddr::from(rs1.wrapping_add_signed(imm)); - let prev = 0x4030_2010u32 + i; - let shift = (addr.shift() * 8) as usize; - let mut next = prev & !(0xff << shift); - next |= (rs2 & 0xff) << shift; - StepRecord::new_s_instruction( - 68 + i as u64 * 4, - ByteAddr(0x4800 + i * 4), - encode_rv32(InsnKind::SB, 11 + i, 12 + i, 0, imm), - rs1, - rs2, - WriteOp { - addr: addr.waddr(), - value: Change::new(prev, next), - previous_cycle: 4, - }, - 8, - ) - }) - .collect(); - - assert_lk_shardram_match::>( - &config, - cb.cs.num_witin as usize, - cb.cs.num_structural_witin as usize, - &steps, - ); - } - - #[test] - fn test_mul_lk_shardram_match_assign_instance() { - let mut cs = ConstraintSystem::::new(|| "mul_lk_shardram"); - let mut cb = CircuitBuilder::new(&mut cs); - let config = - MulInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); - let cases = [(2u32, 11u32), (u32::MAX, 17u32)]; - let steps: Vec<_> = cases - .into_iter() - .enumerate() - .map(|(i, (lhs, rhs))| { - StepRecord::new_r_instruction( - 76 + i as u64 * 4, - ByteAddr(0x4900 + i as u32 * 4), - encode_rv32( - InsnKind::MUL, - 13 + i as u32, - 14 + i as u32, - 15 + i as u32, - 0, - ), - lhs, - rhs, - Change::new(0, lhs.wrapping_mul(rhs)), - 0, - ) - }) - .collect(); - - assert_lk_shardram_match::>( - &config, - cb.cs.num_witin as usize, - cb.cs.num_structural_witin as usize, - &steps, - ); - } - - #[test] - fn test_mulh_lk_shardram_match_assign_instance() { - let mut cs = ConstraintSystem::::new(|| "mulh_lk_shardram"); - let mut cb = CircuitBuilder::new(&mut cs); - let config = - MulhInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); - let cases = [(2i32, -11i32), (i32::MIN, -1i32)]; - let steps: Vec<_> = cases - .into_iter() - .enumerate() - .map(|(i, (lhs, rhs))| { - let outcome = ((lhs as i64).wrapping_mul(rhs as i64) >> 32) as u32; - StepRecord::new_r_instruction( - 84 + i as u64 * 4, - ByteAddr(0x4a00 + i as u32 * 4), - encode_rv32( - InsnKind::MULH, - 16 + i as u32, - 17 + i as u32, - 18 + i as u32, - 0, - ), - lhs as u32, - rhs as u32, - Change::new(0, outcome), - 0, - ) - }) - .collect(); - - assert_lk_shardram_match::>( - &config, - cb.cs.num_witin as usize, - cb.cs.num_structural_witin as usize, - &steps, - ); - } - - #[test] - fn test_div_lk_shardram_match_assign_instance() { - let mut cs = ConstraintSystem::::new(|| "div_lk_shardram"); - let mut cb = CircuitBuilder::new(&mut cs); - let config = - DivInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); - let cases = [(17i32, -3i32), (i32::MIN, -1i32)]; - let steps: Vec<_> = cases - .into_iter() - .enumerate() - .map(|(i, (lhs, rhs))| { - let out = if rhs == 0 { - -1i32 - } else { - lhs.wrapping_div(rhs) - } as u32; - StepRecord::new_r_instruction( - 92 + i as u64 * 4, - ByteAddr(0x4b00 + i as u32 * 4), - encode_rv32( - InsnKind::DIV, - 19 + i as u32, - 20 + i as u32, - 21 + i as u32, - 0, - ), - lhs as u32, - rhs as u32, - Change::new(0, out), - 0, - ) - }) - .collect(); - - assert_lk_shardram_match::>( - &config, - cb.cs.num_witin as usize, - cb.cs.num_structural_witin as usize, - &steps, - ); - } - - #[test] - fn test_remu_lk_shardram_match_assign_instance() { - let mut cs = ConstraintSystem::::new(|| "remu_lk_shardram"); - let mut cb = CircuitBuilder::new(&mut cs); - let config = - RemuInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); - let cases = [(17u32, 3u32), (0x8000_0001u32, 0u32)]; - let steps: Vec<_> = cases - .into_iter() - .enumerate() - .map(|(i, (lhs, rhs))| { - let out = if rhs == 0 { lhs } else { lhs % rhs }; - StepRecord::new_r_instruction( - 100 + i as u64 * 4, - ByteAddr(0x4c00 + i as u32 * 4), - encode_rv32( - InsnKind::REMU, - 22 + i as u32, - 23 + i as u32, - 24 + i as u32, - 0, - ), - lhs, - rhs, - Change::new(0, out), - 0, - ) - }) - .collect(); - - assert_lk_shardram_match::>( - &config, - cb.cs.num_witin as usize, - cb.cs.num_structural_witin as usize, - &steps, - ); - } - - #[test] - fn test_lk_op_encodings_match_cpu_multiplicity() { - let ops = [ - LkOp::AssertU16 { value: 7 }, - LkOp::DynamicRange { value: 11, bits: 8 }, - LkOp::AssertU14 { value: 5 }, - LkOp::Fetch { pc: 0x1234 }, - LkOp::DoubleU8 { a: 1, b: 2 }, - LkOp::And { a: 3, b: 4 }, - LkOp::Or { a: 5, b: 6 }, - LkOp::Xor { a: 7, b: 8 }, - LkOp::Ltu { a: 9, b: 10 }, - LkOp::Pow2 { value: 12 }, - LkOp::ShrByte { - shift: 3, - carry: 17, - bits: 2, - }, - ]; - - let mut lk = LkMultiplicity::default(); - for op in ops { - for (table, key) in op.encode_all() { - lk.increment(table, key); - } - } - - let finalized = lk.into_finalize_result(); - assert_eq!(finalized[LookupTable::Dynamic as usize].len(), 3); - assert_eq!(finalized[LookupTable::Instruction as usize].len(), 1); - assert_eq!(finalized[LookupTable::DoubleU8 as usize].len(), 3); - assert_eq!(finalized[LookupTable::And as usize].len(), 1); - assert_eq!(finalized[LookupTable::Or as usize].len(), 1); - assert_eq!(finalized[LookupTable::Xor as usize].len(), 1); - assert_eq!(finalized[LookupTable::Ltu as usize].len(), 1); - assert_eq!(finalized[LookupTable::Pow as usize].len(), 1); - } -} diff --git a/ceno_zkvm/src/instructions/gpu/mod.rs b/ceno_zkvm/src/instructions/gpu/mod.rs index 2a7e8c7da..3983cfcd7 100644 --- a/ceno_zkvm/src/instructions/gpu/mod.rs +++ b/ceno_zkvm/src/instructions/gpu/mod.rs @@ -1,4 +1,4 @@ -pub mod host_ops; +pub mod utils; #[cfg(feature = "gpu")] pub mod add; #[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] @@ -59,3 +59,831 @@ pub mod d2h; pub mod device_cache; #[cfg(feature = "gpu")] pub mod witgen_gpu; + +#[cfg(test)] +mod tests { + use super::utils::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + e2e::ShardContext, + instructions::{ + Instruction, cpu_assign_instances, cpu_collect_shardram, + cpu_collect_lk_and_shardram, + riscv::{ + AddInstruction, JalInstruction, JalrInstruction, LwInstruction, SbInstruction, + branch::{BeqInstruction, BltInstruction}, + div::{DivInstruction, RemuInstruction}, + logic::AndInstruction, + mulh::{MulInstruction, MulhInstruction}, + shift::SraInstruction, + shift_imm::SlliInstruction, + slt::SltInstruction, + slti::SltiInstruction, + }, + }, + structs::ProgramParams, + witness::LkMultiplicity, + }; + use ceno_emul::{ + ByteAddr, Change, InsnKind, PC_STEP_SIZE, ReadOp, StepRecord, WordAddr, WriteOp, + encode_rv32, + }; + use ff_ext::GoldilocksExt2; + use gkr_iop::tables::LookupTable; + + type E = GoldilocksExt2; + + fn assert_lk_shardram_match>( + config: &I::InstructionConfig, + num_witin: usize, + num_structural_witin: usize, + steps: &[StepRecord], + ) { + let indices: Vec = (0..steps.len()).collect(); + + let mut assign_ctx = ShardContext::default(); + let (_, expected_lk) = cpu_assign_instances::( + config, + &mut assign_ctx, + num_witin, + num_structural_witin, + steps, + &indices, + ) + .unwrap(); + + let mut collect_ctx = ShardContext::default(); + let actual_lk = + cpu_collect_lk_and_shardram::(config, &mut collect_ctx, steps, &indices).unwrap(); + + assert_eq!(flatten_lk(&expected_lk), flatten_lk(&actual_lk)); + assert_eq!( + assign_ctx.get_addr_accessed(), + collect_ctx.get_addr_accessed() + ); + assert_eq!( + flatten_records(assign_ctx.read_records()), + flatten_records(collect_ctx.read_records()) + ); + assert_eq!( + flatten_records(assign_ctx.write_records()), + flatten_records(collect_ctx.write_records()) + ); + } + + fn assert_shard_lk_shardram_match>( + config: &I::InstructionConfig, + num_witin: usize, + num_structural_witin: usize, + steps: &[StepRecord], + ) { + let indices: Vec = (0..steps.len()).collect(); + + let mut assign_ctx = ShardContext::default(); + let (_, expected_lk) = cpu_assign_instances::( + config, + &mut assign_ctx, + num_witin, + num_structural_witin, + steps, + &indices, + ) + .unwrap(); + + let mut collect_ctx = ShardContext::default(); + let actual_lk = + cpu_collect_shardram::(config, &mut collect_ctx, steps, &indices) + .unwrap(); + + assert_eq!( + expected_lk[LookupTable::Instruction as usize], + actual_lk[LookupTable::Instruction as usize] + ); + for (table_idx, table) in actual_lk.iter().enumerate() { + if table_idx != LookupTable::Instruction as usize { + assert!( + table.is_empty(), + "unexpected non-fetch shard-only multiplicity in table {table_idx}: {table:?}" + ); + } + } + assert_eq!( + assign_ctx.get_addr_accessed(), + collect_ctx.get_addr_accessed() + ); + assert_eq!( + flatten_records(assign_ctx.read_records()), + flatten_records(collect_ctx.read_records()) + ); + assert_eq!( + flatten_records(assign_ctx.write_records()), + flatten_records(collect_ctx.write_records()) + ); + } + + fn flatten_records( + records: &[std::collections::BTreeMap], + ) -> Vec<(WordAddr, u64, u64, usize)> { + records + .iter() + .flat_map(|table| { + table + .iter() + .map(|(addr, record)| (*addr, record.prev_cycle, record.cycle, record.shard_id)) + }) + .collect() + } + + fn flatten_lk( + multiplicity: &gkr_iop::utils::lk_multiplicity::Multiplicity, + ) -> Vec> { + multiplicity + .iter() + .map(|table| { + let mut entries = table + .iter() + .map(|(key, count)| (*key, *count)) + .collect::>(); + entries.sort_unstable(); + entries + }) + .collect() + } + + #[test] + fn test_add_lk_shardram_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "add_lk_shardram"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + AddInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let steps: Vec<_> = (0..4) + .map(|i| { + let rd = 2 + i; + let rs1 = 8 + i; + let rs2 = 16 + i; + let lhs = 10 + i as u32; + let rhs = 100 + i as u32; + let insn = encode_rv32(InsnKind::ADD, rs1, rs2, rd, 0); + StepRecord::new_r_instruction( + 4 + (i as u64) * 4, + ByteAddr(0x1000 + i as u32 * 4), + insn, + lhs, + rhs, + Change::new(0, lhs.wrapping_add(rhs)), + 0, + ) + }) + .collect(); + + assert_lk_shardram_match::>( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &steps, + ); + } + + #[test] + fn test_and_lk_shardram_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "and_lk_shardram"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + AndInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let steps: Vec<_> = (0..4) + .map(|i| { + let rd = 2 + i; + let rs1 = 8 + i; + let rs2 = 16 + i; + let lhs = 0xdead_0000 | i as u32; + let rhs = 0x00ff_ff00 | ((i as u32) << 8); + let insn = encode_rv32(InsnKind::AND, rs1, rs2, rd, 0); + StepRecord::new_r_instruction( + 4 + (i as u64) * 4, + ByteAddr(0x2000 + i as u32 * 4), + insn, + lhs, + rhs, + Change::new(0, lhs & rhs), + 0, + ) + }) + .collect(); + + assert_lk_shardram_match::>( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &steps, + ); + } + + #[test] + fn test_add_shard_lk_shardram_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "add_shard_lk_shardram"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + AddInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let steps: Vec<_> = (0..4) + .map(|i| { + let rd = 2 + i; + let rs1 = 8 + i; + let rs2 = 16 + i; + let lhs = 10 + i as u32; + let rhs = 100 + i as u32; + let insn = encode_rv32(InsnKind::ADD, rs1, rs2, rd, 0); + StepRecord::new_r_instruction( + 84 + (i as u64) * 4, + ByteAddr(0x5000 + i as u32 * 4), + insn, + lhs, + rhs, + Change::new(0, lhs.wrapping_add(rhs)), + 0, + ) + }) + .collect(); + + assert_shard_lk_shardram_match::>( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &steps, + ); + } + + #[test] + fn test_and_shard_lk_shardram_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "and_shard_lk_shardram"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + AndInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let steps: Vec<_> = (0..4) + .map(|i| { + let rd = 2 + i; + let rs1 = 8 + i; + let rs2 = 16 + i; + let lhs = 0xdead_0000 | i as u32; + let rhs = 0x00ff_ff00 | ((i as u32) << 8); + let insn = encode_rv32(InsnKind::AND, rs1, rs2, rd, 0); + StepRecord::new_r_instruction( + 100 + (i as u64) * 4, + ByteAddr(0x5100 + i as u32 * 4), + insn, + lhs, + rhs, + Change::new(0, lhs & rhs), + 0, + ) + }) + .collect(); + + assert_shard_lk_shardram_match::>( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &steps, + ); + } + + #[test] + fn test_lw_lk_shardram_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "lw_lk_shardram"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + LwInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let steps: Vec<_> = (0..4) + .map(|i| { + let rd = 2 + i; + let rs1 = 8 + i; + let rs1_val = 0x1000u32 + (i as u32) * 16; + let imm = (i as i32) * 4 - 4; + let mem_addr = rs1_val.wrapping_add_signed(imm); + let mem_val = 0xabc0_0000 | i as u32; + let insn = encode_rv32(InsnKind::LW, rs1, 0, rd, imm); + let mem_read = ReadOp { + addr: WordAddr::from(ByteAddr(mem_addr)), + value: mem_val, + previous_cycle: 0, + }; + StepRecord::new_im_instruction( + 4 + (i as u64) * 4, + ByteAddr(0x3000 + i as u32 * 4), + insn, + rs1_val, + Change::new(0, mem_val), + mem_read, + 0, + ) + }) + .collect(); + + assert_lk_shardram_match::>( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &steps, + ); + } + + #[test] + fn test_lw_shard_lk_shardram_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "lw_shard_lk_shardram"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + LwInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let steps: Vec<_> = (0..4) + .map(|i| { + let rd = 2 + i; + let rs1 = 8 + i; + let rs1_val = 0x1400u32 + (i as u32) * 16; + let imm = (i as i32) * 4 - 4; + let mem_addr = rs1_val.wrapping_add_signed(imm); + let mem_val = 0xabd0_0000 | i as u32; + let insn = encode_rv32(InsnKind::LW, rs1, 0, rd, imm); + let mem_read = ReadOp { + addr: WordAddr::from(ByteAddr(mem_addr)), + value: mem_val, + previous_cycle: 0, + }; + StepRecord::new_im_instruction( + 116 + (i as u64) * 4, + ByteAddr(0x5200 + i as u32 * 4), + insn, + rs1_val, + Change::new(0, mem_val), + mem_read, + 0, + ) + }) + .collect(); + + assert_shard_lk_shardram_match::>( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &steps, + ); + } + + #[test] + fn test_beq_lk_shardram_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "beq_lk_shardram"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + BeqInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let cases = [ + (true, 0x1122_3344, 0x1122_3344), + (false, 0x5566_7788, 0x99aa_bbcc), + ]; + let steps: Vec<_> = cases + .into_iter() + .enumerate() + .map(|(i, (taken, lhs, rhs))| { + let pc = ByteAddr(0x4000 + i as u32 * 4); + let next_pc = if taken { + ByteAddr(pc.0 + 8) + } else { + pc + PC_STEP_SIZE + }; + StepRecord::new_b_instruction( + 4 + i as u64 * 4, + Change::new(pc, next_pc), + encode_rv32(InsnKind::BEQ, 8 + i as u32, 16 + i as u32, 0, 8), + lhs, + rhs, + 0, + ) + }) + .collect(); + + assert_lk_shardram_match::>( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &steps, + ); + } + + #[test] + fn test_blt_lk_shardram_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "blt_lk_shardram"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + BltInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let cases = [(true, (-2i32) as u32, 1u32), (false, 7u32, (-3i32) as u32)]; + let steps: Vec<_> = cases + .into_iter() + .enumerate() + .map(|(i, (taken, lhs, rhs))| { + let pc = ByteAddr(0x4100 + i as u32 * 4); + let next_pc = if taken { + ByteAddr(pc.0.wrapping_sub(8)) + } else { + pc + PC_STEP_SIZE + }; + StepRecord::new_b_instruction( + 12 + i as u64 * 4, + Change::new(pc, next_pc), + encode_rv32(InsnKind::BLT, 4 + i as u32, 5 + i as u32, 0, -8), + lhs, + rhs, + 10, + ) + }) + .collect(); + + assert_lk_shardram_match::>( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &steps, + ); + } + + #[test] + fn test_jal_lk_shardram_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "jal_lk_shardram"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + JalInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let offsets = [8, -8]; + let steps: Vec<_> = offsets + .into_iter() + .enumerate() + .map(|(i, offset)| { + let pc = ByteAddr(0x4200 + i as u32 * 4); + StepRecord::new_j_instruction( + 20 + i as u64 * 4, + Change::new(pc, ByteAddr(pc.0.wrapping_add_signed(offset))), + encode_rv32(InsnKind::JAL, 0, 0, 3 + i as u32, offset), + Change::new(0, (pc + PC_STEP_SIZE).into()), + 0, + ) + }) + .collect(); + + assert_lk_shardram_match::>( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &steps, + ); + } + + #[test] + fn test_jalr_lk_shardram_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "jalr_lk_shardram"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + JalrInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let cases = [(100u32, 3), (0x4010u32, -5)]; + let steps: Vec<_> = cases + .into_iter() + .enumerate() + .map(|(i, (rs1, imm))| { + let pc = ByteAddr(0x4300 + i as u32 * 4); + let next_pc = ByteAddr(rs1.wrapping_add_signed(imm) & !1); + StepRecord::new_i_instruction( + 28 + i as u64 * 4, + Change::new(pc, next_pc), + encode_rv32(InsnKind::JALR, 2 + i as u32, 0, 6 + i as u32, imm), + rs1, + Change::new(0, (pc + PC_STEP_SIZE).into()), + 0, + ) + }) + .collect(); + + assert_lk_shardram_match::>( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &steps, + ); + } + + #[test] + fn test_slt_lk_shardram_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "slt_lk_shardram"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + SltInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let cases = [((-1i32) as u32, 0u32), (5u32, (-2i32) as u32)]; + let steps: Vec<_> = cases + .into_iter() + .enumerate() + .map(|(i, (lhs, rhs))| { + let insn = + encode_rv32(InsnKind::SLT, 9 + i as u32, 10 + i as u32, 11 + i as u32, 0); + StepRecord::new_r_instruction( + 36 + i as u64 * 4, + ByteAddr(0x4400 + i as u32 * 4), + insn, + lhs, + rhs, + Change::new(0, ((lhs as i32) < (rhs as i32)) as u32), + 0, + ) + }) + .collect(); + + assert_lk_shardram_match::>( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &steps, + ); + } + + #[test] + fn test_slti_lk_shardram_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "slti_lk_shardram"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + SltiInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let cases = [(0u32, -1), ((-2i32) as u32, 1)]; + let steps: Vec<_> = cases + .into_iter() + .enumerate() + .map(|(i, (rs1, imm))| { + let insn = encode_rv32(InsnKind::SLTI, 12 + i as u32, 0, 13 + i as u32, imm); + let pc = ByteAddr(0x4500 + i as u32 * 4); + StepRecord::new_i_instruction( + 44 + i as u64 * 4, + Change::new(pc, pc + PC_STEP_SIZE), + insn, + rs1, + Change::new(0, ((rs1 as i32) < imm) as u32), + 0, + ) + }) + .collect(); + + assert_lk_shardram_match::>( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &steps, + ); + } + + #[test] + fn test_sra_lk_shardram_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "sra_lk_shardram"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + SraInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let cases = [(0x8765_4321u32, 4u32), (0xf000_0000u32, 31u32)]; + let steps: Vec<_> = cases + .into_iter() + .enumerate() + .map(|(i, (lhs, rhs))| { + let shift = rhs & 31; + let rd = ((lhs as i32) >> shift) as u32; + StepRecord::new_r_instruction( + 52 + i as u64 * 4, + ByteAddr(0x4600 + i as u32 * 4), + encode_rv32(InsnKind::SRA, 6 + i as u32, 7 + i as u32, 8 + i as u32, 0), + lhs, + rhs, + Change::new(0, rd), + 0, + ) + }) + .collect(); + + assert_lk_shardram_match::>( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &steps, + ); + } + + #[test] + fn test_slli_lk_shardram_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "slli_lk_shardram"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + SlliInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let cases = [(0x1234_5678u32, 3), (0x0000_0001u32, 31)]; + let steps: Vec<_> = cases + .into_iter() + .enumerate() + .map(|(i, (rs1, imm))| { + let pc = ByteAddr(0x4700 + i as u32 * 4); + StepRecord::new_i_instruction( + 60 + i as u64 * 4, + Change::new(pc, pc + PC_STEP_SIZE), + encode_rv32(InsnKind::SLLI, 9 + i as u32, 0, 10 + i as u32, imm), + rs1, + Change::new(0, rs1.wrapping_shl((imm & 31) as u32)), + 0, + ) + }) + .collect(); + + assert_lk_shardram_match::>( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &steps, + ); + } + + #[test] + fn test_sb_lk_shardram_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "sb_lk_shardram"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + SbInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let steps: Vec<_> = (0..2) + .map(|i| { + let rs1 = 0x4800u32 + i * 16; + let rs2 = 0x1234_5600u32 | i; + let imm = i as i32 - 1; + let addr = ByteAddr::from(rs1.wrapping_add_signed(imm)); + let prev = 0x4030_2010u32 + i; + let shift = (addr.shift() * 8) as usize; + let mut next = prev & !(0xff << shift); + next |= (rs2 & 0xff) << shift; + StepRecord::new_s_instruction( + 68 + i as u64 * 4, + ByteAddr(0x4800 + i * 4), + encode_rv32(InsnKind::SB, 11 + i, 12 + i, 0, imm), + rs1, + rs2, + WriteOp { + addr: addr.waddr(), + value: Change::new(prev, next), + previous_cycle: 4, + }, + 8, + ) + }) + .collect(); + + assert_lk_shardram_match::>( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &steps, + ); + } + + #[test] + fn test_mul_lk_shardram_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "mul_lk_shardram"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + MulInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let cases = [(2u32, 11u32), (u32::MAX, 17u32)]; + let steps: Vec<_> = cases + .into_iter() + .enumerate() + .map(|(i, (lhs, rhs))| { + StepRecord::new_r_instruction( + 76 + i as u64 * 4, + ByteAddr(0x4900 + i as u32 * 4), + encode_rv32( + InsnKind::MUL, + 13 + i as u32, + 14 + i as u32, + 15 + i as u32, + 0, + ), + lhs, + rhs, + Change::new(0, lhs.wrapping_mul(rhs)), + 0, + ) + }) + .collect(); + + assert_lk_shardram_match::>( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &steps, + ); + } + + #[test] + fn test_mulh_lk_shardram_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "mulh_lk_shardram"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + MulhInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let cases = [(2i32, -11i32), (i32::MIN, -1i32)]; + let steps: Vec<_> = cases + .into_iter() + .enumerate() + .map(|(i, (lhs, rhs))| { + let outcome = ((lhs as i64).wrapping_mul(rhs as i64) >> 32) as u32; + StepRecord::new_r_instruction( + 84 + i as u64 * 4, + ByteAddr(0x4a00 + i as u32 * 4), + encode_rv32( + InsnKind::MULH, + 16 + i as u32, + 17 + i as u32, + 18 + i as u32, + 0, + ), + lhs as u32, + rhs as u32, + Change::new(0, outcome), + 0, + ) + }) + .collect(); + + assert_lk_shardram_match::>( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &steps, + ); + } + + #[test] + fn test_div_lk_shardram_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "div_lk_shardram"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + DivInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let cases = [(17i32, -3i32), (i32::MIN, -1i32)]; + let steps: Vec<_> = cases + .into_iter() + .enumerate() + .map(|(i, (lhs, rhs))| { + let out = if rhs == 0 { + -1i32 + } else { + lhs.wrapping_div(rhs) + } as u32; + StepRecord::new_r_instruction( + 92 + i as u64 * 4, + ByteAddr(0x4b00 + i as u32 * 4), + encode_rv32( + InsnKind::DIV, + 19 + i as u32, + 20 + i as u32, + 21 + i as u32, + 0, + ), + lhs as u32, + rhs as u32, + Change::new(0, out), + 0, + ) + }) + .collect(); + + assert_lk_shardram_match::>( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &steps, + ); + } + + #[test] + fn test_remu_lk_shardram_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "remu_lk_shardram"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + RemuInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let cases = [(17u32, 3u32), (0x8000_0001u32, 0u32)]; + let steps: Vec<_> = cases + .into_iter() + .enumerate() + .map(|(i, (lhs, rhs))| { + let out = if rhs == 0 { lhs } else { lhs % rhs }; + StepRecord::new_r_instruction( + 100 + i as u64 * 4, + ByteAddr(0x4c00 + i as u32 * 4), + encode_rv32( + InsnKind::REMU, + 22 + i as u32, + 23 + i as u32, + 24 + i as u32, + 0, + ), + lhs, + rhs, + Change::new(0, out), + 0, + ) + }) + .collect(); + + assert_lk_shardram_match::>( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &steps, + ); + } +} diff --git a/ceno_zkvm/src/instructions/gpu/host_ops/emit.rs b/ceno_zkvm/src/instructions/gpu/utils/emit.rs similarity index 100% rename from ceno_zkvm/src/instructions/gpu/host_ops/emit.rs rename to ceno_zkvm/src/instructions/gpu/utils/emit.rs diff --git a/ceno_zkvm/src/instructions/gpu/host_ops/cpu_fallback.rs b/ceno_zkvm/src/instructions/gpu/utils/fallback.rs similarity index 100% rename from ceno_zkvm/src/instructions/gpu/host_ops/cpu_fallback.rs rename to ceno_zkvm/src/instructions/gpu/utils/fallback.rs diff --git a/ceno_zkvm/src/instructions/gpu/host_ops/lk_ops.rs b/ceno_zkvm/src/instructions/gpu/utils/lk_ops.rs similarity index 100% rename from ceno_zkvm/src/instructions/gpu/host_ops/lk_ops.rs rename to ceno_zkvm/src/instructions/gpu/utils/lk_ops.rs diff --git a/ceno_zkvm/src/instructions/gpu/utils/mod.rs b/ceno_zkvm/src/instructions/gpu/utils/mod.rs new file mode 100644 index 000000000..93798a7d3 --- /dev/null +++ b/ceno_zkvm/src/instructions/gpu/utils/mod.rs @@ -0,0 +1,60 @@ +//! Host-side operations for GPU-CPU hybrid witness generation. +//! +//! Contains lookup/shard lk_shardram collection abstractions and CPU fallback paths. + +mod lk_ops; +mod sink; +mod emit; +mod fallback; + +// Re-export all public types for convenience +pub use lk_ops::*; +pub use sink::*; +pub use emit::*; +pub use fallback::*; + + +#[cfg(test)] +mod tests { + use super::*; + use crate::witness::LkMultiplicity; + use gkr_iop::tables::LookupTable; + + #[test] + fn test_lk_op_encodings_match_cpu_multiplicity() { + let ops = [ + LkOp::AssertU16 { value: 7 }, + LkOp::DynamicRange { value: 11, bits: 8 }, + LkOp::AssertU14 { value: 5 }, + LkOp::Fetch { pc: 0x1234 }, + LkOp::DoubleU8 { a: 1, b: 2 }, + LkOp::And { a: 3, b: 4 }, + LkOp::Or { a: 5, b: 6 }, + LkOp::Xor { a: 7, b: 8 }, + LkOp::Ltu { a: 9, b: 10 }, + LkOp::Pow2 { value: 12 }, + LkOp::ShrByte { + shift: 3, + carry: 17, + bits: 2, + }, + ]; + + let mut lk = LkMultiplicity::default(); + for op in ops { + for (table, key) in op.encode_all() { + lk.increment(table, key); + } + } + + let finalized = lk.into_finalize_result(); + assert_eq!(finalized[LookupTable::Dynamic as usize].len(), 3); + assert_eq!(finalized[LookupTable::Instruction as usize].len(), 1); + assert_eq!(finalized[LookupTable::DoubleU8 as usize].len(), 3); + assert_eq!(finalized[LookupTable::And as usize].len(), 1); + assert_eq!(finalized[LookupTable::Or as usize].len(), 1); + assert_eq!(finalized[LookupTable::Xor as usize].len(), 1); + assert_eq!(finalized[LookupTable::Ltu as usize].len(), 1); + assert_eq!(finalized[LookupTable::Pow as usize].len(), 1); + } +} diff --git a/ceno_zkvm/src/instructions/gpu/host_ops/sink.rs b/ceno_zkvm/src/instructions/gpu/utils/sink.rs similarity index 100% rename from ceno_zkvm/src/instructions/gpu/host_ops/sink.rs rename to ceno_zkvm/src/instructions/gpu/utils/sink.rs diff --git a/ceno_zkvm/src/instructions/riscv/arith.rs b/ceno_zkvm/src/instructions/riscv/arith.rs index f7dd04305..07b2d6500 100644 --- a/ceno_zkvm/src/instructions/riscv/arith.rs +++ b/ceno_zkvm/src/instructions/riscv/arith.rs @@ -8,7 +8,7 @@ use crate::{ error::ZKVMError, instructions::{ Instruction, - gpu::host_ops::emit_u16_limbs, + gpu::utils::emit_u16_limbs, }, structs::ProgramParams, uint::Value, diff --git a/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit_v2.rs index 0eade7a25..c87c033b4 100644 --- a/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit_v2.rs @@ -7,7 +7,7 @@ use crate::{ instructions::{ Instruction, riscv::{RIVInstruction, constants::UInt, i_insn::IInstructionConfig}, - gpu::host_ops::emit_u16_limbs, + gpu::utils::emit_u16_limbs, }, structs::ProgramParams, utils::{imm_sign_extend, imm_sign_extend_circuit}, diff --git a/ceno_zkvm/src/instructions/riscv/auipc.rs b/ceno_zkvm/src/instructions/riscv/auipc.rs index 7e7da705d..f3d642794 100644 --- a/ceno_zkvm/src/instructions/riscv/auipc.rs +++ b/ceno_zkvm/src/instructions/riscv/auipc.rs @@ -13,7 +13,7 @@ use crate::{ constants::{PC_BITS, UINT_BYTE_LIMBS, UInt8}, i_insn::IInstructionConfig, }, - gpu::host_ops::{ + gpu::utils::{ LkOp, LkShardramSink, emit_byte_decomposition_ops, emit_const_range_op, }, diff --git a/ceno_zkvm/src/instructions/riscv/b_insn.rs b/ceno_zkvm/src/instructions/riscv/b_insn.rs index b5f7d546a..382c62981 100644 --- a/ceno_zkvm/src/instructions/riscv/b_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/b_insn.rs @@ -9,7 +9,7 @@ use crate::{ error::ZKVMError, instructions::{ riscv::insn_base::{ReadRS1, ReadRS2, StateInOut}, - gpu::host_ops::{LkOp, LkShardramSink}, + gpu::utils::{LkOp, LkShardramSink}, }, tables::InsnRecord, witness::{LkMultiplicity, set_val}, diff --git a/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs index 2f1c42975..b7919af31 100644 --- a/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs @@ -12,7 +12,7 @@ use crate::{ b_insn::BInstructionConfig, constants::{UINT_LIMBS, UInt}, }, - gpu::host_ops::emit_uint_limbs_lt_ops, + gpu::utils::emit_uint_limbs_lt_ops, }, structs::ProgramParams, witness::LkMultiplicity, diff --git a/ceno_zkvm/src/instructions/riscv/div/div_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/div/div_circuit_v2.rs index de7476e9d..3ed1bd071 100644 --- a/ceno_zkvm/src/instructions/riscv/div/div_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/div/div_circuit_v2.rs @@ -18,7 +18,7 @@ use crate::{ instructions::{ Instruction, riscv::constants::LIMB_BITS, - gpu::host_ops::{LkOp, LkShardramSink, emit_u16_limbs}, + gpu::utils::{LkOp, LkShardramSink, emit_u16_limbs}, }, structs::ProgramParams, uint::Value, diff --git a/ceno_zkvm/src/instructions/riscv/i_insn.rs b/ceno_zkvm/src/instructions/riscv/i_insn.rs index ddda27c95..2250c5365 100644 --- a/ceno_zkvm/src/instructions/riscv/i_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/i_insn.rs @@ -8,7 +8,7 @@ use crate::{ error::ZKVMError, instructions::{ riscv::insn_base::{ReadRS1, StateInOut, WriteRD}, - gpu::host_ops::{LkOp, LkShardramSink}, + gpu::utils::{LkOp, LkShardramSink}, }, tables::InsnRecord, witness::LkMultiplicity, diff --git a/ceno_zkvm/src/instructions/riscv/im_insn.rs b/ceno_zkvm/src/instructions/riscv/im_insn.rs index ce939a00a..1ea3ce137 100644 --- a/ceno_zkvm/src/instructions/riscv/im_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/im_insn.rs @@ -4,7 +4,7 @@ use crate::{ error::ZKVMError, instructions::{ riscv::insn_base::{ReadMEM, ReadRS1, StateInOut, WriteRD}, - gpu::host_ops::{LkOp, LkShardramSink}, + gpu::utils::{LkOp, LkShardramSink}, }, tables::InsnRecord, witness::LkMultiplicity, diff --git a/ceno_zkvm/src/instructions/riscv/insn_base.rs b/ceno_zkvm/src/instructions/riscv/insn_base.rs index 6bbdbe583..a6b4dbc87 100644 --- a/ceno_zkvm/src/instructions/riscv/insn_base.rs +++ b/ceno_zkvm/src/instructions/riscv/insn_base.rs @@ -13,7 +13,7 @@ use crate::{ e2e::ShardContext, error::ZKVMError, gadgets::AssertLtConfig, - instructions::gpu::host_ops::{LkOp, SendEvent, LkShardramSink, emit_assert_lt_ops}, + instructions::gpu::utils::{LkOp, SendEvent, LkShardramSink, emit_assert_lt_ops}, structs::RAMType, uint::Value, witness::{LkMultiplicity, set_val}, diff --git a/ceno_zkvm/src/instructions/riscv/j_insn.rs b/ceno_zkvm/src/instructions/riscv/j_insn.rs index da95642e0..8ce03b5ac 100644 --- a/ceno_zkvm/src/instructions/riscv/j_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/j_insn.rs @@ -8,7 +8,7 @@ use crate::{ error::ZKVMError, instructions::{ riscv::insn_base::{StateInOut, WriteRD}, - gpu::host_ops::{LkOp, LkShardramSink}, + gpu::utils::{LkOp, LkShardramSink}, }, tables::InsnRecord, witness::LkMultiplicity, diff --git a/ceno_zkvm/src/instructions/riscv/jump/jal_v2.rs b/ceno_zkvm/src/instructions/riscv/jump/jal_v2.rs index d6555533a..fac4243b9 100644 --- a/ceno_zkvm/src/instructions/riscv/jump/jal_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/jump/jal_v2.rs @@ -13,7 +13,7 @@ use crate::{ constants::{PC_BITS, UINT_BYTE_LIMBS, UInt8}, j_insn::JInstructionConfig, }, - gpu::host_ops::{LkOp, LkShardramSink, emit_byte_decomposition_ops}, + gpu::utils::{LkOp, LkShardramSink, emit_byte_decomposition_ops}, }, structs::ProgramParams, utils::split_to_u8, diff --git a/ceno_zkvm/src/instructions/riscv/jump/jalr_v2.rs b/ceno_zkvm/src/instructions/riscv/jump/jalr_v2.rs index 2e4c89916..ac8ddfe20 100644 --- a/ceno_zkvm/src/instructions/riscv/jump/jalr_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/jump/jalr_v2.rs @@ -15,7 +15,7 @@ use crate::{ i_insn::IInstructionConfig, insn_base::{MemAddr, ReadRS1, StateInOut, WriteRD}, }, - gpu::host_ops::emit_const_range_op, + gpu::utils::emit_const_range_op, }, structs::ProgramParams, tables::InsnRecord, diff --git a/ceno_zkvm/src/instructions/riscv/logic/logic_circuit.rs b/ceno_zkvm/src/instructions/riscv/logic/logic_circuit.rs index 8c10683e2..a233bdb6c 100644 --- a/ceno_zkvm/src/instructions/riscv/logic/logic_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/logic/logic_circuit.rs @@ -12,7 +12,7 @@ use crate::{ instructions::{ Instruction, riscv::{constants::UInt8, r_insn::RInstructionConfig}, - gpu::host_ops::emit_logic_u8_ops, + gpu::utils::emit_logic_u8_ops, }, structs::ProgramParams, utils::split_to_u8, @@ -156,7 +156,7 @@ impl LogicConfig { fn emit_lk_and_shardram( &self, - sink: &mut impl crate::instructions::gpu::host_ops::LkShardramSink, + sink: &mut impl crate::instructions::gpu::utils::LkShardramSink, shard_ctx: &ShardContext, step: &StepRecord, ) { diff --git a/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit_v2.rs index bfe36b3ae..010c899fb 100644 --- a/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit_v2.rs @@ -17,7 +17,7 @@ use crate::{ i_insn::IInstructionConfig, logic_imm::LogicOp, }, - gpu::host_ops::emit_logic_u8_ops, + gpu::utils::emit_logic_u8_ops, }, structs::ProgramParams, tables::InsnRecord, diff --git a/ceno_zkvm/src/instructions/riscv/lui.rs b/ceno_zkvm/src/instructions/riscv/lui.rs index 866b90f2d..4b491fb62 100644 --- a/ceno_zkvm/src/instructions/riscv/lui.rs +++ b/ceno_zkvm/src/instructions/riscv/lui.rs @@ -13,7 +13,7 @@ use crate::{ constants::{UINT_BYTE_LIMBS, UInt8}, i_insn::IInstructionConfig, }, - gpu::host_ops::emit_const_range_op, + gpu::utils::emit_const_range_op, }, structs::ProgramParams, tables::InsnRecord, diff --git a/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs b/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs index d74214a93..479e0e22c 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs @@ -13,7 +13,7 @@ use crate::{ memory::gadget::MemWordUtil, s_insn::SInstructionConfig, }, - gpu::host_ops::{emit_const_range_op, emit_u16_limbs}, + gpu::utils::{emit_const_range_op, emit_u16_limbs}, }, structs::ProgramParams, tables::InsnRecord, diff --git a/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit_v2.rs index 76719b168..31765b607 100644 --- a/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit_v2.rs @@ -9,7 +9,7 @@ use crate::{ constants::{LIMB_BITS, UINT_LIMBS, UInt}, r_insn::RInstructionConfig, }, - gpu::host_ops::{LkOp, LkShardramSink}, + gpu::utils::{LkOp, LkShardramSink}, }, structs::ProgramParams, uint::Value, diff --git a/ceno_zkvm/src/instructions/riscv/r_insn.rs b/ceno_zkvm/src/instructions/riscv/r_insn.rs index b372a8bb2..4ad09b9d6 100644 --- a/ceno_zkvm/src/instructions/riscv/r_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/r_insn.rs @@ -8,7 +8,7 @@ use crate::{ error::ZKVMError, instructions::{ riscv::insn_base::{ReadRS1, ReadRS2, StateInOut, WriteRD}, - gpu::host_ops::{LkOp, LkShardramSink}, + gpu::utils::{LkOp, LkShardramSink}, }, tables::InsnRecord, witness::LkMultiplicity, diff --git a/ceno_zkvm/src/instructions/riscv/s_insn.rs b/ceno_zkvm/src/instructions/riscv/s_insn.rs index d136101b4..1d849146c 100644 --- a/ceno_zkvm/src/instructions/riscv/s_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/s_insn.rs @@ -5,7 +5,7 @@ use crate::{ error::ZKVMError, instructions::{ riscv::insn_base::{ReadRS1, ReadRS2, StateInOut, WriteMEM}, - gpu::host_ops::{LkOp, LkShardramSink}, + gpu::utils::{LkOp, LkShardramSink}, }, tables::InsnRecord, witness::LkMultiplicity, diff --git a/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs index b9496ac20..ec24dc4ff 100644 --- a/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs @@ -10,7 +10,7 @@ use crate::{ i_insn::IInstructionConfig, r_insn::RInstructionConfig, }, - gpu::host_ops::{ + gpu::utils::{ LkOp, LkShardramSink, emit_byte_decomposition_ops, emit_const_range_op, }, diff --git a/ceno_zkvm/src/instructions/riscv/slt/slt_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/slt/slt_circuit_v2.rs index 562712954..749f9e8a6 100644 --- a/ceno_zkvm/src/instructions/riscv/slt/slt_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/slt/slt_circuit_v2.rs @@ -8,7 +8,7 @@ use crate::{ instructions::{ Instruction, riscv::{RIVInstruction, constants::UInt, r_insn::RInstructionConfig}, - gpu::host_ops::emit_uint_limbs_lt_ops, + gpu::utils::emit_uint_limbs_lt_ops, }, structs::ProgramParams, witness::LkMultiplicity, diff --git a/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs index 19a30c647..081e432cf 100644 --- a/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs @@ -12,7 +12,7 @@ use crate::{ constants::{UINT_LIMBS, UInt}, i_insn::IInstructionConfig, }, - gpu::host_ops::emit_uint_limbs_lt_ops, + gpu::utils::emit_uint_limbs_lt_ops, }, structs::ProgramParams, utils::{imm_sign_extend, imm_sign_extend_circuit}, From bcf8e0c952ccba7a9db1ae843a7aec489de29747 Mon Sep 17 00:00:00 2001 From: Velaciela Date: Tue, 24 Mar 2026 12:48:47 +0800 Subject: [PATCH 63/73] naming: ceno_zkvm/src/instructions/gpu --- ceno_zkvm/src/e2e.rs | 8 +-- ceno_zkvm/src/instructions.rs | 6 +- .../gpu/{device_cache.rs => cache.rs} | 2 +- .../src/instructions/gpu/{ => chips}/add.rs | 10 +-- .../src/instructions/gpu/{ => chips}/addi.rs | 4 +- .../src/instructions/gpu/{ => chips}/auipc.rs | 4 +- .../gpu/{ => chips}/branch_cmp.rs | 4 +- .../instructions/gpu/{ => chips}/branch_eq.rs | 4 +- .../src/instructions/gpu/{ => chips}/div.rs | 2 +- .../src/instructions/gpu/{ => chips}/jal.rs | 4 +- .../src/instructions/gpu/{ => chips}/jalr.rs | 4 +- .../instructions/gpu/{ => chips}/keccak.rs | 4 +- .../instructions/gpu/{ => chips}/load_sub.rs | 2 +- .../instructions/gpu/{ => chips}/logic_i.rs | 4 +- .../instructions/gpu/{ => chips}/logic_r.rs | 10 +-- .../src/instructions/gpu/{ => chips}/lui.rs | 4 +- .../src/instructions/gpu/{ => chips}/lw.rs | 8 +-- ceno_zkvm/src/instructions/gpu/chips/mod.rs | 48 ++++++++++++++ .../src/instructions/gpu/{ => chips}/mul.rs | 2 +- .../src/instructions/gpu/{ => chips}/sb.rs | 4 +- .../src/instructions/gpu/{ => chips}/sh.rs | 4 +- .../instructions/gpu/{ => chips}/shard_ram.rs | 0 .../instructions/gpu/{ => chips}/shift_i.rs | 4 +- .../instructions/gpu/{ => chips}/shift_r.rs | 4 +- .../src/instructions/gpu/{ => chips}/slt.rs | 4 +- .../src/instructions/gpu/{ => chips}/slti.rs | 4 +- .../src/instructions/gpu/{ => chips}/sub.rs | 4 +- .../src/instructions/gpu/{ => chips}/sw.rs | 4 +- .../gpu/{gpu_config.rs => config.rs} | 2 +- .../gpu/{witgen_gpu.rs => dispatch.rs} | 62 +++++++++---------- ceno_zkvm/src/instructions/gpu/mod.rs | 58 ++--------------- .../gpu/{ => utils}/colmap_base.rs | 0 .../src/instructions/gpu/{ => utils}/d2h.rs | 0 .../gpu/{ => utils}/debug_compare.rs | 2 +- ceno_zkvm/src/instructions/gpu/utils/mod.rs | 7 +++ ceno_zkvm/src/instructions/riscv/arith.rs | 4 +- .../riscv/arith_imm/arith_imm_circuit_v2.rs | 2 +- ceno_zkvm/src/instructions/riscv/auipc.rs | 2 +- .../riscv/branch/branch_circuit_v2.rs | 12 ++-- .../instructions/riscv/div/div_circuit_v2.rs | 8 +-- .../src/instructions/riscv/ecall/keccak.rs | 2 +- .../src/instructions/riscv/jump/jal_v2.rs | 2 +- .../src/instructions/riscv/jump/jalr_v2.rs | 2 +- .../instructions/riscv/logic/logic_circuit.rs | 2 +- .../riscv/logic_imm/logic_imm_circuit_v2.rs | 2 +- ceno_zkvm/src/instructions/riscv/lui.rs | 2 +- .../src/instructions/riscv/memory/load.rs | 2 +- .../src/instructions/riscv/memory/load_v2.rs | 10 +-- .../src/instructions/riscv/memory/store_v2.rs | 6 +- .../riscv/mulh/mulh_circuit_v2.rs | 8 +-- .../riscv/shift/shift_circuit_v2.rs | 4 +- .../instructions/riscv/slt/slt_circuit_v2.rs | 2 +- .../riscv/slti/slti_circuit_v2.rs | 2 +- ceno_zkvm/src/structs.rs | 4 +- ceno_zkvm/src/tables/shard_ram.rs | 4 +- 55 files changed, 192 insertions(+), 187 deletions(-) rename ceno_zkvm/src/instructions/gpu/{device_cache.rs => cache.rs} (99%) rename ceno_zkvm/src/instructions/gpu/{ => chips}/add.rs (94%) rename ceno_zkvm/src/instructions/gpu/{ => chips}/addi.rs (95%) rename ceno_zkvm/src/instructions/gpu/{ => chips}/auipc.rs (96%) rename ceno_zkvm/src/instructions/gpu/{ => chips}/branch_cmp.rs (96%) rename ceno_zkvm/src/instructions/gpu/{ => chips}/branch_eq.rs (96%) rename ceno_zkvm/src/instructions/gpu/{ => chips}/div.rs (99%) rename ceno_zkvm/src/instructions/gpu/{ => chips}/jal.rs (96%) rename ceno_zkvm/src/instructions/gpu/{ => chips}/jalr.rs (97%) rename ceno_zkvm/src/instructions/gpu/{ => chips}/keccak.rs (98%) rename ceno_zkvm/src/instructions/gpu/{ => chips}/load_sub.rs (99%) rename ceno_zkvm/src/instructions/gpu/{ => chips}/logic_i.rs (96%) rename ceno_zkvm/src/instructions/gpu/{ => chips}/logic_r.rs (94%) rename ceno_zkvm/src/instructions/gpu/{ => chips}/lui.rs (96%) rename ceno_zkvm/src/instructions/gpu/{ => chips}/lw.rs (96%) create mode 100644 ceno_zkvm/src/instructions/gpu/chips/mod.rs rename ceno_zkvm/src/instructions/gpu/{ => chips}/mul.rs (99%) rename ceno_zkvm/src/instructions/gpu/{ => chips}/sb.rs (98%) rename ceno_zkvm/src/instructions/gpu/{ => chips}/sh.rs (97%) rename ceno_zkvm/src/instructions/gpu/{ => chips}/shard_ram.rs (100%) rename ceno_zkvm/src/instructions/gpu/{ => chips}/shift_i.rs (96%) rename ceno_zkvm/src/instructions/gpu/{ => chips}/shift_r.rs (96%) rename ceno_zkvm/src/instructions/gpu/{ => chips}/slt.rs (96%) rename ceno_zkvm/src/instructions/gpu/{ => chips}/slti.rs (96%) rename ceno_zkvm/src/instructions/gpu/{ => chips}/sub.rs (96%) rename ceno_zkvm/src/instructions/gpu/{ => chips}/sw.rs (97%) rename ceno_zkvm/src/instructions/gpu/{gpu_config.rs => config.rs} (99%) rename ceno_zkvm/src/instructions/gpu/{witgen_gpu.rs => dispatch.rs} (95%) rename ceno_zkvm/src/instructions/gpu/{ => utils}/colmap_base.rs (100%) rename ceno_zkvm/src/instructions/gpu/{ => utils}/d2h.rs (100%) rename ceno_zkvm/src/instructions/gpu/{ => utils}/debug_compare.rs (99%) diff --git a/ceno_zkvm/src/e2e.rs b/ceno_zkvm/src/e2e.rs index 612835cfe..c4e49e6c2 100644 --- a/ceno_zkvm/src/e2e.rs +++ b/ceno_zkvm/src/e2e.rs @@ -1511,7 +1511,7 @@ pub fn generate_witness<'a, E: ExtensionField>( // This batch-D2Hs accumulated EC records and addr_accessed into shard_ctx. #[cfg(feature = "gpu")] info_span!("flush_shared_ec").in_scope(|| { - crate::instructions::gpu::witgen_gpu::flush_shared_ec_buffers( + crate::instructions::gpu::dispatch::flush_shared_ec_buffers( &mut shard_ctx, ) }).unwrap(); @@ -1519,7 +1519,7 @@ pub fn generate_witness<'a, E: ExtensionField>( // Free GPU shard_steps cache after all opcode circuits are done. #[cfg(feature = "gpu")] { - crate::instructions::gpu::witgen_gpu::invalidate_shard_steps_cache(); + crate::instructions::gpu::dispatch::invalidate_shard_steps_cache(); if std::env::var_os("CENO_GPU_TRIM_AFTER_WITGEN").is_some() { use gkr_iop::gpu::gpu_prover::get_cuda_hal; @@ -1554,7 +1554,7 @@ pub fn generate_witness<'a, E: ExtensionField>( // Force CPU path for the debug comparison (thread-local, no env var races). #[cfg(feature = "gpu")] - crate::instructions::gpu::witgen_gpu::set_force_cpu_path(true); + crate::instructions::gpu::dispatch::set_force_cpu_path(true); system_config .config @@ -1579,7 +1579,7 @@ pub fn generate_witness<'a, E: ExtensionField>( cpu_witness.finalize_lk_multiplicities(); #[cfg(feature = "gpu")] - crate::instructions::gpu::witgen_gpu::set_force_cpu_path(false); + crate::instructions::gpu::dispatch::set_force_cpu_path(false); log_shard_ctx_diff("post_opcode_assignment", &cpu_shard_ctx, &shard_ctx); diff --git a/ceno_zkvm/src/instructions.rs b/ceno_zkvm/src/instructions.rs index ed95ce846..c9e4b534a 100644 --- a/ceno_zkvm/src/instructions.rs +++ b/ceno_zkvm/src/instructions.rs @@ -345,10 +345,10 @@ macro_rules! impl_gpu_assign { ), $crate::error::ZKVMError, > { - use $crate::instructions::gpu::witgen_gpu; - let gpu_kind: Option = $kind_expr; + use $crate::instructions::gpu::dispatch; + let gpu_kind: Option = $kind_expr; if let Some(kind) = gpu_kind { - if let Some(result) = witgen_gpu::try_gpu_assign_instances::( + if let Some(result) = dispatch::try_gpu_assign_instances::( config, shard_ctx, num_witin, diff --git a/ceno_zkvm/src/instructions/gpu/device_cache.rs b/ceno_zkvm/src/instructions/gpu/cache.rs similarity index 99% rename from ceno_zkvm/src/instructions/gpu/device_cache.rs rename to ceno_zkvm/src/instructions/gpu/cache.rs index d2d02c6c0..af437e97b 100644 --- a/ceno_zkvm/src/instructions/gpu/device_cache.rs +++ b/ceno_zkvm/src/instructions/gpu/cache.rs @@ -394,7 +394,7 @@ pub fn gpu_batch_continuation_ec_on_device( // Convert to GpuShardRamRecord format (writes first, reads after) let mut gpu_records: Vec = Vec::with_capacity(total); for (rec, _name) in write_records.iter().chain(read_records.iter()) { - gpu_records.push(super::d2h::shard_ram_record_to_gpu(rec)); + gpu_records.push(super::utils::d2h::shard_ram_record_to_gpu(rec)); } // GPU batch EC, results stay on device diff --git a/ceno_zkvm/src/instructions/gpu/add.rs b/ceno_zkvm/src/instructions/gpu/chips/add.rs similarity index 94% rename from ceno_zkvm/src/instructions/gpu/add.rs rename to ceno_zkvm/src/instructions/gpu/chips/add.rs index a83011cae..74ae5567d 100644 --- a/ceno_zkvm/src/instructions/gpu/add.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/add.rs @@ -1,7 +1,7 @@ use ceno_gpu::common::witgen::types::AddColumnMap; use ff_ext::ExtensionField; -use super::colmap_base::{extract_carries, extract_rd, extract_rs1, extract_rs2, extract_state, extract_uint_limbs}; +use crate::instructions::gpu::utils::colmap_base::{extract_carries, extract_rd, extract_rs1, extract_rs2, extract_state, extract_uint_limbs}; use crate::instructions::riscv::arith::ArithConfig; /// Extract column map from a constructed ArithConfig (ADD variant). @@ -131,7 +131,7 @@ mod tests { let col_map = extract_add_column_map(&config, cb.cs.num_witin as usize); let flat = col_map.to_flat(); - crate::instructions::gpu::colmap_base::validate_column_map(&flat, col_map.num_cols); + crate::instructions::gpu::utils::colmap_base::validate_column_map(&flat, col_map.num_cols); } #[test] @@ -213,7 +213,7 @@ mod tests { let mut shard_ctx_full_gpu = ShardContext::default(); let (gpu_rmms, gpu_lkm) = - crate::instructions::gpu::witgen_gpu::try_gpu_assign_instances::< + crate::instructions::gpu::dispatch::try_gpu_assign_instances::< E, AddInstruction, >( @@ -223,14 +223,14 @@ mod tests { num_structural_witin, &steps, &indices, - crate::instructions::gpu::witgen_gpu::GpuWitgenKind::Add, + crate::instructions::gpu::dispatch::GpuWitgenKind::Add, ) .unwrap() .expect("GPU path should be available"); // Flush shared EC/addr buffers from GPU device to shard_ctx // (in the e2e pipeline this is called once per shard after all opcode circuits) - crate::instructions::gpu::device_cache::flush_shared_ec_buffers( + crate::instructions::gpu::cache::flush_shared_ec_buffers( &mut shard_ctx_full_gpu, ) .unwrap(); diff --git a/ceno_zkvm/src/instructions/gpu/addi.rs b/ceno_zkvm/src/instructions/gpu/chips/addi.rs similarity index 95% rename from ceno_zkvm/src/instructions/gpu/addi.rs rename to ceno_zkvm/src/instructions/gpu/chips/addi.rs index d206f1a4e..81f0e721b 100644 --- a/ceno_zkvm/src/instructions/gpu/addi.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/addi.rs @@ -1,7 +1,7 @@ use ceno_gpu::common::witgen::types::AddiColumnMap; use ff_ext::ExtensionField; -use super::colmap_base::{extract_carries, extract_rd, extract_rs1, extract_state, extract_uint_limbs}; +use crate::instructions::gpu::utils::colmap_base::{extract_carries, extract_rd, extract_rs1, extract_state, extract_uint_limbs}; use crate::instructions::riscv::arith_imm::arith_imm_circuit_v2::InstructionConfig; /// Extract column map from a constructed InstructionConfig (ADDI v2). @@ -59,7 +59,7 @@ mod tests { let col_map = extract_addi_column_map(&config, cb.cs.num_witin as usize); let flat = col_map.to_flat(); - crate::instructions::gpu::colmap_base::validate_column_map(&flat, col_map.num_cols); + crate::instructions::gpu::utils::colmap_base::validate_column_map(&flat, col_map.num_cols); } #[test] diff --git a/ceno_zkvm/src/instructions/gpu/auipc.rs b/ceno_zkvm/src/instructions/gpu/chips/auipc.rs similarity index 96% rename from ceno_zkvm/src/instructions/gpu/auipc.rs rename to ceno_zkvm/src/instructions/gpu/chips/auipc.rs index 64d345ec7..74c97fdad 100644 --- a/ceno_zkvm/src/instructions/gpu/auipc.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/auipc.rs @@ -1,7 +1,7 @@ use ceno_gpu::common::witgen::types::AuipcColumnMap; use ff_ext::ExtensionField; -use super::colmap_base::{extract_rd, extract_rs1, extract_state, extract_uint_limbs}; +use crate::instructions::gpu::utils::colmap_base::{extract_rd, extract_rs1, extract_state, extract_uint_limbs}; use crate::instructions::riscv::auipc::AuipcConfig; /// Extract column map from a constructed AuipcConfig. @@ -61,7 +61,7 @@ mod tests { let col_map = extract_auipc_column_map(&config, cb.cs.num_witin as usize); let flat = col_map.to_flat(); - crate::instructions::gpu::colmap_base::validate_column_map(&flat, col_map.num_cols); + crate::instructions::gpu::utils::colmap_base::validate_column_map(&flat, col_map.num_cols); } #[test] diff --git a/ceno_zkvm/src/instructions/gpu/branch_cmp.rs b/ceno_zkvm/src/instructions/gpu/chips/branch_cmp.rs similarity index 96% rename from ceno_zkvm/src/instructions/gpu/branch_cmp.rs rename to ceno_zkvm/src/instructions/gpu/chips/branch_cmp.rs index b5716aa9c..53902abde 100644 --- a/ceno_zkvm/src/instructions/gpu/branch_cmp.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/branch_cmp.rs @@ -1,7 +1,7 @@ use ceno_gpu::common::witgen::types::BranchCmpColumnMap; use ff_ext::ExtensionField; -use super::colmap_base::{extract_rs1, extract_rs2, extract_state_branching, extract_uint_limbs}; +use crate::instructions::gpu::utils::colmap_base::{extract_rs1, extract_rs2, extract_state_branching, extract_uint_limbs}; use crate::instructions::riscv::branch::branch_circuit_v2::BranchConfig; /// Extract column map from a constructed BranchConfig (BLT/BGE/BLTU/BGEU variant). @@ -70,7 +70,7 @@ mod tests { let col_map = extract_branch_cmp_column_map(&config, cb.cs.num_witin as usize); let flat = col_map.to_flat(); - crate::instructions::gpu::colmap_base::validate_column_map(&flat, col_map.num_cols); + crate::instructions::gpu::utils::colmap_base::validate_column_map(&flat, col_map.num_cols); } #[test] diff --git a/ceno_zkvm/src/instructions/gpu/branch_eq.rs b/ceno_zkvm/src/instructions/gpu/chips/branch_eq.rs similarity index 96% rename from ceno_zkvm/src/instructions/gpu/branch_eq.rs rename to ceno_zkvm/src/instructions/gpu/chips/branch_eq.rs index 300fef35d..a720f47aa 100644 --- a/ceno_zkvm/src/instructions/gpu/branch_eq.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/branch_eq.rs @@ -1,7 +1,7 @@ use ceno_gpu::common::witgen::types::BranchEqColumnMap; use ff_ext::ExtensionField; -use super::colmap_base::{extract_rs1, extract_rs2, extract_state_branching, extract_uint_limbs}; +use crate::instructions::gpu::utils::colmap_base::{extract_rs1, extract_rs2, extract_state_branching, extract_uint_limbs}; use crate::instructions::riscv::branch::branch_circuit_v2::BranchConfig; /// Extract column map from a constructed BranchConfig (BEQ/BNE variant). @@ -63,7 +63,7 @@ mod tests { let col_map = extract_branch_eq_column_map(&config, cb.cs.num_witin as usize); let flat = col_map.to_flat(); - crate::instructions::gpu::colmap_base::validate_column_map(&flat, col_map.num_cols); + crate::instructions::gpu::utils::colmap_base::validate_column_map(&flat, col_map.num_cols); } #[test] diff --git a/ceno_zkvm/src/instructions/gpu/div.rs b/ceno_zkvm/src/instructions/gpu/chips/div.rs similarity index 99% rename from ceno_zkvm/src/instructions/gpu/div.rs rename to ceno_zkvm/src/instructions/gpu/chips/div.rs index 6d6bd11c9..8f7442934 100644 --- a/ceno_zkvm/src/instructions/gpu/div.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/div.rs @@ -1,7 +1,7 @@ use ceno_gpu::common::witgen::types::DivColumnMap; use ff_ext::ExtensionField; -use super::colmap_base::{extract_rd, extract_rs1, extract_rs2, extract_state, extract_uint_limbs}; +use crate::instructions::gpu::utils::colmap_base::{extract_rd, extract_rs1, extract_rs2, extract_state, extract_uint_limbs}; use crate::instructions::riscv::div::div_circuit_v2::DivRemConfig; /// Extract column map from a constructed DivRemConfig. diff --git a/ceno_zkvm/src/instructions/gpu/jal.rs b/ceno_zkvm/src/instructions/gpu/chips/jal.rs similarity index 96% rename from ceno_zkvm/src/instructions/gpu/jal.rs rename to ceno_zkvm/src/instructions/gpu/chips/jal.rs index dc7da0bb6..82dd428c4 100644 --- a/ceno_zkvm/src/instructions/gpu/jal.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/jal.rs @@ -1,7 +1,7 @@ use ceno_gpu::common::witgen::types::JalColumnMap; use ff_ext::ExtensionField; -use super::colmap_base::{extract_rd, extract_state_branching, extract_uint_limbs}; +use crate::instructions::gpu::utils::colmap_base::{extract_rd, extract_state_branching, extract_uint_limbs}; use crate::instructions::riscv::jump::jal_v2::JalConfig; /// Extract column map from a constructed JalConfig. @@ -49,7 +49,7 @@ mod tests { let col_map = extract_jal_column_map(&config, cb.cs.num_witin as usize); let flat = col_map.to_flat(); - crate::instructions::gpu::colmap_base::validate_column_map(&flat, col_map.num_cols); + crate::instructions::gpu::utils::colmap_base::validate_column_map(&flat, col_map.num_cols); } #[test] diff --git a/ceno_zkvm/src/instructions/gpu/jalr.rs b/ceno_zkvm/src/instructions/gpu/chips/jalr.rs similarity index 97% rename from ceno_zkvm/src/instructions/gpu/jalr.rs rename to ceno_zkvm/src/instructions/gpu/chips/jalr.rs index e68d961ca..762ae2cf4 100644 --- a/ceno_zkvm/src/instructions/gpu/jalr.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/jalr.rs @@ -1,7 +1,7 @@ use ceno_gpu::common::witgen::types::JalrColumnMap; use ff_ext::ExtensionField; -use super::colmap_base::{ +use crate::instructions::gpu::utils::colmap_base::{ extract_rd, extract_rs1, extract_state_branching, extract_uint_limbs, extract_wit_ids, }; use crate::instructions::riscv::jump::jalr_v2::JalrConfig; @@ -68,7 +68,7 @@ mod tests { let col_map = extract_jalr_column_map(&config, cb.cs.num_witin as usize); let flat = col_map.to_flat(); - crate::instructions::gpu::colmap_base::validate_column_map(&flat, col_map.num_cols); + crate::instructions::gpu::utils::colmap_base::validate_column_map(&flat, col_map.num_cols); } #[test] diff --git a/ceno_zkvm/src/instructions/gpu/keccak.rs b/ceno_zkvm/src/instructions/gpu/chips/keccak.rs similarity index 98% rename from ceno_zkvm/src/instructions/gpu/keccak.rs rename to ceno_zkvm/src/instructions/gpu/chips/keccak.rs index 2e0654156..2026a0a75 100644 --- a/ceno_zkvm/src/instructions/gpu/keccak.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/keccak.rs @@ -219,7 +219,7 @@ mod tests { let step_indices: Vec = vec![0]; // --- CPU path (force CPU via thread-local flag) --- - use super::super::witgen_gpu::set_force_cpu_path; + use crate::instructions::gpu::dispatch::set_force_cpu_path; set_force_cpu_path(true); let mut shard_ctx = ShardContext::default(); shard_ctx.syscall_witnesses = std::sync::Arc::new(syscall_witnesses.clone()); @@ -237,7 +237,7 @@ mod tests { let cpu_structural = &cpu_rmms[1]; // --- GPU path (full pipeline via gpu_assign_keccak_instances) --- - use super::super::witgen_gpu::gpu_assign_keccak_instances; + use crate::instructions::gpu::dispatch::gpu_assign_keccak_instances; let mut shard_ctx_gpu = ShardContext::default(); shard_ctx_gpu.syscall_witnesses = std::sync::Arc::new(syscall_witnesses); let (gpu_rmms, gpu_lk) = gpu_assign_keccak_instances::( diff --git a/ceno_zkvm/src/instructions/gpu/load_sub.rs b/ceno_zkvm/src/instructions/gpu/chips/load_sub.rs similarity index 99% rename from ceno_zkvm/src/instructions/gpu/load_sub.rs rename to ceno_zkvm/src/instructions/gpu/chips/load_sub.rs index df3d35937..7390ce113 100644 --- a/ceno_zkvm/src/instructions/gpu/load_sub.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/load_sub.rs @@ -1,7 +1,7 @@ use ceno_gpu::common::witgen::types::LoadSubColumnMap; use ff_ext::ExtensionField; -use super::colmap_base::{extract_rd, extract_read_mem, extract_rs1, extract_state, extract_uint_limbs}; +use crate::instructions::gpu::utils::colmap_base::{extract_rd, extract_read_mem, extract_rs1, extract_state, extract_uint_limbs}; use crate::instructions::riscv::memory::load_v2::LoadConfig; /// Extract column map from a constructed LoadConfig for sub-word loads (LH/LHU/LB/LBU). diff --git a/ceno_zkvm/src/instructions/gpu/logic_i.rs b/ceno_zkvm/src/instructions/gpu/chips/logic_i.rs similarity index 96% rename from ceno_zkvm/src/instructions/gpu/logic_i.rs rename to ceno_zkvm/src/instructions/gpu/chips/logic_i.rs index d7bfb068d..671cd90cf 100644 --- a/ceno_zkvm/src/instructions/gpu/logic_i.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/logic_i.rs @@ -1,7 +1,7 @@ use ceno_gpu::common::witgen::types::LogicIColumnMap; use ff_ext::ExtensionField; -use super::colmap_base::{extract_rd, extract_rs1, extract_state, extract_uint_limbs}; +use crate::instructions::gpu::utils::colmap_base::{extract_rd, extract_rs1, extract_state, extract_uint_limbs}; use crate::instructions::riscv::logic_imm::logic_imm_circuit_v2::LogicConfig; /// Extract column map from a constructed LogicConfig (I-type v2: ANDI/ORI/XORI). @@ -59,7 +59,7 @@ mod tests { let col_map = extract_logic_i_column_map(&config, cb.cs.num_witin as usize); let flat = col_map.to_flat(); - crate::instructions::gpu::colmap_base::validate_column_map(&flat, col_map.num_cols); + crate::instructions::gpu::utils::colmap_base::validate_column_map(&flat, col_map.num_cols); } #[test] diff --git a/ceno_zkvm/src/instructions/gpu/logic_r.rs b/ceno_zkvm/src/instructions/gpu/chips/logic_r.rs similarity index 94% rename from ceno_zkvm/src/instructions/gpu/logic_r.rs rename to ceno_zkvm/src/instructions/gpu/chips/logic_r.rs index bbc4b56e8..77ecf6af9 100644 --- a/ceno_zkvm/src/instructions/gpu/logic_r.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/logic_r.rs @@ -1,7 +1,7 @@ use ceno_gpu::common::witgen::types::LogicRColumnMap; use ff_ext::ExtensionField; -use super::colmap_base::{extract_rd, extract_rs1, extract_rs2, extract_state, extract_uint_limbs}; +use crate::instructions::gpu::utils::colmap_base::{extract_rd, extract_rs1, extract_rs2, extract_state, extract_uint_limbs}; use crate::instructions::riscv::logic::logic_circuit::LogicConfig; /// Extract column map from a constructed LogicConfig (R-type: AND/OR/XOR). @@ -89,7 +89,7 @@ mod tests { let col_map = extract_logic_r_column_map(&config, cb.cs.num_witin as usize); let flat = col_map.to_flat(); - crate::instructions::gpu::colmap_base::validate_column_map(&flat, col_map.num_cols); + crate::instructions::gpu::utils::colmap_base::validate_column_map(&flat, col_map.num_cols); } #[test] @@ -202,7 +202,7 @@ mod tests { let mut shard_ctx_full_gpu = ShardContext::default(); let (gpu_rmms, gpu_lkm) = - crate::instructions::gpu::witgen_gpu::try_gpu_assign_instances::< + crate::instructions::gpu::dispatch::try_gpu_assign_instances::< E, AndInstruction, >( @@ -212,12 +212,12 @@ mod tests { num_structural_witin, &steps, &indices, - crate::instructions::gpu::witgen_gpu::GpuWitgenKind::LogicR(0), + crate::instructions::gpu::dispatch::GpuWitgenKind::LogicR(0), ) .unwrap() .expect("GPU path should be available"); - crate::instructions::gpu::device_cache::flush_shared_ec_buffers( + crate::instructions::gpu::cache::flush_shared_ec_buffers( &mut shard_ctx_full_gpu, ) .unwrap(); diff --git a/ceno_zkvm/src/instructions/gpu/lui.rs b/ceno_zkvm/src/instructions/gpu/chips/lui.rs similarity index 96% rename from ceno_zkvm/src/instructions/gpu/lui.rs rename to ceno_zkvm/src/instructions/gpu/chips/lui.rs index 3d1d04166..715af5996 100644 --- a/ceno_zkvm/src/instructions/gpu/lui.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/lui.rs @@ -1,7 +1,7 @@ use ceno_gpu::common::witgen::types::LuiColumnMap; use ff_ext::ExtensionField; -use super::colmap_base::{extract_rd, extract_rs1, extract_state}; +use crate::instructions::gpu::utils::colmap_base::{extract_rd, extract_rs1, extract_state}; use crate::instructions::riscv::lui::LuiConfig; /// Extract column map from a constructed LuiConfig. @@ -60,7 +60,7 @@ mod tests { let col_map = extract_lui_column_map(&config, cb.cs.num_witin as usize); let flat = col_map.to_flat(); - crate::instructions::gpu::colmap_base::validate_column_map(&flat, col_map.num_cols); + crate::instructions::gpu::utils::colmap_base::validate_column_map(&flat, col_map.num_cols); } #[test] diff --git a/ceno_zkvm/src/instructions/gpu/lw.rs b/ceno_zkvm/src/instructions/gpu/chips/lw.rs similarity index 96% rename from ceno_zkvm/src/instructions/gpu/lw.rs rename to ceno_zkvm/src/instructions/gpu/chips/lw.rs index 56c73a55e..4ed847bfd 100644 --- a/ceno_zkvm/src/instructions/gpu/lw.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/lw.rs @@ -1,7 +1,7 @@ use ceno_gpu::common::witgen::types::LwColumnMap; use ff_ext::ExtensionField; -use super::colmap_base::{extract_rd, extract_read_mem, extract_rs1, extract_state, extract_uint_limbs}; +use crate::instructions::gpu::utils::colmap_base::{extract_rd, extract_read_mem, extract_rs1, extract_state, extract_uint_limbs}; #[cfg(not(feature = "u16limb_circuit"))] use crate::instructions::riscv::memory::load::LoadConfig; @@ -225,7 +225,7 @@ mod tests { let mut shard_ctx_full_gpu = ShardContext::default(); let (gpu_rmms, gpu_lkm) = - crate::instructions::gpu::witgen_gpu::try_gpu_assign_instances::< + crate::instructions::gpu::dispatch::try_gpu_assign_instances::< E, LwInstruction, >( @@ -235,12 +235,12 @@ mod tests { num_structural_witin, &steps, &indices, - crate::instructions::gpu::witgen_gpu::GpuWitgenKind::Lw, + crate::instructions::gpu::dispatch::GpuWitgenKind::Lw, ) .unwrap() .expect("GPU path should be available"); - crate::instructions::gpu::device_cache::flush_shared_ec_buffers( + crate::instructions::gpu::cache::flush_shared_ec_buffers( &mut shard_ctx_full_gpu, ) .unwrap(); diff --git a/ceno_zkvm/src/instructions/gpu/chips/mod.rs b/ceno_zkvm/src/instructions/gpu/chips/mod.rs new file mode 100644 index 000000000..ebb4d36ae --- /dev/null +++ b/ceno_zkvm/src/instructions/gpu/chips/mod.rs @@ -0,0 +1,48 @@ +#[cfg(feature = "gpu")] +pub mod add; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod addi; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod auipc; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod branch_cmp; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod branch_eq; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod div; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod jal; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod jalr; +#[cfg(feature = "gpu")] +pub mod keccak; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod load_sub; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod logic_i; +#[cfg(feature = "gpu")] +pub mod logic_r; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod lui; +#[cfg(feature = "gpu")] +pub mod lw; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod mul; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod sb; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod sh; +#[cfg(feature = "gpu")] +pub mod shard_ram; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod shift_i; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod shift_r; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod slt; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod slti; +#[cfg(feature = "gpu")] +pub mod sub; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod sw; diff --git a/ceno_zkvm/src/instructions/gpu/mul.rs b/ceno_zkvm/src/instructions/gpu/chips/mul.rs similarity index 99% rename from ceno_zkvm/src/instructions/gpu/mul.rs rename to ceno_zkvm/src/instructions/gpu/chips/mul.rs index 56b3c8b40..826c435b6 100644 --- a/ceno_zkvm/src/instructions/gpu/mul.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/mul.rs @@ -1,7 +1,7 @@ use ceno_gpu::common::witgen::types::MulColumnMap; use ff_ext::ExtensionField; -use super::colmap_base::{extract_rd, extract_rs1, extract_rs2, extract_state, extract_uint_limbs}; +use crate::instructions::gpu::utils::colmap_base::{extract_rd, extract_rs1, extract_rs2, extract_state, extract_uint_limbs}; use crate::instructions::riscv::mulh::mulh_circuit_v2::MulhConfig; /// Extract column map from a constructed MulhConfig. diff --git a/ceno_zkvm/src/instructions/gpu/sb.rs b/ceno_zkvm/src/instructions/gpu/chips/sb.rs similarity index 98% rename from ceno_zkvm/src/instructions/gpu/sb.rs rename to ceno_zkvm/src/instructions/gpu/chips/sb.rs index 35652a3e1..94cd947ac 100644 --- a/ceno_zkvm/src/instructions/gpu/sb.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/sb.rs @@ -1,7 +1,7 @@ use ceno_gpu::common::witgen::types::SbColumnMap; use ff_ext::ExtensionField; -use super::colmap_base::{ +use crate::instructions::gpu::utils::colmap_base::{ extract_rs1, extract_rs2, extract_state, extract_uint_limbs, extract_write_mem, }; use crate::instructions::riscv::memory::store_v2::StoreConfig; @@ -99,7 +99,7 @@ mod tests { let col_map = extract_sb_column_map(&config, cb.cs.num_witin as usize); let flat = col_map.to_flat(); - crate::instructions::gpu::colmap_base::validate_column_map(&flat, col_map.num_cols); + crate::instructions::gpu::utils::colmap_base::validate_column_map(&flat, col_map.num_cols); } #[test] diff --git a/ceno_zkvm/src/instructions/gpu/sh.rs b/ceno_zkvm/src/instructions/gpu/chips/sh.rs similarity index 97% rename from ceno_zkvm/src/instructions/gpu/sh.rs rename to ceno_zkvm/src/instructions/gpu/chips/sh.rs index 8e259d037..b2728eaa8 100644 --- a/ceno_zkvm/src/instructions/gpu/sh.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/sh.rs @@ -1,7 +1,7 @@ use ceno_gpu::common::witgen::types::ShColumnMap; use ff_ext::ExtensionField; -use super::colmap_base::{ +use crate::instructions::gpu::utils::colmap_base::{ extract_rs1, extract_rs2, extract_state, extract_uint_limbs, extract_write_mem, }; use crate::instructions::riscv::memory::store_v2::StoreConfig; @@ -76,7 +76,7 @@ mod tests { let col_map = extract_sh_column_map(&config, cb.cs.num_witin as usize); let flat = col_map.to_flat(); - crate::instructions::gpu::colmap_base::validate_column_map(&flat, col_map.num_cols); + crate::instructions::gpu::utils::colmap_base::validate_column_map(&flat, col_map.num_cols); } #[test] diff --git a/ceno_zkvm/src/instructions/gpu/shard_ram.rs b/ceno_zkvm/src/instructions/gpu/chips/shard_ram.rs similarity index 100% rename from ceno_zkvm/src/instructions/gpu/shard_ram.rs rename to ceno_zkvm/src/instructions/gpu/chips/shard_ram.rs diff --git a/ceno_zkvm/src/instructions/gpu/shift_i.rs b/ceno_zkvm/src/instructions/gpu/chips/shift_i.rs similarity index 96% rename from ceno_zkvm/src/instructions/gpu/shift_i.rs rename to ceno_zkvm/src/instructions/gpu/chips/shift_i.rs index 43afca628..ce20db6a9 100644 --- a/ceno_zkvm/src/instructions/gpu/shift_i.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/shift_i.rs @@ -1,7 +1,7 @@ use ceno_gpu::common::witgen::types::ShiftIColumnMap; use ff_ext::ExtensionField; -use super::colmap_base::{extract_rd, extract_rs1, extract_state, extract_uint_limbs}; +use crate::instructions::gpu::utils::colmap_base::{extract_rd, extract_rs1, extract_state, extract_uint_limbs}; use crate::instructions::riscv::shift::shift_circuit_v2::ShiftImmConfig; /// Extract column map from a constructed ShiftImmConfig (I-type: SLLI/SRLI/SRAI). @@ -72,7 +72,7 @@ mod tests { let col_map = extract_shift_i_column_map(&config, cb.cs.num_witin as usize); let flat = col_map.to_flat(); - crate::instructions::gpu::colmap_base::validate_column_map(&flat, col_map.num_cols); + crate::instructions::gpu::utils::colmap_base::validate_column_map(&flat, col_map.num_cols); } #[test] diff --git a/ceno_zkvm/src/instructions/gpu/shift_r.rs b/ceno_zkvm/src/instructions/gpu/chips/shift_r.rs similarity index 96% rename from ceno_zkvm/src/instructions/gpu/shift_r.rs rename to ceno_zkvm/src/instructions/gpu/chips/shift_r.rs index 54b3e582f..31a97e16f 100644 --- a/ceno_zkvm/src/instructions/gpu/shift_r.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/shift_r.rs @@ -1,7 +1,7 @@ use ceno_gpu::common::witgen::types::ShiftRColumnMap; use ff_ext::ExtensionField; -use super::colmap_base::{extract_rd, extract_rs1, extract_rs2, extract_state, extract_uint_limbs}; +use crate::instructions::gpu::utils::colmap_base::{extract_rd, extract_rs1, extract_rs2, extract_state, extract_uint_limbs}; use crate::instructions::riscv::shift::shift_circuit_v2::ShiftRTypeConfig; /// Extract column map from a constructed ShiftRTypeConfig (R-type: SLL/SRL/SRA). @@ -77,7 +77,7 @@ mod tests { let col_map = extract_shift_r_column_map(&config, cb.cs.num_witin as usize); let flat = col_map.to_flat(); - crate::instructions::gpu::colmap_base::validate_column_map(&flat, col_map.num_cols); + crate::instructions::gpu::utils::colmap_base::validate_column_map(&flat, col_map.num_cols); } #[test] diff --git a/ceno_zkvm/src/instructions/gpu/slt.rs b/ceno_zkvm/src/instructions/gpu/chips/slt.rs similarity index 96% rename from ceno_zkvm/src/instructions/gpu/slt.rs rename to ceno_zkvm/src/instructions/gpu/chips/slt.rs index 9d90432d7..2251f0a09 100644 --- a/ceno_zkvm/src/instructions/gpu/slt.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/slt.rs @@ -1,7 +1,7 @@ use ceno_gpu::common::witgen::types::SltColumnMap; use ff_ext::ExtensionField; -use super::colmap_base::{extract_rd, extract_rs1, extract_rs2, extract_state, extract_uint_limbs}; +use crate::instructions::gpu::utils::colmap_base::{extract_rd, extract_rs1, extract_rs2, extract_state, extract_uint_limbs}; use crate::instructions::riscv::slt::slt_circuit_v2::SetLessThanConfig; /// Extract column map from a constructed SetLessThanConfig (SLT/SLTU). @@ -73,7 +73,7 @@ mod tests { let col_map = extract_slt_column_map(&config, cb.cs.num_witin as usize); let flat = col_map.to_flat(); - crate::instructions::gpu::colmap_base::validate_column_map(&flat, col_map.num_cols); + crate::instructions::gpu::utils::colmap_base::validate_column_map(&flat, col_map.num_cols); } #[test] diff --git a/ceno_zkvm/src/instructions/gpu/slti.rs b/ceno_zkvm/src/instructions/gpu/chips/slti.rs similarity index 96% rename from ceno_zkvm/src/instructions/gpu/slti.rs rename to ceno_zkvm/src/instructions/gpu/chips/slti.rs index 2c8c8aa93..052054d2d 100644 --- a/ceno_zkvm/src/instructions/gpu/slti.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/slti.rs @@ -1,7 +1,7 @@ use ceno_gpu::common::witgen::types::SltiColumnMap; use ff_ext::ExtensionField; -use super::colmap_base::{extract_rd, extract_rs1, extract_state, extract_uint_limbs}; +use crate::instructions::gpu::utils::colmap_base::{extract_rd, extract_rs1, extract_state, extract_uint_limbs}; use crate::instructions::riscv::slti::slti_circuit_v2::SetLessThanImmConfig; /// Extract column map from a constructed SetLessThanImmConfig (SLTI/SLTIU). @@ -70,7 +70,7 @@ mod tests { let col_map = extract_slti_column_map(&config, cb.cs.num_witin as usize); let flat = col_map.to_flat(); - crate::instructions::gpu::colmap_base::validate_column_map(&flat, col_map.num_cols); + crate::instructions::gpu::utils::colmap_base::validate_column_map(&flat, col_map.num_cols); } #[test] diff --git a/ceno_zkvm/src/instructions/gpu/sub.rs b/ceno_zkvm/src/instructions/gpu/chips/sub.rs similarity index 96% rename from ceno_zkvm/src/instructions/gpu/sub.rs rename to ceno_zkvm/src/instructions/gpu/chips/sub.rs index 6b876e719..ef051cfe5 100644 --- a/ceno_zkvm/src/instructions/gpu/sub.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/sub.rs @@ -1,7 +1,7 @@ use ceno_gpu::common::witgen::types::SubColumnMap; use ff_ext::ExtensionField; -use super::colmap_base::{extract_carries, extract_rd, extract_rs1, extract_rs2, extract_state, extract_uint_limbs}; +use crate::instructions::gpu::utils::colmap_base::{extract_carries, extract_rd, extract_rs1, extract_rs2, extract_state, extract_uint_limbs}; use crate::instructions::riscv::arith::ArithConfig; /// Extract column map from a constructed ArithConfig (SUB variant). @@ -63,7 +63,7 @@ mod tests { let col_map = extract_sub_column_map(&config, cb.cs.num_witin as usize); let flat = col_map.to_flat(); - crate::instructions::gpu::colmap_base::validate_column_map(&flat, col_map.num_cols); + crate::instructions::gpu::utils::colmap_base::validate_column_map(&flat, col_map.num_cols); } #[test] diff --git a/ceno_zkvm/src/instructions/gpu/sw.rs b/ceno_zkvm/src/instructions/gpu/chips/sw.rs similarity index 97% rename from ceno_zkvm/src/instructions/gpu/sw.rs rename to ceno_zkvm/src/instructions/gpu/chips/sw.rs index dc7914747..abcda6a6e 100644 --- a/ceno_zkvm/src/instructions/gpu/sw.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/sw.rs @@ -1,7 +1,7 @@ use ceno_gpu::common::witgen::types::SwColumnMap; use ff_ext::ExtensionField; -use super::colmap_base::{ +use crate::instructions::gpu::utils::colmap_base::{ extract_rs1, extract_rs2, extract_state, extract_uint_limbs, extract_write_mem, }; use crate::instructions::riscv::memory::store_v2::StoreConfig; @@ -67,7 +67,7 @@ mod tests { let col_map = extract_sw_column_map(&config, cb.cs.num_witin as usize); let flat = col_map.to_flat(); - crate::instructions::gpu::colmap_base::validate_column_map(&flat, col_map.num_cols); + crate::instructions::gpu::utils::colmap_base::validate_column_map(&flat, col_map.num_cols); } #[test] diff --git a/ceno_zkvm/src/instructions/gpu/gpu_config.rs b/ceno_zkvm/src/instructions/gpu/config.rs similarity index 99% rename from ceno_zkvm/src/instructions/gpu/gpu_config.rs rename to ceno_zkvm/src/instructions/gpu/config.rs index 9fe72d4ed..864898ee3 100644 --- a/ceno_zkvm/src/instructions/gpu/gpu_config.rs +++ b/ceno_zkvm/src/instructions/gpu/config.rs @@ -2,7 +2,7 @@ /// environment-variable disable switches. /// /// Extracted from `witgen_gpu.rs` — pure code move, no behavioural changes. -use super::witgen_gpu::GpuWitgenKind; +use super::dispatch::GpuWitgenKind; pub(crate) fn kind_tag(kind: GpuWitgenKind) -> &'static str { match kind { diff --git a/ceno_zkvm/src/instructions/gpu/witgen_gpu.rs b/ceno_zkvm/src/instructions/gpu/dispatch.rs similarity index 95% rename from ceno_zkvm/src/instructions/gpu/witgen_gpu.rs rename to ceno_zkvm/src/instructions/gpu/dispatch.rs index a44b318d9..7aac6e3a4 100644 --- a/ceno_zkvm/src/instructions/gpu/witgen_gpu.rs +++ b/ceno_zkvm/src/instructions/gpu/dispatch.rs @@ -16,11 +16,11 @@ use std::cell::Cell; use tracing::info_span; use witness::{InstancePaddingStrategy, RowMajorMatrix}; -use super::debug_compare::{ +use super::utils::debug_compare::{ debug_compare_final_lk, debug_compare_keccak, debug_compare_shard_ec, debug_compare_shardram, debug_compare_witness, }; -use super::gpu_config::{ +use super::config::{ is_gpu_witgen_disabled, is_kind_disabled, kind_has_verified_lk, kind_has_verified_shard, }; use crate::{ @@ -80,17 +80,17 @@ pub enum GpuWitgenKind { } // Re-exports from device_cache module for external callers (e2e.rs, structs.rs). -pub use super::device_cache::{ +pub use super::cache::{ SharedDeviceBufferSet, flush_shared_ec_buffers, gpu_batch_continuation_ec_on_device, invalidate_shard_meta_cache, invalidate_shard_steps_cache, take_shared_device_buffers, }; // Re-export for external callers (structs.rs). -pub use super::d2h::gpu_batch_continuation_ec; -use super::d2h::{ +pub use super::utils::d2h::gpu_batch_continuation_ec; +use super::utils::d2h::{ CompactEcBuf, LkResult, RamBuf, WitResult, gpu_collect_shard_records, gpu_compact_ec_d2h, gpu_lk_counters_to_multiplicity, gpu_witness_to_rmm, }; -use super::device_cache::{ +use super::cache::{ ensure_shard_metadata_cached, read_shared_addr_count, read_shared_addr_range, upload_shard_steps_cached, with_cached_shard_meta, with_cached_shard_steps, }; @@ -376,7 +376,7 @@ fn gpu_assign_instances_inner>( Ok(([raw_witin, raw_structural], lk_multiplicity)) } -// Type aliases and D2H conversion functions live in super::d2h. +// Type aliases and D2H conversion functions live in super::utils::d2h. /// Compute fetch counter parameters from step data. pub(crate) fn compute_fetch_params( @@ -441,7 +441,7 @@ fn gpu_fill_witness>( as *const crate::instructions::riscv::arith::ArithConfig) }; let col_map = info_span!("col_map") - .in_scope(|| super::add::extract_add_column_map(arith_config, num_witin)); + .in_scope(|| super::chips::add::extract_add_column_map(arith_config, num_witin)); info_span!("hal_witgen_add").in_scope(|| { with_cached_shard_steps(|gpu_records| { with_cached_shard_meta(|shard_bufs| { @@ -471,7 +471,7 @@ fn gpu_fill_witness>( as *const crate::instructions::riscv::arith::ArithConfig) }; let col_map = info_span!("col_map") - .in_scope(|| super::sub::extract_sub_column_map(arith_config, num_witin)); + .in_scope(|| super::chips::sub::extract_sub_column_map(arith_config, num_witin)); info_span!("hal_witgen_sub").in_scope(|| { with_cached_shard_steps(|gpu_records| { with_cached_shard_meta(|shard_bufs| { @@ -501,7 +501,7 @@ fn gpu_fill_witness>( as *const crate::instructions::riscv::logic::logic_circuit::LogicConfig) }; let col_map = info_span!("col_map") - .in_scope(|| super::logic_r::extract_logic_r_column_map(logic_config, num_witin)); + .in_scope(|| super::chips::logic_r::extract_logic_r_column_map(logic_config, num_witin)); info_span!("hal_witgen_logic_r").in_scope(|| { with_cached_shard_steps(|gpu_records| { with_cached_shard_meta(|shard_bufs| { @@ -533,7 +533,7 @@ fn gpu_fill_witness>( as *const crate::instructions::riscv::logic_imm::logic_imm_circuit_v2::LogicConfig) }; let col_map = info_span!("col_map") - .in_scope(|| super::logic_i::extract_logic_i_column_map(logic_config, num_witin)); + .in_scope(|| super::chips::logic_i::extract_logic_i_column_map(logic_config, num_witin)); info_span!("hal_witgen_logic_i").in_scope(|| { with_cached_shard_steps(|gpu_records| { with_cached_shard_meta(|shard_bufs| { @@ -565,7 +565,7 @@ fn gpu_fill_witness>( as *const crate::instructions::riscv::arith_imm::arith_imm_circuit_v2::InstructionConfig) }; let col_map = info_span!("col_map") - .in_scope(|| super::addi::extract_addi_column_map(addi_config, num_witin)); + .in_scope(|| super::chips::addi::extract_addi_column_map(addi_config, num_witin)); info_span!("hal_witgen_addi").in_scope(|| { with_cached_shard_steps(|gpu_records| { with_cached_shard_meta(|shard_bufs| { @@ -596,7 +596,7 @@ fn gpu_fill_witness>( as *const crate::instructions::riscv::lui::LuiConfig) }; let col_map = info_span!("col_map") - .in_scope(|| super::lui::extract_lui_column_map(lui_config, num_witin)); + .in_scope(|| super::chips::lui::extract_lui_column_map(lui_config, num_witin)); info_span!("hal_witgen_lui").in_scope(|| { with_cached_shard_steps(|gpu_records| { with_cached_shard_meta(|shard_bufs| { @@ -625,7 +625,7 @@ fn gpu_fill_witness>( as *const crate::instructions::riscv::auipc::AuipcConfig) }; let col_map = info_span!("col_map") - .in_scope(|| super::auipc::extract_auipc_column_map(auipc_config, num_witin)); + .in_scope(|| super::chips::auipc::extract_auipc_column_map(auipc_config, num_witin)); info_span!("hal_witgen_auipc").in_scope(|| { with_cached_shard_steps(|gpu_records| { with_cached_shard_meta(|shard_bufs| { @@ -656,7 +656,7 @@ fn gpu_fill_witness>( as *const crate::instructions::riscv::jump::jal_v2::JalConfig) }; let col_map = info_span!("col_map") - .in_scope(|| super::jal::extract_jal_column_map(jal_config, num_witin)); + .in_scope(|| super::chips::jal::extract_jal_column_map(jal_config, num_witin)); info_span!("hal_witgen_jal").in_scope(|| { with_cached_shard_steps(|gpu_records| { with_cached_shard_meta(|shard_bufs| { @@ -687,7 +687,7 @@ fn gpu_fill_witness>( >) }; let col_map = info_span!("col_map") - .in_scope(|| super::shift_r::extract_shift_r_column_map(shift_config, num_witin)); + .in_scope(|| super::chips::shift_r::extract_shift_r_column_map(shift_config, num_witin)); info_span!("hal_witgen_shift_r").in_scope(|| { with_cached_shard_steps(|gpu_records| { with_cached_shard_meta(|shard_bufs| { @@ -721,7 +721,7 @@ fn gpu_fill_witness>( >) }; let col_map = info_span!("col_map") - .in_scope(|| super::shift_i::extract_shift_i_column_map(shift_config, num_witin)); + .in_scope(|| super::chips::shift_i::extract_shift_i_column_map(shift_config, num_witin)); info_span!("hal_witgen_shift_i").in_scope(|| { with_cached_shard_steps(|gpu_records| { with_cached_shard_meta(|shard_bufs| { @@ -753,7 +753,7 @@ fn gpu_fill_witness>( as *const crate::instructions::riscv::slt::slt_circuit_v2::SetLessThanConfig) }; let col_map = info_span!("col_map") - .in_scope(|| super::slt::extract_slt_column_map(slt_config, num_witin)); + .in_scope(|| super::chips::slt::extract_slt_column_map(slt_config, num_witin)); info_span!("hal_witgen_slt").in_scope(|| { with_cached_shard_steps(|gpu_records| { with_cached_shard_meta(|shard_bufs| { @@ -785,7 +785,7 @@ fn gpu_fill_witness>( as *const crate::instructions::riscv::slti::slti_circuit_v2::SetLessThanImmConfig) }; let col_map = info_span!("col_map") - .in_scope(|| super::slti::extract_slti_column_map(slti_config, num_witin)); + .in_scope(|| super::chips::slti::extract_slti_column_map(slti_config, num_witin)); info_span!("hal_witgen_slti").in_scope(|| { with_cached_shard_steps(|gpu_records| { with_cached_shard_meta(|shard_bufs| { @@ -819,7 +819,7 @@ fn gpu_fill_witness>( >) }; let col_map = info_span!("col_map").in_scope(|| { - super::branch_eq::extract_branch_eq_column_map(branch_config, num_witin) + super::chips::branch_eq::extract_branch_eq_column_map(branch_config, num_witin) }); info_span!("hal_witgen_branch_eq").in_scope(|| { with_cached_shard_steps(|gpu_records| { @@ -854,7 +854,7 @@ fn gpu_fill_witness>( >) }; let col_map = info_span!("col_map").in_scope(|| { - super::branch_cmp::extract_branch_cmp_column_map(branch_config, num_witin) + super::chips::branch_cmp::extract_branch_cmp_column_map(branch_config, num_witin) }); info_span!("hal_witgen_branch_cmp").in_scope(|| { with_cached_shard_steps(|gpu_records| { @@ -887,7 +887,7 @@ fn gpu_fill_witness>( as *const crate::instructions::riscv::jump::jalr_v2::JalrConfig) }; let col_map = info_span!("col_map") - .in_scope(|| super::jalr::extract_jalr_column_map(jalr_config, num_witin)); + .in_scope(|| super::chips::jalr::extract_jalr_column_map(jalr_config, num_witin)); info_span!("hal_witgen_jalr").in_scope(|| { with_cached_shard_steps(|gpu_records| { with_cached_shard_meta(|shard_bufs| { @@ -917,7 +917,7 @@ fn gpu_fill_witness>( }; let mem_max_bits = sw_config.memory_addr.max_bits as u32; let col_map = info_span!("col_map") - .in_scope(|| super::sw::extract_sw_column_map(sw_config, num_witin)); + .in_scope(|| super::chips::sw::extract_sw_column_map(sw_config, num_witin)); info_span!("hal_witgen_sw").in_scope(|| { with_cached_shard_steps(|gpu_records| { with_cached_shard_meta(|shard_bufs| { @@ -948,7 +948,7 @@ fn gpu_fill_witness>( }; let mem_max_bits = sh_config.memory_addr.max_bits as u32; let col_map = info_span!("col_map") - .in_scope(|| super::sh::extract_sh_column_map(sh_config, num_witin)); + .in_scope(|| super::chips::sh::extract_sh_column_map(sh_config, num_witin)); info_span!("hal_witgen_sh").in_scope(|| { with_cached_shard_steps(|gpu_records| { with_cached_shard_meta(|shard_bufs| { @@ -979,7 +979,7 @@ fn gpu_fill_witness>( }; let mem_max_bits = sb_config.memory_addr.max_bits as u32; let col_map = info_span!("col_map") - .in_scope(|| super::sb::extract_sb_column_map(sb_config, num_witin)); + .in_scope(|| super::chips::sb::extract_sb_column_map(sb_config, num_witin)); info_span!("hal_witgen_sb").in_scope(|| { with_cached_shard_steps(|gpu_records| { with_cached_shard_meta(|shard_bufs| { @@ -1014,7 +1014,7 @@ fn gpu_fill_witness>( let is_byte = load_width == 8; let is_signed_bool = is_signed != 0; let col_map = info_span!("col_map").in_scope(|| { - super::load_sub::extract_load_sub_column_map( + super::chips::load_sub::extract_load_sub_column_map( load_config, num_witin, is_byte, @@ -1055,7 +1055,7 @@ fn gpu_fill_witness>( as *const crate::instructions::riscv::mulh::mulh_circuit_v2::MulhConfig) }; let col_map = info_span!("col_map") - .in_scope(|| super::mul::extract_mul_column_map(mul_config, num_witin, mul_kind)); + .in_scope(|| super::chips::mul::extract_mul_column_map(mul_config, num_witin, mul_kind)); info_span!("hal_witgen_mul").in_scope(|| { with_cached_shard_steps(|gpu_records| { with_cached_shard_meta(|shard_bufs| { @@ -1087,7 +1087,7 @@ fn gpu_fill_witness>( as *const crate::instructions::riscv::div::div_circuit_v2::DivRemConfig) }; let col_map = info_span!("col_map") - .in_scope(|| super::div::extract_div_column_map(div_config, num_witin)); + .in_scope(|| super::chips::div::extract_div_column_map(div_config, num_witin)); info_span!("hal_witgen_div").in_scope(|| { with_cached_shard_steps(|gpu_records| { with_cached_shard_meta(|shard_bufs| { @@ -1125,7 +1125,7 @@ fn gpu_fill_witness>( }; let mem_max_bits = load_config.memory_addr.max_bits as u32; let col_map = info_span!("col_map") - .in_scope(|| super::lw::extract_lw_column_map(load_config, num_witin)); + .in_scope(|| super::chips::lw::extract_lw_column_map(load_config, num_witin)); info_span!("hal_witgen_lw").in_scope(|| { with_cached_shard_steps(|gpu_records| { with_cached_shard_meta(|shard_bufs| { @@ -1273,12 +1273,12 @@ fn gpu_assign_keccak_inner( // Step 1: Extract column map let col_map = info_span!("col_map") - .in_scope(|| super::keccak::extract_keccak_column_map(config, num_witin)); + .in_scope(|| super::chips::keccak::extract_keccak_column_map(config, num_witin)); // Step 2: Pack instances let packed_instances = info_span!("pack_instances") .in_scope(|| { - super::keccak::pack_keccak_instances( + super::chips::keccak::pack_keccak_instances( steps, step_indices, &shard_ctx.syscall_witnesses, diff --git a/ceno_zkvm/src/instructions/gpu/mod.rs b/ceno_zkvm/src/instructions/gpu/mod.rs index 3983cfcd7..aa8eae109 100644 --- a/ceno_zkvm/src/instructions/gpu/mod.rs +++ b/ceno_zkvm/src/instructions/gpu/mod.rs @@ -1,64 +1,14 @@ pub mod utils; +pub mod chips; #[cfg(feature = "gpu")] -pub mod add; -#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] -pub mod addi; -#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] -pub mod auipc; -#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] -pub mod branch_cmp; -#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] -pub mod branch_eq; -#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] -pub mod div; -#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] -pub mod jal; -#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] -pub mod jalr; -#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] -pub mod load_sub; -#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] -pub mod logic_i; #[cfg(feature = "gpu")] -pub mod logic_r; -#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] -pub mod lui; #[cfg(feature = "gpu")] -pub mod lw; -#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] -pub mod mul; -#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] -pub mod sb; -#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] -pub mod sh; -#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] -pub mod shift_i; -#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] -pub mod shift_r; -#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] -pub mod slt; -#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] -pub mod slti; +pub mod config; #[cfg(feature = "gpu")] -pub mod sub; -#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] -pub mod sw; #[cfg(feature = "gpu")] -pub mod keccak; +pub mod cache; #[cfg(feature = "gpu")] -pub mod shard_ram; -#[cfg(feature = "gpu")] -pub mod colmap_base; -#[cfg(feature = "gpu")] -pub mod debug_compare; -#[cfg(feature = "gpu")] -pub mod gpu_config; -#[cfg(feature = "gpu")] -pub mod d2h; -#[cfg(feature = "gpu")] -pub mod device_cache; -#[cfg(feature = "gpu")] -pub mod witgen_gpu; +pub mod dispatch; #[cfg(test)] mod tests { diff --git a/ceno_zkvm/src/instructions/gpu/colmap_base.rs b/ceno_zkvm/src/instructions/gpu/utils/colmap_base.rs similarity index 100% rename from ceno_zkvm/src/instructions/gpu/colmap_base.rs rename to ceno_zkvm/src/instructions/gpu/utils/colmap_base.rs diff --git a/ceno_zkvm/src/instructions/gpu/d2h.rs b/ceno_zkvm/src/instructions/gpu/utils/d2h.rs similarity index 100% rename from ceno_zkvm/src/instructions/gpu/d2h.rs rename to ceno_zkvm/src/instructions/gpu/utils/d2h.rs diff --git a/ceno_zkvm/src/instructions/gpu/debug_compare.rs b/ceno_zkvm/src/instructions/gpu/utils/debug_compare.rs similarity index 99% rename from ceno_zkvm/src/instructions/gpu/debug_compare.rs rename to ceno_zkvm/src/instructions/gpu/utils/debug_compare.rs index 12b9f3550..43f3fe91d 100644 --- a/ceno_zkvm/src/instructions/gpu/debug_compare.rs +++ b/ceno_zkvm/src/instructions/gpu/utils/debug_compare.rs @@ -20,7 +20,7 @@ use crate::{ instructions::{Instruction, cpu_collect_shardram, cpu_collect_lk_and_shardram}, }; -use super::witgen_gpu::{GpuWitgenKind, set_force_cpu_path}; +use crate::instructions::gpu::dispatch::{GpuWitgenKind, set_force_cpu_path}; pub(crate) fn debug_compare_final_lk>( config: &I::InstructionConfig, diff --git a/ceno_zkvm/src/instructions/gpu/utils/mod.rs b/ceno_zkvm/src/instructions/gpu/utils/mod.rs index 93798a7d3..0c120c4fa 100644 --- a/ceno_zkvm/src/instructions/gpu/utils/mod.rs +++ b/ceno_zkvm/src/instructions/gpu/utils/mod.rs @@ -13,6 +13,13 @@ pub use sink::*; pub use emit::*; pub use fallback::*; +#[cfg(feature = "gpu")] +pub mod colmap_base; +#[cfg(feature = "gpu")] +pub mod d2h; +#[cfg(feature = "gpu")] +pub mod debug_compare; + #[cfg(test)] mod tests { diff --git a/ceno_zkvm/src/instructions/riscv/arith.rs b/ceno_zkvm/src/instructions/riscv/arith.rs index 07b2d6500..b179645b9 100644 --- a/ceno_zkvm/src/instructions/riscv/arith.rs +++ b/ceno_zkvm/src/instructions/riscv/arith.rs @@ -160,8 +160,8 @@ impl Instruction for ArithInstruction Some(witgen_gpu::GpuWitgenKind::Add), - InsnKind::SUB => Some(witgen_gpu::GpuWitgenKind::Sub), + InsnKind::ADD => Some(dispatch::GpuWitgenKind::Add), + InsnKind::SUB => Some(dispatch::GpuWitgenKind::Sub), _ => None, }); } diff --git a/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit_v2.rs index c87c033b4..7ba58cea2 100644 --- a/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit_v2.rs @@ -115,5 +115,5 @@ impl Instruction for AddiInstruction { impl_collect_shardram!(i_insn); - impl_gpu_assign!(witgen_gpu::GpuWitgenKind::Addi); + impl_gpu_assign!(dispatch::GpuWitgenKind::Addi); } diff --git a/ceno_zkvm/src/instructions/riscv/auipc.rs b/ceno_zkvm/src/instructions/riscv/auipc.rs index f3d642794..cf40c1eff 100644 --- a/ceno_zkvm/src/instructions/riscv/auipc.rs +++ b/ceno_zkvm/src/instructions/riscv/auipc.rs @@ -223,7 +223,7 @@ impl Instruction for AuipcInstruction { impl_collect_shardram!(i_insn); - impl_gpu_assign!(witgen_gpu::GpuWitgenKind::Auipc); + impl_gpu_assign!(dispatch::GpuWitgenKind::Auipc); } #[cfg(test)] diff --git a/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs index b7919af31..e0c3d2fc1 100644 --- a/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs @@ -227,12 +227,12 @@ impl Instruction for BranchCircuit Some(witgen_gpu::GpuWitgenKind::BranchEq(1)), - InsnKind::BNE => Some(witgen_gpu::GpuWitgenKind::BranchEq(0)), - InsnKind::BLT => Some(witgen_gpu::GpuWitgenKind::BranchCmp(1)), - InsnKind::BGE => Some(witgen_gpu::GpuWitgenKind::BranchCmp(1)), - InsnKind::BLTU => Some(witgen_gpu::GpuWitgenKind::BranchCmp(0)), - InsnKind::BGEU => Some(witgen_gpu::GpuWitgenKind::BranchCmp(0)), + InsnKind::BEQ => Some(dispatch::GpuWitgenKind::BranchEq(1)), + InsnKind::BNE => Some(dispatch::GpuWitgenKind::BranchEq(0)), + InsnKind::BLT => Some(dispatch::GpuWitgenKind::BranchCmp(1)), + InsnKind::BGE => Some(dispatch::GpuWitgenKind::BranchCmp(1)), + InsnKind::BLTU => Some(dispatch::GpuWitgenKind::BranchCmp(0)), + InsnKind::BGEU => Some(dispatch::GpuWitgenKind::BranchCmp(0)), _ => unreachable!(), }); } diff --git a/ceno_zkvm/src/instructions/riscv/div/div_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/div/div_circuit_v2.rs index 3ed1bd071..c0e3b2381 100644 --- a/ceno_zkvm/src/instructions/riscv/div/div_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/div/div_circuit_v2.rs @@ -384,10 +384,10 @@ impl Instruction for ArithInstruction Some(witgen_gpu::GpuWitgenKind::Div(0u32)), - InsnKind::DIVU => Some(witgen_gpu::GpuWitgenKind::Div(1u32)), - InsnKind::REM => Some(witgen_gpu::GpuWitgenKind::Div(2u32)), - InsnKind::REMU => Some(witgen_gpu::GpuWitgenKind::Div(3u32)), + InsnKind::DIV => Some(dispatch::GpuWitgenKind::Div(0u32)), + InsnKind::DIVU => Some(dispatch::GpuWitgenKind::Div(1u32)), + InsnKind::REM => Some(dispatch::GpuWitgenKind::Div(2u32)), + InsnKind::REMU => Some(dispatch::GpuWitgenKind::Div(3u32)), _ => None, }); diff --git a/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs b/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs index f94c274cb..5ed5455d0 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs @@ -180,7 +180,7 @@ impl Instruction for KeccakInstruction { ) -> Result<(RMMCollections, Multiplicity), ZKVMError> { #[cfg(feature = "gpu")] { - use crate::instructions::gpu::witgen_gpu::gpu_assign_keccak_instances; + use crate::instructions::gpu::dispatch::gpu_assign_keccak_instances; if let Some(result) = gpu_assign_keccak_instances::( config, shard_ctx, diff --git a/ceno_zkvm/src/instructions/riscv/jump/jal_v2.rs b/ceno_zkvm/src/instructions/riscv/jump/jal_v2.rs index fac4243b9..f84cc9696 100644 --- a/ceno_zkvm/src/instructions/riscv/jump/jal_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/jump/jal_v2.rs @@ -141,5 +141,5 @@ impl Instruction for JalInstruction { impl_collect_shardram!(j_insn); - impl_gpu_assign!(witgen_gpu::GpuWitgenKind::Jal); + impl_gpu_assign!(dispatch::GpuWitgenKind::Jal); } diff --git a/ceno_zkvm/src/instructions/riscv/jump/jalr_v2.rs b/ceno_zkvm/src/instructions/riscv/jump/jalr_v2.rs index ac8ddfe20..e133033b5 100644 --- a/ceno_zkvm/src/instructions/riscv/jump/jalr_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/jump/jalr_v2.rs @@ -206,5 +206,5 @@ impl Instruction for JalrInstruction { impl_collect_shardram!(i_insn); - impl_gpu_assign!(witgen_gpu::GpuWitgenKind::Jalr); + impl_gpu_assign!(dispatch::GpuWitgenKind::Jalr); } diff --git a/ceno_zkvm/src/instructions/riscv/logic/logic_circuit.rs b/ceno_zkvm/src/instructions/riscv/logic/logic_circuit.rs index a233bdb6c..2b42cdfaf 100644 --- a/ceno_zkvm/src/instructions/riscv/logic/logic_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/logic/logic_circuit.rs @@ -88,7 +88,7 @@ impl Instruction for LogicInstruction { impl_collect_shardram!(r_insn); - impl_gpu_assign!(witgen_gpu::GpuWitgenKind::LogicR(match I::INST_KIND { + impl_gpu_assign!(dispatch::GpuWitgenKind::LogicR(match I::INST_KIND { InsnKind::AND => 0, InsnKind::OR => 1, InsnKind::XOR => 2, diff --git a/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit_v2.rs index 010c899fb..d9624b68f 100644 --- a/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit_v2.rs @@ -143,7 +143,7 @@ impl Instruction for LogicInstruction { impl_collect_shardram!(i_insn); - impl_gpu_assign!(witgen_gpu::GpuWitgenKind::LogicI(match I::INST_KIND { + impl_gpu_assign!(dispatch::GpuWitgenKind::LogicI(match I::INST_KIND { InsnKind::ANDI => 0, InsnKind::ORI => 1, InsnKind::XORI => 2, diff --git a/ceno_zkvm/src/instructions/riscv/lui.rs b/ceno_zkvm/src/instructions/riscv/lui.rs index 4b491fb62..9d8e67f95 100644 --- a/ceno_zkvm/src/instructions/riscv/lui.rs +++ b/ceno_zkvm/src/instructions/riscv/lui.rs @@ -127,7 +127,7 @@ impl Instruction for LuiInstruction { impl_collect_shardram!(i_insn); - impl_gpu_assign!(witgen_gpu::GpuWitgenKind::Lui); + impl_gpu_assign!(dispatch::GpuWitgenKind::Lui); } #[cfg(test)] diff --git a/ceno_zkvm/src/instructions/riscv/memory/load.rs b/ceno_zkvm/src/instructions/riscv/memory/load.rs index d3fe72181..b19e94c01 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/load.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/load.rs @@ -242,7 +242,7 @@ impl Instruction for LoadInstruction Some(witgen_gpu::GpuWitgenKind::Lw), + InsnKind::LW => Some(dispatch::GpuWitgenKind::Lw), _ => None, }); } diff --git a/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs b/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs index 6ce464ed9..e6ec9d6c7 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs @@ -272,20 +272,20 @@ impl Instruction for LoadInstruction Some(witgen_gpu::GpuWitgenKind::Lw), - InsnKind::LH => Some(witgen_gpu::GpuWitgenKind::LoadSub { + InsnKind::LW => Some(dispatch::GpuWitgenKind::Lw), + InsnKind::LH => Some(dispatch::GpuWitgenKind::LoadSub { load_width: 16, is_signed: 1, }), - InsnKind::LHU => Some(witgen_gpu::GpuWitgenKind::LoadSub { + InsnKind::LHU => Some(dispatch::GpuWitgenKind::LoadSub { load_width: 16, is_signed: 0, }), - InsnKind::LB => Some(witgen_gpu::GpuWitgenKind::LoadSub { + InsnKind::LB => Some(dispatch::GpuWitgenKind::LoadSub { load_width: 8, is_signed: 1, }), - InsnKind::LBU => Some(witgen_gpu::GpuWitgenKind::LoadSub { + InsnKind::LBU => Some(dispatch::GpuWitgenKind::LoadSub { load_width: 8, is_signed: 0, }), diff --git a/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs b/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs index 479e0e22c..05851299d 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs @@ -204,9 +204,9 @@ impl Instruction impl_collect_shardram!(s_insn); impl_gpu_assign!(match I::INST_KIND { - InsnKind::SW => Some(witgen_gpu::GpuWitgenKind::Sw), - InsnKind::SH => Some(witgen_gpu::GpuWitgenKind::Sh), - InsnKind::SB => Some(witgen_gpu::GpuWitgenKind::Sb), + InsnKind::SW => Some(dispatch::GpuWitgenKind::Sw), + InsnKind::SH => Some(dispatch::GpuWitgenKind::Sh), + InsnKind::SB => Some(dispatch::GpuWitgenKind::Sb), _ => None, }); } diff --git a/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit_v2.rs index 31765b607..7cc6ce0e9 100644 --- a/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit_v2.rs @@ -416,10 +416,10 @@ impl Instruction for MulhInstructionBas impl_collect_shardram!(r_insn); impl_gpu_assign!(match I::INST_KIND { - InsnKind::MUL => Some(witgen_gpu::GpuWitgenKind::Mul(0u32)), - InsnKind::MULH => Some(witgen_gpu::GpuWitgenKind::Mul(1u32)), - InsnKind::MULHU => Some(witgen_gpu::GpuWitgenKind::Mul(2u32)), - InsnKind::MULHSU => Some(witgen_gpu::GpuWitgenKind::Mul(3u32)), + InsnKind::MUL => Some(dispatch::GpuWitgenKind::Mul(0u32)), + InsnKind::MULH => Some(dispatch::GpuWitgenKind::Mul(1u32)), + InsnKind::MULHU => Some(dispatch::GpuWitgenKind::Mul(2u32)), + InsnKind::MULHSU => Some(dispatch::GpuWitgenKind::Mul(3u32)), _ => None, }); } diff --git a/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs index ec24dc4ff..90f9ac102 100644 --- a/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs @@ -423,7 +423,7 @@ impl Instruction for ShiftLogicalInstru impl_collect_shardram!(r_insn); - impl_gpu_assign!(witgen_gpu::GpuWitgenKind::ShiftR(match I::INST_KIND { + impl_gpu_assign!(dispatch::GpuWitgenKind::ShiftR(match I::INST_KIND { InsnKind::SLL => 0u32, InsnKind::SRL => 1u32, InsnKind::SRA => 2u32, @@ -548,7 +548,7 @@ impl Instruction for ShiftImmInstructio impl_collect_shardram!(i_insn); - impl_gpu_assign!(witgen_gpu::GpuWitgenKind::ShiftI(match I::INST_KIND { + impl_gpu_assign!(dispatch::GpuWitgenKind::ShiftI(match I::INST_KIND { InsnKind::SLLI => 0u32, InsnKind::SRLI => 1u32, InsnKind::SRAI => 2u32, diff --git a/ceno_zkvm/src/instructions/riscv/slt/slt_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/slt/slt_circuit_v2.rs index 749f9e8a6..a61d30578 100644 --- a/ceno_zkvm/src/instructions/riscv/slt/slt_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/slt/slt_circuit_v2.rs @@ -133,7 +133,7 @@ impl Instruction for SetLessThanInstruc impl_collect_shardram!(r_insn); - impl_gpu_assign!(witgen_gpu::GpuWitgenKind::Slt(match I::INST_KIND { + impl_gpu_assign!(dispatch::GpuWitgenKind::Slt(match I::INST_KIND { InsnKind::SLT => 1u32, InsnKind::SLTU => 0u32, _ => unreachable!(), diff --git a/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs index 081e432cf..5505f8c5d 100644 --- a/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs @@ -152,7 +152,7 @@ impl Instruction for SetLessThanImmInst impl_collect_shardram!(i_insn); - impl_gpu_assign!(witgen_gpu::GpuWitgenKind::Slti(match I::INST_KIND { + impl_gpu_assign!(dispatch::GpuWitgenKind::Slti(match I::INST_KIND { InsnKind::SLTI => 1u32, InsnKind::SLTIU => 0u32, _ => unreachable!(), diff --git a/ceno_zkvm/src/structs.rs b/ceno_zkvm/src/structs.rs index 80bc04304..f296de9c2 100644 --- a/ceno_zkvm/src/structs.rs +++ b/ceno_zkvm/src/structs.rs @@ -587,7 +587,7 @@ impl ZKVMWitnesses { let global_input = { #[cfg(feature = "gpu")] let ec_result = { - use crate::instructions::gpu::witgen_gpu::gpu_batch_continuation_ec; + use crate::instructions::gpu::dispatch::gpu_batch_continuation_ec; gpu_batch_continuation_ec::(&write_record_pairs, &read_record_pairs) .ok() }; @@ -733,7 +733,7 @@ impl ZKVMWitnesses { final_mem: &[(&'static str, Option>, &[MemFinalRecord])], config: & as TableCircuit>::TableConfig, ) -> Result { - use crate::instructions::gpu::witgen_gpu::{ + use crate::instructions::gpu::dispatch::{ gpu_batch_continuation_ec_on_device, take_shared_device_buffers, }; use ceno_gpu::Buffer; diff --git a/ceno_zkvm/src/tables/shard_ram.rs b/ceno_zkvm/src/tables/shard_ram.rs index 1c82ee198..dd8d06afd 100644 --- a/ceno_zkvm/src/tables/shard_ram.rs +++ b/ceno_zkvm/src/tables/shard_ram.rs @@ -720,7 +720,7 @@ impl ShardRamCircuit { .collect()); // 2. Extract column map - let col_map = crate::instructions::gpu::shard_ram::extract_shard_ram_column_map( + let col_map = crate::instructions::gpu::chips::shard_ram::extract_shard_ram_column_map( config, num_witin, ); @@ -931,7 +931,7 @@ impl ShardRamCircuit { let num_rows_padded = 2 * n; // 1. Extract column map (same as regular path) - let col_map = crate::instructions::gpu::shard_ram::extract_shard_ram_column_map( + let col_map = crate::instructions::gpu::chips::shard_ram::extract_shard_ram_column_map( config, num_witin, ); From bc6f3ebb9d42e996c740abcd3fc9e250d6391679 Mon Sep 17 00:00:00 2001 From: Velaciela Date: Tue, 24 Mar 2026 18:38:11 +0800 Subject: [PATCH 64/73] lints, fmt --- ceno_emul/src/lib.rs | 3 +- ceno_emul/src/test_utils.rs | 4 +- ceno_emul/src/tracer.rs | 15 +- ceno_zkvm/src/e2e.rs | 125 +- ceno_zkvm/src/gadgets/poseidon2.rs | 2 +- ceno_zkvm/src/gadgets/signed_ext.rs | 1 + ceno_zkvm/src/instructions.rs | 4 +- ceno_zkvm/src/instructions/gpu/cache.rs | 124 +- ceno_zkvm/src/instructions/gpu/chips/add.rs | 32 +- ceno_zkvm/src/instructions/gpu/chips/addi.rs | 22 +- ceno_zkvm/src/instructions/gpu/chips/auipc.rs | 20 +- .../src/instructions/gpu/chips/branch_cmp.rs | 23 +- .../src/instructions/gpu/chips/branch_eq.rs | 23 +- ceno_zkvm/src/instructions/gpu/chips/div.rs | 14 +- ceno_zkvm/src/instructions/gpu/chips/jal.rs | 20 +- ceno_zkvm/src/instructions/gpu/chips/jalr.rs | 25 +- .../src/instructions/gpu/chips/keccak.rs | 26 +- .../src/instructions/gpu/chips/load_sub.rs | 11 +- .../src/instructions/gpu/chips/logic_i.rs | 21 +- .../src/instructions/gpu/chips/logic_r.rs | 33 +- ceno_zkvm/src/instructions/gpu/chips/lui.rs | 20 +- ceno_zkvm/src/instructions/gpu/chips/lw.rs | 29 +- ceno_zkvm/src/instructions/gpu/chips/mul.rs | 11 +- ceno_zkvm/src/instructions/gpu/chips/sb.rs | 26 +- ceno_zkvm/src/instructions/gpu/chips/sh.rs | 26 +- .../src/instructions/gpu/chips/shard_ram.rs | 3 +- .../src/instructions/gpu/chips/shift_i.rs | 21 +- .../src/instructions/gpu/chips/shift_r.rs | 23 +- ceno_zkvm/src/instructions/gpu/chips/slt.rs | 23 +- ceno_zkvm/src/instructions/gpu/chips/slti.rs | 21 +- ceno_zkvm/src/instructions/gpu/chips/sub.rs | 22 +- ceno_zkvm/src/instructions/gpu/chips/sw.rs | 26 +- ceno_zkvm/src/instructions/gpu/config.rs | 7 +- ceno_zkvm/src/instructions/gpu/dispatch.rs | 1043 +++++++++-------- ceno_zkvm/src/instructions/gpu/mod.rs | 53 +- .../src/instructions/gpu/utils/colmap_base.rs | 29 +- ceno_zkvm/src/instructions/gpu/utils/d2h.rs | 61 +- .../instructions/gpu/utils/debug_compare.rs | 169 ++- .../src/instructions/gpu/utils/fallback.rs | 12 +- ceno_zkvm/src/instructions/gpu/utils/mod.rs | 9 +- ceno_zkvm/src/instructions/gpu/utils/sink.rs | 3 + ceno_zkvm/src/instructions/riscv.rs | 1 - ceno_zkvm/src/instructions/riscv/arith.rs | 7 +- .../riscv/arith_imm/arith_imm_circuit_v2.rs | 4 +- ceno_zkvm/src/instructions/riscv/auipc.rs | 7 +- ceno_zkvm/src/instructions/riscv/b_insn.rs | 2 +- .../riscv/branch/branch_circuit_v2.rs | 8 +- .../instructions/riscv/div/div_circuit_v2.rs | 4 +- ceno_zkvm/src/instructions/riscv/i_insn.rs | 2 +- ceno_zkvm/src/instructions/riscv/im_insn.rs | 2 +- ceno_zkvm/src/instructions/riscv/insn_base.rs | 16 +- ceno_zkvm/src/instructions/riscv/j_insn.rs | 2 +- .../src/instructions/riscv/jump/jal_v2.rs | 4 +- .../src/instructions/riscv/jump/jalr_v2.rs | 4 +- .../instructions/riscv/logic/logic_circuit.rs | 13 +- .../riscv/logic_imm/logic_imm_circuit_v2.rs | 4 +- ceno_zkvm/src/instructions/riscv/lui.rs | 4 +- .../src/instructions/riscv/memory/load.rs | 2 +- .../src/instructions/riscv/memory/load_v2.rs | 2 +- .../src/instructions/riscv/memory/store_v2.rs | 8 +- .../riscv/mulh/mulh_circuit_v2.rs | 4 +- ceno_zkvm/src/instructions/riscv/r_insn.rs | 2 +- ceno_zkvm/src/instructions/riscv/s_insn.rs | 2 +- .../riscv/shift/shift_circuit_v2.rs | 7 +- .../instructions/riscv/slt/slt_circuit_v2.rs | 8 +- .../riscv/slti/slti_circuit_v2.rs | 6 +- ceno_zkvm/src/scheme/septic_curve.rs | 172 ++- ceno_zkvm/src/structs.rs | 382 +++--- ceno_zkvm/src/tables/shard_ram.rs | 472 ++++---- 69 files changed, 1966 insertions(+), 1370 deletions(-) diff --git a/ceno_emul/src/lib.rs b/ceno_emul/src/lib.rs index 268fe78e8..607f451e0 100644 --- a/ceno_emul/src/lib.rs +++ b/ceno_emul/src/lib.rs @@ -12,8 +12,7 @@ mod tracer; pub use tracer::{ Change, FullTracer, FullTracerConfig, LatestAccesses, MemOp, NextAccessPair, NextCycleAccess, PackedNextAccessEntry, PreflightTracer, PreflightTracerConfig, ReadOp, ShardPlanBuilder, - StepCellExtractor, StepIndex, - StepRecord, Tracer, WriteOp, + StepCellExtractor, StepIndex, StepRecord, Tracer, WriteOp, }; mod vm_state; diff --git a/ceno_emul/src/test_utils.rs b/ceno_emul/src/test_utils.rs index 2c906c7b4..0fd9f03d1 100644 --- a/ceno_emul/src/test_utils.rs +++ b/ceno_emul/src/test_utils.rs @@ -28,9 +28,7 @@ pub fn keccak_step() -> (StepRecord, Vec, Vec) { let mut vm: VMState = VMState::new_with_tracer_config( CENO_PLATFORM.clone(), program.into(), - FullTracerConfig { - max_step_shard: 10, - }, + FullTracerConfig { max_step_shard: 10 }, ); vm.iter_until_halt().collect::>>().unwrap(); let steps = vm.tracer().recorded_steps(); diff --git a/ceno_emul/src/tracer.rs b/ceno_emul/src/tracer.rs index 164f1c6c4..c345b90bb 100644 --- a/ceno_emul/src/tracer.rs +++ b/ceno_emul/src/tracer.rs @@ -128,8 +128,11 @@ impl PartialEq for PackedNextAccessEntry { impl Ord for PackedNextAccessEntry { #[inline] fn cmp(&self, other: &Self) -> std::cmp::Ordering { - (self.cycles_hi, self.cycles_lo, self.addr) - .cmp(&(other.cycles_hi, other.cycles_lo, other.addr)) + (self.cycles_hi, self.cycles_lo, self.addr).cmp(&( + other.cycles_hi, + other.cycles_lo, + other.addr, + )) } } @@ -1241,7 +1244,13 @@ impl PreflightTracer { tracer } - pub fn into_shard_plan(self) -> (ShardPlanBuilder, NextCycleAccess, Vec) { + pub fn into_shard_plan( + self, + ) -> ( + ShardPlanBuilder, + NextCycleAccess, + Vec, + ) { let Some(mut planner) = self.planner else { panic!("shard planner missing") }; diff --git a/ceno_zkvm/src/e2e.rs b/ceno_zkvm/src/e2e.rs index c4e49e6c2..ea6615304 100644 --- a/ceno_zkvm/src/e2e.rs +++ b/ceno_zkvm/src/e2e.rs @@ -24,10 +24,9 @@ use crate::{ }; use ceno_emul::{ Addr, ByteAddr, CENO_PLATFORM, Cycle, EmuContext, FullTracer, FullTracerConfig, IterAddresses, - NextCycleAccess, PackedNextAccessEntry, Platform, PreflightTracer, PreflightTracerConfig, Program, - RegIdx, - StepCellExtractor, StepIndex, StepRecord, SyscallWitness, Tracer, VM_REG_COUNT, VMState, - WORD_SIZE, Word, WordAddr, host_utils::read_all_messages, + NextCycleAccess, PackedNextAccessEntry, Platform, PreflightTracer, PreflightTracerConfig, + Program, RegIdx, StepCellExtractor, StepIndex, StepRecord, SyscallWitness, Tracer, + VM_REG_COUNT, VMState, WORD_SIZE, Word, WordAddr, host_utils::read_all_messages, }; #[cfg(feature = "gpu")] use ceno_gpu::CudaHal; @@ -230,9 +229,7 @@ impl<'a> Default for ShardContext<'a> { num_shards: 1, max_cycle: Cycle::MAX, addr_future_accesses: Arc::new(Default::default()), - sorted_next_accesses: Arc::new(SortedNextAccesses { - packed: vec![], - }), + sorted_next_accesses: Arc::new(SortedNextAccesses { packed: vec![] }), addr_accessed_tbs: Either::Left(vec![Vec::new(); max_threads]), read_records_tbs: Either::Left( (0..max_threads) @@ -735,9 +732,7 @@ impl Default for ShardContextBuilder { ShardContextBuilder { cur_shard_id: 0, addr_future_accesses: Arc::new(Default::default()), - sorted_next_accesses: Arc::new(SortedNextAccesses { - packed: vec![], - }), + sorted_next_accesses: Arc::new(SortedNextAccesses { packed: vec![] }), prev_shard_cycle_range: vec![], prev_shard_heap_range: vec![], prev_shard_hint_range: vec![], @@ -760,35 +755,37 @@ impl ShardContextBuilder { assert_eq!(multi_prover.max_provers, 1); assert_eq!(multi_prover.prover_id, 0); - let sorted_next_accesses = - info_span!("next_access_presort").in_scope(|| { - let source = std::env::var("CENO_NEXT_ACCESS_SOURCE").unwrap_or_default(); - let mut entries = if source == "hashmap" { - tracing::info!("[next-access presort] converting from HashMap"); - info_span!("next_access_from_hashmap").in_scope(|| { - let mut entries = Vec::new(); - for (cycle, pairs) in addr_future_accesses.iter() { - for &(addr, next_cycle) in pairs.iter() { - entries.push(PackedNextAccessEntry::new(*cycle, addr.0, next_cycle)); - } + let sorted_next_accesses = info_span!("next_access_presort").in_scope(|| { + let source = std::env::var("CENO_NEXT_ACCESS_SOURCE").unwrap_or_default(); + let mut entries = if source == "hashmap" { + tracing::info!("[next-access presort] converting from HashMap"); + info_span!("next_access_from_hashmap").in_scope(|| { + let mut entries = Vec::new(); + for (cycle, pairs) in addr_future_accesses.iter() { + for &(addr, next_cycle) in pairs.iter() { + entries.push(PackedNextAccessEntry::new(*cycle, addr.0, next_cycle)); } - entries - }) - } else { - tracing::info!( - "[next-access presort] using preflight-appended vec ({} entries)", - next_accesses_vec.len() - ); - next_accesses_vec - }; - let len = entries.len(); - info_span!("next_access_par_sort", n = len).in_scope(|| { - entries.par_sort_unstable(); - }); - tracing::info!("[next-access presort] sorted {} entries ({:.2} MB)", - len, len * 16 / (1024 * 1024)); - Arc::new(SortedNextAccesses { packed: entries }) + } + entries + }) + } else { + tracing::info!( + "[next-access presort] using preflight-appended vec ({} entries)", + next_accesses_vec.len() + ); + next_accesses_vec + }; + let len = entries.len(); + info_span!("next_access_par_sort", n = len).in_scope(|| { + entries.par_sort_unstable(); }); + tracing::info!( + "[next-access presort] sorted {} entries ({:.2} MB)", + len, + len * 16 / (1024 * 1024) + ); + Arc::new(SortedNextAccesses { packed: entries }) + }); ShardContextBuilder { cur_shard_id: 0, @@ -2267,28 +2264,31 @@ pub fn run_e2e_verify>( } fn clone_debug_shard_ctx(src: &ShardContext) -> ShardContext<'static> { - let mut cloned = ShardContext::default(); - cloned.shard_id = src.shard_id; - cloned.num_shards = src.num_shards; - cloned.max_cycle = src.max_cycle; - cloned.addr_future_accesses = src.addr_future_accesses.clone(); - cloned.sorted_next_accesses = src.sorted_next_accesses.clone(); - cloned.cur_shard_cycle_range = src.cur_shard_cycle_range.clone(); - cloned.expected_inst_per_shard = src.expected_inst_per_shard; - cloned.max_num_cross_shard_accesses = src.max_num_cross_shard_accesses; - cloned.prev_shard_cycle_range = src.prev_shard_cycle_range.clone(); - cloned.prev_shard_heap_range = src.prev_shard_heap_range.clone(); - cloned.prev_shard_hint_range = src.prev_shard_hint_range.clone(); - cloned.platform = src.platform.clone(); - cloned.shard_heap_addr_range = src.shard_heap_addr_range.clone(); - cloned.shard_hint_addr_range = src.shard_hint_addr_range.clone(); - cloned.syscall_witnesses = src.syscall_witnesses.clone(); - cloned + ShardContext { + shard_id: src.shard_id, + num_shards: src.num_shards, + max_cycle: src.max_cycle, + addr_future_accesses: src.addr_future_accesses.clone(), + sorted_next_accesses: src.sorted_next_accesses.clone(), + cur_shard_cycle_range: src.cur_shard_cycle_range.clone(), + expected_inst_per_shard: src.expected_inst_per_shard, + max_num_cross_shard_accesses: src.max_num_cross_shard_accesses, + prev_shard_cycle_range: src.prev_shard_cycle_range.clone(), + prev_shard_heap_range: src.prev_shard_heap_range.clone(), + prev_shard_hint_range: src.prev_shard_hint_range.clone(), + platform: src.platform.clone(), + shard_heap_addr_range: src.shard_heap_addr_range.clone(), + shard_hint_addr_range: src.shard_hint_addr_range.clone(), + syscall_witnesses: src.syscall_witnesses.clone(), + ..Default::default() + } } +type FlatRecord = (u32, u64, u64, u64, u64, Option, u32, usize); + fn flatten_ram_records( records: &[BTreeMap], -) -> Vec<(u32, u64, u64, u64, u64, Option, u32, usize)> { +) -> Vec { let mut flat = Vec::new(); for table in records { for (addr, record) in table { @@ -2350,18 +2350,21 @@ fn log_combined_lk_diff( let gpu_combined = gpu_witness.combined_lk_mlt().expect("gpu combined_lk_mlt"); let table_names = [ - "Dynamic", "DoubleU8", "And", "Or", "Xor", "Ltu", "Pow", "Instruction", + "Dynamic", + "DoubleU8", + "And", + "Or", + "Xor", + "Ltu", + "Pow", + "Instruction", ]; let mut total_diffs = 0usize; for (table_idx, (cpu_table, gpu_table)) in cpu_combined.iter().zip(gpu_combined.iter()).enumerate() { - let mut keys: Vec = cpu_table - .keys() - .chain(gpu_table.keys()) - .copied() - .collect(); + let mut keys: Vec = cpu_table.keys().chain(gpu_table.keys()).copied().collect(); keys.sort_unstable(); keys.dedup(); diff --git a/ceno_zkvm/src/gadgets/poseidon2.rs b/ceno_zkvm/src/gadgets/poseidon2.rs index fdb6be1a8..5e079489b 100644 --- a/ceno_zkvm/src/gadgets/poseidon2.rs +++ b/ceno_zkvm/src/gadgets/poseidon2.rs @@ -78,7 +78,7 @@ pub struct Poseidon2Config< const HALF_FULL_ROUNDS: usize, const PARTIAL_ROUNDS: usize, > { - pub(crate) p3_cols: Vec, // columns in the plonky3-air + pub(crate) p3_cols: Vec, // columns in the plonky3-air pub(crate) post_linear_layer_cols: Vec, /* additional columns to hold the state after linear layers */ constants: RoundConstants, } diff --git a/ceno_zkvm/src/gadgets/signed_ext.rs b/ceno_zkvm/src/gadgets/signed_ext.rs index 40683274d..90d489f2c 100644 --- a/ceno_zkvm/src/gadgets/signed_ext.rs +++ b/ceno_zkvm/src/gadgets/signed_ext.rs @@ -44,6 +44,7 @@ impl SignedExtendConfig { self.msb.expr() } + #[allow(dead_code)] // used by GPU column map extraction (cfg gated) pub(crate) fn msb(&self) -> WitIn { self.msb } diff --git a/ceno_zkvm/src/instructions.rs b/ceno_zkvm/src/instructions.rs index c9e4b534a..da7bb43b9 100644 --- a/ceno_zkvm/src/instructions.rs +++ b/ceno_zkvm/src/instructions.rs @@ -19,9 +19,8 @@ use rayon::{ }; use witness::{InstancePaddingStrategy, RowMajorMatrix}; -pub mod riscv; pub mod gpu; - +pub mod riscv; pub use gpu::utils::{cpu_assign_instances, cpu_collect_lk_and_shardram, cpu_collect_shardram}; @@ -371,4 +370,3 @@ macro_rules! impl_gpu_assign { } }; } - diff --git a/ceno_zkvm/src/instructions/gpu/cache.rs b/ceno_zkvm/src/instructions/gpu/cache.rs index af437e97b..ae258b409 100644 --- a/ceno_zkvm/src/instructions/gpu/cache.rs +++ b/ceno_zkvm/src/instructions/gpu/cache.rs @@ -5,16 +5,14 @@ /// the same shard. use ceno_emul::{StepRecord, WordAddr}; use ceno_gpu::{ - Buffer, CudaHal, CudaSlice, bb31::CudaHalBB31, bb31::ShardDeviceBuffers, - common::witgen::types::GpuShardRamRecord, common::witgen::types::GpuShardScalars, + Buffer, CudaHal, CudaSlice, + bb31::{CudaHalBB31, ShardDeviceBuffers}, + common::witgen::types::{GpuShardRamRecord, GpuShardScalars}, }; use std::cell::RefCell; use tracing::info_span; -use crate::{ - e2e::ShardContext, - error::ZKVMError, -}; +use crate::{e2e::ShardContext, error::ZKVMError}; /// Cached shard_steps device buffer with metadata for logging. struct ShardStepsCache { @@ -170,13 +168,17 @@ pub(crate) fn ensure_shard_metadata_cached( unsafe { std::slice::from_raw_parts( sorted.packed.as_ptr() as *const u8, - sorted.packed.len() * std::mem::size_of::(), + sorted.packed.len() + * std::mem::size_of::(), ) } }; - let buf = hal.inner.htod_copy_stream(None, packed_bytes).map_err(|e| { - ZKVMError::InvalidWitness(format!("next_access_packed H2D: {e}").into()) - })?; + let buf = hal + .inner + .htod_copy_stream(None, packed_bytes) + .map_err(|e| { + ZKVMError::InvalidWitness(format!("next_access_packed H2D: {e}").into()) + })?; let next_access_device = ceno_gpu::common::buffer::BufferImpl::new(buf); let mb = packed_bytes.len() as f64 / (1024.0 * 1024.0); tracing::info!( @@ -266,24 +268,21 @@ pub(crate) fn ensure_shard_metadata_cached( let ec_u32s = ec_capacity * 26; // 26 u32s per GpuShardRamRecord (104 bytes) let addr_capacity = total_ops_estimate.min(256 * 1024 * 1024) as usize; - let shared_ec_buf = hal.witgen + let shared_ec_buf = hal + .witgen .alloc_u32_zeroed(ec_u32s, None) .map_err(|e| ZKVMError::InvalidWitness(format!("shared_ec_buf alloc: {e}").into()))?; - let shared_ec_count = hal.witgen + let shared_ec_count = hal + .witgen .alloc_u32_zeroed(1, None) - .map_err(|e| { - ZKVMError::InvalidWitness(format!("shared_ec_count alloc: {e}").into()) - })?; - let shared_addr_buf = hal.witgen + .map_err(|e| ZKVMError::InvalidWitness(format!("shared_ec_count alloc: {e}").into()))?; + let shared_addr_buf = hal + .witgen .alloc_u32_zeroed(addr_capacity, None) - .map_err(|e| { - ZKVMError::InvalidWitness(format!("shared_addr_buf alloc: {e}").into()) - })?; - let shared_addr_count = hal.witgen - .alloc_u32_zeroed(1, None) - .map_err(|e| { - ZKVMError::InvalidWitness(format!("shared_addr_count alloc: {e}").into()) - })?; + .map_err(|e| ZKVMError::InvalidWitness(format!("shared_addr_buf alloc: {e}").into()))?; + let shared_addr_count = hal.witgen.alloc_u32_zeroed(1, None).map_err(|e| { + ZKVMError::InvalidWitness(format!("shared_addr_count alloc: {e}").into()) + })?; let shared_ec_out_ptr = shared_ec_buf.device_ptr() as u64; let shared_ec_count_ptr = shared_ec_count.device_ptr() as u64; @@ -292,7 +291,9 @@ pub(crate) fn ensure_shard_metadata_cached( tracing::info!( "[GPU shard] shard_id={}: shared buffers allocated: ec_capacity={}, addr_capacity={}", - shard_id, ec_capacity, addr_capacity, + shard_id, + ec_capacity, + addr_capacity, ); *cache = Some(ShardMetadataCache { @@ -374,7 +375,14 @@ pub struct SharedDeviceBufferSet { pub fn gpu_batch_continuation_ec_on_device( write_records: &[(crate::tables::ShardRamRecord, &'static str)], read_records: &[(crate::tables::ShardRamRecord, &'static str)], -) -> Result<(ceno_gpu::common::buffer::BufferImpl<'static, u32>, usize, usize), ZKVMError> { +) -> Result< + ( + ceno_gpu::common::buffer::BufferImpl<'static, u32>, + usize, + usize, + ), + ZKVMError, +> { use gkr_iop::gpu::get_cuda_hal; let hal = get_cuda_hal().map_err(|e| { @@ -385,9 +393,10 @@ pub fn gpu_batch_continuation_ec_on_device( let n_reads = read_records.len(); let total = n_writes + n_reads; if total == 0 { - let empty = hal.witgen.alloc_u32_zeroed(1, None).map_err(|e| { - ZKVMError::InvalidWitness(format!("alloc: {e}").into()) - })?; + let empty = hal + .witgen + .alloc_u32_zeroed(1, None) + .map_err(|e| ZKVMError::InvalidWitness(format!("alloc: {e}").into()))?; return Ok((empty, 0, 0)); } @@ -398,11 +407,14 @@ pub fn gpu_batch_continuation_ec_on_device( } // GPU batch EC, results stay on device - let (device_buf, _count) = info_span!("gpu_batch_ec_on_device", n = total).in_scope(|| { - hal.witgen.batch_continuation_ec_on_device(&gpu_records, None) - }).map_err(|e| { - ZKVMError::InvalidWitness(format!("GPU batch EC on device failed: {e}").into()) - })?; + let (device_buf, _count) = info_span!("gpu_batch_ec_on_device", n = total) + .in_scope(|| { + hal.witgen + .batch_continuation_ec_on_device(&gpu_records, None) + }) + .map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU batch EC on device failed: {e}").into()) + })?; Ok((device_buf, n_writes, n_reads)) } @@ -414,7 +426,10 @@ pub(crate) fn read_shared_addr_count() -> usize { SHARD_META_CACHE.with(|cache| { let cache = cache.borrow(); let c = cache.as_ref().expect("shard metadata not cached"); - let buf = c.shared_addr_count.as_ref().expect("shared_addr_count not allocated"); + let buf = c + .shared_addr_count + .as_ref() + .expect("shared_addr_count not allocated"); let v: Vec = buf.to_vec().expect("shared_addr_count D2H failed"); v[0] as usize }) @@ -429,7 +444,10 @@ pub(crate) fn read_shared_addr_range(start: usize, end: usize) -> Vec { SHARD_META_CACHE.with(|cache| { let cache = cache.borrow(); let c = cache.as_ref().expect("shard metadata not cached"); - let buf = c.shared_addr_buf.as_ref().expect("shared_addr_buf not allocated"); + let buf = c + .shared_addr_buf + .as_ref() + .expect("shared_addr_buf not allocated"); let all: Vec = buf.to_vec_n(end).expect("shared_addr_buf D2H failed"); all[start..end].to_vec() }) @@ -454,13 +472,15 @@ pub fn flush_shared_ec_buffers(shard_ctx: &mut ShardContext) -> Result<(), ZKVME let ec_count_buf = match c.shared_ec_count.as_ref() { Some(b) => b, None => { - tracing::debug!("[GPU shard] flush_shared_ec_buffers: buffers already taken, no-op"); + tracing::debug!( + "[GPU shard] flush_shared_ec_buffers: buffers already taken, no-op" + ); return Ok(()); } }; - let ec_count_vec: Vec = ec_count_buf.to_vec().map_err(|e| { - ZKVMError::InvalidWitness(format!("shared_ec_count D2H: {e}").into()) - })?; + let ec_count_vec: Vec = ec_count_buf + .to_vec() + .map_err(|e| ZKVMError::InvalidWitness(format!("shared_ec_count D2H: {e}").into()))?; let ec_count = ec_count_vec[0] as usize; let ec_capacity = c.device_bufs.shared_ec_capacity as usize; @@ -476,14 +496,11 @@ pub fn flush_shared_ec_buffers(shard_ctx: &mut ShardContext) -> Result<(), ZKVME // D2H EC records (only the active portion) let ec_buf = c.shared_ec_buf.as_ref().unwrap(); let ec_u32s = ec_count * 26; // 26 u32s per GpuShardRamRecord - let raw_u32: Vec = ec_buf.to_vec_n(ec_u32s).map_err(|e| { - ZKVMError::InvalidWitness(format!("shared_ec_buf D2H: {e}").into()) - })?; + let raw_u32: Vec = ec_buf + .to_vec_n(ec_u32s) + .map_err(|e| ZKVMError::InvalidWitness(format!("shared_ec_buf D2H: {e}").into()))?; let raw_bytes = unsafe { - std::slice::from_raw_parts( - raw_u32.as_ptr() as *const u8, - raw_u32.len() * 4, - ) + std::slice::from_raw_parts(raw_u32.as_ptr() as *const u8, raw_u32.len() * 4) }; tracing::info!( "[GPU shard] flush_shared_ec_buffers: {} EC records, {:.2} MB", @@ -494,12 +511,13 @@ pub fn flush_shared_ec_buffers(shard_ctx: &mut ShardContext) -> Result<(), ZKVME } // D2H addr_accessed count - let addr_count_buf = c.shared_addr_count.as_ref().ok_or_else(|| { - ZKVMError::InvalidWitness("shared_addr_count not allocated".into()) - })?; - let addr_count_vec: Vec = addr_count_buf.to_vec().map_err(|e| { - ZKVMError::InvalidWitness(format!("shared_addr_count D2H: {e}").into()) - })?; + let addr_count_buf = c + .shared_addr_count + .as_ref() + .ok_or_else(|| ZKVMError::InvalidWitness("shared_addr_count not allocated".into()))?; + let addr_count_vec: Vec = addr_count_buf + .to_vec() + .map_err(|e| ZKVMError::InvalidWitness(format!("shared_addr_count D2H: {e}").into()))?; let addr_count = addr_count_vec[0] as usize; let addr_capacity = c.device_bufs.shared_addr_capacity as usize; diff --git a/ceno_zkvm/src/instructions/gpu/chips/add.rs b/ceno_zkvm/src/instructions/gpu/chips/add.rs index 74ae5567d..cf5b7a1c8 100644 --- a/ceno_zkvm/src/instructions/gpu/chips/add.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/add.rs @@ -1,8 +1,12 @@ use ceno_gpu::common::witgen::types::AddColumnMap; use ff_ext::ExtensionField; -use crate::instructions::gpu::utils::colmap_base::{extract_carries, extract_rd, extract_rs1, extract_rs2, extract_state, extract_uint_limbs}; -use crate::instructions::riscv::arith::ArithConfig; +use crate::instructions::{ + gpu::utils::colmap_base::{ + extract_carries, extract_rd, extract_rs1, extract_rs2, extract_state, extract_uint_limbs, + }, + riscv::arith::ArithConfig, +}; /// Extract column map from a constructed ArithConfig (ADD variant). /// @@ -181,8 +185,18 @@ mod tests { }; let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); - let gpu_result = hal.witgen - .witgen_add(&col_map, &gpu_records, &indices_u32, shard_offset, 0, 0, None, None) + let gpu_result = hal + .witgen + .witgen_add( + &col_map, + &gpu_records, + &indices_u32, + shard_offset, + 0, + 0, + None, + None, + ) .unwrap(); // D2H copy (GPU output is column-major) @@ -213,10 +227,7 @@ mod tests { let mut shard_ctx_full_gpu = ShardContext::default(); let (gpu_rmms, gpu_lkm) = - crate::instructions::gpu::dispatch::try_gpu_assign_instances::< - E, - AddInstruction, - >( + crate::instructions::gpu::dispatch::try_gpu_assign_instances::>( &config, &mut shard_ctx_full_gpu, num_witin, @@ -230,10 +241,7 @@ mod tests { // Flush shared EC/addr buffers from GPU device to shard_ctx // (in the e2e pipeline this is called once per shard after all opcode circuits) - crate::instructions::gpu::cache::flush_shared_ec_buffers( - &mut shard_ctx_full_gpu, - ) - .unwrap(); + crate::instructions::gpu::cache::flush_shared_ec_buffers(&mut shard_ctx_full_gpu).unwrap(); assert_eq!(gpu_rmms[0].values(), cpu_rmms[0].values()); assert_eq!(flatten_lk(&gpu_lkm), flatten_lk(&cpu_lkm)); diff --git a/ceno_zkvm/src/instructions/gpu/chips/addi.rs b/ceno_zkvm/src/instructions/gpu/chips/addi.rs index 81f0e721b..234ac00ff 100644 --- a/ceno_zkvm/src/instructions/gpu/chips/addi.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/addi.rs @@ -1,8 +1,12 @@ use ceno_gpu::common::witgen::types::AddiColumnMap; use ff_ext::ExtensionField; -use crate::instructions::gpu::utils::colmap_base::{extract_carries, extract_rd, extract_rs1, extract_state, extract_uint_limbs}; -use crate::instructions::riscv::arith_imm::arith_imm_circuit_v2::InstructionConfig; +use crate::instructions::{ + gpu::utils::colmap_base::{ + extract_carries, extract_rd, extract_rs1, extract_state, extract_uint_limbs, + }, + riscv::arith_imm::arith_imm_circuit_v2::InstructionConfig, +}; /// Extract column map from a constructed InstructionConfig (ADDI v2). pub fn extract_addi_column_map( @@ -122,8 +126,18 @@ mod tests { }; let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); - let gpu_result = hal.witgen - .witgen_addi(&col_map, &gpu_records, &indices_u32, shard_offset, 0, 0, None, None) + let gpu_result = hal + .witgen + .witgen_addi( + &col_map, + &gpu_records, + &indices_u32, + shard_offset, + 0, + 0, + None, + None, + ) .unwrap(); let gpu_data: Vec<::BaseField> = diff --git a/ceno_zkvm/src/instructions/gpu/chips/auipc.rs b/ceno_zkvm/src/instructions/gpu/chips/auipc.rs index 74c97fdad..bf0c0fefc 100644 --- a/ceno_zkvm/src/instructions/gpu/chips/auipc.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/auipc.rs @@ -1,8 +1,10 @@ use ceno_gpu::common::witgen::types::AuipcColumnMap; use ff_ext::ExtensionField; -use crate::instructions::gpu::utils::colmap_base::{extract_rd, extract_rs1, extract_state, extract_uint_limbs}; -use crate::instructions::riscv::auipc::AuipcConfig; +use crate::instructions::{ + gpu::utils::colmap_base::{extract_rd, extract_rs1, extract_state, extract_uint_limbs}, + riscv::auipc::AuipcConfig, +}; /// Extract column map from a constructed AuipcConfig. pub fn extract_auipc_column_map( @@ -124,8 +126,18 @@ mod tests { }; let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); - let gpu_result = hal.witgen - .witgen_auipc(&col_map, &gpu_records, &indices_u32, shard_offset, 0, 0, None, None) + let gpu_result = hal + .witgen + .witgen_auipc( + &col_map, + &gpu_records, + &indices_u32, + shard_offset, + 0, + 0, + None, + None, + ) .unwrap(); let gpu_data: Vec<::BaseField> = diff --git a/ceno_zkvm/src/instructions/gpu/chips/branch_cmp.rs b/ceno_zkvm/src/instructions/gpu/chips/branch_cmp.rs index 53902abde..69407ce0b 100644 --- a/ceno_zkvm/src/instructions/gpu/chips/branch_cmp.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/branch_cmp.rs @@ -1,8 +1,12 @@ use ceno_gpu::common::witgen::types::BranchCmpColumnMap; use ff_ext::ExtensionField; -use crate::instructions::gpu::utils::colmap_base::{extract_rs1, extract_rs2, extract_state_branching, extract_uint_limbs}; -use crate::instructions::riscv::branch::branch_circuit_v2::BranchConfig; +use crate::instructions::{ + gpu::utils::colmap_base::{ + extract_rs1, extract_rs2, extract_state_branching, extract_uint_limbs, + }, + riscv::branch::branch_circuit_v2::BranchConfig, +}; /// Extract column map from a constructed BranchConfig (BLT/BGE/BLTU/BGEU variant). pub fn extract_branch_cmp_column_map( @@ -138,8 +142,19 @@ mod tests { }; let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); - let gpu_result = hal.witgen - .witgen_branch_cmp(&col_map, &gpu_records, &indices_u32, shard_offset, 1, 0, 0, None, None) + let gpu_result = hal + .witgen + .witgen_branch_cmp( + &col_map, + &gpu_records, + &indices_u32, + shard_offset, + 1, + 0, + 0, + None, + None, + ) .unwrap(); let gpu_data: Vec<::BaseField> = diff --git a/ceno_zkvm/src/instructions/gpu/chips/branch_eq.rs b/ceno_zkvm/src/instructions/gpu/chips/branch_eq.rs index a720f47aa..33373bef8 100644 --- a/ceno_zkvm/src/instructions/gpu/chips/branch_eq.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/branch_eq.rs @@ -1,8 +1,12 @@ use ceno_gpu::common::witgen::types::BranchEqColumnMap; use ff_ext::ExtensionField; -use crate::instructions::gpu::utils::colmap_base::{extract_rs1, extract_rs2, extract_state_branching, extract_uint_limbs}; -use crate::instructions::riscv::branch::branch_circuit_v2::BranchConfig; +use crate::instructions::{ + gpu::utils::colmap_base::{ + extract_rs1, extract_rs2, extract_state_branching, extract_uint_limbs, + }, + riscv::branch::branch_circuit_v2::BranchConfig, +}; /// Extract column map from a constructed BranchConfig (BEQ/BNE variant). pub fn extract_branch_eq_column_map( @@ -135,8 +139,19 @@ mod tests { }; let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); - let gpu_result = hal.witgen - .witgen_branch_eq(&col_map, &gpu_records, &indices_u32, shard_offset, 1, 0, 0, None, None) + let gpu_result = hal + .witgen + .witgen_branch_eq( + &col_map, + &gpu_records, + &indices_u32, + shard_offset, + 1, + 0, + 0, + None, + None, + ) .unwrap(); let gpu_data: Vec<::BaseField> = diff --git a/ceno_zkvm/src/instructions/gpu/chips/div.rs b/ceno_zkvm/src/instructions/gpu/chips/div.rs index 8f7442934..9f998ae02 100644 --- a/ceno_zkvm/src/instructions/gpu/chips/div.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/div.rs @@ -1,8 +1,12 @@ use ceno_gpu::common::witgen::types::DivColumnMap; use ff_ext::ExtensionField; -use crate::instructions::gpu::utils::colmap_base::{extract_rd, extract_rs1, extract_rs2, extract_state, extract_uint_limbs}; -use crate::instructions::riscv::div::div_circuit_v2::DivRemConfig; +use crate::instructions::{ + gpu::utils::colmap_base::{ + extract_rd, extract_rs1, extract_rs2, extract_state, extract_uint_limbs, + }, + riscv::div::div_circuit_v2::DivRemConfig, +}; /// Extract column map from a constructed DivRemConfig. /// div_kind: 0=DIV, 1=DIVU, 2=REM, 3=REMU @@ -38,7 +42,8 @@ pub fn extract_div_column_map( // sign_xor let sign_xor = config.sign_xor.id as u32; - let remainder_prime = extract_uint_limbs::(&config.remainder_prime, "remainder_prime"); + let remainder_prime = + extract_uint_limbs::(&config.remainder_prime, "remainder_prime"); // lt_marker let lt_marker: [u32; 2] = [config.lt_marker[0].id as u32, config.lt_marker[1].id as u32]; @@ -337,7 +342,8 @@ mod tests { }; let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); - let gpu_result = hal.witgen + let gpu_result = hal + .witgen .witgen_div( &col_map, &gpu_records, diff --git a/ceno_zkvm/src/instructions/gpu/chips/jal.rs b/ceno_zkvm/src/instructions/gpu/chips/jal.rs index 82dd428c4..a3e965098 100644 --- a/ceno_zkvm/src/instructions/gpu/chips/jal.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/jal.rs @@ -1,8 +1,10 @@ use ceno_gpu::common::witgen::types::JalColumnMap; use ff_ext::ExtensionField; -use crate::instructions::gpu::utils::colmap_base::{extract_rd, extract_state_branching, extract_uint_limbs}; -use crate::instructions::riscv::jump::jal_v2::JalConfig; +use crate::instructions::{ + gpu::utils::colmap_base::{extract_rd, extract_state_branching, extract_uint_limbs}, + riscv::jump::jal_v2::JalConfig, +}; /// Extract column map from a constructed JalConfig. pub fn extract_jal_column_map( @@ -112,8 +114,18 @@ mod tests { }; let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); - let gpu_result = hal.witgen - .witgen_jal(&col_map, &gpu_records, &indices_u32, shard_offset, 0, 0, None, None) + let gpu_result = hal + .witgen + .witgen_jal( + &col_map, + &gpu_records, + &indices_u32, + shard_offset, + 0, + 0, + None, + None, + ) .unwrap(); let gpu_data: Vec<::BaseField> = diff --git a/ceno_zkvm/src/instructions/gpu/chips/jalr.rs b/ceno_zkvm/src/instructions/gpu/chips/jalr.rs index 762ae2cf4..d1da5a6af 100644 --- a/ceno_zkvm/src/instructions/gpu/chips/jalr.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/jalr.rs @@ -1,10 +1,12 @@ use ceno_gpu::common::witgen::types::JalrColumnMap; use ff_ext::ExtensionField; -use crate::instructions::gpu::utils::colmap_base::{ - extract_rd, extract_rs1, extract_state_branching, extract_uint_limbs, extract_wit_ids, +use crate::instructions::{ + gpu::utils::colmap_base::{ + extract_rd, extract_rs1, extract_state_branching, extract_uint_limbs, extract_wit_ids, + }, + riscv::jump::jalr_v2::JalrConfig, }; -use crate::instructions::riscv::jump::jalr_v2::JalrConfig; /// Extract column map from a constructed JalrConfig. pub fn extract_jalr_column_map( @@ -21,7 +23,8 @@ pub fn extract_jalr_column_map( let imm = config.imm.id as u32; let imm_sign = config.imm_sign.id as u32; let jump_pc_addr = extract_uint_limbs::(&config.jump_pc_addr.addr, "jump_pc_addr"); - let jump_pc_addr_bit = extract_wit_ids::<2>(&config.jump_pc_addr.low_bits, "jump_pc_addr low_bits"); + let jump_pc_addr_bit = + extract_wit_ids::<2>(&config.jump_pc_addr.low_bits, "jump_pc_addr low_bits"); // rd_high let rd_high = config.rd_high.id as u32; @@ -133,8 +136,18 @@ mod tests { }; let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); - let gpu_result = hal.witgen - .witgen_jalr(&col_map, &gpu_records, &indices_u32, shard_offset, 0, 0, None, None) + let gpu_result = hal + .witgen + .witgen_jalr( + &col_map, + &gpu_records, + &indices_u32, + shard_offset, + 0, + 0, + None, + None, + ) .unwrap(); let gpu_data: Vec<::BaseField> = diff --git a/ceno_zkvm/src/instructions/gpu/chips/keccak.rs b/ceno_zkvm/src/instructions/gpu/chips/keccak.rs index 2026a0a75..c4450b55a 100644 --- a/ceno_zkvm/src/instructions/gpu/chips/keccak.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/keccak.rs @@ -23,7 +23,11 @@ pub fn extract_keccak_column_map( let ecall_prev_ts = config.ecall_id.prev_ts.id as u32; let ecall_lt_diff = { let diffs = &config.ecall_id.lt_cfg.0.diff; - assert_eq!(diffs.len(), 2, "Expected 2 AssertLt diff limbs for ecall_id"); + assert_eq!( + diffs.len(), + 2, + "Expected 2 AssertLt diff limbs for ecall_id" + ); [diffs[0].id as u32, diffs[1].id as u32] }; @@ -55,7 +59,11 @@ pub fn extract_keccak_column_map( }; let sptr_lt_diff = { let diffs = &config.state_ptr.0.lt_cfg.0.diff; - assert_eq!(diffs.len(), 2, "Expected 2 AssertLt diff limbs for state_ptr"); + assert_eq!( + diffs.len(), + 2, + "Expected 2 AssertLt diff limbs for state_ptr" + ); [diffs[0].id as u32, diffs[1].id as u32] }; @@ -83,7 +91,8 @@ pub fn extract_keccak_column_map( #[cfg(debug_assertions)] { let base = keccak_base_col as usize; - let expected_size = std::mem::size_of::>(); + let expected_size = + std::mem::size_of::>(); // Check that the last keccak column is at base + expected_size - 1 let last_rc = config.layout.layer_exprs.wits.rc.last().unwrap(); assert_eq!( @@ -288,7 +297,11 @@ mod tests { ); let mut struct_mismatches = 0; - for (i, (g, c)) in gpu_struct_data.iter().zip(cpu_struct_data.iter()).enumerate() { + for (i, (g, c)) in gpu_struct_data + .iter() + .zip(cpu_struct_data.iter()) + .enumerate() + { if g != c { if struct_mismatches < 20 { let row = i / num_structural_witin; @@ -337,7 +350,10 @@ mod tests { eprintln!("Keccak LK: {} mismatches", lk_mismatches); assert_eq!(mismatches, 0, "GPU vs CPU witness mismatch"); - assert_eq!(struct_mismatches, 0, "GPU vs CPU structural witness mismatch"); + assert_eq!( + struct_mismatches, 0, + "GPU vs CPU structural witness mismatch" + ); assert_eq!(lk_mismatches, 0, "GPU vs CPU LK multiplicity mismatch"); } } diff --git a/ceno_zkvm/src/instructions/gpu/chips/load_sub.rs b/ceno_zkvm/src/instructions/gpu/chips/load_sub.rs index 7390ce113..36694fdc8 100644 --- a/ceno_zkvm/src/instructions/gpu/chips/load_sub.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/load_sub.rs @@ -1,8 +1,12 @@ use ceno_gpu::common::witgen::types::LoadSubColumnMap; use ff_ext::ExtensionField; -use crate::instructions::gpu::utils::colmap_base::{extract_rd, extract_read_mem, extract_rs1, extract_state, extract_uint_limbs}; -use crate::instructions::riscv::memory::load_v2::LoadConfig; +use crate::instructions::{ + gpu::utils::colmap_base::{ + extract_rd, extract_read_mem, extract_rs1, extract_state, extract_uint_limbs, + }, + riscv::memory::load_v2::LoadConfig, +}; /// Extract column map from a constructed LoadConfig for sub-word loads (LH/LHU/LB/LBU). pub fn extract_load_sub_column_map( @@ -349,7 +353,8 @@ mod tests { let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); let load_width: u32 = if is_byte { 8 } else { 16 }; let is_signed_u32: u32 = if is_signed { 1 } else { 0 }; - let gpu_result = hal.witgen + let gpu_result = hal + .witgen .witgen_load_sub( &col_map, &gpu_records, diff --git a/ceno_zkvm/src/instructions/gpu/chips/logic_i.rs b/ceno_zkvm/src/instructions/gpu/chips/logic_i.rs index 671cd90cf..3d06c8184 100644 --- a/ceno_zkvm/src/instructions/gpu/chips/logic_i.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/logic_i.rs @@ -1,8 +1,10 @@ use ceno_gpu::common::witgen::types::LogicIColumnMap; use ff_ext::ExtensionField; -use crate::instructions::gpu::utils::colmap_base::{extract_rd, extract_rs1, extract_state, extract_uint_limbs}; -use crate::instructions::riscv::logic_imm::logic_imm_circuit_v2::LogicConfig; +use crate::instructions::{ + gpu::utils::colmap_base::{extract_rd, extract_rs1, extract_state, extract_uint_limbs}, + riscv::logic_imm::logic_imm_circuit_v2::LogicConfig, +}; /// Extract column map from a constructed LogicConfig (I-type v2: ANDI/ORI/XORI). pub fn extract_logic_i_column_map( @@ -139,8 +141,19 @@ mod tests { }; let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); - let gpu_result = hal.witgen - .witgen_logic_i(&col_map, &gpu_records, &indices_u32, shard_offset, 0, 0, 0, None, None) + let gpu_result = hal + .witgen + .witgen_logic_i( + &col_map, + &gpu_records, + &indices_u32, + shard_offset, + 0, + 0, + 0, + None, + None, + ) .unwrap(); let gpu_data: Vec<::BaseField> = diff --git a/ceno_zkvm/src/instructions/gpu/chips/logic_r.rs b/ceno_zkvm/src/instructions/gpu/chips/logic_r.rs index 77ecf6af9..770d770ba 100644 --- a/ceno_zkvm/src/instructions/gpu/chips/logic_r.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/logic_r.rs @@ -1,8 +1,12 @@ use ceno_gpu::common::witgen::types::LogicRColumnMap; use ff_ext::ExtensionField; -use crate::instructions::gpu::utils::colmap_base::{extract_rd, extract_rs1, extract_rs2, extract_state, extract_uint_limbs}; -use crate::instructions::riscv::logic::logic_circuit::LogicConfig; +use crate::instructions::{ + gpu::utils::colmap_base::{ + extract_rd, extract_rs1, extract_rs2, extract_state, extract_uint_limbs, + }, + riscv::logic::logic_circuit::LogicConfig, +}; /// Extract column map from a constructed LogicConfig (R-type: AND/OR/XOR). pub fn extract_logic_r_column_map( @@ -173,8 +177,19 @@ mod tests { }; let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); - let gpu_result = hal.witgen - .witgen_logic_r(&col_map, &gpu_records, &indices_u32, shard_offset, 0, 0, 0, None, None) + let gpu_result = hal + .witgen + .witgen_logic_r( + &col_map, + &gpu_records, + &indices_u32, + shard_offset, + 0, + 0, + 0, + None, + None, + ) .unwrap(); let gpu_data: Vec<::BaseField> = @@ -202,10 +217,7 @@ mod tests { let mut shard_ctx_full_gpu = ShardContext::default(); let (gpu_rmms, gpu_lkm) = - crate::instructions::gpu::dispatch::try_gpu_assign_instances::< - E, - AndInstruction, - >( + crate::instructions::gpu::dispatch::try_gpu_assign_instances::>( &config, &mut shard_ctx_full_gpu, num_witin, @@ -217,10 +229,7 @@ mod tests { .unwrap() .expect("GPU path should be available"); - crate::instructions::gpu::cache::flush_shared_ec_buffers( - &mut shard_ctx_full_gpu, - ) - .unwrap(); + crate::instructions::gpu::cache::flush_shared_ec_buffers(&mut shard_ctx_full_gpu).unwrap(); assert_eq!(gpu_rmms[0].values(), cpu_rmms[0].values()); assert_eq!(flatten_lk(&gpu_lkm), flatten_lk(&cpu_lkm)); diff --git a/ceno_zkvm/src/instructions/gpu/chips/lui.rs b/ceno_zkvm/src/instructions/gpu/chips/lui.rs index 715af5996..a6b5b6278 100644 --- a/ceno_zkvm/src/instructions/gpu/chips/lui.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/lui.rs @@ -1,8 +1,10 @@ use ceno_gpu::common::witgen::types::LuiColumnMap; use ff_ext::ExtensionField; -use crate::instructions::gpu::utils::colmap_base::{extract_rd, extract_rs1, extract_state}; -use crate::instructions::riscv::lui::LuiConfig; +use crate::instructions::{ + gpu::utils::colmap_base::{extract_rd, extract_rs1, extract_state}, + riscv::lui::LuiConfig, +}; /// Extract column map from a constructed LuiConfig. pub fn extract_lui_column_map( @@ -123,8 +125,18 @@ mod tests { }; let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); - let gpu_result = hal.witgen - .witgen_lui(&col_map, &gpu_records, &indices_u32, shard_offset, 0, 0, None, None) + let gpu_result = hal + .witgen + .witgen_lui( + &col_map, + &gpu_records, + &indices_u32, + shard_offset, + 0, + 0, + None, + None, + ) .unwrap(); let gpu_data: Vec<::BaseField> = diff --git a/ceno_zkvm/src/instructions/gpu/chips/lw.rs b/ceno_zkvm/src/instructions/gpu/chips/lw.rs index 4ed847bfd..b865d405f 100644 --- a/ceno_zkvm/src/instructions/gpu/chips/lw.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/lw.rs @@ -1,7 +1,9 @@ use ceno_gpu::common::witgen::types::LwColumnMap; use ff_ext::ExtensionField; -use crate::instructions::gpu::utils::colmap_base::{extract_rd, extract_read_mem, extract_rs1, extract_state, extract_uint_limbs}; +use crate::instructions::gpu::utils::colmap_base::{ + extract_rd, extract_read_mem, extract_rs1, extract_state, extract_uint_limbs, +}; #[cfg(not(feature = "u16limb_circuit"))] use crate::instructions::riscv::memory::load::LoadConfig; @@ -195,8 +197,19 @@ mod tests { }; let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); - let gpu_result = hal.witgen - .witgen_lw(&col_map, &gpu_records, &indices_u32, shard_offset, 0, 0, 0, None, None) + let gpu_result = hal + .witgen + .witgen_lw( + &col_map, + &gpu_records, + &indices_u32, + shard_offset, + 0, + 0, + 0, + None, + None, + ) .unwrap(); let gpu_data: Vec<::BaseField> = @@ -225,10 +238,7 @@ mod tests { let mut shard_ctx_full_gpu = ShardContext::default(); let (gpu_rmms, gpu_lkm) = - crate::instructions::gpu::dispatch::try_gpu_assign_instances::< - E, - LwInstruction, - >( + crate::instructions::gpu::dispatch::try_gpu_assign_instances::( &config, &mut shard_ctx_full_gpu, num_witin, @@ -240,10 +250,7 @@ mod tests { .unwrap() .expect("GPU path should be available"); - crate::instructions::gpu::cache::flush_shared_ec_buffers( - &mut shard_ctx_full_gpu, - ) - .unwrap(); + crate::instructions::gpu::cache::flush_shared_ec_buffers(&mut shard_ctx_full_gpu).unwrap(); assert_eq!(gpu_rmms[0].values(), cpu_rmms[0].values()); assert_eq!(flatten_lk(&gpu_lkm), flatten_lk(&cpu_lkm)); diff --git a/ceno_zkvm/src/instructions/gpu/chips/mul.rs b/ceno_zkvm/src/instructions/gpu/chips/mul.rs index 826c435b6..9da0cc5a8 100644 --- a/ceno_zkvm/src/instructions/gpu/chips/mul.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/mul.rs @@ -1,8 +1,12 @@ use ceno_gpu::common::witgen::types::MulColumnMap; use ff_ext::ExtensionField; -use crate::instructions::gpu::utils::colmap_base::{extract_rd, extract_rs1, extract_rs2, extract_state, extract_uint_limbs}; -use crate::instructions::riscv::mulh::mulh_circuit_v2::MulhConfig; +use crate::instructions::{ + gpu::utils::colmap_base::{ + extract_rd, extract_rs1, extract_rs2, extract_state, extract_uint_limbs, + }, + riscv::mulh::mulh_circuit_v2::MulhConfig, +}; /// Extract column map from a constructed MulhConfig. /// mul_kind: 0=MUL, 1=MULH, 2=MULHU, 3=MULHSU @@ -298,7 +302,8 @@ mod tests { }; let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); - let gpu_result = hal.witgen + let gpu_result = hal + .witgen .witgen_mul( &col_map, &gpu_records, diff --git a/ceno_zkvm/src/instructions/gpu/chips/sb.rs b/ceno_zkvm/src/instructions/gpu/chips/sb.rs index 94cd947ac..d95648b36 100644 --- a/ceno_zkvm/src/instructions/gpu/chips/sb.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/sb.rs @@ -1,10 +1,12 @@ use ceno_gpu::common::witgen::types::SbColumnMap; use ff_ext::ExtensionField; -use crate::instructions::gpu::utils::colmap_base::{ - extract_rs1, extract_rs2, extract_state, extract_uint_limbs, extract_write_mem, +use crate::instructions::{ + gpu::utils::colmap_base::{ + extract_rs1, extract_rs2, extract_state, extract_uint_limbs, extract_write_mem, + }, + riscv::memory::store_v2::StoreConfig, }; -use crate::instructions::riscv::memory::store_v2::StoreConfig; /// Extract column map from a constructed StoreConfig (SB variant, N_ZEROS=0). pub fn extract_sb_column_map( @@ -22,7 +24,8 @@ pub fn extract_sb_column_map( let rs2_limbs = extract_uint_limbs::(&config.rs2_read, "rs2_read"); let imm = config.imm.id as u32; let imm_sign = config.imm_sign.id as u32; - let prev_mem_val = extract_uint_limbs::(&config.prev_memory_value, "prev_memory_value"); + let prev_mem_val = + extract_uint_limbs::(&config.prev_memory_value, "prev_memory_value"); let mem_addr = extract_uint_limbs::(&config.memory_addr.addr, "memory_addr"); // SB-specific: 2 low_bits (bit_0, bit_1) @@ -181,8 +184,19 @@ mod tests { }; let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); - let gpu_result = hal.witgen - .witgen_sb(&col_map, &gpu_records, &indices_u32, shard_offset, 0, 0, 0, None, None) + let gpu_result = hal + .witgen + .witgen_sb( + &col_map, + &gpu_records, + &indices_u32, + shard_offset, + 0, + 0, + 0, + None, + None, + ) .unwrap(); let gpu_data: Vec<::BaseField> = diff --git a/ceno_zkvm/src/instructions/gpu/chips/sh.rs b/ceno_zkvm/src/instructions/gpu/chips/sh.rs index b2728eaa8..d9f9a809c 100644 --- a/ceno_zkvm/src/instructions/gpu/chips/sh.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/sh.rs @@ -1,10 +1,12 @@ use ceno_gpu::common::witgen::types::ShColumnMap; use ff_ext::ExtensionField; -use crate::instructions::gpu::utils::colmap_base::{ - extract_rs1, extract_rs2, extract_state, extract_uint_limbs, extract_write_mem, +use crate::instructions::{ + gpu::utils::colmap_base::{ + extract_rs1, extract_rs2, extract_state, extract_uint_limbs, extract_write_mem, + }, + riscv::memory::store_v2::StoreConfig, }; -use crate::instructions::riscv::memory::store_v2::StoreConfig; /// Extract column map from a constructed StoreConfig (SH variant, N_ZEROS=1). pub fn extract_sh_column_map( @@ -22,7 +24,8 @@ pub fn extract_sh_column_map( let rs2_limbs = extract_uint_limbs::(&config.rs2_read, "rs2_read"); let imm = config.imm.id as u32; let imm_sign = config.imm_sign.id as u32; - let prev_mem_val = extract_uint_limbs::(&config.prev_memory_value, "prev_memory_value"); + let prev_mem_val = + extract_uint_limbs::(&config.prev_memory_value, "prev_memory_value"); let mem_addr = extract_uint_limbs::(&config.memory_addr.addr, "memory_addr"); // SH-specific: 1 low_bit (bit_1 for halfword select) @@ -158,8 +161,19 @@ mod tests { }; let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); - let gpu_result = hal.witgen - .witgen_sh(&col_map, &gpu_records, &indices_u32, shard_offset, 0, 0, 0, None, None) + let gpu_result = hal + .witgen + .witgen_sh( + &col_map, + &gpu_records, + &indices_u32, + shard_offset, + 0, + 0, + 0, + None, + None, + ) .unwrap(); let gpu_data: Vec<::BaseField> = diff --git a/ceno_zkvm/src/instructions/gpu/chips/shard_ram.rs b/ceno_zkvm/src/instructions/gpu/chips/shard_ram.rs index 7981bef47..e512ab2fb 100644 --- a/ceno_zkvm/src/instructions/gpu/chips/shard_ram.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/shard_ram.rs @@ -101,8 +101,7 @@ mod tests { ShardRamCircuit::::build_gkr_iop_circuit(&mut cb, &ProgramParams::default()) .unwrap(); - let col_map = - extract_shard_ram_column_map(&config, cb.cs.num_witin as usize); + let col_map = extract_shard_ram_column_map(&config, cb.cs.num_witin as usize); let flat = col_map.to_flat(); // Basic columns should be in range diff --git a/ceno_zkvm/src/instructions/gpu/chips/shift_i.rs b/ceno_zkvm/src/instructions/gpu/chips/shift_i.rs index ce20db6a9..97c54063d 100644 --- a/ceno_zkvm/src/instructions/gpu/chips/shift_i.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/shift_i.rs @@ -1,8 +1,10 @@ use ceno_gpu::common::witgen::types::ShiftIColumnMap; use ff_ext::ExtensionField; -use crate::instructions::gpu::utils::colmap_base::{extract_rd, extract_rs1, extract_state, extract_uint_limbs}; -use crate::instructions::riscv::shift::shift_circuit_v2::ShiftImmConfig; +use crate::instructions::{ + gpu::utils::colmap_base::{extract_rd, extract_rs1, extract_state, extract_uint_limbs}, + riscv::shift::shift_circuit_v2::ShiftImmConfig, +}; /// Extract column map from a constructed ShiftImmConfig (I-type: SLLI/SRLI/SRAI). pub fn extract_shift_i_column_map( @@ -150,8 +152,19 @@ mod tests { }; let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); - let gpu_result = hal.witgen - .witgen_shift_i(&col_map, &gpu_records, &indices_u32, shard_offset, 0, 0, 0, None, None) + let gpu_result = hal + .witgen + .witgen_shift_i( + &col_map, + &gpu_records, + &indices_u32, + shard_offset, + 0, + 0, + 0, + None, + None, + ) .unwrap(); let gpu_data: Vec<::BaseField> = diff --git a/ceno_zkvm/src/instructions/gpu/chips/shift_r.rs b/ceno_zkvm/src/instructions/gpu/chips/shift_r.rs index 31a97e16f..19e8a4e58 100644 --- a/ceno_zkvm/src/instructions/gpu/chips/shift_r.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/shift_r.rs @@ -1,8 +1,12 @@ use ceno_gpu::common::witgen::types::ShiftRColumnMap; use ff_ext::ExtensionField; -use crate::instructions::gpu::utils::colmap_base::{extract_rd, extract_rs1, extract_rs2, extract_state, extract_uint_limbs}; -use crate::instructions::riscv::shift::shift_circuit_v2::ShiftRTypeConfig; +use crate::instructions::{ + gpu::utils::colmap_base::{ + extract_rd, extract_rs1, extract_rs2, extract_state, extract_uint_limbs, + }, + riscv::shift::shift_circuit_v2::ShiftRTypeConfig, +}; /// Extract column map from a constructed ShiftRTypeConfig (R-type: SLL/SRL/SRA). pub fn extract_shift_r_column_map( @@ -155,8 +159,19 @@ mod tests { }; let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); - let gpu_result = hal.witgen - .witgen_shift_r(&col_map, &gpu_records, &indices_u32, shard_offset, 0, 0, 0, None, None) + let gpu_result = hal + .witgen + .witgen_shift_r( + &col_map, + &gpu_records, + &indices_u32, + shard_offset, + 0, + 0, + 0, + None, + None, + ) .unwrap(); let gpu_data: Vec<::BaseField> = diff --git a/ceno_zkvm/src/instructions/gpu/chips/slt.rs b/ceno_zkvm/src/instructions/gpu/chips/slt.rs index 2251f0a09..51e21af9a 100644 --- a/ceno_zkvm/src/instructions/gpu/chips/slt.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/slt.rs @@ -1,8 +1,12 @@ use ceno_gpu::common::witgen::types::SltColumnMap; use ff_ext::ExtensionField; -use crate::instructions::gpu::utils::colmap_base::{extract_rd, extract_rs1, extract_rs2, extract_state, extract_uint_limbs}; -use crate::instructions::riscv::slt::slt_circuit_v2::SetLessThanConfig; +use crate::instructions::{ + gpu::utils::colmap_base::{ + extract_rd, extract_rs1, extract_rs2, extract_state, extract_uint_limbs, + }, + riscv::slt::slt_circuit_v2::SetLessThanConfig, +}; /// Extract column map from a constructed SetLessThanConfig (SLT/SLTU). pub fn extract_slt_column_map( @@ -142,8 +146,19 @@ mod tests { }; let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); - let gpu_result = hal.witgen - .witgen_slt(&col_map, &gpu_records, &indices_u32, shard_offset, 1, 0, 0, None, None) + let gpu_result = hal + .witgen + .witgen_slt( + &col_map, + &gpu_records, + &indices_u32, + shard_offset, + 1, + 0, + 0, + None, + None, + ) .unwrap(); let gpu_data: Vec<::BaseField> = diff --git a/ceno_zkvm/src/instructions/gpu/chips/slti.rs b/ceno_zkvm/src/instructions/gpu/chips/slti.rs index 052054d2d..17b4dad27 100644 --- a/ceno_zkvm/src/instructions/gpu/chips/slti.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/slti.rs @@ -1,8 +1,10 @@ use ceno_gpu::common::witgen::types::SltiColumnMap; use ff_ext::ExtensionField; -use crate::instructions::gpu::utils::colmap_base::{extract_rd, extract_rs1, extract_state, extract_uint_limbs}; -use crate::instructions::riscv::slti::slti_circuit_v2::SetLessThanImmConfig; +use crate::instructions::{ + gpu::utils::colmap_base::{extract_rd, extract_rs1, extract_state, extract_uint_limbs}, + riscv::slti::slti_circuit_v2::SetLessThanImmConfig, +}; /// Extract column map from a constructed SetLessThanImmConfig (SLTI/SLTIU). pub fn extract_slti_column_map( @@ -137,8 +139,19 @@ mod tests { }; let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); - let gpu_result = hal.witgen - .witgen_slti(&col_map, &gpu_records, &indices_u32, shard_offset, 1, 0, 0, None, None) + let gpu_result = hal + .witgen + .witgen_slti( + &col_map, + &gpu_records, + &indices_u32, + shard_offset, + 1, + 0, + 0, + None, + None, + ) .unwrap(); let gpu_data: Vec<::BaseField> = diff --git a/ceno_zkvm/src/instructions/gpu/chips/sub.rs b/ceno_zkvm/src/instructions/gpu/chips/sub.rs index ef051cfe5..f05fd8bf0 100644 --- a/ceno_zkvm/src/instructions/gpu/chips/sub.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/sub.rs @@ -1,8 +1,12 @@ use ceno_gpu::common::witgen::types::SubColumnMap; use ff_ext::ExtensionField; -use crate::instructions::gpu::utils::colmap_base::{extract_carries, extract_rd, extract_rs1, extract_rs2, extract_state, extract_uint_limbs}; -use crate::instructions::riscv::arith::ArithConfig; +use crate::instructions::{ + gpu::utils::colmap_base::{ + extract_carries, extract_rd, extract_rs1, extract_rs2, extract_state, extract_uint_limbs, + }, + riscv::arith::ArithConfig, +}; /// Extract column map from a constructed ArithConfig (SUB variant). /// @@ -143,8 +147,18 @@ mod tests { }; let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); - let gpu_result = hal.witgen - .witgen_sub(&col_map, &gpu_records, &indices_u32, shard_offset, 0, 0, None, None) + let gpu_result = hal + .witgen + .witgen_sub( + &col_map, + &gpu_records, + &indices_u32, + shard_offset, + 0, + 0, + None, + None, + ) .unwrap(); let gpu_data: Vec<::BaseField> = diff --git a/ceno_zkvm/src/instructions/gpu/chips/sw.rs b/ceno_zkvm/src/instructions/gpu/chips/sw.rs index abcda6a6e..9d94603a9 100644 --- a/ceno_zkvm/src/instructions/gpu/chips/sw.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/sw.rs @@ -1,10 +1,12 @@ use ceno_gpu::common::witgen::types::SwColumnMap; use ff_ext::ExtensionField; -use crate::instructions::gpu::utils::colmap_base::{ - extract_rs1, extract_rs2, extract_state, extract_uint_limbs, extract_write_mem, +use crate::instructions::{ + gpu::utils::colmap_base::{ + extract_rs1, extract_rs2, extract_state, extract_uint_limbs, extract_write_mem, + }, + riscv::memory::store_v2::StoreConfig, }; -use crate::instructions::riscv::memory::store_v2::StoreConfig; /// Extract column map from a constructed StoreConfig (SW variant, N_ZEROS=2). pub fn extract_sw_column_map( @@ -22,7 +24,8 @@ pub fn extract_sw_column_map( let rs2_limbs = extract_uint_limbs::(&config.rs2_read, "rs2_read"); let imm = config.imm.id as u32; let imm_sign = config.imm_sign.id as u32; - let prev_mem_val = extract_uint_limbs::(&config.prev_memory_value, "prev_memory_value"); + let prev_mem_val = + extract_uint_limbs::(&config.prev_memory_value, "prev_memory_value"); let mem_addr = extract_uint_limbs::(&config.memory_addr.addr, "memory_addr"); SwColumnMap { @@ -141,8 +144,19 @@ mod tests { }; let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); - let gpu_result = hal.witgen - .witgen_sw(&col_map, &gpu_records, &indices_u32, shard_offset, 0, 0, 0, None, None) + let gpu_result = hal + .witgen + .witgen_sw( + &col_map, + &gpu_records, + &indices_u32, + shard_offset, + 0, + 0, + 0, + None, + None, + ) .unwrap(); let gpu_data: Vec<::BaseField> = diff --git a/ceno_zkvm/src/instructions/gpu/config.rs b/ceno_zkvm/src/instructions/gpu/config.rs index 864898ee3..db3c92cb6 100644 --- a/ceno_zkvm/src/instructions/gpu/config.rs +++ b/ceno_zkvm/src/instructions/gpu/config.rs @@ -124,10 +124,9 @@ pub(crate) fn kind_has_verified_shard(kind: GpuWitgenKind) -> bool { return false; } match kind { - GpuWitgenKind::Add - | GpuWitgenKind::Sub - | GpuWitgenKind::LogicR(_) - | GpuWitgenKind::Lw => true, + GpuWitgenKind::Add | GpuWitgenKind::Sub | GpuWitgenKind::LogicR(_) | GpuWitgenKind::Lw => { + true + } #[cfg(feature = "u16limb_circuit")] GpuWitgenKind::LogicI(_) | GpuWitgenKind::Addi diff --git a/ceno_zkvm/src/instructions/gpu/dispatch.rs b/ceno_zkvm/src/instructions/gpu/dispatch.rs index 7aac6e3a4..a4b4c2d10 100644 --- a/ceno_zkvm/src/instructions/gpu/dispatch.rs +++ b/ceno_zkvm/src/instructions/gpu/dispatch.rs @@ -6,9 +6,13 @@ /// 3. Returns the GPU-generated witness + CPU-collected lk and shardram use ceno_emul::{StepIndex, StepRecord, WordAddr}; use ceno_gpu::{ - Buffer, CudaHal, bb31::CudaHalBB31, common::transpose::matrix_transpose, + Buffer, CudaHal, + bb31::CudaHalBB31, + common::{ + transpose::matrix_transpose, + witgen::types::{GpuRamRecordSlot, GpuShardRamRecord}, + }, }; -use ceno_gpu::common::witgen::types::{GpuRamRecordSlot, GpuShardRamRecord}; use ff_ext::ExtensionField; use gkr_iop::utils::lk_multiplicity::Multiplicity; use p3::field::FieldAlgebra; @@ -16,17 +20,19 @@ use std::cell::Cell; use tracing::info_span; use witness::{InstancePaddingStrategy, RowMajorMatrix}; -use super::utils::debug_compare::{ - debug_compare_final_lk, debug_compare_keccak, debug_compare_shard_ec, - debug_compare_shardram, debug_compare_witness, -}; -use super::config::{ - is_gpu_witgen_disabled, is_kind_disabled, kind_has_verified_lk, kind_has_verified_shard, +use super::{ + config::{ + is_gpu_witgen_disabled, is_kind_disabled, kind_has_verified_lk, kind_has_verified_shard, + }, + utils::debug_compare::{ + debug_compare_final_lk, debug_compare_keccak, debug_compare_shard_ec, + debug_compare_shardram, debug_compare_witness, + }, }; use crate::{ e2e::ShardContext, error::ZKVMError, - instructions::{Instruction, cpu_collect_shardram, cpu_collect_lk_and_shardram}, + instructions::{Instruction, cpu_collect_lk_and_shardram, cpu_collect_shardram}, tables::RMMCollections, witness::LkMultiplicity, }; @@ -86,13 +92,15 @@ pub use super::cache::{ }; // Re-export for external callers (structs.rs). pub use super::utils::d2h::gpu_batch_continuation_ec; -use super::utils::d2h::{ - CompactEcBuf, LkResult, RamBuf, WitResult, gpu_collect_shard_records, gpu_compact_ec_d2h, - gpu_lk_counters_to_multiplicity, gpu_witness_to_rmm, -}; -use super::cache::{ - ensure_shard_metadata_cached, read_shared_addr_count, read_shared_addr_range, - upload_shard_steps_cached, with_cached_shard_meta, with_cached_shard_steps, +use super::{ + cache::{ + ensure_shard_metadata_cached, read_shared_addr_count, read_shared_addr_range, + upload_shard_steps_cached, with_cached_shard_meta, with_cached_shard_steps, + }, + utils::d2h::{ + CompactEcBuf, LkResult, RamBuf, WitResult, gpu_collect_shard_records, gpu_compact_ec_d2h, + gpu_lk_counters_to_multiplicity, gpu_witness_to_rmm, + }, }; thread_local! { @@ -199,24 +207,24 @@ fn gpu_assign_instances_inner>( let total_instances = step_indices.len(); // Step 1: GPU fills witness matrix (+ LK counters + shard records for merged kinds) - let (gpu_witness, gpu_lk_counters, gpu_ram_slots, gpu_compact_ec, gpu_compact_addr) = info_span!("gpu_kernel").in_scope(|| { - gpu_fill_witness::( - hal, - config, - shard_ctx, - num_witin, - shard_steps, - step_indices, - kind, - ) - })?; + let (gpu_witness, gpu_lk_counters, gpu_ram_slots, gpu_compact_ec, gpu_compact_addr) = + info_span!("gpu_kernel").in_scope(|| { + gpu_fill_witness::( + hal, + config, + shard_ctx, + num_witin, + shard_steps, + step_indices, + kind, + ) + })?; // Step 2: Collect lk and shardram // Priority: GPU shard records > CPU shard records > full CPU lk and shardram let lk_multiplicity = if gpu_lk_counters.is_some() && kind_has_verified_lk(kind) { - let lk_multiplicity = info_span!("gpu_lk_d2h").in_scope(|| { - gpu_lk_counters_to_multiplicity(gpu_lk_counters.unwrap()) - })?; + let lk_multiplicity = info_span!("gpu_lk_d2h") + .in_scope(|| gpu_lk_counters_to_multiplicity(gpu_lk_counters.unwrap()))?; if gpu_compact_ec.is_none() && gpu_compact_addr.is_none() && kind_has_verified_shard(kind) { // Shared buffer path: EC records + addr_accessed accumulated on device @@ -227,17 +235,15 @@ fn gpu_assign_instances_inner>( // D2H only the active records (much smaller than full N*3 slot buffer). info_span!("gpu_ec_shard").in_scope(|| { let compact = gpu_compact_ec.unwrap(); - let compact_records = info_span!("compact_d2h") - .in_scope(|| gpu_compact_ec_d2h(&compact))?; + let compact_records = + info_span!("compact_d2h").in_scope(|| gpu_compact_ec_d2h(&compact))?; // D2H ram_slots lazily (only for debug or fallback). // Avoid the 68 MB D2H in the common case. let ram_slots_d2h = || -> Result, ZKVMError> { if let Some(ref ram_buf) = gpu_ram_slots { let sv: Vec = ram_buf.to_vec().map_err(|e| { - ZKVMError::InvalidWitness( - format!("ram_slots D2H failed: {e}").into(), - ) + ZKVMError::InvalidWitness(format!("ram_slots D2H failed: {e}").into()) })?; Ok(unsafe { let ptr = sv.as_ptr() as *const GpuRamRecordSlot; @@ -289,8 +295,13 @@ fn gpu_assign_instances_inner>( if std::env::var_os("CENO_GPU_DEBUG_COMPARE_EC").is_some() { let slots = ram_slots_d2h()?; debug_compare_shard_ec::( - &compact_records, &slots, config, shard_ctx, - shard_steps, step_indices, kind, + &compact_records, + &slots, + config, + shard_ctx, + shard_steps, + step_indices, + kind, ); } @@ -337,7 +348,16 @@ fn gpu_assign_instances_inner>( collect_lk_and_shardram::(config, shard_ctx, shard_steps, step_indices) })? }; - debug_compare_final_lk::(config, shard_ctx, num_witin, num_structural_witin, shard_steps, step_indices, kind, &lk_multiplicity)?; + debug_compare_final_lk::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + kind, + &lk_multiplicity, + )?; debug_compare_shardram::(config, shard_ctx, shard_steps, step_indices, kind)?; // Step 3: Build structural witness (just selector = ONE) @@ -352,15 +372,16 @@ fn gpu_assign_instances_inner>( raw_structural.padding_by_strategy(); // Step 4: Transpose (column-major → row-major) on GPU, then D2H copy to RowMajorMatrix - let mut raw_witin = info_span!("transpose_d2h", rows = total_instances, cols = num_witin).in_scope(|| { - gpu_witness_to_rmm::( - hal, - gpu_witness, - total_instances, - num_witin, - I::padding_strategy(), - ) - })?; + let mut raw_witin = info_span!("transpose_d2h", rows = total_instances, cols = num_witin) + .in_scope(|| { + gpu_witness_to_rmm::( + hal, + gpu_witness, + total_instances, + num_witin, + I::padding_strategy(), + ) + })?; raw_witin.padding_by_strategy(); debug_compare_witness::( config, @@ -408,7 +429,16 @@ fn gpu_fill_witness>( shard_steps: &[StepRecord], step_indices: &[StepIndex], kind: GpuWitgenKind, -) -> Result<(WitResult, Option, Option, Option, Option), ZKVMError> { +) -> Result< + ( + WitResult, + Option, + Option, + Option, + Option, + ), + ZKVMError, +> { // Upload shard_steps to GPU once (cached across ADD/LW calls within same shard). let shard_id = shard_ctx.shard_id; info_span!("upload_shard_steps") @@ -423,7 +453,13 @@ fn gpu_fill_witness>( macro_rules! split_full { ($result:expr) => {{ let full = $result?; - Ok((full.witness, Some(full.lk_counters), full.ram_slots, full.compact_ec, full.compact_addr)) + Ok(( + full.witness, + Some(full.lk_counters), + full.ram_slots, + full.compact_ec, + full.compact_addr, + )) }}; } @@ -445,22 +481,24 @@ fn gpu_fill_witness>( info_span!("hal_witgen_add").in_scope(|| { with_cached_shard_steps(|gpu_records| { with_cached_shard_meta(|shard_bufs| { - split_full!(hal.witgen - .witgen_add( - &col_map, - gpu_records, - &indices_u32, - shard_offset, - fetch_base_pc, - fetch_num_slots, - None, - Some(shard_bufs), - ) - .map_err(|e| { - ZKVMError::InvalidWitness( - format!("GPU witgen_add failed: {e}").into(), + split_full!( + hal.witgen + .witgen_add( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), ) - })) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_add failed: {e}").into(), + ) + }) + ) }) }) }) @@ -475,22 +513,24 @@ fn gpu_fill_witness>( info_span!("hal_witgen_sub").in_scope(|| { with_cached_shard_steps(|gpu_records| { with_cached_shard_meta(|shard_bufs| { - split_full!(hal.witgen - .witgen_sub( - &col_map, - gpu_records, - &indices_u32, - shard_offset, - fetch_base_pc, - fetch_num_slots, - None, - Some(shard_bufs), - ) - .map_err(|e| { - ZKVMError::InvalidWitness( - format!("GPU witgen_sub failed: {e}").into(), + split_full!( + hal.witgen + .witgen_sub( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), ) - })) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_sub failed: {e}").into(), + ) + }) + ) }) }) }) @@ -500,28 +540,31 @@ fn gpu_fill_witness>( &*(config as *const I::InstructionConfig as *const crate::instructions::riscv::logic::logic_circuit::LogicConfig) }; - let col_map = info_span!("col_map") - .in_scope(|| super::chips::logic_r::extract_logic_r_column_map(logic_config, num_witin)); + let col_map = info_span!("col_map").in_scope(|| { + super::chips::logic_r::extract_logic_r_column_map(logic_config, num_witin) + }); info_span!("hal_witgen_logic_r").in_scope(|| { with_cached_shard_steps(|gpu_records| { with_cached_shard_meta(|shard_bufs| { - split_full!(hal.witgen - .witgen_logic_r( - &col_map, - gpu_records, - &indices_u32, - shard_offset, - logic_kind, - fetch_base_pc, - fetch_num_slots, - None, - Some(shard_bufs), - ) - .map_err(|e| { - ZKVMError::InvalidWitness( - format!("GPU witgen_logic_r failed: {e}").into(), + split_full!( + hal.witgen + .witgen_logic_r( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + logic_kind, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), ) - })) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_logic_r failed: {e}").into(), + ) + }) + ) }) }) }) @@ -532,28 +575,31 @@ fn gpu_fill_witness>( &*(config as *const I::InstructionConfig as *const crate::instructions::riscv::logic_imm::logic_imm_circuit_v2::LogicConfig) }; - let col_map = info_span!("col_map") - .in_scope(|| super::chips::logic_i::extract_logic_i_column_map(logic_config, num_witin)); + let col_map = info_span!("col_map").in_scope(|| { + super::chips::logic_i::extract_logic_i_column_map(logic_config, num_witin) + }); info_span!("hal_witgen_logic_i").in_scope(|| { with_cached_shard_steps(|gpu_records| { with_cached_shard_meta(|shard_bufs| { - split_full!(hal.witgen - .witgen_logic_i( - &col_map, - gpu_records, - &indices_u32, - shard_offset, - logic_kind, - fetch_base_pc, - fetch_num_slots, - None, - Some(shard_bufs), - ) - .map_err(|e| { - ZKVMError::InvalidWitness( - format!("GPU witgen_logic_i failed: {e}").into(), + split_full!( + hal.witgen + .witgen_logic_i( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + logic_kind, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), ) - })) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_logic_i failed: {e}").into(), + ) + }) + ) }) }) }) @@ -569,22 +615,24 @@ fn gpu_fill_witness>( info_span!("hal_witgen_addi").in_scope(|| { with_cached_shard_steps(|gpu_records| { with_cached_shard_meta(|shard_bufs| { - split_full!(hal.witgen - .witgen_addi( - &col_map, - gpu_records, - &indices_u32, - shard_offset, - fetch_base_pc, - fetch_num_slots, - None, - Some(shard_bufs), - ) - .map_err(|e| { - ZKVMError::InvalidWitness( - format!("GPU witgen_addi failed: {e}").into(), + split_full!( + hal.witgen + .witgen_addi( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), ) - })) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_addi failed: {e}").into(), + ) + }) + ) }) }) }) @@ -600,20 +648,24 @@ fn gpu_fill_witness>( info_span!("hal_witgen_lui").in_scope(|| { with_cached_shard_steps(|gpu_records| { with_cached_shard_meta(|shard_bufs| { - split_full!(hal.witgen - .witgen_lui( - &col_map, - gpu_records, - &indices_u32, - shard_offset, - fetch_base_pc, - fetch_num_slots, - None, - Some(shard_bufs), - ) - .map_err(|e| { - ZKVMError::InvalidWitness(format!("GPU witgen_lui failed: {e}").into()) - })) + split_full!( + hal.witgen + .witgen_lui( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), + ) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_lui failed: {e}").into(), + ) + }) + ) }) }) }) @@ -624,27 +676,30 @@ fn gpu_fill_witness>( &*(config as *const I::InstructionConfig as *const crate::instructions::riscv::auipc::AuipcConfig) }; - let col_map = info_span!("col_map") - .in_scope(|| super::chips::auipc::extract_auipc_column_map(auipc_config, num_witin)); + let col_map = info_span!("col_map").in_scope(|| { + super::chips::auipc::extract_auipc_column_map(auipc_config, num_witin) + }); info_span!("hal_witgen_auipc").in_scope(|| { with_cached_shard_steps(|gpu_records| { with_cached_shard_meta(|shard_bufs| { - split_full!(hal.witgen - .witgen_auipc( - &col_map, - gpu_records, - &indices_u32, - shard_offset, - fetch_base_pc, - fetch_num_slots, - None, - Some(shard_bufs), - ) - .map_err(|e| { - ZKVMError::InvalidWitness( - format!("GPU witgen_auipc failed: {e}").into(), + split_full!( + hal.witgen + .witgen_auipc( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), ) - })) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_auipc failed: {e}").into(), + ) + }) + ) }) }) }) @@ -660,20 +715,24 @@ fn gpu_fill_witness>( info_span!("hal_witgen_jal").in_scope(|| { with_cached_shard_steps(|gpu_records| { with_cached_shard_meta(|shard_bufs| { - split_full!(hal.witgen - .witgen_jal( - &col_map, - gpu_records, - &indices_u32, - shard_offset, - fetch_base_pc, - fetch_num_slots, - None, - Some(shard_bufs), - ) - .map_err(|e| { - ZKVMError::InvalidWitness(format!("GPU witgen_jal failed: {e}").into()) - })) + split_full!( + hal.witgen + .witgen_jal( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), + ) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_jal failed: {e}").into(), + ) + }) + ) }) }) }) @@ -686,28 +745,31 @@ fn gpu_fill_witness>( E, >) }; - let col_map = info_span!("col_map") - .in_scope(|| super::chips::shift_r::extract_shift_r_column_map(shift_config, num_witin)); + let col_map = info_span!("col_map").in_scope(|| { + super::chips::shift_r::extract_shift_r_column_map(shift_config, num_witin) + }); info_span!("hal_witgen_shift_r").in_scope(|| { with_cached_shard_steps(|gpu_records| { with_cached_shard_meta(|shard_bufs| { - split_full!(hal.witgen - .witgen_shift_r( - &col_map, - gpu_records, - &indices_u32, - shard_offset, - shift_kind, - fetch_base_pc, - fetch_num_slots, - None, - Some(shard_bufs), - ) - .map_err(|e| { - ZKVMError::InvalidWitness( - format!("GPU witgen_shift_r failed: {e}").into(), + split_full!( + hal.witgen + .witgen_shift_r( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + shift_kind, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), ) - })) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_shift_r failed: {e}").into(), + ) + }) + ) }) }) }) @@ -720,28 +782,31 @@ fn gpu_fill_witness>( E, >) }; - let col_map = info_span!("col_map") - .in_scope(|| super::chips::shift_i::extract_shift_i_column_map(shift_config, num_witin)); + let col_map = info_span!("col_map").in_scope(|| { + super::chips::shift_i::extract_shift_i_column_map(shift_config, num_witin) + }); info_span!("hal_witgen_shift_i").in_scope(|| { with_cached_shard_steps(|gpu_records| { with_cached_shard_meta(|shard_bufs| { - split_full!(hal.witgen - .witgen_shift_i( - &col_map, - gpu_records, - &indices_u32, - shard_offset, - shift_kind, - fetch_base_pc, - fetch_num_slots, - None, - Some(shard_bufs), - ) - .map_err(|e| { - ZKVMError::InvalidWitness( - format!("GPU witgen_shift_i failed: {e}").into(), + split_full!( + hal.witgen + .witgen_shift_i( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + shift_kind, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), ) - })) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_shift_i failed: {e}").into(), + ) + }) + ) }) }) }) @@ -757,23 +822,25 @@ fn gpu_fill_witness>( info_span!("hal_witgen_slt").in_scope(|| { with_cached_shard_steps(|gpu_records| { with_cached_shard_meta(|shard_bufs| { - split_full!(hal.witgen - .witgen_slt( - &col_map, - gpu_records, - &indices_u32, - shard_offset, - is_signed, - fetch_base_pc, - fetch_num_slots, - None, - Some(shard_bufs), - ) - .map_err(|e| { - ZKVMError::InvalidWitness( - format!("GPU witgen_slt failed: {e}").into(), + split_full!( + hal.witgen + .witgen_slt( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + is_signed, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), ) - })) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_slt failed: {e}").into(), + ) + }) + ) }) }) }) @@ -789,23 +856,25 @@ fn gpu_fill_witness>( info_span!("hal_witgen_slti").in_scope(|| { with_cached_shard_steps(|gpu_records| { with_cached_shard_meta(|shard_bufs| { - split_full!(hal.witgen - .witgen_slti( - &col_map, - gpu_records, - &indices_u32, - shard_offset, - is_signed, - fetch_base_pc, - fetch_num_slots, - None, - Some(shard_bufs), - ) - .map_err(|e| { - ZKVMError::InvalidWitness( - format!("GPU witgen_slti failed: {e}").into(), + split_full!( + hal.witgen + .witgen_slti( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + is_signed, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), ) - })) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_slti failed: {e}").into(), + ) + }) + ) }) }) }) @@ -824,23 +893,25 @@ fn gpu_fill_witness>( info_span!("hal_witgen_branch_eq").in_scope(|| { with_cached_shard_steps(|gpu_records| { with_cached_shard_meta(|shard_bufs| { - split_full!(hal.witgen - .witgen_branch_eq( - &col_map, - gpu_records, - &indices_u32, - shard_offset, - is_beq, - fetch_base_pc, - fetch_num_slots, - None, - Some(shard_bufs), - ) - .map_err(|e| { - ZKVMError::InvalidWitness( - format!("GPU witgen_branch_eq failed: {e}").into(), + split_full!( + hal.witgen + .witgen_branch_eq( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + is_beq, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), ) - })) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_branch_eq failed: {e}").into(), + ) + }) + ) }) }) }) @@ -859,23 +930,25 @@ fn gpu_fill_witness>( info_span!("hal_witgen_branch_cmp").in_scope(|| { with_cached_shard_steps(|gpu_records| { with_cached_shard_meta(|shard_bufs| { - split_full!(hal.witgen - .witgen_branch_cmp( - &col_map, - gpu_records, - &indices_u32, - shard_offset, - is_signed, - fetch_base_pc, - fetch_num_slots, - None, - Some(shard_bufs), - ) - .map_err(|e| { - ZKVMError::InvalidWitness( - format!("GPU witgen_branch_cmp failed: {e}").into(), + split_full!( + hal.witgen + .witgen_branch_cmp( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + is_signed, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), ) - })) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_branch_cmp failed: {e}").into(), + ) + }) + ) }) }) }) @@ -891,20 +964,24 @@ fn gpu_fill_witness>( info_span!("hal_witgen_jalr").in_scope(|| { with_cached_shard_steps(|gpu_records| { with_cached_shard_meta(|shard_bufs| { - split_full!(hal.witgen - .witgen_jalr( - &col_map, - gpu_records, - &indices_u32, - shard_offset, - fetch_base_pc, - fetch_num_slots, - None, - Some(shard_bufs), - ) - .map_err(|e| { - ZKVMError::InvalidWitness(format!("GPU witgen_jalr failed: {e}").into()) - })) + split_full!( + hal.witgen + .witgen_jalr( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), + ) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_jalr failed: {e}").into(), + ) + }) + ) }) }) }) @@ -921,21 +998,25 @@ fn gpu_fill_witness>( info_span!("hal_witgen_sw").in_scope(|| { with_cached_shard_steps(|gpu_records| { with_cached_shard_meta(|shard_bufs| { - split_full!(hal.witgen - .witgen_sw( - &col_map, - gpu_records, - &indices_u32, - shard_offset, - mem_max_bits, - fetch_base_pc, - fetch_num_slots, - None, - Some(shard_bufs), - ) - .map_err(|e| { - ZKVMError::InvalidWitness(format!("GPU witgen_sw failed: {e}").into()) - })) + split_full!( + hal.witgen + .witgen_sw( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + mem_max_bits, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), + ) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_sw failed: {e}").into(), + ) + }) + ) }) }) }) @@ -952,21 +1033,25 @@ fn gpu_fill_witness>( info_span!("hal_witgen_sh").in_scope(|| { with_cached_shard_steps(|gpu_records| { with_cached_shard_meta(|shard_bufs| { - split_full!(hal.witgen - .witgen_sh( - &col_map, - gpu_records, - &indices_u32, - shard_offset, - mem_max_bits, - fetch_base_pc, - fetch_num_slots, - None, - Some(shard_bufs), - ) - .map_err(|e| { - ZKVMError::InvalidWitness(format!("GPU witgen_sh failed: {e}").into()) - })) + split_full!( + hal.witgen + .witgen_sh( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + mem_max_bits, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), + ) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_sh failed: {e}").into(), + ) + }) + ) }) }) }) @@ -983,21 +1068,25 @@ fn gpu_fill_witness>( info_span!("hal_witgen_sb").in_scope(|| { with_cached_shard_steps(|gpu_records| { with_cached_shard_meta(|shard_bufs| { - split_full!(hal.witgen - .witgen_sb( - &col_map, - gpu_records, - &indices_u32, - shard_offset, - mem_max_bits, - fetch_base_pc, - fetch_num_slots, - None, - Some(shard_bufs), - ) - .map_err(|e| { - ZKVMError::InvalidWitness(format!("GPU witgen_sb failed: {e}").into()) - })) + split_full!( + hal.witgen + .witgen_sb( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + mem_max_bits, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), + ) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_sb failed: {e}").into(), + ) + }) + ) }) }) }) @@ -1025,25 +1114,27 @@ fn gpu_fill_witness>( info_span!("hal_witgen_load_sub").in_scope(|| { with_cached_shard_steps(|gpu_records| { with_cached_shard_meta(|shard_bufs| { - split_full!(hal.witgen - .witgen_load_sub( - &col_map, - gpu_records, - &indices_u32, - shard_offset, - load_width, - is_signed, - mem_max_bits, - fetch_base_pc, - fetch_num_slots, - None, - Some(shard_bufs), - ) - .map_err(|e| { - ZKVMError::InvalidWitness( - format!("GPU witgen_load_sub failed: {e}").into(), + split_full!( + hal.witgen + .witgen_load_sub( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + load_width, + is_signed, + mem_max_bits, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), ) - })) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_load_sub failed: {e}").into(), + ) + }) + ) }) }) }) @@ -1054,28 +1145,31 @@ fn gpu_fill_witness>( &*(config as *const I::InstructionConfig as *const crate::instructions::riscv::mulh::mulh_circuit_v2::MulhConfig) }; - let col_map = info_span!("col_map") - .in_scope(|| super::chips::mul::extract_mul_column_map(mul_config, num_witin, mul_kind)); + let col_map = info_span!("col_map").in_scope(|| { + super::chips::mul::extract_mul_column_map(mul_config, num_witin, mul_kind) + }); info_span!("hal_witgen_mul").in_scope(|| { with_cached_shard_steps(|gpu_records| { with_cached_shard_meta(|shard_bufs| { - split_full!(hal.witgen - .witgen_mul( - &col_map, - gpu_records, - &indices_u32, - shard_offset, - mul_kind, - fetch_base_pc, - fetch_num_slots, - None, - Some(shard_bufs), - ) - .map_err(|e| { - ZKVMError::InvalidWitness( - format!("GPU witgen_mul failed: {e}").into(), + split_full!( + hal.witgen + .witgen_mul( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + mul_kind, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), ) - })) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_mul failed: {e}").into(), + ) + }) + ) }) }) }) @@ -1091,23 +1185,25 @@ fn gpu_fill_witness>( info_span!("hal_witgen_div").in_scope(|| { with_cached_shard_steps(|gpu_records| { with_cached_shard_meta(|shard_bufs| { - split_full!(hal.witgen - .witgen_div( - &col_map, - gpu_records, - &indices_u32, - shard_offset, - div_kind, - fetch_base_pc, - fetch_num_slots, - None, - Some(shard_bufs), - ) - .map_err(|e| { - ZKVMError::InvalidWitness( - format!("GPU witgen_div failed: {e}").into(), + split_full!( + hal.witgen + .witgen_div( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + div_kind, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), ) - })) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_div failed: {e}").into(), + ) + }) + ) }) }) }) @@ -1129,21 +1225,25 @@ fn gpu_fill_witness>( info_span!("hal_witgen_lw").in_scope(|| { with_cached_shard_steps(|gpu_records| { with_cached_shard_meta(|shard_bufs| { - split_full!(hal.witgen - .witgen_lw( - &col_map, - gpu_records, - &indices_u32, - shard_offset, - mem_max_bits, - fetch_base_pc, - fetch_num_slots, - None, - Some(shard_bufs), - ) - .map_err(|e| { - ZKVMError::InvalidWitness(format!("GPU witgen_lw failed: {e}").into()) - })) + split_full!( + hal.witgen + .witgen_lw( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + mem_max_bits, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), + ) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_lw failed: {e}").into(), + ) + }) + ) }) }) }) @@ -1235,10 +1335,7 @@ pub fn gpu_assign_keccak_instances( } let num_instances = step_indices.len(); - tracing::debug!( - "[GPU witgen] keccak with {} instances", - num_instances - ); + tracing::debug!("[GPU witgen] keccak with {} instances", num_instances); info_span!("gpu_witgen_keccak", n = num_instances).in_scope(|| { gpu_assign_keccak_inner::( @@ -1276,14 +1373,13 @@ fn gpu_assign_keccak_inner( .in_scope(|| super::chips::keccak::extract_keccak_column_map(config, num_witin)); // Step 2: Pack instances - let packed_instances = info_span!("pack_instances") - .in_scope(|| { - super::chips::keccak::pack_keccak_instances( - steps, - step_indices, - &shard_ctx.syscall_witnesses, - ) - }); + let packed_instances = info_span!("pack_instances").in_scope(|| { + super::chips::keccak::pack_keccak_instances( + steps, + step_indices, + &shard_ctx.syscall_witnesses, + ) + }); // Step 3: Compute fetch params let (fetch_base_pc, fetch_num_slots) = compute_fetch_params(steps, step_indices); @@ -1302,21 +1398,20 @@ fn gpu_assign_keccak_inner( // Step 5: Launch GPU kernel let gpu_result = info_span!("gpu_kernel").in_scope(|| { with_cached_shard_meta(|shard_bufs| { - hal.witgen.witgen_keccak( - &col_map, - &packed_instances, - num_padded_rows, - shard_ctx.current_shard_offset_cycle(), - fetch_base_pc, - fetch_num_slots, - None, - Some(shard_bufs), - ) - .map_err(|e| { - ZKVMError::InvalidWitness( - format!("GPU witgen_keccak failed: {e}").into(), + hal.witgen + .witgen_keccak( + &col_map, + &packed_instances, + num_padded_rows, + shard_ctx.current_shard_offset_cycle(), + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), ) - }) + .map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU witgen_keccak failed: {e}").into()) + }) }) })?; @@ -1344,8 +1439,8 @@ fn gpu_assign_keccak_inner( // in shared buffers across all kernel invocations. Skip per-kernel D2H. } else if let Some(compact) = gpu_result.compact_ec { info_span!("gpu_ec_shard").in_scope(|| { - let compact_records = info_span!("compact_d2h") - .in_scope(|| gpu_compact_ec_d2h(&compact))?; + let compact_records = + info_span!("compact_d2h").in_scope(|| gpu_compact_ec_d2h(&compact))?; // D2H compact addr_accessed info_span!("compact_addr_d2h").in_scope(|| -> Result<(), ZKVMError> { @@ -1386,48 +1481,45 @@ fn gpu_assign_keccak_inner( } // Step 8: Transpose GPU witness (column-major -> row-major) + D2H - let raw_witin = info_span!("transpose_d2h", rows = num_padded_rows, cols = num_witin).in_scope(|| { - let mut rmm_buffer = hal - .alloc_elems_on_device(num_padded_rows * num_witin, false, None) - .map_err(|e| { - ZKVMError::InvalidWitness( - format!("GPU alloc for transpose failed: {e}").into(), - ) - })?; - matrix_transpose::( - &hal.inner, - &mut rmm_buffer, - &gpu_result.witness.device_buffer, - num_padded_rows, - num_witin, - ) - .map_err(|e| { - ZKVMError::InvalidWitness(format!("GPU transpose failed: {e}").into()) - })?; + let raw_witin = info_span!("transpose_d2h", rows = num_padded_rows, cols = num_witin) + .in_scope(|| { + let mut rmm_buffer = hal + .alloc_elems_on_device(num_padded_rows * num_witin, false, None) + .map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU alloc for transpose failed: {e}").into()) + })?; + matrix_transpose::( + &hal.inner, + &mut rmm_buffer, + &gpu_result.witness.device_buffer, + num_padded_rows, + num_witin, + ) + .map_err(|e| ZKVMError::InvalidWitness(format!("GPU transpose failed: {e}").into()))?; - let gpu_data: Vec<::BaseField> = - rmm_buffer.to_vec().map_err(|e| { - ZKVMError::InvalidWitness(format!("GPU D2H copy failed: {e}").into()) - })?; + let gpu_data: Vec<::BaseField> = + rmm_buffer.to_vec().map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU D2H copy failed: {e}").into()) + })?; - // Safety: BabyBear is the only supported GPU field, and E::BaseField must match - let data: Vec = unsafe { - let mut data = std::mem::ManuallyDrop::new(gpu_data); - Vec::from_raw_parts( - data.as_mut_ptr() as *mut E::BaseField, - data.len(), - data.capacity(), - ) - }; + // Safety: BabyBear is the only supported GPU field, and E::BaseField must match + let data: Vec = unsafe { + let mut data = std::mem::ManuallyDrop::new(gpu_data); + Vec::from_raw_parts( + data.as_mut_ptr() as *mut E::BaseField, + data.len(), + data.capacity(), + ) + }; - Ok::<_, ZKVMError>(RowMajorMatrix::::from_values_with_rotation( - data, - num_witin, - rotation, - num_instances, - InstancePaddingStrategy::Default, - )) - })?; + Ok::<_, ZKVMError>(RowMajorMatrix::::from_values_with_rotation( + data, + num_witin, + rotation, + num_instances, + InstancePaddingStrategy::Default, + )) + })?; // Step 9: Build structural witness on CPU with selector indices let raw_structural = info_span!("structural_witness").in_scope(|| { @@ -1463,26 +1555,19 @@ fn gpu_assign_keccak_inner( let sel_first_indices = sel_first.sparse_indices(); let sel_last_indices = sel_last.sparse_indices(); - let sel_all_indices = config - .layout - .selector_type_layout - .sel_all - .sparse_indices(); + let sel_all_indices = config.layout.selector_type_layout.sel_all.sparse_indices(); // Only set selectors for real instances, not padding ones. for instance_chunk in raw_structural.iter_mut().take(num_instances) { // instance_chunk is a &mut [F] of size 32 * num_structural_witin for &idx in sel_first_indices { - instance_chunk[idx * num_structural_witin + sel_first_id] = - E::BaseField::ONE; + instance_chunk[idx * num_structural_witin + sel_first_id] = E::BaseField::ONE; } for &idx in sel_last_indices { - instance_chunk[idx * num_structural_witin + sel_last_id] = - E::BaseField::ONE; + instance_chunk[idx * num_structural_witin + sel_last_id] = E::BaseField::ONE; } for &idx in sel_all_indices { - instance_chunk[idx * num_structural_witin + sel_all_id] = - E::BaseField::ONE; + instance_chunk[idx * num_structural_witin + sel_all_id] = E::BaseField::ONE; } } raw_structural.padding_by_strategy(); diff --git a/ceno_zkvm/src/instructions/gpu/mod.rs b/ceno_zkvm/src/instructions/gpu/mod.rs index aa8eae109..e18382c1d 100644 --- a/ceno_zkvm/src/instructions/gpu/mod.rs +++ b/ceno_zkvm/src/instructions/gpu/mod.rs @@ -1,24 +1,23 @@ -pub mod utils; -pub mod chips; #[cfg(feature = "gpu")] #[cfg(feature = "gpu")] +pub mod cache; +pub mod chips; #[cfg(feature = "gpu")] -pub mod config; #[cfg(feature = "gpu")] #[cfg(feature = "gpu")] -pub mod cache; +pub mod config; #[cfg(feature = "gpu")] pub mod dispatch; +pub mod utils; #[cfg(test)] mod tests { - use super::utils::*; + use crate::{ circuit_builder::{CircuitBuilder, ConstraintSystem}, e2e::ShardContext, instructions::{ - Instruction, cpu_assign_instances, cpu_collect_shardram, - cpu_collect_lk_and_shardram, + Instruction, cpu_assign_instances, cpu_collect_lk_and_shardram, cpu_collect_shardram, riscv::{ AddInstruction, JalInstruction, JalrInstruction, LwInstruction, SbInstruction, branch::{BeqInstruction, BltInstruction}, @@ -32,7 +31,6 @@ mod tests { }, }, structs::ProgramParams, - witness::LkMultiplicity, }; use ceno_emul::{ ByteAddr, Change, InsnKind, PC_STEP_SIZE, ReadOp, StepRecord, WordAddr, WriteOp, @@ -102,8 +100,7 @@ mod tests { let mut collect_ctx = ShardContext::default(); let actual_lk = - cpu_collect_shardram::(config, &mut collect_ctx, steps, &indices) - .unwrap(); + cpu_collect_shardram::(config, &mut collect_ctx, steps, &indices).unwrap(); assert_eq!( expected_lk[LookupTable::Instruction as usize], @@ -171,12 +168,12 @@ mod tests { let rd = 2 + i; let rs1 = 8 + i; let rs2 = 16 + i; - let lhs = 10 + i as u32; - let rhs = 100 + i as u32; + let lhs = 10 + i; + let rhs = 100 + i; let insn = encode_rv32(InsnKind::ADD, rs1, rs2, rd, 0); StepRecord::new_r_instruction( 4 + (i as u64) * 4, - ByteAddr(0x1000 + i as u32 * 4), + ByteAddr(0x1000 + i * 4), insn, lhs, rhs, @@ -205,12 +202,12 @@ mod tests { let rd = 2 + i; let rs1 = 8 + i; let rs2 = 16 + i; - let lhs = 0xdead_0000 | i as u32; - let rhs = 0x00ff_ff00 | ((i as u32) << 8); + let lhs = 0xdead_0000 | i; + let rhs = 0x00ff_ff00 | (i << 8); let insn = encode_rv32(InsnKind::AND, rs1, rs2, rd, 0); StepRecord::new_r_instruction( 4 + (i as u64) * 4, - ByteAddr(0x2000 + i as u32 * 4), + ByteAddr(0x2000 + i * 4), insn, lhs, rhs, @@ -239,12 +236,12 @@ mod tests { let rd = 2 + i; let rs1 = 8 + i; let rs2 = 16 + i; - let lhs = 10 + i as u32; - let rhs = 100 + i as u32; + let lhs = 10 + i; + let rhs = 100 + i; let insn = encode_rv32(InsnKind::ADD, rs1, rs2, rd, 0); StepRecord::new_r_instruction( 84 + (i as u64) * 4, - ByteAddr(0x5000 + i as u32 * 4), + ByteAddr(0x5000 + i * 4), insn, lhs, rhs, @@ -273,12 +270,12 @@ mod tests { let rd = 2 + i; let rs1 = 8 + i; let rs2 = 16 + i; - let lhs = 0xdead_0000 | i as u32; - let rhs = 0x00ff_ff00 | ((i as u32) << 8); + let lhs = 0xdead_0000 | i; + let rhs = 0x00ff_ff00 | (i << 8); let insn = encode_rv32(InsnKind::AND, rs1, rs2, rd, 0); StepRecord::new_r_instruction( 100 + (i as u64) * 4, - ByteAddr(0x5100 + i as u32 * 4), + ByteAddr(0x5100 + i * 4), insn, lhs, rhs, @@ -306,10 +303,10 @@ mod tests { .map(|i| { let rd = 2 + i; let rs1 = 8 + i; - let rs1_val = 0x1000u32 + (i as u32) * 16; + let rs1_val = 0x1000u32 + i * 16; let imm = (i as i32) * 4 - 4; let mem_addr = rs1_val.wrapping_add_signed(imm); - let mem_val = 0xabc0_0000 | i as u32; + let mem_val = 0xabc0_0000 | i; let insn = encode_rv32(InsnKind::LW, rs1, 0, rd, imm); let mem_read = ReadOp { addr: WordAddr::from(ByteAddr(mem_addr)), @@ -318,7 +315,7 @@ mod tests { }; StepRecord::new_im_instruction( 4 + (i as u64) * 4, - ByteAddr(0x3000 + i as u32 * 4), + ByteAddr(0x3000 + i * 4), insn, rs1_val, Change::new(0, mem_val), @@ -346,10 +343,10 @@ mod tests { .map(|i| { let rd = 2 + i; let rs1 = 8 + i; - let rs1_val = 0x1400u32 + (i as u32) * 16; + let rs1_val = 0x1400u32 + i * 16; let imm = (i as i32) * 4 - 4; let mem_addr = rs1_val.wrapping_add_signed(imm); - let mem_val = 0xabd0_0000 | i as u32; + let mem_val = 0xabd0_0000 | i; let insn = encode_rv32(InsnKind::LW, rs1, 0, rd, imm); let mem_read = ReadOp { addr: WordAddr::from(ByteAddr(mem_addr)), @@ -358,7 +355,7 @@ mod tests { }; StepRecord::new_im_instruction( 116 + (i as u64) * 4, - ByteAddr(0x5200 + i as u32 * 4), + ByteAddr(0x5200 + i * 4), insn, rs1_val, Change::new(0, mem_val), diff --git a/ceno_zkvm/src/instructions/gpu/utils/colmap_base.rs b/ceno_zkvm/src/instructions/gpu/utils/colmap_base.rs index 3d48e6c08..34445436b 100644 --- a/ceno_zkvm/src/instructions/gpu/utils/colmap_base.rs +++ b/ceno_zkvm/src/instructions/gpu/utils/colmap_base.rs @@ -30,7 +30,9 @@ pub fn extract_state(vm: &StateInOut) -> (u32, u32) { pub fn extract_state_branching(vm: &StateInOut) -> (u32, u32, u32) { ( vm.pc.id as u32, - vm.next_pc.expect("branching StateInOut must have next_pc").id as u32, + vm.next_pc + .expect("branching StateInOut must have next_pc") + .id as u32, vm.ts.id as u32, ) } @@ -95,7 +97,12 @@ pub fn extract_write_mem(mem: &WriteMEM) -> (u32, [u32; 2]) { #[inline] pub fn extract_lt_diff(lt: &AssertLtConfig) -> [u32; 2] { let d = <.0.diff; - assert_eq!(d.len(), 2, "Expected 2 AssertLt diff limbs, got {}", d.len()); + assert_eq!( + d.len(), + 2, + "Expected 2 AssertLt diff limbs, got {}", + d.len() + ); [d[0].id as u32, d[1].id as u32] } @@ -105,8 +112,15 @@ pub fn extract_uint_limbs, label: &str, ) -> [u32; N] { - let limbs = u.wits_in().unwrap_or_else(|| panic!("{label} should have WitIn limbs")); - assert_eq!(limbs.len(), N, "Expected {N} limbs for {label}, got {}", limbs.len()); + let limbs = u + .wits_in() + .unwrap_or_else(|| panic!("{label} should have WitIn limbs")); + assert_eq!( + limbs.len(), + N, + "Expected {N} limbs for {label}, got {}", + limbs.len() + ); std::array::from_fn(|i| limbs[i].id as u32) } @@ -120,7 +134,12 @@ pub fn extract_carries::BaseField, ->; +pub(crate) type WitBuf = + ceno_gpu::common::BufferImpl<'static, ::BaseField>; pub(crate) type LkBuf = ceno_gpu::common::BufferImpl<'static, u32>; pub(crate) type RamBuf = ceno_gpu::common::BufferImpl<'static, u32>; pub(crate) type WitResult = ceno_gpu::common::witgen::types::GpuWitnessResult; @@ -38,10 +40,7 @@ pub(crate) type CompactEcBuf = ceno_gpu::common::witgen::types::CompactEcResult< /// /// Reconstructs BTreeMap read/write records and addr_accessed from the GPU output, /// replacing the previous `collect_shardram()` CPU loop. -pub(crate) fn gpu_collect_shard_records( - shard_ctx: &mut ShardContext, - slots: &[GpuRamRecordSlot], -) { +pub(crate) fn gpu_collect_shard_records(shard_ctx: &mut ShardContext, slots: &[GpuRamRecordSlot]) { let current_shard_id = shard_ctx.shard_id; for slot in slots { @@ -61,7 +60,11 @@ pub(crate) fn gpu_collect_shard_records( _ => continue, }; let has_prev_value = slot.flags & (1 << 3) != 0; - let prev_value = if has_prev_value { Some(slot.prev_value) } else { None }; + let prev_value = if has_prev_value { + Some(slot.prev_value) + } else { + None + }; let addr = WordAddr(slot.addr); // Insert read record (bit 1) @@ -107,9 +110,10 @@ pub(crate) fn gpu_compact_ec_d2h( compact: &CompactEcResult, ) -> Result, ZKVMError> { // D2H the count (1 u32) - let count_vec: Vec = compact.count_buf.to_vec().map_err(|e| { - ZKVMError::InvalidWitness(format!("compact_count D2H failed: {e}").into()) - })?; + let count_vec: Vec = compact + .count_buf + .to_vec() + .map_err(|e| ZKVMError::InvalidWitness(format!("compact_count D2H failed: {e}").into()))?; let count = count_vec[0] as usize; if count == 0 { return Ok(vec![]); @@ -118,15 +122,20 @@ pub(crate) fn gpu_compact_ec_d2h( // Partial D2H: only transfer the first `count` records (not the full allocation) let record_u32s = std::mem::size_of::() / 4; // 26 let total_u32s = count * record_u32s; - let buf_vec: Vec = compact.buffer.to_vec_n(total_u32s).map_err(|e| { - ZKVMError::InvalidWitness(format!("compact_out D2H failed: {e}").into()) - })?; + let buf_vec: Vec = compact + .buffer + .to_vec_n(total_u32s) + .map_err(|e| ZKVMError::InvalidWitness(format!("compact_out D2H failed: {e}").into()))?; let records: Vec = unsafe { let ptr = buf_vec.as_ptr() as *const GpuShardRamRecord; std::slice::from_raw_parts(ptr, count).to_vec() }; - tracing::debug!("GPU EC compact D2H: {} active records ({} bytes)", count, total_u32s * 4); + tracing::debug!( + "GPU EC compact D2H: {} active records ({} bytes)", + count, + total_u32s * 4 + ); Ok(records) } @@ -167,11 +176,9 @@ pub fn gpu_batch_continuation_ec( } // GPU batch EC computation - let result = info_span!("gpu_batch_ec", n = total).in_scope(|| { - hal.witgen.batch_continuation_ec(&gpu_records) - }).map_err(|e| { - ZKVMError::InvalidWitness(format!("GPU batch EC failed: {e}").into()) - })?; + let result = info_span!("gpu_batch_ec", n = total) + .in_scope(|| hal.witgen.batch_continuation_ec(&gpu_records)) + .map_err(|e| ZKVMError::InvalidWitness(format!("GPU batch EC failed: {e}").into()))?; // Convert back to ShardRamInput, split into writes and reads let mut write_inputs = Vec::with_capacity(write_records.len()); @@ -246,7 +253,9 @@ pub(crate) fn gpu_shard_ram_record_to_ec_point( } } -pub(crate) fn gpu_lk_counters_to_multiplicity(counters: LkResult) -> Result, ZKVMError> { +pub(crate) fn gpu_lk_counters_to_multiplicity( + counters: LkResult, +) -> Result, ZKVMError> { let mut tables: [FxHashMap; 8] = Default::default(); // Dynamic: D2H + direct FxHashMap construction (no LkMultiplicity) @@ -278,9 +287,7 @@ pub(crate) fn gpu_lk_counters_to_multiplicity(counters: LkResult) -> Result = buf.to_vec().map_err(|e| { - ZKVMError::InvalidWitness( - format!("GPU {:?} lk D2H failed: {e}", table).into(), - ) + ZKVMError::InvalidWitness(format!("GPU {:?} lk D2H failed: {e}", table).into()) })?; let nnz = counts.iter().filter(|&&c| c != 0).count(); let map = &mut tables[table as usize]; diff --git a/ceno_zkvm/src/instructions/gpu/utils/debug_compare.rs b/ceno_zkvm/src/instructions/gpu/utils/debug_compare.rs index 43f3fe91d..3af3b3856 100644 --- a/ceno_zkvm/src/instructions/gpu/utils/debug_compare.rs +++ b/ceno_zkvm/src/instructions/gpu/utils/debug_compare.rs @@ -17,7 +17,7 @@ use witness::RowMajorMatrix; use crate::{ e2e::ShardContext, error::ZKVMError, - instructions::{Instruction, cpu_collect_shardram, cpu_collect_lk_and_shardram}, + instructions::{Instruction, cpu_collect_lk_and_shardram, cpu_collect_shardram}, }; use crate::instructions::gpu::dispatch::{GpuWitgenKind, set_force_cpu_path}; @@ -51,7 +51,11 @@ pub(crate) fn debug_compare_final_lk>( Ok(()) } -pub(crate) fn log_lk_diff(kind: GpuWitgenKind, cpu_lk: &Multiplicity, actual_lk: &Multiplicity) { +pub(crate) fn log_lk_diff( + kind: GpuWitgenKind, + cpu_lk: &Multiplicity, + actual_lk: &Multiplicity, +) { let limit = std::env::var("CENO_GPU_DEBUG_COMPARE_LK_LIMIT") .ok() .and_then(|v| v.parse::().ok()) @@ -176,8 +180,7 @@ pub(crate) fn debug_compare_shardram>( let _ = cpu_collect_lk_and_shardram::(config, &mut cpu_ctx, shard_steps, step_indices)?; let mut mixed_ctx = shard_ctx.new_empty_like(); - let _ = - cpu_collect_shardram::(config, &mut mixed_ctx, shard_steps, step_indices)?; + let _ = cpu_collect_shardram::(config, &mut mixed_ctx, shard_steps, step_indices)?; let cpu_addr = cpu_ctx.get_addr_accessed(); let mixed_addr = mixed_ctx.get_addr_accessed(); @@ -231,8 +234,10 @@ pub(crate) fn debug_compare_shard_ec>( return; } - use crate::scheme::septic_curve::{SepticExtension, SepticPoint}; - use crate::tables::{ECPoint, ShardRamRecord}; + use crate::{ + scheme::septic_curve::{SepticExtension, SepticPoint}, + tables::{ECPoint, ShardRamRecord}, + }; use ff_ext::{PoseidonField, SmallField}; let limit = std::env::var("CENO_GPU_DEBUG_COMPARE_EC_LIMIT") @@ -242,9 +247,7 @@ pub(crate) fn debug_compare_shard_ec>( // ========== Build CPU shard context (independent, isolated) ========== let mut cpu_ctx = shard_ctx.new_empty_like(); - if let Err(e) = cpu_collect_shardram::( - config, &mut cpu_ctx, shard_steps, step_indices, - ) { + if let Err(e) = cpu_collect_shardram::(config, &mut cpu_ctx, shard_steps, step_indices) { tracing::error!("[GPU EC debug] kind={kind:?} CPU shardram records failed: {e:?}"); return; } @@ -287,7 +290,11 @@ pub(crate) fn debug_compare_shard_ec>( .map(|g| { let rec = ShardRamRecord { addr: g.addr, - ram_type: if g.ram_type == 1 { RAMType::Register } else { RAMType::Memory }, + ram_type: if g.ram_type == 1 { + RAMType::Register + } else { + RAMType::Memory + }, value: g.value, shard: g.shard, local_clk: g.local_clk, @@ -297,7 +304,10 @@ pub(crate) fn debug_compare_shard_ec>( let x = SepticExtension(g.point_x.map(|v| E::BaseField::from_canonical_u32(v))); let y = SepticExtension(g.point_y.map(|v| E::BaseField::from_canonical_u32(v))); let point = SepticPoint::from_affine(x, y); - let ec = ECPoint:: { nonce: g.nonce, point }; + let ec = ECPoint:: { + nonce: g.nonce, + point, + }; (rec, ec) }) .collect(); @@ -310,15 +320,28 @@ pub(crate) fn debug_compare_shard_ec>( tracing::error!( "[GPU EC debug] kind={kind:?} ADDR_ACCESSED MISMATCH: cpu={} gpu={} \ cpu_only={} gpu_only={}", - cpu_addr.len(), gpu_addr.len(), cpu_only.len(), gpu_only.len() + cpu_addr.len(), + gpu_addr.len(), + cpu_only.len(), + gpu_only.len() ); for (i, addr) in cpu_only.iter().enumerate() { - if i >= limit { break; } - tracing::error!("[GPU EC debug] kind={kind:?} addr_accessed CPU-only: {}", addr.0); + if i >= limit { + break; + } + tracing::error!( + "[GPU EC debug] kind={kind:?} addr_accessed CPU-only: {}", + addr.0 + ); } for (i, addr) in gpu_only.iter().enumerate() { - if i >= limit { break; } - tracing::error!("[GPU EC debug] kind={kind:?} addr_accessed GPU-only: {}", addr.0); + if i >= limit { + break; + } + tracing::error!( + "[GPU EC debug] kind={kind:?} addr_accessed GPU-only: {}", + addr.0 + ); } } @@ -328,21 +351,38 @@ pub(crate) fn debug_compare_shard_ec>( if cpu_entries.len() != gpu_entries.len() { tracing::error!( "[GPU EC debug] kind={kind:?} RECORD COUNT MISMATCH: cpu={} gpu={}", - cpu_entries.len(), gpu_entries.len() + cpu_entries.len(), + gpu_entries.len() ); let cpu_keys: std::collections::BTreeSet<_> = cpu_entries - .iter().map(|(r, _)| (r.addr, r.is_to_write_set)).collect(); + .iter() + .map(|(r, _)| (r.addr, r.is_to_write_set)) + .collect(); let gpu_keys: std::collections::BTreeSet<_> = gpu_entries - .iter().map(|(r, _)| (r.addr, r.is_to_write_set)).collect(); + .iter() + .map(|(r, _)| (r.addr, r.is_to_write_set)) + .collect(); let mut logged = 0usize; for key in cpu_keys.difference(&gpu_keys) { - if logged >= limit { break; } - tracing::error!("[GPU EC debug] kind={kind:?} CPU-only: addr={} is_write={}", key.0, key.1); + if logged >= limit { + break; + } + tracing::error!( + "[GPU EC debug] kind={kind:?} CPU-only: addr={} is_write={}", + key.0, + key.1 + ); logged += 1; } for key in gpu_keys.difference(&cpu_keys) { - if logged >= limit { break; } - tracing::error!("[GPU EC debug] kind={kind:?} GPU-only: addr={} is_write={}", key.0, key.1); + if logged >= limit { + break; + } + tracing::error!( + "[GPU EC debug] kind={kind:?} GPU-only: addr={} is_write={}", + key.0, + key.1 + ); logged += 1; } } @@ -358,7 +398,9 @@ pub(crate) fn debug_compare_shard_ec>( if gpu_dup_count <= limit { tracing::error!( "[GPU EC debug] kind={kind:?} GPU DUPLICATE: addr={} is_write={} ram_type={:?}", - w[0].0.addr, w[0].0.is_to_write_set, w[0].0.ram_type + w[0].0.addr, + w[0].0.is_to_write_set, + w[0].0.ram_type ); } } @@ -382,7 +424,12 @@ pub(crate) fn debug_compare_shard_ec>( if record_mismatches < limit { tracing::error!( "[GPU EC debug] kind={kind:?} MISSING in GPU: addr={} is_write={} ram={:?} val={} shard={} clk={}", - cr.addr, cr.is_to_write_set, cr.ram_type, cr.value, cr.shard, cr.global_clk + cr.addr, + cr.is_to_write_set, + cr.ram_type, + cr.value, + cr.shard, + cr.global_clk ); } record_mismatches += 1; @@ -393,7 +440,12 @@ pub(crate) fn debug_compare_shard_ec>( if record_mismatches < limit { tracing::error!( "[GPU EC debug] kind={kind:?} EXTRA in GPU: addr={} is_write={} ram={:?} val={} shard={} clk={}", - gr.addr, gr.is_to_write_set, gr.ram_type, gr.value, gr.shard, gr.global_clk + gr.addr, + gr.is_to_write_set, + gr.ram_type, + gr.value, + gr.shard, + gr.global_clk ); } record_mismatches += 1; @@ -432,7 +484,9 @@ pub(crate) fn debug_compare_shard_ec>( if ec_mismatches < limit { tracing::error!( "[GPU EC debug] kind={kind:?} addr={} nonce: cpu={} gpu={}", - cr.addr, ce.nonce, ge.nonce + cr.addr, + ce.nonce, + ge.nonce ); } } @@ -443,7 +497,8 @@ pub(crate) fn debug_compare_shard_ec>( ec_diff = true; if ec_mismatches < limit { tracing::error!( - "[GPU EC debug] kind={kind:?} addr={} x[{j}]: cpu={cv} gpu={gv}", cr.addr + "[GPU EC debug] kind={kind:?} addr={} x[{j}]: cpu={cv} gpu={gv}", + cr.addr ); } } @@ -455,7 +510,8 @@ pub(crate) fn debug_compare_shard_ec>( ec_diff = true; if ec_mismatches < limit { tracing::error!( - "[GPU EC debug] kind={kind:?} addr={} y[{j}]: cpu={cv} gpu={gv}", cr.addr + "[GPU EC debug] kind={kind:?} addr={} y[{j}]: cpu={cv} gpu={gv}", + cr.addr ); } } @@ -475,7 +531,9 @@ pub(crate) fn debug_compare_shard_ec>( let (cr, _) = &cpu_entries[ci]; tracing::error!( "[GPU EC debug] kind={kind:?} MISSING in GPU (tail): addr={} is_write={} val={}", - cr.addr, cr.is_to_write_set, cr.value + cr.addr, + cr.is_to_write_set, + cr.value ); } record_mismatches += 1; @@ -486,7 +544,9 @@ pub(crate) fn debug_compare_shard_ec>( let (gr, _) = &gpu_entries[gi]; tracing::error!( "[GPU EC debug] kind={kind:?} EXTRA in GPU (tail): addr={} is_write={} val={}", - gr.addr, gr.is_to_write_set, gr.value + gr.addr, + gr.is_to_write_set, + gr.value ); } record_mismatches += 1; @@ -498,14 +558,18 @@ pub(crate) fn debug_compare_shard_ec>( if addr_ok && record_mismatches == 0 && ec_mismatches == 0 && gpu_dup_count == 0 { tracing::info!( "[GPU EC debug] kind={kind:?} ALL MATCH: {} records, {} addr_accessed, EC points OK", - matched, cpu_addr.len() + matched, + cpu_addr.len() ); } else { tracing::error!( "[GPU EC debug] kind={kind:?} MISMATCH: matched={matched} record_diffs={record_mismatches} \ ec_diffs={ec_mismatches} gpu_dups={gpu_dup_count} addr_ok={addr_ok} \ (cpu_records={} gpu_records={} cpu_addrs={} gpu_addrs={})", - cpu_entries.len(), gpu_entries.len(), cpu_addr.len(), gpu_addr.len() + cpu_entries.len(), + gpu_entries.len(), + cpu_addr.len(), + gpu_addr.len() ); } } @@ -631,14 +695,15 @@ pub(crate) fn debug_compare_keccak( use crate::instructions::riscv::ecall::keccak::KeccakInstruction; // Set force-CPU flag so gpu_assign_keccak_instances returns None set_force_cpu_path(true); - let result = as crate::instructions::Instruction>::assign_instances( - config, - &mut cpu_ctx, - num_witin, - num_structural_witin, - steps, - step_indices, - ); + let result = + as crate::instructions::Instruction>::assign_instances( + config, + &mut cpu_ctx, + num_witin, + num_structural_witin, + steps, + step_indices, + ); set_force_cpu_path(false); IN_DEBUG_COMPARE.with(|f| f.set(false)); result? @@ -667,16 +732,26 @@ pub(crate) fn debug_compare_keccak( let col = i % num_witin; tracing::error!( "[GPU keccak witness] row={} col={} gpu={:?} cpu={:?}", - row, col, g, c + row, + col, + g, + c ); } diffs += 1; } } if diffs == 0 { - tracing::info!("[GPU keccak debug] witness matrices match ({} elements)", gpu_vals.len()); + tracing::info!( + "[GPU keccak debug] witness matrices match ({} elements)", + gpu_vals.len() + ); } else { - tracing::error!("[GPU keccak debug] witness mismatch: {} diffs out of {}", diffs, gpu_vals.len()); + tracing::error!( + "[GPU keccak debug] witness mismatch: {} diffs out of {}", + diffs, + gpu_vals.len() + ); } } @@ -690,7 +765,8 @@ pub(crate) fn debug_compare_keccak( if cpu_addr.len() != gpu_addr_set.len() { tracing::error!( "[GPU keccak shard] addr_accessed count mismatch: cpu={} gpu={}", - cpu_addr.len(), gpu_addr_set.len() + cpu_addr.len(), + gpu_addr_set.len() ); } let mut missing_from_gpu = 0usize; @@ -720,7 +796,8 @@ pub(crate) fn debug_compare_keccak( } else { tracing::error!( "[GPU keccak shard] addr_accessed diff: missing_from_gpu={} extra_in_gpu={}", - missing_from_gpu, extra_in_gpu + missing_from_gpu, + extra_in_gpu ); } } diff --git a/ceno_zkvm/src/instructions/gpu/utils/fallback.rs b/ceno_zkvm/src/instructions/gpu/utils/fallback.rs index 376e9872e..bd15be32c 100644 --- a/ceno_zkvm/src/instructions/gpu/utils/fallback.rs +++ b/ceno_zkvm/src/instructions/gpu/utils/fallback.rs @@ -10,9 +10,7 @@ use rayon::{ }; use witness::RowMajorMatrix; -use crate::{ - e2e::ShardContext, error::ZKVMError, tables::RMMCollections, witness::LkMultiplicity, -}; +use crate::{e2e::ShardContext, error::ZKVMError, tables::RMMCollections, witness::LkMultiplicity}; use crate::instructions::Instruction; @@ -24,13 +22,7 @@ pub fn cpu_assign_instances>( num_structural_witin: usize, shard_steps: &[ceno_emul::StepRecord], step_indices: &[StepIndex], -) -> Result< - ( - RMMCollections, - Multiplicity, - ), - ZKVMError, -> { +) -> Result<(RMMCollections, Multiplicity), ZKVMError> { assert!(num_structural_witin == 0 || num_structural_witin == 1); let num_structural_witin = num_structural_witin.max(1); diff --git a/ceno_zkvm/src/instructions/gpu/utils/mod.rs b/ceno_zkvm/src/instructions/gpu/utils/mod.rs index 0c120c4fa..46ce4b6e6 100644 --- a/ceno_zkvm/src/instructions/gpu/utils/mod.rs +++ b/ceno_zkvm/src/instructions/gpu/utils/mod.rs @@ -2,16 +2,16 @@ //! //! Contains lookup/shard lk_shardram collection abstractions and CPU fallback paths. -mod lk_ops; -mod sink; mod emit; mod fallback; +mod lk_ops; +mod sink; // Re-export all public types for convenience -pub use lk_ops::*; -pub use sink::*; pub use emit::*; pub use fallback::*; +pub use lk_ops::*; +pub use sink::*; #[cfg(feature = "gpu")] pub mod colmap_base; @@ -20,7 +20,6 @@ pub mod d2h; #[cfg(feature = "gpu")] pub mod debug_compare; - #[cfg(test)] mod tests { use super::*; diff --git a/ceno_zkvm/src/instructions/gpu/utils/sink.rs b/ceno_zkvm/src/instructions/gpu/utils/sink.rs index 3c19f44bc..4e7ff8fde 100644 --- a/ceno_zkvm/src/instructions/gpu/utils/sink.rs +++ b/ceno_zkvm/src/instructions/gpu/utils/sink.rs @@ -18,6 +18,9 @@ pub struct CpuLkShardramSink<'ctx, 'shard, 'lk> { } impl<'ctx, 'shard, 'lk> CpuLkShardramSink<'ctx, 'shard, 'lk> { + /// # Safety + /// Caller must ensure `shard_ctx` points to a valid, live `ShardContext` + /// for the duration of the returned sink's lifetime. pub unsafe fn from_raw( shard_ctx: *mut ShardContext<'shard>, lk: &'lk mut LkMultiplicity, diff --git a/ceno_zkvm/src/instructions/riscv.rs b/ceno_zkvm/src/instructions/riscv.rs index ab15d2295..8798e01aa 100644 --- a/ceno_zkvm/src/instructions/riscv.rs +++ b/ceno_zkvm/src/instructions/riscv.rs @@ -32,7 +32,6 @@ mod r_insn; mod ecall_insn; - #[cfg(feature = "u16limb_circuit")] pub(crate) mod auipc; mod im_insn; diff --git a/ceno_zkvm/src/instructions/riscv/arith.rs b/ceno_zkvm/src/instructions/riscv/arith.rs index b179645b9..198697fca 100644 --- a/ceno_zkvm/src/instructions/riscv/arith.rs +++ b/ceno_zkvm/src/instructions/riscv/arith.rs @@ -2,14 +2,11 @@ use std::marker::PhantomData; use super::{RIVInstruction, constants::UInt, r_insn::RInstructionConfig}; use crate::{ - impl_collect_shardram, impl_collect_lk_and_shardram, impl_gpu_assign, circuit_builder::CircuitBuilder, e2e::ShardContext, error::ZKVMError, - instructions::{ - Instruction, - gpu::utils::emit_u16_limbs, - }, + impl_collect_lk_and_shardram, impl_collect_shardram, impl_gpu_assign, + instructions::{Instruction, gpu::utils::emit_u16_limbs}, structs::ProgramParams, uint::Value, witness::LkMultiplicity, diff --git a/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit_v2.rs index 7ba58cea2..2d0e62969 100644 --- a/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit_v2.rs @@ -3,11 +3,11 @@ use crate::{ circuit_builder::CircuitBuilder, e2e::ShardContext, error::ZKVMError, - impl_collect_shardram, impl_collect_lk_and_shardram, impl_gpu_assign, + impl_collect_lk_and_shardram, impl_collect_shardram, impl_gpu_assign, instructions::{ Instruction, - riscv::{RIVInstruction, constants::UInt, i_insn::IInstructionConfig}, gpu::utils::emit_u16_limbs, + riscv::{RIVInstruction, constants::UInt, i_insn::IInstructionConfig}, }, structs::ProgramParams, utils::{imm_sign_extend, imm_sign_extend_circuit}, diff --git a/ceno_zkvm/src/instructions/riscv/auipc.rs b/ceno_zkvm/src/instructions/riscv/auipc.rs index cf40c1eff..e5aea4394 100644 --- a/ceno_zkvm/src/instructions/riscv/auipc.rs +++ b/ceno_zkvm/src/instructions/riscv/auipc.rs @@ -6,17 +6,14 @@ use crate::{ circuit_builder::CircuitBuilder, e2e::ShardContext, error::ZKVMError, - impl_collect_shardram, impl_collect_lk_and_shardram, impl_gpu_assign, + impl_collect_lk_and_shardram, impl_collect_shardram, impl_gpu_assign, instructions::{ Instruction, + gpu::utils::{LkOp, LkShardramSink, emit_byte_decomposition_ops, emit_const_range_op}, riscv::{ constants::{PC_BITS, UINT_BYTE_LIMBS, UInt8}, i_insn::IInstructionConfig, }, - gpu::utils::{ - LkOp, LkShardramSink, emit_byte_decomposition_ops, - emit_const_range_op, - }, }, structs::ProgramParams, tables::InsnRecord, diff --git a/ceno_zkvm/src/instructions/riscv/b_insn.rs b/ceno_zkvm/src/instructions/riscv/b_insn.rs index 382c62981..a94266f63 100644 --- a/ceno_zkvm/src/instructions/riscv/b_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/b_insn.rs @@ -8,8 +8,8 @@ use crate::{ e2e::ShardContext, error::ZKVMError, instructions::{ - riscv::insn_base::{ReadRS1, ReadRS2, StateInOut}, gpu::utils::{LkOp, LkShardramSink}, + riscv::insn_base::{ReadRS1, ReadRS2, StateInOut}, }, tables::InsnRecord, witness::{LkMultiplicity, set_val}, diff --git a/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs index e0c3d2fc1..34b6ebe68 100644 --- a/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs @@ -4,15 +4,15 @@ use crate::{ e2e::ShardContext, error::ZKVMError, gadgets::{UIntLimbsLT, UIntLimbsLTConfig}, - impl_collect_shardram, impl_collect_lk_and_shardram, impl_gpu_assign, + impl_collect_lk_and_shardram, impl_collect_shardram, impl_gpu_assign, instructions::{ Instruction, + gpu::utils::emit_uint_limbs_lt_ops, riscv::{ RIVInstruction, b_insn::BInstructionConfig, constants::{UINT_LIMBS, UInt}, }, - gpu::utils::emit_uint_limbs_lt_ops, }, structs::ProgramParams, witness::LkMultiplicity, @@ -218,8 +218,8 @@ impl Instruction for BranchCircuit WriteRD { self.emit_op_shardram(shard_ctx, step.cycle(), &op) } - pub fn emit_op_shardram( - &self, - shard_ctx: &mut ShardContext, - cycle: Cycle, - op: &WriteOp, - ) { + pub fn emit_op_shardram(&self, shard_ctx: &mut ShardContext, cycle: Cycle, op: &WriteOp) { shard_ctx.record_send_without_touch( RAMType::Register, op.addr, @@ -661,12 +656,7 @@ impl WriteMEM { self.emit_op_shardram(shard_ctx, step.cycle(), &op) } - pub fn emit_op_shardram( - &self, - shard_ctx: &mut ShardContext, - cycle: Cycle, - op: &WriteOp, - ) { + pub fn emit_op_shardram(&self, shard_ctx: &mut ShardContext, cycle: Cycle, op: &WriteOp) { shard_ctx.record_send_without_touch( RAMType::Memory, op.addr, diff --git a/ceno_zkvm/src/instructions/riscv/j_insn.rs b/ceno_zkvm/src/instructions/riscv/j_insn.rs index 8ce03b5ac..6cc8b18fe 100644 --- a/ceno_zkvm/src/instructions/riscv/j_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/j_insn.rs @@ -7,8 +7,8 @@ use crate::{ e2e::ShardContext, error::ZKVMError, instructions::{ - riscv::insn_base::{StateInOut, WriteRD}, gpu::utils::{LkOp, LkShardramSink}, + riscv::insn_base::{StateInOut, WriteRD}, }, tables::InsnRecord, witness::LkMultiplicity, diff --git a/ceno_zkvm/src/instructions/riscv/jump/jal_v2.rs b/ceno_zkvm/src/instructions/riscv/jump/jal_v2.rs index f84cc9696..bb8ba0abe 100644 --- a/ceno_zkvm/src/instructions/riscv/jump/jal_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/jump/jal_v2.rs @@ -6,14 +6,14 @@ use crate::{ circuit_builder::CircuitBuilder, e2e::ShardContext, error::ZKVMError, - impl_collect_shardram, impl_collect_lk_and_shardram, impl_gpu_assign, + impl_collect_lk_and_shardram, impl_collect_shardram, impl_gpu_assign, instructions::{ Instruction, + gpu::utils::{LkOp, LkShardramSink, emit_byte_decomposition_ops}, riscv::{ constants::{PC_BITS, UINT_BYTE_LIMBS, UInt8}, j_insn::JInstructionConfig, }, - gpu::utils::{LkOp, LkShardramSink, emit_byte_decomposition_ops}, }, structs::ProgramParams, utils::split_to_u8, diff --git a/ceno_zkvm/src/instructions/riscv/jump/jalr_v2.rs b/ceno_zkvm/src/instructions/riscv/jump/jalr_v2.rs index e133033b5..75c0d28cf 100644 --- a/ceno_zkvm/src/instructions/riscv/jump/jalr_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/jump/jalr_v2.rs @@ -7,15 +7,15 @@ use crate::{ circuit_builder::CircuitBuilder, e2e::ShardContext, error::ZKVMError, - impl_collect_shardram, impl_collect_lk_and_shardram, impl_gpu_assign, + impl_collect_lk_and_shardram, impl_collect_shardram, impl_gpu_assign, instructions::{ Instruction, + gpu::utils::emit_const_range_op, riscv::{ constants::{PC_BITS, UINT_LIMBS, UInt}, i_insn::IInstructionConfig, insn_base::{MemAddr, ReadRS1, StateInOut, WriteRD}, }, - gpu::utils::emit_const_range_op, }, structs::ProgramParams, tables::InsnRecord, diff --git a/ceno_zkvm/src/instructions/riscv/logic/logic_circuit.rs b/ceno_zkvm/src/instructions/riscv/logic/logic_circuit.rs index 2b42cdfaf..6dad59ca7 100644 --- a/ceno_zkvm/src/instructions/riscv/logic/logic_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/logic/logic_circuit.rs @@ -8,11 +8,11 @@ use crate::{ circuit_builder::CircuitBuilder, e2e::ShardContext, error::ZKVMError, - impl_collect_shardram, impl_collect_lk_and_shardram, impl_gpu_assign, + impl_collect_lk_and_shardram, impl_collect_shardram, impl_gpu_assign, instructions::{ Instruction, - riscv::{constants::UInt8, r_insn::RInstructionConfig}, gpu::utils::emit_logic_u8_ops, + riscv::{constants::UInt8, r_insn::RInstructionConfig}, }, structs::ProgramParams, utils::split_to_u8, @@ -153,13 +153,4 @@ impl LogicConfig { Ok(()) } - - fn emit_lk_and_shardram( - &self, - sink: &mut impl crate::instructions::gpu::utils::LkShardramSink, - shard_ctx: &ShardContext, - step: &StepRecord, - ) { - self.r_insn.emit_lk_and_shardram(sink, shard_ctx, step); - } } diff --git a/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit_v2.rs index d9624b68f..1137974f8 100644 --- a/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit_v2.rs @@ -9,15 +9,15 @@ use crate::{ circuit_builder::CircuitBuilder, e2e::ShardContext, error::ZKVMError, - impl_collect_shardram, impl_collect_lk_and_shardram, impl_gpu_assign, + impl_collect_lk_and_shardram, impl_collect_shardram, impl_gpu_assign, instructions::{ Instruction, + gpu::utils::emit_logic_u8_ops, riscv::{ constants::{LIMB_BITS, LIMB_MASK, UInt8}, i_insn::IInstructionConfig, logic_imm::LogicOp, }, - gpu::utils::emit_logic_u8_ops, }, structs::ProgramParams, tables::InsnRecord, diff --git a/ceno_zkvm/src/instructions/riscv/lui.rs b/ceno_zkvm/src/instructions/riscv/lui.rs index 9d8e67f95..814924fda 100644 --- a/ceno_zkvm/src/instructions/riscv/lui.rs +++ b/ceno_zkvm/src/instructions/riscv/lui.rs @@ -6,14 +6,14 @@ use crate::{ circuit_builder::CircuitBuilder, e2e::ShardContext, error::ZKVMError, - impl_collect_shardram, impl_collect_lk_and_shardram, impl_gpu_assign, + impl_collect_lk_and_shardram, impl_collect_shardram, impl_gpu_assign, instructions::{ Instruction, + gpu::utils::emit_const_range_op, riscv::{ constants::{UINT_BYTE_LIMBS, UInt8}, i_insn::IInstructionConfig, }, - gpu::utils::emit_const_range_op, }, structs::ProgramParams, tables::InsnRecord, diff --git a/ceno_zkvm/src/instructions/riscv/memory/load.rs b/ceno_zkvm/src/instructions/riscv/memory/load.rs index b19e94c01..fc37371bc 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/load.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/load.rs @@ -4,7 +4,7 @@ use crate::{ e2e::ShardContext, error::ZKVMError, gadgets::SignedExtendConfig, - impl_collect_shardram, impl_collect_lk_and_shardram, impl_gpu_assign, + impl_collect_lk_and_shardram, impl_collect_shardram, impl_gpu_assign, instructions::{ Instruction, riscv::{ diff --git a/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs b/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs index e6ec9d6c7..1810319eb 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs @@ -4,7 +4,7 @@ use crate::{ e2e::ShardContext, error::ZKVMError, gadgets::SignedExtendConfig, - impl_collect_shardram, impl_collect_lk_and_shardram, impl_gpu_assign, + impl_collect_lk_and_shardram, impl_collect_shardram, impl_gpu_assign, instructions::{ Instruction, riscv::{ diff --git a/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs b/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs index 05851299d..a6afeb93a 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs @@ -3,9 +3,10 @@ use crate::{ circuit_builder::CircuitBuilder, e2e::ShardContext, error::ZKVMError, - impl_collect_shardram, impl_collect_lk_and_shardram, impl_gpu_assign, + impl_collect_lk_and_shardram, impl_collect_shardram, impl_gpu_assign, instructions::{ Instruction, + gpu::utils::{emit_const_range_op, emit_u16_limbs}, riscv::{ RIVInstruction, constants::{MEM_BITS, UInt}, @@ -13,7 +14,6 @@ use crate::{ memory::gadget::MemWordUtil, s_insn::SInstructionConfig, }, - gpu::utils::{emit_const_range_op, emit_u16_limbs}, }, structs::ProgramParams, tables::InsnRecord, @@ -181,9 +181,7 @@ impl Instruction let imm = InsnRecord::::imm_internal(&step.insn()); let addr = ByteAddr::from(step.rs1().unwrap().value.wrapping_add_signed(imm.0 as i32)); - config - .memory_addr - .emit_lk_and_shardram(sink, addr.into()); + config.memory_addr.emit_lk_and_shardram(sink, addr.into()); if N_ZEROS == 0 { let memory_op = step.memory_op().unwrap(); diff --git a/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit_v2.rs index 7cc6ce0e9..2ed8358a6 100644 --- a/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit_v2.rs @@ -1,15 +1,15 @@ use crate::{ circuit_builder::CircuitBuilder, error::ZKVMError, - impl_collect_shardram, impl_collect_lk_and_shardram, impl_gpu_assign, + impl_collect_lk_and_shardram, impl_collect_shardram, impl_gpu_assign, instructions::{ Instruction, + gpu::utils::{LkOp, LkShardramSink}, riscv::{ RIVInstruction, constants::{LIMB_BITS, UINT_LIMBS, UInt}, r_insn::RInstructionConfig, }, - gpu::utils::{LkOp, LkShardramSink}, }, structs::ProgramParams, uint::Value, diff --git a/ceno_zkvm/src/instructions/riscv/r_insn.rs b/ceno_zkvm/src/instructions/riscv/r_insn.rs index 4ad09b9d6..525998e41 100644 --- a/ceno_zkvm/src/instructions/riscv/r_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/r_insn.rs @@ -7,8 +7,8 @@ use crate::{ e2e::ShardContext, error::ZKVMError, instructions::{ - riscv::insn_base::{ReadRS1, ReadRS2, StateInOut, WriteRD}, gpu::utils::{LkOp, LkShardramSink}, + riscv::insn_base::{ReadRS1, ReadRS2, StateInOut, WriteRD}, }, tables::InsnRecord, witness::LkMultiplicity, diff --git a/ceno_zkvm/src/instructions/riscv/s_insn.rs b/ceno_zkvm/src/instructions/riscv/s_insn.rs index 1d849146c..38dd29555 100644 --- a/ceno_zkvm/src/instructions/riscv/s_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/s_insn.rs @@ -4,8 +4,8 @@ use crate::{ e2e::ShardContext, error::ZKVMError, instructions::{ - riscv::insn_base::{ReadRS1, ReadRS2, StateInOut, WriteMEM}, gpu::utils::{LkOp, LkShardramSink}, + riscv::insn_base::{ReadRS1, ReadRS2, StateInOut, WriteMEM}, }, tables::InsnRecord, witness::LkMultiplicity, diff --git a/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs index 90f9ac102..493f76f4e 100644 --- a/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs @@ -1,19 +1,16 @@ use crate::e2e::ShardContext; /// constrain implementation follow from https://github.com/openvm-org/openvm/blob/main/extensions/rv32im/circuit/src/shift/core.rs use crate::{ - impl_collect_shardram, impl_collect_lk_and_shardram, impl_gpu_assign, + impl_collect_lk_and_shardram, impl_collect_shardram, impl_gpu_assign, instructions::{ Instruction, + gpu::utils::{LkOp, LkShardramSink, emit_byte_decomposition_ops, emit_const_range_op}, riscv::{ RIVInstruction, constants::{UINT_BYTE_LIMBS, UInt8}, i_insn::IInstructionConfig, r_insn::RInstructionConfig, }, - gpu::utils::{ - LkOp, LkShardramSink, emit_byte_decomposition_ops, - emit_const_range_op, - }, }, structs::ProgramParams, utils::{split_to_limb, split_to_u8}, diff --git a/ceno_zkvm/src/instructions/riscv/slt/slt_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/slt/slt_circuit_v2.rs index a61d30578..f397dcb5e 100644 --- a/ceno_zkvm/src/instructions/riscv/slt/slt_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/slt/slt_circuit_v2.rs @@ -4,11 +4,11 @@ use crate::{ e2e::ShardContext, error::ZKVMError, gadgets::{UIntLimbsLT, UIntLimbsLTConfig}, - impl_collect_shardram, impl_collect_lk_and_shardram, impl_gpu_assign, + impl_collect_lk_and_shardram, impl_collect_shardram, impl_gpu_assign, instructions::{ Instruction, - riscv::{RIVInstruction, constants::UInt, r_insn::RInstructionConfig}, gpu::utils::emit_uint_limbs_lt_ops, + riscv::{RIVInstruction, constants::UInt, r_insn::RInstructionConfig}, }, structs::ProgramParams, witness::LkMultiplicity, @@ -126,8 +126,8 @@ impl Instruction for SetLessThanInstruc emit_uint_limbs_lt_ops( sink, matches!(I::INST_KIND, InsnKind::SLT), - &rs1_limbs, - &rs2_limbs, + rs1_limbs, + rs2_limbs, ); }); diff --git a/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs index 5505f8c5d..4483ea4d4 100644 --- a/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs @@ -4,15 +4,15 @@ use crate::{ e2e::ShardContext, error::ZKVMError, gadgets::{UIntLimbsLT, UIntLimbsLTConfig}, - impl_collect_shardram, impl_collect_lk_and_shardram, impl_gpu_assign, + impl_collect_lk_and_shardram, impl_collect_shardram, impl_gpu_assign, instructions::{ Instruction, + gpu::utils::emit_uint_limbs_lt_ops, riscv::{ RIVInstruction, constants::{UINT_LIMBS, UInt}, i_insn::IInstructionConfig, }, - gpu::utils::emit_uint_limbs_lt_ops, }, structs::ProgramParams, utils::{imm_sign_extend, imm_sign_extend_circuit}, @@ -145,7 +145,7 @@ impl Instruction for SetLessThanImmInst emit_uint_limbs_lt_ops( sink, matches!(I::INST_KIND, InsnKind::SLTI), - &rs1_limbs, + rs1_limbs, &imm_sign_extend, ); }); diff --git a/ceno_zkvm/src/scheme/septic_curve.rs b/ceno_zkvm/src/scheme/septic_curve.rs index 98edd8615..8714ba8fe 100644 --- a/ceno_zkvm/src/scheme/septic_curve.rs +++ b/ceno_zkvm/src/scheme/septic_curve.rs @@ -1112,7 +1112,10 @@ impl SepticJacobianPoint { mod tests { use super::SepticExtension; use crate::scheme::septic_curve::{SepticJacobianPoint, SepticPoint}; - use p3::{babybear::BabyBear, field::{Field, FieldAlgebra}}; + use p3::{ + babybear::BabyBear, + field::Field, + }; use rand::thread_rng; type F = BabyBear; @@ -1179,8 +1182,10 @@ mod tests { #[cfg(feature = "gpu")] fn test_gpu_ec_point_matches_cpu() { use crate::tables::{ECPoint, ShardRamRecord}; - use ceno_gpu::bb31::test_impl::{TestEcInput, run_gpu_ec_test}; - use ceno_gpu::bb31::CudaHalBB31; + use ceno_gpu::bb31::{ + CudaHalBB31, + test_impl::{TestEcInput, run_gpu_ec_test}, + }; use ff_ext::{PoseidonField, SmallField}; use gkr_iop::RAMType; @@ -1189,14 +1194,70 @@ mod tests { // Test cases: various write/read, register/memory, edge cases let test_inputs = vec![ - TestEcInput { addr: 5, ram_type: 1, value: 0x12345678, is_write: 1, shard: 1, global_clk: 100 }, - TestEcInput { addr: 5, ram_type: 1, value: 0x12345678, is_write: 0, shard: 0, global_clk: 50 }, - TestEcInput { addr: 0x80000, ram_type: 2, value: 0xDEADBEEF, is_write: 1, shard: 2, global_clk: 200 }, - TestEcInput { addr: 0x80000, ram_type: 2, value: 0xDEADBEEF, is_write: 0, shard: 1, global_clk: 150 }, - TestEcInput { addr: 0, ram_type: 1, value: 0, is_write: 1, shard: 0, global_clk: 1 }, - TestEcInput { addr: 31, ram_type: 1, value: 0xFFFFFFFF, is_write: 0, shard: 3, global_clk: 999 }, - TestEcInput { addr: 0x40000000, ram_type: 2, value: 42, is_write: 1, shard: 5, global_clk: 500000 }, - TestEcInput { addr: 10, ram_type: 1, value: 1, is_write: 0, shard: 100, global_clk: 1000000 }, + TestEcInput { + addr: 5, + ram_type: 1, + value: 0x12345678, + is_write: 1, + shard: 1, + global_clk: 100, + }, + TestEcInput { + addr: 5, + ram_type: 1, + value: 0x12345678, + is_write: 0, + shard: 0, + global_clk: 50, + }, + TestEcInput { + addr: 0x80000, + ram_type: 2, + value: 0xDEADBEEF, + is_write: 1, + shard: 2, + global_clk: 200, + }, + TestEcInput { + addr: 0x80000, + ram_type: 2, + value: 0xDEADBEEF, + is_write: 0, + shard: 1, + global_clk: 150, + }, + TestEcInput { + addr: 0, + ram_type: 1, + value: 0, + is_write: 1, + shard: 0, + global_clk: 1, + }, + TestEcInput { + addr: 31, + ram_type: 1, + value: 0xFFFFFFFF, + is_write: 0, + shard: 3, + global_clk: 999, + }, + TestEcInput { + addr: 0x40000000, + ram_type: 2, + value: 42, + is_write: 1, + shard: 5, + global_clk: 500000, + }, + TestEcInput { + addr: 10, + ram_type: 1, + value: 1, + is_write: 0, + shard: 100, + global_clk: 1000000, + }, ]; let gpu_results = run_gpu_ec_test(&hal, &test_inputs); @@ -1206,13 +1267,25 @@ mod tests { // Build CPU ShardRamRecord and compute EC point. // GPU kernel treats slot.addr as word address and converts to byte address // via `<< 2` for Memory type; Register type uses reg_id directly. - let addr = if input.ram_type == 2 { input.addr << 2 } else { input.addr }; + let addr = if input.ram_type == 2 { + input.addr << 2 + } else { + input.addr + }; let cpu_record = ShardRamRecord { addr, - ram_type: if input.ram_type == 1 { RAMType::Register } else { RAMType::Memory }, + ram_type: if input.ram_type == 1 { + RAMType::Register + } else { + RAMType::Memory + }, value: input.value, shard: input.shard, - local_clk: if input.is_write != 0 { input.global_clk } else { 0 }, + local_clk: if input.is_write != 0 { + input.global_clk + } else { + 0 + }, global_clk: input.global_clk, is_to_write_set: input.is_write != 0, }; @@ -1241,16 +1314,29 @@ mod tests { } if has_diff { - eprintln!("MISMATCH [{i}]: addr={} ram_type={} value={:#x} is_write={} shard={} clk={}", - input.addr, input.ram_type, input.value, input.is_write, input.shard, input.global_clk); + eprintln!( + "MISMATCH [{i}]: addr={} ram_type={} value={:#x} is_write={} shard={} clk={}", + input.addr, + input.ram_type, + input.value, + input.is_write, + input.shard, + input.global_clk + ); mismatches += 1; } } - assert_eq!(mismatches, 0, + assert_eq!( + mismatches, + 0, "{mismatches}/{} test cases had GPU/CPU EC point mismatches", - test_inputs.len()); - eprintln!("All {} GPU EC point test cases match CPU!", test_inputs.len()); + test_inputs.len() + ); + eprintln!( + "All {} GPU EC point test cases match CPU!", + test_inputs.len() + ); } /// Verify GPU Poseidon2 permutation matches CPU on the exact sponge packing @@ -1258,8 +1344,11 @@ mod tests { #[test] #[cfg(feature = "gpu")] fn test_gpu_poseidon2_sponge_matches_cpu() { - use ceno_gpu::bb31::test_impl::{run_gpu_poseidon2_sponge, SPONGE_WIDTH}; - use ceno_gpu::bb31::CudaHalBB31; + use p3::field::FieldAlgebra; + use ceno_gpu::bb31::{ + CudaHalBB31, + test_impl::{SPONGE_WIDTH, run_gpu_poseidon2_sponge}, + }; use ff_ext::{PoseidonField, SmallField}; use p3::symmetric::Permutation; @@ -1272,25 +1361,25 @@ mod tests { // Case 1: typical write record { let mut s = [0u32; SPONGE_WIDTH]; - s[0] = 5; // addr - s[1] = 1; // ram_type (Register) - s[2] = 0x5678; // value lo16 - s[3] = 0x1234; // value hi16 - s[4] = 1; // shard - s[5] = 100; // global_clk - s[6] = 0; // nonce + s[0] = 5; // addr + s[1] = 1; // ram_type (Register) + s[2] = 0x5678; // value lo16 + s[3] = 0x1234; // value hi16 + s[4] = 1; // shard + s[5] = 100; // global_clk + s[6] = 0; // nonce s }, // Case 2: memory read, different values { let mut s = [0u32; SPONGE_WIDTH]; s[0] = 0x80000; - s[1] = 2; // Memory + s[1] = 2; // Memory s[2] = 0xBEEF; s[3] = 0xDEAD; s[4] = 2; s[5] = 200; - s[6] = 3; // nonce=3 + s[6] = 3; // nonce=3 s }, // Case 3: all zeros (edge case) @@ -1317,7 +1406,10 @@ mod tests { } } - assert_eq!(mismatches, 0, "{mismatches} Poseidon2 output elements differ between GPU and CPU"); + assert_eq!( + mismatches, 0, + "{mismatches} Poseidon2 output elements differ between GPU and CPU" + ); eprintln!("All {} Poseidon2 sponge test cases match!", count); } @@ -1326,8 +1418,8 @@ mod tests { #[test] #[cfg(feature = "gpu")] fn test_gpu_septic_from_x_matches_cpu() { - use ceno_gpu::bb31::test_impl::run_gpu_septic_from_x; - use ceno_gpu::bb31::CudaHalBB31; + use p3::field::FieldAlgebra; + use ceno_gpu::bb31::{CudaHalBB31, test_impl::run_gpu_septic_from_x}; use ff_ext::SmallField; let hal = CudaHalBB31::new(0).unwrap(); @@ -1336,7 +1428,9 @@ mod tests { // (ensures we get a mix of points-exist and points-don't-exist cases) let test_xs: Vec<[u32; 7]> = vec![ // x from Poseidon2([5,1,0x5678,0x1234,1,100,0, 0..]) — known to have a point - [1594766074, 868528894, 1733778006, 1242721508, 1690833816, 1437202757, 1753525271], + [ + 1594766074, 868528894, 1733778006, 1242721508, 1690833816, 1437202757, 1753525271, + ], // Simple: x = [1,0,0,0,0,0,0] [1, 0, 0, 0, 0, 0, 0], // x = [0,0,0,0,0,0,0] (zero) @@ -1344,7 +1438,9 @@ mod tests { // x = [42, 17, 999, 0, 0, 0, 0] [42, 17, 999, 0, 0, 0, 0], // Random-ish values - [1000000007, 123456789, 987654321, 111111111, 222222222, 333333333, 444444444], + [ + 1000000007, 123456789, 987654321, 111111111, 222222222, 333333333, 444444444, + ], ]; let count = test_xs.len(); @@ -1380,8 +1476,10 @@ mod tests { } } - assert_eq!(mismatches, 0, - "{mismatches} septic_from_x results differ between GPU and CPU"); + assert_eq!( + mismatches, 0, + "{mismatches} septic_from_x results differ between GPU and CPU" + ); eprintln!("All {} septic_from_x test cases match!", count); } } diff --git a/ceno_zkvm/src/structs.rs b/ceno_zkvm/src/structs.rs index f296de9c2..680ff4184 100644 --- a/ceno_zkvm/src/structs.rs +++ b/ceno_zkvm/src/structs.rs @@ -483,21 +483,18 @@ impl ZKVMWitnesses { // Falls back to the traditional path on failure. #[cfg(feature = "gpu")] { - let gpu_result = self.try_assign_shared_circuit_gpu( - cs, shard_ctx, final_mem, config, - ); + let gpu_result = self.try_assign_shared_circuit_gpu(cs, shard_ctx, final_mem, config); match gpu_result { - Ok(true) => return Ok(()), // GPU pipeline succeeded - Ok(false) => {} // GPU pipeline unavailable, fall through + Ok(true) => return Ok(()), // GPU pipeline succeeded + Ok(false) => {} // GPU pipeline unavailable, fall through Err(e) => { tracing::warn!("GPU full pipeline failed, falling back: {e:?}"); } } } - let addr_accessed = info_span!("get_addr_accessed").in_scope(|| { - shard_ctx.get_addr_accessed_sorted() - }); + let addr_accessed = + info_span!("get_addr_accessed").in_scope(|| shard_ctx.get_addr_accessed_sorted()); // GPU EC records: convert raw bytes to ShardRamInput (EC points already computed on GPU) // Partition into writes and reads to maintain the ordering invariant required by @@ -512,15 +509,42 @@ impl ZKVMWitnesses { }); // Collect cross-shard records (filter only, no EC computation yet) - let (write_record_pairs, read_record_pairs) = info_span!("collect_records").in_scope(|| { - let first_shard_access_later_recs: Vec<(ShardRamRecord, &'static str)> = - if shard_ctx.is_first_shard() { + let (write_record_pairs, read_record_pairs) = + info_span!("collect_records").in_scope(|| { + let first_shard_access_later_recs: Vec<(ShardRamRecord, &'static str)> = + if shard_ctx.is_first_shard() { + final_mem + .par_iter() + .filter(|(_, range, _)| range.is_none()) + .flat_map(|(mem_name, _, final_mem)| { + final_mem.par_iter().filter_map(|mem_record| { + let (waddr, addr) = Self::mem_addresses(mem_record); + Self::make_cross_shard_record( + mem_name, + mem_record, + waddr, + addr, + shard_ctx, + &addr_accessed, + ) + }) + }) + .collect() + } else { + vec![] + }; + + let current_shard_access_later_recs: Vec<(ShardRamRecord, &'static str)> = final_mem .par_iter() - .filter(|(_, range, _)| range.is_none()) - .flat_map(|(mem_name, _, final_mem)| { + .filter(|(_, range, _)| range.is_some()) + .flat_map(|(mem_name, range, final_mem)| { + let range = range.as_ref().unwrap(); final_mem.par_iter().filter_map(|mem_record| { let (waddr, addr) = Self::mem_addresses(mem_record); + if !range.contains(&addr) { + return None; + } Self::make_cross_shard_record( mem_name, mem_record, @@ -531,65 +555,39 @@ impl ZKVMWitnesses { ) }) }) - .collect() - } else { - vec![] - }; - - let current_shard_access_later_recs: Vec<(ShardRamRecord, &'static str)> = final_mem - .par_iter() - .filter(|(_, range, _)| range.is_some()) - .flat_map(|(mem_name, range, final_mem)| { - let range = range.as_ref().unwrap(); - final_mem.par_iter().filter_map(|mem_record| { - let (waddr, addr) = Self::mem_addresses(mem_record); - if !range.contains(&addr) { - return None; - } - Self::make_cross_shard_record( - mem_name, - mem_record, - waddr, - addr, - shard_ctx, - &addr_accessed, - ) - }) - }) - .collect(); + .collect(); - let write_record_pairs: Vec<(ShardRamRecord, &'static str)> = shard_ctx - .write_records() - .iter() - .flat_map(|records| { - records.iter().map(|(vma, record)| { - ((vma, record, true).into(), "current_shard_external_write") + let write_record_pairs: Vec<(ShardRamRecord, &'static str)> = shard_ctx + .write_records() + .iter() + .flat_map(|records| { + records.iter().map(|(vma, record)| { + ((vma, record, true).into(), "current_shard_external_write") + }) }) - }) - .chain(first_shard_access_later_recs) - .chain(current_shard_access_later_recs) - .collect(); + .chain(first_shard_access_later_recs) + .chain(current_shard_access_later_recs) + .collect(); - let read_record_pairs: Vec<(ShardRamRecord, &'static str)> = shard_ctx - .read_records() - .iter() - .flat_map(|records| { - records.iter().map(|(vma, record)| { - ((vma, record, false).into(), "current_shard_external_read") + let read_record_pairs: Vec<(ShardRamRecord, &'static str)> = shard_ctx + .read_records() + .iter() + .flat_map(|records| { + records.iter().map(|(vma, record)| { + ((vma, record, false).into(), "current_shard_external_read") + }) }) - }) - .collect(); + .collect(); - (write_record_pairs, read_record_pairs) - }); + (write_record_pairs, read_record_pairs) + }); // Compute EC points: GPU path (fast) or CPU fallback let global_input = { #[cfg(feature = "gpu")] let ec_result = { use crate::instructions::gpu::dispatch::gpu_batch_continuation_ec; - gpu_batch_continuation_ec::(&write_record_pairs, &read_record_pairs) - .ok() + gpu_batch_continuation_ec::(&write_record_pairs, &read_record_pairs).ok() }; #[cfg(not(feature = "gpu"))] let ec_result: Option<(Vec>, Vec>)> = None; @@ -609,14 +607,22 @@ impl ZKVMWitnesses { .into_par_iter() .map(|(record, name)| { let ec_point = record.to_ec_point(&perm); - ShardRamInput { name, record, ec_point } + ShardRamInput { + name, + record, + ec_point, + } }) .collect(); let cpu_reads: Vec> = read_record_pairs .into_par_iter() .map(|(record, name)| { let ec_point = record.to_ec_point(&perm); - ShardRamInput { name, record, ec_point } + ShardRamInput { + name, + record, + ec_point, + } }) .collect(); cpu_writes @@ -673,8 +679,14 @@ impl ZKVMWitnesses { shard_ctx.shard_id, input.record.global_clk, global_input.len(), - global_input.iter().filter(|x| x.record.is_to_write_set).count(), - global_input.iter().filter(|x| !x.record.is_to_write_set).count(), + global_input + .iter() + .filter(|x| x.record.is_to_write_set) + .count(), + global_input + .iter() + .filter(|x| !x.record.is_to_write_set) + .count(), ); break; } @@ -687,31 +699,32 @@ impl ZKVMWitnesses { assert!(self.combined_lk_mlt.is_some()); let cs = cs.get_cs(&ShardRamCircuit::::name()).unwrap(); let n_global = global_input.len(); - let circuit_inputs = info_span!("shard_ram_assign_instances", n = n_global).in_scope(|| { - global_input - .par_chunks(shard_ctx.max_num_cross_shard_accesses) - .map(|shard_accesses| { - let witness = ShardRamCircuit::assign_instances( - config, - cs.zkvm_v1_css.num_witin as usize, - cs.zkvm_v1_css.num_structural_witin as usize, - self.combined_lk_mlt.as_ref().unwrap(), - shard_accesses, - )?; - let num_reads = shard_accesses - .par_iter() - .filter(|access| access.record.is_to_write_set) - .count(); - let num_writes = shard_accesses.len() - num_reads; - - Ok(ChipInput::new( - ShardRamCircuit::::name(), - witness, - vec![num_reads, num_writes], - )) - }) - .collect::, ZKVMError>>() - })?; + let circuit_inputs = + info_span!("shard_ram_assign_instances", n = n_global).in_scope(|| { + global_input + .par_chunks(shard_ctx.max_num_cross_shard_accesses) + .map(|shard_accesses| { + let witness = ShardRamCircuit::assign_instances( + config, + cs.zkvm_v1_css.num_witin as usize, + cs.zkvm_v1_css.num_structural_witin as usize, + self.combined_lk_mlt.as_ref().unwrap(), + shard_accesses, + )?; + let num_reads = shard_accesses + .par_iter() + .filter(|access| access.record.is_to_write_set) + .count(); + let num_writes = shard_accesses.len() - num_reads; + + Ok(ChipInput::new( + ShardRamCircuit::::name(), + witness, + vec![num_reads, num_writes], + )) + }) + .collect::, ZKVMError>>() + })?; // set num_read, num_write as separate instance assert!( self.witnesses @@ -769,17 +782,17 @@ impl ZKVMWitnesses { tracing::info!( "[GPU full pipeline] shared buffers: {} EC records, {} addr_accessed", - ec_count, addr_count, + ec_count, + addr_count, ); // 3. GPU sort addr_accessed + dedup, then D2H sorted unique addrs let addr_accessed: Vec = if addr_count > 0 { info_span!("gpu_sort_addr").in_scope(|| { - let (deduped, unique_count) = hal.witgen + let (deduped, unique_count) = hal + .witgen .sort_and_dedup_u32(&mut shared.addr_buf, addr_count, None) - .map_err(|e| { - ZKVMError::InvalidWitness(format!("GPU sort addr: {e}").into()) - })?; + .map_err(|e| ZKVMError::InvalidWitness(format!("GPU sort addr: {e}").into()))?; if unique_count == 0 { return Ok::, ZKVMError>(vec![]); } @@ -787,7 +800,8 @@ impl ZKVMWitnesses { let addrs: Vec = deduped.into_iter().map(WordAddr).collect(); tracing::info!( "[GPU full pipeline] sorted {} addrs → {} unique", - addr_count, unique_count, + addr_count, + unique_count, ); Ok(addrs) })? @@ -796,16 +810,43 @@ impl ZKVMWitnesses { }; // 4. CPU collect_records (24ms, uses sorted unique addrs) - let (write_record_pairs, read_record_pairs) = info_span!("collect_records").in_scope(|| { - // This is the same logic as the existing path - let first_shard_access_later_recs: Vec<(ShardRamRecord, &'static str)> = - if shard_ctx.is_first_shard() { + let (write_record_pairs, read_record_pairs) = + info_span!("collect_records").in_scope(|| { + // This is the same logic as the existing path + let first_shard_access_later_recs: Vec<(ShardRamRecord, &'static str)> = + if shard_ctx.is_first_shard() { + final_mem + .par_iter() + .filter(|(_, range, _)| range.is_none()) + .flat_map(|(mem_name, _, final_mem)| { + final_mem.par_iter().filter_map(|mem_record| { + let (waddr, addr) = Self::mem_addresses(mem_record); + Self::make_cross_shard_record( + mem_name, + mem_record, + waddr, + addr, + shard_ctx, + &addr_accessed, + ) + }) + }) + .collect() + } else { + vec![] + }; + + let current_shard_access_later_recs: Vec<(ShardRamRecord, &'static str)> = final_mem .par_iter() - .filter(|(_, range, _)| range.is_none()) - .flat_map(|(mem_name, _, final_mem)| { + .filter(|(_, range, _)| range.is_some()) + .flat_map(|(mem_name, range, final_mem)| { + let range = range.as_ref().unwrap(); final_mem.par_iter().filter_map(|mem_record| { let (waddr, addr) = Self::mem_addresses(mem_record); + if !range.contains(&addr) { + return None; + } Self::make_cross_shard_record( mem_name, mem_record, @@ -816,88 +857,68 @@ impl ZKVMWitnesses { ) }) }) - .collect() - } else { - vec![] - }; - - let current_shard_access_later_recs: Vec<(ShardRamRecord, &'static str)> = final_mem - .par_iter() - .filter(|(_, range, _)| range.is_some()) - .flat_map(|(mem_name, range, final_mem)| { - let range = range.as_ref().unwrap(); - final_mem.par_iter().filter_map(|mem_record| { - let (waddr, addr) = Self::mem_addresses(mem_record); - if !range.contains(&addr) { - return None; - } - Self::make_cross_shard_record( - mem_name, - mem_record, - waddr, - addr, - shard_ctx, - &addr_accessed, - ) - }) - }) - .collect(); + .collect(); - let write_record_pairs: Vec<(ShardRamRecord, &'static str)> = shard_ctx - .write_records() - .iter() - .flat_map(|records| { - records.iter().map(|(vma, record)| { - ((vma, record, true).into(), "current_shard_external_write") + let write_record_pairs: Vec<(ShardRamRecord, &'static str)> = shard_ctx + .write_records() + .iter() + .flat_map(|records| { + records.iter().map(|(vma, record)| { + ((vma, record, true).into(), "current_shard_external_write") + }) }) - }) - .chain(first_shard_access_later_recs) - .chain(current_shard_access_later_recs) - .collect(); + .chain(first_shard_access_later_recs) + .chain(current_shard_access_later_recs) + .collect(); - let read_record_pairs: Vec<(ShardRamRecord, &'static str)> = shard_ctx - .read_records() - .iter() - .flat_map(|records| { - records.iter().map(|(vma, record)| { - ((vma, record, false).into(), "current_shard_external_read") + let read_record_pairs: Vec<(ShardRamRecord, &'static str)> = shard_ctx + .read_records() + .iter() + .flat_map(|records| { + records.iter().map(|(vma, record)| { + ((vma, record, false).into(), "current_shard_external_read") + }) }) - }) - .collect(); + .collect(); - (write_record_pairs, read_record_pairs) - }); + (write_record_pairs, read_record_pairs) + }); // 5. GPU batch EC on device for continuation records (25ms, results stay on GPU) - let (cont_ec_buf, cont_n_writes, cont_n_reads) = - info_span!("gpu_batch_ec_on_device").in_scope(|| { + let (cont_ec_buf, cont_n_writes, cont_n_reads) = info_span!("gpu_batch_ec_on_device") + .in_scope(|| { gpu_batch_continuation_ec_on_device(&write_record_pairs, &read_record_pairs) })?; let cont_total = cont_n_writes + cont_n_reads; tracing::info!( "[GPU full pipeline] batch EC on device: {} writes + {} reads = {} continuation records", - cont_n_writes, cont_n_reads, cont_total, + cont_n_writes, + cont_n_reads, + cont_total, ); // 6. GPU merge shared_ec + batch_ec, then partition by is_to_write_set - let (partitioned_buf, num_writes, total_records) = - info_span!("gpu_merge_partition").in_scope(|| { - hal.witgen.merge_and_partition_records( - &shared.ec_buf, - ec_count, - &cont_ec_buf, - cont_total, - None, - ) - .map_err(|e| { - ZKVMError::InvalidWitness(format!("GPU merge+partition: {e}").into()) - }) + let (partitioned_buf, num_writes, total_records) = info_span!("gpu_merge_partition") + .in_scope(|| { + hal.witgen + .merge_and_partition_records( + &shared.ec_buf, + ec_count, + &cont_ec_buf, + cont_total, + None, + ) + .map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU merge+partition: {e}").into()) + }) })?; tracing::info!( "[GPU full pipeline] merged+partitioned: {} total ({} writes, {} reads)", - total_records, num_writes, total_records - num_writes, + total_records, + num_writes, + total_records - num_writes, ); // 7. GPU assign_instances from device buffer (chunked by max_cross_shard) @@ -908,7 +929,8 @@ impl ZKVMWitnesses { let max_chunk = shard_ctx.max_num_cross_shard_accesses; // Record sizes needed for chunking - let record_u32s = std::mem::size_of::() / 4; + let record_u32s = + std::mem::size_of::() / 4; let circuit_inputs = info_span!("shard_ram_assign_from_device", n = total_records) .in_scope(|| { @@ -928,9 +950,13 @@ impl ZKVMWitnesses { // out of scope. The 'static lifetime is required by the HAL API. let chunk_byte_start = records_offset * record_u32s * 4; let chunk_byte_end = (records_offset + chunk_size) * record_u32s * 4; - let chunk_view = partitioned_buf.as_slice_range(chunk_byte_start..chunk_byte_end); - let chunk_buf: ceno_gpu::common::buffer::BufferImpl<'static, u32> = - unsafe { std::mem::transmute(ceno_gpu::common::buffer::BufferImpl::::new_from_view(chunk_view)) }; + let chunk_view = + partitioned_buf.as_slice_range(chunk_byte_start..chunk_byte_end); + let chunk_buf: ceno_gpu::common::buffer::BufferImpl<'static, u32> = unsafe { + std::mem::transmute( + ceno_gpu::common::buffer::BufferImpl::::new_from_view(chunk_view), + ) + }; let witness = ShardRamCircuit::::try_gpu_assign_instances_from_device( config, @@ -1189,7 +1215,7 @@ where fn gpu_ec_records_to_shard_ram_inputs( raw: &[u8], ) -> (Vec>, Vec>) { - assert!(raw.len() % GPU_SHARD_RAM_RECORD_SIZE == 0); + assert!(raw.len().is_multiple_of(GPU_SHARD_RAM_RECORD_SIZE)); let count = raw.len() / GPU_SHARD_RAM_RECORD_SIZE; #[inline(always)] @@ -1218,17 +1244,21 @@ fn gpu_ec_records_to_shard_ram_inputs( let mut point_x_arr = [E::BaseField::ZERO; 7]; let mut point_y_arr = [E::BaseField::ZERO; 7]; for j in 0..7 { - point_x_arr[j] = E::BaseField::from_canonical_u32( - u32::from_le_bytes(r[48 + j*4..52 + j*4].try_into().unwrap()), - ); - point_y_arr[j] = E::BaseField::from_canonical_u32( - u32::from_le_bytes(r[76 + j*4..80 + j*4].try_into().unwrap()), - ); + point_x_arr[j] = E::BaseField::from_canonical_u32(u32::from_le_bytes( + r[48 + j * 4..52 + j * 4].try_into().unwrap(), + )); + point_y_arr[j] = E::BaseField::from_canonical_u32(u32::from_le_bytes( + r[76 + j * 4..80 + j * 4].try_into().unwrap(), + )); } let record = ShardRamRecord { addr, - ram_type: if ram_type_val == 1 { RAMType::Register } else { RAMType::Memory }, + ram_type: if ram_type_val == 1 { + RAMType::Register + } else { + RAMType::Memory + }, value, shard, local_clk, diff --git a/ceno_zkvm/src/tables/shard_ram.rs b/ceno_zkvm/src/tables/shard_ram.rs index dd8d06afd..22e2c72a3 100644 --- a/ceno_zkvm/src/tables/shard_ram.rs +++ b/ceno_zkvm/src/tables/shard_ram.rs @@ -487,7 +487,9 @@ impl TableCircuit for ShardRamCircuit { #[cfg(feature = "gpu")] { - if let Some(result) = Self::try_gpu_assign_instances(config, num_witin, num_structural_witin, steps)? { + if let Some(result) = + Self::try_gpu_assign_instances(config, num_witin, num_structural_witin, steps)? + { return Ok(result); } } @@ -663,10 +665,7 @@ impl ShardRamCircuit { use ceno_gpu::{ Buffer, CudaHal, bb31::CudaHalBB31, - common::{ - transpose::matrix_transpose, - witgen::types::GpuShardRamRecord, - }, + common::{transpose::matrix_transpose, witgen::types::GpuShardRamRecord}, }; use gkr_iop::gpu::gpu_prover::get_cuda_hal; use p3::field::PrimeField32; @@ -692,32 +691,35 @@ impl ShardRamCircuit { let num_rows_padded = 2 * n; // 1. Convert ShardRamInput → GpuShardRamRecord - let gpu_records: Vec = tracing::info_span!( - "gpu_shard_ram_pack_records", n = steps.len() - ).in_scope(|| steps - .iter() - .map(|step| { - let r = &step.record; - let ec = &step.ec_point; - let mut rec = GpuShardRamRecord::default(); - rec.addr = r.addr; - rec.ram_type = r.ram_type as u32; - rec.value = r.value; - rec.shard = r.shard; - rec.local_clk = r.local_clk; - rec.global_clk = r.global_clk; - rec.is_to_write_set = if r.is_to_write_set { 1 } else { 0 }; - rec.nonce = ec.nonce; - for i in 0..7 { - // Safe: TypeId check above guarantees E::BaseField == BB - let px: BB = unsafe { *(&ec.point.x.0[i] as *const E::BaseField as *const BB) }; - let py: BB = unsafe { *(&ec.point.y.0[i] as *const E::BaseField as *const BB) }; - rec.point_x[i] = px.as_canonical_u32(); - rec.point_y[i] = py.as_canonical_u32(); - } - rec - }) - .collect()); + let gpu_records: Vec = + tracing::info_span!("gpu_shard_ram_pack_records", n = steps.len()).in_scope(|| { + steps + .iter() + .map(|step| { + let r = &step.record; + let ec = &step.ec_point; + let mut rec = GpuShardRamRecord::default(); + rec.addr = r.addr; + rec.ram_type = r.ram_type as u32; + rec.value = r.value; + rec.shard = r.shard; + rec.local_clk = r.local_clk; + rec.global_clk = r.global_clk; + rec.is_to_write_set = if r.is_to_write_set { 1 } else { 0 }; + rec.nonce = ec.nonce; + for i in 0..7 { + // Safe: TypeId check above guarantees E::BaseField == BB + let px: BB = + unsafe { *(&ec.point.x.0[i] as *const E::BaseField as *const BB) }; + let py: BB = + unsafe { *(&ec.point.y.0[i] as *const E::BaseField as *const BB) }; + rec.point_x[i] = px.as_canonical_u32(); + rec.point_y[i] = py.as_canonical_u32(); + } + rec + }) + .collect() + }); // 2. Extract column map let col_map = crate::instructions::gpu::chips::shard_ram::extract_shard_ram_column_map( @@ -730,151 +732,165 @@ impl ShardRamCircuit { n = steps.len(), num_rows_padded, num_witin, - ).in_scope(|| hal.witgen - .witgen_shard_ram_per_row( - &col_map, - &gpu_records, - num_local_reads as u32, - num_witin as u32, - num_structural_witin as u32, - num_rows_padded as u32, - None, - ) - .map_err(|e| { - ZKVMError::InvalidWitness( - format!("GPU shard_ram per-row kernel failed: {e}").into(), + ) + .in_scope(|| { + hal.witgen + .witgen_shard_ram_per_row( + &col_map, + &gpu_records, + num_local_reads as u32, + num_witin as u32, + num_structural_witin as u32, + num_rows_padded as u32, + None, ) - }))?; + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU shard_ram per-row kernel failed: {e}").into(), + ) + }) + })?; // 4. GPU Phase 2: EC binary tree - let witness_buf = tracing::info_span!( - "gpu_shard_ram_ec_tree", n - ).in_scope(|| -> Result<_, ZKVMError> { - let col_offsets = col_map.to_flat(); - let gpu_cols = hal.alloc_u32_from_host(&col_offsets, None).map_err(|e| { - ZKVMError::InvalidWitness(format!("GPU alloc col offsets failed: {e}").into()) - })?; + let witness_buf = tracing::info_span!("gpu_shard_ram_ec_tree", n).in_scope( + || -> Result<_, ZKVMError> { + let col_offsets = col_map.to_flat(); + let gpu_cols = hal.alloc_u32_from_host(&col_offsets, None).map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU alloc col offsets failed: {e}").into()) + })?; - // Build initial layer points (padded to n) using BB (BabyBear) directly - let mut init_x = vec![BB::ZERO; n * 7]; - let mut init_y = vec![BB::ZERO; n * 7]; - for (i, step) in steps.iter().enumerate() { - for j in 0..7 { - // E::BaseField == BB at runtime (checked above), safe to transmute - init_x[i * 7 + j] = - unsafe { *(&step.ec_point.point.x.0[j] as *const E::BaseField as *const BB) }; - init_y[i * 7 + j] = - unsafe { *(&step.ec_point.point.y.0[j] as *const E::BaseField as *const BB) }; + // Build initial layer points (padded to n) using BB (BabyBear) directly + let mut init_x = vec![BB::ZERO; n * 7]; + let mut init_y = vec![BB::ZERO; n * 7]; + for (i, step) in steps.iter().enumerate() { + for j in 0..7 { + // E::BaseField == BB at runtime (checked above), safe to transmute + init_x[i * 7 + j] = unsafe { + *(&step.ec_point.point.x.0[j] as *const E::BaseField as *const BB) + }; + init_y[i * 7 + j] = unsafe { + *(&step.ec_point.point.y.0[j] as *const E::BaseField as *const BB) + }; + } } - } - - let mut cur_x = hal.alloc_elems_from_host(&init_x, None).map_err(|e| { - ZKVMError::InvalidWitness(format!("GPU alloc init_x failed: {e}").into()) - })?; - let mut cur_y = hal.alloc_elems_from_host(&init_y, None).map_err(|e| { - ZKVMError::InvalidWitness(format!("GPU alloc init_y failed: {e}").into()) - })?; - let mut witness_buf = gpu_witness.device_buffer; - let mut offset = num_rows_padded / 2; // n - let mut current_layer_len = n; + let mut cur_x = hal.alloc_elems_from_host(&init_x, None).map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU alloc init_x failed: {e}").into()) + })?; + let mut cur_y = hal.alloc_elems_from_host(&init_y, None).map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU alloc init_y failed: {e}").into()) + })?; - loop { - if current_layer_len <= 1 { - break; + let mut witness_buf = gpu_witness.device_buffer; + let mut offset = num_rows_padded / 2; // n + let mut current_layer_len = n; + + loop { + if current_layer_len <= 1 { + break; + } + + let (next_x, next_y) = hal + .witgen + .shard_ram_ec_tree_layer( + &gpu_cols, + &cur_x, + &cur_y, + &mut witness_buf, + current_layer_len, + offset, + num_rows_padded, + None, + ) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU EC tree layer failed: {e}").into(), + ) + })?; + + current_layer_len /= 2; + offset += current_layer_len; + cur_x = next_x; + cur_y = next_y; } - let (next_x, next_y) = hal.witgen - .shard_ram_ec_tree_layer( - &gpu_cols, - &cur_x, - &cur_y, - &mut witness_buf, - current_layer_len, - offset, - num_rows_padded, - None, + Ok(witness_buf) + }, + )?; + + // 5. GPU transpose: column-major → row-major + D2H + let (wit_data, struct_data) = + tracing::info_span!("gpu_shard_ram_transpose_d2h", num_rows_padded, num_witin,) + .in_scope(|| -> Result<_, ZKVMError> { + let wit_num_rows = num_rows_padded; + let wit_num_cols = num_witin; + let mut rmm_buf = hal + .witgen + .alloc_elems_on_device(wit_num_rows * wit_num_cols, false, None) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU alloc for transpose failed: {e}").into(), + ) + })?; + matrix_transpose::( + &hal.inner, + &mut rmm_buf, + &witness_buf, + wit_num_rows, + wit_num_cols, ) .map_err(|e| { - ZKVMError::InvalidWitness( - format!("GPU EC tree layer failed: {e}").into(), - ) + ZKVMError::InvalidWitness(format!("GPU transpose failed: {e}").into()) })?; - current_layer_len /= 2; - offset += current_layer_len; - cur_x = next_x; - cur_y = next_y; - } - - Ok(witness_buf) - })?; - - // 5. GPU transpose: column-major → row-major + D2H - let (wit_data, struct_data) = tracing::info_span!( - "gpu_shard_ram_transpose_d2h", num_rows_padded, num_witin, - ).in_scope(|| -> Result<_, ZKVMError> { - let wit_num_rows = num_rows_padded; - let wit_num_cols = num_witin; - let mut rmm_buf = hal.witgen - .alloc_elems_on_device(wit_num_rows * wit_num_cols, false, None) - .map_err(|e| { - ZKVMError::InvalidWitness(format!("GPU alloc for transpose failed: {e}").into()) - })?; - matrix_transpose::( - &hal.inner, - &mut rmm_buf, - &witness_buf, - wit_num_rows, - wit_num_cols, - ) - .map_err(|e| ZKVMError::InvalidWitness(format!("GPU transpose failed: {e}").into()))?; - - let gpu_wit_data: Vec = rmm_buf.to_vec().map_err(|e| { - ZKVMError::InvalidWitness(format!("GPU D2H wit failed: {e}").into()) - })?; - let wit_data: Vec = unsafe { - let mut data = std::mem::ManuallyDrop::new(gpu_wit_data); - Vec::from_raw_parts( - data.as_mut_ptr() as *mut E::BaseField, - data.len(), - data.capacity(), - ) - }; + let gpu_wit_data: Vec = rmm_buf.to_vec().map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU D2H wit failed: {e}").into()) + })?; + let wit_data: Vec = unsafe { + let mut data = std::mem::ManuallyDrop::new(gpu_wit_data); + Vec::from_raw_parts( + data.as_mut_ptr() as *mut E::BaseField, + data.len(), + data.capacity(), + ) + }; - let struct_num_cols = num_structural_witin; - let mut struct_rmm_buf = hal.witgen - .alloc_elems_on_device(wit_num_rows * struct_num_cols, false, None) - .map_err(|e| { - ZKVMError::InvalidWitness( - format!("GPU alloc for struct transpose failed: {e}").into(), + let struct_num_cols = num_structural_witin; + let mut struct_rmm_buf = hal + .witgen + .alloc_elems_on_device(wit_num_rows * struct_num_cols, false, None) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU alloc for struct transpose failed: {e}").into(), + ) + })?; + matrix_transpose::( + &hal.inner, + &mut struct_rmm_buf, + &gpu_structural.device_buffer, + wit_num_rows, + struct_num_cols, ) - })?; - matrix_transpose::( - &hal.inner, - &mut struct_rmm_buf, - &gpu_structural.device_buffer, - wit_num_rows, - struct_num_cols, - ) - .map_err(|e| { - ZKVMError::InvalidWitness(format!("GPU struct transpose failed: {e}").into()) - })?; + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU struct transpose failed: {e}").into(), + ) + })?; - let gpu_struct_data: Vec = struct_rmm_buf.to_vec().map_err(|e| { - ZKVMError::InvalidWitness(format!("GPU D2H struct failed: {e}").into()) - })?; - let struct_data: Vec = unsafe { - let mut data = std::mem::ManuallyDrop::new(gpu_struct_data); - Vec::from_raw_parts( - data.as_mut_ptr() as *mut E::BaseField, - data.len(), - data.capacity(), - ) - }; + let gpu_struct_data: Vec = struct_rmm_buf.to_vec().map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU D2H struct failed: {e}").into()) + })?; + let struct_data: Vec = unsafe { + let mut data = std::mem::ManuallyDrop::new(gpu_struct_data); + Vec::from_raw_parts( + data.as_mut_ptr() as *mut E::BaseField, + data.len(), + data.capacity(), + ) + }; - Ok((wit_data, struct_data)) - })?; + Ok((wit_data, struct_data)) + })?; let raw_witin = witness::RowMajorMatrix::new_by_values( wit_data, @@ -909,11 +925,7 @@ impl ShardRamCircuit { num_records: usize, num_local_writes: usize, ) -> Result>, ZKVMError> { - use ceno_gpu::{ - Buffer, CudaHal, - bb31::CudaHalBB31, - common::transpose::matrix_transpose, - }; + use ceno_gpu::{Buffer, CudaHal, bb31::CudaHalBB31, common::transpose::matrix_transpose}; use gkr_iop::gpu::gpu_prover::get_cuda_hal; type BB = ::BaseField; @@ -941,83 +953,92 @@ impl ShardRamCircuit { n = num_records, num_rows_padded, num_witin, - ).in_scope(|| hal.witgen - .witgen_shard_ram_per_row_from_device( - &col_map, - device_records, - num_records, - num_local_writes as u32, - num_witin as u32, - num_structural_witin as u32, - num_rows_padded as u32, - None, - ) - .map_err(|e| { - ZKVMError::InvalidWitness( - format!("GPU shard_ram per-row (from_device) kernel failed: {e}").into(), + ) + .in_scope(|| { + hal.witgen + .witgen_shard_ram_per_row_from_device( + &col_map, + device_records, + num_records, + num_local_writes as u32, + num_witin as u32, + num_structural_witin as u32, + num_rows_padded as u32, + None, ) - }))?; - - // 3. GPU: extract EC points from device records (replaces CPU loop) - let witness_buf = tracing::info_span!( - "gpu_shard_ram_ec_tree_from_device", n - ).in_scope(|| -> Result<_, ZKVMError> { - let col_offsets = col_map.to_flat(); - let gpu_cols = hal.alloc_u32_from_host(&col_offsets, None).map_err(|e| { - ZKVMError::InvalidWitness(format!("GPU alloc col offsets failed: {e}").into()) - })?; - - // Extract point_x/y from device records into flat arrays - let (mut cur_x, mut cur_y) = hal.witgen - .extract_ec_points_from_device(device_records, num_records, n, None) .map_err(|e| { ZKVMError::InvalidWitness( - format!("GPU extract_ec_points failed: {e}").into(), + format!("GPU shard_ram per-row (from_device) kernel failed: {e}").into(), ) - })?; - - let mut witness_buf = gpu_witness.device_buffer; - let mut offset = num_rows_padded / 2; // n - let mut current_layer_len = n; + }) + })?; - loop { - if current_layer_len <= 1 { - break; - } + // 3. GPU: extract EC points from device records (replaces CPU loop) + let witness_buf = tracing::info_span!("gpu_shard_ram_ec_tree_from_device", n).in_scope( + || -> Result<_, ZKVMError> { + let col_offsets = col_map.to_flat(); + let gpu_cols = hal.alloc_u32_from_host(&col_offsets, None).map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU alloc col offsets failed: {e}").into()) + })?; - let (next_x, next_y) = hal.witgen - .shard_ram_ec_tree_layer( - &gpu_cols, - &cur_x, - &cur_y, - &mut witness_buf, - current_layer_len, - offset, - num_rows_padded, - None, - ) + // Extract point_x/y from device records into flat arrays + let (mut cur_x, mut cur_y) = hal + .witgen + .extract_ec_points_from_device(device_records, num_records, n, None) .map_err(|e| { ZKVMError::InvalidWitness( - format!("GPU EC tree layer failed: {e}").into(), + format!("GPU extract_ec_points failed: {e}").into(), ) })?; - current_layer_len /= 2; - offset += current_layer_len; - cur_x = next_x; - cur_y = next_y; - } + let mut witness_buf = gpu_witness.device_buffer; + let mut offset = num_rows_padded / 2; // n + let mut current_layer_len = n; + + loop { + if current_layer_len <= 1 { + break; + } + + let (next_x, next_y) = hal + .witgen + .shard_ram_ec_tree_layer( + &gpu_cols, + &cur_x, + &cur_y, + &mut witness_buf, + current_layer_len, + offset, + num_rows_padded, + None, + ) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU EC tree layer failed: {e}").into(), + ) + })?; + + current_layer_len /= 2; + offset += current_layer_len; + cur_x = next_x; + cur_y = next_y; + } - Ok(witness_buf) - })?; + Ok(witness_buf) + }, + )?; // 4. GPU transpose + D2H (same as regular path) let (wit_data, struct_data) = tracing::info_span!( - "gpu_shard_ram_transpose_d2h_from_device", num_rows_padded, num_witin, - ).in_scope(|| -> Result<_, ZKVMError> { + "gpu_shard_ram_transpose_d2h_from_device", + num_rows_padded, + num_witin, + ) + .in_scope(|| -> Result<_, ZKVMError> { let wit_num_rows = num_rows_padded; let wit_num_cols = num_witin; - let mut rmm_buf = hal.witgen + let mut rmm_buf = hal + .witgen .alloc_elems_on_device(wit_num_rows * wit_num_cols, false, None) .map_err(|e| { ZKVMError::InvalidWitness(format!("GPU alloc for transpose failed: {e}").into()) @@ -1044,7 +1065,8 @@ impl ShardRamCircuit { }; let struct_num_cols = num_structural_witin; - let mut struct_rmm_buf = hal.witgen + let mut struct_rmm_buf = hal + .witgen .alloc_elems_on_device(wit_num_rows * struct_num_cols, false, None) .map_err(|e| { ZKVMError::InvalidWitness( From f7dc9b400336a20387b169e30fcc04753b3a919d Mon Sep 17 00:00:00 2001 From: Velaciela Date: Tue, 24 Mar 2026 20:06:05 +0800 Subject: [PATCH 65/73] macro: test --- ceno_zkvm/src/instructions/gpu/chips/add.rs | 115 +++-------------- ceno_zkvm/src/instructions/gpu/chips/addi.rs | 37 +----- ceno_zkvm/src/instructions/gpu/chips/auipc.rs | 37 +----- .../src/instructions/gpu/chips/branch_cmp.rs | 37 +----- .../src/instructions/gpu/chips/branch_eq.rs | 37 +----- ceno_zkvm/src/instructions/gpu/chips/div.rs | 86 ++----------- ceno_zkvm/src/instructions/gpu/chips/jal.rs | 37 +----- ceno_zkvm/src/instructions/gpu/chips/jalr.rs | 37 +----- .../src/instructions/gpu/chips/load_sub.rs | 119 ++--------------- .../src/instructions/gpu/chips/logic_i.rs | 37 +----- .../src/instructions/gpu/chips/logic_r.rs | 110 +++------------- ceno_zkvm/src/instructions/gpu/chips/lui.rs | 37 +----- ceno_zkvm/src/instructions/gpu/chips/lw.rs | 121 +++--------------- ceno_zkvm/src/instructions/gpu/chips/mul.rs | 108 ++-------------- ceno_zkvm/src/instructions/gpu/chips/sb.rs | 37 +----- ceno_zkvm/src/instructions/gpu/chips/sh.rs | 37 +----- .../src/instructions/gpu/chips/shift_i.rs | 37 +----- .../src/instructions/gpu/chips/shift_r.rs | 38 +----- ceno_zkvm/src/instructions/gpu/chips/slt.rs | 38 +----- ceno_zkvm/src/instructions/gpu/chips/slti.rs | 37 +----- ceno_zkvm/src/instructions/gpu/chips/sub.rs | 38 +----- ceno_zkvm/src/instructions/gpu/chips/sw.rs | 37 +----- ceno_zkvm/src/instructions/gpu/dispatch.rs | 11 +- .../utils/{colmap_base.rs => column_map.rs} | 39 +++++- ceno_zkvm/src/instructions/gpu/utils/mod.rs | 4 +- .../instructions/gpu/utils/test_helpers.rs | 119 +++++++++++++++++ 26 files changed, 331 insertions(+), 1096 deletions(-) rename ceno_zkvm/src/instructions/gpu/utils/{colmap_base.rs => column_map.rs} (79%) create mode 100644 ceno_zkvm/src/instructions/gpu/utils/test_helpers.rs diff --git a/ceno_zkvm/src/instructions/gpu/chips/add.rs b/ceno_zkvm/src/instructions/gpu/chips/add.rs index cf5b7a1c8..73eec76c5 100644 --- a/ceno_zkvm/src/instructions/gpu/chips/add.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/add.rs @@ -2,7 +2,7 @@ use ceno_gpu::common::witgen::types::AddColumnMap; use ff_ext::ExtensionField; use crate::instructions::{ - gpu::utils::colmap_base::{ + gpu::utils::column_map::{ extract_carries, extract_rd, extract_rs1, extract_rs2, extract_state, extract_uint_limbs, }, riscv::arith::ArithConfig, @@ -58,35 +58,6 @@ mod tests { type E = BabyBearExt4; - fn flatten_records( - records: &[std::collections::BTreeMap], - ) -> Vec<(ceno_emul::WordAddr, u64, u64, usize)> { - records - .iter() - .flat_map(|table| { - table - .iter() - .map(|(addr, record)| (*addr, record.prev_cycle, record.cycle, record.shard_id)) - }) - .collect() - } - - fn flatten_lk( - multiplicity: &gkr_iop::utils::lk_multiplicity::Multiplicity, - ) -> Vec> { - multiplicity - .iter() - .map(|table| { - let mut entries = table - .iter() - .map(|(key, count)| (*key, *count)) - .collect::>(); - entries.sort_unstable(); - entries - }) - .collect() - } - fn make_test_steps(n: usize) -> Vec { const EDGE_CASES: &[(u32, u32)] = &[ (0, 0), @@ -125,23 +96,17 @@ mod tests { .collect() } - #[test] - fn test_extract_add_column_map() { - let mut cs = ConstraintSystem::::new(|| "test"); - let mut cb = CircuitBuilder::new(&mut cs); - let config = - AddInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); - - let col_map = extract_add_column_map(&config, cb.cs.num_witin as usize); - let flat = col_map.to_flat(); - - crate::instructions::gpu::utils::colmap_base::validate_column_map(&flat, col_map.num_cols); - } + use crate::instructions::gpu::utils::column_map::test_colmap; + test_colmap!(test_extract_add_column_map, AddInstruction, extract_add_column_map); #[test] #[cfg(feature = "gpu")] fn test_gpu_witgen_add_correctness() { use crate::e2e::ShardContext; + use crate::instructions::gpu::dispatch; + use crate::instructions::gpu::utils::test_helpers::{ + assert_full_gpu_pipeline, assert_witness_colmajor_eq, + }; use ceno_gpu::{Buffer, bb31::CudaHalBB31}; let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); @@ -199,63 +164,19 @@ mod tests { ) .unwrap(); - // D2H copy (GPU output is column-major) let gpu_data: Vec<::BaseField> = gpu_result.witness.device_buffer.to_vec().unwrap(); - - // Compare element by element (GPU is column-major, CPU is row-major) - let cpu_data = cpu_witness.values(); - assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); - - let mut mismatches = 0; - for row in 0..n { - for col in 0..num_witin { - let gpu_val = gpu_data[col * n + row]; // column-major - let cpu_val = cpu_data[row * num_witin + col]; // row-major - if gpu_val != cpu_val { - if mismatches < 10 { - eprintln!( - "Mismatch at row={}, col={}: GPU={:?}, CPU={:?}", - row, col, gpu_val, cpu_val - ); - } - mismatches += 1; - } - } - } - assert_eq!(mismatches, 0, "Found {} mismatches", mismatches); - - let mut shard_ctx_full_gpu = ShardContext::default(); - let (gpu_rmms, gpu_lkm) = - crate::instructions::gpu::dispatch::try_gpu_assign_instances::>( - &config, - &mut shard_ctx_full_gpu, - num_witin, - num_structural_witin, - &steps, - &indices, - crate::instructions::gpu::dispatch::GpuWitgenKind::Add, - ) - .unwrap() - .expect("GPU path should be available"); - - // Flush shared EC/addr buffers from GPU device to shard_ctx - // (in the e2e pipeline this is called once per shard after all opcode circuits) - crate::instructions::gpu::cache::flush_shared_ec_buffers(&mut shard_ctx_full_gpu).unwrap(); - - assert_eq!(gpu_rmms[0].values(), cpu_rmms[0].values()); - assert_eq!(flatten_lk(&gpu_lkm), flatten_lk(&cpu_lkm)); - assert_eq!( - shard_ctx_full_gpu.get_addr_accessed(), - shard_ctx.get_addr_accessed() - ); - assert_eq!( - flatten_records(shard_ctx_full_gpu.read_records()), - flatten_records(shard_ctx.read_records()) - ); - assert_eq!( - flatten_records(shard_ctx_full_gpu.write_records()), - flatten_records(shard_ctx.write_records()) + assert_witness_colmajor_eq(&gpu_data, cpu_witness.values(), n, num_witin); + + assert_full_gpu_pipeline::>( + &config, + &steps, + dispatch::GpuWitgenKind::Add, + &cpu_rmms, + &cpu_lkm, + &shard_ctx, + num_witin, + num_structural_witin, ); } } diff --git a/ceno_zkvm/src/instructions/gpu/chips/addi.rs b/ceno_zkvm/src/instructions/gpu/chips/addi.rs index 234ac00ff..9302a2f7a 100644 --- a/ceno_zkvm/src/instructions/gpu/chips/addi.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/addi.rs @@ -2,7 +2,7 @@ use ceno_gpu::common::witgen::types::AddiColumnMap; use ff_ext::ExtensionField; use crate::instructions::{ - gpu::utils::colmap_base::{ + gpu::utils::column_map::{ extract_carries, extract_rd, extract_rs1, extract_state, extract_uint_limbs, }, riscv::arith_imm::arith_imm_circuit_v2::InstructionConfig, @@ -54,22 +54,14 @@ mod tests { type E = BabyBearExt4; - #[test] - fn test_extract_addi_column_map() { - let mut cs = ConstraintSystem::::new(|| "test_addi"); - let mut cb = CircuitBuilder::new(&mut cs); - let config = - AddiInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); - - let col_map = extract_addi_column_map(&config, cb.cs.num_witin as usize); - let flat = col_map.to_flat(); - crate::instructions::gpu::utils::colmap_base::validate_column_map(&flat, col_map.num_cols); - } + use crate::instructions::gpu::utils::column_map::test_colmap; + test_colmap!(test_extract_addi_column_map, AddiInstruction, extract_addi_column_map); #[test] #[cfg(feature = "gpu")] fn test_gpu_witgen_addi_correctness() { use crate::e2e::ShardContext; + use crate::instructions::gpu::utils::test_helpers::assert_witness_colmajor_eq; use ceno_emul::{ByteAddr, Change, InsnKind, PC_STEP_SIZE, StepRecord, encode_rv32}; use ceno_gpu::{Buffer, bb31::CudaHalBB31}; @@ -142,25 +134,6 @@ mod tests { let gpu_data: Vec<::BaseField> = gpu_result.witness.device_buffer.to_vec().unwrap(); - let cpu_data = cpu_witness.values(); - assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); - - let mut mismatches = 0; - for row in 0..n { - for c in 0..num_witin { - let gpu_val = gpu_data[c * n + row]; - let cpu_val = cpu_data[row * num_witin + c]; - if gpu_val != cpu_val { - if mismatches < 10 { - eprintln!( - "Mismatch at row={}, col={}: GPU={:?}, CPU={:?}", - row, c, gpu_val, cpu_val - ); - } - mismatches += 1; - } - } - } - assert_eq!(mismatches, 0, "Found {} mismatches", mismatches); + assert_witness_colmajor_eq(&gpu_data, cpu_witness.values(), n, num_witin); } } diff --git a/ceno_zkvm/src/instructions/gpu/chips/auipc.rs b/ceno_zkvm/src/instructions/gpu/chips/auipc.rs index bf0c0fefc..b4948b7d3 100644 --- a/ceno_zkvm/src/instructions/gpu/chips/auipc.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/auipc.rs @@ -2,7 +2,7 @@ use ceno_gpu::common::witgen::types::AuipcColumnMap; use ff_ext::ExtensionField; use crate::instructions::{ - gpu::utils::colmap_base::{extract_rd, extract_rs1, extract_state, extract_uint_limbs}, + gpu::utils::column_map::{extract_rd, extract_rs1, extract_state, extract_uint_limbs}, riscv::auipc::AuipcConfig, }; @@ -54,22 +54,14 @@ mod tests { type E = BabyBearExt4; - #[test] - fn test_extract_auipc_column_map() { - let mut cs = ConstraintSystem::::new(|| "test_auipc"); - let mut cb = CircuitBuilder::new(&mut cs); - let config = - AuipcInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); - - let col_map = extract_auipc_column_map(&config, cb.cs.num_witin as usize); - let flat = col_map.to_flat(); - crate::instructions::gpu::utils::colmap_base::validate_column_map(&flat, col_map.num_cols); - } + use crate::instructions::gpu::utils::column_map::test_colmap; + test_colmap!(test_extract_auipc_column_map, AuipcInstruction, extract_auipc_column_map); #[test] #[cfg(feature = "gpu")] fn test_gpu_witgen_auipc_correctness() { use crate::e2e::ShardContext; + use crate::instructions::gpu::utils::test_helpers::assert_witness_colmajor_eq; use ceno_emul::{ByteAddr, Change, InsnKind, PC_STEP_SIZE, StepRecord, encode_rv32}; use ceno_gpu::{Buffer, bb31::CudaHalBB31}; @@ -142,25 +134,6 @@ mod tests { let gpu_data: Vec<::BaseField> = gpu_result.witness.device_buffer.to_vec().unwrap(); - let cpu_data = cpu_witness.values(); - assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); - - let mut mismatches = 0; - for row in 0..n { - for c in 0..num_witin { - let gpu_val = gpu_data[c * n + row]; - let cpu_val = cpu_data[row * num_witin + c]; - if gpu_val != cpu_val { - if mismatches < 10 { - eprintln!( - "Mismatch at row={}, col={}: GPU={:?}, CPU={:?}", - row, c, gpu_val, cpu_val - ); - } - mismatches += 1; - } - } - } - assert_eq!(mismatches, 0, "Found {} mismatches", mismatches); + assert_witness_colmajor_eq(&gpu_data, cpu_witness.values(), n, num_witin); } } diff --git a/ceno_zkvm/src/instructions/gpu/chips/branch_cmp.rs b/ceno_zkvm/src/instructions/gpu/chips/branch_cmp.rs index 69407ce0b..6ef8c2034 100644 --- a/ceno_zkvm/src/instructions/gpu/chips/branch_cmp.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/branch_cmp.rs @@ -2,7 +2,7 @@ use ceno_gpu::common::witgen::types::BranchCmpColumnMap; use ff_ext::ExtensionField; use crate::instructions::{ - gpu::utils::colmap_base::{ + gpu::utils::column_map::{ extract_rs1, extract_rs2, extract_state_branching, extract_uint_limbs, }, riscv::branch::branch_circuit_v2::BranchConfig, @@ -65,22 +65,14 @@ mod tests { type E = BabyBearExt4; - #[test] - fn test_extract_branch_cmp_column_map() { - let mut cs = ConstraintSystem::::new(|| "test"); - let mut cb = CircuitBuilder::new(&mut cs); - let config = - BltInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); - - let col_map = extract_branch_cmp_column_map(&config, cb.cs.num_witin as usize); - let flat = col_map.to_flat(); - crate::instructions::gpu::utils::colmap_base::validate_column_map(&flat, col_map.num_cols); - } + use crate::instructions::gpu::utils::column_map::test_colmap; + test_colmap!(test_extract_branch_cmp_column_map, BltInstruction, extract_branch_cmp_column_map); #[test] #[cfg(feature = "gpu")] fn test_gpu_witgen_branch_cmp_correctness() { use crate::e2e::ShardContext; + use crate::instructions::gpu::utils::test_helpers::assert_witness_colmajor_eq; use ceno_emul::{ByteAddr, Change, InsnKind, PC_STEP_SIZE, StepRecord, encode_rv32}; use ceno_gpu::{Buffer, bb31::CudaHalBB31}; @@ -159,25 +151,6 @@ mod tests { let gpu_data: Vec<::BaseField> = gpu_result.witness.device_buffer.to_vec().unwrap(); - let cpu_data = cpu_witness.values(); - assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); - - let mut mismatches = 0; - for row in 0..n { - for c in 0..num_witin { - let gpu_val = gpu_data[c * n + row]; - let cpu_val = cpu_data[row * num_witin + c]; - if gpu_val != cpu_val { - if mismatches < 10 { - eprintln!( - "Mismatch at row={}, col={}: GPU={:?}, CPU={:?}", - row, c, gpu_val, cpu_val - ); - } - mismatches += 1; - } - } - } - assert_eq!(mismatches, 0, "Found {} mismatches", mismatches); + assert_witness_colmajor_eq(&gpu_data, cpu_witness.values(), n, num_witin); } } diff --git a/ceno_zkvm/src/instructions/gpu/chips/branch_eq.rs b/ceno_zkvm/src/instructions/gpu/chips/branch_eq.rs index 33373bef8..3fa316740 100644 --- a/ceno_zkvm/src/instructions/gpu/chips/branch_eq.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/branch_eq.rs @@ -2,7 +2,7 @@ use ceno_gpu::common::witgen::types::BranchEqColumnMap; use ff_ext::ExtensionField; use crate::instructions::{ - gpu::utils::colmap_base::{ + gpu::utils::column_map::{ extract_rs1, extract_rs2, extract_state_branching, extract_uint_limbs, }, riscv::branch::branch_circuit_v2::BranchConfig, @@ -58,22 +58,14 @@ mod tests { type E = BabyBearExt4; - #[test] - fn test_extract_branch_eq_column_map() { - let mut cs = ConstraintSystem::::new(|| "test"); - let mut cb = CircuitBuilder::new(&mut cs); - let config = - BeqInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); - - let col_map = extract_branch_eq_column_map(&config, cb.cs.num_witin as usize); - let flat = col_map.to_flat(); - crate::instructions::gpu::utils::colmap_base::validate_column_map(&flat, col_map.num_cols); - } + use crate::instructions::gpu::utils::column_map::test_colmap; + test_colmap!(test_extract_branch_eq_column_map, BeqInstruction, extract_branch_eq_column_map); #[test] #[cfg(feature = "gpu")] fn test_gpu_witgen_branch_eq_correctness() { use crate::e2e::ShardContext; + use crate::instructions::gpu::utils::test_helpers::assert_witness_colmajor_eq; use ceno_emul::{ByteAddr, Change, InsnKind, PC_STEP_SIZE, StepRecord, encode_rv32}; use ceno_gpu::{Buffer, bb31::CudaHalBB31}; @@ -156,25 +148,6 @@ mod tests { let gpu_data: Vec<::BaseField> = gpu_result.witness.device_buffer.to_vec().unwrap(); - let cpu_data = cpu_witness.values(); - assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); - - let mut mismatches = 0; - for row in 0..n { - for c in 0..num_witin { - let gpu_val = gpu_data[c * n + row]; - let cpu_val = cpu_data[row * num_witin + c]; - if gpu_val != cpu_val { - if mismatches < 10 { - eprintln!( - "Mismatch at row={}, col={}: GPU={:?}, CPU={:?}", - row, c, gpu_val, cpu_val - ); - } - mismatches += 1; - } - } - } - assert_eq!(mismatches, 0, "Found {} mismatches", mismatches); + assert_witness_colmajor_eq(&gpu_data, cpu_witness.values(), n, num_witin); } } diff --git a/ceno_zkvm/src/instructions/gpu/chips/div.rs b/ceno_zkvm/src/instructions/gpu/chips/div.rs index 9f998ae02..24f91457e 100644 --- a/ceno_zkvm/src/instructions/gpu/chips/div.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/div.rs @@ -2,7 +2,7 @@ use ceno_gpu::common::witgen::types::DivColumnMap; use ff_ext::ExtensionField; use crate::instructions::{ - gpu::utils::colmap_base::{ + gpu::utils::column_map::{ extract_rd, extract_rs1, extract_rs2, extract_state, extract_uint_limbs, }, riscv::div::div_circuit_v2::DivRemConfig, @@ -99,68 +99,17 @@ mod tests { type E = BabyBearExt4; - fn test_column_map_validity(col_map: &DivColumnMap) { - let (n_entries, flat) = col_map.to_flat(); - for (i, &col) in flat[..n_entries].iter().enumerate() { - assert!( - (col as usize) < col_map.num_cols as usize, - "Column {} (index {}) out of range: {} >= {}", - i, - col, - col, - col_map.num_cols - ); - } - let mut seen = std::collections::HashSet::new(); - for &col in &flat[..n_entries] { - assert!(seen.insert(col), "Duplicate column ID: {}", col); - } - } - - #[test] - fn test_extract_div_column_map() { - let mut cs = ConstraintSystem::::new(|| "test_div"); - let mut cb = CircuitBuilder::new(&mut cs); - let config = - DivInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); - let col_map = extract_div_column_map(&config, cb.cs.num_witin as usize); - test_column_map_validity(&col_map); - } - - #[test] - fn test_extract_divu_column_map() { - let mut cs = ConstraintSystem::::new(|| "test_divu"); - let mut cb = CircuitBuilder::new(&mut cs); - let config = - DivuInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); - let col_map = extract_div_column_map(&config, cb.cs.num_witin as usize); - test_column_map_validity(&col_map); - } - - #[test] - fn test_extract_rem_column_map() { - let mut cs = ConstraintSystem::::new(|| "test_rem"); - let mut cb = CircuitBuilder::new(&mut cs); - let config = - RemInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); - let col_map = extract_div_column_map(&config, cb.cs.num_witin as usize); - test_column_map_validity(&col_map); - } - - #[test] - fn test_extract_remu_column_map() { - let mut cs = ConstraintSystem::::new(|| "test_remu"); - let mut cb = CircuitBuilder::new(&mut cs); - let config = - RemuInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); - let col_map = extract_div_column_map(&config, cb.cs.num_witin as usize); - test_column_map_validity(&col_map); - } + use crate::instructions::gpu::utils::column_map::test_colmap; + test_colmap!(test_extract_div_column_map, DivInstruction, extract_div_column_map); + test_colmap!(test_extract_divu_column_map, DivuInstruction, extract_div_column_map); + test_colmap!(test_extract_rem_column_map, RemInstruction, extract_div_column_map); + test_colmap!(test_extract_remu_column_map, RemuInstruction, extract_div_column_map); #[test] #[cfg(feature = "gpu")] fn test_gpu_witgen_div_correctness() { use crate::e2e::ShardContext; + use crate::instructions::gpu::utils::test_helpers::assert_witness_colmajor_eq; use ceno_emul::{ByteAddr, Change, InsnKind, StepRecord, encode_rv32}; use ceno_gpu::{Buffer, bb31::CudaHalBB31}; @@ -359,26 +308,7 @@ mod tests { let gpu_data: Vec<::BaseField> = gpu_result.witness.device_buffer.to_vec().unwrap(); - let cpu_data = cpu_witness.values(); - assert_eq!(gpu_data.len(), cpu_data.len(), "{}: Size mismatch", name); - - let mut mismatches = 0; - for row in 0..n { - for c in 0..num_witin { - let gpu_val = gpu_data[c * n + row]; - let cpu_val = cpu_data[row * num_witin + c]; - if gpu_val != cpu_val { - if mismatches < 10 { - eprintln!( - "{}: Mismatch at row={}, col={}: GPU={:?}, CPU={:?}", - name, row, c, gpu_val, cpu_val - ); - } - mismatches += 1; - } - } - } - assert_eq!(mismatches, 0, "{}: Found {} mismatches", name, mismatches); + assert_witness_colmajor_eq(&gpu_data, cpu_witness.values(), n, num_witin); eprintln!("{} GPU vs CPU: PASS ({} instances)", name, n); } } diff --git a/ceno_zkvm/src/instructions/gpu/chips/jal.rs b/ceno_zkvm/src/instructions/gpu/chips/jal.rs index a3e965098..e364a2e0c 100644 --- a/ceno_zkvm/src/instructions/gpu/chips/jal.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/jal.rs @@ -2,7 +2,7 @@ use ceno_gpu::common::witgen::types::JalColumnMap; use ff_ext::ExtensionField; use crate::instructions::{ - gpu::utils::colmap_base::{extract_rd, extract_state_branching, extract_uint_limbs}, + gpu::utils::column_map::{extract_rd, extract_state_branching, extract_uint_limbs}, riscv::jump::jal_v2::JalConfig, }; @@ -42,22 +42,14 @@ mod tests { type E = BabyBearExt4; - #[test] - fn test_extract_jal_column_map() { - let mut cs = ConstraintSystem::::new(|| "test_jal"); - let mut cb = CircuitBuilder::new(&mut cs); - let config = - JalInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); - - let col_map = extract_jal_column_map(&config, cb.cs.num_witin as usize); - let flat = col_map.to_flat(); - crate::instructions::gpu::utils::colmap_base::validate_column_map(&flat, col_map.num_cols); - } + use crate::instructions::gpu::utils::column_map::test_colmap; + test_colmap!(test_extract_jal_column_map, JalInstruction, extract_jal_column_map); #[test] #[cfg(feature = "gpu")] fn test_gpu_witgen_jal_correctness() { use crate::e2e::ShardContext; + use crate::instructions::gpu::utils::test_helpers::assert_witness_colmajor_eq; use ceno_emul::{ByteAddr, Change, InsnKind, PC_STEP_SIZE, StepRecord, encode_rv32}; use ceno_gpu::{Buffer, bb31::CudaHalBB31}; @@ -130,25 +122,6 @@ mod tests { let gpu_data: Vec<::BaseField> = gpu_result.witness.device_buffer.to_vec().unwrap(); - let cpu_data = cpu_witness.values(); - assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); - - let mut mismatches = 0; - for row in 0..n { - for c in 0..num_witin { - let gpu_val = gpu_data[c * n + row]; - let cpu_val = cpu_data[row * num_witin + c]; - if gpu_val != cpu_val { - if mismatches < 10 { - eprintln!( - "Mismatch at row={}, col={}: GPU={:?}, CPU={:?}", - row, c, gpu_val, cpu_val - ); - } - mismatches += 1; - } - } - } - assert_eq!(mismatches, 0, "Found {} mismatches", mismatches); + assert_witness_colmajor_eq(&gpu_data, cpu_witness.values(), n, num_witin); } } diff --git a/ceno_zkvm/src/instructions/gpu/chips/jalr.rs b/ceno_zkvm/src/instructions/gpu/chips/jalr.rs index d1da5a6af..b3a52d8ed 100644 --- a/ceno_zkvm/src/instructions/gpu/chips/jalr.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/jalr.rs @@ -2,7 +2,7 @@ use ceno_gpu::common::witgen::types::JalrColumnMap; use ff_ext::ExtensionField; use crate::instructions::{ - gpu::utils::colmap_base::{ + gpu::utils::column_map::{ extract_rd, extract_rs1, extract_state_branching, extract_uint_limbs, extract_wit_ids, }, riscv::jump::jalr_v2::JalrConfig, @@ -62,22 +62,14 @@ mod tests { type E = BabyBearExt4; - #[test] - fn test_extract_jalr_column_map() { - let mut cs = ConstraintSystem::::new(|| "test_jalr"); - let mut cb = CircuitBuilder::new(&mut cs); - let config = - JalrInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); - - let col_map = extract_jalr_column_map(&config, cb.cs.num_witin as usize); - let flat = col_map.to_flat(); - crate::instructions::gpu::utils::colmap_base::validate_column_map(&flat, col_map.num_cols); - } + use crate::instructions::gpu::utils::column_map::test_colmap; + test_colmap!(test_extract_jalr_column_map, JalrInstruction, extract_jalr_column_map); #[test] #[cfg(feature = "gpu")] fn test_gpu_witgen_jalr_correctness() { use crate::e2e::ShardContext; + use crate::instructions::gpu::utils::test_helpers::assert_witness_colmajor_eq; use ceno_emul::{ByteAddr, Change, InsnKind, PC_STEP_SIZE, StepRecord, encode_rv32}; use ceno_gpu::{Buffer, bb31::CudaHalBB31}; @@ -152,25 +144,6 @@ mod tests { let gpu_data: Vec<::BaseField> = gpu_result.witness.device_buffer.to_vec().unwrap(); - let cpu_data = cpu_witness.values(); - assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); - - let mut mismatches = 0; - for row in 0..n { - for c in 0..num_witin { - let gpu_val = gpu_data[c * n + row]; - let cpu_val = cpu_data[row * num_witin + c]; - if gpu_val != cpu_val { - if mismatches < 10 { - eprintln!( - "Mismatch at row={}, col={}: GPU={:?}, CPU={:?}", - row, c, gpu_val, cpu_val - ); - } - mismatches += 1; - } - } - } - assert_eq!(mismatches, 0, "Found {} mismatches", mismatches); + assert_witness_colmajor_eq(&gpu_data, cpu_witness.values(), n, num_witin); } } diff --git a/ceno_zkvm/src/instructions/gpu/chips/load_sub.rs b/ceno_zkvm/src/instructions/gpu/chips/load_sub.rs index 36694fdc8..608930e7d 100644 --- a/ceno_zkvm/src/instructions/gpu/chips/load_sub.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/load_sub.rs @@ -2,7 +2,7 @@ use ceno_gpu::common::witgen::types::LoadSubColumnMap; use ff_ext::ExtensionField; use crate::instructions::{ - gpu::utils::colmap_base::{ + gpu::utils::column_map::{ extract_rd, extract_read_mem, extract_rs1, extract_state, extract_uint_limbs, }, riscv::memory::load_v2::LoadConfig, @@ -12,8 +12,6 @@ use crate::instructions::{ pub fn extract_load_sub_column_map( config: &LoadConfig, num_witin: usize, - is_byte: bool, // true for LB/LBU - is_signed: bool, // true for LH/LB ) -> LoadSubColumnMap { let im = &config.im_insn; @@ -28,14 +26,13 @@ pub fn extract_load_sub_column_map( let mem_addr = extract_uint_limbs::(&config.memory_addr.addr, "memory_addr"); let mem_read = extract_uint_limbs::(&config.memory_read, "memory_read"); - // Sub-word specific: addr_bit_1 (all sub-word loads have at least 1 low_bit) + // Infer variant from config: LB/LBU have 2 low_bits, LH/LHU have 1. let low_bits = &config.memory_addr.low_bits; + let is_byte = low_bits.len() == 2; + let addr_bit_1 = if is_byte { - // LB/LBU: 2 low_bits, [0]=bit_0, [1]=bit_1 - assert_eq!(low_bits.len(), 2, "LB/LBU should have 2 low_bits"); low_bits[1].id as u32 } else { - // LH/LHU: 1 low_bit, [0]=bit_1 assert_eq!(low_bits.len(), 1, "LH/LHU should have 1 low_bit"); low_bits[0].id as u32 }; @@ -61,16 +58,8 @@ pub fn extract_load_sub_column_map( (None, None, None) }; - // Signed: msb - let msb = if is_signed { - let sec = config - .signed_extend_config - .as_ref() - .expect("signed loads must have signed_extend_config"); - Some(sec.msb().id as u32) - } else { - None - }; + // Signed loads have signed_extend_config + let msb = config.signed_extend_config.as_ref().map(|sec| sec.msb().id as u32); LoadSubColumnMap { pc, @@ -114,78 +103,17 @@ mod tests { type E = BabyBearExt4; - fn test_column_map_validity(col_map: &LoadSubColumnMap) { - let (n_entries, flat) = col_map.to_flat(); - for (i, &col) in flat[..n_entries].iter().enumerate() { - assert!( - (col as usize) < col_map.num_cols as usize, - "Column {} (index {}) out of range: {} >= {}", - i, - col, - col, - col_map.num_cols - ); - } - let mut seen = std::collections::HashSet::new(); - for &col in &flat[..n_entries] { - assert!(seen.insert(col), "Duplicate column ID: {}", col); - } - } - - #[test] - fn test_extract_lh_column_map() { - let mut cs = ConstraintSystem::::new(|| "test_lh"); - let mut cb = CircuitBuilder::new(&mut cs); - let config = - LhInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); - let col_map = extract_load_sub_column_map(&config, cb.cs.num_witin as usize, false, true); - test_column_map_validity(&col_map); - assert!(col_map.msb.is_some()); - assert!(col_map.addr_bit_0.is_none()); - } - - #[test] - fn test_extract_lhu_column_map() { - let mut cs = ConstraintSystem::::new(|| "test_lhu"); - let mut cb = CircuitBuilder::new(&mut cs); - let config = - LhuInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); - let col_map = extract_load_sub_column_map(&config, cb.cs.num_witin as usize, false, false); - test_column_map_validity(&col_map); - assert!(col_map.msb.is_none()); - assert!(col_map.addr_bit_0.is_none()); - } - - #[test] - fn test_extract_lb_column_map() { - let mut cs = ConstraintSystem::::new(|| "test_lb"); - let mut cb = CircuitBuilder::new(&mut cs); - let config = - LbInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); - let col_map = extract_load_sub_column_map(&config, cb.cs.num_witin as usize, true, true); - test_column_map_validity(&col_map); - assert!(col_map.msb.is_some()); - assert!(col_map.addr_bit_0.is_some()); - assert!(col_map.target_byte.is_some()); - } - - #[test] - fn test_extract_lbu_column_map() { - let mut cs = ConstraintSystem::::new(|| "test_lbu"); - let mut cb = CircuitBuilder::new(&mut cs); - let config = - LbuInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); - let col_map = extract_load_sub_column_map(&config, cb.cs.num_witin as usize, true, false); - test_column_map_validity(&col_map); - assert!(col_map.msb.is_none()); - assert!(col_map.addr_bit_0.is_some()); - assert!(col_map.target_byte.is_some()); - } + use crate::instructions::gpu::utils::column_map::test_colmap; + test_colmap!(test_extract_lh_column_map, LhInstruction, extract_load_sub_column_map); + test_colmap!(test_extract_lhu_column_map, LhuInstruction, extract_load_sub_column_map); + test_colmap!(test_extract_lb_column_map, LbInstruction, extract_load_sub_column_map); + test_colmap!(test_extract_lbu_column_map, LbuInstruction, extract_load_sub_column_map); #[test] #[cfg(feature = "gpu")] fn test_gpu_witgen_load_sub_correctness() { use crate::e2e::ShardContext; + use crate::instructions::gpu::utils::test_helpers::assert_witness_colmajor_eq; use ceno_emul::{ByteAddr, Change, InsnKind, ReadOp, StepRecord, WordAddr, encode_rv32}; use ceno_gpu::{Buffer, bb31::CudaHalBB31}; @@ -340,7 +268,7 @@ mod tests { let cpu_witness = &cpu_rmms[0]; // GPU path - let col_map = extract_load_sub_column_map(&config, num_witin, is_byte, is_signed); + let col_map = extract_load_sub_column_map(&config, num_witin); let shard_ctx_gpu = ShardContext::default(); let shard_offset = shard_ctx_gpu.current_shard_offset_cycle(); let steps_bytes: &[u8] = unsafe { @@ -372,26 +300,7 @@ mod tests { let gpu_data: Vec<::BaseField> = gpu_result.witness.device_buffer.to_vec().unwrap(); - let cpu_data = cpu_witness.values(); - assert_eq!(gpu_data.len(), cpu_data.len(), "{}: Size mismatch", name); - - let mut mismatches = 0; - for row in 0..n { - for c in 0..num_witin { - let gpu_val = gpu_data[c * n + row]; - let cpu_val = cpu_data[row * num_witin + c]; - if gpu_val != cpu_val { - if mismatches < 10 { - eprintln!( - "{}: Mismatch at row={}, col={}: GPU={:?}, CPU={:?}", - name, row, c, gpu_val, cpu_val - ); - } - mismatches += 1; - } - } - } - assert_eq!(mismatches, 0, "{}: Found {} mismatches", name, mismatches); + assert_witness_colmajor_eq(&gpu_data, cpu_witness.values(), n, num_witin); eprintln!("{} GPU vs CPU: PASS ({} instances)", name, n); } } diff --git a/ceno_zkvm/src/instructions/gpu/chips/logic_i.rs b/ceno_zkvm/src/instructions/gpu/chips/logic_i.rs index 3d06c8184..914437620 100644 --- a/ceno_zkvm/src/instructions/gpu/chips/logic_i.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/logic_i.rs @@ -2,7 +2,7 @@ use ceno_gpu::common::witgen::types::LogicIColumnMap; use ff_ext::ExtensionField; use crate::instructions::{ - gpu::utils::colmap_base::{extract_rd, extract_rs1, extract_state, extract_uint_limbs}, + gpu::utils::column_map::{extract_rd, extract_rs1, extract_state, extract_uint_limbs}, riscv::logic_imm::logic_imm_circuit_v2::LogicConfig, }; @@ -52,22 +52,14 @@ mod tests { type E = BabyBearExt4; - #[test] - fn test_extract_logic_i_column_map() { - let mut cs = ConstraintSystem::::new(|| "test_logic_i"); - let mut cb = CircuitBuilder::new(&mut cs); - let config = - AndiInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); - - let col_map = extract_logic_i_column_map(&config, cb.cs.num_witin as usize); - let flat = col_map.to_flat(); - crate::instructions::gpu::utils::colmap_base::validate_column_map(&flat, col_map.num_cols); - } + use crate::instructions::gpu::utils::column_map::test_colmap; + test_colmap!(test_extract_logic_i_column_map, AndiInstruction, extract_logic_i_column_map); #[test] #[cfg(feature = "gpu")] fn test_gpu_witgen_logic_i_correctness() { use crate::e2e::ShardContext; + use crate::instructions::gpu::utils::test_helpers::assert_witness_colmajor_eq; use ceno_emul::{ByteAddr, Change, InsnKind, PC_STEP_SIZE, StepRecord, encode_rv32u}; use ceno_gpu::{Buffer, bb31::CudaHalBB31}; @@ -158,25 +150,6 @@ mod tests { let gpu_data: Vec<::BaseField> = gpu_result.witness.device_buffer.to_vec().unwrap(); - let cpu_data = cpu_witness.values(); - assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); - - let mut mismatches = 0; - for row in 0..n { - for c in 0..num_witin { - let gpu_val = gpu_data[c * n + row]; - let cpu_val = cpu_data[row * num_witin + c]; - if gpu_val != cpu_val { - if mismatches < 10 { - eprintln!( - "Mismatch at row={}, col={}: GPU={:?}, CPU={:?}", - row, c, gpu_val, cpu_val - ); - } - mismatches += 1; - } - } - } - assert_eq!(mismatches, 0, "Found {} mismatches", mismatches); + assert_witness_colmajor_eq(&gpu_data, cpu_witness.values(), n, num_witin); } } diff --git a/ceno_zkvm/src/instructions/gpu/chips/logic_r.rs b/ceno_zkvm/src/instructions/gpu/chips/logic_r.rs index 770d770ba..df5d64f15 100644 --- a/ceno_zkvm/src/instructions/gpu/chips/logic_r.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/logic_r.rs @@ -2,7 +2,7 @@ use ceno_gpu::common::witgen::types::LogicRColumnMap; use ff_ext::ExtensionField; use crate::instructions::{ - gpu::utils::colmap_base::{ + gpu::utils::column_map::{ extract_rd, extract_rs1, extract_rs2, extract_state, extract_uint_limbs, }, riscv::logic::logic_circuit::LogicConfig, @@ -54,52 +54,17 @@ mod tests { type E = BabyBearExt4; - fn flatten_records( - records: &[std::collections::BTreeMap], - ) -> Vec<(ceno_emul::WordAddr, u64, u64, usize)> { - records - .iter() - .flat_map(|table| { - table - .iter() - .map(|(addr, record)| (*addr, record.prev_cycle, record.cycle, record.shard_id)) - }) - .collect() - } - - fn flatten_lk( - multiplicity: &gkr_iop::utils::lk_multiplicity::Multiplicity, - ) -> Vec> { - multiplicity - .iter() - .map(|table| { - let mut entries = table - .iter() - .map(|(key, count)| (*key, *count)) - .collect::>(); - entries.sort_unstable(); - entries - }) - .collect() - } - - #[test] - fn test_extract_logic_r_column_map() { - let mut cs = ConstraintSystem::::new(|| "test_logic_r"); - let mut cb = CircuitBuilder::new(&mut cs); - let config = - AndInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); - - let col_map = extract_logic_r_column_map(&config, cb.cs.num_witin as usize); - let flat = col_map.to_flat(); - - crate::instructions::gpu::utils::colmap_base::validate_column_map(&flat, col_map.num_cols); - } + use crate::instructions::gpu::utils::column_map::test_colmap; + test_colmap!(test_extract_logic_r_column_map, AndInstruction, extract_logic_r_column_map); #[test] #[cfg(feature = "gpu")] fn test_gpu_witgen_logic_r_correctness() { use crate::e2e::ShardContext; + use crate::instructions::gpu::dispatch; + use crate::instructions::gpu::utils::test_helpers::{ + assert_full_gpu_pipeline, assert_witness_colmajor_eq, + }; use ceno_emul::{ByteAddr, Change, InsnKind, StepRecord, encode_rv32}; use ceno_gpu::{Buffer, bb31::CudaHalBB31}; @@ -194,56 +159,17 @@ mod tests { let gpu_data: Vec<::BaseField> = gpu_result.witness.device_buffer.to_vec().unwrap(); - let cpu_data = cpu_witness.values(); - assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); - - let mut mismatches = 0; - for row in 0..n { - for c in 0..num_witin { - let gpu_val = gpu_data[c * n + row]; - let cpu_val = cpu_data[row * num_witin + c]; - if gpu_val != cpu_val { - if mismatches < 10 { - eprintln!( - "Mismatch at row={}, col={}: GPU={:?}, CPU={:?}", - row, c, gpu_val, cpu_val - ); - } - mismatches += 1; - } - } - } - assert_eq!(mismatches, 0, "Found {} mismatches", mismatches); - - let mut shard_ctx_full_gpu = ShardContext::default(); - let (gpu_rmms, gpu_lkm) = - crate::instructions::gpu::dispatch::try_gpu_assign_instances::>( - &config, - &mut shard_ctx_full_gpu, - num_witin, - num_structural_witin, - &steps, - &indices, - crate::instructions::gpu::dispatch::GpuWitgenKind::LogicR(0), - ) - .unwrap() - .expect("GPU path should be available"); - - crate::instructions::gpu::cache::flush_shared_ec_buffers(&mut shard_ctx_full_gpu).unwrap(); - - assert_eq!(gpu_rmms[0].values(), cpu_rmms[0].values()); - assert_eq!(flatten_lk(&gpu_lkm), flatten_lk(&cpu_lkm)); - assert_eq!( - shard_ctx_full_gpu.get_addr_accessed(), - shard_ctx.get_addr_accessed() - ); - assert_eq!( - flatten_records(shard_ctx_full_gpu.read_records()), - flatten_records(shard_ctx.read_records()) - ); - assert_eq!( - flatten_records(shard_ctx_full_gpu.write_records()), - flatten_records(shard_ctx.write_records()) + assert_witness_colmajor_eq(&gpu_data, cpu_witness.values(), n, num_witin); + + assert_full_gpu_pipeline::>( + &config, + &steps, + dispatch::GpuWitgenKind::LogicR(0), + &cpu_rmms, + &cpu_lkm, + &shard_ctx, + num_witin, + num_structural_witin, ); } } diff --git a/ceno_zkvm/src/instructions/gpu/chips/lui.rs b/ceno_zkvm/src/instructions/gpu/chips/lui.rs index a6b5b6278..52ea86960 100644 --- a/ceno_zkvm/src/instructions/gpu/chips/lui.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/lui.rs @@ -2,7 +2,7 @@ use ceno_gpu::common::witgen::types::LuiColumnMap; use ff_ext::ExtensionField; use crate::instructions::{ - gpu::utils::colmap_base::{extract_rd, extract_rs1, extract_state}, + gpu::utils::column_map::{extract_rd, extract_rs1, extract_state}, riscv::lui::LuiConfig, }; @@ -53,22 +53,14 @@ mod tests { type E = BabyBearExt4; - #[test] - fn test_extract_lui_column_map() { - let mut cs = ConstraintSystem::::new(|| "test_lui"); - let mut cb = CircuitBuilder::new(&mut cs); - let config = - LuiInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); - - let col_map = extract_lui_column_map(&config, cb.cs.num_witin as usize); - let flat = col_map.to_flat(); - crate::instructions::gpu::utils::colmap_base::validate_column_map(&flat, col_map.num_cols); - } + use crate::instructions::gpu::utils::column_map::test_colmap; + test_colmap!(test_extract_lui_column_map, LuiInstruction, extract_lui_column_map); #[test] #[cfg(feature = "gpu")] fn test_gpu_witgen_lui_correctness() { use crate::e2e::ShardContext; + use crate::instructions::gpu::utils::test_helpers::assert_witness_colmajor_eq; use ceno_emul::{ByteAddr, Change, InsnKind, PC_STEP_SIZE, StepRecord, encode_rv32}; use ceno_gpu::{Buffer, bb31::CudaHalBB31}; @@ -141,25 +133,6 @@ mod tests { let gpu_data: Vec<::BaseField> = gpu_result.witness.device_buffer.to_vec().unwrap(); - let cpu_data = cpu_witness.values(); - assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); - - let mut mismatches = 0; - for row in 0..n { - for c in 0..num_witin { - let gpu_val = gpu_data[c * n + row]; - let cpu_val = cpu_data[row * num_witin + c]; - if gpu_val != cpu_val { - if mismatches < 10 { - eprintln!( - "Mismatch at row={}, col={}: GPU={:?}, CPU={:?}", - row, c, gpu_val, cpu_val - ); - } - mismatches += 1; - } - } - } - assert_eq!(mismatches, 0, "Found {} mismatches", mismatches); + assert_witness_colmajor_eq(&gpu_data, cpu_witness.values(), n, num_witin); } } diff --git a/ceno_zkvm/src/instructions/gpu/chips/lw.rs b/ceno_zkvm/src/instructions/gpu/chips/lw.rs index b865d405f..d992486fa 100644 --- a/ceno_zkvm/src/instructions/gpu/chips/lw.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/lw.rs @@ -1,7 +1,7 @@ use ceno_gpu::common::witgen::types::LwColumnMap; use ff_ext::ExtensionField; -use crate::instructions::gpu::utils::colmap_base::{ +use crate::instructions::gpu::utils::column_map::{ extract_rd, extract_read_mem, extract_rs1, extract_state, extract_uint_limbs, }; @@ -66,35 +66,6 @@ mod tests { type E = BabyBearExt4; type LwInstruction = crate::instructions::riscv::LwInstruction; - fn flatten_records( - records: &[std::collections::BTreeMap], - ) -> Vec<(ceno_emul::WordAddr, u64, u64, usize)> { - records - .iter() - .flat_map(|table| { - table - .iter() - .map(|(addr, record)| (*addr, record.prev_cycle, record.cycle, record.shard_id)) - }) - .collect() - } - - fn flatten_lk( - multiplicity: &gkr_iop::utils::lk_multiplicity::Multiplicity, - ) -> Vec> { - multiplicity - .iter() - .map(|table| { - let mut entries = table - .iter() - .map(|(key, count)| (*key, *count)) - .collect::>(); - entries.sort_unstable(); - entries - }) - .collect() - } - fn make_lw_test_steps(n: usize) -> Vec { let pc_start = 0x1000u32; // Use varying immediates including negative values to test imm_field encoding @@ -129,35 +100,17 @@ mod tests { .collect() } - #[test] - fn test_extract_lw_column_map() { - let mut cs = ConstraintSystem::::new(|| "test_lw"); - let mut cb = CircuitBuilder::new(&mut cs); - let config = LwInstruction::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); - - let col_map = extract_lw_column_map(&config, cb.cs.num_witin as usize); - let (n_entries, flat) = col_map.to_flat(); - - for (i, &col) in flat[..n_entries].iter().enumerate() { - assert!( - (col as usize) < col_map.num_cols as usize, - "Column {} (index {}) out of range: {} >= {}", - i, - col, - col, - col_map.num_cols - ); - } - let mut seen = std::collections::HashSet::new(); - for &col in &flat[..n_entries] { - assert!(seen.insert(col), "Duplicate column ID: {}", col); - } - } + use crate::instructions::gpu::utils::column_map::test_colmap; + test_colmap!(test_extract_lw_column_map, LwInstruction, extract_lw_column_map); #[test] #[cfg(feature = "gpu")] fn test_gpu_witgen_lw_correctness() { use crate::e2e::ShardContext; + use crate::instructions::gpu::dispatch; + use crate::instructions::gpu::utils::test_helpers::{ + assert_full_gpu_pipeline, assert_witness_colmajor_eq, + }; use ceno_gpu::{Buffer, bb31::CudaHalBB31}; let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); @@ -214,57 +167,17 @@ mod tests { let gpu_data: Vec<::BaseField> = gpu_result.witness.device_buffer.to_vec().unwrap(); + assert_witness_colmajor_eq(&gpu_data, cpu_witness.values(), n, num_witin); - let cpu_data = cpu_witness.values(); - assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); - - let mut mismatches = 0; - for row in 0..n { - for c in 0..num_witin { - let gpu_val = gpu_data[c * n + row]; - let cpu_val = cpu_data[row * num_witin + c]; - if gpu_val != cpu_val { - if mismatches < 10 { - eprintln!( - "Mismatch at row={}, col={}: GPU={:?}, CPU={:?}", - row, c, gpu_val, cpu_val - ); - } - mismatches += 1; - } - } - } - assert_eq!(mismatches, 0, "Found {} mismatches", mismatches); - - let mut shard_ctx_full_gpu = ShardContext::default(); - let (gpu_rmms, gpu_lkm) = - crate::instructions::gpu::dispatch::try_gpu_assign_instances::( - &config, - &mut shard_ctx_full_gpu, - num_witin, - num_structural_witin, - &steps, - &indices, - crate::instructions::gpu::dispatch::GpuWitgenKind::Lw, - ) - .unwrap() - .expect("GPU path should be available"); - - crate::instructions::gpu::cache::flush_shared_ec_buffers(&mut shard_ctx_full_gpu).unwrap(); - - assert_eq!(gpu_rmms[0].values(), cpu_rmms[0].values()); - assert_eq!(flatten_lk(&gpu_lkm), flatten_lk(&cpu_lkm)); - assert_eq!( - shard_ctx_full_gpu.get_addr_accessed(), - shard_ctx.get_addr_accessed() - ); - assert_eq!( - flatten_records(shard_ctx_full_gpu.read_records()), - flatten_records(shard_ctx.read_records()) - ); - assert_eq!( - flatten_records(shard_ctx_full_gpu.write_records()), - flatten_records(shard_ctx.write_records()) + assert_full_gpu_pipeline::( + &config, + &steps, + dispatch::GpuWitgenKind::Lw, + &cpu_rmms, + &cpu_lkm, + &shard_ctx, + num_witin, + num_structural_witin, ); } } diff --git a/ceno_zkvm/src/instructions/gpu/chips/mul.rs b/ceno_zkvm/src/instructions/gpu/chips/mul.rs index 9da0cc5a8..394efe689 100644 --- a/ceno_zkvm/src/instructions/gpu/chips/mul.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/mul.rs @@ -2,7 +2,7 @@ use ceno_gpu::common::witgen::types::MulColumnMap; use ff_ext::ExtensionField; use crate::instructions::{ - gpu::utils::colmap_base::{ + gpu::utils::column_map::{ extract_rd, extract_rs1, extract_rs2, extract_state, extract_uint_limbs, }, riscv::mulh::mulh_circuit_v2::MulhConfig, @@ -13,7 +13,6 @@ use crate::instructions::{ pub fn extract_mul_column_map( config: &MulhConfig, num_witin: usize, - mul_kind: u32, ) -> MulColumnMap { let (pc, ts) = extract_state(&config.r_insn.vm_state); let (rs1_id, rs1_prev_ts, rs1_lt_diff) = extract_rs1(&config.r_insn.rs1); @@ -24,19 +23,14 @@ pub fn extract_mul_column_map( let rs2_limbs = extract_uint_limbs::(&config.rs2_read, "rs2_read"); let rd_low: [u32; 2] = [config.rd_low[0].id as u32, config.rd_low[1].id as u32]; - // MULH/MULHU/MULHSU have rd_high + extensions - let (rd_high, rs1_ext, rs2_ext) = if mul_kind != 0 { - let h = config - .rd_high - .as_ref() - .expect("MULH variants must have rd_high"); - ( + // MULH/MULHU/MULHSU have rd_high + extensions; MUL does not. + let (rd_high, rs1_ext, rs2_ext) = match config.rd_high.as_ref() { + Some(h) => ( Some([h[0].id as u32, h[1].id as u32]), Some(config.rs1_ext.expect("MULH variants must have rs1_ext").id as u32), Some(config.rs2_ext.expect("MULH variants must have rs2_ext").id as u32), - ) - } else { - (None, None, None) + ), + None => (None, None, None), }; MulColumnMap { @@ -77,72 +71,17 @@ mod tests { type E = BabyBearExt4; - fn test_column_map_validity(col_map: &MulColumnMap) { - let (n_entries, flat) = col_map.to_flat(); - for (i, &col) in flat[..n_entries].iter().enumerate() { - assert!( - (col as usize) < col_map.num_cols as usize, - "Column {} (index {}) out of range: {} >= {}", - i, - col, - col, - col_map.num_cols - ); - } - let mut seen = std::collections::HashSet::new(); - for &col in &flat[..n_entries] { - assert!(seen.insert(col), "Duplicate column ID: {}", col); - } - } - - #[test] - fn test_extract_mul_column_map() { - let mut cs = ConstraintSystem::::new(|| "test_mul"); - let mut cb = CircuitBuilder::new(&mut cs); - let config = - MulInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); - let col_map = extract_mul_column_map(&config, cb.cs.num_witin as usize, 0); - test_column_map_validity(&col_map); - assert!(col_map.rd_high.is_none()); - } - - #[test] - fn test_extract_mulh_column_map() { - let mut cs = ConstraintSystem::::new(|| "test_mulh"); - let mut cb = CircuitBuilder::new(&mut cs); - let config = - MulhInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); - let col_map = extract_mul_column_map(&config, cb.cs.num_witin as usize, 1); - test_column_map_validity(&col_map); - assert!(col_map.rd_high.is_some()); - } - - #[test] - fn test_extract_mulhu_column_map() { - let mut cs = ConstraintSystem::::new(|| "test_mulhu"); - let mut cb = CircuitBuilder::new(&mut cs); - let config = - MulhuInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); - let col_map = extract_mul_column_map(&config, cb.cs.num_witin as usize, 2); - test_column_map_validity(&col_map); - assert!(col_map.rd_high.is_some()); - } - - #[test] - fn test_extract_mulhsu_column_map() { - let mut cs = ConstraintSystem::::new(|| "test_mulhsu"); - let mut cb = CircuitBuilder::new(&mut cs); - let config = - MulhsuInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); - let col_map = extract_mul_column_map(&config, cb.cs.num_witin as usize, 3); - test_column_map_validity(&col_map); - assert!(col_map.rd_high.is_some()); - } + use crate::instructions::gpu::utils::column_map::test_colmap; + test_colmap!(test_extract_mul_column_map, MulInstruction, extract_mul_column_map); + test_colmap!(test_extract_mulh_column_map, MulhInstruction, extract_mul_column_map); + test_colmap!(test_extract_mulhu_column_map, MulhuInstruction, extract_mul_column_map); + test_colmap!(test_extract_mulhsu_column_map, MulhsuInstruction, extract_mul_column_map); #[test] #[cfg(feature = "gpu")] fn test_gpu_witgen_mul_correctness() { use crate::e2e::ShardContext; + use crate::instructions::gpu::utils::test_helpers::assert_witness_colmajor_eq; use ceno_emul::{ByteAddr, Change, InsnKind, StepRecord, encode_rv32}; use ceno_gpu::{Buffer, bb31::CudaHalBB31}; @@ -291,7 +230,7 @@ mod tests { let cpu_witness = &cpu_rmms[0]; // GPU path - let col_map = extract_mul_column_map(&config, num_witin, mul_kind); + let col_map = extract_mul_column_map(&config, num_witin); let shard_ctx_gpu = ShardContext::default(); let shard_offset = shard_ctx_gpu.current_shard_offset_cycle(); let steps_bytes: &[u8] = unsafe { @@ -319,26 +258,7 @@ mod tests { let gpu_data: Vec<::BaseField> = gpu_result.witness.device_buffer.to_vec().unwrap(); - let cpu_data = cpu_witness.values(); - assert_eq!(gpu_data.len(), cpu_data.len(), "{}: Size mismatch", name); - - let mut mismatches = 0; - for row in 0..n { - for c in 0..num_witin { - let gpu_val = gpu_data[c * n + row]; - let cpu_val = cpu_data[row * num_witin + c]; - if gpu_val != cpu_val { - if mismatches < 10 { - eprintln!( - "{}: Mismatch at row={}, col={}: GPU={:?}, CPU={:?}", - name, row, c, gpu_val, cpu_val - ); - } - mismatches += 1; - } - } - } - assert_eq!(mismatches, 0, "{}: Found {} mismatches", name, mismatches); + assert_witness_colmajor_eq(&gpu_data, cpu_witness.values(), n, num_witin); eprintln!("{} GPU vs CPU: PASS ({} instances)", name, n); } } diff --git a/ceno_zkvm/src/instructions/gpu/chips/sb.rs b/ceno_zkvm/src/instructions/gpu/chips/sb.rs index d95648b36..dd3484610 100644 --- a/ceno_zkvm/src/instructions/gpu/chips/sb.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/sb.rs @@ -2,7 +2,7 @@ use ceno_gpu::common::witgen::types::SbColumnMap; use ff_ext::ExtensionField; use crate::instructions::{ - gpu::utils::colmap_base::{ + gpu::utils::column_map::{ extract_rs1, extract_rs2, extract_state, extract_uint_limbs, extract_write_mem, }, riscv::memory::store_v2::StoreConfig, @@ -93,22 +93,14 @@ mod tests { type E = BabyBearExt4; - #[test] - fn test_extract_sb_column_map() { - let mut cs = ConstraintSystem::::new(|| "test_sb"); - let mut cb = CircuitBuilder::new(&mut cs); - let config = - SbInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); - - let col_map = extract_sb_column_map(&config, cb.cs.num_witin as usize); - let flat = col_map.to_flat(); - crate::instructions::gpu::utils::colmap_base::validate_column_map(&flat, col_map.num_cols); - } + use crate::instructions::gpu::utils::column_map::test_colmap; + test_colmap!(test_extract_sb_column_map, SbInstruction, extract_sb_column_map); #[test] #[cfg(feature = "gpu")] fn test_gpu_witgen_sb_correctness() { use crate::e2e::ShardContext; + use crate::instructions::gpu::utils::test_helpers::assert_witness_colmajor_eq; use ceno_emul::{ByteAddr, Change, InsnKind, StepRecord, WordAddr, WriteOp, encode_rv32}; use ceno_gpu::{Buffer, bb31::CudaHalBB31}; @@ -201,25 +193,6 @@ mod tests { let gpu_data: Vec<::BaseField> = gpu_result.witness.device_buffer.to_vec().unwrap(); - let cpu_data = cpu_witness.values(); - assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); - - let mut mismatches = 0; - for row in 0..n { - for c in 0..num_witin { - let gpu_val = gpu_data[c * n + row]; - let cpu_val = cpu_data[row * num_witin + c]; - if gpu_val != cpu_val { - if mismatches < 10 { - eprintln!( - "Mismatch at row={}, col={}: GPU={:?}, CPU={:?}", - row, c, gpu_val, cpu_val - ); - } - mismatches += 1; - } - } - } - assert_eq!(mismatches, 0, "Found {} mismatches", mismatches); + assert_witness_colmajor_eq(&gpu_data, cpu_witness.values(), n, num_witin); } } diff --git a/ceno_zkvm/src/instructions/gpu/chips/sh.rs b/ceno_zkvm/src/instructions/gpu/chips/sh.rs index d9f9a809c..7c59366d7 100644 --- a/ceno_zkvm/src/instructions/gpu/chips/sh.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/sh.rs @@ -2,7 +2,7 @@ use ceno_gpu::common::witgen::types::ShColumnMap; use ff_ext::ExtensionField; use crate::instructions::{ - gpu::utils::colmap_base::{ + gpu::utils::column_map::{ extract_rs1, extract_rs2, extract_state, extract_uint_limbs, extract_write_mem, }, riscv::memory::store_v2::StoreConfig, @@ -70,22 +70,14 @@ mod tests { type E = BabyBearExt4; - #[test] - fn test_extract_sh_column_map() { - let mut cs = ConstraintSystem::::new(|| "test_sh"); - let mut cb = CircuitBuilder::new(&mut cs); - let config = - ShInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); - - let col_map = extract_sh_column_map(&config, cb.cs.num_witin as usize); - let flat = col_map.to_flat(); - crate::instructions::gpu::utils::colmap_base::validate_column_map(&flat, col_map.num_cols); - } + use crate::instructions::gpu::utils::column_map::test_colmap; + test_colmap!(test_extract_sh_column_map, ShInstruction, extract_sh_column_map); #[test] #[cfg(feature = "gpu")] fn test_gpu_witgen_sh_correctness() { use crate::e2e::ShardContext; + use crate::instructions::gpu::utils::test_helpers::assert_witness_colmajor_eq; use ceno_emul::{ByteAddr, Change, InsnKind, StepRecord, WordAddr, WriteOp, encode_rv32}; use ceno_gpu::{Buffer, bb31::CudaHalBB31}; @@ -178,25 +170,6 @@ mod tests { let gpu_data: Vec<::BaseField> = gpu_result.witness.device_buffer.to_vec().unwrap(); - let cpu_data = cpu_witness.values(); - assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); - - let mut mismatches = 0; - for row in 0..n { - for c in 0..num_witin { - let gpu_val = gpu_data[c * n + row]; - let cpu_val = cpu_data[row * num_witin + c]; - if gpu_val != cpu_val { - if mismatches < 10 { - eprintln!( - "Mismatch at row={}, col={}: GPU={:?}, CPU={:?}", - row, c, gpu_val, cpu_val - ); - } - mismatches += 1; - } - } - } - assert_eq!(mismatches, 0, "Found {} mismatches", mismatches); + assert_witness_colmajor_eq(&gpu_data, cpu_witness.values(), n, num_witin); } } diff --git a/ceno_zkvm/src/instructions/gpu/chips/shift_i.rs b/ceno_zkvm/src/instructions/gpu/chips/shift_i.rs index 97c54063d..da0181dc4 100644 --- a/ceno_zkvm/src/instructions/gpu/chips/shift_i.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/shift_i.rs @@ -2,7 +2,7 @@ use ceno_gpu::common::witgen::types::ShiftIColumnMap; use ff_ext::ExtensionField; use crate::instructions::{ - gpu::utils::colmap_base::{extract_rd, extract_rs1, extract_state, extract_uint_limbs}, + gpu::utils::column_map::{extract_rd, extract_rs1, extract_state, extract_uint_limbs}, riscv::shift::shift_circuit_v2::ShiftImmConfig, }; @@ -65,22 +65,14 @@ mod tests { type E = BabyBearExt4; - #[test] - fn test_extract_shift_i_column_map() { - let mut cs = ConstraintSystem::::new(|| "test_shift_i"); - let mut cb = CircuitBuilder::new(&mut cs); - let config = - SlliInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); - - let col_map = extract_shift_i_column_map(&config, cb.cs.num_witin as usize); - let flat = col_map.to_flat(); - crate::instructions::gpu::utils::colmap_base::validate_column_map(&flat, col_map.num_cols); - } + use crate::instructions::gpu::utils::column_map::test_colmap; + test_colmap!(test_extract_shift_i_column_map, SlliInstruction, extract_shift_i_column_map); #[test] #[cfg(feature = "gpu")] fn test_gpu_witgen_shift_i_correctness() { use crate::e2e::ShardContext; + use crate::instructions::gpu::utils::test_helpers::assert_witness_colmajor_eq; use ceno_emul::{ByteAddr, Change, InsnKind, PC_STEP_SIZE, StepRecord, encode_rv32}; use ceno_gpu::{Buffer, bb31::CudaHalBB31}; @@ -169,25 +161,6 @@ mod tests { let gpu_data: Vec<::BaseField> = gpu_result.witness.device_buffer.to_vec().unwrap(); - let cpu_data = cpu_witness.values(); - assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); - - let mut mismatches = 0; - for row in 0..n { - for c in 0..num_witin { - let gpu_val = gpu_data[c * n + row]; - let cpu_val = cpu_data[row * num_witin + c]; - if gpu_val != cpu_val { - if mismatches < 10 { - eprintln!( - "Mismatch at row={}, col={}: GPU={:?}, CPU={:?}", - row, c, gpu_val, cpu_val - ); - } - mismatches += 1; - } - } - } - assert_eq!(mismatches, 0, "Found {} mismatches", mismatches); + assert_witness_colmajor_eq(&gpu_data, cpu_witness.values(), n, num_witin); } } diff --git a/ceno_zkvm/src/instructions/gpu/chips/shift_r.rs b/ceno_zkvm/src/instructions/gpu/chips/shift_r.rs index 19e8a4e58..c2efab990 100644 --- a/ceno_zkvm/src/instructions/gpu/chips/shift_r.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/shift_r.rs @@ -2,7 +2,7 @@ use ceno_gpu::common::witgen::types::ShiftRColumnMap; use ff_ext::ExtensionField; use crate::instructions::{ - gpu::utils::colmap_base::{ + gpu::utils::column_map::{ extract_rd, extract_rs1, extract_rs2, extract_state, extract_uint_limbs, }, riscv::shift::shift_circuit_v2::ShiftRTypeConfig, @@ -71,23 +71,14 @@ mod tests { type E = BabyBearExt4; - #[test] - fn test_extract_shift_r_column_map() { - let mut cs = ConstraintSystem::::new(|| "test_shift_r"); - let mut cb = CircuitBuilder::new(&mut cs); - let config = - SllInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); - - let col_map = extract_shift_r_column_map(&config, cb.cs.num_witin as usize); - let flat = col_map.to_flat(); - - crate::instructions::gpu::utils::colmap_base::validate_column_map(&flat, col_map.num_cols); - } + use crate::instructions::gpu::utils::column_map::test_colmap; + test_colmap!(test_extract_shift_r_column_map, SllInstruction, extract_shift_r_column_map); #[test] #[cfg(feature = "gpu")] fn test_gpu_witgen_shift_r_correctness() { use crate::e2e::ShardContext; + use crate::instructions::gpu::utils::test_helpers::assert_witness_colmajor_eq; use ceno_emul::{ByteAddr, Change, InsnKind, StepRecord, encode_rv32}; use ceno_gpu::{Buffer, bb31::CudaHalBB31}; @@ -176,25 +167,6 @@ mod tests { let gpu_data: Vec<::BaseField> = gpu_result.witness.device_buffer.to_vec().unwrap(); - let cpu_data = cpu_witness.values(); - assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); - - let mut mismatches = 0; - for row in 0..n { - for c in 0..num_witin { - let gpu_val = gpu_data[c * n + row]; - let cpu_val = cpu_data[row * num_witin + c]; - if gpu_val != cpu_val { - if mismatches < 10 { - eprintln!( - "Mismatch at row={}, col={}: GPU={:?}, CPU={:?}", - row, c, gpu_val, cpu_val - ); - } - mismatches += 1; - } - } - } - assert_eq!(mismatches, 0, "Found {} mismatches", mismatches); + assert_witness_colmajor_eq(&gpu_data, cpu_witness.values(), n, num_witin); } } diff --git a/ceno_zkvm/src/instructions/gpu/chips/slt.rs b/ceno_zkvm/src/instructions/gpu/chips/slt.rs index 51e21af9a..e2816af9b 100644 --- a/ceno_zkvm/src/instructions/gpu/chips/slt.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/slt.rs @@ -2,7 +2,7 @@ use ceno_gpu::common::witgen::types::SltColumnMap; use ff_ext::ExtensionField; use crate::instructions::{ - gpu::utils::colmap_base::{ + gpu::utils::column_map::{ extract_rd, extract_rs1, extract_rs2, extract_state, extract_uint_limbs, }, riscv::slt::slt_circuit_v2::SetLessThanConfig, @@ -67,23 +67,14 @@ mod tests { type E = BabyBearExt4; - #[test] - fn test_extract_slt_column_map() { - let mut cs = ConstraintSystem::::new(|| "test"); - let mut cb = CircuitBuilder::new(&mut cs); - let config = - SltInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); - - let col_map = extract_slt_column_map(&config, cb.cs.num_witin as usize); - let flat = col_map.to_flat(); - - crate::instructions::gpu::utils::colmap_base::validate_column_map(&flat, col_map.num_cols); - } + use crate::instructions::gpu::utils::column_map::test_colmap; + test_colmap!(test_extract_slt_column_map, SltInstruction, extract_slt_column_map); #[test] #[cfg(feature = "gpu")] fn test_gpu_witgen_slt_correctness() { use crate::e2e::ShardContext; + use crate::instructions::gpu::utils::test_helpers::assert_witness_colmajor_eq; use ceno_emul::{ByteAddr, Change, InsnKind, StepRecord, encode_rv32}; use ceno_gpu::{Buffer, bb31::CudaHalBB31}; @@ -163,25 +154,6 @@ mod tests { let gpu_data: Vec<::BaseField> = gpu_result.witness.device_buffer.to_vec().unwrap(); - let cpu_data = cpu_witness.values(); - assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); - - let mut mismatches = 0; - for row in 0..n { - for c in 0..num_witin { - let gpu_val = gpu_data[c * n + row]; - let cpu_val = cpu_data[row * num_witin + c]; - if gpu_val != cpu_val { - if mismatches < 10 { - eprintln!( - "Mismatch at row={}, col={}: GPU={:?}, CPU={:?}", - row, c, gpu_val, cpu_val - ); - } - mismatches += 1; - } - } - } - assert_eq!(mismatches, 0, "Found {} mismatches", mismatches); + assert_witness_colmajor_eq(&gpu_data, cpu_witness.values(), n, num_witin); } } diff --git a/ceno_zkvm/src/instructions/gpu/chips/slti.rs b/ceno_zkvm/src/instructions/gpu/chips/slti.rs index 17b4dad27..cbd21db06 100644 --- a/ceno_zkvm/src/instructions/gpu/chips/slti.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/slti.rs @@ -2,7 +2,7 @@ use ceno_gpu::common::witgen::types::SltiColumnMap; use ff_ext::ExtensionField; use crate::instructions::{ - gpu::utils::colmap_base::{extract_rd, extract_rs1, extract_state, extract_uint_limbs}, + gpu::utils::column_map::{extract_rd, extract_rs1, extract_state, extract_uint_limbs}, riscv::slti::slti_circuit_v2::SetLessThanImmConfig, }; @@ -63,22 +63,14 @@ mod tests { type E = BabyBearExt4; - #[test] - fn test_extract_slti_column_map() { - let mut cs = ConstraintSystem::::new(|| "test"); - let mut cb = CircuitBuilder::new(&mut cs); - let config = - SltiInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); - - let col_map = extract_slti_column_map(&config, cb.cs.num_witin as usize); - let flat = col_map.to_flat(); - crate::instructions::gpu::utils::colmap_base::validate_column_map(&flat, col_map.num_cols); - } + use crate::instructions::gpu::utils::column_map::test_colmap; + test_colmap!(test_extract_slti_column_map, SltiInstruction, extract_slti_column_map); #[test] #[cfg(feature = "gpu")] fn test_gpu_witgen_slti_correctness() { use crate::e2e::ShardContext; + use crate::instructions::gpu::utils::test_helpers::assert_witness_colmajor_eq; use ceno_emul::{ByteAddr, Change, InsnKind, PC_STEP_SIZE, StepRecord, encode_rv32}; use ceno_gpu::{Buffer, bb31::CudaHalBB31}; @@ -156,25 +148,6 @@ mod tests { let gpu_data: Vec<::BaseField> = gpu_result.witness.device_buffer.to_vec().unwrap(); - let cpu_data = cpu_witness.values(); - assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); - - let mut mismatches = 0; - for row in 0..n { - for c in 0..num_witin { - let gpu_val = gpu_data[c * n + row]; - let cpu_val = cpu_data[row * num_witin + c]; - if gpu_val != cpu_val { - if mismatches < 10 { - eprintln!( - "Mismatch at row={}, col={}: GPU={:?}, CPU={:?}", - row, c, gpu_val, cpu_val - ); - } - mismatches += 1; - } - } - } - assert_eq!(mismatches, 0, "Found {} mismatches", mismatches); + assert_witness_colmajor_eq(&gpu_data, cpu_witness.values(), n, num_witin); } } diff --git a/ceno_zkvm/src/instructions/gpu/chips/sub.rs b/ceno_zkvm/src/instructions/gpu/chips/sub.rs index f05fd8bf0..2861ee9b9 100644 --- a/ceno_zkvm/src/instructions/gpu/chips/sub.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/sub.rs @@ -2,7 +2,7 @@ use ceno_gpu::common::witgen::types::SubColumnMap; use ff_ext::ExtensionField; use crate::instructions::{ - gpu::utils::colmap_base::{ + gpu::utils::column_map::{ extract_carries, extract_rd, extract_rs1, extract_rs2, extract_state, extract_uint_limbs, }, riscv::arith::ArithConfig, @@ -57,23 +57,14 @@ mod tests { type E = BabyBearExt4; - #[test] - fn test_extract_sub_column_map() { - let mut cs = ConstraintSystem::::new(|| "test_sub"); - let mut cb = CircuitBuilder::new(&mut cs); - let config = - SubInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); - - let col_map = extract_sub_column_map(&config, cb.cs.num_witin as usize); - let flat = col_map.to_flat(); - - crate::instructions::gpu::utils::colmap_base::validate_column_map(&flat, col_map.num_cols); - } + use crate::instructions::gpu::utils::column_map::test_colmap; + test_colmap!(test_extract_sub_column_map, SubInstruction, extract_sub_column_map); #[test] #[cfg(feature = "gpu")] fn test_gpu_witgen_sub_correctness() { use crate::e2e::ShardContext; + use crate::instructions::gpu::utils::test_helpers::assert_witness_colmajor_eq; use ceno_emul::{ByteAddr, Change, InsnKind, StepRecord, encode_rv32}; use ceno_gpu::{Buffer, bb31::CudaHalBB31}; @@ -163,25 +154,6 @@ mod tests { let gpu_data: Vec<::BaseField> = gpu_result.witness.device_buffer.to_vec().unwrap(); - let cpu_data = cpu_witness.values(); - assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); - - let mut mismatches = 0; - for row in 0..n { - for c in 0..num_witin { - let gpu_val = gpu_data[c * n + row]; - let cpu_val = cpu_data[row * num_witin + c]; - if gpu_val != cpu_val { - if mismatches < 10 { - eprintln!( - "Mismatch at row={}, col={}: GPU={:?}, CPU={:?}", - row, c, gpu_val, cpu_val - ); - } - mismatches += 1; - } - } - } - assert_eq!(mismatches, 0, "Found {} mismatches", mismatches); + assert_witness_colmajor_eq(&gpu_data, cpu_witness.values(), n, num_witin); } } diff --git a/ceno_zkvm/src/instructions/gpu/chips/sw.rs b/ceno_zkvm/src/instructions/gpu/chips/sw.rs index 9d94603a9..481f8fd6a 100644 --- a/ceno_zkvm/src/instructions/gpu/chips/sw.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/sw.rs @@ -2,7 +2,7 @@ use ceno_gpu::common::witgen::types::SwColumnMap; use ff_ext::ExtensionField; use crate::instructions::{ - gpu::utils::colmap_base::{ + gpu::utils::column_map::{ extract_rs1, extract_rs2, extract_state, extract_uint_limbs, extract_write_mem, }, riscv::memory::store_v2::StoreConfig, @@ -61,22 +61,14 @@ mod tests { type E = BabyBearExt4; - #[test] - fn test_extract_sw_column_map() { - let mut cs = ConstraintSystem::::new(|| "test_sw"); - let mut cb = CircuitBuilder::new(&mut cs); - let config = - SwInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); - - let col_map = extract_sw_column_map(&config, cb.cs.num_witin as usize); - let flat = col_map.to_flat(); - crate::instructions::gpu::utils::colmap_base::validate_column_map(&flat, col_map.num_cols); - } + use crate::instructions::gpu::utils::column_map::test_colmap; + test_colmap!(test_extract_sw_column_map, SwInstruction, extract_sw_column_map); #[test] #[cfg(feature = "gpu")] fn test_gpu_witgen_sw_correctness() { use crate::e2e::ShardContext; + use crate::instructions::gpu::utils::test_helpers::assert_witness_colmajor_eq; use ceno_emul::{ByteAddr, Change, InsnKind, StepRecord, WordAddr, WriteOp, encode_rv32}; use ceno_gpu::{Buffer, bb31::CudaHalBB31}; @@ -161,25 +153,6 @@ mod tests { let gpu_data: Vec<::BaseField> = gpu_result.witness.device_buffer.to_vec().unwrap(); - let cpu_data = cpu_witness.values(); - assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); - - let mut mismatches = 0; - for row in 0..n { - for c in 0..num_witin { - let gpu_val = gpu_data[c * n + row]; - let cpu_val = cpu_data[row * num_witin + c]; - if gpu_val != cpu_val { - if mismatches < 10 { - eprintln!( - "Mismatch at row={}, col={}: GPU={:?}, CPU={:?}", - row, c, gpu_val, cpu_val - ); - } - mismatches += 1; - } - } - } - assert_eq!(mismatches, 0, "Found {} mismatches", mismatches); + assert_witness_colmajor_eq(&gpu_data, cpu_witness.values(), n, num_witin); } } diff --git a/ceno_zkvm/src/instructions/gpu/dispatch.rs b/ceno_zkvm/src/instructions/gpu/dispatch.rs index a4b4c2d10..c2cd131e9 100644 --- a/ceno_zkvm/src/instructions/gpu/dispatch.rs +++ b/ceno_zkvm/src/instructions/gpu/dispatch.rs @@ -1100,15 +1100,8 @@ fn gpu_fill_witness>( &*(config as *const I::InstructionConfig as *const crate::instructions::riscv::memory::load_v2::LoadConfig) }; - let is_byte = load_width == 8; - let is_signed_bool = is_signed != 0; let col_map = info_span!("col_map").in_scope(|| { - super::chips::load_sub::extract_load_sub_column_map( - load_config, - num_witin, - is_byte, - is_signed_bool, - ) + super::chips::load_sub::extract_load_sub_column_map(load_config, num_witin) }); let mem_max_bits = load_config.memory_addr.max_bits as u32; info_span!("hal_witgen_load_sub").in_scope(|| { @@ -1146,7 +1139,7 @@ fn gpu_fill_witness>( as *const crate::instructions::riscv::mulh::mulh_circuit_v2::MulhConfig) }; let col_map = info_span!("col_map").in_scope(|| { - super::chips::mul::extract_mul_column_map(mul_config, num_witin, mul_kind) + super::chips::mul::extract_mul_column_map(mul_config, num_witin) }); info_span!("hal_witgen_mul").in_scope(|| { with_cached_shard_steps(|gpu_records| { diff --git a/ceno_zkvm/src/instructions/gpu/utils/colmap_base.rs b/ceno_zkvm/src/instructions/gpu/utils/column_map.rs similarity index 79% rename from ceno_zkvm/src/instructions/gpu/utils/colmap_base.rs rename to ceno_zkvm/src/instructions/gpu/utils/column_map.rs index 34445436b..789df64f4 100644 --- a/ceno_zkvm/src/instructions/gpu/utils/colmap_base.rs +++ b/ceno_zkvm/src/instructions/gpu/utils/column_map.rs @@ -156,7 +156,7 @@ pub fn extract_wit_ids(wits: &[WitIn], label: &str) -> [u32; N] } // --------------------------------------------------------------------------- -// Test helper +// Test helpers // --------------------------------------------------------------------------- #[cfg(test)] @@ -172,3 +172,40 @@ pub fn validate_column_map(flat: &[u32], num_cols: u32) { assert!(seen.insert(col), "Duplicate column ID: {col}"); } } + +/// Generate a `test_extract_*_column_map` test that: +/// 1. Constructs the circuit via `$Instruction::construct_circuit` +/// 2. Extracts the column map via `$extract_fn` +/// 3. Validates all column IDs are in-range and unique +/// +/// Usage: +/// ```ignore +/// test_colmap!(test_extract_add_column_map, AddInstruction, extract_add_column_map); +/// // Multi-variant (e.g. div/divu/rem/remu sharing one extractor): +/// test_colmap!(test_extract_divu_column_map, DivuInstruction, extract_div_column_map); +/// ``` +#[cfg(test)] +macro_rules! test_colmap { + ($test_name:ident, $Instruction:ty, $extract_fn:ident) => { + #[test] + fn $test_name() { + let mut cs = crate::circuit_builder::ConstraintSystem::::new(|| stringify!($test_name)); + let mut cb = crate::circuit_builder::CircuitBuilder::new(&mut cs); + let config = <$Instruction as crate::instructions::Instruction>::construct_circuit( + &mut cb, + &crate::structs::ProgramParams::default(), + ) + .unwrap(); + let col_map = $extract_fn(&config, cb.cs.num_witin as usize); + let flat = col_map.to_flat(); + crate::instructions::gpu::utils::column_map::validate_column_map( + &flat, + col_map.num_cols, + ); + } + }; +} + +#[cfg(test)] +pub(crate) use test_colmap; + diff --git a/ceno_zkvm/src/instructions/gpu/utils/mod.rs b/ceno_zkvm/src/instructions/gpu/utils/mod.rs index 46ce4b6e6..5811cf199 100644 --- a/ceno_zkvm/src/instructions/gpu/utils/mod.rs +++ b/ceno_zkvm/src/instructions/gpu/utils/mod.rs @@ -14,11 +14,13 @@ pub use lk_ops::*; pub use sink::*; #[cfg(feature = "gpu")] -pub mod colmap_base; +pub mod column_map; #[cfg(feature = "gpu")] pub mod d2h; #[cfg(feature = "gpu")] pub mod debug_compare; +#[cfg(test)] +pub mod test_helpers; #[cfg(test)] mod tests { diff --git a/ceno_zkvm/src/instructions/gpu/utils/test_helpers.rs b/ceno_zkvm/src/instructions/gpu/utils/test_helpers.rs new file mode 100644 index 000000000..f8c48f598 --- /dev/null +++ b/ceno_zkvm/src/instructions/gpu/utils/test_helpers.rs @@ -0,0 +1,119 @@ +// --------------------------------------------------------------------------- +// GPU correctness test helpers +// --------------------------------------------------------------------------- + +/// Compare GPU column-major witness data against CPU row-major reference. +/// Panics with detailed mismatch info if any element differs. +#[cfg(test)] +pub fn assert_witness_colmajor_eq( + gpu_colmajor: &[F], + cpu_rowmajor: &[F], + n_rows: usize, + n_cols: usize, +) { + assert_eq!( + gpu_colmajor.len(), + cpu_rowmajor.len(), + "Size mismatch: gpu={} cpu={}", + gpu_colmajor.len(), + cpu_rowmajor.len() + ); + let mut mismatches = 0; + for row in 0..n_rows { + for col in 0..n_cols { + let gpu_val = &gpu_colmajor[col * n_rows + row]; + let cpu_val = &cpu_rowmajor[row * n_cols + col]; + if gpu_val != cpu_val { + if mismatches < 10 { + eprintln!( + "Mismatch at row={row}, col={col}: GPU={gpu_val:?}, CPU={cpu_val:?}" + ); + } + mismatches += 1; + } + } + } + assert_eq!(mismatches, 0, "Found {mismatches} mismatches"); +} + +/// Run `try_gpu_assign_instances` + `flush_shared_ec_buffers`, then assert +/// witness, LK multiplicity, addr_accessed, and read/write records all match +/// the CPU reference in `cpu_ctx`. +#[cfg(test)] +pub fn assert_full_gpu_pipeline>( + config: &I::InstructionConfig, + steps: &[ceno_emul::StepRecord], + kind: crate::instructions::gpu::dispatch::GpuWitgenKind, + cpu_rmms: &crate::tables::RMMCollections, + cpu_lkm: &gkr_iop::utils::lk_multiplicity::Multiplicity, + cpu_ctx: &crate::e2e::ShardContext, + num_witin: usize, + num_structural_witin: usize, +) { + let indices: Vec = (0..steps.len()).collect(); + + let mut gpu_ctx = crate::e2e::ShardContext::default(); + let (gpu_rmms, gpu_lkm) = + crate::instructions::gpu::dispatch::try_gpu_assign_instances::( + config, + &mut gpu_ctx, + num_witin, + num_structural_witin, + steps, + &indices, + kind, + ) + .unwrap() + .expect("GPU path should be available"); + + crate::instructions::gpu::cache::flush_shared_ec_buffers(&mut gpu_ctx).unwrap(); + + assert_eq!(gpu_rmms[0].values(), cpu_rmms[0].values(), "witness mismatch"); + assert_eq!( + flatten_lk_for_test(&gpu_lkm), + flatten_lk_for_test(cpu_lkm), + "LK multiplicity mismatch" + ); + assert_eq!( + gpu_ctx.get_addr_accessed(), + cpu_ctx.get_addr_accessed(), + "addr_accessed mismatch" + ); + assert_eq!( + flatten_records_for_test(gpu_ctx.read_records()), + flatten_records_for_test(cpu_ctx.read_records()), + "read_records mismatch" + ); + assert_eq!( + flatten_records_for_test(gpu_ctx.write_records()), + flatten_records_for_test(cpu_ctx.write_records()), + "write_records mismatch" + ); +} + +#[cfg(test)] +fn flatten_lk_for_test( + m: &gkr_iop::utils::lk_multiplicity::Multiplicity, +) -> Vec> { + m.iter() + .map(|table| { + let mut entries: Vec<_> = table.iter().map(|(k, v)| (*k, *v)).collect(); + entries.sort_unstable(); + entries + }) + .collect() +} + +#[cfg(test)] +fn flatten_records_for_test( + records: &[std::collections::BTreeMap], +) -> Vec<(ceno_emul::WordAddr, u64, u64, usize)> { + records + .iter() + .flat_map(|table| { + table + .iter() + .map(|(addr, r)| (*addr, r.prev_cycle, r.cycle, r.shard_id)) + }) + .collect() +} From 6ca050b41c76ae04fa9438e95189debbb644fefc Mon Sep 17 00:00:00 2001 From: Velaciela Date: Tue, 24 Mar 2026 22:21:03 +0800 Subject: [PATCH 66/73] config, dispatch --- ceno_zkvm/src/e2e.rs | 13 +- ceno_zkvm/src/instructions/gpu/cache.rs | 16 + .../src/instructions/gpu/chips/keccak.rs | 347 +++++++++++- ceno_zkvm/src/instructions/gpu/config.rs | 195 +------ ceno_zkvm/src/instructions/gpu/dispatch.rs | 524 +++--------------- .../instructions/gpu/utils/debug_compare.rs | 53 +- .../src/instructions/riscv/ecall/keccak.rs | 2 +- 7 files changed, 478 insertions(+), 672 deletions(-) diff --git a/ceno_zkvm/src/e2e.rs b/ceno_zkvm/src/e2e.rs index ea6615304..7b587d995 100644 --- a/ceno_zkvm/src/e2e.rs +++ b/ceno_zkvm/src/e2e.rs @@ -1490,7 +1490,7 @@ pub fn generate_witness<'a, E: ExtensionField>( } let debug_compare_e2e_shard = - std::env::var_os("CENO_GPU_DEBUG_COMPARE_E2E_SHARD").is_some(); + crate::instructions::gpu::config::is_debug_compare_enabled(); let debug_shard_ctx_template = debug_compare_e2e_shard.then(|| clone_debug_shard_ctx(&shard_ctx)); info_span!("assign_opcode_circuits").in_scope(|| { system_config @@ -1515,16 +1515,7 @@ pub fn generate_witness<'a, E: ExtensionField>( // Free GPU shard_steps cache after all opcode circuits are done. #[cfg(feature = "gpu")] - { - crate::instructions::gpu::dispatch::invalidate_shard_steps_cache(); - if std::env::var_os("CENO_GPU_TRIM_AFTER_WITGEN").is_some() { - use gkr_iop::gpu::gpu_prover::get_cuda_hal; - - let cuda_hal = get_cuda_hal().unwrap(); - cuda_hal.inner().trim_mem_pool().unwrap(); - cuda_hal.inner().synchronize().unwrap(); - } - } + crate::instructions::gpu::cache::invalidate_shard_steps_cache(); info_span!("assign_dummy_circuits").in_scope(|| { system_config diff --git a/ceno_zkvm/src/instructions/gpu/cache.rs b/ceno_zkvm/src/instructions/gpu/cache.rs index ae258b409..e42b1cb88 100644 --- a/ceno_zkvm/src/instructions/gpu/cache.rs +++ b/ceno_zkvm/src/instructions/gpu/cache.rs @@ -330,6 +330,22 @@ pub(crate) fn with_cached_shard_meta(f: impl FnOnce(&ShardDeviceBuffers) -> R }) } +/// Borrow both cached device buffers (shard_steps + shard_meta) in one call. +/// Eliminates the nested `with_cached_shard_steps(|s| with_cached_shard_meta(|m| ...))` pattern. +pub(crate) fn with_cached_gpu_ctx( + f: impl FnOnce(&CudaSlice, &ShardDeviceBuffers) -> R, +) -> R { + SHARD_STEPS_DEVICE.with(|steps_cache| { + let steps = steps_cache.borrow(); + let s = steps.as_ref().expect("shard_steps not uploaded"); + SHARD_META_CACHE.with(|meta_cache| { + let meta = meta_cache.borrow(); + let m = meta.as_ref().expect("shard metadata not uploaded"); + f(&s.device_buf, &m.device_bufs) + }) + }) +} + /// Invalidate the shard metadata cache (call when shard processing is complete). pub fn invalidate_shard_meta_cache() { SHARD_META_CACHE.with(|cache| { diff --git a/ceno_zkvm/src/instructions/gpu/chips/keccak.rs b/ceno_zkvm/src/instructions/gpu/chips/keccak.rs index c4450b55a..2f35b167c 100644 --- a/ceno_zkvm/src/instructions/gpu/chips/keccak.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/keccak.rs @@ -7,6 +7,26 @@ use crate::instructions::riscv::ecall::keccak::EcallKeccakConfig; use ceno_emul::SyscallWitness; +use ceno_gpu::{Buffer, CudaHal, bb31::CudaHalBB31, common::transpose::matrix_transpose}; +use ceno_gpu::common::witgen::types::GpuShardRamRecord; +use ceno_emul::WordAddr; +use gkr_iop::utils::lk_multiplicity::Multiplicity; +use p3::field::FieldAlgebra; +use tracing::info_span; +use witness::{InstancePaddingStrategy, RowMajorMatrix}; + +use crate::{ + e2e::ShardContext, + error::ZKVMError, + tables::RMMCollections, + witness::LkMultiplicity, +}; +use crate::instructions::gpu::config::{is_gpu_witgen_disabled, is_kind_disabled, is_debug_compare_enabled}; +use crate::instructions::gpu::dispatch::{GpuWitgenKind, compute_fetch_params, is_force_cpu_path}; +use crate::instructions::gpu::cache::{ensure_shard_metadata_cached, with_cached_shard_meta, read_shared_addr_count, read_shared_addr_range}; +use crate::instructions::gpu::utils::d2h::{gpu_lk_counters_to_multiplicity, gpu_compact_ec_d2h}; +use crate::instructions::gpu::utils::debug_compare::debug_compare_keccak; + /// Extract column map from a constructed EcallKeccakConfig. /// /// VM state columns are listed individually. Keccak math columns use @@ -169,6 +189,330 @@ pub fn pack_keccak_instances( .collect() } +/// GPU dispatch entry point for keccak ecall witness generation. +/// +/// Unlike `try_gpu_assign_instances`, keccak has a rotation-aware matrix layout +/// (each logical instance spans 32 physical rows) and requires building +/// structural witness on CPU with selector indices from the cyclic group. +#[cfg(feature = "gpu")] +pub fn gpu_assign_keccak_instances( + config: &crate::instructions::riscv::ecall::keccak::EcallKeccakConfig, + shard_ctx: &mut ShardContext, + num_witin: usize, + num_structural_witin: usize, + steps: &[StepRecord], + step_indices: &[StepIndex], +) -> Result, Multiplicity)>, ZKVMError> { + use crate::precompiles::KECCAK_ROUNDS_CEIL_LOG2; + use gkr_iop::gpu::get_cuda_hal; + + // Guard: disabled or force-CPU + if is_gpu_witgen_disabled() || is_force_cpu_path() { + return Ok(None); + } + // Check if keccak is disabled via CENO_GPU_DISABLE_WITGEN_KINDS=keccak + if is_kind_disabled(GpuWitgenKind::Keccak) { + return Ok(None); + } + + // GPU only supports BabyBear field + if std::any::TypeId::of::() + != std::any::TypeId::of::<::BaseField>() + { + return Ok(None); + } + + let hal = match get_cuda_hal() { + Ok(hal) => hal, + Err(_) => return Ok(None), + }; + + // Empty step_indices: return empty matrices + if step_indices.is_empty() { + let rotation = KECCAK_ROUNDS_CEIL_LOG2; + let raw_witin = RowMajorMatrix::::new_by_rotation( + 0, + rotation, + num_witin, + InstancePaddingStrategy::Default, + ); + let raw_structural = RowMajorMatrix::::new_by_rotation( + 0, + rotation, + num_structural_witin, + InstancePaddingStrategy::Default, + ); + let lk = LkMultiplicity::default(); + return Ok(Some(( + [raw_witin, raw_structural], + lk.into_finalize_result(), + ))); + } + + let num_instances = step_indices.len(); + tracing::debug!("[GPU witgen] keccak with {} instances", num_instances); + + info_span!("gpu_witgen_keccak", n = num_instances).in_scope(|| { + gpu_assign_keccak_inner::( + config, + shard_ctx, + num_witin, + num_structural_witin, + steps, + step_indices, + &hal, + ) + .map(Some) + }) +} + +/// Keccak-specific GPU witness generation, separate from `gpu_assign_instances_inner` because: +/// 1. Rotation: each instance spans 32 rows (not 1), requiring `new_by_rotation` +/// 2. Structural witness: 3 selectors (sel_first/sel_last/sel_all) vs the standard 1 +/// 3. Input packing: needs `packed_instances` with `syscall_witnesses` +/// +/// The LK/shardram collection logic (Steps 6-7) is identical to the standard path; +/// it is duplicated here rather than shared. +#[cfg(feature = "gpu")] +fn gpu_assign_keccak_inner( + config: &crate::instructions::riscv::ecall::keccak::EcallKeccakConfig, + shard_ctx: &mut ShardContext, + num_witin: usize, + num_structural_witin: usize, + steps: &[StepRecord], + step_indices: &[StepIndex], + hal: &CudaHalBB31, +) -> Result<(RMMCollections, Multiplicity), ZKVMError> { + use crate::precompiles::KECCAK_ROUNDS_CEIL_LOG2; + + let num_instances = step_indices.len(); + let num_padded_instances = num_instances.next_power_of_two().max(2); + let num_padded_rows = num_padded_instances * 32; // 2^5 = 32 rows per instance + let rotation = KECCAK_ROUNDS_CEIL_LOG2; // = 5 + + // Step 1: Extract column map + let col_map = info_span!("col_map") + .in_scope(|| extract_keccak_column_map(config, num_witin)); + + // Step 2: Pack instances + let packed_instances = info_span!("pack_instances").in_scope(|| { + pack_keccak_instances( + steps, + step_indices, + &shard_ctx.syscall_witnesses, + ) + }); + + // Step 3: Compute fetch params + let (fetch_base_pc, fetch_num_slots) = compute_fetch_params(steps, step_indices); + + // Step 4: Ensure shard metadata cached + info_span!("ensure_shard_meta") + .in_scope(|| ensure_shard_metadata_cached(hal, shard_ctx, steps.len()))?; + + // Snapshot shared addr count before kernel (for debug comparison) + let addr_count_before = if crate::instructions::gpu::config::is_debug_compare_enabled() { + read_shared_addr_count() + } else { + 0 + }; + + // Step 5: Launch GPU kernel + let gpu_result = info_span!("gpu_kernel").in_scope(|| { + with_cached_shard_meta(|shard_bufs| { + hal.witgen + .witgen_keccak( + &col_map, + &packed_instances, + num_padded_rows, + shard_ctx.current_shard_offset_cycle(), + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), + ) + .map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU witgen_keccak failed: {e}").into()) + }) + }) + })?; + + // D2H keccak's addr entries from shared buffer (delta since before kernel) + let gpu_keccak_addrs = if crate::instructions::gpu::config::is_debug_compare_enabled() { + let addr_count_after = read_shared_addr_count(); + if addr_count_after > addr_count_before { + read_shared_addr_range(addr_count_before, addr_count_after) + } else { + Vec::new() + } + } else { + Vec::new() + }; + + // Step 6: Collect LK multiplicity + let lk_multiplicity = info_span!("gpu_lk_d2h") + .in_scope(|| gpu_lk_counters_to_multiplicity(gpu_result.lk_counters))?; + + // Debug LK comparison is done in the unit test instead. + + // Step 7: Handle compact EC records (shared buffer path) + if gpu_result.compact_ec.is_none() && gpu_result.compact_addr.is_none() { + // Shared buffer path: EC records + addr_accessed accumulated on device + // in shared buffers across all kernel invocations. Skip per-kernel D2H. + } else if let Some(compact) = gpu_result.compact_ec { + info_span!("gpu_ec_shard").in_scope(|| { + let compact_records = + info_span!("compact_d2h").in_scope(|| gpu_compact_ec_d2h(&compact))?; + + // D2H compact addr_accessed + info_span!("compact_addr_d2h").in_scope(|| -> Result<(), ZKVMError> { + if let Some(ref ca) = gpu_result.compact_addr { + let count_vec: Vec = ca.count_buf.to_vec().map_err(|e| { + ZKVMError::InvalidWitness( + format!("compact_addr_count D2H failed: {e}").into(), + ) + })?; + let n = count_vec[0] as usize; + if n > 0 { + let addrs: Vec = ca.buffer.to_vec_n(n).map_err(|e| { + ZKVMError::InvalidWitness( + format!("compact_addr D2H failed: {e}").into(), + ) + })?; + let mut forked = shard_ctx.get_forked(); + let thread_ctx = &mut forked[0]; + for &addr in &addrs { + thread_ctx.push_addr_accessed(WordAddr(addr)); + } + } + } + Ok(()) + })?; + + // Populate shard_ctx with GPU EC records + let raw_bytes = unsafe { + std::slice::from_raw_parts( + compact_records.as_ptr() as *const u8, + compact_records.len() * std::mem::size_of::(), + ) + }; + shard_ctx.extend_gpu_ec_records_raw(raw_bytes); + + Ok::<(), ZKVMError>(()) + })?; + } + + // Step 8: Transpose GPU witness (column-major -> row-major) + D2H + let raw_witin = info_span!("transpose_d2h", rows = num_padded_rows, cols = num_witin) + .in_scope(|| { + let mut rmm_buffer = hal + .alloc_elems_on_device(num_padded_rows * num_witin, false, None) + .map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU alloc for transpose failed: {e}").into()) + })?; + matrix_transpose::( + &hal.inner, + &mut rmm_buffer, + &gpu_result.witness.device_buffer, + num_padded_rows, + num_witin, + ) + .map_err(|e| ZKVMError::InvalidWitness(format!("GPU transpose failed: {e}").into()))?; + + let gpu_data: Vec<::BaseField> = + rmm_buffer.to_vec().map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU D2H copy failed: {e}").into()) + })?; + + // Safety: BabyBear is the only supported GPU field, and E::BaseField must match + let data: Vec = unsafe { + let mut data = std::mem::ManuallyDrop::new(gpu_data); + Vec::from_raw_parts( + data.as_mut_ptr() as *mut E::BaseField, + data.len(), + data.capacity(), + ) + }; + + Ok::<_, ZKVMError>(RowMajorMatrix::::from_values_with_rotation( + data, + num_witin, + rotation, + num_instances, + InstancePaddingStrategy::Default, + )) + })?; + + // Step 9: Build structural witness on CPU with selector indices + let raw_structural = info_span!("structural_witness").in_scope(|| { + let mut raw_structural = RowMajorMatrix::::new_by_rotation( + num_instances, + rotation, + num_structural_witin, + InstancePaddingStrategy::Default, + ); + + // Get selector column IDs from config + let sel_first = config + .layout + .selector_type_layout + .sel_first + .as_ref() + .expect("sel_first must be Some"); + let sel_last = config + .layout + .selector_type_layout + .sel_last + .as_ref() + .expect("sel_last must be Some"); + + let sel_first_id = sel_first.selector_expr().id(); + let sel_last_id = sel_last.selector_expr().id(); + let sel_all_id = config + .layout + .selector_type_layout + .sel_all + .selector_expr() + .id(); + + let sel_first_indices = sel_first.sparse_indices(); + let sel_last_indices = sel_last.sparse_indices(); + let sel_all_indices = config.layout.selector_type_layout.sel_all.sparse_indices(); + + // Only set selectors for real instances, not padding ones. + for instance_chunk in raw_structural.iter_mut().take(num_instances) { + // instance_chunk is a &mut [F] of size 32 * num_structural_witin + for &idx in sel_first_indices { + instance_chunk[idx * num_structural_witin + sel_first_id] = E::BaseField::ONE; + } + for &idx in sel_last_indices { + instance_chunk[idx * num_structural_witin + sel_last_id] = E::BaseField::ONE; + } + for &idx in sel_all_indices { + instance_chunk[idx * num_structural_witin + sel_all_id] = E::BaseField::ONE; + } + } + raw_structural.padding_by_strategy(); + + raw_structural + }); + + // Debug comparisons (activated by env vars) + debug_compare_keccak::( + config, + shard_ctx, + num_witin, + num_structural_witin, + steps, + step_indices, + &lk_multiplicity, + &raw_witin, + &gpu_keccak_addrs, + )?; + + Ok(([raw_witin, raw_structural], lk_multiplicity)) +} + #[cfg(test)] mod tests { use super::*; @@ -246,7 +590,7 @@ mod tests { let cpu_structural = &cpu_rmms[1]; // --- GPU path (full pipeline via gpu_assign_keccak_instances) --- - use crate::instructions::gpu::dispatch::gpu_assign_keccak_instances; + use super::gpu_assign_keccak_instances; let mut shard_ctx_gpu = ShardContext::default(); shard_ctx_gpu.syscall_witnesses = std::sync::Arc::new(syscall_witnesses); let (gpu_rmms, gpu_lk) = gpu_assign_keccak_instances::( @@ -357,3 +701,4 @@ mod tests { assert_eq!(lk_mismatches, 0, "GPU vs CPU LK multiplicity mismatch"); } } + diff --git a/ceno_zkvm/src/instructions/gpu/config.rs b/ceno_zkvm/src/instructions/gpu/config.rs index db3c92cb6..efb5ffc6c 100644 --- a/ceno_zkvm/src/instructions/gpu/config.rs +++ b/ceno_zkvm/src/instructions/gpu/config.rs @@ -1,7 +1,10 @@ /// GPU witgen path-control helpers: kind tags, verified-kind queries, and /// environment-variable disable switches. /// -/// Extracted from `witgen_gpu.rs` — pure code move, no behavioural changes. +/// Environment variables (3 total): +/// - `CENO_GPU_DISABLE_WITGEN` — global kill switch, all chips use CPU +/// - `CENO_GPU_DISABLE_WITGEN_KINDS=add,sub,keccak,...` — per-kind disable (comma-separated tags) +/// - `CENO_GPU_DEBUG_COMPARE_WITGEN` — enable GPU vs CPU comparison for all chips (witness, LK, shard, EC) use super::dispatch::GpuWitgenKind; pub(crate) fn kind_tag(kind: GpuWitgenKind) -> &'static str { @@ -10,206 +13,40 @@ pub(crate) fn kind_tag(kind: GpuWitgenKind) -> &'static str { GpuWitgenKind::Sub => "sub", GpuWitgenKind::LogicR(_) => "logic_r", GpuWitgenKind::Lw => "lw", - #[cfg(feature = "u16limb_circuit")] GpuWitgenKind::LogicI(_) => "logic_i", - #[cfg(feature = "u16limb_circuit")] GpuWitgenKind::Addi => "addi", - #[cfg(feature = "u16limb_circuit")] GpuWitgenKind::Lui => "lui", - #[cfg(feature = "u16limb_circuit")] GpuWitgenKind::Auipc => "auipc", - #[cfg(feature = "u16limb_circuit")] GpuWitgenKind::Jal => "jal", - #[cfg(feature = "u16limb_circuit")] GpuWitgenKind::ShiftR(_) => "shift_r", - #[cfg(feature = "u16limb_circuit")] GpuWitgenKind::ShiftI(_) => "shift_i", - #[cfg(feature = "u16limb_circuit")] GpuWitgenKind::Slt(_) => "slt", - #[cfg(feature = "u16limb_circuit")] GpuWitgenKind::Slti(_) => "slti", - #[cfg(feature = "u16limb_circuit")] GpuWitgenKind::BranchEq(_) => "branch_eq", - #[cfg(feature = "u16limb_circuit")] GpuWitgenKind::BranchCmp(_) => "branch_cmp", - #[cfg(feature = "u16limb_circuit")] GpuWitgenKind::Jalr => "jalr", - #[cfg(feature = "u16limb_circuit")] GpuWitgenKind::Sw => "sw", - #[cfg(feature = "u16limb_circuit")] GpuWitgenKind::Sh => "sh", - #[cfg(feature = "u16limb_circuit")] GpuWitgenKind::Sb => "sb", - #[cfg(feature = "u16limb_circuit")] GpuWitgenKind::LoadSub { .. } => "load_sub", - #[cfg(feature = "u16limb_circuit")] GpuWitgenKind::Mul(_) => "mul", - #[cfg(feature = "u16limb_circuit")] GpuWitgenKind::Div(_) => "div", GpuWitgenKind::Keccak => "keccak", } } -/// Returns true if the GPU CUDA kernel for this kind has been verified to produce -/// correct LK multiplicity counters matching the CPU baseline. -/// Unverified kinds fall back to CPU full lk_shardram (GPU still handles witness). +/// Check if a specific GPU witgen kind is disabled via `CENO_GPU_DISABLE_WITGEN_KINDS` env var. /// -/// Override with `CENO_GPU_DISABLE_LK_KINDS=add,sub,...` to force specific kinds -/// back to CPU LK (for binary-search debugging). -/// Set `CENO_GPU_DISABLE_LK_KINDS=all` to disable GPU LK for ALL kinds. -pub(crate) fn kind_has_verified_lk(kind: GpuWitgenKind) -> bool { - if is_lk_kind_disabled(kind) { - return false; - } - match kind { - // Phase B verified (Add/Sub/LogicR/Lw) - GpuWitgenKind::Add => true, - GpuWitgenKind::Sub => true, - GpuWitgenKind::LogicR(_) => true, - GpuWitgenKind::Lw => true, - // Phase C verified via debug_compare_final_lk - #[cfg(feature = "u16limb_circuit")] - GpuWitgenKind::Addi => true, - #[cfg(feature = "u16limb_circuit")] - GpuWitgenKind::LogicI(_) => true, - #[cfg(feature = "u16limb_circuit")] - GpuWitgenKind::Lui => true, - #[cfg(feature = "u16limb_circuit")] - GpuWitgenKind::Slti(_) => true, - #[cfg(feature = "u16limb_circuit")] - GpuWitgenKind::BranchEq(_) => true, - #[cfg(feature = "u16limb_circuit")] - GpuWitgenKind::BranchCmp(_) => true, - #[cfg(feature = "u16limb_circuit")] - GpuWitgenKind::Sw => true, - // Phase C CUDA kernel fixes applied - #[cfg(feature = "u16limb_circuit")] - GpuWitgenKind::ShiftI(_) => true, - #[cfg(feature = "u16limb_circuit")] - GpuWitgenKind::Auipc => true, - #[cfg(feature = "u16limb_circuit")] - GpuWitgenKind::Jal => true, - #[cfg(feature = "u16limb_circuit")] - GpuWitgenKind::Jalr => true, - #[cfg(feature = "u16limb_circuit")] - GpuWitgenKind::Sb => true, - // Remaining kinds enabled - #[cfg(feature = "u16limb_circuit")] - GpuWitgenKind::ShiftR(_) => true, - #[cfg(feature = "u16limb_circuit")] - GpuWitgenKind::Slt(_) => true, - #[cfg(feature = "u16limb_circuit")] - GpuWitgenKind::Sh => true, - #[cfg(feature = "u16limb_circuit")] - GpuWitgenKind::LoadSub { .. } => true, - #[cfg(feature = "u16limb_circuit")] - GpuWitgenKind::Mul(_) => true, - #[cfg(feature = "u16limb_circuit")] - GpuWitgenKind::Div(_) => true, - // Keccak has its own dispatch path with its own LK handling. - GpuWitgenKind::Keccak => false, - #[cfg(not(feature = "u16limb_circuit"))] - _ => false, - } -} - -/// Returns true if GPU shard records are verified for this kind. -/// Set CENO_GPU_DISABLE_SHARD_KINDS=all to force ALL kinds back to CPU shard path. -pub(crate) fn kind_has_verified_shard(kind: GpuWitgenKind) -> bool { - // Global kill switch: force pure CPU shard path for baseline testing - if std::env::var_os("CENO_GPU_CPU_SHARD").is_some() { - return false; - } - if is_shard_kind_disabled(kind) { - return false; - } - match kind { - GpuWitgenKind::Add | GpuWitgenKind::Sub | GpuWitgenKind::LogicR(_) | GpuWitgenKind::Lw => { - true - } - #[cfg(feature = "u16limb_circuit")] - GpuWitgenKind::LogicI(_) - | GpuWitgenKind::Addi - | GpuWitgenKind::Lui - | GpuWitgenKind::Auipc - | GpuWitgenKind::Jal - | GpuWitgenKind::ShiftR(_) - | GpuWitgenKind::ShiftI(_) - | GpuWitgenKind::Slt(_) - | GpuWitgenKind::Slti(_) - | GpuWitgenKind::BranchEq(_) - | GpuWitgenKind::BranchCmp(_) - | GpuWitgenKind::Jalr - | GpuWitgenKind::Sw - | GpuWitgenKind::Sh - | GpuWitgenKind::Sb - | GpuWitgenKind::LoadSub { .. } - | GpuWitgenKind::Mul(_) - | GpuWitgenKind::Div(_) => true, - // Keccak has its own dispatch path, never enters try_gpu_assign_instances. - GpuWitgenKind::Keccak => false, - #[cfg(not(feature = "u16limb_circuit"))] - _ => false, - } -} - -/// Check if GPU LK is disabled for a specific kind via CENO_GPU_DISABLE_LK_KINDS env var. -/// Format: CENO_GPU_DISABLE_LK_KINDS=add,sub,lw (comma-separated kind tags) -/// Special value: CENO_GPU_DISABLE_LK_KINDS=all (disables GPU LK for ALL kinds) -pub(crate) fn is_lk_kind_disabled(kind: GpuWitgenKind) -> bool { - thread_local! { - static DISABLED: std::cell::OnceCell> = const { std::cell::OnceCell::new() }; - } - DISABLED.with(|cell| { - let disabled = cell.get_or_init(|| { - std::env::var("CENO_GPU_DISABLE_LK_KINDS") - .ok() - .map(|s| s.split(',').map(|t| t.trim().to_lowercase()).collect()) - .unwrap_or_default() - }); - if disabled.is_empty() { - return false; - } - if disabled.iter().any(|d| d == "all") { - return true; - } - let tag = kind_tag(kind); - disabled.iter().any(|d| d == tag) - }) -} - -/// Check if GPU shard records are disabled for a specific kind via env var. -pub(crate) fn is_shard_kind_disabled(kind: GpuWitgenKind) -> bool { - thread_local! { - static DISABLED: std::cell::OnceCell> = const { std::cell::OnceCell::new() }; - } - DISABLED.with(|cell| { - let disabled = cell.get_or_init(|| { - std::env::var("CENO_GPU_DISABLE_SHARD_KINDS") - .ok() - .map(|s| s.split(',').map(|t| t.trim().to_lowercase()).collect()) - .unwrap_or_default() - }); - if disabled.is_empty() { - return false; - } - if disabled.iter().any(|d| d == "all") { - return true; - } - let tag = kind_tag(kind); - disabled.iter().any(|d| d == tag) - }) -} - -/// Check if a specific GPU witgen kind is disabled via CENO_GPU_DISABLE_KINDS env var. -/// Format: CENO_GPU_DISABLE_KINDS=add,sub,lw (comma-separated kind tags) +/// Format: `CENO_GPU_DISABLE_WITGEN_KINDS=add,sub,keccak,lw` (comma-separated kind tags) +/// +/// This covers all chips including keccak. pub(crate) fn is_kind_disabled(kind: GpuWitgenKind) -> bool { thread_local! { static DISABLED: std::cell::OnceCell> = const { std::cell::OnceCell::new() }; } DISABLED.with(|cell| { let disabled = cell.get_or_init(|| { - std::env::var("CENO_GPU_DISABLE_KINDS") + std::env::var("CENO_GPU_DISABLE_WITGEN_KINDS") .ok() .map(|s| s.split(',').map(|t| t.trim().to_lowercase()).collect()) .unwrap_or_default() @@ -222,15 +59,14 @@ pub(crate) fn is_kind_disabled(kind: GpuWitgenKind) -> bool { }) } -/// Returns true if GPU witgen is globally disabled via CENO_GPU_DISABLE_WITGEN env var. -/// The value is cached at first access so it's immune to runtime env var manipulation. +/// Returns true if GPU witgen is globally disabled via `CENO_GPU_DISABLE_WITGEN` env var. +/// The value is cached at first access. pub(crate) fn is_gpu_witgen_disabled() -> bool { use std::sync::OnceLock; static DISABLED: OnceLock = OnceLock::new(); *DISABLED.get_or_init(|| { let val = std::env::var_os("CENO_GPU_DISABLE_WITGEN"); let disabled = val.is_some(); - // Use eprintln to bypass tracing filters — always visible on stderr eprintln!( "[GPU witgen] CENO_GPU_DISABLE_WITGEN={:?} → disabled={}", val, disabled @@ -238,3 +74,12 @@ pub(crate) fn is_gpu_witgen_disabled() -> bool { disabled }) } + +/// Returns true if `CENO_GPU_DEBUG_COMPARE_WITGEN` is set (any value). +/// When enabled, GPU vs CPU comparison runs for ALL categories: +/// witness, LK multiplicity, shard records, EC points, and E2E shard context. +pub(crate) fn is_debug_compare_enabled() -> bool { + use std::sync::OnceLock; + static ENABLED: OnceLock = OnceLock::new(); + *ENABLED.get_or_init(|| std::env::var_os("CENO_GPU_DEBUG_COMPARE_WITGEN").is_some()) +} diff --git a/ceno_zkvm/src/instructions/gpu/dispatch.rs b/ceno_zkvm/src/instructions/gpu/dispatch.rs index c2cd131e9..612d36e5f 100644 --- a/ceno_zkvm/src/instructions/gpu/dispatch.rs +++ b/ceno_zkvm/src/instructions/gpu/dispatch.rs @@ -22,10 +22,10 @@ use witness::{InstancePaddingStrategy, RowMajorMatrix}; use super::{ config::{ - is_gpu_witgen_disabled, is_kind_disabled, kind_has_verified_lk, kind_has_verified_shard, + is_gpu_witgen_disabled, is_kind_disabled, }, utils::debug_compare::{ - debug_compare_final_lk, debug_compare_keccak, debug_compare_shard_ec, + debug_compare_final_lk, debug_compare_shard_ec, debug_compare_shardram, debug_compare_witness, }, }; @@ -42,44 +42,44 @@ pub enum GpuWitgenKind { Add, Sub, LogicR(u32), // 0=AND, 1=OR, 2=XOR - #[cfg(feature = "u16limb_circuit")] + LogicI(u32), // 0=AND, 1=OR, 2=XOR - #[cfg(feature = "u16limb_circuit")] + Addi, - #[cfg(feature = "u16limb_circuit")] + Lui, - #[cfg(feature = "u16limb_circuit")] + Auipc, - #[cfg(feature = "u16limb_circuit")] + Jal, - #[cfg(feature = "u16limb_circuit")] + ShiftR(u32), // 0=SLL, 1=SRL, 2=SRA - #[cfg(feature = "u16limb_circuit")] + ShiftI(u32), // 0=SLLI, 1=SRLI, 2=SRAI - #[cfg(feature = "u16limb_circuit")] + Slt(u32), // 1=SLT(signed), 0=SLTU(unsigned) - #[cfg(feature = "u16limb_circuit")] + Slti(u32), // 1=SLTI(signed), 0=SLTIU(unsigned) - #[cfg(feature = "u16limb_circuit")] + BranchEq(u32), // 1=BEQ, 0=BNE - #[cfg(feature = "u16limb_circuit")] + BranchCmp(u32), // 1=signed (BLT/BGE), 0=unsigned (BLTU/BGEU) - #[cfg(feature = "u16limb_circuit")] + Jalr, - #[cfg(feature = "u16limb_circuit")] + Sw, - #[cfg(feature = "u16limb_circuit")] + Sh, - #[cfg(feature = "u16limb_circuit")] + Sb, - #[cfg(feature = "u16limb_circuit")] + LoadSub { load_width: u32, is_signed: u32, }, - #[cfg(feature = "u16limb_circuit")] + Mul(u32), // 0=MUL, 1=MULH, 2=MULHU, 3=MULHSU - #[cfg(feature = "u16limb_circuit")] + Div(u32), // 0=DIV, 1=DIVU, 2=REM, 3=REMU Lw, Keccak, @@ -95,7 +95,7 @@ pub use super::utils::d2h::gpu_batch_continuation_ec; use super::{ cache::{ ensure_shard_metadata_cached, read_shared_addr_count, read_shared_addr_range, - upload_shard_steps_cached, with_cached_shard_meta, with_cached_shard_steps, + upload_shard_steps_cached, with_cached_gpu_ctx, with_cached_shard_meta, with_cached_shard_steps, }, utils::d2h::{ CompactEcBuf, LkResult, RamBuf, WitResult, gpu_collect_shard_records, gpu_compact_ec_d2h, @@ -114,7 +114,7 @@ pub fn set_force_cpu_path(force: bool) { FORCE_CPU_PATH.with(|f| f.set(force)); } -fn is_force_cpu_path() -> bool { +pub(crate) fn is_force_cpu_path() -> bool { FORCE_CPU_PATH.with(|f| f.get()) } @@ -222,15 +222,20 @@ fn gpu_assign_instances_inner>( // Step 2: Collect lk and shardram // Priority: GPU shard records > CPU shard records > full CPU lk and shardram - let lk_multiplicity = if gpu_lk_counters.is_some() && kind_has_verified_lk(kind) { + // + // Keccak never enters this function (it has `gpu_assign_keccak_inner`). + // Guard defensively in case the enum value is ever passed here by mistake. + let is_standard_kind = !matches!(kind, GpuWitgenKind::Keccak); + + let lk_multiplicity = if gpu_lk_counters.is_some() && is_standard_kind { let lk_multiplicity = info_span!("gpu_lk_d2h") .in_scope(|| gpu_lk_counters_to_multiplicity(gpu_lk_counters.unwrap()))?; - if gpu_compact_ec.is_none() && gpu_compact_addr.is_none() && kind_has_verified_shard(kind) { + if gpu_compact_ec.is_none() && gpu_compact_addr.is_none() && is_standard_kind { // Shared buffer path: EC records + addr_accessed accumulated on device // in shared buffers across all kernel invocations. Skip per-kernel D2H. // Data will be consumed in batch by assign_shared_circuit. - } else if gpu_compact_ec.is_some() && kind_has_verified_shard(kind) { + } else if gpu_compact_ec.is_some() && is_standard_kind { // GPU EC path: compact records already have EC points computed on device. // D2H only the active records (much smaller than full N*3 slot buffer). info_span!("gpu_ec_shard").in_scope(|| { @@ -292,7 +297,7 @@ fn gpu_assign_instances_inner>( })?; // Debug: compare GPU shard_ctx vs CPU shard_ctx independently - if std::env::var_os("CENO_GPU_DEBUG_COMPARE_EC").is_some() { + if crate::instructions::gpu::config::is_debug_compare_enabled() { let slots = ram_slots_d2h()?; debug_compare_shard_ec::( &compact_records, @@ -316,7 +321,7 @@ fn gpu_assign_instances_inner>( Ok::<(), ZKVMError>(()) })?; - } else if gpu_ram_slots.is_some() && kind_has_verified_shard(kind) { + } else if gpu_ram_slots.is_some() && is_standard_kind { // GPU shard records path (no EC): D2H + lightweight CPU scan info_span!("gpu_shard_records").in_scope(|| { let ram_buf = gpu_ram_slots.unwrap(); @@ -337,7 +342,7 @@ fn gpu_assign_instances_inner>( } else { // CPU: collect shard records only (send/addr_accessed). info_span!("cpu_shard_records").in_scope(|| { - let _ = collect_shardram::(config, shard_ctx, shard_steps, step_indices)?; + let _ = cpu_collect_shardram::(config, shard_ctx, shard_steps, step_indices)?; Ok::<(), ZKVMError>(()) })?; } @@ -345,7 +350,7 @@ fn gpu_assign_instances_inner>( } else { // GPU LK counters missing or unverified — fall back to full CPU lk and shardram info_span!("cpu_lk_shardram").in_scope(|| { - collect_lk_and_shardram::(config, shard_ctx, shard_steps, step_indices) + cpu_collect_lk_and_shardram::(config, shard_ctx, shard_steps, step_indices) })? }; debug_compare_final_lk::( @@ -479,8 +484,7 @@ fn gpu_fill_witness>( let col_map = info_span!("col_map") .in_scope(|| super::chips::add::extract_add_column_map(arith_config, num_witin)); info_span!("hal_witgen_add").in_scope(|| { - with_cached_shard_steps(|gpu_records| { - with_cached_shard_meta(|shard_bufs| { + with_cached_gpu_ctx(|gpu_records, shard_bufs| { split_full!( hal.witgen .witgen_add( @@ -500,7 +504,6 @@ fn gpu_fill_witness>( }) ) }) - }) }) } GpuWitgenKind::Sub => { @@ -511,8 +514,7 @@ fn gpu_fill_witness>( let col_map = info_span!("col_map") .in_scope(|| super::chips::sub::extract_sub_column_map(arith_config, num_witin)); info_span!("hal_witgen_sub").in_scope(|| { - with_cached_shard_steps(|gpu_records| { - with_cached_shard_meta(|shard_bufs| { + with_cached_gpu_ctx(|gpu_records, shard_bufs| { split_full!( hal.witgen .witgen_sub( @@ -532,7 +534,6 @@ fn gpu_fill_witness>( }) ) }) - }) }) } GpuWitgenKind::LogicR(logic_kind) => { @@ -544,8 +545,7 @@ fn gpu_fill_witness>( super::chips::logic_r::extract_logic_r_column_map(logic_config, num_witin) }); info_span!("hal_witgen_logic_r").in_scope(|| { - with_cached_shard_steps(|gpu_records| { - with_cached_shard_meta(|shard_bufs| { + with_cached_gpu_ctx(|gpu_records, shard_bufs| { split_full!( hal.witgen .witgen_logic_r( @@ -566,10 +566,9 @@ fn gpu_fill_witness>( }) ) }) - }) }) } - #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::LogicI(logic_kind) => { let logic_config = unsafe { &*(config as *const I::InstructionConfig @@ -579,8 +578,7 @@ fn gpu_fill_witness>( super::chips::logic_i::extract_logic_i_column_map(logic_config, num_witin) }); info_span!("hal_witgen_logic_i").in_scope(|| { - with_cached_shard_steps(|gpu_records| { - with_cached_shard_meta(|shard_bufs| { + with_cached_gpu_ctx(|gpu_records, shard_bufs| { split_full!( hal.witgen .witgen_logic_i( @@ -601,10 +599,9 @@ fn gpu_fill_witness>( }) ) }) - }) }) } - #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Addi => { let addi_config = unsafe { &*(config as *const I::InstructionConfig @@ -613,8 +610,7 @@ fn gpu_fill_witness>( let col_map = info_span!("col_map") .in_scope(|| super::chips::addi::extract_addi_column_map(addi_config, num_witin)); info_span!("hal_witgen_addi").in_scope(|| { - with_cached_shard_steps(|gpu_records| { - with_cached_shard_meta(|shard_bufs| { + with_cached_gpu_ctx(|gpu_records, shard_bufs| { split_full!( hal.witgen .witgen_addi( @@ -634,10 +630,9 @@ fn gpu_fill_witness>( }) ) }) - }) }) } - #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Lui => { let lui_config = unsafe { &*(config as *const I::InstructionConfig @@ -646,8 +641,7 @@ fn gpu_fill_witness>( let col_map = info_span!("col_map") .in_scope(|| super::chips::lui::extract_lui_column_map(lui_config, num_witin)); info_span!("hal_witgen_lui").in_scope(|| { - with_cached_shard_steps(|gpu_records| { - with_cached_shard_meta(|shard_bufs| { + with_cached_gpu_ctx(|gpu_records, shard_bufs| { split_full!( hal.witgen .witgen_lui( @@ -667,10 +661,9 @@ fn gpu_fill_witness>( }) ) }) - }) }) } - #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Auipc => { let auipc_config = unsafe { &*(config as *const I::InstructionConfig @@ -680,8 +673,7 @@ fn gpu_fill_witness>( super::chips::auipc::extract_auipc_column_map(auipc_config, num_witin) }); info_span!("hal_witgen_auipc").in_scope(|| { - with_cached_shard_steps(|gpu_records| { - with_cached_shard_meta(|shard_bufs| { + with_cached_gpu_ctx(|gpu_records, shard_bufs| { split_full!( hal.witgen .witgen_auipc( @@ -701,10 +693,9 @@ fn gpu_fill_witness>( }) ) }) - }) }) } - #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Jal => { let jal_config = unsafe { &*(config as *const I::InstructionConfig @@ -713,8 +704,7 @@ fn gpu_fill_witness>( let col_map = info_span!("col_map") .in_scope(|| super::chips::jal::extract_jal_column_map(jal_config, num_witin)); info_span!("hal_witgen_jal").in_scope(|| { - with_cached_shard_steps(|gpu_records| { - with_cached_shard_meta(|shard_bufs| { + with_cached_gpu_ctx(|gpu_records, shard_bufs| { split_full!( hal.witgen .witgen_jal( @@ -734,10 +724,9 @@ fn gpu_fill_witness>( }) ) }) - }) }) } - #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::ShiftR(shift_kind) => { let shift_config = unsafe { &*(config as *const I::InstructionConfig @@ -749,8 +738,7 @@ fn gpu_fill_witness>( super::chips::shift_r::extract_shift_r_column_map(shift_config, num_witin) }); info_span!("hal_witgen_shift_r").in_scope(|| { - with_cached_shard_steps(|gpu_records| { - with_cached_shard_meta(|shard_bufs| { + with_cached_gpu_ctx(|gpu_records, shard_bufs| { split_full!( hal.witgen .witgen_shift_r( @@ -771,10 +759,9 @@ fn gpu_fill_witness>( }) ) }) - }) }) } - #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::ShiftI(shift_kind) => { let shift_config = unsafe { &*(config as *const I::InstructionConfig @@ -786,8 +773,7 @@ fn gpu_fill_witness>( super::chips::shift_i::extract_shift_i_column_map(shift_config, num_witin) }); info_span!("hal_witgen_shift_i").in_scope(|| { - with_cached_shard_steps(|gpu_records| { - with_cached_shard_meta(|shard_bufs| { + with_cached_gpu_ctx(|gpu_records, shard_bufs| { split_full!( hal.witgen .witgen_shift_i( @@ -808,10 +794,9 @@ fn gpu_fill_witness>( }) ) }) - }) }) } - #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Slt(is_signed) => { let slt_config = unsafe { &*(config as *const I::InstructionConfig @@ -820,8 +805,7 @@ fn gpu_fill_witness>( let col_map = info_span!("col_map") .in_scope(|| super::chips::slt::extract_slt_column_map(slt_config, num_witin)); info_span!("hal_witgen_slt").in_scope(|| { - with_cached_shard_steps(|gpu_records| { - with_cached_shard_meta(|shard_bufs| { + with_cached_gpu_ctx(|gpu_records, shard_bufs| { split_full!( hal.witgen .witgen_slt( @@ -842,10 +826,9 @@ fn gpu_fill_witness>( }) ) }) - }) }) } - #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Slti(is_signed) => { let slti_config = unsafe { &*(config as *const I::InstructionConfig @@ -854,8 +837,7 @@ fn gpu_fill_witness>( let col_map = info_span!("col_map") .in_scope(|| super::chips::slti::extract_slti_column_map(slti_config, num_witin)); info_span!("hal_witgen_slti").in_scope(|| { - with_cached_shard_steps(|gpu_records| { - with_cached_shard_meta(|shard_bufs| { + with_cached_gpu_ctx(|gpu_records, shard_bufs| { split_full!( hal.witgen .witgen_slti( @@ -876,10 +858,9 @@ fn gpu_fill_witness>( }) ) }) - }) }) } - #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::BranchEq(is_beq) => { let branch_config = unsafe { &*(config as *const I::InstructionConfig @@ -891,8 +872,7 @@ fn gpu_fill_witness>( super::chips::branch_eq::extract_branch_eq_column_map(branch_config, num_witin) }); info_span!("hal_witgen_branch_eq").in_scope(|| { - with_cached_shard_steps(|gpu_records| { - with_cached_shard_meta(|shard_bufs| { + with_cached_gpu_ctx(|gpu_records, shard_bufs| { split_full!( hal.witgen .witgen_branch_eq( @@ -913,10 +893,9 @@ fn gpu_fill_witness>( }) ) }) - }) }) } - #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::BranchCmp(is_signed) => { let branch_config = unsafe { &*(config as *const I::InstructionConfig @@ -928,8 +907,7 @@ fn gpu_fill_witness>( super::chips::branch_cmp::extract_branch_cmp_column_map(branch_config, num_witin) }); info_span!("hal_witgen_branch_cmp").in_scope(|| { - with_cached_shard_steps(|gpu_records| { - with_cached_shard_meta(|shard_bufs| { + with_cached_gpu_ctx(|gpu_records, shard_bufs| { split_full!( hal.witgen .witgen_branch_cmp( @@ -950,10 +928,9 @@ fn gpu_fill_witness>( }) ) }) - }) }) } - #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Jalr => { let jalr_config = unsafe { &*(config as *const I::InstructionConfig @@ -962,8 +939,7 @@ fn gpu_fill_witness>( let col_map = info_span!("col_map") .in_scope(|| super::chips::jalr::extract_jalr_column_map(jalr_config, num_witin)); info_span!("hal_witgen_jalr").in_scope(|| { - with_cached_shard_steps(|gpu_records| { - with_cached_shard_meta(|shard_bufs| { + with_cached_gpu_ctx(|gpu_records, shard_bufs| { split_full!( hal.witgen .witgen_jalr( @@ -983,10 +959,9 @@ fn gpu_fill_witness>( }) ) }) - }) }) } - #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Sw => { let sw_config = unsafe { &*(config as *const I::InstructionConfig @@ -996,8 +971,7 @@ fn gpu_fill_witness>( let col_map = info_span!("col_map") .in_scope(|| super::chips::sw::extract_sw_column_map(sw_config, num_witin)); info_span!("hal_witgen_sw").in_scope(|| { - with_cached_shard_steps(|gpu_records| { - with_cached_shard_meta(|shard_bufs| { + with_cached_gpu_ctx(|gpu_records, shard_bufs| { split_full!( hal.witgen .witgen_sw( @@ -1018,10 +992,9 @@ fn gpu_fill_witness>( }) ) }) - }) }) } - #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Sh => { let sh_config = unsafe { &*(config as *const I::InstructionConfig @@ -1031,8 +1004,7 @@ fn gpu_fill_witness>( let col_map = info_span!("col_map") .in_scope(|| super::chips::sh::extract_sh_column_map(sh_config, num_witin)); info_span!("hal_witgen_sh").in_scope(|| { - with_cached_shard_steps(|gpu_records| { - with_cached_shard_meta(|shard_bufs| { + with_cached_gpu_ctx(|gpu_records, shard_bufs| { split_full!( hal.witgen .witgen_sh( @@ -1053,10 +1025,9 @@ fn gpu_fill_witness>( }) ) }) - }) }) } - #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Sb => { let sb_config = unsafe { &*(config as *const I::InstructionConfig @@ -1066,8 +1037,7 @@ fn gpu_fill_witness>( let col_map = info_span!("col_map") .in_scope(|| super::chips::sb::extract_sb_column_map(sb_config, num_witin)); info_span!("hal_witgen_sb").in_scope(|| { - with_cached_shard_steps(|gpu_records| { - with_cached_shard_meta(|shard_bufs| { + with_cached_gpu_ctx(|gpu_records, shard_bufs| { split_full!( hal.witgen .witgen_sb( @@ -1088,10 +1058,9 @@ fn gpu_fill_witness>( }) ) }) - }) }) } - #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::LoadSub { load_width, is_signed, @@ -1105,8 +1074,7 @@ fn gpu_fill_witness>( }); let mem_max_bits = load_config.memory_addr.max_bits as u32; info_span!("hal_witgen_load_sub").in_scope(|| { - with_cached_shard_steps(|gpu_records| { - with_cached_shard_meta(|shard_bufs| { + with_cached_gpu_ctx(|gpu_records, shard_bufs| { split_full!( hal.witgen .witgen_load_sub( @@ -1129,10 +1097,9 @@ fn gpu_fill_witness>( }) ) }) - }) }) } - #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Mul(mul_kind) => { let mul_config = unsafe { &*(config as *const I::InstructionConfig @@ -1142,8 +1109,7 @@ fn gpu_fill_witness>( super::chips::mul::extract_mul_column_map(mul_config, num_witin) }); info_span!("hal_witgen_mul").in_scope(|| { - with_cached_shard_steps(|gpu_records| { - with_cached_shard_meta(|shard_bufs| { + with_cached_gpu_ctx(|gpu_records, shard_bufs| { split_full!( hal.witgen .witgen_mul( @@ -1164,10 +1130,9 @@ fn gpu_fill_witness>( }) ) }) - }) }) } - #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Div(div_kind) => { let div_config = unsafe { &*(config as *const I::InstructionConfig @@ -1176,8 +1141,7 @@ fn gpu_fill_witness>( let col_map = info_span!("col_map") .in_scope(|| super::chips::div::extract_div_column_map(div_config, num_witin)); info_span!("hal_witgen_div").in_scope(|| { - with_cached_shard_steps(|gpu_records| { - with_cached_shard_meta(|shard_bufs| { + with_cached_gpu_ctx(|gpu_records, shard_bufs| { split_full!( hal.witgen .witgen_div( @@ -1198,11 +1162,10 @@ fn gpu_fill_witness>( }) ) }) - }) }) } GpuWitgenKind::Lw => { - #[cfg(feature = "u16limb_circuit")] + let load_config = unsafe { &*(config as *const I::InstructionConfig as *const crate::instructions::riscv::memory::load_v2::LoadConfig) @@ -1216,8 +1179,7 @@ fn gpu_fill_witness>( let col_map = info_span!("col_map") .in_scope(|| super::chips::lw::extract_lw_column_map(load_config, num_witin)); info_span!("hal_witgen_lw").in_scope(|| { - with_cached_shard_steps(|gpu_records| { - with_cached_shard_meta(|shard_bufs| { + with_cached_gpu_ctx(|gpu_records, shard_bufs| { split_full!( hal.witgen .witgen_lw( @@ -1238,7 +1200,6 @@ fn gpu_fill_witness>( }) ) }) - }) }) } GpuWitgenKind::Keccak => { @@ -1247,339 +1208,4 @@ fn gpu_fill_witness>( } } -/// CPU-side loop to collect lk and shardram only (shard_ctx.send, lk_multiplicity). -/// Runs assign_instance with a scratch buffer per thread. -fn collect_lk_and_shardram>( - config: &I::InstructionConfig, - shard_ctx: &mut ShardContext, - shard_steps: &[StepRecord], - step_indices: &[StepIndex], -) -> Result, ZKVMError> { - cpu_collect_lk_and_shardram::(config, shard_ctx, shard_steps, step_indices) -} - -fn collect_shardram>( - config: &I::InstructionConfig, - shard_ctx: &mut ShardContext, - shard_steps: &[StepRecord], - step_indices: &[StepIndex], -) -> Result, ZKVMError> { - cpu_collect_shardram::(config, shard_ctx, shard_steps, step_indices) -} - -/// GPU dispatch entry point for keccak ecall witness generation. -/// -/// Unlike `try_gpu_assign_instances`, keccak has a rotation-aware matrix layout -/// (each logical instance spans 32 physical rows) and requires building -/// structural witness on CPU with selector indices from the cyclic group. -#[cfg(feature = "gpu")] -pub fn gpu_assign_keccak_instances( - config: &crate::instructions::riscv::ecall::keccak::EcallKeccakConfig, - shard_ctx: &mut ShardContext, - num_witin: usize, - num_structural_witin: usize, - steps: &[StepRecord], - step_indices: &[StepIndex], -) -> Result, Multiplicity)>, ZKVMError> { - use crate::precompiles::KECCAK_ROUNDS_CEIL_LOG2; - use gkr_iop::gpu::get_cuda_hal; - - // Guard: disabled or force-CPU - if is_gpu_witgen_disabled() || is_force_cpu_path() { - return Ok(None); - } - // CENO_GPU_DISABLE_KECCAK=1 → fall back to CPU keccak witgen - if std::env::var_os("CENO_GPU_DISABLE_KECCAK").is_some() { - return Ok(None); - } - - // GPU only supports BabyBear field - if std::any::TypeId::of::() - != std::any::TypeId::of::<::BaseField>() - { - return Ok(None); - } - - let hal = match get_cuda_hal() { - Ok(hal) => hal, - Err(_) => return Ok(None), - }; - - // Empty step_indices: return empty matrices - if step_indices.is_empty() { - let rotation = KECCAK_ROUNDS_CEIL_LOG2; - let raw_witin = RowMajorMatrix::::new_by_rotation( - 0, - rotation, - num_witin, - InstancePaddingStrategy::Default, - ); - let raw_structural = RowMajorMatrix::::new_by_rotation( - 0, - rotation, - num_structural_witin, - InstancePaddingStrategy::Default, - ); - let lk = LkMultiplicity::default(); - return Ok(Some(( - [raw_witin, raw_structural], - lk.into_finalize_result(), - ))); - } - - let num_instances = step_indices.len(); - tracing::debug!("[GPU witgen] keccak with {} instances", num_instances); - - info_span!("gpu_witgen_keccak", n = num_instances).in_scope(|| { - gpu_assign_keccak_inner::( - config, - shard_ctx, - num_witin, - num_structural_witin, - steps, - step_indices, - &hal, - ) - .map(Some) - }) -} - -#[cfg(feature = "gpu")] -fn gpu_assign_keccak_inner( - config: &crate::instructions::riscv::ecall::keccak::EcallKeccakConfig, - shard_ctx: &mut ShardContext, - num_witin: usize, - num_structural_witin: usize, - steps: &[StepRecord], - step_indices: &[StepIndex], - hal: &CudaHalBB31, -) -> Result<(RMMCollections, Multiplicity), ZKVMError> { - use crate::precompiles::KECCAK_ROUNDS_CEIL_LOG2; - - let num_instances = step_indices.len(); - let num_padded_instances = num_instances.next_power_of_two().max(2); - let num_padded_rows = num_padded_instances * 32; // 2^5 = 32 rows per instance - let rotation = KECCAK_ROUNDS_CEIL_LOG2; // = 5 - - // Step 1: Extract column map - let col_map = info_span!("col_map") - .in_scope(|| super::chips::keccak::extract_keccak_column_map(config, num_witin)); - // Step 2: Pack instances - let packed_instances = info_span!("pack_instances").in_scope(|| { - super::chips::keccak::pack_keccak_instances( - steps, - step_indices, - &shard_ctx.syscall_witnesses, - ) - }); - - // Step 3: Compute fetch params - let (fetch_base_pc, fetch_num_slots) = compute_fetch_params(steps, step_indices); - - // Step 4: Ensure shard metadata cached - info_span!("ensure_shard_meta") - .in_scope(|| ensure_shard_metadata_cached(hal, shard_ctx, steps.len()))?; - - // Snapshot shared addr count before kernel (for debug comparison) - let addr_count_before = if std::env::var_os("CENO_GPU_DEBUG_COMPARE_SHARD").is_some() { - read_shared_addr_count() - } else { - 0 - }; - - // Step 5: Launch GPU kernel - let gpu_result = info_span!("gpu_kernel").in_scope(|| { - with_cached_shard_meta(|shard_bufs| { - hal.witgen - .witgen_keccak( - &col_map, - &packed_instances, - num_padded_rows, - shard_ctx.current_shard_offset_cycle(), - fetch_base_pc, - fetch_num_slots, - None, - Some(shard_bufs), - ) - .map_err(|e| { - ZKVMError::InvalidWitness(format!("GPU witgen_keccak failed: {e}").into()) - }) - }) - })?; - - // D2H keccak's addr entries from shared buffer (delta since before kernel) - let gpu_keccak_addrs = if std::env::var_os("CENO_GPU_DEBUG_COMPARE_SHARD").is_some() { - let addr_count_after = read_shared_addr_count(); - if addr_count_after > addr_count_before { - read_shared_addr_range(addr_count_before, addr_count_after) - } else { - Vec::new() - } - } else { - Vec::new() - }; - - // Step 6: Collect LK multiplicity - let lk_multiplicity = info_span!("gpu_lk_d2h") - .in_scope(|| gpu_lk_counters_to_multiplicity(gpu_result.lk_counters))?; - - // Debug LK comparison is done in the unit test instead. - - // Step 7: Handle compact EC records (shared buffer path) - if gpu_result.compact_ec.is_none() && gpu_result.compact_addr.is_none() { - // Shared buffer path: EC records + addr_accessed accumulated on device - // in shared buffers across all kernel invocations. Skip per-kernel D2H. - } else if let Some(compact) = gpu_result.compact_ec { - info_span!("gpu_ec_shard").in_scope(|| { - let compact_records = - info_span!("compact_d2h").in_scope(|| gpu_compact_ec_d2h(&compact))?; - - // D2H compact addr_accessed - info_span!("compact_addr_d2h").in_scope(|| -> Result<(), ZKVMError> { - if let Some(ref ca) = gpu_result.compact_addr { - let count_vec: Vec = ca.count_buf.to_vec().map_err(|e| { - ZKVMError::InvalidWitness( - format!("compact_addr_count D2H failed: {e}").into(), - ) - })?; - let n = count_vec[0] as usize; - if n > 0 { - let addrs: Vec = ca.buffer.to_vec_n(n).map_err(|e| { - ZKVMError::InvalidWitness( - format!("compact_addr D2H failed: {e}").into(), - ) - })?; - let mut forked = shard_ctx.get_forked(); - let thread_ctx = &mut forked[0]; - for &addr in &addrs { - thread_ctx.push_addr_accessed(WordAddr(addr)); - } - } - } - Ok(()) - })?; - - // Populate shard_ctx with GPU EC records - let raw_bytes = unsafe { - std::slice::from_raw_parts( - compact_records.as_ptr() as *const u8, - compact_records.len() * std::mem::size_of::(), - ) - }; - shard_ctx.extend_gpu_ec_records_raw(raw_bytes); - - Ok::<(), ZKVMError>(()) - })?; - } - - // Step 8: Transpose GPU witness (column-major -> row-major) + D2H - let raw_witin = info_span!("transpose_d2h", rows = num_padded_rows, cols = num_witin) - .in_scope(|| { - let mut rmm_buffer = hal - .alloc_elems_on_device(num_padded_rows * num_witin, false, None) - .map_err(|e| { - ZKVMError::InvalidWitness(format!("GPU alloc for transpose failed: {e}").into()) - })?; - matrix_transpose::( - &hal.inner, - &mut rmm_buffer, - &gpu_result.witness.device_buffer, - num_padded_rows, - num_witin, - ) - .map_err(|e| ZKVMError::InvalidWitness(format!("GPU transpose failed: {e}").into()))?; - - let gpu_data: Vec<::BaseField> = - rmm_buffer.to_vec().map_err(|e| { - ZKVMError::InvalidWitness(format!("GPU D2H copy failed: {e}").into()) - })?; - - // Safety: BabyBear is the only supported GPU field, and E::BaseField must match - let data: Vec = unsafe { - let mut data = std::mem::ManuallyDrop::new(gpu_data); - Vec::from_raw_parts( - data.as_mut_ptr() as *mut E::BaseField, - data.len(), - data.capacity(), - ) - }; - - Ok::<_, ZKVMError>(RowMajorMatrix::::from_values_with_rotation( - data, - num_witin, - rotation, - num_instances, - InstancePaddingStrategy::Default, - )) - })?; - - // Step 9: Build structural witness on CPU with selector indices - let raw_structural = info_span!("structural_witness").in_scope(|| { - let mut raw_structural = RowMajorMatrix::::new_by_rotation( - num_instances, - rotation, - num_structural_witin, - InstancePaddingStrategy::Default, - ); - - // Get selector column IDs from config - let sel_first = config - .layout - .selector_type_layout - .sel_first - .as_ref() - .expect("sel_first must be Some"); - let sel_last = config - .layout - .selector_type_layout - .sel_last - .as_ref() - .expect("sel_last must be Some"); - - let sel_first_id = sel_first.selector_expr().id(); - let sel_last_id = sel_last.selector_expr().id(); - let sel_all_id = config - .layout - .selector_type_layout - .sel_all - .selector_expr() - .id(); - - let sel_first_indices = sel_first.sparse_indices(); - let sel_last_indices = sel_last.sparse_indices(); - let sel_all_indices = config.layout.selector_type_layout.sel_all.sparse_indices(); - - // Only set selectors for real instances, not padding ones. - for instance_chunk in raw_structural.iter_mut().take(num_instances) { - // instance_chunk is a &mut [F] of size 32 * num_structural_witin - for &idx in sel_first_indices { - instance_chunk[idx * num_structural_witin + sel_first_id] = E::BaseField::ONE; - } - for &idx in sel_last_indices { - instance_chunk[idx * num_structural_witin + sel_last_id] = E::BaseField::ONE; - } - for &idx in sel_all_indices { - instance_chunk[idx * num_structural_witin + sel_all_id] = E::BaseField::ONE; - } - } - raw_structural.padding_by_strategy(); - - raw_structural - }); - - // Debug comparisons (activated by env vars) - debug_compare_keccak::( - config, - shard_ctx, - num_witin, - num_structural_witin, - steps, - step_indices, - &lk_multiplicity, - &raw_witin, - &gpu_keccak_addrs, - )?; - - Ok(([raw_witin, raw_structural], lk_multiplicity)) -} diff --git a/ceno_zkvm/src/instructions/gpu/utils/debug_compare.rs b/ceno_zkvm/src/instructions/gpu/utils/debug_compare.rs index 3af3b3856..1441a1849 100644 --- a/ceno_zkvm/src/instructions/gpu/utils/debug_compare.rs +++ b/ceno_zkvm/src/instructions/gpu/utils/debug_compare.rs @@ -2,10 +2,8 @@ /// /// These functions compare GPU-produced results against CPU baselines /// to validate correctness. Activated by environment variables: -/// - CENO_GPU_DEBUG_COMPARE_LK: compare lookup multiplicities -/// - CENO_GPU_DEBUG_COMPARE_WITNESS: compare witness matrices -/// - CENO_GPU_DEBUG_COMPARE_SHARD: compare shardram records -/// - CENO_GPU_DEBUG_COMPARE_EC: compare EC points +/// All comparisons are activated by setting `CENO_GPU_DEBUG_COMPARE_WITGEN=1`. +/// This enables: LK multiplicity, witness matrix, shardram records, and EC point comparison. use ceno_emul::{StepIndex, StepRecord, WordAddr}; use ceno_gpu::common::witgen::types::{GpuRamRecordSlot, GpuShardRamRecord}; use ff_ext::ExtensionField; @@ -32,7 +30,7 @@ pub(crate) fn debug_compare_final_lk>( kind: GpuWitgenKind, mixed_lk: &Multiplicity, ) -> Result<(), ZKVMError> { - if std::env::var_os("CENO_GPU_DEBUG_COMPARE_LK").is_none() { + if !crate::instructions::gpu::config::is_debug_compare_enabled() { return Ok(()); } @@ -56,10 +54,7 @@ pub(crate) fn log_lk_diff( cpu_lk: &Multiplicity, actual_lk: &Multiplicity, ) { - let limit = std::env::var("CENO_GPU_DEBUG_COMPARE_LK_LIMIT") - .ok() - .and_then(|v| v.parse::().ok()) - .unwrap_or(32); + let limit: usize = 16; let mut total_diffs = 0usize; for (table_idx, (cpu_table, actual_table)) in cpu_lk.iter().zip(actual_lk.iter()).enumerate() { @@ -114,7 +109,7 @@ pub(crate) fn debug_compare_witness>( kind: GpuWitgenKind, gpu_witness: &RowMajorMatrix, ) -> Result<(), ZKVMError> { - if std::env::var_os("CENO_GPU_DEBUG_COMPARE_WITNESS").is_none() { + if !crate::instructions::gpu::config::is_debug_compare_enabled() { return Ok(()); } @@ -134,10 +129,7 @@ pub(crate) fn debug_compare_witness>( return Ok(()); } - let limit = std::env::var("CENO_GPU_DEBUG_COMPARE_WITNESS_LIMIT") - .ok() - .and_then(|v| v.parse::().ok()) - .unwrap_or(16); + let limit: usize = 16; let cpu_num_cols = cpu_witness.n_col(); let cpu_num_rows = cpu_vals.len() / cpu_num_cols; let mut mismatches = 0usize; @@ -172,7 +164,7 @@ pub(crate) fn debug_compare_shardram>( step_indices: &[StepIndex], kind: GpuWitgenKind, ) -> Result<(), ZKVMError> { - if std::env::var_os("CENO_GPU_DEBUG_COMPARE_SHARD").is_none() { + if !crate::instructions::gpu::config::is_debug_compare_enabled() { return Ok(()); } @@ -220,7 +212,7 @@ pub(crate) fn debug_compare_shardram>( /// B. shard records (sorted, normalized to ShardRamRecord) /// C. EC points (nonce + SepticPoint x,y) /// -/// Activated by CENO_GPU_DEBUG_COMPARE_EC=1. +/// Activated by CENO_GPU_DEBUG_COMPARE_WITGEN=1. pub(crate) fn debug_compare_shard_ec>( compact_records: &[GpuShardRamRecord], ram_slots: &[GpuRamRecordSlot], @@ -230,7 +222,7 @@ pub(crate) fn debug_compare_shard_ec>( step_indices: &[StepIndex], kind: GpuWitgenKind, ) { - if std::env::var_os("CENO_GPU_DEBUG_COMPARE_EC").is_none() { + if !crate::instructions::gpu::config::is_debug_compare_enabled() { return; } @@ -240,10 +232,7 @@ pub(crate) fn debug_compare_shard_ec>( }; use ff_ext::{PoseidonField, SmallField}; - let limit = std::env::var("CENO_GPU_DEBUG_COMPARE_EC_LIMIT") - .ok() - .and_then(|v| v.parse::().ok()) - .unwrap_or(16); + let limit: usize = 16; // ========== Build CPU shard context (independent, isolated) ========== let mut cpu_ctx = shard_ctx.new_empty_like(); @@ -601,10 +590,7 @@ pub(crate) fn log_ram_record_diff( cpu_records: &[(u32, u64, u64, u64, u64, Option, u32, usize)], mixed_records: &[(u32, u64, u64, u64, u64, Option, u32, usize)], ) { - let limit = std::env::var("CENO_GPU_DEBUG_COMPARE_SHARD_LIMIT") - .ok() - .and_then(|v| v.parse::().ok()) - .unwrap_or(16); + let limit: usize = 16; tracing::error!( "[GPU shard debug] kind={kind:?} {} cpu={} gpu={}", label, @@ -649,8 +635,7 @@ pub(crate) fn lookup_table_name(table_idx: usize) -> &'static str { /// Debug comparison for keccak GPU witgen. /// Runs the CPU path and compares LK / witness / shardram records. /// -/// Activated by CENO_GPU_DEBUG_COMPARE_LK, CENO_GPU_DEBUG_COMPARE_WITNESS, -/// or CENO_GPU_DEBUG_COMPARE_SHARD environment variables. +/// Activated by CENO_GPU_DEBUG_COMPARE_WITGEN=1. #[cfg(feature = "gpu")] pub(crate) fn debug_compare_keccak( config: &crate::instructions::riscv::ecall::keccak::EcallKeccakConfig, @@ -663,9 +648,10 @@ pub(crate) fn debug_compare_keccak( gpu_witin: &RowMajorMatrix, gpu_addrs: &[u32], ) -> Result<(), ZKVMError> { - let want_lk = std::env::var_os("CENO_GPU_DEBUG_COMPARE_LK").is_some(); - let want_witness = std::env::var_os("CENO_GPU_DEBUG_COMPARE_WITNESS").is_some(); - let want_shard = std::env::var_os("CENO_GPU_DEBUG_COMPARE_SHARD").is_some(); + let enabled = crate::instructions::gpu::config::is_debug_compare_enabled(); + let want_lk = enabled; + let want_witness = enabled; + let want_shard = enabled; if !want_lk && !want_witness && !want_shard { return Ok(()); @@ -684,7 +670,7 @@ pub(crate) fn debug_compare_keccak( tracing::info!("[GPU keccak debug] running CPU baseline for comparison"); // Run CPU path via assign_instances. The IN_DEBUG_COMPARE guard prevents - // gpu_assign_keccak_instances from calling debug_compare_keccak again, + // gpu_assign_keccak_instances (in chips/keccak.rs) from calling debug_compare_keccak again, // so it will produce the GPU result, which is then returned without // re-entering this function. We need assign_instances (not cpu_assign_instances) // because keccak has rotation matrices and 3 structural columns. @@ -717,10 +703,7 @@ pub(crate) fn debug_compare_keccak( } if want_witness { - let limit = std::env::var("CENO_GPU_DEBUG_COMPARE_WITNESS_LIMIT") - .ok() - .and_then(|v| v.parse::().ok()) - .unwrap_or(32); + let limit: usize = 16; let cpu_witin = &cpu_rmms[0]; let gpu_vals = gpu_witin.values(); let cpu_vals = cpu_witin.values(); diff --git a/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs b/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs index 5ed5455d0..4c82ffd08 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs @@ -180,7 +180,7 @@ impl Instruction for KeccakInstruction { ) -> Result<(RMMCollections, Multiplicity), ZKVMError> { #[cfg(feature = "gpu")] { - use crate::instructions::gpu::dispatch::gpu_assign_keccak_instances; + use crate::instructions::gpu::chips::keccak::gpu_assign_keccak_instances; if let Some(result) = gpu_assign_keccak_instances::( config, shard_ctx, From aadf86b74e2bcf6d13a072fe3a42ed22b7616595 Mon Sep 17 00:00:00 2001 From: Velaciela Date: Tue, 24 Mar 2026 22:42:14 +0800 Subject: [PATCH 67/73] fmt, lints --- ceno_zkvm/src/e2e.rs | 7 +- ceno_zkvm/src/instructions/gpu/chips/add.rs | 16 +- ceno_zkvm/src/instructions/gpu/chips/addi.rs | 11 +- ceno_zkvm/src/instructions/gpu/chips/auipc.rs | 11 +- .../src/instructions/gpu/chips/branch_cmp.rs | 11 +- .../src/instructions/gpu/chips/branch_eq.rs | 11 +- ceno_zkvm/src/instructions/gpu/chips/div.rs | 29 +- ceno_zkvm/src/instructions/gpu/chips/jal.rs | 11 +- ceno_zkvm/src/instructions/gpu/chips/jalr.rs | 11 +- .../src/instructions/gpu/chips/keccak.rs | 37 +- .../src/instructions/gpu/chips/load_sub.rs | 34 +- .../src/instructions/gpu/chips/logic_i.rs | 11 +- .../src/instructions/gpu/chips/logic_r.rs | 16 +- ceno_zkvm/src/instructions/gpu/chips/lui.rs | 11 +- ceno_zkvm/src/instructions/gpu/chips/lw.rs | 16 +- ceno_zkvm/src/instructions/gpu/chips/mul.rs | 29 +- ceno_zkvm/src/instructions/gpu/chips/sb.rs | 11 +- ceno_zkvm/src/instructions/gpu/chips/sh.rs | 11 +- .../src/instructions/gpu/chips/shift_i.rs | 11 +- .../src/instructions/gpu/chips/shift_r.rs | 11 +- ceno_zkvm/src/instructions/gpu/chips/slt.rs | 11 +- ceno_zkvm/src/instructions/gpu/chips/slti.rs | 11 +- ceno_zkvm/src/instructions/gpu/chips/sub.rs | 11 +- ceno_zkvm/src/instructions/gpu/chips/sw.rs | 11 +- ceno_zkvm/src/instructions/gpu/dispatch.rs | 914 +++++++++--------- .../src/instructions/gpu/utils/column_map.rs | 4 +- ceno_zkvm/src/instructions/gpu/utils/mod.rs | 2 +- .../instructions/gpu/utils/test_helpers.rs | 38 +- ceno_zkvm/src/scheme/septic_curve.rs | 10 +- 29 files changed, 729 insertions(+), 599 deletions(-) diff --git a/ceno_zkvm/src/e2e.rs b/ceno_zkvm/src/e2e.rs index 7b587d995..d2ca0bae6 100644 --- a/ceno_zkvm/src/e2e.rs +++ b/ceno_zkvm/src/e2e.rs @@ -1489,8 +1489,11 @@ pub fn generate_witness<'a, E: ExtensionField>( } } + #[cfg(feature = "gpu")] let debug_compare_e2e_shard = crate::instructions::gpu::config::is_debug_compare_enabled(); + #[cfg(not(feature = "gpu"))] + let debug_compare_e2e_shard = false; let debug_shard_ctx_template = debug_compare_e2e_shard.then(|| clone_debug_shard_ctx(&shard_ctx)); info_span!("assign_opcode_circuits").in_scope(|| { system_config @@ -2277,9 +2280,7 @@ fn clone_debug_shard_ctx(src: &ShardContext) -> ShardContext<'static> { type FlatRecord = (u32, u64, u64, u64, u64, Option, u32, usize); -fn flatten_ram_records( - records: &[BTreeMap], -) -> Vec { +fn flatten_ram_records(records: &[BTreeMap]) -> Vec { let mut flat = Vec::new(); for table in records { for (addr, record) in table { diff --git a/ceno_zkvm/src/instructions/gpu/chips/add.rs b/ceno_zkvm/src/instructions/gpu/chips/add.rs index 73eec76c5..3d7f53b65 100644 --- a/ceno_zkvm/src/instructions/gpu/chips/add.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/add.rs @@ -97,15 +97,21 @@ mod tests { } use crate::instructions::gpu::utils::column_map::test_colmap; - test_colmap!(test_extract_add_column_map, AddInstruction, extract_add_column_map); + test_colmap!( + test_extract_add_column_map, + AddInstruction, + extract_add_column_map + ); #[test] #[cfg(feature = "gpu")] fn test_gpu_witgen_add_correctness() { - use crate::e2e::ShardContext; - use crate::instructions::gpu::dispatch; - use crate::instructions::gpu::utils::test_helpers::{ - assert_full_gpu_pipeline, assert_witness_colmajor_eq, + use crate::{ + e2e::ShardContext, + instructions::gpu::{ + dispatch, + utils::test_helpers::{assert_full_gpu_pipeline, assert_witness_colmajor_eq}, + }, }; use ceno_gpu::{Buffer, bb31::CudaHalBB31}; diff --git a/ceno_zkvm/src/instructions/gpu/chips/addi.rs b/ceno_zkvm/src/instructions/gpu/chips/addi.rs index 9302a2f7a..c17493a04 100644 --- a/ceno_zkvm/src/instructions/gpu/chips/addi.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/addi.rs @@ -55,13 +55,18 @@ mod tests { type E = BabyBearExt4; use crate::instructions::gpu::utils::column_map::test_colmap; - test_colmap!(test_extract_addi_column_map, AddiInstruction, extract_addi_column_map); + test_colmap!( + test_extract_addi_column_map, + AddiInstruction, + extract_addi_column_map + ); #[test] #[cfg(feature = "gpu")] fn test_gpu_witgen_addi_correctness() { - use crate::e2e::ShardContext; - use crate::instructions::gpu::utils::test_helpers::assert_witness_colmajor_eq; + use crate::{ + e2e::ShardContext, instructions::gpu::utils::test_helpers::assert_witness_colmajor_eq, + }; use ceno_emul::{ByteAddr, Change, InsnKind, PC_STEP_SIZE, StepRecord, encode_rv32}; use ceno_gpu::{Buffer, bb31::CudaHalBB31}; diff --git a/ceno_zkvm/src/instructions/gpu/chips/auipc.rs b/ceno_zkvm/src/instructions/gpu/chips/auipc.rs index b4948b7d3..eed81fa7f 100644 --- a/ceno_zkvm/src/instructions/gpu/chips/auipc.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/auipc.rs @@ -55,13 +55,18 @@ mod tests { type E = BabyBearExt4; use crate::instructions::gpu::utils::column_map::test_colmap; - test_colmap!(test_extract_auipc_column_map, AuipcInstruction, extract_auipc_column_map); + test_colmap!( + test_extract_auipc_column_map, + AuipcInstruction, + extract_auipc_column_map + ); #[test] #[cfg(feature = "gpu")] fn test_gpu_witgen_auipc_correctness() { - use crate::e2e::ShardContext; - use crate::instructions::gpu::utils::test_helpers::assert_witness_colmajor_eq; + use crate::{ + e2e::ShardContext, instructions::gpu::utils::test_helpers::assert_witness_colmajor_eq, + }; use ceno_emul::{ByteAddr, Change, InsnKind, PC_STEP_SIZE, StepRecord, encode_rv32}; use ceno_gpu::{Buffer, bb31::CudaHalBB31}; diff --git a/ceno_zkvm/src/instructions/gpu/chips/branch_cmp.rs b/ceno_zkvm/src/instructions/gpu/chips/branch_cmp.rs index 6ef8c2034..e717d05b7 100644 --- a/ceno_zkvm/src/instructions/gpu/chips/branch_cmp.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/branch_cmp.rs @@ -66,13 +66,18 @@ mod tests { type E = BabyBearExt4; use crate::instructions::gpu::utils::column_map::test_colmap; - test_colmap!(test_extract_branch_cmp_column_map, BltInstruction, extract_branch_cmp_column_map); + test_colmap!( + test_extract_branch_cmp_column_map, + BltInstruction, + extract_branch_cmp_column_map + ); #[test] #[cfg(feature = "gpu")] fn test_gpu_witgen_branch_cmp_correctness() { - use crate::e2e::ShardContext; - use crate::instructions::gpu::utils::test_helpers::assert_witness_colmajor_eq; + use crate::{ + e2e::ShardContext, instructions::gpu::utils::test_helpers::assert_witness_colmajor_eq, + }; use ceno_emul::{ByteAddr, Change, InsnKind, PC_STEP_SIZE, StepRecord, encode_rv32}; use ceno_gpu::{Buffer, bb31::CudaHalBB31}; diff --git a/ceno_zkvm/src/instructions/gpu/chips/branch_eq.rs b/ceno_zkvm/src/instructions/gpu/chips/branch_eq.rs index 3fa316740..00c16cb4f 100644 --- a/ceno_zkvm/src/instructions/gpu/chips/branch_eq.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/branch_eq.rs @@ -59,13 +59,18 @@ mod tests { type E = BabyBearExt4; use crate::instructions::gpu::utils::column_map::test_colmap; - test_colmap!(test_extract_branch_eq_column_map, BeqInstruction, extract_branch_eq_column_map); + test_colmap!( + test_extract_branch_eq_column_map, + BeqInstruction, + extract_branch_eq_column_map + ); #[test] #[cfg(feature = "gpu")] fn test_gpu_witgen_branch_eq_correctness() { - use crate::e2e::ShardContext; - use crate::instructions::gpu::utils::test_helpers::assert_witness_colmajor_eq; + use crate::{ + e2e::ShardContext, instructions::gpu::utils::test_helpers::assert_witness_colmajor_eq, + }; use ceno_emul::{ByteAddr, Change, InsnKind, PC_STEP_SIZE, StepRecord, encode_rv32}; use ceno_gpu::{Buffer, bb31::CudaHalBB31}; diff --git a/ceno_zkvm/src/instructions/gpu/chips/div.rs b/ceno_zkvm/src/instructions/gpu/chips/div.rs index 24f91457e..3e980ae4f 100644 --- a/ceno_zkvm/src/instructions/gpu/chips/div.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/div.rs @@ -100,16 +100,33 @@ mod tests { type E = BabyBearExt4; use crate::instructions::gpu::utils::column_map::test_colmap; - test_colmap!(test_extract_div_column_map, DivInstruction, extract_div_column_map); - test_colmap!(test_extract_divu_column_map, DivuInstruction, extract_div_column_map); - test_colmap!(test_extract_rem_column_map, RemInstruction, extract_div_column_map); - test_colmap!(test_extract_remu_column_map, RemuInstruction, extract_div_column_map); + test_colmap!( + test_extract_div_column_map, + DivInstruction, + extract_div_column_map + ); + test_colmap!( + test_extract_divu_column_map, + DivuInstruction, + extract_div_column_map + ); + test_colmap!( + test_extract_rem_column_map, + RemInstruction, + extract_div_column_map + ); + test_colmap!( + test_extract_remu_column_map, + RemuInstruction, + extract_div_column_map + ); #[test] #[cfg(feature = "gpu")] fn test_gpu_witgen_div_correctness() { - use crate::e2e::ShardContext; - use crate::instructions::gpu::utils::test_helpers::assert_witness_colmajor_eq; + use crate::{ + e2e::ShardContext, instructions::gpu::utils::test_helpers::assert_witness_colmajor_eq, + }; use ceno_emul::{ByteAddr, Change, InsnKind, StepRecord, encode_rv32}; use ceno_gpu::{Buffer, bb31::CudaHalBB31}; diff --git a/ceno_zkvm/src/instructions/gpu/chips/jal.rs b/ceno_zkvm/src/instructions/gpu/chips/jal.rs index e364a2e0c..22bc084b4 100644 --- a/ceno_zkvm/src/instructions/gpu/chips/jal.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/jal.rs @@ -43,13 +43,18 @@ mod tests { type E = BabyBearExt4; use crate::instructions::gpu::utils::column_map::test_colmap; - test_colmap!(test_extract_jal_column_map, JalInstruction, extract_jal_column_map); + test_colmap!( + test_extract_jal_column_map, + JalInstruction, + extract_jal_column_map + ); #[test] #[cfg(feature = "gpu")] fn test_gpu_witgen_jal_correctness() { - use crate::e2e::ShardContext; - use crate::instructions::gpu::utils::test_helpers::assert_witness_colmajor_eq; + use crate::{ + e2e::ShardContext, instructions::gpu::utils::test_helpers::assert_witness_colmajor_eq, + }; use ceno_emul::{ByteAddr, Change, InsnKind, PC_STEP_SIZE, StepRecord, encode_rv32}; use ceno_gpu::{Buffer, bb31::CudaHalBB31}; diff --git a/ceno_zkvm/src/instructions/gpu/chips/jalr.rs b/ceno_zkvm/src/instructions/gpu/chips/jalr.rs index b3a52d8ed..a5663b7cf 100644 --- a/ceno_zkvm/src/instructions/gpu/chips/jalr.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/jalr.rs @@ -63,13 +63,18 @@ mod tests { type E = BabyBearExt4; use crate::instructions::gpu::utils::column_map::test_colmap; - test_colmap!(test_extract_jalr_column_map, JalrInstruction, extract_jalr_column_map); + test_colmap!( + test_extract_jalr_column_map, + JalrInstruction, + extract_jalr_column_map + ); #[test] #[cfg(feature = "gpu")] fn test_gpu_witgen_jalr_correctness() { - use crate::e2e::ShardContext; - use crate::instructions::gpu::utils::test_helpers::assert_witness_colmajor_eq; + use crate::{ + e2e::ShardContext, instructions::gpu::utils::test_helpers::assert_witness_colmajor_eq, + }; use ceno_emul::{ByteAddr, Change, InsnKind, PC_STEP_SIZE, StepRecord, encode_rv32}; use ceno_gpu::{Buffer, bb31::CudaHalBB31}; diff --git a/ceno_zkvm/src/instructions/gpu/chips/keccak.rs b/ceno_zkvm/src/instructions/gpu/chips/keccak.rs index 2f35b167c..e64d9311f 100644 --- a/ceno_zkvm/src/instructions/gpu/chips/keccak.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/keccak.rs @@ -7,9 +7,12 @@ use crate::instructions::riscv::ecall::keccak::EcallKeccakConfig; use ceno_emul::SyscallWitness; -use ceno_gpu::{Buffer, CudaHal, bb31::CudaHalBB31, common::transpose::matrix_transpose}; -use ceno_gpu::common::witgen::types::GpuShardRamRecord; use ceno_emul::WordAddr; +use ceno_gpu::{ + Buffer, CudaHal, + bb31::CudaHalBB31, + common::{transpose::matrix_transpose, witgen::types::GpuShardRamRecord}, +}; use gkr_iop::utils::lk_multiplicity::Multiplicity; use p3::field::FieldAlgebra; use tracing::info_span; @@ -18,14 +21,21 @@ use witness::{InstancePaddingStrategy, RowMajorMatrix}; use crate::{ e2e::ShardContext, error::ZKVMError, + instructions::gpu::{ + cache::{ + ensure_shard_metadata_cached, read_shared_addr_count, read_shared_addr_range, + with_cached_shard_meta, + }, + config::{is_debug_compare_enabled, is_gpu_witgen_disabled, is_kind_disabled}, + dispatch::{GpuWitgenKind, compute_fetch_params, is_force_cpu_path}, + utils::{ + d2h::{gpu_compact_ec_d2h, gpu_lk_counters_to_multiplicity}, + debug_compare::debug_compare_keccak, + }, + }, tables::RMMCollections, witness::LkMultiplicity, }; -use crate::instructions::gpu::config::{is_gpu_witgen_disabled, is_kind_disabled, is_debug_compare_enabled}; -use crate::instructions::gpu::dispatch::{GpuWitgenKind, compute_fetch_params, is_force_cpu_path}; -use crate::instructions::gpu::cache::{ensure_shard_metadata_cached, with_cached_shard_meta, read_shared_addr_count, read_shared_addr_range}; -use crate::instructions::gpu::utils::d2h::{gpu_lk_counters_to_multiplicity, gpu_compact_ec_d2h}; -use crate::instructions::gpu::utils::debug_compare::debug_compare_keccak; /// Extract column map from a constructed EcallKeccakConfig. /// @@ -291,17 +301,11 @@ fn gpu_assign_keccak_inner( let rotation = KECCAK_ROUNDS_CEIL_LOG2; // = 5 // Step 1: Extract column map - let col_map = info_span!("col_map") - .in_scope(|| extract_keccak_column_map(config, num_witin)); + let col_map = info_span!("col_map").in_scope(|| extract_keccak_column_map(config, num_witin)); // Step 2: Pack instances - let packed_instances = info_span!("pack_instances").in_scope(|| { - pack_keccak_instances( - steps, - step_indices, - &shard_ctx.syscall_witnesses, - ) - }); + let packed_instances = info_span!("pack_instances") + .in_scope(|| pack_keccak_instances(steps, step_indices, &shard_ctx.syscall_witnesses)); // Step 3: Compute fetch params let (fetch_base_pc, fetch_num_slots) = compute_fetch_params(steps, step_indices); @@ -701,4 +705,3 @@ mod tests { assert_eq!(lk_mismatches, 0, "GPU vs CPU LK multiplicity mismatch"); } } - diff --git a/ceno_zkvm/src/instructions/gpu/chips/load_sub.rs b/ceno_zkvm/src/instructions/gpu/chips/load_sub.rs index 608930e7d..200b069d7 100644 --- a/ceno_zkvm/src/instructions/gpu/chips/load_sub.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/load_sub.rs @@ -59,7 +59,10 @@ pub fn extract_load_sub_column_map( }; // Signed loads have signed_extend_config - let msb = config.signed_extend_config.as_ref().map(|sec| sec.msb().id as u32); + let msb = config + .signed_extend_config + .as_ref() + .map(|sec| sec.msb().id as u32); LoadSubColumnMap { pc, @@ -104,16 +107,33 @@ mod tests { type E = BabyBearExt4; use crate::instructions::gpu::utils::column_map::test_colmap; - test_colmap!(test_extract_lh_column_map, LhInstruction, extract_load_sub_column_map); - test_colmap!(test_extract_lhu_column_map, LhuInstruction, extract_load_sub_column_map); - test_colmap!(test_extract_lb_column_map, LbInstruction, extract_load_sub_column_map); - test_colmap!(test_extract_lbu_column_map, LbuInstruction, extract_load_sub_column_map); + test_colmap!( + test_extract_lh_column_map, + LhInstruction, + extract_load_sub_column_map + ); + test_colmap!( + test_extract_lhu_column_map, + LhuInstruction, + extract_load_sub_column_map + ); + test_colmap!( + test_extract_lb_column_map, + LbInstruction, + extract_load_sub_column_map + ); + test_colmap!( + test_extract_lbu_column_map, + LbuInstruction, + extract_load_sub_column_map + ); #[test] #[cfg(feature = "gpu")] fn test_gpu_witgen_load_sub_correctness() { - use crate::e2e::ShardContext; - use crate::instructions::gpu::utils::test_helpers::assert_witness_colmajor_eq; + use crate::{ + e2e::ShardContext, instructions::gpu::utils::test_helpers::assert_witness_colmajor_eq, + }; use ceno_emul::{ByteAddr, Change, InsnKind, ReadOp, StepRecord, WordAddr, encode_rv32}; use ceno_gpu::{Buffer, bb31::CudaHalBB31}; diff --git a/ceno_zkvm/src/instructions/gpu/chips/logic_i.rs b/ceno_zkvm/src/instructions/gpu/chips/logic_i.rs index 914437620..791129945 100644 --- a/ceno_zkvm/src/instructions/gpu/chips/logic_i.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/logic_i.rs @@ -53,13 +53,18 @@ mod tests { type E = BabyBearExt4; use crate::instructions::gpu::utils::column_map::test_colmap; - test_colmap!(test_extract_logic_i_column_map, AndiInstruction, extract_logic_i_column_map); + test_colmap!( + test_extract_logic_i_column_map, + AndiInstruction, + extract_logic_i_column_map + ); #[test] #[cfg(feature = "gpu")] fn test_gpu_witgen_logic_i_correctness() { - use crate::e2e::ShardContext; - use crate::instructions::gpu::utils::test_helpers::assert_witness_colmajor_eq; + use crate::{ + e2e::ShardContext, instructions::gpu::utils::test_helpers::assert_witness_colmajor_eq, + }; use ceno_emul::{ByteAddr, Change, InsnKind, PC_STEP_SIZE, StepRecord, encode_rv32u}; use ceno_gpu::{Buffer, bb31::CudaHalBB31}; diff --git a/ceno_zkvm/src/instructions/gpu/chips/logic_r.rs b/ceno_zkvm/src/instructions/gpu/chips/logic_r.rs index df5d64f15..2c22b8c39 100644 --- a/ceno_zkvm/src/instructions/gpu/chips/logic_r.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/logic_r.rs @@ -55,15 +55,21 @@ mod tests { type E = BabyBearExt4; use crate::instructions::gpu::utils::column_map::test_colmap; - test_colmap!(test_extract_logic_r_column_map, AndInstruction, extract_logic_r_column_map); + test_colmap!( + test_extract_logic_r_column_map, + AndInstruction, + extract_logic_r_column_map + ); #[test] #[cfg(feature = "gpu")] fn test_gpu_witgen_logic_r_correctness() { - use crate::e2e::ShardContext; - use crate::instructions::gpu::dispatch; - use crate::instructions::gpu::utils::test_helpers::{ - assert_full_gpu_pipeline, assert_witness_colmajor_eq, + use crate::{ + e2e::ShardContext, + instructions::gpu::{ + dispatch, + utils::test_helpers::{assert_full_gpu_pipeline, assert_witness_colmajor_eq}, + }, }; use ceno_emul::{ByteAddr, Change, InsnKind, StepRecord, encode_rv32}; use ceno_gpu::{Buffer, bb31::CudaHalBB31}; diff --git a/ceno_zkvm/src/instructions/gpu/chips/lui.rs b/ceno_zkvm/src/instructions/gpu/chips/lui.rs index 52ea86960..37e359a28 100644 --- a/ceno_zkvm/src/instructions/gpu/chips/lui.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/lui.rs @@ -54,13 +54,18 @@ mod tests { type E = BabyBearExt4; use crate::instructions::gpu::utils::column_map::test_colmap; - test_colmap!(test_extract_lui_column_map, LuiInstruction, extract_lui_column_map); + test_colmap!( + test_extract_lui_column_map, + LuiInstruction, + extract_lui_column_map + ); #[test] #[cfg(feature = "gpu")] fn test_gpu_witgen_lui_correctness() { - use crate::e2e::ShardContext; - use crate::instructions::gpu::utils::test_helpers::assert_witness_colmajor_eq; + use crate::{ + e2e::ShardContext, instructions::gpu::utils::test_helpers::assert_witness_colmajor_eq, + }; use ceno_emul::{ByteAddr, Change, InsnKind, PC_STEP_SIZE, StepRecord, encode_rv32}; use ceno_gpu::{Buffer, bb31::CudaHalBB31}; diff --git a/ceno_zkvm/src/instructions/gpu/chips/lw.rs b/ceno_zkvm/src/instructions/gpu/chips/lw.rs index d992486fa..14f09779b 100644 --- a/ceno_zkvm/src/instructions/gpu/chips/lw.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/lw.rs @@ -101,15 +101,21 @@ mod tests { } use crate::instructions::gpu::utils::column_map::test_colmap; - test_colmap!(test_extract_lw_column_map, LwInstruction, extract_lw_column_map); + test_colmap!( + test_extract_lw_column_map, + LwInstruction, + extract_lw_column_map + ); #[test] #[cfg(feature = "gpu")] fn test_gpu_witgen_lw_correctness() { - use crate::e2e::ShardContext; - use crate::instructions::gpu::dispatch; - use crate::instructions::gpu::utils::test_helpers::{ - assert_full_gpu_pipeline, assert_witness_colmajor_eq, + use crate::{ + e2e::ShardContext, + instructions::gpu::{ + dispatch, + utils::test_helpers::{assert_full_gpu_pipeline, assert_witness_colmajor_eq}, + }, }; use ceno_gpu::{Buffer, bb31::CudaHalBB31}; diff --git a/ceno_zkvm/src/instructions/gpu/chips/mul.rs b/ceno_zkvm/src/instructions/gpu/chips/mul.rs index 394efe689..f7e8acab3 100644 --- a/ceno_zkvm/src/instructions/gpu/chips/mul.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/mul.rs @@ -72,16 +72,33 @@ mod tests { type E = BabyBearExt4; use crate::instructions::gpu::utils::column_map::test_colmap; - test_colmap!(test_extract_mul_column_map, MulInstruction, extract_mul_column_map); - test_colmap!(test_extract_mulh_column_map, MulhInstruction, extract_mul_column_map); - test_colmap!(test_extract_mulhu_column_map, MulhuInstruction, extract_mul_column_map); - test_colmap!(test_extract_mulhsu_column_map, MulhsuInstruction, extract_mul_column_map); + test_colmap!( + test_extract_mul_column_map, + MulInstruction, + extract_mul_column_map + ); + test_colmap!( + test_extract_mulh_column_map, + MulhInstruction, + extract_mul_column_map + ); + test_colmap!( + test_extract_mulhu_column_map, + MulhuInstruction, + extract_mul_column_map + ); + test_colmap!( + test_extract_mulhsu_column_map, + MulhsuInstruction, + extract_mul_column_map + ); #[test] #[cfg(feature = "gpu")] fn test_gpu_witgen_mul_correctness() { - use crate::e2e::ShardContext; - use crate::instructions::gpu::utils::test_helpers::assert_witness_colmajor_eq; + use crate::{ + e2e::ShardContext, instructions::gpu::utils::test_helpers::assert_witness_colmajor_eq, + }; use ceno_emul::{ByteAddr, Change, InsnKind, StepRecord, encode_rv32}; use ceno_gpu::{Buffer, bb31::CudaHalBB31}; diff --git a/ceno_zkvm/src/instructions/gpu/chips/sb.rs b/ceno_zkvm/src/instructions/gpu/chips/sb.rs index dd3484610..f312b06f8 100644 --- a/ceno_zkvm/src/instructions/gpu/chips/sb.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/sb.rs @@ -94,13 +94,18 @@ mod tests { type E = BabyBearExt4; use crate::instructions::gpu::utils::column_map::test_colmap; - test_colmap!(test_extract_sb_column_map, SbInstruction, extract_sb_column_map); + test_colmap!( + test_extract_sb_column_map, + SbInstruction, + extract_sb_column_map + ); #[test] #[cfg(feature = "gpu")] fn test_gpu_witgen_sb_correctness() { - use crate::e2e::ShardContext; - use crate::instructions::gpu::utils::test_helpers::assert_witness_colmajor_eq; + use crate::{ + e2e::ShardContext, instructions::gpu::utils::test_helpers::assert_witness_colmajor_eq, + }; use ceno_emul::{ByteAddr, Change, InsnKind, StepRecord, WordAddr, WriteOp, encode_rv32}; use ceno_gpu::{Buffer, bb31::CudaHalBB31}; diff --git a/ceno_zkvm/src/instructions/gpu/chips/sh.rs b/ceno_zkvm/src/instructions/gpu/chips/sh.rs index 7c59366d7..c73768408 100644 --- a/ceno_zkvm/src/instructions/gpu/chips/sh.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/sh.rs @@ -71,13 +71,18 @@ mod tests { type E = BabyBearExt4; use crate::instructions::gpu::utils::column_map::test_colmap; - test_colmap!(test_extract_sh_column_map, ShInstruction, extract_sh_column_map); + test_colmap!( + test_extract_sh_column_map, + ShInstruction, + extract_sh_column_map + ); #[test] #[cfg(feature = "gpu")] fn test_gpu_witgen_sh_correctness() { - use crate::e2e::ShardContext; - use crate::instructions::gpu::utils::test_helpers::assert_witness_colmajor_eq; + use crate::{ + e2e::ShardContext, instructions::gpu::utils::test_helpers::assert_witness_colmajor_eq, + }; use ceno_emul::{ByteAddr, Change, InsnKind, StepRecord, WordAddr, WriteOp, encode_rv32}; use ceno_gpu::{Buffer, bb31::CudaHalBB31}; diff --git a/ceno_zkvm/src/instructions/gpu/chips/shift_i.rs b/ceno_zkvm/src/instructions/gpu/chips/shift_i.rs index da0181dc4..7496268d4 100644 --- a/ceno_zkvm/src/instructions/gpu/chips/shift_i.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/shift_i.rs @@ -66,13 +66,18 @@ mod tests { type E = BabyBearExt4; use crate::instructions::gpu::utils::column_map::test_colmap; - test_colmap!(test_extract_shift_i_column_map, SlliInstruction, extract_shift_i_column_map); + test_colmap!( + test_extract_shift_i_column_map, + SlliInstruction, + extract_shift_i_column_map + ); #[test] #[cfg(feature = "gpu")] fn test_gpu_witgen_shift_i_correctness() { - use crate::e2e::ShardContext; - use crate::instructions::gpu::utils::test_helpers::assert_witness_colmajor_eq; + use crate::{ + e2e::ShardContext, instructions::gpu::utils::test_helpers::assert_witness_colmajor_eq, + }; use ceno_emul::{ByteAddr, Change, InsnKind, PC_STEP_SIZE, StepRecord, encode_rv32}; use ceno_gpu::{Buffer, bb31::CudaHalBB31}; diff --git a/ceno_zkvm/src/instructions/gpu/chips/shift_r.rs b/ceno_zkvm/src/instructions/gpu/chips/shift_r.rs index c2efab990..e1cb7b9e0 100644 --- a/ceno_zkvm/src/instructions/gpu/chips/shift_r.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/shift_r.rs @@ -72,13 +72,18 @@ mod tests { type E = BabyBearExt4; use crate::instructions::gpu::utils::column_map::test_colmap; - test_colmap!(test_extract_shift_r_column_map, SllInstruction, extract_shift_r_column_map); + test_colmap!( + test_extract_shift_r_column_map, + SllInstruction, + extract_shift_r_column_map + ); #[test] #[cfg(feature = "gpu")] fn test_gpu_witgen_shift_r_correctness() { - use crate::e2e::ShardContext; - use crate::instructions::gpu::utils::test_helpers::assert_witness_colmajor_eq; + use crate::{ + e2e::ShardContext, instructions::gpu::utils::test_helpers::assert_witness_colmajor_eq, + }; use ceno_emul::{ByteAddr, Change, InsnKind, StepRecord, encode_rv32}; use ceno_gpu::{Buffer, bb31::CudaHalBB31}; diff --git a/ceno_zkvm/src/instructions/gpu/chips/slt.rs b/ceno_zkvm/src/instructions/gpu/chips/slt.rs index e2816af9b..45251fdad 100644 --- a/ceno_zkvm/src/instructions/gpu/chips/slt.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/slt.rs @@ -68,13 +68,18 @@ mod tests { type E = BabyBearExt4; use crate::instructions::gpu::utils::column_map::test_colmap; - test_colmap!(test_extract_slt_column_map, SltInstruction, extract_slt_column_map); + test_colmap!( + test_extract_slt_column_map, + SltInstruction, + extract_slt_column_map + ); #[test] #[cfg(feature = "gpu")] fn test_gpu_witgen_slt_correctness() { - use crate::e2e::ShardContext; - use crate::instructions::gpu::utils::test_helpers::assert_witness_colmajor_eq; + use crate::{ + e2e::ShardContext, instructions::gpu::utils::test_helpers::assert_witness_colmajor_eq, + }; use ceno_emul::{ByteAddr, Change, InsnKind, StepRecord, encode_rv32}; use ceno_gpu::{Buffer, bb31::CudaHalBB31}; diff --git a/ceno_zkvm/src/instructions/gpu/chips/slti.rs b/ceno_zkvm/src/instructions/gpu/chips/slti.rs index cbd21db06..ba7868ef2 100644 --- a/ceno_zkvm/src/instructions/gpu/chips/slti.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/slti.rs @@ -64,13 +64,18 @@ mod tests { type E = BabyBearExt4; use crate::instructions::gpu::utils::column_map::test_colmap; - test_colmap!(test_extract_slti_column_map, SltiInstruction, extract_slti_column_map); + test_colmap!( + test_extract_slti_column_map, + SltiInstruction, + extract_slti_column_map + ); #[test] #[cfg(feature = "gpu")] fn test_gpu_witgen_slti_correctness() { - use crate::e2e::ShardContext; - use crate::instructions::gpu::utils::test_helpers::assert_witness_colmajor_eq; + use crate::{ + e2e::ShardContext, instructions::gpu::utils::test_helpers::assert_witness_colmajor_eq, + }; use ceno_emul::{ByteAddr, Change, InsnKind, PC_STEP_SIZE, StepRecord, encode_rv32}; use ceno_gpu::{Buffer, bb31::CudaHalBB31}; diff --git a/ceno_zkvm/src/instructions/gpu/chips/sub.rs b/ceno_zkvm/src/instructions/gpu/chips/sub.rs index 2861ee9b9..6107771db 100644 --- a/ceno_zkvm/src/instructions/gpu/chips/sub.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/sub.rs @@ -58,13 +58,18 @@ mod tests { type E = BabyBearExt4; use crate::instructions::gpu::utils::column_map::test_colmap; - test_colmap!(test_extract_sub_column_map, SubInstruction, extract_sub_column_map); + test_colmap!( + test_extract_sub_column_map, + SubInstruction, + extract_sub_column_map + ); #[test] #[cfg(feature = "gpu")] fn test_gpu_witgen_sub_correctness() { - use crate::e2e::ShardContext; - use crate::instructions::gpu::utils::test_helpers::assert_witness_colmajor_eq; + use crate::{ + e2e::ShardContext, instructions::gpu::utils::test_helpers::assert_witness_colmajor_eq, + }; use ceno_emul::{ByteAddr, Change, InsnKind, StepRecord, encode_rv32}; use ceno_gpu::{Buffer, bb31::CudaHalBB31}; diff --git a/ceno_zkvm/src/instructions/gpu/chips/sw.rs b/ceno_zkvm/src/instructions/gpu/chips/sw.rs index 481f8fd6a..f501ba910 100644 --- a/ceno_zkvm/src/instructions/gpu/chips/sw.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/sw.rs @@ -62,13 +62,18 @@ mod tests { type E = BabyBearExt4; use crate::instructions::gpu::utils::column_map::test_colmap; - test_colmap!(test_extract_sw_column_map, SwInstruction, extract_sw_column_map); + test_colmap!( + test_extract_sw_column_map, + SwInstruction, + extract_sw_column_map + ); #[test] #[cfg(feature = "gpu")] fn test_gpu_witgen_sw_correctness() { - use crate::e2e::ShardContext; - use crate::instructions::gpu::utils::test_helpers::assert_witness_colmajor_eq; + use crate::{ + e2e::ShardContext, instructions::gpu::utils::test_helpers::assert_witness_colmajor_eq, + }; use ceno_emul::{ByteAddr, Change, InsnKind, StepRecord, WordAddr, WriteOp, encode_rv32}; use ceno_gpu::{Buffer, bb31::CudaHalBB31}; diff --git a/ceno_zkvm/src/instructions/gpu/dispatch.rs b/ceno_zkvm/src/instructions/gpu/dispatch.rs index 612d36e5f..772c7cce9 100644 --- a/ceno_zkvm/src/instructions/gpu/dispatch.rs +++ b/ceno_zkvm/src/instructions/gpu/dispatch.rs @@ -21,12 +21,10 @@ use tracing::info_span; use witness::{InstancePaddingStrategy, RowMajorMatrix}; use super::{ - config::{ - is_gpu_witgen_disabled, is_kind_disabled, - }, + config::{is_gpu_witgen_disabled, is_kind_disabled}, utils::debug_compare::{ - debug_compare_final_lk, debug_compare_shard_ec, - debug_compare_shardram, debug_compare_witness, + debug_compare_final_lk, debug_compare_shard_ec, debug_compare_shardram, + debug_compare_witness, }, }; use crate::{ @@ -42,44 +40,23 @@ pub enum GpuWitgenKind { Add, Sub, LogicR(u32), // 0=AND, 1=OR, 2=XOR - LogicI(u32), // 0=AND, 1=OR, 2=XOR - Addi, - Lui, - Auipc, - Jal, - - ShiftR(u32), // 0=SLL, 1=SRL, 2=SRA - - ShiftI(u32), // 0=SLLI, 1=SRLI, 2=SRAI - - Slt(u32), // 1=SLT(signed), 0=SLTU(unsigned) - - Slti(u32), // 1=SLTI(signed), 0=SLTIU(unsigned) - - BranchEq(u32), // 1=BEQ, 0=BNE - + ShiftR(u32), // 0=SLL, 1=SRL, 2=SRA + ShiftI(u32), // 0=SLLI, 1=SRLI, 2=SRAI + Slt(u32), // 1=SLT(signed), 0=SLTU(unsigned) + Slti(u32), // 1=SLTI(signed), 0=SLTIU(unsigned) + BranchEq(u32), // 1=BEQ, 0=BNE BranchCmp(u32), // 1=signed (BLT/BGE), 0=unsigned (BLTU/BGEU) - Jalr, - Sw, - Sh, - Sb, - - LoadSub { - load_width: u32, - is_signed: u32, - }, - + LoadSub { load_width: u32, is_signed: u32 }, Mul(u32), // 0=MUL, 1=MULH, 2=MULHU, 3=MULHSU - Div(u32), // 0=DIV, 1=DIVU, 2=REM, 3=REMU Lw, Keccak, @@ -95,7 +72,8 @@ pub use super::utils::d2h::gpu_batch_continuation_ec; use super::{ cache::{ ensure_shard_metadata_cached, read_shared_addr_count, read_shared_addr_range, - upload_shard_steps_cached, with_cached_gpu_ctx, with_cached_shard_meta, with_cached_shard_steps, + upload_shard_steps_cached, with_cached_gpu_ctx, with_cached_shard_meta, + with_cached_shard_steps, }, utils::d2h::{ CompactEcBuf, LkResult, RamBuf, WitResult, gpu_collect_shard_records, gpu_compact_ec_d2h, @@ -485,25 +463,25 @@ fn gpu_fill_witness>( .in_scope(|| super::chips::add::extract_add_column_map(arith_config, num_witin)); info_span!("hal_witgen_add").in_scope(|| { with_cached_gpu_ctx(|gpu_records, shard_bufs| { - split_full!( - hal.witgen - .witgen_add( - &col_map, - gpu_records, - &indices_u32, - shard_offset, - fetch_base_pc, - fetch_num_slots, - None, - Some(shard_bufs), + split_full!( + hal.witgen + .witgen_add( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), + ) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_add failed: {e}").into(), ) - .map_err(|e| { - ZKVMError::InvalidWitness( - format!("GPU witgen_add failed: {e}").into(), - ) - }) - ) - }) + }) + ) + }) }) } GpuWitgenKind::Sub => { @@ -515,25 +493,25 @@ fn gpu_fill_witness>( .in_scope(|| super::chips::sub::extract_sub_column_map(arith_config, num_witin)); info_span!("hal_witgen_sub").in_scope(|| { with_cached_gpu_ctx(|gpu_records, shard_bufs| { - split_full!( - hal.witgen - .witgen_sub( - &col_map, - gpu_records, - &indices_u32, - shard_offset, - fetch_base_pc, - fetch_num_slots, - None, - Some(shard_bufs), + split_full!( + hal.witgen + .witgen_sub( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), + ) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_sub failed: {e}").into(), ) - .map_err(|e| { - ZKVMError::InvalidWitness( - format!("GPU witgen_sub failed: {e}").into(), - ) - }) - ) - }) + }) + ) + }) }) } GpuWitgenKind::LogicR(logic_kind) => { @@ -546,29 +524,29 @@ fn gpu_fill_witness>( }); info_span!("hal_witgen_logic_r").in_scope(|| { with_cached_gpu_ctx(|gpu_records, shard_bufs| { - split_full!( - hal.witgen - .witgen_logic_r( - &col_map, - gpu_records, - &indices_u32, - shard_offset, - logic_kind, - fetch_base_pc, - fetch_num_slots, - None, - Some(shard_bufs), + split_full!( + hal.witgen + .witgen_logic_r( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + logic_kind, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), + ) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_logic_r failed: {e}").into(), ) - .map_err(|e| { - ZKVMError::InvalidWitness( - format!("GPU witgen_logic_r failed: {e}").into(), - ) - }) - ) - }) + }) + ) + }) }) } - + GpuWitgenKind::LogicI(logic_kind) => { let logic_config = unsafe { &*(config as *const I::InstructionConfig @@ -579,29 +557,29 @@ fn gpu_fill_witness>( }); info_span!("hal_witgen_logic_i").in_scope(|| { with_cached_gpu_ctx(|gpu_records, shard_bufs| { - split_full!( - hal.witgen - .witgen_logic_i( - &col_map, - gpu_records, - &indices_u32, - shard_offset, - logic_kind, - fetch_base_pc, - fetch_num_slots, - None, - Some(shard_bufs), + split_full!( + hal.witgen + .witgen_logic_i( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + logic_kind, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), + ) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_logic_i failed: {e}").into(), ) - .map_err(|e| { - ZKVMError::InvalidWitness( - format!("GPU witgen_logic_i failed: {e}").into(), - ) - }) - ) - }) + }) + ) + }) }) } - + GpuWitgenKind::Addi => { let addi_config = unsafe { &*(config as *const I::InstructionConfig @@ -611,28 +589,28 @@ fn gpu_fill_witness>( .in_scope(|| super::chips::addi::extract_addi_column_map(addi_config, num_witin)); info_span!("hal_witgen_addi").in_scope(|| { with_cached_gpu_ctx(|gpu_records, shard_bufs| { - split_full!( - hal.witgen - .witgen_addi( - &col_map, - gpu_records, - &indices_u32, - shard_offset, - fetch_base_pc, - fetch_num_slots, - None, - Some(shard_bufs), + split_full!( + hal.witgen + .witgen_addi( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), + ) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_addi failed: {e}").into(), ) - .map_err(|e| { - ZKVMError::InvalidWitness( - format!("GPU witgen_addi failed: {e}").into(), - ) - }) - ) - }) + }) + ) + }) }) } - + GpuWitgenKind::Lui => { let lui_config = unsafe { &*(config as *const I::InstructionConfig @@ -642,28 +620,28 @@ fn gpu_fill_witness>( .in_scope(|| super::chips::lui::extract_lui_column_map(lui_config, num_witin)); info_span!("hal_witgen_lui").in_scope(|| { with_cached_gpu_ctx(|gpu_records, shard_bufs| { - split_full!( - hal.witgen - .witgen_lui( - &col_map, - gpu_records, - &indices_u32, - shard_offset, - fetch_base_pc, - fetch_num_slots, - None, - Some(shard_bufs), + split_full!( + hal.witgen + .witgen_lui( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), + ) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_lui failed: {e}").into(), ) - .map_err(|e| { - ZKVMError::InvalidWitness( - format!("GPU witgen_lui failed: {e}").into(), - ) - }) - ) - }) + }) + ) + }) }) } - + GpuWitgenKind::Auipc => { let auipc_config = unsafe { &*(config as *const I::InstructionConfig @@ -674,28 +652,28 @@ fn gpu_fill_witness>( }); info_span!("hal_witgen_auipc").in_scope(|| { with_cached_gpu_ctx(|gpu_records, shard_bufs| { - split_full!( - hal.witgen - .witgen_auipc( - &col_map, - gpu_records, - &indices_u32, - shard_offset, - fetch_base_pc, - fetch_num_slots, - None, - Some(shard_bufs), + split_full!( + hal.witgen + .witgen_auipc( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), + ) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_auipc failed: {e}").into(), ) - .map_err(|e| { - ZKVMError::InvalidWitness( - format!("GPU witgen_auipc failed: {e}").into(), - ) - }) - ) - }) + }) + ) + }) }) } - + GpuWitgenKind::Jal => { let jal_config = unsafe { &*(config as *const I::InstructionConfig @@ -705,28 +683,28 @@ fn gpu_fill_witness>( .in_scope(|| super::chips::jal::extract_jal_column_map(jal_config, num_witin)); info_span!("hal_witgen_jal").in_scope(|| { with_cached_gpu_ctx(|gpu_records, shard_bufs| { - split_full!( - hal.witgen - .witgen_jal( - &col_map, - gpu_records, - &indices_u32, - shard_offset, - fetch_base_pc, - fetch_num_slots, - None, - Some(shard_bufs), + split_full!( + hal.witgen + .witgen_jal( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), + ) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_jal failed: {e}").into(), ) - .map_err(|e| { - ZKVMError::InvalidWitness( - format!("GPU witgen_jal failed: {e}").into(), - ) - }) - ) - }) + }) + ) + }) }) } - + GpuWitgenKind::ShiftR(shift_kind) => { let shift_config = unsafe { &*(config as *const I::InstructionConfig @@ -739,29 +717,29 @@ fn gpu_fill_witness>( }); info_span!("hal_witgen_shift_r").in_scope(|| { with_cached_gpu_ctx(|gpu_records, shard_bufs| { - split_full!( - hal.witgen - .witgen_shift_r( - &col_map, - gpu_records, - &indices_u32, - shard_offset, - shift_kind, - fetch_base_pc, - fetch_num_slots, - None, - Some(shard_bufs), + split_full!( + hal.witgen + .witgen_shift_r( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + shift_kind, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), + ) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_shift_r failed: {e}").into(), ) - .map_err(|e| { - ZKVMError::InvalidWitness( - format!("GPU witgen_shift_r failed: {e}").into(), - ) - }) - ) - }) + }) + ) + }) }) } - + GpuWitgenKind::ShiftI(shift_kind) => { let shift_config = unsafe { &*(config as *const I::InstructionConfig @@ -774,29 +752,29 @@ fn gpu_fill_witness>( }); info_span!("hal_witgen_shift_i").in_scope(|| { with_cached_gpu_ctx(|gpu_records, shard_bufs| { - split_full!( - hal.witgen - .witgen_shift_i( - &col_map, - gpu_records, - &indices_u32, - shard_offset, - shift_kind, - fetch_base_pc, - fetch_num_slots, - None, - Some(shard_bufs), + split_full!( + hal.witgen + .witgen_shift_i( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + shift_kind, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), + ) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_shift_i failed: {e}").into(), ) - .map_err(|e| { - ZKVMError::InvalidWitness( - format!("GPU witgen_shift_i failed: {e}").into(), - ) - }) - ) - }) + }) + ) + }) }) } - + GpuWitgenKind::Slt(is_signed) => { let slt_config = unsafe { &*(config as *const I::InstructionConfig @@ -806,29 +784,29 @@ fn gpu_fill_witness>( .in_scope(|| super::chips::slt::extract_slt_column_map(slt_config, num_witin)); info_span!("hal_witgen_slt").in_scope(|| { with_cached_gpu_ctx(|gpu_records, shard_bufs| { - split_full!( - hal.witgen - .witgen_slt( - &col_map, - gpu_records, - &indices_u32, - shard_offset, - is_signed, - fetch_base_pc, - fetch_num_slots, - None, - Some(shard_bufs), + split_full!( + hal.witgen + .witgen_slt( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + is_signed, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), + ) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_slt failed: {e}").into(), ) - .map_err(|e| { - ZKVMError::InvalidWitness( - format!("GPU witgen_slt failed: {e}").into(), - ) - }) - ) - }) + }) + ) + }) }) } - + GpuWitgenKind::Slti(is_signed) => { let slti_config = unsafe { &*(config as *const I::InstructionConfig @@ -838,29 +816,29 @@ fn gpu_fill_witness>( .in_scope(|| super::chips::slti::extract_slti_column_map(slti_config, num_witin)); info_span!("hal_witgen_slti").in_scope(|| { with_cached_gpu_ctx(|gpu_records, shard_bufs| { - split_full!( - hal.witgen - .witgen_slti( - &col_map, - gpu_records, - &indices_u32, - shard_offset, - is_signed, - fetch_base_pc, - fetch_num_slots, - None, - Some(shard_bufs), + split_full!( + hal.witgen + .witgen_slti( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + is_signed, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), + ) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_slti failed: {e}").into(), ) - .map_err(|e| { - ZKVMError::InvalidWitness( - format!("GPU witgen_slti failed: {e}").into(), - ) - }) - ) - }) + }) + ) + }) }) } - + GpuWitgenKind::BranchEq(is_beq) => { let branch_config = unsafe { &*(config as *const I::InstructionConfig @@ -873,29 +851,29 @@ fn gpu_fill_witness>( }); info_span!("hal_witgen_branch_eq").in_scope(|| { with_cached_gpu_ctx(|gpu_records, shard_bufs| { - split_full!( - hal.witgen - .witgen_branch_eq( - &col_map, - gpu_records, - &indices_u32, - shard_offset, - is_beq, - fetch_base_pc, - fetch_num_slots, - None, - Some(shard_bufs), + split_full!( + hal.witgen + .witgen_branch_eq( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + is_beq, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), + ) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_branch_eq failed: {e}").into(), ) - .map_err(|e| { - ZKVMError::InvalidWitness( - format!("GPU witgen_branch_eq failed: {e}").into(), - ) - }) - ) - }) + }) + ) + }) }) } - + GpuWitgenKind::BranchCmp(is_signed) => { let branch_config = unsafe { &*(config as *const I::InstructionConfig @@ -908,29 +886,29 @@ fn gpu_fill_witness>( }); info_span!("hal_witgen_branch_cmp").in_scope(|| { with_cached_gpu_ctx(|gpu_records, shard_bufs| { - split_full!( - hal.witgen - .witgen_branch_cmp( - &col_map, - gpu_records, - &indices_u32, - shard_offset, - is_signed, - fetch_base_pc, - fetch_num_slots, - None, - Some(shard_bufs), + split_full!( + hal.witgen + .witgen_branch_cmp( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + is_signed, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), + ) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_branch_cmp failed: {e}").into(), ) - .map_err(|e| { - ZKVMError::InvalidWitness( - format!("GPU witgen_branch_cmp failed: {e}").into(), - ) - }) - ) - }) + }) + ) + }) }) } - + GpuWitgenKind::Jalr => { let jalr_config = unsafe { &*(config as *const I::InstructionConfig @@ -940,28 +918,28 @@ fn gpu_fill_witness>( .in_scope(|| super::chips::jalr::extract_jalr_column_map(jalr_config, num_witin)); info_span!("hal_witgen_jalr").in_scope(|| { with_cached_gpu_ctx(|gpu_records, shard_bufs| { - split_full!( - hal.witgen - .witgen_jalr( - &col_map, - gpu_records, - &indices_u32, - shard_offset, - fetch_base_pc, - fetch_num_slots, - None, - Some(shard_bufs), + split_full!( + hal.witgen + .witgen_jalr( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), + ) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_jalr failed: {e}").into(), ) - .map_err(|e| { - ZKVMError::InvalidWitness( - format!("GPU witgen_jalr failed: {e}").into(), - ) - }) - ) - }) + }) + ) + }) }) } - + GpuWitgenKind::Sw => { let sw_config = unsafe { &*(config as *const I::InstructionConfig @@ -972,29 +950,29 @@ fn gpu_fill_witness>( .in_scope(|| super::chips::sw::extract_sw_column_map(sw_config, num_witin)); info_span!("hal_witgen_sw").in_scope(|| { with_cached_gpu_ctx(|gpu_records, shard_bufs| { - split_full!( - hal.witgen - .witgen_sw( - &col_map, - gpu_records, - &indices_u32, - shard_offset, - mem_max_bits, - fetch_base_pc, - fetch_num_slots, - None, - Some(shard_bufs), + split_full!( + hal.witgen + .witgen_sw( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + mem_max_bits, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), + ) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_sw failed: {e}").into(), ) - .map_err(|e| { - ZKVMError::InvalidWitness( - format!("GPU witgen_sw failed: {e}").into(), - ) - }) - ) - }) + }) + ) + }) }) } - + GpuWitgenKind::Sh => { let sh_config = unsafe { &*(config as *const I::InstructionConfig @@ -1005,29 +983,29 @@ fn gpu_fill_witness>( .in_scope(|| super::chips::sh::extract_sh_column_map(sh_config, num_witin)); info_span!("hal_witgen_sh").in_scope(|| { with_cached_gpu_ctx(|gpu_records, shard_bufs| { - split_full!( - hal.witgen - .witgen_sh( - &col_map, - gpu_records, - &indices_u32, - shard_offset, - mem_max_bits, - fetch_base_pc, - fetch_num_slots, - None, - Some(shard_bufs), + split_full!( + hal.witgen + .witgen_sh( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + mem_max_bits, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), + ) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_sh failed: {e}").into(), ) - .map_err(|e| { - ZKVMError::InvalidWitness( - format!("GPU witgen_sh failed: {e}").into(), - ) - }) - ) - }) + }) + ) + }) }) } - + GpuWitgenKind::Sb => { let sb_config = unsafe { &*(config as *const I::InstructionConfig @@ -1038,29 +1016,29 @@ fn gpu_fill_witness>( .in_scope(|| super::chips::sb::extract_sb_column_map(sb_config, num_witin)); info_span!("hal_witgen_sb").in_scope(|| { with_cached_gpu_ctx(|gpu_records, shard_bufs| { - split_full!( - hal.witgen - .witgen_sb( - &col_map, - gpu_records, - &indices_u32, - shard_offset, - mem_max_bits, - fetch_base_pc, - fetch_num_slots, - None, - Some(shard_bufs), + split_full!( + hal.witgen + .witgen_sb( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + mem_max_bits, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), + ) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_sb failed: {e}").into(), ) - .map_err(|e| { - ZKVMError::InvalidWitness( - format!("GPU witgen_sb failed: {e}").into(), - ) - }) - ) - }) + }) + ) + }) }) } - + GpuWitgenKind::LoadSub { load_width, is_signed, @@ -1075,64 +1053,63 @@ fn gpu_fill_witness>( let mem_max_bits = load_config.memory_addr.max_bits as u32; info_span!("hal_witgen_load_sub").in_scope(|| { with_cached_gpu_ctx(|gpu_records, shard_bufs| { - split_full!( - hal.witgen - .witgen_load_sub( - &col_map, - gpu_records, - &indices_u32, - shard_offset, - load_width, - is_signed, - mem_max_bits, - fetch_base_pc, - fetch_num_slots, - None, - Some(shard_bufs), + split_full!( + hal.witgen + .witgen_load_sub( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + load_width, + is_signed, + mem_max_bits, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), + ) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_load_sub failed: {e}").into(), ) - .map_err(|e| { - ZKVMError::InvalidWitness( - format!("GPU witgen_load_sub failed: {e}").into(), - ) - }) - ) - }) + }) + ) + }) }) } - + GpuWitgenKind::Mul(mul_kind) => { let mul_config = unsafe { &*(config as *const I::InstructionConfig as *const crate::instructions::riscv::mulh::mulh_circuit_v2::MulhConfig) }; - let col_map = info_span!("col_map").in_scope(|| { - super::chips::mul::extract_mul_column_map(mul_config, num_witin) - }); + let col_map = info_span!("col_map") + .in_scope(|| super::chips::mul::extract_mul_column_map(mul_config, num_witin)); info_span!("hal_witgen_mul").in_scope(|| { with_cached_gpu_ctx(|gpu_records, shard_bufs| { - split_full!( - hal.witgen - .witgen_mul( - &col_map, - gpu_records, - &indices_u32, - shard_offset, - mul_kind, - fetch_base_pc, - fetch_num_slots, - None, - Some(shard_bufs), + split_full!( + hal.witgen + .witgen_mul( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + mul_kind, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), + ) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_mul failed: {e}").into(), ) - .map_err(|e| { - ZKVMError::InvalidWitness( - format!("GPU witgen_mul failed: {e}").into(), - ) - }) - ) - }) + }) + ) + }) }) } - + GpuWitgenKind::Div(div_kind) => { let div_config = unsafe { &*(config as *const I::InstructionConfig @@ -1142,30 +1119,29 @@ fn gpu_fill_witness>( .in_scope(|| super::chips::div::extract_div_column_map(div_config, num_witin)); info_span!("hal_witgen_div").in_scope(|| { with_cached_gpu_ctx(|gpu_records, shard_bufs| { - split_full!( - hal.witgen - .witgen_div( - &col_map, - gpu_records, - &indices_u32, - shard_offset, - div_kind, - fetch_base_pc, - fetch_num_slots, - None, - Some(shard_bufs), + split_full!( + hal.witgen + .witgen_div( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + div_kind, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), + ) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_div failed: {e}").into(), ) - .map_err(|e| { - ZKVMError::InvalidWitness( - format!("GPU witgen_div failed: {e}").into(), - ) - }) - ) - }) + }) + ) + }) }) } GpuWitgenKind::Lw => { - let load_config = unsafe { &*(config as *const I::InstructionConfig as *const crate::instructions::riscv::memory::load_v2::LoadConfig) @@ -1180,26 +1156,26 @@ fn gpu_fill_witness>( .in_scope(|| super::chips::lw::extract_lw_column_map(load_config, num_witin)); info_span!("hal_witgen_lw").in_scope(|| { with_cached_gpu_ctx(|gpu_records, shard_bufs| { - split_full!( - hal.witgen - .witgen_lw( - &col_map, - gpu_records, - &indices_u32, - shard_offset, - mem_max_bits, - fetch_base_pc, - fetch_num_slots, - None, - Some(shard_bufs), + split_full!( + hal.witgen + .witgen_lw( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + mem_max_bits, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), + ) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_lw failed: {e}").into(), ) - .map_err(|e| { - ZKVMError::InvalidWitness( - format!("GPU witgen_lw failed: {e}").into(), - ) - }) - ) - }) + }) + ) + }) }) } GpuWitgenKind::Keccak => { @@ -1207,5 +1183,3 @@ fn gpu_fill_witness>( } } } - - diff --git a/ceno_zkvm/src/instructions/gpu/utils/column_map.rs b/ceno_zkvm/src/instructions/gpu/utils/column_map.rs index 789df64f4..58f3d5faa 100644 --- a/ceno_zkvm/src/instructions/gpu/utils/column_map.rs +++ b/ceno_zkvm/src/instructions/gpu/utils/column_map.rs @@ -189,7 +189,8 @@ macro_rules! test_colmap { ($test_name:ident, $Instruction:ty, $extract_fn:ident) => { #[test] fn $test_name() { - let mut cs = crate::circuit_builder::ConstraintSystem::::new(|| stringify!($test_name)); + let mut cs = + crate::circuit_builder::ConstraintSystem::::new(|| stringify!($test_name)); let mut cb = crate::circuit_builder::CircuitBuilder::new(&mut cs); let config = <$Instruction as crate::instructions::Instruction>::construct_circuit( &mut cb, @@ -208,4 +209,3 @@ macro_rules! test_colmap { #[cfg(test)] pub(crate) use test_colmap; - diff --git a/ceno_zkvm/src/instructions/gpu/utils/mod.rs b/ceno_zkvm/src/instructions/gpu/utils/mod.rs index 5811cf199..1efabe305 100644 --- a/ceno_zkvm/src/instructions/gpu/utils/mod.rs +++ b/ceno_zkvm/src/instructions/gpu/utils/mod.rs @@ -19,7 +19,7 @@ pub mod column_map; pub mod d2h; #[cfg(feature = "gpu")] pub mod debug_compare; -#[cfg(test)] +#[cfg(all(test, feature = "gpu"))] pub mod test_helpers; #[cfg(test)] diff --git a/ceno_zkvm/src/instructions/gpu/utils/test_helpers.rs b/ceno_zkvm/src/instructions/gpu/utils/test_helpers.rs index f8c48f598..5e04edfe2 100644 --- a/ceno_zkvm/src/instructions/gpu/utils/test_helpers.rs +++ b/ceno_zkvm/src/instructions/gpu/utils/test_helpers.rs @@ -25,9 +25,7 @@ pub fn assert_witness_colmajor_eq( let cpu_val = &cpu_rowmajor[row * n_cols + col]; if gpu_val != cpu_val { if mismatches < 10 { - eprintln!( - "Mismatch at row={row}, col={col}: GPU={gpu_val:?}, CPU={cpu_val:?}" - ); + eprintln!("Mismatch at row={row}, col={col}: GPU={gpu_val:?}, CPU={cpu_val:?}"); } mismatches += 1; } @@ -40,7 +38,10 @@ pub fn assert_witness_colmajor_eq( /// witness, LK multiplicity, addr_accessed, and read/write records all match /// the CPU reference in `cpu_ctx`. #[cfg(test)] -pub fn assert_full_gpu_pipeline>( +pub fn assert_full_gpu_pipeline< + E: ff_ext::ExtensionField, + I: crate::instructions::Instruction, +>( config: &I::InstructionConfig, steps: &[ceno_emul::StepRecord], kind: crate::instructions::gpu::dispatch::GpuWitgenKind, @@ -53,22 +54,25 @@ pub fn assert_full_gpu_pipeline = (0..steps.len()).collect(); let mut gpu_ctx = crate::e2e::ShardContext::default(); - let (gpu_rmms, gpu_lkm) = - crate::instructions::gpu::dispatch::try_gpu_assign_instances::( - config, - &mut gpu_ctx, - num_witin, - num_structural_witin, - steps, - &indices, - kind, - ) - .unwrap() - .expect("GPU path should be available"); + let (gpu_rmms, gpu_lkm) = crate::instructions::gpu::dispatch::try_gpu_assign_instances::( + config, + &mut gpu_ctx, + num_witin, + num_structural_witin, + steps, + &indices, + kind, + ) + .unwrap() + .expect("GPU path should be available"); crate::instructions::gpu::cache::flush_shared_ec_buffers(&mut gpu_ctx).unwrap(); - assert_eq!(gpu_rmms[0].values(), cpu_rmms[0].values(), "witness mismatch"); + assert_eq!( + gpu_rmms[0].values(), + cpu_rmms[0].values(), + "witness mismatch" + ); assert_eq!( flatten_lk_for_test(&gpu_lkm), flatten_lk_for_test(cpu_lkm), diff --git a/ceno_zkvm/src/scheme/septic_curve.rs b/ceno_zkvm/src/scheme/septic_curve.rs index 8714ba8fe..ef654ad04 100644 --- a/ceno_zkvm/src/scheme/septic_curve.rs +++ b/ceno_zkvm/src/scheme/septic_curve.rs @@ -1112,10 +1112,7 @@ impl SepticJacobianPoint { mod tests { use super::SepticExtension; use crate::scheme::septic_curve::{SepticJacobianPoint, SepticPoint}; - use p3::{ - babybear::BabyBear, - field::Field, - }; + use p3::{babybear::BabyBear, field::Field}; use rand::thread_rng; type F = BabyBear; @@ -1344,13 +1341,12 @@ mod tests { #[test] #[cfg(feature = "gpu")] fn test_gpu_poseidon2_sponge_matches_cpu() { - use p3::field::FieldAlgebra; use ceno_gpu::bb31::{ CudaHalBB31, test_impl::{SPONGE_WIDTH, run_gpu_poseidon2_sponge}, }; use ff_ext::{PoseidonField, SmallField}; - use p3::symmetric::Permutation; + use p3::{field::FieldAlgebra, symmetric::Permutation}; let hal = CudaHalBB31::new(0).unwrap(); let perm = F::get_default_perm(); @@ -1418,9 +1414,9 @@ mod tests { #[test] #[cfg(feature = "gpu")] fn test_gpu_septic_from_x_matches_cpu() { - use p3::field::FieldAlgebra; use ceno_gpu::bb31::{CudaHalBB31, test_impl::run_gpu_septic_from_x}; use ff_ext::SmallField; + use p3::field::FieldAlgebra; let hal = CudaHalBB31::new(0).unwrap(); From 3c719dcec5873f106c9013b60bed72cb60882a2d Mon Sep 17 00:00:00 2001 From: Velaciela Date: Wed, 25 Mar 2026 09:12:52 +0800 Subject: [PATCH 68/73] shard_ram: funcs --- ceno_zkvm/src/instructions/gpu/cache.rs | 51 ------ .../src/instructions/gpu/chips/shard_ram.rs | 167 ++++++++++++++++++ ceno_zkvm/src/instructions/gpu/dispatch.rs | 6 +- ceno_zkvm/src/instructions/gpu/utils/d2h.rs | 114 ------------ ceno_zkvm/src/structs.rs | 5 +- 5 files changed, 173 insertions(+), 170 deletions(-) diff --git a/ceno_zkvm/src/instructions/gpu/cache.rs b/ceno_zkvm/src/instructions/gpu/cache.rs index e42b1cb88..0797c436a 100644 --- a/ceno_zkvm/src/instructions/gpu/cache.rs +++ b/ceno_zkvm/src/instructions/gpu/cache.rs @@ -384,57 +384,6 @@ pub struct SharedDeviceBufferSet { pub addr_count: ceno_gpu::common::buffer::BufferImpl<'static, u32>, } -/// Batch compute EC points for continuation records, keeping results on device. -/// -/// Returns (device_buf_as_u32, num_records) where the device buffer contains -/// GpuShardRamRecord entries with EC points computed. -pub fn gpu_batch_continuation_ec_on_device( - write_records: &[(crate::tables::ShardRamRecord, &'static str)], - read_records: &[(crate::tables::ShardRamRecord, &'static str)], -) -> Result< - ( - ceno_gpu::common::buffer::BufferImpl<'static, u32>, - usize, - usize, - ), - ZKVMError, -> { - use gkr_iop::gpu::get_cuda_hal; - - let hal = get_cuda_hal().map_err(|e| { - ZKVMError::InvalidWitness(format!("GPU not available for batch EC: {e}").into()) - })?; - - let n_writes = write_records.len(); - let n_reads = read_records.len(); - let total = n_writes + n_reads; - if total == 0 { - let empty = hal - .witgen - .alloc_u32_zeroed(1, None) - .map_err(|e| ZKVMError::InvalidWitness(format!("alloc: {e}").into()))?; - return Ok((empty, 0, 0)); - } - - // Convert to GpuShardRamRecord format (writes first, reads after) - let mut gpu_records: Vec = Vec::with_capacity(total); - for (rec, _name) in write_records.iter().chain(read_records.iter()) { - gpu_records.push(super::utils::d2h::shard_ram_record_to_gpu(rec)); - } - - // GPU batch EC, results stay on device - let (device_buf, _count) = info_span!("gpu_batch_ec_on_device", n = total) - .in_scope(|| { - hal.witgen - .batch_continuation_ec_on_device(&gpu_records, None) - }) - .map_err(|e| { - ZKVMError::InvalidWitness(format!("GPU batch EC on device failed: {e}").into()) - })?; - - Ok((device_buf, n_writes, n_reads)) -} - /// Read the current shared addr count from device (single u32 D2H). /// Used by debug comparison to snapshot count before/after a kernel. #[cfg(feature = "gpu")] diff --git a/ceno_zkvm/src/instructions/gpu/chips/shard_ram.rs b/ceno_zkvm/src/instructions/gpu/chips/shard_ram.rs index e512ab2fb..e0cca71d6 100644 --- a/ceno_zkvm/src/instructions/gpu/chips/shard_ram.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/shard_ram.rs @@ -81,6 +81,173 @@ pub fn extract_shard_ram_column_map( } } +// --------------------------------------------------------------------------- +// ShardRam EC batch computation +// --------------------------------------------------------------------------- + +use ceno_gpu::common::witgen::types::GpuShardRamRecord; +use gkr_iop::RAMType; +use p3::field::FieldAlgebra; +use tracing::info_span; + +use crate::error::ZKVMError; + +/// Convert a ShardRamRecord to GpuShardRamRecord (metadata only, EC fields zeroed). +pub(crate) fn shard_ram_record_to_gpu(rec: &crate::tables::ShardRamRecord) -> GpuShardRamRecord { + GpuShardRamRecord { + addr: rec.addr, + ram_type: match rec.ram_type { + RAMType::Register => 1, + RAMType::Memory => 2, + _ => 0, + }, + value: rec.value, + _pad0: 0, + shard: rec.shard, + local_clk: rec.local_clk, + global_clk: rec.global_clk, + is_to_write_set: if rec.is_to_write_set { 1 } else { 0 }, + nonce: 0, + point_x: [0; 7], + point_y: [0; 7], + } +} + +/// Convert a GPU-computed GpuShardRamRecord to ECPoint. +fn gpu_shard_ram_record_to_ec_point( + gpu_rec: &GpuShardRamRecord, +) -> crate::tables::ECPoint { + use crate::scheme::septic_curve::{SepticExtension, SepticPoint}; + + let mut point_x_arr = [E::BaseField::ZERO; 7]; + let mut point_y_arr = [E::BaseField::ZERO; 7]; + for j in 0..7 { + point_x_arr[j] = E::BaseField::from_canonical_u32(gpu_rec.point_x[j]); + point_y_arr[j] = E::BaseField::from_canonical_u32(gpu_rec.point_y[j]); + } + + let x = SepticExtension(point_x_arr); + let y = SepticExtension(point_y_arr); + let point = SepticPoint::from_affine(x, y); + + crate::tables::ECPoint { + nonce: gpu_rec.nonce, + point, + } +} + +/// Batch compute EC points on GPU, D2H results back to CPU as ShardRamInput. +/// +/// Used by the CPU fallback path in `structs.rs` when the full GPU pipeline +/// is unavailable. For the device-resident variant, see `gpu_batch_continuation_ec_on_device`. +pub fn gpu_batch_continuation_ec( + write_records: &[(crate::tables::ShardRamRecord, &'static str)], + read_records: &[(crate::tables::ShardRamRecord, &'static str)], +) -> Result< + ( + Vec>, + Vec>, + ), + ZKVMError, +> { + use crate::tables::ShardRamInput; + use gkr_iop::gpu::get_cuda_hal; + + let hal = get_cuda_hal().map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU not available for batch EC: {e}").into()) + })?; + + let total = write_records.len() + read_records.len(); + if total == 0 { + return Ok((vec![], vec![])); + } + + let mut gpu_records: Vec = Vec::with_capacity(total); + for (rec, _name) in write_records.iter().chain(read_records.iter()) { + gpu_records.push(shard_ram_record_to_gpu(rec)); + } + + let result = info_span!("gpu_batch_ec", n = total) + .in_scope(|| hal.witgen.batch_continuation_ec(&gpu_records)) + .map_err(|e| ZKVMError::InvalidWitness(format!("GPU batch EC failed: {e}").into()))?; + + let mut write_inputs = Vec::with_capacity(write_records.len()); + let mut read_inputs = Vec::with_capacity(read_records.len()); + + for (i, gpu_rec) in result.iter().enumerate() { + let (rec, name) = if i < write_records.len() { + (&write_records[i].0, write_records[i].1) + } else { + let ri = i - write_records.len(); + (&read_records[ri].0, read_records[ri].1) + }; + + let ec_point = gpu_shard_ram_record_to_ec_point::(gpu_rec); + let input = ShardRamInput { + name, + record: rec.clone(), + ec_point, + }; + + if i < write_records.len() { + write_inputs.push(input); + } else { + read_inputs.push(input); + } + } + + Ok((write_inputs, read_inputs)) +} + +/// Batch compute EC points on GPU, results stay on device. +/// +/// Used by the full GPU pipeline in `structs.rs` where records feed directly +/// into `merge_and_partition_records` on device without D2H. +pub fn gpu_batch_continuation_ec_on_device( + write_records: &[(crate::tables::ShardRamRecord, &'static str)], + read_records: &[(crate::tables::ShardRamRecord, &'static str)], +) -> Result< + ( + ceno_gpu::common::buffer::BufferImpl<'static, u32>, + usize, + usize, + ), + ZKVMError, +> { + use gkr_iop::gpu::get_cuda_hal; + + let hal = get_cuda_hal().map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU not available for batch EC: {e}").into()) + })?; + + let n_writes = write_records.len(); + let n_reads = read_records.len(); + let total = n_writes + n_reads; + if total == 0 { + let empty = hal + .witgen + .alloc_u32_zeroed(1, None) + .map_err(|e| ZKVMError::InvalidWitness(format!("alloc: {e}").into()))?; + return Ok((empty, 0, 0)); + } + + let mut gpu_records: Vec = Vec::with_capacity(total); + for (rec, _name) in write_records.iter().chain(read_records.iter()) { + gpu_records.push(shard_ram_record_to_gpu(rec)); + } + + let (device_buf, _count) = info_span!("gpu_batch_ec_on_device", n = total) + .in_scope(|| { + hal.witgen + .batch_continuation_ec_on_device(&gpu_records, None) + }) + .map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU batch EC on device failed: {e}").into()) + })?; + + Ok((device_buf, n_writes, n_reads)) +} + #[cfg(test)] mod tests { use super::*; diff --git a/ceno_zkvm/src/instructions/gpu/dispatch.rs b/ceno_zkvm/src/instructions/gpu/dispatch.rs index 772c7cce9..1b7e39e75 100644 --- a/ceno_zkvm/src/instructions/gpu/dispatch.rs +++ b/ceno_zkvm/src/instructions/gpu/dispatch.rs @@ -64,11 +64,11 @@ pub enum GpuWitgenKind { // Re-exports from device_cache module for external callers (e2e.rs, structs.rs). pub use super::cache::{ - SharedDeviceBufferSet, flush_shared_ec_buffers, gpu_batch_continuation_ec_on_device, - invalidate_shard_meta_cache, invalidate_shard_steps_cache, take_shared_device_buffers, + SharedDeviceBufferSet, flush_shared_ec_buffers, invalidate_shard_meta_cache, + invalidate_shard_steps_cache, take_shared_device_buffers, }; // Re-export for external callers (structs.rs). -pub use super::utils::d2h::gpu_batch_continuation_ec; +pub use super::chips::shard_ram::gpu_batch_continuation_ec; use super::{ cache::{ ensure_shard_metadata_cached, read_shared_addr_count, read_shared_addr_range, diff --git a/ceno_zkvm/src/instructions/gpu/utils/d2h.rs b/ceno_zkvm/src/instructions/gpu/utils/d2h.rs index 497abb4c2..1a4c96acc 100644 --- a/ceno_zkvm/src/instructions/gpu/utils/d2h.rs +++ b/ceno_zkvm/src/instructions/gpu/utils/d2h.rs @@ -139,120 +139,6 @@ pub(crate) fn gpu_compact_ec_d2h( Ok(records) } -/// Batch compute EC points for continuation circuit ShardRamRecords on GPU. -/// -/// Converts ShardRamRecords to GPU format, launches the `batch_continuation_ec` -/// kernel to compute Poseidon2 + SepticCurve on device, and converts results -/// back to ShardRamInput (with EC points). -/// -/// Returns (write_inputs, read_inputs) maintaining the write-before-read ordering -/// invariant required by ShardRamCircuit::assign_instances. -pub fn gpu_batch_continuation_ec( - write_records: &[(crate::tables::ShardRamRecord, &'static str)], - read_records: &[(crate::tables::ShardRamRecord, &'static str)], -) -> Result< - ( - Vec>, - Vec>, - ), - ZKVMError, -> { - use crate::tables::ShardRamInput; - use gkr_iop::gpu::get_cuda_hal; - - let hal = get_cuda_hal().map_err(|e| { - ZKVMError::InvalidWitness(format!("GPU not available for batch EC: {e}").into()) - })?; - - let total = write_records.len() + read_records.len(); - if total == 0 { - return Ok((vec![], vec![])); - } - - // Convert ShardRamRecords to GpuShardRamRecord format - let mut gpu_records: Vec = Vec::with_capacity(total); - for (rec, _name) in write_records.iter().chain(read_records.iter()) { - gpu_records.push(shard_ram_record_to_gpu(rec)); - } - - // GPU batch EC computation - let result = info_span!("gpu_batch_ec", n = total) - .in_scope(|| hal.witgen.batch_continuation_ec(&gpu_records)) - .map_err(|e| ZKVMError::InvalidWitness(format!("GPU batch EC failed: {e}").into()))?; - - // Convert back to ShardRamInput, split into writes and reads - let mut write_inputs = Vec::with_capacity(write_records.len()); - let mut read_inputs = Vec::with_capacity(read_records.len()); - - for (i, gpu_rec) in result.iter().enumerate() { - let (rec, name) = if i < write_records.len() { - (&write_records[i].0, write_records[i].1) - } else { - let ri = i - write_records.len(); - (&read_records[ri].0, read_records[ri].1) - }; - - let ec_point = gpu_shard_ram_record_to_ec_point::(gpu_rec); - let input = ShardRamInput { - name, - record: rec.clone(), - ec_point, - }; - - if i < write_records.len() { - write_inputs.push(input); - } else { - read_inputs.push(input); - } - } - - Ok((write_inputs, read_inputs)) -} - -/// Convert a ShardRamRecord to GpuShardRamRecord (metadata only, EC fields zeroed). -pub(crate) fn shard_ram_record_to_gpu(rec: &crate::tables::ShardRamRecord) -> GpuShardRamRecord { - GpuShardRamRecord { - addr: rec.addr, - ram_type: match rec.ram_type { - RAMType::Register => 1, - RAMType::Memory => 2, - _ => 0, - }, - value: rec.value, - _pad0: 0, - shard: rec.shard, - local_clk: rec.local_clk, - global_clk: rec.global_clk, - is_to_write_set: if rec.is_to_write_set { 1 } else { 0 }, - nonce: 0, - point_x: [0; 7], - point_y: [0; 7], - } -} - -/// Convert a GPU-computed GpuShardRamRecord to ECPoint. -pub(crate) fn gpu_shard_ram_record_to_ec_point( - gpu_rec: &GpuShardRamRecord, -) -> crate::tables::ECPoint { - use crate::scheme::septic_curve::{SepticExtension, SepticPoint}; - - let mut point_x_arr = [E::BaseField::ZERO; 7]; - let mut point_y_arr = [E::BaseField::ZERO; 7]; - for j in 0..7 { - point_x_arr[j] = E::BaseField::from_canonical_u32(gpu_rec.point_x[j]); - point_y_arr[j] = E::BaseField::from_canonical_u32(gpu_rec.point_y[j]); - } - - let x = SepticExtension(point_x_arr); - let y = SepticExtension(point_y_arr); - let point = SepticPoint::from_affine(x, y); - - crate::tables::ECPoint { - nonce: gpu_rec.nonce, - point, - } -} - pub(crate) fn gpu_lk_counters_to_multiplicity( counters: LkResult, ) -> Result, ZKVMError> { diff --git a/ceno_zkvm/src/structs.rs b/ceno_zkvm/src/structs.rs index 680ff4184..8e2e4a390 100644 --- a/ceno_zkvm/src/structs.rs +++ b/ceno_zkvm/src/structs.rs @@ -746,8 +746,9 @@ impl ZKVMWitnesses { final_mem: &[(&'static str, Option>, &[MemFinalRecord])], config: & as TableCircuit>::TableConfig, ) -> Result { - use crate::instructions::gpu::dispatch::{ - gpu_batch_continuation_ec_on_device, take_shared_device_buffers, + use crate::instructions::gpu::{ + chips::shard_ram::gpu_batch_continuation_ec_on_device, + dispatch::take_shared_device_buffers, }; use ceno_gpu::Buffer; use gkr_iop::gpu::get_cuda_hal; From bd18f1d5e4c5563216d40dc245f0112f5014db27 Mon Sep 17 00:00:00 2001 From: Velaciela Date: Wed, 25 Mar 2026 09:34:57 +0800 Subject: [PATCH 69/73] default: disable gpu witgen --- .../src/instructions/gpu/chips/keccak.rs | 4 ++-- ceno_zkvm/src/instructions/gpu/config.rs | 21 ++++++++++--------- ceno_zkvm/src/instructions/gpu/dispatch.rs | 4 ++-- .../instructions/gpu/utils/debug_compare.rs | 2 +- .../instructions/gpu/utils/test_helpers.rs | 10 ++++++--- 5 files changed, 23 insertions(+), 18 deletions(-) diff --git a/ceno_zkvm/src/instructions/gpu/chips/keccak.rs b/ceno_zkvm/src/instructions/gpu/chips/keccak.rs index e64d9311f..7c8200ce2 100644 --- a/ceno_zkvm/src/instructions/gpu/chips/keccak.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/keccak.rs @@ -26,7 +26,7 @@ use crate::{ ensure_shard_metadata_cached, read_shared_addr_count, read_shared_addr_range, with_cached_shard_meta, }, - config::{is_debug_compare_enabled, is_gpu_witgen_disabled, is_kind_disabled}, + config::{is_debug_compare_enabled, is_gpu_witgen_enabled, is_kind_disabled}, dispatch::{GpuWitgenKind, compute_fetch_params, is_force_cpu_path}, utils::{ d2h::{gpu_compact_ec_d2h, gpu_lk_counters_to_multiplicity}, @@ -217,7 +217,7 @@ pub fn gpu_assign_keccak_instances( use gkr_iop::gpu::get_cuda_hal; // Guard: disabled or force-CPU - if is_gpu_witgen_disabled() || is_force_cpu_path() { + if !is_gpu_witgen_enabled() || is_force_cpu_path() { return Ok(None); } // Check if keccak is disabled via CENO_GPU_DISABLE_WITGEN_KINDS=keccak diff --git a/ceno_zkvm/src/instructions/gpu/config.rs b/ceno_zkvm/src/instructions/gpu/config.rs index efb5ffc6c..6c9ad1e26 100644 --- a/ceno_zkvm/src/instructions/gpu/config.rs +++ b/ceno_zkvm/src/instructions/gpu/config.rs @@ -2,7 +2,7 @@ /// environment-variable disable switches. /// /// Environment variables (3 total): -/// - `CENO_GPU_DISABLE_WITGEN` — global kill switch, all chips use CPU +/// - `CENO_GPU_ENABLE_WITGEN` — opt-in GPU witgen (default: CPU) /// - `CENO_GPU_DISABLE_WITGEN_KINDS=add,sub,keccak,...` — per-kind disable (comma-separated tags) /// - `CENO_GPU_DEBUG_COMPARE_WITGEN` — enable GPU vs CPU comparison for all chips (witness, LK, shard, EC) use super::dispatch::GpuWitgenKind; @@ -59,19 +59,20 @@ pub(crate) fn is_kind_disabled(kind: GpuWitgenKind) -> bool { }) } -/// Returns true if GPU witgen is globally disabled via `CENO_GPU_DISABLE_WITGEN` env var. +/// Returns true if GPU witgen is enabled via `CENO_GPU_ENABLE_WITGEN` env var. +/// Default is disabled (CPU witgen). Set `CENO_GPU_ENABLE_WITGEN=1` to opt in. /// The value is cached at first access. -pub(crate) fn is_gpu_witgen_disabled() -> bool { +pub(crate) fn is_gpu_witgen_enabled() -> bool { use std::sync::OnceLock; - static DISABLED: OnceLock = OnceLock::new(); - *DISABLED.get_or_init(|| { - let val = std::env::var_os("CENO_GPU_DISABLE_WITGEN"); - let disabled = val.is_some(); + static ENABLED: OnceLock = OnceLock::new(); + *ENABLED.get_or_init(|| { + let val = std::env::var_os("CENO_GPU_ENABLE_WITGEN"); + let enabled = val.is_some(); eprintln!( - "[GPU witgen] CENO_GPU_DISABLE_WITGEN={:?} → disabled={}", - val, disabled + "[GPU witgen] CENO_GPU_ENABLE_WITGEN={:?} → enabled={}", + val, enabled ); - disabled + enabled }) } diff --git a/ceno_zkvm/src/instructions/gpu/dispatch.rs b/ceno_zkvm/src/instructions/gpu/dispatch.rs index 1b7e39e75..56331a432 100644 --- a/ceno_zkvm/src/instructions/gpu/dispatch.rs +++ b/ceno_zkvm/src/instructions/gpu/dispatch.rs @@ -21,7 +21,7 @@ use tracing::info_span; use witness::{InstancePaddingStrategy, RowMajorMatrix}; use super::{ - config::{is_gpu_witgen_disabled, is_kind_disabled}, + config::{is_gpu_witgen_enabled, is_kind_disabled}, utils::debug_compare::{ debug_compare_final_lk, debug_compare_shard_ec, debug_compare_shardram, debug_compare_witness, @@ -117,7 +117,7 @@ pub(crate) fn try_gpu_assign_instances>( ) -> Result, Multiplicity)>, ZKVMError> { use gkr_iop::gpu::get_cuda_hal; - if is_gpu_witgen_disabled() || is_force_cpu_path() { + if !is_gpu_witgen_enabled() || is_force_cpu_path() { return Ok(None); } diff --git a/ceno_zkvm/src/instructions/gpu/utils/debug_compare.rs b/ceno_zkvm/src/instructions/gpu/utils/debug_compare.rs index 1441a1849..1c91b6071 100644 --- a/ceno_zkvm/src/instructions/gpu/utils/debug_compare.rs +++ b/ceno_zkvm/src/instructions/gpu/utils/debug_compare.rs @@ -657,7 +657,7 @@ pub(crate) fn debug_compare_keccak( return Ok(()); } - // Guard against recursion: is_gpu_witgen_disabled() uses OnceLock so env var + // Guard against recursion: is_gpu_witgen_enabled() uses OnceLock so env var // manipulation doesn't work. Use a thread-local flag instead. thread_local! { static IN_DEBUG_COMPARE: Cell = const { Cell::new(false) }; diff --git a/ceno_zkvm/src/instructions/gpu/utils/test_helpers.rs b/ceno_zkvm/src/instructions/gpu/utils/test_helpers.rs index 5e04edfe2..3fbbad170 100644 --- a/ceno_zkvm/src/instructions/gpu/utils/test_helpers.rs +++ b/ceno_zkvm/src/instructions/gpu/utils/test_helpers.rs @@ -54,7 +54,7 @@ pub fn assert_full_gpu_pipeline< let indices: Vec = (0..steps.len()).collect(); let mut gpu_ctx = crate::e2e::ShardContext::default(); - let (gpu_rmms, gpu_lkm) = crate::instructions::gpu::dispatch::try_gpu_assign_instances::( + let result = crate::instructions::gpu::dispatch::try_gpu_assign_instances::( config, &mut gpu_ctx, num_witin, @@ -63,8 +63,12 @@ pub fn assert_full_gpu_pipeline< &indices, kind, ) - .unwrap() - .expect("GPU path should be available"); + .unwrap(); + // Skip pipeline comparison if GPU witgen is not enabled (CENO_GPU_ENABLE_WITGEN unset) + let Some((gpu_rmms, gpu_lkm)) = result else { + eprintln!("GPU witgen not enabled, skipping full pipeline comparison"); + return; + }; crate::instructions::gpu::cache::flush_shared_ec_buffers(&mut gpu_ctx).unwrap(); From 4ac6a6fe235828925a4e44b6e1e5d32e0df10d84 Mon Sep 17 00:00:00 2001 From: Velaciela Date: Wed, 25 Mar 2026 09:45:34 +0800 Subject: [PATCH 70/73] e2e: pipeline --- ceno_zkvm/src/e2e.rs | 176 ++++++++++++++++++++++--------------------- 1 file changed, 89 insertions(+), 87 deletions(-) diff --git a/ceno_zkvm/src/e2e.rs b/ceno_zkvm/src/e2e.rs index d2ca0bae6..b4cc40587 100644 --- a/ceno_zkvm/src/e2e.rs +++ b/ceno_zkvm/src/e2e.rs @@ -2087,89 +2087,91 @@ fn create_proofs_streaming< init_mem_state: &InitMemState, ) -> Vec> { let ctx = prover.pk.program_ctx.as_ref().unwrap(); + + // Two pipeline modes: + // + // Default GPU backend (CENO_GPU_ENABLE_WITGEN unset): + // Overlap: CPU witgen (thread A) and GPU prove (thread B) run in parallel. + // CPU produces witness for shard N+1 while GPU proves shard N. + // Uses a bounded(0) rendezvous channel for back-pressure. + // + // CENO_GPU_ENABLE_WITGEN=1 (GPU witgen) or CPU-only build: + // Sequential: witgen → prove, one shard at a time. + // When GPU witgen is on, GPU is shared between witgen and proving. + let proofs = info_span!("[ceno] app_prove.inner").in_scope(|| { - // #[cfg(feature = "gpu")] - // { - // use crossbeam::channel; - // let (tx, rx) = channel::bounded(0); - // std::thread::scope(|s| { - // // pipeline cpu/gpu workload - // // cpu producer - // s.spawn({ - // move || { - // let wit_iter = generate_witness( - // &ctx.system_config, - // emulation_result, - // ctx.program.clone(), - // &ctx.platform, - // init_mem_state, - // target_shard_id, - // ); - - // let wit_iter = if let Some(target_shard_id) = target_shard_id { - // Box::new(wit_iter.skip(target_shard_id)) as Box> - // } else { - // Box::new(wit_iter) - // }; - - // for proof_input in wit_iter { - // if tx.send(proof_input).is_err() { - // tracing::warn!( - // "witness consumer dropped; stopping witness generation early" - // ); - // break; - // } - // } - // } - // }); - - // // gpu consumer - // { - // let mut proofs = Vec::new(); - // let mut proof_err = None; - // let rx = rx; - // while let Ok((zkvm_witness, shard_ctx, pi)) = rx.recv() { - // if is_mock_proving { - // MockProver::assert_satisfied_full( - // &shard_ctx, - // &ctx.system_config.zkvm_cs, - // ctx.zkvm_fixed_traces.clone(), - // &zkvm_witness, - // &pi, - // &ctx.program, - // ); - // tracing::info!("Mock proving passed"); - // } - - // let transcript = Transcript::new(b"riscv"); - // let start = std::time::Instant::now(); - // match prover.create_proof(&shard_ctx, zkvm_witness, pi, transcript) { - // Ok(zkvm_proof) => { - // tracing::debug!( - // "{}th shard proof created in {:?}", - // shard_ctx.shard_id, - // start.elapsed() - // ); - // proofs.push(zkvm_proof); - // } - // Err(err) => { - // proof_err = Some(err); - // break; - // } - // } - // } - // drop(rx); - // if let Some(err) = proof_err { - // panic!("create_proof failed: {err:?}"); - // } - // proofs - // } - // }) - // } - - // #[cfg(not(feature = "gpu"))] + #[cfg(feature = "gpu")] + { + // With GPU feature: check if GPU witgen is enabled to select pipeline mode. + if !crate::instructions::gpu::config::is_gpu_witgen_enabled() { + // Default: overlap CPU witgen with GPU proving. + use crossbeam::channel; + let (tx, rx) = channel::bounded(0); + return std::thread::scope(|s| { + // CPU producer: generate witness shards + s.spawn(move || { + let wit_iter = generate_witness( + &ctx.system_config, + emulation_result, + ctx.program.clone(), + &ctx.platform, + init_mem_state, + target_shard_id, + ); + + let wit_iter: Box> = + if let Some(target_shard_id) = target_shard_id { + Box::new(wit_iter.skip(target_shard_id)) + } else { + Box::new(wit_iter) + }; + + for proof_input in wit_iter { + if tx.send(proof_input).is_err() { + tracing::warn!( + "witness consumer dropped; stopping witness generation early" + ); + break; + } + } + }); + + // GPU consumer: prove each shard as it arrives + let mut proofs = Vec::new(); + while let Ok((zkvm_witness, shard_ctx, pi)) = rx.recv() { + if is_mock_proving { + MockProver::assert_satisfied_full( + &shard_ctx, + &ctx.system_config.zkvm_cs, + ctx.zkvm_fixed_traces.clone(), + &zkvm_witness, + &pi, + &ctx.program, + ); + tracing::info!("Mock proving passed"); + } + + let transcript = Transcript::new(b"riscv"); + let start = std::time::Instant::now(); + let zkvm_proof = prover + .create_proof(&shard_ctx, zkvm_witness, pi, transcript) + .expect("create_proof failed"); + tracing::debug!( + "{}th shard proof created in {:?}", + shard_ctx.shard_id, + start.elapsed() + ); + proofs.push(zkvm_proof); + } + proofs + }); + } + // Fall through: GPU witgen enabled → sequential path below. + } + + // Sequential: witgen → prove, one shard at a time. + // Used by: GPU witgen mode (CENO_GPU_ENABLE_WITGEN=1) and CPU-only builds. { - // Generate witness let wit_iter = generate_witness( &ctx.system_config, emulation_result, @@ -2179,11 +2181,12 @@ fn create_proofs_streaming< target_shard_id, ); - let wit_iter = if let Some(target_shard_id) = target_shard_id { - Box::new(wit_iter.skip(target_shard_id)) as Box> - } else { - Box::new(wit_iter) - }; + let wit_iter: Box> = + if let Some(target_shard_id) = target_shard_id { + Box::new(wit_iter.skip(target_shard_id)) + } else { + Box::new(wit_iter) + }; wit_iter .map(|(zkvm_witness, shard_ctx, pi)| { @@ -2209,7 +2212,6 @@ fn create_proofs_streaming< shard_ctx.shard_id, start.elapsed() ); - // only show e2e stats in cpu mode tracing::info!("e2e proof stat: {}", zkvm_proof); zkvm_proof }) From 4dba098cbf7a2e70d71a46a79b81da846168211f Mon Sep 17 00:00:00 2001 From: Velaciela Date: Wed, 25 Mar 2026 09:59:52 +0800 Subject: [PATCH 71/73] README.md --- ceno_zkvm/src/instructions/gpu/README.md | 214 +++++++++++++++++++++++ 1 file changed, 214 insertions(+) create mode 100644 ceno_zkvm/src/instructions/gpu/README.md diff --git a/ceno_zkvm/src/instructions/gpu/README.md b/ceno_zkvm/src/instructions/gpu/README.md new file mode 100644 index 000000000..565b73f46 --- /dev/null +++ b/ceno_zkvm/src/instructions/gpu/README.md @@ -0,0 +1,214 @@ +# GPU Witness Generation + +Accelerate witness generation by offloading computation from CPU to GPU. +This module (`ceno_zkvm/src/instructions/gpu/`) contains all GPU-side dispatch, +caching, and utility code for the witness generation pipeline. + +The CUDA backend lives in the sibling repo `ceno-gpu/` (`cuda_hal/src/common/witgen/`). + +## Architecture + +### Module Layout + +``` +gpu/ +├── dispatch.rs — GPU dispatch entry point (try_gpu_assign_instances, gpu_fill_witness) +├── config.rs — Environment variable config (3 env vars), kind tags +├── cache.rs — Thread-local device buffer caching, shared EC/addr buffers +├── chips/ — Per-chip column map extractors + chip-specific GPU dispatch +│ ├── add.rs ... sw.rs (24 RV32IM column map extractors) +│ ├── keccak.rs (column map + keccak GPU dispatch: gpu_assign_keccak_instances) +│ └── shard_ram.rs (column map + batch EC computation: gpu_batch_continuation_ec) +├── utils/ +│ ├── column_map.rs — Shared column map extraction helpers (extract_rs1, extract_rd, ...) +│ ├── d2h.rs — Device-to-host: witness transpose, LK counter decode, compact EC D2H +│ ├── debug_compare.rs— GPU vs CPU comparison (activated by CENO_GPU_DEBUG_COMPARE_WITGEN) +│ ├── lk_ops.rs — LkOp enum, SendEvent struct +│ ├── sink.rs — LkShardramSink trait, CpuLkShardramSink +│ ├── emit.rs — Emit helper functions (emit_u16_limbs, emit_logic_u8_ops, ...) +│ ├── fallback.rs — CPU fallback: cpu_assign_instances, cpu_collect_lk_and_shardram +│ └── test_helpers.rs — Test utilities: assert_witness_colmajor_eq, assert_full_gpu_pipeline +└── mod.rs — Module declarations + lk_shardram integration tests (19 tests) +``` + +### Data Flow + +``` + Pass 1: PreflightTracer + ┌──────────────────────┐ + │ ShardPlanBuilder │ → shard boundaries + │ PackedNextAccessEntry│ → sorted future-access table + └──────────┬───────────┘ + │ + Pass 2: FullTracer (per shard) + ┌──────────▼───────────┐ + │ Vec │ 136 bytes/step, #[repr(C)] + └──────────┬───────────┘ + │ H2D (cached per shard in cache.rs) + ┌──────────▼───────────────────────────────────┐ + │ GPU Per-Instruction │ + │ ┌─────────────┬──────────────┬────────────┐ │ + │ │ F-1 Witness │ F-2 LK Count │ F-3 EC/Addr│ │ + │ │ (col-major) │ (atomics) │ (shared buf)│ │ + │ └──────┬──────┴──────┬───────┴─────┬──────┘ │ + └─────────┼─────────────┼─────────────┼────────┘ + │ │ │ + GPU transpose D2H counters flush at shard end + │ │ │ + ┌─────────▼─────────────▼─────────────▼────────┐ + │ CPU Merge │ + │ RowMajorMatrix LkMultiplicity ShardContext │ + └──────────────────────┬───────────────────────┘ + │ + ┌──────────────────────▼───────────────────────┐ + │ ShardRamCircuit (GPU) │ + │ Phase 1: per-row Poseidon2 (344 cols) │ + │ Phase 2: binary EC tree (layer-by-layer) │ + └──────────────────────┬───────────────────────┘ + │ + ▼ + Proof Generation +``` + +### Per-Shard Pipeline + +Within `generate_witness()` (e2e.rs), each shard executes: + +1. **upload_shard_steps_cached** — H2D `Vec` (cached, shared across all chips) +2. **ensure_shard_metadata_cached** — H2D shard scalars + allocate shared EC/addr buffers +3. **Per-chip dispatch** — `gpu_fill_witness` matches `GpuWitgenKind` → 22 kernel variants + - Each kernel writes: witness columns (col-major), LK counters (atomics), EC records + addr (shared buffers) +4. **flush_shared_ec_buffers** — D2H shared EC records + addr_accessed into `ShardContext` +5. **invalidate_shard_steps_cache** — Free GPU shard_steps memory +6. **assign_shared_circuit** — ShardRamCircuit GPU pipeline (Poseidon2 + EC tree) + +### GPU/CPU Decision (dispatch.rs) + +``` +try_gpu_assign_instances(): + 1. is_gpu_witgen_enabled()? → CPU fallback if not set + 2. is_force_cpu_path() thread-local? → CPU fallback (debug comparison) + 3. I::GPU_LK_SHARDRAM == false? → CPU fallback + 4. is_kind_disabled(kind)? → CPU fallback + 5. Field != BabyBear? → CPU fallback + 6. get_cuda_hal() unavailable? → CPU fallback + 7. All pass → GPU path +``` + +### Keccak Dispatch + +Keccak has a dedicated GPU dispatch path (`chips/keccak.rs::gpu_assign_keccak_instances`) +separate from `try_gpu_assign_instances` because: +1. **Rotation**: each instance spans 32 rows (not 1), requiring `new_by_rotation` +2. **Structural witness**: 3 selectors (sel_first/sel_last/sel_all) vs the standard 1 +3. **Input packing**: needs `packed_instances` with `syscall_witnesses` + +The LK/shardram collection logic is identical to the standard path. + +### Lk and Shardram Collection + +After GPU computes the witness matrix, LK multiplicities and shard RAM records +are collected through one of several paths (priority order): + +| Path | Witness | LK Multiplicity | Shard Records | When | +|------|---------|-----------------|---------------|------| +| **A** Shared buffer | GPU | GPU counters → D2H | Shared GPU buffer (deferred) | Default for all verified kinds | +| **B** Compact EC | GPU | GPU counters → D2H | Compact EC D2H per-kernel | Older non-shared-buffer kinds | +| **C** CPU shardram | GPU | GPU counters → D2H | CPU `cpu_collect_shardram` | GPU shard unverified | +| **D** CPU full | GPU | CPU `cpu_collect_lk_and_shardram` | CPU full | GPU LK unverified | +| **E** CPU only | CPU | CPU `assign_instance` | CPU `assign_instance` | GPU unavailable | + +Currently all non-Keccak kinds use **Path A**. Paths B-E are fallback/debug paths. + +## E2E Pipeline Modes (e2e.rs) + +``` +create_proofs_streaming() +│ +├─ Default GPU backend (CENO_GPU_ENABLE_WITGEN unset): +│ Overlap pipeline: +│ Thread A (CPU): witgen(shard 0) → witgen(shard 1) → witgen(shard 2) → ... +│ Thread B (GPU): ................prove(shard 0) → prove(shard 1) → ... +│ crossbeam::bounded(0) rendezvous channel for back-pressure +│ +└─ CENO_GPU_ENABLE_WITGEN=1 (GPU witgen) or CPU-only build: + Sequential pipeline: + witgen(shard 0) → prove(shard 0) → witgen(shard 1) → prove(shard 1) → ... + GPU shared between witgen and proving; no overlap possible. +``` + +## Environment Variables + +| Variable | Default | Purpose | +|----------|---------|---------| +| `CENO_GPU_ENABLE_WITGEN` | unset (CPU witgen) | Set to enable GPU witness generation. Sequential witgen+prove pipeline. | +| `CENO_GPU_DISABLE_WITGEN_KINDS` | none | Comma-separated kind tags to disable specific chips' GPU path. Example: `add,keccak,lw`. Falls back to CPU for those chips. | +| `CENO_GPU_DEBUG_COMPARE_WITGEN` | unset | Enable GPU vs CPU comparison for all chips. Runs both paths and diffs results. | + +### `CENO_GPU_DEBUG_COMPARE_WITGEN` Coverage + +When set, the following comparisons run automatically: + +**Per-chip (in dispatch.rs, for each opcode circuit):** +- `debug_compare_final_lk` — GPU LK multiplicity vs CPU `assign_instance` baseline (all 8 lookup tables) +- `debug_compare_witness` — GPU witness matrix vs CPU witness (element-by-element, col-major vs row-major) +- `debug_compare_shardram` — GPU shard records (read_records, write_records, addr_accessed) vs CPU +- `debug_compare_shard_ec` — GPU compact EC records vs CPU-computed EC points (nonce, x[7], y[7]) + +**Per-chip, Keccak-specific (in chips/keccak.rs):** +- `debug_compare_keccak` — Combined witness + LK + shard comparison for keccak's rotation-aware layout + +**Per-shard, E2E level (in e2e.rs):** +- `log_shard_ctx_diff` — Full shard context comparison after all opcode circuits (addr_accessed, read/write records across all chips merged) +- `log_combined_lk_diff` — Merged LK multiplicities after `finalize_lk_multiplicities()` (catches cross-chip merge issues) + +All comparisons output to stderr via `eprintln!` / `tracing::error!`, with a default limit of 16 mismatches per category. + +## Tests + +**79 tests total** (`cargo test --features gpu,u16limb_circuit -p ceno_zkvm --lib -- "gpu"`) + +| Category | Count | Location | What it tests | +|----------|------:|----------|---------------| +| Column map extraction | 33 | `chips/*.rs` (31 via `test_colmap!` macro + 2 manual) | Circuit config → column map: all IDs in-range and unique | +| GPU witgen correctness | 23 | `chips/*.rs` | GPU kernel output vs CPU `assign_instance` (element-by-element witness comparison) | +| LK+shardram match | 19 | `gpu/mod.rs` | `collect_lk_and_shardram` / `collect_shardram` vs `assign_instance` baseline | +| LkOp encoding | 1 | `utils/mod.rs` | `LkOp::encode_all()` produces correct table/key pairs | +| EC point match | 1 | `scheme/septic_curve.rs` | GPU Poseidon2+SepticCurve EC point vs CPU `to_ec_point` | +| Poseidon2 sponge | 1 | `scheme/septic_curve.rs` | GPU Poseidon2 permutation vs CPU | +| Septic from_x | 1 | `scheme/septic_curve.rs` | GPU `septic_point_from_x` vs CPU | + +### Running Tests + +```bash +# All GPU tests (requires CUDA device) +CENO_GPU_ENABLE_WITGEN=1 cargo test --features gpu,u16limb_circuit -p ceno_zkvm --lib -- "gpu" + +# Column map tests only (no CUDA device needed) +cargo test --features gpu,u16limb_circuit -p ceno_zkvm --lib -- "test_extract_" + +# LK/shardram tests only (no CUDA device needed) +cargo test --features gpu,u16limb_circuit -p ceno_zkvm --lib -- "lk_shardram" + +# With debug comparison enabled +CENO_GPU_ENABLE_WITGEN=1 CENO_GPU_DEBUG_COMPARE_WITGEN=1 cargo test --features gpu,u16limb_circuit -p ceno_host -- test_elf +``` + +## Per-Chip Boilerplate Macros + +Three macros in `instructions.rs` reduce per-chip GPU integration to ~3 lines: + +```rust +impl Instruction for MyChip { + // Emit LK ops + shard RAM records (CPU companion for GPU witgen) + impl_collect_lk_and_shardram!(r_insn, |sink, step, _config, _ctx| { + emit_u16_limbs(sink, step.rd().unwrap().value.after); + }); + + // Collect shard RAM records only (when GPU handles LK) + impl_collect_shardram!(r_insn); + + // GPU dispatch: try GPU → fallback CPU + impl_gpu_assign!(dispatch::GpuWitgenKind::Add); +} +``` From e7bb55b9666ec79ee463d6ce8ac1fe21969a863f Mon Sep 17 00:00:00 2001 From: Velaciela Date: Wed, 25 Mar 2026 11:47:09 +0800 Subject: [PATCH 72/73] invasive_changes.md --- .../src/instructions/gpu/invasive_changes.md | 234 ++++++++++++++++++ 1 file changed, 234 insertions(+) create mode 100644 ceno_zkvm/src/instructions/gpu/invasive_changes.md diff --git a/ceno_zkvm/src/instructions/gpu/invasive_changes.md b/ceno_zkvm/src/instructions/gpu/invasive_changes.md new file mode 100644 index 000000000..fb79f22ac --- /dev/null +++ b/ceno_zkvm/src/instructions/gpu/invasive_changes.md @@ -0,0 +1,234 @@ +# GPU Witness Generation — Invasive Changes to Existing Codebase + +This document lists all changes to **existing** ceno structures, traits, and flows +that this PR introduces. GPU-only new code (`instructions/gpu/`) is excluded — +this focuses on what existing code was modified and why. + +--- + +## 1. `ceno_emul` — FFI Layout Changes (+332 / -88 lines) + +### `#[repr(C)]` on emulator types + +The following types were made `#[repr(C)]` to enable zero-copy H2D transfer to GPU: + +| Type | File | Size | Purpose | +|------|------|------|---------| +| `StepRecord` | `tracer.rs` | 136B | Per-step emulator output, bulk H2D | +| `Instruction` | `rv32im.rs` | 12B | Opcode encoding embedded in StepRecord | +| `InsnKind` | `rv32im.rs` | 1B | `#[repr(u8)]` enum discriminant | +| `MemOp` | `tracer.rs` | 16/24B | Read/Write ops embedded in StepRecord | +| `Change` | `tracer.rs` | 2×T | Before/after pair | + +**Impact**: These were previously `#[derive(Debug, Clone)]` with compiler-chosen layout. +Adding `#[repr(C)]` pins field order and padding. No behavioral change for CPU code, +but **field reordering or insertion now requires updating the CUDA mirror structs**. + +### New types in `tracer.rs` + +- `PackedNextAccessEntry` (16B, `#[repr(C)]`) — 40-bit packed cycle+addr for GPU FA table +- `ShardPlanBuilder` — preflight shard planning with cell-count balancing + +### Layout test + +`test_step_record_layout_for_gpu` verifies byte offsets of all `StepRecord` fields +at compile time. CUDA side has matching `static_assert(sizeof(...))`. + +--- + +## 2. `Instruction` Trait — New Methods and Constants + +**File**: `ceno_zkvm/src/instructions.rs` + +| Addition | Purpose | +|----------|---------| +| `const GPU_LK_SHARDRAM: bool = false` | Opt-in flag: does this chip have GPU LK+shardram support? | +| `fn collect_lk_and_shardram(...)` | CPU companion: collect all LK multiplicities + shard RAM records (without witness replay) | +| `fn collect_shardram(...)` | CPU companion: collect shard RAM records only (GPU handles LK) | + +**Default implementations** return `Err(...)` — chips must explicitly opt in. + +**Impact**: Existing chips that don't implement GPU support are unaffected (defaults). +The trait's existing `assign_instance` and `assign_instances` are unchanged. + +Three macros reduce per-chip boilerplate: +- `impl_collect_lk_and_shardram!` — wraps the unsafe `CpuLkShardramSink` prologue +- `impl_collect_shardram!` — one-line delegate to insn_config +- `impl_gpu_assign!` — `#[cfg(feature = "gpu")] assign_instances` override + +--- + +## 3. Gadgets — New `emit_lk_and_shardram` / `emit_shardram` Methods + +**File**: `ceno_zkvm/src/instructions/riscv/insn_base.rs` (+253 lines) + +Every base gadget (`ReadRS1`, `ReadRS2`, `WriteRD`, `ReadMEM`, `WriteMEM`, `MemAddr`) +gained two new methods: + +| Method | What it does | +|--------|-------------| +| `emit_lk_and_shardram(sink, ctx, step)` | Emit LK ops + RAM send events through `LkShardramSink` | +| `emit_shardram(shard_ctx, step)` | Directly write shard RAM records to `ShardContext` (no LK) | + +**Impact**: Additive only — existing `assign_instance` methods are unchanged. +The new methods extract the same logic that `assign_instance` performed inline, +but route through the `LkShardramSink` trait instead of directly calling +`lk_multiplicity.assert_ux(...)`. + +### Intermediate configs (`r_insn.rs`, `i_insn.rs`, `b_insn.rs`, `s_insn.rs`, `j_insn.rs`, `im_insn.rs`) + +Each gained corresponding `emit_lk_and_shardram` / `emit_shardram` methods that +compose their gadgets' methods + emit `LkOp::Fetch`. + +--- + +## 4. Per-Chip Circuit Files — GPU Opt-in (+792 / -129 lines across ~20 files) + +Each v2 circuit file (arith.rs, logic_circuit.rs, div_circuit_v2.rs, etc.) gained: + +```rust +const GPU_LK_SHARDRAM: bool = true; // or conditional match + +impl_collect_lk_and_shardram!(r_insn, |sink, step, _config, _ctx| { + // chip-specific LK ops +}); +impl_collect_shardram!(r_insn); +impl_gpu_assign!(dispatch::GpuWitgenKind::Add); +``` + +**Impact**: Additive — existing `assign_instance` and `construct_circuit` unchanged. +The `#[cfg(feature = "gpu")] assign_instances` override is only compiled with the +`gpu` feature flag. + +--- + +## 5. `ShardContext` — New Fields and Methods + +**File**: `ceno_zkvm/src/e2e.rs` (+616 / -199 lines) + +### New fields + +| Field | Type | Purpose | +|-------|------|---------| +| `sorted_next_accesses` | `Arc` | Pre-sorted packed future-access table for GPU bulk H2D | +| `gpu_ec_records` | `Vec` | Raw bytes of GPU-produced compact EC shard records | +| `syscall_witnesses` | `Arc>` | Keccak syscall data (previously passed separately) | + +### New methods + +| Method | Purpose | +|--------|---------| +| `new_empty_like()` | Clone shard metadata with empty record storage (for debug comparison) | +| `insert_read_record()` / `insert_write_record()` | Direct record insertion (GPU D2H path) | +| `push_addr_accessed()` | Direct addr insertion (GPU D2H path) | +| `extend_gpu_ec_records_raw()` | Append raw GPU EC record bytes | +| `has_gpu_ec_records()` / `take_gpu_ec_records()` | GPU EC record lifecycle | + +### Renamed method + +`send()` → split into `record_send_without_touch()` (no addr_accessed tracking) and +`send()` (which calls `record_send_without_touch` + `push_addr_accessed`). + +### Pipeline hooks (in `generate_witness` shard loop) + +```rust +#[cfg(feature = "gpu")] +flush_shared_ec_buffers(&mut shard_ctx); // D2H shared GPU buffers + +#[cfg(feature = "gpu")] +invalidate_shard_steps_cache(); // free GPU memory +``` + +### Pipeline mode (in `create_proofs_streaming`) + +New overlap pipeline (default when GPU feature enabled but `CENO_GPU_ENABLE_WITGEN` unset): +CPU witgen on thread A, GPU prove on thread B, connected by `crossbeam::bounded(0)` channel. + +--- + +## 6. `ZKVMWitnesses` — GPU ShardRam Pipeline + +**File**: `ceno_zkvm/src/structs.rs` (+580 / -130 lines) + +### `assign_shared_circuit` — new GPU fast path + +Added `try_assign_shared_circuit_gpu()` that keeps data on GPU device: +1. Takes shared device buffers (EC records + addr_accessed) +2. GPU sort+dedup addr_accessed +3. GPU batch EC computation for continuation records +4. GPU merge+partition records (writes before reads) +5. GPU ShardRamCircuit witness generation (Poseidon2 + EC tree) + +Falls back to CPU path on failure. + +### `gpu_ec_records_to_shard_ram_inputs` + +Converts raw GPU EC bytes (`Vec`) to `Vec>` with pre-computed +EC points. Used in the CPU fallback path. + +--- + +## 7. `ShardRamCircuit` — GPU Witness Generation + +**File**: `ceno_zkvm/src/tables/shard_ram.rs` (+491 / -14 lines) + +### New GPU functions + +| Function | Purpose | +|----------|---------| +| `try_gpu_assign_instances()` | H2D path: CPU records → GPU kernel → D2H witness | +| `try_gpu_assign_instances_from_device()` | Device path: records already on GPU → kernel → D2H | + +Both run a two-phase GPU pipeline: +1. **Per-row kernel**: basic fields + Poseidon2 trace (344 witness columns) +2. **EC tree kernel**: layer-by-layer binary tree EC summation + +### Visibility change + +`ShardRamConfig` fields changed from private to `pub(crate)` to allow +column map extraction in `gpu/chips/shard_ram.rs`. + +--- + +## 8. `SepticCurve` — New Math Utilities + +**File**: `ceno_zkvm/src/scheme/septic_curve.rs` (+307 lines) + +New CPU-side math for EC point computation (mirrored in CUDA): + +| Function | Purpose | +|----------|---------| +| `SepticExtension::frobenius()` | Frobenius endomorphism for norm computation | +| `SepticExtension::sqrt()` | Cipolla's algorithm for field square roots | +| `SepticPoint::from_x()` | Lift x-coordinate to curve point (used by nonce-finding loop) | +| `QuadraticExtension` | Auxiliary type for Cipolla's algorithm | + +--- + +## 9. Minor Touches + +| File | Change | +|------|--------| +| `Cargo.toml` | `gpu` feature flag, `crossbeam` dependency | +| `gkr_iop/src/gadgets/is_lt.rs` | `AssertLtConfig.0.diff` field access (already `pub`) | +| `gkr_iop/src/utils/lk_multiplicity.rs` | Minor: `LkMultiplicity::increment` | +| `ceno_zkvm/src/gadgets/signed_ext.rs` | `pub(crate) fn msb()` accessor for GPU column map | +| `ceno_zkvm/src/gadgets/poseidon2.rs` | Column contiguity constants for GPU | +| `ceno_zkvm/src/tables/*.rs` | `pub(crate)` visibility on config fields for GPU column map access | +| `ceno_zkvm/src/scheme/{cpu,gpu,prover,verifier}` | Minor plumbing for GPU proving path | +| `ceno_host/tests/test_elf.rs` | E2E test adjustments | + +--- + +## Summary + +| Category | Nature | Risk | +|----------|--------|------| +| `#[repr(C)]` on emulator types | Layout pinning | Low — additive, but field changes now need CUDA sync | +| `Instruction` trait extensions | Additive (defaults provided) | None — existing chips unaffected | +| Gadget `emit_*` methods | Additive | None — existing `assign_instance` unchanged | +| `ShardContext` new fields | Additive (defaults in `Default`) | Low — `Vec::new()` / `Arc::new()` zero-cost | +| `send()` → `record_send_without_touch()` + `send()` | Rename + split | Low — `send()` still works identically | +| `ShardRamConfig` visibility | `private` → `pub(crate)` | None | +| Pipeline overlap mode | New default behavior | Medium — CPU witgen + GPU prove on separate threads | +| `septic_curve.rs` math | Additive | None — new functions, existing unchanged | From e2b8ae7f95fb3d382d64cd69494ca396bfd0d3e7 Mon Sep 17 00:00:00 2001 From: Velaciela Date: Wed, 25 Mar 2026 11:49:07 +0800 Subject: [PATCH 73/73] revert: local path --- Cargo.lock | 123 ++---------------- Cargo.toml | 28 ++-- .../src/instructions/gpu/chips/keccak.rs | 14 +- ceno_zkvm/src/structs.rs | 12 +- 4 files changed, 44 insertions(+), 133 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 883f5dc37..e9b6acddc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1598,48 +1598,10 @@ version = "0.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e2931af7e13dc045d8e9d26afccc6fa115d64e115c9c84b1166288b46f6782c2" -[[package]] -name = "cuda-config" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ee74643f7430213a1a78320f88649de309b20b80818325575e393f848f79f5d" -dependencies = [ - "glob", -] - -[[package]] -name = "cuda-runtime-sys" -version = "0.3.0-alpha.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d070b301187fee3c611e75a425cf12247b7c75c09729dbdef95cb9cb64e8c39" -dependencies = [ - "cuda-config", -] - [[package]] name = "cuda_hal" version = "0.1.0" -dependencies = [ - "anyhow", - "cuda-runtime-sys", - "cudarc", - "downcast-rs", - "ff_ext", - "itertools 0.13.0", - "mpcs", - "multilinear_extensions", - "p3", - "rand 0.8.5", - "rayon", - "sha2", - "sppark", - "sppark_plug", - "sumcheck", - "thiserror 1.0.69", - "tracing", - "transcript", - "witness", -] +source = "git+https://github.com/scroll-tech/ceno-gpu-mock.git?branch=main#fe8f7923b7d3a3823c27949fab0aab8e31011aa9" [[package]] name = "cudarc" @@ -2273,6 +2235,7 @@ dependencies = [ [[package]] name = "ff_ext" version = "0.1.0" +source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.22#a140b93cf80109ef86c2327b9a940d4cace83628" dependencies = [ "once_cell", "p3", @@ -2706,15 +2669,6 @@ dependencies = [ "digest 0.10.7", ] -[[package]] -name = "home" -version = "0.5.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc627f471c528ff0c4a49e1d5e60450c8f6461dd6d10ba9dcd3a61d3dff7728d" -dependencies = [ - "windows-sys 0.61.1", -] - [[package]] name = "iana-time-zone" version = "0.1.64" @@ -3146,12 +3100,6 @@ dependencies = [ "vcpkg", ] -[[package]] -name = "linux-raw-sys" -version = "0.4.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d26c52dbd32dccf2d10cac7725f8eae5296885fb5703b261f7d0a0739ec807ab" - [[package]] name = "linux-raw-sys" version = "0.9.4" @@ -3293,6 +3241,7 @@ dependencies = [ [[package]] name = "mpcs" version = "0.1.0" +source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.22#a140b93cf80109ef86c2327b9a940d4cace83628" dependencies = [ "bincode 1.3.3", "clap", @@ -3316,6 +3265,7 @@ dependencies = [ [[package]] name = "multilinear_extensions" version = "0.1.0" +source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.22#a140b93cf80109ef86c2327b9a940d4cace83628" dependencies = [ "either", "ff_ext", @@ -4606,6 +4556,7 @@ dependencies = [ [[package]] name = "p3" version = "0.1.0" +source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.22#a140b93cf80109ef86c2327b9a940d4cace83628" dependencies = [ "p3-air", "p3-baby-bear", @@ -5173,6 +5124,7 @@ dependencies = [ [[package]] name = "poseidon" version = "0.1.0" +source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.22#a140b93cf80109ef86c2327b9a940d4cace83628" dependencies = [ "ff_ext", "p3", @@ -5770,19 +5722,6 @@ dependencies = [ "semver 1.0.26", ] -[[package]] -name = "rustix" -version = "0.38.44" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fdb5bc1ae2baa591800df16c9ca78619bf65c0488b41b96ccec5d11220d8c154" -dependencies = [ - "bitflags", - "errno", - "libc", - "linux-raw-sys 0.4.15", - "windows-sys 0.59.0", -] - [[package]] name = "rustix" version = "1.0.7" @@ -5792,7 +5731,7 @@ dependencies = [ "bitflags", "errno", "libc", - "linux-raw-sys 0.9.4", + "linux-raw-sys", "windows-sys 0.59.0", ] @@ -6142,6 +6081,7 @@ dependencies = [ [[package]] name = "sp1-curves" version = "0.1.0" +source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.22#a140b93cf80109ef86c2327b9a940d4cace83628" dependencies = [ "cfg-if", "dashu", @@ -6176,25 +6116,6 @@ dependencies = [ "der", ] -[[package]] -name = "sppark" -version = "0.1.11" -dependencies = [ - "cc", - "which", -] - -[[package]] -name = "sppark_plug" -version = "0.1.0" -dependencies = [ - "cc", - "ff_ext", - "itertools 0.13.0", - "p3", - "sppark", -] - [[package]] name = "stable_deref_trait" version = "1.2.0" @@ -6285,6 +6206,7 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "sumcheck" version = "0.1.0" +source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.22#a140b93cf80109ef86c2327b9a940d4cace83628" dependencies = [ "either", "ff_ext", @@ -6302,6 +6224,7 @@ dependencies = [ [[package]] name = "sumcheck_macro" version = "0.1.0" +source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.22#a140b93cf80109ef86c2327b9a940d4cace83628" dependencies = [ "itertools 0.13.0", "p3", @@ -6382,7 +6305,7 @@ dependencies = [ "fastrand", "getrandom 0.3.2", "once_cell", - "rustix 1.0.7", + "rustix", "windows-sys 0.59.0", ] @@ -6708,6 +6631,7 @@ dependencies = [ [[package]] name = "transcript" version = "0.1.0" +source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.22#a140b93cf80109ef86c2327b9a940d4cace83628" dependencies = [ "ff_ext", "itertools 0.13.0", @@ -6998,21 +6922,10 @@ dependencies = [ "wasm-bindgen", ] -[[package]] -name = "which" -version = "4.4.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87ba24419a2078cd2b0f2ede2691b6c66d8e47836da3b6db8265ebad47afbfc7" -dependencies = [ - "either", - "home", - "once_cell", - "rustix 0.38.44", -] - [[package]] name = "whir" version = "0.1.0" +source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.22#a140b93cf80109ef86c2327b9a940d4cace83628" dependencies = [ "bincode 1.3.3", "clap", @@ -7140,15 +7053,6 @@ dependencies = [ "windows-targets 0.53.4", ] -[[package]] -name = "windows-sys" -version = "0.61.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6f109e41dd4a3c848907eb83d5a42ea98b3769495597450cf6d153507b166f0f" -dependencies = [ - "windows-link", -] - [[package]] name = "windows-targets" version = "0.52.6" @@ -7308,6 +7212,7 @@ dependencies = [ [[package]] name = "witness" version = "0.1.0" +source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.22#a140b93cf80109ef86c2327b9a940d4cace83628" dependencies = [ "ff_ext", "multilinear_extensions", diff --git a/Cargo.toml b/Cargo.toml index 09e299842..ec8834541 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -127,20 +127,20 @@ lto = "thin" #ceno_crypto_primitives = { path = "../ceno-patch/crypto-primitives", package = "ceno_crypto_primitives" } #ceno_syscall = { path = "../ceno-patch/syscall", package = "ceno_syscall" } -[patch."https://github.com/scroll-tech/ceno-gpu-mock.git"] -ceno_gpu = { path = "../ceno-gpu/cuda_hal", package = "cuda_hal", default-features = false, features = ["bb31"] } - -[patch."https://github.com/scroll-tech/gkr-backend"] -ff_ext = { path = "../gkr-backend/crates/ff_ext", package = "ff_ext" } -mpcs = { path = "../gkr-backend/crates/mpcs", package = "mpcs" } -multilinear_extensions = { path = "../gkr-backend/crates/multilinear_extensions", package = "multilinear_extensions" } -p3 = { path = "../gkr-backend/crates/p3", package = "p3" } -poseidon = { path = "../gkr-backend/crates/poseidon", package = "poseidon" } -sp1-curves = { path = "../gkr-backend/crates/curves", package = "sp1-curves" } -sumcheck = { path = "../gkr-backend/crates/sumcheck", package = "sumcheck" } -transcript = { path = "../gkr-backend/crates/transcript", package = "transcript" } -whir = { path = "../gkr-backend/crates/whir", package = "whir" } -witness = { path = "../gkr-backend/crates/witness", package = "witness" } +# [patch."https://github.com/scroll-tech/ceno-gpu-mock.git"] +# ceno_gpu = { path = "../ceno-gpu/cuda_hal", package = "cuda_hal", default-features = false, features = ["bb31"] } + +# [patch."https://github.com/scroll-tech/gkr-backend"] +# ff_ext = { path = "../gkr-backend/crates/ff_ext", package = "ff_ext" } +# mpcs = { path = "../gkr-backend/crates/mpcs", package = "mpcs" } +# multilinear_extensions = { path = "../gkr-backend/crates/multilinear_extensions", package = "multilinear_extensions" } +# p3 = { path = "../gkr-backend/crates/p3", package = "p3" } +# poseidon = { path = "../gkr-backend/crates/poseidon", package = "poseidon" } +# sp1-curves = { path = "../gkr-backend/crates/curves", package = "sp1-curves" } +# sumcheck = { path = "../gkr-backend/crates/sumcheck", package = "sumcheck" } +# transcript = { path = "../gkr-backend/crates/transcript", package = "transcript" } +# whir = { path = "../gkr-backend/crates/whir", package = "whir" } +# witness = { path = "../gkr-backend/crates/witness", package = "witness" } # [patch."https://github.com/scroll-tech/openvm.git"] # openvm = { path = "../openvm-scroll-tech/crates/toolchain/openvm", default-features = false } diff --git a/ceno_zkvm/src/instructions/gpu/chips/keccak.rs b/ceno_zkvm/src/instructions/gpu/chips/keccak.rs index 7c8200ce2..edca0b680 100644 --- a/ceno_zkvm/src/instructions/gpu/chips/keccak.rs +++ b/ceno_zkvm/src/instructions/gpu/chips/keccak.rs @@ -438,13 +438,17 @@ fn gpu_assign_keccak_inner( ) }; - Ok::<_, ZKVMError>(RowMajorMatrix::::from_values_with_rotation( - data, - num_witin, - rotation, + // Construct a rotation-aware matrix and fill with GPU-transposed data. + // new_by_rotation allocates the correct padded size; we overwrite the real portion. + let mut rmm = RowMajorMatrix::::new_by_rotation( num_instances, + rotation, + num_witin, InstancePaddingStrategy::Default, - )) + ); + // Access inner p3 matrix's values via DerefMut + std::ops::DerefMut::deref_mut(&mut rmm).values[..data.len()].copy_from_slice(&data); + Ok::<_, ZKVMError>(rmm) })?; // Step 9: Build structural witness on CPU with selector indices diff --git a/ceno_zkvm/src/structs.rs b/ceno_zkvm/src/structs.rs index 8e2e4a390..06b0030c8 100644 --- a/ceno_zkvm/src/structs.rs +++ b/ceno_zkvm/src/structs.rs @@ -480,9 +480,9 @@ impl ZKVMWitnesses { use tracing::info_span; // Try the full GPU pipeline: keep data on device, minimal CPU roundtrips. - // Falls back to the traditional path on failure. + // Only when GPU witgen is enabled (otherwise witgen must not touch GPU). #[cfg(feature = "gpu")] - { + if crate::instructions::gpu::config::is_gpu_witgen_enabled() { let gpu_result = self.try_assign_shared_circuit_gpu(cs, shard_ctx, final_mem, config); match gpu_result { Ok(true) => return Ok(()), // GPU pipeline succeeded @@ -582,12 +582,14 @@ impl ZKVMWitnesses { (write_record_pairs, read_record_pairs) }); - // Compute EC points: GPU path (fast) or CPU fallback + // Compute EC points: GPU path (only when GPU witgen enabled) or CPU fallback let global_input = { #[cfg(feature = "gpu")] - let ec_result = { - use crate::instructions::gpu::dispatch::gpu_batch_continuation_ec; + let ec_result = if crate::instructions::gpu::config::is_gpu_witgen_enabled() { + use crate::instructions::gpu::chips::shard_ram::gpu_batch_continuation_ec; gpu_batch_continuation_ec::(&write_record_pairs, &read_record_pairs).ok() + } else { + None }; #[cfg(not(feature = "gpu"))] let ec_result: Option<(Vec>, Vec>)> = None;