Skip to content

Inference | Hybrid prefix caching.#3225

Open
lmcafee-nvidia wants to merge 15 commits intoNVIDIA:mainfrom
lmcafee-nvidia:prefix-caching-mamba
Open

Inference | Hybrid prefix caching.#3225
lmcafee-nvidia wants to merge 15 commits intoNVIDIA:mainfrom
lmcafee-nvidia:prefix-caching-mamba

Conversation

@lmcafee-nvidia
Copy link
Contributor

@lmcafee-nvidia lmcafee-nvidia commented Feb 3, 2026

What does this PR do ?

Add Mamba state prefix caching for hybrid Transformer-Mamba models, enabling KV cache prefix sharing to also share corresponding Mamba conv/SSM states.

Key Features

  • Mamba state caching: When KV cache blocks are reused due to shared prefixes, the corresponding Mamba states can now also be cached and restored, avoiding redundant recomputation
  • Memory-budgeted cache: New --inference-dynamic-batching-prefix-caching-mamba-gb argument controls the memory budget for cached Mamba states
  • LRU eviction: Mamba cache uses LRU eviction based on KV block timestamps when the cache is full
  • Automatic invalidation: When KV blocks are evicted, their associated Mamba states are automatically invalidated

Changes

Core Implementation (megatron/core/inference/contexts/dynamic_context.py, megatron/core/inference/engines/dynamic_engine.py):

  • Fixed-size tensor pool for cached Mamba conv/SSM states
  • Block-to-slot mapping for associating KV blocks with Mamba cache slots
  • Store Mamba state at block boundaries after prefill completes
  • Restore cached Mamba state when prefix matches during request scheduling

Block Allocator (megatron/core/inference/contexts/dynamic_block_allocator.py):

  • Invalidate Mamba state when KV blocks are evicted

Arguments (megatron/training/arguments.py):

  • Add --inference-dynamic-batching-prefix-caching-mamba-gb parameter
  • Auto-enable chunked prefill when Mamba prefix caching is used

Test plan

  • Unit tests for cache allocation/deallocation
  • Unit tests for LRU eviction (20 tests in tests/unit_tests/inference/contexts/test_mamba_prefix_caching.py)
  • Tests for memory budget edge cases
  • Tests for state store/restore integration
  • End-to-end inference test with hybrid model

⚠️ For major changes (either in lines of code or in its impact), please make sure to first share a design doc with the team. If you're unsure what's the best way to do so, contact the @mcore-oncall.

Contribution process

flowchart LR
    A[Pre-checks] --> B[PR Tests]
    subgraph Code Review/Approval
        C1[Expert Review] --> C2[Final Review]
    end
    B --> C1
    C2 --> D[Merge]
Loading

Pre-checks

  • I want this PR in a versioned release and have added the appropriate Milestone (e.g., Core 0.8)
  • I have added relevant unit tests
  • I have added relevant functional tests
  • I have added proper typing to my code Typing guidelines
  • I have added relevant documentation
  • I have run the autoformatter.sh on my PR

Code review

The following process is enforced via the CODEOWNERS file for changes into megatron/core. For changes outside of megatron/core, it is up to the PR author whether or not to tag the Final Reviewer team.

For MRs into `main` branch

Feel free to message or comment the @mcore-oncall to help accelerate your merge into main. The less complex your PR is, the faster it will be approved and merged!

(Step 1): Add PR label Expert Review

(Step 2): Collect the expert reviewers reviews

  1. Attach the Expert Review label when your PR is ready for review.
  2. GitHub auto-assigns expert reviewers based on your changes. They will get notified and pick up your PR soon.

⚠️ Only proceed to the next step once all reviewers have approved, merge-conflict are resolved and the CI is passing.
Final Review might get declined if these requirements are not fulfilled.

(Step 3): Final Review

  1. Add Final Review label
  2. GitHub auto-assigns final reviewers based on your changes. They will get notified and pick up your PR soon.

(Optional Step 4): Cherry-pick into release branch

If this PR also needs to be merged into core_r* release branches, after this PR has been merged, select Cherry-pick to open a new PR into the release branch.

For MRs into `dev` branch The proposed review process for `dev` branch is under active discussion.

MRs are mergable after one approval by either eharper@nvidia.com or zijiey@nvidia.com.

Merging your PR

Any member of core-adlr and core-nemo will be able to merge your PR.

@lmcafee-nvidia lmcafee-nvidia added this to the Core 0.15 milestone Feb 3, 2026
@lmcafee-nvidia lmcafee-nvidia self-assigned this Feb 3, 2026
@lmcafee-nvidia lmcafee-nvidia requested a review from a team as a code owner February 3, 2026 14:49
@lmcafee-nvidia lmcafee-nvidia added the Expert Review Apply this label to indicate that your PR is ready for expert review. label Feb 3, 2026
@lmcafee-nvidia lmcafee-nvidia requested review from a team as code owners February 3, 2026 14:49
@copy-pr-bot
Copy link

copy-pr-bot bot commented Feb 3, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@ko3n1g ko3n1g requested a review from a team February 3, 2026 14:50
@lmcafee-nvidia lmcafee-nvidia requested a review from a team as a code owner February 18, 2026 16:35
@lmcafee-nvidia lmcafee-nvidia force-pushed the prefix-caching-mamba branch 3 times, most recently from ed541ac to a3b62d3 Compare February 27, 2026 14:18
lmcafee-nvidia and others added 10 commits March 3, 2026 12:13
Implement prefix caching support for Mamba states in hybrid Transformer-Mamba
models. When KV cache blocks are reused due to shared prefixes, the corresponding
Mamba conv/SSM states can also be cached and restored, avoiding redundant
recomputation.

Key changes:
- Add fixed-size tensor pool for cached Mamba states with LRU eviction
- Store Mamba state at block boundaries after prefill completes
- Restore cached Mamba state when prefix matches during request scheduling
- Invalidate Mamba state when associated KV blocks are evicted
- Add --inference-dynamic-batching-prefix-caching-mamba-gb argument
- Auto-enable chunked prefill when Mamba prefix caching is used

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
…budget

Add 10 new tests covering gaps in Mamba prefix caching:

Eviction edge cases (IMPORTANT):
- test_eviction_all_slots_active_raises_error: Verify RuntimeError when all slots in use
- test_eviction_with_mixed_ref_counts: Verify only ref_count=0 blocks evicted
- test_eviction_ordering_by_timestamp: Verify LRU ordering by timestamp

Integration tests (IMPORTANT):
- test_store_mamba_state_for_complete_blocks: Verify state storage for complete blocks
- test_mamba_state_restored_on_prefix_match: Verify state restoration on prefix match

Memory budget edge cases (MEDIUM):
- test_zero_slots_when_budget_too_small: Verify zero slots for tiny budget
- test_negative_budget_treated_as_disabled: Verify negative budget disables caching
- test_exact_slot_boundary_budget: Verify exact slot count at budget boundary

Stress tests (LOW):
- test_rapid_allocation_eviction_cycle: 100 allocation/eviction cycles
- test_large_number_of_blocks: Many blocks with limited Mamba cache

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Break the monolithic TestMambaPrefixCaching class (20 tests) into
8 focused test classes organized by functionality:

- TestBasicMambaStateCaching: Basic Mamba state caching (3)
- TestMambaEviction: LRU eviction tests (2)
- TestMambaEdgeCases: Edge case tests (3)
- TestMambaMemoryAccounting: Memory accounting tests (2)
- TestMambaEvictionEdgeCases: Eviction edge cases (3)
- TestMambaIntegration: Integration with prefix caching (2)
- TestMambaMemoryBudgetEdgeCases: Memory budget edge cases (3)
- TestMambaStress: Stress and robustness tests (2)

All classes inherit from MambaPrefixCachingTestBase which contains
shared setup/teardown and helper methods.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
…anch

- Store prefix_caching_mamba_gb on self in __init__ so allocate_mamba_cache() can access it
- Fix MambaInferenceStateConfig import (moved from mamba_metadata to config)
- Enable block_evict_lru in mamba tests (required for LRU timestamp access)
- Fix release_request() -> release_memory_blocks_from_request_indexes()

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
For hybrid models, KV prefix matching without corresponding Mamba state
causes incorrect output because Mamba is recurrent (state at block N
encapsulates all prior tokens). This limits KV prefix-matched blocks to
only those that also have cached Mamba state, falling back to full
prefill when no Mamba state is available.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…g test

Replaces TestCrossConfigEndToEnd's two context-only tests with a single
engine-level test that builds a full MambaModel engine stack and verifies
output token equivalence across 3 configurations (chunked+prefix,
chunked-only, baseline) with interleaving boundaries.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…tates

Previously only the first prefill request's restored Mamba state was used
correctly (via batch kernel); all others went through varlen which ignores
initial states, producing incorrect output. This change generalizes the
single batch-kernel call to loop over N requests with initial states.

Key changes:
- Track per-request initial Mamba state flag in dynamic_context
- Generalize mamba_metadata from 1 to N batch-kernel prefill requests
- Reorder scheduler to place mamba-state requests before varlen requests
- Loop batch kernel in mamba_mixer for each request with initial states
- Add E2E test for multiple simultaneous prefills with restored states

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Split hash_to_block_id into kv_hash_to_block_id (all KV blocks) and
mamba_hash_to_block_id (only blocks with cached Mamba state). The KV map
discovers full prefix match extent, the Mamba map identifies which matched
blocks have state for divergence detection.

Engine now enforces chunk breaks at Mamba-meaningful boundaries (KV divergence
and last-aligned) so _store_mamba_states_for_completed_prefill can detect and
store state at the correct points. Mamba hashes are registered on store and
deregistered on both Mamba eviction (LRU) and KV eviction.

Fixes total_prefilled calculation in _store_mamba_states to use actual tokens
in context rather than full prompt length, which was wrong for chunked prefills.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Three new engine-level tests covering previously untested code paths:
- Divergence boundary: KV match extends beyond Mamba match, Mamba state
  restored and stored at divergence boundary
- Last-aligned boundary: non-block-aligned prompt forced-chunked at
  last-aligned boundary, Mamba state stored after first chunk
- Mixed kernel routing: continuing chunked prefill (batch kernel) runs
  alongside fresh prefills (varlen kernel) in a single forward step

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
… limits

Three fixes for hybrid Mamba+Attention prefix caching:

1. Remove unnecessary -1 from prefix_skip_tokens computation. The old
   formula `min(num_matched * block_size, chunk_length - 1)` forced
   reprocessing of an already-matched token. For hybrid models, skip
   count is now `min(num_mamba_matched, num_kv_matched)` blocks to
   ensure both KV and Mamba state cover the skipped range.

2. Add direct decode entry guard: when a block-aligned prompt is fully
   prefix-matched, effective_chunk_length becomes 0 which crashes Flash
   Attention. Back off by 1 token so the last prompt token is
   reprocessed, producing first output logits.

3. Fix _get_mamba_chunk_limit to skip boundaries already covered by
   restored Mamba state. Previously forced unnecessary chunk breaks at
   kv_divergence and last_aligned even when Mamba state was already
   cached at those boundaries.

Also fix chunked prefill auto-enable in arguments.py: was checking
disable_chunked_prefill instead of enable_chunked_prefill.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
lmcafee-nvidia and others added 2 commits March 4, 2026 09:44
Brings in mamba_chunk_scan_combined_varlen and supporting Triton SSM
kernels from onur/ssm-kernels branch, enabling all prefill requests
(including those with restored initial states) to go through a single
varlen SSM kernel call.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
All Mamba prefill requests now go through a single varlen SSM kernel call
instead of routing restored-state requests through a per-request batch
kernel loop. Adds a Triton varlen causal conv1d kernel as an opt-in
alternative to the per-request loop for the conv phase.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
lmcafee-nvidia and others added 3 commits March 4, 2026 11:12
…ence

1. Varlen unification commit bug: _dynamic_inference_prefill passed
   CUDA-graph-padded metadata tensors (with -1 batch_indices and
   zero-length cu_seqlens segments) to the varlen SSM kernel, causing
   out-of-bounds indexing when computing seq_idx_for_varlen. Fixed by
   stripping padding to only real prefill requests.

2. PR 3442 bug: _ssm_prefill set cu_chunk_seqlens = cu_seqlens, treating
   each sequence as a single chunk. When sequences exceed chunk_size
   (128), the SSM Triton kernels access dt and dA_cumsum arrays out of
   bounds, causing illegal memory access. Fixed by subdividing sequences
   into chunk_size-aligned boundaries.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Our new code: strip padded tokens from zxBCdt in
_dynamic_inference_prefill before passing to _ssm_prefill, and pad the
output back to the padded size for the downstream residual add. Without
this, the conv1d loop rebuilds xBC with only real tokens while dt and z
retain the padded token count, causing a shape assertion failure in the
varlen SSM kernel.

Also clean up verbose debug assertion in ssd_combined.py.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Move initial_conv_states extraction before causal_conv1d_varlen_states
+ tensor_masked_update to prevent reading overwritten state. Add
cuda_graph_compatible path to _ssm_prefill that avoids .item() calls
by using causal_conv1d_fn with seq_idx and mamba_chunk_scan_combined.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

complexity: high Expert Review Apply this label to indicate that your PR is ready for expert review.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants