diff --git a/Cargo.lock b/Cargo.lock index 2d88e5a73..e9b6acddc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2413,6 +2413,7 @@ dependencies = [ "p3", "rand 0.8.5", "rayon", + "rustc-hash", "serde", "smallvec", "strum", diff --git a/Cargo.toml b/Cargo.toml index b20888473..ec8834541 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -127,20 +127,20 @@ lto = "thin" #ceno_crypto_primitives = { path = "../ceno-patch/crypto-primitives", package = "ceno_crypto_primitives" } #ceno_syscall = { path = "../ceno-patch/syscall", package = "ceno_syscall" } -#[patch."https://github.com/scroll-tech/ceno-gpu-mock.git"] -#ceno_gpu = { path = "../ceno-gpu/cuda_hal", package = "cuda_hal", default-features = false, features = ["bb31"] } - -#[patch."https://github.com/scroll-tech/gkr-backend"] -#ff_ext = { path = "../gkr-backend/crates/ff_ext", package = "ff_ext" } -#mpcs = { path = "../gkr-backend/crates/mpcs", package = "mpcs" } -#multilinear_extensions = { path = "../gkr-backend/crates/multilinear_extensions", package = "multilinear_extensions" } -#p3 = { path = "../gkr-backend/crates/p3", package = "p3" } -#poseidon = { path = "../gkr-backend/crates/poseidon", package = "poseidon" } -#sp1-curves = { path = "../gkr-backend/crates/curves", package = "sp1-curves" } -#sumcheck = { path = "../gkr-backend/crates/sumcheck", package = "sumcheck" } -#transcript = { path = "../gkr-backend/crates/transcript", package = "transcript" } -#whir = { path = "../gkr-backend/crates/whir", package = "whir" } -#witness = { path = "../gkr-backend/crates/witness", package = "witness" } +# [patch."https://github.com/scroll-tech/ceno-gpu-mock.git"] +# ceno_gpu = { path = "../ceno-gpu/cuda_hal", package = "cuda_hal", default-features = false, features = ["bb31"] } + +# [patch."https://github.com/scroll-tech/gkr-backend"] +# ff_ext = { path = "../gkr-backend/crates/ff_ext", package = "ff_ext" } +# mpcs = { path = "../gkr-backend/crates/mpcs", package = "mpcs" } +# multilinear_extensions = { path = "../gkr-backend/crates/multilinear_extensions", package = "multilinear_extensions" } +# p3 = { path = "../gkr-backend/crates/p3", package = "p3" } +# poseidon = { path = "../gkr-backend/crates/poseidon", package = "poseidon" } +# sp1-curves = { path = "../gkr-backend/crates/curves", package = "sp1-curves" } +# sumcheck = { path = "../gkr-backend/crates/sumcheck", package = "sumcheck" } +# transcript = { path = "../gkr-backend/crates/transcript", package = "transcript" } +# whir = { path = "../gkr-backend/crates/whir", package = "whir" } +# witness = { path = "../gkr-backend/crates/witness", package = "witness" } # [patch."https://github.com/scroll-tech/openvm.git"] # openvm = { path = "../openvm-scroll-tech/crates/toolchain/openvm", default-features = false } diff --git a/ceno_emul/src/addr.rs b/ceno_emul/src/addr.rs index 24a9ceea2..14c938e07 100644 --- a/ceno_emul/src/addr.rs +++ b/ceno_emul/src/addr.rs @@ -30,12 +30,14 @@ pub type Word = u32; pub type SWord = i32; pub type Addr = u32; pub type Cycle = u64; -pub type RegIdx = usize; +pub type RegIdx = u8; #[derive(Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)] +#[repr(C)] pub struct ByteAddr(pub u32); #[derive(Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[repr(C)] pub struct WordAddr(pub u32); impl From for WordAddr { diff --git a/ceno_emul/src/disassemble/mod.rs b/ceno_emul/src/disassemble/mod.rs index 8332a6d6f..853617d63 100644 --- a/ceno_emul/src/disassemble/mod.rs +++ b/ceno_emul/src/disassemble/mod.rs @@ -1,4 +1,7 @@ -use crate::rv32im::{InsnKind, Instruction}; +use crate::{ + addr::RegIdx, + rv32im::{InsnKind, Instruction}, +}; use itertools::izip; use rrs_lib::{ InstructionProcessor, @@ -19,9 +22,9 @@ impl Instruction { pub const fn from_r_type(kind: InsnKind, dec_insn: &RType, raw: u32) -> Self { Self { kind, - rd: dec_insn.rd, - rs1: dec_insn.rs1, - rs2: dec_insn.rs2, + rd: dec_insn.rd as RegIdx, + rs1: dec_insn.rs1 as RegIdx, + rs2: dec_insn.rs2 as RegIdx, imm: 0, raw, } @@ -32,8 +35,8 @@ impl Instruction { pub const fn from_i_type(kind: InsnKind, dec_insn: &IType, raw: u32) -> Self { Self { kind, - rd: dec_insn.rd, - rs1: dec_insn.rs1, + rd: dec_insn.rd as RegIdx, + rs1: dec_insn.rs1 as RegIdx, imm: dec_insn.imm, rs2: 0, raw, @@ -45,8 +48,8 @@ impl Instruction { pub const fn from_i_type_shamt(kind: InsnKind, dec_insn: &ITypeShamt, raw: u32) -> Self { Self { kind, - rd: dec_insn.rd, - rs1: dec_insn.rs1, + rd: dec_insn.rd as RegIdx, + rs1: dec_insn.rs1 as RegIdx, imm: dec_insn.shamt as i32, rs2: 0, raw, @@ -59,8 +62,8 @@ impl Instruction { Self { kind, rd: 0, - rs1: dec_insn.rs1, - rs2: dec_insn.rs2, + rs1: dec_insn.rs1 as RegIdx, + rs2: dec_insn.rs2 as RegIdx, imm: dec_insn.imm, raw, } @@ -72,8 +75,8 @@ impl Instruction { Self { kind, rd: 0, - rs1: dec_insn.rs1, - rs2: dec_insn.rs2, + rs1: dec_insn.rs1 as RegIdx, + rs2: dec_insn.rs2 as RegIdx, imm: dec_insn.imm, raw, } @@ -231,7 +234,7 @@ impl InstructionProcessor for InstructionTranspiler { fn process_jal(&mut self, dec_insn: JType) -> Self::InstructionResult { Instruction { kind: InsnKind::JAL, - rd: dec_insn.rd, + rd: dec_insn.rd as RegIdx, rs1: 0, rs2: 0, imm: dec_insn.imm, @@ -242,8 +245,8 @@ impl InstructionProcessor for InstructionTranspiler { fn process_jalr(&mut self, dec_insn: IType) -> Self::InstructionResult { Instruction { kind: InsnKind::JALR, - rd: dec_insn.rd, - rs1: dec_insn.rs1, + rd: dec_insn.rd as RegIdx, + rs1: dec_insn.rs1 as RegIdx, rs2: 0, imm: dec_insn.imm, raw: self.word, @@ -265,7 +268,7 @@ impl InstructionProcessor for InstructionTranspiler { // See [`InstructionTranspiler::process_auipc`] for more background on the conversion. Instruction { kind: InsnKind::ADDI, - rd: dec_insn.rd, + rd: dec_insn.rd as RegIdx, rs1: 0, rs2: 0, imm: dec_insn.imm, @@ -276,7 +279,7 @@ impl InstructionProcessor for InstructionTranspiler { { Instruction { kind: InsnKind::LUI, - rd: dec_insn.rd, + rd: dec_insn.rd as RegIdx, rs1: 0, rs2: 0, imm: dec_insn.imm, @@ -311,7 +314,7 @@ impl InstructionProcessor for InstructionTranspiler { // real world scenarios like a `reth` run. Instruction { kind: InsnKind::ADDI, - rd: dec_insn.rd, + rd: dec_insn.rd as RegIdx, rs1: 0, rs2: 0, imm: dec_insn.imm.wrapping_add(pc as i32), @@ -322,7 +325,7 @@ impl InstructionProcessor for InstructionTranspiler { { Instruction { kind: InsnKind::AUIPC, - rd: dec_insn.rd, + rd: dec_insn.rd as RegIdx, rs1: 0, rs2: 0, imm: dec_insn.imm, diff --git a/ceno_emul/src/lib.rs b/ceno_emul/src/lib.rs index 6b16a3587..607f451e0 100644 --- a/ceno_emul/src/lib.rs +++ b/ceno_emul/src/lib.rs @@ -11,8 +11,8 @@ pub use platform::{CENO_PLATFORM, Platform}; mod tracer; pub use tracer::{ Change, FullTracer, FullTracerConfig, LatestAccesses, MemOp, NextAccessPair, NextCycleAccess, - PreflightTracer, PreflightTracerConfig, ReadOp, ShardPlanBuilder, StepCellExtractor, StepIndex, - StepRecord, Tracer, WriteOp, + PackedNextAccessEntry, PreflightTracer, PreflightTracerConfig, ReadOp, ShardPlanBuilder, + StepCellExtractor, StepIndex, StepRecord, Tracer, WriteOp, }; mod vm_state; @@ -34,7 +34,7 @@ pub use syscalls::{ BN254_FP_MUL, BN254_FP2_ADD, BN254_FP2_MUL, KECCAK_PERMUTE, SECP256K1_ADD, SECP256K1_DECOMPRESS, SECP256K1_DOUBLE, SECP256K1_SCALAR_INVERT, SECP256R1_ADD, SECP256R1_DECOMPRESS, SECP256R1_DOUBLE, SECP256R1_SCALAR_INVERT, SHA_EXTEND, SyscallSpec, - UINT256_MUL, + SyscallWitness, UINT256_MUL, bn254::{ BN254_FP_WORDS, BN254_FP2_WORDS, BN254_POINT_WORDS, Bn254AddSpec, Bn254DoubleSpec, Bn254Fp2AddSpec, Bn254Fp2MulSpec, Bn254FpAddSpec, Bn254FpMulSpec, diff --git a/ceno_emul/src/platform.rs b/ceno_emul/src/platform.rs index 75c7e8f11..4c84b96c9 100644 --- a/ceno_emul/src/platform.rs +++ b/ceno_emul/src/platform.rs @@ -134,7 +134,7 @@ impl Platform { /// Virtual address of a register. pub const fn register_vma(index: RegIdx) -> Addr { // Register VMAs are aligned, cannot be confused with indices, and readable in hex. - (index << 8) as Addr + (index as Addr) << 8 } /// Register index from a virtual address (unchecked). @@ -220,7 +220,7 @@ mod tests { // Registers do not overlap with ROM or RAM. for reg in [ Platform::register_vma(0), - Platform::register_vma(VMState::::REG_COUNT - 1), + Platform::register_vma((VMState::::REG_COUNT - 1) as RegIdx), ] { assert!(!p.is_rom(reg)); assert!(!p.is_ram(reg)); diff --git a/ceno_emul/src/rv32im.rs b/ceno_emul/src/rv32im.rs index a3eac8896..48f7beab9 100644 --- a/ceno_emul/src/rv32im.rs +++ b/ceno_emul/src/rv32im.rs @@ -29,9 +29,9 @@ use super::addr::{ByteAddr, RegIdx, WORD_SIZE, Word, WordAddr}; pub const fn encode_rv32(kind: InsnKind, rs1: u32, rs2: u32, rd: u32, imm: i32) -> Instruction { Instruction { kind, - rs1: rs1 as usize, - rs2: rs2 as usize, - rd: rd as usize, + rs1: rs1 as RegIdx, + rs2: rs2 as RegIdx, + rd: rd as RegIdx, imm, raw: 0, } @@ -43,9 +43,9 @@ pub const fn encode_rv32(kind: InsnKind, rs1: u32, rs2: u32, rd: u32, imm: i32) pub const fn encode_rv32u(kind: InsnKind, rs1: u32, rs2: u32, rd: u32, imm: u32) -> Instruction { Instruction { kind, - rs1: rs1 as usize, - rs2: rs2 as usize, - rd: rd as usize, + rs1: rs1 as RegIdx, + rs2: rs2 as RegIdx, + rd: rd as RegIdx, imm: imm as i32, raw: 0, } @@ -113,6 +113,7 @@ pub enum TrapCause { } #[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord)] +#[repr(C)] pub struct Instruction { pub kind: InsnKind, pub rs1: RegIdx, @@ -162,6 +163,7 @@ use InsnFormat::*; ToPrimitive, Default, )] +#[repr(u8)] #[allow(clippy::upper_case_acronyms)] pub enum InsnKind { #[default] @@ -425,7 +427,7 @@ fn step_compute(ctx: &mut M, kind: InsnKind, insn: &Instruction) if !new_pc.is_aligned() { return ctx.trap(TrapCause::InstructionAddressMisaligned); } - ctx.store_register(insn.rd_internal() as usize, out)?; + ctx.store_register(insn.rd_internal() as RegIdx, out)?; ctx.set_pc(new_pc); Ok(true) } @@ -502,7 +504,7 @@ fn step_load(ctx: &mut M, kind: InsnKind, decoded: &Instruction) } _ => unreachable!(), }; - ctx.store_register(decoded.rd_internal() as usize, out)?; + ctx.store_register(decoded.rd_internal() as RegIdx, out)?; ctx.set_pc(ctx.get_pc() + WORD_SIZE); Ok(true) } diff --git a/ceno_emul/src/syscalls.rs b/ceno_emul/src/syscalls.rs index 5d9674fc6..31c87b271 100644 --- a/ceno_emul/src/syscalls.rs +++ b/ceno_emul/src/syscalls.rs @@ -60,19 +60,15 @@ pub fn handle_syscall(vm: &VMState, function_code: u32) -> Result< /// A syscall event, available to the circuit witness generators. /// TODO: separate mem_ops into two stages: reads-and-writes #[derive(Clone, Debug, Default, PartialEq, Eq)] +#[non_exhaustive] pub struct SyscallWitness { pub mem_ops: Vec, pub reg_ops: Vec, - _marker: (), } impl SyscallWitness { fn new(mem_ops: Vec, reg_ops: Vec) -> SyscallWitness { - SyscallWitness { - mem_ops, - reg_ops, - _marker: (), - } + SyscallWitness { mem_ops, reg_ops } } } diff --git a/ceno_emul/src/test_utils.rs b/ceno_emul/src/test_utils.rs index 39577c13c..0fd9f03d1 100644 --- a/ceno_emul/src/test_utils.rs +++ b/ceno_emul/src/test_utils.rs @@ -1,10 +1,12 @@ use crate::{ CENO_PLATFORM, InsnKind, Instruction, Platform, Program, StepRecord, VMState, encode_rv32, - encode_rv32u, syscalls::KECCAK_PERMUTE, + encode_rv32u, + syscalls::{KECCAK_PERMUTE, SyscallWitness}, + tracer::FullTracerConfig, }; use anyhow::Result; -pub fn keccak_step() -> (StepRecord, Vec) { +pub fn keccak_step() -> (StepRecord, Vec, Vec) { let instructions = vec![ // Call Keccak-f. load_immediate(Platform::reg_arg0() as u32, CENO_PLATFORM.heap.start), @@ -23,11 +25,16 @@ pub fn keccak_step() -> (StepRecord, Vec) { instructions.clone(), Default::default(), ); - let mut vm = VMState::new(CENO_PLATFORM.clone(), program.into()); + let mut vm: VMState = VMState::new_with_tracer_config( + CENO_PLATFORM.clone(), + program.into(), + FullTracerConfig { max_step_shard: 10 }, + ); vm.iter_until_halt().collect::>>().unwrap(); let steps = vm.tracer().recorded_steps(); + let syscall_witnesses = vm.tracer().syscall_witnesses().to_vec(); - (steps[2].clone(), instructions) + (steps[2], instructions, syscall_witnesses) } const fn load_immediate(rd: u32, imm: u32) -> Instruction { diff --git a/ceno_emul/src/tracer.rs b/ceno_emul/src/tracer.rs index 45821ae64..c345b90bb 100644 --- a/ceno_emul/src/tracer.rs +++ b/ceno_emul/src/tracer.rs @@ -22,7 +22,8 @@ use std::{collections::BTreeMap, fmt, sync::Arc}; /// - Any of `rs1 / rs2 / rd` **may be `x0`**. The trace handles this like any register, including the value that was _supposed_ to be stored. The circuits must handle this case: either **store `0` or skip `x0` operations**. /// /// - Any pair of `rs1 / rs2 / rd` **may be the same**. Then, one op will point to the other op in the same instruction but a different subcycle. The circuits may follow the operations **without special handling** of repeated registers. -#[derive(Clone, Debug, Default, PartialEq, Eq)] +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +#[repr(C)] pub struct StepRecord { cycle: Cycle, pc: Change, @@ -30,14 +31,45 @@ pub struct StepRecord { pub hint_maxtouch_addr: Change, pub insn: Instruction, - rs1: Option, - rs2: Option, + has_rs1: bool, + has_rs2: bool, + has_rd: bool, + has_memory_op: bool, - rd: Option, + rs1: ReadOp, + rs2: ReadOp, + rd: WriteOp, + memory_op: WriteOp, - memory_op: Option, + /// Index into the separate syscall witness storage. + /// `u32::MAX` means no syscall for this step. + syscall_index: u32, +} - syscall: Option, +impl StepRecord { + /// Sentinel value indicating no syscall is associated with this step. + pub const NO_SYSCALL: u32 = u32::MAX; +} + +impl Default for StepRecord { + fn default() -> Self { + Self { + cycle: 0, + pc: Default::default(), + heap_maxtouch_addr: Default::default(), + hint_maxtouch_addr: Default::default(), + insn: Default::default(), + has_rs1: false, + has_rs2: false, + has_rd: false, + has_memory_op: false, + rs1: Default::default(), + rs2: Default::default(), + rd: Default::default(), + memory_op: Default::default(), + syscall_index: StepRecord::NO_SYSCALL, + } + } } pub type StepIndex = usize; @@ -54,6 +86,63 @@ pub trait StepCellExtractor { pub type NextAccessPair = SmallVec<[(WordAddr, Cycle); 1]>; pub type NextCycleAccess = FxHashMap; +/// Packed next-access entry (16 bytes, u128-aligned). +/// Stores (cycle, addr, next_cycle) with 40-bit cycles for GPU bulk H2D upload. +/// Must be layout-compatible with CUDA `PackedNextAccessEntry` in shard_helpers.cuh. +#[repr(C, align(16))] +#[derive(Debug, Clone, Copy, Default)] +pub struct PackedNextAccessEntry { + pub cycles_lo: u32, + pub addr: u32, + pub nexts_lo: u32, + pub cycles_hi: u8, + pub nexts_hi: u8, + pub _reserved: u16, +} + +impl PackedNextAccessEntry { + #[inline] + pub fn new(cycle: u64, addr: u32, next_cycle: u64) -> Self { + Self { + cycles_lo: cycle as u32, + addr, + nexts_lo: next_cycle as u32, + cycles_hi: (cycle >> 32) as u8, + nexts_hi: (next_cycle >> 32) as u8, + _reserved: 0, + } + } +} + +impl Eq for PackedNextAccessEntry {} + +impl PartialEq for PackedNextAccessEntry { + #[inline] + fn eq(&self, other: &Self) -> bool { + self.cycles_hi == other.cycles_hi + && self.cycles_lo == other.cycles_lo + && self.addr == other.addr + } +} + +impl Ord for PackedNextAccessEntry { + #[inline] + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + (self.cycles_hi, self.cycles_lo, self.addr).cmp(&( + other.cycles_hi, + other.cycles_lo, + other.addr, + )) + } +} + +impl PartialOrd for PackedNextAccessEntry { + #[inline] + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + fn init_mmio_min_max_access( platform: &Platform, ) -> BTreeMap { @@ -152,7 +241,8 @@ pub trait Tracer { ) -> Option<(WordAddr, WordAddr)>; } -#[derive(Clone, Debug, PartialEq, Eq, Hash)] +#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, Hash)] +#[repr(C)] pub struct MemOp { /// Virtual Memory Address. /// For registers, get it from `Platform::register_vma(idx)`. @@ -605,27 +695,41 @@ impl StepRecord { heap_maxtouch_addr: Change, hint_maxtouch_addr: Change, ) -> StepRecord { + let has_rs1 = rs1_read.is_some(); + let has_rs2 = rs2_read.is_some(); + let has_rd = rd.is_some(); + let has_memory_op = memory_op.is_some(); StepRecord { cycle, pc, - rs1: rs1_read.map(|rs1| ReadOp { - addr: Platform::register_vma(insn.rs1).into(), - value: rs1, - previous_cycle, - }), - rs2: rs2_read.map(|rs2| ReadOp { - addr: Platform::register_vma(insn.rs2).into(), - value: rs2, - previous_cycle, - }), - rd: rd.map(|rd| WriteOp { - addr: Platform::register_vma(insn.rd_internal() as RegIdx).into(), - value: rd, - previous_cycle, - }), + has_rs1, + has_rs2, + has_rd, + has_memory_op, + rs1: rs1_read + .map(|rs1| ReadOp { + addr: Platform::register_vma(insn.rs1).into(), + value: rs1, + previous_cycle, + }) + .unwrap_or_default(), + rs2: rs2_read + .map(|rs2| ReadOp { + addr: Platform::register_vma(insn.rs2).into(), + value: rs2, + previous_cycle, + }) + .unwrap_or_default(), + rd: rd + .map(|rd| WriteOp { + addr: Platform::register_vma(insn.rd_internal() as RegIdx).into(), + value: rd, + previous_cycle, + }) + .unwrap_or_default(), insn, - memory_op, - syscall: None, + memory_op: memory_op.unwrap_or_default(), + syscall_index: StepRecord::NO_SYSCALL, heap_maxtouch_addr, hint_maxtouch_addr, } @@ -645,19 +749,23 @@ impl StepRecord { } pub fn rs1(&self) -> Option { - self.rs1.clone() + if self.has_rs1 { Some(self.rs1) } else { None } } pub fn rs2(&self) -> Option { - self.rs2.clone() + if self.has_rs2 { Some(self.rs2) } else { None } } pub fn rd(&self) -> Option { - self.rd.clone() + if self.has_rd { Some(self.rd) } else { None } } pub fn memory_op(&self) -> Option { - self.memory_op.clone() + if self.has_memory_op { + Some(self.memory_op) + } else { + None + } } #[inline(always)] @@ -665,8 +773,19 @@ impl StepRecord { self.pc.before == self.pc.after } - pub fn syscall(&self) -> Option<&SyscallWitness> { - self.syscall.as_ref() + /// Returns true if this step has a syscall witness. + pub fn has_syscall(&self) -> bool { + self.syscall_index != Self::NO_SYSCALL + } + + /// Look up the syscall witness from a separate store. + /// The store is typically obtained from `FullTracer::syscall_witnesses()`. + pub fn syscall<'a>(&self, store: &'a [SyscallWitness]) -> Option<&'a SyscallWitness> { + if self.syscall_index == Self::NO_SYSCALL { + None + } else { + Some(&store[self.syscall_index as usize]) + } } } @@ -684,6 +803,9 @@ pub struct FullTracer { pending_index: usize, pending_cycle: Cycle, + /// Syscall witnesses stored separately (StepRecord references by index). + syscall_witnesses: Vec, + // record each section max access address // (start_addr -> (start_addr, end_addr, min_access_addr, max_access_addr)) mmio_min_max_access: Option>, @@ -724,6 +846,7 @@ impl FullTracer { len: 0, pending_index: 0, pending_cycle: Self::SUBCYCLES_PER_INSN, + syscall_witnesses: Vec::new(), mmio_min_max_access: Some(mmio_max_access), platform: platform.clone(), latest_accesses: LatestAccesses::new(platform), @@ -760,6 +883,7 @@ impl FullTracer { pub fn reset_step_buffer(&mut self) { self.len = 0; self.pending_index = 0; + self.syscall_witnesses.clear(); self.reset_pending_slot(); } @@ -767,6 +891,11 @@ impl FullTracer { &self.records[..self.len] } + /// Returns the syscall witness store. Pass this to `StepRecord::syscall()`. + pub fn syscall_witnesses(&self) -> &[SyscallWitness] { + &self.syscall_witnesses + } + #[inline(always)] pub fn step_record(&self, index: StepIndex) -> &StepRecord { assert!( @@ -822,41 +951,41 @@ impl FullTracer { #[inline(always)] pub fn load_register(&mut self, idx: RegIdx, value: Word) { let addr = Platform::register_vma(idx).into(); - match ( - self.records[self.pending_index].rs1.as_ref(), - self.records[self.pending_index].rs2.as_ref(), - ) { - (None, None) => { - self.records[self.pending_index].rs1 = Some(ReadOp { - addr, - value, - previous_cycle: self.track_access(addr, Self::SUBCYCLE_RS1), - }); - } - (Some(_), None) => { - self.records[self.pending_index].rs2 = Some(ReadOp { - addr, - value, - previous_cycle: self.track_access(addr, Self::SUBCYCLE_RS2), - }); - } - _ => unimplemented!("Only two register reads are supported"), + if !self.records[self.pending_index].has_rs1 { + let previous_cycle = self.track_access(addr, Self::SUBCYCLE_RS1); + self.records[self.pending_index].rs1 = ReadOp { + addr, + value, + previous_cycle, + }; + self.records[self.pending_index].has_rs1 = true; + } else if !self.records[self.pending_index].has_rs2 { + let previous_cycle = self.track_access(addr, Self::SUBCYCLE_RS2); + self.records[self.pending_index].rs2 = ReadOp { + addr, + value, + previous_cycle, + }; + self.records[self.pending_index].has_rs2 = true; + } else { + unimplemented!("Only two register reads are supported"); } } #[inline(always)] pub fn store_register(&mut self, idx: RegIdx, value: Change) { - if self.records[self.pending_index].rd.is_some() { + if self.records[self.pending_index].has_rd { unimplemented!("Only one register write is supported"); } let addr = Platform::register_vma(idx).into(); let previous_cycle = self.track_access(addr, Self::SUBCYCLE_RD); - self.records[self.pending_index].rd = Some(WriteOp { + self.records[self.pending_index].rd = WriteOp { addr, value, previous_cycle, - }); + }; + self.records[self.pending_index].has_rd = true; } #[inline(always)] @@ -866,7 +995,7 @@ impl FullTracer { #[inline(always)] pub fn store_memory(&mut self, addr: WordAddr, value: Change) { - if self.records[self.pending_index].memory_op.is_some() { + if self.records[self.pending_index].has_memory_op { unimplemented!("Only one memory access is supported"); } @@ -899,19 +1028,26 @@ impl FullTracer { } } - self.records[self.pending_index].memory_op = Some(WriteOp { + let previous_cycle = self.track_access(addr, Self::SUBCYCLE_MEM); + self.records[self.pending_index].memory_op = WriteOp { addr, value, - previous_cycle: self.track_access(addr, Self::SUBCYCLE_MEM), - }); + previous_cycle, + }; + self.records[self.pending_index].has_memory_op = true; } #[inline(always)] pub fn track_syscall(&mut self, effects: SyscallEffects) { let witness = effects.finalize(self); let record = &mut self.records[self.pending_index]; - assert!(record.syscall.is_none(), "Only one syscall per step"); - record.syscall = Some(witness); + assert!( + record.syscall_index == StepRecord::NO_SYSCALL, + "Only one syscall per step" + ); + let idx = self.syscall_witnesses.len(); + self.syscall_witnesses.push(witness); + record.syscall_index = idx as u32; } #[inline(always)] @@ -972,6 +1108,7 @@ pub struct PreflightTracer { mmio_min_max_access: Option>, latest_accesses: LatestAccesses, next_accesses: NextCycleAccess, + next_accesses_vec: Vec, register_reads_tracked: u8, planner: Option, current_shard_start_cycle: Cycle, @@ -996,6 +1133,7 @@ impl fmt::Debug for PreflightTracer { .field("mmio_min_max_access", &self.mmio_min_max_access) .field("latest_accesses", &self.latest_accesses) .field("next_accesses", &self.next_accesses) + .field("next_accesses_vec_len", &self.next_accesses_vec.len()) .field("register_reads_tracked", &self.register_reads_tracked) .field("planner", &self.planner) .field("current_shard_start_cycle", &self.current_shard_start_cycle) @@ -1093,6 +1231,7 @@ impl PreflightTracer { mmio_min_max_access: Some(init_mmio_min_max_access(platform)), latest_accesses: LatestAccesses::new(platform), next_accesses: FxHashMap::default(), + next_accesses_vec: Vec::new(), register_reads_tracked: 0, planner: Some(ShardPlanBuilder::new( max_cell_per_shard, @@ -1105,14 +1244,20 @@ impl PreflightTracer { tracer } - pub fn into_shard_plan(self) -> (ShardPlanBuilder, NextCycleAccess) { + pub fn into_shard_plan( + self, + ) -> ( + ShardPlanBuilder, + NextCycleAccess, + Vec, + ) { let Some(mut planner) = self.planner else { panic!("shard planner missing") }; if !planner.finalized { planner.finalize(self.cycle); } - (planner, self.next_accesses) + (planner, self.next_accesses, self.next_accesses_vec) } #[inline(always)] @@ -1233,6 +1378,8 @@ impl Tracer for PreflightTracer { .entry(prev_cycle) .or_default() .push((addr, cur_cycle)); + self.next_accesses_vec + .push(PackedNextAccessEntry::new(prev_cycle, addr.0, cur_cycle)); } prev_cycle } @@ -1371,6 +1518,7 @@ impl Tracer for FullTracer { } #[derive(Copy, Clone, Default, PartialEq, Eq)] +#[repr(C)] pub struct Change { pub before: T, pub after: T, @@ -1387,3 +1535,92 @@ impl fmt::Debug for Change { write!(f, "{:?} -> {:?}", self.before, self.after) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_step_record_is_copy_and_compact() { + // Verify StepRecord is Copy (this compiles only if Copy is implemented) + fn assert_copy() {} + assert_copy::(); + + // Verify repr(C) compactness — should be well under 128 bytes + let size = std::mem::size_of::(); + eprintln!("StepRecord size: {} bytes", size); + assert!( + size <= 144, + "StepRecord should be compact for GPU transfer: got {} bytes", + size + ); + } + + #[test] + fn test_supporting_types_are_copy() { + fn assert_copy() {} + assert_copy::(); + assert_copy::(); + assert_copy::>(); + assert_copy::>(); + } + + /// Verify exact byte offsets of StepRecord fields for CUDA struct alignment. + /// If this test fails, the CUDA step_record.cuh header must be updated to match. + #[test] + fn test_step_record_layout_for_gpu() { + use std::mem; + + macro_rules! offset_of { + ($type:ty, $field:ident) => {{ + let val = <$type>::default(); + let base = &val as *const _ as usize; + let field = &val.$field as *const _ as usize; + field - base + }}; + } + + // Sub-type sizes + assert_eq!(mem::size_of::(), 12, "Instruction size"); + assert_eq!(mem::size_of::(), 16, "ReadOp size"); + assert_eq!(mem::size_of::(), 24, "WriteOp size"); + assert_eq!( + mem::size_of::>(), + 8, + "Change size" + ); + + // StepRecord field offsets — these must match step_record.cuh + assert_eq!(offset_of!(StepRecord, cycle), 0); + assert_eq!(offset_of!(StepRecord, pc), 8); + assert_eq!(offset_of!(StepRecord, heap_maxtouch_addr), 16); + assert_eq!(offset_of!(StepRecord, hint_maxtouch_addr), 24); + assert_eq!(offset_of!(StepRecord, insn), 32); + assert_eq!(offset_of!(StepRecord, has_rs1), 44); + assert_eq!(offset_of!(StepRecord, has_rs2), 45); + assert_eq!(offset_of!(StepRecord, has_rd), 46); + assert_eq!(offset_of!(StepRecord, has_memory_op), 47); + assert_eq!(offset_of!(StepRecord, rs1), 48); + assert_eq!(offset_of!(StepRecord, rs2), 64); + assert_eq!(offset_of!(StepRecord, rd), 80); + assert_eq!(offset_of!(StepRecord, memory_op), 104); + assert_eq!(offset_of!(StepRecord, syscall_index), 128); + + // Total size + assert_eq!(mem::size_of::(), 136, "StepRecord total size"); + assert_eq!(mem::align_of::(), 8, "StepRecord alignment"); + + // InsnKind must be repr(u8) for CUDA compatibility + assert_eq!( + mem::size_of::(), + 1, + "InsnKind must be 1 byte (repr(u8))" + ); + + eprintln!( + "StepRecord layout verified: {} bytes, {} align", + mem::size_of::(), + mem::align_of::() + ); + } +} diff --git a/ceno_emul/src/vm_state.rs b/ceno_emul/src/vm_state.rs index d65844682..0613436be 100644 --- a/ceno_emul/src/vm_state.rs +++ b/ceno_emul/src/vm_state.rs @@ -155,7 +155,7 @@ impl VMState { } pub fn init_register_unsafe(&mut self, idx: RegIdx, value: Word) { - self.registers[idx] = value; + self.registers[idx as usize] = value; } fn halt(&mut self, exit_code: u32) { @@ -171,7 +171,7 @@ impl VMState { } for (idx, value) in effects.iter_reg_values() { - self.registers[idx] = value; + self.registers[idx as usize] = value; } let next_pc = effects.next_pc.unwrap_or(self.pc + PC_STEP_SIZE as u32); @@ -252,7 +252,7 @@ impl EmuContext for VMState { if idx != 0 { let before = self.peek_register(idx); self.tracer.store_register(idx, Change { before, after }); - self.registers[idx] = after; + self.registers[idx as usize] = after; } Ok(()) } @@ -276,7 +276,7 @@ impl EmuContext for VMState { /// Get the value of a register without side-effects. fn peek_register(&self, idx: RegIdx) -> Word { - self.registers[idx] + self.registers[idx as usize] } /// Get the value of a memory word without side-effects. diff --git a/ceno_host/tests/test_elf.rs b/ceno_host/tests/test_elf.rs index ff752267d..b06c48d6e 100644 --- a/ceno_host/tests/test_elf.rs +++ b/ceno_host/tests/test_elf.rs @@ -3,7 +3,7 @@ use std::{collections::BTreeSet, iter::from_fn, sync::Arc}; use anyhow::Result; use ceno_emul::{ BN254_FP_WORDS, BN254_FP2_WORDS, BN254_POINT_WORDS, CENO_PLATFORM, EmuContext, InsnKind, - Platform, Program, SECP256K1_ARG_WORDS, SECP256K1_COORDINATE_WORDS, StepRecord, + Platform, Program, SECP256K1_ARG_WORDS, SECP256K1_COORDINATE_WORDS, StepRecord, SyscallWitness, UINT256_WORDS_FIELD_ELEMENT, VMState, WORD_SIZE, Word, WordAddr, WriteOp, host_utils::{read_all_messages, read_all_messages_as_words}, }; @@ -21,7 +21,7 @@ fn test_ceno_rt_mini() -> Result<()> { ..CENO_PLATFORM.clone() }; let mut state = VMState::new(platform, Arc::new(program)); - let _steps = run(&mut state)?; + let (_steps, _syscall_witnesses) = run(&mut state)?; Ok(()) } @@ -39,7 +39,7 @@ fn test_ceno_rt_panic() { ..CENO_PLATFORM.clone() }; let mut state = VMState::new(platform, Arc::new(program)); - let steps = run(&mut state).unwrap(); + let (steps, _syscall_witnesses) = run(&mut state).unwrap(); let last = steps.last().unwrap(); assert_eq!(last.insn().kind, InsnKind::ECALL); assert_eq!(last.rs1().unwrap().value, Platform::ecall_halt()); @@ -56,7 +56,7 @@ fn test_ceno_rt_mem() -> Result<()> { }; let sheap = program.sheap.into(); let mut state = VMState::new(platform, Arc::new(program.clone())); - let _steps = run(&mut state)?; + let (_steps, _syscall_witnesses) = run(&mut state)?; let value = state.peek_memory(sheap); assert_eq!(value, 6765, "Expected Fibonacci 20, got {}", value); @@ -72,7 +72,7 @@ fn test_ceno_rt_alloc() -> Result<()> { ..CENO_PLATFORM.clone() }; let mut state = VMState::new(platform, Arc::new(program)); - let _steps = run(&mut state)?; + let (_steps, _syscall_witnesses) = run(&mut state)?; // Search for the RAM action of the test program. let mut found = (false, false); @@ -102,7 +102,7 @@ fn test_ceno_rt_io() -> Result<()> { ..CENO_PLATFORM.clone() }; let mut state = VMState::new(platform, Arc::new(program)); - let _steps = run(&mut state)?; + let (_steps, _syscall_witnesses) = run(&mut state)?; let all_messages = messages_to_strings(&read_all_messages(&state)); for msg in &all_messages { @@ -235,7 +235,7 @@ fn test_hashing() -> Result<()> { fn test_keccak_syscall() -> Result<()> { let program_elf = ceno_examples::keccak_syscall; let mut state = VMState::new_from_elf(unsafe_platform(), program_elf)?; - let steps = run(&mut state)?; + let (steps, syscall_witnesses) = run(&mut state)?; // Expect the program to have written successive states between Keccak permutations. let keccak_first_iter_outs = sample_keccak_f(1); @@ -251,7 +251,10 @@ fn test_keccak_syscall() -> Result<()> { } // Find the syscall records. - let syscalls = steps.iter().filter_map(|step| step.syscall()).collect_vec(); + let syscalls = steps + .iter() + .filter_map(|step| step.syscall(&syscall_witnesses)) + .collect_vec(); assert_eq!(syscalls.len(), 100); // Check the syscall effects. @@ -293,9 +296,12 @@ fn bytes_to_words(bytes: [u8; 65]) -> [u32; 16] { fn test_secp256k1() -> Result<()> { let program_elf = ceno_examples::secp256k1; let mut state = VMState::new_from_elf(unsafe_platform(), program_elf)?; - let steps = run(&mut state)?; + let (steps, syscall_witnesses) = run(&mut state)?; - let syscalls = steps.iter().filter_map(|step| step.syscall()).collect_vec(); + let syscalls = steps + .iter() + .filter_map(|step| step.syscall(&syscall_witnesses)) + .collect_vec(); assert!(!syscalls.is_empty()); Ok(()) @@ -305,9 +311,12 @@ fn test_secp256k1() -> Result<()> { fn test_secp256k1_add() -> Result<()> { let program_elf = ceno_examples::secp256k1_add_syscall; let mut state = VMState::new_from_elf(unsafe_platform(), program_elf)?; - let steps = run(&mut state)?; + let (steps, syscall_witnesses) = run(&mut state)?; - let syscalls = steps.iter().filter_map(|step| step.syscall()).collect_vec(); + let syscalls = steps + .iter() + .filter_map(|step| step.syscall(&syscall_witnesses)) + .collect_vec(); assert_eq!(syscalls.len(), 1); let witness = syscalls[0]; @@ -358,9 +367,12 @@ fn test_secp256k1_double() -> Result<()> { let program_elf = ceno_examples::secp256k1_double_syscall; let mut state = VMState::new_from_elf(unsafe_platform(), program_elf)?; - let steps = run(&mut state)?; + let (steps, syscall_witnesses) = run(&mut state)?; - let syscalls = steps.iter().filter_map(|step| step.syscall()).collect_vec(); + let syscalls = steps + .iter() + .filter_map(|step| step.syscall(&syscall_witnesses)) + .collect_vec(); assert_eq!(syscalls.len(), 1); let witness = syscalls[0]; @@ -394,9 +406,12 @@ fn test_secp256k1_decompress() -> Result<()> { let program_elf = ceno_examples::secp256k1_decompress_syscall; let mut state = VMState::new_from_elf(unsafe_platform(), program_elf)?; - let steps = run(&mut state)?; + let (steps, syscall_witnesses) = run(&mut state)?; - let syscalls = steps.iter().filter_map(|step| step.syscall()).collect_vec(); + let syscalls = steps + .iter() + .filter_map(|step| step.syscall(&syscall_witnesses)) + .collect_vec(); assert_eq!(syscalls.len(), 1); let witness = syscalls[0]; @@ -456,8 +471,11 @@ fn test_secp256k1_ecrecover() -> Result<()> { let program_elf = ceno_examples::secp256k1_ecrecover; let mut state = VMState::new_from_elf(unsafe_platform(), program_elf)?; - let steps = run(&mut state)?; - let syscalls = steps.iter().filter_map(|step| step.syscall()).collect_vec(); + let (steps, syscall_witnesses) = run(&mut state)?; + let syscalls = steps + .iter() + .filter_map(|step| step.syscall(&syscall_witnesses)) + .collect_vec(); assert!(!syscalls.is_empty()); Ok(()) @@ -479,8 +497,11 @@ fn test_sha256_extend() -> Result<()> { ]; let mut state = VMState::new_from_elf(unsafe_platform(), program_elf)?; - let steps = run(&mut state)?; - let syscalls = steps.iter().filter_map(|step| step.syscall()).collect_vec(); + let (steps, syscall_witnesses) = run(&mut state)?; + let syscalls = steps + .iter() + .filter_map(|step| step.syscall(&syscall_witnesses)) + .collect_vec(); assert_eq!(syscalls.len(), 48); for round in 0..48 { @@ -534,10 +555,13 @@ fn test_sha256_full() -> Result<()> { fn test_bn254_fptower_syscalls() -> Result<()> { let program_elf = ceno_examples::bn254_fptower_syscalls; let mut state = VMState::new_from_elf(unsafe_platform(), program_elf)?; - let steps = run(&mut state)?; + let (steps, syscall_witnesses) = run(&mut state)?; const RUNS: usize = 10; - let syscalls = steps.iter().filter_map(|step| step.syscall()).collect_vec(); + let syscalls = steps + .iter() + .filter_map(|step| step.syscall(&syscall_witnesses)) + .collect_vec(); assert_eq!(syscalls.len(), 4 * RUNS); for witness in syscalls.iter() { @@ -584,9 +608,12 @@ fn test_bn254_fptower_syscalls() -> Result<()> { fn test_bn254_curve() -> Result<()> { let program_elf = ceno_examples::bn254_curve_syscalls; let mut state = VMState::new_from_elf(unsafe_platform(), program_elf)?; - let steps = run(&mut state)?; + let (steps, syscall_witnesses) = run(&mut state)?; - let syscalls = steps.iter().filter_map(|step| step.syscall()).collect_vec(); + let syscalls = steps + .iter() + .filter_map(|step| step.syscall(&syscall_witnesses)) + .collect_vec(); assert_eq!(syscalls.len(), 3); // add @@ -652,9 +679,12 @@ fn test_uint256_mul() -> Result<()> { let program_elf = ceno_examples::uint256_mul_syscall; let mut state = VMState::new_from_elf(unsafe_platform(), program_elf)?; - let steps = run(&mut state)?; + let (steps, syscall_witnesses) = run(&mut state)?; - let syscalls = steps.iter().filter_map(|step| step.syscall()).collect_vec(); + let syscalls = steps + .iter() + .filter_map(|step| step.syscall(&syscall_witnesses)) + .collect_vec(); assert_eq!(syscalls.len(), 1); let witness = syscalls[0]; @@ -805,9 +835,10 @@ fn messages_to_strings(messages: &[Vec]) -> Vec { .collect() } -fn run(state: &mut VMState) -> Result> { +fn run(state: &mut VMState) -> Result<(Vec, Vec)> { state.iter_until_halt().collect::>>()?; let steps = state.tracer().recorded_steps().to_vec(); + let syscall_witnesses = state.tracer().syscall_witnesses().to_vec(); eprintln!("Emulator ran for {} steps.", steps.len()); - Ok(steps) + Ok((steps, syscall_witnesses)) } diff --git a/ceno_zkvm/Cargo.toml b/ceno_zkvm/Cargo.toml index 42d4e869b..eb350c444 100644 --- a/ceno_zkvm/Cargo.toml +++ b/ceno_zkvm/Cargo.toml @@ -134,3 +134,8 @@ name = "weierstrass_add" [[bench]] harness = false name = "weierstrass_double" + +[[bench]] +harness = false +name = "witgen_add_gpu" +required-features = ["gpu"] diff --git a/ceno_zkvm/benches/witgen_add_gpu.rs b/ceno_zkvm/benches/witgen_add_gpu.rs new file mode 100644 index 000000000..7b68d93bb --- /dev/null +++ b/ceno_zkvm/benches/witgen_add_gpu.rs @@ -0,0 +1,134 @@ +use std::time::Duration; + +use ceno_emul::{ByteAddr, Change, InsnKind, StepRecord, encode_rv32}; +use ceno_zkvm::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + e2e::ShardContext, + instructions::{Instruction, riscv::arith::AddInstruction}, + structs::ProgramParams, +}; +use criterion::*; +use ff_ext::BabyBearExt4; + +#[cfg(feature = "gpu")] +use ceno_gpu::bb31::CudaHalBB31; +#[cfg(feature = "gpu")] +use ceno_zkvm::instructions::riscv::gpu::add::extract_add_column_map; + +mod alloc; + +type E = BabyBearExt4; + +criterion_group! { + name = witgen_add; + config = Criterion::default().warm_up_time(Duration::from_millis(2000)); + targets = bench_witgen_add +} + +criterion_main!(witgen_add); + +fn make_test_steps(n: usize) -> Vec { + let pc_start = 0x1000u32; + (0..n) + .map(|i| { + let rs1 = (i as u32) % 1000 + 1; + let rs2 = (i as u32) % 500 + 3; + let rd_before = (i as u32) % 200; + let rd_after = rs1.wrapping_add(rs2); + let cycle = 4 + (i as u64) * 4; + let pc = ByteAddr(pc_start + (i as u32) * 4); + let insn_code = encode_rv32(InsnKind::ADD, 2, 3, 4, 0); + StepRecord::new_r_instruction( + cycle, + pc, + insn_code, + rs1, + rs2, + Change::new(rd_before, rd_after), + 0, + ) + }) + .collect() +} + +#[cfg(feature = "gpu")] +fn step_records_to_bytes(records: &[StepRecord]) -> &[u8] { + unsafe { + std::slice::from_raw_parts( + records.as_ptr() as *const u8, + records.len() * std::mem::size_of::(), + ) + } +} + +fn bench_witgen_add(c: &mut Criterion) { + let mut cs = ConstraintSystem::::new(|| "bench"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + AddInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let num_witin = cb.cs.num_witin as usize; + let num_structural_witin = cb.cs.num_structural_witin as usize; + + #[cfg(feature = "gpu")] + let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); + #[cfg(feature = "gpu")] + let col_map = extract_add_column_map(&config, num_witin); + + for pow in [10, 12, 14, 16, 18] { + let n = 1usize << pow; + let mut group = c.benchmark_group(format!("witgen_add_2^{}", pow)); + group.sample_size(10); + + let steps = make_test_steps(n); + let indices: Vec = (0..n).collect(); + + // CPU benchmark + group.bench_function("cpu_assign_instances", |b| { + b.iter(|| { + let mut shard_ctx = ShardContext::default(); + AddInstruction::::assign_instances( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap() + }) + }); + + // GPU benchmark (total: H2D records + indices + kernel + synchronize) + #[cfg(feature = "gpu")] + group.bench_function("gpu_total", |b| { + let steps_bytes = step_records_to_bytes(&steps); + let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); + b.iter(|| { + let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); + let shard_ctx = ShardContext::default(); + let shard_offset = shard_ctx.current_shard_offset_cycle(); + hal.witgen_add(&col_map, &gpu_records, &indices_u32, shard_offset, None) + .unwrap() + }) + }); + + // GPU benchmark (kernel only: records pre-uploaded) + #[cfg(feature = "gpu")] + { + let steps_bytes = step_records_to_bytes(&steps); + let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); + let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); + let shard_ctx = ShardContext::default(); + let shard_offset = shard_ctx.current_shard_offset_cycle(); + + group.bench_function("gpu_kernel_only", |b| { + b.iter(|| { + hal.witgen_add(&col_map, &gpu_records, &indices_u32, shard_offset, None) + .unwrap() + }) + }); + } + + group.finish(); + } +} diff --git a/ceno_zkvm/src/e2e.rs b/ceno_zkvm/src/e2e.rs index 215cbf7b6..b4cc40587 100644 --- a/ceno_zkvm/src/e2e.rs +++ b/ceno_zkvm/src/e2e.rs @@ -24,10 +24,12 @@ use crate::{ }; use ceno_emul::{ Addr, ByteAddr, CENO_PLATFORM, Cycle, EmuContext, FullTracer, FullTracerConfig, IterAddresses, - NextCycleAccess, Platform, PreflightTracer, PreflightTracerConfig, Program, StepCellExtractor, - StepIndex, StepRecord, Tracer, VM_REG_COUNT, VMState, WORD_SIZE, Word, WordAddr, - host_utils::read_all_messages, + NextCycleAccess, PackedNextAccessEntry, Platform, PreflightTracer, PreflightTracerConfig, + Program, RegIdx, StepCellExtractor, StepIndex, StepRecord, SyscallWitness, Tracer, + VM_REG_COUNT, VMState, WORD_SIZE, Word, WordAddr, host_utils::read_all_messages, }; +#[cfg(feature = "gpu")] +use ceno_gpu::CudaHal; use clap::ValueEnum; use either::Either; use ff_ext::{ExtensionField, SmallField}; @@ -39,6 +41,7 @@ use itertools::MinMaxResult; use itertools::{Itertools, chain}; use mpcs::{PolynomialCommitmentScheme, SecurityLevel}; use multilinear_extensions::util::max_usable_threads; +use rayon::prelude::*; use rustc_hash::FxHashSet; use serde::Serialize; #[cfg(debug_assertions)] @@ -179,11 +182,18 @@ impl Default for MultiProver { } } +/// Pre-sorted packed future access entries for GPU bulk H2D upload. +/// Sorted by (cycle, addr) composite key. +pub struct SortedNextAccesses { + pub packed: Vec, +} + pub struct ShardContext<'a> { pub shard_id: usize, num_shards: usize, max_cycle: Cycle, pub addr_future_accesses: Arc, + pub sorted_next_accesses: Arc, addr_accessed_tbs: Either>, &'a mut Vec>, read_records_tbs: Either>, &'a mut BTreeMap>, @@ -199,6 +209,12 @@ pub struct ShardContext<'a> { pub platform: Platform, pub shard_heap_addr_range: Range, pub shard_hint_addr_range: Range, + /// Syscall witnesses for StepRecord::syscall() lookups. + pub syscall_witnesses: Arc>, + /// GPU-produced compact EC shard records (raw bytes of GpuShardRamRecord). + /// Each record is GPU_SHARD_RAM_RECORD_SIZE bytes. These bypass BTreeMap and + /// are converted to ShardRamInput in assign_shared_circuit. + pub gpu_ec_records: Vec, } impl<'a> Default for ShardContext<'a> { @@ -213,6 +229,7 @@ impl<'a> Default for ShardContext<'a> { num_shards: 1, max_cycle: Cycle::MAX, addr_future_accesses: Arc::new(Default::default()), + sorted_next_accesses: Arc::new(SortedNextAccesses { packed: vec![] }), addr_accessed_tbs: Either::Left(vec![Vec::new(); max_threads]), read_records_tbs: Either::Left( (0..max_threads) @@ -233,10 +250,15 @@ impl<'a> Default for ShardContext<'a> { platform: CENO_PLATFORM.clone(), shard_heap_addr_range: CENO_PLATFORM.heap.clone(), shard_hint_addr_range: CENO_PLATFORM.hints.clone(), + syscall_witnesses: Arc::new(Vec::new()), + gpu_ec_records: vec![], } } } +/// Size of a single GpuShardRamRecord in bytes (must match CUDA struct). +pub const GPU_SHARD_RAM_RECORD_SIZE: usize = 104; + /// `prover_id` and `num_provers` in MultiProver are exposed as arguments /// to specify the number of physical provers in a cluster, /// each mark with a prover_id. @@ -248,6 +270,41 @@ impl<'a> Default for ShardContext<'a> { /// for example, if there are 10 shards and 3 provers, /// the shard counts will be distributed as 3, 3, and 4, ensuring an even workload across all provers. impl<'a> ShardContext<'a> { + /// Create a new ShardContext with the same shard metadata but empty record storage. + /// Useful for debug comparisons against the actual shard context. + pub fn new_empty_like(&self) -> ShardContext<'static> { + let max_threads = max_usable_threads(); + ShardContext { + shard_id: self.shard_id, + num_shards: self.num_shards, + max_cycle: self.max_cycle, + addr_future_accesses: self.addr_future_accesses.clone(), + sorted_next_accesses: self.sorted_next_accesses.clone(), + addr_accessed_tbs: Either::Left(vec![Vec::new(); max_threads]), + read_records_tbs: Either::Left( + (0..max_threads) + .map(|_| BTreeMap::new()) + .collect::>(), + ), + write_records_tbs: Either::Left( + (0..max_threads) + .map(|_| BTreeMap::new()) + .collect::>(), + ), + cur_shard_cycle_range: self.cur_shard_cycle_range.clone(), + expected_inst_per_shard: self.expected_inst_per_shard, + max_num_cross_shard_accesses: self.max_num_cross_shard_accesses, + prev_shard_cycle_range: self.prev_shard_cycle_range.clone(), + prev_shard_heap_range: self.prev_shard_heap_range.clone(), + prev_shard_hint_range: self.prev_shard_hint_range.clone(), + platform: self.platform.clone(), + shard_heap_addr_range: self.shard_heap_addr_range.clone(), + shard_hint_addr_range: self.shard_hint_addr_range.clone(), + syscall_witnesses: self.syscall_witnesses.clone(), + gpu_ec_records: vec![], + } + } + pub fn get_forked(&mut self) -> Vec> { match ( &mut self.read_records_tbs, @@ -267,6 +324,7 @@ impl<'a> ShardContext<'a> { num_shards: self.num_shards, max_cycle: self.max_cycle, addr_future_accesses: self.addr_future_accesses.clone(), + sorted_next_accesses: self.sorted_next_accesses.clone(), addr_accessed_tbs: Either::Right(addr_accessed_tbs), read_records_tbs: Either::Right(read), write_records_tbs: Either::Right(write), @@ -279,6 +337,8 @@ impl<'a> ShardContext<'a> { platform: self.platform.clone(), shard_heap_addr_range: self.shard_heap_addr_range.clone(), shard_hint_addr_range: self.shard_hint_addr_range.clone(), + syscall_witnesses: self.syscall_witnesses.clone(), + gpu_ec_records: vec![], }) .collect_vec(), _ => panic!("invalid type"), @@ -387,9 +447,55 @@ impl<'a> ShardContext<'a> { }) } + #[inline(always)] + pub fn insert_read_record(&mut self, addr: WordAddr, record: RAMRecord) { + let ram_record = self + .read_records_tbs + .as_mut() + .right() + .expect("illegal type"); + ram_record.insert(addr, record); + } + + #[inline(always)] + pub fn insert_write_record(&mut self, addr: WordAddr, record: RAMRecord) { + let ram_record = self + .write_records_tbs + .as_mut() + .right() + .expect("illegal type"); + ram_record.insert(addr, record); + } + + #[inline(always)] + pub fn push_addr_accessed(&mut self, addr: WordAddr) { + let addr_accessed = self + .addr_accessed_tbs + .as_mut() + .right() + .expect("illegal type"); + addr_accessed.push(addr); + } + + /// Extend GPU EC records with raw bytes from GpuShardRamRecord slice. + /// Called from the GPU EC path to accumulate records across kernel invocations. + pub fn extend_gpu_ec_records_raw(&mut self, raw_bytes: &[u8]) { + self.gpu_ec_records.extend_from_slice(raw_bytes); + } + + /// Returns true if GPU EC records have been collected. + pub fn has_gpu_ec_records(&self) -> bool { + !self.gpu_ec_records.is_empty() + } + + /// Take GPU EC records, leaving the field empty. + pub fn take_gpu_ec_records(&mut self) -> Vec { + std::mem::take(&mut self.gpu_ec_records) + } + #[inline(always)] #[allow(clippy::too_many_arguments)] - pub fn send( + pub fn record_send_without_touch( &mut self, ram_type: crate::structs::RAMType, addr: WordAddr, @@ -406,15 +512,9 @@ impl<'a> ShardContext<'a> { let addr_raw = addr.baddr().0; let is_heap = self.platform.heap.contains(&addr_raw); let is_hint = self.platform.hints.contains(&addr_raw); - // 1. checking reads from the external bus if prev_cycle > 0 || (prev_cycle == 0 && (!is_heap && !is_hint)) { let prev_shard_id = self.extract_shard_id_by_cycle(prev_cycle); - let ram_record = self - .read_records_tbs - .as_mut() - .right() - .expect("illegal type"); - ram_record.insert( + self.insert_read_record( addr, RAMRecord { ram_type, @@ -433,22 +533,15 @@ impl<'a> ShardContext<'a> { prev_cycle == 0 && (is_heap || is_hint), "addr {addr_raw:x} prev_cycle {prev_cycle}, is_heap {is_heap}, is_hint {is_hint}", ); - // 2. handle heap/hint initial reads outside the shard range. let prev_shard_id = if is_heap && !self.shard_heap_addr_range.contains(&addr_raw) { Some(self.extract_shard_id_by_heap_addr(addr_raw)) } else if is_hint && !self.shard_hint_addr_range.contains(&addr_raw) { Some(self.extract_shard_id_by_hint_addr(addr_raw)) } else { - // dynamic init in current shard, skip and do nothing None }; if let Some(prev_shard_id) = prev_shard_id { - let ram_record = self - .read_records_tbs - .as_mut() - .right() - .expect("illegal type"); - ram_record.insert( + self.insert_read_record( addr, RAMRecord { ram_type, @@ -466,18 +559,12 @@ impl<'a> ShardContext<'a> { } } - // check write to external mem bus if let Some(future_touch_cycle) = self.find_future_next_access(cycle, addr) && self.after_current_shard_cycle(future_touch_cycle) && self.is_in_current_shard(cycle) { let shard_cycle = self.aligned_current_ts(cycle); - let ram_record = self - .write_records_tbs - .as_mut() - .right() - .expect("illegal type"); - ram_record.insert( + self.insert_write_record( addr, RAMRecord { ram_type, @@ -492,26 +579,53 @@ impl<'a> ShardContext<'a> { }, ); } + } - let addr_accessed = self - .addr_accessed_tbs - .as_mut() - .right() - .expect("illegal type"); - addr_accessed.push(addr); + #[inline(always)] + #[allow(clippy::too_many_arguments)] + pub fn send( + &mut self, + ram_type: crate::structs::RAMType, + addr: WordAddr, + id: u64, + cycle: Cycle, + prev_cycle: Cycle, + value: Word, + prev_value: Option, + ) { + self.record_send_without_touch(ram_type, addr, id, cycle, prev_cycle, value, prev_value); + self.push_addr_accessed(addr); } /// merge addr accessed in different threads pub fn get_addr_accessed(&self) -> FxHashSet { - let mut merged = FxHashSet::default(); if let Either::Left(addr_accessed_tbs) = &self.addr_accessed_tbs { + let total: usize = addr_accessed_tbs.iter().map(|v| v.len()).sum(); + let mut merged = FxHashSet::with_capacity_and_hasher(total, Default::default()); for addrs in addr_accessed_tbs { merged.extend(addrs.iter().copied()); } + merged + } else { + panic!("invalid type"); + } + } + + /// merge addr accessed into a sorted Vec for fast binary search lookups. + /// Much faster than FxHashSet for large sets (avoids hashing overhead). + pub fn get_addr_accessed_sorted(&self) -> Vec { + if let Either::Left(addr_accessed_tbs) = &self.addr_accessed_tbs { + let total: usize = addr_accessed_tbs.iter().map(|v| v.len()).sum(); + let mut merged = Vec::with_capacity(total); + for addrs in addr_accessed_tbs { + merged.extend_from_slice(addrs); + } + merged.par_sort_unstable(); + merged.dedup(); + merged } else { panic!("invalid type"); } - merged } /// Splits a total count `num_shards` into up to `num_provers` non-empty parts, distributing as evenly as possible. @@ -604,6 +718,7 @@ impl ShardStepSummary { pub struct ShardContextBuilder { pub cur_shard_id: usize, addr_future_accesses: Arc, + sorted_next_accesses: Arc, prev_shard_cycle_range: Vec, prev_shard_heap_range: Vec, prev_shard_hint_range: Vec, @@ -617,6 +732,7 @@ impl Default for ShardContextBuilder { ShardContextBuilder { cur_shard_id: 0, addr_future_accesses: Arc::new(Default::default()), + sorted_next_accesses: Arc::new(SortedNextAccesses { packed: vec![] }), prev_shard_cycle_range: vec![], prev_shard_heap_range: vec![], prev_shard_hint_range: vec![], @@ -634,12 +750,47 @@ impl ShardContextBuilder { shard_cycle_boundaries: Arc>, max_cycle: Cycle, addr_future_accesses: NextCycleAccess, + next_accesses_vec: Vec, ) -> Self { assert_eq!(multi_prover.max_provers, 1); assert_eq!(multi_prover.prover_id, 0); + + let sorted_next_accesses = info_span!("next_access_presort").in_scope(|| { + let source = std::env::var("CENO_NEXT_ACCESS_SOURCE").unwrap_or_default(); + let mut entries = if source == "hashmap" { + tracing::info!("[next-access presort] converting from HashMap"); + info_span!("next_access_from_hashmap").in_scope(|| { + let mut entries = Vec::new(); + for (cycle, pairs) in addr_future_accesses.iter() { + for &(addr, next_cycle) in pairs.iter() { + entries.push(PackedNextAccessEntry::new(*cycle, addr.0, next_cycle)); + } + } + entries + }) + } else { + tracing::info!( + "[next-access presort] using preflight-appended vec ({} entries)", + next_accesses_vec.len() + ); + next_accesses_vec + }; + let len = entries.len(); + info_span!("next_access_par_sort", n = len).in_scope(|| { + entries.par_sort_unstable(); + }); + tracing::info!( + "[next-access presort] sorted {} entries ({:.2} MB)", + len, + len * 16 / (1024 * 1024) + ); + Arc::new(SortedNextAccesses { packed: entries }) + }); + ShardContextBuilder { cur_shard_id: 0, addr_future_accesses: Arc::new(addr_future_accesses), + sorted_next_accesses, prev_shard_cycle_range: vec![0], prev_shard_heap_range: vec![0], prev_shard_hint_range: vec![0], @@ -726,6 +877,7 @@ impl ShardContextBuilder { cur_shard_cycle_range: summary.first_cycle as usize ..(summary.last_cycle + FullTracer::SUBCYCLES_PER_INSN) as usize, addr_future_accesses: self.addr_future_accesses.clone(), + sorted_next_accesses: self.sorted_next_accesses.clone(), prev_shard_cycle_range: self.prev_shard_cycle_range.clone(), prev_shard_heap_range: self.prev_shard_heap_range.clone(), prev_shard_hint_range: self.prev_shard_hint_range.clone(), @@ -750,6 +902,7 @@ pub trait StepSource: Iterator { fn start_new_shard(&mut self); fn shard_steps(&self) -> &[StepRecord]; fn step_record(&self, idx: StepIndex) -> &StepRecord; + fn syscall_witnesses(&self) -> &[SyscallWitness]; } /// Lazily replays `StepRecord`s by re-running the VM up to the number of steps @@ -822,6 +975,10 @@ impl StepSource for StepReplay { fn step_record(&self, idx: StepIndex) -> &StepRecord { self.vm.tracer().step_record(idx) } + + fn syscall_witnesses(&self) -> &[SyscallWitness] { + self.vm.tracer().syscall_witnesses() + } } pub fn emulate_program<'a>( @@ -899,8 +1056,8 @@ pub fn emulate_program<'a>( let reg_final = reg_init .iter() .map(|rec| { - let index = rec.addr as usize; - if index < VM_REG_COUNT { + if (rec.addr as usize) < VM_REG_COUNT { + let index = rec.addr as RegIdx; let vma: WordAddr = Platform::register_vma(index).into(); MemFinalRecord { ram_type: RAMType::Register, @@ -1039,7 +1196,7 @@ pub fn emulate_program<'a>( } let tracer = vm.take_tracer(); - let (plan_builder, next_accesses) = tracer.into_shard_plan(); + let (plan_builder, next_accesses, next_accesses_vec) = tracer.into_shard_plan(); let max_step_shard = plan_builder.max_step_shard(); let shard_cycle_boundaries = Arc::new(plan_builder.into_cycle_boundaries()); let shard_ctx_builder = ShardContextBuilder::from_plan( @@ -1048,6 +1205,7 @@ pub fn emulate_program<'a>( shard_cycle_boundaries.clone(), max_cycle, next_accesses, + next_accesses_vec, ); tracing::info!( "num_shards: {}, max_cycle {}, shard_cycle_boundaries {:?}", @@ -1270,18 +1428,19 @@ pub fn generate_witness<'a, E: ExtensionField>( shard_id = shard_ctx_builder.cur_shard_id ) .in_scope(|| { - let time = std::time::Instant::now(); instrunction_dispatch_ctx.begin_shard(); let (mut shard_ctx, shard_summary) = - match shard_ctx_builder.position_next_shard( - &mut step_iter, - |idx, record| instrunction_dispatch_ctx.ingest_step(idx, record), - ) { + match info_span!("position_next_shard").in_scope(|| { + shard_ctx_builder.position_next_shard( + &mut step_iter, + |idx, record| instrunction_dispatch_ctx.ingest_step(idx, record), + ) + }) { Some(result) => result, None => return None, }; - tracing::debug!("position_next_shard finish in {:?}", time.elapsed()); let shard_steps = step_iter.shard_steps(); + shard_ctx.syscall_witnesses = Arc::new(step_iter.syscall_witnesses().to_vec()); let mut zkvm_witness = ZKVMWitnesses::default(); let mut pi = pi_template.clone(); @@ -1330,31 +1489,95 @@ pub fn generate_witness<'a, E: ExtensionField>( } } - let time = std::time::Instant::now(); - system_config - .config - .assign_opcode_circuit( - &system_config.zkvm_cs, - &mut shard_ctx, - &mut instrunction_dispatch_ctx, - shard_steps, - &mut zkvm_witness, - ) - .unwrap(); - tracing::debug!("assign_opcode_circuit finish in {:?}", time.elapsed()); - let time = std::time::Instant::now(); - system_config - .dummy_config - .assign_opcode_circuit( - &system_config.zkvm_cs, + #[cfg(feature = "gpu")] + let debug_compare_e2e_shard = + crate::instructions::gpu::config::is_debug_compare_enabled(); + #[cfg(not(feature = "gpu"))] + let debug_compare_e2e_shard = false; + let debug_shard_ctx_template = debug_compare_e2e_shard.then(|| clone_debug_shard_ctx(&shard_ctx)); + info_span!("assign_opcode_circuits").in_scope(|| { + system_config + .config + .assign_opcode_circuit( + &system_config.zkvm_cs, + &mut shard_ctx, + &mut instrunction_dispatch_ctx, + shard_steps, + &mut zkvm_witness, + ) + }).unwrap(); + + // Flush shared EC/addr buffers from GPU after all opcode circuits are done. + // This batch-D2Hs accumulated EC records and addr_accessed into shard_ctx. + #[cfg(feature = "gpu")] + info_span!("flush_shared_ec").in_scope(|| { + crate::instructions::gpu::dispatch::flush_shared_ec_buffers( &mut shard_ctx, - &instrunction_dispatch_ctx, - shard_steps, - &mut zkvm_witness, ) - .unwrap(); - tracing::debug!("assign_dummy_config finish in {:?}", time.elapsed()); - zkvm_witness.finalize_lk_multiplicities(); + }).unwrap(); + + // Free GPU shard_steps cache after all opcode circuits are done. + #[cfg(feature = "gpu")] + crate::instructions::gpu::cache::invalidate_shard_steps_cache(); + + info_span!("assign_dummy_circuits").in_scope(|| { + system_config + .dummy_config + .assign_opcode_circuit( + &system_config.zkvm_cs, + &mut shard_ctx, + &instrunction_dispatch_ctx, + shard_steps, + &mut zkvm_witness, + ) + }).unwrap(); + info_span!("finalize_lk_multiplicities").in_scope(|| { + zkvm_witness.finalize_lk_multiplicities(); + }); + + if let Some(mut cpu_shard_ctx) = debug_shard_ctx_template { + let mut cpu_witness = ZKVMWitnesses::default(); + let mut cpu_dispatch_ctx = system_config.inst_dispatch_builder.to_dispatch_ctx(); + cpu_dispatch_ctx.begin_shard(); + for (step_idx, step) in shard_steps.iter().enumerate() { + cpu_dispatch_ctx.ingest_step(step_idx, step); + } + + // Force CPU path for the debug comparison (thread-local, no env var races). + #[cfg(feature = "gpu")] + crate::instructions::gpu::dispatch::set_force_cpu_path(true); + + system_config + .config + .assign_opcode_circuit( + &system_config.zkvm_cs, + &mut cpu_shard_ctx, + &mut cpu_dispatch_ctx, + shard_steps, + &mut cpu_witness, + ) + .unwrap(); + system_config + .dummy_config + .assign_opcode_circuit( + &system_config.zkvm_cs, + &mut cpu_shard_ctx, + &cpu_dispatch_ctx, + shard_steps, + &mut cpu_witness, + ) + .unwrap(); + cpu_witness.finalize_lk_multiplicities(); + + #[cfg(feature = "gpu")] + crate::instructions::gpu::dispatch::set_force_cpu_path(false); + + log_shard_ctx_diff("post_opcode_assignment", &cpu_shard_ctx, &shard_ctx); + + // Compare combined_lk_mlt (the merged LK after finalize_lk_multiplicities). + // This catches issues where per-chip LK appears correct but the merge differs. + log_combined_lk_diff(&cpu_witness, &zkvm_witness); + } // Memory record routing (per address / waddr) // @@ -1389,110 +1612,102 @@ pub fn generate_witness<'a, E: ExtensionField>( // ├─ later rw? NO -> ShardRAM + LocalFinalize // └─ later rw? YES -> ShardRAM - let time = std::time::Instant::now(); - system_config - .config - .assign_table_circuit(&system_config.zkvm_cs, &mut zkvm_witness) - .unwrap(); - tracing::debug!("assign_table_circuit finish in {:?}", time.elapsed()); + info_span!("assign_table_circuits").in_scope(|| { + system_config + .config + .assign_table_circuit(&system_config.zkvm_cs, &mut zkvm_witness) + }).unwrap(); + + info_span!("assign_init_table").in_scope(|| { + if shard_ctx.is_first_shard() { + system_config + .mmu_config + .assign_init_table_circuit( + &system_config.zkvm_cs, + &mut zkvm_witness, + &pi, + &emul_result.final_mem_state.reg, + &emul_result.final_mem_state.mem, + &emul_result.final_mem_state.io, + &emul_result.final_mem_state.stack, + ) + } else { + system_config + .mmu_config + .assign_init_table_circuit( + &system_config.zkvm_cs, + &mut zkvm_witness, + &pi, + &[], + &[], + &[], + &[], + ) + } + }).unwrap(); - if shard_ctx.is_first_shard() { - let time = std::time::Instant::now(); + info_span!("assign_dynamic_init_table").in_scope(|| { system_config .mmu_config - .assign_init_table_circuit( + .assign_dynamic_init_table_circuit( &system_config.zkvm_cs, &mut zkvm_witness, &pi, - &emul_result.final_mem_state.reg, - &emul_result.final_mem_state.mem, - &emul_result.final_mem_state.io, - &emul_result.final_mem_state.stack, + &emul_result.final_mem_state.hints, + &emul_result.final_mem_state.heap, ) - .unwrap(); - tracing::debug!("assign_init_table_circuit finish in {:?}", time.elapsed()); - } else { + }).unwrap(); + + info_span!("assign_continuation").in_scope(|| { system_config .mmu_config - .assign_init_table_circuit( + .assign_continuation_circuit( &system_config.zkvm_cs, + &shard_ctx, &mut zkvm_witness, &pi, - &[], - &[], - &[], - &[], + &emul_result.final_mem_state.reg, + &emul_result.final_mem_state.mem, + &emul_result.final_mem_state.io, + &emul_result.final_mem_state.hints, + &emul_result.final_mem_state.stack, + &emul_result.final_mem_state.heap, ) - .unwrap(); - } + }).unwrap(); - let time = std::time::Instant::now(); - system_config - .mmu_config - .assign_dynamic_init_table_circuit( - &system_config.zkvm_cs, - &mut zkvm_witness, - &pi, - &emul_result.final_mem_state.hints, - &emul_result.final_mem_state.heap, - ) - .unwrap(); - tracing::debug!( - "assign_dynamic_init_table_circuit finish in {:?}", - time.elapsed() - ); - let time = std::time::Instant::now(); - system_config - .mmu_config - .assign_continuation_circuit( - &system_config.zkvm_cs, - &shard_ctx, - &mut zkvm_witness, - &pi, - &emul_result.final_mem_state.reg, - &emul_result.final_mem_state.mem, - &emul_result.final_mem_state.io, - &emul_result.final_mem_state.hints, - &emul_result.final_mem_state.stack, - &emul_result.final_mem_state.heap, - ) - .unwrap(); - tracing::debug!("assign_continuation_circuit finish in {:?}", time.elapsed()); - - let time = std::time::Instant::now(); - zkvm_witness - .assign_table_circuit::>( - &system_config.zkvm_cs, - &system_config.prog_config, - &program, - ) - .unwrap(); - tracing::debug!("assign_table_circuit finish in {:?}", time.elapsed()); + info_span!("assign_program_table").in_scope(|| { + zkvm_witness + .assign_table_circuit::>( + &system_config.zkvm_cs, + &system_config.prog_config, + &program, + ) + }).unwrap(); if let Some(shard_ram_witnesses) = zkvm_witness.get_witness(&ShardRamCircuit::::name()) { - let time = std::time::Instant::now(); - let shard_ram_ec_sum: SepticPoint = shard_ram_witnesses - .iter() - .filter(|shard_ram_witness| shard_ram_witness.num_instances[0] > 0) - .map(|shard_ram_witness| { - ShardRamCircuit::::extract_ec_sum( - &system_config.mmu_config.ram_bus_circuit, - &shard_ram_witness.witness_rmms[0], - ) - }) - .sum(); - - let xy = shard_ram_ec_sum - .x - .0 - .iter() - .chain(shard_ram_ec_sum.y.0.iter()); - for (f, v) in xy.zip_eq(pi.shard_rw_sum.as_mut_slice()) { - *v = f.to_canonical_u64() as u32; - } - tracing::debug!("update pi shard_rw_sum finish in {:?}", time.elapsed()); + info_span!("shard_ram_ec_sum").in_scope(|| { + let shard_ram_ec_sum: SepticPoint = shard_ram_witnesses + .iter() + .filter(|shard_ram_witness| shard_ram_witness.num_instances[0] > 0) + .map(|shard_ram_witness| { + ShardRamCircuit::::extract_ec_sum( + &system_config.mmu_config.ram_bus_circuit, + &shard_ram_witness.witness_rmms[0], + ) + }) + .sum(); + + let xy = shard_ram_ec_sum + .x + .0 + .iter() + .chain(shard_ram_ec_sum.y.0.iter()); + for (f, v) in xy.zip_eq(pi.shard_rw_sum.as_mut_slice()) { + *v = f.to_canonical_u64() as u32; + } + }); } Some((zkvm_witness, shard_ctx, pi)) @@ -1872,16 +2087,29 @@ fn create_proofs_streaming< init_mem_state: &InitMemState, ) -> Vec> { let ctx = prover.pk.program_ctx.as_ref().unwrap(); + + // Two pipeline modes: + // + // Default GPU backend (CENO_GPU_ENABLE_WITGEN unset): + // Overlap: CPU witgen (thread A) and GPU prove (thread B) run in parallel. + // CPU produces witness for shard N+1 while GPU proves shard N. + // Uses a bounded(0) rendezvous channel for back-pressure. + // + // CENO_GPU_ENABLE_WITGEN=1 (GPU witgen) or CPU-only build: + // Sequential: witgen → prove, one shard at a time. + // When GPU witgen is on, GPU is shared between witgen and proving. + let proofs = info_span!("[ceno] app_prove.inner").in_scope(|| { #[cfg(feature = "gpu")] { - use crossbeam::channel; - let (tx, rx) = channel::bounded(0); - std::thread::scope(|s| { - // pipeline cpu/gpu workload - // cpu producer - s.spawn({ - move || { + // With GPU feature: check if GPU witgen is enabled to select pipeline mode. + if !crate::instructions::gpu::config::is_gpu_witgen_enabled() { + // Default: overlap CPU witgen with GPU proving. + use crossbeam::channel; + let (tx, rx) = channel::bounded(0); + return std::thread::scope(|s| { + // CPU producer: generate witness shards + s.spawn(move || { let wit_iter = generate_witness( &ctx.system_config, emulation_result, @@ -1891,11 +2119,12 @@ fn create_proofs_streaming< target_shard_id, ); - let wit_iter = if let Some(target_shard_id) = target_shard_id { - Box::new(wit_iter.skip(target_shard_id)) as Box> - } else { - Box::new(wit_iter) - }; + let wit_iter: Box> = + if let Some(target_shard_id) = target_shard_id { + Box::new(wit_iter.skip(target_shard_id)) + } else { + Box::new(wit_iter) + }; for proof_input in wit_iter { if tx.send(proof_input).is_err() { @@ -1905,14 +2134,10 @@ fn create_proofs_streaming< break; } } - } - }); + }); - // gpu consumer - { + // GPU consumer: prove each shard as it arrives let mut proofs = Vec::new(); - let mut proof_err = None; - let rx = rx; while let Ok((zkvm_witness, shard_ctx, pi)) = rx.recv() { if is_mock_proving { MockProver::assert_satisfied_full( @@ -1928,33 +2153,25 @@ fn create_proofs_streaming< let transcript = Transcript::new(b"riscv"); let start = std::time::Instant::now(); - match prover.create_proof(&shard_ctx, zkvm_witness, pi, transcript) { - Ok(zkvm_proof) => { - tracing::debug!( - "{}th shard proof created in {:?}", - shard_ctx.shard_id, - start.elapsed() - ); - proofs.push(zkvm_proof); - } - Err(err) => { - proof_err = Some(err); - break; - } - } - } - drop(rx); - if let Some(err) = proof_err { - panic!("create_proof failed: {err:?}"); + let zkvm_proof = prover + .create_proof(&shard_ctx, zkvm_witness, pi, transcript) + .expect("create_proof failed"); + tracing::debug!( + "{}th shard proof created in {:?}", + shard_ctx.shard_id, + start.elapsed() + ); + proofs.push(zkvm_proof); } proofs - } - }) + }); + } + // Fall through: GPU witgen enabled → sequential path below. } - #[cfg(not(feature = "gpu"))] + // Sequential: witgen → prove, one shard at a time. + // Used by: GPU witgen mode (CENO_GPU_ENABLE_WITGEN=1) and CPU-only builds. { - // Generate witness let wit_iter = generate_witness( &ctx.system_config, emulation_result, @@ -1964,11 +2181,12 @@ fn create_proofs_streaming< target_shard_id, ); - let wit_iter = if let Some(target_shard_id) = target_shard_id { - Box::new(wit_iter.skip(target_shard_id)) as Box> - } else { - Box::new(wit_iter) - }; + let wit_iter: Box> = + if let Some(target_shard_id) = target_shard_id { + Box::new(wit_iter.skip(target_shard_id)) + } else { + Box::new(wit_iter) + }; wit_iter .map(|(zkvm_witness, shard_ctx, pi)| { @@ -1994,7 +2212,6 @@ fn create_proofs_streaming< shard_ctx.shard_id, start.elapsed() ); - // only show e2e stats in cpu mode tracing::info!("e2e proof stat: {}", zkvm_proof); zkvm_proof }) @@ -2042,6 +2259,199 @@ pub fn run_e2e_verify>( } } +fn clone_debug_shard_ctx(src: &ShardContext) -> ShardContext<'static> { + ShardContext { + shard_id: src.shard_id, + num_shards: src.num_shards, + max_cycle: src.max_cycle, + addr_future_accesses: src.addr_future_accesses.clone(), + sorted_next_accesses: src.sorted_next_accesses.clone(), + cur_shard_cycle_range: src.cur_shard_cycle_range.clone(), + expected_inst_per_shard: src.expected_inst_per_shard, + max_num_cross_shard_accesses: src.max_num_cross_shard_accesses, + prev_shard_cycle_range: src.prev_shard_cycle_range.clone(), + prev_shard_heap_range: src.prev_shard_heap_range.clone(), + prev_shard_hint_range: src.prev_shard_hint_range.clone(), + platform: src.platform.clone(), + shard_heap_addr_range: src.shard_heap_addr_range.clone(), + shard_hint_addr_range: src.shard_hint_addr_range.clone(), + syscall_witnesses: src.syscall_witnesses.clone(), + ..Default::default() + } +} + +type FlatRecord = (u32, u64, u64, u64, u64, Option, u32, usize); + +fn flatten_ram_records(records: &[BTreeMap]) -> Vec { + let mut flat = Vec::new(); + for table in records { + for (addr, record) in table { + flat.push(( + addr.0, + record.reg_id, + record.prev_cycle, + record.cycle, + record.shard_cycle, + record.prev_value, + record.value, + record.shard_id, + )); + } + } + flat +} + +fn log_shard_ctx_diff(kind: &str, cpu: &ShardContext, gpu: &ShardContext) { + let cpu_addr = cpu.get_addr_accessed(); + let gpu_addr = gpu.get_addr_accessed(); + if cpu_addr != gpu_addr { + tracing::error!( + "[GPU e2e debug] {} addr_accessed cpu={} gpu={}", + kind, + cpu_addr.len(), + gpu_addr.len() + ); + } + + let cpu_reads = flatten_ram_records(cpu.read_records()); + let gpu_reads = flatten_ram_records(gpu.read_records()); + if cpu_reads != gpu_reads { + tracing::error!( + "[GPU e2e debug] {} read_records cpu={} gpu={}", + kind, + cpu_reads.len(), + gpu_reads.len() + ); + } + + let cpu_writes = flatten_ram_records(cpu.write_records()); + let gpu_writes = flatten_ram_records(gpu.write_records()); + if cpu_writes != gpu_writes { + tracing::error!( + "[GPU e2e debug] {} write_records cpu={} gpu={}", + kind, + cpu_writes.len(), + gpu_writes.len() + ); + } +} + +fn log_combined_lk_diff( + cpu_witness: &ZKVMWitnesses, + gpu_witness: &ZKVMWitnesses, +) { + let cpu_combined = cpu_witness.combined_lk_mlt().expect("cpu combined_lk_mlt"); + let gpu_combined = gpu_witness.combined_lk_mlt().expect("gpu combined_lk_mlt"); + + let table_names = [ + "Dynamic", + "DoubleU8", + "And", + "Or", + "Xor", + "Ltu", + "Pow", + "Instruction", + ]; + + let mut total_diffs = 0usize; + for (table_idx, (cpu_table, gpu_table)) in + cpu_combined.iter().zip(gpu_combined.iter()).enumerate() + { + let mut keys: Vec = cpu_table.keys().chain(gpu_table.keys()).copied().collect(); + keys.sort_unstable(); + keys.dedup(); + + let mut table_diffs = 0usize; + for &key in &keys { + let cpu_count = cpu_table.get(&key).copied().unwrap_or(0); + let gpu_count = gpu_table.get(&key).copied().unwrap_or(0); + if cpu_count != gpu_count { + table_diffs += 1; + if table_diffs <= 8 { + let name = table_names.get(table_idx).unwrap_or(&"Unknown"); + tracing::error!( + "[GPU e2e debug] combined_lk table={} key={} cpu={} gpu={}", + name, + key, + cpu_count, + gpu_count + ); + } + } + } + total_diffs += table_diffs; + if table_diffs > 8 { + let name = table_names.get(table_idx).unwrap_or(&"Unknown"); + tracing::error!( + "[GPU e2e debug] combined_lk table={} total_diffs={} (showing first 8)", + name, + table_diffs + ); + } + } + + // Also compare per-chip LK multiplicities + let cpu_lk_keys: std::collections::BTreeSet<_> = cpu_witness.lk_mlts().keys().collect(); + let gpu_lk_keys: std::collections::BTreeSet<_> = gpu_witness.lk_mlts().keys().collect(); + if cpu_lk_keys != gpu_lk_keys { + tracing::error!( + "[GPU e2e debug] lk_mlts key mismatch: cpu_only={:?} gpu_only={:?}", + cpu_lk_keys.difference(&gpu_lk_keys).collect::>(), + gpu_lk_keys.difference(&cpu_lk_keys).collect::>(), + ); + } + for name in cpu_lk_keys.intersection(&gpu_lk_keys) { + let cpu_lk = cpu_witness.lk_mlts().get(*name).unwrap(); + let gpu_lk = gpu_witness.lk_mlts().get(*name).unwrap(); + let mut chip_diffs = 0usize; + for (t_idx, (ct, gt)) in cpu_lk.iter().zip(gpu_lk.iter()).enumerate() { + let mut ks: Vec = ct.keys().chain(gt.keys()).copied().collect(); + ks.sort_unstable(); + ks.dedup(); + for &k in &ks { + let cv = ct.get(&k).copied().unwrap_or(0); + let gv = gt.get(&k).copied().unwrap_or(0); + if cv != gv { + chip_diffs += 1; + if chip_diffs <= 4 { + let tname = table_names.get(t_idx).unwrap_or(&"Unknown"); + tracing::error!( + "[GPU e2e debug] per_chip_lk chip={} table={} key={} cpu={} gpu={}", + name, + tname, + k, + cv, + gv + ); + } + } + } + } + if chip_diffs > 0 { + total_diffs += chip_diffs; + tracing::error!( + "[GPU e2e debug] per_chip_lk chip={} total_diffs={}", + name, + chip_diffs + ); + } + } + + if total_diffs == 0 { + tracing::info!( + "[GPU e2e debug] combined_lk_mlt + per_chip_lk: CPU/GPU match (tables={}, chips={})", + cpu_combined.len(), + cpu_lk_keys.len() + ); + } else { + tracing::error!( + "[GPU e2e debug] TOTAL LK DIFFS = {} (combined + per-chip)", + total_diffs + ); + } +} + #[cfg(debug_assertions)] fn debug_memory_ranges<'a, T: Tracer, I: Iterator>( vm: &VMState, @@ -2122,7 +2532,9 @@ pub fn verify + serde::Ser #[cfg(test)] mod tests { use crate::e2e::{MultiProver, ShardContextBuilder}; - use ceno_emul::{CENO_PLATFORM, Cycle, FullTracer, NextCycleAccess, StepIndex, StepRecord}; + use ceno_emul::{ + CENO_PLATFORM, Cycle, FullTracer, NextCycleAccess, StepIndex, StepRecord, SyscallWitness, + }; use itertools::Itertools; use std::sync::Arc; @@ -2182,6 +2594,7 @@ mod tests { shard_cycle_boundaries, max_cycle, NextCycleAccess::default(), + Vec::new(), ); struct TestReplay { steps: Vec, @@ -2224,6 +2637,10 @@ mod tests { fn step_record(&self, idx: StepIndex) -> &StepRecord { &self.steps[self.shard_start + idx] } + + fn syscall_witnesses(&self) -> &[SyscallWitness] { + &[] // Test replay doesn't track syscalls + } } let mut steps_iter = TestReplay::new(steps); diff --git a/ceno_zkvm/src/gadgets/poseidon2.rs b/ceno_zkvm/src/gadgets/poseidon2.rs index 1dbfcccb5..5e079489b 100644 --- a/ceno_zkvm/src/gadgets/poseidon2.rs +++ b/ceno_zkvm/src/gadgets/poseidon2.rs @@ -78,8 +78,8 @@ pub struct Poseidon2Config< const HALF_FULL_ROUNDS: usize, const PARTIAL_ROUNDS: usize, > { - p3_cols: Vec, // columns in the plonky3-air - post_linear_layer_cols: Vec, /* additional columns to hold the state after linear layers */ + pub(crate) p3_cols: Vec, // columns in the plonky3-air + pub(crate) post_linear_layer_cols: Vec, /* additional columns to hold the state after linear layers */ constants: RoundConstants, } diff --git a/ceno_zkvm/src/gadgets/signed_ext.rs b/ceno_zkvm/src/gadgets/signed_ext.rs index 4be082386..90d489f2c 100644 --- a/ceno_zkvm/src/gadgets/signed_ext.rs +++ b/ceno_zkvm/src/gadgets/signed_ext.rs @@ -44,6 +44,11 @@ impl SignedExtendConfig { self.msb.expr() } + #[allow(dead_code)] // used by GPU column map extraction (cfg gated) + pub(crate) fn msb(&self) -> WitIn { + self.msb + } + fn construct_circuit( cb: &mut CircuitBuilder, n_bits: usize, diff --git a/ceno_zkvm/src/instructions.rs b/ceno_zkvm/src/instructions.rs index 9dd99ef92..da7bb43b9 100644 --- a/ceno_zkvm/src/instructions.rs +++ b/ceno_zkvm/src/instructions.rs @@ -19,12 +19,17 @@ use rayon::{ }; use witness::{InstancePaddingStrategy, RowMajorMatrix}; +pub mod gpu; pub mod riscv; +pub use gpu::utils::{cpu_assign_instances, cpu_collect_lk_and_shardram, cpu_collect_shardram}; + pub trait Instruction { type InstructionConfig: Send + Sync; type InsnType: Clone + Copy; + const GPU_LK_SHARDRAM: bool = false; + fn padding_strategy() -> InstancePaddingStrategy { InstancePaddingStrategy::Default } @@ -96,6 +101,36 @@ pub trait Instruction { step: &StepRecord, ) -> Result<(), ZKVMError>; + fn collect_lk_and_shardram( + _config: &Self::InstructionConfig, + _shard_ctx: &mut ShardContext, + _lk_multiplicity: &mut LkMultiplicity, + _step: &StepRecord, + ) -> Result<(), ZKVMError> { + Err(ZKVMError::InvalidWitness( + format!( + "{} does not implement lk and shardram collection", + Self::name() + ) + .into(), + )) + } + + fn collect_shardram( + _config: &Self::InstructionConfig, + _shard_ctx: &mut ShardContext, + _lk_multiplicity: &mut LkMultiplicity, + _step: &StepRecord, + ) -> Result<(), ZKVMError> { + Err(ZKVMError::InvalidWitness( + format!( + "{} does not implement shardram-only collection", + Self::name() + ) + .into(), + )) + } + fn assign_instances( config: &Self::InstructionConfig, shard_ctx: &mut ShardContext, @@ -190,3 +225,148 @@ pub trait Instruction { pub fn full_step_indices(steps: &[StepRecord]) -> Vec { (0..steps.len()).collect() } + +// --------------------------------------------------------------------------- +// Macros to reduce per-chip boilerplate +// --------------------------------------------------------------------------- + +/// Implement `collect_lk_and_shardram` with a common prologue +/// (create `CpuLkShardramSink`, dispatch to `config.$field.emit_lk_and_shardram`) +/// and a chip-specific body for additional LK ops. +/// +/// The closure receives `(sink, step, config, ctx)`: +/// - `sink: &mut CpuLkShardramSink` — emit LK ops and send events +/// - `step: &StepRecord` — current step +/// - `config: &Self::InstructionConfig` — circuit config (for sub-configs) +/// - `ctx: &ShardContext` — read-only shard context +/// +/// Usage inside `impl Instruction for MyChip`: +/// ```ignore +/// impl_collect_lk_and_shardram!(r_insn, |sink, step, _config, _ctx| { +/// emit_u16_limbs(sink, step.rd().unwrap().value.after); +/// }); +/// ``` +#[macro_export] +macro_rules! impl_collect_lk_and_shardram { + ($field:ident, |$sink:ident, $step:ident, $config:ident, $ctx:ident| $body:block) => { + fn collect_lk_and_shardram( + config: &Self::InstructionConfig, + shard_ctx: &mut $crate::e2e::ShardContext, + lk_multiplicity: &mut $crate::witness::LkMultiplicity, + step: &ceno_emul::StepRecord, + ) -> Result<(), $crate::error::ZKVMError> { + let shard_ctx_ptr = shard_ctx as *mut $crate::e2e::ShardContext; + let _ctx = unsafe { &*shard_ctx_ptr }; + let mut _sink_val = unsafe { + $crate::instructions::gpu::utils::CpuLkShardramSink::from_raw( + shard_ctx_ptr, + lk_multiplicity, + ) + }; + config.$field.emit_lk_and_shardram(&mut _sink_val, _ctx, step); + let $sink = &mut _sink_val; + let $step = step; + let $config = config; + let $ctx = _ctx; + $body + Ok(()) + } + }; +} + +/// Implement `collect_shardram` by delegating to +/// `config.$field.emit_shardram(shard_ctx, lk_multiplicity, step)`. +/// +/// Every chip's implementation is identical except for the config field name +/// (`r_insn`, `i_insn`, `b_insn`, `s_insn`, `j_insn`, `im_insn`). +/// +/// Usage inside `impl Instruction for MyChip`: +/// ```ignore +/// impl_collect_shardram!(r_insn); +/// ``` +#[macro_export] +macro_rules! impl_collect_shardram { + ($field:ident) => { + fn collect_shardram( + config: &Self::InstructionConfig, + shard_ctx: &mut $crate::e2e::ShardContext, + lk_multiplicity: &mut $crate::witness::LkMultiplicity, + step: &ceno_emul::StepRecord, + ) -> Result<(), $crate::error::ZKVMError> { + config + .$field + .emit_shardram(shard_ctx, lk_multiplicity, step); + Ok(()) + } + }; +} + +/// Implement the `#[cfg(feature = "gpu")] fn assign_instances` override that: +/// 1. Computes `Option` from `$kind_expr` +/// 2. Tries `try_gpu_assign_instances` → returns on success +/// 3. Falls back to `cpu_assign_instances` +/// +/// Usage inside `impl Instruction for MyChip`: +/// ```ignore +/// // Single kind (always GPU): +/// impl_gpu_assign!(GpuWitgenKind::Lui); +/// +/// // Match expression → Option: +/// impl_gpu_assign!(match I::INST_KIND { +/// InsnKind::ADD => Some(GpuWitgenKind::Add), +/// InsnKind::SUB => Some(GpuWitgenKind::Sub), +/// _ => None, +/// }); +/// ``` +#[macro_export] +macro_rules! impl_gpu_assign { + // Match/block → Option + (match $($rest:tt)*) => { + $crate::impl_gpu_assign!(@ match $($rest)*); + }; + // Single kind — always use GPU + ($kind:expr) => { + $crate::impl_gpu_assign!(@ Some($kind)); + }; + (@ $kind_expr:expr) => { + #[cfg(feature = "gpu")] + fn assign_instances( + config: &Self::InstructionConfig, + shard_ctx: &mut $crate::e2e::ShardContext, + num_witin: usize, + num_structural_witin: usize, + shard_steps: &[ceno_emul::StepRecord], + step_indices: &[ceno_emul::StepIndex], + ) -> Result< + ( + $crate::tables::RMMCollections, + gkr_iop::utils::lk_multiplicity::Multiplicity, + ), + $crate::error::ZKVMError, + > { + use $crate::instructions::gpu::dispatch; + let gpu_kind: Option = $kind_expr; + if let Some(kind) = gpu_kind { + if let Some(result) = dispatch::try_gpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + kind, + )? { + return Ok(result); + } + } + $crate::instructions::cpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + ) + } + }; +} diff --git a/ceno_zkvm/src/instructions/gpu/README.md b/ceno_zkvm/src/instructions/gpu/README.md new file mode 100644 index 000000000..565b73f46 --- /dev/null +++ b/ceno_zkvm/src/instructions/gpu/README.md @@ -0,0 +1,214 @@ +# GPU Witness Generation + +Accelerate witness generation by offloading computation from CPU to GPU. +This module (`ceno_zkvm/src/instructions/gpu/`) contains all GPU-side dispatch, +caching, and utility code for the witness generation pipeline. + +The CUDA backend lives in the sibling repo `ceno-gpu/` (`cuda_hal/src/common/witgen/`). + +## Architecture + +### Module Layout + +``` +gpu/ +├── dispatch.rs — GPU dispatch entry point (try_gpu_assign_instances, gpu_fill_witness) +├── config.rs — Environment variable config (3 env vars), kind tags +├── cache.rs — Thread-local device buffer caching, shared EC/addr buffers +├── chips/ — Per-chip column map extractors + chip-specific GPU dispatch +│ ├── add.rs ... sw.rs (24 RV32IM column map extractors) +│ ├── keccak.rs (column map + keccak GPU dispatch: gpu_assign_keccak_instances) +│ └── shard_ram.rs (column map + batch EC computation: gpu_batch_continuation_ec) +├── utils/ +│ ├── column_map.rs — Shared column map extraction helpers (extract_rs1, extract_rd, ...) +│ ├── d2h.rs — Device-to-host: witness transpose, LK counter decode, compact EC D2H +│ ├── debug_compare.rs— GPU vs CPU comparison (activated by CENO_GPU_DEBUG_COMPARE_WITGEN) +│ ├── lk_ops.rs — LkOp enum, SendEvent struct +│ ├── sink.rs — LkShardramSink trait, CpuLkShardramSink +│ ├── emit.rs — Emit helper functions (emit_u16_limbs, emit_logic_u8_ops, ...) +│ ├── fallback.rs — CPU fallback: cpu_assign_instances, cpu_collect_lk_and_shardram +│ └── test_helpers.rs — Test utilities: assert_witness_colmajor_eq, assert_full_gpu_pipeline +└── mod.rs — Module declarations + lk_shardram integration tests (19 tests) +``` + +### Data Flow + +``` + Pass 1: PreflightTracer + ┌──────────────────────┐ + │ ShardPlanBuilder │ → shard boundaries + │ PackedNextAccessEntry│ → sorted future-access table + └──────────┬───────────┘ + │ + Pass 2: FullTracer (per shard) + ┌──────────▼───────────┐ + │ Vec │ 136 bytes/step, #[repr(C)] + └──────────┬───────────┘ + │ H2D (cached per shard in cache.rs) + ┌──────────▼───────────────────────────────────┐ + │ GPU Per-Instruction │ + │ ┌─────────────┬──────────────┬────────────┐ │ + │ │ F-1 Witness │ F-2 LK Count │ F-3 EC/Addr│ │ + │ │ (col-major) │ (atomics) │ (shared buf)│ │ + │ └──────┬──────┴──────┬───────┴─────┬──────┘ │ + └─────────┼─────────────┼─────────────┼────────┘ + │ │ │ + GPU transpose D2H counters flush at shard end + │ │ │ + ┌─────────▼─────────────▼─────────────▼────────┐ + │ CPU Merge │ + │ RowMajorMatrix LkMultiplicity ShardContext │ + └──────────────────────┬───────────────────────┘ + │ + ┌──────────────────────▼───────────────────────┐ + │ ShardRamCircuit (GPU) │ + │ Phase 1: per-row Poseidon2 (344 cols) │ + │ Phase 2: binary EC tree (layer-by-layer) │ + └──────────────────────┬───────────────────────┘ + │ + ▼ + Proof Generation +``` + +### Per-Shard Pipeline + +Within `generate_witness()` (e2e.rs), each shard executes: + +1. **upload_shard_steps_cached** — H2D `Vec` (cached, shared across all chips) +2. **ensure_shard_metadata_cached** — H2D shard scalars + allocate shared EC/addr buffers +3. **Per-chip dispatch** — `gpu_fill_witness` matches `GpuWitgenKind` → 22 kernel variants + - Each kernel writes: witness columns (col-major), LK counters (atomics), EC records + addr (shared buffers) +4. **flush_shared_ec_buffers** — D2H shared EC records + addr_accessed into `ShardContext` +5. **invalidate_shard_steps_cache** — Free GPU shard_steps memory +6. **assign_shared_circuit** — ShardRamCircuit GPU pipeline (Poseidon2 + EC tree) + +### GPU/CPU Decision (dispatch.rs) + +``` +try_gpu_assign_instances(): + 1. is_gpu_witgen_enabled()? → CPU fallback if not set + 2. is_force_cpu_path() thread-local? → CPU fallback (debug comparison) + 3. I::GPU_LK_SHARDRAM == false? → CPU fallback + 4. is_kind_disabled(kind)? → CPU fallback + 5. Field != BabyBear? → CPU fallback + 6. get_cuda_hal() unavailable? → CPU fallback + 7. All pass → GPU path +``` + +### Keccak Dispatch + +Keccak has a dedicated GPU dispatch path (`chips/keccak.rs::gpu_assign_keccak_instances`) +separate from `try_gpu_assign_instances` because: +1. **Rotation**: each instance spans 32 rows (not 1), requiring `new_by_rotation` +2. **Structural witness**: 3 selectors (sel_first/sel_last/sel_all) vs the standard 1 +3. **Input packing**: needs `packed_instances` with `syscall_witnesses` + +The LK/shardram collection logic is identical to the standard path. + +### Lk and Shardram Collection + +After GPU computes the witness matrix, LK multiplicities and shard RAM records +are collected through one of several paths (priority order): + +| Path | Witness | LK Multiplicity | Shard Records | When | +|------|---------|-----------------|---------------|------| +| **A** Shared buffer | GPU | GPU counters → D2H | Shared GPU buffer (deferred) | Default for all verified kinds | +| **B** Compact EC | GPU | GPU counters → D2H | Compact EC D2H per-kernel | Older non-shared-buffer kinds | +| **C** CPU shardram | GPU | GPU counters → D2H | CPU `cpu_collect_shardram` | GPU shard unverified | +| **D** CPU full | GPU | CPU `cpu_collect_lk_and_shardram` | CPU full | GPU LK unverified | +| **E** CPU only | CPU | CPU `assign_instance` | CPU `assign_instance` | GPU unavailable | + +Currently all non-Keccak kinds use **Path A**. Paths B-E are fallback/debug paths. + +## E2E Pipeline Modes (e2e.rs) + +``` +create_proofs_streaming() +│ +├─ Default GPU backend (CENO_GPU_ENABLE_WITGEN unset): +│ Overlap pipeline: +│ Thread A (CPU): witgen(shard 0) → witgen(shard 1) → witgen(shard 2) → ... +│ Thread B (GPU): ................prove(shard 0) → prove(shard 1) → ... +│ crossbeam::bounded(0) rendezvous channel for back-pressure +│ +└─ CENO_GPU_ENABLE_WITGEN=1 (GPU witgen) or CPU-only build: + Sequential pipeline: + witgen(shard 0) → prove(shard 0) → witgen(shard 1) → prove(shard 1) → ... + GPU shared between witgen and proving; no overlap possible. +``` + +## Environment Variables + +| Variable | Default | Purpose | +|----------|---------|---------| +| `CENO_GPU_ENABLE_WITGEN` | unset (CPU witgen) | Set to enable GPU witness generation. Sequential witgen+prove pipeline. | +| `CENO_GPU_DISABLE_WITGEN_KINDS` | none | Comma-separated kind tags to disable specific chips' GPU path. Example: `add,keccak,lw`. Falls back to CPU for those chips. | +| `CENO_GPU_DEBUG_COMPARE_WITGEN` | unset | Enable GPU vs CPU comparison for all chips. Runs both paths and diffs results. | + +### `CENO_GPU_DEBUG_COMPARE_WITGEN` Coverage + +When set, the following comparisons run automatically: + +**Per-chip (in dispatch.rs, for each opcode circuit):** +- `debug_compare_final_lk` — GPU LK multiplicity vs CPU `assign_instance` baseline (all 8 lookup tables) +- `debug_compare_witness` — GPU witness matrix vs CPU witness (element-by-element, col-major vs row-major) +- `debug_compare_shardram` — GPU shard records (read_records, write_records, addr_accessed) vs CPU +- `debug_compare_shard_ec` — GPU compact EC records vs CPU-computed EC points (nonce, x[7], y[7]) + +**Per-chip, Keccak-specific (in chips/keccak.rs):** +- `debug_compare_keccak` — Combined witness + LK + shard comparison for keccak's rotation-aware layout + +**Per-shard, E2E level (in e2e.rs):** +- `log_shard_ctx_diff` — Full shard context comparison after all opcode circuits (addr_accessed, read/write records across all chips merged) +- `log_combined_lk_diff` — Merged LK multiplicities after `finalize_lk_multiplicities()` (catches cross-chip merge issues) + +All comparisons output to stderr via `eprintln!` / `tracing::error!`, with a default limit of 16 mismatches per category. + +## Tests + +**79 tests total** (`cargo test --features gpu,u16limb_circuit -p ceno_zkvm --lib -- "gpu"`) + +| Category | Count | Location | What it tests | +|----------|------:|----------|---------------| +| Column map extraction | 33 | `chips/*.rs` (31 via `test_colmap!` macro + 2 manual) | Circuit config → column map: all IDs in-range and unique | +| GPU witgen correctness | 23 | `chips/*.rs` | GPU kernel output vs CPU `assign_instance` (element-by-element witness comparison) | +| LK+shardram match | 19 | `gpu/mod.rs` | `collect_lk_and_shardram` / `collect_shardram` vs `assign_instance` baseline | +| LkOp encoding | 1 | `utils/mod.rs` | `LkOp::encode_all()` produces correct table/key pairs | +| EC point match | 1 | `scheme/septic_curve.rs` | GPU Poseidon2+SepticCurve EC point vs CPU `to_ec_point` | +| Poseidon2 sponge | 1 | `scheme/septic_curve.rs` | GPU Poseidon2 permutation vs CPU | +| Septic from_x | 1 | `scheme/septic_curve.rs` | GPU `septic_point_from_x` vs CPU | + +### Running Tests + +```bash +# All GPU tests (requires CUDA device) +CENO_GPU_ENABLE_WITGEN=1 cargo test --features gpu,u16limb_circuit -p ceno_zkvm --lib -- "gpu" + +# Column map tests only (no CUDA device needed) +cargo test --features gpu,u16limb_circuit -p ceno_zkvm --lib -- "test_extract_" + +# LK/shardram tests only (no CUDA device needed) +cargo test --features gpu,u16limb_circuit -p ceno_zkvm --lib -- "lk_shardram" + +# With debug comparison enabled +CENO_GPU_ENABLE_WITGEN=1 CENO_GPU_DEBUG_COMPARE_WITGEN=1 cargo test --features gpu,u16limb_circuit -p ceno_host -- test_elf +``` + +## Per-Chip Boilerplate Macros + +Three macros in `instructions.rs` reduce per-chip GPU integration to ~3 lines: + +```rust +impl Instruction for MyChip { + // Emit LK ops + shard RAM records (CPU companion for GPU witgen) + impl_collect_lk_and_shardram!(r_insn, |sink, step, _config, _ctx| { + emit_u16_limbs(sink, step.rd().unwrap().value.after); + }); + + // Collect shard RAM records only (when GPU handles LK) + impl_collect_shardram!(r_insn); + + // GPU dispatch: try GPU → fallback CPU + impl_gpu_assign!(dispatch::GpuWitgenKind::Add); +} +``` diff --git a/ceno_zkvm/src/instructions/gpu/cache.rs b/ceno_zkvm/src/instructions/gpu/cache.rs new file mode 100644 index 000000000..0797c436a --- /dev/null +++ b/ceno_zkvm/src/instructions/gpu/cache.rs @@ -0,0 +1,515 @@ +/// Device buffer caching for GPU witness generation. +/// +/// Manages thread-local caches for shard step data and shard metadata +/// device buffers, avoiding redundant host-to-device transfers within +/// the same shard. +use ceno_emul::{StepRecord, WordAddr}; +use ceno_gpu::{ + Buffer, CudaHal, CudaSlice, + bb31::{CudaHalBB31, ShardDeviceBuffers}, + common::witgen::types::{GpuShardRamRecord, GpuShardScalars}, +}; +use std::cell::RefCell; +use tracing::info_span; + +use crate::{e2e::ShardContext, error::ZKVMError}; + +/// Cached shard_steps device buffer with metadata for logging. +struct ShardStepsCache { + host_ptr: usize, + byte_len: usize, + shard_id: usize, + n_steps: usize, + device_buf: CudaSlice, +} + +// Thread-local cache for shard_steps device buffer. Invalidated when shard changes. +thread_local! { + static SHARD_STEPS_DEVICE: RefCell> = + const { RefCell::new(None) }; +} + +/// Upload shard_steps to GPU, reusing cached device buffer if the same data. +pub(crate) fn upload_shard_steps_cached( + hal: &CudaHalBB31, + shard_steps: &[StepRecord], + shard_id: usize, +) -> Result<(), ZKVMError> { + let ptr = shard_steps.as_ptr() as usize; + let byte_len = shard_steps.len() * std::mem::size_of::(); + + SHARD_STEPS_DEVICE.with(|cache| { + let mut cache = cache.borrow_mut(); + if let Some(c) = cache.as_ref() { + if c.host_ptr == ptr && c.byte_len == byte_len { + return Ok(()); // cache hit + } + } + // Cache miss: upload + let mb = byte_len as f64 / (1024.0 * 1024.0); + tracing::info!( + "[GPU witgen] uploading shard_steps: shard_id={}, n_steps={}, {:.2} MB", + shard_id, + shard_steps.len(), + mb, + ); + let bytes: &[u8] = + unsafe { std::slice::from_raw_parts(shard_steps.as_ptr() as *const u8, byte_len) }; + let device_buf = hal.inner.htod_copy_stream(None, bytes).map_err(|e| { + ZKVMError::InvalidWitness(format!("shard_steps H2D failed: {e}").into()) + })?; + *cache = Some(ShardStepsCache { + host_ptr: ptr, + byte_len, + shard_id, + n_steps: shard_steps.len(), + device_buf, + }); + Ok(()) + }) +} + +/// Borrow the cached device buffer for kernel launch. +/// Panics if `upload_shard_steps_cached` was not called first. +pub(crate) fn with_cached_shard_steps(f: impl FnOnce(&CudaSlice) -> R) -> R { + SHARD_STEPS_DEVICE.with(|cache| { + let cache = cache.borrow(); + let c = cache.as_ref().expect("shard_steps not uploaded"); + f(&c.device_buf) + }) +} + +/// Invalidate the cached shard_steps device buffer. +/// Call this when shard processing is complete to free GPU memory. +pub fn invalidate_shard_steps_cache() { + SHARD_STEPS_DEVICE.with(|cache| { + let mut cache = cache.borrow_mut(); + if let Some(c) = cache.as_ref() { + let mb = c.byte_len as f64 / (1024.0 * 1024.0); + tracing::info!( + "[GPU witgen] releasing shard_steps cache: shard_id={}, n_steps={}, {:.2} MB", + c.shard_id, + c.n_steps, + mb, + ); + } + *cache = None; + }); +} + +/// Cached shard metadata device buffers for GPU shard records. +/// Invalidated when shard_id changes; shared across all kernel invocations in one shard. +struct ShardMetadataCache { + shard_id: usize, + device_bufs: ShardDeviceBuffers, + /// Shared EC record buffer (owns the GPU memory, pointer stored in device_bufs). + shared_ec_buf: Option>, + /// Shared EC record count buffer (single u32 counter). + shared_ec_count: Option>, + /// Shared addr_accessed buffer (u32 word addresses). + shared_addr_buf: Option>, + /// Shared addr_accessed count buffer (single u32 counter). + shared_addr_count: Option>, +} + +thread_local! { + static SHARD_META_CACHE: RefCell> = + const { RefCell::new(None) }; +} + +/// Build and cache shard metadata device buffers for GPU shard records. +/// +/// FA (future access) device buffers are global and identical across all shards, +/// so they are uploaded once and reused via move. Only per-shard data (scalars + +/// prev_shard_ranges) is re-uploaded when the shard changes. +pub(crate) fn ensure_shard_metadata_cached( + hal: &CudaHalBB31, + shard_ctx: &ShardContext, + n_total_steps: usize, +) -> Result<(), ZKVMError> { + let shard_id = shard_ctx.shard_id; + SHARD_META_CACHE.with(|cache| { + let mut cache = cache.borrow_mut(); + if let Some(c) = cache.as_ref() { + if c.shard_id == shard_id { + return Ok(()); // cache hit + } + } + + // Move FA device buffer from previous cache (reuse across shards). + // FA data is global — identical across all shards — so we reuse, not re-upload. + let existing_fa = cache.take().map(|c| { + let ShardDeviceBuffers { + next_access_packed, + scalars: _, + prev_shard_cycle_range: _, + prev_shard_heap_range: _, + prev_shard_hint_range: _, + gpu_ec_shard_id: _, + shared_ec_out_ptr: _, + shared_ec_count_ptr: _, + shared_addr_out_ptr: _, + shared_addr_count_ptr: _, + shared_ec_capacity: _, + shared_addr_capacity: _, + } = c.device_bufs; + next_access_packed + }); + + let next_access_packed_device = if let Some(fa) = existing_fa { + fa // Reuse existing GPU memory — zero cost pointer move + } else { + // First shard: bulk H2D upload packed FA entries (no sort here) + let sorted = &shard_ctx.sorted_next_accesses; + tracing::info_span!("next_access_h2d").in_scope(|| -> Result<_, ZKVMError> { + let packed_bytes: &[u8] = if sorted.packed.is_empty() { + &[0u8; 16] // sentinel for empty + } else { + unsafe { + std::slice::from_raw_parts( + sorted.packed.as_ptr() as *const u8, + sorted.packed.len() + * std::mem::size_of::(), + ) + } + }; + let buf = hal + .inner + .htod_copy_stream(None, packed_bytes) + .map_err(|e| { + ZKVMError::InvalidWitness(format!("next_access_packed H2D: {e}").into()) + })?; + let next_access_device = ceno_gpu::common::buffer::BufferImpl::new(buf); + let mb = packed_bytes.len() as f64 / (1024.0 * 1024.0); + tracing::info!( + "[GPU shard] FA uploaded once: {} entries, {:.2} MB (packed)", + sorted.packed.len(), + mb, + ); + Ok(next_access_device) + })? + }; + + // Per-shard: always re-upload scalars + prev_shard_ranges + let scalars = GpuShardScalars { + shard_cycle_start: shard_ctx.cur_shard_cycle_range.start as u64, + shard_cycle_end: shard_ctx.cur_shard_cycle_range.end as u64, + shard_offset_cycle: shard_ctx.current_shard_offset_cycle(), + shard_id: shard_id as u32, + heap_start: shard_ctx.platform.heap.start, + heap_end: shard_ctx.platform.heap.end, + hint_start: shard_ctx.platform.hints.start, + hint_end: shard_ctx.platform.hints.end, + shard_heap_start: shard_ctx.shard_heap_addr_range.start, + shard_heap_end: shard_ctx.shard_heap_addr_range.end, + shard_hint_start: shard_ctx.shard_hint_addr_range.start, + shard_hint_end: shard_ctx.shard_hint_addr_range.end, + next_access_count: shard_ctx.sorted_next_accesses.packed.len() as u32, + num_prev_shards: shard_ctx.prev_shard_cycle_range.len() as u32, + num_prev_heap_ranges: shard_ctx.prev_shard_heap_range.len() as u32, + num_prev_hint_ranges: shard_ctx.prev_shard_hint_range.len() as u32, + }; + + let (scalars_device, pscr_device, pshr_device, pshi_device) = + tracing::info_span!("shard_scalars_h2d").in_scope(|| -> Result<_, ZKVMError> { + let scalars_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + &scalars as *const GpuShardScalars as *const u8, + std::mem::size_of::(), + ) + }; + let scalars_device = + hal.inner + .htod_copy_stream(None, scalars_bytes) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("shard scalars H2D failed: {e}").into(), + ) + })?; + + let pscr = &shard_ctx.prev_shard_cycle_range; + let pscr_device = hal + .alloc_u64_from_host(if pscr.is_empty() { &[0u64] } else { pscr }, None) + .map_err(|e| { + ZKVMError::InvalidWitness(format!("pscr H2D failed: {e}").into()) + })?; + + let pshr = &shard_ctx.prev_shard_heap_range; + let pshr_device = hal + .alloc_u32_from_host(if pshr.is_empty() { &[0u32] } else { pshr }, None) + .map_err(|e| { + ZKVMError::InvalidWitness(format!("pshr H2D failed: {e}").into()) + })?; + + let pshi = &shard_ctx.prev_shard_hint_range; + let pshi_device = hal + .alloc_u32_from_host(if pshi.is_empty() { &[0u32] } else { pshi }, None) + .map_err(|e| { + ZKVMError::InvalidWitness(format!("pshi H2D failed: {e}").into()) + })?; + + Ok((scalars_device, pscr_device, pshr_device, pshi_device)) + })?; + + tracing::info!( + "[GPU shard] shard_id={}: per-shard scalars updated", + shard_id, + ); + + // Allocate shared EC/addr compact buffers for this shard. + // + // EC records: cross-shard only (sparse subset of RAM ops). + // 104 bytes each (26 u32s). Cap at 16M entries ≈ 1.6 GB. + // Addr records: every gpu_send() emits one (dense). + // 4 bytes each (1 u32). Cap at 256M entries ≈ 1 GB. + let max_ops_per_step = 52u64; // keccak worst case + let total_ops_estimate = n_total_steps as u64 * max_ops_per_step; + let ec_capacity = total_ops_estimate.min(16 * 1024 * 1024) as usize; + let ec_u32s = ec_capacity * 26; // 26 u32s per GpuShardRamRecord (104 bytes) + let addr_capacity = total_ops_estimate.min(256 * 1024 * 1024) as usize; + + let shared_ec_buf = hal + .witgen + .alloc_u32_zeroed(ec_u32s, None) + .map_err(|e| ZKVMError::InvalidWitness(format!("shared_ec_buf alloc: {e}").into()))?; + let shared_ec_count = hal + .witgen + .alloc_u32_zeroed(1, None) + .map_err(|e| ZKVMError::InvalidWitness(format!("shared_ec_count alloc: {e}").into()))?; + let shared_addr_buf = hal + .witgen + .alloc_u32_zeroed(addr_capacity, None) + .map_err(|e| ZKVMError::InvalidWitness(format!("shared_addr_buf alloc: {e}").into()))?; + let shared_addr_count = hal.witgen.alloc_u32_zeroed(1, None).map_err(|e| { + ZKVMError::InvalidWitness(format!("shared_addr_count alloc: {e}").into()) + })?; + + let shared_ec_out_ptr = shared_ec_buf.device_ptr() as u64; + let shared_ec_count_ptr = shared_ec_count.device_ptr() as u64; + let shared_addr_out_ptr = shared_addr_buf.device_ptr() as u64; + let shared_addr_count_ptr = shared_addr_count.device_ptr() as u64; + + tracing::info!( + "[GPU shard] shard_id={}: shared buffers allocated: ec_capacity={}, addr_capacity={}", + shard_id, + ec_capacity, + addr_capacity, + ); + + *cache = Some(ShardMetadataCache { + shard_id, + device_bufs: ShardDeviceBuffers { + scalars: scalars_device, + next_access_packed: next_access_packed_device, + prev_shard_cycle_range: pscr_device, + prev_shard_heap_range: pshr_device, + prev_shard_hint_range: pshi_device, + gpu_ec_shard_id: Some(shard_id as u64), + shared_ec_out_ptr, + shared_ec_count_ptr, + shared_addr_out_ptr, + shared_addr_count_ptr, + shared_ec_capacity: ec_capacity as u32, + shared_addr_capacity: addr_capacity as u32, + }, + shared_ec_buf: Some(shared_ec_buf), + shared_ec_count: Some(shared_ec_count), + shared_addr_buf: Some(shared_addr_buf), + shared_addr_count: Some(shared_addr_count), + }); + Ok(()) + }) +} + +/// Borrow the cached shard device buffers for kernel launch. +pub(crate) fn with_cached_shard_meta(f: impl FnOnce(&ShardDeviceBuffers) -> R) -> R { + SHARD_META_CACHE.with(|cache| { + let cache = cache.borrow(); + let c = cache.as_ref().expect("shard metadata not uploaded"); + f(&c.device_bufs) + }) +} + +/// Borrow both cached device buffers (shard_steps + shard_meta) in one call. +/// Eliminates the nested `with_cached_shard_steps(|s| with_cached_shard_meta(|m| ...))` pattern. +pub(crate) fn with_cached_gpu_ctx( + f: impl FnOnce(&CudaSlice, &ShardDeviceBuffers) -> R, +) -> R { + SHARD_STEPS_DEVICE.with(|steps_cache| { + let steps = steps_cache.borrow(); + let s = steps.as_ref().expect("shard_steps not uploaded"); + SHARD_META_CACHE.with(|meta_cache| { + let meta = meta_cache.borrow(); + let m = meta.as_ref().expect("shard metadata not uploaded"); + f(&s.device_buf, &m.device_bufs) + }) + }) +} + +/// Invalidate the shard metadata cache (call when shard processing is complete). +pub fn invalidate_shard_meta_cache() { + SHARD_META_CACHE.with(|cache| { + *cache.borrow_mut() = None; + }); +} + +/// Take ownership of shared EC and addr_accessed device buffers from the cache. +/// +/// Returns (shared_ec_buf, ec_count, shared_addr_buf, addr_count) or None if unavailable. +/// The cache is invalidated after this call — must be called at most once per shard. +pub fn take_shared_device_buffers() -> Option { + SHARD_META_CACHE.with(|cache| { + let mut cache = cache.borrow_mut(); + let c = cache.as_mut()?; + + let ec_buf = c.shared_ec_buf.take()?; + let ec_count = c.shared_ec_count.take()?; + let addr_buf = c.shared_addr_buf.take()?; + let addr_count = c.shared_addr_count.take()?; + + Some(SharedDeviceBufferSet { + ec_buf, + ec_count, + addr_buf, + addr_count, + }) + }) +} + +/// Shared device buffers taken from the shard metadata cache. +pub struct SharedDeviceBufferSet { + pub ec_buf: ceno_gpu::common::buffer::BufferImpl<'static, u32>, + pub ec_count: ceno_gpu::common::buffer::BufferImpl<'static, u32>, + pub addr_buf: ceno_gpu::common::buffer::BufferImpl<'static, u32>, + pub addr_count: ceno_gpu::common::buffer::BufferImpl<'static, u32>, +} + +/// Read the current shared addr count from device (single u32 D2H). +/// Used by debug comparison to snapshot count before/after a kernel. +#[cfg(feature = "gpu")] +pub(crate) fn read_shared_addr_count() -> usize { + SHARD_META_CACHE.with(|cache| { + let cache = cache.borrow(); + let c = cache.as_ref().expect("shard metadata not cached"); + let buf = c + .shared_addr_count + .as_ref() + .expect("shared_addr_count not allocated"); + let v: Vec = buf.to_vec().expect("shared_addr_count D2H failed"); + v[0] as usize + }) +} + +/// Read a range of addr entries [start..end) from the shared addr buffer. +#[cfg(feature = "gpu")] +pub(crate) fn read_shared_addr_range(start: usize, end: usize) -> Vec { + if start >= end { + return Vec::new(); + } + SHARD_META_CACHE.with(|cache| { + let cache = cache.borrow(); + let c = cache.as_ref().expect("shard metadata not cached"); + let buf = c + .shared_addr_buf + .as_ref() + .expect("shared_addr_buf not allocated"); + let all: Vec = buf.to_vec_n(end).expect("shared_addr_buf D2H failed"); + all[start..end].to_vec() + }) +} + +/// Batch D2H of shared EC records and addr_accessed buffers after all kernel invocations. +/// +/// Called once per shard after all opcode `gpu_assign_instances_inner` calls complete. +/// Transfers accumulated EC records and addresses from shared GPU buffers into `shard_ctx`. +/// +/// If the shared buffers have already been taken by `take_shared_device_buffers` +/// (for the full GPU pipeline), this is a no-op. +pub fn flush_shared_ec_buffers(shard_ctx: &mut ShardContext) -> Result<(), ZKVMError> { + SHARD_META_CACHE.with(|cache| { + let cache = cache.borrow(); + let c = match cache.as_ref() { + Some(c) => c, + None => return Ok(()), // cache already invalidated — no-op + }; + + // If buffers have been taken by take_shared_device_buffers, skip D2H + let ec_count_buf = match c.shared_ec_count.as_ref() { + Some(b) => b, + None => { + tracing::debug!( + "[GPU shard] flush_shared_ec_buffers: buffers already taken, no-op" + ); + return Ok(()); + } + }; + let ec_count_vec: Vec = ec_count_buf + .to_vec() + .map_err(|e| ZKVMError::InvalidWitness(format!("shared_ec_count D2H: {e}").into()))?; + let ec_count = ec_count_vec[0] as usize; + let ec_capacity = c.device_bufs.shared_ec_capacity as usize; + + assert!( + ec_count <= ec_capacity, + "GPU shared EC buffer overflow: count={} > capacity={}. \ + Increase ec_capacity in ensure_shard_metadata_cached.", + ec_count, + ec_capacity, + ); + + if ec_count > 0 { + // D2H EC records (only the active portion) + let ec_buf = c.shared_ec_buf.as_ref().unwrap(); + let ec_u32s = ec_count * 26; // 26 u32s per GpuShardRamRecord + let raw_u32: Vec = ec_buf + .to_vec_n(ec_u32s) + .map_err(|e| ZKVMError::InvalidWitness(format!("shared_ec_buf D2H: {e}").into()))?; + let raw_bytes = unsafe { + std::slice::from_raw_parts(raw_u32.as_ptr() as *const u8, raw_u32.len() * 4) + }; + tracing::info!( + "[GPU shard] flush_shared_ec_buffers: {} EC records, {:.2} MB", + ec_count, + raw_bytes.len() as f64 / (1024.0 * 1024.0), + ); + shard_ctx.extend_gpu_ec_records_raw(raw_bytes); + } + + // D2H addr_accessed count + let addr_count_buf = c + .shared_addr_count + .as_ref() + .ok_or_else(|| ZKVMError::InvalidWitness("shared_addr_count not allocated".into()))?; + let addr_count_vec: Vec = addr_count_buf + .to_vec() + .map_err(|e| ZKVMError::InvalidWitness(format!("shared_addr_count D2H: {e}").into()))?; + let addr_count = addr_count_vec[0] as usize; + let addr_capacity = c.device_bufs.shared_addr_capacity as usize; + + assert!( + addr_count <= addr_capacity, + "GPU shared addr buffer overflow: count={} > capacity={}", + addr_count, + addr_capacity, + ); + + if addr_count > 0 { + let addr_buf = c.shared_addr_buf.as_ref().unwrap(); + let addrs: Vec = addr_buf.to_vec_n(addr_count).map_err(|e| { + ZKVMError::InvalidWitness(format!("shared_addr_buf D2H: {e}").into()) + })?; + tracing::info!( + "[GPU shard] flush_shared_ec_buffers: {} addr_accessed, {:.2} MB", + addr_count, + addr_count as f64 * 4.0 / (1024.0 * 1024.0), + ); + let mut forked = shard_ctx.get_forked(); + let thread_ctx = &mut forked[0]; + for &addr in &addrs { + thread_ctx.push_addr_accessed(WordAddr(addr)); + } + } + + Ok(()) + }) +} diff --git a/ceno_zkvm/src/instructions/gpu/chips/add.rs b/ceno_zkvm/src/instructions/gpu/chips/add.rs new file mode 100644 index 000000000..3d7f53b65 --- /dev/null +++ b/ceno_zkvm/src/instructions/gpu/chips/add.rs @@ -0,0 +1,188 @@ +use ceno_gpu::common::witgen::types::AddColumnMap; +use ff_ext::ExtensionField; + +use crate::instructions::{ + gpu::utils::column_map::{ + extract_carries, extract_rd, extract_rs1, extract_rs2, extract_state, extract_uint_limbs, + }, + riscv::arith::ArithConfig, +}; + +/// Extract column map from a constructed ArithConfig (ADD variant). +/// +/// This reads all WitIn.id values from the config tree and packs them +/// into an AddColumnMap suitable for GPU kernel dispatch. +pub fn extract_add_column_map( + config: &ArithConfig, + num_witin: usize, +) -> AddColumnMap { + let (pc, ts) = extract_state(&config.r_insn.vm_state); + let (rs1_id, rs1_prev_ts, rs1_lt_diff) = extract_rs1(&config.r_insn.rs1); + let (rs2_id, rs2_prev_ts, rs2_lt_diff) = extract_rs2(&config.r_insn.rs2); + let (rd_id, rd_prev_ts, rd_prev_val, rd_lt_diff) = extract_rd(&config.r_insn.rd); + + let rs1_limbs = extract_uint_limbs::(&config.rs1_read, "rs1_read"); + let rs2_limbs = extract_uint_limbs::(&config.rs2_read, "rs2_read"); + let rd_carries = extract_carries::(&config.rd_written, "rd_written"); + + AddColumnMap { + pc, + ts, + rs1_id, + rs1_prev_ts, + rs1_lt_diff, + rs2_id, + rs2_prev_ts, + rs2_lt_diff, + rd_id, + rd_prev_ts, + rd_prev_val, + rd_lt_diff, + rs1_limbs, + rs2_limbs, + rd_carries, + num_cols: num_witin as u32, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::{Instruction, riscv::arith::AddInstruction}, + structs::ProgramParams, + }; + use ceno_emul::{ByteAddr, Change, InsnKind, StepRecord, encode_rv32}; + use ff_ext::BabyBearExt4; + + type E = BabyBearExt4; + + fn make_test_steps(n: usize) -> Vec { + const EDGE_CASES: &[(u32, u32)] = &[ + (0, 0), + (0, 1), + (1, 0), + (u32::MAX, 1), // overflow + (u32::MAX, u32::MAX), // double overflow + (0x80000000, 0x80000000), // INT_MIN + INT_MIN + (0x7FFFFFFF, 1), // INT_MAX + 1 + (0xFFFF0000, 0x0000FFFF), // limb carry + ]; + + let pc_start = 0x1000u32; + (0..n) + .map(|i| { + let (rs1, rs2) = if i < EDGE_CASES.len() { + EDGE_CASES[i] + } else { + ((i as u32) % 1000 + 1, (i as u32) % 500 + 3) + }; + let rd_before = (i as u32) % 200; + let rd_after = rs1.wrapping_add(rs2); + let cycle = 4 + (i as u64) * 4; + let pc = ByteAddr(pc_start + (i as u32) * 4); + let insn_code = encode_rv32(InsnKind::ADD, 2, 3, 4, 0); + StepRecord::new_r_instruction( + cycle, + pc, + insn_code, + rs1, + rs2, + Change::new(rd_before, rd_after), + 0, + ) + }) + .collect() + } + + use crate::instructions::gpu::utils::column_map::test_colmap; + test_colmap!( + test_extract_add_column_map, + AddInstruction, + extract_add_column_map + ); + + #[test] + #[cfg(feature = "gpu")] + fn test_gpu_witgen_add_correctness() { + use crate::{ + e2e::ShardContext, + instructions::gpu::{ + dispatch, + utils::test_helpers::{assert_full_gpu_pipeline, assert_witness_colmajor_eq}, + }, + }; + use ceno_gpu::{Buffer, bb31::CudaHalBB31}; + + let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); + + // Construct circuit + let mut cs = ConstraintSystem::::new(|| "test_gpu"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + AddInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let num_witin = cb.cs.num_witin as usize; + let num_structural_witin = cb.cs.num_structural_witin as usize; + + // Generate test data + let n = 1024; + let steps = make_test_steps(n); + let indices: Vec = (0..n).collect(); + + // CPU path + let mut shard_ctx = ShardContext::default(); + let (cpu_rmms, cpu_lkm) = + crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(); + let cpu_witness = &cpu_rmms[0]; + + // GPU path (AOS with indirect indexing) + let col_map = extract_add_column_map(&config, num_witin); + let shard_ctx_gpu = ShardContext::default(); + let shard_offset = shard_ctx_gpu.current_shard_offset_cycle(); + let steps_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + steps.as_ptr() as *const u8, + steps.len() * std::mem::size_of::(), + ) + }; + let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); + let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); + let gpu_result = hal + .witgen + .witgen_add( + &col_map, + &gpu_records, + &indices_u32, + shard_offset, + 0, + 0, + None, + None, + ) + .unwrap(); + + let gpu_data: Vec<::BaseField> = + gpu_result.witness.device_buffer.to_vec().unwrap(); + assert_witness_colmajor_eq(&gpu_data, cpu_witness.values(), n, num_witin); + + assert_full_gpu_pipeline::>( + &config, + &steps, + dispatch::GpuWitgenKind::Add, + &cpu_rmms, + &cpu_lkm, + &shard_ctx, + num_witin, + num_structural_witin, + ); + } +} diff --git a/ceno_zkvm/src/instructions/gpu/chips/addi.rs b/ceno_zkvm/src/instructions/gpu/chips/addi.rs new file mode 100644 index 000000000..c17493a04 --- /dev/null +++ b/ceno_zkvm/src/instructions/gpu/chips/addi.rs @@ -0,0 +1,144 @@ +use ceno_gpu::common::witgen::types::AddiColumnMap; +use ff_ext::ExtensionField; + +use crate::instructions::{ + gpu::utils::column_map::{ + extract_carries, extract_rd, extract_rs1, extract_state, extract_uint_limbs, + }, + riscv::arith_imm::arith_imm_circuit_v2::InstructionConfig, +}; + +/// Extract column map from a constructed InstructionConfig (ADDI v2). +pub fn extract_addi_column_map( + config: &InstructionConfig, + num_witin: usize, +) -> AddiColumnMap { + let im = &config.i_insn; + + let (pc, ts) = extract_state(&im.vm_state); + let (rs1_id, rs1_prev_ts, rs1_lt_diff) = extract_rs1(&im.rs1); + let (rd_id, rd_prev_ts, rd_prev_val, rd_lt_diff) = extract_rd(&im.rd); + + let rs1_limbs = extract_uint_limbs::(&config.rs1_read, "rs1_read"); + let imm = config.imm.id as u32; + let imm_sign = config.imm_sign.id as u32; + let rd_carries = extract_carries::(&config.rd_written, "rd_written"); + + AddiColumnMap { + pc, + ts, + rs1_id, + rs1_prev_ts, + rs1_lt_diff, + rd_id, + rd_prev_ts, + rd_prev_val, + rd_lt_diff, + rs1_limbs, + imm, + imm_sign, + rd_carries, + num_cols: num_witin as u32, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::{Instruction, riscv::arith_imm::AddiInstruction}, + structs::ProgramParams, + }; + use ff_ext::BabyBearExt4; + + type E = BabyBearExt4; + + use crate::instructions::gpu::utils::column_map::test_colmap; + test_colmap!( + test_extract_addi_column_map, + AddiInstruction, + extract_addi_column_map + ); + + #[test] + #[cfg(feature = "gpu")] + fn test_gpu_witgen_addi_correctness() { + use crate::{ + e2e::ShardContext, instructions::gpu::utils::test_helpers::assert_witness_colmajor_eq, + }; + use ceno_emul::{ByteAddr, Change, InsnKind, PC_STEP_SIZE, StepRecord, encode_rv32}; + use ceno_gpu::{Buffer, bb31::CudaHalBB31}; + + let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); + + let mut cs = ConstraintSystem::::new(|| "test_addi_gpu"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + AddiInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let num_witin = cb.cs.num_witin as usize; + let num_structural_witin = cb.cs.num_structural_witin as usize; + + let n = 1024; + let steps: Vec = (0..n) + .map(|i| { + let rs1 = (i as u32) * 137 + 1; + let imm = ((i as i32) % 2048 - 1024) as i32; + let rd_after = rs1.wrapping_add(imm as u32); + let cycle = 4 + (i as u64) * 4; + let pc = ByteAddr(0x1000 + (i as u32) * 4); + let insn_code = encode_rv32(InsnKind::ADDI, 2, 0, 4, imm); + StepRecord::new_i_instruction( + cycle, + Change::new(pc, pc + PC_STEP_SIZE), + insn_code, + rs1, + Change::new((i as u32) % 200, rd_after), + 0, + ) + }) + .collect(); + let indices: Vec = (0..n).collect(); + + let mut shard_ctx = ShardContext::default(); + let (cpu_rmms, _lkm) = crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(); + let cpu_witness = &cpu_rmms[0]; + + let col_map = extract_addi_column_map(&config, num_witin); + let shard_ctx_gpu = ShardContext::default(); + let shard_offset = shard_ctx_gpu.current_shard_offset_cycle(); + let steps_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + steps.as_ptr() as *const u8, + steps.len() * std::mem::size_of::(), + ) + }; + let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); + let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); + let gpu_result = hal + .witgen + .witgen_addi( + &col_map, + &gpu_records, + &indices_u32, + shard_offset, + 0, + 0, + None, + None, + ) + .unwrap(); + + let gpu_data: Vec<::BaseField> = + gpu_result.witness.device_buffer.to_vec().unwrap(); + assert_witness_colmajor_eq(&gpu_data, cpu_witness.values(), n, num_witin); + } +} diff --git a/ceno_zkvm/src/instructions/gpu/chips/auipc.rs b/ceno_zkvm/src/instructions/gpu/chips/auipc.rs new file mode 100644 index 000000000..eed81fa7f --- /dev/null +++ b/ceno_zkvm/src/instructions/gpu/chips/auipc.rs @@ -0,0 +1,144 @@ +use ceno_gpu::common::witgen::types::AuipcColumnMap; +use ff_ext::ExtensionField; + +use crate::instructions::{ + gpu::utils::column_map::{extract_rd, extract_rs1, extract_state, extract_uint_limbs}, + riscv::auipc::AuipcConfig, +}; + +/// Extract column map from a constructed AuipcConfig. +pub fn extract_auipc_column_map( + config: &AuipcConfig, + num_witin: usize, +) -> AuipcColumnMap { + let im = &config.i_insn; + + let (pc, ts) = extract_state(&im.vm_state); + let (rs1_id, rs1_prev_ts, rs1_lt_diff) = extract_rs1(&im.rs1); + let (rd_id, rd_prev_ts, rd_prev_val, rd_lt_diff) = extract_rd(&im.rd); + + let rd_bytes = extract_uint_limbs::(&config.rd_written, "rd_written"); + let pc_limbs: [u32; 2] = [config.pc_limbs[0].id as u32, config.pc_limbs[1].id as u32]; + let imm_limbs: [u32; 3] = [ + config.imm_limbs[0].id as u32, + config.imm_limbs[1].id as u32, + config.imm_limbs[2].id as u32, + ]; + + AuipcColumnMap { + pc, + ts, + rs1_id, + rs1_prev_ts, + rs1_lt_diff, + rd_id, + rd_prev_ts, + rd_prev_val, + rd_lt_diff, + rd_bytes, + pc_limbs, + imm_limbs, + num_cols: num_witin as u32, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::{Instruction, riscv::auipc::AuipcInstruction}, + structs::ProgramParams, + }; + use ff_ext::BabyBearExt4; + + type E = BabyBearExt4; + + use crate::instructions::gpu::utils::column_map::test_colmap; + test_colmap!( + test_extract_auipc_column_map, + AuipcInstruction, + extract_auipc_column_map + ); + + #[test] + #[cfg(feature = "gpu")] + fn test_gpu_witgen_auipc_correctness() { + use crate::{ + e2e::ShardContext, instructions::gpu::utils::test_helpers::assert_witness_colmajor_eq, + }; + use ceno_emul::{ByteAddr, Change, InsnKind, PC_STEP_SIZE, StepRecord, encode_rv32}; + use ceno_gpu::{Buffer, bb31::CudaHalBB31}; + + let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); + + let mut cs = ConstraintSystem::::new(|| "test_auipc_gpu"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + AuipcInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let num_witin = cb.cs.num_witin as usize; + let num_structural_witin = cb.cs.num_structural_witin as usize; + + let n = 1024; + let steps: Vec = (0..n) + .map(|i| { + let imm_20bit = (i as i32) % 0x100000; // 0..0xfffff (20-bit) + let imm = imm_20bit << 12; // AUIPC immediate is upper 20 bits + let pc = ByteAddr(0x1000 + (i as u32) * 4); + let rd_after = pc.0.wrapping_add(imm as u32); + let cycle = 4 + (i as u64) * 4; + let insn_code = encode_rv32(InsnKind::AUIPC, 0, 0, 4, imm); + StepRecord::new_i_instruction( + cycle, + Change::new(pc, pc + PC_STEP_SIZE), + insn_code, + 0, + Change::new((i as u32) % 200, rd_after), + 0, + ) + }) + .collect(); + let indices: Vec = (0..n).collect(); + + let mut shard_ctx = ShardContext::default(); + let (cpu_rmms, _lkm) = crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(); + let cpu_witness = &cpu_rmms[0]; + + let col_map = extract_auipc_column_map(&config, num_witin); + let shard_ctx_gpu = ShardContext::default(); + let shard_offset = shard_ctx_gpu.current_shard_offset_cycle(); + let steps_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + steps.as_ptr() as *const u8, + steps.len() * std::mem::size_of::(), + ) + }; + let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); + let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); + let gpu_result = hal + .witgen + .witgen_auipc( + &col_map, + &gpu_records, + &indices_u32, + shard_offset, + 0, + 0, + None, + None, + ) + .unwrap(); + + let gpu_data: Vec<::BaseField> = + gpu_result.witness.device_buffer.to_vec().unwrap(); + assert_witness_colmajor_eq(&gpu_data, cpu_witness.values(), n, num_witin); + } +} diff --git a/ceno_zkvm/src/instructions/gpu/chips/branch_cmp.rs b/ceno_zkvm/src/instructions/gpu/chips/branch_cmp.rs new file mode 100644 index 000000000..e717d05b7 --- /dev/null +++ b/ceno_zkvm/src/instructions/gpu/chips/branch_cmp.rs @@ -0,0 +1,161 @@ +use ceno_gpu::common::witgen::types::BranchCmpColumnMap; +use ff_ext::ExtensionField; + +use crate::instructions::{ + gpu::utils::column_map::{ + extract_rs1, extract_rs2, extract_state_branching, extract_uint_limbs, + }, + riscv::branch::branch_circuit_v2::BranchConfig, +}; + +/// Extract column map from a constructed BranchConfig (BLT/BGE/BLTU/BGEU variant). +pub fn extract_branch_cmp_column_map( + config: &BranchConfig, + num_witin: usize, +) -> BranchCmpColumnMap { + let rs1_limbs = extract_uint_limbs::(&config.read_rs1, "read_rs1"); + let rs2_limbs = extract_uint_limbs::(&config.read_rs2, "read_rs2"); + + let lt_config = config.uint_lt_config.as_ref().unwrap(); + let cmp_lt = lt_config.cmp_lt.id as u32; + let a_msb_f = lt_config.a_msb_f.id as u32; + let b_msb_f = lt_config.b_msb_f.id as u32; + let diff_marker: [u32; 2] = [ + lt_config.diff_marker[0].id as u32, + lt_config.diff_marker[1].id as u32, + ]; + let diff_val = lt_config.diff_val.id as u32; + + let (pc, next_pc, ts) = extract_state_branching(&config.b_insn.vm_state); + let (rs1_id, rs1_prev_ts, rs1_lt_diff) = extract_rs1(&config.b_insn.rs1); + let (rs2_id, rs2_prev_ts, rs2_lt_diff) = extract_rs2(&config.b_insn.rs2); + let imm = config.b_insn.imm.id as u32; + + BranchCmpColumnMap { + rs1_limbs, + rs2_limbs, + cmp_lt, + a_msb_f, + b_msb_f, + diff_marker, + diff_val, + pc, + next_pc, + ts, + rs1_id, + rs1_prev_ts, + rs1_lt_diff, + rs2_id, + rs2_prev_ts, + rs2_lt_diff, + imm, + num_cols: num_witin as u32, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::{Instruction, riscv::branch::BltInstruction}, + structs::ProgramParams, + }; + use ff_ext::BabyBearExt4; + + type E = BabyBearExt4; + + use crate::instructions::gpu::utils::column_map::test_colmap; + test_colmap!( + test_extract_branch_cmp_column_map, + BltInstruction, + extract_branch_cmp_column_map + ); + + #[test] + #[cfg(feature = "gpu")] + fn test_gpu_witgen_branch_cmp_correctness() { + use crate::{ + e2e::ShardContext, instructions::gpu::utils::test_helpers::assert_witness_colmajor_eq, + }; + use ceno_emul::{ByteAddr, Change, InsnKind, PC_STEP_SIZE, StepRecord, encode_rv32}; + use ceno_gpu::{Buffer, bb31::CudaHalBB31}; + + let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); + + let mut cs = ConstraintSystem::::new(|| "test_blt_gpu"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + BltInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let num_witin = cb.cs.num_witin as usize; + let num_structural_witin = cb.cs.num_structural_witin as usize; + + let n = 1024; + let insn_code = encode_rv32(InsnKind::BLT, 2, 3, 0, -8); + let steps: Vec = (0..n) + .map(|i| { + let rs1 = ((i as i32) * 137 - 500) as u32; + let rs2 = ((i as i32) * 89 - 300) as u32; + let taken = (rs1 as i32) < (rs2 as i32); + let pc = ByteAddr(0x2000 + (i as u32) * 4); + let pc_after = if taken { + ByteAddr(pc.0.wrapping_sub(8)) + } else { + pc + PC_STEP_SIZE + }; + let cycle = 4 + (i as u64) * 4; + StepRecord::new_b_instruction( + cycle, + Change::new(pc, pc_after), + insn_code, + rs1, + rs2, + 0, + ) + }) + .collect(); + let indices: Vec = (0..n).collect(); + + let mut shard_ctx = ShardContext::default(); + let (cpu_rmms, _lkm) = crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(); + let cpu_witness = &cpu_rmms[0]; + + let col_map = extract_branch_cmp_column_map(&config, num_witin); + let shard_ctx_gpu = ShardContext::default(); + let shard_offset = shard_ctx_gpu.current_shard_offset_cycle(); + let steps_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + steps.as_ptr() as *const u8, + steps.len() * std::mem::size_of::(), + ) + }; + let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); + let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); + let gpu_result = hal + .witgen + .witgen_branch_cmp( + &col_map, + &gpu_records, + &indices_u32, + shard_offset, + 1, + 0, + 0, + None, + None, + ) + .unwrap(); + + let gpu_data: Vec<::BaseField> = + gpu_result.witness.device_buffer.to_vec().unwrap(); + assert_witness_colmajor_eq(&gpu_data, cpu_witness.values(), n, num_witin); + } +} diff --git a/ceno_zkvm/src/instructions/gpu/chips/branch_eq.rs b/ceno_zkvm/src/instructions/gpu/chips/branch_eq.rs new file mode 100644 index 000000000..00c16cb4f --- /dev/null +++ b/ceno_zkvm/src/instructions/gpu/chips/branch_eq.rs @@ -0,0 +1,158 @@ +use ceno_gpu::common::witgen::types::BranchEqColumnMap; +use ff_ext::ExtensionField; + +use crate::instructions::{ + gpu::utils::column_map::{ + extract_rs1, extract_rs2, extract_state_branching, extract_uint_limbs, + }, + riscv::branch::branch_circuit_v2::BranchConfig, +}; + +/// Extract column map from a constructed BranchConfig (BEQ/BNE variant). +pub fn extract_branch_eq_column_map( + config: &BranchConfig, + num_witin: usize, +) -> BranchEqColumnMap { + let rs1_limbs = extract_uint_limbs::(&config.read_rs1, "read_rs1"); + let rs2_limbs = extract_uint_limbs::(&config.read_rs2, "read_rs2"); + + let branch_taken = config.eq_branch_taken_bit.as_ref().unwrap().id as u32; + let diff_inv_marker: [u32; 2] = { + let markers = config.eq_diff_inv_marker.as_ref().unwrap(); + [markers[0].id as u32, markers[1].id as u32] + }; + + let (pc, next_pc, ts) = extract_state_branching(&config.b_insn.vm_state); + let (rs1_id, rs1_prev_ts, rs1_lt_diff) = extract_rs1(&config.b_insn.rs1); + let (rs2_id, rs2_prev_ts, rs2_lt_diff) = extract_rs2(&config.b_insn.rs2); + let imm = config.b_insn.imm.id as u32; + + BranchEqColumnMap { + rs1_limbs, + rs2_limbs, + branch_taken, + diff_inv_marker, + pc, + next_pc, + ts, + rs1_id, + rs1_prev_ts, + rs1_lt_diff, + rs2_id, + rs2_prev_ts, + rs2_lt_diff, + imm, + num_cols: num_witin as u32, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::{Instruction, riscv::branch::BeqInstruction}, + structs::ProgramParams, + }; + use ff_ext::BabyBearExt4; + + type E = BabyBearExt4; + + use crate::instructions::gpu::utils::column_map::test_colmap; + test_colmap!( + test_extract_branch_eq_column_map, + BeqInstruction, + extract_branch_eq_column_map + ); + + #[test] + #[cfg(feature = "gpu")] + fn test_gpu_witgen_branch_eq_correctness() { + use crate::{ + e2e::ShardContext, instructions::gpu::utils::test_helpers::assert_witness_colmajor_eq, + }; + use ceno_emul::{ByteAddr, Change, InsnKind, PC_STEP_SIZE, StepRecord, encode_rv32}; + use ceno_gpu::{Buffer, bb31::CudaHalBB31}; + + let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); + + let mut cs = ConstraintSystem::::new(|| "test_beq_gpu"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + BeqInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let num_witin = cb.cs.num_witin as usize; + let num_structural_witin = cb.cs.num_structural_witin as usize; + + let n = 1024; + let insn_code = encode_rv32(InsnKind::BEQ, 2, 3, 0, 8); + let steps: Vec = (0..n) + .map(|i| { + let rs1 = ((i as u32) * 137) ^ 0xABCD; + let rs2 = if i % 3 == 0 { + rs1 + } else { + ((i as u32) * 89) ^ 0x1234 + }; + let taken = rs1 == rs2; + let pc = ByteAddr(0x1000 + (i as u32) * 4); + let pc_after = if taken { + ByteAddr(pc.0 + 8) + } else { + pc + PC_STEP_SIZE + }; + let cycle = 4 + (i as u64) * 4; + StepRecord::new_b_instruction( + cycle, + Change::new(pc, pc_after), + insn_code, + rs1, + rs2, + 0, + ) + }) + .collect(); + let indices: Vec = (0..n).collect(); + + let mut shard_ctx = ShardContext::default(); + let (cpu_rmms, _lkm) = crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(); + let cpu_witness = &cpu_rmms[0]; + + let col_map = extract_branch_eq_column_map(&config, num_witin); + let shard_ctx_gpu = ShardContext::default(); + let shard_offset = shard_ctx_gpu.current_shard_offset_cycle(); + let steps_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + steps.as_ptr() as *const u8, + steps.len() * std::mem::size_of::(), + ) + }; + let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); + let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); + let gpu_result = hal + .witgen + .witgen_branch_eq( + &col_map, + &gpu_records, + &indices_u32, + shard_offset, + 1, + 0, + 0, + None, + None, + ) + .unwrap(); + + let gpu_data: Vec<::BaseField> = + gpu_result.witness.device_buffer.to_vec().unwrap(); + assert_witness_colmajor_eq(&gpu_data, cpu_witness.values(), n, num_witin); + } +} diff --git a/ceno_zkvm/src/instructions/gpu/chips/div.rs b/ceno_zkvm/src/instructions/gpu/chips/div.rs new file mode 100644 index 000000000..3e980ae4f --- /dev/null +++ b/ceno_zkvm/src/instructions/gpu/chips/div.rs @@ -0,0 +1,332 @@ +use ceno_gpu::common::witgen::types::DivColumnMap; +use ff_ext::ExtensionField; + +use crate::instructions::{ + gpu::utils::column_map::{ + extract_rd, extract_rs1, extract_rs2, extract_state, extract_uint_limbs, + }, + riscv::div::div_circuit_v2::DivRemConfig, +}; + +/// Extract column map from a constructed DivRemConfig. +/// div_kind: 0=DIV, 1=DIVU, 2=REM, 3=REMU +pub fn extract_div_column_map( + config: &DivRemConfig, + num_witin: usize, +) -> DivColumnMap { + let (pc, ts) = extract_state(&config.r_insn.vm_state); + let (rs1_id, rs1_prev_ts, rs1_lt_diff) = extract_rs1(&config.r_insn.rs1); + let (rs2_id, rs2_prev_ts, rs2_lt_diff) = extract_rs2(&config.r_insn.rs2); + let (rd_id, rd_prev_ts, rd_prev_val, rd_lt_diff) = extract_rd(&config.r_insn.rd); + + let dividend = extract_uint_limbs::(&config.dividend, "dividend"); + let divisor = extract_uint_limbs::(&config.divisor, "divisor"); + let quotient = extract_uint_limbs::(&config.quotient, "quotient"); + let remainder = extract_uint_limbs::(&config.remainder, "remainder"); + + // Sign/control bits + let dividend_sign = config.dividend_sign.id as u32; + let divisor_sign = config.divisor_sign.id as u32; + let quotient_sign = config.quotient_sign.id as u32; + let remainder_zero = config.remainder_zero.id as u32; + let divisor_zero = config.divisor_zero.id as u32; + + // Inverse witnesses + let divisor_sum_inv = config.divisor_sum_inv.id as u32; + let remainder_sum_inv = config.remainder_sum_inv.id as u32; + let remainder_inv: [u32; 2] = [ + config.remainder_inv[0].id as u32, + config.remainder_inv[1].id as u32, + ]; + + // sign_xor + let sign_xor = config.sign_xor.id as u32; + + let remainder_prime = + extract_uint_limbs::(&config.remainder_prime, "remainder_prime"); + + // lt_marker + let lt_marker: [u32; 2] = [config.lt_marker[0].id as u32, config.lt_marker[1].id as u32]; + + // lt_diff + let lt_diff = config.lt_diff.id as u32; + + DivColumnMap { + pc, + ts, + rs1_id, + rs1_prev_ts, + rs1_lt_diff, + rs2_id, + rs2_prev_ts, + rs2_lt_diff, + rd_id, + rd_prev_ts, + rd_prev_val, + rd_lt_diff, + dividend, + divisor, + quotient, + remainder, + dividend_sign, + divisor_sign, + quotient_sign, + remainder_zero, + divisor_zero, + divisor_sum_inv, + remainder_sum_inv, + remainder_inv, + sign_xor, + remainder_prime, + lt_marker, + lt_diff, + num_cols: num_witin as u32, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::{ + Instruction, + riscv::div::{DivInstruction, DivuInstruction, RemInstruction, RemuInstruction}, + }, + structs::ProgramParams, + }; + use ff_ext::BabyBearExt4; + + type E = BabyBearExt4; + + use crate::instructions::gpu::utils::column_map::test_colmap; + test_colmap!( + test_extract_div_column_map, + DivInstruction, + extract_div_column_map + ); + test_colmap!( + test_extract_divu_column_map, + DivuInstruction, + extract_div_column_map + ); + test_colmap!( + test_extract_rem_column_map, + RemInstruction, + extract_div_column_map + ); + test_colmap!( + test_extract_remu_column_map, + RemuInstruction, + extract_div_column_map + ); + + #[test] + #[cfg(feature = "gpu")] + fn test_gpu_witgen_div_correctness() { + use crate::{ + e2e::ShardContext, instructions::gpu::utils::test_helpers::assert_witness_colmajor_eq, + }; + use ceno_emul::{ByteAddr, Change, InsnKind, StepRecord, encode_rv32}; + use ceno_gpu::{Buffer, bb31::CudaHalBB31}; + + let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); + + let variants: &[(InsnKind, u32, &str)] = &[ + (InsnKind::DIV, 0, "DIV"), + (InsnKind::DIVU, 1, "DIVU"), + (InsnKind::REM, 2, "REM"), + (InsnKind::REMU, 3, "REMU"), + ]; + + for &(insn_kind, div_kind, name) in variants { + eprintln!("Testing {} GPU vs CPU correctness...", name); + + let mut cs = ConstraintSystem::::new(|| format!("test_{}", name.to_lowercase())); + let mut cb = CircuitBuilder::new(&mut cs); + + let config = match insn_kind { + InsnKind::DIV => { + DivInstruction::::construct_circuit(&mut cb, &ProgramParams::default()) + .unwrap() + } + InsnKind::DIVU => { + DivuInstruction::::construct_circuit(&mut cb, &ProgramParams::default()) + .unwrap() + } + InsnKind::REM => { + RemInstruction::::construct_circuit(&mut cb, &ProgramParams::default()) + .unwrap() + } + InsnKind::REMU => { + RemuInstruction::::construct_circuit(&mut cb, &ProgramParams::default()) + .unwrap() + } + _ => unreachable!(), + }; + let num_witin = cb.cs.num_witin as usize; + let num_structural_witin = cb.cs.num_structural_witin as usize; + + let n = 1024; + + const EDGE_CASES: &[(u32, u32)] = &[ + (0, 1), // 0 / 1 + (1, 1), // 1 / 1 + (0, 0), // 0 / 0 (zero divisor) + (12345, 0), // non-zero / 0 (zero divisor) + (u32::MAX, 0), // max / 0 (zero divisor) + (0x80000000, 0), // INT_MIN / 0 (zero divisor) + (0x80000000, 0xFFFFFFFF), // INT_MIN / -1 (signed overflow!) + (0x7FFFFFFF, 0xFFFFFFFF), // INT_MAX / -1 + (0xFFFFFFFF, 0xFFFFFFFF), // -1 / -1 + (0x80000000, 1), // INT_MIN / 1 + (0x80000000, 2), // INT_MIN / 2 + (u32::MAX, u32::MAX), // max / max + (u32::MAX, 1), // max / 1 + (1, u32::MAX), // 1 / max + ]; + + let steps: Vec = (0..n) + .map(|i| { + let pc = ByteAddr(0x1000 + (i as u32) * 4); + // Use edge cases first, then varied values with zero divisor + let (rs1_val, rs2_val) = if i < EDGE_CASES.len() { + EDGE_CASES[i] + } else { + let rs1 = (i as u32).wrapping_mul(12345).wrapping_add(7); + let rs2 = if i % 50 == 0 { + 0 // test zero divisor + } else { + (i as u32).wrapping_mul(54321).wrapping_add(13) + }; + (rs1, rs2) + }; + let rd_after = match insn_kind { + InsnKind::DIV => { + if rs2_val == 0 { + u32::MAX // -1 as u32 + } else { + (rs1_val as i32).wrapping_div(rs2_val as i32) as u32 + } + } + InsnKind::DIVU => { + if rs2_val == 0 { + u32::MAX + } else { + rs1_val / rs2_val + } + } + InsnKind::REM => { + if rs2_val == 0 { + rs1_val + } else { + (rs1_val as i32).wrapping_rem(rs2_val as i32) as u32 + } + } + InsnKind::REMU => { + if rs2_val == 0 { + rs1_val + } else { + rs1_val % rs2_val + } + } + _ => unreachable!(), + }; + let rd_before = (i as u32) % 200; + let cycle = 4 + (i as u64) * 4; + let insn_code = encode_rv32(insn_kind, 2, 3, 4, 0); + + StepRecord::new_r_instruction( + cycle, + pc, + insn_code, + rs1_val, + rs2_val, + Change::new(rd_before, rd_after), + 0, + ) + }) + .collect(); + let indices: Vec = (0..n).collect(); + + // CPU path + let mut shard_ctx = ShardContext::default(); + let (cpu_rmms, _lkm) = match insn_kind { + InsnKind::DIV => crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(), + InsnKind::DIVU => { + crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap() + } + InsnKind::REM => crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(), + InsnKind::REMU => { + crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap() + } + _ => unreachable!(), + }; + let cpu_witness = &cpu_rmms[0]; + + // GPU path + let col_map = extract_div_column_map(&config, num_witin); + let shard_ctx_gpu = ShardContext::default(); + let shard_offset = shard_ctx_gpu.current_shard_offset_cycle(); + let steps_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + steps.as_ptr() as *const u8, + steps.len() * std::mem::size_of::(), + ) + }; + let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); + let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); + let gpu_result = hal + .witgen + .witgen_div( + &col_map, + &gpu_records, + &indices_u32, + shard_offset, + div_kind, + 0, + 0, + None, + None, + ) + .unwrap(); + + let gpu_data: Vec<::BaseField> = + gpu_result.witness.device_buffer.to_vec().unwrap(); + assert_witness_colmajor_eq(&gpu_data, cpu_witness.values(), n, num_witin); + eprintln!("{} GPU vs CPU: PASS ({} instances)", name, n); + } + } +} diff --git a/ceno_zkvm/src/instructions/gpu/chips/jal.rs b/ceno_zkvm/src/instructions/gpu/chips/jal.rs new file mode 100644 index 000000000..22bc084b4 --- /dev/null +++ b/ceno_zkvm/src/instructions/gpu/chips/jal.rs @@ -0,0 +1,132 @@ +use ceno_gpu::common::witgen::types::JalColumnMap; +use ff_ext::ExtensionField; + +use crate::instructions::{ + gpu::utils::column_map::{extract_rd, extract_state_branching, extract_uint_limbs}, + riscv::jump::jal_v2::JalConfig, +}; + +/// Extract column map from a constructed JalConfig. +pub fn extract_jal_column_map( + config: &JalConfig, + num_witin: usize, +) -> JalColumnMap { + let jm = &config.j_insn; + + let (pc, next_pc, ts) = extract_state_branching(&jm.vm_state); + let (rd_id, rd_prev_ts, rd_prev_val, rd_lt_diff) = extract_rd(&jm.rd); + let rd_bytes = extract_uint_limbs::(&config.rd_written, "rd_written"); + + JalColumnMap { + pc, + next_pc, + ts, + rd_id, + rd_prev_ts, + rd_prev_val, + rd_lt_diff, + rd_bytes, + num_cols: num_witin as u32, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::{Instruction, riscv::jump::JalInstruction}, + structs::ProgramParams, + }; + use ff_ext::BabyBearExt4; + + type E = BabyBearExt4; + + use crate::instructions::gpu::utils::column_map::test_colmap; + test_colmap!( + test_extract_jal_column_map, + JalInstruction, + extract_jal_column_map + ); + + #[test] + #[cfg(feature = "gpu")] + fn test_gpu_witgen_jal_correctness() { + use crate::{ + e2e::ShardContext, instructions::gpu::utils::test_helpers::assert_witness_colmajor_eq, + }; + use ceno_emul::{ByteAddr, Change, InsnKind, PC_STEP_SIZE, StepRecord, encode_rv32}; + use ceno_gpu::{Buffer, bb31::CudaHalBB31}; + + let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); + + let mut cs = ConstraintSystem::::new(|| "test_jal_gpu"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + JalInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let num_witin = cb.cs.num_witin as usize; + let num_structural_witin = cb.cs.num_structural_witin as usize; + + let n = 1024; + let steps: Vec = (0..n) + .map(|i| { + let pc = ByteAddr(0x1000 + (i as u32) * 4); + // JAL offset must be even; use small positive/negative offsets + let offset = (((i as i32) % 256) - 128) * 2; // even offsets + let new_pc = ByteAddr(pc.0.wrapping_add_signed(offset)); + let rd_after: u32 = (pc + PC_STEP_SIZE).into(); + let cycle = 4 + (i as u64) * 4; + let insn_code = encode_rv32(InsnKind::JAL, 0, 0, 4, offset); + StepRecord::new_j_instruction( + cycle, + Change::new(pc, new_pc), + insn_code, + Change::new((i as u32) % 200, rd_after), + 0, + ) + }) + .collect(); + let indices: Vec = (0..n).collect(); + + let mut shard_ctx = ShardContext::default(); + let (cpu_rmms, _lkm) = crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(); + let cpu_witness = &cpu_rmms[0]; + + let col_map = extract_jal_column_map(&config, num_witin); + let shard_ctx_gpu = ShardContext::default(); + let shard_offset = shard_ctx_gpu.current_shard_offset_cycle(); + let steps_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + steps.as_ptr() as *const u8, + steps.len() * std::mem::size_of::(), + ) + }; + let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); + let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); + let gpu_result = hal + .witgen + .witgen_jal( + &col_map, + &gpu_records, + &indices_u32, + shard_offset, + 0, + 0, + None, + None, + ) + .unwrap(); + + let gpu_data: Vec<::BaseField> = + gpu_result.witness.device_buffer.to_vec().unwrap(); + assert_witness_colmajor_eq(&gpu_data, cpu_witness.values(), n, num_witin); + } +} diff --git a/ceno_zkvm/src/instructions/gpu/chips/jalr.rs b/ceno_zkvm/src/instructions/gpu/chips/jalr.rs new file mode 100644 index 000000000..a5663b7cf --- /dev/null +++ b/ceno_zkvm/src/instructions/gpu/chips/jalr.rs @@ -0,0 +1,154 @@ +use ceno_gpu::common::witgen::types::JalrColumnMap; +use ff_ext::ExtensionField; + +use crate::instructions::{ + gpu::utils::column_map::{ + extract_rd, extract_rs1, extract_state_branching, extract_uint_limbs, extract_wit_ids, + }, + riscv::jump::jalr_v2::JalrConfig, +}; + +/// Extract column map from a constructed JalrConfig. +pub fn extract_jalr_column_map( + config: &JalrConfig, + num_witin: usize, +) -> JalrColumnMap { + let im = &config.i_insn; + + let (pc, next_pc, ts) = extract_state_branching(&im.vm_state); + let (rs1_id, rs1_prev_ts, rs1_lt_diff) = extract_rs1(&im.rs1); + let (rd_id, rd_prev_ts, rd_prev_val, rd_lt_diff) = extract_rd(&im.rd); + + let rs1_limbs = extract_uint_limbs::(&config.rs1_read, "rs1_read"); + let imm = config.imm.id as u32; + let imm_sign = config.imm_sign.id as u32; + let jump_pc_addr = extract_uint_limbs::(&config.jump_pc_addr.addr, "jump_pc_addr"); + let jump_pc_addr_bit = + extract_wit_ids::<2>(&config.jump_pc_addr.low_bits, "jump_pc_addr low_bits"); + + // rd_high + let rd_high = config.rd_high.id as u32; + + JalrColumnMap { + pc, + next_pc, + ts, + rs1_id, + rs1_prev_ts, + rs1_lt_diff, + rd_id, + rd_prev_ts, + rd_prev_val, + rd_lt_diff, + rs1_limbs, + imm, + imm_sign, + jump_pc_addr, + jump_pc_addr_bit, + rd_high, + num_cols: num_witin as u32, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::{Instruction, riscv::jump::JalrInstruction}, + structs::ProgramParams, + }; + use ff_ext::BabyBearExt4; + + type E = BabyBearExt4; + + use crate::instructions::gpu::utils::column_map::test_colmap; + test_colmap!( + test_extract_jalr_column_map, + JalrInstruction, + extract_jalr_column_map + ); + + #[test] + #[cfg(feature = "gpu")] + fn test_gpu_witgen_jalr_correctness() { + use crate::{ + e2e::ShardContext, instructions::gpu::utils::test_helpers::assert_witness_colmajor_eq, + }; + use ceno_emul::{ByteAddr, Change, InsnKind, PC_STEP_SIZE, StepRecord, encode_rv32}; + use ceno_gpu::{Buffer, bb31::CudaHalBB31}; + + let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); + + let mut cs = ConstraintSystem::::new(|| "test_jalr_gpu"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + JalrInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let num_witin = cb.cs.num_witin as usize; + let num_structural_witin = cb.cs.num_structural_witin as usize; + + let n = 1024; + let steps: Vec = (0..n) + .map(|i| { + let pc = ByteAddr(0x1000 + (i as u32) * 4); + let rs1_val: u32 = 0x0010_0000u32.wrapping_add(i as u32 * 137); + let imm: i32 = ((i as i32) % 2048) - 1024; // range [-1024, 1023] + let jump_raw = rs1_val.wrapping_add(imm as u32); + let new_pc = ByteAddr(jump_raw & !1u32); // aligned to 2 bytes + let rd_after: u32 = (pc + PC_STEP_SIZE).into(); + let cycle = 4 + (i as u64) * 4; + let insn_code = encode_rv32(InsnKind::JALR, 1, 0, 4, imm); + StepRecord::new_i_instruction( + cycle, + Change::new(pc, new_pc), + insn_code, + rs1_val, + Change::new((i as u32) % 200, rd_after), + 0, + ) + }) + .collect(); + let indices: Vec = (0..n).collect(); + + let mut shard_ctx = ShardContext::default(); + let (cpu_rmms, _lkm) = crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(); + let cpu_witness = &cpu_rmms[0]; + + let col_map = extract_jalr_column_map(&config, num_witin); + let shard_ctx_gpu = ShardContext::default(); + let shard_offset = shard_ctx_gpu.current_shard_offset_cycle(); + let steps_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + steps.as_ptr() as *const u8, + steps.len() * std::mem::size_of::(), + ) + }; + let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); + let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); + let gpu_result = hal + .witgen + .witgen_jalr( + &col_map, + &gpu_records, + &indices_u32, + shard_offset, + 0, + 0, + None, + None, + ) + .unwrap(); + + let gpu_data: Vec<::BaseField> = + gpu_result.witness.device_buffer.to_vec().unwrap(); + assert_witness_colmajor_eq(&gpu_data, cpu_witness.values(), n, num_witin); + } +} diff --git a/ceno_zkvm/src/instructions/gpu/chips/keccak.rs b/ceno_zkvm/src/instructions/gpu/chips/keccak.rs new file mode 100644 index 000000000..edca0b680 --- /dev/null +++ b/ceno_zkvm/src/instructions/gpu/chips/keccak.rs @@ -0,0 +1,711 @@ +use ceno_emul::{StepIndex, StepRecord}; +use ceno_gpu::common::witgen::types::{GpuKeccakInstance, GpuKeccakWriteOp, KeccakColumnMap}; +use ff_ext::ExtensionField; +use std::sync::Arc; + +use crate::instructions::riscv::ecall::keccak::EcallKeccakConfig; + +use ceno_emul::SyscallWitness; + +use ceno_emul::WordAddr; +use ceno_gpu::{ + Buffer, CudaHal, + bb31::CudaHalBB31, + common::{transpose::matrix_transpose, witgen::types::GpuShardRamRecord}, +}; +use gkr_iop::utils::lk_multiplicity::Multiplicity; +use p3::field::FieldAlgebra; +use tracing::info_span; +use witness::{InstancePaddingStrategy, RowMajorMatrix}; + +use crate::{ + e2e::ShardContext, + error::ZKVMError, + instructions::gpu::{ + cache::{ + ensure_shard_metadata_cached, read_shared_addr_count, read_shared_addr_range, + with_cached_shard_meta, + }, + config::{is_debug_compare_enabled, is_gpu_witgen_enabled, is_kind_disabled}, + dispatch::{GpuWitgenKind, compute_fetch_params, is_force_cpu_path}, + utils::{ + d2h::{gpu_compact_ec_d2h, gpu_lk_counters_to_multiplicity}, + debug_compare::debug_compare_keccak, + }, + }, + tables::RMMCollections, + witness::LkMultiplicity, +}; + +/// Extract column map from a constructed EcallKeccakConfig. +/// +/// VM state columns are listed individually. Keccak math columns use +/// a single base offset since they're allocated contiguously via transmute. +pub fn extract_keccak_column_map( + config: &EcallKeccakConfig, + num_witin: usize, +) -> KeccakColumnMap { + // StateInOut + let pc = config.vm_state.pc.id as u32; + let ts = config.vm_state.ts.id as u32; + + // OpFixedRS - ecall_id + let ecall_prev_ts = config.ecall_id.prev_ts.id as u32; + let ecall_lt_diff = { + let diffs = &config.ecall_id.lt_cfg.0.diff; + assert_eq!( + diffs.len(), + 2, + "Expected 2 AssertLt diff limbs for ecall_id" + ); + [diffs[0].id as u32, diffs[1].id as u32] + }; + + // MemAddr - state_ptr address limbs + let addr_limbs = { + let limbs = config + .state_ptr + .1 + .addr + .wits_in() + .expect("MemAddr should have WitIn limbs"); + assert_eq!(limbs.len(), 2, "Expected 2 addr limbs"); + [limbs[0].id as u32, limbs[1].id as u32] + }; + + // OpFixedRS - state_ptr register write + let sptr_prev_ts = config.state_ptr.0.prev_ts.id as u32; + let sptr_prev_val = { + let limbs = config + .state_ptr + .0 + .prev_value + .as_ref() + .expect("state_ptr should have prev_value") + .wits_in() + .expect("prev_value should have WitIn limbs"); + assert_eq!(limbs.len(), 2, "Expected 2 prev_value limbs"); + [limbs[0].id as u32, limbs[1].id as u32] + }; + let sptr_lt_diff = { + let diffs = &config.state_ptr.0.lt_cfg.0.diff; + assert_eq!( + diffs.len(), + 2, + "Expected 2 AssertLt diff limbs for state_ptr" + ); + [diffs[0].id as u32, diffs[1].id as u32] + }; + + // WriteMEM x50: prev_ts + lt_diff[2] each + let mut mem_prev_ts = [0u32; 50]; + let mut mem_lt_diff_0 = [0u32; 50]; + let mut mem_lt_diff_1 = [0u32; 50]; + for (i, writer) in config.mem_rw.iter().enumerate() { + mem_prev_ts[i] = writer.prev_ts.id as u32; + let diffs = &writer.lt_cfg.0.diff; + assert_eq!( + diffs.len(), + 2, + "Expected 2 AssertLt diff limbs for mem_rw[{}]", + i + ); + mem_lt_diff_0[i] = diffs[0].id as u32; + mem_lt_diff_1[i] = diffs[1].id as u32; + } + + // Keccak math columns base offset (contiguous block) + let keccak_base_col = config.layout.layer_exprs.wits.input8[0].id as u32; + + // Verify contiguity of keccak math columns + #[cfg(debug_assertions)] + { + let base = keccak_base_col as usize; + let expected_size = + std::mem::size_of::>(); + // Check that the last keccak column is at base + expected_size - 1 + let last_rc = config.layout.layer_exprs.wits.rc.last().unwrap(); + assert_eq!( + last_rc.id as usize, + base + expected_size - 1, + "Keccak math columns not contiguous: last rc id {} != expected {}", + last_rc.id, + base + expected_size - 1 + ); + } + + KeccakColumnMap { + pc, + ts, + ecall_prev_ts, + ecall_lt_diff, + addr_limbs, + sptr_prev_ts, + sptr_prev_val, + sptr_lt_diff, + mem_prev_ts, + mem_lt_diff_0, + mem_lt_diff_1, + keccak_base_col, + num_cols: num_witin as u32, + } +} + +/// Pack step records + syscall witnesses into flat GPU-transferable instances. +pub fn pack_keccak_instances( + steps: &[StepRecord], + step_indices: &[StepIndex], + syscall_witnesses: &Arc>, +) -> Vec { + step_indices + .iter() + .map(|&idx| { + let step = &steps[idx]; + let sw = step + .syscall(syscall_witnesses) + .expect("keccak step must have syscall witness"); + + // Register op (state_ptr) + let reg_op = &sw.reg_ops[0]; + let gpu_reg_op = GpuKeccakWriteOp { + addr: reg_op.addr.0, + value_before: reg_op.value.before, + value_after: reg_op.value.after, + _pad: 0, + previous_cycle: reg_op.previous_cycle, + }; + + // Memory ops (50 read-writes) + let mut mem_ops = [GpuKeccakWriteOp::default(); 50]; + for (i, op) in sw.mem_ops.iter().enumerate() { + mem_ops[i] = GpuKeccakWriteOp { + addr: op.addr.0, + value_before: op.value.before, + value_after: op.value.after, + _pad: 0, + previous_cycle: op.previous_cycle, + }; + } + + GpuKeccakInstance { + pc: step.pc().before.0, + _pad0: 0, + cycle: step.cycle(), + ecall_prev_cycle: step.rs1().unwrap().previous_cycle, + reg_op: gpu_reg_op, + mem_ops, + } + }) + .collect() +} + +/// GPU dispatch entry point for keccak ecall witness generation. +/// +/// Unlike `try_gpu_assign_instances`, keccak has a rotation-aware matrix layout +/// (each logical instance spans 32 physical rows) and requires building +/// structural witness on CPU with selector indices from the cyclic group. +#[cfg(feature = "gpu")] +pub fn gpu_assign_keccak_instances( + config: &crate::instructions::riscv::ecall::keccak::EcallKeccakConfig, + shard_ctx: &mut ShardContext, + num_witin: usize, + num_structural_witin: usize, + steps: &[StepRecord], + step_indices: &[StepIndex], +) -> Result, Multiplicity)>, ZKVMError> { + use crate::precompiles::KECCAK_ROUNDS_CEIL_LOG2; + use gkr_iop::gpu::get_cuda_hal; + + // Guard: disabled or force-CPU + if !is_gpu_witgen_enabled() || is_force_cpu_path() { + return Ok(None); + } + // Check if keccak is disabled via CENO_GPU_DISABLE_WITGEN_KINDS=keccak + if is_kind_disabled(GpuWitgenKind::Keccak) { + return Ok(None); + } + + // GPU only supports BabyBear field + if std::any::TypeId::of::() + != std::any::TypeId::of::<::BaseField>() + { + return Ok(None); + } + + let hal = match get_cuda_hal() { + Ok(hal) => hal, + Err(_) => return Ok(None), + }; + + // Empty step_indices: return empty matrices + if step_indices.is_empty() { + let rotation = KECCAK_ROUNDS_CEIL_LOG2; + let raw_witin = RowMajorMatrix::::new_by_rotation( + 0, + rotation, + num_witin, + InstancePaddingStrategy::Default, + ); + let raw_structural = RowMajorMatrix::::new_by_rotation( + 0, + rotation, + num_structural_witin, + InstancePaddingStrategy::Default, + ); + let lk = LkMultiplicity::default(); + return Ok(Some(( + [raw_witin, raw_structural], + lk.into_finalize_result(), + ))); + } + + let num_instances = step_indices.len(); + tracing::debug!("[GPU witgen] keccak with {} instances", num_instances); + + info_span!("gpu_witgen_keccak", n = num_instances).in_scope(|| { + gpu_assign_keccak_inner::( + config, + shard_ctx, + num_witin, + num_structural_witin, + steps, + step_indices, + &hal, + ) + .map(Some) + }) +} + +/// Keccak-specific GPU witness generation, separate from `gpu_assign_instances_inner` because: +/// 1. Rotation: each instance spans 32 rows (not 1), requiring `new_by_rotation` +/// 2. Structural witness: 3 selectors (sel_first/sel_last/sel_all) vs the standard 1 +/// 3. Input packing: needs `packed_instances` with `syscall_witnesses` +/// +/// The LK/shardram collection logic (Steps 6-7) is identical to the standard path; +/// it is duplicated here rather than shared. +#[cfg(feature = "gpu")] +fn gpu_assign_keccak_inner( + config: &crate::instructions::riscv::ecall::keccak::EcallKeccakConfig, + shard_ctx: &mut ShardContext, + num_witin: usize, + num_structural_witin: usize, + steps: &[StepRecord], + step_indices: &[StepIndex], + hal: &CudaHalBB31, +) -> Result<(RMMCollections, Multiplicity), ZKVMError> { + use crate::precompiles::KECCAK_ROUNDS_CEIL_LOG2; + + let num_instances = step_indices.len(); + let num_padded_instances = num_instances.next_power_of_two().max(2); + let num_padded_rows = num_padded_instances * 32; // 2^5 = 32 rows per instance + let rotation = KECCAK_ROUNDS_CEIL_LOG2; // = 5 + + // Step 1: Extract column map + let col_map = info_span!("col_map").in_scope(|| extract_keccak_column_map(config, num_witin)); + + // Step 2: Pack instances + let packed_instances = info_span!("pack_instances") + .in_scope(|| pack_keccak_instances(steps, step_indices, &shard_ctx.syscall_witnesses)); + + // Step 3: Compute fetch params + let (fetch_base_pc, fetch_num_slots) = compute_fetch_params(steps, step_indices); + + // Step 4: Ensure shard metadata cached + info_span!("ensure_shard_meta") + .in_scope(|| ensure_shard_metadata_cached(hal, shard_ctx, steps.len()))?; + + // Snapshot shared addr count before kernel (for debug comparison) + let addr_count_before = if crate::instructions::gpu::config::is_debug_compare_enabled() { + read_shared_addr_count() + } else { + 0 + }; + + // Step 5: Launch GPU kernel + let gpu_result = info_span!("gpu_kernel").in_scope(|| { + with_cached_shard_meta(|shard_bufs| { + hal.witgen + .witgen_keccak( + &col_map, + &packed_instances, + num_padded_rows, + shard_ctx.current_shard_offset_cycle(), + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), + ) + .map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU witgen_keccak failed: {e}").into()) + }) + }) + })?; + + // D2H keccak's addr entries from shared buffer (delta since before kernel) + let gpu_keccak_addrs = if crate::instructions::gpu::config::is_debug_compare_enabled() { + let addr_count_after = read_shared_addr_count(); + if addr_count_after > addr_count_before { + read_shared_addr_range(addr_count_before, addr_count_after) + } else { + Vec::new() + } + } else { + Vec::new() + }; + + // Step 6: Collect LK multiplicity + let lk_multiplicity = info_span!("gpu_lk_d2h") + .in_scope(|| gpu_lk_counters_to_multiplicity(gpu_result.lk_counters))?; + + // Debug LK comparison is done in the unit test instead. + + // Step 7: Handle compact EC records (shared buffer path) + if gpu_result.compact_ec.is_none() && gpu_result.compact_addr.is_none() { + // Shared buffer path: EC records + addr_accessed accumulated on device + // in shared buffers across all kernel invocations. Skip per-kernel D2H. + } else if let Some(compact) = gpu_result.compact_ec { + info_span!("gpu_ec_shard").in_scope(|| { + let compact_records = + info_span!("compact_d2h").in_scope(|| gpu_compact_ec_d2h(&compact))?; + + // D2H compact addr_accessed + info_span!("compact_addr_d2h").in_scope(|| -> Result<(), ZKVMError> { + if let Some(ref ca) = gpu_result.compact_addr { + let count_vec: Vec = ca.count_buf.to_vec().map_err(|e| { + ZKVMError::InvalidWitness( + format!("compact_addr_count D2H failed: {e}").into(), + ) + })?; + let n = count_vec[0] as usize; + if n > 0 { + let addrs: Vec = ca.buffer.to_vec_n(n).map_err(|e| { + ZKVMError::InvalidWitness( + format!("compact_addr D2H failed: {e}").into(), + ) + })?; + let mut forked = shard_ctx.get_forked(); + let thread_ctx = &mut forked[0]; + for &addr in &addrs { + thread_ctx.push_addr_accessed(WordAddr(addr)); + } + } + } + Ok(()) + })?; + + // Populate shard_ctx with GPU EC records + let raw_bytes = unsafe { + std::slice::from_raw_parts( + compact_records.as_ptr() as *const u8, + compact_records.len() * std::mem::size_of::(), + ) + }; + shard_ctx.extend_gpu_ec_records_raw(raw_bytes); + + Ok::<(), ZKVMError>(()) + })?; + } + + // Step 8: Transpose GPU witness (column-major -> row-major) + D2H + let raw_witin = info_span!("transpose_d2h", rows = num_padded_rows, cols = num_witin) + .in_scope(|| { + let mut rmm_buffer = hal + .alloc_elems_on_device(num_padded_rows * num_witin, false, None) + .map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU alloc for transpose failed: {e}").into()) + })?; + matrix_transpose::( + &hal.inner, + &mut rmm_buffer, + &gpu_result.witness.device_buffer, + num_padded_rows, + num_witin, + ) + .map_err(|e| ZKVMError::InvalidWitness(format!("GPU transpose failed: {e}").into()))?; + + let gpu_data: Vec<::BaseField> = + rmm_buffer.to_vec().map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU D2H copy failed: {e}").into()) + })?; + + // Safety: BabyBear is the only supported GPU field, and E::BaseField must match + let data: Vec = unsafe { + let mut data = std::mem::ManuallyDrop::new(gpu_data); + Vec::from_raw_parts( + data.as_mut_ptr() as *mut E::BaseField, + data.len(), + data.capacity(), + ) + }; + + // Construct a rotation-aware matrix and fill with GPU-transposed data. + // new_by_rotation allocates the correct padded size; we overwrite the real portion. + let mut rmm = RowMajorMatrix::::new_by_rotation( + num_instances, + rotation, + num_witin, + InstancePaddingStrategy::Default, + ); + // Access inner p3 matrix's values via DerefMut + std::ops::DerefMut::deref_mut(&mut rmm).values[..data.len()].copy_from_slice(&data); + Ok::<_, ZKVMError>(rmm) + })?; + + // Step 9: Build structural witness on CPU with selector indices + let raw_structural = info_span!("structural_witness").in_scope(|| { + let mut raw_structural = RowMajorMatrix::::new_by_rotation( + num_instances, + rotation, + num_structural_witin, + InstancePaddingStrategy::Default, + ); + + // Get selector column IDs from config + let sel_first = config + .layout + .selector_type_layout + .sel_first + .as_ref() + .expect("sel_first must be Some"); + let sel_last = config + .layout + .selector_type_layout + .sel_last + .as_ref() + .expect("sel_last must be Some"); + + let sel_first_id = sel_first.selector_expr().id(); + let sel_last_id = sel_last.selector_expr().id(); + let sel_all_id = config + .layout + .selector_type_layout + .sel_all + .selector_expr() + .id(); + + let sel_first_indices = sel_first.sparse_indices(); + let sel_last_indices = sel_last.sparse_indices(); + let sel_all_indices = config.layout.selector_type_layout.sel_all.sparse_indices(); + + // Only set selectors for real instances, not padding ones. + for instance_chunk in raw_structural.iter_mut().take(num_instances) { + // instance_chunk is a &mut [F] of size 32 * num_structural_witin + for &idx in sel_first_indices { + instance_chunk[idx * num_structural_witin + sel_first_id] = E::BaseField::ONE; + } + for &idx in sel_last_indices { + instance_chunk[idx * num_structural_witin + sel_last_id] = E::BaseField::ONE; + } + for &idx in sel_all_indices { + instance_chunk[idx * num_structural_witin + sel_all_id] = E::BaseField::ONE; + } + } + raw_structural.padding_by_strategy(); + + raw_structural + }); + + // Debug comparisons (activated by env vars) + debug_compare_keccak::( + config, + shard_ctx, + num_witin, + num_structural_witin, + steps, + step_indices, + &lk_multiplicity, + &raw_witin, + &gpu_keccak_addrs, + )?; + + Ok(([raw_witin, raw_structural], lk_multiplicity)) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::{Instruction, riscv::ecall::keccak::KeccakInstruction}, + structs::ProgramParams, + }; + use ff_ext::BabyBearExt4; + + type E = BabyBearExt4; + + #[test] + fn test_extract_keccak_column_map() { + let mut cs = ConstraintSystem::::new(|| "test"); + let mut cb = CircuitBuilder::new(&mut cs); + let (config, _gkr_circuit) = + KeccakInstruction::::build_gkr_iop_circuit(&mut cb, &ProgramParams::default()) + .unwrap(); + + let col_map = extract_keccak_column_map(&config, cb.cs.num_witin as usize); + let flat = col_map.to_flat(); + + // All column IDs should be within range + // Note: keccak_base_col and num_cols are metadata, not column indices + let metadata_indices = [flat.len() - 1, flat.len() - 2]; // num_cols, keccak_base_col + for (i, &col) in flat.iter().enumerate() { + if metadata_indices.contains(&i) { + continue; + } + assert!( + (col as usize) < col_map.num_cols as usize, + "Column {} (index {}) out of range: {} >= {}", + i, + col, + col, + col_map.num_cols + ); + } + } + + #[test] + fn test_gpu_witgen_keccak_correctness() { + use crate::e2e::ShardContext; + + let mut cs = ConstraintSystem::::new(|| "test_keccak_gpu"); + let mut cb = CircuitBuilder::new(&mut cs); + let (config, _gkr_circuit) = + KeccakInstruction::::build_gkr_iop_circuit(&mut cb, &ProgramParams::default()) + .unwrap(); + let num_witin = cb.cs.num_witin as usize; + let num_structural_witin = cb.cs.num_structural_witin as usize; + + // Get test data from emulator + let (step, _program, syscall_witnesses) = ceno_emul::test_utils::keccak_step(); + let steps = vec![step]; + let step_indices: Vec = vec![0]; + + // --- CPU path (force CPU via thread-local flag) --- + use crate::instructions::gpu::dispatch::set_force_cpu_path; + set_force_cpu_path(true); + let mut shard_ctx = ShardContext::default(); + shard_ctx.syscall_witnesses = std::sync::Arc::new(syscall_witnesses.clone()); + let (cpu_rmms, _cpu_lkm) = KeccakInstruction::::assign_instances( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &step_indices, + ) + .unwrap(); + set_force_cpu_path(false); + let cpu_witness = &cpu_rmms[0]; + let cpu_structural = &cpu_rmms[1]; + + // --- GPU path (full pipeline via gpu_assign_keccak_instances) --- + use super::gpu_assign_keccak_instances; + let mut shard_ctx_gpu = ShardContext::default(); + shard_ctx_gpu.syscall_witnesses = std::sync::Arc::new(syscall_witnesses); + let (gpu_rmms, gpu_lk) = gpu_assign_keccak_instances::( + &config, + &mut shard_ctx_gpu, + num_witin, + num_structural_witin, + &steps, + &step_indices, + ) + .unwrap() + .expect("GPU path should not return None"); + let gpu_witness = &gpu_rmms[0]; + let gpu_structural = &gpu_rmms[1]; + + // --- Compare witness (raw_witin) --- + let gpu_data = gpu_witness.values(); + let cpu_data = cpu_witness.values(); + assert_eq!(gpu_data.len(), cpu_data.len(), "witness size mismatch"); + + let mut mismatches = 0; + for (i, (g, c)) in gpu_data.iter().zip(cpu_data.iter()).enumerate() { + if g != c { + if mismatches < 20 { + let row = i / num_witin; + let col = i % num_witin; + eprintln!( + "Witness mismatch row={}, col={}: GPU={:?}, CPU={:?}", + row, col, g, c + ); + } + mismatches += 1; + } + } + eprintln!( + "Keccak witness: {} mismatches out of {} cells", + mismatches, + gpu_data.len() + ); + + // --- Compare structural witness --- + let gpu_struct_data = gpu_structural.values(); + let cpu_struct_data = cpu_structural.values(); + assert_eq!( + gpu_struct_data.len(), + cpu_struct_data.len(), + "structural witness size mismatch" + ); + + let mut struct_mismatches = 0; + for (i, (g, c)) in gpu_struct_data + .iter() + .zip(cpu_struct_data.iter()) + .enumerate() + { + if g != c { + if struct_mismatches < 20 { + let row = i / num_structural_witin; + let col = i % num_structural_witin; + eprintln!( + "Structural mismatch row={}, col={}: GPU={:?}, CPU={:?}", + row, col, g, c + ); + } + struct_mismatches += 1; + } + } + eprintln!( + "Keccak structural: {} mismatches out of {} cells", + struct_mismatches, + gpu_struct_data.len() + ); + + // --- Compare LK multiplicity --- + let mut lk_mismatches = 0; + for (table_idx, (gpu_map, cpu_map)) in gpu_lk.0.iter().zip(_cpu_lkm.0.iter()).enumerate() { + for (&k, &gpu_v) in gpu_map.iter() { + let cpu_v = cpu_map.get(&k).copied().unwrap_or(0); + if gpu_v != cpu_v { + if lk_mismatches < 30 { + eprintln!( + "LK mismatch table={}, key={:#x}: GPU={}, CPU={}", + table_idx, k, gpu_v, cpu_v, + ); + } + lk_mismatches += 1; + } + } + for (&k, &cpu_v) in cpu_map.iter() { + if !gpu_map.contains_key(&k) { + if lk_mismatches < 30 { + eprintln!( + "LK mismatch table={}, key={:#x}: GPU=missing, CPU={}", + table_idx, k, cpu_v, + ); + } + lk_mismatches += 1; + } + } + } + eprintln!("Keccak LK: {} mismatches", lk_mismatches); + + assert_eq!(mismatches, 0, "GPU vs CPU witness mismatch"); + assert_eq!( + struct_mismatches, 0, + "GPU vs CPU structural witness mismatch" + ); + assert_eq!(lk_mismatches, 0, "GPU vs CPU LK multiplicity mismatch"); + } +} diff --git a/ceno_zkvm/src/instructions/gpu/chips/load_sub.rs b/ceno_zkvm/src/instructions/gpu/chips/load_sub.rs new file mode 100644 index 000000000..200b069d7 --- /dev/null +++ b/ceno_zkvm/src/instructions/gpu/chips/load_sub.rs @@ -0,0 +1,327 @@ +use ceno_gpu::common::witgen::types::LoadSubColumnMap; +use ff_ext::ExtensionField; + +use crate::instructions::{ + gpu::utils::column_map::{ + extract_rd, extract_read_mem, extract_rs1, extract_state, extract_uint_limbs, + }, + riscv::memory::load_v2::LoadConfig, +}; + +/// Extract column map from a constructed LoadConfig for sub-word loads (LH/LHU/LB/LBU). +pub fn extract_load_sub_column_map( + config: &LoadConfig, + num_witin: usize, +) -> LoadSubColumnMap { + let im = &config.im_insn; + + let (pc, ts) = extract_state(&im.vm_state); + let (rs1_id, rs1_prev_ts, rs1_lt_diff) = extract_rs1(&im.rs1); + let (rd_id, rd_prev_ts, rd_prev_val, rd_lt_diff) = extract_rd(&im.rd); + let (mem_prev_ts, mem_lt_diff) = extract_read_mem(&im.mem_read); + + let rs1_limbs = extract_uint_limbs::(&config.rs1_read, "rs1_read"); + let imm = config.imm.id as u32; + let imm_sign = config.imm_sign.id as u32; + let mem_addr = extract_uint_limbs::(&config.memory_addr.addr, "memory_addr"); + let mem_read = extract_uint_limbs::(&config.memory_read, "memory_read"); + + // Infer variant from config: LB/LBU have 2 low_bits, LH/LHU have 1. + let low_bits = &config.memory_addr.low_bits; + let is_byte = low_bits.len() == 2; + + let addr_bit_1 = if is_byte { + low_bits[1].id as u32 + } else { + assert_eq!(low_bits.len(), 1, "LH/LHU should have 1 low_bit"); + low_bits[0].id as u32 + }; + + let target_limb = config + .target_limb + .expect("sub-word loads must have target_limb") + .id as u32; + + // LB/LBU: addr_bit_0, target_byte, dummy_byte + let (addr_bit_0, target_byte, dummy_byte) = if is_byte { + let bytes = config + .target_limb_bytes + .as_ref() + .expect("LB/LBU must have target_limb_bytes"); + assert_eq!(bytes.len(), 2); + ( + Some(low_bits[0].id as u32), + Some(bytes[0].id as u32), + Some(bytes[1].id as u32), + ) + } else { + (None, None, None) + }; + + // Signed loads have signed_extend_config + let msb = config + .signed_extend_config + .as_ref() + .map(|sec| sec.msb().id as u32); + + LoadSubColumnMap { + pc, + ts, + rs1_id, + rs1_prev_ts, + rs1_lt_diff, + rd_id, + rd_prev_ts, + rd_prev_val, + rd_lt_diff, + mem_prev_ts, + mem_lt_diff, + rs1_limbs, + imm, + imm_sign, + mem_addr, + mem_read, + addr_bit_1, + target_limb, + addr_bit_0, + target_byte, + dummy_byte, + msb, + num_cols: num_witin as u32, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::{ + Instruction, + riscv::memory::{LbInstruction, LbuInstruction, LhInstruction, LhuInstruction}, + }, + structs::ProgramParams, + }; + use ff_ext::BabyBearExt4; + + type E = BabyBearExt4; + + use crate::instructions::gpu::utils::column_map::test_colmap; + test_colmap!( + test_extract_lh_column_map, + LhInstruction, + extract_load_sub_column_map + ); + test_colmap!( + test_extract_lhu_column_map, + LhuInstruction, + extract_load_sub_column_map + ); + test_colmap!( + test_extract_lb_column_map, + LbInstruction, + extract_load_sub_column_map + ); + test_colmap!( + test_extract_lbu_column_map, + LbuInstruction, + extract_load_sub_column_map + ); + + #[test] + #[cfg(feature = "gpu")] + fn test_gpu_witgen_load_sub_correctness() { + use crate::{ + e2e::ShardContext, instructions::gpu::utils::test_helpers::assert_witness_colmajor_eq, + }; + use ceno_emul::{ByteAddr, Change, InsnKind, ReadOp, StepRecord, WordAddr, encode_rv32}; + use ceno_gpu::{Buffer, bb31::CudaHalBB31}; + + let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); + + // Test all 4 variants + let variants: &[(InsnKind, bool, bool, &str)] = &[ + (InsnKind::LH, false, true, "LH"), + (InsnKind::LHU, false, false, "LHU"), + (InsnKind::LB, true, true, "LB"), + (InsnKind::LBU, true, false, "LBU"), + ]; + + for &(insn_kind, is_byte, is_signed, name) in variants { + eprintln!("Testing {} GPU vs CPU correctness...", name); + + let mut cs = ConstraintSystem::::new(|| format!("test_{}", name.to_lowercase())); + let mut cb = CircuitBuilder::new(&mut cs); + + // We need to construct the right instruction type + let config = match insn_kind { + InsnKind::LH => { + LhInstruction::::construct_circuit(&mut cb, &ProgramParams::default()) + .unwrap() + } + InsnKind::LHU => { + LhuInstruction::::construct_circuit(&mut cb, &ProgramParams::default()) + .unwrap() + } + InsnKind::LB => { + LbInstruction::::construct_circuit(&mut cb, &ProgramParams::default()) + .unwrap() + } + InsnKind::LBU => { + LbuInstruction::::construct_circuit(&mut cb, &ProgramParams::default()) + .unwrap() + } + _ => unreachable!(), + }; + let num_witin = cb.cs.num_witin as usize; + let num_structural_witin = cb.cs.num_structural_witin as usize; + + let n = 1024; + let imm_values: [i32; 4] = if is_byte { + [0, 1, -1, -3] + } else { + [0, 2, -2, -6] + }; + + let steps: Vec = (0..n) + .map(|i| { + let pc = ByteAddr(0x1000 + (i as u32) * 4); + let rs1_val = 0x1000u32 + (i as u32) * 16; + let imm: i32 = imm_values[i % imm_values.len()]; + let mem_addr = rs1_val.wrapping_add_signed(imm); + let mem_val = (i as u32) * 111 % 500000; + + // Compute rd_after based on load type + let shift = mem_addr & 3; + let bit_1 = (shift >> 1) & 1; + let bit_0 = shift & 1; + let target_limb: u16 = if bit_1 == 0 { + (mem_val & 0xFFFF) as u16 + } else { + (mem_val >> 16) as u16 + }; + let rd_after = match insn_kind { + InsnKind::LH => (target_limb as i16) as i32 as u32, + InsnKind::LHU => target_limb as u32, + InsnKind::LB => { + let byte = if bit_0 == 0 { + (target_limb & 0xFF) as u8 + } else { + ((target_limb >> 8) & 0xFF) as u8 + }; + (byte as i8) as i32 as u32 + } + InsnKind::LBU => { + let byte = if bit_0 == 0 { + (target_limb & 0xFF) as u8 + } else { + ((target_limb >> 8) & 0xFF) as u8 + }; + byte as u32 + } + _ => unreachable!(), + }; + let rd_before = (i as u32) % 200; + let cycle = 4 + (i as u64) * 4; + let insn_code = encode_rv32(insn_kind, 2, 0, 4, imm); + + let mem_read_op = ReadOp { + addr: WordAddr::from(ByteAddr(mem_addr & !3)), + value: mem_val, + previous_cycle: 0, + }; + + StepRecord::new_im_instruction( + cycle, + pc, + insn_code, + rs1_val, + Change::new(rd_before, rd_after), + mem_read_op, + 0, + ) + }) + .collect(); + let indices: Vec = (0..n).collect(); + + // CPU path + let mut shard_ctx = ShardContext::default(); + let (cpu_rmms, _lkm) = match insn_kind { + InsnKind::LH => crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(), + InsnKind::LHU => crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(), + InsnKind::LB => crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(), + InsnKind::LBU => crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(), + _ => unreachable!(), + }; + let cpu_witness = &cpu_rmms[0]; + + // GPU path + let col_map = extract_load_sub_column_map(&config, num_witin); + let shard_ctx_gpu = ShardContext::default(); + let shard_offset = shard_ctx_gpu.current_shard_offset_cycle(); + let steps_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + steps.as_ptr() as *const u8, + steps.len() * std::mem::size_of::(), + ) + }; + let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); + let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); + let load_width: u32 = if is_byte { 8 } else { 16 }; + let is_signed_u32: u32 = if is_signed { 1 } else { 0 }; + let gpu_result = hal + .witgen + .witgen_load_sub( + &col_map, + &gpu_records, + &indices_u32, + shard_offset, + load_width, + is_signed_u32, + 0, + 0, + 0, + None, + None, + ) + .unwrap(); + + let gpu_data: Vec<::BaseField> = + gpu_result.witness.device_buffer.to_vec().unwrap(); + assert_witness_colmajor_eq(&gpu_data, cpu_witness.values(), n, num_witin); + eprintln!("{} GPU vs CPU: PASS ({} instances)", name, n); + } + } +} diff --git a/ceno_zkvm/src/instructions/gpu/chips/logic_i.rs b/ceno_zkvm/src/instructions/gpu/chips/logic_i.rs new file mode 100644 index 000000000..791129945 --- /dev/null +++ b/ceno_zkvm/src/instructions/gpu/chips/logic_i.rs @@ -0,0 +1,160 @@ +use ceno_gpu::common::witgen::types::LogicIColumnMap; +use ff_ext::ExtensionField; + +use crate::instructions::{ + gpu::utils::column_map::{extract_rd, extract_rs1, extract_state, extract_uint_limbs}, + riscv::logic_imm::logic_imm_circuit_v2::LogicConfig, +}; + +/// Extract column map from a constructed LogicConfig (I-type v2: ANDI/ORI/XORI). +pub fn extract_logic_i_column_map( + config: &LogicConfig, + num_witin: usize, +) -> LogicIColumnMap { + let im = &config.i_insn; + + let (pc, ts) = extract_state(&im.vm_state); + let (rs1_id, rs1_prev_ts, rs1_lt_diff) = extract_rs1(&im.rs1); + let (rd_id, rd_prev_ts, rd_prev_val, rd_lt_diff) = extract_rd(&im.rd); + + let rs1_bytes = extract_uint_limbs::(&config.rs1_read, "rs1_read"); + let rd_bytes = extract_uint_limbs::(&config.rd_written, "rd_written"); + let imm_lo_bytes = extract_uint_limbs::(&config.imm_lo, "imm_lo"); + let imm_hi_bytes = extract_uint_limbs::(&config.imm_hi, "imm_hi"); + + LogicIColumnMap { + pc, + ts, + rs1_id, + rs1_prev_ts, + rs1_lt_diff, + rd_id, + rd_prev_ts, + rd_prev_val, + rd_lt_diff, + rs1_bytes, + rd_bytes, + imm_lo_bytes, + imm_hi_bytes, + num_cols: num_witin as u32, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::{Instruction, riscv::logic_imm::AndiInstruction}, + structs::ProgramParams, + }; + use ff_ext::BabyBearExt4; + + type E = BabyBearExt4; + + use crate::instructions::gpu::utils::column_map::test_colmap; + test_colmap!( + test_extract_logic_i_column_map, + AndiInstruction, + extract_logic_i_column_map + ); + + #[test] + #[cfg(feature = "gpu")] + fn test_gpu_witgen_logic_i_correctness() { + use crate::{ + e2e::ShardContext, instructions::gpu::utils::test_helpers::assert_witness_colmajor_eq, + }; + use ceno_emul::{ByteAddr, Change, InsnKind, PC_STEP_SIZE, StepRecord, encode_rv32u}; + use ceno_gpu::{Buffer, bb31::CudaHalBB31}; + + let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); + + let mut cs = ConstraintSystem::::new(|| "test_logic_i_gpu"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + AndiInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let num_witin = cb.cs.num_witin as usize; + let num_structural_witin = cb.cs.num_structural_witin as usize; + + const EDGE_CASES: &[(u32, u32)] = &[ + (0, 0), + (u32::MAX, 0xFFF), // all bits AND max imm + (u32::MAX, 0), + (0, 0xFFF), + (0xAAAAAAAA, 0x555), // alternating + (0xFFFF0000, 0xFFF), + (0x12345678, 0x000), + (0xDEADBEEF, 0xABC), + ]; + + let n = 1024; + let steps: Vec = (0..n) + .map(|i| { + let (rs1, imm) = if i < EDGE_CASES.len() { + EDGE_CASES[i] + } else { + ( + (i as u32).wrapping_mul(0x01010101) ^ 0xabed_5eff, + (i as u32) % 4096, + ) + }; + let rd_after = rs1 & imm; // ANDI + let cycle = 4 + (i as u64) * 4; + let pc = ByteAddr(0x1000 + (i as u32) * 4); + let insn_code = encode_rv32u(InsnKind::ANDI, 2, 0, 4, imm); + StepRecord::new_i_instruction( + cycle, + Change::new(pc, pc + PC_STEP_SIZE), + insn_code, + rs1, + Change::new((i as u32) % 200, rd_after), + 0, + ) + }) + .collect(); + let indices: Vec = (0..n).collect(); + + let mut shard_ctx = ShardContext::default(); + let (cpu_rmms, _lkm) = crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(); + let cpu_witness = &cpu_rmms[0]; + + let col_map = extract_logic_i_column_map(&config, num_witin); + let shard_ctx_gpu = ShardContext::default(); + let shard_offset = shard_ctx_gpu.current_shard_offset_cycle(); + let steps_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + steps.as_ptr() as *const u8, + steps.len() * std::mem::size_of::(), + ) + }; + let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); + let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); + let gpu_result = hal + .witgen + .witgen_logic_i( + &col_map, + &gpu_records, + &indices_u32, + shard_offset, + 0, + 0, + 0, + None, + None, + ) + .unwrap(); + + let gpu_data: Vec<::BaseField> = + gpu_result.witness.device_buffer.to_vec().unwrap(); + assert_witness_colmajor_eq(&gpu_data, cpu_witness.values(), n, num_witin); + } +} diff --git a/ceno_zkvm/src/instructions/gpu/chips/logic_r.rs b/ceno_zkvm/src/instructions/gpu/chips/logic_r.rs new file mode 100644 index 000000000..2c22b8c39 --- /dev/null +++ b/ceno_zkvm/src/instructions/gpu/chips/logic_r.rs @@ -0,0 +1,181 @@ +use ceno_gpu::common::witgen::types::LogicRColumnMap; +use ff_ext::ExtensionField; + +use crate::instructions::{ + gpu::utils::column_map::{ + extract_rd, extract_rs1, extract_rs2, extract_state, extract_uint_limbs, + }, + riscv::logic::logic_circuit::LogicConfig, +}; + +/// Extract column map from a constructed LogicConfig (R-type: AND/OR/XOR). +pub fn extract_logic_r_column_map( + config: &LogicConfig, + num_witin: usize, +) -> LogicRColumnMap { + let (pc, ts) = extract_state(&config.r_insn.vm_state); + let (rs1_id, rs1_prev_ts, rs1_lt_diff) = extract_rs1(&config.r_insn.rs1); + let (rs2_id, rs2_prev_ts, rs2_lt_diff) = extract_rs2(&config.r_insn.rs2); + let (rd_id, rd_prev_ts, rd_prev_val, rd_lt_diff) = extract_rd(&config.r_insn.rd); + + let rs1_bytes = extract_uint_limbs::(&config.rs1_read, "rs1_read"); + let rs2_bytes = extract_uint_limbs::(&config.rs2_read, "rs2_read"); + let rd_bytes = extract_uint_limbs::(&config.rd_written, "rd_written"); + + LogicRColumnMap { + pc, + ts, + rs1_id, + rs1_prev_ts, + rs1_lt_diff, + rs2_id, + rs2_prev_ts, + rs2_lt_diff, + rd_id, + rd_prev_ts, + rd_prev_val, + rd_lt_diff, + rs1_bytes, + rs2_bytes, + rd_bytes, + num_cols: num_witin as u32, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::{Instruction, riscv::logic::AndInstruction}, + structs::ProgramParams, + }; + use ff_ext::BabyBearExt4; + + type E = BabyBearExt4; + + use crate::instructions::gpu::utils::column_map::test_colmap; + test_colmap!( + test_extract_logic_r_column_map, + AndInstruction, + extract_logic_r_column_map + ); + + #[test] + #[cfg(feature = "gpu")] + fn test_gpu_witgen_logic_r_correctness() { + use crate::{ + e2e::ShardContext, + instructions::gpu::{ + dispatch, + utils::test_helpers::{assert_full_gpu_pipeline, assert_witness_colmajor_eq}, + }, + }; + use ceno_emul::{ByteAddr, Change, InsnKind, StepRecord, encode_rv32}; + use ceno_gpu::{Buffer, bb31::CudaHalBB31}; + + let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); + + let mut cs = ConstraintSystem::::new(|| "test_and_gpu"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + AndInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let num_witin = cb.cs.num_witin as usize; + let num_structural_witin = cb.cs.num_structural_witin as usize; + + const EDGE_CASES: &[(u32, u32)] = &[ + (0, 0), + (u32::MAX, u32::MAX), + (u32::MAX, 0), + (0, u32::MAX), + (0xAAAAAAAA, 0x55555555), // alternating bits + (0xFFFF0000, 0x0000FFFF), // no overlap + (0xDEADBEEF, 0xFFFFFFFF), // identity + (0x12345678, 0x00000000), // zero + ]; + + let n = 1024; + let steps: Vec = (0..n) + .map(|i| { + let (rs1, rs2) = if i < EDGE_CASES.len() { + EDGE_CASES[i] + } else { + ( + 0xDEAD_0000u32 | (i as u32), + 0x00FF_FF00u32 | ((i as u32) << 8), + ) + }; + let rd_after = rs1 & rs2; // AND + let cycle = 4 + (i as u64) * 4; + let pc = ByteAddr(0x1000 + (i as u32) * 4); + let insn_code = encode_rv32(InsnKind::AND, 2, 3, 4, 0); + StepRecord::new_r_instruction( + cycle, + pc, + insn_code, + rs1, + rs2, + Change::new((i as u32) % 200, rd_after), + 0, + ) + }) + .collect(); + let indices: Vec = (0..n).collect(); + + // CPU path + let mut shard_ctx = ShardContext::default(); + let (cpu_rmms, cpu_lkm) = + crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(); + let cpu_witness = &cpu_rmms[0]; + + // GPU path + let col_map = extract_logic_r_column_map(&config, num_witin); + let shard_ctx_gpu = ShardContext::default(); + let shard_offset = shard_ctx_gpu.current_shard_offset_cycle(); + let steps_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + steps.as_ptr() as *const u8, + steps.len() * std::mem::size_of::(), + ) + }; + let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); + let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); + let gpu_result = hal + .witgen + .witgen_logic_r( + &col_map, + &gpu_records, + &indices_u32, + shard_offset, + 0, + 0, + 0, + None, + None, + ) + .unwrap(); + + let gpu_data: Vec<::BaseField> = + gpu_result.witness.device_buffer.to_vec().unwrap(); + assert_witness_colmajor_eq(&gpu_data, cpu_witness.values(), n, num_witin); + + assert_full_gpu_pipeline::>( + &config, + &steps, + dispatch::GpuWitgenKind::LogicR(0), + &cpu_rmms, + &cpu_lkm, + &shard_ctx, + num_witin, + num_structural_witin, + ); + } +} diff --git a/ceno_zkvm/src/instructions/gpu/chips/lui.rs b/ceno_zkvm/src/instructions/gpu/chips/lui.rs new file mode 100644 index 000000000..37e359a28 --- /dev/null +++ b/ceno_zkvm/src/instructions/gpu/chips/lui.rs @@ -0,0 +1,143 @@ +use ceno_gpu::common::witgen::types::LuiColumnMap; +use ff_ext::ExtensionField; + +use crate::instructions::{ + gpu::utils::column_map::{extract_rd, extract_rs1, extract_state}, + riscv::lui::LuiConfig, +}; + +/// Extract column map from a constructed LuiConfig. +pub fn extract_lui_column_map( + config: &LuiConfig, + num_witin: usize, +) -> LuiColumnMap { + let im = &config.i_insn; + + let (pc, ts) = extract_state(&im.vm_state); + let (rs1_id, rs1_prev_ts, rs1_lt_diff) = extract_rs1(&im.rs1); + let (rd_id, rd_prev_ts, rd_prev_val, rd_lt_diff) = extract_rd(&im.rd); + + // LUI-specific: rd bytes (skip byte 0) + imm + let rd_bytes: [u32; 3] = [ + config.rd_written[0].id as u32, + config.rd_written[1].id as u32, + config.rd_written[2].id as u32, + ]; + let imm = config.imm.id as u32; + + LuiColumnMap { + pc, + ts, + rs1_id, + rs1_prev_ts, + rs1_lt_diff, + rd_id, + rd_prev_ts, + rd_prev_val, + rd_lt_diff, + rd_bytes, + imm, + num_cols: num_witin as u32, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::{Instruction, riscv::lui::LuiInstruction}, + structs::ProgramParams, + }; + use ff_ext::BabyBearExt4; + + type E = BabyBearExt4; + + use crate::instructions::gpu::utils::column_map::test_colmap; + test_colmap!( + test_extract_lui_column_map, + LuiInstruction, + extract_lui_column_map + ); + + #[test] + #[cfg(feature = "gpu")] + fn test_gpu_witgen_lui_correctness() { + use crate::{ + e2e::ShardContext, instructions::gpu::utils::test_helpers::assert_witness_colmajor_eq, + }; + use ceno_emul::{ByteAddr, Change, InsnKind, PC_STEP_SIZE, StepRecord, encode_rv32}; + use ceno_gpu::{Buffer, bb31::CudaHalBB31}; + + let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); + + let mut cs = ConstraintSystem::::new(|| "test_lui_gpu"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + LuiInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let num_witin = cb.cs.num_witin as usize; + let num_structural_witin = cb.cs.num_structural_witin as usize; + + let n = 1024; + let steps: Vec = (0..n) + .map(|i| { + let imm_20bit = (i as i32) % 0x100000; // 0..0xfffff (20-bit) + let imm = imm_20bit << 12; // LUI immediate is upper 20 bits + let rd_after = imm as u32; + let cycle = 4 + (i as u64) * 4; + let pc = ByteAddr(0x1000 + (i as u32) * 4); + let insn_code = encode_rv32(InsnKind::LUI, 0, 0, 4, imm); + StepRecord::new_i_instruction( + cycle, + Change::new(pc, pc + PC_STEP_SIZE), + insn_code, + 0, + Change::new((i as u32) % 200, rd_after), + 0, + ) + }) + .collect(); + let indices: Vec = (0..n).collect(); + + let mut shard_ctx = ShardContext::default(); + let (cpu_rmms, _lkm) = crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(); + let cpu_witness = &cpu_rmms[0]; + + let col_map = extract_lui_column_map(&config, num_witin); + let shard_ctx_gpu = ShardContext::default(); + let shard_offset = shard_ctx_gpu.current_shard_offset_cycle(); + let steps_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + steps.as_ptr() as *const u8, + steps.len() * std::mem::size_of::(), + ) + }; + let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); + let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); + let gpu_result = hal + .witgen + .witgen_lui( + &col_map, + &gpu_records, + &indices_u32, + shard_offset, + 0, + 0, + None, + None, + ) + .unwrap(); + + let gpu_data: Vec<::BaseField> = + gpu_result.witness.device_buffer.to_vec().unwrap(); + assert_witness_colmajor_eq(&gpu_data, cpu_witness.values(), n, num_witin); + } +} diff --git a/ceno_zkvm/src/instructions/gpu/chips/lw.rs b/ceno_zkvm/src/instructions/gpu/chips/lw.rs new file mode 100644 index 000000000..14f09779b --- /dev/null +++ b/ceno_zkvm/src/instructions/gpu/chips/lw.rs @@ -0,0 +1,189 @@ +use ceno_gpu::common::witgen::types::LwColumnMap; +use ff_ext::ExtensionField; + +use crate::instructions::gpu::utils::column_map::{ + extract_rd, extract_read_mem, extract_rs1, extract_state, extract_uint_limbs, +}; + +#[cfg(not(feature = "u16limb_circuit"))] +use crate::instructions::riscv::memory::load::LoadConfig; +#[cfg(feature = "u16limb_circuit")] +use crate::instructions::riscv::memory::load_v2::LoadConfig; + +/// Extract column map from a constructed LoadConfig (LW variant). +pub fn extract_lw_column_map( + config: &LoadConfig, + num_witin: usize, +) -> LwColumnMap { + let im = &config.im_insn; + + let (pc, ts) = extract_state(&im.vm_state); + let (rs1_id, rs1_prev_ts, rs1_lt_diff) = extract_rs1(&im.rs1); + let (rd_id, rd_prev_ts, rd_prev_val, rd_lt_diff) = extract_rd(&im.rd); + let (mem_prev_ts, mem_lt_diff) = extract_read_mem(&im.mem_read); + + let rs1_limbs = extract_uint_limbs::(&config.rs1_read, "rs1_read"); + let imm = config.imm.id as u32; + #[cfg(feature = "u16limb_circuit")] + let imm_sign = Some(config.imm_sign.id as u32); + #[cfg(not(feature = "u16limb_circuit"))] + let imm_sign = None; + let mem_addr_limbs = extract_uint_limbs::(&config.memory_addr.addr, "memory_addr"); + let mem_read_limbs = extract_uint_limbs::(&config.memory_read, "memory_read"); + + LwColumnMap { + pc, + ts, + rs1_id, + rs1_prev_ts, + rs1_lt_diff, + rd_id, + rd_prev_ts, + rd_prev_val, + rd_lt_diff, + mem_prev_ts, + mem_lt_diff, + rs1_limbs, + imm, + imm_sign, + mem_addr_limbs, + mem_read_limbs, + num_cols: num_witin as u32, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::Instruction, + structs::ProgramParams, + }; + use ceno_emul::{ByteAddr, Change, InsnKind, ReadOp, StepRecord, WordAddr, encode_rv32}; + use ff_ext::BabyBearExt4; + + type E = BabyBearExt4; + type LwInstruction = crate::instructions::riscv::LwInstruction; + + fn make_lw_test_steps(n: usize) -> Vec { + let pc_start = 0x1000u32; + // Use varying immediates including negative values to test imm_field encoding + let imm_values: [i32; 4] = [0, 4, -4, -8]; + (0..n) + .map(|i| { + let rs1_val = 0x1000u32 + (i as u32) * 16; // 16-byte aligned base + let imm: i32 = imm_values[i % imm_values.len()]; + let mem_addr = rs1_val.wrapping_add_signed(imm); + let mem_val = (i as u32) * 111 % 1000000; + let rd_before = (i as u32) % 200; + let cycle = 4 + (i as u64) * 4; + let pc = ByteAddr(pc_start + (i as u32) * 4); + let insn_code = encode_rv32(InsnKind::LW, 2, 0, 4, imm); + + let mem_read_op = ReadOp { + addr: WordAddr::from(ByteAddr(mem_addr)), + value: mem_val, + previous_cycle: 0, + }; + + StepRecord::new_im_instruction( + cycle, + pc, + insn_code, + rs1_val, + Change::new(rd_before, mem_val), + mem_read_op, + 0, + ) + }) + .collect() + } + + use crate::instructions::gpu::utils::column_map::test_colmap; + test_colmap!( + test_extract_lw_column_map, + LwInstruction, + extract_lw_column_map + ); + + #[test] + #[cfg(feature = "gpu")] + fn test_gpu_witgen_lw_correctness() { + use crate::{ + e2e::ShardContext, + instructions::gpu::{ + dispatch, + utils::test_helpers::{assert_full_gpu_pipeline, assert_witness_colmajor_eq}, + }, + }; + use ceno_gpu::{Buffer, bb31::CudaHalBB31}; + + let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); + + let mut cs = ConstraintSystem::::new(|| "test_lw_gpu"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = LwInstruction::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let num_witin = cb.cs.num_witin as usize; + let num_structural_witin = cb.cs.num_structural_witin as usize; + + let n = 1024; + let steps = make_lw_test_steps(n); + let indices: Vec = (0..n).collect(); + + // CPU path + let mut shard_ctx = ShardContext::default(); + let (cpu_rmms, cpu_lkm) = crate::instructions::cpu_assign_instances::( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(); + let cpu_witness = &cpu_rmms[0]; + + // GPU path (AOS with indirect indexing) + let col_map = extract_lw_column_map(&config, num_witin); + let shard_ctx_gpu = ShardContext::default(); + let shard_offset = shard_ctx_gpu.current_shard_offset_cycle(); + let steps_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + steps.as_ptr() as *const u8, + steps.len() * std::mem::size_of::(), + ) + }; + let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); + let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); + let gpu_result = hal + .witgen + .witgen_lw( + &col_map, + &gpu_records, + &indices_u32, + shard_offset, + 0, + 0, + 0, + None, + None, + ) + .unwrap(); + + let gpu_data: Vec<::BaseField> = + gpu_result.witness.device_buffer.to_vec().unwrap(); + assert_witness_colmajor_eq(&gpu_data, cpu_witness.values(), n, num_witin); + + assert_full_gpu_pipeline::( + &config, + &steps, + dispatch::GpuWitgenKind::Lw, + &cpu_rmms, + &cpu_lkm, + &shard_ctx, + num_witin, + num_structural_witin, + ); + } +} diff --git a/ceno_zkvm/src/instructions/gpu/chips/mod.rs b/ceno_zkvm/src/instructions/gpu/chips/mod.rs new file mode 100644 index 000000000..ebb4d36ae --- /dev/null +++ b/ceno_zkvm/src/instructions/gpu/chips/mod.rs @@ -0,0 +1,48 @@ +#[cfg(feature = "gpu")] +pub mod add; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod addi; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod auipc; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod branch_cmp; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod branch_eq; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod div; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod jal; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod jalr; +#[cfg(feature = "gpu")] +pub mod keccak; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod load_sub; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod logic_i; +#[cfg(feature = "gpu")] +pub mod logic_r; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod lui; +#[cfg(feature = "gpu")] +pub mod lw; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod mul; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod sb; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod sh; +#[cfg(feature = "gpu")] +pub mod shard_ram; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod shift_i; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod shift_r; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod slt; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod slti; +#[cfg(feature = "gpu")] +pub mod sub; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod sw; diff --git a/ceno_zkvm/src/instructions/gpu/chips/mul.rs b/ceno_zkvm/src/instructions/gpu/chips/mul.rs new file mode 100644 index 000000000..f7e8acab3 --- /dev/null +++ b/ceno_zkvm/src/instructions/gpu/chips/mul.rs @@ -0,0 +1,282 @@ +use ceno_gpu::common::witgen::types::MulColumnMap; +use ff_ext::ExtensionField; + +use crate::instructions::{ + gpu::utils::column_map::{ + extract_rd, extract_rs1, extract_rs2, extract_state, extract_uint_limbs, + }, + riscv::mulh::mulh_circuit_v2::MulhConfig, +}; + +/// Extract column map from a constructed MulhConfig. +/// mul_kind: 0=MUL, 1=MULH, 2=MULHU, 3=MULHSU +pub fn extract_mul_column_map( + config: &MulhConfig, + num_witin: usize, +) -> MulColumnMap { + let (pc, ts) = extract_state(&config.r_insn.vm_state); + let (rs1_id, rs1_prev_ts, rs1_lt_diff) = extract_rs1(&config.r_insn.rs1); + let (rs2_id, rs2_prev_ts, rs2_lt_diff) = extract_rs2(&config.r_insn.rs2); + let (rd_id, rd_prev_ts, rd_prev_val, rd_lt_diff) = extract_rd(&config.r_insn.rd); + + let rs1_limbs = extract_uint_limbs::(&config.rs1_read, "rs1_read"); + let rs2_limbs = extract_uint_limbs::(&config.rs2_read, "rs2_read"); + let rd_low: [u32; 2] = [config.rd_low[0].id as u32, config.rd_low[1].id as u32]; + + // MULH/MULHU/MULHSU have rd_high + extensions; MUL does not. + let (rd_high, rs1_ext, rs2_ext) = match config.rd_high.as_ref() { + Some(h) => ( + Some([h[0].id as u32, h[1].id as u32]), + Some(config.rs1_ext.expect("MULH variants must have rs1_ext").id as u32), + Some(config.rs2_ext.expect("MULH variants must have rs2_ext").id as u32), + ), + None => (None, None, None), + }; + + MulColumnMap { + pc, + ts, + rs1_id, + rs1_prev_ts, + rs1_lt_diff, + rs2_id, + rs2_prev_ts, + rs2_lt_diff, + rd_id, + rd_prev_ts, + rd_prev_val, + rd_lt_diff, + rs1_limbs, + rs2_limbs, + rd_low, + rd_high, + rs1_ext, + rs2_ext, + num_cols: num_witin as u32, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::{ + Instruction, + riscv::mulh::{MulInstruction, MulhInstruction, MulhsuInstruction, MulhuInstruction}, + }, + structs::ProgramParams, + }; + use ff_ext::BabyBearExt4; + + type E = BabyBearExt4; + + use crate::instructions::gpu::utils::column_map::test_colmap; + test_colmap!( + test_extract_mul_column_map, + MulInstruction, + extract_mul_column_map + ); + test_colmap!( + test_extract_mulh_column_map, + MulhInstruction, + extract_mul_column_map + ); + test_colmap!( + test_extract_mulhu_column_map, + MulhuInstruction, + extract_mul_column_map + ); + test_colmap!( + test_extract_mulhsu_column_map, + MulhsuInstruction, + extract_mul_column_map + ); + + #[test] + #[cfg(feature = "gpu")] + fn test_gpu_witgen_mul_correctness() { + use crate::{ + e2e::ShardContext, instructions::gpu::utils::test_helpers::assert_witness_colmajor_eq, + }; + use ceno_emul::{ByteAddr, Change, InsnKind, StepRecord, encode_rv32}; + use ceno_gpu::{Buffer, bb31::CudaHalBB31}; + + let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); + + let variants: &[(InsnKind, u32, &str)] = &[ + (InsnKind::MUL, 0, "MUL"), + (InsnKind::MULH, 1, "MULH"), + (InsnKind::MULHU, 2, "MULHU"), + (InsnKind::MULHSU, 3, "MULHSU"), + ]; + + for &(insn_kind, mul_kind, name) in variants { + eprintln!("Testing {} GPU vs CPU correctness...", name); + + let mut cs = ConstraintSystem::::new(|| format!("test_{}", name.to_lowercase())); + let mut cb = CircuitBuilder::new(&mut cs); + + let config = match insn_kind { + InsnKind::MUL => { + MulInstruction::::construct_circuit(&mut cb, &ProgramParams::default()) + .unwrap() + } + InsnKind::MULH => { + MulhInstruction::::construct_circuit(&mut cb, &ProgramParams::default()) + .unwrap() + } + InsnKind::MULHU => { + MulhuInstruction::::construct_circuit(&mut cb, &ProgramParams::default()) + .unwrap() + } + InsnKind::MULHSU => { + MulhsuInstruction::::construct_circuit(&mut cb, &ProgramParams::default()) + .unwrap() + } + _ => unreachable!(), + }; + let num_witin = cb.cs.num_witin as usize; + let num_structural_witin = cb.cs.num_structural_witin as usize; + + const EDGE_CASES: &[(u32, u32)] = &[ + (0, 0), // zero * zero + (0, 12345), // zero * non-zero + (12345, 0), // non-zero * zero + (1, 1), // identity + (u32::MAX, 1), // max * 1 + (1, u32::MAX), // 1 * max + (u32::MAX, u32::MAX), // max * max + (0x80000000, 2), // INT_MIN * 2 (for MULH) + (2, 0x80000000), // 2 * INT_MIN + (0xFFFFFFFF, 0xFFFFFFFF), // (-1) * (-1) for signed + (0x80000000, 0xFFFFFFFF), // INT_MIN * (-1) + (0x7FFFFFFF, 0x7FFFFFFF), // INT_MAX * INT_MAX + ]; + + let n = 1024; + let steps: Vec = (0..n) + .map(|i| { + let pc = ByteAddr(0x1000 + (i as u32) * 4); + let (rs1_val, rs2_val) = if i < EDGE_CASES.len() { + EDGE_CASES[i] + } else { + ( + (i as u32).wrapping_mul(12345).wrapping_add(7), + (i as u32).wrapping_mul(54321).wrapping_add(13), + ) + }; + let rd_after = match insn_kind { + InsnKind::MUL => rs1_val.wrapping_mul(rs2_val), + InsnKind::MULH => { + ((rs1_val as i32 as i64).wrapping_mul(rs2_val as i32 as i64) >> 32) + as u32 + } + InsnKind::MULHU => { + ((rs1_val as u64).wrapping_mul(rs2_val as u64) >> 32) as u32 + } + InsnKind::MULHSU => { + ((rs1_val as i32 as i64).wrapping_mul(rs2_val as i64) >> 32) as u32 + } + _ => unreachable!(), + }; + let rd_before = (i as u32) % 200; + let cycle = 4 + (i as u64) * 4; + let insn_code = encode_rv32(insn_kind, 2, 3, 4, 0); + + StepRecord::new_r_instruction( + cycle, + pc, + insn_code, + rs1_val, + rs2_val, + Change::new(rd_before, rd_after), + 0, + ) + }) + .collect(); + let indices: Vec = (0..n).collect(); + + // CPU path + let mut shard_ctx = ShardContext::default(); + let (cpu_rmms, _lkm) = match insn_kind { + InsnKind::MUL => crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(), + InsnKind::MULH => { + crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap() + } + InsnKind::MULHU => { + crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap() + } + InsnKind::MULHSU => { + crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap() + } + _ => unreachable!(), + }; + let cpu_witness = &cpu_rmms[0]; + + // GPU path + let col_map = extract_mul_column_map(&config, num_witin); + let shard_ctx_gpu = ShardContext::default(); + let shard_offset = shard_ctx_gpu.current_shard_offset_cycle(); + let steps_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + steps.as_ptr() as *const u8, + steps.len() * std::mem::size_of::(), + ) + }; + let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); + let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); + let gpu_result = hal + .witgen + .witgen_mul( + &col_map, + &gpu_records, + &indices_u32, + shard_offset, + mul_kind, + 0, + 0, + None, + None, + ) + .unwrap(); + + let gpu_data: Vec<::BaseField> = + gpu_result.witness.device_buffer.to_vec().unwrap(); + assert_witness_colmajor_eq(&gpu_data, cpu_witness.values(), n, num_witin); + eprintln!("{} GPU vs CPU: PASS ({} instances)", name, n); + } + } +} diff --git a/ceno_zkvm/src/instructions/gpu/chips/sb.rs b/ceno_zkvm/src/instructions/gpu/chips/sb.rs new file mode 100644 index 000000000..f312b06f8 --- /dev/null +++ b/ceno_zkvm/src/instructions/gpu/chips/sb.rs @@ -0,0 +1,203 @@ +use ceno_gpu::common::witgen::types::SbColumnMap; +use ff_ext::ExtensionField; + +use crate::instructions::{ + gpu::utils::column_map::{ + extract_rs1, extract_rs2, extract_state, extract_uint_limbs, extract_write_mem, + }, + riscv::memory::store_v2::StoreConfig, +}; + +/// Extract column map from a constructed StoreConfig (SB variant, N_ZEROS=0). +pub fn extract_sb_column_map( + config: &StoreConfig, + num_witin: usize, +) -> SbColumnMap { + let sm = &config.s_insn; + + let (pc, ts) = extract_state(&sm.vm_state); + let (rs1_id, rs1_prev_ts, rs1_lt_diff) = extract_rs1(&sm.rs1); + let (rs2_id, rs2_prev_ts, rs2_lt_diff) = extract_rs2(&sm.rs2); + let (mem_prev_ts, mem_lt_diff) = extract_write_mem(&sm.mem_write); + + let rs1_limbs = extract_uint_limbs::(&config.rs1_read, "rs1_read"); + let rs2_limbs = extract_uint_limbs::(&config.rs2_read, "rs2_read"); + let imm = config.imm.id as u32; + let imm_sign = config.imm_sign.id as u32; + let prev_mem_val = + extract_uint_limbs::(&config.prev_memory_value, "prev_memory_value"); + let mem_addr = extract_uint_limbs::(&config.memory_addr.addr, "memory_addr"); + + // SB-specific: 2 low_bits (bit_0, bit_1) + assert_eq!( + config.memory_addr.low_bits.len(), + 2, + "SB should have 2 low_bits" + ); + let mem_addr_bit_0 = config.memory_addr.low_bits[0].id as u32; + let mem_addr_bit_1 = config.memory_addr.low_bits[1].id as u32; + + // MemWordUtil fields (SB has N_ZEROS=0 so these exist) + let mem_word_util = config + .next_memory_value + .as_ref() + .expect("SB must have next_memory_value (MemWordUtil)"); + assert_eq!(mem_word_util.prev_limb_bytes.len(), 2); + let prev_limb_bytes: [u32; 2] = [ + mem_word_util.prev_limb_bytes[0].id as u32, + mem_word_util.prev_limb_bytes[1].id as u32, + ]; + assert_eq!(mem_word_util.rs2_limb_bytes.len(), 1); + let rs2_limb_byte = mem_word_util.rs2_limb_bytes[0].id as u32; + let expected_limb = mem_word_util + .expected_limb + .as_ref() + .expect("SB must have expected_limb") + .id as u32; + + SbColumnMap { + pc, + ts, + rs1_id, + rs1_prev_ts, + rs1_lt_diff, + rs2_id, + rs2_prev_ts, + rs2_lt_diff, + mem_prev_ts, + mem_lt_diff, + rs1_limbs, + rs2_limbs, + imm, + imm_sign, + prev_mem_val, + mem_addr, + mem_addr_bit_0, + mem_addr_bit_1, + prev_limb_bytes, + rs2_limb_byte, + expected_limb, + num_cols: num_witin as u32, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::{Instruction, riscv::memory::SbInstruction}, + structs::ProgramParams, + }; + use ff_ext::BabyBearExt4; + + type E = BabyBearExt4; + + use crate::instructions::gpu::utils::column_map::test_colmap; + test_colmap!( + test_extract_sb_column_map, + SbInstruction, + extract_sb_column_map + ); + + #[test] + #[cfg(feature = "gpu")] + fn test_gpu_witgen_sb_correctness() { + use crate::{ + e2e::ShardContext, instructions::gpu::utils::test_helpers::assert_witness_colmajor_eq, + }; + use ceno_emul::{ByteAddr, Change, InsnKind, StepRecord, WordAddr, WriteOp, encode_rv32}; + use ceno_gpu::{Buffer, bb31::CudaHalBB31}; + + let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); + + let mut cs = ConstraintSystem::::new(|| "test_sb_gpu"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + SbInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let num_witin = cb.cs.num_witin as usize; + let num_structural_witin = cb.cs.num_structural_witin as usize; + + let n = 1024; + let imm_values: [i32; 4] = [0, 1, -1, -3]; + let steps: Vec = (0..n) + .map(|i| { + let pc = ByteAddr(0x1000 + (i as u32) * 4); + let rs1_val = 0x1000u32 + (i as u32) * 16; + let rs2_val = (i as u32) * 111 % 1000000; + let imm: i32 = imm_values[i % imm_values.len()]; + let mem_addr = rs1_val.wrapping_add_signed(imm); + let prev_mem_val = (i as u32) * 77 % 500000; + // SB stores the low byte of rs2 into the selected byte + let bit_0 = mem_addr & 1; + let bit_1 = (mem_addr >> 1) & 1; + let rs2_byte = (rs2_val & 0xFF) as u8; + let byte_idx = (bit_1 * 2 + bit_0) as usize; + let mut bytes = prev_mem_val.to_le_bytes(); + bytes[byte_idx] = rs2_byte; + let new_mem_val = u32::from_le_bytes(bytes); + let cycle = 4 + (i as u64) * 4; + let insn_code = encode_rv32(InsnKind::SB, 2, 3, 0, imm); + + let mem_write_op = WriteOp { + addr: WordAddr::from(ByteAddr(mem_addr & !3)), + value: Change::new(prev_mem_val, new_mem_val), + previous_cycle: 0, + }; + + StepRecord::new_s_instruction( + cycle, + pc, + insn_code, + rs1_val, + rs2_val, + mem_write_op, + 0, + ) + }) + .collect(); + let indices: Vec = (0..n).collect(); + + let mut shard_ctx = ShardContext::default(); + let (cpu_rmms, _lkm) = crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(); + let cpu_witness = &cpu_rmms[0]; + + let col_map = extract_sb_column_map(&config, num_witin); + let shard_ctx_gpu = ShardContext::default(); + let shard_offset = shard_ctx_gpu.current_shard_offset_cycle(); + let steps_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + steps.as_ptr() as *const u8, + steps.len() * std::mem::size_of::(), + ) + }; + let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); + let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); + let gpu_result = hal + .witgen + .witgen_sb( + &col_map, + &gpu_records, + &indices_u32, + shard_offset, + 0, + 0, + 0, + None, + None, + ) + .unwrap(); + + let gpu_data: Vec<::BaseField> = + gpu_result.witness.device_buffer.to_vec().unwrap(); + assert_witness_colmajor_eq(&gpu_data, cpu_witness.values(), n, num_witin); + } +} diff --git a/ceno_zkvm/src/instructions/gpu/chips/sh.rs b/ceno_zkvm/src/instructions/gpu/chips/sh.rs new file mode 100644 index 000000000..c73768408 --- /dev/null +++ b/ceno_zkvm/src/instructions/gpu/chips/sh.rs @@ -0,0 +1,180 @@ +use ceno_gpu::common::witgen::types::ShColumnMap; +use ff_ext::ExtensionField; + +use crate::instructions::{ + gpu::utils::column_map::{ + extract_rs1, extract_rs2, extract_state, extract_uint_limbs, extract_write_mem, + }, + riscv::memory::store_v2::StoreConfig, +}; + +/// Extract column map from a constructed StoreConfig (SH variant, N_ZEROS=1). +pub fn extract_sh_column_map( + config: &StoreConfig, + num_witin: usize, +) -> ShColumnMap { + let sm = &config.s_insn; + + let (pc, ts) = extract_state(&sm.vm_state); + let (rs1_id, rs1_prev_ts, rs1_lt_diff) = extract_rs1(&sm.rs1); + let (rs2_id, rs2_prev_ts, rs2_lt_diff) = extract_rs2(&sm.rs2); + let (mem_prev_ts, mem_lt_diff) = extract_write_mem(&sm.mem_write); + + let rs1_limbs = extract_uint_limbs::(&config.rs1_read, "rs1_read"); + let rs2_limbs = extract_uint_limbs::(&config.rs2_read, "rs2_read"); + let imm = config.imm.id as u32; + let imm_sign = config.imm_sign.id as u32; + let prev_mem_val = + extract_uint_limbs::(&config.prev_memory_value, "prev_memory_value"); + let mem_addr = extract_uint_limbs::(&config.memory_addr.addr, "memory_addr"); + + // SH-specific: 1 low_bit (bit_1 for halfword select) + assert_eq!( + config.memory_addr.low_bits.len(), + 1, + "SH should have 1 low_bit" + ); + let mem_addr_bit_1 = config.memory_addr.low_bits[0].id as u32; + + ShColumnMap { + pc, + ts, + rs1_id, + rs1_prev_ts, + rs1_lt_diff, + rs2_id, + rs2_prev_ts, + rs2_lt_diff, + mem_prev_ts, + mem_lt_diff, + rs1_limbs, + rs2_limbs, + imm, + imm_sign, + prev_mem_val, + mem_addr, + mem_addr_bit_1, + num_cols: num_witin as u32, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::{Instruction, riscv::memory::ShInstruction}, + structs::ProgramParams, + }; + use ff_ext::BabyBearExt4; + + type E = BabyBearExt4; + + use crate::instructions::gpu::utils::column_map::test_colmap; + test_colmap!( + test_extract_sh_column_map, + ShInstruction, + extract_sh_column_map + ); + + #[test] + #[cfg(feature = "gpu")] + fn test_gpu_witgen_sh_correctness() { + use crate::{ + e2e::ShardContext, instructions::gpu::utils::test_helpers::assert_witness_colmajor_eq, + }; + use ceno_emul::{ByteAddr, Change, InsnKind, StepRecord, WordAddr, WriteOp, encode_rv32}; + use ceno_gpu::{Buffer, bb31::CudaHalBB31}; + + let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); + + let mut cs = ConstraintSystem::::new(|| "test_sh_gpu"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + ShInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let num_witin = cb.cs.num_witin as usize; + let num_structural_witin = cb.cs.num_structural_witin as usize; + + let n = 1024; + let imm_values: [i32; 4] = [0, 2, -2, -6]; + let steps: Vec = (0..n) + .map(|i| { + let pc = ByteAddr(0x1000 + (i as u32) * 4); + let rs1_val = 0x1000u32 + (i as u32) * 16; + let rs2_val = (i as u32) * 111 % 1000000; + let imm: i32 = imm_values[i % imm_values.len()]; + let mem_addr = rs1_val.wrapping_add_signed(imm); + // SH stores the low halfword of rs2 into the selected halfword + let prev_mem_val = (i as u32) * 77 % 500000; + let bit_1 = (mem_addr >> 1) & 1; + let rs2_hw = rs2_val & 0xFFFF; + let new_mem_val = if bit_1 == 0 { + (prev_mem_val & 0xFFFF0000) | rs2_hw + } else { + (prev_mem_val & 0x0000FFFF) | (rs2_hw << 16) + }; + let cycle = 4 + (i as u64) * 4; + let insn_code = encode_rv32(InsnKind::SH, 2, 3, 0, imm); + + let mem_write_op = WriteOp { + addr: WordAddr::from(ByteAddr(mem_addr & !3)), + value: Change::new(prev_mem_val, new_mem_val), + previous_cycle: 0, + }; + + StepRecord::new_s_instruction( + cycle, + pc, + insn_code, + rs1_val, + rs2_val, + mem_write_op, + 0, + ) + }) + .collect(); + let indices: Vec = (0..n).collect(); + + let mut shard_ctx = ShardContext::default(); + let (cpu_rmms, _lkm) = crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(); + let cpu_witness = &cpu_rmms[0]; + + let col_map = extract_sh_column_map(&config, num_witin); + let shard_ctx_gpu = ShardContext::default(); + let shard_offset = shard_ctx_gpu.current_shard_offset_cycle(); + let steps_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + steps.as_ptr() as *const u8, + steps.len() * std::mem::size_of::(), + ) + }; + let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); + let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); + let gpu_result = hal + .witgen + .witgen_sh( + &col_map, + &gpu_records, + &indices_u32, + shard_offset, + 0, + 0, + 0, + None, + None, + ) + .unwrap(); + + let gpu_data: Vec<::BaseField> = + gpu_result.witness.device_buffer.to_vec().unwrap(); + assert_witness_colmajor_eq(&gpu_data, cpu_witness.values(), n, num_witin); + } +} diff --git a/ceno_zkvm/src/instructions/gpu/chips/shard_ram.rs b/ceno_zkvm/src/instructions/gpu/chips/shard_ram.rs new file mode 100644 index 000000000..e0cca71d6 --- /dev/null +++ b/ceno_zkvm/src/instructions/gpu/chips/shard_ram.rs @@ -0,0 +1,300 @@ +use ceno_gpu::common::witgen::types::ShardRamColumnMap; +use ff_ext::ExtensionField; + +use crate::tables::ShardRamConfig; + +/// Extract column map from a constructed ShardRamConfig. +/// +/// This reads all WitIn.id values from the config and packs them +/// into a ShardRamColumnMap suitable for GPU kernel dispatch. +pub fn extract_shard_ram_column_map( + config: &ShardRamConfig, + num_witin: usize, +) -> ShardRamColumnMap { + let addr = config.addr.id as u32; + let is_ram_register = config.is_ram_register.id as u32; + + let value_limbs = config + .value + .wits_in() + .expect("value should have WitIn limbs"); + assert_eq!(value_limbs.len(), 2, "Expected 2 value limbs"); + let value = [value_limbs[0].id as u32, value_limbs[1].id as u32]; + + let shard = config.shard.id as u32; + let global_clk = config.global_clk.id as u32; + let local_clk = config.local_clk.id as u32; + let nonce = config.nonce.id as u32; + let is_global_write = config.is_global_write.id as u32; + + let mut x = [0u32; 7]; + let mut y = [0u32; 7]; + let mut slope = [0u32; 7]; + for i in 0..7 { + x[i] = config.x[i].id as u32; + y[i] = config.y[i].id as u32; + slope[i] = config.slope[i].id as u32; + } + + // Poseidon2 columns: p3_cols are contiguous, followed by post_linear_layer_cols + let poseidon2_base_col = config.perm_config.p3_cols[0].id as u32; + let num_p3_cols = config.perm_config.p3_cols.len() as u32; + let num_post_linear = config.perm_config.post_linear_layer_cols.len() as u32; + let num_poseidon2_cols = num_p3_cols + num_post_linear; + + // Verify contiguity: p3_cols should be contiguous + for (i, col) in config.perm_config.p3_cols.iter().enumerate() { + debug_assert_eq!( + col.id as u32, + poseidon2_base_col + i as u32, + "p3_cols not contiguous at index {}", + i + ); + } + // post_linear_layer_cols should be contiguous after p3_cols + let post_base = poseidon2_base_col + num_p3_cols; + for (i, col) in config.perm_config.post_linear_layer_cols.iter().enumerate() { + debug_assert_eq!( + col.id as u32, + post_base + i as u32, + "post_linear_layer_cols not contiguous at index {}", + i + ); + } + + ShardRamColumnMap { + addr, + is_ram_register, + value, + shard, + global_clk, + local_clk, + nonce, + is_global_write, + x, + y, + slope, + poseidon2_base_col, + num_poseidon2_cols, + num_p3_cols, + num_cols: num_witin as u32, + } +} + +// --------------------------------------------------------------------------- +// ShardRam EC batch computation +// --------------------------------------------------------------------------- + +use ceno_gpu::common::witgen::types::GpuShardRamRecord; +use gkr_iop::RAMType; +use p3::field::FieldAlgebra; +use tracing::info_span; + +use crate::error::ZKVMError; + +/// Convert a ShardRamRecord to GpuShardRamRecord (metadata only, EC fields zeroed). +pub(crate) fn shard_ram_record_to_gpu(rec: &crate::tables::ShardRamRecord) -> GpuShardRamRecord { + GpuShardRamRecord { + addr: rec.addr, + ram_type: match rec.ram_type { + RAMType::Register => 1, + RAMType::Memory => 2, + _ => 0, + }, + value: rec.value, + _pad0: 0, + shard: rec.shard, + local_clk: rec.local_clk, + global_clk: rec.global_clk, + is_to_write_set: if rec.is_to_write_set { 1 } else { 0 }, + nonce: 0, + point_x: [0; 7], + point_y: [0; 7], + } +} + +/// Convert a GPU-computed GpuShardRamRecord to ECPoint. +fn gpu_shard_ram_record_to_ec_point( + gpu_rec: &GpuShardRamRecord, +) -> crate::tables::ECPoint { + use crate::scheme::septic_curve::{SepticExtension, SepticPoint}; + + let mut point_x_arr = [E::BaseField::ZERO; 7]; + let mut point_y_arr = [E::BaseField::ZERO; 7]; + for j in 0..7 { + point_x_arr[j] = E::BaseField::from_canonical_u32(gpu_rec.point_x[j]); + point_y_arr[j] = E::BaseField::from_canonical_u32(gpu_rec.point_y[j]); + } + + let x = SepticExtension(point_x_arr); + let y = SepticExtension(point_y_arr); + let point = SepticPoint::from_affine(x, y); + + crate::tables::ECPoint { + nonce: gpu_rec.nonce, + point, + } +} + +/// Batch compute EC points on GPU, D2H results back to CPU as ShardRamInput. +/// +/// Used by the CPU fallback path in `structs.rs` when the full GPU pipeline +/// is unavailable. For the device-resident variant, see `gpu_batch_continuation_ec_on_device`. +pub fn gpu_batch_continuation_ec( + write_records: &[(crate::tables::ShardRamRecord, &'static str)], + read_records: &[(crate::tables::ShardRamRecord, &'static str)], +) -> Result< + ( + Vec>, + Vec>, + ), + ZKVMError, +> { + use crate::tables::ShardRamInput; + use gkr_iop::gpu::get_cuda_hal; + + let hal = get_cuda_hal().map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU not available for batch EC: {e}").into()) + })?; + + let total = write_records.len() + read_records.len(); + if total == 0 { + return Ok((vec![], vec![])); + } + + let mut gpu_records: Vec = Vec::with_capacity(total); + for (rec, _name) in write_records.iter().chain(read_records.iter()) { + gpu_records.push(shard_ram_record_to_gpu(rec)); + } + + let result = info_span!("gpu_batch_ec", n = total) + .in_scope(|| hal.witgen.batch_continuation_ec(&gpu_records)) + .map_err(|e| ZKVMError::InvalidWitness(format!("GPU batch EC failed: {e}").into()))?; + + let mut write_inputs = Vec::with_capacity(write_records.len()); + let mut read_inputs = Vec::with_capacity(read_records.len()); + + for (i, gpu_rec) in result.iter().enumerate() { + let (rec, name) = if i < write_records.len() { + (&write_records[i].0, write_records[i].1) + } else { + let ri = i - write_records.len(); + (&read_records[ri].0, read_records[ri].1) + }; + + let ec_point = gpu_shard_ram_record_to_ec_point::(gpu_rec); + let input = ShardRamInput { + name, + record: rec.clone(), + ec_point, + }; + + if i < write_records.len() { + write_inputs.push(input); + } else { + read_inputs.push(input); + } + } + + Ok((write_inputs, read_inputs)) +} + +/// Batch compute EC points on GPU, results stay on device. +/// +/// Used by the full GPU pipeline in `structs.rs` where records feed directly +/// into `merge_and_partition_records` on device without D2H. +pub fn gpu_batch_continuation_ec_on_device( + write_records: &[(crate::tables::ShardRamRecord, &'static str)], + read_records: &[(crate::tables::ShardRamRecord, &'static str)], +) -> Result< + ( + ceno_gpu::common::buffer::BufferImpl<'static, u32>, + usize, + usize, + ), + ZKVMError, +> { + use gkr_iop::gpu::get_cuda_hal; + + let hal = get_cuda_hal().map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU not available for batch EC: {e}").into()) + })?; + + let n_writes = write_records.len(); + let n_reads = read_records.len(); + let total = n_writes + n_reads; + if total == 0 { + let empty = hal + .witgen + .alloc_u32_zeroed(1, None) + .map_err(|e| ZKVMError::InvalidWitness(format!("alloc: {e}").into()))?; + return Ok((empty, 0, 0)); + } + + let mut gpu_records: Vec = Vec::with_capacity(total); + for (rec, _name) in write_records.iter().chain(read_records.iter()) { + gpu_records.push(shard_ram_record_to_gpu(rec)); + } + + let (device_buf, _count) = info_span!("gpu_batch_ec_on_device", n = total) + .in_scope(|| { + hal.witgen + .batch_continuation_ec_on_device(&gpu_records, None) + }) + .map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU batch EC on device failed: {e}").into()) + })?; + + Ok((device_buf, n_writes, n_reads)) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + structs::ProgramParams, + tables::{ShardRamCircuit, TableCircuit}, + }; + use ff_ext::BabyBearExt4; + + type E = BabyBearExt4; + + #[test] + fn test_extract_shard_ram_column_map() { + let mut cs = ConstraintSystem::::new(|| "test"); + let mut cb = CircuitBuilder::new(&mut cs); + let (config, _gkr_circuit) = + ShardRamCircuit::::build_gkr_iop_circuit(&mut cb, &ProgramParams::default()) + .unwrap(); + + let col_map = extract_shard_ram_column_map(&config, cb.cs.num_witin as usize); + let flat = col_map.to_flat(); + + // Basic columns should be in range + // (excluding poseidon2 meta entries which are counts, not column IDs) + for (i, &col) in flat[..30].iter().enumerate() { + assert!( + (col as usize) < col_map.num_cols as usize, + "Column {} (flat index {}) out of range: {} >= {}", + col, + i, + col, + col_map.num_cols + ); + } + + // Check uniqueness of actual column IDs (first 30 entries) + let mut seen = std::collections::HashSet::new(); + for &col in &flat[..30] { + assert!(seen.insert(col), "Duplicate column ID: {}", col); + } + + // Verify Poseidon2 column counts are reasonable + assert_eq!(col_map.num_p3_cols, 299, "Expected 299 p3 cols"); + assert_eq!( + col_map.num_poseidon2_cols, 344, + "Expected 344 total Poseidon2 cols" + ); + } +} diff --git a/ceno_zkvm/src/instructions/gpu/chips/shift_i.rs b/ceno_zkvm/src/instructions/gpu/chips/shift_i.rs new file mode 100644 index 000000000..7496268d4 --- /dev/null +++ b/ceno_zkvm/src/instructions/gpu/chips/shift_i.rs @@ -0,0 +1,171 @@ +use ceno_gpu::common::witgen::types::ShiftIColumnMap; +use ff_ext::ExtensionField; + +use crate::instructions::{ + gpu::utils::column_map::{extract_rd, extract_rs1, extract_state, extract_uint_limbs}, + riscv::shift::shift_circuit_v2::ShiftImmConfig, +}; + +/// Extract column map from a constructed ShiftImmConfig (I-type: SLLI/SRLI/SRAI). +pub fn extract_shift_i_column_map( + config: &ShiftImmConfig, + num_witin: usize, +) -> ShiftIColumnMap { + let (pc, ts) = extract_state(&config.i_insn.vm_state); + let (rs1_id, rs1_prev_ts, rs1_lt_diff) = extract_rs1(&config.i_insn.rs1); + let (rd_id, rd_prev_ts, rd_prev_val, rd_lt_diff) = extract_rd(&config.i_insn.rd); + + let rs1_bytes = extract_uint_limbs::(&config.rs1_read, "rs1_read"); + let rd_bytes = extract_uint_limbs::(&config.rd_written, "rd_written"); + let imm = config.imm.id as u32; + + // ShiftBase + let bit_shift_marker: [u32; 8] = + std::array::from_fn(|i| config.shift_base_config.bit_shift_marker[i].id as u32); + let limb_shift_marker: [u32; 4] = + std::array::from_fn(|i| config.shift_base_config.limb_shift_marker[i].id as u32); + let bit_multiplier_left = config.shift_base_config.bit_multiplier_left.id as u32; + let bit_multiplier_right = config.shift_base_config.bit_multiplier_right.id as u32; + let b_sign = config.shift_base_config.b_sign.id as u32; + let bit_shift_carry: [u32; 4] = + std::array::from_fn(|i| config.shift_base_config.bit_shift_carry[i].id as u32); + + ShiftIColumnMap { + pc, + ts, + rs1_id, + rs1_prev_ts, + rs1_lt_diff, + rd_id, + rd_prev_ts, + rd_prev_val, + rd_lt_diff, + rs1_bytes, + rd_bytes, + imm, + bit_shift_marker, + limb_shift_marker, + bit_multiplier_left, + bit_multiplier_right, + b_sign, + bit_shift_carry, + num_cols: num_witin as u32, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::{Instruction, riscv::shift_imm::SlliInstruction}, + structs::ProgramParams, + }; + use ff_ext::BabyBearExt4; + + type E = BabyBearExt4; + + use crate::instructions::gpu::utils::column_map::test_colmap; + test_colmap!( + test_extract_shift_i_column_map, + SlliInstruction, + extract_shift_i_column_map + ); + + #[test] + #[cfg(feature = "gpu")] + fn test_gpu_witgen_shift_i_correctness() { + use crate::{ + e2e::ShardContext, instructions::gpu::utils::test_helpers::assert_witness_colmajor_eq, + }; + use ceno_emul::{ByteAddr, Change, InsnKind, PC_STEP_SIZE, StepRecord, encode_rv32}; + use ceno_gpu::{Buffer, bb31::CudaHalBB31}; + + let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); + + let mut cs = ConstraintSystem::::new(|| "test_shift_i_gpu"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + SlliInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let num_witin = cb.cs.num_witin as usize; + let num_structural_witin = cb.cs.num_structural_witin as usize; + + const EDGE_CASES: &[(u32, u32)] = &[ + (0, 0), + (1, 0), // shift by 0 + (1, 31), // shift to MSB + (u32::MAX, 0), // no shift + (u32::MAX, 16), // shift half + (u32::MAX, 31), // shift max + (0x80000000, 1), // INT_MIN << 1 + (0xDEADBEEF, 4), // nibble shift + ]; + + let n = 1024; + let steps: Vec = (0..n) + .map(|i| { + let (rs1, shamt) = if i < EDGE_CASES.len() { + let (r, s) = EDGE_CASES[i]; + (r, s as i32) + } else { + ((i as u32).wrapping_mul(0x01010101), (i as i32) % 32) + }; + let rd_after = rs1 << (shamt as u32); + let cycle = 4 + (i as u64) * 4; + let pc = ByteAddr(0x1000 + (i as u32) * 4); + let insn_code = encode_rv32(InsnKind::SLLI, 2, 0, 4, shamt); + StepRecord::new_i_instruction( + cycle, + Change::new(pc, pc + PC_STEP_SIZE), + insn_code, + rs1, + Change::new((i as u32) % 200, rd_after), + 0, + ) + }) + .collect(); + let indices: Vec = (0..n).collect(); + + let mut shard_ctx = ShardContext::default(); + let (cpu_rmms, _lkm) = crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(); + let cpu_witness = &cpu_rmms[0]; + + let col_map = extract_shift_i_column_map(&config, num_witin); + let shard_ctx_gpu = ShardContext::default(); + let shard_offset = shard_ctx_gpu.current_shard_offset_cycle(); + let steps_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + steps.as_ptr() as *const u8, + steps.len() * std::mem::size_of::(), + ) + }; + let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); + let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); + let gpu_result = hal + .witgen + .witgen_shift_i( + &col_map, + &gpu_records, + &indices_u32, + shard_offset, + 0, + 0, + 0, + None, + None, + ) + .unwrap(); + + let gpu_data: Vec<::BaseField> = + gpu_result.witness.device_buffer.to_vec().unwrap(); + assert_witness_colmajor_eq(&gpu_data, cpu_witness.values(), n, num_witin); + } +} diff --git a/ceno_zkvm/src/instructions/gpu/chips/shift_r.rs b/ceno_zkvm/src/instructions/gpu/chips/shift_r.rs new file mode 100644 index 000000000..e1cb7b9e0 --- /dev/null +++ b/ceno_zkvm/src/instructions/gpu/chips/shift_r.rs @@ -0,0 +1,177 @@ +use ceno_gpu::common::witgen::types::ShiftRColumnMap; +use ff_ext::ExtensionField; + +use crate::instructions::{ + gpu::utils::column_map::{ + extract_rd, extract_rs1, extract_rs2, extract_state, extract_uint_limbs, + }, + riscv::shift::shift_circuit_v2::ShiftRTypeConfig, +}; + +/// Extract column map from a constructed ShiftRTypeConfig (R-type: SLL/SRL/SRA). +pub fn extract_shift_r_column_map( + config: &ShiftRTypeConfig, + num_witin: usize, +) -> ShiftRColumnMap { + let (pc, ts) = extract_state(&config.r_insn.vm_state); + let (rs1_id, rs1_prev_ts, rs1_lt_diff) = extract_rs1(&config.r_insn.rs1); + let (rs2_id, rs2_prev_ts, rs2_lt_diff) = extract_rs2(&config.r_insn.rs2); + let (rd_id, rd_prev_ts, rd_prev_val, rd_lt_diff) = extract_rd(&config.r_insn.rd); + + let rs1_bytes = extract_uint_limbs::(&config.rs1_read, "rs1_read"); + let rs2_bytes = extract_uint_limbs::(&config.rs2_read, "rs2_read"); + let rd_bytes = extract_uint_limbs::(&config.rd_written, "rd_written"); + + // ShiftBase + let bit_shift_marker: [u32; 8] = + std::array::from_fn(|i| config.shift_base_config.bit_shift_marker[i].id as u32); + let limb_shift_marker: [u32; 4] = + std::array::from_fn(|i| config.shift_base_config.limb_shift_marker[i].id as u32); + let bit_multiplier_left = config.shift_base_config.bit_multiplier_left.id as u32; + let bit_multiplier_right = config.shift_base_config.bit_multiplier_right.id as u32; + let b_sign = config.shift_base_config.b_sign.id as u32; + let bit_shift_carry: [u32; 4] = + std::array::from_fn(|i| config.shift_base_config.bit_shift_carry[i].id as u32); + + ShiftRColumnMap { + pc, + ts, + rs1_id, + rs1_prev_ts, + rs1_lt_diff, + rs2_id, + rs2_prev_ts, + rs2_lt_diff, + rd_id, + rd_prev_ts, + rd_prev_val, + rd_lt_diff, + rs1_bytes, + rs2_bytes, + rd_bytes, + bit_shift_marker, + limb_shift_marker, + bit_multiplier_left, + bit_multiplier_right, + b_sign, + bit_shift_carry, + num_cols: num_witin as u32, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::{Instruction, riscv::shift::SllInstruction}, + structs::ProgramParams, + }; + use ff_ext::BabyBearExt4; + + type E = BabyBearExt4; + + use crate::instructions::gpu::utils::column_map::test_colmap; + test_colmap!( + test_extract_shift_r_column_map, + SllInstruction, + extract_shift_r_column_map + ); + + #[test] + #[cfg(feature = "gpu")] + fn test_gpu_witgen_shift_r_correctness() { + use crate::{ + e2e::ShardContext, instructions::gpu::utils::test_helpers::assert_witness_colmajor_eq, + }; + use ceno_emul::{ByteAddr, Change, InsnKind, StepRecord, encode_rv32}; + use ceno_gpu::{Buffer, bb31::CudaHalBB31}; + + let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); + + let mut cs = ConstraintSystem::::new(|| "test_shift_r_gpu"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + SllInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let num_witin = cb.cs.num_witin as usize; + let num_structural_witin = cb.cs.num_structural_witin as usize; + + const EDGE_CASES: &[(u32, u32)] = &[ + (0, 0), + (1, 0), // shift by 0 + (1, 31), // shift to MSB + (u32::MAX, 0), // no shift + (u32::MAX, 16), // shift half + (u32::MAX, 31), // shift max + (0x80000000, 1), // INT_MIN << 1 + (0xDEADBEEF, 4), // nibble shift + ]; + + let n = 1024; + let steps: Vec = (0..n) + .map(|i| { + let (rs1, rs2) = if i < EDGE_CASES.len() { + EDGE_CASES[i] + } else { + ((i as u32).wrapping_mul(0x01010101), (i as u32) % 32) + }; + let rd_after = rs1 << (rs2 & 0x1F); + let cycle = 4 + (i as u64) * 4; + let pc = ByteAddr(0x1000 + (i as u32) * 4); + let insn_code = encode_rv32(InsnKind::SLL, 2, 3, 4, 0); + StepRecord::new_r_instruction( + cycle, + pc, + insn_code, + rs1, + rs2, + Change::new((i as u32) % 200, rd_after), + 0, + ) + }) + .collect(); + let indices: Vec = (0..n).collect(); + + let mut shard_ctx = ShardContext::default(); + let (cpu_rmms, _lkm) = crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(); + let cpu_witness = &cpu_rmms[0]; + + let col_map = extract_shift_r_column_map(&config, num_witin); + let shard_ctx_gpu = ShardContext::default(); + let shard_offset = shard_ctx_gpu.current_shard_offset_cycle(); + let steps_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + steps.as_ptr() as *const u8, + steps.len() * std::mem::size_of::(), + ) + }; + let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); + let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); + let gpu_result = hal + .witgen + .witgen_shift_r( + &col_map, + &gpu_records, + &indices_u32, + shard_offset, + 0, + 0, + 0, + None, + None, + ) + .unwrap(); + + let gpu_data: Vec<::BaseField> = + gpu_result.witness.device_buffer.to_vec().unwrap(); + assert_witness_colmajor_eq(&gpu_data, cpu_witness.values(), n, num_witin); + } +} diff --git a/ceno_zkvm/src/instructions/gpu/chips/slt.rs b/ceno_zkvm/src/instructions/gpu/chips/slt.rs new file mode 100644 index 000000000..45251fdad --- /dev/null +++ b/ceno_zkvm/src/instructions/gpu/chips/slt.rs @@ -0,0 +1,164 @@ +use ceno_gpu::common::witgen::types::SltColumnMap; +use ff_ext::ExtensionField; + +use crate::instructions::{ + gpu::utils::column_map::{ + extract_rd, extract_rs1, extract_rs2, extract_state, extract_uint_limbs, + }, + riscv::slt::slt_circuit_v2::SetLessThanConfig, +}; + +/// Extract column map from a constructed SetLessThanConfig (SLT/SLTU). +pub fn extract_slt_column_map( + config: &SetLessThanConfig, + num_witin: usize, +) -> SltColumnMap { + let rs1_limbs = extract_uint_limbs::(&config.rs1_read, "rs1_read"); + let rs2_limbs = extract_uint_limbs::(&config.rs2_read, "rs2_read"); + + // UIntLimbsLT comparison gadget + let cmp_lt = config.uint_lt_config.cmp_lt.id as u32; + let a_msb_f = config.uint_lt_config.a_msb_f.id as u32; + let b_msb_f = config.uint_lt_config.b_msb_f.id as u32; + let diff_marker: [u32; 2] = [ + config.uint_lt_config.diff_marker[0].id as u32, + config.uint_lt_config.diff_marker[1].id as u32, + ]; + let diff_val = config.uint_lt_config.diff_val.id as u32; + + let (pc, ts) = extract_state(&config.r_insn.vm_state); + let (rs1_id, rs1_prev_ts, rs1_lt_diff) = extract_rs1(&config.r_insn.rs1); + let (rs2_id, rs2_prev_ts, rs2_lt_diff) = extract_rs2(&config.r_insn.rs2); + let (rd_id, rd_prev_ts, rd_prev_val, rd_lt_diff) = extract_rd(&config.r_insn.rd); + + SltColumnMap { + rs1_limbs, + rs2_limbs, + cmp_lt, + a_msb_f, + b_msb_f, + diff_marker, + diff_val, + pc, + ts, + rs1_id, + rs1_prev_ts, + rs1_lt_diff, + rs2_id, + rs2_prev_ts, + rs2_lt_diff, + rd_id, + rd_prev_ts, + rd_prev_val, + rd_lt_diff, + num_cols: num_witin as u32, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::{Instruction, riscv::slt::SltInstruction}, + structs::ProgramParams, + }; + use ff_ext::BabyBearExt4; + + type E = BabyBearExt4; + + use crate::instructions::gpu::utils::column_map::test_colmap; + test_colmap!( + test_extract_slt_column_map, + SltInstruction, + extract_slt_column_map + ); + + #[test] + #[cfg(feature = "gpu")] + fn test_gpu_witgen_slt_correctness() { + use crate::{ + e2e::ShardContext, instructions::gpu::utils::test_helpers::assert_witness_colmajor_eq, + }; + use ceno_emul::{ByteAddr, Change, InsnKind, StepRecord, encode_rv32}; + use ceno_gpu::{Buffer, bb31::CudaHalBB31}; + + let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); + + let mut cs = ConstraintSystem::::new(|| "test_slt_gpu"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + SltInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let num_witin = cb.cs.num_witin as usize; + let num_structural_witin = cb.cs.num_structural_witin as usize; + + let n = 1024; + let steps: Vec = (0..n) + .map(|i| { + // Mix positive, negative, equal cases + let rs1 = ((i as i32) * 137 - 500) as u32; + let rs2 = ((i as i32) * 89 - 300) as u32; + let rd_after = if (rs1 as i32) < (rs2 as i32) { + 1u32 + } else { + 0u32 + }; + let cycle = 4 + (i as u64) * 4; + let pc = ByteAddr(0x1000 + (i as u32) * 4); + let insn_code = encode_rv32(InsnKind::SLT, 2, 3, 4, 0); + StepRecord::new_r_instruction( + cycle, + pc, + insn_code, + rs1, + rs2, + Change::new((i as u32) % 200, rd_after), + 0, + ) + }) + .collect(); + let indices: Vec = (0..n).collect(); + + let mut shard_ctx = ShardContext::default(); + let (cpu_rmms, _lkm) = crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(); + let cpu_witness = &cpu_rmms[0]; + + let col_map = extract_slt_column_map(&config, num_witin); + let shard_ctx_gpu = ShardContext::default(); + let shard_offset = shard_ctx_gpu.current_shard_offset_cycle(); + let steps_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + steps.as_ptr() as *const u8, + steps.len() * std::mem::size_of::(), + ) + }; + let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); + let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); + let gpu_result = hal + .witgen + .witgen_slt( + &col_map, + &gpu_records, + &indices_u32, + shard_offset, + 1, + 0, + 0, + None, + None, + ) + .unwrap(); + + let gpu_data: Vec<::BaseField> = + gpu_result.witness.device_buffer.to_vec().unwrap(); + assert_witness_colmajor_eq(&gpu_data, cpu_witness.values(), n, num_witin); + } +} diff --git a/ceno_zkvm/src/instructions/gpu/chips/slti.rs b/ceno_zkvm/src/instructions/gpu/chips/slti.rs new file mode 100644 index 000000000..ba7868ef2 --- /dev/null +++ b/ceno_zkvm/src/instructions/gpu/chips/slti.rs @@ -0,0 +1,158 @@ +use ceno_gpu::common::witgen::types::SltiColumnMap; +use ff_ext::ExtensionField; + +use crate::instructions::{ + gpu::utils::column_map::{extract_rd, extract_rs1, extract_state, extract_uint_limbs}, + riscv::slti::slti_circuit_v2::SetLessThanImmConfig, +}; + +/// Extract column map from a constructed SetLessThanImmConfig (SLTI/SLTIU). +pub fn extract_slti_column_map( + config: &SetLessThanImmConfig, + num_witin: usize, +) -> SltiColumnMap { + let rs1_limbs = extract_uint_limbs::(&config.rs1_read, "rs1_read"); + let imm = config.imm.id as u32; + let imm_sign = config.imm_sign.id as u32; + + // UIntLimbsLT comparison gadget + let cmp_lt = config.uint_lt_config.cmp_lt.id as u32; + let a_msb_f = config.uint_lt_config.a_msb_f.id as u32; + let b_msb_f = config.uint_lt_config.b_msb_f.id as u32; + let diff_marker: [u32; 2] = [ + config.uint_lt_config.diff_marker[0].id as u32, + config.uint_lt_config.diff_marker[1].id as u32, + ]; + let diff_val = config.uint_lt_config.diff_val.id as u32; + + let (pc, ts) = extract_state(&config.i_insn.vm_state); + let (rs1_id, rs1_prev_ts, rs1_lt_diff) = extract_rs1(&config.i_insn.rs1); + let (rd_id, rd_prev_ts, rd_prev_val, rd_lt_diff) = extract_rd(&config.i_insn.rd); + + SltiColumnMap { + rs1_limbs, + imm, + imm_sign, + cmp_lt, + a_msb_f, + b_msb_f, + diff_marker, + diff_val, + pc, + ts, + rs1_id, + rs1_prev_ts, + rs1_lt_diff, + rd_id, + rd_prev_ts, + rd_prev_val, + rd_lt_diff, + num_cols: num_witin as u32, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::{Instruction, riscv::slti::SltiInstruction}, + structs::ProgramParams, + }; + use ff_ext::BabyBearExt4; + + type E = BabyBearExt4; + + use crate::instructions::gpu::utils::column_map::test_colmap; + test_colmap!( + test_extract_slti_column_map, + SltiInstruction, + extract_slti_column_map + ); + + #[test] + #[cfg(feature = "gpu")] + fn test_gpu_witgen_slti_correctness() { + use crate::{ + e2e::ShardContext, instructions::gpu::utils::test_helpers::assert_witness_colmajor_eq, + }; + use ceno_emul::{ByteAddr, Change, InsnKind, PC_STEP_SIZE, StepRecord, encode_rv32}; + use ceno_gpu::{Buffer, bb31::CudaHalBB31}; + + let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); + + let mut cs = ConstraintSystem::::new(|| "test_slti_gpu"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + SltiInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let num_witin = cb.cs.num_witin as usize; + let num_structural_witin = cb.cs.num_structural_witin as usize; + + let n = 1024; + let steps: Vec = (0..n) + .map(|i| { + let rs1 = ((i as i32) * 137 - 500) as u32; + let imm = ((i as i32) % 2048 - 1024) as i32; // -1024..1023 + let rd_after = if (rs1 as i32) < (imm as i32) { + 1u32 + } else { + 0u32 + }; + let cycle = 4 + (i as u64) * 4; + let pc = ByteAddr(0x1000 + (i as u32) * 4); + let insn_code = encode_rv32(InsnKind::SLTI, 2, 0, 4, imm); + StepRecord::new_i_instruction( + cycle, + Change::new(pc, pc + PC_STEP_SIZE), + insn_code, + rs1, + Change::new((i as u32) % 200, rd_after), + 0, + ) + }) + .collect(); + let indices: Vec = (0..n).collect(); + + let mut shard_ctx = ShardContext::default(); + let (cpu_rmms, _lkm) = crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(); + let cpu_witness = &cpu_rmms[0]; + + let col_map = extract_slti_column_map(&config, num_witin); + let shard_ctx_gpu = ShardContext::default(); + let shard_offset = shard_ctx_gpu.current_shard_offset_cycle(); + let steps_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + steps.as_ptr() as *const u8, + steps.len() * std::mem::size_of::(), + ) + }; + let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); + let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); + let gpu_result = hal + .witgen + .witgen_slti( + &col_map, + &gpu_records, + &indices_u32, + shard_offset, + 1, + 0, + 0, + None, + None, + ) + .unwrap(); + + let gpu_data: Vec<::BaseField> = + gpu_result.witness.device_buffer.to_vec().unwrap(); + assert_witness_colmajor_eq(&gpu_data, cpu_witness.values(), n, num_witin); + } +} diff --git a/ceno_zkvm/src/instructions/gpu/chips/sub.rs b/ceno_zkvm/src/instructions/gpu/chips/sub.rs new file mode 100644 index 000000000..6107771db --- /dev/null +++ b/ceno_zkvm/src/instructions/gpu/chips/sub.rs @@ -0,0 +1,164 @@ +use ceno_gpu::common::witgen::types::SubColumnMap; +use ff_ext::ExtensionField; + +use crate::instructions::{ + gpu::utils::column_map::{ + extract_carries, extract_rd, extract_rs1, extract_rs2, extract_state, extract_uint_limbs, + }, + riscv::arith::ArithConfig, +}; + +/// Extract column map from a constructed ArithConfig (SUB variant). +/// +/// SUB proves: rs1 = rs2 + rd. The carries come from the (rs2 + rd) addition, +/// stored in rs1_read.carries (since rs1_read = rs2.add(rd) in construct_circuit). +pub fn extract_sub_column_map( + config: &ArithConfig, + num_witin: usize, +) -> SubColumnMap { + let (pc, ts) = extract_state(&config.r_insn.vm_state); + let (rs1_id, rs1_prev_ts, rs1_lt_diff) = extract_rs1(&config.r_insn.rs1); + let (rs2_id, rs2_prev_ts, rs2_lt_diff) = extract_rs2(&config.r_insn.rs2); + let (rd_id, rd_prev_ts, rd_prev_val, rd_lt_diff) = extract_rd(&config.r_insn.rd); + + let rs2_limbs = extract_uint_limbs::(&config.rs2_read, "rs2_read"); + let rd_limbs = extract_uint_limbs::(&config.rd_written, "rd_written"); + let carries = extract_carries::(&config.rs1_read, "rs1_read"); + + SubColumnMap { + pc, + ts, + rs1_id, + rs1_prev_ts, + rs1_lt_diff, + rs2_id, + rs2_prev_ts, + rs2_lt_diff, + rd_id, + rd_prev_ts, + rd_prev_val, + rd_lt_diff, + rs2_limbs, + rd_limbs, + carries, + num_cols: num_witin as u32, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::{Instruction, riscv::arith::SubInstruction}, + structs::ProgramParams, + }; + use ff_ext::BabyBearExt4; + + type E = BabyBearExt4; + + use crate::instructions::gpu::utils::column_map::test_colmap; + test_colmap!( + test_extract_sub_column_map, + SubInstruction, + extract_sub_column_map + ); + + #[test] + #[cfg(feature = "gpu")] + fn test_gpu_witgen_sub_correctness() { + use crate::{ + e2e::ShardContext, instructions::gpu::utils::test_helpers::assert_witness_colmajor_eq, + }; + use ceno_emul::{ByteAddr, Change, InsnKind, StepRecord, encode_rv32}; + use ceno_gpu::{Buffer, bb31::CudaHalBB31}; + + let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); + + let mut cs = ConstraintSystem::::new(|| "test_sub_gpu"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + SubInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let num_witin = cb.cs.num_witin as usize; + let num_structural_witin = cb.cs.num_structural_witin as usize; + + const EDGE_CASES: &[(u32, u32)] = &[ + (0, 0), + (0, 1), // underflow + (1, 0), + (0, u32::MAX), // underflow + (u32::MAX, u32::MAX), + (0x80000000, 1), // INT_MIN - 1 + (0, 0x80000000), // 0 - INT_MIN + (0x7FFFFFFF, 0xFFFFFFFF), // INT_MAX - (-1) + ]; + + let n = 1024; + let steps: Vec = (0..n) + .map(|i| { + let (rs1, rs2) = if i < EDGE_CASES.len() { + EDGE_CASES[i] + } else { + ((i as u32) % 1000 + 500, (i as u32) % 300 + 1) + }; + let rd_after = rs1.wrapping_sub(rs2); + let cycle = 4 + (i as u64) * 4; + let pc = ByteAddr(0x1000 + (i as u32) * 4); + let insn_code = encode_rv32(InsnKind::SUB, 2, 3, 4, 0); + StepRecord::new_r_instruction( + cycle, + pc, + insn_code, + rs1, + rs2, + Change::new((i as u32) % 200, rd_after), + 0, + ) + }) + .collect(); + let indices: Vec = (0..n).collect(); + + // CPU path + let mut shard_ctx = ShardContext::default(); + let (cpu_rmms, _lkm) = crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(); + let cpu_witness = &cpu_rmms[0]; + + // GPU path + let col_map = extract_sub_column_map(&config, num_witin); + let shard_ctx_gpu = ShardContext::default(); + let shard_offset = shard_ctx_gpu.current_shard_offset_cycle(); + let steps_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + steps.as_ptr() as *const u8, + steps.len() * std::mem::size_of::(), + ) + }; + let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); + let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); + let gpu_result = hal + .witgen + .witgen_sub( + &col_map, + &gpu_records, + &indices_u32, + shard_offset, + 0, + 0, + None, + None, + ) + .unwrap(); + + let gpu_data: Vec<::BaseField> = + gpu_result.witness.device_buffer.to_vec().unwrap(); + assert_witness_colmajor_eq(&gpu_data, cpu_witness.values(), n, num_witin); + } +} diff --git a/ceno_zkvm/src/instructions/gpu/chips/sw.rs b/ceno_zkvm/src/instructions/gpu/chips/sw.rs new file mode 100644 index 000000000..f501ba910 --- /dev/null +++ b/ceno_zkvm/src/instructions/gpu/chips/sw.rs @@ -0,0 +1,163 @@ +use ceno_gpu::common::witgen::types::SwColumnMap; +use ff_ext::ExtensionField; + +use crate::instructions::{ + gpu::utils::column_map::{ + extract_rs1, extract_rs2, extract_state, extract_uint_limbs, extract_write_mem, + }, + riscv::memory::store_v2::StoreConfig, +}; + +/// Extract column map from a constructed StoreConfig (SW variant, N_ZEROS=2). +pub fn extract_sw_column_map( + config: &StoreConfig, + num_witin: usize, +) -> SwColumnMap { + let sm = &config.s_insn; + + let (pc, ts) = extract_state(&sm.vm_state); + let (rs1_id, rs1_prev_ts, rs1_lt_diff) = extract_rs1(&sm.rs1); + let (rs2_id, rs2_prev_ts, rs2_lt_diff) = extract_rs2(&sm.rs2); + let (mem_prev_ts, mem_lt_diff) = extract_write_mem(&sm.mem_write); + + let rs1_limbs = extract_uint_limbs::(&config.rs1_read, "rs1_read"); + let rs2_limbs = extract_uint_limbs::(&config.rs2_read, "rs2_read"); + let imm = config.imm.id as u32; + let imm_sign = config.imm_sign.id as u32; + let prev_mem_val = + extract_uint_limbs::(&config.prev_memory_value, "prev_memory_value"); + let mem_addr = extract_uint_limbs::(&config.memory_addr.addr, "memory_addr"); + + SwColumnMap { + pc, + ts, + rs1_id, + rs1_prev_ts, + rs1_lt_diff, + rs2_id, + rs2_prev_ts, + rs2_lt_diff, + mem_prev_ts, + mem_lt_diff, + rs1_limbs, + rs2_limbs, + imm, + imm_sign, + prev_mem_val, + mem_addr, + num_cols: num_witin as u32, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::{Instruction, riscv::memory::SwInstruction}, + structs::ProgramParams, + }; + use ff_ext::BabyBearExt4; + + type E = BabyBearExt4; + + use crate::instructions::gpu::utils::column_map::test_colmap; + test_colmap!( + test_extract_sw_column_map, + SwInstruction, + extract_sw_column_map + ); + + #[test] + #[cfg(feature = "gpu")] + fn test_gpu_witgen_sw_correctness() { + use crate::{ + e2e::ShardContext, instructions::gpu::utils::test_helpers::assert_witness_colmajor_eq, + }; + use ceno_emul::{ByteAddr, Change, InsnKind, StepRecord, WordAddr, WriteOp, encode_rv32}; + use ceno_gpu::{Buffer, bb31::CudaHalBB31}; + + let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); + + let mut cs = ConstraintSystem::::new(|| "test_sw_gpu"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + SwInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let num_witin = cb.cs.num_witin as usize; + let num_structural_witin = cb.cs.num_structural_witin as usize; + + let n = 1024; + let imm_values: [i32; 4] = [0, 4, -4, -8]; + let steps: Vec = (0..n) + .map(|i| { + let pc = ByteAddr(0x1000 + (i as u32) * 4); + let rs1_val = 0x1000u32 + (i as u32) * 16; // 16-byte aligned base + let rs2_val = (i as u32) * 111 % 1000000; // value to store + let imm: i32 = imm_values[i % imm_values.len()]; + let mem_addr = rs1_val.wrapping_add_signed(imm); + let prev_mem_val = (i as u32) * 77 % 500000; + let cycle = 4 + (i as u64) * 4; + let insn_code = encode_rv32(InsnKind::SW, 2, 3, 0, imm); + + let mem_write_op = WriteOp { + addr: WordAddr::from(ByteAddr(mem_addr)), + value: Change::new(prev_mem_val, rs2_val), + previous_cycle: 0, + }; + + StepRecord::new_s_instruction( + cycle, + pc, + insn_code, + rs1_val, + rs2_val, + mem_write_op, + 0, + ) + }) + .collect(); + let indices: Vec = (0..n).collect(); + + let mut shard_ctx = ShardContext::default(); + let (cpu_rmms, _lkm) = crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(); + let cpu_witness = &cpu_rmms[0]; + + let col_map = extract_sw_column_map(&config, num_witin); + let shard_ctx_gpu = ShardContext::default(); + let shard_offset = shard_ctx_gpu.current_shard_offset_cycle(); + let steps_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + steps.as_ptr() as *const u8, + steps.len() * std::mem::size_of::(), + ) + }; + let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); + let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); + let gpu_result = hal + .witgen + .witgen_sw( + &col_map, + &gpu_records, + &indices_u32, + shard_offset, + 0, + 0, + 0, + None, + None, + ) + .unwrap(); + + let gpu_data: Vec<::BaseField> = + gpu_result.witness.device_buffer.to_vec().unwrap(); + assert_witness_colmajor_eq(&gpu_data, cpu_witness.values(), n, num_witin); + } +} diff --git a/ceno_zkvm/src/instructions/gpu/config.rs b/ceno_zkvm/src/instructions/gpu/config.rs new file mode 100644 index 000000000..6c9ad1e26 --- /dev/null +++ b/ceno_zkvm/src/instructions/gpu/config.rs @@ -0,0 +1,86 @@ +/// GPU witgen path-control helpers: kind tags, verified-kind queries, and +/// environment-variable disable switches. +/// +/// Environment variables (3 total): +/// - `CENO_GPU_ENABLE_WITGEN` — opt-in GPU witgen (default: CPU) +/// - `CENO_GPU_DISABLE_WITGEN_KINDS=add,sub,keccak,...` — per-kind disable (comma-separated tags) +/// - `CENO_GPU_DEBUG_COMPARE_WITGEN` — enable GPU vs CPU comparison for all chips (witness, LK, shard, EC) +use super::dispatch::GpuWitgenKind; + +pub(crate) fn kind_tag(kind: GpuWitgenKind) -> &'static str { + match kind { + GpuWitgenKind::Add => "add", + GpuWitgenKind::Sub => "sub", + GpuWitgenKind::LogicR(_) => "logic_r", + GpuWitgenKind::Lw => "lw", + GpuWitgenKind::LogicI(_) => "logic_i", + GpuWitgenKind::Addi => "addi", + GpuWitgenKind::Lui => "lui", + GpuWitgenKind::Auipc => "auipc", + GpuWitgenKind::Jal => "jal", + GpuWitgenKind::ShiftR(_) => "shift_r", + GpuWitgenKind::ShiftI(_) => "shift_i", + GpuWitgenKind::Slt(_) => "slt", + GpuWitgenKind::Slti(_) => "slti", + GpuWitgenKind::BranchEq(_) => "branch_eq", + GpuWitgenKind::BranchCmp(_) => "branch_cmp", + GpuWitgenKind::Jalr => "jalr", + GpuWitgenKind::Sw => "sw", + GpuWitgenKind::Sh => "sh", + GpuWitgenKind::Sb => "sb", + GpuWitgenKind::LoadSub { .. } => "load_sub", + GpuWitgenKind::Mul(_) => "mul", + GpuWitgenKind::Div(_) => "div", + GpuWitgenKind::Keccak => "keccak", + } +} + +/// Check if a specific GPU witgen kind is disabled via `CENO_GPU_DISABLE_WITGEN_KINDS` env var. +/// +/// Format: `CENO_GPU_DISABLE_WITGEN_KINDS=add,sub,keccak,lw` (comma-separated kind tags) +/// +/// This covers all chips including keccak. +pub(crate) fn is_kind_disabled(kind: GpuWitgenKind) -> bool { + thread_local! { + static DISABLED: std::cell::OnceCell> = const { std::cell::OnceCell::new() }; + } + DISABLED.with(|cell| { + let disabled = cell.get_or_init(|| { + std::env::var("CENO_GPU_DISABLE_WITGEN_KINDS") + .ok() + .map(|s| s.split(',').map(|t| t.trim().to_lowercase()).collect()) + .unwrap_or_default() + }); + if disabled.is_empty() { + return false; + } + let tag = kind_tag(kind); + disabled.iter().any(|d| d == tag) + }) +} + +/// Returns true if GPU witgen is enabled via `CENO_GPU_ENABLE_WITGEN` env var. +/// Default is disabled (CPU witgen). Set `CENO_GPU_ENABLE_WITGEN=1` to opt in. +/// The value is cached at first access. +pub(crate) fn is_gpu_witgen_enabled() -> bool { + use std::sync::OnceLock; + static ENABLED: OnceLock = OnceLock::new(); + *ENABLED.get_or_init(|| { + let val = std::env::var_os("CENO_GPU_ENABLE_WITGEN"); + let enabled = val.is_some(); + eprintln!( + "[GPU witgen] CENO_GPU_ENABLE_WITGEN={:?} → enabled={}", + val, enabled + ); + enabled + }) +} + +/// Returns true if `CENO_GPU_DEBUG_COMPARE_WITGEN` is set (any value). +/// When enabled, GPU vs CPU comparison runs for ALL categories: +/// witness, LK multiplicity, shard records, EC points, and E2E shard context. +pub(crate) fn is_debug_compare_enabled() -> bool { + use std::sync::OnceLock; + static ENABLED: OnceLock = OnceLock::new(); + *ENABLED.get_or_init(|| std::env::var_os("CENO_GPU_DEBUG_COMPARE_WITGEN").is_some()) +} diff --git a/ceno_zkvm/src/instructions/gpu/dispatch.rs b/ceno_zkvm/src/instructions/gpu/dispatch.rs new file mode 100644 index 000000000..56331a432 --- /dev/null +++ b/ceno_zkvm/src/instructions/gpu/dispatch.rs @@ -0,0 +1,1185 @@ +/// GPU witness generation dispatcher for the proving pipeline. +/// +/// This module provides `try_gpu_assign_instances` which: +/// 1. Runs the GPU kernel to fill the witness matrix (fast) +/// 2. Runs a lightweight CPU loop to collect lk and shardram without witness replay +/// 3. Returns the GPU-generated witness + CPU-collected lk and shardram +use ceno_emul::{StepIndex, StepRecord, WordAddr}; +use ceno_gpu::{ + Buffer, CudaHal, + bb31::CudaHalBB31, + common::{ + transpose::matrix_transpose, + witgen::types::{GpuRamRecordSlot, GpuShardRamRecord}, + }, +}; +use ff_ext::ExtensionField; +use gkr_iop::utils::lk_multiplicity::Multiplicity; +use p3::field::FieldAlgebra; +use std::cell::Cell; +use tracing::info_span; +use witness::{InstancePaddingStrategy, RowMajorMatrix}; + +use super::{ + config::{is_gpu_witgen_enabled, is_kind_disabled}, + utils::debug_compare::{ + debug_compare_final_lk, debug_compare_shard_ec, debug_compare_shardram, + debug_compare_witness, + }, +}; +use crate::{ + e2e::ShardContext, + error::ZKVMError, + instructions::{Instruction, cpu_collect_lk_and_shardram, cpu_collect_shardram}, + tables::RMMCollections, + witness::LkMultiplicity, +}; + +#[derive(Debug, Clone, Copy)] +pub enum GpuWitgenKind { + Add, + Sub, + LogicR(u32), // 0=AND, 1=OR, 2=XOR + LogicI(u32), // 0=AND, 1=OR, 2=XOR + Addi, + Lui, + Auipc, + Jal, + ShiftR(u32), // 0=SLL, 1=SRL, 2=SRA + ShiftI(u32), // 0=SLLI, 1=SRLI, 2=SRAI + Slt(u32), // 1=SLT(signed), 0=SLTU(unsigned) + Slti(u32), // 1=SLTI(signed), 0=SLTIU(unsigned) + BranchEq(u32), // 1=BEQ, 0=BNE + BranchCmp(u32), // 1=signed (BLT/BGE), 0=unsigned (BLTU/BGEU) + Jalr, + Sw, + Sh, + Sb, + LoadSub { load_width: u32, is_signed: u32 }, + Mul(u32), // 0=MUL, 1=MULH, 2=MULHU, 3=MULHSU + Div(u32), // 0=DIV, 1=DIVU, 2=REM, 3=REMU + Lw, + Keccak, +} + +// Re-exports from device_cache module for external callers (e2e.rs, structs.rs). +pub use super::cache::{ + SharedDeviceBufferSet, flush_shared_ec_buffers, invalidate_shard_meta_cache, + invalidate_shard_steps_cache, take_shared_device_buffers, +}; +// Re-export for external callers (structs.rs). +pub use super::chips::shard_ram::gpu_batch_continuation_ec; +use super::{ + cache::{ + ensure_shard_metadata_cached, read_shared_addr_count, read_shared_addr_range, + upload_shard_steps_cached, with_cached_gpu_ctx, with_cached_shard_meta, + with_cached_shard_steps, + }, + utils::d2h::{ + CompactEcBuf, LkResult, RamBuf, WitResult, gpu_collect_shard_records, gpu_compact_ec_d2h, + gpu_lk_counters_to_multiplicity, gpu_witness_to_rmm, + }, +}; + +thread_local! { + /// Thread-local flag to force CPU path (used by debug comparison code). + static FORCE_CPU_PATH: Cell = const { Cell::new(false) }; +} + +/// Force the current thread to use CPU path for all GPU witgen calls. +/// Used by debug comparison code in e2e.rs to run a CPU-only reference. +pub fn set_force_cpu_path(force: bool) { + FORCE_CPU_PATH.with(|f| f.set(force)); +} + +pub(crate) fn is_force_cpu_path() -> bool { + FORCE_CPU_PATH.with(|f| f.get()) +} + +/// Try to run GPU witness generation for the given instruction. +/// Returns `Ok(Some(...))` if GPU was used, `Ok(None)` if GPU is unavailable (caller should fallback to CPU). +/// +/// # Safety invariant +/// +/// The caller **must** ensure that `I::InstructionConfig` matches `kind`: +/// - `GpuWitgenKind::Add` requires `I` to be `ArithInstruction` (config = `ArithConfig`) +/// - `GpuWitgenKind::Lw` requires `I` to be `LoadInstruction` (config = `LoadConfig`) +/// +/// Violating this will cause undefined behavior via pointer cast in [`gpu_fill_witness`]. +pub(crate) fn try_gpu_assign_instances>( + config: &I::InstructionConfig, + shard_ctx: &mut ShardContext, + num_witin: usize, + num_structural_witin: usize, + shard_steps: &[StepRecord], + step_indices: &[StepIndex], + kind: GpuWitgenKind, +) -> Result, Multiplicity)>, ZKVMError> { + use gkr_iop::gpu::get_cuda_hal; + + if !is_gpu_witgen_enabled() || is_force_cpu_path() { + return Ok(None); + } + + if !I::GPU_LK_SHARDRAM { + return Ok(None); + } + + if is_kind_disabled(kind) { + return Ok(None); + } + + let total_instances = step_indices.len(); + if total_instances == 0 { + // Empty: just return empty matrices + let num_structural_witin = num_structural_witin.max(1); + let raw_witin = RowMajorMatrix::::new(0, num_witin, I::padding_strategy()); + let raw_structural = + RowMajorMatrix::::new(0, num_structural_witin, I::padding_strategy()); + let lk = LkMultiplicity::default(); + return Ok(Some(( + [raw_witin, raw_structural], + lk.into_finalize_result(), + ))); + } + + // GPU only supports BabyBear field + if std::any::TypeId::of::() + != std::any::TypeId::of::<::BaseField>() + { + return Ok(None); + } + + let hal = match get_cuda_hal() { + Ok(hal) => hal, + Err(_) => return Ok(None), // GPU not available, fallback to CPU + }; + + tracing::debug!("[GPU witgen] {:?} with {} instances", kind, total_instances); + info_span!("gpu_witgen", kind = ?kind, n = total_instances).in_scope(|| { + gpu_assign_instances_inner::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + kind, + &hal, + ) + .map(Some) + }) +} + +fn gpu_assign_instances_inner>( + config: &I::InstructionConfig, + shard_ctx: &mut ShardContext, + num_witin: usize, + num_structural_witin: usize, + shard_steps: &[StepRecord], + step_indices: &[StepIndex], + kind: GpuWitgenKind, + hal: &CudaHalBB31, +) -> Result<(RMMCollections, Multiplicity), ZKVMError> { + let num_structural_witin = num_structural_witin.max(1); + let total_instances = step_indices.len(); + + // Step 1: GPU fills witness matrix (+ LK counters + shard records for merged kinds) + let (gpu_witness, gpu_lk_counters, gpu_ram_slots, gpu_compact_ec, gpu_compact_addr) = + info_span!("gpu_kernel").in_scope(|| { + gpu_fill_witness::( + hal, + config, + shard_ctx, + num_witin, + shard_steps, + step_indices, + kind, + ) + })?; + + // Step 2: Collect lk and shardram + // Priority: GPU shard records > CPU shard records > full CPU lk and shardram + // + // Keccak never enters this function (it has `gpu_assign_keccak_inner`). + // Guard defensively in case the enum value is ever passed here by mistake. + let is_standard_kind = !matches!(kind, GpuWitgenKind::Keccak); + + let lk_multiplicity = if gpu_lk_counters.is_some() && is_standard_kind { + let lk_multiplicity = info_span!("gpu_lk_d2h") + .in_scope(|| gpu_lk_counters_to_multiplicity(gpu_lk_counters.unwrap()))?; + + if gpu_compact_ec.is_none() && gpu_compact_addr.is_none() && is_standard_kind { + // Shared buffer path: EC records + addr_accessed accumulated on device + // in shared buffers across all kernel invocations. Skip per-kernel D2H. + // Data will be consumed in batch by assign_shared_circuit. + } else if gpu_compact_ec.is_some() && is_standard_kind { + // GPU EC path: compact records already have EC points computed on device. + // D2H only the active records (much smaller than full N*3 slot buffer). + info_span!("gpu_ec_shard").in_scope(|| { + let compact = gpu_compact_ec.unwrap(); + let compact_records = + info_span!("compact_d2h").in_scope(|| gpu_compact_ec_d2h(&compact))?; + + // D2H ram_slots lazily (only for debug or fallback). + // Avoid the 68 MB D2H in the common case. + let ram_slots_d2h = || -> Result, ZKVMError> { + if let Some(ref ram_buf) = gpu_ram_slots { + let sv: Vec = ram_buf.to_vec().map_err(|e| { + ZKVMError::InvalidWitness(format!("ram_slots D2H failed: {e}").into()) + })?; + Ok(unsafe { + let ptr = sv.as_ptr() as *const GpuRamRecordSlot; + let len = sv.len() * 4 / std::mem::size_of::(); + std::slice::from_raw_parts(ptr, len).to_vec() + }) + } else { + Ok(vec![]) + } + }; + + // D2H compact addr_accessed (GPU-side compaction via atomicAdd). + // Much smaller than full ram_slots D2H (4 bytes/addr vs 48 bytes/slot). + info_span!("compact_addr_d2h").in_scope(|| -> Result<(), ZKVMError> { + if let Some(ref ca) = gpu_compact_addr { + let count_vec: Vec = ca.count_buf.to_vec().map_err(|e| { + ZKVMError::InvalidWitness( + format!("compact_addr_count D2H failed: {e}").into(), + ) + })?; + let n = count_vec[0] as usize; + if n > 0 { + let addrs: Vec = ca.buffer.to_vec_n(n).map_err(|e| { + ZKVMError::InvalidWitness( + format!("compact_addr D2H failed: {e}").into(), + ) + })?; + let mut forked = shard_ctx.get_forked(); + let thread_ctx = &mut forked[0]; + for &addr in &addrs { + thread_ctx.push_addr_accessed(WordAddr(addr)); + } + } + } else { + // Fallback: D2H full ram_slots for addr_accessed + let slots = ram_slots_d2h()?; + let mut forked = shard_ctx.get_forked(); + let thread_ctx = &mut forked[0]; + for slot in &slots { + if slot.flags & (1 << 4) != 0 { + thread_ctx.push_addr_accessed(WordAddr(slot.addr)); + } + } + } + Ok(()) + })?; + + // Debug: compare GPU shard_ctx vs CPU shard_ctx independently + if crate::instructions::gpu::config::is_debug_compare_enabled() { + let slots = ram_slots_d2h()?; + debug_compare_shard_ec::( + &compact_records, + &slots, + config, + shard_ctx, + shard_steps, + step_indices, + kind, + ); + } + + // Populate shard_ctx: gpu_ec_records (raw bytes for assign_shared_circuit) + let raw_bytes = unsafe { + std::slice::from_raw_parts( + compact_records.as_ptr() as *const u8, + compact_records.len() * std::mem::size_of::(), + ) + }; + shard_ctx.extend_gpu_ec_records_raw(raw_bytes); + + Ok::<(), ZKVMError>(()) + })?; + } else if gpu_ram_slots.is_some() && is_standard_kind { + // GPU shard records path (no EC): D2H + lightweight CPU scan + info_span!("gpu_shard_records").in_scope(|| { + let ram_buf = gpu_ram_slots.unwrap(); + let slot_bytes: Vec = ram_buf.to_vec().map_err(|e| { + ZKVMError::InvalidWitness(format!("ram_slots D2H failed: {e}").into()) + })?; + let slots: &[GpuRamRecordSlot] = unsafe { + std::slice::from_raw_parts( + slot_bytes.as_ptr() as *const GpuRamRecordSlot, + slot_bytes.len() * 4 / std::mem::size_of::(), + ) + }; + let mut forked = shard_ctx.get_forked(); + let thread_ctx = &mut forked[0]; + gpu_collect_shard_records(thread_ctx, slots); + Ok::<(), ZKVMError>(()) + })?; + } else { + // CPU: collect shard records only (send/addr_accessed). + info_span!("cpu_shard_records").in_scope(|| { + let _ = cpu_collect_shardram::(config, shard_ctx, shard_steps, step_indices)?; + Ok::<(), ZKVMError>(()) + })?; + } + lk_multiplicity + } else { + // GPU LK counters missing or unverified — fall back to full CPU lk and shardram + info_span!("cpu_lk_shardram").in_scope(|| { + cpu_collect_lk_and_shardram::(config, shard_ctx, shard_steps, step_indices) + })? + }; + debug_compare_final_lk::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + kind, + &lk_multiplicity, + )?; + debug_compare_shardram::(config, shard_ctx, shard_steps, step_indices, kind)?; + + // Step 3: Build structural witness (just selector = ONE) + let mut raw_structural = RowMajorMatrix::::new( + total_instances, + num_structural_witin, + I::padding_strategy(), + ); + for row in raw_structural.iter_mut() { + *row.last_mut().unwrap() = E::BaseField::ONE; + } + raw_structural.padding_by_strategy(); + + // Step 4: Transpose (column-major → row-major) on GPU, then D2H copy to RowMajorMatrix + let mut raw_witin = info_span!("transpose_d2h", rows = total_instances, cols = num_witin) + .in_scope(|| { + gpu_witness_to_rmm::( + hal, + gpu_witness, + total_instances, + num_witin, + I::padding_strategy(), + ) + })?; + raw_witin.padding_by_strategy(); + debug_compare_witness::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + kind, + &raw_witin, + )?; + + Ok(([raw_witin, raw_structural], lk_multiplicity)) +} + +// Type aliases and D2H conversion functions live in super::utils::d2h. + +/// Compute fetch counter parameters from step data. +pub(crate) fn compute_fetch_params( + shard_steps: &[StepRecord], + step_indices: &[StepIndex], +) -> (u32, usize) { + let mut min_pc = u32::MAX; + let mut max_pc = 0u32; + for &idx in step_indices { + let pc = shard_steps[idx].pc().before.0; + min_pc = min_pc.min(pc); + max_pc = max_pc.max(pc); + } + if min_pc > max_pc { + return (0, 0); + } + let fetch_base_pc = min_pc; + let fetch_num_slots = ((max_pc - min_pc) / 4 + 1) as usize; + (fetch_base_pc, fetch_num_slots) +} + +/// GPU kernel dispatch based on instruction kind. +/// All kinds return witness + LK counters (merged into single GPU kernel). +fn gpu_fill_witness>( + hal: &CudaHalBB31, + config: &I::InstructionConfig, + shard_ctx: &ShardContext, + num_witin: usize, + shard_steps: &[StepRecord], + step_indices: &[StepIndex], + kind: GpuWitgenKind, +) -> Result< + ( + WitResult, + Option, + Option, + Option, + Option, + ), + ZKVMError, +> { + // Upload shard_steps to GPU once (cached across ADD/LW calls within same shard). + let shard_id = shard_ctx.shard_id; + info_span!("upload_shard_steps") + .in_scope(|| upload_shard_steps_cached(hal, shard_steps, shard_id))?; + + // Convert step_indices from usize to u32 for GPU. + let indices_u32: Vec = info_span!("indices_u32", n = step_indices.len()) + .in_scope(|| step_indices.iter().map(|&i| i as u32).collect()); + let shard_offset = shard_ctx.current_shard_offset_cycle(); + + // Helper to split GpuWitgenFullResult into (witness, Some(lk_counters), ram_slots, compact_ec, compact_addr) + macro_rules! split_full { + ($result:expr) => {{ + let full = $result?; + Ok(( + full.witness, + Some(full.lk_counters), + full.ram_slots, + full.compact_ec, + full.compact_addr, + )) + }}; + } + + // Compute fetch params for all GPU kinds (LK counters are merged into all kernels) + let (fetch_base_pc, fetch_num_slots) = compute_fetch_params(shard_steps, step_indices); + + // Ensure shard metadata is cached for GPU shard records (shared across all kernel kinds) + info_span!("ensure_shard_meta") + .in_scope(|| ensure_shard_metadata_cached(hal, shard_ctx, shard_steps.len()))?; + + match kind { + GpuWitgenKind::Add => { + let arith_config = unsafe { + &*(config as *const I::InstructionConfig + as *const crate::instructions::riscv::arith::ArithConfig) + }; + let col_map = info_span!("col_map") + .in_scope(|| super::chips::add::extract_add_column_map(arith_config, num_witin)); + info_span!("hal_witgen_add").in_scope(|| { + with_cached_gpu_ctx(|gpu_records, shard_bufs| { + split_full!( + hal.witgen + .witgen_add( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), + ) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_add failed: {e}").into(), + ) + }) + ) + }) + }) + } + GpuWitgenKind::Sub => { + let arith_config = unsafe { + &*(config as *const I::InstructionConfig + as *const crate::instructions::riscv::arith::ArithConfig) + }; + let col_map = info_span!("col_map") + .in_scope(|| super::chips::sub::extract_sub_column_map(arith_config, num_witin)); + info_span!("hal_witgen_sub").in_scope(|| { + with_cached_gpu_ctx(|gpu_records, shard_bufs| { + split_full!( + hal.witgen + .witgen_sub( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), + ) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_sub failed: {e}").into(), + ) + }) + ) + }) + }) + } + GpuWitgenKind::LogicR(logic_kind) => { + let logic_config = unsafe { + &*(config as *const I::InstructionConfig + as *const crate::instructions::riscv::logic::logic_circuit::LogicConfig) + }; + let col_map = info_span!("col_map").in_scope(|| { + super::chips::logic_r::extract_logic_r_column_map(logic_config, num_witin) + }); + info_span!("hal_witgen_logic_r").in_scope(|| { + with_cached_gpu_ctx(|gpu_records, shard_bufs| { + split_full!( + hal.witgen + .witgen_logic_r( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + logic_kind, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), + ) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_logic_r failed: {e}").into(), + ) + }) + ) + }) + }) + } + + GpuWitgenKind::LogicI(logic_kind) => { + let logic_config = unsafe { + &*(config as *const I::InstructionConfig + as *const crate::instructions::riscv::logic_imm::logic_imm_circuit_v2::LogicConfig) + }; + let col_map = info_span!("col_map").in_scope(|| { + super::chips::logic_i::extract_logic_i_column_map(logic_config, num_witin) + }); + info_span!("hal_witgen_logic_i").in_scope(|| { + with_cached_gpu_ctx(|gpu_records, shard_bufs| { + split_full!( + hal.witgen + .witgen_logic_i( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + logic_kind, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), + ) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_logic_i failed: {e}").into(), + ) + }) + ) + }) + }) + } + + GpuWitgenKind::Addi => { + let addi_config = unsafe { + &*(config as *const I::InstructionConfig + as *const crate::instructions::riscv::arith_imm::arith_imm_circuit_v2::InstructionConfig) + }; + let col_map = info_span!("col_map") + .in_scope(|| super::chips::addi::extract_addi_column_map(addi_config, num_witin)); + info_span!("hal_witgen_addi").in_scope(|| { + with_cached_gpu_ctx(|gpu_records, shard_bufs| { + split_full!( + hal.witgen + .witgen_addi( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), + ) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_addi failed: {e}").into(), + ) + }) + ) + }) + }) + } + + GpuWitgenKind::Lui => { + let lui_config = unsafe { + &*(config as *const I::InstructionConfig + as *const crate::instructions::riscv::lui::LuiConfig) + }; + let col_map = info_span!("col_map") + .in_scope(|| super::chips::lui::extract_lui_column_map(lui_config, num_witin)); + info_span!("hal_witgen_lui").in_scope(|| { + with_cached_gpu_ctx(|gpu_records, shard_bufs| { + split_full!( + hal.witgen + .witgen_lui( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), + ) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_lui failed: {e}").into(), + ) + }) + ) + }) + }) + } + + GpuWitgenKind::Auipc => { + let auipc_config = unsafe { + &*(config as *const I::InstructionConfig + as *const crate::instructions::riscv::auipc::AuipcConfig) + }; + let col_map = info_span!("col_map").in_scope(|| { + super::chips::auipc::extract_auipc_column_map(auipc_config, num_witin) + }); + info_span!("hal_witgen_auipc").in_scope(|| { + with_cached_gpu_ctx(|gpu_records, shard_bufs| { + split_full!( + hal.witgen + .witgen_auipc( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), + ) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_auipc failed: {e}").into(), + ) + }) + ) + }) + }) + } + + GpuWitgenKind::Jal => { + let jal_config = unsafe { + &*(config as *const I::InstructionConfig + as *const crate::instructions::riscv::jump::jal_v2::JalConfig) + }; + let col_map = info_span!("col_map") + .in_scope(|| super::chips::jal::extract_jal_column_map(jal_config, num_witin)); + info_span!("hal_witgen_jal").in_scope(|| { + with_cached_gpu_ctx(|gpu_records, shard_bufs| { + split_full!( + hal.witgen + .witgen_jal( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), + ) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_jal failed: {e}").into(), + ) + }) + ) + }) + }) + } + + GpuWitgenKind::ShiftR(shift_kind) => { + let shift_config = unsafe { + &*(config as *const I::InstructionConfig + as *const crate::instructions::riscv::shift::shift_circuit_v2::ShiftRTypeConfig< + E, + >) + }; + let col_map = info_span!("col_map").in_scope(|| { + super::chips::shift_r::extract_shift_r_column_map(shift_config, num_witin) + }); + info_span!("hal_witgen_shift_r").in_scope(|| { + with_cached_gpu_ctx(|gpu_records, shard_bufs| { + split_full!( + hal.witgen + .witgen_shift_r( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + shift_kind, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), + ) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_shift_r failed: {e}").into(), + ) + }) + ) + }) + }) + } + + GpuWitgenKind::ShiftI(shift_kind) => { + let shift_config = unsafe { + &*(config as *const I::InstructionConfig + as *const crate::instructions::riscv::shift::shift_circuit_v2::ShiftImmConfig< + E, + >) + }; + let col_map = info_span!("col_map").in_scope(|| { + super::chips::shift_i::extract_shift_i_column_map(shift_config, num_witin) + }); + info_span!("hal_witgen_shift_i").in_scope(|| { + with_cached_gpu_ctx(|gpu_records, shard_bufs| { + split_full!( + hal.witgen + .witgen_shift_i( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + shift_kind, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), + ) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_shift_i failed: {e}").into(), + ) + }) + ) + }) + }) + } + + GpuWitgenKind::Slt(is_signed) => { + let slt_config = unsafe { + &*(config as *const I::InstructionConfig + as *const crate::instructions::riscv::slt::slt_circuit_v2::SetLessThanConfig) + }; + let col_map = info_span!("col_map") + .in_scope(|| super::chips::slt::extract_slt_column_map(slt_config, num_witin)); + info_span!("hal_witgen_slt").in_scope(|| { + with_cached_gpu_ctx(|gpu_records, shard_bufs| { + split_full!( + hal.witgen + .witgen_slt( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + is_signed, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), + ) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_slt failed: {e}").into(), + ) + }) + ) + }) + }) + } + + GpuWitgenKind::Slti(is_signed) => { + let slti_config = unsafe { + &*(config as *const I::InstructionConfig + as *const crate::instructions::riscv::slti::slti_circuit_v2::SetLessThanImmConfig) + }; + let col_map = info_span!("col_map") + .in_scope(|| super::chips::slti::extract_slti_column_map(slti_config, num_witin)); + info_span!("hal_witgen_slti").in_scope(|| { + with_cached_gpu_ctx(|gpu_records, shard_bufs| { + split_full!( + hal.witgen + .witgen_slti( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + is_signed, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), + ) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_slti failed: {e}").into(), + ) + }) + ) + }) + }) + } + + GpuWitgenKind::BranchEq(is_beq) => { + let branch_config = unsafe { + &*(config as *const I::InstructionConfig + as *const crate::instructions::riscv::branch::branch_circuit_v2::BranchConfig< + E, + >) + }; + let col_map = info_span!("col_map").in_scope(|| { + super::chips::branch_eq::extract_branch_eq_column_map(branch_config, num_witin) + }); + info_span!("hal_witgen_branch_eq").in_scope(|| { + with_cached_gpu_ctx(|gpu_records, shard_bufs| { + split_full!( + hal.witgen + .witgen_branch_eq( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + is_beq, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), + ) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_branch_eq failed: {e}").into(), + ) + }) + ) + }) + }) + } + + GpuWitgenKind::BranchCmp(is_signed) => { + let branch_config = unsafe { + &*(config as *const I::InstructionConfig + as *const crate::instructions::riscv::branch::branch_circuit_v2::BranchConfig< + E, + >) + }; + let col_map = info_span!("col_map").in_scope(|| { + super::chips::branch_cmp::extract_branch_cmp_column_map(branch_config, num_witin) + }); + info_span!("hal_witgen_branch_cmp").in_scope(|| { + with_cached_gpu_ctx(|gpu_records, shard_bufs| { + split_full!( + hal.witgen + .witgen_branch_cmp( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + is_signed, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), + ) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_branch_cmp failed: {e}").into(), + ) + }) + ) + }) + }) + } + + GpuWitgenKind::Jalr => { + let jalr_config = unsafe { + &*(config as *const I::InstructionConfig + as *const crate::instructions::riscv::jump::jalr_v2::JalrConfig) + }; + let col_map = info_span!("col_map") + .in_scope(|| super::chips::jalr::extract_jalr_column_map(jalr_config, num_witin)); + info_span!("hal_witgen_jalr").in_scope(|| { + with_cached_gpu_ctx(|gpu_records, shard_bufs| { + split_full!( + hal.witgen + .witgen_jalr( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), + ) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_jalr failed: {e}").into(), + ) + }) + ) + }) + }) + } + + GpuWitgenKind::Sw => { + let sw_config = unsafe { + &*(config as *const I::InstructionConfig + as *const crate::instructions::riscv::memory::store_v2::StoreConfig) + }; + let mem_max_bits = sw_config.memory_addr.max_bits as u32; + let col_map = info_span!("col_map") + .in_scope(|| super::chips::sw::extract_sw_column_map(sw_config, num_witin)); + info_span!("hal_witgen_sw").in_scope(|| { + with_cached_gpu_ctx(|gpu_records, shard_bufs| { + split_full!( + hal.witgen + .witgen_sw( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + mem_max_bits, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), + ) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_sw failed: {e}").into(), + ) + }) + ) + }) + }) + } + + GpuWitgenKind::Sh => { + let sh_config = unsafe { + &*(config as *const I::InstructionConfig + as *const crate::instructions::riscv::memory::store_v2::StoreConfig) + }; + let mem_max_bits = sh_config.memory_addr.max_bits as u32; + let col_map = info_span!("col_map") + .in_scope(|| super::chips::sh::extract_sh_column_map(sh_config, num_witin)); + info_span!("hal_witgen_sh").in_scope(|| { + with_cached_gpu_ctx(|gpu_records, shard_bufs| { + split_full!( + hal.witgen + .witgen_sh( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + mem_max_bits, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), + ) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_sh failed: {e}").into(), + ) + }) + ) + }) + }) + } + + GpuWitgenKind::Sb => { + let sb_config = unsafe { + &*(config as *const I::InstructionConfig + as *const crate::instructions::riscv::memory::store_v2::StoreConfig) + }; + let mem_max_bits = sb_config.memory_addr.max_bits as u32; + let col_map = info_span!("col_map") + .in_scope(|| super::chips::sb::extract_sb_column_map(sb_config, num_witin)); + info_span!("hal_witgen_sb").in_scope(|| { + with_cached_gpu_ctx(|gpu_records, shard_bufs| { + split_full!( + hal.witgen + .witgen_sb( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + mem_max_bits, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), + ) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_sb failed: {e}").into(), + ) + }) + ) + }) + }) + } + + GpuWitgenKind::LoadSub { + load_width, + is_signed, + } => { + let load_config = unsafe { + &*(config as *const I::InstructionConfig + as *const crate::instructions::riscv::memory::load_v2::LoadConfig) + }; + let col_map = info_span!("col_map").in_scope(|| { + super::chips::load_sub::extract_load_sub_column_map(load_config, num_witin) + }); + let mem_max_bits = load_config.memory_addr.max_bits as u32; + info_span!("hal_witgen_load_sub").in_scope(|| { + with_cached_gpu_ctx(|gpu_records, shard_bufs| { + split_full!( + hal.witgen + .witgen_load_sub( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + load_width, + is_signed, + mem_max_bits, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), + ) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_load_sub failed: {e}").into(), + ) + }) + ) + }) + }) + } + + GpuWitgenKind::Mul(mul_kind) => { + let mul_config = unsafe { + &*(config as *const I::InstructionConfig + as *const crate::instructions::riscv::mulh::mulh_circuit_v2::MulhConfig) + }; + let col_map = info_span!("col_map") + .in_scope(|| super::chips::mul::extract_mul_column_map(mul_config, num_witin)); + info_span!("hal_witgen_mul").in_scope(|| { + with_cached_gpu_ctx(|gpu_records, shard_bufs| { + split_full!( + hal.witgen + .witgen_mul( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + mul_kind, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), + ) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_mul failed: {e}").into(), + ) + }) + ) + }) + }) + } + + GpuWitgenKind::Div(div_kind) => { + let div_config = unsafe { + &*(config as *const I::InstructionConfig + as *const crate::instructions::riscv::div::div_circuit_v2::DivRemConfig) + }; + let col_map = info_span!("col_map") + .in_scope(|| super::chips::div::extract_div_column_map(div_config, num_witin)); + info_span!("hal_witgen_div").in_scope(|| { + with_cached_gpu_ctx(|gpu_records, shard_bufs| { + split_full!( + hal.witgen + .witgen_div( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + div_kind, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), + ) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_div failed: {e}").into(), + ) + }) + ) + }) + }) + } + GpuWitgenKind::Lw => { + let load_config = unsafe { + &*(config as *const I::InstructionConfig + as *const crate::instructions::riscv::memory::load_v2::LoadConfig) + }; + #[cfg(not(feature = "u16limb_circuit"))] + let load_config = unsafe { + &*(config as *const I::InstructionConfig + as *const crate::instructions::riscv::memory::load::LoadConfig) + }; + let mem_max_bits = load_config.memory_addr.max_bits as u32; + let col_map = info_span!("col_map") + .in_scope(|| super::chips::lw::extract_lw_column_map(load_config, num_witin)); + info_span!("hal_witgen_lw").in_scope(|| { + with_cached_gpu_ctx(|gpu_records, shard_bufs| { + split_full!( + hal.witgen + .witgen_lw( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + mem_max_bits, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), + ) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_lw failed: {e}").into(), + ) + }) + ) + }) + }) + } + GpuWitgenKind::Keccak => { + unreachable!("keccak uses gpu_assign_keccak_instances, not try_gpu_assign_instances") + } + } +} diff --git a/ceno_zkvm/src/instructions/gpu/invasive_changes.md b/ceno_zkvm/src/instructions/gpu/invasive_changes.md new file mode 100644 index 000000000..fb79f22ac --- /dev/null +++ b/ceno_zkvm/src/instructions/gpu/invasive_changes.md @@ -0,0 +1,234 @@ +# GPU Witness Generation — Invasive Changes to Existing Codebase + +This document lists all changes to **existing** ceno structures, traits, and flows +that this PR introduces. GPU-only new code (`instructions/gpu/`) is excluded — +this focuses on what existing code was modified and why. + +--- + +## 1. `ceno_emul` — FFI Layout Changes (+332 / -88 lines) + +### `#[repr(C)]` on emulator types + +The following types were made `#[repr(C)]` to enable zero-copy H2D transfer to GPU: + +| Type | File | Size | Purpose | +|------|------|------|---------| +| `StepRecord` | `tracer.rs` | 136B | Per-step emulator output, bulk H2D | +| `Instruction` | `rv32im.rs` | 12B | Opcode encoding embedded in StepRecord | +| `InsnKind` | `rv32im.rs` | 1B | `#[repr(u8)]` enum discriminant | +| `MemOp` | `tracer.rs` | 16/24B | Read/Write ops embedded in StepRecord | +| `Change` | `tracer.rs` | 2×T | Before/after pair | + +**Impact**: These were previously `#[derive(Debug, Clone)]` with compiler-chosen layout. +Adding `#[repr(C)]` pins field order and padding. No behavioral change for CPU code, +but **field reordering or insertion now requires updating the CUDA mirror structs**. + +### New types in `tracer.rs` + +- `PackedNextAccessEntry` (16B, `#[repr(C)]`) — 40-bit packed cycle+addr for GPU FA table +- `ShardPlanBuilder` — preflight shard planning with cell-count balancing + +### Layout test + +`test_step_record_layout_for_gpu` verifies byte offsets of all `StepRecord` fields +at compile time. CUDA side has matching `static_assert(sizeof(...))`. + +--- + +## 2. `Instruction` Trait — New Methods and Constants + +**File**: `ceno_zkvm/src/instructions.rs` + +| Addition | Purpose | +|----------|---------| +| `const GPU_LK_SHARDRAM: bool = false` | Opt-in flag: does this chip have GPU LK+shardram support? | +| `fn collect_lk_and_shardram(...)` | CPU companion: collect all LK multiplicities + shard RAM records (without witness replay) | +| `fn collect_shardram(...)` | CPU companion: collect shard RAM records only (GPU handles LK) | + +**Default implementations** return `Err(...)` — chips must explicitly opt in. + +**Impact**: Existing chips that don't implement GPU support are unaffected (defaults). +The trait's existing `assign_instance` and `assign_instances` are unchanged. + +Three macros reduce per-chip boilerplate: +- `impl_collect_lk_and_shardram!` — wraps the unsafe `CpuLkShardramSink` prologue +- `impl_collect_shardram!` — one-line delegate to insn_config +- `impl_gpu_assign!` — `#[cfg(feature = "gpu")] assign_instances` override + +--- + +## 3. Gadgets — New `emit_lk_and_shardram` / `emit_shardram` Methods + +**File**: `ceno_zkvm/src/instructions/riscv/insn_base.rs` (+253 lines) + +Every base gadget (`ReadRS1`, `ReadRS2`, `WriteRD`, `ReadMEM`, `WriteMEM`, `MemAddr`) +gained two new methods: + +| Method | What it does | +|--------|-------------| +| `emit_lk_and_shardram(sink, ctx, step)` | Emit LK ops + RAM send events through `LkShardramSink` | +| `emit_shardram(shard_ctx, step)` | Directly write shard RAM records to `ShardContext` (no LK) | + +**Impact**: Additive only — existing `assign_instance` methods are unchanged. +The new methods extract the same logic that `assign_instance` performed inline, +but route through the `LkShardramSink` trait instead of directly calling +`lk_multiplicity.assert_ux(...)`. + +### Intermediate configs (`r_insn.rs`, `i_insn.rs`, `b_insn.rs`, `s_insn.rs`, `j_insn.rs`, `im_insn.rs`) + +Each gained corresponding `emit_lk_and_shardram` / `emit_shardram` methods that +compose their gadgets' methods + emit `LkOp::Fetch`. + +--- + +## 4. Per-Chip Circuit Files — GPU Opt-in (+792 / -129 lines across ~20 files) + +Each v2 circuit file (arith.rs, logic_circuit.rs, div_circuit_v2.rs, etc.) gained: + +```rust +const GPU_LK_SHARDRAM: bool = true; // or conditional match + +impl_collect_lk_and_shardram!(r_insn, |sink, step, _config, _ctx| { + // chip-specific LK ops +}); +impl_collect_shardram!(r_insn); +impl_gpu_assign!(dispatch::GpuWitgenKind::Add); +``` + +**Impact**: Additive — existing `assign_instance` and `construct_circuit` unchanged. +The `#[cfg(feature = "gpu")] assign_instances` override is only compiled with the +`gpu` feature flag. + +--- + +## 5. `ShardContext` — New Fields and Methods + +**File**: `ceno_zkvm/src/e2e.rs` (+616 / -199 lines) + +### New fields + +| Field | Type | Purpose | +|-------|------|---------| +| `sorted_next_accesses` | `Arc` | Pre-sorted packed future-access table for GPU bulk H2D | +| `gpu_ec_records` | `Vec` | Raw bytes of GPU-produced compact EC shard records | +| `syscall_witnesses` | `Arc>` | Keccak syscall data (previously passed separately) | + +### New methods + +| Method | Purpose | +|--------|---------| +| `new_empty_like()` | Clone shard metadata with empty record storage (for debug comparison) | +| `insert_read_record()` / `insert_write_record()` | Direct record insertion (GPU D2H path) | +| `push_addr_accessed()` | Direct addr insertion (GPU D2H path) | +| `extend_gpu_ec_records_raw()` | Append raw GPU EC record bytes | +| `has_gpu_ec_records()` / `take_gpu_ec_records()` | GPU EC record lifecycle | + +### Renamed method + +`send()` → split into `record_send_without_touch()` (no addr_accessed tracking) and +`send()` (which calls `record_send_without_touch` + `push_addr_accessed`). + +### Pipeline hooks (in `generate_witness` shard loop) + +```rust +#[cfg(feature = "gpu")] +flush_shared_ec_buffers(&mut shard_ctx); // D2H shared GPU buffers + +#[cfg(feature = "gpu")] +invalidate_shard_steps_cache(); // free GPU memory +``` + +### Pipeline mode (in `create_proofs_streaming`) + +New overlap pipeline (default when GPU feature enabled but `CENO_GPU_ENABLE_WITGEN` unset): +CPU witgen on thread A, GPU prove on thread B, connected by `crossbeam::bounded(0)` channel. + +--- + +## 6. `ZKVMWitnesses` — GPU ShardRam Pipeline + +**File**: `ceno_zkvm/src/structs.rs` (+580 / -130 lines) + +### `assign_shared_circuit` — new GPU fast path + +Added `try_assign_shared_circuit_gpu()` that keeps data on GPU device: +1. Takes shared device buffers (EC records + addr_accessed) +2. GPU sort+dedup addr_accessed +3. GPU batch EC computation for continuation records +4. GPU merge+partition records (writes before reads) +5. GPU ShardRamCircuit witness generation (Poseidon2 + EC tree) + +Falls back to CPU path on failure. + +### `gpu_ec_records_to_shard_ram_inputs` + +Converts raw GPU EC bytes (`Vec`) to `Vec>` with pre-computed +EC points. Used in the CPU fallback path. + +--- + +## 7. `ShardRamCircuit` — GPU Witness Generation + +**File**: `ceno_zkvm/src/tables/shard_ram.rs` (+491 / -14 lines) + +### New GPU functions + +| Function | Purpose | +|----------|---------| +| `try_gpu_assign_instances()` | H2D path: CPU records → GPU kernel → D2H witness | +| `try_gpu_assign_instances_from_device()` | Device path: records already on GPU → kernel → D2H | + +Both run a two-phase GPU pipeline: +1. **Per-row kernel**: basic fields + Poseidon2 trace (344 witness columns) +2. **EC tree kernel**: layer-by-layer binary tree EC summation + +### Visibility change + +`ShardRamConfig` fields changed from private to `pub(crate)` to allow +column map extraction in `gpu/chips/shard_ram.rs`. + +--- + +## 8. `SepticCurve` — New Math Utilities + +**File**: `ceno_zkvm/src/scheme/septic_curve.rs` (+307 lines) + +New CPU-side math for EC point computation (mirrored in CUDA): + +| Function | Purpose | +|----------|---------| +| `SepticExtension::frobenius()` | Frobenius endomorphism for norm computation | +| `SepticExtension::sqrt()` | Cipolla's algorithm for field square roots | +| `SepticPoint::from_x()` | Lift x-coordinate to curve point (used by nonce-finding loop) | +| `QuadraticExtension` | Auxiliary type for Cipolla's algorithm | + +--- + +## 9. Minor Touches + +| File | Change | +|------|--------| +| `Cargo.toml` | `gpu` feature flag, `crossbeam` dependency | +| `gkr_iop/src/gadgets/is_lt.rs` | `AssertLtConfig.0.diff` field access (already `pub`) | +| `gkr_iop/src/utils/lk_multiplicity.rs` | Minor: `LkMultiplicity::increment` | +| `ceno_zkvm/src/gadgets/signed_ext.rs` | `pub(crate) fn msb()` accessor for GPU column map | +| `ceno_zkvm/src/gadgets/poseidon2.rs` | Column contiguity constants for GPU | +| `ceno_zkvm/src/tables/*.rs` | `pub(crate)` visibility on config fields for GPU column map access | +| `ceno_zkvm/src/scheme/{cpu,gpu,prover,verifier}` | Minor plumbing for GPU proving path | +| `ceno_host/tests/test_elf.rs` | E2E test adjustments | + +--- + +## Summary + +| Category | Nature | Risk | +|----------|--------|------| +| `#[repr(C)]` on emulator types | Layout pinning | Low — additive, but field changes now need CUDA sync | +| `Instruction` trait extensions | Additive (defaults provided) | None — existing chips unaffected | +| Gadget `emit_*` methods | Additive | None — existing `assign_instance` unchanged | +| `ShardContext` new fields | Additive (defaults in `Default`) | Low — `Vec::new()` / `Arc::new()` zero-cost | +| `send()` → `record_send_without_touch()` + `send()` | Rename + split | Low — `send()` still works identically | +| `ShardRamConfig` visibility | `private` → `pub(crate)` | None | +| Pipeline overlap mode | New default behavior | Medium — CPU witgen + GPU prove on separate threads | +| `septic_curve.rs` math | Additive | None — new functions, existing unchanged | diff --git a/ceno_zkvm/src/instructions/gpu/mod.rs b/ceno_zkvm/src/instructions/gpu/mod.rs new file mode 100644 index 000000000..e18382c1d --- /dev/null +++ b/ceno_zkvm/src/instructions/gpu/mod.rs @@ -0,0 +1,836 @@ +#[cfg(feature = "gpu")] +#[cfg(feature = "gpu")] +pub mod cache; +pub mod chips; +#[cfg(feature = "gpu")] +#[cfg(feature = "gpu")] +#[cfg(feature = "gpu")] +pub mod config; +#[cfg(feature = "gpu")] +pub mod dispatch; +pub mod utils; + +#[cfg(test)] +mod tests { + + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + e2e::ShardContext, + instructions::{ + Instruction, cpu_assign_instances, cpu_collect_lk_and_shardram, cpu_collect_shardram, + riscv::{ + AddInstruction, JalInstruction, JalrInstruction, LwInstruction, SbInstruction, + branch::{BeqInstruction, BltInstruction}, + div::{DivInstruction, RemuInstruction}, + logic::AndInstruction, + mulh::{MulInstruction, MulhInstruction}, + shift::SraInstruction, + shift_imm::SlliInstruction, + slt::SltInstruction, + slti::SltiInstruction, + }, + }, + structs::ProgramParams, + }; + use ceno_emul::{ + ByteAddr, Change, InsnKind, PC_STEP_SIZE, ReadOp, StepRecord, WordAddr, WriteOp, + encode_rv32, + }; + use ff_ext::GoldilocksExt2; + use gkr_iop::tables::LookupTable; + + type E = GoldilocksExt2; + + fn assert_lk_shardram_match>( + config: &I::InstructionConfig, + num_witin: usize, + num_structural_witin: usize, + steps: &[StepRecord], + ) { + let indices: Vec = (0..steps.len()).collect(); + + let mut assign_ctx = ShardContext::default(); + let (_, expected_lk) = cpu_assign_instances::( + config, + &mut assign_ctx, + num_witin, + num_structural_witin, + steps, + &indices, + ) + .unwrap(); + + let mut collect_ctx = ShardContext::default(); + let actual_lk = + cpu_collect_lk_and_shardram::(config, &mut collect_ctx, steps, &indices).unwrap(); + + assert_eq!(flatten_lk(&expected_lk), flatten_lk(&actual_lk)); + assert_eq!( + assign_ctx.get_addr_accessed(), + collect_ctx.get_addr_accessed() + ); + assert_eq!( + flatten_records(assign_ctx.read_records()), + flatten_records(collect_ctx.read_records()) + ); + assert_eq!( + flatten_records(assign_ctx.write_records()), + flatten_records(collect_ctx.write_records()) + ); + } + + fn assert_shard_lk_shardram_match>( + config: &I::InstructionConfig, + num_witin: usize, + num_structural_witin: usize, + steps: &[StepRecord], + ) { + let indices: Vec = (0..steps.len()).collect(); + + let mut assign_ctx = ShardContext::default(); + let (_, expected_lk) = cpu_assign_instances::( + config, + &mut assign_ctx, + num_witin, + num_structural_witin, + steps, + &indices, + ) + .unwrap(); + + let mut collect_ctx = ShardContext::default(); + let actual_lk = + cpu_collect_shardram::(config, &mut collect_ctx, steps, &indices).unwrap(); + + assert_eq!( + expected_lk[LookupTable::Instruction as usize], + actual_lk[LookupTable::Instruction as usize] + ); + for (table_idx, table) in actual_lk.iter().enumerate() { + if table_idx != LookupTable::Instruction as usize { + assert!( + table.is_empty(), + "unexpected non-fetch shard-only multiplicity in table {table_idx}: {table:?}" + ); + } + } + assert_eq!( + assign_ctx.get_addr_accessed(), + collect_ctx.get_addr_accessed() + ); + assert_eq!( + flatten_records(assign_ctx.read_records()), + flatten_records(collect_ctx.read_records()) + ); + assert_eq!( + flatten_records(assign_ctx.write_records()), + flatten_records(collect_ctx.write_records()) + ); + } + + fn flatten_records( + records: &[std::collections::BTreeMap], + ) -> Vec<(WordAddr, u64, u64, usize)> { + records + .iter() + .flat_map(|table| { + table + .iter() + .map(|(addr, record)| (*addr, record.prev_cycle, record.cycle, record.shard_id)) + }) + .collect() + } + + fn flatten_lk( + multiplicity: &gkr_iop::utils::lk_multiplicity::Multiplicity, + ) -> Vec> { + multiplicity + .iter() + .map(|table| { + let mut entries = table + .iter() + .map(|(key, count)| (*key, *count)) + .collect::>(); + entries.sort_unstable(); + entries + }) + .collect() + } + + #[test] + fn test_add_lk_shardram_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "add_lk_shardram"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + AddInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let steps: Vec<_> = (0..4) + .map(|i| { + let rd = 2 + i; + let rs1 = 8 + i; + let rs2 = 16 + i; + let lhs = 10 + i; + let rhs = 100 + i; + let insn = encode_rv32(InsnKind::ADD, rs1, rs2, rd, 0); + StepRecord::new_r_instruction( + 4 + (i as u64) * 4, + ByteAddr(0x1000 + i * 4), + insn, + lhs, + rhs, + Change::new(0, lhs.wrapping_add(rhs)), + 0, + ) + }) + .collect(); + + assert_lk_shardram_match::>( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &steps, + ); + } + + #[test] + fn test_and_lk_shardram_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "and_lk_shardram"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + AndInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let steps: Vec<_> = (0..4) + .map(|i| { + let rd = 2 + i; + let rs1 = 8 + i; + let rs2 = 16 + i; + let lhs = 0xdead_0000 | i; + let rhs = 0x00ff_ff00 | (i << 8); + let insn = encode_rv32(InsnKind::AND, rs1, rs2, rd, 0); + StepRecord::new_r_instruction( + 4 + (i as u64) * 4, + ByteAddr(0x2000 + i * 4), + insn, + lhs, + rhs, + Change::new(0, lhs & rhs), + 0, + ) + }) + .collect(); + + assert_lk_shardram_match::>( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &steps, + ); + } + + #[test] + fn test_add_shard_lk_shardram_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "add_shard_lk_shardram"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + AddInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let steps: Vec<_> = (0..4) + .map(|i| { + let rd = 2 + i; + let rs1 = 8 + i; + let rs2 = 16 + i; + let lhs = 10 + i; + let rhs = 100 + i; + let insn = encode_rv32(InsnKind::ADD, rs1, rs2, rd, 0); + StepRecord::new_r_instruction( + 84 + (i as u64) * 4, + ByteAddr(0x5000 + i * 4), + insn, + lhs, + rhs, + Change::new(0, lhs.wrapping_add(rhs)), + 0, + ) + }) + .collect(); + + assert_shard_lk_shardram_match::>( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &steps, + ); + } + + #[test] + fn test_and_shard_lk_shardram_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "and_shard_lk_shardram"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + AndInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let steps: Vec<_> = (0..4) + .map(|i| { + let rd = 2 + i; + let rs1 = 8 + i; + let rs2 = 16 + i; + let lhs = 0xdead_0000 | i; + let rhs = 0x00ff_ff00 | (i << 8); + let insn = encode_rv32(InsnKind::AND, rs1, rs2, rd, 0); + StepRecord::new_r_instruction( + 100 + (i as u64) * 4, + ByteAddr(0x5100 + i * 4), + insn, + lhs, + rhs, + Change::new(0, lhs & rhs), + 0, + ) + }) + .collect(); + + assert_shard_lk_shardram_match::>( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &steps, + ); + } + + #[test] + fn test_lw_lk_shardram_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "lw_lk_shardram"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + LwInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let steps: Vec<_> = (0..4) + .map(|i| { + let rd = 2 + i; + let rs1 = 8 + i; + let rs1_val = 0x1000u32 + i * 16; + let imm = (i as i32) * 4 - 4; + let mem_addr = rs1_val.wrapping_add_signed(imm); + let mem_val = 0xabc0_0000 | i; + let insn = encode_rv32(InsnKind::LW, rs1, 0, rd, imm); + let mem_read = ReadOp { + addr: WordAddr::from(ByteAddr(mem_addr)), + value: mem_val, + previous_cycle: 0, + }; + StepRecord::new_im_instruction( + 4 + (i as u64) * 4, + ByteAddr(0x3000 + i * 4), + insn, + rs1_val, + Change::new(0, mem_val), + mem_read, + 0, + ) + }) + .collect(); + + assert_lk_shardram_match::>( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &steps, + ); + } + + #[test] + fn test_lw_shard_lk_shardram_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "lw_shard_lk_shardram"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + LwInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let steps: Vec<_> = (0..4) + .map(|i| { + let rd = 2 + i; + let rs1 = 8 + i; + let rs1_val = 0x1400u32 + i * 16; + let imm = (i as i32) * 4 - 4; + let mem_addr = rs1_val.wrapping_add_signed(imm); + let mem_val = 0xabd0_0000 | i; + let insn = encode_rv32(InsnKind::LW, rs1, 0, rd, imm); + let mem_read = ReadOp { + addr: WordAddr::from(ByteAddr(mem_addr)), + value: mem_val, + previous_cycle: 0, + }; + StepRecord::new_im_instruction( + 116 + (i as u64) * 4, + ByteAddr(0x5200 + i * 4), + insn, + rs1_val, + Change::new(0, mem_val), + mem_read, + 0, + ) + }) + .collect(); + + assert_shard_lk_shardram_match::>( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &steps, + ); + } + + #[test] + fn test_beq_lk_shardram_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "beq_lk_shardram"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + BeqInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let cases = [ + (true, 0x1122_3344, 0x1122_3344), + (false, 0x5566_7788, 0x99aa_bbcc), + ]; + let steps: Vec<_> = cases + .into_iter() + .enumerate() + .map(|(i, (taken, lhs, rhs))| { + let pc = ByteAddr(0x4000 + i as u32 * 4); + let next_pc = if taken { + ByteAddr(pc.0 + 8) + } else { + pc + PC_STEP_SIZE + }; + StepRecord::new_b_instruction( + 4 + i as u64 * 4, + Change::new(pc, next_pc), + encode_rv32(InsnKind::BEQ, 8 + i as u32, 16 + i as u32, 0, 8), + lhs, + rhs, + 0, + ) + }) + .collect(); + + assert_lk_shardram_match::>( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &steps, + ); + } + + #[test] + fn test_blt_lk_shardram_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "blt_lk_shardram"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + BltInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let cases = [(true, (-2i32) as u32, 1u32), (false, 7u32, (-3i32) as u32)]; + let steps: Vec<_> = cases + .into_iter() + .enumerate() + .map(|(i, (taken, lhs, rhs))| { + let pc = ByteAddr(0x4100 + i as u32 * 4); + let next_pc = if taken { + ByteAddr(pc.0.wrapping_sub(8)) + } else { + pc + PC_STEP_SIZE + }; + StepRecord::new_b_instruction( + 12 + i as u64 * 4, + Change::new(pc, next_pc), + encode_rv32(InsnKind::BLT, 4 + i as u32, 5 + i as u32, 0, -8), + lhs, + rhs, + 10, + ) + }) + .collect(); + + assert_lk_shardram_match::>( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &steps, + ); + } + + #[test] + fn test_jal_lk_shardram_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "jal_lk_shardram"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + JalInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let offsets = [8, -8]; + let steps: Vec<_> = offsets + .into_iter() + .enumerate() + .map(|(i, offset)| { + let pc = ByteAddr(0x4200 + i as u32 * 4); + StepRecord::new_j_instruction( + 20 + i as u64 * 4, + Change::new(pc, ByteAddr(pc.0.wrapping_add_signed(offset))), + encode_rv32(InsnKind::JAL, 0, 0, 3 + i as u32, offset), + Change::new(0, (pc + PC_STEP_SIZE).into()), + 0, + ) + }) + .collect(); + + assert_lk_shardram_match::>( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &steps, + ); + } + + #[test] + fn test_jalr_lk_shardram_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "jalr_lk_shardram"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + JalrInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let cases = [(100u32, 3), (0x4010u32, -5)]; + let steps: Vec<_> = cases + .into_iter() + .enumerate() + .map(|(i, (rs1, imm))| { + let pc = ByteAddr(0x4300 + i as u32 * 4); + let next_pc = ByteAddr(rs1.wrapping_add_signed(imm) & !1); + StepRecord::new_i_instruction( + 28 + i as u64 * 4, + Change::new(pc, next_pc), + encode_rv32(InsnKind::JALR, 2 + i as u32, 0, 6 + i as u32, imm), + rs1, + Change::new(0, (pc + PC_STEP_SIZE).into()), + 0, + ) + }) + .collect(); + + assert_lk_shardram_match::>( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &steps, + ); + } + + #[test] + fn test_slt_lk_shardram_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "slt_lk_shardram"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + SltInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let cases = [((-1i32) as u32, 0u32), (5u32, (-2i32) as u32)]; + let steps: Vec<_> = cases + .into_iter() + .enumerate() + .map(|(i, (lhs, rhs))| { + let insn = + encode_rv32(InsnKind::SLT, 9 + i as u32, 10 + i as u32, 11 + i as u32, 0); + StepRecord::new_r_instruction( + 36 + i as u64 * 4, + ByteAddr(0x4400 + i as u32 * 4), + insn, + lhs, + rhs, + Change::new(0, ((lhs as i32) < (rhs as i32)) as u32), + 0, + ) + }) + .collect(); + + assert_lk_shardram_match::>( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &steps, + ); + } + + #[test] + fn test_slti_lk_shardram_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "slti_lk_shardram"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + SltiInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let cases = [(0u32, -1), ((-2i32) as u32, 1)]; + let steps: Vec<_> = cases + .into_iter() + .enumerate() + .map(|(i, (rs1, imm))| { + let insn = encode_rv32(InsnKind::SLTI, 12 + i as u32, 0, 13 + i as u32, imm); + let pc = ByteAddr(0x4500 + i as u32 * 4); + StepRecord::new_i_instruction( + 44 + i as u64 * 4, + Change::new(pc, pc + PC_STEP_SIZE), + insn, + rs1, + Change::new(0, ((rs1 as i32) < imm) as u32), + 0, + ) + }) + .collect(); + + assert_lk_shardram_match::>( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &steps, + ); + } + + #[test] + fn test_sra_lk_shardram_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "sra_lk_shardram"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + SraInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let cases = [(0x8765_4321u32, 4u32), (0xf000_0000u32, 31u32)]; + let steps: Vec<_> = cases + .into_iter() + .enumerate() + .map(|(i, (lhs, rhs))| { + let shift = rhs & 31; + let rd = ((lhs as i32) >> shift) as u32; + StepRecord::new_r_instruction( + 52 + i as u64 * 4, + ByteAddr(0x4600 + i as u32 * 4), + encode_rv32(InsnKind::SRA, 6 + i as u32, 7 + i as u32, 8 + i as u32, 0), + lhs, + rhs, + Change::new(0, rd), + 0, + ) + }) + .collect(); + + assert_lk_shardram_match::>( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &steps, + ); + } + + #[test] + fn test_slli_lk_shardram_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "slli_lk_shardram"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + SlliInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let cases = [(0x1234_5678u32, 3), (0x0000_0001u32, 31)]; + let steps: Vec<_> = cases + .into_iter() + .enumerate() + .map(|(i, (rs1, imm))| { + let pc = ByteAddr(0x4700 + i as u32 * 4); + StepRecord::new_i_instruction( + 60 + i as u64 * 4, + Change::new(pc, pc + PC_STEP_SIZE), + encode_rv32(InsnKind::SLLI, 9 + i as u32, 0, 10 + i as u32, imm), + rs1, + Change::new(0, rs1.wrapping_shl((imm & 31) as u32)), + 0, + ) + }) + .collect(); + + assert_lk_shardram_match::>( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &steps, + ); + } + + #[test] + fn test_sb_lk_shardram_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "sb_lk_shardram"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + SbInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let steps: Vec<_> = (0..2) + .map(|i| { + let rs1 = 0x4800u32 + i * 16; + let rs2 = 0x1234_5600u32 | i; + let imm = i as i32 - 1; + let addr = ByteAddr::from(rs1.wrapping_add_signed(imm)); + let prev = 0x4030_2010u32 + i; + let shift = (addr.shift() * 8) as usize; + let mut next = prev & !(0xff << shift); + next |= (rs2 & 0xff) << shift; + StepRecord::new_s_instruction( + 68 + i as u64 * 4, + ByteAddr(0x4800 + i * 4), + encode_rv32(InsnKind::SB, 11 + i, 12 + i, 0, imm), + rs1, + rs2, + WriteOp { + addr: addr.waddr(), + value: Change::new(prev, next), + previous_cycle: 4, + }, + 8, + ) + }) + .collect(); + + assert_lk_shardram_match::>( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &steps, + ); + } + + #[test] + fn test_mul_lk_shardram_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "mul_lk_shardram"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + MulInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let cases = [(2u32, 11u32), (u32::MAX, 17u32)]; + let steps: Vec<_> = cases + .into_iter() + .enumerate() + .map(|(i, (lhs, rhs))| { + StepRecord::new_r_instruction( + 76 + i as u64 * 4, + ByteAddr(0x4900 + i as u32 * 4), + encode_rv32( + InsnKind::MUL, + 13 + i as u32, + 14 + i as u32, + 15 + i as u32, + 0, + ), + lhs, + rhs, + Change::new(0, lhs.wrapping_mul(rhs)), + 0, + ) + }) + .collect(); + + assert_lk_shardram_match::>( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &steps, + ); + } + + #[test] + fn test_mulh_lk_shardram_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "mulh_lk_shardram"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + MulhInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let cases = [(2i32, -11i32), (i32::MIN, -1i32)]; + let steps: Vec<_> = cases + .into_iter() + .enumerate() + .map(|(i, (lhs, rhs))| { + let outcome = ((lhs as i64).wrapping_mul(rhs as i64) >> 32) as u32; + StepRecord::new_r_instruction( + 84 + i as u64 * 4, + ByteAddr(0x4a00 + i as u32 * 4), + encode_rv32( + InsnKind::MULH, + 16 + i as u32, + 17 + i as u32, + 18 + i as u32, + 0, + ), + lhs as u32, + rhs as u32, + Change::new(0, outcome), + 0, + ) + }) + .collect(); + + assert_lk_shardram_match::>( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &steps, + ); + } + + #[test] + fn test_div_lk_shardram_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "div_lk_shardram"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + DivInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let cases = [(17i32, -3i32), (i32::MIN, -1i32)]; + let steps: Vec<_> = cases + .into_iter() + .enumerate() + .map(|(i, (lhs, rhs))| { + let out = if rhs == 0 { + -1i32 + } else { + lhs.wrapping_div(rhs) + } as u32; + StepRecord::new_r_instruction( + 92 + i as u64 * 4, + ByteAddr(0x4b00 + i as u32 * 4), + encode_rv32( + InsnKind::DIV, + 19 + i as u32, + 20 + i as u32, + 21 + i as u32, + 0, + ), + lhs as u32, + rhs as u32, + Change::new(0, out), + 0, + ) + }) + .collect(); + + assert_lk_shardram_match::>( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &steps, + ); + } + + #[test] + fn test_remu_lk_shardram_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "remu_lk_shardram"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + RemuInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let cases = [(17u32, 3u32), (0x8000_0001u32, 0u32)]; + let steps: Vec<_> = cases + .into_iter() + .enumerate() + .map(|(i, (lhs, rhs))| { + let out = if rhs == 0 { lhs } else { lhs % rhs }; + StepRecord::new_r_instruction( + 100 + i as u64 * 4, + ByteAddr(0x4c00 + i as u32 * 4), + encode_rv32( + InsnKind::REMU, + 22 + i as u32, + 23 + i as u32, + 24 + i as u32, + 0, + ), + lhs, + rhs, + Change::new(0, out), + 0, + ) + }) + .collect(); + + assert_lk_shardram_match::>( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &steps, + ); + } +} diff --git a/ceno_zkvm/src/instructions/gpu/utils/column_map.rs b/ceno_zkvm/src/instructions/gpu/utils/column_map.rs new file mode 100644 index 000000000..58f3d5faa --- /dev/null +++ b/ceno_zkvm/src/instructions/gpu/utils/column_map.rs @@ -0,0 +1,211 @@ +//! Shared column-map extraction helpers for GPU witness generation. +//! +//! Each GPU column-map extractor (`add.rs`, `sub.rs`, …) reads `WitIn.id` +//! fields from a circuit config and packs them into a `#[repr(C)]` struct +//! for the CUDA kernel. The base instruction fields (pc, ts, rs1, rs2, rd, +//! mem timestamps) are identical across instruction formats — these helpers +//! eliminate the duplication. + +use ff_ext::ExtensionField; +use gkr_iop::gadgets::AssertLtConfig; +use multilinear_extensions::WitIn; + +use crate::{ + instructions::riscv::insn_base::{ReadMEM, ReadRS1, ReadRS2, StateInOut, WriteMEM, WriteRD}, + uint::UIntLimbs, +}; + +// --------------------------------------------------------------------------- +// StateInOut +// --------------------------------------------------------------------------- + +/// Extract `(pc, ts)` from a non-branching `StateInOut`. +#[inline] +pub fn extract_state(vm: &StateInOut) -> (u32, u32) { + (vm.pc.id as u32, vm.ts.id as u32) +} + +/// Extract `(pc, next_pc, ts)` from a branching `StateInOut`. +#[inline] +pub fn extract_state_branching(vm: &StateInOut) -> (u32, u32, u32) { + ( + vm.pc.id as u32, + vm.next_pc + .expect("branching StateInOut must have next_pc") + .id as u32, + vm.ts.id as u32, + ) +} + +// --------------------------------------------------------------------------- +// Register reads / writes +// --------------------------------------------------------------------------- + +/// Extract `(id, prev_ts, lt_diff[2])` from a `ReadRS1`. +#[inline] +pub fn extract_rs1(rs1: &ReadRS1) -> (u32, u32, [u32; 2]) { + ( + rs1.id.id as u32, + rs1.prev_ts.id as u32, + extract_lt_diff(&rs1.lt_cfg), + ) +} + +/// Extract `(id, prev_ts, lt_diff[2])` from a `ReadRS2`. +#[inline] +pub fn extract_rs2(rs2: &ReadRS2) -> (u32, u32, [u32; 2]) { + ( + rs2.id.id as u32, + rs2.prev_ts.id as u32, + extract_lt_diff(&rs2.lt_cfg), + ) +} + +/// Extract `(id, prev_ts, prev_val[2], lt_diff[2])` from a `WriteRD`. +#[inline] +pub fn extract_rd(rd: &WriteRD) -> (u32, u32, [u32; 2], [u32; 2]) { + let prev_val = extract_uint_limbs::(&rd.prev_value, "WriteRD prev_value"); + ( + rd.id.id as u32, + rd.prev_ts.id as u32, + prev_val, + extract_lt_diff(&rd.lt_cfg), + ) +} + +// --------------------------------------------------------------------------- +// Memory reads / writes +// --------------------------------------------------------------------------- + +/// Extract `(prev_ts, lt_diff[2])` from a `ReadMEM`. +#[inline] +pub fn extract_read_mem(mem: &ReadMEM) -> (u32, [u32; 2]) { + (mem.prev_ts.id as u32, extract_lt_diff(&mem.lt_cfg)) +} + +/// Extract `(prev_ts, lt_diff[2])` from a `WriteMEM`. +#[inline] +pub fn extract_write_mem(mem: &WriteMEM) -> (u32, [u32; 2]) { + (mem.prev_ts.id as u32, extract_lt_diff(&mem.lt_cfg)) +} + +// --------------------------------------------------------------------------- +// Primitive helpers +// --------------------------------------------------------------------------- + +/// Extract the two diff-limb column IDs from an `AssertLtConfig`. +#[inline] +pub fn extract_lt_diff(lt: &AssertLtConfig) -> [u32; 2] { + let d = <.0.diff; + assert_eq!( + d.len(), + 2, + "Expected 2 AssertLt diff limbs, got {}", + d.len() + ); + [d[0].id as u32, d[1].id as u32] +} + +/// Extract `N` limb column IDs from a `UIntLimbs` via `wits_in()`. +#[inline] +pub fn extract_uint_limbs( + u: &UIntLimbs, + label: &str, +) -> [u32; N] { + let limbs = u + .wits_in() + .unwrap_or_else(|| panic!("{label} should have WitIn limbs")); + assert_eq!( + limbs.len(), + N, + "Expected {N} limbs for {label}, got {}", + limbs.len() + ); + std::array::from_fn(|i| limbs[i].id as u32) +} + +/// Extract `N` carry column IDs from a `UIntLimbs`'s carries. +#[inline] +pub fn extract_carries( + u: &UIntLimbs, + label: &str, +) -> [u32; N] { + let carries = u + .carries + .as_ref() + .unwrap_or_else(|| panic!("{label} should have carries")); + assert_eq!( + carries.len(), + N, + "Expected {N} carries for {label}, got {}", + carries.len() + ); + std::array::from_fn(|i| carries[i].id as u32) +} + +/// Extract `N` column IDs from a `&[WitIn]` slice (e.g. byte decomposition). +#[inline] +pub fn extract_wit_ids(wits: &[WitIn], label: &str) -> [u32; N] { + assert_eq!( + wits.len(), + N, + "Expected {N} WitIn entries for {label}, got {}", + wits.len() + ); + std::array::from_fn(|i| wits[i].id as u32) +} + +// --------------------------------------------------------------------------- +// Test helpers +// --------------------------------------------------------------------------- + +#[cfg(test)] +pub fn validate_column_map(flat: &[u32], num_cols: u32) { + for (i, &col) in flat.iter().enumerate() { + assert!( + col < num_cols, + "Column index {i}: value {col} out of range (num_cols = {num_cols})" + ); + } + let mut seen = std::collections::HashSet::new(); + for &col in flat { + assert!(seen.insert(col), "Duplicate column ID: {col}"); + } +} + +/// Generate a `test_extract_*_column_map` test that: +/// 1. Constructs the circuit via `$Instruction::construct_circuit` +/// 2. Extracts the column map via `$extract_fn` +/// 3. Validates all column IDs are in-range and unique +/// +/// Usage: +/// ```ignore +/// test_colmap!(test_extract_add_column_map, AddInstruction, extract_add_column_map); +/// // Multi-variant (e.g. div/divu/rem/remu sharing one extractor): +/// test_colmap!(test_extract_divu_column_map, DivuInstruction, extract_div_column_map); +/// ``` +#[cfg(test)] +macro_rules! test_colmap { + ($test_name:ident, $Instruction:ty, $extract_fn:ident) => { + #[test] + fn $test_name() { + let mut cs = + crate::circuit_builder::ConstraintSystem::::new(|| stringify!($test_name)); + let mut cb = crate::circuit_builder::CircuitBuilder::new(&mut cs); + let config = <$Instruction as crate::instructions::Instruction>::construct_circuit( + &mut cb, + &crate::structs::ProgramParams::default(), + ) + .unwrap(); + let col_map = $extract_fn(&config, cb.cs.num_witin as usize); + let flat = col_map.to_flat(); + crate::instructions::gpu::utils::column_map::validate_column_map( + &flat, + col_map.num_cols, + ); + } + }; +} + +#[cfg(test)] +pub(crate) use test_colmap; diff --git a/ceno_zkvm/src/instructions/gpu/utils/d2h.rs b/ceno_zkvm/src/instructions/gpu/utils/d2h.rs new file mode 100644 index 000000000..1a4c96acc --- /dev/null +++ b/ceno_zkvm/src/instructions/gpu/utils/d2h.rs @@ -0,0 +1,262 @@ +/// Device-to-host conversion functions for GPU witness generation. +/// +/// This module handles: +/// - Type aliases for GPU buffer types +/// - D2H transfer of witness matrices (transpose + copy) +/// - D2H transfer of lookup counter buffers +/// - D2H transfer of compact EC records +/// - Conversion between host and GPU shard RAM record formats +/// - Batch EC point computation on GPU for continuation circuits +use ceno_emul::WordAddr; +use ceno_gpu::{ + Buffer, CudaHal, + bb31::CudaHalBB31, + common::{ + transpose::matrix_transpose, + witgen::types::{CompactEcResult, GpuRamRecordSlot, GpuShardRamRecord}, + }, +}; +use ff_ext::ExtensionField; +use gkr_iop::{RAMType, tables::LookupTable, utils::lk_multiplicity::Multiplicity}; +use p3::field::FieldAlgebra; +use rustc_hash::FxHashMap; +use tracing::info_span; +use witness::{InstancePaddingStrategy, RowMajorMatrix}; + +use crate::{ + e2e::{RAMRecord, ShardContext}, + error::ZKVMError, +}; + +pub(crate) type WitBuf = + ceno_gpu::common::BufferImpl<'static, ::BaseField>; +pub(crate) type LkBuf = ceno_gpu::common::BufferImpl<'static, u32>; +pub(crate) type RamBuf = ceno_gpu::common::BufferImpl<'static, u32>; +pub(crate) type WitResult = ceno_gpu::common::witgen::types::GpuWitnessResult; +pub(crate) type LkResult = ceno_gpu::common::witgen::types::GpuLookupCountersResult; +pub(crate) type CompactEcBuf = ceno_gpu::common::witgen::types::CompactEcResult; + +/// CPU-side lightweight scan of GPU-produced RAM record slots. +/// +/// Reconstructs BTreeMap read/write records and addr_accessed from the GPU output, +/// replacing the previous `collect_shardram()` CPU loop. +pub(crate) fn gpu_collect_shard_records(shard_ctx: &mut ShardContext, slots: &[GpuRamRecordSlot]) { + let current_shard_id = shard_ctx.shard_id; + + for slot in slots { + // Check was_sent flag (bit 4): this slot corresponds to a send() call + if slot.flags & (1 << 4) != 0 { + shard_ctx.push_addr_accessed(WordAddr(slot.addr)); + } + + // Check active flag (bit 0): this slot has a read or write record + if slot.flags & 1 == 0 { + continue; + } + + let ram_type = match (slot.flags >> 5) & 0x7 { + 1 => RAMType::Register, + 2 => RAMType::Memory, + _ => continue, + }; + let has_prev_value = slot.flags & (1 << 3) != 0; + let prev_value = if has_prev_value { + Some(slot.prev_value) + } else { + None + }; + let addr = WordAddr(slot.addr); + + // Insert read record (bit 1) + if slot.flags & (1 << 1) != 0 { + shard_ctx.insert_read_record( + addr, + RAMRecord { + ram_type, + reg_id: slot.reg_id as u64, + addr, + prev_cycle: slot.prev_cycle, + cycle: slot.cycle, + shard_cycle: 0, + prev_value, + value: slot.value, + shard_id: slot.read_shard_id as usize, + }, + ); + } + + // Insert write record (bit 2) + if slot.flags & (1 << 2) != 0 { + shard_ctx.insert_write_record( + addr, + RAMRecord { + ram_type, + reg_id: slot.reg_id as u64, + addr, + prev_cycle: slot.prev_cycle, + cycle: slot.cycle, + shard_cycle: slot.shard_cycle, + prev_value, + value: slot.value, + shard_id: current_shard_id, + }, + ); + } + } +} + +/// D2H the compact EC result: read count, then partial-D2H only that many records. +pub(crate) fn gpu_compact_ec_d2h( + compact: &CompactEcResult, +) -> Result, ZKVMError> { + // D2H the count (1 u32) + let count_vec: Vec = compact + .count_buf + .to_vec() + .map_err(|e| ZKVMError::InvalidWitness(format!("compact_count D2H failed: {e}").into()))?; + let count = count_vec[0] as usize; + if count == 0 { + return Ok(vec![]); + } + + // Partial D2H: only transfer the first `count` records (not the full allocation) + let record_u32s = std::mem::size_of::() / 4; // 26 + let total_u32s = count * record_u32s; + let buf_vec: Vec = compact + .buffer + .to_vec_n(total_u32s) + .map_err(|e| ZKVMError::InvalidWitness(format!("compact_out D2H failed: {e}").into()))?; + + let records: Vec = unsafe { + let ptr = buf_vec.as_ptr() as *const GpuShardRamRecord; + std::slice::from_raw_parts(ptr, count).to_vec() + }; + tracing::debug!( + "GPU EC compact D2H: {} active records ({} bytes)", + count, + total_u32s * 4 + ); + Ok(records) +} + +pub(crate) fn gpu_lk_counters_to_multiplicity( + counters: LkResult, +) -> Result, ZKVMError> { + let mut tables: [FxHashMap; 8] = Default::default(); + + // Dynamic: D2H + direct FxHashMap construction (no LkMultiplicity) + info_span!("lk_dynamic_d2h").in_scope(|| { + let counts: Vec = counters.dynamic.to_vec().map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU dynamic lk D2H failed: {e}").into()) + })?; + let nnz = counts.iter().filter(|&&c| c != 0).count(); + let map = &mut tables[LookupTable::Dynamic as usize]; + map.reserve(nnz); + for (key, &count) in counts.iter().enumerate() { + if count != 0 { + map.insert(key as u64, count as usize); + } + } + Ok::<(), ZKVMError>(()) + })?; + + // Dense tables: same pattern, skip None + info_span!("lk_dense_d2h").in_scope(|| { + let dense: &[(LookupTable, &Option)] = &[ + (LookupTable::DoubleU8, &counters.double_u8), + (LookupTable::And, &counters.and_table), + (LookupTable::Or, &counters.or_table), + (LookupTable::Xor, &counters.xor_table), + (LookupTable::Ltu, &counters.ltu_table), + (LookupTable::Pow, &counters.pow_table), + ]; + for &(table, ref buf_opt) in dense { + if let Some(buf) = buf_opt { + let counts: Vec = buf.to_vec().map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU {:?} lk D2H failed: {e}", table).into()) + })?; + let nnz = counts.iter().filter(|&&c| c != 0).count(); + let map = &mut tables[table as usize]; + map.reserve(nnz); + for (key, &count) in counts.iter().enumerate() { + if count != 0 { + map.insert(key as u64, count as usize); + } + } + } + } + Ok::<(), ZKVMError>(()) + })?; + + // Fetch (Instruction table) + if let Some(fetch_buf) = counters.fetch { + info_span!("lk_fetch_d2h").in_scope(|| { + let base_pc = counters.fetch_base_pc; + let counts = fetch_buf.to_vec().map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU fetch lk D2H failed: {e}").into()) + })?; + let nnz = counts.iter().filter(|&&c| c != 0).count(); + let map = &mut tables[LookupTable::Instruction as usize]; + map.reserve(nnz); + for (slot_idx, &count) in counts.iter().enumerate() { + if count != 0 { + let pc = base_pc as u64 + (slot_idx as u64) * 4; + map.insert(pc, count as usize); + } + } + Ok::<(), ZKVMError>(()) + })?; + } + + Ok(Multiplicity(tables)) +} + +/// Convert GPU device buffer (column-major) to RowMajorMatrix via GPU transpose + D2H copy. +/// +/// GPU witgen kernels output column-major layout for better memory coalescing. +/// This function transposes to row-major on GPU before copying to host. +pub(crate) fn gpu_witness_to_rmm( + hal: &CudaHalBB31, + gpu_result: ceno_gpu::common::witgen::types::GpuWitnessResult< + ceno_gpu::common::BufferImpl<'static, ::BaseField>, + >, + num_rows: usize, + num_cols: usize, + padding: InstancePaddingStrategy, +) -> Result, ZKVMError> { + // Transpose from column-major to row-major on GPU. + // Column-major (num_rows x num_cols) is stored as num_cols groups of num_rows elements, + // which is equivalent to a (num_cols x num_rows) row-major matrix. + // Transposing with cols=num_rows, rows=num_cols produces (num_rows x num_cols) row-major. + let mut rmm_buffer = hal + .alloc_elems_on_device(num_rows * num_cols, false, None) + .map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU alloc for transpose failed: {e}").into()) + })?; + matrix_transpose::( + &hal.inner, + &mut rmm_buffer, + &gpu_result.device_buffer, + num_rows, + num_cols, + ) + .map_err(|e| ZKVMError::InvalidWitness(format!("GPU transpose failed: {e}").into()))?; + + let gpu_data: Vec<::BaseField> = rmm_buffer + .to_vec() + .map_err(|e| ZKVMError::InvalidWitness(format!("GPU D2H copy failed: {e}").into()))?; + + // Safety: BabyBear is the only supported GPU field, and E::BaseField must match + let data: Vec = unsafe { + let mut data = std::mem::ManuallyDrop::new(gpu_data); + Vec::from_raw_parts( + data.as_mut_ptr() as *mut E::BaseField, + data.len(), + data.capacity(), + ) + }; + + Ok(RowMajorMatrix::::new_by_values( + data, num_cols, padding, + )) +} diff --git a/ceno_zkvm/src/instructions/gpu/utils/debug_compare.rs b/ceno_zkvm/src/instructions/gpu/utils/debug_compare.rs new file mode 100644 index 000000000..1c91b6071 --- /dev/null +++ b/ceno_zkvm/src/instructions/gpu/utils/debug_compare.rs @@ -0,0 +1,789 @@ +/// Debug comparison functions for GPU witness generation. +/// +/// These functions compare GPU-produced results against CPU baselines +/// to validate correctness. Activated by environment variables: +/// All comparisons are activated by setting `CENO_GPU_DEBUG_COMPARE_WITGEN=1`. +/// This enables: LK multiplicity, witness matrix, shardram records, and EC point comparison. +use ceno_emul::{StepIndex, StepRecord, WordAddr}; +use ceno_gpu::common::witgen::types::{GpuRamRecordSlot, GpuShardRamRecord}; +use ff_ext::ExtensionField; +use gkr_iop::{RAMType, tables::LookupTable, utils::lk_multiplicity::Multiplicity}; +use p3::field::FieldAlgebra; +use std::cell::Cell; +use witness::RowMajorMatrix; + +use crate::{ + e2e::ShardContext, + error::ZKVMError, + instructions::{Instruction, cpu_collect_lk_and_shardram, cpu_collect_shardram}, +}; + +use crate::instructions::gpu::dispatch::{GpuWitgenKind, set_force_cpu_path}; + +pub(crate) fn debug_compare_final_lk>( + config: &I::InstructionConfig, + shard_ctx: &ShardContext, + num_witin: usize, + num_structural_witin: usize, + shard_steps: &[StepRecord], + step_indices: &[StepIndex], + kind: GpuWitgenKind, + mixed_lk: &Multiplicity, +) -> Result<(), ZKVMError> { + if !crate::instructions::gpu::config::is_debug_compare_enabled() { + return Ok(()); + } + + // Compare against cpu_assign_instances (the true baseline using assign_instance) + let mut cpu_ctx = shard_ctx.new_empty_like(); + let (_, cpu_assign_lk) = crate::instructions::cpu_assign_instances::( + config, + &mut cpu_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + )?; + tracing::info!("[GPU lk debug] kind={kind:?} comparing mixed_lk vs cpu_assign_instances lk"); + log_lk_diff(kind, &cpu_assign_lk, mixed_lk); + Ok(()) +} + +pub(crate) fn log_lk_diff( + kind: GpuWitgenKind, + cpu_lk: &Multiplicity, + actual_lk: &Multiplicity, +) { + let limit: usize = 16; + + let mut total_diffs = 0usize; + for (table_idx, (cpu_table, actual_table)) in cpu_lk.iter().zip(actual_lk.iter()).enumerate() { + let mut keys = cpu_table + .keys() + .chain(actual_table.keys()) + .copied() + .collect::>(); + keys.sort_unstable(); + keys.dedup(); + + let mut table_diffs = Vec::new(); + for key in keys { + let cpu_count = cpu_table.get(&key).copied().unwrap_or(0); + let actual_count = actual_table.get(&key).copied().unwrap_or(0); + if cpu_count != actual_count { + table_diffs.push((key, cpu_count, actual_count)); + } + } + + if !table_diffs.is_empty() { + total_diffs += table_diffs.len(); + tracing::error!( + "[GPU lk debug] kind={kind:?} table={} diff_count={}", + lookup_table_name(table_idx), + table_diffs.len() + ); + for (key, cpu_count, actual_count) in table_diffs.into_iter().take(limit) { + tracing::error!( + "[GPU lk debug] kind={kind:?} table={} key={} cpu={} gpu={}", + lookup_table_name(table_idx), + key, + cpu_count, + actual_count + ); + } + } + } + + if total_diffs == 0 { + tracing::info!("[GPU lk debug] kind={kind:?} CPU/GPU lookup multiplicities match"); + } +} + +pub(crate) fn debug_compare_witness>( + config: &I::InstructionConfig, + shard_ctx: &ShardContext, + num_witin: usize, + num_structural_witin: usize, + shard_steps: &[StepRecord], + step_indices: &[StepIndex], + kind: GpuWitgenKind, + gpu_witness: &RowMajorMatrix, +) -> Result<(), ZKVMError> { + if !crate::instructions::gpu::config::is_debug_compare_enabled() { + return Ok(()); + } + + let mut cpu_ctx = shard_ctx.new_empty_like(); + let (cpu_rmms, _) = crate::instructions::cpu_assign_instances::( + config, + &mut cpu_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + )?; + let cpu_witness = &cpu_rmms[0]; + let cpu_vals = cpu_witness.values(); + let gpu_vals = gpu_witness.values(); + if cpu_vals == gpu_vals { + return Ok(()); + } + + let limit: usize = 16; + let cpu_num_cols = cpu_witness.n_col(); + let cpu_num_rows = cpu_vals.len() / cpu_num_cols; + let mut mismatches = 0usize; + for row in 0..cpu_num_rows { + for col in 0..cpu_num_cols { + let idx = row * cpu_num_cols + col; + if cpu_vals[idx] != gpu_vals[idx] { + mismatches += 1; + if mismatches <= limit { + tracing::error!( + "[GPU witness debug] kind={kind:?} row={} col={} cpu={:?} gpu={:?}", + row, + col, + cpu_vals[idx], + gpu_vals[idx] + ); + } + } + } + } + tracing::error!( + "[GPU witness debug] kind={kind:?} total_mismatches={}", + mismatches + ); + Ok(()) +} + +pub(crate) fn debug_compare_shardram>( + config: &I::InstructionConfig, + shard_ctx: &ShardContext, + shard_steps: &[StepRecord], + step_indices: &[StepIndex], + kind: GpuWitgenKind, +) -> Result<(), ZKVMError> { + if !crate::instructions::gpu::config::is_debug_compare_enabled() { + return Ok(()); + } + + let mut cpu_ctx = shard_ctx.new_empty_like(); + let _ = cpu_collect_lk_and_shardram::(config, &mut cpu_ctx, shard_steps, step_indices)?; + + let mut mixed_ctx = shard_ctx.new_empty_like(); + let _ = cpu_collect_shardram::(config, &mut mixed_ctx, shard_steps, step_indices)?; + + let cpu_addr = cpu_ctx.get_addr_accessed(); + let mixed_addr = mixed_ctx.get_addr_accessed(); + if cpu_addr != mixed_addr { + tracing::error!( + "[GPU shard debug] kind={kind:?} addr_accessed cpu={} gpu={}", + cpu_addr.len(), + mixed_addr.len() + ); + } + + let cpu_reads = flatten_ram_records(cpu_ctx.read_records()); + let mixed_reads = flatten_ram_records(mixed_ctx.read_records()); + if cpu_reads != mixed_reads { + log_ram_record_diff(kind, "read_records", &cpu_reads, &mixed_reads); + } + + let cpu_writes = flatten_ram_records(cpu_ctx.write_records()); + let mixed_writes = flatten_ram_records(mixed_ctx.write_records()); + if cpu_writes != mixed_writes { + log_ram_record_diff(kind, "write_records", &cpu_writes, &mixed_writes); + } + + Ok(()) +} + +/// Compare GPU shard context vs CPU shard context, field by field. +/// +/// Both paths are independent and produce equivalent ShardContext state: +/// CPU path: cpu_collect_shardram -> addr_accessed + write_records + read_records +/// GPU path: compact_records -> shard records (gpu_ec_records) +/// ram_slots WAS_SENT -> addr_accessed +/// (write_records and read_records stay empty for GPU EC kernels) +/// +/// This function builds both independently and compares: +/// A. addr_accessed sets +/// B. shard records (sorted, normalized to ShardRamRecord) +/// C. EC points (nonce + SepticPoint x,y) +/// +/// Activated by CENO_GPU_DEBUG_COMPARE_WITGEN=1. +pub(crate) fn debug_compare_shard_ec>( + compact_records: &[GpuShardRamRecord], + ram_slots: &[GpuRamRecordSlot], + config: &I::InstructionConfig, + shard_ctx: &ShardContext, + shard_steps: &[StepRecord], + step_indices: &[StepIndex], + kind: GpuWitgenKind, +) { + if !crate::instructions::gpu::config::is_debug_compare_enabled() { + return; + } + + use crate::{ + scheme::septic_curve::{SepticExtension, SepticPoint}, + tables::{ECPoint, ShardRamRecord}, + }; + use ff_ext::{PoseidonField, SmallField}; + + let limit: usize = 16; + + // ========== Build CPU shard context (independent, isolated) ========== + let mut cpu_ctx = shard_ctx.new_empty_like(); + if let Err(e) = cpu_collect_shardram::(config, &mut cpu_ctx, shard_steps, step_indices) { + tracing::error!("[GPU EC debug] kind={kind:?} CPU shardram records failed: {e:?}"); + return; + } + + let perm = ::get_default_perm(); + + // CPU: addr_accessed + let cpu_addr = cpu_ctx.get_addr_accessed(); + + // CPU: shard records (BTreeMap -> ShardRamRecord + ECPoint) + let mut cpu_entries: Vec<(ShardRamRecord, ECPoint)> = Vec::new(); + for records in cpu_ctx.write_records() { + for (vma, record) in records { + let rec: ShardRamRecord = (vma, record, true).into(); + let ec = rec.to_ec_point::(&perm); + cpu_entries.push((rec, ec)); + } + } + for records in cpu_ctx.read_records() { + for (vma, record) in records { + let rec: ShardRamRecord = (vma, record, false).into(); + let ec = rec.to_ec_point::(&perm); + cpu_entries.push((rec, ec)); + } + } + cpu_entries.sort_by_key(|(r, _)| (r.addr, r.is_to_write_set as u8, r.ram_type as u8)); + + // ========== Build GPU shard context (independent, from D2H data only) ========== + + // GPU: addr_accessed (from ram_slots WAS_SENT flags) + let gpu_addr: rustc_hash::FxHashSet = ram_slots + .iter() + .filter(|s| s.flags & (1 << 4) != 0) + .map(|s| WordAddr(s.addr)) + .collect(); + + // GPU: shard records (compact_records -> ShardRamRecord + ECPoint) + let mut gpu_entries: Vec<(ShardRamRecord, ECPoint)> = compact_records + .iter() + .map(|g| { + let rec = ShardRamRecord { + addr: g.addr, + ram_type: if g.ram_type == 1 { + RAMType::Register + } else { + RAMType::Memory + }, + value: g.value, + shard: g.shard, + local_clk: g.local_clk, + global_clk: g.global_clk, + is_to_write_set: g.is_to_write_set != 0, + }; + let x = SepticExtension(g.point_x.map(|v| E::BaseField::from_canonical_u32(v))); + let y = SepticExtension(g.point_y.map(|v| E::BaseField::from_canonical_u32(v))); + let point = SepticPoint::from_affine(x, y); + let ec = ECPoint:: { + nonce: g.nonce, + point, + }; + (rec, ec) + }) + .collect(); + gpu_entries.sort_by_key(|(r, _)| (r.addr, r.is_to_write_set as u8, r.ram_type as u8)); + + // ========== Compare A: addr_accessed ========== + if cpu_addr != gpu_addr { + let cpu_only: Vec<_> = cpu_addr.difference(&gpu_addr).collect(); + let gpu_only: Vec<_> = gpu_addr.difference(&cpu_addr).collect(); + tracing::error!( + "[GPU EC debug] kind={kind:?} ADDR_ACCESSED MISMATCH: cpu={} gpu={} \ + cpu_only={} gpu_only={}", + cpu_addr.len(), + gpu_addr.len(), + cpu_only.len(), + gpu_only.len() + ); + for (i, addr) in cpu_only.iter().enumerate() { + if i >= limit { + break; + } + tracing::error!( + "[GPU EC debug] kind={kind:?} addr_accessed CPU-only: {}", + addr.0 + ); + } + for (i, addr) in gpu_only.iter().enumerate() { + if i >= limit { + break; + } + tracing::error!( + "[GPU EC debug] kind={kind:?} addr_accessed GPU-only: {}", + addr.0 + ); + } + } + + // ========== Compare B+C: shard records + EC points ========== + + // Check counts + if cpu_entries.len() != gpu_entries.len() { + tracing::error!( + "[GPU EC debug] kind={kind:?} RECORD COUNT MISMATCH: cpu={} gpu={}", + cpu_entries.len(), + gpu_entries.len() + ); + let cpu_keys: std::collections::BTreeSet<_> = cpu_entries + .iter() + .map(|(r, _)| (r.addr, r.is_to_write_set)) + .collect(); + let gpu_keys: std::collections::BTreeSet<_> = gpu_entries + .iter() + .map(|(r, _)| (r.addr, r.is_to_write_set)) + .collect(); + let mut logged = 0usize; + for key in cpu_keys.difference(&gpu_keys) { + if logged >= limit { + break; + } + tracing::error!( + "[GPU EC debug] kind={kind:?} CPU-only: addr={} is_write={}", + key.0, + key.1 + ); + logged += 1; + } + for key in gpu_keys.difference(&cpu_keys) { + if logged >= limit { + break; + } + tracing::error!( + "[GPU EC debug] kind={kind:?} GPU-only: addr={} is_write={}", + key.0, + key.1 + ); + logged += 1; + } + } + + // Check GPU duplicates (BTreeMap deduplicates, atomicAdd doesn't) + let mut gpu_dup_count = 0usize; + for w in gpu_entries.windows(2) { + if w[0].0.addr == w[1].0.addr + && w[0].0.is_to_write_set == w[1].0.is_to_write_set + && w[0].0.ram_type == w[1].0.ram_type + { + gpu_dup_count += 1; + if gpu_dup_count <= limit { + tracing::error!( + "[GPU EC debug] kind={kind:?} GPU DUPLICATE: addr={} is_write={} ram_type={:?}", + w[0].0.addr, + w[0].0.is_to_write_set, + w[0].0.ram_type + ); + } + } + } + + // Merge-walk sorted lists + let mut ci = 0usize; + let mut gi = 0usize; + let mut record_mismatches = 0usize; + let mut ec_mismatches = 0usize; + let mut matched = 0usize; + + while ci < cpu_entries.len() && gi < gpu_entries.len() { + let (cr, ce) = &cpu_entries[ci]; + let (gr, ge) = &gpu_entries[gi]; + let ck = (cr.addr, cr.is_to_write_set as u8, cr.ram_type as u8); + let gk = (gr.addr, gr.is_to_write_set as u8, gr.ram_type as u8); + + match ck.cmp(&gk) { + std::cmp::Ordering::Less => { + if record_mismatches < limit { + tracing::error!( + "[GPU EC debug] kind={kind:?} MISSING in GPU: addr={} is_write={} ram={:?} val={} shard={} clk={}", + cr.addr, + cr.is_to_write_set, + cr.ram_type, + cr.value, + cr.shard, + cr.global_clk + ); + } + record_mismatches += 1; + ci += 1; + continue; + } + std::cmp::Ordering::Greater => { + if record_mismatches < limit { + tracing::error!( + "[GPU EC debug] kind={kind:?} EXTRA in GPU: addr={} is_write={} ram={:?} val={} shard={} clk={}", + gr.addr, + gr.is_to_write_set, + gr.ram_type, + gr.value, + gr.shard, + gr.global_clk + ); + } + record_mismatches += 1; + gi += 1; + continue; + } + std::cmp::Ordering::Equal => {} + } + + // Keys match -- compare record fields + let mut field_diff = false; + for (name, cv, gv) in [ + ("value", cr.value as u64, gr.value as u64), + ("shard", cr.shard, gr.shard), + ("local_clk", cr.local_clk, gr.local_clk), + ("global_clk", cr.global_clk, gr.global_clk), + ] { + if cv != gv { + field_diff = true; + if record_mismatches < limit { + tracing::error!( + "[GPU EC debug] kind={kind:?} addr={} {name}: cpu={cv} gpu={gv}", + cr.addr + ); + } + } + } + if field_diff { + record_mismatches += 1; + } + + // Compare EC points + let mut ec_diff = false; + if ce.nonce != ge.nonce { + ec_diff = true; + if ec_mismatches < limit { + tracing::error!( + "[GPU EC debug] kind={kind:?} addr={} nonce: cpu={} gpu={}", + cr.addr, + ce.nonce, + ge.nonce + ); + } + } + for j in 0..7 { + let cv = ce.point.x.0[j].to_canonical_u64() as u32; + let gv = ge.point.x.0[j].to_canonical_u64() as u32; + if cv != gv { + ec_diff = true; + if ec_mismatches < limit { + tracing::error!( + "[GPU EC debug] kind={kind:?} addr={} x[{j}]: cpu={cv} gpu={gv}", + cr.addr + ); + } + } + } + for j in 0..7 { + let cv = ce.point.y.0[j].to_canonical_u64() as u32; + let gv = ge.point.y.0[j].to_canonical_u64() as u32; + if cv != gv { + ec_diff = true; + if ec_mismatches < limit { + tracing::error!( + "[GPU EC debug] kind={kind:?} addr={} y[{j}]: cpu={cv} gpu={gv}", + cr.addr + ); + } + } + } + if ec_diff { + ec_mismatches += 1; + } + + matched += 1; + ci += 1; + gi += 1; + } + + // Remaining unmatched + while ci < cpu_entries.len() { + if record_mismatches < limit { + let (cr, _) = &cpu_entries[ci]; + tracing::error!( + "[GPU EC debug] kind={kind:?} MISSING in GPU (tail): addr={} is_write={} val={}", + cr.addr, + cr.is_to_write_set, + cr.value + ); + } + record_mismatches += 1; + ci += 1; + } + while gi < gpu_entries.len() { + if record_mismatches < limit { + let (gr, _) = &gpu_entries[gi]; + tracing::error!( + "[GPU EC debug] kind={kind:?} EXTRA in GPU (tail): addr={} is_write={} val={}", + gr.addr, + gr.is_to_write_set, + gr.value + ); + } + record_mismatches += 1; + gi += 1; + } + + // ========== Summary ========== + let addr_ok = cpu_addr == gpu_addr; + if addr_ok && record_mismatches == 0 && ec_mismatches == 0 && gpu_dup_count == 0 { + tracing::info!( + "[GPU EC debug] kind={kind:?} ALL MATCH: {} records, {} addr_accessed, EC points OK", + matched, + cpu_addr.len() + ); + } else { + tracing::error!( + "[GPU EC debug] kind={kind:?} MISMATCH: matched={matched} record_diffs={record_mismatches} \ + ec_diffs={ec_mismatches} gpu_dups={gpu_dup_count} addr_ok={addr_ok} \ + (cpu_records={} gpu_records={} cpu_addrs={} gpu_addrs={})", + cpu_entries.len(), + gpu_entries.len(), + cpu_addr.len(), + gpu_addr.len() + ); + } +} + +pub(crate) fn flatten_ram_records( + records: &[std::collections::BTreeMap], +) -> Vec<(u32, u64, u64, u64, u64, Option, u32, usize)> { + let mut flat = Vec::new(); + for table in records { + for (addr, record) in table { + flat.push(( + addr.0, + record.reg_id, + record.prev_cycle, + record.cycle, + record.shard_cycle, + record.prev_value, + record.value, + record.shard_id, + )); + } + } + flat +} + +pub(crate) fn log_ram_record_diff( + kind: GpuWitgenKind, + label: &str, + cpu_records: &[(u32, u64, u64, u64, u64, Option, u32, usize)], + mixed_records: &[(u32, u64, u64, u64, u64, Option, u32, usize)], +) { + let limit: usize = 16; + tracing::error!( + "[GPU shard debug] kind={kind:?} {} cpu={} gpu={}", + label, + cpu_records.len(), + mixed_records.len() + ); + let max_len = cpu_records.len().max(mixed_records.len()); + let mut logged = 0usize; + for idx in 0..max_len { + let cpu = cpu_records.get(idx); + let gpu = mixed_records.get(idx); + if cpu != gpu { + tracing::error!( + "[GPU shard debug] kind={kind:?} {} idx={} cpu={:?} gpu={:?}", + label, + idx, + cpu, + gpu + ); + logged += 1; + if logged >= limit { + break; + } + } + } +} + +pub(crate) fn lookup_table_name(table_idx: usize) -> &'static str { + match table_idx { + x if x == LookupTable::Dynamic as usize => "Dynamic", + x if x == LookupTable::DoubleU8 as usize => "DoubleU8", + x if x == LookupTable::And as usize => "And", + x if x == LookupTable::Or as usize => "Or", + x if x == LookupTable::Xor as usize => "Xor", + x if x == LookupTable::Ltu as usize => "Ltu", + x if x == LookupTable::Pow as usize => "Pow", + x if x == LookupTable::Instruction as usize => "Instruction", + _ => "Unknown", + } +} + +/// Debug comparison for keccak GPU witgen. +/// Runs the CPU path and compares LK / witness / shardram records. +/// +/// Activated by CENO_GPU_DEBUG_COMPARE_WITGEN=1. +#[cfg(feature = "gpu")] +pub(crate) fn debug_compare_keccak( + config: &crate::instructions::riscv::ecall::keccak::EcallKeccakConfig, + shard_ctx: &ShardContext, + num_witin: usize, + num_structural_witin: usize, + steps: &[StepRecord], + step_indices: &[StepIndex], + gpu_lk: &Multiplicity, + gpu_witin: &RowMajorMatrix, + gpu_addrs: &[u32], +) -> Result<(), ZKVMError> { + let enabled = crate::instructions::gpu::config::is_debug_compare_enabled(); + let want_lk = enabled; + let want_witness = enabled; + let want_shard = enabled; + + if !want_lk && !want_witness && !want_shard { + return Ok(()); + } + + // Guard against recursion: is_gpu_witgen_enabled() uses OnceLock so env var + // manipulation doesn't work. Use a thread-local flag instead. + thread_local! { + static IN_DEBUG_COMPARE: Cell = const { Cell::new(false) }; + } + if IN_DEBUG_COMPARE.with(|f| f.get()) { + return Ok(()); + } + IN_DEBUG_COMPARE.with(|f| f.set(true)); + + tracing::info!("[GPU keccak debug] running CPU baseline for comparison"); + + // Run CPU path via assign_instances. The IN_DEBUG_COMPARE guard prevents + // gpu_assign_keccak_instances (in chips/keccak.rs) from calling debug_compare_keccak again, + // so it will produce the GPU result, which is then returned without + // re-entering this function. We need assign_instances (not cpu_assign_instances) + // because keccak has rotation matrices and 3 structural columns. + // + // To force the CPU path, we use is_force_cpu_path() by setting the env var. + let mut cpu_ctx = shard_ctx.new_empty_like(); + let (cpu_rmms, cpu_lk) = { + use crate::instructions::riscv::ecall::keccak::KeccakInstruction; + // Set force-CPU flag so gpu_assign_keccak_instances returns None + set_force_cpu_path(true); + let result = + as crate::instructions::Instruction>::assign_instances( + config, + &mut cpu_ctx, + num_witin, + num_structural_witin, + steps, + step_indices, + ); + set_force_cpu_path(false); + IN_DEBUG_COMPARE.with(|f| f.set(false)); + result? + }; + + let kind = GpuWitgenKind::Keccak; + + if want_lk { + tracing::info!("[GPU keccak debug] comparing LK multiplicities"); + log_lk_diff(kind, &cpu_lk, gpu_lk); + } + + if want_witness { + let limit: usize = 16; + let cpu_witin = &cpu_rmms[0]; + let gpu_vals = gpu_witin.values(); + let cpu_vals = cpu_witin.values(); + let mut diffs = 0usize; + for (i, (g, c)) in gpu_vals.iter().zip(cpu_vals.iter()).enumerate() { + if g != c { + if diffs < limit { + let row = i / num_witin; + let col = i % num_witin; + tracing::error!( + "[GPU keccak witness] row={} col={} gpu={:?} cpu={:?}", + row, + col, + g, + c + ); + } + diffs += 1; + } + } + if diffs == 0 { + tracing::info!( + "[GPU keccak debug] witness matrices match ({} elements)", + gpu_vals.len() + ); + } else { + tracing::error!( + "[GPU keccak debug] witness mismatch: {} diffs out of {}", + diffs, + gpu_vals.len() + ); + } + } + + if want_shard { + // Compare addr_accessed: GPU entries were D2H'd from the shared buffer + // delta (before/after kernel launch) and passed in as gpu_addrs. + let cpu_addr = cpu_ctx.get_addr_accessed(); + let gpu_addr_set: rustc_hash::FxHashSet = + gpu_addrs.iter().map(|&a| WordAddr(a)).collect(); + + if cpu_addr.len() != gpu_addr_set.len() { + tracing::error!( + "[GPU keccak shard] addr_accessed count mismatch: cpu={} gpu={}", + cpu_addr.len(), + gpu_addr_set.len() + ); + } + let mut missing_from_gpu = 0usize; + let mut extra_in_gpu = 0usize; + let limit = 16usize; + for addr in &cpu_addr { + if !gpu_addr_set.contains(addr) { + if missing_from_gpu < limit { + tracing::error!("[GPU keccak shard] addr {} in CPU but not GPU", addr.0); + } + missing_from_gpu += 1; + } + } + for &addr in gpu_addrs { + if !cpu_addr.contains(&WordAddr(addr)) { + if extra_in_gpu < limit { + tracing::error!("[GPU keccak shard] addr {} in GPU but not CPU", addr); + } + extra_in_gpu += 1; + } + } + if missing_from_gpu == 0 && extra_in_gpu == 0 { + tracing::info!( + "[GPU keccak shard] addr_accessed matches: {} entries", + cpu_addr.len() + ); + } else { + tracing::error!( + "[GPU keccak shard] addr_accessed diff: missing_from_gpu={} extra_in_gpu={}", + missing_from_gpu, + extra_in_gpu + ); + } + } + + Ok(()) +} diff --git a/ceno_zkvm/src/instructions/gpu/utils/emit.rs b/ceno_zkvm/src/instructions/gpu/utils/emit.rs new file mode 100644 index 000000000..ae4647dea --- /dev/null +++ b/ceno_zkvm/src/instructions/gpu/utils/emit.rs @@ -0,0 +1,160 @@ +use gkr_iop::{ + gadgets::{AssertLtConfig, cal_lt_diff}, + tables::{LookupTable, OpsTable}, +}; + +use crate::instructions::riscv::constants::{LIMB_BITS, UINT_LIMBS}; + +use super::{LkOp, LkShardramSink}; + +pub fn emit_assert_lt_ops( + sink: &mut impl LkShardramSink, + lt_cfg: &AssertLtConfig, + lhs: u64, + rhs: u64, +) { + let max_bits = lt_cfg.0.max_bits; + let diff = cal_lt_diff(lhs < rhs, max_bits, lhs, rhs); + for i in 0..(max_bits / u16::BITS as usize) { + let value = ((diff >> (i * u16::BITS as usize)) & 0xffff) as u16; + sink.emit_lk(LkOp::AssertU16 { value }); + } + let remain_bits = max_bits % u16::BITS as usize; + if remain_bits > 1 { + let value = (diff >> ((lt_cfg.0.diff.len() - 1) * u16::BITS as usize)) & 0xffff; + sink.emit_lk(LkOp::DynamicRange { + value, + bits: remain_bits as u32, + }); + } +} + +pub fn emit_u16_limbs(sink: &mut impl LkShardramSink, value: u32) { + sink.emit_lk(LkOp::AssertU16 { + value: (value & 0xffff) as u16, + }); + sink.emit_lk(LkOp::AssertU16 { + value: (value >> 16) as u16, + }); +} + +pub fn emit_const_range_op(sink: &mut impl LkShardramSink, value: u64, bits: usize) { + match bits { + 0 | 1 => {} + 14 => sink.emit_lk(LkOp::AssertU14 { + value: value as u16, + }), + 16 => sink.emit_lk(LkOp::AssertU16 { + value: value as u16, + }), + _ => sink.emit_lk(LkOp::DynamicRange { + value, + bits: bits as u32, + }), + } +} + +pub fn emit_byte_decomposition_ops(sink: &mut impl LkShardramSink, bytes: &[u8]) { + for chunk in bytes.chunks(2) { + match chunk { + [a, b] => sink.emit_lk(LkOp::DoubleU8 { a: *a, b: *b }), + [a] => emit_const_range_op(sink, *a as u64, 8), + _ => unreachable!(), + } + } +} + +pub fn emit_signed_extend_op(sink: &mut impl LkShardramSink, n_bits: usize, value: u64) { + let msb = value >> (n_bits - 1); + sink.emit_lk(LkOp::DynamicRange { + value: 2 * value - (msb << n_bits), + bits: n_bits as u32, + }); +} + +pub fn emit_logic_u8_ops( + sink: &mut impl LkShardramSink, + lhs: u64, + rhs: u64, + num_bytes: usize, +) { + for i in 0..num_bytes { + let a = ((lhs >> (i * 8)) & 0xff) as u8; + let b = ((rhs >> (i * 8)) & 0xff) as u8; + let op = match OP::ROM_TYPE { + LookupTable::And => LkOp::And { a, b }, + LookupTable::Or => LkOp::Or { a, b }, + LookupTable::Xor => LkOp::Xor { a, b }, + LookupTable::Ltu => LkOp::Ltu { a, b }, + rom_type => unreachable!("unsupported logic table: {rom_type:?}"), + }; + sink.emit_lk(op); + } +} + +pub fn emit_uint_limbs_lt_ops( + sink: &mut impl LkShardramSink, + is_sign_comparison: bool, + a: &[u16], + b: &[u16], +) { + assert_eq!(a.len(), UINT_LIMBS); + assert_eq!(b.len(), UINT_LIMBS); + + let last = UINT_LIMBS - 1; + let sign_mask = 1 << (LIMB_BITS - 1); + let is_a_neg = is_sign_comparison && (a[last] & sign_mask) != 0; + let is_b_neg = is_sign_comparison && (b[last] & sign_mask) != 0; + + let (cmp_lt, diff_idx) = (0..UINT_LIMBS) + .rev() + .find(|&i| a[i] != b[i]) + .map(|i| ((a[i] < b[i]) ^ is_a_neg ^ is_b_neg, i)) + .unwrap_or((false, UINT_LIMBS)); + + let a_msb_range = if is_a_neg { + a[last] - sign_mask + } else { + a[last] + ((is_sign_comparison as u16) << (LIMB_BITS - 1)) + }; + let b_msb_range = if is_b_neg { + b[last] - sign_mask + } else { + b[last] + ((is_sign_comparison as u16) << (LIMB_BITS - 1)) + }; + + let to_signed = |value: u16, is_neg: bool| -> i32 { + if is_neg { + value as i32 - (1 << LIMB_BITS) + } else { + value as i32 + } + }; + let diff_val = if diff_idx == UINT_LIMBS { + 0 + } else if diff_idx == last { + let a_signed = to_signed(a[last], is_a_neg); + let b_signed = to_signed(b[last], is_b_neg); + if cmp_lt { + (b_signed - a_signed) as u16 + } else { + (a_signed - b_signed) as u16 + } + } else if cmp_lt { + b[diff_idx] - a[diff_idx] + } else { + a[diff_idx] - b[diff_idx] + }; + + emit_const_range_op( + sink, + if diff_idx == UINT_LIMBS { + 0 + } else { + (diff_val - 1) as u64 + }, + LIMB_BITS, + ); + emit_const_range_op(sink, a_msb_range as u64, LIMB_BITS); + emit_const_range_op(sink, b_msb_range as u64, LIMB_BITS); +} diff --git a/ceno_zkvm/src/instructions/gpu/utils/fallback.rs b/ceno_zkvm/src/instructions/gpu/utils/fallback.rs new file mode 100644 index 000000000..bd15be32c --- /dev/null +++ b/ceno_zkvm/src/instructions/gpu/utils/fallback.rs @@ -0,0 +1,158 @@ +use ceno_emul::StepIndex; +use ff_ext::ExtensionField; +use gkr_iop::utils::lk_multiplicity::Multiplicity; +use itertools::Itertools; +use multilinear_extensions::util::max_usable_threads; +use p3::field::FieldAlgebra; +use rayon::{ + iter::{IndexedParallelIterator, ParallelIterator}, + slice::ParallelSlice, +}; +use witness::RowMajorMatrix; + +use crate::{e2e::ShardContext, error::ZKVMError, tables::RMMCollections, witness::LkMultiplicity}; + +use crate::instructions::Instruction; + +/// CPU-only assign_instances. Extracted so GPU-enabled instructions can call this as fallback. +pub fn cpu_assign_instances>( + config: &I::InstructionConfig, + shard_ctx: &mut ShardContext, + num_witin: usize, + num_structural_witin: usize, + shard_steps: &[ceno_emul::StepRecord], + step_indices: &[StepIndex], +) -> Result<(RMMCollections, Multiplicity), ZKVMError> { + assert!(num_structural_witin == 0 || num_structural_witin == 1); + let num_structural_witin = num_structural_witin.max(1); + + let nthreads = max_usable_threads(); + let total_instances = step_indices.len(); + let num_instance_per_batch = if total_instances > 256 { + total_instances.div_ceil(nthreads) + } else { + total_instances + } + .max(1); + let lk_multiplicity = LkMultiplicity::default(); + let mut raw_witin = + RowMajorMatrix::::new(total_instances, num_witin, I::padding_strategy()); + let mut raw_structual_witin = RowMajorMatrix::::new( + total_instances, + num_structural_witin, + I::padding_strategy(), + ); + let raw_witin_iter = raw_witin.par_batch_iter_mut(num_instance_per_batch); + let raw_structual_witin_iter = raw_structual_witin.par_batch_iter_mut(num_instance_per_batch); + let shard_ctx_vec = shard_ctx.get_forked(); + + raw_witin_iter + .zip_eq(raw_structual_witin_iter) + .zip_eq(step_indices.par_chunks(num_instance_per_batch)) + .zip(shard_ctx_vec) + .flat_map( + |(((instances, structural_instance), indices), mut shard_ctx)| { + let mut lk_multiplicity = lk_multiplicity.clone(); + instances + .chunks_mut(num_witin) + .zip_eq(structural_instance.chunks_mut(num_structural_witin)) + .zip_eq(indices.iter().copied()) + .map(|((instance, structural_instance), step_idx)| { + *structural_instance.last_mut().unwrap() = E::BaseField::ONE; + I::assign_instance( + config, + &mut shard_ctx, + instance, + &mut lk_multiplicity, + &shard_steps[step_idx], + ) + }) + .collect::>() + }, + ) + .collect::>()?; + + raw_witin.padding_by_strategy(); + raw_structual_witin.padding_by_strategy(); + Ok(( + [raw_witin, raw_structual_witin], + lk_multiplicity.into_finalize_result(), + )) +} + +/// CPU-only lk_shardram collection for GPU-enabled instructions. +/// +/// This path deliberately avoids scratch witness buffers and calls only the +/// instruction-specific side-effect collector. +pub fn cpu_collect_lk_and_shardram>( + config: &I::InstructionConfig, + shard_ctx: &mut ShardContext, + shard_steps: &[ceno_emul::StepRecord], + step_indices: &[StepIndex], +) -> Result, ZKVMError> { + cpu_collect_lk_shardram_inner::(config, shard_ctx, shard_steps, step_indices, false) +} + +/// CPU-side `send()` / `addr_accessed` collection for GPU-assisted lk paths. +/// +/// Implementations may still increment fetch multiplicity on CPU, but all other +/// lookup multiplicities are expected to come from the GPU path. +pub fn cpu_collect_shardram>( + config: &I::InstructionConfig, + shard_ctx: &mut ShardContext, + shard_steps: &[ceno_emul::StepRecord], + step_indices: &[StepIndex], +) -> Result, ZKVMError> { + cpu_collect_lk_shardram_inner::(config, shard_ctx, shard_steps, step_indices, true) +} + +fn cpu_collect_lk_shardram_inner>( + config: &I::InstructionConfig, + shard_ctx: &mut ShardContext, + shard_steps: &[ceno_emul::StepRecord], + step_indices: &[StepIndex], + shard_only: bool, +) -> Result, ZKVMError> { + let nthreads = max_usable_threads(); + let total = step_indices.len(); + let batch_size = if total > 256 { + total.div_ceil(nthreads) + } else { + total + } + .max(1); + + let lk_multiplicity = LkMultiplicity::default(); + let shard_ctx_vec = shard_ctx.get_forked(); + + step_indices + .par_chunks(batch_size) + .zip(shard_ctx_vec) + .flat_map(|(indices, mut shard_ctx)| { + let mut lk_multiplicity = lk_multiplicity.clone(); + indices + .iter() + .copied() + .map(|step_idx| { + if shard_only { + I::collect_shardram( + config, + &mut shard_ctx, + &mut lk_multiplicity, + &shard_steps[step_idx], + ) + } else { + I::collect_lk_and_shardram( + config, + &mut shard_ctx, + &mut lk_multiplicity, + &shard_steps[step_idx], + ) + } + }) + .collect::>() + }) + .collect::>()?; + + Ok(lk_multiplicity.into_finalize_result()) +} diff --git a/ceno_zkvm/src/instructions/gpu/utils/lk_ops.rs b/ceno_zkvm/src/instructions/gpu/utils/lk_ops.rs new file mode 100644 index 000000000..ef409f845 --- /dev/null +++ b/ceno_zkvm/src/instructions/gpu/utils/lk_ops.rs @@ -0,0 +1,76 @@ +use ceno_emul::{Cycle, Word, WordAddr}; +use gkr_iop::tables::LookupTable; +use smallvec::SmallVec; + +use crate::structs::RAMType; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum LkOp { + AssertU16 { value: u16 }, + DynamicRange { value: u64, bits: u32 }, + AssertU14 { value: u16 }, + Fetch { pc: u32 }, + DoubleU8 { a: u8, b: u8 }, + And { a: u8, b: u8 }, + Or { a: u8, b: u8 }, + Xor { a: u8, b: u8 }, + Ltu { a: u8, b: u8 }, + Pow2 { value: u8 }, + ShrByte { shift: u8, carry: u8, bits: u8 }, +} + +impl LkOp { + pub fn encode_all(&self) -> SmallVec<[(LookupTable, u64); 2]> { + match *self { + LkOp::AssertU16 { value } => { + SmallVec::from_slice(&[(LookupTable::Dynamic, (1u64 << 16) + value as u64)]) + } + LkOp::DynamicRange { value, bits } => { + SmallVec::from_slice(&[(LookupTable::Dynamic, (1u64 << bits) + value)]) + } + LkOp::AssertU14 { value } => { + SmallVec::from_slice(&[(LookupTable::Dynamic, (1u64 << 14) + value as u64)]) + } + LkOp::Fetch { pc } => SmallVec::from_slice(&[(LookupTable::Instruction, pc as u64)]), + LkOp::DoubleU8 { a, b } => { + SmallVec::from_slice(&[(LookupTable::DoubleU8, ((a as u64) << 8) + b as u64)]) + } + LkOp::And { a, b } => { + SmallVec::from_slice(&[(LookupTable::And, (a as u64) | ((b as u64) << 8))]) + } + LkOp::Or { a, b } => { + SmallVec::from_slice(&[(LookupTable::Or, (a as u64) | ((b as u64) << 8))]) + } + LkOp::Xor { a, b } => { + SmallVec::from_slice(&[(LookupTable::Xor, (a as u64) | ((b as u64) << 8))]) + } + LkOp::Ltu { a, b } => { + SmallVec::from_slice(&[(LookupTable::Ltu, (a as u64) | ((b as u64) << 8))]) + } + LkOp::Pow2 { value } => { + SmallVec::from_slice(&[(LookupTable::Pow, 2u64 | ((value as u64) << 8))]) + } + LkOp::ShrByte { shift, carry, bits } => SmallVec::from_slice(&[ + ( + LookupTable::DoubleU8, + ((shift as u64) << 8) + ((shift as u64) << bits), + ), + ( + LookupTable::DoubleU8, + ((carry as u64) << 8) + ((carry as u64) << (8 - bits)), + ), + ]), + } + } +} + +#[derive(Debug, Clone, Copy)] +pub struct SendEvent { + pub ram_type: RAMType, + pub addr: WordAddr, + pub id: u64, + pub cycle: Cycle, + pub prev_cycle: Cycle, + pub value: Word, + pub prev_value: Option, +} diff --git a/ceno_zkvm/src/instructions/gpu/utils/mod.rs b/ceno_zkvm/src/instructions/gpu/utils/mod.rs new file mode 100644 index 000000000..1efabe305 --- /dev/null +++ b/ceno_zkvm/src/instructions/gpu/utils/mod.rs @@ -0,0 +1,68 @@ +//! Host-side operations for GPU-CPU hybrid witness generation. +//! +//! Contains lookup/shard lk_shardram collection abstractions and CPU fallback paths. + +mod emit; +mod fallback; +mod lk_ops; +mod sink; + +// Re-export all public types for convenience +pub use emit::*; +pub use fallback::*; +pub use lk_ops::*; +pub use sink::*; + +#[cfg(feature = "gpu")] +pub mod column_map; +#[cfg(feature = "gpu")] +pub mod d2h; +#[cfg(feature = "gpu")] +pub mod debug_compare; +#[cfg(all(test, feature = "gpu"))] +pub mod test_helpers; + +#[cfg(test)] +mod tests { + use super::*; + use crate::witness::LkMultiplicity; + use gkr_iop::tables::LookupTable; + + #[test] + fn test_lk_op_encodings_match_cpu_multiplicity() { + let ops = [ + LkOp::AssertU16 { value: 7 }, + LkOp::DynamicRange { value: 11, bits: 8 }, + LkOp::AssertU14 { value: 5 }, + LkOp::Fetch { pc: 0x1234 }, + LkOp::DoubleU8 { a: 1, b: 2 }, + LkOp::And { a: 3, b: 4 }, + LkOp::Or { a: 5, b: 6 }, + LkOp::Xor { a: 7, b: 8 }, + LkOp::Ltu { a: 9, b: 10 }, + LkOp::Pow2 { value: 12 }, + LkOp::ShrByte { + shift: 3, + carry: 17, + bits: 2, + }, + ]; + + let mut lk = LkMultiplicity::default(); + for op in ops { + for (table, key) in op.encode_all() { + lk.increment(table, key); + } + } + + let finalized = lk.into_finalize_result(); + assert_eq!(finalized[LookupTable::Dynamic as usize].len(), 3); + assert_eq!(finalized[LookupTable::Instruction as usize].len(), 1); + assert_eq!(finalized[LookupTable::DoubleU8 as usize].len(), 3); + assert_eq!(finalized[LookupTable::And as usize].len(), 1); + assert_eq!(finalized[LookupTable::Or as usize].len(), 1); + assert_eq!(finalized[LookupTable::Xor as usize].len(), 1); + assert_eq!(finalized[LookupTable::Ltu as usize].len(), 1); + assert_eq!(finalized[LookupTable::Pow as usize].len(), 1); + } +} diff --git a/ceno_zkvm/src/instructions/gpu/utils/sink.rs b/ceno_zkvm/src/instructions/gpu/utils/sink.rs new file mode 100644 index 000000000..4e7ff8fde --- /dev/null +++ b/ceno_zkvm/src/instructions/gpu/utils/sink.rs @@ -0,0 +1,84 @@ +use ceno_emul::WordAddr; +use std::marker::PhantomData; + +use crate::{e2e::ShardContext, witness::LkMultiplicity}; + +use super::{LkOp, SendEvent}; + +pub trait LkShardramSink { + fn emit_lk(&mut self, op: LkOp); + fn emit_send(&mut self, event: SendEvent); + fn touch_addr(&mut self, addr: WordAddr); +} + +pub struct CpuLkShardramSink<'ctx, 'shard, 'lk> { + shard_ctx: *mut ShardContext<'shard>, + lk: &'lk mut LkMultiplicity, + _marker: PhantomData<&'ctx mut ShardContext<'shard>>, +} + +impl<'ctx, 'shard, 'lk> CpuLkShardramSink<'ctx, 'shard, 'lk> { + /// # Safety + /// Caller must ensure `shard_ctx` points to a valid, live `ShardContext` + /// for the duration of the returned sink's lifetime. + pub unsafe fn from_raw( + shard_ctx: *mut ShardContext<'shard>, + lk: &'lk mut LkMultiplicity, + ) -> Self { + Self { + shard_ctx, + lk, + _marker: PhantomData, + } + } + + fn shard_ctx(&mut self) -> &mut ShardContext<'shard> { + // Safety: `from_raw` is only constructed from a live `&mut ShardContext` + // for the duration of lk_shardram collection. + unsafe { &mut *self.shard_ctx } + } +} + +/// Create a `CpuLkShardramSink` and an immutable view of `ShardContext`, +/// then pass both to the closure `f`. +/// +/// This encapsulates the raw-pointer trick needed to hold `&mut ShardContext` +/// (inside the sink, for writes) and `&ShardContext` (for reads like +/// `aligned_prev_ts`) simultaneously. +/// +/// # Safety +/// Safe to call — the unsafety is internal and bounded by the closure lifetime. +pub fn with_cpu_sink<'a, R>( + shard_ctx: &'a mut ShardContext<'a>, + lk: &'a mut LkMultiplicity, + f: impl FnOnce(&mut CpuLkShardramSink<'a, 'a, 'a>, &ShardContext) -> R, +) -> R { + let ptr = shard_ctx as *mut ShardContext; + let view = unsafe { &*ptr }; + let mut sink = unsafe { CpuLkShardramSink::from_raw(ptr, lk) }; + f(&mut sink, view) +} + +impl LkShardramSink for CpuLkShardramSink<'_, '_, '_> { + fn emit_lk(&mut self, op: LkOp) { + for (table, key) in op.encode_all() { + self.lk.increment(table, key); + } + } + + fn emit_send(&mut self, event: SendEvent) { + self.shard_ctx().record_send_without_touch( + event.ram_type, + event.addr, + event.id, + event.cycle, + event.prev_cycle, + event.value, + event.prev_value, + ); + } + + fn touch_addr(&mut self, addr: WordAddr) { + self.shard_ctx().push_addr_accessed(addr); + } +} diff --git a/ceno_zkvm/src/instructions/gpu/utils/test_helpers.rs b/ceno_zkvm/src/instructions/gpu/utils/test_helpers.rs new file mode 100644 index 000000000..3fbbad170 --- /dev/null +++ b/ceno_zkvm/src/instructions/gpu/utils/test_helpers.rs @@ -0,0 +1,127 @@ +// --------------------------------------------------------------------------- +// GPU correctness test helpers +// --------------------------------------------------------------------------- + +/// Compare GPU column-major witness data against CPU row-major reference. +/// Panics with detailed mismatch info if any element differs. +#[cfg(test)] +pub fn assert_witness_colmajor_eq( + gpu_colmajor: &[F], + cpu_rowmajor: &[F], + n_rows: usize, + n_cols: usize, +) { + assert_eq!( + gpu_colmajor.len(), + cpu_rowmajor.len(), + "Size mismatch: gpu={} cpu={}", + gpu_colmajor.len(), + cpu_rowmajor.len() + ); + let mut mismatches = 0; + for row in 0..n_rows { + for col in 0..n_cols { + let gpu_val = &gpu_colmajor[col * n_rows + row]; + let cpu_val = &cpu_rowmajor[row * n_cols + col]; + if gpu_val != cpu_val { + if mismatches < 10 { + eprintln!("Mismatch at row={row}, col={col}: GPU={gpu_val:?}, CPU={cpu_val:?}"); + } + mismatches += 1; + } + } + } + assert_eq!(mismatches, 0, "Found {mismatches} mismatches"); +} + +/// Run `try_gpu_assign_instances` + `flush_shared_ec_buffers`, then assert +/// witness, LK multiplicity, addr_accessed, and read/write records all match +/// the CPU reference in `cpu_ctx`. +#[cfg(test)] +pub fn assert_full_gpu_pipeline< + E: ff_ext::ExtensionField, + I: crate::instructions::Instruction, +>( + config: &I::InstructionConfig, + steps: &[ceno_emul::StepRecord], + kind: crate::instructions::gpu::dispatch::GpuWitgenKind, + cpu_rmms: &crate::tables::RMMCollections, + cpu_lkm: &gkr_iop::utils::lk_multiplicity::Multiplicity, + cpu_ctx: &crate::e2e::ShardContext, + num_witin: usize, + num_structural_witin: usize, +) { + let indices: Vec = (0..steps.len()).collect(); + + let mut gpu_ctx = crate::e2e::ShardContext::default(); + let result = crate::instructions::gpu::dispatch::try_gpu_assign_instances::( + config, + &mut gpu_ctx, + num_witin, + num_structural_witin, + steps, + &indices, + kind, + ) + .unwrap(); + // Skip pipeline comparison if GPU witgen is not enabled (CENO_GPU_ENABLE_WITGEN unset) + let Some((gpu_rmms, gpu_lkm)) = result else { + eprintln!("GPU witgen not enabled, skipping full pipeline comparison"); + return; + }; + + crate::instructions::gpu::cache::flush_shared_ec_buffers(&mut gpu_ctx).unwrap(); + + assert_eq!( + gpu_rmms[0].values(), + cpu_rmms[0].values(), + "witness mismatch" + ); + assert_eq!( + flatten_lk_for_test(&gpu_lkm), + flatten_lk_for_test(cpu_lkm), + "LK multiplicity mismatch" + ); + assert_eq!( + gpu_ctx.get_addr_accessed(), + cpu_ctx.get_addr_accessed(), + "addr_accessed mismatch" + ); + assert_eq!( + flatten_records_for_test(gpu_ctx.read_records()), + flatten_records_for_test(cpu_ctx.read_records()), + "read_records mismatch" + ); + assert_eq!( + flatten_records_for_test(gpu_ctx.write_records()), + flatten_records_for_test(cpu_ctx.write_records()), + "write_records mismatch" + ); +} + +#[cfg(test)] +fn flatten_lk_for_test( + m: &gkr_iop::utils::lk_multiplicity::Multiplicity, +) -> Vec> { + m.iter() + .map(|table| { + let mut entries: Vec<_> = table.iter().map(|(k, v)| (*k, *v)).collect(); + entries.sort_unstable(); + entries + }) + .collect() +} + +#[cfg(test)] +fn flatten_records_for_test( + records: &[std::collections::BTreeMap], +) -> Vec<(ceno_emul::WordAddr, u64, u64, usize)> { + records + .iter() + .flat_map(|table| { + table + .iter() + .map(|(addr, r)| (*addr, r.prev_cycle, r.cycle, r.shard_id)) + }) + .collect() +} diff --git a/ceno_zkvm/src/instructions/riscv.rs b/ceno_zkvm/src/instructions/riscv.rs index c77b707b4..8798e01aa 100644 --- a/ceno_zkvm/src/instructions/riscv.rs +++ b/ceno_zkvm/src/instructions/riscv.rs @@ -33,11 +33,11 @@ mod r_insn; mod ecall_insn; #[cfg(feature = "u16limb_circuit")] -mod auipc; +pub(crate) mod auipc; mod im_insn; #[cfg(feature = "u16limb_circuit")] -mod lui; -mod memory; +pub(crate) mod lui; +pub(crate) mod memory; mod s_insn; #[cfg(test)] mod test; diff --git a/ceno_zkvm/src/instructions/riscv/arith.rs b/ceno_zkvm/src/instructions/riscv/arith.rs index a5f6e006f..198697fca 100644 --- a/ceno_zkvm/src/instructions/riscv/arith.rs +++ b/ceno_zkvm/src/instructions/riscv/arith.rs @@ -2,8 +2,14 @@ use std::marker::PhantomData; use super::{RIVInstruction, constants::UInt, r_insn::RInstructionConfig}; use crate::{ - circuit_builder::CircuitBuilder, e2e::ShardContext, error::ZKVMError, - instructions::Instruction, structs::ProgramParams, uint::Value, witness::LkMultiplicity, + circuit_builder::CircuitBuilder, + e2e::ShardContext, + error::ZKVMError, + impl_collect_lk_and_shardram, impl_collect_shardram, impl_gpu_assign, + instructions::{Instruction, gpu::utils::emit_u16_limbs}, + structs::ProgramParams, + uint::Value, + witness::LkMultiplicity, }; use ceno_emul::{InsnKind, StepRecord}; use ff_ext::ExtensionField; @@ -11,11 +17,11 @@ use ff_ext::ExtensionField; /// This config handles R-Instructions that represent registers values as 2 * u16. #[derive(Debug)] pub struct ArithConfig { - r_insn: RInstructionConfig, + pub r_insn: RInstructionConfig, - rs1_read: UInt, - rs2_read: UInt, - rd_written: UInt, + pub rs1_read: UInt, + pub rs2_read: UInt, + pub rd_written: UInt, } pub struct ArithInstruction(PhantomData<(E, I)>); @@ -36,6 +42,8 @@ impl Instruction for ArithInstruction; type InsnType = InsnKind; + const GPU_LK_SHARDRAM: bool = matches!(I::INST_KIND, InsnKind::ADD | InsnKind::SUB); + fn inst_kinds() -> &'static [Self::InsnType] { &[I::INST_KIND] } @@ -132,6 +140,27 @@ impl Instruction for ArithInstruction { + emit_u16_limbs(sink, step.rd().unwrap().value.after); + } + InsnKind::SUB => { + emit_u16_limbs(sink, step.rd().unwrap().value.after); + emit_u16_limbs(sink, step.rs1().unwrap().value); + } + _ => unreachable!("Unsupported instruction kind"), + } + }); + + impl_collect_shardram!(r_insn); + + impl_gpu_assign!(match I::INST_KIND { + InsnKind::ADD => Some(dispatch::GpuWitgenKind::Add), + InsnKind::SUB => Some(dispatch::GpuWitgenKind::Sub), + _ => None, + }); } #[cfg(test)] diff --git a/ceno_zkvm/src/instructions/riscv/arith_imm.rs b/ceno_zkvm/src/instructions/riscv/arith_imm.rs index f41832719..96a554a6c 100644 --- a/ceno_zkvm/src/instructions/riscv/arith_imm.rs +++ b/ceno_zkvm/src/instructions/riscv/arith_imm.rs @@ -1,7 +1,7 @@ #[cfg(not(feature = "u16limb_circuit"))] mod arith_imm_circuit; #[cfg(feature = "u16limb_circuit")] -mod arith_imm_circuit_v2; +pub(crate) mod arith_imm_circuit_v2; #[cfg(feature = "u16limb_circuit")] pub use crate::instructions::riscv::arith_imm::arith_imm_circuit_v2::AddiInstruction; diff --git a/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit_v2.rs index 027483d1e..2d0e62969 100644 --- a/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit_v2.rs @@ -3,8 +3,10 @@ use crate::{ circuit_builder::CircuitBuilder, e2e::ShardContext, error::ZKVMError, + impl_collect_lk_and_shardram, impl_collect_shardram, impl_gpu_assign, instructions::{ Instruction, + gpu::utils::emit_u16_limbs, riscv::{RIVInstruction, constants::UInt, i_insn::IInstructionConfig}, }, structs::ProgramParams, @@ -21,19 +23,21 @@ use witness::set_val; pub struct AddiInstruction(PhantomData); pub struct InstructionConfig { - i_insn: IInstructionConfig, + pub(crate) i_insn: IInstructionConfig, - rs1_read: UInt, - imm: WitIn, + pub(crate) rs1_read: UInt, + pub(crate) imm: WitIn, // 0 positive, 1 negative - imm_sign: WitIn, - rd_written: UInt, + pub(crate) imm_sign: WitIn, + pub(crate) rd_written: UInt, } impl Instruction for AddiInstruction { type InstructionConfig = InstructionConfig; type InsnType = InsnKind; + const GPU_LK_SHARDRAM: bool = true; + fn inst_kinds() -> &'static [Self::InsnType] { &[InsnKind::ADDI] } @@ -104,4 +108,12 @@ impl Instruction for AddiInstruction { Ok(()) } + + impl_collect_lk_and_shardram!(i_insn, |sink, step, _config, _ctx| { + emit_u16_limbs(sink, step.rd().unwrap().value.after); + }); + + impl_collect_shardram!(i_insn); + + impl_gpu_assign!(dispatch::GpuWitgenKind::Addi); } diff --git a/ceno_zkvm/src/instructions/riscv/auipc.rs b/ceno_zkvm/src/instructions/riscv/auipc.rs index 6311fc2aa..e5aea4394 100644 --- a/ceno_zkvm/src/instructions/riscv/auipc.rs +++ b/ceno_zkvm/src/instructions/riscv/auipc.rs @@ -6,8 +6,10 @@ use crate::{ circuit_builder::CircuitBuilder, e2e::ShardContext, error::ZKVMError, + impl_collect_lk_and_shardram, impl_collect_shardram, impl_gpu_assign, instructions::{ Instruction, + gpu::utils::{LkOp, LkShardramSink, emit_byte_decomposition_ops, emit_const_range_op}, riscv::{ constants::{PC_BITS, UINT_BYTE_LIMBS, UInt8}, i_insn::IInstructionConfig, @@ -39,6 +41,8 @@ impl Instruction for AuipcInstruction { type InstructionConfig = AuipcConfig; type InsnType = InsnKind; + const GPU_LK_SHARDRAM: bool = true; + fn inst_kinds() -> &'static [Self::InsnType] { &[InsnKind::AUIPC] } @@ -185,6 +189,38 @@ impl Instruction for AuipcInstruction { Ok(()) } + + impl_collect_lk_and_shardram!(i_insn, |sink, step, config, _ctx| { + let rd_written = split_to_u8(step.rd().unwrap().value.after); + emit_byte_decomposition_ops(sink, &rd_written); + + let pc = split_to_u8(step.pc().before.0); + // Only iterate over the middle limbs that have witness columns (pc_limbs has UINT_BYTE_LIMBS-2 elements). + // The MSB limb is range-checked via XOR below, the LSB is shared with rd_written[0]. + for val in pc.iter().skip(1).take(config.pc_limbs.len()) { + emit_const_range_op(sink, *val as u64, 8); + } + + let imm = InsnRecord::::imm_internal(&step.insn()).0 as u32; + for val in split_to_u8::(imm) + .into_iter() + .take(config.imm_limbs.len()) + { + emit_const_range_op(sink, val as u64, 8); + } + + let last_limb_bits = PC_BITS - UInt8::::LIMB_BITS * (UINT_BYTE_LIMBS - 1); + let additional_bits = + (last_limb_bits..UInt8::::LIMB_BITS).fold(0, |acc, x| acc + (1 << x)); + sink.emit_lk(LkOp::Xor { + a: pc[3], + b: additional_bits as u8, + }); + }); + + impl_collect_shardram!(i_insn); + + impl_gpu_assign!(dispatch::GpuWitgenKind::Auipc); } #[cfg(test)] diff --git a/ceno_zkvm/src/instructions/riscv/b_insn.rs b/ceno_zkvm/src/instructions/riscv/b_insn.rs index cdc1db56d..a94266f63 100644 --- a/ceno_zkvm/src/instructions/riscv/b_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/b_insn.rs @@ -7,7 +7,10 @@ use crate::{ circuit_builder::CircuitBuilder, e2e::ShardContext, error::ZKVMError, - instructions::riscv::insn_base::{ReadRS1, ReadRS2, StateInOut}, + instructions::{ + gpu::utils::{LkOp, LkShardramSink}, + riscv::insn_base::{ReadRS1, ReadRS2, StateInOut}, + }, tables::InsnRecord, witness::{LkMultiplicity, set_val}, }; @@ -111,4 +114,28 @@ impl BInstructionConfig { Ok(()) } + + pub fn emit_shardram( + &self, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) { + lk_multiplicity.fetch(step.pc().before.0); + self.rs1.emit_shardram(shard_ctx, step); + self.rs2.emit_shardram(shard_ctx, step); + } + + pub fn emit_lk_and_shardram( + &self, + sink: &mut impl LkShardramSink, + shard_ctx: &ShardContext, + step: &StepRecord, + ) { + sink.emit_lk(LkOp::Fetch { + pc: step.pc().before.0, + }); + self.rs1.emit_lk_and_shardram(sink, shard_ctx, step); + self.rs2.emit_lk_and_shardram(sink, shard_ctx, step); + } } diff --git a/ceno_zkvm/src/instructions/riscv/branch.rs b/ceno_zkvm/src/instructions/riscv/branch.rs index ab080ac0d..e1e5f40df 100644 --- a/ceno_zkvm/src/instructions/riscv/branch.rs +++ b/ceno_zkvm/src/instructions/riscv/branch.rs @@ -4,7 +4,7 @@ use ceno_emul::InsnKind; #[cfg(not(feature = "u16limb_circuit"))] mod branch_circuit; #[cfg(feature = "u16limb_circuit")] -mod branch_circuit_v2; +pub(crate) mod branch_circuit_v2; #[cfg(test)] mod test; diff --git a/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs index 85ef6914b..34b6ebe68 100644 --- a/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs @@ -4,8 +4,10 @@ use crate::{ e2e::ShardContext, error::ZKVMError, gadgets::{UIntLimbsLT, UIntLimbsLTConfig}, + impl_collect_lk_and_shardram, impl_collect_shardram, impl_gpu_assign, instructions::{ Instruction, + gpu::utils::emit_uint_limbs_lt_ops, riscv::{ RIVInstruction, b_insn::BInstructionConfig, @@ -41,6 +43,8 @@ impl Instruction for BranchCircuit; type InsnType = InsnKind; + const GPU_LK_SHARDRAM: bool = true; + fn inst_kinds() -> &'static [Self::InsnType] { &[I::INST_KIND] } @@ -204,4 +208,31 @@ impl Instruction for BranchCircuit Some(dispatch::GpuWitgenKind::BranchEq(1)), + InsnKind::BNE => Some(dispatch::GpuWitgenKind::BranchEq(0)), + InsnKind::BLT => Some(dispatch::GpuWitgenKind::BranchCmp(1)), + InsnKind::BGE => Some(dispatch::GpuWitgenKind::BranchCmp(1)), + InsnKind::BLTU => Some(dispatch::GpuWitgenKind::BranchCmp(0)), + InsnKind::BGEU => Some(dispatch::GpuWitgenKind::BranchCmp(0)), + _ => unreachable!(), + }); } diff --git a/ceno_zkvm/src/instructions/riscv/div.rs b/ceno_zkvm/src/instructions/riscv/div.rs index 981995452..829f6140c 100644 --- a/ceno_zkvm/src/instructions/riscv/div.rs +++ b/ceno_zkvm/src/instructions/riscv/div.rs @@ -3,7 +3,7 @@ use ceno_emul::InsnKind; #[cfg(not(feature = "u16limb_circuit"))] mod div_circuit; #[cfg(feature = "u16limb_circuit")] -mod div_circuit_v2; +pub(crate) mod div_circuit_v2; use super::RIVInstruction; diff --git a/ceno_zkvm/src/instructions/riscv/div/div_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/div/div_circuit_v2.rs index eb1a5d0f9..acfb4042d 100644 --- a/ceno_zkvm/src/instructions/riscv/div/div_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/div/div_circuit_v2.rs @@ -14,7 +14,12 @@ use crate::{ circuit_builder::CircuitBuilder, e2e::ShardContext, error::ZKVMError, - instructions::{Instruction, riscv::constants::LIMB_BITS}, + impl_collect_lk_and_shardram, impl_collect_shardram, impl_gpu_assign, + instructions::{ + Instruction, + gpu::utils::{LkOp, LkShardramSink, emit_u16_limbs}, + riscv::constants::LIMB_BITS, + }, structs::ProgramParams, uint::Value, witness::{LkMultiplicity, set_val}, @@ -30,18 +35,18 @@ pub struct DivRemConfig { pub(crate) remainder: UInt, pub(crate) r_insn: RInstructionConfig, - dividend_sign: WitIn, - divisor_sign: WitIn, - quotient_sign: WitIn, - remainder_zero: WitIn, - divisor_zero: WitIn, - divisor_sum_inv: WitIn, - remainder_sum_inv: WitIn, - remainder_inv: [WitIn; UINT_LIMBS], - sign_xor: WitIn, - remainder_prime: UInt, // r' - lt_marker: [WitIn; UINT_LIMBS], - lt_diff: WitIn, + pub(crate) dividend_sign: WitIn, + pub(crate) divisor_sign: WitIn, + pub(crate) quotient_sign: WitIn, + pub(crate) remainder_zero: WitIn, + pub(crate) divisor_zero: WitIn, + pub(crate) divisor_sum_inv: WitIn, + pub(crate) remainder_sum_inv: WitIn, + pub(crate) remainder_inv: [WitIn; UINT_LIMBS], + pub(crate) sign_xor: WitIn, + pub(crate) remainder_prime: UInt, // r' + pub(crate) lt_marker: [WitIn; UINT_LIMBS], + pub(crate) lt_diff: WitIn, } pub struct ArithInstruction(PhantomData<(E, I)>); @@ -50,6 +55,8 @@ impl Instruction for ArithInstruction; type InsnType = InsnKind; + const GPU_LK_SHARDRAM: bool = true; + fn inst_kinds() -> &'static [Self::InsnType] { &[I::INST_KIND] } @@ -376,6 +383,14 @@ impl Instruction for ArithInstruction Some(dispatch::GpuWitgenKind::Div(0u32)), + InsnKind::DIVU => Some(dispatch::GpuWitgenKind::Div(1u32)), + InsnKind::REM => Some(dispatch::GpuWitgenKind::Div(2u32)), + InsnKind::REMU => Some(dispatch::GpuWitgenKind::Div(3u32)), + _ => None, + }); + fn assign_instance( config: &Self::InstructionConfig, shard_ctx: &mut ShardContext, @@ -522,6 +537,87 @@ impl Instruction for ArithInstruction Instruction for LargeEcallDummy lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, ) -> Result<(), ZKVMError> { - let ops = &step.syscall().expect("syscall step"); + let sw = shard_ctx.syscall_witnesses.clone(); + let ops = &step.syscall(&sw).expect("syscall step"); // Assign instruction. config diff --git a/ceno_zkvm/src/instructions/riscv/dummy/test.rs b/ceno_zkvm/src/instructions/riscv/dummy/test.rs index b74c8ca39..bf952a4f1 100644 --- a/ceno_zkvm/src/instructions/riscv/dummy/test.rs +++ b/ceno_zkvm/src/instructions/riscv/dummy/test.rs @@ -18,10 +18,12 @@ fn test_large_ecall_dummy_keccak() { let mut cb = CircuitBuilder::new(&mut cs); let config = KeccakDummy::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); - let (step, program) = ceno_emul::test_utils::keccak_step(); + let (step, program, syscall_witnesses) = ceno_emul::test_utils::keccak_step(); + let mut shard_ctx = ShardContext::default(); + shard_ctx.syscall_witnesses = std::sync::Arc::new(syscall_witnesses); let (raw_witin, lkm) = KeccakDummy::assign_instances_from_steps( &config, - &mut ShardContext::default(), + &mut shard_ctx, cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, &[step], diff --git a/ceno_zkvm/src/instructions/riscv/ecall.rs b/ceno_zkvm/src/instructions/riscv/ecall.rs index 84a836bf3..b743e8bb3 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall.rs @@ -2,7 +2,7 @@ mod fptower_fp; mod fptower_fp2_add; mod fptower_fp2_mul; mod halt; -mod keccak; +pub(crate) mod keccak; mod sha_extend; mod uint256; mod weierstrass_add; diff --git a/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp.rs b/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp.rs index 57d824b01..7ee531cc8 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp.rs @@ -358,7 +358,8 @@ fn assign_fp_op_instances( .zip_eq(indices.iter().copied()) .map(|(instance, idx)| { let step = &steps[idx]; - let ops = &step.syscall().expect("syscall step"); + let sw = shard_ctx.syscall_witnesses.clone(); + let ops = &step.syscall(&sw).expect("syscall step"); config .vm_state .assign_instance(instance, &shard_ctx, step)?; @@ -419,7 +420,7 @@ fn assign_fp_op_instances( .map(|&idx| { let step = &steps[idx]; let values: Vec = step - .syscall() + .syscall(&shard_ctx.syscall_witnesses) .unwrap() .mem_ops .iter() diff --git a/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp2_add.rs b/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp2_add.rs index 6715a0b74..2d99ed4a4 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp2_add.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp2_add.rs @@ -273,7 +273,8 @@ fn assign_fp2_add_instances = step - .syscall() + .syscall(&shard_ctx.syscall_witnesses) .unwrap() .mem_ops .iter() diff --git a/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp2_mul.rs b/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp2_mul.rs index 7537d31ca..4aacf3418 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp2_mul.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp2_mul.rs @@ -271,7 +271,8 @@ fn assign_fp2_mul_instances = step - .syscall() + .syscall(&shard_ctx.syscall_witnesses) .unwrap() .mem_ops .iter() diff --git a/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs b/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs index e088cc0cc..4c82ffd08 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs @@ -44,13 +44,13 @@ use crate::{ #[derive(Debug)] pub struct EcallKeccakConfig { pub layout: KeccakLayout, - vm_state: StateInOut, - ecall_id: OpFixedRS, - state_ptr: (OpFixedRS, MemAddr), - mem_rw: Vec, + pub(crate) vm_state: StateInOut, + pub(crate) ecall_id: OpFixedRS, + pub(crate) state_ptr: (OpFixedRS, MemAddr), + pub(crate) mem_rw: Vec, } -/// KeccakInstruction can handle any instruction and produce its side-effects. +/// KeccakInstruction can handle any instruction and produce its lk and shardram data. pub struct KeccakInstruction(PhantomData); impl Instruction for KeccakInstruction { @@ -178,6 +178,21 @@ impl Instruction for KeccakInstruction { steps: &[StepRecord], step_indices: &[StepIndex], ) -> Result<(RMMCollections, Multiplicity), ZKVMError> { + #[cfg(feature = "gpu")] + { + use crate::instructions::gpu::chips::keccak::gpu_assign_keccak_instances; + if let Some(result) = gpu_assign_keccak_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + steps, + step_indices, + )? { + return Ok(result); + } + } + let mut lk_multiplicity = LkMultiplicity::default(); if step_indices.is_empty() { return Ok(( @@ -221,7 +236,8 @@ impl Instruction for KeccakInstruction { .zip_eq(indices.iter().copied()) .map(|(instance_with_rotation, idx)| { let step = &steps[idx]; - let ops = &step.syscall().expect("syscall step"); + let sw = shard_ctx.syscall_witnesses.clone(); + let ops = &step.syscall(&sw).expect("syscall step"); let bh = BooleanHypercube::new(KECCAK_ROUNDS_CEIL_LOG2); let mut cyclic_group = bh.into_iter(); @@ -285,7 +301,7 @@ impl Instruction for KeccakInstruction { .map(|&idx| -> KeccakInstance { let step = &steps[idx]; let (instance, prev_ts): (Vec, Vec) = step - .syscall() + .syscall(&shard_ctx.syscall_witnesses) .unwrap() .mem_ops .iter() diff --git a/ceno_zkvm/src/instructions/riscv/ecall/sha_extend.rs b/ceno_zkvm/src/instructions/riscv/ecall/sha_extend.rs index b61673e4a..8044e5d2e 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/sha_extend.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/sha_extend.rs @@ -47,7 +47,7 @@ pub struct EcallShaExtendConfig { mem_rw: Vec, } -/// ShaExtendInstruction can handle any instruction and produce its side-effects. +/// ShaExtendInstruction can handle any instruction and produce its lk and shardram data. pub struct ShaExtendInstruction(PhantomData); impl Instruction for ShaExtendInstruction { @@ -218,7 +218,8 @@ impl Instruction for ShaExtendInstruction { .zip_eq(indices.iter().copied()) .map(|(instance, idx)| { let step = &steps[idx]; - let ops = step.syscall().expect("syscall step"); + let sw = shard_ctx.syscall_witnesses.clone(); + let ops = step.syscall(&sw).expect("syscall step"); // vm_state config @@ -285,7 +286,8 @@ impl Instruction for ShaExtendInstruction { .iter() .map(|&idx| -> ShaExtendInstance { let step = &steps[idx]; - let ops = step.syscall().expect("syscall step"); + let sw = shard_ctx.syscall_witnesses.clone(); + let ops = step.syscall(&sw).expect("syscall step"); let w_i_minus_2 = ops.mem_ops[0].value.before; let w_i_minus_7 = ops.mem_ops[1].value.before; let w_i_minus_15 = ops.mem_ops[2].value.before; diff --git a/ceno_zkvm/src/instructions/riscv/ecall/uint256.rs b/ceno_zkvm/src/instructions/riscv/ecall/uint256.rs index f3a39093f..92c3778c5 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/uint256.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/uint256.rs @@ -63,7 +63,7 @@ pub struct EcallUint256MulConfig { mem_rw: Vec, } -/// Uint256MulInstruction can handle any instruction and produce its side-effects. +/// Uint256MulInstruction can handle any instruction and produce its lk and shardram data. pub struct Uint256MulInstruction(PhantomData); impl Instruction for Uint256MulInstruction { @@ -270,7 +270,8 @@ impl Instruction for Uint256MulInstruction { .zip_eq(indices.iter().copied()) .map(|(instance, idx)| { let step = &steps[idx]; - let ops = &step.syscall().expect("syscall step"); + let sw = shard_ctx.syscall_witnesses.clone(); + let ops = &step.syscall(&sw).expect("syscall step"); // vm_state config @@ -336,7 +337,7 @@ impl Instruction for Uint256MulInstruction { .map(|&idx| { let step = &steps[idx]; let (instance, _prev_ts): (Vec, Vec) = step - .syscall() + .syscall(&shard_ctx.syscall_witnesses) .unwrap() .mem_ops .iter() @@ -371,7 +372,7 @@ impl Instruction for Uint256MulInstruction { } } -/// Uint256InvInstruction can handle any instruction and produce its side-effects. +/// Uint256InvInstruction can handle any instruction and produce its lk and shardram data. pub struct Uint256InvInstruction(PhantomData<(E, P)>); pub struct Secp256K1EcallSpec; @@ -593,7 +594,8 @@ impl Instruction for Uint256InvInstr .zip_eq(indices.iter().copied()) .map(|(instance, idx)| { let step = &steps[idx]; - let ops = &step.syscall().expect("syscall step"); + let sw = shard_ctx.syscall_witnesses.clone(); + let ops = &step.syscall(&sw).expect("syscall step"); // vm_state config @@ -646,7 +648,7 @@ impl Instruction for Uint256InvInstr .map(|&idx| { let step = &steps[idx]; let (instance, _): (Vec, Vec) = step - .syscall() + .syscall(&shard_ctx.syscall_witnesses) .unwrap() .mem_ops .iter() diff --git a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_add.rs b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_add.rs index 80c85ef7a..27470aff1 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_add.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_add.rs @@ -52,7 +52,7 @@ pub struct EcallWeierstrassAddAssignConfig mem_rw: Vec, } -/// WeierstrassAddAssignInstruction can handle any instruction and produce its side-effects. +/// WeierstrassAddAssignInstruction can handle any instruction and produce its lk and shardram data. pub struct WeierstrassAddAssignInstruction(PhantomData<(E, EC)>); impl Instruction @@ -278,7 +278,8 @@ impl Instruction .zip_eq(indices.iter().copied()) .map(|(instance, idx)| { let step = &steps[idx]; - let ops = &step.syscall().expect("syscall step"); + let sw = shard_ctx.syscall_witnesses.clone(); + let ops = &step.syscall(&sw).expect("syscall step"); // vm_state config @@ -345,7 +346,7 @@ impl Instruction .map(|&idx| { let step = &steps[idx]; let (instance, _prev_ts): (Vec, Vec) = step - .syscall() + .syscall(&shard_ctx.syscall_witnesses) .unwrap() .mem_ops .iter() diff --git a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_decompress.rs b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_decompress.rs index a07fc00b2..682b2b967 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_decompress.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_decompress.rs @@ -59,7 +59,7 @@ pub struct EcallWeierstrassDecompressConfig, } -/// WeierstrassDecompressInstruction can handle any instruction and produce its side-effects. +/// WeierstrassDecompressInstruction can handle any instruction and produce its lk and shardram data. pub struct WeierstrassDecompressInstruction(PhantomData<(E, EC)>); impl Instruction @@ -278,7 +278,8 @@ impl Instruction Instruction, Vec) = step - .syscall() + .syscall(&shard_ctx.syscall_witnesses) .unwrap() .mem_ops .iter() diff --git a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_double.rs b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_double.rs index 72f5f71d8..a95221b7c 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_double.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_double.rs @@ -54,7 +54,7 @@ pub struct EcallWeierstrassDoubleAssignConfig< mem_rw: Vec, } -/// WeierstrassDoubleAssignInstruction can handle any instruction and produce its side-effects. +/// WeierstrassDoubleAssignInstruction can handle any instruction and produce its lk and shardram data. pub struct WeierstrassDoubleAssignInstruction(PhantomData<(E, EC)>); impl Instruction @@ -250,7 +250,8 @@ impl Instruction Instruction, Vec) = step - .syscall() + .syscall(&shard_ctx.syscall_witnesses) .unwrap() .mem_ops .iter() diff --git a/ceno_zkvm/src/instructions/riscv/ecall_base.rs b/ceno_zkvm/src/instructions/riscv/ecall_base.rs index 7e655c408..69d9d286f 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall_base.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall_base.rs @@ -18,13 +18,13 @@ use ceno_emul::FullTracer as Tracer; use multilinear_extensions::{ToExpr, WitIn}; #[derive(Debug)] -pub struct OpFixedRS { +pub struct OpFixedRS { pub prev_ts: WitIn, pub prev_value: Option>, pub lt_cfg: AssertLtConfig, } -impl OpFixedRS { +impl OpFixedRS { pub fn construct_circuit( circuit_builder: &mut CircuitBuilder, rd_written: RegisterExpr, diff --git a/ceno_zkvm/src/instructions/riscv/i_insn.rs b/ceno_zkvm/src/instructions/riscv/i_insn.rs index c726f8a88..39786e9ce 100644 --- a/ceno_zkvm/src/instructions/riscv/i_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/i_insn.rs @@ -6,7 +6,10 @@ use crate::{ circuit_builder::CircuitBuilder, e2e::ShardContext, error::ZKVMError, - instructions::riscv::insn_base::{ReadRS1, StateInOut, WriteRD}, + instructions::{ + gpu::utils::{LkOp, LkShardramSink}, + riscv::insn_base::{ReadRS1, StateInOut, WriteRD}, + }, tables::InsnRecord, witness::LkMultiplicity, }; @@ -76,4 +79,28 @@ impl IInstructionConfig { Ok(()) } + + pub fn emit_lk_and_shardram( + &self, + sink: &mut impl LkShardramSink, + shard_ctx: &ShardContext, + step: &StepRecord, + ) { + sink.emit_lk(LkOp::Fetch { + pc: step.pc().before.0, + }); + self.rs1.emit_lk_and_shardram(sink, shard_ctx, step); + self.rd.emit_lk_and_shardram(sink, shard_ctx, step); + } + + pub fn emit_shardram( + &self, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) { + lk_multiplicity.fetch(step.pc().before.0); + self.rs1.emit_shardram(shard_ctx, step); + self.rd.emit_shardram(shard_ctx, step); + } } diff --git a/ceno_zkvm/src/instructions/riscv/im_insn.rs b/ceno_zkvm/src/instructions/riscv/im_insn.rs index c7f6cace0..78a029477 100644 --- a/ceno_zkvm/src/instructions/riscv/im_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/im_insn.rs @@ -2,7 +2,10 @@ use crate::{ chip_handler::{AddressExpr, MemoryExpr, RegisterExpr, general::InstFetch}, circuit_builder::CircuitBuilder, error::ZKVMError, - instructions::riscv::insn_base::{ReadMEM, ReadRS1, StateInOut, WriteRD}, + instructions::{ + gpu::utils::{LkOp, LkShardramSink}, + riscv::insn_base::{ReadMEM, ReadRS1, StateInOut, WriteRD}, + }, tables::InsnRecord, witness::LkMultiplicity, }; @@ -17,10 +20,10 @@ use multilinear_extensions::{Expression, ToExpr}; /// - Register reads and writes /// - Memory reads pub struct IMInstructionConfig { - vm_state: StateInOut, - rs1: ReadRS1, - rd: WriteRD, - mem_read: ReadMEM, + pub vm_state: StateInOut, + pub rs1: ReadRS1, + pub rd: WriteRD, + pub mem_read: ReadMEM, } impl IMInstructionConfig { @@ -85,4 +88,30 @@ impl IMInstructionConfig { Ok(()) } + + pub fn emit_lk_and_shardram( + &self, + sink: &mut impl LkShardramSink, + shard_ctx: &ShardContext, + step: &StepRecord, + ) { + sink.emit_lk(LkOp::Fetch { + pc: step.pc().before.0, + }); + self.rs1.emit_lk_and_shardram(sink, shard_ctx, step); + self.rd.emit_lk_and_shardram(sink, shard_ctx, step); + self.mem_read.emit_lk_and_shardram(sink, shard_ctx, step); + } + + pub fn emit_shardram( + &self, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) { + lk_multiplicity.fetch(step.pc().before.0); + self.rs1.emit_shardram(shard_ctx, step); + self.rd.emit_shardram(shard_ctx, step); + self.mem_read.emit_shardram(shard_ctx, step); + } } diff --git a/ceno_zkvm/src/instructions/riscv/insn_base.rs b/ceno_zkvm/src/instructions/riscv/insn_base.rs index 1a378ad8c..56ca70c3f 100644 --- a/ceno_zkvm/src/instructions/riscv/insn_base.rs +++ b/ceno_zkvm/src/instructions/riscv/insn_base.rs @@ -13,6 +13,7 @@ use crate::{ e2e::ShardContext, error::ZKVMError, gadgets::AssertLtConfig, + instructions::gpu::utils::{LkOp, LkShardramSink, SendEvent, emit_assert_lt_ops}, structs::RAMType, uint::Value, witness::{LkMultiplicity, set_val}, @@ -141,6 +142,47 @@ impl ReadRS1 { Ok(()) } + + pub fn emit_lk_and_shardram( + &self, + sink: &mut impl LkShardramSink, + shard_ctx: &ShardContext, + step: &StepRecord, + ) { + let op = step.rs1().expect("rs1 op"); + let shard_prev_cycle = shard_ctx.aligned_prev_ts(op.previous_cycle); + let shard_cycle = step.cycle() - shard_ctx.current_shard_offset_cycle(); + emit_assert_lt_ops( + sink, + &self.lt_cfg, + shard_prev_cycle, + shard_cycle + Tracer::SUBCYCLE_RS1, + ); + sink.emit_send(SendEvent { + ram_type: RAMType::Register, + addr: op.addr, + id: op.register_index() as u64, + cycle: step.cycle() + Tracer::SUBCYCLE_RS1, + prev_cycle: op.previous_cycle, + value: op.value, + prev_value: None, + }); + sink.touch_addr(op.addr); + } + + pub fn emit_shardram(&self, shard_ctx: &mut ShardContext, step: &StepRecord) { + let op = step.rs1().expect("rs1 op"); + shard_ctx.record_send_without_touch( + RAMType::Register, + op.addr, + op.register_index() as u64, + step.cycle() + Tracer::SUBCYCLE_RS1, + op.previous_cycle, + op.value, + None, + ); + shard_ctx.push_addr_accessed(op.addr); + } } #[derive(Debug)] @@ -209,6 +251,47 @@ impl ReadRS2 { Ok(()) } + + pub fn emit_lk_and_shardram( + &self, + sink: &mut impl LkShardramSink, + shard_ctx: &ShardContext, + step: &StepRecord, + ) { + let op = step.rs2().expect("rs2 op"); + let shard_prev_cycle = shard_ctx.aligned_prev_ts(op.previous_cycle); + let shard_cycle = step.cycle() - shard_ctx.current_shard_offset_cycle(); + emit_assert_lt_ops( + sink, + &self.lt_cfg, + shard_prev_cycle, + shard_cycle + Tracer::SUBCYCLE_RS2, + ); + sink.emit_send(SendEvent { + ram_type: RAMType::Register, + addr: op.addr, + id: op.register_index() as u64, + cycle: step.cycle() + Tracer::SUBCYCLE_RS2, + prev_cycle: op.previous_cycle, + value: op.value, + prev_value: None, + }); + sink.touch_addr(op.addr); + } + + pub fn emit_shardram(&self, shard_ctx: &mut ShardContext, step: &StepRecord) { + let op = step.rs2().expect("rs2 op"); + shard_ctx.record_send_without_touch( + RAMType::Register, + op.addr, + op.register_index() as u64, + step.cycle() + Tracer::SUBCYCLE_RS2, + op.previous_cycle, + op.value, + None, + ); + shard_ctx.push_addr_accessed(op.addr); + } } #[derive(Debug)] @@ -295,6 +378,61 @@ impl WriteRD { Ok(()) } + + pub fn emit_op_lk_and_shardram( + &self, + sink: &mut impl LkShardramSink, + shard_ctx: &ShardContext, + cycle: Cycle, + op: &WriteOp, + ) { + let shard_prev_cycle = shard_ctx.aligned_prev_ts(op.previous_cycle); + let shard_cycle = cycle - shard_ctx.current_shard_offset_cycle(); + emit_assert_lt_ops( + sink, + &self.lt_cfg, + shard_prev_cycle, + shard_cycle + Tracer::SUBCYCLE_RD, + ); + sink.emit_send(SendEvent { + ram_type: RAMType::Register, + addr: op.addr, + id: op.register_index() as u64, + cycle: cycle + Tracer::SUBCYCLE_RD, + prev_cycle: op.previous_cycle, + value: op.value.after, + prev_value: Some(op.value.before), + }); + sink.touch_addr(op.addr); + } + + pub fn emit_lk_and_shardram( + &self, + sink: &mut impl LkShardramSink, + shard_ctx: &ShardContext, + step: &StepRecord, + ) { + let op = step.rd().expect("rd op"); + self.emit_op_lk_and_shardram(sink, shard_ctx, step.cycle(), &op) + } + + pub fn emit_shardram(&self, shard_ctx: &mut ShardContext, step: &StepRecord) { + let op = step.rd().expect("rd op"); + self.emit_op_shardram(shard_ctx, step.cycle(), &op) + } + + pub fn emit_op_shardram(&self, shard_ctx: &mut ShardContext, cycle: Cycle, op: &WriteOp) { + shard_ctx.record_send_without_touch( + RAMType::Register, + op.addr, + op.register_index() as u64, + cycle + Tracer::SUBCYCLE_RD, + op.previous_cycle, + op.value.after, + Some(op.value.before), + ); + shard_ctx.push_addr_accessed(op.addr); + } } #[derive(Debug)] @@ -361,6 +499,47 @@ impl ReadMEM { Ok(()) } + + pub fn emit_lk_and_shardram( + &self, + sink: &mut impl LkShardramSink, + shard_ctx: &ShardContext, + step: &StepRecord, + ) { + let op = step.memory_op().expect("memory op"); + let shard_prev_cycle = shard_ctx.aligned_prev_ts(op.previous_cycle); + let shard_cycle = step.cycle() - shard_ctx.current_shard_offset_cycle(); + emit_assert_lt_ops( + sink, + &self.lt_cfg, + shard_prev_cycle, + shard_cycle + Tracer::SUBCYCLE_MEM, + ); + sink.emit_send(SendEvent { + ram_type: RAMType::Memory, + addr: op.addr, + id: op.addr.baddr().0 as u64, + cycle: step.cycle() + Tracer::SUBCYCLE_MEM, + prev_cycle: op.previous_cycle, + value: op.value.after, + prev_value: None, + }); + sink.touch_addr(op.addr); + } + + pub fn emit_shardram(&self, shard_ctx: &mut ShardContext, step: &StepRecord) { + let op = step.memory_op().expect("memory op"); + shard_ctx.record_send_without_touch( + RAMType::Memory, + op.addr, + op.addr.baddr().0 as u64, + step.cycle() + Tracer::SUBCYCLE_MEM, + op.previous_cycle, + op.value.after, + None, + ); + shard_ctx.push_addr_accessed(op.addr); + } } #[derive(Debug)] @@ -434,13 +613,68 @@ impl WriteMEM { Ok(()) } + + pub fn emit_op_lk_and_shardram( + &self, + sink: &mut impl LkShardramSink, + shard_ctx: &ShardContext, + cycle: Cycle, + op: &WriteOp, + ) { + let shard_prev_cycle = shard_ctx.aligned_prev_ts(op.previous_cycle); + let shard_cycle = cycle - shard_ctx.current_shard_offset_cycle(); + emit_assert_lt_ops( + sink, + &self.lt_cfg, + shard_prev_cycle, + shard_cycle + Tracer::SUBCYCLE_MEM, + ); + sink.emit_send(SendEvent { + ram_type: RAMType::Memory, + addr: op.addr, + id: op.addr.baddr().0 as u64, + cycle: cycle + Tracer::SUBCYCLE_MEM, + prev_cycle: op.previous_cycle, + value: op.value.after, + prev_value: Some(op.value.before), + }); + sink.touch_addr(op.addr); + } + + pub fn emit_lk_and_shardram( + &self, + sink: &mut impl LkShardramSink, + shard_ctx: &ShardContext, + step: &StepRecord, + ) { + let op = step.memory_op().expect("memory op"); + self.emit_op_lk_and_shardram(sink, shard_ctx, step.cycle(), &op) + } + + pub fn emit_shardram(&self, shard_ctx: &mut ShardContext, step: &StepRecord) { + let op = step.memory_op().expect("memory op"); + self.emit_op_shardram(shard_ctx, step.cycle(), &op) + } + + pub fn emit_op_shardram(&self, shard_ctx: &mut ShardContext, cycle: Cycle, op: &WriteOp) { + shard_ctx.record_send_without_touch( + RAMType::Memory, + op.addr, + op.addr.baddr().0 as u64, + cycle + Tracer::SUBCYCLE_MEM, + op.previous_cycle, + op.value.after, + Some(op.value.before), + ); + shard_ctx.push_addr_accessed(op.addr); + } } #[derive(Debug)] pub struct MemAddr { - addr: UInt, - low_bits: Vec, - max_bits: usize, + pub addr: UInt, + pub low_bits: Vec, + pub max_bits: usize, } impl MemAddr { @@ -584,6 +818,22 @@ impl MemAddr { Ok(()) } + pub fn emit_lk_and_shardram(&self, sink: &mut impl LkShardramSink, addr: Word) { + let mid_u14 = ((addr & 0xffff) >> Self::N_LOW_BITS) as u16; + sink.emit_lk(LkOp::AssertU14 { value: mid_u14 }); + + for i in 1..UINT_LIMBS { + let high_u16 = ((addr >> (i * 16)) & 0xffff) as u64; + let bits = (self.max_bits - i * 16).min(16); + if bits > 1 { + sink.emit_lk(LkOp::DynamicRange { + value: high_u16, + bits: bits as u32, + }); + } + } + } + fn n_zeros(&self) -> usize { Self::N_LOW_BITS - self.low_bits.len() } diff --git a/ceno_zkvm/src/instructions/riscv/j_insn.rs b/ceno_zkvm/src/instructions/riscv/j_insn.rs index 84cb84679..6cc8b18fe 100644 --- a/ceno_zkvm/src/instructions/riscv/j_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/j_insn.rs @@ -6,7 +6,10 @@ use crate::{ circuit_builder::CircuitBuilder, e2e::ShardContext, error::ZKVMError, - instructions::riscv::insn_base::{StateInOut, WriteRD}, + instructions::{ + gpu::utils::{LkOp, LkShardramSink}, + riscv::insn_base::{StateInOut, WriteRD}, + }, tables::InsnRecord, witness::LkMultiplicity, }; @@ -68,4 +71,26 @@ impl JInstructionConfig { Ok(()) } + + pub fn emit_shardram( + &self, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) { + lk_multiplicity.fetch(step.pc().before.0); + self.rd.emit_shardram(shard_ctx, step); + } + + pub fn emit_lk_and_shardram( + &self, + sink: &mut impl LkShardramSink, + shard_ctx: &ShardContext, + step: &StepRecord, + ) { + sink.emit_lk(LkOp::Fetch { + pc: step.pc().before.0, + }); + self.rd.emit_lk_and_shardram(sink, shard_ctx, step); + } } diff --git a/ceno_zkvm/src/instructions/riscv/jump.rs b/ceno_zkvm/src/instructions/riscv/jump.rs index 7bf1a41f6..c0b121827 100644 --- a/ceno_zkvm/src/instructions/riscv/jump.rs +++ b/ceno_zkvm/src/instructions/riscv/jump.rs @@ -1,12 +1,12 @@ #[cfg(not(feature = "u16limb_circuit"))] mod jal; #[cfg(feature = "u16limb_circuit")] -mod jal_v2; +pub(crate) mod jal_v2; #[cfg(not(feature = "u16limb_circuit"))] mod jalr; #[cfg(feature = "u16limb_circuit")] -mod jalr_v2; +pub(crate) mod jalr_v2; #[cfg(not(feature = "u16limb_circuit"))] pub use jal::JalInstruction; diff --git a/ceno_zkvm/src/instructions/riscv/jump/jal_v2.rs b/ceno_zkvm/src/instructions/riscv/jump/jal_v2.rs index a766ea795..bb8ba0abe 100644 --- a/ceno_zkvm/src/instructions/riscv/jump/jal_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/jump/jal_v2.rs @@ -6,8 +6,10 @@ use crate::{ circuit_builder::CircuitBuilder, e2e::ShardContext, error::ZKVMError, + impl_collect_lk_and_shardram, impl_collect_shardram, impl_gpu_assign, instructions::{ Instruction, + gpu::utils::{LkOp, LkShardramSink, emit_byte_decomposition_ops}, riscv::{ constants::{PC_BITS, UINT_BYTE_LIMBS, UInt8}, j_insn::JInstructionConfig, @@ -44,6 +46,8 @@ impl Instruction for JalInstruction { type InstructionConfig = JalConfig; type InsnType = InsnKind; + const GPU_LK_SHARDRAM: bool = true; + fn inst_kinds() -> &'static [Self::InsnType] { &[InsnKind::JAL] } @@ -121,4 +125,21 @@ impl Instruction for JalInstruction { Ok(()) } + + impl_collect_lk_and_shardram!(j_insn, |sink, step, _config, _ctx| { + let rd_written = split_to_u8(step.rd().unwrap().value.after); + emit_byte_decomposition_ops(sink, &rd_written); + + let last_limb_bits = PC_BITS - UInt8::::LIMB_BITS * (UINT_BYTE_LIMBS - 1); + let additional_bits = + (last_limb_bits..UInt8::::LIMB_BITS).fold(0, |acc, x| acc + (1 << x)); + sink.emit_lk(LkOp::Xor { + a: rd_written[3], + b: additional_bits as u8, + }); + }); + + impl_collect_shardram!(j_insn); + + impl_gpu_assign!(dispatch::GpuWitgenKind::Jal); } diff --git a/ceno_zkvm/src/instructions/riscv/jump/jalr_v2.rs b/ceno_zkvm/src/instructions/riscv/jump/jalr_v2.rs index 7c51728ac..75c0d28cf 100644 --- a/ceno_zkvm/src/instructions/riscv/jump/jalr_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/jump/jalr_v2.rs @@ -7,8 +7,10 @@ use crate::{ circuit_builder::CircuitBuilder, e2e::ShardContext, error::ZKVMError, + impl_collect_lk_and_shardram, impl_collect_shardram, impl_gpu_assign, instructions::{ Instruction, + gpu::utils::emit_const_range_op, riscv::{ constants::{PC_BITS, UINT_LIMBS, UInt}, i_insn::IInstructionConfig, @@ -44,6 +46,8 @@ impl Instruction for JalrInstruction { type InstructionConfig = JalrConfig; type InsnType = InsnKind; + const GPU_LK_SHARDRAM: bool = true; + fn inst_kinds() -> &'static [Self::InsnType] { &[InsnKind::JALR] } @@ -188,4 +192,19 @@ impl Instruction for JalrInstruction { Ok(()) } + + impl_collect_lk_and_shardram!(i_insn, |sink, step, config, _ctx| { + let rd_value = Value::new_unchecked(step.rd().unwrap().value.after); + let rd_limb = rd_value.as_u16_limbs(); + emit_const_range_op(sink, rd_limb[0] as u64, 16); + emit_const_range_op(sink, rd_limb[1] as u64, PC_BITS - 16); + + let imm = InsnRecord::::imm_internal(&step.insn()); + let jump_pc = step.rs1().unwrap().value.wrapping_add_signed(imm.0 as i32); + config.jump_pc_addr.emit_lk_and_shardram(sink, jump_pc); + }); + + impl_collect_shardram!(i_insn); + + impl_gpu_assign!(dispatch::GpuWitgenKind::Jalr); } diff --git a/ceno_zkvm/src/instructions/riscv/logic.rs b/ceno_zkvm/src/instructions/riscv/logic.rs index 9ac2cd4c1..9684c36bf 100644 --- a/ceno_zkvm/src/instructions/riscv/logic.rs +++ b/ceno_zkvm/src/instructions/riscv/logic.rs @@ -1,4 +1,4 @@ -mod logic_circuit; +pub(crate) mod logic_circuit; use gkr_iop::tables::ops::{AndTable, OrTable, XorTable}; use logic_circuit::{LogicInstruction, LogicOp}; diff --git a/ceno_zkvm/src/instructions/riscv/logic/logic_circuit.rs b/ceno_zkvm/src/instructions/riscv/logic/logic_circuit.rs index 4d2cf6db8..6dad59ca7 100644 --- a/ceno_zkvm/src/instructions/riscv/logic/logic_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/logic/logic_circuit.rs @@ -8,8 +8,10 @@ use crate::{ circuit_builder::CircuitBuilder, e2e::ShardContext, error::ZKVMError, + impl_collect_lk_and_shardram, impl_collect_shardram, impl_gpu_assign, instructions::{ Instruction, + gpu::utils::emit_logic_u8_ops, riscv::{constants::UInt8, r_insn::RInstructionConfig}, }, structs::ProgramParams, @@ -31,6 +33,8 @@ impl Instruction for LogicInstruction { type InstructionConfig = LogicConfig; type InsnType = InsnKind; + const GPU_LK_SHARDRAM: bool = true; + fn inst_kinds() -> &'static [Self::InsnType] { &[I::INST_KIND] } @@ -72,16 +76,34 @@ impl Instruction for LogicInstruction { config.assign_instance(instance, shard_ctx, lk_multiplicity, step) } + + impl_collect_lk_and_shardram!(r_insn, |sink, step, _config, _ctx| { + emit_logic_u8_ops::( + sink, + step.rs1().unwrap().value as u64, + step.rs2().unwrap().value as u64, + 4, + ); + }); + + impl_collect_shardram!(r_insn); + + impl_gpu_assign!(dispatch::GpuWitgenKind::LogicR(match I::INST_KIND { + InsnKind::AND => 0, + InsnKind::OR => 1, + InsnKind::XOR => 2, + kind => unreachable!("unsupported logic GPU kind: {kind:?}"), + })); } /// This config implements R-Instructions that represent registers values as 4 * u8. /// Non-generic code shared by several circuits. #[derive(Debug)] pub struct LogicConfig { - r_insn: RInstructionConfig, + pub(crate) r_insn: RInstructionConfig, - rs1_read: UInt8, - rs2_read: UInt8, + pub(crate) rs1_read: UInt8, + pub(crate) rs2_read: UInt8, pub(crate) rd_written: UInt8, } diff --git a/ceno_zkvm/src/instructions/riscv/logic_imm.rs b/ceno_zkvm/src/instructions/riscv/logic_imm.rs index a4b46edcc..44a51233b 100644 --- a/ceno_zkvm/src/instructions/riscv/logic_imm.rs +++ b/ceno_zkvm/src/instructions/riscv/logic_imm.rs @@ -2,7 +2,7 @@ mod logic_imm_circuit; #[cfg(feature = "u16limb_circuit")] -mod logic_imm_circuit_v2; +pub(crate) mod logic_imm_circuit_v2; #[cfg(not(feature = "u16limb_circuit"))] pub use crate::instructions::riscv::logic_imm::logic_imm_circuit::LogicInstruction; diff --git a/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit_v2.rs index 14c2adeb0..1137974f8 100644 --- a/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit_v2.rs @@ -9,8 +9,10 @@ use crate::{ circuit_builder::CircuitBuilder, e2e::ShardContext, error::ZKVMError, + impl_collect_lk_and_shardram, impl_collect_shardram, impl_gpu_assign, instructions::{ Instruction, + gpu::utils::emit_logic_u8_ops, riscv::{ constants::{LIMB_BITS, LIMB_MASK, UInt8}, i_insn::IInstructionConfig, @@ -33,6 +35,8 @@ impl Instruction for LogicInstruction { type InstructionConfig = LogicConfig; type InsnType = InsnKind; + const GPU_LK_SHARDRAM: bool = true; + fn inst_kinds() -> &'static [Self::InsnType] { &[I::INST_KIND] } @@ -124,18 +128,39 @@ impl Instruction for LogicInstruction { config.assign_instance(instance, shard_ctx, lkm, step) } + + impl_collect_lk_and_shardram!(i_insn, |sink, step, _config, _ctx| { + let rs1_lo = step.rs1().unwrap().value & LIMB_MASK; + let rs1_hi = (step.rs1().unwrap().value >> LIMB_BITS) & LIMB_MASK; + let imm_lo = InsnRecord::::imm_internal(&step.insn()).0 as u32 & LIMB_MASK; + let imm_hi = (InsnRecord::::imm_signed_internal(&step.insn()).0 as u32 + >> LIMB_BITS) + & LIMB_MASK; + + emit_logic_u8_ops::(sink, rs1_lo.into(), imm_lo.into(), 2); + emit_logic_u8_ops::(sink, rs1_hi.into(), imm_hi.into(), 2); + }); + + impl_collect_shardram!(i_insn); + + impl_gpu_assign!(dispatch::GpuWitgenKind::LogicI(match I::INST_KIND { + InsnKind::ANDI => 0, + InsnKind::ORI => 1, + InsnKind::XORI => 2, + kind => unreachable!("unsupported logic_imm GPU kind: {kind:?}"), + })); } /// This config implements I-Instructions that represent registers values as 4 * u8. /// Non-generic code shared by several circuits. #[derive(Debug)] pub struct LogicConfig { - i_insn: IInstructionConfig, + pub(crate) i_insn: IInstructionConfig, - rs1_read: UInt8, + pub(crate) rs1_read: UInt8, pub(crate) rd_written: UInt8, - imm_lo: UIntLimbs<{ LIMB_BITS }, 8, E>, - imm_hi: UIntLimbs<{ LIMB_BITS }, 8, E>, + pub(crate) imm_lo: UIntLimbs<{ LIMB_BITS }, 8, E>, + pub(crate) imm_hi: UIntLimbs<{ LIMB_BITS }, 8, E>, } impl LogicConfig { diff --git a/ceno_zkvm/src/instructions/riscv/lui.rs b/ceno_zkvm/src/instructions/riscv/lui.rs index deb7b5736..814924fda 100644 --- a/ceno_zkvm/src/instructions/riscv/lui.rs +++ b/ceno_zkvm/src/instructions/riscv/lui.rs @@ -6,8 +6,10 @@ use crate::{ circuit_builder::CircuitBuilder, e2e::ShardContext, error::ZKVMError, + impl_collect_lk_and_shardram, impl_collect_shardram, impl_gpu_assign, instructions::{ Instruction, + gpu::utils::emit_const_range_op, riscv::{ constants::{UINT_BYTE_LIMBS, UInt8}, i_insn::IInstructionConfig, @@ -36,6 +38,8 @@ impl Instruction for LuiInstruction { type InstructionConfig = LuiConfig; type InsnType = InsnKind; + const GPU_LK_SHARDRAM: bool = true; + fn inst_kinds() -> &'static [Self::InsnType] { &[InsnKind::LUI] } @@ -103,7 +107,7 @@ impl Instruction for LuiInstruction { .i_insn .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; - let rd_written = split_to_u8(step.rd().unwrap().value.after); + let rd_written = split_to_u8::(step.rd().unwrap().value.after); for (val, witin) in izip!(rd_written.iter().skip(1), config.rd_written) { lk_multiplicity.assert_ux::<8>(*val as u64); set_val!(instance, witin, E::BaseField::from_canonical_u8(*val)); @@ -113,6 +117,17 @@ impl Instruction for LuiInstruction { Ok(()) } + + impl_collect_lk_and_shardram!(i_insn, |sink, step, _config, _ctx| { + let rd_written = split_to_u8::(step.rd().unwrap().value.after); + for val in rd_written.iter().skip(1) { + emit_const_range_op(sink, *val as u64, 8); + } + }); + + impl_collect_shardram!(i_insn); + + impl_gpu_assign!(dispatch::GpuWitgenKind::Lui); } #[cfg(test)] diff --git a/ceno_zkvm/src/instructions/riscv/memory.rs b/ceno_zkvm/src/instructions/riscv/memory.rs index bb29491f7..294d7fd44 100644 --- a/ceno_zkvm/src/instructions/riscv/memory.rs +++ b/ceno_zkvm/src/instructions/riscv/memory.rs @@ -6,9 +6,9 @@ pub mod load; pub mod store; #[cfg(feature = "u16limb_circuit")] -mod load_v2; +pub mod load_v2; #[cfg(feature = "u16limb_circuit")] -mod store_v2; +pub(crate) mod store_v2; #[cfg(test)] mod test; diff --git a/ceno_zkvm/src/instructions/riscv/memory/gadget.rs b/ceno_zkvm/src/instructions/riscv/memory/gadget.rs index 3a8da4a09..5b95ddb05 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/gadget.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/gadget.rs @@ -14,10 +14,10 @@ use p3::field::{Field, FieldAlgebra}; use witness::set_val; pub struct MemWordUtil { - prev_limb_bytes: Vec, - rs2_limb_bytes: Vec, + pub(crate) prev_limb_bytes: Vec, + pub(crate) rs2_limb_bytes: Vec, - expected_limb: Option, + pub(crate) expected_limb: Option, expect_limbs_expr: [Expression; 2], } @@ -138,7 +138,7 @@ impl MemWordUtil { step: &StepRecord, shift: u32, ) -> Result<(), ZKVMError> { - let memory_op = step.memory_op().clone().unwrap(); + let memory_op = step.memory_op().unwrap(); let prev_value = Value::new_unchecked(memory_op.value.before); let rs2_value = Value::new_unchecked(step.rs2().unwrap().value); diff --git a/ceno_zkvm/src/instructions/riscv/memory/load.rs b/ceno_zkvm/src/instructions/riscv/memory/load.rs index 818e8902a..fc37371bc 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/load.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/load.rs @@ -4,6 +4,7 @@ use crate::{ e2e::ShardContext, error::ZKVMError, gadgets::SignedExtendConfig, + impl_collect_lk_and_shardram, impl_collect_shardram, impl_gpu_assign, instructions::{ Instruction, riscv::{ @@ -22,16 +23,16 @@ use p3::field::FieldAlgebra; use std::marker::PhantomData; pub struct LoadConfig { - im_insn: IMInstructionConfig, + pub im_insn: IMInstructionConfig, - rs1_read: UInt, - imm: WitIn, - memory_addr: MemAddr, + pub rs1_read: UInt, + pub imm: WitIn, + pub memory_addr: MemAddr, - memory_read: UInt, - target_limb: Option, - target_limb_bytes: Option>, - signed_extend_config: Option>, + pub memory_read: UInt, + pub target_limb: Option, + pub target_limb_bytes: Option>, + pub signed_extend_config: Option>, } pub struct LoadInstruction(PhantomData<(E, I)>); @@ -40,6 +41,8 @@ impl Instruction for LoadInstruction; type InsnType = InsnKind; + const GPU_LK_SHARDRAM: bool = matches!(I::INST_KIND, InsnKind::LW); + fn inst_kinds() -> &'static [Self::InsnType] { &[I::INST_KIND] } @@ -226,4 +229,20 @@ impl Instruction for LoadInstruction::imm_internal(&step.insn()); + let unaligned_addr = + ByteAddr::from(step.rs1().unwrap().value.wrapping_add_signed(imm.0 as i32)); + config + .memory_addr + .emit_lk_and_shardram(sink, unaligned_addr.into()); + }); + + impl_collect_shardram!(im_insn); + + impl_gpu_assign!(match I::INST_KIND { + InsnKind::LW => Some(dispatch::GpuWitgenKind::Lw), + _ => None, + }); } diff --git a/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs b/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs index 5a9ed40eb..1810319eb 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs @@ -4,6 +4,7 @@ use crate::{ e2e::ShardContext, error::ZKVMError, gadgets::SignedExtendConfig, + impl_collect_lk_and_shardram, impl_collect_shardram, impl_gpu_assign, instructions::{ Instruction, riscv::{ @@ -25,17 +26,17 @@ use p3::field::{Field, FieldAlgebra}; use std::marker::PhantomData; pub struct LoadConfig { - im_insn: IMInstructionConfig, + pub im_insn: IMInstructionConfig, - rs1_read: UInt, - imm: WitIn, - imm_sign: WitIn, - memory_addr: MemAddr, + pub rs1_read: UInt, + pub imm: WitIn, + pub imm_sign: WitIn, + pub memory_addr: MemAddr, - memory_read: UInt, - target_limb: Option, - target_limb_bytes: Option>, - signed_extend_config: Option>, + pub memory_read: UInt, + pub target_limb: Option, + pub target_limb_bytes: Option>, + pub signed_extend_config: Option>, } pub struct LoadInstruction(PhantomData<(E, I)>); @@ -44,6 +45,11 @@ impl Instruction for LoadInstruction; type InsnType = InsnKind; + const GPU_LK_SHARDRAM: bool = matches!( + I::INST_KIND, + InsnKind::LW | InsnKind::LB | InsnKind::LBU | InsnKind::LH | InsnKind::LHU + ); + fn inst_kinds() -> &'static [Self::InsnType] { &[I::INST_KIND] } @@ -251,4 +257,38 @@ impl Instruction for LoadInstruction::imm_internal(&step.insn()); + let unaligned_addr = + ByteAddr::from(step.rs1().unwrap().value.wrapping_add_signed(imm.0 as i32)); + config + .memory_addr + .emit_lk_and_shardram(sink, unaligned_addr.into()); + }); + + impl_collect_shardram!(im_insn); + + impl_gpu_assign!(match I::INST_KIND { + InsnKind::LW => Some(dispatch::GpuWitgenKind::Lw), + InsnKind::LH => Some(dispatch::GpuWitgenKind::LoadSub { + load_width: 16, + is_signed: 1, + }), + InsnKind::LHU => Some(dispatch::GpuWitgenKind::LoadSub { + load_width: 16, + is_signed: 0, + }), + InsnKind::LB => Some(dispatch::GpuWitgenKind::LoadSub { + load_width: 8, + is_signed: 1, + }), + InsnKind::LBU => Some(dispatch::GpuWitgenKind::LoadSub { + load_width: 8, + is_signed: 0, + }), + _ => None, + }); } diff --git a/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs b/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs index a1bd7a812..a6afeb93a 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs @@ -3,8 +3,10 @@ use crate::{ circuit_builder::CircuitBuilder, e2e::ShardContext, error::ZKVMError, + impl_collect_lk_and_shardram, impl_collect_shardram, impl_gpu_assign, instructions::{ Instruction, + gpu::utils::{emit_const_range_op, emit_u16_limbs}, riscv::{ RIVInstruction, constants::{MEM_BITS, UInt}, @@ -24,16 +26,16 @@ use p3::field::{Field, FieldAlgebra}; use std::marker::PhantomData; pub struct StoreConfig { - s_insn: SInstructionConfig, + pub(crate) s_insn: SInstructionConfig, - rs1_read: UInt, - rs2_read: UInt, - imm: WitIn, - imm_sign: WitIn, - prev_memory_value: UInt, + pub(crate) rs1_read: UInt, + pub(crate) rs2_read: UInt, + pub(crate) imm: WitIn, + pub(crate) imm_sign: WitIn, + pub(crate) prev_memory_value: UInt, - memory_addr: MemAddr, - next_memory_value: Option>, + pub(crate) memory_addr: MemAddr, + pub(crate) next_memory_value: Option>, } pub struct StoreInstruction(PhantomData<(E, I)>); @@ -44,6 +46,8 @@ impl Instruction type InstructionConfig = StoreConfig; type InsnType = InsnKind; + const GPU_LK_SHARDRAM: bool = true; + fn inst_kinds() -> &'static [Self::InsnType] { &[I::INST_KIND] } @@ -171,4 +175,36 @@ impl Instruction Ok(()) } + + impl_collect_lk_and_shardram!(s_insn, |sink, step, config, _ctx| { + emit_u16_limbs(sink, step.memory_op().unwrap().value.before); + + let imm = InsnRecord::::imm_internal(&step.insn()); + let addr = ByteAddr::from(step.rs1().unwrap().value.wrapping_add_signed(imm.0 as i32)); + config.memory_addr.emit_lk_and_shardram(sink, addr.into()); + + if N_ZEROS == 0 { + let memory_op = step.memory_op().unwrap(); + let prev_value = Value::new_unchecked(memory_op.value.before); + let rs2_value = Value::new_unchecked(step.rs2().unwrap().value); + let prev_limb = prev_value.as_u16_limbs()[((addr.shift() >> 1) & 1) as usize]; + let rs2_limb = rs2_value.as_u16_limbs()[0]; + + for byte in prev_limb.to_le_bytes() { + emit_const_range_op(sink, byte as u64, 8); + } + for byte in rs2_limb.to_le_bytes() { + emit_const_range_op(sink, byte as u64, 8); + } + } + }); + + impl_collect_shardram!(s_insn); + + impl_gpu_assign!(match I::INST_KIND { + InsnKind::SW => Some(dispatch::GpuWitgenKind::Sw), + InsnKind::SH => Some(dispatch::GpuWitgenKind::Sh), + InsnKind::SB => Some(dispatch::GpuWitgenKind::Sb), + _ => None, + }); } diff --git a/ceno_zkvm/src/instructions/riscv/mulh.rs b/ceno_zkvm/src/instructions/riscv/mulh.rs index 4a4c8065b..61279fbb6 100644 --- a/ceno_zkvm/src/instructions/riscv/mulh.rs +++ b/ceno_zkvm/src/instructions/riscv/mulh.rs @@ -4,7 +4,7 @@ use ceno_emul::InsnKind; #[cfg(not(feature = "u16limb_circuit"))] mod mulh_circuit; #[cfg(feature = "u16limb_circuit")] -mod mulh_circuit_v2; +pub(crate) mod mulh_circuit_v2; #[cfg(not(feature = "u16limb_circuit"))] use mulh_circuit::MulhInstructionBase; diff --git a/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit_v2.rs index f3bddff1b..2ed8358a6 100644 --- a/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit_v2.rs @@ -1,8 +1,10 @@ use crate::{ circuit_builder::CircuitBuilder, error::ZKVMError, + impl_collect_lk_and_shardram, impl_collect_shardram, impl_gpu_assign, instructions::{ Instruction, + gpu::utils::{LkOp, LkShardramSink}, riscv::{ RIVInstruction, constants::{LIMB_BITS, UINT_LIMBS, UInt}, @@ -26,13 +28,13 @@ use std::{array, marker::PhantomData}; pub struct MulhInstructionBase(PhantomData<(E, I)>); pub struct MulhConfig { - rs1_read: UInt, - rs2_read: UInt, - r_insn: RInstructionConfig, - rd_low: [WitIn; UINT_LIMBS], - rd_high: Option<[WitIn; UINT_LIMBS]>, - rs1_ext: Option, - rs2_ext: Option, + pub(crate) rs1_read: UInt, + pub(crate) rs2_read: UInt, + pub(crate) r_insn: RInstructionConfig, + pub(crate) rd_low: [WitIn; UINT_LIMBS], + pub(crate) rd_high: Option<[WitIn; UINT_LIMBS]>, + pub(crate) rs1_ext: Option, + pub(crate) rs2_ext: Option, phantom: PhantomData, } @@ -40,6 +42,8 @@ impl Instruction for MulhInstructionBas type InstructionConfig = MulhConfig; type InsnType = InsnKind; + const GPU_LK_SHARDRAM: bool = true; + fn inst_kinds() -> &'static [Self::InsnType] { &[I::INST_KIND] } @@ -327,6 +331,97 @@ impl Instruction for MulhInstructionBas Ok(()) } + + impl_collect_lk_and_shardram!(r_insn, |sink, step, _config, _ctx| { + let rs1 = step.rs1().unwrap().value; + let rs1_val = Value::new_unchecked(rs1); + let rs2 = step.rs2().unwrap().value; + let rs2_val = Value::new_unchecked(rs2); + + let (rd_high, rd_low, carry, rs1_ext, rs2_ext) = run_mulh::( + I::INST_KIND, + rs1_val + .as_u16_limbs() + .iter() + .map(|x| *x as u32) + .collect::>() + .as_slice(), + rs2_val + .as_u16_limbs() + .iter() + .map(|x| *x as u32) + .collect::>() + .as_slice(), + ); + + for (rd_low, carry_low) in rd_low.iter().zip(carry[0..UINT_LIMBS].iter()) { + sink.emit_lk(LkOp::DynamicRange { + value: *rd_low as u64, + bits: 16, + }); + sink.emit_lk(LkOp::DynamicRange { + value: *carry_low as u64, + bits: 18, + }); + } + + match I::INST_KIND { + InsnKind::MULH | InsnKind::MULHU | InsnKind::MULHSU => { + for (rd_high, carry_high) in rd_high.iter().zip(carry[UINT_LIMBS..].iter()) { + sink.emit_lk(LkOp::DynamicRange { + value: *rd_high as u64, + bits: 16, + }); + sink.emit_lk(LkOp::DynamicRange { + value: *carry_high as u64, + bits: 18, + }); + } + } + _ => {} + } + + let sign_mask = 1 << (LIMB_BITS - 1); + let ext = (1 << LIMB_BITS) - 1; + let rs1_sign = rs1_ext / ext; + let rs2_sign = rs2_ext / ext; + let rs1_limbs = rs1_val.as_u16_limbs(); + let rs2_limbs = rs2_val.as_u16_limbs(); + + match I::INST_KIND { + InsnKind::MULH => { + sink.emit_lk(LkOp::DynamicRange { + value: (2 * (rs1_limbs[UINT_LIMBS - 1] as u32 - rs1_sign * sign_mask)) as u64, + bits: 16, + }); + sink.emit_lk(LkOp::DynamicRange { + value: (2 * (rs2_limbs[UINT_LIMBS - 1] as u32 - rs2_sign * sign_mask)) as u64, + bits: 16, + }); + } + InsnKind::MULHSU => { + sink.emit_lk(LkOp::DynamicRange { + value: (2 * (rs1_limbs[UINT_LIMBS - 1] as u32 - rs1_sign * sign_mask)) as u64, + bits: 16, + }); + sink.emit_lk(LkOp::DynamicRange { + value: (rs2_limbs[UINT_LIMBS - 1] as u32 - rs2_sign * sign_mask) as u64, + bits: 16, + }); + } + _ => {} + } + }); + + impl_collect_shardram!(r_insn); + + impl_gpu_assign!(match I::INST_KIND { + InsnKind::MUL => Some(dispatch::GpuWitgenKind::Mul(0u32)), + InsnKind::MULH => Some(dispatch::GpuWitgenKind::Mul(1u32)), + InsnKind::MULHU => Some(dispatch::GpuWitgenKind::Mul(2u32)), + InsnKind::MULHSU => Some(dispatch::GpuWitgenKind::Mul(3u32)), + _ => None, + }); } fn run_mulh( diff --git a/ceno_zkvm/src/instructions/riscv/r_insn.rs b/ceno_zkvm/src/instructions/riscv/r_insn.rs index a4b9bb128..525998e41 100644 --- a/ceno_zkvm/src/instructions/riscv/r_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/r_insn.rs @@ -6,7 +6,10 @@ use crate::{ circuit_builder::CircuitBuilder, e2e::ShardContext, error::ZKVMError, - instructions::riscv::insn_base::{ReadRS1, ReadRS2, StateInOut, WriteRD}, + instructions::{ + gpu::utils::{LkOp, LkShardramSink}, + riscv::insn_base::{ReadRS1, ReadRS2, StateInOut, WriteRD}, + }, tables::InsnRecord, witness::LkMultiplicity, }; @@ -81,4 +84,30 @@ impl RInstructionConfig { Ok(()) } + + pub fn emit_lk_and_shardram( + &self, + sink: &mut impl LkShardramSink, + shard_ctx: &ShardContext, + step: &StepRecord, + ) { + sink.emit_lk(LkOp::Fetch { + pc: step.pc().before.0, + }); + self.rs1.emit_lk_and_shardram(sink, shard_ctx, step); + self.rs2.emit_lk_and_shardram(sink, shard_ctx, step); + self.rd.emit_lk_and_shardram(sink, shard_ctx, step); + } + + pub fn emit_shardram( + &self, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) { + lk_multiplicity.fetch(step.pc().before.0); + self.rs1.emit_shardram(shard_ctx, step); + self.rs2.emit_shardram(shard_ctx, step); + self.rd.emit_shardram(shard_ctx, step); + } } diff --git a/ceno_zkvm/src/instructions/riscv/rv32im.rs b/ceno_zkvm/src/instructions/riscv/rv32im.rs index 091ce3000..1227032f0 100644 --- a/ceno_zkvm/src/instructions/riscv/rv32im.rs +++ b/ceno_zkvm/src/instructions/riscv/rv32im.rs @@ -69,6 +69,7 @@ use std::{ collections::{BTreeMap, HashMap}, }; use strum::{EnumCount, IntoEnumIterator}; +use tracing::info_span; pub mod mmu; @@ -681,13 +682,17 @@ impl Rv32imConfig { let records = instrunction_dispatch_ctx .records_for_kinds::() .unwrap_or(&[]); - witness.assign_opcode_circuit::<$instruction>( - cs, - shard_ctx, - &self.$config, - shard_steps, - records, - )?; + let n = records.len(); + info_span!("assign_chip", chip = %<$instruction>::name(), n) + .in_scope(|| { + witness.assign_opcode_circuit::<$instruction>( + cs, + shard_ctx, + &self.$config, + shard_steps, + records, + ) + })?; }}; } @@ -696,13 +701,17 @@ impl Rv32imConfig { let records = instrunction_dispatch_ctx .records_for_ecall_code($code) .unwrap_or(&[]); - witness.assign_opcode_circuit::<$instruction>( - cs, - shard_ctx, - &self.$config, - shard_steps, - records, - )?; + let n = records.len(); + info_span!("assign_chip", chip = %<$instruction>::name(), n) + .in_scope(|| { + witness.assign_opcode_circuit::<$instruction>( + cs, + shard_ctx, + &self.$config, + shard_steps, + records, + ) + })?; }}; } @@ -846,22 +855,20 @@ impl Rv32imConfig { cs: &ZKVMConstraintSystem, witness: &mut ZKVMWitnesses, ) -> Result<(), ZKVMError> { - witness.assign_table_circuit::>( - cs, - &self.dynamic_range_config, - &(), - )?; - witness.assign_table_circuit::>( - cs, - &self.double_u8_range_config, - &(), - )?; - witness.assign_table_circuit::>(cs, &self.and_table_config, &())?; - witness.assign_table_circuit::>(cs, &self.or_table_config, &())?; - witness.assign_table_circuit::>(cs, &self.xor_table_config, &())?; - witness.assign_table_circuit::>(cs, &self.ltu_config, &())?; + macro_rules! assign_table { + ($table:ty, $config:expr) => { + info_span!("assign_table", table = %<$table>::name()) + .in_scope(|| witness.assign_table_circuit::<$table>(cs, $config, &()))?; + }; + } + assign_table!(DynamicRangeTableCircuit, &self.dynamic_range_config); + assign_table!(DoubleU8TableCircuit, &self.double_u8_range_config); + assign_table!(AndTableCircuit, &self.and_table_config); + assign_table!(OrTableCircuit, &self.or_table_config); + assign_table!(XorTableCircuit, &self.xor_table_config); + assign_table!(LtuTableCircuit, &self.ltu_config); #[cfg(not(feature = "u16limb_circuit"))] - witness.assign_table_circuit::>(cs, &self.pow_config, &())?; + assign_table!(PowTableCircuit, &self.pow_config); Ok(()) } @@ -1016,13 +1023,17 @@ impl DummyExtraConfig { let phantom_log_pc_cycle_records = instrunction_dispatch_ctx .records_for_ecall_code(LogPcCycleSpec::CODE) .unwrap_or(&[]); - witness.assign_opcode_circuit::>( - cs, - shard_ctx, - &self.phantom_log_pc_cycle, - shard_steps, - phantom_log_pc_cycle_records, - )?; + let n = phantom_log_pc_cycle_records.len(); + info_span!("assign_chip", chip = %LargeEcallDummy::::name(), n) + .in_scope(|| { + witness.assign_opcode_circuit::>( + cs, + shard_ctx, + &self.phantom_log_pc_cycle, + shard_steps, + phantom_log_pc_cycle_records, + ) + })?; Ok(()) } } diff --git a/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs b/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs index b3f96feb0..3d4f1efc0 100644 --- a/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs +++ b/ceno_zkvm/src/instructions/riscv/rv32im/mmu.rs @@ -199,16 +199,20 @@ impl MmuConfig { .filter(|(_, _, record)| !record.is_empty()) .collect_vec(); - witness.assign_table_circuit::>( - cs, - &self.local_final_circuit, - &(shard_ctx, all_records.as_slice()), - )?; - witness.assign_shared_circuit( - cs, - &(shard_ctx, all_records.as_slice()), - &self.ram_bus_circuit, - )?; + tracing::info_span!("local_final_circuit").in_scope(|| { + witness.assign_table_circuit::>( + cs, + &self.local_final_circuit, + &(shard_ctx, all_records.as_slice()), + ) + })?; + tracing::info_span!("shared_circuit").in_scope(|| { + witness.assign_shared_circuit( + cs, + &(shard_ctx, all_records.as_slice()), + &self.ram_bus_circuit, + ) + })?; Ok(()) } diff --git a/ceno_zkvm/src/instructions/riscv/s_insn.rs b/ceno_zkvm/src/instructions/riscv/s_insn.rs index f252a7c60..38dd29555 100644 --- a/ceno_zkvm/src/instructions/riscv/s_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/s_insn.rs @@ -3,7 +3,10 @@ use crate::{ circuit_builder::CircuitBuilder, e2e::ShardContext, error::ZKVMError, - instructions::riscv::insn_base::{ReadRS1, ReadRS2, StateInOut, WriteMEM}, + instructions::{ + gpu::utils::{LkOp, LkShardramSink}, + riscv::insn_base::{ReadRS1, ReadRS2, StateInOut, WriteMEM}, + }, tables::InsnRecord, witness::LkMultiplicity, }; @@ -16,10 +19,10 @@ use multilinear_extensions::{Expression, ToExpr}; /// - Registers reads. /// - Memory write pub struct SInstructionConfig { - vm_state: StateInOut, - rs1: ReadRS1, - rs2: ReadRS2, - mem_write: WriteMEM, + pub(crate) vm_state: StateInOut, + pub(crate) rs1: ReadRS1, + pub(crate) rs2: ReadRS2, + pub(crate) mem_write: WriteMEM, } impl SInstructionConfig { @@ -91,4 +94,31 @@ impl SInstructionConfig { Ok(()) } + + pub fn emit_shardram( + &self, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) { + lk_multiplicity.fetch(step.pc().before.0); + self.rs1.emit_shardram(shard_ctx, step); + self.rs2.emit_shardram(shard_ctx, step); + self.mem_write.emit_shardram(shard_ctx, step); + } + + #[allow(dead_code)] + pub fn emit_lk_and_shardram( + &self, + sink: &mut impl LkShardramSink, + shard_ctx: &ShardContext, + step: &StepRecord, + ) { + sink.emit_lk(LkOp::Fetch { + pc: step.pc().before.0, + }); + self.rs1.emit_lk_and_shardram(sink, shard_ctx, step); + self.rs2.emit_lk_and_shardram(sink, shard_ctx, step); + self.mem_write.emit_lk_and_shardram(sink, shard_ctx, step); + } } diff --git a/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs index 310d17491..493f76f4e 100644 --- a/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs @@ -1,8 +1,10 @@ use crate::e2e::ShardContext; /// constrain implementation follow from https://github.com/openvm-org/openvm/blob/main/extensions/rv32im/circuit/src/shift/core.rs use crate::{ + impl_collect_lk_and_shardram, impl_collect_shardram, impl_gpu_assign, instructions::{ Instruction, + gpu::utils::{LkOp, LkShardramSink, emit_byte_decomposition_ops, emit_const_range_op}, riscv::{ RIVInstruction, constants::{UINT_BYTE_LIMBS, UInt8}, @@ -206,6 +208,45 @@ impl }) } + pub fn emit_lk_and_shardram( + &self, + sink: &mut impl LkShardramSink, + kind: InsnKind, + b: u32, + c: u32, + ) { + let b = split_to_limb::(b); + let c = split_to_limb::(c); + let (_, limb_shift, bit_shift) = run_shift::( + kind, + &b.clone().try_into().unwrap(), + &c.clone().try_into().unwrap(), + ); + + let bit_shift_carry: [u32; NUM_LIMBS] = array::from_fn(|i| match kind { + InsnKind::SLL | InsnKind::SLLI => b[i] >> (LIMB_BITS - bit_shift), + _ => b[i] % (1 << bit_shift), + }); + for val in bit_shift_carry { + sink.emit_lk(LkOp::DynamicRange { + value: val as u64, + bits: bit_shift as u32, + }); + } + + let num_bits_log = (NUM_LIMBS * LIMB_BITS).ilog2(); + let carry_quotient = + (((c[0] as usize) - bit_shift - limb_shift * LIMB_BITS) >> num_bits_log) as u64; + emit_const_range_op(sink, carry_quotient, LIMB_BITS - num_bits_log as usize); + + if matches!(kind, InsnKind::SRA | InsnKind::SRAI) { + sink.emit_lk(LkOp::Xor { + a: b[NUM_LIMBS - 1] as u8, + b: (1 << (LIMB_BITS - 1)) as u8, + }); + } + } + pub fn assign_instances( &self, instance: &mut [::BaseField], @@ -265,11 +306,11 @@ impl } pub struct ShiftRTypeConfig { - shift_base_config: ShiftBaseConfig, - rs1_read: UInt8, - rs2_read: UInt8, + pub(crate) shift_base_config: ShiftBaseConfig, + pub(crate) rs1_read: UInt8, + pub(crate) rs2_read: UInt8, pub rd_written: UInt8, - r_insn: RInstructionConfig, + pub(crate) r_insn: RInstructionConfig, } pub struct ShiftLogicalInstruction(PhantomData<(E, I)>); @@ -278,6 +319,8 @@ impl Instruction for ShiftLogicalInstru type InstructionConfig = ShiftRTypeConfig; type InsnType = InsnKind; + const GPU_LK_SHARDRAM: bool = true; + fn inst_kinds() -> &'static [Self::InsnType] { &[I::INST_KIND] } @@ -363,14 +406,34 @@ impl Instruction for ShiftLogicalInstru Ok(()) } + + impl_collect_lk_and_shardram!(r_insn, |sink, step, config, _ctx| { + let rd_written = split_to_u8::(step.rd().unwrap().value.after); + emit_byte_decomposition_ops(sink, &rd_written); + config.shift_base_config.emit_lk_and_shardram( + sink, + I::INST_KIND, + step.rs1().unwrap().value, + step.rs2().unwrap().value, + ); + }); + + impl_collect_shardram!(r_insn); + + impl_gpu_assign!(dispatch::GpuWitgenKind::ShiftR(match I::INST_KIND { + InsnKind::SLL => 0u32, + InsnKind::SRL => 1u32, + InsnKind::SRA => 2u32, + _ => unreachable!(), + })); } pub struct ShiftImmConfig { - shift_base_config: ShiftBaseConfig, - rs1_read: UInt8, + pub(crate) shift_base_config: ShiftBaseConfig, + pub(crate) rs1_read: UInt8, pub rd_written: UInt8, - i_insn: IInstructionConfig, - imm: WitIn, + pub(crate) i_insn: IInstructionConfig, + pub(crate) imm: WitIn, } pub struct ShiftImmInstruction(PhantomData<(E, I)>); @@ -379,6 +442,8 @@ impl Instruction for ShiftImmInstructio type InstructionConfig = ShiftImmConfig; type InsnType = InsnKind; + const GPU_LK_SHARDRAM: bool = true; + fn inst_kinds() -> &'static [Self::InsnType] { &[I::INST_KIND] } @@ -466,6 +531,26 @@ impl Instruction for ShiftImmInstructio Ok(()) } + + impl_collect_lk_and_shardram!(i_insn, |sink, step, config, _ctx| { + let rd_written = split_to_u8::(step.rd().unwrap().value.after); + emit_byte_decomposition_ops(sink, &rd_written); + config.shift_base_config.emit_lk_and_shardram( + sink, + I::INST_KIND, + step.rs1().unwrap().value, + step.insn().imm as i16 as u16 as u32, + ); + }); + + impl_collect_shardram!(i_insn); + + impl_gpu_assign!(dispatch::GpuWitgenKind::ShiftI(match I::INST_KIND { + InsnKind::SLLI => 0u32, + InsnKind::SRLI => 1u32, + InsnKind::SRAI => 2u32, + _ => unreachable!(), + })); } fn run_shift( diff --git a/ceno_zkvm/src/instructions/riscv/slt.rs b/ceno_zkvm/src/instructions/riscv/slt.rs index a0dc51bd6..01d39a49c 100644 --- a/ceno_zkvm/src/instructions/riscv/slt.rs +++ b/ceno_zkvm/src/instructions/riscv/slt.rs @@ -1,7 +1,7 @@ #[cfg(not(feature = "u16limb_circuit"))] mod slt_circuit; #[cfg(feature = "u16limb_circuit")] -mod slt_circuit_v2; +pub(crate) mod slt_circuit_v2; use ceno_emul::InsnKind; diff --git a/ceno_zkvm/src/instructions/riscv/slt/slt_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/slt/slt_circuit_v2.rs index d57aeb2cd..f397dcb5e 100644 --- a/ceno_zkvm/src/instructions/riscv/slt/slt_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/slt/slt_circuit_v2.rs @@ -4,8 +4,10 @@ use crate::{ e2e::ShardContext, error::ZKVMError, gadgets::{UIntLimbsLT, UIntLimbsLTConfig}, + impl_collect_lk_and_shardram, impl_collect_shardram, impl_gpu_assign, instructions::{ Instruction, + gpu::utils::emit_uint_limbs_lt_ops, riscv::{RIVInstruction, constants::UInt, r_insn::RInstructionConfig}, }, structs::ProgramParams, @@ -19,19 +21,21 @@ pub struct SetLessThanInstruction(PhantomData<(E, I)>); /// This config handles R-Instructions that represent registers values as 2 * u16. pub struct SetLessThanConfig { - r_insn: RInstructionConfig, + pub(crate) r_insn: RInstructionConfig, - rs1_read: UInt, - rs2_read: UInt, + pub(crate) rs1_read: UInt, + pub(crate) rs2_read: UInt, #[allow(dead_code)] pub(crate) rd_written: UInt, - uint_lt_config: UIntLimbsLTConfig, + pub(crate) uint_lt_config: UIntLimbsLTConfig, } impl Instruction for SetLessThanInstruction { type InstructionConfig = SetLessThanConfig; type InsnType = InsnKind; + const GPU_LK_SHARDRAM: bool = true; + fn inst_kinds() -> &'static [Self::InsnType] { &[I::INST_KIND] } @@ -113,4 +117,25 @@ impl Instruction for SetLessThanInstruc )?; Ok(()) } + + impl_collect_lk_and_shardram!(r_insn, |sink, step, _config, _ctx| { + let rs1_value = Value::new_unchecked(step.rs1().unwrap().value); + let rs2_value = Value::new_unchecked(step.rs2().unwrap().value); + let rs1_limbs = rs1_value.as_u16_limbs(); + let rs2_limbs = rs2_value.as_u16_limbs(); + emit_uint_limbs_lt_ops( + sink, + matches!(I::INST_KIND, InsnKind::SLT), + rs1_limbs, + rs2_limbs, + ); + }); + + impl_collect_shardram!(r_insn); + + impl_gpu_assign!(dispatch::GpuWitgenKind::Slt(match I::INST_KIND { + InsnKind::SLT => 1u32, + InsnKind::SLTU => 0u32, + _ => unreachable!(), + })); } diff --git a/ceno_zkvm/src/instructions/riscv/slti.rs b/ceno_zkvm/src/instructions/riscv/slti.rs index 90dcb8448..474b664ee 100644 --- a/ceno_zkvm/src/instructions/riscv/slti.rs +++ b/ceno_zkvm/src/instructions/riscv/slti.rs @@ -1,5 +1,5 @@ #[cfg(feature = "u16limb_circuit")] -mod slti_circuit_v2; +pub(crate) mod slti_circuit_v2; #[cfg(not(feature = "u16limb_circuit"))] mod slti_circuit; diff --git a/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs index b2449614e..4483ea4d4 100644 --- a/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs @@ -4,8 +4,10 @@ use crate::{ e2e::ShardContext, error::ZKVMError, gadgets::{UIntLimbsLT, UIntLimbsLTConfig}, + impl_collect_lk_and_shardram, impl_collect_shardram, impl_gpu_assign, instructions::{ Instruction, + gpu::utils::emit_uint_limbs_lt_ops, riscv::{ RIVInstruction, constants::{UINT_LIMBS, UInt}, @@ -25,16 +27,16 @@ use witness::set_val; #[derive(Debug)] pub struct SetLessThanImmConfig { - i_insn: IInstructionConfig, + pub(crate) i_insn: IInstructionConfig, - rs1_read: UInt, - imm: WitIn, + pub(crate) rs1_read: UInt, + pub(crate) imm: WitIn, // 0 positive, 1 negative - imm_sign: WitIn, + pub(crate) imm_sign: WitIn, #[allow(dead_code)] pub(crate) rd_written: UInt, - uint_lt_config: UIntLimbsLTConfig, + pub(crate) uint_lt_config: UIntLimbsLTConfig, } pub struct SetLessThanImmInstruction(PhantomData<(E, I)>); @@ -43,6 +45,8 @@ impl Instruction for SetLessThanImmInst type InstructionConfig = SetLessThanImmConfig; type InsnType = InsnKind; + const GPU_LK_SHARDRAM: bool = true; + fn inst_kinds() -> &'static [Self::InsnType] { &[I::INST_KIND] } @@ -133,4 +137,24 @@ impl Instruction for SetLessThanImmInst )?; Ok(()) } + + impl_collect_lk_and_shardram!(i_insn, |sink, step, _config, _ctx| { + let rs1_value = Value::new_unchecked(step.rs1().unwrap().value); + let rs1_limbs = rs1_value.as_u16_limbs(); + let imm_sign_extend = imm_sign_extend(true, step.insn().imm as i16); + emit_uint_limbs_lt_ops( + sink, + matches!(I::INST_KIND, InsnKind::SLTI), + rs1_limbs, + &imm_sign_extend, + ); + }); + + impl_collect_shardram!(i_insn); + + impl_gpu_assign!(dispatch::GpuWitgenKind::Slti(match I::INST_KIND { + InsnKind::SLTI => 1u32, + InsnKind::SLTIU => 0u32, + _ => unreachable!(), + })); } diff --git a/ceno_zkvm/src/precompiles/mod.rs b/ceno_zkvm/src/precompiles/mod.rs index 3d9a6e545..7e0e6d049 100644 --- a/ceno_zkvm/src/precompiles/mod.rs +++ b/ceno_zkvm/src/precompiles/mod.rs @@ -1,6 +1,6 @@ mod bitwise_keccakf; mod fptower; -mod lookup_keccakf; +pub(crate) mod lookup_keccakf; mod sha256; mod uint256; mod utils; diff --git a/ceno_zkvm/src/scheme/septic_curve.rs b/ceno_zkvm/src/scheme/septic_curve.rs index f9b6b4f76..ef654ad04 100644 --- a/ceno_zkvm/src/scheme/septic_curve.rs +++ b/ceno_zkvm/src/scheme/septic_curve.rs @@ -1171,4 +1171,311 @@ mod tests { assert!(j4.is_on_curve()); assert_eq!(j4.into_affine(), p4); } + + /// GPU vs CPU EC point computation test. + /// Launches the `test_septic_ec_point` CUDA kernel with known inputs, + /// then compares GPU outputs against CPU `ShardRamRecord::to_ec_point()`. + #[test] + #[cfg(feature = "gpu")] + fn test_gpu_ec_point_matches_cpu() { + use crate::tables::{ECPoint, ShardRamRecord}; + use ceno_gpu::bb31::{ + CudaHalBB31, + test_impl::{TestEcInput, run_gpu_ec_test}, + }; + use ff_ext::{PoseidonField, SmallField}; + use gkr_iop::RAMType; + + let hal = CudaHalBB31::new(0).unwrap(); + let perm = F::get_default_perm(); + + // Test cases: various write/read, register/memory, edge cases + let test_inputs = vec![ + TestEcInput { + addr: 5, + ram_type: 1, + value: 0x12345678, + is_write: 1, + shard: 1, + global_clk: 100, + }, + TestEcInput { + addr: 5, + ram_type: 1, + value: 0x12345678, + is_write: 0, + shard: 0, + global_clk: 50, + }, + TestEcInput { + addr: 0x80000, + ram_type: 2, + value: 0xDEADBEEF, + is_write: 1, + shard: 2, + global_clk: 200, + }, + TestEcInput { + addr: 0x80000, + ram_type: 2, + value: 0xDEADBEEF, + is_write: 0, + shard: 1, + global_clk: 150, + }, + TestEcInput { + addr: 0, + ram_type: 1, + value: 0, + is_write: 1, + shard: 0, + global_clk: 1, + }, + TestEcInput { + addr: 31, + ram_type: 1, + value: 0xFFFFFFFF, + is_write: 0, + shard: 3, + global_clk: 999, + }, + TestEcInput { + addr: 0x40000000, + ram_type: 2, + value: 42, + is_write: 1, + shard: 5, + global_clk: 500000, + }, + TestEcInput { + addr: 10, + ram_type: 1, + value: 1, + is_write: 0, + shard: 100, + global_clk: 1000000, + }, + ]; + + let gpu_results = run_gpu_ec_test(&hal, &test_inputs); + + let mut mismatches = 0; + for (i, (input, gpu_rec)) in test_inputs.iter().zip(gpu_results.iter()).enumerate() { + // Build CPU ShardRamRecord and compute EC point. + // GPU kernel treats slot.addr as word address and converts to byte address + // via `<< 2` for Memory type; Register type uses reg_id directly. + let addr = if input.ram_type == 2 { + input.addr << 2 + } else { + input.addr + }; + let cpu_record = ShardRamRecord { + addr, + ram_type: if input.ram_type == 1 { + RAMType::Register + } else { + RAMType::Memory + }, + value: input.value, + shard: input.shard, + local_clk: if input.is_write != 0 { + input.global_clk + } else { + 0 + }, + global_clk: input.global_clk, + is_to_write_set: input.is_write != 0, + }; + let cpu_ec: ECPoint = cpu_record.to_ec_point(&perm); + + let mut has_diff = false; + + if gpu_rec.nonce != cpu_ec.nonce { + eprintln!("[{i}] nonce: gpu={} cpu={}", gpu_rec.nonce, cpu_ec.nonce); + has_diff = true; + } + + for j in 0..7 { + let gpu_x = gpu_rec.point_x[j]; + let cpu_x = cpu_ec.point.x.0[j].to_canonical_u64() as u32; + if gpu_x != cpu_x { + eprintln!("[{i}] x[{j}]: gpu={gpu_x} cpu={cpu_x}"); + has_diff = true; + } + let gpu_y = gpu_rec.point_y[j]; + let cpu_y = cpu_ec.point.y.0[j].to_canonical_u64() as u32; + if gpu_y != cpu_y { + eprintln!("[{i}] y[{j}]: gpu={gpu_y} cpu={cpu_y}"); + has_diff = true; + } + } + + if has_diff { + eprintln!( + "MISMATCH [{i}]: addr={} ram_type={} value={:#x} is_write={} shard={} clk={}", + input.addr, + input.ram_type, + input.value, + input.is_write, + input.shard, + input.global_clk + ); + mismatches += 1; + } + } + + assert_eq!( + mismatches, + 0, + "{mismatches}/{} test cases had GPU/CPU EC point mismatches", + test_inputs.len() + ); + eprintln!( + "All {} GPU EC point test cases match CPU!", + test_inputs.len() + ); + } + + /// Verify GPU Poseidon2 permutation matches CPU on the exact sponge packing + /// used by to_ec_point(). This isolates Montgomery encoding correctness. + #[test] + #[cfg(feature = "gpu")] + fn test_gpu_poseidon2_sponge_matches_cpu() { + use ceno_gpu::bb31::{ + CudaHalBB31, + test_impl::{SPONGE_WIDTH, run_gpu_poseidon2_sponge}, + }; + use ff_ext::{PoseidonField, SmallField}; + use p3::{field::FieldAlgebra, symmetric::Permutation}; + + let hal = CudaHalBB31::new(0).unwrap(); + let perm = F::get_default_perm(); + + // Build sponge inputs matching to_ec_point packing: + // [addr, ram_type, value_lo16, value_hi16, shard, global_clk, nonce, 0..0] + let test_cases: Vec<[u32; SPONGE_WIDTH]> = vec![ + // Case 1: typical write record + { + let mut s = [0u32; SPONGE_WIDTH]; + s[0] = 5; // addr + s[1] = 1; // ram_type (Register) + s[2] = 0x5678; // value lo16 + s[3] = 0x1234; // value hi16 + s[4] = 1; // shard + s[5] = 100; // global_clk + s[6] = 0; // nonce + s + }, + // Case 2: memory read, different values + { + let mut s = [0u32; SPONGE_WIDTH]; + s[0] = 0x80000; + s[1] = 2; // Memory + s[2] = 0xBEEF; + s[3] = 0xDEAD; + s[4] = 2; + s[5] = 200; + s[6] = 3; // nonce=3 + s + }, + // Case 3: all zeros (edge case) + [0u32; SPONGE_WIDTH], + ]; + + let count = test_cases.len(); + let flat_input: Vec = test_cases.iter().flat_map(|s| s.iter().copied()).collect(); + let gpu_output = run_gpu_poseidon2_sponge(&hal, &flat_input, count); + + let mut mismatches = 0; + for (i, input) in test_cases.iter().enumerate() { + // CPU Poseidon2 + let cpu_input: Vec = input.iter().map(|&v| F::from_canonical_u32(v)).collect(); + let cpu_output = perm.permute(cpu_input); + + for j in 0..SPONGE_WIDTH { + let gpu_v = gpu_output[i * SPONGE_WIDTH + j]; + let cpu_v = cpu_output[j].to_canonical_u64() as u32; + if gpu_v != cpu_v { + eprintln!("[case {i}] sponge[{j}]: gpu={gpu_v} cpu={cpu_v}"); + mismatches += 1; + } + } + } + + assert_eq!( + mismatches, 0, + "{mismatches} Poseidon2 output elements differ between GPU and CPU" + ); + eprintln!("All {} Poseidon2 sponge test cases match!", count); + } + + /// Verify GPU septic_point_from_x matches CPU SepticPoint::from_x. + /// Tests the full GF(p^7) sqrt (Cipolla) + curve equation. + #[test] + #[cfg(feature = "gpu")] + fn test_gpu_septic_from_x_matches_cpu() { + use ceno_gpu::bb31::{CudaHalBB31, test_impl::run_gpu_septic_from_x}; + use ff_ext::SmallField; + use p3::field::FieldAlgebra; + + let hal = CudaHalBB31::new(0).unwrap(); + + // Generate test x-coordinates by hashing known inputs + // (ensures we get a mix of points-exist and points-don't-exist cases) + let test_xs: Vec<[u32; 7]> = vec![ + // x from Poseidon2([5,1,0x5678,0x1234,1,100,0, 0..]) — known to have a point + [ + 1594766074, 868528894, 1733778006, 1242721508, 1690833816, 1437202757, 1753525271, + ], + // Simple: x = [1,0,0,0,0,0,0] + [1, 0, 0, 0, 0, 0, 0], + // x = [0,0,0,0,0,0,0] (zero) + [0, 0, 0, 0, 0, 0, 0], + // x = [42, 17, 999, 0, 0, 0, 0] + [42, 17, 999, 0, 0, 0, 0], + // Random-ish values + [ + 1000000007, 123456789, 987654321, 111111111, 222222222, 333333333, 444444444, + ], + ]; + + let count = test_xs.len(); + let flat_x: Vec = test_xs.iter().flat_map(|x| x.iter().copied()).collect(); + let (gpu_y, gpu_flags) = run_gpu_septic_from_x(&hal, &flat_x, count); + + let mut mismatches = 0; + for (i, x_arr) in test_xs.iter().enumerate() { + // CPU: SepticPoint::from_x + let x = SepticExtension(x_arr.map(|v| F::from_canonical_u32(v))); + let cpu_result = SepticPoint::::from_x(x); + + let gpu_found = gpu_flags[i] != 0; + let cpu_found = cpu_result.is_some(); + + if gpu_found != cpu_found { + eprintln!("[{i}] from_x existence: gpu={gpu_found} cpu={cpu_found}"); + mismatches += 1; + continue; + } + + if let Some(cpu_pt) = cpu_result { + // Compare y coordinates (GPU returns canonical, before any negation) + for j in 0..7 { + let gpu_v = gpu_y[i * 7 + j]; + let cpu_v = cpu_pt.y.0[j].to_canonical_u64() as u32; + // from_x returns the "natural" sqrt; they should match exactly + if gpu_v != cpu_v { + eprintln!("[{i}] y[{j}]: gpu={gpu_v} cpu={cpu_v}"); + mismatches += 1; + } + } + } + } + + assert_eq!( + mismatches, 0, + "{mismatches} septic_from_x results differ between GPU and CPU" + ); + eprintln!("All {} septic_from_x test cases match!", count); + } } diff --git a/ceno_zkvm/src/structs.rs b/ceno_zkvm/src/structs.rs index 1f6847140..06b0030c8 100644 --- a/ceno_zkvm/src/structs.rs +++ b/ceno_zkvm/src/structs.rs @@ -1,9 +1,9 @@ use crate::{ circuit_builder::{CircuitBuilder, ConstraintSystem}, - e2e::{E2EProgramCtx, ShardContext}, + e2e::{E2EProgramCtx, GPU_SHARD_RAM_RECORD_SIZE, ShardContext}, error::ZKVMError, instructions::Instruction, - scheme::septic_curve::SepticPoint, + scheme::septic_curve::{SepticExtension, SepticPoint}, state::StateCircuit, tables::{ ECPoint, MemFinalRecord, RMMCollections, ShardRamCircuit, ShardRamInput, ShardRamRecord, @@ -20,7 +20,7 @@ use rayon::{ iter::{IntoParallelIterator, IntoParallelRefIterator, ParallelIterator}, prelude::ParallelSlice, }; -use rustc_hash::FxHashSet; +use rustc_hash::FxHashMap; use serde::{Deserialize, Serialize, de::DeserializeOwned}; use std::{ collections::{BTreeMap, HashMap}, @@ -351,7 +351,7 @@ impl ChipInput { pub struct ZKVMWitnesses { pub witnesses: BTreeMap>>, lk_mlts: BTreeMap>, - combined_lk_mlt: Option>>, + combined_lk_mlt: Option>>, } impl ZKVMWitnesses { @@ -363,6 +363,14 @@ impl ZKVMWitnesses { self.lk_mlts.get(name) } + pub fn combined_lk_mlt(&self) -> Option<&Vec>> { + self.combined_lk_mlt.as_ref() + } + + pub fn lk_mlts(&self) -> &BTreeMap> { + &self.lk_mlts + } + pub fn assign_opcode_circuit>( &mut self, cs: &ZKVMConstraintSystem, @@ -469,100 +477,164 @@ impl ZKVMWitnesses { ), config: & as TableCircuit>::TableConfig, ) -> Result<(), ZKVMError> { - let perm = ::get_default_perm(); - let addr_accessed = shard_ctx.get_addr_accessed(); - - // future shard needed records := shard_ctx.write_records ∪ // - // (shard_ctx.after_current_shard_cycle(mem_record.cycle) && !addr_accessed.contains(&waddr)) - - // 1. process final mem which - // 1.1 init in first shard - // 1.2 not accessed in first shard - // 1.3 accessed in future shard - let first_shard_access_later_records = if shard_ctx.is_first_shard() { - final_mem - .par_iter() - // only process no range restriction memory record - // for range specified it means dynamic init across different shard - .filter(|(_, range, _)| range.is_none()) - .flat_map(|(mem_name, _, final_mem)| { - final_mem.par_iter().filter_map(|mem_record| { - let (waddr, addr) = Self::mem_addresses(mem_record); - Self::make_cross_shard_input( - mem_name, - mem_record, - waddr, - addr, - shard_ctx, - &addr_accessed, - &perm, - ) + use tracing::info_span; + + // Try the full GPU pipeline: keep data on device, minimal CPU roundtrips. + // Only when GPU witgen is enabled (otherwise witgen must not touch GPU). + #[cfg(feature = "gpu")] + if crate::instructions::gpu::config::is_gpu_witgen_enabled() { + let gpu_result = self.try_assign_shared_circuit_gpu(cs, shard_ctx, final_mem, config); + match gpu_result { + Ok(true) => return Ok(()), // GPU pipeline succeeded + Ok(false) => {} // GPU pipeline unavailable, fall through + Err(e) => { + tracing::warn!("GPU full pipeline failed, falling back: {e:?}"); + } + } + } + + let addr_accessed = + info_span!("get_addr_accessed").in_scope(|| shard_ctx.get_addr_accessed_sorted()); + + // GPU EC records: convert raw bytes to ShardRamInput (EC points already computed on GPU) + // Partition into writes and reads to maintain the ordering invariant required by + // ShardRamCircuit::assign_instances (writes first, reads after). + let (gpu_ec_writes, gpu_ec_reads) = + info_span!("gpu_ec_convert", n = shard_ctx.gpu_ec_records.len() / 104).in_scope(|| { + if shard_ctx.has_gpu_ec_records() { + gpu_ec_records_to_shard_ram_inputs::(&shard_ctx.gpu_ec_records) + } else { + (vec![], vec![]) + } + }); + + // Collect cross-shard records (filter only, no EC computation yet) + let (write_record_pairs, read_record_pairs) = + info_span!("collect_records").in_scope(|| { + let first_shard_access_later_recs: Vec<(ShardRamRecord, &'static str)> = + if shard_ctx.is_first_shard() { + final_mem + .par_iter() + .filter(|(_, range, _)| range.is_none()) + .flat_map(|(mem_name, _, final_mem)| { + final_mem.par_iter().filter_map(|mem_record| { + let (waddr, addr) = Self::mem_addresses(mem_record); + Self::make_cross_shard_record( + mem_name, + mem_record, + waddr, + addr, + shard_ctx, + &addr_accessed, + ) + }) + }) + .collect() + } else { + vec![] + }; + + let current_shard_access_later_recs: Vec<(ShardRamRecord, &'static str)> = + final_mem + .par_iter() + .filter(|(_, range, _)| range.is_some()) + .flat_map(|(mem_name, range, final_mem)| { + let range = range.as_ref().unwrap(); + final_mem.par_iter().filter_map(|mem_record| { + let (waddr, addr) = Self::mem_addresses(mem_record); + if !range.contains(&addr) { + return None; + } + Self::make_cross_shard_record( + mem_name, + mem_record, + waddr, + addr, + shard_ctx, + &addr_accessed, + ) + }) + }) + .collect(); + + let write_record_pairs: Vec<(ShardRamRecord, &'static str)> = shard_ctx + .write_records() + .iter() + .flat_map(|records| { + records.iter().map(|(vma, record)| { + ((vma, record, true).into(), "current_shard_external_write") + }) }) - }) - .collect() - } else { - vec![] - }; + .chain(first_shard_access_later_recs) + .chain(current_shard_access_later_recs) + .collect(); - // 2. process records which - // 2.1 init within current shard - // 2.2 not accessed in current shard - // 2.3 access by later shards. - let current_shard_access_later = final_mem - .par_iter() - // only process range-restricted memory record - // for range specified it means dynamic init across different shard - .filter(|(_, range, _)| range.is_some()) - .flat_map(|(mem_name, range, final_mem)| { - let range = range.as_ref().unwrap(); - final_mem.par_iter().filter_map(|mem_record| { - let (waddr, addr) = Self::mem_addresses(mem_record); - if !range.contains(&addr) { - return None; - } - Self::make_cross_shard_input( - mem_name, - mem_record, - waddr, - addr, - shard_ctx, - &addr_accessed, - &perm, - ) - }) - }) - .collect::>(); - - let global_input = shard_ctx - .write_records() - .par_iter() - .flat_map(|records| { - // global write -> local reads - records.par_iter().map(|(vma, record)| { - let global_write: ShardRamRecord = (vma, record, true).into(); - let ec_point: ECPoint = global_write.to_ec_point(&perm); - ShardRamInput { - name: "current_shard_external_write", - record: global_write, - ec_point, - } - }) - }) - .chain(first_shard_access_later_records.into_par_iter()) - .chain(current_shard_access_later.into_par_iter()) - .chain(shard_ctx.read_records().par_iter().flat_map(|records| { - // global read -> local write - records.par_iter().map(|(vma, record)| { - let global_read: ShardRamRecord = (vma, record, false).into(); - let ec_point: ECPoint = global_read.to_ec_point(&perm); - ShardRamInput { - name: "current_shard_external_read", - record: global_read, - ec_point, - } - }) - })) - .collect::>(); + let read_record_pairs: Vec<(ShardRamRecord, &'static str)> = shard_ctx + .read_records() + .iter() + .flat_map(|records| { + records.iter().map(|(vma, record)| { + ((vma, record, false).into(), "current_shard_external_read") + }) + }) + .collect(); + + (write_record_pairs, read_record_pairs) + }); + + // Compute EC points: GPU path (only when GPU witgen enabled) or CPU fallback + let global_input = { + #[cfg(feature = "gpu")] + let ec_result = if crate::instructions::gpu::config::is_gpu_witgen_enabled() { + use crate::instructions::gpu::chips::shard_ram::gpu_batch_continuation_ec; + gpu_batch_continuation_ec::(&write_record_pairs, &read_record_pairs).ok() + } else { + None + }; + #[cfg(not(feature = "gpu"))] + let ec_result: Option<(Vec>, Vec>)> = None; + + if let Some((computed_writes, computed_reads)) = ec_result { + // GPU path: chain computed EC with pre-computed GPU EC records + computed_writes + .into_iter() + .chain(gpu_ec_writes) + .chain(computed_reads) + .chain(gpu_ec_reads) + .collect::>() + } else { + // CPU fallback: compute EC points with Poseidon2 permutation + let perm = ::get_default_perm(); + let cpu_writes: Vec> = write_record_pairs + .into_par_iter() + .map(|(record, name)| { + let ec_point = record.to_ec_point(&perm); + ShardRamInput { + name, + record, + ec_point, + } + }) + .collect(); + let cpu_reads: Vec> = read_record_pairs + .into_par_iter() + .map(|(record, name)| { + let ec_point = record.to_ec_point(&perm); + ShardRamInput { + name, + record, + ec_point, + } + }) + .collect(); + cpu_writes + .into_iter() + .chain(gpu_ec_writes) + .chain(cpu_reads) + .chain(gpu_ec_reads) + .collect::>() + } + }; if tracing::enabled!(Level::DEBUG) { let total = global_input.len() as f64; @@ -592,31 +664,69 @@ impl ZKVMWitnesses { } } + // Invariant: all writes (is_to_write_set=true) must precede all reads. + // ShardRamCircuit::assign_instances uses take_while to count writes. + // Activate with CENO_DEBUG_SHARD_RAM_ORDER=1. + if std::env::var_os("CENO_DEBUG_SHARD_RAM_ORDER").is_some() { + let mut seen_read = false; + for (i, input) in global_input.iter().enumerate() { + if input.record.is_to_write_set { + if seen_read { + tracing::error!( + "[SHARD_RAM_ORDER] BUG: write after read at index={i} \ + addr={} ram_type={:?} shard={} global_clk={} \ + (total={} writes={} reads={})", + input.record.addr, + input.record.ram_type, + shard_ctx.shard_id, + input.record.global_clk, + global_input.len(), + global_input + .iter() + .filter(|x| x.record.is_to_write_set) + .count(), + global_input + .iter() + .filter(|x| !x.record.is_to_write_set) + .count(), + ); + break; + } + } else { + seen_read = true; + } + } + } + assert!(self.combined_lk_mlt.is_some()); let cs = cs.get_cs(&ShardRamCircuit::::name()).unwrap(); - let circuit_inputs = global_input - .par_chunks(shard_ctx.max_num_cross_shard_accesses) - .map(|shard_accesses| { - let witness = ShardRamCircuit::assign_instances( - config, - cs.zkvm_v1_css.num_witin as usize, - cs.zkvm_v1_css.num_structural_witin as usize, - self.combined_lk_mlt.as_ref().unwrap(), - shard_accesses, - )?; - let num_reads = shard_accesses - .par_iter() - .filter(|access| access.record.is_to_write_set) - .count(); - let num_writes = shard_accesses.len() - num_reads; - - Ok(ChipInput::new( - ShardRamCircuit::::name(), - witness, - vec![num_reads, num_writes], - )) - }) - .collect::, ZKVMError>>()?; + let n_global = global_input.len(); + let circuit_inputs = + info_span!("shard_ram_assign_instances", n = n_global).in_scope(|| { + global_input + .par_chunks(shard_ctx.max_num_cross_shard_accesses) + .map(|shard_accesses| { + let witness = ShardRamCircuit::assign_instances( + config, + cs.zkvm_v1_css.num_witin as usize, + cs.zkvm_v1_css.num_structural_witin as usize, + self.combined_lk_mlt.as_ref().unwrap(), + shard_accesses, + )?; + let num_reads = shard_accesses + .par_iter() + .filter(|access| access.record.is_to_write_set) + .count(); + let num_writes = shard_accesses.len() - num_reads; + + Ok(ChipInput::new( + ShardRamCircuit::::name(), + witness, + vec![num_reads, num_writes], + )) + }) + .collect::, ZKVMError>>() + })?; // set num_read, num_write as separate instance assert!( self.witnesses @@ -627,6 +737,269 @@ impl ZKVMWitnesses { Ok(()) } + /// Full GPU pipeline for assign_shared_circuit: keep data on device, minimal CPU roundtrips. + /// + /// Returns Ok(true) if successful, Ok(false) if unavailable (no shared device buffers). + #[cfg(feature = "gpu")] + fn try_assign_shared_circuit_gpu( + &mut self, + cs: &ZKVMConstraintSystem, + shard_ctx: &ShardContext, + final_mem: &[(&'static str, Option>, &[MemFinalRecord])], + config: & as TableCircuit>::TableConfig, + ) -> Result { + use crate::instructions::gpu::{ + chips::shard_ram::gpu_batch_continuation_ec_on_device, + dispatch::take_shared_device_buffers, + }; + use ceno_gpu::Buffer; + use gkr_iop::gpu::get_cuda_hal; + use tracing::info_span; + + // 1. Take shared device buffers (if available) + let mut shared = match take_shared_device_buffers() { + Some(s) => s, + None => return Ok(false), + }; + + let hal = match get_cuda_hal() { + Ok(h) => h, + Err(_) => return Ok(false), + }; + + tracing::info!("[GPU full pipeline] starting device-resident assign_shared_circuit"); + + // 2. D2H the EC count and addr count + let ec_count = { + let cv: Vec = shared.ec_count.to_vec().map_err(|e| { + ZKVMError::InvalidWitness(format!("shared_ec_count D2H: {e}").into()) + })?; + cv[0] as usize + }; + let addr_count = { + let cv: Vec = shared.addr_count.to_vec().map_err(|e| { + ZKVMError::InvalidWitness(format!("shared_addr_count D2H: {e}").into()) + })?; + cv[0] as usize + }; + + tracing::info!( + "[GPU full pipeline] shared buffers: {} EC records, {} addr_accessed", + ec_count, + addr_count, + ); + + // 3. GPU sort addr_accessed + dedup, then D2H sorted unique addrs + let addr_accessed: Vec = if addr_count > 0 { + info_span!("gpu_sort_addr").in_scope(|| { + let (deduped, unique_count) = hal + .witgen + .sort_and_dedup_u32(&mut shared.addr_buf, addr_count, None) + .map_err(|e| ZKVMError::InvalidWitness(format!("GPU sort addr: {e}").into()))?; + if unique_count == 0 { + return Ok::, ZKVMError>(vec![]); + } + // GPU-sorted + CPU-deduped; convert to WordAddr + let addrs: Vec = deduped.into_iter().map(WordAddr).collect(); + tracing::info!( + "[GPU full pipeline] sorted {} addrs → {} unique", + addr_count, + unique_count, + ); + Ok(addrs) + })? + } else { + vec![] + }; + + // 4. CPU collect_records (24ms, uses sorted unique addrs) + let (write_record_pairs, read_record_pairs) = + info_span!("collect_records").in_scope(|| { + // This is the same logic as the existing path + let first_shard_access_later_recs: Vec<(ShardRamRecord, &'static str)> = + if shard_ctx.is_first_shard() { + final_mem + .par_iter() + .filter(|(_, range, _)| range.is_none()) + .flat_map(|(mem_name, _, final_mem)| { + final_mem.par_iter().filter_map(|mem_record| { + let (waddr, addr) = Self::mem_addresses(mem_record); + Self::make_cross_shard_record( + mem_name, + mem_record, + waddr, + addr, + shard_ctx, + &addr_accessed, + ) + }) + }) + .collect() + } else { + vec![] + }; + + let current_shard_access_later_recs: Vec<(ShardRamRecord, &'static str)> = + final_mem + .par_iter() + .filter(|(_, range, _)| range.is_some()) + .flat_map(|(mem_name, range, final_mem)| { + let range = range.as_ref().unwrap(); + final_mem.par_iter().filter_map(|mem_record| { + let (waddr, addr) = Self::mem_addresses(mem_record); + if !range.contains(&addr) { + return None; + } + Self::make_cross_shard_record( + mem_name, + mem_record, + waddr, + addr, + shard_ctx, + &addr_accessed, + ) + }) + }) + .collect(); + + let write_record_pairs: Vec<(ShardRamRecord, &'static str)> = shard_ctx + .write_records() + .iter() + .flat_map(|records| { + records.iter().map(|(vma, record)| { + ((vma, record, true).into(), "current_shard_external_write") + }) + }) + .chain(first_shard_access_later_recs) + .chain(current_shard_access_later_recs) + .collect(); + + let read_record_pairs: Vec<(ShardRamRecord, &'static str)> = shard_ctx + .read_records() + .iter() + .flat_map(|records| { + records.iter().map(|(vma, record)| { + ((vma, record, false).into(), "current_shard_external_read") + }) + }) + .collect(); + + (write_record_pairs, read_record_pairs) + }); + + // 5. GPU batch EC on device for continuation records (25ms, results stay on GPU) + let (cont_ec_buf, cont_n_writes, cont_n_reads) = info_span!("gpu_batch_ec_on_device") + .in_scope(|| { + gpu_batch_continuation_ec_on_device(&write_record_pairs, &read_record_pairs) + })?; + let cont_total = cont_n_writes + cont_n_reads; + + tracing::info!( + "[GPU full pipeline] batch EC on device: {} writes + {} reads = {} continuation records", + cont_n_writes, + cont_n_reads, + cont_total, + ); + + // 6. GPU merge shared_ec + batch_ec, then partition by is_to_write_set + let (partitioned_buf, num_writes, total_records) = info_span!("gpu_merge_partition") + .in_scope(|| { + hal.witgen + .merge_and_partition_records( + &shared.ec_buf, + ec_count, + &cont_ec_buf, + cont_total, + None, + ) + .map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU merge+partition: {e}").into()) + }) + })?; + + tracing::info!( + "[GPU full pipeline] merged+partitioned: {} total ({} writes, {} reads)", + total_records, + num_writes, + total_records - num_writes, + ); + + // 7. GPU assign_instances from device buffer (chunked by max_cross_shard) + assert!(self.combined_lk_mlt.is_some()); + let cs_inner = cs.get_cs(&ShardRamCircuit::::name()).unwrap(); + let num_witin = cs_inner.zkvm_v1_css.num_witin as usize; + let num_structural_witin = cs_inner.zkvm_v1_css.num_structural_witin as usize; + let max_chunk = shard_ctx.max_num_cross_shard_accesses; + + // Record sizes needed for chunking + let record_u32s = + std::mem::size_of::() / 4; + + let circuit_inputs = info_span!("shard_ram_assign_from_device", n = total_records) + .in_scope(|| { + // Process chunks sequentially (each chunk uses GPU exclusively) + let mut inputs = Vec::new(); + let mut records_offset = 0usize; + let mut writes_remaining = num_writes; + + while records_offset < total_records { + let chunk_size = max_chunk.min(total_records - records_offset); + let chunk_writes = writes_remaining.min(chunk_size); + writes_remaining = writes_remaining.saturating_sub(chunk_size); + + // Create a view into the partitioned buffer for this chunk. + // SAFETY: chunk_buf borrows from partitioned_buf and is dropped + // at the end of each loop iteration, before partitioned_buf goes + // out of scope. The 'static lifetime is required by the HAL API. + let chunk_byte_start = records_offset * record_u32s * 4; + let chunk_byte_end = (records_offset + chunk_size) * record_u32s * 4; + let chunk_view = + partitioned_buf.as_slice_range(chunk_byte_start..chunk_byte_end); + let chunk_buf: ceno_gpu::common::buffer::BufferImpl<'static, u32> = unsafe { + std::mem::transmute( + ceno_gpu::common::buffer::BufferImpl::::new_from_view(chunk_view), + ) + }; + + let witness = ShardRamCircuit::::try_gpu_assign_instances_from_device( + config, + num_witin, + num_structural_witin, + &chunk_buf, + chunk_size, + chunk_writes, + )?; + + let witness = witness.ok_or_else(|| { + ZKVMError::InvalidWitness("GPU shard_ram from_device returned None".into()) + })?; + + let num_reads = chunk_size - chunk_writes; + inputs.push(ChipInput::new( + ShardRamCircuit::::name(), + witness, + vec![chunk_writes, num_reads], + )); + + records_offset += chunk_size; + } + Ok::<_, ZKVMError>(inputs) + })?; + + assert!( + self.witnesses + .insert(ShardRamCircuit::::name(), circuit_inputs) + .is_none() + ); + + tracing::info!( + "[GPU full pipeline] assign_shared_circuit complete: {} total records", + total_records, + ); + + Ok(true) + } + pub fn get_witnesses_name_instance(&self) -> Vec<(String, Vec)> { self.witnesses .iter() @@ -657,17 +1030,19 @@ impl ZKVMWitnesses { } } + /// Filter and construct a cross-shard ShardRamRecord without EC computation. + /// Used by the GPU path where EC is computed in batch on device. #[inline(always)] - fn make_cross_shard_input( + fn make_cross_shard_record( mem_name: &'static str, mem_record: &MemFinalRecord, waddr: WordAddr, addr: u32, shard_ctx: &ShardContext, - addr_accessed: &FxHashSet, - perm: &<::BaseField as PoseidonField>::P, - ) -> Option> { - if addr_accessed.contains(&waddr) || !shard_ctx.after_current_shard_cycle(mem_record.cycle) + addr_accessed: &[WordAddr], + ) -> Option<(ShardRamRecord, &'static str)> { + if addr_accessed.binary_search(&waddr).is_ok() + || !shard_ctx.after_current_shard_cycle(mem_record.cycle) { return None; } @@ -685,12 +1060,7 @@ impl ZKVMWitnesses { global_clk: 0, is_to_write_set: true, }; - let ec_point: ECPoint = global_write.to_ec_point(perm); - Some(ShardRamInput { - name: mem_name, - record: global_write, - ec_point, - }) + Some((global_write, mem_name)) } } @@ -840,3 +1210,85 @@ where // mainly used for debugging pub circuit_index_to_name: BTreeMap, } + +/// Convert raw GPU EC record bytes to ShardRamInput. +/// The raw bytes are from `GpuShardRamRecord` structs (104 bytes each). +/// EC points are already computed on GPU — no Poseidon2/SepticCurve needed. +/// Returns (writes, reads) pre-partitioned using parallel iteration. +fn gpu_ec_records_to_shard_ram_inputs( + raw: &[u8], +) -> (Vec>, Vec>) { + assert!(raw.len().is_multiple_of(GPU_SHARD_RAM_RECORD_SIZE)); + let count = raw.len() / GPU_SHARD_RAM_RECORD_SIZE; + + #[inline(always)] + fn convert_record(raw: &[u8], i: usize) -> ShardRamInput { + use gkr_iop::RAMType; + use p3::field::FieldAlgebra; + + let base = i * GPU_SHARD_RAM_RECORD_SIZE; + let r = &raw[base..base + GPU_SHARD_RAM_RECORD_SIZE]; + + // Read fields directly from the byte buffer. + // Layout matches GpuShardRamRecord (104 bytes, #[repr(C)]): + // 0: addr(u32), 4: ram_type(u32), 8: value(u32), 12: _pad(u32), + // 16: shard(u64), 24: local_clk(u64), 32: global_clk(u64), + // 40: is_to_write_set(u32), 44: nonce(u32), + // 48: point_x[7](u32×7), 76: point_y[7](u32×7) + let addr = u32::from_le_bytes(r[0..4].try_into().unwrap()); + let ram_type_val = u32::from_le_bytes(r[4..8].try_into().unwrap()); + let value = u32::from_le_bytes(r[8..12].try_into().unwrap()); + let shard = u64::from_le_bytes(r[16..24].try_into().unwrap()); + let local_clk = u64::from_le_bytes(r[24..32].try_into().unwrap()); + let global_clk = u64::from_le_bytes(r[32..40].try_into().unwrap()); + let is_to_write_set = u32::from_le_bytes(r[40..44].try_into().unwrap()) != 0; + let nonce = u32::from_le_bytes(r[44..48].try_into().unwrap()); + + let mut point_x_arr = [E::BaseField::ZERO; 7]; + let mut point_y_arr = [E::BaseField::ZERO; 7]; + for j in 0..7 { + point_x_arr[j] = E::BaseField::from_canonical_u32(u32::from_le_bytes( + r[48 + j * 4..52 + j * 4].try_into().unwrap(), + )); + point_y_arr[j] = E::BaseField::from_canonical_u32(u32::from_le_bytes( + r[76 + j * 4..80 + j * 4].try_into().unwrap(), + )); + } + + let record = ShardRamRecord { + addr, + ram_type: if ram_type_val == 1 { + RAMType::Register + } else { + RAMType::Memory + }, + value, + shard, + local_clk, + global_clk, + is_to_write_set, + }; + + ShardRamInput { + name: if is_to_write_set { + "current_shard_external_write" + } else { + "current_shard_external_read" + }, + record, + ec_point: ECPoint { + nonce, + point: SepticPoint::from_affine( + SepticExtension(point_x_arr), + SepticExtension(point_y_arr), + ), + }, + } + } + + // Parallel convert + partition in one pass + (0..count) + .into_par_iter() + .map(|i| convert_record::(raw, i)) + .partition(|input| input.record.is_to_write_set) +} diff --git a/ceno_zkvm/src/tables/mod.rs b/ceno_zkvm/src/tables/mod.rs index 0acfed059..db32dfe66 100644 --- a/ceno_zkvm/src/tables/mod.rs +++ b/ceno_zkvm/src/tables/mod.rs @@ -7,7 +7,7 @@ use gkr_iop::{ }; use itertools::Itertools; use multilinear_extensions::ToExpr; -use std::collections::HashMap; +use rustc_hash::FxHashMap; use witness::RowMajorMatrix; mod shard_ram; @@ -94,7 +94,7 @@ pub trait TableCircuit { config: &Self::TableConfig, num_witin: usize, num_structural_witin: usize, - multiplicity: &[HashMap], + multiplicity: &[FxHashMap], input: &Self::WitnessInput<'_>, ) -> Result, ZKVMError>; } diff --git a/ceno_zkvm/src/tables/ops/ops_circuit.rs b/ceno_zkvm/src/tables/ops/ops_circuit.rs index b1216f5ae..05948c00c 100644 --- a/ceno_zkvm/src/tables/ops/ops_circuit.rs +++ b/ceno_zkvm/src/tables/ops/ops_circuit.rs @@ -2,7 +2,8 @@ use super::ops_impl::OpTableConfig; -use std::{collections::HashMap, marker::PhantomData}; +use rustc_hash::FxHashMap; +use std::marker::PhantomData; use crate::{ circuit_builder::CircuitBuilder, @@ -47,7 +48,7 @@ impl TableCircuit for OpsTableCircuit config: &Self::TableConfig, num_witin: usize, num_structural_witin: usize, - multiplicity: &[HashMap], + multiplicity: &[FxHashMap], _input: &(), ) -> Result, ZKVMError> { let multiplicity = &multiplicity[OP::ROM_TYPE as usize]; diff --git a/ceno_zkvm/src/tables/ops/ops_impl.rs b/ceno_zkvm/src/tables/ops/ops_impl.rs index 72b80a548..7046de304 100644 --- a/ceno_zkvm/src/tables/ops/ops_impl.rs +++ b/ceno_zkvm/src/tables/ops/ops_impl.rs @@ -4,7 +4,7 @@ use ff_ext::{ExtensionField, SmallField}; use gkr_iop::error::CircuitBuilderError; use itertools::Itertools; use rayon::iter::{IndexedParallelIterator, ParallelIterator}; -use std::collections::HashMap; +use rustc_hash::FxHashMap; use witness::{InstancePaddingStrategy, RowMajorMatrix, set_fixed_val, set_val}; use crate::{ @@ -70,7 +70,7 @@ impl OpTableConfig { &self, num_witin: usize, num_structural_witin: usize, - multiplicity: &HashMap, + multiplicity: &FxHashMap, length: usize, ) -> Result, CircuitBuilderError> { assert_eq!(num_structural_witin, 1); diff --git a/ceno_zkvm/src/tables/program.rs b/ceno_zkvm/src/tables/program.rs index 3894828d9..96c4356a0 100644 --- a/ceno_zkvm/src/tables/program.rs +++ b/ceno_zkvm/src/tables/program.rs @@ -15,7 +15,8 @@ use itertools::Itertools; use multilinear_extensions::{Expression, Fixed, ToExpr, WitIn}; use p3::field::FieldAlgebra; use rayon::iter::{IndexedParallelIterator, ParallelIterator}; -use std::{collections::HashMap, marker::PhantomData}; +use rustc_hash::FxHashMap; +use std::marker::PhantomData; use witness::{ InstancePaddingStrategy, RowMajorMatrix, next_pow2_instance_padding, set_fixed_val, set_val, }; @@ -268,7 +269,7 @@ impl TableCircuit for ProgramTableCircuit { config: &Self::TableConfig, num_witin: usize, num_structural_witin: usize, - multiplicity: &[HashMap], + multiplicity: &[FxHashMap], program: &Program, ) -> Result, ZKVMError> { assert!(!program.instructions.is_empty()); diff --git a/ceno_zkvm/src/tables/ram/ram_circuit.rs b/ceno_zkvm/src/tables/ram/ram_circuit.rs index 249f70125..dcf2142fa 100644 --- a/ceno_zkvm/src/tables/ram/ram_circuit.rs +++ b/ceno_zkvm/src/tables/ram/ram_circuit.rs @@ -19,7 +19,8 @@ use gkr_iop::{ }; use itertools::Itertools; use multilinear_extensions::{Expression, StructuralWitIn, StructuralWitInType, ToExpr}; -use std::{collections::HashMap, marker::PhantomData, ops::Range}; +use rustc_hash::FxHashMap; +use std::{marker::PhantomData, ops::Range}; use witness::{InstancePaddingStrategy, RowMajorMatrix}; #[derive(Clone, Debug)] @@ -110,7 +111,7 @@ impl< config: &Self::TableConfig, num_witin: usize, num_structural_witin: usize, - _multiplicity: &[HashMap], + _multiplicity: &[FxHashMap], final_v: &Self::WitnessInput<'_>, ) -> Result, ZKVMError> { // assume returned table is well-formed include padding @@ -167,7 +168,7 @@ impl TableCirc config: &Self::TableConfig, num_witin: usize, num_structural_witin: usize, - _multiplicity: &[HashMap], + _multiplicity: &[FxHashMap], final_mem: &[MemFinalRecord], ) -> Result, ZKVMError> { // assume returned table is well-formed including padding @@ -294,7 +295,7 @@ impl< config: &Self::TableConfig, num_witin: usize, num_structural_witin: usize, - _multiplicity: &[HashMap], + _multiplicity: &[FxHashMap], data: &Self::WitnessInput<'_>, ) -> Result, ZKVMError> { // assume returned table is well-formed include padding @@ -380,7 +381,7 @@ impl TableCircuit for LocalFinalRamC config: &Self::TableConfig, num_witin: usize, num_structural_witin: usize, - _multiplicity: &[HashMap], + _multiplicity: &[FxHashMap], (shard_ctx, final_mem): &Self::WitnessInput<'_>, ) -> Result, ZKVMError> { // assume returned table is well-formed include padding diff --git a/ceno_zkvm/src/tables/range/range_circuit.rs b/ceno_zkvm/src/tables/range/range_circuit.rs index d98161fea..7bdec456b 100644 --- a/ceno_zkvm/src/tables/range/range_circuit.rs +++ b/ceno_zkvm/src/tables/range/range_circuit.rs @@ -1,6 +1,7 @@ //! Range tables as circuits with trait TableCircuit. -use std::{collections::HashMap, marker::PhantomData}; +use rustc_hash::FxHashMap; +use std::marker::PhantomData; use crate::{ circuit_builder::CircuitBuilder, @@ -68,7 +69,7 @@ impl TableCircuit config: &Self::TableConfig, num_witin: usize, num_structural_witin: usize, - multiplicity: &[HashMap], + multiplicity: &[FxHashMap], _input: &(), ) -> Result, ZKVMError> { let multiplicity = &multiplicity[LookupTable::Dynamic as usize]; @@ -149,7 +150,7 @@ impl], + multiplicity: &[FxHashMap], _input: &(), ) -> Result, ZKVMError> { let multiplicity = &multiplicity[R::ROM_TYPE as usize]; diff --git a/ceno_zkvm/src/tables/range/range_impl.rs b/ceno_zkvm/src/tables/range/range_impl.rs index a95664085..fa3901a6b 100644 --- a/ceno_zkvm/src/tables/range/range_impl.rs +++ b/ceno_zkvm/src/tables/range/range_impl.rs @@ -3,7 +3,7 @@ use ff_ext::{ExtensionField, SmallField}; use gkr_iop::{error::CircuitBuilderError, tables::LookupTable}; use rayon::iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator}; -use std::collections::HashMap; +use rustc_hash::FxHashMap; use witness::{InstancePaddingStrategy, RowMajorMatrix, set_val}; use crate::{ @@ -56,7 +56,7 @@ impl DynamicRangeTableConfig { &self, num_witin: usize, num_structural_witin: usize, - multiplicity: &HashMap, + multiplicity: &FxHashMap, max_bits: usize, ) -> Result<[RowMajorMatrix; 2], CircuitBuilderError> { let length = 1 << (max_bits + 1); @@ -158,7 +158,7 @@ impl DoubleRangeTableConfig { &self, num_witin: usize, num_structural_witin: usize, - multiplicity: &HashMap, + multiplicity: &FxHashMap, ) -> Result<[RowMajorMatrix; 2], CircuitBuilderError> { let length = 1 << (self.range_a_bits + self.range_b_bits); let mut witness: RowMajorMatrix = diff --git a/ceno_zkvm/src/tables/shard_ram.rs b/ceno_zkvm/src/tables/shard_ram.rs index 23897fce8..22e2c72a3 100644 --- a/ceno_zkvm/src/tables/shard_ram.rs +++ b/ceno_zkvm/src/tables/shard_ram.rs @@ -1,4 +1,5 @@ -use std::{collections::HashMap, iter::repeat_n, marker::PhantomData}; +use rustc_hash::FxHashMap; +use std::{iter::repeat_n, marker::PhantomData}; use crate::{ Value, @@ -165,21 +166,21 @@ impl ShardRamRecord { /// 3. For a local memory write record which will be read in the future, /// the shard ram circuit will insert a local read record and write it to the **global set**. pub struct ShardRamConfig { - addr: WitIn, - is_ram_register: WitIn, - value: UInt, - shard: WitIn, - global_clk: WitIn, - local_clk: WitIn, - nonce: WitIn, + pub(crate) addr: WitIn, + pub(crate) is_ram_register: WitIn, + pub(crate) value: UInt, + pub(crate) shard: WitIn, + pub(crate) global_clk: WitIn, + pub(crate) local_clk: WitIn, + pub(crate) nonce: WitIn, // if it's write to global set, then insert a local read record // s.t. local offline memory checking can cancel out // serves as propagating local write to global. - is_global_write: WitIn, - x: Vec, - y: Vec, - slope: Vec, - perm_config: Poseidon2Config, + pub(crate) is_global_write: WitIn, + pub(crate) x: Vec, + pub(crate) y: Vec, + pub(crate) slope: Vec, + pub(crate) perm_config: Poseidon2Config, } impl ShardRamConfig { @@ -474,7 +475,7 @@ impl TableCircuit for ShardRamCircuit { config: &Self::TableConfig, num_witin: usize, num_structural_witin: usize, - _multiplicity: &[HashMap], + _multiplicity: &[FxHashMap], steps: &Self::WitnessInput<'_>, ) -> Result, ZKVMError> { if steps.is_empty() { @@ -483,6 +484,15 @@ impl TableCircuit for ShardRamCircuit { witness::RowMajorMatrix::empty(), ]); } + + #[cfg(feature = "gpu")] + { + if let Some(result) = + Self::try_gpu_assign_instances(config, num_witin, num_structural_witin, steps)? + { + return Ok(result); + } + } // FIXME selector is the only structural witness // this is workaround, as call `construct_circuit` will not initialized selector // we can remove this one all opcode unittest migrate to call `build_gkr_iop_circuit` @@ -643,6 +653,473 @@ impl TableCircuit for ShardRamCircuit { } } +#[cfg(feature = "gpu")] +impl ShardRamCircuit { + /// Try to run assign_instances on GPU. Returns None if GPU is unavailable. + fn try_gpu_assign_instances( + config: &ShardRamConfig, + num_witin: usize, + num_structural_witin: usize, + steps: &[ShardRamInput], + ) -> Result>, ZKVMError> { + use ceno_gpu::{ + Buffer, CudaHal, + bb31::CudaHalBB31, + common::{transpose::matrix_transpose, witgen::types::GpuShardRamRecord}, + }; + use gkr_iop::gpu::gpu_prover::get_cuda_hal; + use p3::field::PrimeField32; + + type BB = ::BaseField; + + // GPU only supports BabyBear + if std::any::TypeId::of::() != std::any::TypeId::of::() { + return Ok(None); + } + + let hal = match get_cuda_hal() { + Ok(h) => h, + Err(_) => return Ok(None), + }; + + let num_local_reads = steps + .iter() + .take_while(|s| s.record.is_to_write_set) + .count(); + + let n = next_pow2_instance_padding(steps.len()); + let num_rows_padded = 2 * n; + + // 1. Convert ShardRamInput → GpuShardRamRecord + let gpu_records: Vec = + tracing::info_span!("gpu_shard_ram_pack_records", n = steps.len()).in_scope(|| { + steps + .iter() + .map(|step| { + let r = &step.record; + let ec = &step.ec_point; + let mut rec = GpuShardRamRecord::default(); + rec.addr = r.addr; + rec.ram_type = r.ram_type as u32; + rec.value = r.value; + rec.shard = r.shard; + rec.local_clk = r.local_clk; + rec.global_clk = r.global_clk; + rec.is_to_write_set = if r.is_to_write_set { 1 } else { 0 }; + rec.nonce = ec.nonce; + for i in 0..7 { + // Safe: TypeId check above guarantees E::BaseField == BB + let px: BB = + unsafe { *(&ec.point.x.0[i] as *const E::BaseField as *const BB) }; + let py: BB = + unsafe { *(&ec.point.y.0[i] as *const E::BaseField as *const BB) }; + rec.point_x[i] = px.as_canonical_u32(); + rec.point_y[i] = py.as_canonical_u32(); + } + rec + }) + .collect() + }); + + // 2. Extract column map + let col_map = crate::instructions::gpu::chips::shard_ram::extract_shard_ram_column_map( + config, num_witin, + ); + + // 3. GPU Phase 1: per-row assignment + let (gpu_witness, gpu_structural) = tracing::info_span!( + "gpu_shard_ram_per_row", + n = steps.len(), + num_rows_padded, + num_witin, + ) + .in_scope(|| { + hal.witgen + .witgen_shard_ram_per_row( + &col_map, + &gpu_records, + num_local_reads as u32, + num_witin as u32, + num_structural_witin as u32, + num_rows_padded as u32, + None, + ) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU shard_ram per-row kernel failed: {e}").into(), + ) + }) + })?; + + // 4. GPU Phase 2: EC binary tree + let witness_buf = tracing::info_span!("gpu_shard_ram_ec_tree", n).in_scope( + || -> Result<_, ZKVMError> { + let col_offsets = col_map.to_flat(); + let gpu_cols = hal.alloc_u32_from_host(&col_offsets, None).map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU alloc col offsets failed: {e}").into()) + })?; + + // Build initial layer points (padded to n) using BB (BabyBear) directly + let mut init_x = vec![BB::ZERO; n * 7]; + let mut init_y = vec![BB::ZERO; n * 7]; + for (i, step) in steps.iter().enumerate() { + for j in 0..7 { + // E::BaseField == BB at runtime (checked above), safe to transmute + init_x[i * 7 + j] = unsafe { + *(&step.ec_point.point.x.0[j] as *const E::BaseField as *const BB) + }; + init_y[i * 7 + j] = unsafe { + *(&step.ec_point.point.y.0[j] as *const E::BaseField as *const BB) + }; + } + } + + let mut cur_x = hal.alloc_elems_from_host(&init_x, None).map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU alloc init_x failed: {e}").into()) + })?; + let mut cur_y = hal.alloc_elems_from_host(&init_y, None).map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU alloc init_y failed: {e}").into()) + })?; + + let mut witness_buf = gpu_witness.device_buffer; + let mut offset = num_rows_padded / 2; // n + let mut current_layer_len = n; + + loop { + if current_layer_len <= 1 { + break; + } + + let (next_x, next_y) = hal + .witgen + .shard_ram_ec_tree_layer( + &gpu_cols, + &cur_x, + &cur_y, + &mut witness_buf, + current_layer_len, + offset, + num_rows_padded, + None, + ) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU EC tree layer failed: {e}").into(), + ) + })?; + + current_layer_len /= 2; + offset += current_layer_len; + cur_x = next_x; + cur_y = next_y; + } + + Ok(witness_buf) + }, + )?; + + // 5. GPU transpose: column-major → row-major + D2H + let (wit_data, struct_data) = + tracing::info_span!("gpu_shard_ram_transpose_d2h", num_rows_padded, num_witin,) + .in_scope(|| -> Result<_, ZKVMError> { + let wit_num_rows = num_rows_padded; + let wit_num_cols = num_witin; + let mut rmm_buf = hal + .witgen + .alloc_elems_on_device(wit_num_rows * wit_num_cols, false, None) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU alloc for transpose failed: {e}").into(), + ) + })?; + matrix_transpose::( + &hal.inner, + &mut rmm_buf, + &witness_buf, + wit_num_rows, + wit_num_cols, + ) + .map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU transpose failed: {e}").into()) + })?; + + let gpu_wit_data: Vec = rmm_buf.to_vec().map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU D2H wit failed: {e}").into()) + })?; + let wit_data: Vec = unsafe { + let mut data = std::mem::ManuallyDrop::new(gpu_wit_data); + Vec::from_raw_parts( + data.as_mut_ptr() as *mut E::BaseField, + data.len(), + data.capacity(), + ) + }; + + let struct_num_cols = num_structural_witin; + let mut struct_rmm_buf = hal + .witgen + .alloc_elems_on_device(wit_num_rows * struct_num_cols, false, None) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU alloc for struct transpose failed: {e}").into(), + ) + })?; + matrix_transpose::( + &hal.inner, + &mut struct_rmm_buf, + &gpu_structural.device_buffer, + wit_num_rows, + struct_num_cols, + ) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU struct transpose failed: {e}").into(), + ) + })?; + + let gpu_struct_data: Vec = struct_rmm_buf.to_vec().map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU D2H struct failed: {e}").into()) + })?; + let struct_data: Vec = unsafe { + let mut data = std::mem::ManuallyDrop::new(gpu_struct_data); + Vec::from_raw_parts( + data.as_mut_ptr() as *mut E::BaseField, + data.len(), + data.capacity(), + ) + }; + + Ok((wit_data, struct_data)) + })?; + + let raw_witin = witness::RowMajorMatrix::new_by_values( + wit_data, + num_witin, + InstancePaddingStrategy::Default, + ); + let raw_structural_witin = witness::RowMajorMatrix::new_by_values( + struct_data, + num_structural_witin, + InstancePaddingStrategy::Default, + ); + + tracing::info!( + "GPU shard_ram assign_instances done: {} records, {} padded rows", + steps.len(), + num_rows_padded + ); + + Ok(Some([raw_witin, raw_structural_witin])) + } + + /// GPU assign_instances from a device buffer of GpuShardRamRecord. + /// + /// Avoids the ShardRamInput → GpuShardRamRecord conversion and H2D transfer + /// by accepting records that are already on device (from opcode witgen + batch EC). + #[cfg(feature = "gpu")] + pub fn try_gpu_assign_instances_from_device( + config: &ShardRamConfig, + num_witin: usize, + num_structural_witin: usize, + device_records: &ceno_gpu::common::buffer::BufferImpl<'static, u32>, + num_records: usize, + num_local_writes: usize, + ) -> Result>, ZKVMError> { + use ceno_gpu::{Buffer, CudaHal, bb31::CudaHalBB31, common::transpose::matrix_transpose}; + use gkr_iop::gpu::gpu_prover::get_cuda_hal; + + type BB = ::BaseField; + + if std::any::TypeId::of::() != std::any::TypeId::of::() { + return Ok(None); + } + + let hal = match get_cuda_hal() { + Ok(h) => h, + Err(_) => return Ok(None), + }; + + let n = next_pow2_instance_padding(num_records); + let num_rows_padded = 2 * n; + + // 1. Extract column map (same as regular path) + let col_map = crate::instructions::gpu::chips::shard_ram::extract_shard_ram_column_map( + config, num_witin, + ); + + // 2. GPU Phase 1: per-row assignment (records already on device) + let (gpu_witness, gpu_structural) = tracing::info_span!( + "gpu_shard_ram_per_row_from_device", + n = num_records, + num_rows_padded, + num_witin, + ) + .in_scope(|| { + hal.witgen + .witgen_shard_ram_per_row_from_device( + &col_map, + device_records, + num_records, + num_local_writes as u32, + num_witin as u32, + num_structural_witin as u32, + num_rows_padded as u32, + None, + ) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU shard_ram per-row (from_device) kernel failed: {e}").into(), + ) + }) + })?; + + // 3. GPU: extract EC points from device records (replaces CPU loop) + let witness_buf = tracing::info_span!("gpu_shard_ram_ec_tree_from_device", n).in_scope( + || -> Result<_, ZKVMError> { + let col_offsets = col_map.to_flat(); + let gpu_cols = hal.alloc_u32_from_host(&col_offsets, None).map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU alloc col offsets failed: {e}").into()) + })?; + + // Extract point_x/y from device records into flat arrays + let (mut cur_x, mut cur_y) = hal + .witgen + .extract_ec_points_from_device(device_records, num_records, n, None) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU extract_ec_points failed: {e}").into(), + ) + })?; + + let mut witness_buf = gpu_witness.device_buffer; + let mut offset = num_rows_padded / 2; // n + let mut current_layer_len = n; + + loop { + if current_layer_len <= 1 { + break; + } + + let (next_x, next_y) = hal + .witgen + .shard_ram_ec_tree_layer( + &gpu_cols, + &cur_x, + &cur_y, + &mut witness_buf, + current_layer_len, + offset, + num_rows_padded, + None, + ) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU EC tree layer failed: {e}").into(), + ) + })?; + + current_layer_len /= 2; + offset += current_layer_len; + cur_x = next_x; + cur_y = next_y; + } + + Ok(witness_buf) + }, + )?; + + // 4. GPU transpose + D2H (same as regular path) + let (wit_data, struct_data) = tracing::info_span!( + "gpu_shard_ram_transpose_d2h_from_device", + num_rows_padded, + num_witin, + ) + .in_scope(|| -> Result<_, ZKVMError> { + let wit_num_rows = num_rows_padded; + let wit_num_cols = num_witin; + let mut rmm_buf = hal + .witgen + .alloc_elems_on_device(wit_num_rows * wit_num_cols, false, None) + .map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU alloc for transpose failed: {e}").into()) + })?; + matrix_transpose::( + &hal.inner, + &mut rmm_buf, + &witness_buf, + wit_num_rows, + wit_num_cols, + ) + .map_err(|e| ZKVMError::InvalidWitness(format!("GPU transpose failed: {e}").into()))?; + + let gpu_wit_data: Vec = rmm_buf.to_vec().map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU D2H wit failed: {e}").into()) + })?; + let wit_data: Vec = unsafe { + let mut data = std::mem::ManuallyDrop::new(gpu_wit_data); + Vec::from_raw_parts( + data.as_mut_ptr() as *mut E::BaseField, + data.len(), + data.capacity(), + ) + }; + + let struct_num_cols = num_structural_witin; + let mut struct_rmm_buf = hal + .witgen + .alloc_elems_on_device(wit_num_rows * struct_num_cols, false, None) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU alloc for struct transpose failed: {e}").into(), + ) + })?; + matrix_transpose::( + &hal.inner, + &mut struct_rmm_buf, + &gpu_structural.device_buffer, + wit_num_rows, + struct_num_cols, + ) + .map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU struct transpose failed: {e}").into()) + })?; + + let gpu_struct_data: Vec = struct_rmm_buf.to_vec().map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU D2H struct failed: {e}").into()) + })?; + let struct_data: Vec = unsafe { + let mut data = std::mem::ManuallyDrop::new(gpu_struct_data); + Vec::from_raw_parts( + data.as_mut_ptr() as *mut E::BaseField, + data.len(), + data.capacity(), + ) + }; + + Ok((wit_data, struct_data)) + })?; + + let raw_witin = witness::RowMajorMatrix::new_by_values( + wit_data, + num_witin, + InstancePaddingStrategy::Default, + ); + let raw_structural_witin = witness::RowMajorMatrix::new_by_values( + struct_data, + num_structural_witin, + InstancePaddingStrategy::Default, + ); + + tracing::info!( + "GPU shard_ram assign_instances (from_device) done: {} records, {} padded rows", + num_records, + num_rows_padded + ); + + Ok(Some([raw_witin, raw_structural_witin])) + } +} + #[cfg(test)] mod tests { use either::Either; diff --git a/gkr_iop/Cargo.toml b/gkr_iop/Cargo.toml index f93ef4335..cffea08eb 100644 --- a/gkr_iop/Cargo.toml +++ b/gkr_iop/Cargo.toml @@ -22,6 +22,7 @@ once_cell.workspace = true p3.workspace = true rand.workspace = true rayon.workspace = true +rustc-hash.workspace = true serde.workspace = true smallvec.workspace = true strum.workspace = true diff --git a/gkr_iop/src/gadgets/is_lt.rs b/gkr_iop/src/gadgets/is_lt.rs index d3f4a2ac6..b6dc4720f 100644 --- a/gkr_iop/src/gadgets/is_lt.rs +++ b/gkr_iop/src/gadgets/is_lt.rs @@ -12,7 +12,7 @@ use crate::{ }; #[derive(Debug, Clone)] -pub struct AssertLtConfig(InnerLtConfig); +pub struct AssertLtConfig(pub InnerLtConfig); impl AssertLtConfig { pub fn construct_circuit< diff --git a/gkr_iop/src/utils/lk_multiplicity.rs b/gkr_iop/src/utils/lk_multiplicity.rs index 7dded4e70..62de189aa 100644 --- a/gkr_iop/src/utils/lk_multiplicity.rs +++ b/gkr_iop/src/utils/lk_multiplicity.rs @@ -1,8 +1,8 @@ use ff_ext::SmallField; use itertools::izip; +use rustc_hash::FxHashMap; use std::{ cell::RefCell, - collections::HashMap, fmt::Debug, hash::Hash, mem::{self}, @@ -16,7 +16,7 @@ use crate::tables::{ ops::{AndTable, LtuTable, OrTable, PowTable, XorTable}, }; -pub type MultiplicityRaw = [HashMap; mem::variant_count::()]; +pub type MultiplicityRaw = [FxHashMap; mem::variant_count::()]; #[derive(Clone, Default, Debug)] pub struct Multiplicity(pub MultiplicityRaw);