Fix SSM kernel bugs and unify Mamba prefill through varlen path#2
Open
lmcafee-nvidia wants to merge 1 commit intooyilmaz-nvidia:onur/ssm-kernelsfrom
Open
Fix SSM kernel bugs and unify Mamba prefill through varlen path#2lmcafee-nvidia wants to merge 1 commit intooyilmaz-nvidia:onur/ssm-kernelsfrom
lmcafee-nvidia wants to merge 1 commit intooyilmaz-nvidia:onur/ssm-kernelsfrom
Conversation
Bug fixes: 1. conv_state save-before-read: extract initial conv states BEFORE causal_conv1d_varlen_states + tensor_masked_update overwrites the conv_state buffer. Previously, initial_conv_states was read AFTER the buffer was updated, so restored requests would see their own newly-computed final states instead of the pre-existing initial states, corrupting the convolution output. 2. cu_chunk_seqlens OOB: the SSM Triton kernels allocate per-chunk output arrays of size chunk_size (128). Passing cu_seqlens directly as cu_chunk_seqlens caused out-of-bounds memory access when any sequence exceeded chunk_size tokens. Fix: subdivide each sequence into chunks of at most self.chunk_size, producing correct cu_chunk_seqlens boundaries. 3. zxBCdt padding mismatch: after conv1d, the per-request loop rebuilt xBC with only real tokens while dt and z retained padded token count. This caused a shape assertion failure in the SSM kernel. Fix: strip padded tokens from zxBCdt before _ssm_prefill, then pad the output back to the original padded size for downstream residual add. 4. Per-request conv1d with initial_states: causal_conv1d_fn cannot accept both seq_idx and initial_states simultaneously. The old code passed seq_idx to handle multiple sequences but this zeroes state at sequence boundaries instead of using the cached initial states. Fix: loop over requests, calling causal_conv1d_fn per-request with initial_states and channels-last layout. Improvements: - Unify all Mamba prefill (including chunked) through single varlen SSM kernel call, removing separate chunked-prefill routing and the _batch_indices_chunked_prefill / _device_chunked_prefill metadata - Simplify _dynamic_inference to flat decode + prefill structure - Add _dynamic_inference_prefill helper that strips CUDA-graph padding from metadata and data tensors before calling _ssm_prefill - Remove deprecated constructor parameters (use_mem_eff_path, d_state, headdim, ngroups) and their warnings - Add assertion format string in ssd_combined.py for easier debugging Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
causal_conv1d_varlen_states+tensor_masked_updateoverwrites theconv_statebuffer, fixing corrupted convolution output for restored requestschunk_sizetokens instead of passingcu_seqlensdirectly, fixing out-of-bounds memory access when sequences exceedchunk_size(128) tokenszxBCdtbefore_ssm_prefill, then pad the output back, fixing shape assertion failures when padded and real token counts differcausal_conv1d_fnper-request withinitial_states, becausecausal_conv1d_fncannot accept bothseq_idxandinitial_statessimultaneously_batch_indices_chunked_prefill/_device_chunked_prefillmetadata_dynamic_inference: flat decode + prefill structure with new_dynamic_inference_prefillhelper that strips CUDA-graph padding from metadata and data tensorsuse_mem_eff_path,d_state,headdim,ngroupsand theirwarnings.warnblocksssd_combined.pyfor easier debuggingbatch_allocate_slots()toMambaMetadatafor batch slot allocationTest plan
mamba_mixer.pyimports without errors🤖 Generated with Claude Code