Skip to content

Fix SSM kernel bugs and unify Mamba prefill through varlen path#2

Open
lmcafee-nvidia wants to merge 1 commit intooyilmaz-nvidia:onur/ssm-kernelsfrom
lmcafee-nvidia:ssm-kernel-fixes
Open

Fix SSM kernel bugs and unify Mamba prefill through varlen path#2
lmcafee-nvidia wants to merge 1 commit intooyilmaz-nvidia:onur/ssm-kernelsfrom
lmcafee-nvidia:ssm-kernel-fixes

Conversation

@lmcafee-nvidia
Copy link

Summary

  • conv_state save-before-read: extract initial conv states BEFORE causal_conv1d_varlen_states + tensor_masked_update overwrites the conv_state buffer, fixing corrupted convolution output for restored requests
  • cu_chunk_seqlens OOB: subdivide each sequence into chunks of at most chunk_size tokens instead of passing cu_seqlens directly, fixing out-of-bounds memory access when sequences exceed chunk_size (128) tokens
  • zxBCdt padding mismatch: strip padded tokens from zxBCdt before _ssm_prefill, then pad the output back, fixing shape assertion failures when padded and real token counts differ
  • per-request conv1d with initial_states: loop over requests calling causal_conv1d_fn per-request with initial_states, because causal_conv1d_fn cannot accept both seq_idx and initial_states simultaneously
  • unify prefill path: all Mamba prefill (including chunked) goes through a single varlen SSM kernel call, removing separate chunked-prefill routing and _batch_indices_chunked_prefill / _device_chunked_prefill metadata
  • simplify _dynamic_inference: flat decode + prefill structure with new _dynamic_inference_prefill helper that strips CUDA-graph padding from metadata and data tensors
  • remove deprecated constructor params: use_mem_eff_path, d_state, headdim, ngroups and their warnings.warn blocks
  • add assertion format string in ssd_combined.py for easier debugging
  • add batch_allocate_slots() to MambaMetadata for batch slot allocation

Test plan

  • verify mamba_mixer.py imports without errors
  • run existing SSM kernel unit tests
  • run hybrid model inference end-to-end with decode + prefill mix
  • verify chunked prefill requests work through the unified varlen path

🤖 Generated with Claude Code

Bug fixes:
1. conv_state save-before-read: extract initial conv states BEFORE
   causal_conv1d_varlen_states + tensor_masked_update overwrites the
   conv_state buffer. Previously, initial_conv_states was read AFTER
   the buffer was updated, so restored requests would see their own
   newly-computed final states instead of the pre-existing initial
   states, corrupting the convolution output.

2. cu_chunk_seqlens OOB: the SSM Triton kernels allocate per-chunk
   output arrays of size chunk_size (128). Passing cu_seqlens directly
   as cu_chunk_seqlens caused out-of-bounds memory access when any
   sequence exceeded chunk_size tokens. Fix: subdivide each sequence
   into chunks of at most self.chunk_size, producing correct
   cu_chunk_seqlens boundaries.

3. zxBCdt padding mismatch: after conv1d, the per-request loop rebuilt
   xBC with only real tokens while dt and z retained padded token count.
   This caused a shape assertion failure in the SSM kernel. Fix: strip
   padded tokens from zxBCdt before _ssm_prefill, then pad the output
   back to the original padded size for downstream residual add.

4. Per-request conv1d with initial_states: causal_conv1d_fn cannot
   accept both seq_idx and initial_states simultaneously. The old code
   passed seq_idx to handle multiple sequences but this zeroes state at
   sequence boundaries instead of using the cached initial states. Fix:
   loop over requests, calling causal_conv1d_fn per-request with
   initial_states and channels-last layout.

Improvements:
- Unify all Mamba prefill (including chunked) through single varlen SSM
  kernel call, removing separate chunked-prefill routing and the
  _batch_indices_chunked_prefill / _device_chunked_prefill metadata
- Simplify _dynamic_inference to flat decode + prefill structure
- Add _dynamic_inference_prefill helper that strips CUDA-graph padding
  from metadata and data tensors before calling _ssm_prefill
- Remove deprecated constructor parameters (use_mem_eff_path, d_state,
  headdim, ngroups) and their warnings
- Add assertion format string in ssd_combined.py for easier debugging

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant