Skip to content

feat: GPU witness generation (RV32IM + Keccak + ShardRam)#1259

Open
Velaciela wants to merge 73 commits intomasterfrom
feat/gpu-witnessgen
Open

feat: GPU witness generation (RV32IM + Keccak + ShardRam)#1259
Velaciela wants to merge 73 commits intomasterfrom
feat/gpu-witnessgen

Conversation

@Velaciela
Copy link
Collaborator

@Velaciela Velaciela commented Mar 3, 2026

related: #1265

GPU Witness Generation

Accelerate witness generation by offloading computation from CPU to GPU.
This module (ceno_zkvm/src/instructions/gpu/) contains all GPU-side dispatch,
caching, and utility code for the witness generation pipeline.

The CUDA backend lives in the sibling repo ceno-gpu/ (cuda_hal/src/common/witgen/).

Architecture

Module Layout

gpu/
├── dispatch.rs         — GPU dispatch entry point (try_gpu_assign_instances, gpu_fill_witness)
├── config.rs           — Environment variable config (3 env vars), kind tags
├── cache.rs            — Thread-local device buffer caching, shared EC/addr buffers
├── chips/              — Per-chip column map extractors + chip-specific GPU dispatch
│   ├── add.rs ... sw.rs  (24 RV32IM column map extractors)
│   ├── keccak.rs         (column map + keccak GPU dispatch: gpu_assign_keccak_instances)
│   └── shard_ram.rs      (column map + batch EC computation: gpu_batch_continuation_ec)
├── utils/
│   ├── column_map.rs   — Shared column map extraction helpers (extract_rs1, extract_rd, ...)
│   ├── d2h.rs          — Device-to-host: witness transpose, LK counter decode, compact EC D2H
│   ├── debug_compare.rs— GPU vs CPU comparison (activated by CENO_GPU_DEBUG_COMPARE_WITGEN)
│   ├── lk_ops.rs       — LkOp enum, SendEvent struct
│   ├── sink.rs         — LkShardramSink trait, CpuLkShardramSink
│   ├── emit.rs         — Emit helper functions (emit_u16_limbs, emit_logic_u8_ops, ...)
│   ├── fallback.rs     — CPU fallback: cpu_assign_instances, cpu_collect_lk_and_shardram
│   └── test_helpers.rs — Test utilities: assert_witness_colmajor_eq, assert_full_gpu_pipeline
└── mod.rs              — Module declarations + lk_shardram integration tests (19 tests)

Data Flow

                    Pass 1: PreflightTracer
                    ┌──────────────────────┐
                    │  ShardPlanBuilder     │ → shard boundaries
                    │  PackedNextAccessEntry│ → sorted future-access table
                    └──────────┬───────────┘
                               │
                    Pass 2: FullTracer (per shard)
                    ┌──────────▼───────────┐
                    │  Vec<StepRecord>      │ 136 bytes/step, #[repr(C)]
                    └──────────┬───────────┘
                               │ H2D (cached per shard in cache.rs)
                    ┌──────────▼───────────────────────────────────┐
                    │              GPU Per-Instruction              │
                    │  ┌─────────────┬──────────────┬────────────┐ │
                    │  │ F-1 Witness │ F-2 LK Count │ F-3 EC/Addr│ │
                    │  │ (col-major) │  (atomics)   │ (shared buf)│ │
                    │  └──────┬──────┴──────┬───────┴─────┬──────┘ │
                    └─────────┼─────────────┼─────────────┼────────┘
                              │             │             │
                      GPU transpose    D2H counters   flush at shard end
                              │             │             │
                    ┌─────────▼─────────────▼─────────────▼────────┐
                    │                 CPU Merge                     │
                    │  RowMajorMatrix  LkMultiplicity  ShardContext │
                    └──────────────────────┬───────────────────────┘
                                           │
                    ┌──────────────────────▼───────────────────────┐
                    │           ShardRamCircuit (GPU)               │
                    │  Phase 1: per-row Poseidon2 (344 cols)       │
                    │  Phase 2: binary EC tree (layer-by-layer)    │
                    └──────────────────────┬───────────────────────┘
                                           │
                                           ▼
                                     Proof Generation

Per-Shard Pipeline

Within generate_witness() (e2e.rs), each shard executes:

  1. upload_shard_steps_cached — H2D Vec<StepRecord> (cached, shared across all chips)
  2. ensure_shard_metadata_cached — H2D shard scalars + allocate shared EC/addr buffers
  3. Per-chip dispatchgpu_fill_witness matches GpuWitgenKind → 22 kernel variants
    • Each kernel writes: witness columns (col-major), LK counters (atomics), EC records + addr (shared buffers)
  4. flush_shared_ec_buffers — D2H shared EC records + addr_accessed into ShardContext
  5. invalidate_shard_steps_cache — Free GPU shard_steps memory
  6. assign_shared_circuit — ShardRamCircuit GPU pipeline (Poseidon2 + EC tree)

GPU/CPU Decision (dispatch.rs)

try_gpu_assign_instances():
  1. is_gpu_witgen_enabled()?          → CPU fallback if not set
  2. is_force_cpu_path() thread-local? → CPU fallback (debug comparison)
  3. I::GPU_LK_SHARDRAM == false?      → CPU fallback
  4. is_kind_disabled(kind)?           → CPU fallback
  5. Field != BabyBear?                → CPU fallback
  6. get_cuda_hal() unavailable?       → CPU fallback
  7. All pass                          → GPU path

Keccak Dispatch

Keccak has a dedicated GPU dispatch path (chips/keccak.rs::gpu_assign_keccak_instances)
separate from try_gpu_assign_instances because:

  1. Rotation: each instance spans 32 rows (not 1), requiring new_by_rotation
  2. Structural witness: 3 selectors (sel_first/sel_last/sel_all) vs the standard 1
  3. Input packing: needs packed_instances with syscall_witnesses

The LK/shardram collection logic is identical to the standard path.

Lk and Shardram Collection

After GPU computes the witness matrix, LK multiplicities and shard RAM records
are collected through one of several paths (priority order):

Path Witness LK Multiplicity Shard Records When
A Shared buffer GPU GPU counters → D2H Shared GPU buffer (deferred) Default for all verified kinds
B Compact EC GPU GPU counters → D2H Compact EC D2H per-kernel Older non-shared-buffer kinds
C CPU shardram GPU GPU counters → D2H CPU cpu_collect_shardram GPU shard unverified
D CPU full GPU CPU cpu_collect_lk_and_shardram CPU full GPU LK unverified
E CPU only CPU CPU assign_instance CPU assign_instance GPU unavailable

Currently all non-Keccak kinds use Path A. Paths B-E are fallback/debug paths.

E2E Pipeline Modes (e2e.rs)

create_proofs_streaming()
│
├─ Default GPU backend (CENO_GPU_ENABLE_WITGEN unset):
│   Overlap pipeline:
│     Thread A (CPU): witgen(shard 0) → witgen(shard 1) → witgen(shard 2) → ...
│     Thread B (GPU): ................prove(shard 0) → prove(shard 1) → ...
│     crossbeam::bounded(0) rendezvous channel for back-pressure
│
└─ CENO_GPU_ENABLE_WITGEN=1 (GPU witgen) or CPU-only build:
    Sequential pipeline:
      witgen(shard 0) → prove(shard 0) → witgen(shard 1) → prove(shard 1) → ...
      GPU shared between witgen and proving; no overlap possible.

Environment Variables

Variable Default Purpose
CENO_GPU_ENABLE_WITGEN unset (CPU witgen) Set to enable GPU witness generation. Sequential witgen+prove pipeline.
CENO_GPU_DISABLE_WITGEN_KINDS none Comma-separated kind tags to disable specific chips' GPU path. Example: add,keccak,lw. Falls back to CPU for those chips.
CENO_GPU_DEBUG_COMPARE_WITGEN unset Enable GPU vs CPU comparison for all chips. Runs both paths and diffs results.

CENO_GPU_DEBUG_COMPARE_WITGEN Coverage

When set, the following comparisons run automatically:

Per-chip (in dispatch.rs, for each opcode circuit):

  • debug_compare_final_lk — GPU LK multiplicity vs CPU assign_instance baseline (all 8 lookup tables)
  • debug_compare_witness — GPU witness matrix vs CPU witness (element-by-element, col-major vs row-major)
  • debug_compare_shardram — GPU shard records (read_records, write_records, addr_accessed) vs CPU
  • debug_compare_shard_ec — GPU compact EC records vs CPU-computed EC points (nonce, x[7], y[7])

Per-chip, Keccak-specific (in chips/keccak.rs):

  • debug_compare_keccak — Combined witness + LK + shard comparison for keccak's rotation-aware layout

Per-shard, E2E level (in e2e.rs):

  • log_shard_ctx_diff — Full shard context comparison after all opcode circuits (addr_accessed, read/write records across all chips merged)
  • log_combined_lk_diff — Merged LK multiplicities after finalize_lk_multiplicities() (catches cross-chip merge issues)

