Inference | Hybrid prefix caching.#3225
Open
lmcafee-nvidia wants to merge 15 commits intoNVIDIA:mainfrom
Open
Conversation
5944bd7 to
296d684
Compare
ed541ac to
a3b62d3
Compare
Implement prefix caching support for Mamba states in hybrid Transformer-Mamba models. When KV cache blocks are reused due to shared prefixes, the corresponding Mamba conv/SSM states can also be cached and restored, avoiding redundant recomputation. Key changes: - Add fixed-size tensor pool for cached Mamba states with LRU eviction - Store Mamba state at block boundaries after prefill completes - Restore cached Mamba state when prefix matches during request scheduling - Invalidate Mamba state when associated KV blocks are evicted - Add --inference-dynamic-batching-prefix-caching-mamba-gb argument - Auto-enable chunked prefill when Mamba prefix caching is used Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
…budget Add 10 new tests covering gaps in Mamba prefix caching: Eviction edge cases (IMPORTANT): - test_eviction_all_slots_active_raises_error: Verify RuntimeError when all slots in use - test_eviction_with_mixed_ref_counts: Verify only ref_count=0 blocks evicted - test_eviction_ordering_by_timestamp: Verify LRU ordering by timestamp Integration tests (IMPORTANT): - test_store_mamba_state_for_complete_blocks: Verify state storage for complete blocks - test_mamba_state_restored_on_prefix_match: Verify state restoration on prefix match Memory budget edge cases (MEDIUM): - test_zero_slots_when_budget_too_small: Verify zero slots for tiny budget - test_negative_budget_treated_as_disabled: Verify negative budget disables caching - test_exact_slot_boundary_budget: Verify exact slot count at budget boundary Stress tests (LOW): - test_rapid_allocation_eviction_cycle: 100 allocation/eviction cycles - test_large_number_of_blocks: Many blocks with limited Mamba cache Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Break the monolithic TestMambaPrefixCaching class (20 tests) into 8 focused test classes organized by functionality: - TestBasicMambaStateCaching: Basic Mamba state caching (3) - TestMambaEviction: LRU eviction tests (2) - TestMambaEdgeCases: Edge case tests (3) - TestMambaMemoryAccounting: Memory accounting tests (2) - TestMambaEvictionEdgeCases: Eviction edge cases (3) - TestMambaIntegration: Integration with prefix caching (2) - TestMambaMemoryBudgetEdgeCases: Memory budget edge cases (3) - TestMambaStress: Stress and robustness tests (2) All classes inherit from MambaPrefixCachingTestBase which contains shared setup/teardown and helper methods. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
…anch - Store prefix_caching_mamba_gb on self in __init__ so allocate_mamba_cache() can access it - Fix MambaInferenceStateConfig import (moved from mamba_metadata to config) - Enable block_evict_lru in mamba tests (required for LRU timestamp access) - Fix release_request() -> release_memory_blocks_from_request_indexes() Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
For hybrid models, KV prefix matching without corresponding Mamba state causes incorrect output because Mamba is recurrent (state at block N encapsulates all prior tokens). This limits KV prefix-matched blocks to only those that also have cached Mamba state, falling back to full prefill when no Mamba state is available. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…g test Replaces TestCrossConfigEndToEnd's two context-only tests with a single engine-level test that builds a full MambaModel engine stack and verifies output token equivalence across 3 configurations (chunked+prefix, chunked-only, baseline) with interleaving boundaries. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…tates Previously only the first prefill request's restored Mamba state was used correctly (via batch kernel); all others went through varlen which ignores initial states, producing incorrect output. This change generalizes the single batch-kernel call to loop over N requests with initial states. Key changes: - Track per-request initial Mamba state flag in dynamic_context - Generalize mamba_metadata from 1 to N batch-kernel prefill requests - Reorder scheduler to place mamba-state requests before varlen requests - Loop batch kernel in mamba_mixer for each request with initial states - Add E2E test for multiple simultaneous prefills with restored states Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Split hash_to_block_id into kv_hash_to_block_id (all KV blocks) and mamba_hash_to_block_id (only blocks with cached Mamba state). The KV map discovers full prefix match extent, the Mamba map identifies which matched blocks have state for divergence detection. Engine now enforces chunk breaks at Mamba-meaningful boundaries (KV divergence and last-aligned) so _store_mamba_states_for_completed_prefill can detect and store state at the correct points. Mamba hashes are registered on store and deregistered on both Mamba eviction (LRU) and KV eviction. Fixes total_prefilled calculation in _store_mamba_states to use actual tokens in context rather than full prompt length, which was wrong for chunked prefills. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Three new engine-level tests covering previously untested code paths: - Divergence boundary: KV match extends beyond Mamba match, Mamba state restored and stored at divergence boundary - Last-aligned boundary: non-block-aligned prompt forced-chunked at last-aligned boundary, Mamba state stored after first chunk - Mixed kernel routing: continuing chunked prefill (batch kernel) runs alongside fresh prefills (varlen kernel) in a single forward step Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
… limits Three fixes for hybrid Mamba+Attention prefix caching: 1. Remove unnecessary -1 from prefix_skip_tokens computation. The old formula `min(num_matched * block_size, chunk_length - 1)` forced reprocessing of an already-matched token. For hybrid models, skip count is now `min(num_mamba_matched, num_kv_matched)` blocks to ensure both KV and Mamba state cover the skipped range. 2. Add direct decode entry guard: when a block-aligned prompt is fully prefix-matched, effective_chunk_length becomes 0 which crashes Flash Attention. Back off by 1 token so the last prompt token is reprocessed, producing first output logits. 3. Fix _get_mamba_chunk_limit to skip boundaries already covered by restored Mamba state. Previously forced unnecessary chunk breaks at kv_divergence and last_aligned even when Mamba state was already cached at those boundaries. Also fix chunked prefill auto-enable in arguments.py: was checking disable_chunked_prefill instead of enable_chunked_prefill. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
27831b6 to
e4a98ee
Compare
Brings in mamba_chunk_scan_combined_varlen and supporting Triton SSM kernels from onur/ssm-kernels branch, enabling all prefill requests (including those with restored initial states) to go through a single varlen SSM kernel call. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
All Mamba prefill requests now go through a single varlen SSM kernel call instead of routing restored-state requests through a per-request batch kernel loop. Adds a Triton varlen causal conv1d kernel as an opt-in alternative to the per-request loop for the conv phase. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…ence 1. Varlen unification commit bug: _dynamic_inference_prefill passed CUDA-graph-padded metadata tensors (with -1 batch_indices and zero-length cu_seqlens segments) to the varlen SSM kernel, causing out-of-bounds indexing when computing seq_idx_for_varlen. Fixed by stripping padding to only real prefill requests. 2. PR 3442 bug: _ssm_prefill set cu_chunk_seqlens = cu_seqlens, treating each sequence as a single chunk. When sequences exceed chunk_size (128), the SSM Triton kernels access dt and dA_cumsum arrays out of bounds, causing illegal memory access. Fixed by subdividing sequences into chunk_size-aligned boundaries. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Our new code: strip padded tokens from zxBCdt in _dynamic_inference_prefill before passing to _ssm_prefill, and pad the output back to the padded size for the downstream residual add. Without this, the conv1d loop rebuilds xBC with only real tokens while dt and z retain the padded token count, causing a shape assertion failure in the varlen SSM kernel. Also clean up verbose debug assertion in ssd_combined.py. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Move initial_conv_states extraction before causal_conv1d_varlen_states + tensor_masked_update to prevent reading overwritten state. Add cuda_graph_compatible path to _ssm_prefill that avoids .item() calls by using causal_conv1d_fn with seq_idx and mamba_chunk_scan_combined. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What does this PR do ?
Add Mamba state prefix caching for hybrid Transformer-Mamba models, enabling KV cache prefix sharing to also share corresponding Mamba conv/SSM states.
Key Features
--inference-dynamic-batching-prefix-caching-mamba-gbargument controls the memory budget for cached Mamba statesChanges
Core Implementation (
megatron/core/inference/contexts/dynamic_context.py,megatron/core/inference/engines/dynamic_engine.py):Block Allocator (
megatron/core/inference/contexts/dynamic_block_allocator.py):Arguments (
megatron/training/arguments.py):--inference-dynamic-batching-prefix-caching-mamba-gbparameterTest plan
tests/unit_tests/inference/contexts/test_mamba_prefix_caching.py)Contribution process
flowchart LR A[Pre-checks] --> B[PR Tests] subgraph Code Review/Approval C1[Expert Review] --> C2[Final Review] end B --> C1 C2 --> D[Merge]Pre-checks
Core 0.8)Code review
The following process is enforced via the CODEOWNERS file for changes into
megatron/core. For changes outside ofmegatron/core, it is up to the PR author whether or not to tag the Final Reviewer team.For MRs into `main` branch
Feel free to message or comment the @mcore-oncall to help accelerate your merge into main. The less complex your PR is, the faster it will be approved and merged!
(Step 1): Add PR label
Expert Review(Step 2): Collect the expert reviewers reviews
Expert Reviewlabel when your PR is ready for review.Final Review might get declined if these requirements are not fulfilled.
(Step 3): Final Review
Final Reviewlabel(Optional Step 4): Cherry-pick into release branch
If this PR also needs to be merged into
core_r*release branches, after this PR has been merged, selectCherry-pickto open a new PR into the release branch.For MRs into `dev` branch
The proposed review process for `dev` branch is under active discussion.MRs are mergable after one approval by either
eharper@nvidia.comorzijiey@nvidia.com.Merging your PR
Any member of core-adlr and
core-nemowill be able to merge your PR.