From 77c43d128846d29858e2b48755192671e944d139 Mon Sep 17 00:00:00 2001 From: Velaciela Date: Wed, 4 Mar 2026 10:02:34 +0800 Subject: [PATCH 1/3] repr(C) StepRecord --- ceno_emul/src/addr.rs | 4 +- ceno_emul/src/disassemble/mod.rs | 41 +-- ceno_emul/src/lib.rs | 2 +- ceno_emul/src/platform.rs | 4 +- ceno_emul/src/rv32im.rs | 18 +- ceno_emul/src/test_utils.rs | 8 +- ceno_emul/src/tracer.rs | 285 ++++++++++++++---- ceno_emul/src/vm_state.rs | 8 +- ceno_host/tests/test_elf.rs | 87 ++++-- ceno_zkvm/src/e2e.rs | 28 +- .../instructions/riscv/dummy/dummy_ecall.rs | 3 +- .../src/instructions/riscv/dummy/test.rs | 6 +- .../instructions/riscv/ecall/fptower_fp.rs | 5 +- .../riscv/ecall/fptower_fp2_add.rs | 5 +- .../riscv/ecall/fptower_fp2_mul.rs | 5 +- .../src/instructions/riscv/ecall/keccak.rs | 5 +- .../instructions/riscv/ecall/sha_extend.rs | 6 +- .../src/instructions/riscv/ecall/uint256.rs | 10 +- .../riscv/ecall/weierstrass_add.rs | 5 +- .../riscv/ecall/weierstrass_decompress.rs | 5 +- .../riscv/ecall/weierstrass_double.rs | 5 +- .../src/instructions/riscv/ecall_base.rs | 4 +- 22 files changed, 394 insertions(+), 155 deletions(-) diff --git a/ceno_emul/src/addr.rs b/ceno_emul/src/addr.rs index 24a9ceea2..14c938e07 100644 --- a/ceno_emul/src/addr.rs +++ b/ceno_emul/src/addr.rs @@ -30,12 +30,14 @@ pub type Word = u32; pub type SWord = i32; pub type Addr = u32; pub type Cycle = u64; -pub type RegIdx = usize; +pub type RegIdx = u8; #[derive(Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)] +#[repr(C)] pub struct ByteAddr(pub u32); #[derive(Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[repr(C)] pub struct WordAddr(pub u32); impl From for WordAddr { diff --git a/ceno_emul/src/disassemble/mod.rs b/ceno_emul/src/disassemble/mod.rs index 8332a6d6f..853617d63 100644 --- a/ceno_emul/src/disassemble/mod.rs +++ b/ceno_emul/src/disassemble/mod.rs @@ -1,4 +1,7 @@ -use crate::rv32im::{InsnKind, Instruction}; +use crate::{ + addr::RegIdx, + rv32im::{InsnKind, Instruction}, +}; use itertools::izip; use rrs_lib::{ InstructionProcessor, @@ -19,9 +22,9 @@ impl Instruction { pub const fn from_r_type(kind: InsnKind, dec_insn: &RType, raw: u32) -> Self { Self { kind, - rd: dec_insn.rd, - rs1: dec_insn.rs1, - rs2: dec_insn.rs2, + rd: dec_insn.rd as RegIdx, + rs1: dec_insn.rs1 as RegIdx, + rs2: dec_insn.rs2 as RegIdx, imm: 0, raw, } @@ -32,8 +35,8 @@ impl Instruction { pub const fn from_i_type(kind: InsnKind, dec_insn: &IType, raw: u32) -> Self { Self { kind, - rd: dec_insn.rd, - rs1: dec_insn.rs1, + rd: dec_insn.rd as RegIdx, + rs1: dec_insn.rs1 as RegIdx, imm: dec_insn.imm, rs2: 0, raw, @@ -45,8 +48,8 @@ impl Instruction { pub const fn from_i_type_shamt(kind: InsnKind, dec_insn: &ITypeShamt, raw: u32) -> Self { Self { kind, - rd: dec_insn.rd, - rs1: dec_insn.rs1, + rd: dec_insn.rd as RegIdx, + rs1: dec_insn.rs1 as RegIdx, imm: dec_insn.shamt as i32, rs2: 0, raw, @@ -59,8 +62,8 @@ impl Instruction { Self { kind, rd: 0, - rs1: dec_insn.rs1, - rs2: dec_insn.rs2, + rs1: dec_insn.rs1 as RegIdx, + rs2: dec_insn.rs2 as RegIdx, imm: dec_insn.imm, raw, } @@ -72,8 +75,8 @@ impl Instruction { Self { kind, rd: 0, - rs1: dec_insn.rs1, - rs2: dec_insn.rs2, + rs1: dec_insn.rs1 as RegIdx, + rs2: dec_insn.rs2 as RegIdx, imm: dec_insn.imm, raw, } @@ -231,7 +234,7 @@ impl InstructionProcessor for InstructionTranspiler { fn process_jal(&mut self, dec_insn: JType) -> Self::InstructionResult { Instruction { kind: InsnKind::JAL, - rd: dec_insn.rd, + rd: dec_insn.rd as RegIdx, rs1: 0, rs2: 0, imm: dec_insn.imm, @@ -242,8 +245,8 @@ impl InstructionProcessor for InstructionTranspiler { fn process_jalr(&mut self, dec_insn: IType) -> Self::InstructionResult { Instruction { kind: InsnKind::JALR, - rd: dec_insn.rd, - rs1: dec_insn.rs1, + rd: dec_insn.rd as RegIdx, + rs1: dec_insn.rs1 as RegIdx, rs2: 0, imm: dec_insn.imm, raw: self.word, @@ -265,7 +268,7 @@ impl InstructionProcessor for InstructionTranspiler { // See [`InstructionTranspiler::process_auipc`] for more background on the conversion. Instruction { kind: InsnKind::ADDI, - rd: dec_insn.rd, + rd: dec_insn.rd as RegIdx, rs1: 0, rs2: 0, imm: dec_insn.imm, @@ -276,7 +279,7 @@ impl InstructionProcessor for InstructionTranspiler { { Instruction { kind: InsnKind::LUI, - rd: dec_insn.rd, + rd: dec_insn.rd as RegIdx, rs1: 0, rs2: 0, imm: dec_insn.imm, @@ -311,7 +314,7 @@ impl InstructionProcessor for InstructionTranspiler { // real world scenarios like a `reth` run. Instruction { kind: InsnKind::ADDI, - rd: dec_insn.rd, + rd: dec_insn.rd as RegIdx, rs1: 0, rs2: 0, imm: dec_insn.imm.wrapping_add(pc as i32), @@ -322,7 +325,7 @@ impl InstructionProcessor for InstructionTranspiler { { Instruction { kind: InsnKind::AUIPC, - rd: dec_insn.rd, + rd: dec_insn.rd as RegIdx, rs1: 0, rs2: 0, imm: dec_insn.imm, diff --git a/ceno_emul/src/lib.rs b/ceno_emul/src/lib.rs index 6b16a3587..915edd18f 100644 --- a/ceno_emul/src/lib.rs +++ b/ceno_emul/src/lib.rs @@ -34,7 +34,7 @@ pub use syscalls::{ BN254_FP_MUL, BN254_FP2_ADD, BN254_FP2_MUL, KECCAK_PERMUTE, SECP256K1_ADD, SECP256K1_DECOMPRESS, SECP256K1_DOUBLE, SECP256K1_SCALAR_INVERT, SECP256R1_ADD, SECP256R1_DECOMPRESS, SECP256R1_DOUBLE, SECP256R1_SCALAR_INVERT, SHA_EXTEND, SyscallSpec, - UINT256_MUL, + SyscallWitness, UINT256_MUL, bn254::{ BN254_FP_WORDS, BN254_FP2_WORDS, BN254_POINT_WORDS, Bn254AddSpec, Bn254DoubleSpec, Bn254Fp2AddSpec, Bn254Fp2MulSpec, Bn254FpAddSpec, Bn254FpMulSpec, diff --git a/ceno_emul/src/platform.rs b/ceno_emul/src/platform.rs index 75c7e8f11..4c84b96c9 100644 --- a/ceno_emul/src/platform.rs +++ b/ceno_emul/src/platform.rs @@ -134,7 +134,7 @@ impl Platform { /// Virtual address of a register. pub const fn register_vma(index: RegIdx) -> Addr { // Register VMAs are aligned, cannot be confused with indices, and readable in hex. - (index << 8) as Addr + (index as Addr) << 8 } /// Register index from a virtual address (unchecked). @@ -220,7 +220,7 @@ mod tests { // Registers do not overlap with ROM or RAM. for reg in [ Platform::register_vma(0), - Platform::register_vma(VMState::::REG_COUNT - 1), + Platform::register_vma((VMState::::REG_COUNT - 1) as RegIdx), ] { assert!(!p.is_rom(reg)); assert!(!p.is_ram(reg)); diff --git a/ceno_emul/src/rv32im.rs b/ceno_emul/src/rv32im.rs index a3eac8896..48f7beab9 100644 --- a/ceno_emul/src/rv32im.rs +++ b/ceno_emul/src/rv32im.rs @@ -29,9 +29,9 @@ use super::addr::{ByteAddr, RegIdx, WORD_SIZE, Word, WordAddr}; pub const fn encode_rv32(kind: InsnKind, rs1: u32, rs2: u32, rd: u32, imm: i32) -> Instruction { Instruction { kind, - rs1: rs1 as usize, - rs2: rs2 as usize, - rd: rd as usize, + rs1: rs1 as RegIdx, + rs2: rs2 as RegIdx, + rd: rd as RegIdx, imm, raw: 0, } @@ -43,9 +43,9 @@ pub const fn encode_rv32(kind: InsnKind, rs1: u32, rs2: u32, rd: u32, imm: i32) pub const fn encode_rv32u(kind: InsnKind, rs1: u32, rs2: u32, rd: u32, imm: u32) -> Instruction { Instruction { kind, - rs1: rs1 as usize, - rs2: rs2 as usize, - rd: rd as usize, + rs1: rs1 as RegIdx, + rs2: rs2 as RegIdx, + rd: rd as RegIdx, imm: imm as i32, raw: 0, } @@ -113,6 +113,7 @@ pub enum TrapCause { } #[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord)] +#[repr(C)] pub struct Instruction { pub kind: InsnKind, pub rs1: RegIdx, @@ -162,6 +163,7 @@ use InsnFormat::*; ToPrimitive, Default, )] +#[repr(u8)] #[allow(clippy::upper_case_acronyms)] pub enum InsnKind { #[default] @@ -425,7 +427,7 @@ fn step_compute(ctx: &mut M, kind: InsnKind, insn: &Instruction) if !new_pc.is_aligned() { return ctx.trap(TrapCause::InstructionAddressMisaligned); } - ctx.store_register(insn.rd_internal() as usize, out)?; + ctx.store_register(insn.rd_internal() as RegIdx, out)?; ctx.set_pc(new_pc); Ok(true) } @@ -502,7 +504,7 @@ fn step_load(ctx: &mut M, kind: InsnKind, decoded: &Instruction) } _ => unreachable!(), }; - ctx.store_register(decoded.rd_internal() as usize, out)?; + ctx.store_register(decoded.rd_internal() as RegIdx, out)?; ctx.set_pc(ctx.get_pc() + WORD_SIZE); Ok(true) } diff --git a/ceno_emul/src/test_utils.rs b/ceno_emul/src/test_utils.rs index 39577c13c..92ad64a97 100644 --- a/ceno_emul/src/test_utils.rs +++ b/ceno_emul/src/test_utils.rs @@ -1,10 +1,11 @@ use crate::{ CENO_PLATFORM, InsnKind, Instruction, Platform, Program, StepRecord, VMState, encode_rv32, - encode_rv32u, syscalls::KECCAK_PERMUTE, + encode_rv32u, + syscalls::{KECCAK_PERMUTE, SyscallWitness}, }; use anyhow::Result; -pub fn keccak_step() -> (StepRecord, Vec) { +pub fn keccak_step() -> (StepRecord, Vec, Vec) { let instructions = vec![ // Call Keccak-f. load_immediate(Platform::reg_arg0() as u32, CENO_PLATFORM.heap.start), @@ -26,8 +27,9 @@ pub fn keccak_step() -> (StepRecord, Vec) { let mut vm = VMState::new(CENO_PLATFORM.clone(), program.into()); vm.iter_until_halt().collect::>>().unwrap(); let steps = vm.tracer().recorded_steps(); + let syscall_witnesses = vm.tracer().syscall_witnesses().to_vec(); - (steps[2].clone(), instructions) + (steps[2].clone(), instructions, syscall_witnesses) } const fn load_immediate(rd: u32, imm: u32) -> Instruction { diff --git a/ceno_emul/src/tracer.rs b/ceno_emul/src/tracer.rs index 45821ae64..4aa7b3080 100644 --- a/ceno_emul/src/tracer.rs +++ b/ceno_emul/src/tracer.rs @@ -22,7 +22,8 @@ use std::{collections::BTreeMap, fmt, sync::Arc}; /// - Any of `rs1 / rs2 / rd` **may be `x0`**. The trace handles this like any register, including the value that was _supposed_ to be stored. The circuits must handle this case: either **store `0` or skip `x0` operations**. /// /// - Any pair of `rs1 / rs2 / rd` **may be the same**. Then, one op will point to the other op in the same instruction but a different subcycle. The circuits may follow the operations **without special handling** of repeated registers. -#[derive(Clone, Debug, Default, PartialEq, Eq)] +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +#[repr(C)] pub struct StepRecord { cycle: Cycle, pc: Change, @@ -30,14 +31,45 @@ pub struct StepRecord { pub hint_maxtouch_addr: Change, pub insn: Instruction, - rs1: Option, - rs2: Option, + has_rs1: bool, + has_rs2: bool, + has_rd: bool, + has_memory_op: bool, - rd: Option, + rs1: ReadOp, + rs2: ReadOp, + rd: WriteOp, + memory_op: WriteOp, - memory_op: Option, + /// Index into the separate syscall witness storage. + /// `u32::MAX` means no syscall for this step. + syscall_index: u32, +} - syscall: Option, +impl StepRecord { + /// Sentinel value indicating no syscall is associated with this step. + pub const NO_SYSCALL: u32 = u32::MAX; +} + +impl Default for StepRecord { + fn default() -> Self { + Self { + cycle: 0, + pc: Default::default(), + heap_maxtouch_addr: Default::default(), + hint_maxtouch_addr: Default::default(), + insn: Default::default(), + has_rs1: false, + has_rs2: false, + has_rd: false, + has_memory_op: false, + rs1: Default::default(), + rs2: Default::default(), + rd: Default::default(), + memory_op: Default::default(), + syscall_index: StepRecord::NO_SYSCALL, + } + } } pub type StepIndex = usize; @@ -152,7 +184,8 @@ pub trait Tracer { ) -> Option<(WordAddr, WordAddr)>; } -#[derive(Clone, Debug, PartialEq, Eq, Hash)] +#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, Hash)] +#[repr(C)] pub struct MemOp { /// Virtual Memory Address. /// For registers, get it from `Platform::register_vma(idx)`. @@ -605,27 +638,41 @@ impl StepRecord { heap_maxtouch_addr: Change, hint_maxtouch_addr: Change, ) -> StepRecord { + let has_rs1 = rs1_read.is_some(); + let has_rs2 = rs2_read.is_some(); + let has_rd = rd.is_some(); + let has_memory_op = memory_op.is_some(); StepRecord { cycle, pc, - rs1: rs1_read.map(|rs1| ReadOp { - addr: Platform::register_vma(insn.rs1).into(), - value: rs1, - previous_cycle, - }), - rs2: rs2_read.map(|rs2| ReadOp { - addr: Platform::register_vma(insn.rs2).into(), - value: rs2, - previous_cycle, - }), - rd: rd.map(|rd| WriteOp { - addr: Platform::register_vma(insn.rd_internal() as RegIdx).into(), - value: rd, - previous_cycle, - }), + has_rs1, + has_rs2, + has_rd, + has_memory_op, + rs1: rs1_read + .map(|rs1| ReadOp { + addr: Platform::register_vma(insn.rs1).into(), + value: rs1, + previous_cycle, + }) + .unwrap_or_default(), + rs2: rs2_read + .map(|rs2| ReadOp { + addr: Platform::register_vma(insn.rs2).into(), + value: rs2, + previous_cycle, + }) + .unwrap_or_default(), + rd: rd + .map(|rd| WriteOp { + addr: Platform::register_vma(insn.rd_internal() as RegIdx).into(), + value: rd, + previous_cycle, + }) + .unwrap_or_default(), insn, - memory_op, - syscall: None, + memory_op: memory_op.unwrap_or_default(), + syscall_index: StepRecord::NO_SYSCALL, heap_maxtouch_addr, hint_maxtouch_addr, } @@ -645,19 +692,23 @@ impl StepRecord { } pub fn rs1(&self) -> Option { - self.rs1.clone() + if self.has_rs1 { Some(self.rs1) } else { None } } pub fn rs2(&self) -> Option { - self.rs2.clone() + if self.has_rs2 { Some(self.rs2) } else { None } } pub fn rd(&self) -> Option { - self.rd.clone() + if self.has_rd { Some(self.rd) } else { None } } pub fn memory_op(&self) -> Option { - self.memory_op.clone() + if self.has_memory_op { + Some(self.memory_op) + } else { + None + } } #[inline(always)] @@ -665,8 +716,19 @@ impl StepRecord { self.pc.before == self.pc.after } - pub fn syscall(&self) -> Option<&SyscallWitness> { - self.syscall.as_ref() + /// Returns true if this step has a syscall witness. + pub fn has_syscall(&self) -> bool { + self.syscall_index != Self::NO_SYSCALL + } + + /// Look up the syscall witness from a separate store. + /// The store is typically obtained from `FullTracer::syscall_witnesses()`. + pub fn syscall<'a>(&self, store: &'a [SyscallWitness]) -> Option<&'a SyscallWitness> { + if self.syscall_index == Self::NO_SYSCALL { + None + } else { + Some(&store[self.syscall_index as usize]) + } } } @@ -684,6 +746,9 @@ pub struct FullTracer { pending_index: usize, pending_cycle: Cycle, + /// Syscall witnesses stored separately (StepRecord references by index). + syscall_witnesses: Vec, + // record each section max access address // (start_addr -> (start_addr, end_addr, min_access_addr, max_access_addr)) mmio_min_max_access: Option>, @@ -724,6 +789,7 @@ impl FullTracer { len: 0, pending_index: 0, pending_cycle: Self::SUBCYCLES_PER_INSN, + syscall_witnesses: Vec::new(), mmio_min_max_access: Some(mmio_max_access), platform: platform.clone(), latest_accesses: LatestAccesses::new(platform), @@ -760,6 +826,7 @@ impl FullTracer { pub fn reset_step_buffer(&mut self) { self.len = 0; self.pending_index = 0; + self.syscall_witnesses.clear(); self.reset_pending_slot(); } @@ -767,6 +834,11 @@ impl FullTracer { &self.records[..self.len] } + /// Returns the syscall witness store. Pass this to `StepRecord::syscall()`. + pub fn syscall_witnesses(&self) -> &[SyscallWitness] { + &self.syscall_witnesses + } + #[inline(always)] pub fn step_record(&self, index: StepIndex) -> &StepRecord { assert!( @@ -822,41 +894,41 @@ impl FullTracer { #[inline(always)] pub fn load_register(&mut self, idx: RegIdx, value: Word) { let addr = Platform::register_vma(idx).into(); - match ( - self.records[self.pending_index].rs1.as_ref(), - self.records[self.pending_index].rs2.as_ref(), - ) { - (None, None) => { - self.records[self.pending_index].rs1 = Some(ReadOp { - addr, - value, - previous_cycle: self.track_access(addr, Self::SUBCYCLE_RS1), - }); - } - (Some(_), None) => { - self.records[self.pending_index].rs2 = Some(ReadOp { - addr, - value, - previous_cycle: self.track_access(addr, Self::SUBCYCLE_RS2), - }); - } - _ => unimplemented!("Only two register reads are supported"), + if !self.records[self.pending_index].has_rs1 { + let previous_cycle = self.track_access(addr, Self::SUBCYCLE_RS1); + self.records[self.pending_index].rs1 = ReadOp { + addr, + value, + previous_cycle, + }; + self.records[self.pending_index].has_rs1 = true; + } else if !self.records[self.pending_index].has_rs2 { + let previous_cycle = self.track_access(addr, Self::SUBCYCLE_RS2); + self.records[self.pending_index].rs2 = ReadOp { + addr, + value, + previous_cycle, + }; + self.records[self.pending_index].has_rs2 = true; + } else { + unimplemented!("Only two register reads are supported"); } } #[inline(always)] pub fn store_register(&mut self, idx: RegIdx, value: Change) { - if self.records[self.pending_index].rd.is_some() { + if self.records[self.pending_index].has_rd { unimplemented!("Only one register write is supported"); } let addr = Platform::register_vma(idx).into(); let previous_cycle = self.track_access(addr, Self::SUBCYCLE_RD); - self.records[self.pending_index].rd = Some(WriteOp { + self.records[self.pending_index].rd = WriteOp { addr, value, previous_cycle, - }); + }; + self.records[self.pending_index].has_rd = true; } #[inline(always)] @@ -866,7 +938,7 @@ impl FullTracer { #[inline(always)] pub fn store_memory(&mut self, addr: WordAddr, value: Change) { - if self.records[self.pending_index].memory_op.is_some() { + if self.records[self.pending_index].has_memory_op { unimplemented!("Only one memory access is supported"); } @@ -899,19 +971,26 @@ impl FullTracer { } } - self.records[self.pending_index].memory_op = Some(WriteOp { + let previous_cycle = self.track_access(addr, Self::SUBCYCLE_MEM); + self.records[self.pending_index].memory_op = WriteOp { addr, value, - previous_cycle: self.track_access(addr, Self::SUBCYCLE_MEM), - }); + previous_cycle, + }; + self.records[self.pending_index].has_memory_op = true; } #[inline(always)] pub fn track_syscall(&mut self, effects: SyscallEffects) { let witness = effects.finalize(self); let record = &mut self.records[self.pending_index]; - assert!(record.syscall.is_none(), "Only one syscall per step"); - record.syscall = Some(witness); + assert!( + record.syscall_index == StepRecord::NO_SYSCALL, + "Only one syscall per step" + ); + let idx = self.syscall_witnesses.len(); + self.syscall_witnesses.push(witness); + record.syscall_index = idx as u32; } #[inline(always)] @@ -1371,6 +1450,7 @@ impl Tracer for FullTracer { } #[derive(Copy, Clone, Default, PartialEq, Eq)] +#[repr(C)] pub struct Change { pub before: T, pub after: T, @@ -1387,3 +1467,92 @@ impl fmt::Debug for Change { write!(f, "{:?} -> {:?}", self.before, self.after) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_step_record_is_copy_and_compact() { + // Verify StepRecord is Copy (this compiles only if Copy is implemented) + fn assert_copy() {} + assert_copy::(); + + // Verify repr(C) compactness — should be well under 128 bytes + let size = std::mem::size_of::(); + eprintln!("StepRecord size: {} bytes", size); + assert!( + size <= 144, + "StepRecord should be compact for GPU transfer: got {} bytes", + size + ); + } + + #[test] + fn test_supporting_types_are_copy() { + fn assert_copy() {} + assert_copy::(); + assert_copy::(); + assert_copy::>(); + assert_copy::>(); + } + + /// Verify exact byte offsets of StepRecord fields for CUDA struct alignment. + /// If this test fails, the CUDA step_record.cuh header must be updated to match. + #[test] + fn test_step_record_layout_for_gpu() { + use std::mem; + + macro_rules! offset_of { + ($type:ty, $field:ident) => {{ + let val = <$type>::default(); + let base = &val as *const _ as usize; + let field = &val.$field as *const _ as usize; + field - base + }}; + } + + // Sub-type sizes + assert_eq!(mem::size_of::(), 12, "Instruction size"); + assert_eq!(mem::size_of::(), 16, "ReadOp size"); + assert_eq!(mem::size_of::(), 24, "WriteOp size"); + assert_eq!( + mem::size_of::>(), + 8, + "Change size" + ); + + // StepRecord field offsets — these must match step_record.cuh + assert_eq!(offset_of!(StepRecord, cycle), 0); + assert_eq!(offset_of!(StepRecord, pc), 8); + assert_eq!(offset_of!(StepRecord, heap_maxtouch_addr), 16); + assert_eq!(offset_of!(StepRecord, hint_maxtouch_addr), 24); + assert_eq!(offset_of!(StepRecord, insn), 32); + assert_eq!(offset_of!(StepRecord, has_rs1), 44); + assert_eq!(offset_of!(StepRecord, has_rs2), 45); + assert_eq!(offset_of!(StepRecord, has_rd), 46); + assert_eq!(offset_of!(StepRecord, has_memory_op), 47); + assert_eq!(offset_of!(StepRecord, rs1), 48); + assert_eq!(offset_of!(StepRecord, rs2), 64); + assert_eq!(offset_of!(StepRecord, rd), 80); + assert_eq!(offset_of!(StepRecord, memory_op), 104); + assert_eq!(offset_of!(StepRecord, syscall_index), 128); + + // Total size + assert_eq!(mem::size_of::(), 136, "StepRecord total size"); + assert_eq!(mem::align_of::(), 8, "StepRecord alignment"); + + // InsnKind must be repr(u8) for CUDA compatibility + assert_eq!( + mem::size_of::(), + 1, + "InsnKind must be 1 byte (repr(u8))" + ); + + eprintln!( + "StepRecord layout verified: {} bytes, {} align", + mem::size_of::(), + mem::align_of::() + ); + } +} diff --git a/ceno_emul/src/vm_state.rs b/ceno_emul/src/vm_state.rs index d65844682..0613436be 100644 --- a/ceno_emul/src/vm_state.rs +++ b/ceno_emul/src/vm_state.rs @@ -155,7 +155,7 @@ impl VMState { } pub fn init_register_unsafe(&mut self, idx: RegIdx, value: Word) { - self.registers[idx] = value; + self.registers[idx as usize] = value; } fn halt(&mut self, exit_code: u32) { @@ -171,7 +171,7 @@ impl VMState { } for (idx, value) in effects.iter_reg_values() { - self.registers[idx] = value; + self.registers[idx as usize] = value; } let next_pc = effects.next_pc.unwrap_or(self.pc + PC_STEP_SIZE as u32); @@ -252,7 +252,7 @@ impl EmuContext for VMState { if idx != 0 { let before = self.peek_register(idx); self.tracer.store_register(idx, Change { before, after }); - self.registers[idx] = after; + self.registers[idx as usize] = after; } Ok(()) } @@ -276,7 +276,7 @@ impl EmuContext for VMState { /// Get the value of a register without side-effects. fn peek_register(&self, idx: RegIdx) -> Word { - self.registers[idx] + self.registers[idx as usize] } /// Get the value of a memory word without side-effects. diff --git a/ceno_host/tests/test_elf.rs b/ceno_host/tests/test_elf.rs index ff752267d..b06c48d6e 100644 --- a/ceno_host/tests/test_elf.rs +++ b/ceno_host/tests/test_elf.rs @@ -3,7 +3,7 @@ use std::{collections::BTreeSet, iter::from_fn, sync::Arc}; use anyhow::Result; use ceno_emul::{ BN254_FP_WORDS, BN254_FP2_WORDS, BN254_POINT_WORDS, CENO_PLATFORM, EmuContext, InsnKind, - Platform, Program, SECP256K1_ARG_WORDS, SECP256K1_COORDINATE_WORDS, StepRecord, + Platform, Program, SECP256K1_ARG_WORDS, SECP256K1_COORDINATE_WORDS, StepRecord, SyscallWitness, UINT256_WORDS_FIELD_ELEMENT, VMState, WORD_SIZE, Word, WordAddr, WriteOp, host_utils::{read_all_messages, read_all_messages_as_words}, }; @@ -21,7 +21,7 @@ fn test_ceno_rt_mini() -> Result<()> { ..CENO_PLATFORM.clone() }; let mut state = VMState::new(platform, Arc::new(program)); - let _steps = run(&mut state)?; + let (_steps, _syscall_witnesses) = run(&mut state)?; Ok(()) } @@ -39,7 +39,7 @@ fn test_ceno_rt_panic() { ..CENO_PLATFORM.clone() }; let mut state = VMState::new(platform, Arc::new(program)); - let steps = run(&mut state).unwrap(); + let (steps, _syscall_witnesses) = run(&mut state).unwrap(); let last = steps.last().unwrap(); assert_eq!(last.insn().kind, InsnKind::ECALL); assert_eq!(last.rs1().unwrap().value, Platform::ecall_halt()); @@ -56,7 +56,7 @@ fn test_ceno_rt_mem() -> Result<()> { }; let sheap = program.sheap.into(); let mut state = VMState::new(platform, Arc::new(program.clone())); - let _steps = run(&mut state)?; + let (_steps, _syscall_witnesses) = run(&mut state)?; let value = state.peek_memory(sheap); assert_eq!(value, 6765, "Expected Fibonacci 20, got {}", value); @@ -72,7 +72,7 @@ fn test_ceno_rt_alloc() -> Result<()> { ..CENO_PLATFORM.clone() }; let mut state = VMState::new(platform, Arc::new(program)); - let _steps = run(&mut state)?; + let (_steps, _syscall_witnesses) = run(&mut state)?; // Search for the RAM action of the test program. let mut found = (false, false); @@ -102,7 +102,7 @@ fn test_ceno_rt_io() -> Result<()> { ..CENO_PLATFORM.clone() }; let mut state = VMState::new(platform, Arc::new(program)); - let _steps = run(&mut state)?; + let (_steps, _syscall_witnesses) = run(&mut state)?; let all_messages = messages_to_strings(&read_all_messages(&state)); for msg in &all_messages { @@ -235,7 +235,7 @@ fn test_hashing() -> Result<()> { fn test_keccak_syscall() -> Result<()> { let program_elf = ceno_examples::keccak_syscall; let mut state = VMState::new_from_elf(unsafe_platform(), program_elf)?; - let steps = run(&mut state)?; + let (steps, syscall_witnesses) = run(&mut state)?; // Expect the program to have written successive states between Keccak permutations. let keccak_first_iter_outs = sample_keccak_f(1); @@ -251,7 +251,10 @@ fn test_keccak_syscall() -> Result<()> { } // Find the syscall records. - let syscalls = steps.iter().filter_map(|step| step.syscall()).collect_vec(); + let syscalls = steps + .iter() + .filter_map(|step| step.syscall(&syscall_witnesses)) + .collect_vec(); assert_eq!(syscalls.len(), 100); // Check the syscall effects. @@ -293,9 +296,12 @@ fn bytes_to_words(bytes: [u8; 65]) -> [u32; 16] { fn test_secp256k1() -> Result<()> { let program_elf = ceno_examples::secp256k1; let mut state = VMState::new_from_elf(unsafe_platform(), program_elf)?; - let steps = run(&mut state)?; + let (steps, syscall_witnesses) = run(&mut state)?; - let syscalls = steps.iter().filter_map(|step| step.syscall()).collect_vec(); + let syscalls = steps + .iter() + .filter_map(|step| step.syscall(&syscall_witnesses)) + .collect_vec(); assert!(!syscalls.is_empty()); Ok(()) @@ -305,9 +311,12 @@ fn test_secp256k1() -> Result<()> { fn test_secp256k1_add() -> Result<()> { let program_elf = ceno_examples::secp256k1_add_syscall; let mut state = VMState::new_from_elf(unsafe_platform(), program_elf)?; - let steps = run(&mut state)?; + let (steps, syscall_witnesses) = run(&mut state)?; - let syscalls = steps.iter().filter_map(|step| step.syscall()).collect_vec(); + let syscalls = steps + .iter() + .filter_map(|step| step.syscall(&syscall_witnesses)) + .collect_vec(); assert_eq!(syscalls.len(), 1); let witness = syscalls[0]; @@ -358,9 +367,12 @@ fn test_secp256k1_double() -> Result<()> { let program_elf = ceno_examples::secp256k1_double_syscall; let mut state = VMState::new_from_elf(unsafe_platform(), program_elf)?; - let steps = run(&mut state)?; + let (steps, syscall_witnesses) = run(&mut state)?; - let syscalls = steps.iter().filter_map(|step| step.syscall()).collect_vec(); + let syscalls = steps + .iter() + .filter_map(|step| step.syscall(&syscall_witnesses)) + .collect_vec(); assert_eq!(syscalls.len(), 1); let witness = syscalls[0]; @@ -394,9 +406,12 @@ fn test_secp256k1_decompress() -> Result<()> { let program_elf = ceno_examples::secp256k1_decompress_syscall; let mut state = VMState::new_from_elf(unsafe_platform(), program_elf)?; - let steps = run(&mut state)?; + let (steps, syscall_witnesses) = run(&mut state)?; - let syscalls = steps.iter().filter_map(|step| step.syscall()).collect_vec(); + let syscalls = steps + .iter() + .filter_map(|step| step.syscall(&syscall_witnesses)) + .collect_vec(); assert_eq!(syscalls.len(), 1); let witness = syscalls[0]; @@ -456,8 +471,11 @@ fn test_secp256k1_ecrecover() -> Result<()> { let program_elf = ceno_examples::secp256k1_ecrecover; let mut state = VMState::new_from_elf(unsafe_platform(), program_elf)?; - let steps = run(&mut state)?; - let syscalls = steps.iter().filter_map(|step| step.syscall()).collect_vec(); + let (steps, syscall_witnesses) = run(&mut state)?; + let syscalls = steps + .iter() + .filter_map(|step| step.syscall(&syscall_witnesses)) + .collect_vec(); assert!(!syscalls.is_empty()); Ok(()) @@ -479,8 +497,11 @@ fn test_sha256_extend() -> Result<()> { ]; let mut state = VMState::new_from_elf(unsafe_platform(), program_elf)?; - let steps = run(&mut state)?; - let syscalls = steps.iter().filter_map(|step| step.syscall()).collect_vec(); + let (steps, syscall_witnesses) = run(&mut state)?; + let syscalls = steps + .iter() + .filter_map(|step| step.syscall(&syscall_witnesses)) + .collect_vec(); assert_eq!(syscalls.len(), 48); for round in 0..48 { @@ -534,10 +555,13 @@ fn test_sha256_full() -> Result<()> { fn test_bn254_fptower_syscalls() -> Result<()> { let program_elf = ceno_examples::bn254_fptower_syscalls; let mut state = VMState::new_from_elf(unsafe_platform(), program_elf)?; - let steps = run(&mut state)?; + let (steps, syscall_witnesses) = run(&mut state)?; const RUNS: usize = 10; - let syscalls = steps.iter().filter_map(|step| step.syscall()).collect_vec(); + let syscalls = steps + .iter() + .filter_map(|step| step.syscall(&syscall_witnesses)) + .collect_vec(); assert_eq!(syscalls.len(), 4 * RUNS); for witness in syscalls.iter() { @@ -584,9 +608,12 @@ fn test_bn254_fptower_syscalls() -> Result<()> { fn test_bn254_curve() -> Result<()> { let program_elf = ceno_examples::bn254_curve_syscalls; let mut state = VMState::new_from_elf(unsafe_platform(), program_elf)?; - let steps = run(&mut state)?; + let (steps, syscall_witnesses) = run(&mut state)?; - let syscalls = steps.iter().filter_map(|step| step.syscall()).collect_vec(); + let syscalls = steps + .iter() + .filter_map(|step| step.syscall(&syscall_witnesses)) + .collect_vec(); assert_eq!(syscalls.len(), 3); // add @@ -652,9 +679,12 @@ fn test_uint256_mul() -> Result<()> { let program_elf = ceno_examples::uint256_mul_syscall; let mut state = VMState::new_from_elf(unsafe_platform(), program_elf)?; - let steps = run(&mut state)?; + let (steps, syscall_witnesses) = run(&mut state)?; - let syscalls = steps.iter().filter_map(|step| step.syscall()).collect_vec(); + let syscalls = steps + .iter() + .filter_map(|step| step.syscall(&syscall_witnesses)) + .collect_vec(); assert_eq!(syscalls.len(), 1); let witness = syscalls[0]; @@ -805,9 +835,10 @@ fn messages_to_strings(messages: &[Vec]) -> Vec { .collect() } -fn run(state: &mut VMState) -> Result> { +fn run(state: &mut VMState) -> Result<(Vec, Vec)> { state.iter_until_halt().collect::>>()?; let steps = state.tracer().recorded_steps().to_vec(); + let syscall_witnesses = state.tracer().syscall_witnesses().to_vec(); eprintln!("Emulator ran for {} steps.", steps.len()); - Ok(steps) + Ok((steps, syscall_witnesses)) } diff --git a/ceno_zkvm/src/e2e.rs b/ceno_zkvm/src/e2e.rs index 215cbf7b6..7a3c4710c 100644 --- a/ceno_zkvm/src/e2e.rs +++ b/ceno_zkvm/src/e2e.rs @@ -24,9 +24,9 @@ use crate::{ }; use ceno_emul::{ Addr, ByteAddr, CENO_PLATFORM, Cycle, EmuContext, FullTracer, FullTracerConfig, IterAddresses, - NextCycleAccess, Platform, PreflightTracer, PreflightTracerConfig, Program, StepCellExtractor, - StepIndex, StepRecord, Tracer, VM_REG_COUNT, VMState, WORD_SIZE, Word, WordAddr, - host_utils::read_all_messages, + NextCycleAccess, Platform, PreflightTracer, PreflightTracerConfig, Program, RegIdx, + StepCellExtractor, StepIndex, StepRecord, SyscallWitness, Tracer, VM_REG_COUNT, VMState, + WORD_SIZE, Word, WordAddr, host_utils::read_all_messages, }; use clap::ValueEnum; use either::Either; @@ -199,6 +199,8 @@ pub struct ShardContext<'a> { pub platform: Platform, pub shard_heap_addr_range: Range, pub shard_hint_addr_range: Range, + /// Syscall witnesses for StepRecord::syscall() lookups. + pub syscall_witnesses: Arc>, } impl<'a> Default for ShardContext<'a> { @@ -233,6 +235,7 @@ impl<'a> Default for ShardContext<'a> { platform: CENO_PLATFORM.clone(), shard_heap_addr_range: CENO_PLATFORM.heap.clone(), shard_hint_addr_range: CENO_PLATFORM.hints.clone(), + syscall_witnesses: Arc::new(Vec::new()), } } } @@ -279,6 +282,7 @@ impl<'a> ShardContext<'a> { platform: self.platform.clone(), shard_heap_addr_range: self.shard_heap_addr_range.clone(), shard_hint_addr_range: self.shard_hint_addr_range.clone(), + syscall_witnesses: self.syscall_witnesses.clone(), }) .collect_vec(), _ => panic!("invalid type"), @@ -750,6 +754,7 @@ pub trait StepSource: Iterator { fn start_new_shard(&mut self); fn shard_steps(&self) -> &[StepRecord]; fn step_record(&self, idx: StepIndex) -> &StepRecord; + fn syscall_witnesses(&self) -> &[SyscallWitness]; } /// Lazily replays `StepRecord`s by re-running the VM up to the number of steps @@ -822,6 +827,10 @@ impl StepSource for StepReplay { fn step_record(&self, idx: StepIndex) -> &StepRecord { self.vm.tracer().step_record(idx) } + + fn syscall_witnesses(&self) -> &[SyscallWitness] { + self.vm.tracer().syscall_witnesses() + } } pub fn emulate_program<'a>( @@ -899,8 +908,8 @@ pub fn emulate_program<'a>( let reg_final = reg_init .iter() .map(|rec| { - let index = rec.addr as usize; - if index < VM_REG_COUNT { + if (rec.addr as usize) < VM_REG_COUNT { + let index = rec.addr as RegIdx; let vma: WordAddr = Platform::register_vma(index).into(); MemFinalRecord { ram_type: RAMType::Register, @@ -1282,6 +1291,7 @@ pub fn generate_witness<'a, E: ExtensionField>( }; tracing::debug!("position_next_shard finish in {:?}", time.elapsed()); let shard_steps = step_iter.shard_steps(); + shard_ctx.syscall_witnesses = Arc::new(step_iter.syscall_witnesses().to_vec()); let mut zkvm_witness = ZKVMWitnesses::default(); let mut pi = pi_template.clone(); @@ -2122,7 +2132,9 @@ pub fn verify + serde::Ser #[cfg(test)] mod tests { use crate::e2e::{MultiProver, ShardContextBuilder}; - use ceno_emul::{CENO_PLATFORM, Cycle, FullTracer, NextCycleAccess, StepIndex, StepRecord}; + use ceno_emul::{ + CENO_PLATFORM, Cycle, FullTracer, NextCycleAccess, StepIndex, StepRecord, SyscallWitness, + }; use itertools::Itertools; use std::sync::Arc; @@ -2224,6 +2236,10 @@ mod tests { fn step_record(&self, idx: StepIndex) -> &StepRecord { &self.steps[self.shard_start + idx] } + + fn syscall_witnesses(&self) -> &[SyscallWitness] { + &[] // Test replay doesn't track syscalls + } } let mut steps_iter = TestReplay::new(steps); diff --git a/ceno_zkvm/src/instructions/riscv/dummy/dummy_ecall.rs b/ceno_zkvm/src/instructions/riscv/dummy/dummy_ecall.rs index 3ae516e9c..650d5d97a 100644 --- a/ceno_zkvm/src/instructions/riscv/dummy/dummy_ecall.rs +++ b/ceno_zkvm/src/instructions/riscv/dummy/dummy_ecall.rs @@ -99,7 +99,8 @@ impl Instruction for LargeEcallDummy lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, ) -> Result<(), ZKVMError> { - let ops = &step.syscall().expect("syscall step"); + let sw = shard_ctx.syscall_witnesses.clone(); + let ops = &step.syscall(&sw).expect("syscall step"); // Assign instruction. config diff --git a/ceno_zkvm/src/instructions/riscv/dummy/test.rs b/ceno_zkvm/src/instructions/riscv/dummy/test.rs index b74c8ca39..bf952a4f1 100644 --- a/ceno_zkvm/src/instructions/riscv/dummy/test.rs +++ b/ceno_zkvm/src/instructions/riscv/dummy/test.rs @@ -18,10 +18,12 @@ fn test_large_ecall_dummy_keccak() { let mut cb = CircuitBuilder::new(&mut cs); let config = KeccakDummy::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); - let (step, program) = ceno_emul::test_utils::keccak_step(); + let (step, program, syscall_witnesses) = ceno_emul::test_utils::keccak_step(); + let mut shard_ctx = ShardContext::default(); + shard_ctx.syscall_witnesses = std::sync::Arc::new(syscall_witnesses); let (raw_witin, lkm) = KeccakDummy::assign_instances_from_steps( &config, - &mut ShardContext::default(), + &mut shard_ctx, cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, &[step], diff --git a/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp.rs b/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp.rs index 57d824b01..7ee531cc8 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp.rs @@ -358,7 +358,8 @@ fn assign_fp_op_instances( .zip_eq(indices.iter().copied()) .map(|(instance, idx)| { let step = &steps[idx]; - let ops = &step.syscall().expect("syscall step"); + let sw = shard_ctx.syscall_witnesses.clone(); + let ops = &step.syscall(&sw).expect("syscall step"); config .vm_state .assign_instance(instance, &shard_ctx, step)?; @@ -419,7 +420,7 @@ fn assign_fp_op_instances( .map(|&idx| { let step = &steps[idx]; let values: Vec = step - .syscall() + .syscall(&shard_ctx.syscall_witnesses) .unwrap() .mem_ops .iter() diff --git a/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp2_add.rs b/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp2_add.rs index 6715a0b74..2d99ed4a4 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp2_add.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp2_add.rs @@ -273,7 +273,8 @@ fn assign_fp2_add_instances = step - .syscall() + .syscall(&shard_ctx.syscall_witnesses) .unwrap() .mem_ops .iter() diff --git a/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp2_mul.rs b/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp2_mul.rs index 7537d31ca..4aacf3418 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp2_mul.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp2_mul.rs @@ -271,7 +271,8 @@ fn assign_fp2_mul_instances = step - .syscall() + .syscall(&shard_ctx.syscall_witnesses) .unwrap() .mem_ops .iter() diff --git a/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs b/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs index e088cc0cc..f9c9f1712 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs @@ -221,7 +221,8 @@ impl Instruction for KeccakInstruction { .zip_eq(indices.iter().copied()) .map(|(instance_with_rotation, idx)| { let step = &steps[idx]; - let ops = &step.syscall().expect("syscall step"); + let sw = shard_ctx.syscall_witnesses.clone(); + let ops = &step.syscall(&sw).expect("syscall step"); let bh = BooleanHypercube::new(KECCAK_ROUNDS_CEIL_LOG2); let mut cyclic_group = bh.into_iter(); @@ -285,7 +286,7 @@ impl Instruction for KeccakInstruction { .map(|&idx| -> KeccakInstance { let step = &steps[idx]; let (instance, prev_ts): (Vec, Vec) = step - .syscall() + .syscall(&shard_ctx.syscall_witnesses) .unwrap() .mem_ops .iter() diff --git a/ceno_zkvm/src/instructions/riscv/ecall/sha_extend.rs b/ceno_zkvm/src/instructions/riscv/ecall/sha_extend.rs index b61673e4a..3a0f42ab2 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/sha_extend.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/sha_extend.rs @@ -218,7 +218,8 @@ impl Instruction for ShaExtendInstruction { .zip_eq(indices.iter().copied()) .map(|(instance, idx)| { let step = &steps[idx]; - let ops = step.syscall().expect("syscall step"); + let sw = shard_ctx.syscall_witnesses.clone(); + let ops = step.syscall(&sw).expect("syscall step"); // vm_state config @@ -285,7 +286,8 @@ impl Instruction for ShaExtendInstruction { .iter() .map(|&idx| -> ShaExtendInstance { let step = &steps[idx]; - let ops = step.syscall().expect("syscall step"); + let sw = shard_ctx.syscall_witnesses.clone(); + let ops = step.syscall(&sw).expect("syscall step"); let w_i_minus_2 = ops.mem_ops[0].value.before; let w_i_minus_7 = ops.mem_ops[1].value.before; let w_i_minus_15 = ops.mem_ops[2].value.before; diff --git a/ceno_zkvm/src/instructions/riscv/ecall/uint256.rs b/ceno_zkvm/src/instructions/riscv/ecall/uint256.rs index f3a39093f..67a042cd7 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/uint256.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/uint256.rs @@ -270,7 +270,8 @@ impl Instruction for Uint256MulInstruction { .zip_eq(indices.iter().copied()) .map(|(instance, idx)| { let step = &steps[idx]; - let ops = &step.syscall().expect("syscall step"); + let sw = shard_ctx.syscall_witnesses.clone(); + let ops = &step.syscall(&sw).expect("syscall step"); // vm_state config @@ -336,7 +337,7 @@ impl Instruction for Uint256MulInstruction { .map(|&idx| { let step = &steps[idx]; let (instance, _prev_ts): (Vec, Vec) = step - .syscall() + .syscall(&shard_ctx.syscall_witnesses) .unwrap() .mem_ops .iter() @@ -593,7 +594,8 @@ impl Instruction for Uint256InvInstr .zip_eq(indices.iter().copied()) .map(|(instance, idx)| { let step = &steps[idx]; - let ops = &step.syscall().expect("syscall step"); + let sw = shard_ctx.syscall_witnesses.clone(); + let ops = &step.syscall(&sw).expect("syscall step"); // vm_state config @@ -646,7 +648,7 @@ impl Instruction for Uint256InvInstr .map(|&idx| { let step = &steps[idx]; let (instance, _): (Vec, Vec) = step - .syscall() + .syscall(&shard_ctx.syscall_witnesses) .unwrap() .mem_ops .iter() diff --git a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_add.rs b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_add.rs index 80c85ef7a..ca4f59ed3 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_add.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_add.rs @@ -278,7 +278,8 @@ impl Instruction .zip_eq(indices.iter().copied()) .map(|(instance, idx)| { let step = &steps[idx]; - let ops = &step.syscall().expect("syscall step"); + let sw = shard_ctx.syscall_witnesses.clone(); + let ops = &step.syscall(&sw).expect("syscall step"); // vm_state config @@ -345,7 +346,7 @@ impl Instruction .map(|&idx| { let step = &steps[idx]; let (instance, _prev_ts): (Vec, Vec) = step - .syscall() + .syscall(&shard_ctx.syscall_witnesses) .unwrap() .mem_ops .iter() diff --git a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_decompress.rs b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_decompress.rs index a07fc00b2..1d79efd63 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_decompress.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_decompress.rs @@ -278,7 +278,8 @@ impl Instruction Instruction, Vec) = step - .syscall() + .syscall(&shard_ctx.syscall_witnesses) .unwrap() .mem_ops .iter() diff --git a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_double.rs b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_double.rs index 72f5f71d8..206ef143d 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_double.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_double.rs @@ -250,7 +250,8 @@ impl Instruction Instruction, Vec) = step - .syscall() + .syscall(&shard_ctx.syscall_witnesses) .unwrap() .mem_ops .iter() diff --git a/ceno_zkvm/src/instructions/riscv/ecall_base.rs b/ceno_zkvm/src/instructions/riscv/ecall_base.rs index 7e655c408..69d9d286f 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall_base.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall_base.rs @@ -18,13 +18,13 @@ use ceno_emul::FullTracer as Tracer; use multilinear_extensions::{ToExpr, WitIn}; #[derive(Debug)] -pub struct OpFixedRS { +pub struct OpFixedRS { pub prev_ts: WitIn, pub prev_value: Option>, pub lt_cfg: AssertLtConfig, } -impl OpFixedRS { +impl OpFixedRS { pub fn construct_circuit( circuit_builder: &mut CircuitBuilder, rd_written: RegisterExpr, From ac2bf54d394f0db191cb7f2b5fb139d0b074cc94 Mon Sep 17 00:00:00 2001 From: Velaciela Date: Wed, 4 Mar 2026 10:25:28 +0800 Subject: [PATCH 2/3] fix --- ceno_emul/src/syscalls.rs | 8 ++------ ceno_emul/src/test_utils.rs | 2 +- ceno_zkvm/src/instructions/riscv/memory/gadget.rs | 2 +- 3 files changed, 4 insertions(+), 8 deletions(-) diff --git a/ceno_emul/src/syscalls.rs b/ceno_emul/src/syscalls.rs index 5d9674fc6..31c87b271 100644 --- a/ceno_emul/src/syscalls.rs +++ b/ceno_emul/src/syscalls.rs @@ -60,19 +60,15 @@ pub fn handle_syscall(vm: &VMState, function_code: u32) -> Result< /// A syscall event, available to the circuit witness generators. /// TODO: separate mem_ops into two stages: reads-and-writes #[derive(Clone, Debug, Default, PartialEq, Eq)] +#[non_exhaustive] pub struct SyscallWitness { pub mem_ops: Vec, pub reg_ops: Vec, - _marker: (), } impl SyscallWitness { fn new(mem_ops: Vec, reg_ops: Vec) -> SyscallWitness { - SyscallWitness { - mem_ops, - reg_ops, - _marker: (), - } + SyscallWitness { mem_ops, reg_ops } } } diff --git a/ceno_emul/src/test_utils.rs b/ceno_emul/src/test_utils.rs index 92ad64a97..625ed52c5 100644 --- a/ceno_emul/src/test_utils.rs +++ b/ceno_emul/src/test_utils.rs @@ -29,7 +29,7 @@ pub fn keccak_step() -> (StepRecord, Vec, Vec) { let steps = vm.tracer().recorded_steps(); let syscall_witnesses = vm.tracer().syscall_witnesses().to_vec(); - (steps[2].clone(), instructions, syscall_witnesses) + (steps[2], instructions, syscall_witnesses) } const fn load_immediate(rd: u32, imm: u32) -> Instruction { diff --git a/ceno_zkvm/src/instructions/riscv/memory/gadget.rs b/ceno_zkvm/src/instructions/riscv/memory/gadget.rs index 3a8da4a09..a37be1f61 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/gadget.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/gadget.rs @@ -138,7 +138,7 @@ impl MemWordUtil { step: &StepRecord, shift: u32, ) -> Result<(), ZKVMError> { - let memory_op = step.memory_op().clone().unwrap(); + let memory_op = step.memory_op().unwrap(); let prev_value = Value::new_unchecked(memory_op.value.before); let rs2_value = Value::new_unchecked(step.rs2().unwrap().value); From 3c44a961255d02b9c7495f106731661ae2964adb Mon Sep 17 00:00:00 2001 From: Velaciela Date: Wed, 25 Mar 2026 15:47:03 +0800 Subject: [PATCH 3/3] avoid to_vec() --- ceno_emul/src/tracer.rs | 7 +++++++ ceno_zkvm/src/e2e.rs | 16 +++++++++++++++- 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/ceno_emul/src/tracer.rs b/ceno_emul/src/tracer.rs index 4aa7b3080..74c8de255 100644 --- a/ceno_emul/src/tracer.rs +++ b/ceno_emul/src/tracer.rs @@ -835,10 +835,17 @@ impl FullTracer { } /// Returns the syscall witness store. Pass this to `StepRecord::syscall()`. + #[inline(always)] pub fn syscall_witnesses(&self) -> &[SyscallWitness] { &self.syscall_witnesses } + /// Take ownership of syscall witnesses, leaving an empty Vec for the next shard. + /// Avoids the `to_vec()` clone when wrapping in `Arc`. + pub fn take_syscall_witnesses(&mut self) -> Vec { + std::mem::take(&mut self.syscall_witnesses) + } + #[inline(always)] pub fn step_record(&self, index: StepIndex) -> &StepRecord { assert!( diff --git a/ceno_zkvm/src/e2e.rs b/ceno_zkvm/src/e2e.rs index 7a3c4710c..e47dfd9a2 100644 --- a/ceno_zkvm/src/e2e.rs +++ b/ceno_zkvm/src/e2e.rs @@ -200,6 +200,7 @@ pub struct ShardContext<'a> { pub shard_heap_addr_range: Range, pub shard_hint_addr_range: Range, /// Syscall witnesses for StepRecord::syscall() lookups. + /// Borrowed from the tracer — no per-shard Vec clone. pub syscall_witnesses: Arc>, } @@ -755,6 +756,8 @@ pub trait StepSource: Iterator { fn shard_steps(&self) -> &[StepRecord]; fn step_record(&self, idx: StepIndex) -> &StepRecord; fn syscall_witnesses(&self) -> &[SyscallWitness]; + /// Take ownership of syscall witnesses (zero-copy move, leaves empty Vec). + fn take_syscall_witnesses(&mut self) -> Vec; } /// Lazily replays `StepRecord`s by re-running the VM up to the number of steps @@ -831,6 +834,10 @@ impl StepSource for StepReplay { fn syscall_witnesses(&self) -> &[SyscallWitness] { self.vm.tracer().syscall_witnesses() } + + fn take_syscall_witnesses(&mut self) -> Vec { + self.vm.tracer_mut().take_syscall_witnesses() + } } pub fn emulate_program<'a>( @@ -1290,8 +1297,11 @@ pub fn generate_witness<'a, E: ExtensionField>( None => return None, }; tracing::debug!("position_next_shard finish in {:?}", time.elapsed()); + // Move (not clone) syscall witnesses from tracer into Arc. + // take_syscall_witnesses() swaps the tracer's Vec with an empty one — zero copy. + // Must be called before shard_steps() to avoid borrow conflict. + shard_ctx.syscall_witnesses = Arc::new(step_iter.take_syscall_witnesses()); let shard_steps = step_iter.shard_steps(); - shard_ctx.syscall_witnesses = Arc::new(step_iter.syscall_witnesses().to_vec()); let mut zkvm_witness = ZKVMWitnesses::default(); let mut pi = pi_template.clone(); @@ -2240,6 +2250,10 @@ mod tests { fn syscall_witnesses(&self) -> &[SyscallWitness] { &[] // Test replay doesn't track syscalls } + + fn take_syscall_witnesses(&mut self) -> Vec { + Vec::new() + } } let mut steps_iter = TestReplay::new(steps);