All comparisons output to stderr via eprintln! / tracing::error!, with a default limit of 16 mismatches per category.

Tests

79 tests total (cargo test --features gpu,u16limb_circuit -p ceno_zkvm --lib -- "gpu")

Category Count Location What it tests
Column map extraction 33 chips/*.rs (31 via test_colmap! macro + 2 manual) Circuit config → column map: all IDs in-range and unique
GPU witgen correctness 23 chips/*.rs GPU kernel output vs CPU assign_instance (element-by-element witness comparison)
LK+shardram match 19 gpu/mod.rs collect_lk_and_shardram / collect_shardram vs assign_instance baseline
LkOp encoding 1 utils/mod.rs LkOp::encode_all() produces correct table/key pairs
EC point match 1 scheme/septic_curve.rs GPU Poseidon2+SepticCurve EC point vs CPU to_ec_point
Poseidon2 sponge 1 scheme/septic_curve.rs GPU Poseidon2 permutation vs CPU
Septic from_x 1 scheme/septic_curve.rs GPU septic_point_from_x vs CPU

Running Tests

# All GPU tests (requires CUDA device)
CENO_GPU_ENABLE_WITGEN=1 cargo test --features gpu,u16limb_circuit -p ceno_zkvm --lib -- "gpu"

# Column map tests only (no CUDA device needed)
cargo test --features gpu,u16limb_circuit -p ceno_zkvm --lib -- "test_extract_"

# LK/shardram tests only (no CUDA device needed)
cargo test --features gpu,u16limb_circuit -p ceno_zkvm --lib -- "lk_shardram"

# With debug comparison enabled
CENO_GPU_ENABLE_WITGEN=1 CENO_GPU_DEBUG_COMPARE_WITGEN=1 cargo test --features gpu,u16limb_circuit -p ceno_host -- test_elf

Per-Chip Boilerplate Macros

Three macros in instructions.rs reduce per-chip GPU integration to ~3 lines:

impl Instruction<E> for MyChip {
    // Emit LK ops + shard RAM records (CPU companion for GPU witgen)
    impl_collect_lk_and_shardram!(r_insn, |sink, step, _config, _ctx| {
        emit_u16_limbs(sink, step.rd().unwrap().value.after);
    });

    // Collect shard RAM records only (when GPU handles LK)
    impl_collect_shardram!(r_insn);

    // GPU dispatch: try GPU → fallback CPU
    impl_gpu_assign!(dispatch::GpuWitgenKind::Add);
}

@Velaciela Velaciela mentioned this pull request Mar 9, 2026
5 tasks
@Velaciela Velaciela force-pushed the feat/gpu-witnessgen branch from c95ee3f to aadf86b Compare March 24, 2026 14:42
@Velaciela Velaciela changed the title (draft) feat: GPU witness generation feat: GPU witness generation (RV32IM + Keccak + ShardRam) Mar 25, 2026
@Velaciela Velaciela marked this pull request as ready for review March 25, 2026 02:03
@Velaciela
Copy link
Collaborator Author

GPU Witness Generation — Invasive Changes to Existing Codebase

This document lists all changes to existing ceno structures, traits, and flows
that this PR introduces. GPU-only new code (instructions/gpu/) is excluded —
this focuses on what existing code was modified and why.


1. ceno_emul — FFI Layout Changes (+332 / -88 lines)

#[repr(C)] on emulator types

The following types were made #[repr(C)] to enable zero-copy H2D transfer to GPU:

Type File Size Purpose
StepRecord tracer.rs 136B Per-step emulator output, bulk H2D
Instruction rv32im.rs 12B Opcode encoding embedded in StepRecord
InsnKind rv32im.rs 1B #[repr(u8)] enum discriminant
MemOp<T> tracer.rs 16/24B Read/Write ops embedded in StepRecord
Change<T> tracer.rs 2×T Before/after pair

Impact: These were previously #[derive(Debug, Clone)] with compiler-chosen layout.
Adding #[repr(C)] pins field order and padding. No behavioral change for CPU code,
but field reordering or insertion now requires updating the CUDA mirror structs.

New types in tracer.rs

  • PackedNextAccessEntry (16B, #[repr(C)]) — 40-bit packed cycle+addr for GPU FA table
  • ShardPlanBuilder — preflight shard planning with cell-count balancing

Layout test

test_step_record_layout_for_gpu verifies byte offsets of all StepRecord fields
at compile time. CUDA side has matching static_assert(sizeof(...)).


2. Instruction<E> Trait — New Methods and Constants

File: ceno_zkvm/src/instructions.rs

Addition Purpose
const GPU_LK_SHARDRAM: bool = false Opt-in flag: does this chip have GPU LK+shardram support?
fn collect_lk_and_shardram(...) CPU companion: collect all LK multiplicities + shard RAM records (without witness replay)
fn collect_shardram(...) CPU companion: collect shard RAM records only (GPU handles LK)

Default implementations return Err(...) — chips must explicitly opt in.

Impact: Existing chips that don't implement GPU support are unaffected (defaults).
The trait's existing assign_instance and assign_instances are unchanged.

Three macros reduce per-chip boilerplate:

  • impl_collect_lk_and_shardram! — wraps the unsafe CpuLkShardramSink prologue
  • impl_collect_shardram! — one-line delegate to insn_config
  • impl_gpu_assign!#[cfg(feature = "gpu")] assign_instances override

3. Gadgets — New emit_lk_and_shardram / emit_shardram Methods

File: ceno_zkvm/src/instructions/riscv/insn_base.rs (+253 lines)

Every base gadget (ReadRS1, ReadRS2, WriteRD, ReadMEM, WriteMEM, MemAddr)
gained two new methods:

Method What it does
emit_lk_and_shardram(sink, ctx, step) Emit LK ops + RAM send events through LkShardramSink
emit_shardram(shard_ctx, step) Directly write shard RAM records to ShardContext (no LK)

Impact: Additive only — existing assign_instance methods are unchanged.
The new methods extract the same logic that assign_instance performed inline,
but route through the LkShardramSink trait instead of directly calling
lk_multiplicity.assert_ux(...).

Intermediate configs (r_insn.rs, i_insn.rs, b_insn.rs, s_insn.rs, j_insn.rs, im_insn.rs)

Each gained corresponding emit_lk_and_shardram / emit_shardram methods that
compose their gadgets' methods + emit LkOp::Fetch.


4. Per-Chip Circuit Files — GPU Opt-in (+792 / -129 lines across ~20 files)

Each v2 circuit file (arith.rs, logic_circuit.rs, div_circuit_v2.rs, etc.) gained:

const GPU_LK_SHARDRAM: bool = true;  // or conditional match

impl_collect_lk_and_shardram!(r_insn, |sink, step, _config, _ctx| {
    // chip-specific LK ops
});
impl_collect_shardram!(r_insn);
impl_gpu_assign!(dispatch::GpuWitgenKind::Add);

Impact: Additive — existing assign_instance and construct_circuit unchanged.
The #[cfg(feature = "gpu")] assign_instances override is only compiled with the
gpu feature flag.


5. ShardContext — New Fields and Methods

File: ceno_zkvm/src/e2e.rs (+616 / -199 lines)

New fields

Field Type Purpose
sorted_next_accesses Arc<SortedNextAccesses> Pre-sorted packed future-access table for GPU bulk H2D
gpu_ec_records Vec<u8> Raw bytes of GPU-produced compact EC shard records
syscall_witnesses Arc<Vec<SyscallWitness>> Keccak syscall data (previously passed separately)

New methods

Method Purpose
new_empty_like() Clone shard metadata with empty record storage (for debug comparison)
insert_read_record() / insert_write_record() Direct record insertion (GPU D2H path)
push_addr_accessed() Direct addr insertion (GPU D2H path)
extend_gpu_ec_records_raw() Append raw GPU EC record bytes
has_gpu_ec_records() / take_gpu_ec_records() GPU EC record lifecycle

Renamed method

send() → split into record_send_without_touch() (no addr_accessed tracking) and
send() (which calls record_send_without_touch + push_addr_accessed).

Pipeline hooks (in generate_witness shard loop)

#[cfg(feature = "gpu")]
flush_shared_ec_buffers(&mut shard_ctx);  // D2H shared GPU buffers

#[cfg(feature = "gpu")]
invalidate_shard_steps_cache();  // free GPU memory

Pipeline mode (in create_proofs_streaming)

New overlap pipeline (default when GPU feature enabled but CENO_GPU_ENABLE_WITGEN unset):
CPU witgen on thread A, GPU prove on thread B, connected by crossbeam::bounded(0) channel.


6. ZKVMWitnesses — GPU ShardRam Pipeline

File: ceno_zkvm/src/structs.rs (+580 / -130 lines)

assign_shared_circuit — new GPU fast path

Added try_assign_shared_circuit_gpu() that keeps data on GPU device:

  1. Takes shared device buffers (EC records + addr_accessed)
  2. GPU sort+dedup addr_accessed
  3. GPU batch EC computation for continuation records
  4. GPU merge+partition records (writes before reads)
  5. GPU ShardRamCircuit witness generation (Poseidon2 + EC tree)

Falls back to CPU path on failure.

gpu_ec_records_to_shard_ram_inputs

Converts raw GPU EC bytes (Vec<u8>) to Vec<ShardRamInput<E>> with pre-computed
EC points. Used in the CPU fallback path.


7. ShardRamCircuit — GPU Witness Generation

File: ceno_zkvm/src/tables/shard_ram.rs (+491 / -14 lines)

New GPU functions

Function Purpose
try_gpu_assign_instances() H2D path: CPU records → GPU kernel → D2H witness
try_gpu_assign_instances_from_device() Device path: records already on GPU → kernel → D2H

Both run a two-phase GPU pipeline:

  1. Per-row kernel: basic fields + Poseidon2 trace (344 witness columns)
  2. EC tree kernel: layer-by-layer binary tree EC summation

Visibility change

ShardRamConfig fields changed from private to pub(crate) to allow
column map extraction in gpu/chips/shard_ram.rs.


8. SepticCurve — New Math Utilities

File: ceno_zkvm/src/scheme/septic_curve.rs (+307 lines)

New CPU-side math for EC point computation (mirrored in CUDA):

Function Purpose
SepticExtension::frobenius() Frobenius endomorphism for norm computation
SepticExtension::sqrt() Cipolla's algorithm for field square roots
SepticPoint::from_x() Lift x-coordinate to curve point (used by nonce-finding loop)
QuadraticExtension<F> Auxiliary type for Cipolla's algorithm

9. Minor Touches

File Change
Cargo.toml gpu feature flag, crossbeam dependency
gkr_iop/src/gadgets/is_lt.rs AssertLtConfig.0.diff field access (already pub)
gkr_iop/src/utils/lk_multiplicity.rs Minor: LkMultiplicity::increment
ceno_zkvm/src/gadgets/signed_ext.rs pub(crate) fn msb() accessor for GPU column map
ceno_zkvm/src/gadgets/poseidon2.rs Column contiguity constants for GPU
ceno_zkvm/src/tables/*.rs pub(crate) visibility on config fields for GPU column map access
ceno_zkvm/src/scheme/{cpu,gpu,prover,verifier} Minor plumbing for GPU proving path
ceno_host/tests/test_elf.rs E2E test adjustments

Summary

Category Nature Risk
#[repr(C)] on emulator types Layout pinning Low — additive, but field changes now need CUDA sync
Instruction<E> trait extensions Additive (defaults provided) None — existing chips unaffected
Gadget emit_* methods Additive None — existing assign_instance unchanged
ShardContext new fields Additive (defaults in Default) Low — Vec::new() / Arc::new() zero-cost
send()record_send_without_touch() + send() Rename + split Low — send() still works identically
ShardRamConfig visibility privatepub(crate) None
Pipeline overlap mode New default behavior Medium — CPU witgen + GPU prove on separate threads
septic_curve.rs math Additive None — new functions, existing unchanged

Copy link
Collaborator

@hero78119 hero78119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A quick review regarding to tracer & SortedNextAccesses field

mmio_min_max_access: Option<BTreeMap<WordAddr, (WordAddr, WordAddr, WordAddr, WordAddr)>>,
latest_accesses: LatestAccesses,
next_accesses: NextCycleAccess,
next_accesses_vec: Vec<PackedNextAccessEntry>,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can re-build SortedNextAccesses from self.next_accesses map so we can avoid introduce this new vector field and PackedNextAccessEntry

Copy link
Collaborator Author

@Velaciela Velaciela Mar 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

re-build

slower than this method (roughly 50ms vs 200ms)
CENO_NEXT_ACCESS_SOURCE=preflight vs CENO_NEXT_ACCESS_SOURCE=hashmap

Copy link
Collaborator

@hero78119 hero78119 Mar 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe the slow come from the sequential traverse of hashmap
how about par_extend() ?

                info_span!("next_access_from_hashmap").in_scope(|| {
                    let total_entries: usize =
                        addr_future_accesses.values().map(|pairs| pairs.len()).sum();
                    let mut entries = Vec::with_capacity(total_entries);
                    entries.par_extend(addr_future_accesses.par_iter().flat_map_iter(
                        |(cycle, pairs)| {
                            pairs
                                .iter()
                                .map(move |&(addr, next_cycle)| {
                                    PackedNextAccessEntry::new(*cycle, addr.0, next_cycle)
                                })
                        },
                    ));
                    entries
                })

If this work, we can remove vector version from tracer and only rebuild here and gated by gpu feature

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants