Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
15 commits
Select commit Hold shift + click to select a range
ac424ed
Add Mamba prefix caching for hybrid models
lmcafee-nvidia Feb 26, 2026
5b81ca6
Add comprehensive Mamba prefix caching tests for eviction and memory …
lmcafee-nvidia Feb 3, 2026
5754ed9
Refactor Mamba prefix caching tests into separate focused classes
lmcafee-nvidia Feb 3, 2026
ec66cb4
Fix Mamba prefix caching compatibility with current prefix-caching br…
lmcafee-nvidia Feb 17, 2026
f39046b
Couple KV+Mamba prefix matching for hybrid model correctness
lmcafee-nvidia Feb 17, 2026
195e7da
Replace context-level cross-config test with engine-level interleavin…
lmcafee-nvidia Feb 17, 2026
c351bca
Loop batch kernel for multiple prefill requests with restored Mamba s…
lmcafee-nvidia Feb 18, 2026
a30d56b
Two-map hash design for coupled KV+Mamba prefix caching
lmcafee-nvidia Feb 27, 2026
ce52dc0
Add tests for chunked prefill Mamba state and mixed kernel routing
lmcafee-nvidia Feb 27, 2026
e4a98ee
Fix hybrid Mamba prefix caching: skip tokens, zero-chunk guard, chunk…
lmcafee-nvidia Feb 28, 2026
e1ee1c0
Merge PR 3442 (onur/ssm-kernels): varlen SSM kernel with initial_states
lmcafee-nvidia Mar 4, 2026
08ccbf5
Unify Mamba prefill through varlen kernel with initial state support
lmcafee-nvidia Mar 4, 2026
141e253
Fix two varlen SSM kernel bugs causing CUDA errors on hybrid-2b infer…
lmcafee-nvidia Mar 4, 2026
37953e5
Fix zxBCdt padding mismatch in varlen Mamba prefill
lmcafee-nvidia Mar 4, 2026
2c0ded6
Fix conv_state save-before-read bug and add CUDA graph prefill path
lmcafee-nvidia Mar 5, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions megatron/core/inference/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,14 @@ class InferenceConfig:
Only applies when enable_prefix_caching is True.
"""

prefix_caching_mamba_gb: Optional[float] = None
"""Memory budget (GB) for cached Mamba states in prefix caching.
Required for Mamba prefix caching in hybrid models. If None, Mamba prefix caching is disabled."""

use_triton_conv1d: bool = False
"""Whether to use a Triton varlen conv1d kernel for Mamba prefill instead of
per-request causal_conv1d_fn calls. Only applies to hybrid models with Mamba layers."""

# =================================
# Logging config
# =================================
Expand Down
168 changes: 26 additions & 142 deletions megatron/core/inference/contexts/attention_context/mamba_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,6 @@ def __init__(self, max_requests: int, max_tokens: int):
(self.max_requests,), -1, dtype=torch.int32, device=self.device
)

# Map from the active chunked prefill request to its slot in the static Mamba state buffer
self._batch_indices_chunked_prefill_buffer = torch.full(
(1,), -1, dtype=torch.int32, device=self.device
)

# Map from token id to request id for active prefill requests
self._seq_idx_buffer = torch.full(
(1, self.max_tokens), -1, dtype=torch.int32, device=self.device
Expand All @@ -56,14 +51,6 @@ def __init__(self, max_requests: int, max_tokens: int):
(2,), dtype=torch.int32, device=self.device
)

# Tuple of (
# chunked prefill sequence length,
# total regular prefill sequence length
# )
self._device_chunked_prefill_buffer = torch.zeros(
(2,), dtype=torch.int32, device=self.device
)

# Allocator for Mamba state slots
self.mamba_state_free_slots = torch.arange(
self.max_requests, dtype=torch.int32, device=torch.cuda.current_device()
Expand All @@ -90,11 +77,9 @@ def reset_varlen_metadata(self) -> None:
"""Resets varlen metadata."""
self.batch_indices_decode = None
self.batch_indices_prefill = None
self.batch_indices_chunked_prefill = None
self.cu_seqlens = None
self.seq_idx = None
self.device_decode_prefill = None
self.device_chunked_prefill = None

def update(
self,
Expand All @@ -103,7 +88,6 @@ def update(
cu_seqlens: torch.Tensor,
batch_dimensions: InferenceBatchDimensions,
padded_batch_dimensions: InferenceBatchDimensions,
enable_chunked_prefill: bool,
) -> None:
"""
Updates the dedicated CUDA graph mapping tensor with the indices
Expand All @@ -116,7 +100,6 @@ def update(
cu_seqlens (Tensor): Cumulative sequence lengths.
batch_dimensions (InferenceBatchDimensions): Dimensions of the current batch.
padded_batch_dimensions (InferenceBatchDimensions): Dimensions of the padded batch.
enable_chunked_prefill (bool): Whether chunked prefill is enabled
"""
real_decode_count = batch_dimensions.decode_req_count
real_prefill_count = batch_dimensions.prefill_req_count
Expand All @@ -125,48 +108,6 @@ def update(
padded_prefill_count = padded_batch_dimensions.prefill_req_count
padded_token_count = padded_batch_dimensions.token_count

has_chunked_prefill_req = enable_chunked_prefill and real_prefill_count > 0

# Although the context ensures that the last request is always the designated
# chunked prefill request, what we actually care about is ensuring that any
# prefill request with non-zero initial states is executed through the
# chunked prefill path.
#
# In the batch arrangement passed to this update function, the logic assumes
# the *first* prefill request is the one carrying states.
#
# There are three scenarios:
#
# Scenario A: No prefill request has initial states yet, but the last request
# is the designated chunked prefill request (starting a new chunk).
#
# [ ... Decode Requests ... ] [ Prefill (start) ]
# ^
# |--- First prefill request
# Treated as having states.
# Harmless because actual initial states are 0.
#
# Scenario B: There is exactly 1 prefill request which is a continuing
# chunked prefill request with non-zero initial states.
#
# [ ... Decode Requests ... ] [ Prefill (cont) ]
# ^
# |--- First prefill request
# Has non-zero initial states.
#
# Scenario C: There is a leftover chunked prefill request that is executing
# its last chunk, followed by additional prefill requests.
#
# [ ... Decode Requests ... ] [ Prefill (end) ] [ Prefill (new) ] ...
# ^
# |--- First prefill request
# Has non-zero initial states.
#
# The implementation generalizes to Scenario A as well, where the first prefill
# request is treated as if it has non-zero initial states, which is safe.
# While this results in a minor inefficiency f there is no continuing chunked prefill
# request in a given batch, this case is infrequent.

if padded_decode_count > 0:
# Update decode indices
self._batch_indices_decode_buffer[:real_decode_count].copy_(
Expand All @@ -176,118 +117,61 @@ def update(
self._batch_indices_decode_buffer[real_decode_count:padded_decode_count] = -1
self.batch_indices_decode = self._batch_indices_decode_buffer[:padded_decode_count]

# Determine if we have a chunked prefill request and adjust counts for regular prefill
regular_prefill_count = real_prefill_count
chunked_req_idx = -1

if has_chunked_prefill_req:
# The first prefill request is the chunked one
regular_prefill_count -= 1
chunked_req_idx = real_decode_count

# Update chunked prefill indices
self._batch_indices_chunked_prefill_buffer[0] = active_mamba_indices[chunked_req_idx]
self.batch_indices_chunked_prefill = self._batch_indices_chunked_prefill_buffer
else:
self.batch_indices_chunked_prefill = None

if padded_prefill_count > 0:
# Update prefill indices (excluding chunked prefill from regular prefill buffer)
if regular_prefill_count > 0:
# If chunked prefill exists, regular prefills start after it.
# If no chunked prefill, regular prefills start at real_decode_count.
start_idx = real_decode_count + (1 if has_chunked_prefill_req else 0)

self._batch_indices_prefill_buffer[:regular_prefill_count].copy_(
active_mamba_indices[start_idx : start_idx + regular_prefill_count]
# Update prefill indices (all prefill requests go through varlen)
if real_prefill_count > 0:
prefill_start_idx = real_decode_count
self._batch_indices_prefill_buffer[:real_prefill_count].copy_(
active_mamba_indices[prefill_start_idx : prefill_start_idx + real_prefill_count]
)

if padded_prefill_count > regular_prefill_count:
self._batch_indices_prefill_buffer[regular_prefill_count:padded_prefill_count] = -1
if padded_prefill_count > real_prefill_count:
self._batch_indices_prefill_buffer[
real_prefill_count:padded_prefill_count
] = -1

self.batch_indices_prefill = self._batch_indices_prefill_buffer[:padded_prefill_count]

# Update seq_idx for regular prefills
# If chunked prefill exists, we need to skip its tokens in seq_idx
# Update seq_idx for all prefill requests
prefill_start_req_idx = real_decode_count
end_prefill_req_idx = real_decode_count + real_prefill_count

# Index where regular prefills end in the batch (decode + chunked + regular)
end_regular_prefill_req_idx = (
real_decode_count + regular_prefill_count + (1 if has_chunked_prefill_req else 0)
)
end_regular_prefill_token_idx = cu_seqlens[end_regular_prefill_req_idx]

# Index where regular prefills start
start_regular_prefill_req_idx = real_decode_count + (
1 if has_chunked_prefill_req else 0
)
start_regular_prefill_token_idx = cu_seqlens[start_regular_prefill_req_idx]
start_prefill_token_idx = cu_seqlens[prefill_start_req_idx]
end_prefill_token_idx = cu_seqlens[end_prefill_req_idx]

# The length of tokens belonging to regular prefill requests
seq_len = end_regular_prefill_token_idx - start_regular_prefill_token_idx
seq_len = end_prefill_token_idx - start_prefill_token_idx

if seq_len > 0:
# We subtract start_regular_prefill_req_idx to normalize request IDs to
# 0-based relative to this buffer
# Normalize request IDs to 0-based relative to prefill requests
self._seq_idx_buffer[:, :seq_len].copy_(
token_to_request_idx[
start_regular_prefill_token_idx:end_regular_prefill_token_idx
]
- start_regular_prefill_req_idx
token_to_request_idx[start_prefill_token_idx:end_prefill_token_idx]
- prefill_start_req_idx
)

if padded_token_count > seq_len:
self._seq_idx_buffer[:, seq_len:padded_token_count] = -1
self.seq_idx = self._seq_idx_buffer[:, :padded_token_count]

# Update cu_seqlens for regular prefill requests
# Update cu_seqlens for all prefill requests
self._cu_seqlens_buffer[0] = 0
if regular_prefill_count > 0:
# Copy cu_seqlens for regular prefill requests and normalize by
# subtracting the start token index
start_req_idx = real_decode_count + (1 if has_chunked_prefill_req else 0)
end_req_idx = start_req_idx + regular_prefill_count

self._cu_seqlens_buffer[1 : regular_prefill_count + 1].copy_(
cu_seqlens[start_req_idx + 1 : end_req_idx + 1] - cu_seqlens[start_req_idx]
if real_prefill_count > 0:
self._cu_seqlens_buffer[1 : real_prefill_count + 1].copy_(
cu_seqlens[prefill_start_req_idx + 1 : end_prefill_req_idx + 1]
- cu_seqlens[prefill_start_req_idx]
)

# Pad the rest with the last value (effectively length 0 segments)
last_val = self._cu_seqlens_buffer[regular_prefill_count]
self._cu_seqlens_buffer[regular_prefill_count + 1 : padded_prefill_count + 1].fill_(
last_val = self._cu_seqlens_buffer[real_prefill_count]
self._cu_seqlens_buffer[real_prefill_count + 1 : padded_prefill_count + 1].fill_(
last_val
)
self.cu_seqlens = self._cu_seqlens_buffer[: padded_prefill_count + 1]

if padded_decode_count > 0 and padded_prefill_count > 0:
self._device_decode_prefill_buffer[0] = real_decode_count
# This describes the number of items in the prefill tensor relative to the
# decode tensor. If chunked prefill is present, it is included in the
# "prefill" part of the main split.
self._device_decode_prefill_buffer[1] = regular_prefill_count + (
1 if has_chunked_prefill_req else 0
)
self._device_decode_prefill_buffer[1] = real_prefill_count
self.device_decode_prefill = self._device_decode_prefill_buffer

# If using chunked prefill for this batch, store the number of chunked tokens
# and the number of regular prefill tokens
if has_chunked_prefill_req:
# Chunked request is the first prefill request (index real_decode_count)
chunked_prefill_token_count = (
cu_seqlens[real_decode_count + 1] - cu_seqlens[real_decode_count]
)

# Regular prefill tokens are everything after the chunked request tokens
regular_prefill_token_count = 0
if regular_prefill_count > 0:
regular_prefill_token_count = (
cu_seqlens[real_decode_count + 1 + regular_prefill_count]
- cu_seqlens[real_decode_count + 1]
)

self._device_chunked_prefill_buffer[0] = chunked_prefill_token_count
self._device_chunked_prefill_buffer[1] = regular_prefill_token_count
self.device_chunked_prefill = self._device_chunked_prefill_buffer

def allocate_slot(self) -> Optional[int]:
"""
Allocates a new slot for a request in the Mamba state buffers.
Expand Down
44 changes: 38 additions & 6 deletions megatron/core/inference/contexts/dynamic_block_allocator.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,11 @@ def __init__(
(self.total_count,), -1, dtype=torch.int64, device=torch.cuda.current_device()
)

# Hash-to-block mapping for O(1) prefix lookup
self.hash_to_block_id: Dict[int, int] = {}
# Hash-to-block mapping for O(1) prefix lookup (KV blocks)
self.kv_hash_to_block_id: Dict[int, int] = {}

# Hash-to-block mapping for blocks with cached Mamba state (1:1 with mamba slots)
self.mamba_hash_to_block_id: Dict[int, int] = {}

# Reference count per block: 0 = cached (evictable), >0 = actively used
self.block_ref_counts = torch.zeros(
Expand Down Expand Up @@ -216,7 +219,8 @@ def reset(self) -> None:
self.block_hashes.fill_(-1)

# Reset prefix caching state
self.hash_to_block_id.clear()
self.kv_hash_to_block_id.clear()
self.mamba_hash_to_block_id.clear()
self.block_ref_counts.fill_(0)
if self.prefix_caching_eviction_policy == PrefixCachingEvictionPolicy.LRU:
self.block_timestamps.fill_(0)
Expand All @@ -237,7 +241,26 @@ def register_block_hashes(self, block_ids: list[int], block_hashes: list[int]) -
id_tensor = torch.tensor(block_ids, dtype=torch.int64, device=self.block_hashes.device)
hash_tensor = torch.tensor(block_hashes, dtype=torch.int64, device=self.block_hashes.device)
self.block_hashes[id_tensor] = hash_tensor
self.hash_to_block_id.update(zip(block_hashes, block_ids))
self.kv_hash_to_block_id.update(zip(block_hashes, block_ids))

def register_mamba_block_hash(self, block_id: int, block_hash: int) -> None:
"""Register a block in the mamba hash map (1:1 with mamba state).

Args:
block_id: The KV block ID that has cached mamba state.
block_hash: The hash value for this block.
"""
self.mamba_hash_to_block_id[block_hash] = block_id

def deregister_mamba_block_hash(self, block_id: int) -> None:
"""Remove a block from the mamba hash map (when mamba state is evicted).

Args:
block_id: The KV block ID whose mamba state was evicted.
"""
hash_val = self.block_hashes[block_id].item()
if hash_val > 0:
self.mamba_hash_to_block_id.pop(hash_val, None)

def _deregister_blocks(self, block_ids: Tensor) -> None:
"""Remove blocks from prefix caching state and return to free pool.
Expand All @@ -255,12 +278,21 @@ def _deregister_blocks(self, block_ids: Tensor) -> None:
block_ids_i64 = block_ids.to(torch.int64)
hashes = self.block_hashes[block_ids_i64].tolist()

# Remove from hash_to_block_id dict (set ops + C-level map, no Python loop)
# Remove from kv_hash_to_block_id dict (set ops + C-level map, no Python loop)
keys_to_delete = set(hashes) - {-1}
deque(
map(self.hash_to_block_id.pop, keys_to_delete & self.hash_to_block_id.keys()), maxlen=0
map(self.kv_hash_to_block_id.pop, keys_to_delete & self.kv_hash_to_block_id.keys()), maxlen=0
)

# Also remove from mamba_hash_to_block_id (KV eviction implies mamba invalidation)
mamba_keys = keys_to_delete & self.mamba_hash_to_block_id.keys()
if mamba_keys:
deque(map(self.mamba_hash_to_block_id.pop, mamba_keys), maxlen=0)

# Invalidate Mamba state for evicted blocks (if Mamba prefix caching is enabled)
for block_id_int in block_ids.tolist():
self.context.invalidate_mamba_state_for_block(block_id_int)

# Reset block state (batched tensor ops)
self.block_hashes[block_ids] = -1
self.block_ref_counts[block_ids] = 0
Expand Down
Loading