diff --git a/megatron/core/inference/config.py b/megatron/core/inference/config.py index b9bd1f8421e..f4445814af8 100644 --- a/megatron/core/inference/config.py +++ b/megatron/core/inference/config.py @@ -209,6 +209,14 @@ class InferenceConfig: Only applies when enable_prefix_caching is True. """ + prefix_caching_mamba_gb: Optional[float] = None + """Memory budget (GB) for cached Mamba states in prefix caching. + Required for Mamba prefix caching in hybrid models. If None, Mamba prefix caching is disabled.""" + + use_triton_conv1d: bool = False + """Whether to use a Triton varlen conv1d kernel for Mamba prefill instead of + per-request causal_conv1d_fn calls. Only applies to hybrid models with Mamba layers.""" + # ================================= # Logging config # ================================= diff --git a/megatron/core/inference/contexts/attention_context/mamba_metadata.py b/megatron/core/inference/contexts/attention_context/mamba_metadata.py index bacaf882944..5f94ea911c2 100644 --- a/megatron/core/inference/contexts/attention_context/mamba_metadata.py +++ b/megatron/core/inference/contexts/attention_context/mamba_metadata.py @@ -36,11 +36,6 @@ def __init__(self, max_requests: int, max_tokens: int): (self.max_requests,), -1, dtype=torch.int32, device=self.device ) - # Map from the active chunked prefill request to its slot in the static Mamba state buffer - self._batch_indices_chunked_prefill_buffer = torch.full( - (1,), -1, dtype=torch.int32, device=self.device - ) - # Map from token id to request id for active prefill requests self._seq_idx_buffer = torch.full( (1, self.max_tokens), -1, dtype=torch.int32, device=self.device @@ -56,14 +51,6 @@ def __init__(self, max_requests: int, max_tokens: int): (2,), dtype=torch.int32, device=self.device ) - # Tuple of ( - # chunked prefill sequence length, - # total regular prefill sequence length - # ) - self._device_chunked_prefill_buffer = torch.zeros( - (2,), dtype=torch.int32, device=self.device - ) - # Allocator for Mamba state slots self.mamba_state_free_slots = torch.arange( self.max_requests, dtype=torch.int32, device=torch.cuda.current_device() @@ -90,11 +77,9 @@ def reset_varlen_metadata(self) -> None: """Resets varlen metadata.""" self.batch_indices_decode = None self.batch_indices_prefill = None - self.batch_indices_chunked_prefill = None self.cu_seqlens = None self.seq_idx = None self.device_decode_prefill = None - self.device_chunked_prefill = None def update( self, @@ -103,7 +88,6 @@ def update( cu_seqlens: torch.Tensor, batch_dimensions: InferenceBatchDimensions, padded_batch_dimensions: InferenceBatchDimensions, - enable_chunked_prefill: bool, ) -> None: """ Updates the dedicated CUDA graph mapping tensor with the indices @@ -116,7 +100,6 @@ def update( cu_seqlens (Tensor): Cumulative sequence lengths. batch_dimensions (InferenceBatchDimensions): Dimensions of the current batch. padded_batch_dimensions (InferenceBatchDimensions): Dimensions of the padded batch. - enable_chunked_prefill (bool): Whether chunked prefill is enabled """ real_decode_count = batch_dimensions.decode_req_count real_prefill_count = batch_dimensions.prefill_req_count @@ -125,48 +108,6 @@ def update( padded_prefill_count = padded_batch_dimensions.prefill_req_count padded_token_count = padded_batch_dimensions.token_count - has_chunked_prefill_req = enable_chunked_prefill and real_prefill_count > 0 - - # Although the context ensures that the last request is always the designated - # chunked prefill request, what we actually care about is ensuring that any - # prefill request with non-zero initial states is executed through the - # chunked prefill path. - # - # In the batch arrangement passed to this update function, the logic assumes - # the *first* prefill request is the one carrying states. - # - # There are three scenarios: - # - # Scenario A: No prefill request has initial states yet, but the last request - # is the designated chunked prefill request (starting a new chunk). - # - # [ ... Decode Requests ... ] [ Prefill (start) ] - # ^ - # |--- First prefill request - # Treated as having states. - # Harmless because actual initial states are 0. - # - # Scenario B: There is exactly 1 prefill request which is a continuing - # chunked prefill request with non-zero initial states. - # - # [ ... Decode Requests ... ] [ Prefill (cont) ] - # ^ - # |--- First prefill request - # Has non-zero initial states. - # - # Scenario C: There is a leftover chunked prefill request that is executing - # its last chunk, followed by additional prefill requests. - # - # [ ... Decode Requests ... ] [ Prefill (end) ] [ Prefill (new) ] ... - # ^ - # |--- First prefill request - # Has non-zero initial states. - # - # The implementation generalizes to Scenario A as well, where the first prefill - # request is treated as if it has non-zero initial states, which is safe. - # While this results in a minor inefficiency f there is no continuing chunked prefill - # request in a given batch, this case is infrequent. - if padded_decode_count > 0: # Update decode indices self._batch_indices_decode_buffer[:real_decode_count].copy_( @@ -176,118 +117,61 @@ def update( self._batch_indices_decode_buffer[real_decode_count:padded_decode_count] = -1 self.batch_indices_decode = self._batch_indices_decode_buffer[:padded_decode_count] - # Determine if we have a chunked prefill request and adjust counts for regular prefill - regular_prefill_count = real_prefill_count - chunked_req_idx = -1 - - if has_chunked_prefill_req: - # The first prefill request is the chunked one - regular_prefill_count -= 1 - chunked_req_idx = real_decode_count - - # Update chunked prefill indices - self._batch_indices_chunked_prefill_buffer[0] = active_mamba_indices[chunked_req_idx] - self.batch_indices_chunked_prefill = self._batch_indices_chunked_prefill_buffer - else: - self.batch_indices_chunked_prefill = None - if padded_prefill_count > 0: - # Update prefill indices (excluding chunked prefill from regular prefill buffer) - if regular_prefill_count > 0: - # If chunked prefill exists, regular prefills start after it. - # If no chunked prefill, regular prefills start at real_decode_count. - start_idx = real_decode_count + (1 if has_chunked_prefill_req else 0) - - self._batch_indices_prefill_buffer[:regular_prefill_count].copy_( - active_mamba_indices[start_idx : start_idx + regular_prefill_count] + # Update prefill indices (all prefill requests go through varlen) + if real_prefill_count > 0: + prefill_start_idx = real_decode_count + self._batch_indices_prefill_buffer[:real_prefill_count].copy_( + active_mamba_indices[prefill_start_idx : prefill_start_idx + real_prefill_count] ) - if padded_prefill_count > regular_prefill_count: - self._batch_indices_prefill_buffer[regular_prefill_count:padded_prefill_count] = -1 + if padded_prefill_count > real_prefill_count: + self._batch_indices_prefill_buffer[ + real_prefill_count:padded_prefill_count + ] = -1 self.batch_indices_prefill = self._batch_indices_prefill_buffer[:padded_prefill_count] - # Update seq_idx for regular prefills - # If chunked prefill exists, we need to skip its tokens in seq_idx + # Update seq_idx for all prefill requests + prefill_start_req_idx = real_decode_count + end_prefill_req_idx = real_decode_count + real_prefill_count - # Index where regular prefills end in the batch (decode + chunked + regular) - end_regular_prefill_req_idx = ( - real_decode_count + regular_prefill_count + (1 if has_chunked_prefill_req else 0) - ) - end_regular_prefill_token_idx = cu_seqlens[end_regular_prefill_req_idx] - - # Index where regular prefills start - start_regular_prefill_req_idx = real_decode_count + ( - 1 if has_chunked_prefill_req else 0 - ) - start_regular_prefill_token_idx = cu_seqlens[start_regular_prefill_req_idx] + start_prefill_token_idx = cu_seqlens[prefill_start_req_idx] + end_prefill_token_idx = cu_seqlens[end_prefill_req_idx] - # The length of tokens belonging to regular prefill requests - seq_len = end_regular_prefill_token_idx - start_regular_prefill_token_idx + seq_len = end_prefill_token_idx - start_prefill_token_idx if seq_len > 0: - # We subtract start_regular_prefill_req_idx to normalize request IDs to - # 0-based relative to this buffer + # Normalize request IDs to 0-based relative to prefill requests self._seq_idx_buffer[:, :seq_len].copy_( - token_to_request_idx[ - start_regular_prefill_token_idx:end_regular_prefill_token_idx - ] - - start_regular_prefill_req_idx + token_to_request_idx[start_prefill_token_idx:end_prefill_token_idx] + - prefill_start_req_idx ) if padded_token_count > seq_len: self._seq_idx_buffer[:, seq_len:padded_token_count] = -1 self.seq_idx = self._seq_idx_buffer[:, :padded_token_count] - # Update cu_seqlens for regular prefill requests + # Update cu_seqlens for all prefill requests self._cu_seqlens_buffer[0] = 0 - if regular_prefill_count > 0: - # Copy cu_seqlens for regular prefill requests and normalize by - # subtracting the start token index - start_req_idx = real_decode_count + (1 if has_chunked_prefill_req else 0) - end_req_idx = start_req_idx + regular_prefill_count - - self._cu_seqlens_buffer[1 : regular_prefill_count + 1].copy_( - cu_seqlens[start_req_idx + 1 : end_req_idx + 1] - cu_seqlens[start_req_idx] + if real_prefill_count > 0: + self._cu_seqlens_buffer[1 : real_prefill_count + 1].copy_( + cu_seqlens[prefill_start_req_idx + 1 : end_prefill_req_idx + 1] + - cu_seqlens[prefill_start_req_idx] ) # Pad the rest with the last value (effectively length 0 segments) - last_val = self._cu_seqlens_buffer[regular_prefill_count] - self._cu_seqlens_buffer[regular_prefill_count + 1 : padded_prefill_count + 1].fill_( + last_val = self._cu_seqlens_buffer[real_prefill_count] + self._cu_seqlens_buffer[real_prefill_count + 1 : padded_prefill_count + 1].fill_( last_val ) self.cu_seqlens = self._cu_seqlens_buffer[: padded_prefill_count + 1] if padded_decode_count > 0 and padded_prefill_count > 0: self._device_decode_prefill_buffer[0] = real_decode_count - # This describes the number of items in the prefill tensor relative to the - # decode tensor. If chunked prefill is present, it is included in the - # "prefill" part of the main split. - self._device_decode_prefill_buffer[1] = regular_prefill_count + ( - 1 if has_chunked_prefill_req else 0 - ) + self._device_decode_prefill_buffer[1] = real_prefill_count self.device_decode_prefill = self._device_decode_prefill_buffer - # If using chunked prefill for this batch, store the number of chunked tokens - # and the number of regular prefill tokens - if has_chunked_prefill_req: - # Chunked request is the first prefill request (index real_decode_count) - chunked_prefill_token_count = ( - cu_seqlens[real_decode_count + 1] - cu_seqlens[real_decode_count] - ) - - # Regular prefill tokens are everything after the chunked request tokens - regular_prefill_token_count = 0 - if regular_prefill_count > 0: - regular_prefill_token_count = ( - cu_seqlens[real_decode_count + 1 + regular_prefill_count] - - cu_seqlens[real_decode_count + 1] - ) - - self._device_chunked_prefill_buffer[0] = chunked_prefill_token_count - self._device_chunked_prefill_buffer[1] = regular_prefill_token_count - self.device_chunked_prefill = self._device_chunked_prefill_buffer - def allocate_slot(self) -> Optional[int]: """ Allocates a new slot for a request in the Mamba state buffers. diff --git a/megatron/core/inference/contexts/dynamic_block_allocator.py b/megatron/core/inference/contexts/dynamic_block_allocator.py index abfb7278b14..89b6f34c322 100644 --- a/megatron/core/inference/contexts/dynamic_block_allocator.py +++ b/megatron/core/inference/contexts/dynamic_block_allocator.py @@ -57,8 +57,11 @@ def __init__( (self.total_count,), -1, dtype=torch.int64, device=torch.cuda.current_device() ) - # Hash-to-block mapping for O(1) prefix lookup - self.hash_to_block_id: Dict[int, int] = {} + # Hash-to-block mapping for O(1) prefix lookup (KV blocks) + self.kv_hash_to_block_id: Dict[int, int] = {} + + # Hash-to-block mapping for blocks with cached Mamba state (1:1 with mamba slots) + self.mamba_hash_to_block_id: Dict[int, int] = {} # Reference count per block: 0 = cached (evictable), >0 = actively used self.block_ref_counts = torch.zeros( @@ -216,7 +219,8 @@ def reset(self) -> None: self.block_hashes.fill_(-1) # Reset prefix caching state - self.hash_to_block_id.clear() + self.kv_hash_to_block_id.clear() + self.mamba_hash_to_block_id.clear() self.block_ref_counts.fill_(0) if self.prefix_caching_eviction_policy == PrefixCachingEvictionPolicy.LRU: self.block_timestamps.fill_(0) @@ -237,7 +241,26 @@ def register_block_hashes(self, block_ids: list[int], block_hashes: list[int]) - id_tensor = torch.tensor(block_ids, dtype=torch.int64, device=self.block_hashes.device) hash_tensor = torch.tensor(block_hashes, dtype=torch.int64, device=self.block_hashes.device) self.block_hashes[id_tensor] = hash_tensor - self.hash_to_block_id.update(zip(block_hashes, block_ids)) + self.kv_hash_to_block_id.update(zip(block_hashes, block_ids)) + + def register_mamba_block_hash(self, block_id: int, block_hash: int) -> None: + """Register a block in the mamba hash map (1:1 with mamba state). + + Args: + block_id: The KV block ID that has cached mamba state. + block_hash: The hash value for this block. + """ + self.mamba_hash_to_block_id[block_hash] = block_id + + def deregister_mamba_block_hash(self, block_id: int) -> None: + """Remove a block from the mamba hash map (when mamba state is evicted). + + Args: + block_id: The KV block ID whose mamba state was evicted. + """ + hash_val = self.block_hashes[block_id].item() + if hash_val > 0: + self.mamba_hash_to_block_id.pop(hash_val, None) def _deregister_blocks(self, block_ids: Tensor) -> None: """Remove blocks from prefix caching state and return to free pool. @@ -255,12 +278,21 @@ def _deregister_blocks(self, block_ids: Tensor) -> None: block_ids_i64 = block_ids.to(torch.int64) hashes = self.block_hashes[block_ids_i64].tolist() - # Remove from hash_to_block_id dict (set ops + C-level map, no Python loop) + # Remove from kv_hash_to_block_id dict (set ops + C-level map, no Python loop) keys_to_delete = set(hashes) - {-1} deque( - map(self.hash_to_block_id.pop, keys_to_delete & self.hash_to_block_id.keys()), maxlen=0 + map(self.kv_hash_to_block_id.pop, keys_to_delete & self.kv_hash_to_block_id.keys()), maxlen=0 ) + # Also remove from mamba_hash_to_block_id (KV eviction implies mamba invalidation) + mamba_keys = keys_to_delete & self.mamba_hash_to_block_id.keys() + if mamba_keys: + deque(map(self.mamba_hash_to_block_id.pop, mamba_keys), maxlen=0) + + # Invalidate Mamba state for evicted blocks (if Mamba prefix caching is enabled) + for block_id_int in block_ids.tolist(): + self.context.invalidate_mamba_state_for_block(block_id_int) + # Reset block state (batched tensor ops) self.block_hashes[block_ids] = -1 self.block_ref_counts[block_ids] = 0 diff --git a/megatron/core/inference/contexts/dynamic_context.py b/megatron/core/inference/contexts/dynamic_context.py index 29923bd08ad..810896620af 100644 --- a/megatron/core/inference/contexts/dynamic_context.py +++ b/megatron/core/inference/contexts/dynamic_context.py @@ -250,6 +250,10 @@ def __init__(self, model_config: TransformerConfig, inference_config: InferenceC # Prefix caching configuration self.enable_prefix_caching = inference_config.enable_prefix_caching self.prefix_caching_eviction_policy = inference_config.prefix_caching_eviction_policy + self.prefix_caching_mamba_gb = inference_config.prefix_caching_mamba_gb + + # Mamba conv1d kernel selection + self.use_triton_conv1d = inference_config.use_triton_conv1d # Step counter (used for LRU timestamps in prefix caching) self.step_count = 0 @@ -356,12 +360,12 @@ def __init__(self, model_config: TransformerConfig, inference_config: InferenceC ) assert self.block_size_bytes > 0 - mamba_states_memory_per_request = 0 + self.mamba_states_memory_per_request = 0 if self.is_hybrid_model: - mamba_states_memory_per_request += math.prod(self.mamba_conv_states_shape) - mamba_states_memory_per_request += math.prod(self.mamba_ssm_states_shape) - mamba_states_memory_per_request *= self.num_mamba_layers - mamba_states_memory_per_request *= dtype_size_bytes + self.mamba_states_memory_per_request += math.prod(self.mamba_conv_states_shape) + self.mamba_states_memory_per_request += math.prod(self.mamba_ssm_states_shape) + self.mamba_states_memory_per_request *= self.num_mamba_layers + self.mamba_states_memory_per_request *= dtype_size_bytes # Unified memory and general tensor management. self.unified_memory_level = inference_config.unified_memory_level @@ -412,7 +416,7 @@ def __init__(self, model_config: TransformerConfig, inference_config: InferenceC # Calculate total memory before partition total_memory = buffer_size_bytes + paused_buffer_size_bytes mamba_memory_bytes = total_memory * mamba_memory_ratio - mamba_max_requests = int(mamba_memory_bytes // mamba_states_memory_per_request) + mamba_max_requests = int(mamba_memory_bytes // self.mamba_states_memory_per_request) # Reduce buffer sizes for KV cache buffer_size_bytes = int(buffer_size_bytes * (1.0 - mamba_memory_ratio)) @@ -423,11 +427,11 @@ def __init__(self, model_config: TransformerConfig, inference_config: InferenceC paused_block_count = paused_buffer_size_bytes // self.block_size_bytes else: block_count = buffer_size_bytes // ( - self.block_size_bytes + mamba_states_memory_per_request + self.block_size_bytes + self.mamba_states_memory_per_request ) block_count = max(2, block_count) # need >= 1 active block + 1 dummy block paused_block_count = paused_buffer_size_bytes // ( - self.block_size_bytes + mamba_states_memory_per_request + self.block_size_bytes + self.mamba_states_memory_per_request ) # If using pipeline parallelism synchronize the total block count in case the @@ -659,6 +663,47 @@ def _allocate_mamba_states(self): else: self.mamba_metadata = None + def _allocate_mamba_cache(self): + """Allocate fixed-size tensor pool for Mamba prefix caching.""" + # Calculate max slots from memory budget + prefix_caching_mamba_gb = self.prefix_caching_mamba_gb + if (self.is_hybrid_model and self.enable_prefix_caching and + prefix_caching_mamba_gb is not None and prefix_caching_mamba_gb > 0): + prefix_caching_mamba_bytes = int(prefix_caching_mamba_gb * (1024 ** 3)) + self.max_mamba_cache_slots = prefix_caching_mamba_bytes // self.mamba_states_memory_per_request + else: + self.max_mamba_cache_slots = 0 + + if self.max_mamba_cache_slots > 0: + # Fixed-size tensor pool for cached Mamba states + self.mamba_cache_conv_states = torch.empty( + (self.num_mamba_layers, self.max_mamba_cache_slots) + self.mamba_conv_states_shape, + dtype=self.params_dtype, + device=torch.cuda.current_device(), + ) + self.mamba_cache_ssm_states = torch.empty( + (self.num_mamba_layers, self.max_mamba_cache_slots) + self.mamba_ssm_states_shape, + dtype=self.params_dtype, + device=torch.cuda.current_device(), + ) + + # Mapping: block_id -> mamba_slot_id (-1 if no cached state) + total_blocks = self.block_allocator.total_count + self.block_to_mamba_slot = torch.full( + (total_blocks,), -1, dtype=torch.int32, device=torch.cuda.current_device() + ) + + # Reverse mapping for eviction: slot_id -> block_id + self.mamba_slot_to_block = torch.full( + (self.max_mamba_cache_slots,), -1, dtype=torch.int32, device=torch.cuda.current_device() + ) + + # Free slot pool (similar to MambaMetadata pattern) + self.mamba_cache_free_slots = torch.arange( + self.max_mamba_cache_slots, dtype=torch.int32, device=torch.cuda.current_device() + ) + self.mamba_cache_free_count = self.max_mamba_cache_slots + def initialize_all_tensors(self) -> None: """Allocate all GPU state during initial construction.""" # Mark allocated. @@ -734,6 +779,7 @@ def initialize_all_tensors(self) -> None: with ctx_manager: self._allocate_memory_buffer() self._allocate_mamba_states() + self._allocate_mamba_cache() # Reset attention and Mamba state. self.reset_attention_state() @@ -1100,6 +1146,144 @@ def reset_mamba_state(self) -> None: if self.is_hybrid_model: self.mamba_metadata.reset() + # ========================================================================= + # Mamba prefix caching methods + # ========================================================================= + + def _allocate_mamba_cache_slot(self, block_id: int) -> int: + """Allocate a Mamba cache slot for a block, evicting LRU if needed. + + Args: + block_id: The KV block ID to associate with this Mamba state + + Returns: + Slot index for storing the Mamba state + """ + # Check if block already has a slot + existing_slot = self.block_to_mamba_slot[block_id].item() + if existing_slot >= 0: + return existing_slot + + # Try to get a free slot + if self.mamba_cache_free_count > 0: + self.mamba_cache_free_count -= 1 + slot = self.mamba_cache_free_slots[self.mamba_cache_free_count].item() + else: + # Need to evict LRU - find oldest block with Mamba state + slot = self._evict_lru_mamba_slot() + + # Update mappings + self.block_to_mamba_slot[block_id] = slot + self.mamba_slot_to_block[slot] = block_id + return slot + + def _evict_lru_mamba_slot(self) -> int: + """Evict the least recently used Mamba state and return its slot. + + Uses block_timestamps from KV block allocator for LRU ordering. + Only evicts Mamba state - KV cache remains intact. + """ + # Find all blocks with Mamba state that are not currently in use + # (ref_count == 0 means cached but not actively used) + candidates = [] + for slot in range(self.max_mamba_cache_slots): + block_id = self.mamba_slot_to_block[slot].item() + if block_id >= 0: + ref_count = self.block_allocator.block_ref_counts[block_id].item() + if ref_count == 0: + timestamp = self.block_allocator.block_timestamps[block_id].item() + candidates.append((timestamp, slot, block_id)) + + if not candidates: + raise RuntimeError("Cannot evict Mamba state - all slots in active use") + + # Sort by timestamp (oldest first) and evict + candidates.sort(key=lambda x: x[0]) + _, slot, block_id = candidates[0] + + # Clear mappings (but keep KV cache intact) + self.block_to_mamba_slot[block_id] = -1 + self.mamba_slot_to_block[slot] = -1 + + # Remove from mamba hash map + self.block_allocator.deregister_mamba_block_hash(block_id) + + return slot + + def store_mamba_state_for_block(self, block_id: int, request_idx: int) -> None: + """Copy current per-request Mamba state to block cache. + + Called after prefill completes a block that needs caching. + + Args: + block_id: The KV block ID to store state for + request_idx: Index of request whose Mamba state to store + """ + if self.max_mamba_cache_slots == 0: + return # Mamba caching disabled + + # Allocate slot (may trigger LRU eviction) + slot = self._allocate_mamba_cache_slot(block_id) + + # Copy from per-request state to cache + mamba_idx = self.mamba_metadata.request_to_mamba_state_idx[request_idx].item() + self.mamba_cache_conv_states[:, slot] = self.mamba_conv_states[:, mamba_idx].clone() + self.mamba_cache_ssm_states[:, slot] = self.mamba_ssm_states[:, mamba_idx].clone() + + def has_mamba_state_for_block(self, block_id: int) -> bool: + """Check if a block has valid cached Mamba state. + + Args: + block_id: The KV block ID to check + + Returns: + True if block has cached Mamba state, False otherwise + """ + if self.max_mamba_cache_slots == 0: + return False + return self.block_to_mamba_slot[block_id].item() >= 0 + + def restore_mamba_state_from_block(self, request_idx: int, block_id: int) -> bool: + """Initialize request's Mamba state from cached block state. + + Args: + request_idx: Index of request to restore state for + block_id: The KV block ID containing cached state + + Returns: + True if state was restored, False if no cached state + """ + if self.max_mamba_cache_slots == 0: + return False + + slot = self.block_to_mamba_slot[block_id].item() + if slot < 0: + return False + + mamba_idx = self.mamba_metadata.request_to_mamba_state_idx[request_idx].item() + self.mamba_conv_states[:, mamba_idx] = self.mamba_cache_conv_states[:, slot].clone() + self.mamba_ssm_states[:, mamba_idx] = self.mamba_cache_ssm_states[:, slot].clone() + return True + + def invalidate_mamba_state_for_block(self, block_id: int) -> None: + """Invalidate Mamba state when KV block is evicted. + + Called when a KV block is evicted - must also free its Mamba slot. + + Args: + block_id: The KV block ID being evicted + """ + if self.max_mamba_cache_slots == 0: + return + + slot = self.block_to_mamba_slot[block_id].item() + if slot >= 0: + # Return slot to free pool + self.block_to_mamba_slot[block_id] = -1 + self.mamba_slot_to_block[slot] = -1 + self.mamba_cache_free_slots[self.mamba_cache_free_count] = slot + self.mamba_cache_free_count += 1 + def add_dummy_requests_parallel( self, requests: Sequence[DynamicInferenceRequest], *, count_as_prefill: bool = True ) -> None: @@ -1482,7 +1666,6 @@ def initialize_attention_state( cu_seqlens, batch_dimensions=attn_dimensions, padded_batch_dimensions=self.padded_batch_dimensions, - enable_chunked_prefill=self.is_chunked_prefill_enabled(), ) if self.moe_enable_routing_replay: @@ -1612,6 +1795,20 @@ def _compute_prefix_match( Shared by check_availability (budget checks) and add_request (execution). + For hybrid models, matched_block_ids covers the full KV match (for block + reuse and ref counting), but prefix_skip_tokens is based only on the + Mamba-cached range. The four-part prefill is: + 1. Skip (0 to num_mamba_matched * block_size): KV + Mamba cached. + 2. Divergence (num_mamba_matched to num_kv_matched blocks): KV blocks + reused, Mamba computes recurrent state. Engine stores Mamba state + in the allocator at kv_divergence_token. + 3. Fresh aligned (num_kv_matched blocks to last_aligned_token): nothing + cached. Engine stores Mamba state in the allocator at + last_aligned_token. + 4. Fresh partial (last_aligned_token to end): partial block. Mamba + state is maintained in the context's per-request state but not + stored in the allocator (not at a block boundary). + Returns: Tuple of (matched_block_ids, num_blocks_from_pool, already_allocated_blocks, overall_required_blocks, @@ -1642,7 +1839,15 @@ def _compute_prefix_match( block_aligned = finished % self.block_size_tokens == 0 if num_matched > 0 and block_aligned: - prefix_skip_tokens = min(num_matched * self.block_size_tokens, chunk_length - 1) + is_chunked_prefill = finished > 0 + if self.is_hybrid_model and not is_chunked_prefill: + # Skip only through Mamba-cached blocks. The full KV match + # stays in matched_block_ids for block reuse and ref counting. + num_mamba_matched = getattr(req, '_mamba_num_matched_blocks', 0) + num_skippable = min(num_mamba_matched, num_matched) + else: + num_skippable = num_matched + prefix_skip_tokens = min(num_skippable * self.block_size_tokens, chunk_length) else: prefix_skip_tokens = 0 @@ -1713,7 +1918,7 @@ def _find_matching_prefix_blocks( return [], 0 hashes = req.precomputed_block_hashes[start_block:end_block] - hash_to_block = self.block_allocator.hash_to_block_id + hash_to_block = self.block_allocator.kv_hash_to_block_id # Batch dict lookups via C-level map() — faster than Python for loop block_ids = list(map(hash_to_block.get, hashes)) @@ -1770,6 +1975,17 @@ def add_request(self, req: DynamicInferenceRequest, chunk_length: Optional[int] effective_chunk_length, ) = self._compute_prefix_match(req, chunk_length) num_matched_blocks = len(matched_block_ids) + + # Direct decode entry: when all prompt blocks are matched on a + # block-aligned boundary, effective_chunk_length is 0. Back off by 1 + # token so the last prompt token is reprocessed, producing the first + # output logits. After update_requests the request transitions to + # decode mode with kv_offset == chunk_length, identical to a request + # that completed its full prefill normally. + if effective_chunk_length == 0: + prefix_skip_tokens -= 1 + effective_chunk_length = 1 + effective_kv_offset = req.finished_chunk_token_count + prefix_skip_tokens # Slice tokens to skip matched prefix @@ -1911,12 +2127,24 @@ def _register_range(start: int, end: int): mamba_idx = self.mamba_metadata.allocate_slot() if mamba_idx is None: raise ContextOverflowError(req.request_id, "No Mamba slots available") - - # Initialize the allocated Mamba state - self.mamba_conv_states[:, mamba_idx] = 0.0 - self.mamba_ssm_states[:, mamba_idx] = 0.0 self.mamba_metadata.request_to_mamba_state_idx[self.total_request_count] = mamba_idx + # Check if we should restore Mamba state from cache (prefix match) + num_mamba_matched = getattr(req, '_mamba_num_matched_blocks', 0) + restored = False + if num_mamba_matched > 0 and num_matched_blocks > 0: + # Find the block ID of the last Mamba-matched block + last_mamba_block_idx = num_mamba_matched - 1 + last_mamba_block_id = matched_block_ids[last_mamba_block_idx] + restored = self.restore_mamba_state_from_block( + self.total_request_count, last_mamba_block_id + ) + + if not restored: + # No cached state - initialize to zero + self.mamba_conv_states[:, mamba_idx] = 0.0 + self.mamba_ssm_states[:, mamba_idx] = 0.0 + self.active_token_count += effective_chunk_length self.lifetime_prefill_token_count += effective_chunk_length self.total_request_count += 0 if req.finished_chunk_token_count > 0 else 1 @@ -1997,7 +2225,7 @@ def release_memory_blocks_from_request_indexes(self, request_indexes) -> None: # tensor. self.request_to_kv_block_ids[request_indexes] = -1 - # Free Mamba slots. + # Free Mamba slots and clear initial-state flags. if self.is_hybrid_model: self.mamba_metadata.free_slots(request_indexes) diff --git a/megatron/core/inference/engines/dynamic_engine.py b/megatron/core/inference/engines/dynamic_engine.py index 5de3a4108ed..01ce50f5d04 100644 --- a/megatron/core/inference/engines/dynamic_engine.py +++ b/megatron/core/inference/engines/dynamic_engine.py @@ -3,6 +3,7 @@ import asyncio import concurrent.futures import logging +import math import multiprocessing import socket import struct @@ -1163,6 +1164,120 @@ def get_prefix_coordination_metrics(self) -> dict: """ return {"waits": self._prefix_coordination_waits} + def _compute_mamba_prefill_boundaries( + self, + req: DynamicInferenceRequest, + num_matched_blocks: int, + ) -> tuple: + """Compute token boundaries for three-part Mamba prefill. + + Args: + req: The inference request + num_matched_blocks: Number of prefix blocks with cached Mamba state + + Returns: + Tuple of (divergence_token, last_aligned_token, prompt_length) + """ + block_size = self.context.block_size_tokens + prompt_length = len(req.prompt_tokens) + divergence_token = num_matched_blocks * block_size + last_aligned_token = (prompt_length // block_size) * block_size + return divergence_token, last_aligned_token, prompt_length + + def _find_mamba_divergence_block(self, matched_kv_blocks: list) -> int: + """Find the last KV block that also has cached Mamba state. + + Args: + matched_kv_blocks: List of KV block IDs that matched the prefix + + Returns: + Number of blocks with valid Mamba state (0 if none) + """ + for i in range(len(matched_kv_blocks) - 1, -1, -1): + if self.context.has_mamba_state_for_block(matched_kv_blocks[i]): + return i + 1 + return 0 # No Mamba state cached for any matched block + + def _store_mamba_states_for_completed_prefill(self): + """Store Mamba state at block boundaries after prefill. + + Called after prefill forward pass completes. For each request that + was part of this prefill, stores Mamba state at meaningful boundaries: + the KV divergence block and the last-aligned block. + + Only stores state if: + - The chunk ends at a mamba-meaningful boundary (KV divergence or last-aligned) + - The request has at least one complete block + - The block doesn't already have cached Mamba state + """ + block_size = self.context.block_size_tokens + + for req_idx in range(self.context.paused_request_count, + self.context.total_request_count): + request_id = self.context.request_ids[req_idx].item() + req = self.get_request(request_id) + + # Calculate total tokens actually in context (processed through forward). + # For continuing chunks, remaining_prompt_tokens still holds unprocessed + # tokens, so subtract from total prompt length. + total_prefilled = len(req.prompt_tokens) - len(req.remaining_prompt_tokens) + + # For hybrid mamba: only store when chunk ends at a meaningful boundary + kv_divergence = getattr(req, '_kv_divergence_token', 0) + last_aligned = getattr(req, '_mamba_last_aligned_token', 0) + is_kv_divergence = (kv_divergence > 0 and total_prefilled == kv_divergence) + is_last_aligned = (last_aligned > 0 and total_prefilled == last_aligned) + if not (is_kv_divergence or is_last_aligned): + continue + + # Find the last complete block index + num_complete_blocks = total_prefilled // block_size + if num_complete_blocks == 0: + continue # No complete blocks to store + + # Get the block ID of the last complete block + last_complete_block_idx = num_complete_blocks - 1 + if last_complete_block_idx >= self.context.request_kv_block_counts[req_idx].item(): + continue # Safety check + + block_id = self.context.request_to_kv_block_ids[req_idx, last_complete_block_idx].item() + if block_id < 0: + continue # Invalid block ID + + # Store mamba state and register in mamba hash map + already_has = self.context.has_mamba_state_for_block(block_id) + if not already_has: + self.context.store_mamba_state_for_block(block_id, req_idx) + if req.precomputed_block_hashes and last_complete_block_idx < len(req.precomputed_block_hashes): + block_hash = req.precomputed_block_hashes[last_complete_block_idx] + self.context.block_allocator.register_mamba_block_hash(block_id, block_hash) + + def _get_mamba_chunk_limit(self, req) -> Optional[int]: + """Return max chunk_length based on next mamba boundary, or None. + + Two boundaries matter for mamba state storage: KV divergence (where + KV match ends) and last-aligned (last complete block in prompt). + The engine must break chunks at these points so _store_mamba_states + can detect and store state at the correct boundaries. + """ + finished = req.finished_chunk_token_count + kv_divergence = getattr(req, '_kv_divergence_token', 0) + last_aligned = getattr(req, '_mamba_last_aligned_token', 0) + + # Skip boundaries already covered by restored Mamba state. The state + # at those boundaries is already cached from a previous request, so + # no chunk break is needed for storage. + mamba_covered = ( + getattr(req, '_mamba_num_matched_blocks', 0) * self.context.block_size_tokens + ) + + if kv_divergence > mamba_covered and finished < kv_divergence: + return kv_divergence - finished + elif last_aligned > mamba_covered and finished < last_aligned: + return last_aligned - finished + else: + return None + def schedule_waiting_requests(self): """Tries to schedule any requests in the waiting pool.""" if self.enable_chunked_prefill: @@ -1174,6 +1289,10 @@ def schedule_non_chunked_prefill(self): """ Perform the same original scheduling logic for non-chunked runs """ + # Mamba prefix caching requires chunked prefill for breaking at block boundaries + assert not (self.context.is_hybrid_model and self.context.max_mamba_cache_slots > 0), \ + "Mamba prefix caching requires chunked prefill. Use schedule_chunked_prefill() instead." + prefix_caching_enabled = self.context.enable_prefix_caching if prefix_caching_enabled: pending_block_hashes = set() @@ -1200,7 +1319,7 @@ def schedule_non_chunked_prefill(self): # Add these hashes to pending. if prefix_caching_enabled: for block_hash in req.precomputed_block_hashes: - if block_hash not in self.context.block_allocator.hash_to_block_id: + if block_hash not in self.context.block_allocator.kv_hash_to_block_id: pending_block_hashes.add(block_hash) self.context.add_request(req) self._loop.call_soon_threadsafe( @@ -1259,6 +1378,33 @@ def schedule_chunked_prefill(self): ) continue + # For hybrid models with Mamba prefix caching: compute Mamba-aware boundaries + if (self.context.is_hybrid_model and + self.context.max_mamba_cache_slots > 0 and + not is_continuing_chunked_prefill): + # Find matching KV blocks + num_hashes = len(req.precomputed_block_hashes) if req.precomputed_block_hashes else 0 + matched_blocks, _ = self.context._find_matching_prefix_blocks(req, 0, num_hashes) + num_kv_matched = len(matched_blocks) + + # Find how many of those also have cached Mamba state + if num_kv_matched > 0: + num_mamba_matched = self._find_mamba_divergence_block(matched_blocks) + else: + num_mamba_matched = 0 + + # Store for use in add_request() and chunk break enforcement + req._mamba_num_matched_blocks = num_mamba_matched + req._mamba_divergence_token, req._mamba_last_aligned_token, _ = \ + self._compute_mamba_prefill_boundaries(req, num_mamba_matched) + # KV divergence is only meaningful when mamba match is active and KV + # match extends beyond it. Otherwise, add_request truncates the KV + # match to 0, making the divergence point meaningless. + if num_mamba_matched > 0 and num_kv_matched > num_mamba_matched: + req._kv_divergence_token = num_kv_matched * self.context.block_size_tokens + else: + req._kv_divergence_token = 0 + # Use remaining prompt tokens for scheduling decisions remaining_len = len(req.remaining_prompt_tokens) token_fully_can_be_added = ( @@ -1269,11 +1415,19 @@ def schedule_chunked_prefill(self): request_can_be_added = is_continuing_chunked_prefill or request_can_be_added if request_can_be_added and kv_cache_available: - if token_fully_can_be_added: + # Compute mamba chunk limit for hybrid models + mamba_limit = None + if self.context.is_hybrid_model and self.context.max_mamba_cache_slots > 0: + mamba_limit = self._get_mamba_chunk_limit(req) + + # Check if mamba boundary requires a chunk break even when tokens fully fit + mamba_forces_chunk = (mamba_limit is not None and remaining_len > mamba_limit) + + if token_fully_can_be_added and not mamba_forces_chunk: # Add these hashes to pending. if prefix_caching_enabled: for block_hash in req.precomputed_block_hashes: - if block_hash not in self.context.block_allocator.hash_to_block_id: + if block_hash not in self.context.block_allocator.kv_hash_to_block_id: pending_block_hashes.add(block_hash) self.context.chunked_prefill_request_id = -1 self.context.add_request(req) @@ -1286,13 +1440,15 @@ def schedule_chunked_prefill(self): self.waiting_request_ids.popleft() # Only this case we keep checking the rest of the waiting queue can_schedule = True - elif token_partially_can_be_added: + elif token_partially_can_be_added or mamba_forces_chunk: # Add these hashes to pending. if prefix_caching_enabled: for block_hash in req.precomputed_block_hashes: - if block_hash not in self.context.block_allocator.hash_to_block_id: + if block_hash not in self.context.block_allocator.kv_hash_to_block_id: pending_block_hashes.add(block_hash) - chunk_length = self.context.max_tokens - self.context.active_token_count + chunk_length = min(remaining_len, self.context.max_tokens - self.context.active_token_count) + if mamba_limit is not None: + chunk_length = min(chunk_length, mamba_limit) # If this chunk would leave exactly 1 token for the final chunk, reduce this # chunk by 1 so the final chunk has 2 tokens. This avoids the edge case where @@ -1363,6 +1519,11 @@ async def async_forward(self) -> Tuple[Dict, Dict, float]: step_time = self.step_start_event.elapsed_time(self.step_end_event) / 1e3 self.context.step_count += 1 + # Store Mamba state for blocks that completed during this prefill + if (not is_decode_only and self.context.enable_prefix_caching + and self.context.is_hybrid_model and self.context.max_mamba_cache_slots > 0): + self._store_mamba_states_for_completed_prefill() + range_pop() if ( diff --git a/megatron/core/ssm/mamba_mixer.py b/megatron/core/ssm/mamba_mixer.py index 896ed5300a9..7bd3da50098 100644 --- a/megatron/core/ssm/mamba_mixer.py +++ b/megatron/core/ssm/mamba_mixer.py @@ -65,6 +65,19 @@ HAVE_MAMBA_SSM = True except ImportError: + mamba_chunk_scan_combined = None + mamba_split_conv1d_scan_combined = None + HAVE_MAMBA_SSM = False + +try: + from megatron.core.ssm.ops.ssd_combined import mamba_chunk_scan_combined_varlen + + HAVE_SSM_OPS_VARLEN = True +except ImportError: + mamba_chunk_scan_combined_varlen = None + HAVE_SSM_OPS_VARLEN = False + +if not HAVE_MAMBA_SSM: from unittest.mock import MagicMock RMSNormGated = MagicMock() @@ -505,67 +518,69 @@ def _dynamic_inference_prefill( conv_state: torch.Tensor, ssm_state: torch.Tensor, ) -> torch.Tensor: - """Helper to run dynamic inference prefill (chunked prefill request separately).""" - metadata = context.mamba_metadata - - # Use the regular prefill request count to determine if regular - # prefill path needs to be run when chunked prefill is enabled - prefill_req_count = context.batch_dimensions.prefill_req_count - - # Padded prefill token count - prefill_token_count = zxBCdt.shape[0] - - enable_chunked_prefill = context.is_chunked_prefill_enabled() + """Helper to run dynamic inference prefill. - y_chunked = None - y_regular = None + All prefill requests (with or without initial Mamba states) are processed + together through the unified varlen path. - # Chunked prefill - if enable_chunked_prefill: - y_chunked = self._ssm_prefill( - zxBCdt[: metadata.device_chunked_prefill[0]], - conv_state=conv_state, - ssm_state=ssm_state, - batch_indices=metadata.batch_indices_chunked_prefill, - is_chunked_prefill=True, - ) - - # Update zxBCdt to contain the remaining slice for regular prefill - zxBCdt_remainder = torch.empty_like(zxBCdt) - tensor_get_slice_after( - zxBCdt, zxBCdt_remainder, metadata.device_chunked_prefill, check_bounds=False - ) - zxBCdt = zxBCdt_remainder + During CUDA graph capture, GPU-to-CPU sync (.item()) is forbidden, so we + pass padded tensors directly to the old mamba_chunk_scan_combined kernel + which handles padding natively via seq_idx. No requests have initial + states during graph capture, so this produces correct results. + """ + metadata = context.mamba_metadata + real_prefill_count = context.batch_dimensions.prefill_req_count + if real_prefill_count <= 0: + return None - # Regular prefill - if not enable_chunked_prefill or prefill_req_count > 1: - y_regular = self._ssm_prefill( + if context.is_creating_cuda_graphs: + # CUDA graph capture: pass padded tensors directly (no .item() calls). + return self._ssm_prefill( zxBCdt, conv_state=conv_state, ssm_state=ssm_state, seq_idx=metadata.seq_idx, cu_seqlens=metadata.cu_seqlens, - return_varlen_states=True, batch_indices=metadata.batch_indices_prefill, + cuda_graph_compatible=True, ) - # Merge chunked prefill and regular prefill results - if y_chunked is not None and y_regular is not None: - y_combined = torch.empty_like(y_regular) - tensor_merge( - y_chunked, y_regular, metadata.device_chunked_prefill, output_tensor=y_combined - ) - return y_combined - elif y_chunked is not None: - y_prefill = torch.empty( - (prefill_token_count, 1, y_chunked.shape[-1]), - dtype=y_chunked.dtype, - device=y_chunked.device, + # Strip CUDA-graph padding from metadata tensors. The padded entries + # have -1 batch_indices and zero-length segments in cu_seqlens, which + # would cause out-of-bounds indexing in the varlen SSM kernel. + cu_seqlens = metadata.cu_seqlens[: real_prefill_count + 1] + batch_indices = metadata.batch_indices_prefill[:real_prefill_count] + real_token_count = cu_seqlens[-1].item() + seq_idx = ( + metadata.seq_idx[:, :real_token_count] + if metadata.seq_idx is not None + else None + ) + + # Also strip padded tokens from the data tensor itself. zxBCdt has + # shape (padded_total_tokens, 1, d); trim to real tokens so that all + # downstream tensors (z, xBC, dt) have consistent shapes. + padded_token_count = zxBCdt.shape[0] + zxBCdt = zxBCdt[:real_token_count] + + y_prefill = self._ssm_prefill( + zxBCdt, + conv_state=conv_state, + ssm_state=ssm_state, + seq_idx=seq_idx, + cu_seqlens=cu_seqlens, + batch_indices=batch_indices, + use_triton_conv1d=context.use_triton_conv1d, + ) + + # Pad output back to padded token count. The caller (residual add, + # tensor_merge) expects the output to match the padded input shape. + if y_prefill.shape[0] < padded_token_count: + y_prefill = F.pad( + y_prefill, (0, 0, 0, 0, 0, padded_token_count - y_prefill.shape[0]) ) - y_prefill[: metadata.device_chunked_prefill[0]] = y_chunked - return y_prefill - else: - return y_regular + + return y_prefill def _decode( self, hidden_states, conv_state, ssm_state, batch_indices: Optional[torch.Tensor] = None @@ -694,9 +709,9 @@ def _ssm_prefill( ssm_state: Optional[torch.Tensor], seq_idx: Optional[torch.Tensor] = None, cu_seqlens: Optional[torch.Tensor] = None, - return_varlen_states: bool = False, batch_indices: Optional[torch.Tensor] = None, - is_chunked_prefill: bool = False, + use_triton_conv1d: bool = False, + cuda_graph_compatible: bool = False, ) -> torch.Tensor: """ Performs SSM computation for inference prefill step. @@ -708,18 +723,19 @@ def _ssm_prefill( ssm_state: The selective scan state tensor for inference. seq_idx: A map from token index to request index for variable-length sequences. cu_seqlens: Cumulative sequence lengths for variable-length sequences. - return_varlen_states: Whether to return variable-length states from the SSM kernel. batch_indices: A map from batch id to position in the Mamba state tensors for dynamic inference. - is_chunked_prefill: Whether the request is a chunked prefill request. + use_triton_conv1d: Whether to use the Triton varlen conv1d kernel instead + of per-request causal_conv1d_fn calls. + cuda_graph_compatible: When True, avoids GPU-to-CPU sync (.item()) by + using causal_conv1d_fn with seq_idx and mamba_chunk_scan_combined + instead of the varlen kernel. Used during CUDA graph capture when + no requests have initial states. Returns: The output tensor of shape (l, b, d). """ is_dynamic_batching = seq_idx is not None - assert not ( - is_dynamic_batching and is_chunked_prefill - ), "Cannot use chunked prefill with dynamic batching" # transpose: l b pd --> b l pd zxBCdt = rearrange(zxBCdt, "l b d -> b l d").contiguous() @@ -738,55 +754,92 @@ def _ssm_prefill( ) # Compute short convolution - initial_conv_state = None if conv_state is not None and is_dynamic_batching: - # xBC should have shape (b l d) for causal_conv1d_varlen_states assert batch_indices is not None + + # Extract initial conv states BEFORE saving new ones. The conv_state + # buffer holds the previous state (zeros for fresh requests, cached + # values for restored requests). We must read it before overwriting. + initial_conv_states = conv_state[batch_indices, :, 1:] # (num_reqs, conv_dim, d_conv-1) + + # Save final conv states from the input sequence conv_varlen_states = causal_conv1d_varlen_states( xBC.squeeze(0), cu_seqlens, state_len=conv_state.shape[-1] ) tensor_masked_update(conv_state, batch_indices, conv_varlen_states) - # Maintain channels-last memory layout to use seq_idx for causal_conv1d_fn - # See https://github.com/Dao-AILab/causal-conv1d/blob/69e6dadc28b169a4c49cb86b586f64ee90242c70/csrc/causal_conv1d.cpp#L174 # pylint: disable=line-too-long - xBC = xBC.transpose(1, 2) - elif is_chunked_prefill: - # Maintain channels-last memory layout to use initial_states for causal_conv1d_fn - # See https://github.com/Dao-AILab/causal-conv1d/blob/69e6dadc28b169a4c49cb86b586f64ee90242c70/csrc/causal_conv1d.cpp#L200 # pylint: disable=line-too-long - assert batch_indices is not None - initial_conv_state = ( - conv_state[batch_indices, :, 1:].permute(0, 2, 1).contiguous().transpose(1, 2) - ) - xBC = xBC.transpose(1, 2) - tensor_masked_update( - conv_state, batch_indices, F.pad(xBC, (self.d_conv - xBC.shape[-1], 0)) - ) + conv_weight = rearrange(self.cp.get_conv1d_weight(), "d 1 w -> d w") + conv_bias = self.cp.get_conv1d_bias() + + if cuda_graph_compatible: + # CUDA graph capture: use causal_conv1d_fn with seq_idx (no .item() + # calls, handles padding via seq_idx). No initial states during + # graph capture, so zeroing at sequence boundaries is correct. + xBC = xBC.transpose(1, 2) # (1, L, D) -> (1, D, L) + xBC = causal_conv1d_fn( + x=xBC, + weight=conv_weight, + bias=conv_bias, + activation=self.activation, + seq_idx=seq_idx, + ) + xBC = rearrange(xBC, "b d l -> b l d").contiguous() + elif use_triton_conv1d: + from megatron.core.ssm.ops.causal_conv1d_varlen import causal_conv1d_varlen_fn + + xBC_out = causal_conv1d_varlen_fn( + x=xBC.squeeze(0).contiguous(), # (total_tokens, conv_dim) + weight=conv_weight, + bias=conv_bias, + cu_seqlens=cu_seqlens, + initial_states=initial_conv_states, + activation=self.activation, + ) + xBC = xBC_out.unsqueeze(0) # (1, total_tokens, conv_dim) + else: + # Per-request loop calling causal_conv1d_fn with initial_states. + # causal_conv1d_fn requires channels-last memory layout when using + # initial_states: create x as (1, seq_len, conv_dim) then transpose + # to get stride pattern (seq_len*conv_dim, 1, conv_dim). + num_requests = cu_seqlens.shape[0] - 1 + xBC_parts = [] + for r in range(num_requests): + start = cu_seqlens[r].item() + end = cu_seqlens[r + 1].item() + if end <= start: + continue + # xBC is (1, total_tokens, conv_dim); slice gives channels-last via transpose + xBC_r = xBC[:, start:end, :].transpose(1, 2) # channels-last (1, C, L) + init_r = initial_conv_states[r : r + 1] # (1, conv_dim, d_conv-1) + init_r = init_r.permute(0, 2, 1).contiguous().transpose(1, 2) # channels-last + xBC_r = causal_conv1d_fn( + x=xBC_r, + weight=conv_weight, + bias=conv_bias, + activation=self.activation, + initial_states=init_r, + ) + xBC_parts.append(xBC_r.transpose(1, 2).contiguous()) # (1, L, C) + xBC = torch.cat(xBC_parts, dim=1) # (1, total_tokens, conv_dim) else: - # transpose: b l pd --> b pd l + # Non-dynamic-batching path (static batching / training fallback) xBC = rearrange(xBC, "b l d -> b d l").contiguous() if conv_state is not None: - # If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv - # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise. - conv_state.copy_( - F.pad(xBC, (self.d_conv - xBC.shape[-1], 0)) - ) # Update state (B D W) - - seqlen = xBC.size(2) - if causal_conv1d_fn is None: - xBC = self.act(self.cp.conv1d(xBC)[..., :seqlen]) - else: - assert self.activation in ["silu", "swish"] - xBC = causal_conv1d_fn( - x=xBC, - weight=rearrange(self.cp.get_conv1d_weight(), "d 1 w -> d w"), - bias=self.cp.get_conv1d_bias(), - activation=self.activation, - seq_idx=seq_idx, - initial_states=initial_conv_state, - ) + conv_state.copy_(F.pad(xBC, (self.d_conv - xBC.shape[-1], 0))) - # transpose b pd l --> b l pd - xBC = rearrange(xBC, "b d l -> b l d").contiguous() + seqlen = xBC.size(2) + if causal_conv1d_fn is None: + xBC = self.act(self.cp.conv1d(xBC)[..., :seqlen]) + else: + assert self.activation in ["silu", "swish"] + xBC = causal_conv1d_fn( + x=xBC, + weight=rearrange(self.cp.get_conv1d_weight(), "d 1 w -> d w"), + bias=self.cp.get_conv1d_bias(), + activation=self.activation, + seq_idx=seq_idx, + ) + xBC = rearrange(xBC, "b d l -> b l d").contiguous() x, B, C = torch.split( xBC, @@ -798,66 +851,130 @@ def _ssm_prefill( dim=-1, ) - # TODO Vijay: fuse most of the transposes with the GEMMS x = rearrange(x, "b l (h p) -> b l h p", p=self.headdim).contiguous() dt = dt.contiguous() B = rearrange(B, "b l (g n) -> b l g n", n=self.d_state).contiguous() C = rearrange(C, "b l (g n) -> b l g n", n=self.d_state).contiguous() z = rearrange(z, "b l (h p) -> b l h p", p=self.headdim).contiguous() - # If `rmsnorm == False`, then the norm inside `mamba_chunk_scan_combined` will be used. - # In this case, if `cp_size > 1` then that norm could be performed on less heads than if - # `cp_size == 1` (groups of heads can be sharded across CP ranks), which would be - # mathematically incorrect, and potentially arithmetically unstable. assert ( self.cp.cp_size == 1 or self.rmsnorm ), "Context parallel not supported for use_mem_eff_path==False and rmsnorm==False" - if is_chunked_prefill: + if is_dynamic_batching and cuda_graph_compatible: + # CUDA graph capture: use mamba_chunk_scan_combined with seq_idx + # (no .item() calls, handles padding natively). Tensors keep their + # batch dimension (b=1), matching the non-dynamic-batching path. + y = mamba_chunk_scan_combined( + x, + dt, + A, + B, + C, + self.chunk_size, + D=( + rearrange(self.cp.get_D().float(), "(h p) -> h p", p=self.headdim) + if self.D_has_hdim + else self.cp.get_D() + ), + z=z if not self.rmsnorm else None, + dt_bias=self.cp.get_dt_bias().float(), + dt_softplus=True, + return_final_states=True, + seq_idx=seq_idx, + cu_seqlens=cu_seqlens, + return_varlen_states=True, + initial_states=None, + ) + y, _, ssm_varlen_states = y + tensor_masked_update(ssm_state, batch_indices, ssm_varlen_states) + elif is_dynamic_batching: + # Unified varlen SSM path: all prefill requests through single kernel call initial_ssm_state = ssm_state[batch_indices] - else: - initial_ssm_state = None - # Note that both `seq_idx` and `cu_seqlens` must be passed in - # for variable length generation. - # See https://github.com/state-spaces/mamba/blob/e0761ece1db07e0949dd88b4f4cd440420a19fd9/tests/test_generation.py#L97 # pylint: disable=line-too-long - y = mamba_chunk_scan_combined( - x, - dt, - A, - B, - C, - self.chunk_size, - D=( - rearrange(self.cp.get_D().float(), "(h p) -> h p", p=self.headdim) - if self.D_has_hdim - else self.cp.get_D() - ), - z=z if not self.rmsnorm else None, - dt_bias=self.cp.get_dt_bias().float(), - dt_softplus=True, - return_final_states=ssm_state is not None, - seq_idx=seq_idx, - cu_seqlens=cu_seqlens, - return_varlen_states=return_varlen_states, - initial_states=initial_ssm_state, - ) + x = x.squeeze(0) + dt = dt.squeeze(0) + A = A.squeeze(0) + B = B.squeeze(0) + C = C.squeeze(0) + z = z.squeeze(0) + y = torch.empty_like(x) + + # Build cu_chunk_seqlens: subdivide each sequence into chunks of + # at most self.chunk_size tokens. The SSM kernels allocate per-chunk + # output arrays of size chunk_size, so passing cu_seqlens directly + # would cause out-of-bounds access when sequences are longer than + # chunk_size. + chunk_boundaries = [0] + num_seqs = cu_seqlens.numel() - 1 + for i in range(num_seqs): + start = cu_seqlens[i].item() + end = cu_seqlens[i + 1].item() + pos = start + self.chunk_size + while pos < end: + chunk_boundaries.append(pos) + pos += self.chunk_size + chunk_boundaries.append(end) + cu_chunk_seqlens = cu_seqlens.new_tensor(chunk_boundaries) + + seq_idx_for_varlen = None + if seq_idx is not None: + chunk_starts = cu_chunk_seqlens[:-1] + seq_idx_for_varlen = seq_idx[0, chunk_starts].contiguous() + + ssm_varlen_states = mamba_chunk_scan_combined_varlen( + x=x, + dt=dt, + A=A, + B=B, + C=C, + chunk_size=self.chunk_size, + cu_seqlens=cu_seqlens, + cu_chunk_seqlens=cu_chunk_seqlens, + last_chunk_indices=None, + seq_idx=seq_idx_for_varlen, + out=y, + D=( + rearrange(self.cp.get_D().float(), "(h p) -> h p", p=self.headdim) + if self.D_has_hdim + else self.cp.get_D() + ), + z=z if not self.rmsnorm else None, + dt_bias=self.cp.get_dt_bias().float(), + initial_states=initial_ssm_state, + return_intermediate_states=False, + dt_softplus=True, + dt_limit=(0.0, float("inf")), + state_dtype=ssm_state.dtype, + ) - if ssm_state is not None: - if return_varlen_states: - assert batch_indices is not None + y = y.unsqueeze(0) + z = z.unsqueeze(0) - y, _, ssm_varlen_states = y + tensor_masked_update(ssm_state, batch_indices, ssm_varlen_states) + else: + # Non-dynamic-batching path (static batching) + initial_ssm_state = None + y = mamba_chunk_scan_combined( + x, + dt, + A, + B, + C, + self.chunk_size, + D=( + rearrange(self.cp.get_D().float(), "(h p) -> h p", p=self.headdim) + if self.D_has_hdim + else self.cp.get_D() + ), + z=z if not self.rmsnorm else None, + dt_bias=self.cp.get_dt_bias().float(), + dt_softplus=True, + return_final_states=ssm_state is not None, + initial_states=initial_ssm_state, + ) - # This has to be varlen_states, NOT last_state - # See reference implementation: - # https://github.com/state-spaces/mamba/blob/e0761ece1db07e0949dd88b4f4cd440420a19fd9/mamba_ssm/modules/mamba2.py#L267 # pylint: disable=line-too-long - tensor_masked_update(ssm_state, batch_indices, ssm_varlen_states) - elif is_chunked_prefill: - assert batch_indices is not None - y, last_state = y - tensor_masked_update(ssm_state, batch_indices, last_state) - else: + if ssm_state is not None: y, last_state = y ssm_state.copy_(last_state) diff --git a/megatron/core/ssm/ops/__init__.py b/megatron/core/ssm/ops/__init__.py new file mode 100644 index 00000000000..03b4a09a529 --- /dev/null +++ b/megatron/core/ssm/ops/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# Triton kernels for Mamba SSM (adapted from vLLM / state-spaces/mamba). + +from .ssd_combined import mamba_chunk_scan_combined_varlen + +__all__ = ["mamba_chunk_scan_combined_varlen"] diff --git a/megatron/core/ssm/ops/causal_conv1d_varlen.py b/megatron/core/ssm/ops/causal_conv1d_varlen.py new file mode 100644 index 00000000000..4c2c4c05c9d --- /dev/null +++ b/megatron/core/ssm/ops/causal_conv1d_varlen.py @@ -0,0 +1,240 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +"""Triton varlen depthwise causal 1D convolution with per-sequence initial states and fused SiLU. + +Supports packed variable-length sequences where `causal_conv1d_fn` cannot accept +both `seq_idx` and `initial_states` simultaneously. +""" + +import torch +import triton +import triton.language as tl + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_T": 128, "BLOCK_C": 64}, num_warps=4), + triton.Config({"BLOCK_T": 128, "BLOCK_C": 128}, num_warps=4), + triton.Config({"BLOCK_T": 256, "BLOCK_C": 64}, num_warps=4), + triton.Config({"BLOCK_T": 256, "BLOCK_C": 128}, num_warps=8), + ], + key=["conv_dim"], +) +@triton.jit +def _causal_conv1d_varlen_kernel( + x_ptr, + weight_ptr, + bias_ptr, + seq_idx_ptr, + seq_start_ptr, + initial_states_ptr, + out_ptr, + total_tokens, + conv_dim: tl.constexpr, + initial_states_stride_req, + initial_states_stride_dim, + WIDTH: tl.constexpr, + BLOCK_T: tl.constexpr, + BLOCK_C: tl.constexpr, + HAS_INITIAL_STATES: tl.constexpr, +): + """Depthwise causal conv1d over packed varlen sequences with initial states and SiLU. + + Fully vectorized over BLOCK_T tokens x BLOCK_C channels per thread block. + """ + pid_c = tl.program_id(0) + pid_t = tl.program_id(1) + + c_off = pid_c * BLOCK_C + tl.arange(0, BLOCK_C) # (BLOCK_C,) + c_mask = c_off < conv_dim + t_off = pid_t * BLOCK_T + tl.arange(0, BLOCK_T) # (BLOCK_T,) + t_mask = t_off < total_tokens + + # Load bias: (BLOCK_C,) broadcast to (BLOCK_T, BLOCK_C) + bias = tl.load(bias_ptr + c_off, mask=c_mask, other=0.0).to(tl.float32) + acc = tl.zeros((BLOCK_T, BLOCK_C), dtype=tl.float32) + bias[None, :] + + # Load per-token request ID and request start position + req_id = tl.load(seq_idx_ptr + t_off, mask=t_mask, other=0) # (BLOCK_T,) + req_start = tl.load(seq_start_ptr + t_off, mask=t_mask, other=0) # (BLOCK_T,) + + # Unrolled convolution over WIDTH taps (typically 4) + for j in tl.static_range(WIDTH): + # Load weight column j: (BLOCK_C,) + w_j = tl.load(weight_ptr + c_off * WIDTH + j, mask=c_mask, other=0.0).to(tl.float32) + + # Source position for this tap + src = t_off - (WIDTH - 1) + j # (BLOCK_T,) + in_seq = src >= req_start # (BLOCK_T,) — True if source is within the sequence + + # Load from x for in-sequence positions (mask out out-of-bounds) + src_safe = tl.maximum(src, 0) + x_val = tl.load( + x_ptr + src_safe[:, None] * conv_dim + c_off[None, :], + mask=t_mask[:, None] & c_mask[None, :] & in_seq[:, None], + other=0.0, + ).to(tl.float32) # (BLOCK_T, BLOCK_C) + + if HAS_INITIAL_STATES: + # For tokens where src < req_start, load from initial_states + state_col = (WIDTH - 1) - (req_start - src) # (BLOCK_T,) + valid_state = (~in_seq) & (state_col >= 0) # (BLOCK_T,) + state_col_safe = tl.maximum(state_col, 0) + + state_val = tl.load( + initial_states_ptr + + req_id[:, None] * initial_states_stride_req + + c_off[None, :] * initial_states_stride_dim + + state_col_safe[:, None], + mask=t_mask[:, None] & c_mask[None, :] & valid_state[:, None], + other=0.0, + ).to(tl.float32) # (BLOCK_T, BLOCK_C) + + tap = tl.where(in_seq[:, None], x_val, state_val) + else: + tap = x_val + + acc += tap * w_j[None, :] + + # SiLU activation: x * sigmoid(x) + sigmoid_acc = 1.0 / (1.0 + tl.exp(-acc)) + result = acc * sigmoid_acc + + # Store output (cast back to input dtype) + tl.store( + out_ptr + t_off[:, None] * conv_dim + c_off[None, :], + result, + mask=t_mask[:, None] & c_mask[None, :], + ) + + +def causal_conv1d_varlen_fn( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + cu_seqlens: torch.Tensor, + initial_states: torch.Tensor = None, + activation: str = "silu", +) -> torch.Tensor: + """Depthwise causal 1D convolution over packed variable-length sequences. + + Supports both `cu_seqlens` (sequence boundaries) and `initial_states` + simultaneously, unlike `causal_conv1d_fn` which requires mutual exclusivity + between `seq_idx` and `initial_states`. + + Args: + x: Input tensor of shape (total_tokens, conv_dim), channels-last packed. + weight: Convolution weights of shape (conv_dim, d_conv). + bias: Bias of shape (conv_dim,). + cu_seqlens: Cumulative sequence lengths of shape (num_requests + 1,), int32. + initial_states: Per-request initial conv states of shape + (num_requests, conv_dim, d_conv - 1). If None, uses zeros. + activation: Activation function, must be "silu". + + Returns: + Output tensor of shape (total_tokens, conv_dim). + """ + assert activation == "silu", f"Only silu activation is supported, got {activation}" + assert x.is_contiguous(), "x must be contiguous" + assert weight.is_contiguous(), "weight must be contiguous" + + total_tokens, conv_dim = x.shape + d_conv = weight.shape[1] + num_requests = cu_seqlens.shape[0] - 1 + + out = torch.empty_like(x) + + # Precompute per-token seq_idx and seq_start from cu_seqlens. + # seq_idx[t] = request ID for token t + # seq_start[t] = start position of the request containing token t + seq_lengths = cu_seqlens[1:] - cu_seqlens[:-1] + seq_idx = torch.repeat_interleave( + torch.arange(num_requests, device=x.device, dtype=torch.int32), seq_lengths + ) + seq_start = torch.repeat_interleave(cu_seqlens[:-1], seq_lengths).to(torch.int32) + + has_initial_states = initial_states is not None + if not has_initial_states: + initial_states = torch.empty((1, 1, 1), dtype=x.dtype, device=x.device) + is_stride_req = 1 + is_stride_dim = 1 + else: + assert initial_states.shape == (num_requests, conv_dim, d_conv - 1) + is_stride_req = initial_states.stride(0) + is_stride_dim = initial_states.stride(1) + + grid = lambda meta: ( + triton.cdiv(conv_dim, meta["BLOCK_C"]), + triton.cdiv(total_tokens, meta["BLOCK_T"]), + ) + + _causal_conv1d_varlen_kernel[grid]( + x, + weight, + bias, + seq_idx, + seq_start, + initial_states, + out, + total_tokens, + conv_dim, + is_stride_req, + is_stride_dim, + WIDTH=d_conv, + HAS_INITIAL_STATES=has_initial_states, + ) + + return out + + +def _causal_conv1d_varlen_simple( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + cu_seqlens: torch.Tensor, + initial_states: torch.Tensor, + out: torch.Tensor, +) -> None: + """Simple PyTorch implementation of varlen causal conv1d with initial states and SiLU. + + This is a reference implementation for testing. Processes each request and token + sequentially. + """ + total_tokens, conv_dim = x.shape + d_conv = weight.shape[1] + num_requests = cu_seqlens.shape[0] - 1 + + for r in range(num_requests): + start = cu_seqlens[r].item() + end = cu_seqlens[r + 1].item() + seq_len = end - start + + if seq_len == 0: + continue + + if initial_states is not None: + init_state = initial_states[r] # (conv_dim, d_conv - 1) + else: + init_state = torch.zeros( + (conv_dim, d_conv - 1), dtype=x.dtype, device=x.device + ) + + x_seq = x[start:end] # (seq_len, conv_dim) + + for t in range(seq_len): + acc = bias.float() # (conv_dim,) + for j in range(d_conv): + src_pos = t - (d_conv - 1) + j + if src_pos < 0: + state_col = (d_conv - 1) + src_pos + if state_col >= 0 and state_col < d_conv - 1: + tap = init_state[:, state_col].float() + else: + tap = torch.zeros(conv_dim, dtype=torch.float32, device=x.device) + else: + tap = x_seq[src_pos].float() + + acc = acc + tap * weight[:, j].float() + + result = acc * torch.sigmoid(acc) + out[start + t] = result.to(out.dtype) diff --git a/megatron/core/ssm/ops/ssd_bmm.py b/megatron/core/ssm/ops/ssd_bmm.py new file mode 100644 index 00000000000..57731ba5f98 --- /dev/null +++ b/megatron/core/ssm/ops/ssd_bmm.py @@ -0,0 +1,207 @@ +# Copyright (c) 2024, Tri Dao, Albert Gu. +# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_bmm.py +# Adapted from vLLM project (Apache-2.0). + +import torch +import triton +import triton.language as tl + + +@triton.autotune( + configs=[ + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64}, + num_stages=3, + num_warps=8, + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, + num_stages=5, + num_warps=2, + ), + triton.Config( + {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, + num_stages=5, + num_warps=2, + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, + num_stages=4, + num_warps=2, + ), + ], + key=["chunk_size", "K", "IS_CAUSAL"], +) +@triton.jit +def _bmm_chunk_fwd_kernel( + # Pointers to matrices + a_ptr, + b_ptr, + out_ptr, + cu_chunk_seqlens_ptr, + # Matrix dimensions + seqlen, + chunk_size: tl.constexpr, + K: tl.constexpr, + ngroups: tl.constexpr, + stride_a_seqlen: tl.int64, + stride_a_head: tl.int64, + stride_ak: tl.constexpr, + stride_b_seqlen: tl.int64, + stride_b_head: tl.int64, + stride_bk: tl.constexpr, + stride_out_chunk: tl.int64, + stride_out_head: tl.int64, + stride_outm: tl.int64, + stride_outn: tl.constexpr, + # Meta-parameters + IS_CAUSAL: tl.constexpr, + dot_dtype: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +): + pid_ch = tl.program_id(axis=1).to(tl.int64) + pid_c = pid_ch // ngroups + pid_h = pid_ch - pid_c * ngroups + num_pid_n = tl.cdiv(chunk_size, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + if IS_CAUSAL: + if pid_n * BLOCK_SIZE_N >= (pid_m + 1) * BLOCK_SIZE_M: + return + + chunk_seqlen_start = tl.load(cu_chunk_seqlens_ptr + pid_c) + chunk_seqlen_end = tl.load(cu_chunk_seqlens_ptr + pid_c + 1) + + a_ptr += chunk_seqlen_start * stride_a_seqlen + pid_h * stride_a_head + b_ptr += chunk_seqlen_start * stride_b_seqlen + pid_h * stride_b_head + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_m[:, None] * stride_a_seqlen + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_b_seqlen) + chunk_size_limit = chunk_seqlen_end - chunk_seqlen_start + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # compute a * b.T + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load( + a_ptrs, + mask=(offs_m[:, None] < chunk_size_limit) + & (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0, + ).to(dot_dtype) + b = tl.load( + b_ptrs, + mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K) + & (offs_n[None, :] < chunk_size_limit), + other=0.0, + ).to(dot_dtype) + acc += tl.dot(a, b) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + out = acc.to(out_ptr.dtype.element_ty) + out_ptr += pid_c * stride_out_chunk + pid_h * stride_out_head + out_ptrs = out_ptr + (stride_outm * offs_m[:, None] + offs_n[None, :] * stride_outn) + tl.store( + out_ptrs, + out, + mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size), + ) + + +def _bmm_chunk_fwd(a, b, chunk_size, cu_chunk_seqlens, causal=False, output_dtype=None): + """ + Argument: + a: (seqlen, ngroups, k) + b: (seqlen, ngroups, k) + chunk_size: int + cu_chunk_seq_lens: (nchunks+1,) + causal: if True, then out[i, j] for i > j will be arbitrary, only out[i, j] for i <= j are + guaranteed to be correct. + Return: + out: (nchunks, ngroups, chunk_size, chunk_size) + """ + seqlen, ngroups, k = a.shape + assert b.shape == a.shape + if a.stride(-1) != 1 and a.stride(0) != 1: + a = a.contiguous() + if b.stride(-1) != 1 and b.stride(0) != 1: + b = b.contiguous() + + nchunks = len(cu_chunk_seqlens) - 1 + # Allocates output. + out_dtype = a.dtype if output_dtype is None else output_dtype + out = torch.empty( + (nchunks, ngroups, chunk_size, chunk_size), device=a.device, dtype=out_dtype + ) + dot_dtype = ( + tl.bfloat16 + if a.dtype == torch.bfloat16 or b.dtype == torch.bfloat16 + else ( + tl.float16 + if a.dtype == torch.float16 or b.dtype == torch.float16 + else tl.float32 + ) + ) + grid = lambda META: ( + triton.cdiv(chunk_size, META["BLOCK_SIZE_M"]) + * triton.cdiv(chunk_size, META["BLOCK_SIZE_N"]), + nchunks * ngroups, + ) + with torch.cuda.device(a.device.index): + _bmm_chunk_fwd_kernel[grid]( + a_ptr=a, + b_ptr=b, + out_ptr=out, + cu_chunk_seqlens_ptr=cu_chunk_seqlens, + seqlen=seqlen, + chunk_size=chunk_size, + K=k, + ngroups=ngroups, + stride_a_seqlen=a.stride(0), + stride_a_head=a.stride(1), + stride_ak=a.stride(2), + stride_b_seqlen=b.stride(0), + stride_b_head=b.stride(1), + stride_bk=b.stride(2), + stride_out_chunk=out.stride(0), + stride_out_head=out.stride(1), + stride_outm=out.stride(-2), + stride_outn=out.stride(-1), + IS_CAUSAL=causal, + dot_dtype=dot_dtype, + ) + return out diff --git a/megatron/core/ssm/ops/ssd_chunk_scan.py b/megatron/core/ssm/ops/ssd_chunk_scan.py new file mode 100644 index 00000000000..a1715935c97 --- /dev/null +++ b/megatron/core/ssm/ops/ssd_chunk_scan.py @@ -0,0 +1,453 @@ +# Copyright (c) 2024, Tri Dao, Albert Gu. +# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_chunk_scan.py +# Adapted from vLLM project (Apache-2.0). + +from packaging import version + +import triton +import triton.language as tl + +TRITON_22 = version.parse(triton.__version__) >= version.parse("2.2.0") + + +@triton.autotune( + configs=[ + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64}, + num_stages=3, + num_warps=8, + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, + num_stages=5, + num_warps=2, + ), + triton.Config( + {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, + num_stages=5, + num_warps=2, + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, + num_stages=4, + num_warps=2, + ), + ], + key=["chunk_size", "hdim", "dstate", "IS_CAUSAL"], +) +@triton.jit +def _chunk_scan_fwd_kernel( + # Pointers to matrices + cb_ptr, + x_ptr, + z_ptr, + out_ptr, + dt_ptr, + dA_cumsum_ptr, + seq_idx_ptr, + C_ptr, + states_ptr, + D_ptr, + initstates_ptr, + cu_chunk_seqlens_ptr, + # Matrix dimensions + chunk_size: tl.constexpr, + hdim: tl.constexpr, + dstate: tl.constexpr, + seqlen, + nheads_ngroups_ratio: tl.constexpr, + # Strides + stride_cb_chunk: tl.int64, + stride_cb_head: tl.int64, + stride_cb_csize_m: tl.int64, + stride_cb_csize_k: tl.constexpr, + stride_x_seqlen: tl.int64, + stride_x_head: tl.int64, + stride_x_hdim: tl.constexpr, + stride_z_seqlen: tl.int64, + stride_z_head: tl.int64, + stride_z_hdim: tl.constexpr, + stride_out_seqlen: tl.int64, + stride_out_head: tl.int64, + stride_out_hdim: tl.constexpr, + stride_dt_chunk: tl.int64, + stride_dt_head: tl.int64, + stride_dt_csize: tl.constexpr, + stride_dA_cs_chunk: tl.int64, + stride_dA_cs_head: tl.int64, + stride_dA_cs_csize: tl.constexpr, + stride_seq_idx_chunk: tl.constexpr, + stride_C_seqlen: tl.int64, + stride_C_head: tl.int64, + stride_C_dstate: tl.constexpr, + stride_states_chunk: tl.int64, + stride_states_head: tl.int64, + stride_states_hdim: tl.int64, + stride_states_dstate: tl.constexpr, + stride_init_states_batch: tl.int64, + stride_init_states_head: tl.int64, + stride_init_states_hdim: tl.int64, + stride_init_states_dstate: tl.constexpr, + stride_D_head: tl.constexpr, + # Meta-parameters + IS_CAUSAL: tl.constexpr, + HAS_D: tl.constexpr, + D_HAS_HDIM: tl.constexpr, + HAS_Z: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_DSTATE: tl.constexpr, + IS_TRITON_22: tl.constexpr, + HAS_INITSTATES: tl.constexpr, +): + pid_c = tl.program_id(axis=1).to(tl.int64) + pid_h = tl.program_id(axis=2) + num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + cb_ptr += pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head + chunk_seqlen_start = tl.load(cu_chunk_seqlens_ptr + pid_c) + chunk_seqlen_end = tl.load(cu_chunk_seqlens_ptr + pid_c + 1) + x_ptr += chunk_seqlen_start * stride_x_seqlen + pid_h * stride_x_head + dt_ptr += pid_c * stride_dt_chunk + pid_h * stride_dt_head + dA_cumsum_ptr += pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head + C_ptr += ( + chunk_seqlen_start * stride_C_seqlen + + (pid_h // nheads_ngroups_ratio) * stride_C_head + ) + + # M-block offsets and prev states + # - logic in next block may override these if there is an active offset + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + + seq_idx_ptr += pid_c * stride_seq_idx_chunk + seq_idx = tl.load(seq_idx_ptr) + seq_idx_prev = tl.load( + seq_idx_ptr - stride_seq_idx_chunk, mask=pid_c >= 1, other=-1 + ) + + if HAS_INITSTATES and (seq_idx != seq_idx_prev): + prev_states_ptr = ( + initstates_ptr + + seq_idx * stride_init_states_batch + + pid_h * stride_init_states_head + ) + prev_states_hdim = stride_init_states_hdim + prev_states_dstate = stride_init_states_dstate + else: + prev_states_ptr = ( + states_ptr + (pid_c - 1) * stride_states_chunk + pid_h * stride_states_head + ) + prev_states_hdim = stride_states_hdim + prev_states_dstate = stride_states_dstate + + chunk_size_limit = chunk_seqlen_end - chunk_seqlen_start + + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + dA_cs_m = tl.load( + dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0 + ).to(tl.float32) + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + offs_out_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_out_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + # Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128 + offs_k_dstate = tl.arange( + 0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K + ) + C_ptrs = C_ptr + ( + offs_m[:, None] * stride_C_seqlen + offs_k_dstate[None, :] * stride_C_dstate + ) + + scale_m = tl.exp(dA_cs_m) + if BLOCK_SIZE_DSTATE <= 128: + C = tl.load( + C_ptrs, + mask=(offs_m[:, None] < chunk_size_limit) + & (offs_k_dstate[None, :] < dstate), + other=0.0, + ) + + if not HAS_INITSTATES and (seq_idx != seq_idx_prev): + # if no init states AND starting a new sequence, we need zeros + prev_states = tl.zeros( + (BLOCK_SIZE_DSTATE, BLOCK_SIZE_N), dtype=C_ptr.dtype.element_ty + ) + else: + # otherwise read the previous state + prev_states_ptrs = ( + prev_states_ptr + + offs_n[None, :] * prev_states_hdim + + offs_k_dstate[:, None] * prev_states_dstate + ) + prev_states = tl.load( + prev_states_ptrs, + mask=(offs_k_dstate[:, None] < dstate) & (offs_n[None, :] < hdim), + other=0.0, + ) + prev_states = prev_states.to(C_ptr.dtype.element_ty) + + acc = tl.dot(C, prev_states) * scale_m[:, None] + + else: + prev_states_ptrs = ( + prev_states_ptr + + offs_n[None, :] * prev_states_hdim + + offs_k_dstate[:, None] * prev_states_dstate + ) + for k in range(0, dstate, BLOCK_SIZE_K): + C = tl.load( + C_ptrs, + mask=(offs_m[:, None] < chunk_size_limit) + & (offs_k_dstate[None, :] < dstate - k), + other=0.0, + ) + if not HAS_INITSTATES and (seq_idx != seq_idx_prev): + prev_states = tl.zeros( + (BLOCK_SIZE_K, BLOCK_SIZE_N), dtype=C_ptr.dtype.element_ty + ) + else: + prev_states = tl.load( + prev_states_ptrs, + mask=(offs_k_dstate[:, None] < dstate - k) + & (offs_n[None, :] < hdim), + other=0.0, + ) + prev_states = prev_states.to(C_ptr.dtype.element_ty) + acc += tl.dot(C, prev_states) + C_ptrs += BLOCK_SIZE_K + prev_states_ptrs += BLOCK_SIZE_K + acc *= scale_m[:, None] + + offs_k = tl.arange(0, BLOCK_SIZE_K) + cb_ptrs = cb_ptr + ( + offs_m[:, None] * stride_cb_csize_m + offs_k[None, :] * stride_cb_csize_k + ) + x_ptrs = x_ptr + ( + offs_k[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim + ) + dt_ptrs = dt_ptr + offs_k * stride_dt_csize + dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize + K_MAX = ( + chunk_size_limit + if not IS_CAUSAL + else min((pid_m + 1) * BLOCK_SIZE_M, chunk_size_limit) + ) + for k in range(0, K_MAX, BLOCK_SIZE_K): + cb = tl.load( + cb_ptrs, + mask=(offs_m[:, None] < chunk_size) & (offs_k[None, :] < chunk_size - k), + other=0.0, + ).to(tl.float32) + dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size - k, other=0.0).to( + tl.float32 + ) + # If there's seq_idx, we already set cb[i, j] = 0 for seq_idx[i] != seq_idx[j]. + # So we don't need masking wrt seq_idx here. + cb *= tl.exp(dA_cs_m[:, None] - dA_cs_k[None, :]) + dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(tl.float32) + cb *= dt_k + if IS_CAUSAL: + mask = offs_m[:, None] >= k + offs_k[None, :] + cb = tl.where(mask, cb, 0.0) + cb = cb.to(x_ptr.dtype.element_ty) + x = tl.load( + x_ptrs, + mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < hdim), + other=0.0, + ) + acc += tl.dot(cb, x) + cb_ptrs += BLOCK_SIZE_K * stride_cb_csize_k + x_ptrs += BLOCK_SIZE_K * stride_x_seqlen + dt_ptrs += BLOCK_SIZE_K * stride_dt_csize + dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize + + offs_out_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_out_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + if HAS_D: + if D_HAS_HDIM: + D = tl.load( + D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0 + ).to(tl.float32) + else: + D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32) + x_residual = tl.load( + x_ptr + + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim), + mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), + other=0.0, + ).to(tl.float32) + acc += x_residual * D + + if HAS_Z: + z_ptr += chunk_seqlen_start * stride_z_seqlen + pid_h * stride_z_head + z_ptrs = z_ptr + ( + stride_z_seqlen * offs_out_m[:, None] + stride_z_hdim * offs_out_n[None, :] + ) + z = tl.load( + z_ptrs, + mask=(offs_out_m[:, None] < chunk_size_limit) + & (offs_out_n[None, :] < hdim), + other=0.0, + ).to(tl.float32) + acc *= z * tl.sigmoid(z) + + out_ptr += chunk_seqlen_start * stride_out_seqlen + pid_h * stride_out_head + out_ptrs = out_ptr + ( + stride_out_seqlen * offs_out_m[:, None] + offs_out_n[None, :] * stride_out_hdim + ) + tl.store( + out_ptrs, + acc, + mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim), + ) + + +def _chunk_scan_fwd( + cb, + x, + dt, + dA_cumsum, + C, + states, + cu_chunk_seqlens, + out, + seq_idx, + D=None, + z=None, + initial_states=None, +): + assert seq_idx is not None, "this implementation requires seq_idx" + + seqlen, nheads, headdim = x.shape + _, nchunks, chunk_size = dt.shape + _, ngroups, dstate = C.shape + assert nheads % ngroups == 0 + assert C.shape == (seqlen, ngroups, dstate) + assert cb.shape == (nchunks, ngroups, chunk_size, chunk_size) + if D is not None: + assert D.shape == (nheads, headdim) or D.shape == (nheads,) + if z is not None: + assert z.shape == x.shape + assert dt.shape == (nheads, nchunks, chunk_size) + assert dA_cumsum.shape == (nheads, nchunks, chunk_size) + assert states.shape == (nchunks, nheads, headdim, dstate) + assert seq_idx.shape == (nchunks,) + + grid = lambda META: ( + triton.cdiv(chunk_size, META["BLOCK_SIZE_M"]) + * triton.cdiv(headdim, META["BLOCK_SIZE_N"]), + nchunks, + nheads, + ) + + z_strides = (z.stride(0), z.stride(1), z.stride(2)) if z is not None else (0, 0, 0) + initial_states_strides = ( + ( + initial_states.stride(0), + initial_states.stride(1), + initial_states.stride(2), + initial_states.stride(3), + ) + if initial_states is not None + else (0, 0, 0, 0) + ) + + _chunk_scan_fwd_kernel[grid]( + cb_ptr=cb, + x_ptr=x, + z_ptr=z, + out_ptr=out, + dt_ptr=dt, + dA_cumsum_ptr=dA_cumsum, + seq_idx_ptr=seq_idx, + C_ptr=C, + states_ptr=states, + D_ptr=D, + initstates_ptr=initial_states, + cu_chunk_seqlens_ptr=cu_chunk_seqlens, + chunk_size=chunk_size, + hdim=headdim, + dstate=dstate, + seqlen=seqlen, + nheads_ngroups_ratio=nheads // ngroups, + stride_cb_chunk=cb.stride(0), + stride_cb_head=cb.stride(1), + stride_cb_csize_m=cb.stride(2), + stride_cb_csize_k=cb.stride(3), + stride_x_seqlen=x.stride(0), + stride_x_head=x.stride(1), + stride_x_hdim=x.stride(2), + stride_z_seqlen=z_strides[0], + stride_z_head=z_strides[1], + stride_z_hdim=z_strides[2], + stride_out_seqlen=out.stride(0), + stride_out_head=out.stride(1), + stride_out_hdim=out.stride(2), + stride_dt_chunk=dt.stride(1), + stride_dt_head=dt.stride(0), + stride_dt_csize=dt.stride(2), + stride_dA_cs_chunk=dA_cumsum.stride(1), + stride_dA_cs_head=dA_cumsum.stride(0), + stride_dA_cs_csize=dA_cumsum.stride(2), + stride_seq_idx_chunk=seq_idx.stride(0), + stride_C_seqlen=C.stride(0), + stride_C_head=C.stride(1), + stride_C_dstate=C.stride(2), + stride_states_chunk=states.stride(0), + stride_states_head=states.stride(1), + stride_states_hdim=states.stride(2), + stride_states_dstate=states.stride(3), + stride_init_states_batch=initial_states_strides[0], + stride_init_states_head=initial_states_strides[1], + stride_init_states_hdim=initial_states_strides[2], + stride_init_states_dstate=initial_states_strides[3], + stride_D_head=D.stride(0) if D is not None else 0, + IS_CAUSAL=True, + HAS_D=D is not None, + D_HAS_HDIM=D.dim() == 2 if D is not None else True, + HAS_Z=z is not None, + BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16), + IS_TRITON_22=TRITON_22, + HAS_INITSTATES=initial_states is not None, + ) + return diff --git a/megatron/core/ssm/ops/ssd_chunk_state.py b/megatron/core/ssm/ops/ssd_chunk_state.py new file mode 100644 index 00000000000..9e2fdaf867b --- /dev/null +++ b/megatron/core/ssm/ops/ssd_chunk_state.py @@ -0,0 +1,718 @@ +# Copyright (c) 2024, Tri Dao, Albert Gu. +# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_chunk_state.py +# Adapted from vLLM project (Apache-2.0). + +import torch +import triton +import triton.language as tl +from packaging import version + +try: + TRITON3 = version.parse(triton.__version__) >= version.parse("3.0.0") +except: + raise ImportError("Triton version 3.0.0 or higher is required") + +if TRITON3: + @triton.jit + def softplus(dt): + dt = tl.where(dt <= 20.0, tl.math.log(tl.math.exp(dt) + 1), dt) + return dt +else: + @triton.jit + def softplus(dt): + dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt) + return dt + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE_H": 2}), + triton.Config({"BLOCK_SIZE_H": 4}), + triton.Config({"BLOCK_SIZE_H": 8}), + triton.Config({"BLOCK_SIZE_H": 16}), + triton.Config({"BLOCK_SIZE_H": 32}), + triton.Config({"BLOCK_SIZE_H": 64}), + ], + key=["chunk_size", "nheads"], +) +@triton.jit +def _chunk_cumsum_fwd_kernel( + # Pointers to matrices + dt_ptr, + A_ptr, + dt_bias_ptr, + dt_out_ptr, + dA_cumsum_ptr, + cu_chunk_seqlens_ptr, + # Matrix dimension + seqlen, + nheads: tl.constexpr, + chunk_size: tl.constexpr, + dt_min: tl.constexpr, + dt_max: tl.constexpr, + # Strides + stride_dt_seqlen: tl.int64, + stride_dt_head: tl.constexpr, + stride_A_head: tl.constexpr, + stride_dt_bias_head: tl.constexpr, + stride_dt_out_head: tl.int64, + stride_dt_out_chunk: tl.int64, + stride_dt_out_csize: tl.constexpr, + stride_dA_cs_head: tl.int64, + stride_dA_cs_chunk: tl.int64, + stride_dA_cs_csize: tl.constexpr, + # Meta-parameters + DT_SOFTPLUS: tl.constexpr, + HAS_DT_BIAS: tl.constexpr, + BLOCK_SIZE_H: tl.constexpr, + BLOCK_SIZE_CHUNK: tl.constexpr, +): + # if dt is long, may cause problems, so use 64 bit + # https://github.com/triton-lang/triton/issues/1058 + pid_c = tl.program_id(axis=0).to(tl.int64) + pid_h = tl.program_id(axis=1) + + chunk_seqlen_start = tl.load(cu_chunk_seqlens_ptr + pid_c) + chunk_seqlen_end = tl.load(cu_chunk_seqlens_ptr + pid_c + 1) + + dt_ptr += chunk_seqlen_start * stride_dt_seqlen + dt_out_ptr += pid_c * stride_dt_out_chunk + dA_cumsum_ptr += pid_c * stride_dA_cs_chunk + + offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H) + offs_c = tl.arange(0, BLOCK_SIZE_CHUNK) + dt_ptrs = dt_ptr + ( + offs_h[:, None] * stride_dt_head + offs_c[None, :] * stride_dt_seqlen + ) + A_ptrs = A_ptr + offs_h * stride_A_head + dt_out_ptrs = dt_out_ptr + ( + offs_h[:, None] * stride_dt_out_head + offs_c[None, :] * stride_dt_out_csize + ) + dA_cs_ptrs = dA_cumsum_ptr + ( + offs_h[:, None] * stride_dA_cs_head + offs_c[None, :] * stride_dA_cs_csize + ) + chunk_size_limit = chunk_seqlen_end - chunk_seqlen_start + + dt = tl.load( + dt_ptrs, + mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), + other=0.0, + ).to(tl.float32) + if HAS_DT_BIAS: + dt_bias = tl.load( + dt_bias_ptr + offs_h * stride_dt_bias_head, mask=offs_h < nheads, other=0.0 + ).to(tl.float32) + dt += dt_bias[:, None] + if DT_SOFTPLUS: + dt = tl.where(dt <= 20.0, softplus(dt), dt) + + dt = tl.clamp(dt, dt_min, dt_max) + dt = tl.where( + (offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt, 0.0 + ) + tl.store( + dt_out_ptrs, + dt, + mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size), + ) + A = tl.load(A_ptrs, mask=offs_h < nheads, other=0.0).to(tl.float32) + dA = dt * A[:, None] + dA_cs = tl.cumsum(dA, axis=1) + tl.store( + dA_cs_ptrs, + dA_cs, + mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size), + ) + + +@triton.autotune( + configs=[ + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64}, + num_stages=3, + num_warps=8, + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, + num_stages=5, + num_warps=2, + ), + triton.Config( + {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, + num_stages=5, + num_warps=2, + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, + num_stages=4, + num_warps=2, + ), + ], + key=["hdim", "dstate", "chunk_size"], +) +@triton.jit +def _chunk_state_fwd_kernel( + # Pointers to matrices + x_ptr, + b_ptr, + states_ptr, + dt_ptr, + dA_cumsum_ptr, + cu_chunk_seqlens_ptr, + # Matrix dimensions + hdim: tl.constexpr, + dstate: tl.constexpr, + chunk_size: tl.constexpr, + seqlen, + nheads_ngroups_ratio: tl.constexpr, + # Strides + stride_x_seqlen: tl.int64, + stride_x_head: tl.int64, + stride_x_hdim: tl.constexpr, + stride_b_seqlen: tl.int64, + stride_b_head: tl.int64, + stride_b_dstate: tl.constexpr, + stride_states_chunk: tl.int64, + stride_states_head: tl.int64, + stride_states_hdim: tl.int64, + stride_states_dstate: tl.constexpr, + stride_dt_head: tl.int64, + stride_dt_chunk: tl.int64, + stride_dt_csize: tl.constexpr, + stride_dA_cs_head: tl.int64, + stride_dA_cs_chunk: tl.int64, + stride_dA_cs_csize: tl.constexpr, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +): + pid_c = tl.program_id(axis=1).to(tl.int64) + pid_h = tl.program_id(axis=2) + num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + chunk_seqlen_start = tl.load(cu_chunk_seqlens_ptr + pid_c) + chunk_seqlen_end = tl.load(cu_chunk_seqlens_ptr + pid_c + 1) + b_ptr += ( + chunk_seqlen_start * stride_b_seqlen + + (pid_h // nheads_ngroups_ratio) * stride_b_head + ) + x_ptr += chunk_seqlen_start * stride_x_seqlen + pid_h * stride_x_head + dt_ptr += pid_c * stride_dt_chunk + pid_h * stride_dt_head + dA_cumsum_ptr += pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + x_ptrs = x_ptr + ( + offs_m[:, None] * stride_x_hdim + offs_k[None, :] * stride_x_seqlen + ) + b_ptrs = b_ptr + ( + offs_n[None, :] * stride_b_dstate + offs_k[:, None] * stride_b_seqlen + ) + dt_ptrs = dt_ptr + offs_k * stride_dt_csize + dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to( + tl.float32 + ) + dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize + + chunk_size_limit = chunk_seqlen_end - chunk_seqlen_start + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, chunk_size_limit, BLOCK_SIZE_K): + x = tl.load( + x_ptrs, + mask=(offs_m[:, None] < hdim) & (offs_k[None, :] < chunk_size_limit - k), + other=0.0, + ) + b = tl.load( + b_ptrs, + mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < dstate), + other=0.0, + ).to(tl.float32) + dA_cs_k = tl.load( + dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0 + ).to(tl.float32) + dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to( + tl.float32 + ) + scale = tl.exp(dA_cs_last - dA_cs_k) * dt_k + b *= scale[:, None] + b = b.to(x_ptr.dtype.element_ty) + acc += tl.dot(x, b) + + x_ptrs += BLOCK_SIZE_K * stride_x_seqlen + b_ptrs += BLOCK_SIZE_K * stride_b_seqlen + dt_ptrs += BLOCK_SIZE_K * stride_dt_csize + dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize + + states = acc.to(states_ptr.dtype.element_ty) + + states_ptr += pid_c * stride_states_chunk + pid_h * stride_states_head + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + states_ptrs = states_ptr + ( + offs_m[:, None] * stride_states_hdim + offs_n[None, :] * stride_states_dstate + ) + c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate) + tl.store(states_ptrs, states, mask=c_mask) + + +def _chunk_cumsum_fwd( + dt, + A, + chunk_size, + cu_chunk_seqlens, + dt_bias=None, + dt_softplus=False, + dt_limit=(0.0, float("inf")), +): + seqlen, nheads = dt.shape + assert A.shape == (nheads,) + if dt_bias is not None: + assert dt_bias.shape == (nheads,) + nchunks = cu_chunk_seqlens.shape[0] - 1 + dt_out = torch.empty( + nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32 + ) + dA_cumsum = torch.empty( + nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32 + ) + grid_chunk_cs = lambda META: (nchunks, triton.cdiv(nheads, META["BLOCK_SIZE_H"])) + with torch.cuda.device(dt.device.index): + _chunk_cumsum_fwd_kernel[grid_chunk_cs]( + dt_ptr=dt, + A_ptr=A, + dt_bias_ptr=dt_bias, + dt_out_ptr=dt_out, + dA_cumsum_ptr=dA_cumsum, + cu_chunk_seqlens_ptr=cu_chunk_seqlens, + seqlen=seqlen, + nheads=nheads, + chunk_size=chunk_size, + dt_min=dt_limit[0], + dt_max=dt_limit[1], + stride_dt_seqlen=dt.stride(0), + stride_dt_head=dt.stride(1), + stride_A_head=A.stride(0), + stride_dt_bias_head=dt_bias.stride(0) if dt_bias is not None else 0, + stride_dt_out_head=dt_out.stride(0), + stride_dt_out_chunk=dt_out.stride(1), + stride_dt_out_csize=dt_out.stride(2), + stride_dA_cs_head=dA_cumsum.stride(0), + stride_dA_cs_chunk=dA_cumsum.stride(1), + stride_dA_cs_csize=dA_cumsum.stride(2), + DT_SOFTPLUS=dt_softplus, + HAS_DT_BIAS=dt_bias is not None, + BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size), + ) + return dA_cumsum, dt_out + + +def _chunk_state_fwd( + B, x, dt, dA_cumsum, cu_chunk_seqlens, states=None, states_in_fp32=True +): + seqlen, nheads, headdim = x.shape + _, nchunks, chunk_size = dt.shape + _, ngroups, dstate = B.shape + assert nheads % ngroups == 0 + assert B.shape == (seqlen, ngroups, dstate) + assert dt.shape == (nheads, nchunks, chunk_size) + assert dA_cumsum.shape == dt.shape + + if states is not None: + assert states.shape == (nchunks, nheads, headdim, dstate) + else: + states_dtype = torch.float32 if states_in_fp32 else B.dtype + states = torch.empty( + (nchunks, nheads, headdim, dstate), device=x.device, dtype=states_dtype + ) + + grid = lambda META: ( + triton.cdiv(headdim, META["BLOCK_SIZE_M"]) + * triton.cdiv(dstate, META["BLOCK_SIZE_N"]), + nchunks, + nheads, + ) + with torch.cuda.device(x.device.index): + _chunk_state_fwd_kernel[grid]( + x_ptr=x, + b_ptr=B, + states_ptr=states, + dt_ptr=dt, + dA_cumsum_ptr=dA_cumsum, + cu_chunk_seqlens_ptr=cu_chunk_seqlens, + hdim=headdim, + dstate=dstate, + chunk_size=chunk_size, + seqlen=seqlen, + nheads_ngroups_ratio=nheads // ngroups, + stride_x_seqlen=x.stride(0), + stride_x_head=x.stride(1), + stride_x_hdim=x.stride(2), + stride_b_seqlen=B.stride(0), + stride_b_head=B.stride(1), + stride_b_dstate=B.stride(2), + stride_states_chunk=states.stride(0), + stride_states_head=states.stride(1), + stride_states_hdim=states.stride(2), + stride_states_dstate=states.stride(3), + stride_dt_head=dt.stride(0), + stride_dt_chunk=dt.stride(1), + stride_dt_csize=dt.stride(2), + stride_dA_cs_head=dA_cumsum.stride(0), + stride_dA_cs_chunk=dA_cumsum.stride(1), + stride_dA_cs_csize=dA_cumsum.stride(2), + ) + return states + + +@triton.autotune( + configs=[ + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64}, + num_stages=3, + num_warps=8, + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, + num_stages=5, + num_warps=2, + ), + triton.Config( + {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, + num_stages=5, + num_warps=2, + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, + num_stages=4, + num_warps=2, + ), + ], + key=["hdim", "dstate", "chunk_size"], +) +@triton.jit +def _chunk_state_varlen_kernel( + x_ptr, + b_ptr, + dt_ptr, + dA_cumsum_ptr, + chunk_states_ptr, + cu_seqlens_ptr, + last_chunk_indices_ptr, + cu_chunk_seqlens_ptr, + states_ptr, + initstates_ptr, + hdim: tl.constexpr, + dstate: tl.constexpr, + chunk_size: tl.constexpr, + nheads_ngroups_ratio: tl.constexpr, + stride_x_seqlen: tl.int64, + stride_x_head: tl.int64, + stride_x_hdim: tl.constexpr, + stride_b_seqlen: tl.int64, + stride_b_head: tl.int64, + stride_b_dstate: tl.constexpr, + stride_dt_head: tl.int64, + stride_dt_chunk: tl.int64, + stride_dt_csize: tl.constexpr, + stride_dA_cs_head: tl.int64, + stride_dA_cs_chunk: tl.int64, + stride_dA_cs_csize: tl.constexpr, + stride_chunk_states_chunk: tl.int64, + stride_chunk_states_head: tl.int64, + stride_chunk_states_hdim: tl.int64, + stride_chunk_states_dstate: tl.constexpr, + stride_states_batch: tl.int64, + stride_states_head: tl.int64, + stride_states_hdim: tl.int64, + stride_states_dstate: tl.constexpr, + stride_init_states_batch: tl.int64, + stride_init_states_head: tl.int64, + stride_init_states_hdim: tl.int64, + stride_init_states_dstate: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + HAS_INITSTATES: tl.constexpr, + USE_LAST_CHUNK_INDICES: tl.constexpr, +): + pid_b = tl.program_id(axis=1) + pid_h = tl.program_id(axis=2) + num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + end_idx = tl.load(cu_seqlens_ptr + pid_b + 1) + start_idx = tl.load(cu_seqlens_ptr + pid_b) + if USE_LAST_CHUNK_INDICES: + pid_c = tl.load(last_chunk_indices_ptr + pid_b).to(tl.int64) + chunk_start = tl.load(cu_chunk_seqlens_ptr + pid_c) + chunk_size_limit = tl.load(cu_chunk_seqlens_ptr + pid_c + 1) - chunk_start + else: + pid_c = (end_idx - 1) // chunk_size + chunk_start = pid_c * chunk_size + chunk_size_limit = end_idx - chunk_start + b_ptr += ( + chunk_start * stride_b_seqlen + + (pid_h // nheads_ngroups_ratio) * stride_b_head + ) + x_ptr += chunk_start * stride_x_seqlen + pid_h * stride_x_head + dt_ptr += pid_c * stride_dt_chunk + pid_h * stride_dt_head + dA_cumsum_ptr += pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head + chunk_states_ptr += ( + pid_c * stride_chunk_states_chunk + pid_h * stride_chunk_states_head + ) + + if HAS_INITSTATES: + initstates_ptr += pid_h * stride_init_states_head + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + x_ptrs = x_ptr + ( + offs_m[:, None] * stride_x_hdim + offs_k[None, :] * stride_x_seqlen + ) + b_ptrs = b_ptr + ( + offs_n[None, :] * stride_b_dstate + offs_k[:, None] * stride_b_seqlen + ) + dt_ptrs = dt_ptr + offs_k * stride_dt_csize + dA_cs_last = tl.load( + dA_cumsum_ptr + (end_idx - 1 - chunk_start) * stride_dA_cs_csize + ).to(tl.float32) + dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize + + start_idx_cur = tl.maximum(start_idx - chunk_start, 0) + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, chunk_size_limit, BLOCK_SIZE_K): + x = tl.load( + x_ptrs, + mask=(offs_m[:, None] < hdim) + & (offs_k[None, :] < chunk_size_limit - k) + & (offs_k[None, :] >= start_idx_cur - k), + other=0.0, + ) + b = tl.load( + b_ptrs, + mask=(offs_k[:, None] < chunk_size_limit - k) + & (offs_n[None, :] < dstate) + & (offs_k[:, None] >= start_idx_cur - k), + other=0.0, + ).to(tl.float32) + dA_cs_k = tl.load( + dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0 + ).to(tl.float32) + dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to( + tl.float32 + ) + scale = tl.where( + (offs_k >= start_idx_cur - k) & (offs_k < chunk_size_limit - k), + tl.exp(dA_cs_last - dA_cs_k) * dt_k, + 0.0, + ) + b *= scale[:, None] + b = b.to(x_ptr.dtype.element_ty) + acc += tl.dot(x, b) + x_ptrs += BLOCK_SIZE_K * stride_x_seqlen + b_ptrs += BLOCK_SIZE_K * stride_b_seqlen + dt_ptrs += BLOCK_SIZE_K * stride_dt_csize + dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize + + if (start_idx < chunk_start) or (HAS_INITSTATES): + dA_cs_boundary = 0.0 + if not HAS_INITSTATES: + past_states_ptrs = chunk_states_ptr + ( + offs_m[:, None] * stride_chunk_states_hdim + + offs_n[None, :] * stride_chunk_states_dstate + ) + else: + if start_idx < chunk_start: + past_states_ptrs = chunk_states_ptr + ( + offs_m[:, None] * stride_chunk_states_hdim + + offs_n[None, :] * stride_chunk_states_dstate + ) + else: + past_states_ptrs = initstates_ptr + ( + pid_b * stride_init_states_batch + + offs_m[:, None] * stride_init_states_hdim + + offs_n[None, :] * stride_init_states_dstate + ) + if start_idx > chunk_start: + dA_cs_boundary = tl.load( + dA_cumsum_ptr + + (start_idx - chunk_start - 1) * stride_dA_cs_csize + ).to(tl.float32) + + past_states = tl.load( + past_states_ptrs, + mask=(offs_m[:, None] < hdim) & (offs_n[None, :] < dstate), + other=0.0, + ).to(tl.float32) + scale = tl.exp(dA_cs_last - dA_cs_boundary) + acc += past_states * scale + + states = acc.to(states_ptr.dtype.element_ty) + states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + states_ptrs = states_ptr + ( + offs_m[:, None] * stride_states_hdim + offs_n[None, :] * stride_states_dstate + ) + c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate) + tl.store(states_ptrs, states, mask=c_mask) + + +def chunk_state_varlen( + B, + x, + dt, + dA_cumsum, + cu_seqlens, + chunk_states, + initial_states=None, + last_chunk_indices=None, + cu_chunk_seqlens=None, +): + """Compute per-sequence final SSM state from chunk states (correct when sequences share chunks).""" + total_seqlen, nheads, headdim = x.shape + _, nchunks, chunk_size = dt.shape + _, ngroups, dstate = B.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + assert nheads % ngroups == 0 + assert B.shape == (total_seqlen, ngroups, dstate) + assert dt.shape == (nheads, nchunks, chunk_size) + assert dA_cumsum.shape == dt.shape + assert chunk_states.shape == (nchunks, nheads, headdim, dstate) + if initial_states is not None: + assert initial_states.shape == (batch, nheads, headdim, dstate) + use_last_chunk = ( + last_chunk_indices is not None and cu_chunk_seqlens is not None + ) + if use_last_chunk: + last_chunk_indices = last_chunk_indices.contiguous().to(x.device) + cu_chunk_seqlens = cu_chunk_seqlens.contiguous().to(x.device) + else: + last_chunk_indices = torch.zeros(1, dtype=torch.int64, device=x.device) + cu_chunk_seqlens = cu_seqlens + + states = torch.empty( + batch, + nheads, + headdim, + dstate, + dtype=chunk_states.dtype, + device=chunk_states.device, + ) + initial_states_strides = ( + ( + initial_states.stride(0), + initial_states.stride(1), + initial_states.stride(2), + initial_states.stride(3), + ) + if initial_states is not None + else (0, 0, 0, 0) + ) + grid = lambda META: ( + triton.cdiv(headdim, META["BLOCK_SIZE_M"]) + * triton.cdiv(dstate, META["BLOCK_SIZE_N"]), + batch, + nheads, + ) + with torch.cuda.device(x.device.index): + _chunk_state_varlen_kernel[grid]( + x_ptr=x, + b_ptr=B, + dt_ptr=dt, + dA_cumsum_ptr=dA_cumsum, + chunk_states_ptr=chunk_states, + cu_seqlens_ptr=cu_seqlens, + last_chunk_indices_ptr=last_chunk_indices, + cu_chunk_seqlens_ptr=cu_chunk_seqlens, + states_ptr=states, + initstates_ptr=initial_states, + hdim=headdim, + dstate=dstate, + chunk_size=chunk_size, + nheads_ngroups_ratio=nheads // ngroups, + stride_x_seqlen=x.stride(0), + stride_x_head=x.stride(1), + stride_x_hdim=x.stride(2), + stride_b_seqlen=B.stride(0), + stride_b_head=B.stride(1), + stride_b_dstate=B.stride(2), + stride_dt_head=dt.stride(0), + stride_dt_chunk=dt.stride(1), + stride_dt_csize=dt.stride(2), + stride_dA_cs_head=dA_cumsum.stride(0), + stride_dA_cs_chunk=dA_cumsum.stride(1), + stride_dA_cs_csize=dA_cumsum.stride(2), + stride_chunk_states_chunk=chunk_states.stride(0), + stride_chunk_states_head=chunk_states.stride(1), + stride_chunk_states_hdim=chunk_states.stride(2), + stride_chunk_states_dstate=chunk_states.stride(3), + stride_states_batch=states.stride(0), + stride_states_head=states.stride(1), + stride_states_hdim=states.stride(2), + stride_states_dstate=states.stride(3), + stride_init_states_batch=initial_states_strides[0], + stride_init_states_head=initial_states_strides[1], + stride_init_states_hdim=initial_states_strides[2], + stride_init_states_dstate=initial_states_strides[3], + HAS_INITSTATES=initial_states is not None, + USE_LAST_CHUNK_INDICES=use_last_chunk, + ) + return states diff --git a/megatron/core/ssm/ops/ssd_combined.py b/megatron/core/ssm/ops/ssd_combined.py new file mode 100644 index 00000000000..b7918fedf77 --- /dev/null +++ b/megatron/core/ssm/ops/ssd_combined.py @@ -0,0 +1,241 @@ +# Copyright (c) 2024, Tri Dao, Albert Gu. +# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_combined.py +# Adapted from vLLM project (Apache-2.0). + +import torch +from einops import rearrange +from packaging import version + +import triton + +from .ssd_bmm import _bmm_chunk_fwd +from .ssd_chunk_scan import _chunk_scan_fwd +from .ssd_chunk_state import ( + _chunk_cumsum_fwd, + _chunk_state_fwd, + chunk_state_varlen, +) +from .ssd_state_passing import _state_passing_fwd + +TRITON_22 = version.parse(triton.__version__) >= version.parse("2.2.0") + + +def is_int_pow_2(n): + return isinstance(n, int) and n > 0 and (n & (n - 1)) == 0 + + +def _mamba_chunk_scan_combined_fwd( + x, + dt, + A, + B, + C, + chunk_size, + out, + D=None, + z=None, + dt_bias=None, + initial_states=None, + return_intermediate_states=False, + seq_idx=None, + cu_seqlens=None, + cu_chunk_seqlens=None, + last_chunk_indices=None, + dt_softplus=False, + dt_limit=(0.0, float("inf")), + state_dtype=None, +): + assert is_int_pow_2(chunk_size), "chunk_size must be integer power of 2" + seqlen, nheads, headdim = x.shape + _, ngroups, dstate = B.shape + assert nheads % ngroups == 0 + assert B.shape == (seqlen, ngroups, dstate), f"B.shape={B.shape} != ({seqlen}, {ngroups}, {dstate})" + assert dt.shape == (seqlen, nheads) + assert A.shape == (nheads,) + assert C.shape == B.shape + if z is not None: + assert z.shape == x.shape + if D is not None: + assert D.shape == (nheads, headdim) or D.shape == (nheads,) + if seq_idx is not None: + assert seq_idx.shape == (cu_chunk_seqlens.shape[0] - 1,) + if B.stride(-1) != 1: + B = B.contiguous() + if C.stride(-1) != 1: + C = C.contiguous() + if ( + x.stride(-1) != 1 and x.stride(0) != 1 + ): # Either M or K dimension should be contiguous + x = x.contiguous() + if ( + z is not None and z.stride(-1) != 1 and z.stride(0) != 1 + ): # Either M or K dimension should be contiguous + z = z.contiguous() + if D is not None and D.stride(-1) != 1: + D = D.contiguous() + assert cu_seqlens is not None, "Assuming varlen input - must supply cu_seqlens" + + if initial_states is not None: + assert initial_states.shape == (len(cu_seqlens) - 1, nheads, headdim, dstate) + + # This function executes 5 sub-functions for computing mamba + # - a good resource is the blog https://goombalab.github.io/blog/2024/mamba2-part3-algorithm/ + # which has a minimal implementation to understand the below operations + # - as explained by the blog, mamba is a special case of causal attention + # - the idea is to chunk the attention matrix and compute each + # submatrix separately using different optimizations. + # - see the blog and paper for a visualization of the submatrices + # which we refer to in the comments below + + # 1. Compute chunked cumsum of A * dt + # - here dt may go through a softplus activation + dA_cumsum, dt = _chunk_cumsum_fwd( + dt, + A, + chunk_size, + cu_chunk_seqlens, + dt_bias=dt_bias, + dt_softplus=dt_softplus, + dt_limit=dt_limit, + ) + + # 2. Compute the state for each intra-chunk + # (right term of low-rank factorization of off-diagonal blocks; B terms) + states = _chunk_state_fwd( + B, x, dt, dA_cumsum, cu_chunk_seqlens, states_in_fp32=True + ) + + # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries + # (middle term of factorization of off-diag blocks; A terms) + # - for handling chunked prefill, this requires i) initial_states and + # ii) seq_idx to be all specified. + # - When a new seq_idx is detected, we will stop passing the prev_state + # and switch accordingly to the init_state corresponding to the new seq_idx. + states = _state_passing_fwd( + rearrange(states, "... p n -> ... (p n)"), + dA_cumsum, # (nheads, nchunks, chunk_size) + cu_chunk_seqlens, + initial_states=rearrange(initial_states, "... p n -> ... (p n)") + if initial_states is not None + else None, # (batch, nheads, headdim*dstate) + seq_idx=seq_idx, + out_dtype=state_dtype if state_dtype is not None else C.dtype, + ) + states = rearrange(states, "... (p n) -> ... p n", n=dstate) + + # 4. Compute batched matrix multiply for C_j^T B_i terms + CB = _bmm_chunk_fwd(C, B, chunk_size, cu_chunk_seqlens, output_dtype=torch.float32) + + # 5. Scan and compute the diagonal blocks, taking into + # account past causal states. + # - if initial states are provided, then states information will be + # augmented with initial_states. + # - to do this properly, we need to account for example changes in + # the continuous batch, therefore we introduce pseudo chunks, which is + # a chunk that is split up each time an example changes. + # - in each (pseudo) chunk, we detect if the previous (pseudo) chunk had + # a seq_idx change, in which case we take states information from + # init_states. + _chunk_scan_fwd( + CB, + x, + dt, + dA_cumsum, + C, + states, + cu_chunk_seqlens, + out, # in-place update + seq_idx, + D=D, + z=z, + initial_states=initial_states, + ) + + if return_intermediate_states: + return states + else: + # Per-sequence final state at exact last token (correct when sequences share chunks) + return chunk_state_varlen( + B, + x, + dt, + dA_cumsum, + cu_seqlens, + states, + initial_states=initial_states, + last_chunk_indices=last_chunk_indices, + cu_chunk_seqlens=cu_chunk_seqlens, + ) + + +def mamba_chunk_scan_combined_varlen( + x, + dt, + A, + B, + C, + chunk_size, + cu_seqlens, + cu_chunk_seqlens, + last_chunk_indices, + seq_idx, + out, + D=None, + z=None, + dt_bias=None, + initial_states=None, + dt_softplus=False, + dt_limit=(0.0, float("inf")), + return_intermediate_states=False, + state_dtype=None, +): + """ + Argument: + x: (seqlen, nheads, headdim) + dt: (seqlen, nheads) + A: (nheads) + B: (seqlen, ngroups, dstate) + C: (seqlen, ngroups, dstate) + chunk_size: int + cu_seqlens: (batch + 1,) + cu_chunk_seqlens: (nchunks + 1,) + last_chunk_indices: (batch,) + seq_idx: (nchunks,) + out: (seqlen, nheads, headdim) preallocated output tensor + D: (nheads, headdim) or (nheads,) + z: (seqlen, nheads, headdim) + dt_bias: (nheads,) + initial_states: (batch, nheads, headdim, dstate) + dt_softplus: Whether to apply softplus to dt + out: (seqlen, nheads, headdim) preallocated output tensor + state_dtype: The data type of the ssm state + Return: + varlen_states: (batch, nheads, headdim, dstate) + """ + + assert cu_seqlens is not None, "cu_seqlens must be provided assuming varlen input" + assert seq_idx is not None + + varlen_states = _mamba_chunk_scan_combined_fwd( + x, + dt, + A, + B, + C, + chunk_size, + out, + D=D, + z=z, + dt_bias=dt_bias, + initial_states=initial_states, + return_intermediate_states=return_intermediate_states, + seq_idx=seq_idx, + cu_seqlens=cu_seqlens, + cu_chunk_seqlens=cu_chunk_seqlens, + last_chunk_indices=last_chunk_indices, + dt_softplus=dt_softplus, + dt_limit=dt_limit, + state_dtype=state_dtype, + ) + + return varlen_states diff --git a/megatron/core/ssm/ops/ssd_state_passing.py b/megatron/core/ssm/ops/ssd_state_passing.py new file mode 100644 index 00000000000..a121a860be4 --- /dev/null +++ b/megatron/core/ssm/ops/ssd_state_passing.py @@ -0,0 +1,153 @@ +# Copyright (c) 2024, Tri Dao, Albert Gu. +# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_state_passing.py +# Adapted from vLLM project (Apache-2.0). + +import torch +import triton +import triton.language as tl + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE": 64}), + triton.Config({"BLOCK_SIZE": 128}), + triton.Config({"BLOCK_SIZE": 256}), + triton.Config({"BLOCK_SIZE": 512}), + triton.Config({"BLOCK_SIZE": 1024}), + triton.Config({"BLOCK_SIZE": 2048}), + ], + key=["dim"], +) +@triton.jit +def _state_passing_fwd_kernel( + # Pointers to matrices + states_ptr, + out_ptr, + dA_cs_ptr, + initstates_ptr, + seq_idx_ptr, + cu_chunk_seqlens_ptr, + # Matrix dimensions + dim: tl.constexpr, + nchunks, + seqlen, + chunk_size: tl.constexpr, + # Strides + stride_states_chunk: tl.int64, + stride_states_head: tl.int64, + stride_states_dim: tl.constexpr, + stride_out_chunk: tl.int64, + stride_out_head: tl.int64, + stride_out_dim: tl.constexpr, + stride_dA_cs_head: tl.int64, + stride_dA_cs_chunk: tl.int64, + stride_dA_cs_csize: tl.constexpr, + stride_initstates_batch: tl.int64, + stride_initstates_head: tl.int64, + stride_initstates_dim: tl.constexpr, + stride_seq_idx_chunk: tl.constexpr, + # Meta-parameters + HAS_INITSTATES: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + pid_h = tl.program_id(axis=1) + pid_m = tl.program_id(axis=0) + + states_ptr += pid_h * stride_states_head + dA_cs_ptr += pid_h * stride_dA_cs_head + (chunk_size - 1) * stride_dA_cs_csize + out_ptr += pid_h * stride_out_head + + offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + states_ptrs = states_ptr + offs_m * stride_states_dim + out_ptrs = out_ptr + offs_m * stride_out_dim + + if HAS_INITSTATES: + initstates_ptrs = ( + initstates_ptr + + pid_h * stride_initstates_head + + offs_m * stride_initstates_dim + ) + + states = tl.load(initstates_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + else: + states = tl.zeros((BLOCK_SIZE,), dtype=tl.float32) + + prev_seq_idx = 0 + for c in range(nchunks): + new_states = tl.load(states_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + dA_cs = tl.load(dA_cs_ptr).to(tl.float32) + seq_idx = tl.load(seq_idx_ptr + c * stride_seq_idx_chunk) + # we have started a new sequence + if prev_seq_idx != seq_idx: + if HAS_INITSTATES: + initstates_ptrs = ( + initstates_ptr + + seq_idx * stride_initstates_batch + + pid_h * stride_initstates_head + + offs_m * stride_initstates_dim + ) + states = tl.load(initstates_ptrs, mask=offs_m < dim, other=0.0).to( + tl.float32 + ) + else: + states = tl.zeros((BLOCK_SIZE,), dtype=tl.float32) + + prev_seq_idx = seq_idx + states = tl.exp(dA_cs) * states + new_states + tl.store(out_ptrs, states, mask=offs_m < dim) + + states_ptrs += stride_states_chunk + dA_cs_ptr += stride_dA_cs_chunk + out_ptrs += stride_out_chunk + + +def _state_passing_fwd( + states, + dA_cumsum, + cu_chunk_seqlens, + seq_idx, + initial_states=None, + out_dtype=None, +): + nchunks, nheads, dim = states.shape + chunk_size = dA_cumsum.shape[-1] + assert dA_cumsum.shape == (nheads, nchunks, chunk_size) + seqlen = seq_idx.shape[-1] + out_dtype = states.dtype if out_dtype is None else out_dtype + out = torch.empty((nchunks, nheads, dim), device=states.device, dtype=out_dtype) + + initial_states_strides = ( + (initial_states.stride(0), initial_states.stride(1), initial_states.stride(2)) + if initial_states is not None + else (0, 0, 0) + ) + + grid = lambda META: (triton.cdiv(dim, META["BLOCK_SIZE"]), nheads) + with torch.cuda.device(states.device.index): + _state_passing_fwd_kernel[grid]( + states_ptr=states, + out_ptr=out, + dA_cs_ptr=dA_cumsum, + initstates_ptr=initial_states, + seq_idx_ptr=seq_idx, + cu_chunk_seqlens_ptr=cu_chunk_seqlens, + dim=dim, + nchunks=nchunks, + seqlen=seqlen if seq_idx is not None else 0, + chunk_size=chunk_size if seq_idx is not None else 0, + stride_states_chunk=states.stride(0), + stride_states_head=states.stride(1), + stride_states_dim=states.stride(2), + stride_out_chunk=out.stride(0), + stride_out_head=out.stride(1), + stride_out_dim=out.stride(2), + stride_dA_cs_head=dA_cumsum.stride(0), + stride_dA_cs_chunk=dA_cumsum.stride(1), + stride_dA_cs_csize=dA_cumsum.stride(2), + stride_initstates_batch=initial_states_strides[0], + stride_initstates_head=initial_states_strides[1], + stride_initstates_dim=initial_states_strides[2], + stride_seq_idx_chunk=seq_idx.stride(0), + HAS_INITSTATES=initial_states is not None, + ) + return out diff --git a/megatron/inference/utils.py b/megatron/inference/utils.py index 4c2104afb4b..7bc9a0e9c70 100644 --- a/megatron/inference/utils.py +++ b/megatron/inference/utils.py @@ -331,6 +331,8 @@ def get_inference_config_from_model_and_args(model: MegatronModule, args): enable_chunked_prefill=args.enable_chunked_prefill, enable_prefix_caching=args.inference_dynamic_batching_enable_prefix_caching, prefix_caching_eviction_policy=PrefixCachingEvictionPolicy(args.inference_dynamic_batching_prefix_caching_eviction_policy), + prefix_caching_mamba_gb=getattr(args, 'inference_dynamic_batching_prefix_caching_mamba_gb', None), + use_triton_conv1d=getattr(args, 'inference_dynamic_batching_mamba_triton_conv1d', False), metrics_writer=metrics_writer, logging_step_interval=args.inference_logging_step_interval, ) diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index 3f596e66503..9d8e32f849e 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -1409,6 +1409,16 @@ def validate_args(args, defaults={}): assert args.inference_dynamic_batching_buffer_size_gb is not None assert args.inference_dynamic_batching_block_size % 256 == 0, "block size should be a multiple of 256" + # Mamba prefix caching requires chunked prefill for breaking at block boundaries + if (args.inference_dynamic_batching_prefix_caching_mamba_gb is not None and + args.inference_dynamic_batching_prefix_caching_mamba_gb > 0): + if not args.enable_chunked_prefill: + args.enable_chunked_prefill = True + warn_rank_0( + 'Chunked prefill was disabled but is required for Mamba prefix caching. ' + 'Enabling chunked prefill automatically.' + ) + if args.cuda_graph_impl == "local" and args.expert_model_parallel_size > 1 and args.transformer_impl != "inference_optimized": assert args.moe_pad_experts_for_cuda_graph_inference, \ "--moe-pad-experts-for-cuda-graph-inference must be set when using CUDA graphs with expert parallelism" @@ -1825,6 +1835,18 @@ def _add_inference_args(parser): '"ref_zero" (default) immediately returns blocks to the ' 'free pool when ref_count hits 0. "lru" keeps blocks ' 'cached and evicts via LRU only when space is needed.') + group.add_argument('--inference-dynamic-batching-prefix-caching-mamba-gb', + type=float, default=None, + help='Memory budget (GB) for cached Mamba states in prefix caching. ' + 'Required for Mamba prefix caching in hybrid models. ' + 'If not specified, Mamba prefix caching is disabled. ' + 'When enabled, chunked prefill is automatically enabled if disabled.') + group.add_argument('--inference-dynamic-batching-mamba-triton-conv1d', + '--no-inference-dynamic-batching-mamba-triton-conv1d', + action='store_true', default=False, + dest='inference_dynamic_batching_mamba_triton_conv1d', + help='Use Triton varlen conv1d kernel for Mamba prefill instead of ' + 'per-request causal_conv1d_fn calls.') group.add_argument('--inference-dynamic-batching-cuda-graph-max-tokens', type=int, default=16384, help='Maximum number of tokens to capture in a cuda graph.') diff --git a/tests/unit_tests/inference/contexts/attention_metadata/test_mamba_metadata.py b/tests/unit_tests/inference/contexts/attention_metadata/test_mamba_metadata.py index dd34061888e..c82fe1ebb26 100644 --- a/tests/unit_tests/inference/contexts/attention_metadata/test_mamba_metadata.py +++ b/tests/unit_tests/inference/contexts/attention_metadata/test_mamba_metadata.py @@ -101,7 +101,7 @@ def test_update_decode_only_exact_match(self, metadata_context): expected_decode = torch.arange(4, dtype=torch.int32, device=metadata_context.device) assert torch.equal(metadata_context.batch_indices_decode, expected_decode) assert metadata_context.batch_indices_prefill is None - assert metadata_context.batch_indices_chunked_prefill is None + assert metadata_context.batch_kernel_batch_indices is None assert metadata_context.device_decode_prefill is None assert metadata_context.cu_seqlens is None assert metadata_context.seq_idx is None @@ -125,7 +125,7 @@ def test_update_decode_only_padded(self, metadata_context): ) assert torch.equal(metadata_context.batch_indices_decode, expected_decode) assert metadata_context.batch_indices_prefill is None - assert metadata_context.batch_indices_chunked_prefill is None + assert metadata_context.batch_kernel_batch_indices is None assert metadata_context.device_decode_prefill is None @pytest.mark.internal @@ -144,7 +144,7 @@ def test_update_chunked_enabled_no_prefill_reqs(self, metadata_context): # Should behave exactly like decode-only (chunked logic skipped if real_prefill == 0) expected_decode = torch.tensor([0, 1], dtype=torch.int32, device=metadata_context.device) assert torch.equal(metadata_context.batch_indices_decode, expected_decode) - assert metadata_context.batch_indices_chunked_prefill is None + assert metadata_context.batch_kernel_batch_indices is None assert metadata_context.batch_indices_prefill is None assert metadata_context.cu_seqlens is None assert metadata_context.seq_idx is None @@ -180,7 +180,7 @@ def test_update_prefill_only_exact(self, metadata_context): assert torch.equal(metadata_context.seq_idx, expected_seq_idx) assert metadata_context.batch_indices_decode is None - assert metadata_context.batch_indices_chunked_prefill is None + assert metadata_context.batch_kernel_batch_indices is None assert metadata_context.device_decode_prefill is None @pytest.mark.internal @@ -214,7 +214,7 @@ def test_update_prefill_only_padded(self, metadata_context): assert torch.equal(metadata_context.seq_idx, expected_seq_idx) assert metadata_context.batch_indices_decode is None - assert metadata_context.batch_indices_chunked_prefill is None + assert metadata_context.batch_kernel_batch_indices is None assert metadata_context.device_decode_prefill is None # ------------------------------------------------------------------------- @@ -328,7 +328,7 @@ def test_update_chunked_prefill_mixed_exact(self, metadata_context): ) assert torch.equal(metadata_context.device_chunked_prefill, expected_device_chunked_prefill) - assert metadata_context.batch_indices_chunked_prefill[0] == 1 + assert metadata_context.batch_kernel_batch_indices[0] == 1 expected_prefill = torch.tensor([2, -1], dtype=torch.int32, device=metadata_context.device) assert torch.equal(metadata_context.batch_indices_prefill, expected_prefill) @@ -369,7 +369,7 @@ def test_update_chunked_prefill_mixed_padded(self, metadata_context): ) assert torch.equal(metadata_context.device_chunked_prefill, expected_device_chunked_prefill) - assert metadata_context.batch_indices_chunked_prefill[0] == 2 + assert metadata_context.batch_kernel_batch_indices[0] == 2 expected_prefill = torch.tensor([3, -1], dtype=torch.int32, device=metadata_context.device) assert torch.equal(metadata_context.batch_indices_prefill, expected_prefill) @@ -405,7 +405,7 @@ def test_update_chunked_only_padded(self, metadata_context): assert metadata_context.batch_indices_decode is None - assert metadata_context.batch_indices_chunked_prefill[0] == 0 + assert metadata_context.batch_kernel_batch_indices[0] == 0 expected_prefill = torch.tensor([-1, -1], dtype=torch.int32, device=metadata_context.device) assert torch.equal(metadata_context.batch_indices_prefill, expected_prefill) diff --git a/tests/unit_tests/inference/contexts/test_dynamic_prefix_caching.py b/tests/unit_tests/inference/contexts/test_dynamic_prefix_caching.py index bc8db76ecfb..9edd14b1ee8 100644 --- a/tests/unit_tests/inference/contexts/test_dynamic_prefix_caching.py +++ b/tests/unit_tests/inference/contexts/test_dynamic_prefix_caching.py @@ -213,7 +213,7 @@ def test_disabled_allocator_has_no_caching_state(self): ctx_disabled = self._ctx(enable_prefix_caching=False) alloc_d = ctx_disabled.block_allocator assert not hasattr(alloc_d, 'block_hashes') - assert not hasattr(alloc_d, 'hash_to_block_id') + assert not hasattr(alloc_d, 'kv_hash_to_block_id') assert not hasattr(alloc_d, 'block_ref_counts') ctx_rz = self._ctx(prefix_caching_eviction_policy=PrefixCachingEvictionPolicy.REF_ZERO) @@ -318,12 +318,12 @@ def test_lru_ref_decrement_preserves_cached_blocks(self): ctx.release_memory_blocks_from_request_indexes(torch.tensor([0])) assert alloc.block_ref_counts[b0].item() == 1 - assert b0_hash in alloc.hash_to_block_id + assert b0_hash in alloc.kv_hash_to_block_id ctx.release_memory_blocks_from_request_indexes(torch.tensor([1])) assert alloc.block_ref_counts[b0].item() == 0 assert alloc.block_ref_counts[b1].item() == 0 - assert b0_hash in alloc.hash_to_block_id, "LRU keeps cached blocks" + assert b0_hash in alloc.kv_hash_to_block_id, "LRU keeps cached blocks" @pytest.mark.internal def test_lru_cached_blocks_reused_by_new_request(self): @@ -339,7 +339,7 @@ def test_lru_cached_blocks_reused_by_new_request(self): ctx.release_memory_blocks_from_request_indexes(torch.tensor([0])) ctx.total_request_count = 0 assert alloc.block_ref_counts[b0].item() == 0 - assert alloc.block_hashes[b0].item() in alloc.hash_to_block_id + assert alloc.block_hashes[b0].item() in alloc.kv_hash_to_block_id ctx.add_request(self._req(ctx, prompt.clone(), request_id=2)) assert self._block_ids(ctx, 0, 2) == [b0, b1] @@ -395,12 +395,12 @@ def test_refzero_deregisters_on_last_release(self): # Release first: ref=1, hash persists ctx.release_memory_blocks_from_request_indexes(torch.tensor([0])) assert alloc.block_ref_counts[b0].item() == 1 - assert b0_hash in alloc.hash_to_block_id + assert b0_hash in alloc.kv_hash_to_block_id # Release second: ref=0, hash removed, blocks returned ctx.release_memory_blocks_from_request_indexes(torch.tensor([1])) assert alloc.block_ref_counts[b0].item() == 0 - assert b0_hash not in alloc.hash_to_block_id + assert b0_hash not in alloc.kv_hash_to_block_id assert alloc.block_hashes[b0].item() == -1 assert alloc.block_hashes[b1].item() == -1 assert alloc.total_avail == avail_before + 2 @@ -496,8 +496,8 @@ def test_blocks_discoverable_after_add_request(self): b0, b1 = self._block_ids(ctx, 0, 2) h0, h1 = req.precomputed_block_hashes - assert alloc.hash_to_block_id.get(h0) == b0 - assert alloc.hash_to_block_id.get(h1) == b1 + assert alloc.kv_hash_to_block_id.get(h0) == b0 + assert alloc.kv_hash_to_block_id.get(h1) == b1 assert alloc.block_hashes[b0].item() == h0 assert alloc.block_hashes[b1].item() == h1 @@ -538,7 +538,7 @@ def test_decode_does_not_register_completed_blocks(self): @pytest.mark.internal def test_second_request_finds_registered_blocks(self): - """After req1 registers 3 blocks, req2's hashes all resolve in hash_to_block_id.""" + """After req1 registers 3 blocks, req2's hashes all resolve in kv_hash_to_block_id.""" ctx = self._ctx() bs = ctx.block_size_tokens alloc = ctx.block_allocator @@ -549,7 +549,7 @@ def test_second_request_finds_registered_blocks(self): req2 = self._req(ctx, prompt.clone(), request_id=2) for h in req2.precomputed_block_hashes: - assert h in alloc.hash_to_block_id, f"Hash {h} should be discoverable" + assert h in alloc.kv_hash_to_block_id, f"Hash {h} should be discoverable" # ========================================================================= diff --git a/tests/unit_tests/inference/contexts/test_mamba_prefix_caching.py b/tests/unit_tests/inference/contexts/test_mamba_prefix_caching.py new file mode 100644 index 00000000000..a44ff4aa374 --- /dev/null +++ b/tests/unit_tests/inference/contexts/test_mamba_prefix_caching.py @@ -0,0 +1,2464 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +"""Tests for Mamba prefix caching in hybrid models. + +Focuses on Mamba-specific prefix caching features: +- Mamba state store/restore/invalidation +- Mamba LRU eviction +- Coupled KV+Mamba prefix matching (the key correctness fix) +- Cross-config end-to-end equivalence +- Zero-budget behavior +""" + +import random +import types + +import pytest +import torch + +from megatron.core import parallel_state +from megatron.core.inference.config import InferenceConfig, MambaInferenceStateConfig, PrefixCachingEvictionPolicy +from megatron.core.inference.contexts.dynamic_context import DynamicInferenceContext +from megatron.core.inference.engines import DynamicInferenceEngine +from megatron.core.inference.inference_request import DynamicInferenceRequest +from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper import ( + GPTInferenceWrapper, +) +from megatron.core.inference.sampling_params import SamplingParams +from megatron.core.inference.text_generation_controllers.text_generation_controller import ( + TextGenerationController, +) +from megatron.core.models.mamba.mamba_layer_specs import mamba_stack_spec +from megatron.core.models.mamba.mamba_model import MambaModel +from megatron.core.ssm.mamba_hybrid_layer_allocation import Symbols +from megatron.core.ssm.mamba_mixer import _check_mamba_sequence_packing_support +from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed +from megatron.core.transformer import TransformerConfig +from megatron.core.transformer.cuda_graphs import CudaGraphManager, _CudagraphGlobalRecord +from megatron.core.utils import is_fa_min_version +from tests.unit_tests.test_utilities import Utils + + +def _set_rounder(value): + """Set all DynamicInferenceContext rounders to a given value.""" + DynamicInferenceContext.ROUNDER = value + DynamicInferenceContext.TOKEN_ROUNDER = value + DynamicInferenceContext.REQUEST_ROUNDER = value + + +def _build_hybrid_context( + block_size=32, + max_tokens=256, + max_requests=8, + buffer_size_gb=0.01, + enable_prefix_caching=True, + prefix_caching_mamba_gb=0.001, + num_layers=4, + kv_channels=8, + num_attention_heads=2, + rounder=64, + layer_type_list=None, + params_dtype=torch.float32, + pp_size=1, + eviction_policy=PrefixCachingEvictionPolicy.REF_ZERO, +) -> DynamicInferenceContext: + """Build a DynamicInferenceContext configured for a hybrid Mamba model.""" + _set_rounder(rounder) + + if layer_type_list is None: + layer_type_list = [Symbols.MAMBA, Symbols.MLP, Symbols.ATTENTION, Symbols.MLP] + + mamba_conv_states_shape = (544, 4) + mamba_ssm_states_shape = (8, 64, 16) + mamba_inference_state_config = MambaInferenceStateConfig( + layer_type_list, mamba_conv_states_shape, mamba_ssm_states_shape + ) + + transformer_config = TransformerConfig( + params_dtype=params_dtype, + num_layers=num_layers, + kv_channels=kv_channels, + num_attention_heads=num_attention_heads, + hidden_size=kv_channels * num_attention_heads, + tensor_model_parallel_size=1, + pipeline_model_parallel_size=pp_size, + use_cpu_initialization=True, + ) + inference_config = InferenceConfig( + max_sequence_length=1024, + buffer_size_gb=buffer_size_gb, + paused_buffer_size_gb=0.2 * buffer_size_gb, + block_size_tokens=block_size, + max_tokens=max_tokens, + mamba_inference_state_config=mamba_inference_state_config, + use_flashinfer_fused_rope=None, + unified_memory_level=0, + enable_prefix_caching=enable_prefix_caching, + prefix_caching_mamba_gb=prefix_caching_mamba_gb, + prefix_caching_eviction_policy=eviction_policy, + ) + return DynamicInferenceContext( + model_config=transformer_config, inference_config=inference_config + ) + + +def _make_request(request_id, prompt_tokens, block_size, enable_prefix_caching=True, + num_tokens_to_generate=50): + """Create a DynamicInferenceRequest with the given parameters.""" + if isinstance(prompt_tokens, int): + prompt_tokens = torch.arange(prompt_tokens, device=torch.cuda.current_device()) + return DynamicInferenceRequest( + request_id=request_id, + prompt_tokens=prompt_tokens, + sampling_params=SamplingParams(num_tokens_to_generate=num_tokens_to_generate), + block_size_tokens=block_size, + enable_prefix_caching=enable_prefix_caching, + ) + + + +# ============================================================================= +# Test Classes +# ============================================================================= + + +class TestMambaCacheOperations: + """Tests for basic Mamba state store, restore, and invalidation.""" + + def setup_method(self, method): + Utils.initialize_model_parallel( + tensor_model_parallel_size=1, pipeline_model_parallel_size=1 + ) + model_parallel_cuda_manual_seed(123) + + def teardown_method(self, method): + Utils.destroy_model_parallel() + + @pytest.mark.internal + def test_store_and_restore_mamba_state(self): + """Store Mamba state for a block, then restore it to a different request slot.""" + ctx = _build_hybrid_context(prefix_caching_mamba_gb=0.01) + block_size = ctx.block_size_tokens + + req = _make_request(1, block_size * 2, block_size) + ctx.add_request(req) + + # Write known values into request's Mamba state + mamba_idx = ctx.mamba_metadata.request_to_mamba_state_idx[0].item() + ctx.mamba_conv_states[:, mamba_idx] = 1.0 + ctx.mamba_ssm_states[:, mamba_idx] = 2.0 + + # Store for block 0 + block_0_id = ctx.request_to_kv_block_ids[0][0].item() + ctx.store_mamba_state_for_block(block_0_id, 0) + + # Overwrite request state, then restore + ctx.mamba_conv_states[:, mamba_idx] = 0.0 + ctx.mamba_ssm_states[:, mamba_idx] = 0.0 + + restored = ctx.restore_mamba_state_from_block(0, block_0_id) + assert restored + assert torch.allclose( + ctx.mamba_conv_states[:, mamba_idx], + torch.ones_like(ctx.mamba_conv_states[:, mamba_idx]), + ) + assert torch.allclose( + ctx.mamba_ssm_states[:, mamba_idx], + torch.full_like(ctx.mamba_ssm_states[:, mamba_idx], 2.0), + ) + + @pytest.mark.internal + def test_has_mamba_state_for_block(self): + """has_mamba_state_for_block returns True only after store.""" + ctx = _build_hybrid_context(prefix_caching_mamba_gb=0.01) + block_size = ctx.block_size_tokens + + req = _make_request(1, block_size * 2, block_size) + ctx.add_request(req) + + block_0_id = ctx.request_to_kv_block_ids[0][0].item() + assert not ctx.has_mamba_state_for_block(block_0_id) + + ctx.store_mamba_state_for_block(block_0_id, 0) + assert ctx.has_mamba_state_for_block(block_0_id) + + @pytest.mark.internal + def test_mamba_state_invalidated_on_block_eviction(self): + """invalidate_mamba_state_for_block clears stored state.""" + ctx = _build_hybrid_context(prefix_caching_mamba_gb=0.01) + block_size = ctx.block_size_tokens + + req = _make_request(1, block_size * 2, block_size) + ctx.add_request(req) + + block_0_id = ctx.request_to_kv_block_ids[0][0].item() + ctx.store_mamba_state_for_block(block_0_id, 0) + assert ctx.has_mamba_state_for_block(block_0_id) + + ctx.invalidate_mamba_state_for_block(block_0_id) + assert not ctx.has_mamba_state_for_block(block_0_id) + + +class TestMambaCacheEviction: + """Tests for Mamba LRU eviction behavior.""" + + def setup_method(self, method): + Utils.initialize_model_parallel( + tensor_model_parallel_size=1, pipeline_model_parallel_size=1 + ) + model_parallel_cuda_manual_seed(123) + + def teardown_method(self, method): + Utils.destroy_model_parallel() + + @pytest.mark.internal + def test_lru_eviction_when_pool_full(self): + """When the Mamba cache is full, the LRU slot (oldest timestamp, ref_count=0) is evicted.""" + ctx = _build_hybrid_context(prefix_caching_mamba_gb=0.001, eviction_policy=PrefixCachingEvictionPolicy.LRU) + max_slots = ctx.max_mamba_cache_slots + assert max_slots > 0 + + # Fill all slots manually + for i in range(max_slots): + ctx.block_to_mamba_slot[i] = i + ctx.mamba_slot_to_block[i] = i + ctx.block_allocator.block_ref_counts[i] = 0 + ctx.block_allocator.block_timestamps[i] = i * 100 # block 0 oldest + ctx.mamba_cache_free_count = 0 + + # Allocate a new slot -- should evict block 0 (oldest) + new_block_id = max_slots + 1 + slot = ctx._allocate_mamba_cache_slot(new_block_id) + assert slot >= 0 + assert ctx.block_to_mamba_slot[0].item() == -1, "Block 0 should be evicted" + assert ctx.block_to_mamba_slot[new_block_id].item() == slot + + @pytest.mark.internal + def test_eviction_frees_slot_for_reuse(self): + """Invalidating a block returns its slot to the free pool.""" + ctx = _build_hybrid_context(prefix_caching_mamba_gb=0.01) + initial_free = ctx.mamba_cache_free_count + + slot = ctx._allocate_mamba_cache_slot(0) + assert ctx.mamba_cache_free_count == initial_free - 1 + + ctx.invalidate_mamba_state_for_block(0) + assert ctx.mamba_cache_free_count == initial_free + + slot2 = ctx._allocate_mamba_cache_slot(1) + assert slot2 == slot, "Should reuse freed slot" + + @pytest.mark.internal + def test_evicted_block_not_prefix_matchable(self): + """After Mamba state is evicted for a block, that block should NOT be prefix-matched + on a hybrid model (due to the coupled KV+Mamba fix).""" + # Use very small Mamba cache (1-2 slots) so eviction is easy to trigger + ctx = _build_hybrid_context( + block_size=32, + buffer_size_gb=0.01, + prefix_caching_mamba_gb=0.001, + max_tokens=None, + max_requests=8, + ) + block_size = ctx.block_size_tokens + max_slots = ctx.max_mamba_cache_slots + assert max_slots >= 1, f"Need at least 1 Mamba slot, got {max_slots}" + + # --- Request A: 64 tokens (2 blocks) --- + a_idx = ctx.total_request_count # 0 + prompt_a = torch.arange(block_size * 2, device=torch.cuda.current_device()) + req_a = _make_request(1, prompt_a, block_size) + ctx.add_request(req_a) + + + # Store Mamba state for block 1 (last complete block of A) + block_1_id = ctx.request_to_kv_block_ids[a_idx][1].item() + ctx.store_mamba_state_for_block(block_1_id, a_idx) + assert ctx.has_mamba_state_for_block(block_1_id) + + # Release A so its blocks become cached (ref_count=0) + ctx.release_memory_blocks_from_request_indexes(torch.tensor([a_idx])) + + # --- Request C: different 64 tokens (forces Mamba eviction) --- + c_idx = ctx.total_request_count # 1 + prompt_c = torch.arange(1000, 1000 + block_size * 2, device=torch.cuda.current_device()) + req_c = _make_request(2, prompt_c, block_size) + ctx.add_request(req_c) + + + # Store Mamba state for C's last block -- may evict block_1's Mamba state + c_block_1_id = ctx.request_to_kv_block_ids[c_idx][1].item() + ctx.store_mamba_state_for_block(c_block_1_id, c_idx) + + # Release C + ctx.release_memory_blocks_from_request_indexes(torch.tensor([c_idx])) + + # --- Request B: same prefix as A --- + b_idx = ctx.total_request_count # 2 + req_b = _make_request(3, prompt_a.clone(), block_size) + # Simulate engine: no Mamba state for matched blocks + req_b._mamba_num_matched_blocks = 0 + ctx.add_request(req_b) + + # B's Mamba state should be zero-initialized (not restored from cache) + mamba_idx_b = ctx.mamba_metadata.request_to_mamba_state_idx[b_idx].item() + assert torch.all(ctx.mamba_conv_states[:, mamba_idx_b] == 0.0), \ + "B's Mamba state should be zero (no cache restore)" + + +class TestMambaPrefixMatching: + """Tests for the coupled KV+Mamba prefix matching fix.""" + + def setup_method(self, method): + Utils.initialize_model_parallel( + tensor_model_parallel_size=1, pipeline_model_parallel_size=1 + ) + model_parallel_cuda_manual_seed(123) + + def teardown_method(self, method): + Utils.destroy_model_parallel() + + @pytest.mark.internal + def test_hybrid_model_no_mamba_budget_no_prefix_match(self): + """With prefix_caching_mamba_gb=None, hybrid model should never prefix-match + because there's no Mamba state to restore.""" + ctx = _build_hybrid_context( + block_size=32, + buffer_size_gb=0.01, + enable_prefix_caching=True, + prefix_caching_mamba_gb=None, # No Mamba budget + max_tokens=None, + ) + block_size = ctx.block_size_tokens + assert ctx.max_mamba_cache_slots == 0 + + # Add request A + a_idx = ctx.total_request_count # 0 + prompt = torch.arange(block_size * 2, device=torch.cuda.current_device()) + req_a = _make_request(1, prompt, block_size) + ctx.add_request(req_a) + + + # Release A + ctx.release_memory_blocks_from_request_indexes(torch.tensor([a_idx])) + + # Add request B with same prefix -- no _mamba_num_matched_blocks set + b_idx = ctx.total_request_count # 1 + req_b = _make_request(2, prompt.clone(), block_size) + ctx.add_request(req_b) + + # The coupled fix: getattr(req, '_mamba_num_matched_blocks', 0) == 0 + # so num_matched_blocks gets set to 0, even though KV blocks are cached. + # B's Mamba state should be zero-initialized + mamba_idx_b = ctx.mamba_metadata.request_to_mamba_state_idx[b_idx].item() + assert torch.all(ctx.mamba_conv_states[:, mamba_idx_b] == 0.0), \ + "With no Mamba budget, B should get zero-init Mamba state" + + @pytest.mark.internal + def test_hybrid_model_with_mamba_budget_prefix_matches(self): + """With Mamba budget and cached state, prefix matching should work correctly.""" + ctx = _build_hybrid_context( + block_size=32, + buffer_size_gb=0.01, + enable_prefix_caching=True, + prefix_caching_mamba_gb=0.01, + max_tokens=None, + eviction_policy=PrefixCachingEvictionPolicy.LRU, + ) + block_size = ctx.block_size_tokens + assert ctx.max_mamba_cache_slots > 0 + + # Add and process request A + a_idx = ctx.total_request_count # 0 + prompt = torch.arange(block_size * 2, device=torch.cuda.current_device()) + req_a = _make_request(1, prompt, block_size) + ctx.add_request(req_a) + + + # Store Mamba state for A's blocks + mamba_idx_a = ctx.mamba_metadata.request_to_mamba_state_idx[a_idx].item() + ctx.mamba_conv_states[:, mamba_idx_a] = 7.0 + ctx.mamba_ssm_states[:, mamba_idx_a] = 14.0 + + block_0_id = ctx.request_to_kv_block_ids[a_idx][0].item() + block_1_id = ctx.request_to_kv_block_ids[a_idx][1].item() + ctx.store_mamba_state_for_block(block_0_id, a_idx) + ctx.store_mamba_state_for_block(block_1_id, a_idx) + + # Release A + ctx.release_memory_blocks_from_request_indexes(torch.tensor([a_idx])) + + # Add request B with same prefix and _mamba_num_matched_blocks set + b_idx = ctx.total_request_count # 1 + req_b = _make_request(2, prompt.clone(), block_size) + req_b._mamba_num_matched_blocks = 2 # Both blocks have Mamba state + ctx.add_request(req_b) + + # B should have restored Mamba state from the last matched block (block 1) + mamba_idx_b = ctx.mamba_metadata.request_to_mamba_state_idx[b_idx].item() + assert torch.allclose( + ctx.mamba_conv_states[:, mamba_idx_b], + torch.full_like(ctx.mamba_conv_states[:, mamba_idx_b], 7.0), + ), "B should have restored Mamba conv state from block 1" + assert torch.allclose( + ctx.mamba_ssm_states[:, mamba_idx_b], + torch.full_like(ctx.mamba_ssm_states[:, mamba_idx_b], 14.0), + ), "B should have restored Mamba SSM state from block 1" + + @pytest.mark.internal + def test_mamba_match_limits_kv_match(self): + """KV matches 3 blocks but Mamba only has state for 1 → effective match = 1.""" + ctx = _build_hybrid_context( + block_size=32, + buffer_size_gb=0.01, + enable_prefix_caching=True, + prefix_caching_mamba_gb=0.001, # Very small -- few slots + max_tokens=None, + eviction_policy=PrefixCachingEvictionPolicy.LRU, + ) + block_size = ctx.block_size_tokens + + # Add request A with 3 complete blocks (96 tokens) + a_idx = ctx.total_request_count # 0 + prompt = torch.arange(block_size * 3, device=torch.cuda.current_device()) + req_a = _make_request(1, prompt, block_size) + ctx.add_request(req_a) + + + # Store Mamba state ONLY for block 0 + mamba_idx_a = ctx.mamba_metadata.request_to_mamba_state_idx[a_idx].item() + ctx.mamba_conv_states[:, mamba_idx_a] = 99.0 + ctx.mamba_ssm_states[:, mamba_idx_a] = 99.0 + block_0_id = ctx.request_to_kv_block_ids[a_idx][0].item() + ctx.store_mamba_state_for_block(block_0_id, a_idx) + + # Release A + ctx.release_memory_blocks_from_request_indexes(torch.tensor([a_idx])) + + # Add request B with same 96 tokens + req_b = _make_request(2, prompt.clone(), block_size) + # Engine computed: KV matches 3, but Mamba only for 1 + req_b._mamba_num_matched_blocks = 1 + b_idx = ctx.total_request_count # B's index before add_request increments it + ctx.add_request(req_b) + + # B should have restored Mamba state from block 0 + mamba_idx_b = ctx.mamba_metadata.request_to_mamba_state_idx[b_idx].item() + assert torch.allclose( + ctx.mamba_conv_states[:, mamba_idx_b], + torch.full_like(ctx.mamba_conv_states[:, mamba_idx_b], 99.0), + ), "B should restore Mamba state from block 0" + + # B should have matched only 1 block (not 3), meaning it allocated 2 new blocks + # Verify by checking that B has blocks assigned for its 3-block request + b_blocks = ctx.request_to_kv_block_ids[b_idx][:3].tolist() + # Block 0 should be shared (same as A's block 0) + assert b_blocks[0] == block_0_id, "Block 0 should be shared from A" + + @pytest.mark.internal + def test_mamba_match_zero_limits_all_kv_matches(self): + """KV matches 2 blocks but Mamba has state for 0 → effective match = 0.""" + ctx = _build_hybrid_context( + block_size=32, + buffer_size_gb=0.01, + enable_prefix_caching=True, + prefix_caching_mamba_gb=0.01, + max_tokens=None, + eviction_policy=PrefixCachingEvictionPolicy.LRU, + ) + block_size = ctx.block_size_tokens + + # Add request A + a_idx = ctx.total_request_count # 0 + prompt = torch.arange(block_size * 2, device=torch.cuda.current_device()) + req_a = _make_request(1, prompt, block_size) + ctx.add_request(req_a) + + + # Don't store any Mamba state for A's blocks + # Release A + ctx.release_memory_blocks_from_request_indexes(torch.tensor([a_idx])) + + # Add request B with same prefix, but engine says 0 Mamba matches + b_idx = ctx.total_request_count # 1 + req_b = _make_request(2, prompt.clone(), block_size) + req_b._mamba_num_matched_blocks = 0 + ctx.add_request(req_b) + + # B should have zero-init Mamba state (no restore) + mamba_idx_b = ctx.mamba_metadata.request_to_mamba_state_idx[b_idx].item() + assert torch.all(ctx.mamba_conv_states[:, mamba_idx_b] == 0.0), \ + "B should get zero-init Mamba state when _mamba_num_matched_blocks=0" + + +def _skip_if_mamba_sequence_packing_not_available(): + """Skip if Mamba sequence packing is not available.""" + sequence_packing_available, reason = _check_mamba_sequence_packing_support() + if not sequence_packing_available: + pytest.skip(reason) + + +def _build_engine( + enable_chunked_prefill, + enable_prefix_caching, + prefix_caching_mamba_gb, + block_size_tokens, + max_tokens, + max_requests, + vocab_size, + max_sequence_length, + buffer_size_gb, + seed, + eviction_policy=PrefixCachingEvictionPolicy.REF_ZERO, +): + """Build a full MambaModel engine stack for end-to-end testing. + + Returns (engine, model) tuple. + """ + _set_rounder(4) + + # Seed RNG for reproducible model weights. + random.seed(seed) + torch.manual_seed(seed) + model_parallel_cuda_manual_seed( + seed=seed, + inference_rng_tracker=True, + use_cudagraphable_rng=False, + force_reset_rng=True, + ) + + transformer_config = TransformerConfig( + params_dtype=torch.bfloat16, + num_layers=3, # 1 Mamba + 1 attention + 1 MLP + hidden_size=256, + mamba_num_heads=16, + num_attention_heads=16, + use_cpu_initialization=True, + cuda_graph_impl="none", + inference_rng_tracker=True, + tensor_model_parallel_size=1, + pipeline_model_parallel_size=1, + pipeline_dtype=torch.bfloat16, + add_bias_linear=True, + is_hybrid_model=True, + ) + + model = MambaModel( + config=transformer_config, + mamba_stack_spec=mamba_stack_spec, + vocab_size=vocab_size, + max_sequence_length=max_sequence_length, + parallel_output=True, + hybrid_attention_ratio=0.3, + hybrid_mlp_ratio=0.3, + pre_process=parallel_state.is_pipeline_first_stage(), + post_process=parallel_state.is_pipeline_last_stage(), + ).cuda() + + for param in model.parameters(): + param.data = param.data.to(transformer_config.params_dtype) + model.eval() + + mamba_inference_state_config = MambaInferenceStateConfig.from_model(model) + + context = DynamicInferenceContext( + model_config=transformer_config, + inference_config=InferenceConfig( + max_sequence_length=max_sequence_length, + buffer_size_gb=buffer_size_gb, + paused_buffer_size_gb=0.2 * buffer_size_gb, + block_size_tokens=block_size_tokens, + max_tokens=max_tokens, + max_requests=max_requests, + mamba_inference_state_config=mamba_inference_state_config, + enable_chunked_prefill=enable_chunked_prefill, + enable_prefix_caching=enable_prefix_caching, + prefix_caching_mamba_gb=prefix_caching_mamba_gb, + prefix_caching_eviction_policy=eviction_policy, + materialize_only_last_token_logits=False, + use_flashinfer_fused_rope=None, + unified_memory_level=0, + ), + ) + + wrapped = GPTInferenceWrapper(model, context) + wrapped.model_is_pipeline_parallel = not ( + parallel_state.is_pipeline_first_stage() and parallel_state.is_pipeline_last_stage() + ) + + controller = TextGenerationController( + inference_wrapped_model=wrapped, + tokenizer=types.SimpleNamespace( + vocab_size=vocab_size, + detokenize=lambda tokens: "text", + ), + ) + + _CudagraphGlobalRecord.cudagraph_created = False + _CudagraphGlobalRecord.cudagraph_record = [] + CudaGraphManager.global_mempool = None + + engine = DynamicInferenceEngine(controller, context) + return engine, model + + +def _install_deterministic_mock_forward(engine, vocab_size): + """Replace the model's forward with a deterministic mock. + + The mock produces logits that are a deterministic function of position_ids, + ensuring identical output tokens across engine configurations (chunked vs + non-chunked, prefix-cached vs fresh) while avoiding TE GEMM issues. + + The key insight: logits depend on position_ids (not input_ids), so + regardless of how the prompt is chunked, the logit at position P is + always the same. This simulates a position-aware model where the output + is independent of chunking boundaries. + """ + def mock_forward(input_ids, position_ids, attention_mask, *args, **kwargs): + batch, seq_len = input_ids.shape + logits = torch.zeros( + batch, seq_len, vocab_size, device=input_ids.device, dtype=torch.bfloat16 + ) + for b in range(batch): + for s in range(seq_len): + pos = position_ids[b, s].item() + # Next token = (position + 7) % vocab_size, avoiding 0 and termination_id + predicted = (pos + 7) % (vocab_size - 1) + 1 + logits[b, s, predicted] = 10.0 + return logits + + model = engine.controller.inference_wrapped_model.model + model.forward = mock_forward + + +def _run_to_completion(engine, requests): + """Add requests, step engine until all complete, return generated tokens per request_id.""" + for req in requests: + engine._add_request(req) + + results = {} + while engine.has_unfinished_requests(): + result = engine.step_modern() + for record in result["finished_request_records"]: + finished = record.merge() + results[finished.request_id] = list(finished.generated_tokens) + + return results + + +class TestCrossConfigEndToEnd: + """Engine-level test verifying that chunked_prefill x prefix_caching + configurations produce identical output tokens through a MambaModel engine + stack with a deterministic mock forward. + + Scenario: + block_size=32, max_tokens=80, num_tokens_to_generate=4 + Request A: 64 tokens (2 blocks) + Request B: 100 tokens (first 64 shared with A, 36 unique) + + When B is scheduled with chunked prefill + prefix caching: + Block: |---block 0---|---block 1---|---block 2---|--block 3--| + Tokens: 0 32 64 96 100 + |<-- prefix match (Mamba cached) -->| + |<----------- chunk 1 (80 tokens) ---------->| + |<- chunk 2 ->| + + Three configs compared: + 1. chunked=True, prefix=True (interleaved boundaries) + 2. chunked=True, prefix=False (full prefill in chunks) + 3. chunked=False, prefix=False (full prefill at once) + + All three must produce identical output tokens for both A and B. + + Uses a deterministic mock forward (position-based logits) to ensure + reproducible comparisons while exercising the full engine scheduling, + chunked prefill, and prefix caching codepaths. + """ + + def teardown_method(self, method): + Utils.destroy_model_parallel() + + @pytest.mark.internal + @pytest.mark.skipif( + not is_fa_min_version("2.7.3"), reason="need latest flash attn for dynamic batching" + ) + @torch.no_grad() + def test_interleaving_boundaries(self): + _skip_if_mamba_sequence_packing_not_available() + + # --- Parameters --- + seed = 42 + vocab_size = 100 + block_size = 32 + max_tokens = 80 + max_sequence_length = 256 + buffer_size_gb = 0.1 + num_tokens_to_generate = 4 + + # Seed once for deterministic prompt generation. + torch.manual_seed(seed) + + # Shared prefix (64 tokens = 2 blocks) + B's unique suffix (36 tokens) + prompt_a = torch.randint(0, vocab_size - 1, (64,), device="cuda", dtype=torch.int64) + prompt_b = torch.cat([ + prompt_a, + torch.randint(0, vocab_size - 1, (36,), device="cuda", dtype=torch.int64), + ]) + + configs = [ + { + "name": "chunked+prefix", + "enable_chunked_prefill": True, + "enable_prefix_caching": True, + "prefix_caching_mamba_gb": 0.01, + "max_tokens": max_tokens, + "eviction_policy": PrefixCachingEvictionPolicy.LRU, + }, + { + "name": "chunked_only", + "enable_chunked_prefill": True, + "enable_prefix_caching": False, + "prefix_caching_mamba_gb": None, + "max_tokens": max_tokens, + }, + { + "name": "baseline", + "enable_chunked_prefill": False, + "enable_prefix_caching": False, + "prefix_caching_mamba_gb": None, + "max_tokens": None, # no chunking limit needed + }, + ] + + all_results = {} + for config in configs: + # Re-init model parallel for each config + Utils.destroy_model_parallel() + Utils.initialize_model_parallel( + tensor_model_parallel_size=1, pipeline_model_parallel_size=1 + ) + + engine, model = _build_engine( + enable_chunked_prefill=config["enable_chunked_prefill"], + enable_prefix_caching=config["enable_prefix_caching"], + prefix_caching_mamba_gb=config["prefix_caching_mamba_gb"], + block_size_tokens=block_size, + max_tokens=config["max_tokens"], + max_requests=4, + vocab_size=vocab_size, + max_sequence_length=max_sequence_length, + buffer_size_gb=buffer_size_gb, + seed=seed, + eviction_policy=config.get("eviction_policy", PrefixCachingEvictionPolicy.REF_ZERO), + ) + _install_deterministic_mock_forward(engine, vocab_size) + + # --- Request A: 64 tokens → run to completion --- + req_a = DynamicInferenceRequest( + request_id=0, + prompt_tokens=prompt_a.clone(), + sampling_params=SamplingParams( + num_tokens_to_generate=num_tokens_to_generate, + termination_id=-1, + ), + block_size_tokens=block_size, + enable_prefix_caching=config["enable_prefix_caching"], + ) + results_a = _run_to_completion(engine, [req_a]) + + # --- Request B: 100 tokens (shares first 64 with A) → run to completion --- + req_b = DynamicInferenceRequest( + request_id=1, + prompt_tokens=prompt_b.clone(), + sampling_params=SamplingParams( + num_tokens_to_generate=num_tokens_to_generate, + termination_id=-1, + ), + block_size_tokens=block_size, + enable_prefix_caching=config["enable_prefix_caching"], + ) + results_b = _run_to_completion(engine, [req_b]) + + all_results[config["name"]] = { + "a": results_a[0], + "b": results_b[1], + } + + # --- Assertions: all configs must produce identical tokens --- + names = list(all_results.keys()) + for i in range(1, len(names)): + for req_label in ("a", "b"): + tokens_ref = all_results[names[0]][req_label] + tokens_cur = all_results[names[i]][req_label] + assert tokens_ref == tokens_cur, ( + f"Request {req_label} mismatch between '{names[0]}' and '{names[i]}':\n" + f" {names[0]}: {tokens_ref}\n" + f" {names[i]}: {tokens_cur}" + ) + + +class TestBudgetZero: + """Tests for zero-budget and disabled Mamba caching scenarios.""" + + def setup_method(self, method): + Utils.initialize_model_parallel( + tensor_model_parallel_size=1, pipeline_model_parallel_size=1 + ) + model_parallel_cuda_manual_seed(123) + + def teardown_method(self, method): + Utils.destroy_model_parallel() + + @pytest.mark.internal + def test_zero_mamba_budget_disables_mamba_caching(self): + """With prefix_caching_mamba_gb=None, max_mamba_cache_slots should be 0.""" + ctx = _build_hybrid_context( + enable_prefix_caching=True, + prefix_caching_mamba_gb=None, + ) + assert ctx.max_mamba_cache_slots == 0 + + @pytest.mark.internal + def test_zero_mamba_budget_with_prefix_caching_still_works_for_non_hybrid(self): + """Non-hybrid model with prefix caching should work normally without Mamba cache.""" + _set_rounder(64) + + transformer_config = TransformerConfig( + params_dtype=torch.float32, + num_layers=4, + kv_channels=8, + num_attention_heads=2, + hidden_size=16, + tensor_model_parallel_size=1, + pipeline_model_parallel_size=1, + use_cpu_initialization=True, + ) + inference_config = InferenceConfig( + max_sequence_length=1024, + buffer_size_gb=0.01, + paused_buffer_size_gb=0.002, + block_size_tokens=32, + max_tokens=None, + mamba_inference_state_config=None, # Non-hybrid + use_flashinfer_fused_rope=None, + unified_memory_level=0, + enable_prefix_caching=True, + prefix_caching_mamba_gb=None, + ) + ctx = DynamicInferenceContext( + model_config=transformer_config, inference_config=inference_config + ) + + assert ctx.max_mamba_cache_slots == 0 + assert not ctx.is_hybrid_model + + # Prefix caching should still work for KV blocks + block_size = ctx.block_size_tokens + prompt = torch.arange(block_size * 2, device=torch.cuda.current_device()) + + a_idx = ctx.total_request_count # 0 + req_a = _make_request(1, prompt, block_size) + ctx.add_request(req_a) + + + a_block_0 = ctx.request_to_kv_block_ids[a_idx][0].item() + ctx.release_memory_blocks_from_request_indexes(torch.tensor([a_idx])) + + # Request B with same prefix should get prefix match (no hybrid limitation) + b_idx = ctx.total_request_count # 1 + req_b = _make_request(2, prompt.clone(), block_size) + ctx.add_request(req_b) + + b_block_0 = ctx.request_to_kv_block_ids[b_idx][0].item() + assert b_block_0 == a_block_0, "Non-hybrid model should still prefix-match KV blocks" + + @pytest.mark.internal + def test_negative_mamba_budget_disables_caching(self): + """Negative budget should result in zero Mamba slots.""" + ctx = _build_hybrid_context( + enable_prefix_caching=True, + prefix_caching_mamba_gb=-0.01, + ) + assert ctx.max_mamba_cache_slots == 0 + + @pytest.mark.internal + def test_tiny_mamba_budget_zero_slots(self): + """Extremely tiny budget that can't fit even 1 slot.""" + ctx = _build_hybrid_context( + enable_prefix_caching=True, + prefix_caching_mamba_gb=1e-12, + ) + assert ctx.max_mamba_cache_slots == 0 + + +class TestMultiplePrefillWithInitialStates: + """Engine-level test verifying numerical correctness when multiple prefill + requests with restored Mamba states run simultaneously. + + Scenario: + block_size=32, max_tokens=256 + + 1. Request A: 128-token prompt -> run to completion + - Stores Mamba states at block boundaries (blocks 0-3) + + 2. Request B: same 64-token prefix as A + 32 unique tokens (total 96) + - Restores Mamba state from block 1 (divergence at token 64) + + 3. Request C: same 64-token prefix as A + 32 different unique tokens (total 96) + - Also restores Mamba state from block 1 + + 4. Schedule B and C simultaneously -> both have initial states + -> Both go through unified varlen path + + 5. Compare: run B alone and C alone in separate engine instances + -> outputs must match the simultaneous run + """ + + def teardown_method(self, method): + Utils.destroy_model_parallel() + + @pytest.mark.internal + @pytest.mark.skipif( + not is_fa_min_version("2.7.3"), reason="need latest flash attn for dynamic batching" + ) + @torch.no_grad() + def test_multiple_prefill_with_restored_states(self): + _skip_if_mamba_sequence_packing_not_available() + + # --- Parameters --- + seed = 42 + vocab_size = 100 + block_size = 32 + max_tokens = 256 + max_requests = 8 + max_sequence_length = 512 + buffer_size_gb = 0.1 + num_tokens_to_generate = 4 + + # Generate deterministic prompts + torch.manual_seed(seed) + shared_prefix = torch.randint(0, vocab_size - 1, (64,), device="cuda", dtype=torch.int64) + suffix_b = torch.randint(0, vocab_size - 1, (32,), device="cuda", dtype=torch.int64) + suffix_c = torch.randint(0, vocab_size - 1, (32,), device="cuda", dtype=torch.int64) + prompt_a = torch.cat([ + shared_prefix, + torch.randint(0, vocab_size - 1, (64,), device="cuda", dtype=torch.int64), + ]) # 128 tokens + prompt_b = torch.cat([shared_prefix, suffix_b]) # 96 tokens + prompt_c = torch.cat([shared_prefix, suffix_c]) # 96 tokens + + engine_kwargs = dict( + enable_chunked_prefill=True, + enable_prefix_caching=True, + prefix_caching_mamba_gb=0.01, + block_size_tokens=block_size, + max_tokens=max_tokens, + max_requests=max_requests, + vocab_size=vocab_size, + max_sequence_length=max_sequence_length, + buffer_size_gb=buffer_size_gb, + seed=seed, + eviction_policy=PrefixCachingEvictionPolicy.LRU, + ) + + def make_req(req_id, prompt): + return DynamicInferenceRequest( + request_id=req_id, + prompt_tokens=prompt.clone(), + sampling_params=SamplingParams( + num_tokens_to_generate=num_tokens_to_generate, + termination_id=-1, + ), + block_size_tokens=block_size, + enable_prefix_caching=True, + ) + + # --- Simultaneous run: A first, then B and C together --- + Utils.destroy_model_parallel() + Utils.initialize_model_parallel( + tensor_model_parallel_size=1, pipeline_model_parallel_size=1 + ) + engine_sim, _ = _build_engine(**engine_kwargs) + _install_deterministic_mock_forward(engine_sim, vocab_size) + + # Run A to completion (populates Mamba state cache) + results_a_sim = _run_to_completion(engine_sim, [make_req(0, prompt_a)]) + assert 0 in results_a_sim + + # Run B and C simultaneously (both should get restored Mamba states) + engine_sim._add_request(make_req(1, prompt_b)) + engine_sim._add_request(make_req(2, prompt_c)) + results_bc_sim = {} + while engine_sim.has_unfinished_requests(): + engine_sim.schedule_waiting_requests() + result = engine_sim.step_modern() + for record in result["finished_request_records"]: + finished = record.merge() + results_bc_sim[finished.request_id] = list(finished.generated_tokens) + + assert 1 in results_bc_sim, "Request B did not complete" + assert 2 in results_bc_sim, "Request C did not complete" + + # --- Individual run: A then B alone --- + Utils.destroy_model_parallel() + Utils.initialize_model_parallel( + tensor_model_parallel_size=1, pipeline_model_parallel_size=1 + ) + engine_b, _ = _build_engine(**engine_kwargs) + _install_deterministic_mock_forward(engine_b, vocab_size) + + _run_to_completion(engine_b, [make_req(0, prompt_a)]) + results_b_individual = _run_to_completion(engine_b, [make_req(1, prompt_b)]) + + # --- Individual run: A then C alone --- + Utils.destroy_model_parallel() + Utils.initialize_model_parallel( + tensor_model_parallel_size=1, pipeline_model_parallel_size=1 + ) + engine_c, _ = _build_engine(**engine_kwargs) + _install_deterministic_mock_forward(engine_c, vocab_size) + + _run_to_completion(engine_c, [make_req(0, prompt_a)]) + results_c_individual = _run_to_completion(engine_c, [make_req(2, prompt_c)]) + + # --- Assertions --- + assert results_bc_sim[1] == results_b_individual[1], ( + f"Request B mismatch:\n" + f" simultaneous: {results_bc_sim[1]}\n" + f" individual: {results_b_individual[1]}" + ) + assert results_bc_sim[2] == results_c_individual[2], ( + f"Request C mismatch:\n" + f" simultaneous: {results_bc_sim[2]}\n" + f" individual: {results_c_individual[2]}" + ) + + # Verify non-trivial output + assert len(results_bc_sim[1]) == num_tokens_to_generate + assert len(results_bc_sim[2]) == num_tokens_to_generate + + +class TestMambaHashMap: + """Tests for the two-map design: kv_hash_to_block_id + mamba_hash_to_block_id.""" + + def setup_method(self, method): + Utils.initialize_model_parallel( + tensor_model_parallel_size=1, pipeline_model_parallel_size=1 + ) + model_parallel_cuda_manual_seed(123) + + def teardown_method(self, method): + Utils.destroy_model_parallel() + + @pytest.mark.internal + def test_mamba_hash_registered_on_store(self): + """Storing Mamba state for a block registers its hash in mamba_hash_to_block_id.""" + ctx = _build_hybrid_context(prefix_caching_mamba_gb=0.01) + block_size = ctx.block_size_tokens + alloc = ctx.block_allocator + + req = _make_request(1, block_size * 2, block_size) + ctx.add_request(req) + + block_0_id = ctx.request_to_kv_block_ids[0][0].item() + block_0_hash = alloc.block_hashes[block_0_id].item() + + # Before store: hash in kv map but not mamba map + assert block_0_hash in alloc.kv_hash_to_block_id + assert block_0_hash not in alloc.mamba_hash_to_block_id + + # Store and register + ctx.store_mamba_state_for_block(block_0_id, 0) + alloc.register_mamba_block_hash(block_0_id, block_0_hash) + + # After store: hash in both maps + assert block_0_hash in alloc.kv_hash_to_block_id + assert block_0_hash in alloc.mamba_hash_to_block_id + assert alloc.mamba_hash_to_block_id[block_0_hash] == block_0_id + + @pytest.mark.internal + def test_mamba_hash_removed_on_mamba_eviction(self): + """Mamba eviction removes hash from mamba_hash_to_block_id but keeps kv_hash_to_block_id.""" + ctx = _build_hybrid_context(prefix_caching_mamba_gb=0.001, eviction_policy=PrefixCachingEvictionPolicy.LRU) + block_size = ctx.block_size_tokens + alloc = ctx.block_allocator + max_slots = ctx.max_mamba_cache_slots + assert max_slots >= 1 + + # Add request, store mamba state, register hash + req = _make_request(1, block_size * 2, block_size) + ctx.add_request(req) + block_0_id = ctx.request_to_kv_block_ids[0][0].item() + block_0_hash = alloc.block_hashes[block_0_id].item() + ctx.store_mamba_state_for_block(block_0_id, 0) + alloc.register_mamba_block_hash(block_0_id, block_0_hash) + + assert block_0_hash in alloc.mamba_hash_to_block_id + + # Release request so blocks become evictable (ref_count = 0) + ctx.release_memory_blocks_from_request_indexes(torch.tensor([0])) + + # Make block_0 the oldest (timestamp=0) so it gets evicted first + alloc.block_timestamps[block_0_id] = 0 + + # Fill remaining mamba slots with valid block IDs (use low IDs that + # are within total_count and not already mapped) + filled = 0 + for i in range(alloc.total_count): + if filled >= max_slots: + break + if ctx.block_to_mamba_slot[i].item() < 0: + alloc.block_ref_counts[i] = 0 + alloc.block_timestamps[i] = 1000 + filled # newer than block_0 + ctx._allocate_mamba_cache_slot(i) + filled += 1 + + # After eviction: mamba hash gone, kv hash remains + assert block_0_hash not in alloc.mamba_hash_to_block_id + assert block_0_hash in alloc.kv_hash_to_block_id + + @pytest.mark.internal + def test_mamba_hash_removed_on_kv_eviction(self): + """KV block eviction removes hash from both kv and mamba hash maps.""" + ctx = _build_hybrid_context(prefix_caching_mamba_gb=0.01, buffer_size_gb=0.01, rounder=1, eviction_policy=PrefixCachingEvictionPolicy.LRU) + block_size = ctx.block_size_tokens + alloc = ctx.block_allocator + + # Add request, store mamba state, register hash + req = _make_request(1, block_size * 2, block_size) + ctx.add_request(req) + block_0_id = ctx.request_to_kv_block_ids[0][0].item() + block_1_id = ctx.request_to_kv_block_ids[0][1].item() + block_0_hash = alloc.block_hashes[block_0_id].item() + ctx.store_mamba_state_for_block(block_0_id, 0) + alloc.register_mamba_block_hash(block_0_id, block_0_hash) + + assert block_0_hash in alloc.kv_hash_to_block_id + assert block_0_hash in alloc.mamba_hash_to_block_id + + # Release so blocks become cached (ref_count=0) + ctx.release_memory_blocks_from_request_indexes(torch.tensor([0])) + ctx.total_request_count = 0 + + # Directly evict via _deregister_blocks (simulates LRU eviction) + blocks_to_evict = torch.tensor([block_0_id, block_1_id], + device=torch.cuda.current_device(), + dtype=torch.int32) + alloc._deregister_blocks(blocks_to_evict) + + # After KV eviction: both hashes should be gone + assert block_0_hash not in alloc.kv_hash_to_block_id + assert block_0_hash not in alloc.mamba_hash_to_block_id + + @pytest.mark.internal + def test_reset_clears_mamba_hash_map(self): + """reset() clears mamba_hash_to_block_id.""" + ctx = _build_hybrid_context(prefix_caching_mamba_gb=0.01) + block_size = ctx.block_size_tokens + alloc = ctx.block_allocator + + req = _make_request(1, block_size * 2, block_size) + ctx.add_request(req) + block_0_id = ctx.request_to_kv_block_ids[0][0].item() + block_0_hash = alloc.block_hashes[block_0_id].item() + ctx.store_mamba_state_for_block(block_0_id, 0) + alloc.register_mamba_block_hash(block_0_id, block_0_hash) + + assert len(alloc.mamba_hash_to_block_id) > 0 + + alloc.reset() + assert len(alloc.mamba_hash_to_block_id) == 0 + assert len(alloc.kv_hash_to_block_id) == 0 + + @pytest.mark.internal + def test_kv_match_extends_beyond_mamba_match(self): + """KV match can cover more blocks than mamba match (two-map design).""" + ctx = _build_hybrid_context( + prefix_caching_mamba_gb=0.001, + buffer_size_gb=0.01, + max_tokens=None, + ) + block_size = ctx.block_size_tokens + alloc = ctx.block_allocator + + # Add request A with 3 blocks + prompt = torch.arange(block_size * 3, device=torch.cuda.current_device()) + req_a = _make_request(1, prompt, block_size) + ctx.add_request(req_a) + + # Register mamba state only for block 0 + block_0_id = ctx.request_to_kv_block_ids[0][0].item() + block_0_hash = alloc.block_hashes[block_0_id].item() + ctx.store_mamba_state_for_block(block_0_id, 0) + alloc.register_mamba_block_hash(block_0_id, block_0_hash) + + # All 3 blocks in kv map + for i in range(3): + bid = ctx.request_to_kv_block_ids[0][i].item() + h = alloc.block_hashes[bid].item() + assert h in alloc.kv_hash_to_block_id, f"Block {i} should be in kv map" + + # Only block 0 in mamba map + assert block_0_hash in alloc.mamba_hash_to_block_id + for i in range(1, 3): + bid = ctx.request_to_kv_block_ids[0][i].item() + h = alloc.block_hashes[bid].item() + assert h not in alloc.mamba_hash_to_block_id, f"Block {i} should NOT be in mamba map" + + +class TestChunkedPrefillMambaState: + """Tests for chunked prefill requests that receive intermediate Mamba state. + + Covers two critical scenarios: + 1a. A request whose KV match extends beyond its Mamba match gets its KV match + truncated at the divergence boundary, restores Mamba state from cache for + the matched portion, and stores new Mamba state at the divergence boundary. + 1b. A request with a non-block-aligned prompt and no prefix match gets + forced-chunked at the last-aligned boundary, with Mamba state stored after + the first chunk and continued via per-request state in the second chunk. + """ + + def teardown_method(self, method): + Utils.destroy_model_parallel() + + @pytest.mark.internal + @pytest.mark.skipif( + not is_fa_min_version("2.7.3"), reason="need latest flash attn for dynamic batching" + ) + @torch.no_grad() + def test_divergence_boundary_restores_and_stores_mamba_state(self): + """Request C shares 96 tokens with B (3 KV blocks matched), but only 2 blocks + have Mamba state. C restores Mamba from block 1, processes 32 effective tokens, + and stores Mamba state for block 2 (the divergence boundary). + + Sequence: + A (64 tokens) → Mamba stored for block 1 + B (128 tokens, first 64 shared with A) → Mamba stored for block 3 + C (96 tokens, first 96 shared with B) → KV match=3, Mamba match=2 + → restores from block 1, stores at block 2 + + Verified by comparing C's output between chunked+prefix and chunked_only configs. + """ + _skip_if_mamba_sequence_packing_not_available() + + # --- Parameters --- + seed = 42 + vocab_size = 100 + block_size = 32 + max_tokens = 256 + max_sequence_length = 256 + buffer_size_gb = 0.1 + num_tokens_to_generate = 4 + + # Deterministic prompt generation. + torch.manual_seed(seed) + shared_ab = torch.randint(0, vocab_size - 1, (64,), device="cuda", dtype=torch.int64) + suffix_b = torch.randint(0, vocab_size - 1, (64,), device="cuda", dtype=torch.int64) + + prompt_a = shared_ab.clone() # 64 tokens (2 blocks) + prompt_b = torch.cat([shared_ab, suffix_b]) # 128 tokens (4 blocks) + # C shares first 96 tokens with B (blocks 0, 1, 2) + prompt_c = torch.cat([shared_ab, suffix_b[:32]]) # 96 tokens (3 blocks) + + def make_req(req_id, prompt, enable_pc): + return DynamicInferenceRequest( + request_id=req_id, + prompt_tokens=prompt.clone(), + sampling_params=SamplingParams( + num_tokens_to_generate=num_tokens_to_generate, + termination_id=-1, + ), + block_size_tokens=block_size, + enable_prefix_caching=enable_pc, + ) + + configs = [ + { + "name": "chunked+prefix", + "enable_chunked_prefill": True, + "enable_prefix_caching": True, + "prefix_caching_mamba_gb": 0.01, + "max_tokens": max_tokens, + "eviction_policy": PrefixCachingEvictionPolicy.LRU, + }, + { + "name": "chunked_only", + "enable_chunked_prefill": True, + "enable_prefix_caching": False, + "prefix_caching_mamba_gb": None, + "max_tokens": max_tokens, + }, + ] + + all_results = {} + for config in configs: + Utils.destroy_model_parallel() + Utils.initialize_model_parallel( + tensor_model_parallel_size=1, pipeline_model_parallel_size=1 + ) + + engine, _ = _build_engine( + enable_chunked_prefill=config["enable_chunked_prefill"], + enable_prefix_caching=config["enable_prefix_caching"], + prefix_caching_mamba_gb=config["prefix_caching_mamba_gb"], + block_size_tokens=block_size, + max_tokens=config["max_tokens"], + max_requests=8, + vocab_size=vocab_size, + max_sequence_length=max_sequence_length, + buffer_size_gb=buffer_size_gb, + seed=seed, + eviction_policy=config.get("eviction_policy", PrefixCachingEvictionPolicy.REF_ZERO), + ) + _install_deterministic_mock_forward(engine, vocab_size) + + # --- Request A: 64 tokens (2 blocks) → run to completion --- + # Engine computes: last_aligned = 64, no divergence (first request). + # After A's prefill (total_prefilled=64 == last_aligned=64): + # stores Mamba state for block 1 (last complete block). + # After A completes: KV blocks 0,1 cached (LRU), Mamba at block 1. + req_a = make_req(0, prompt_a, config["enable_prefix_caching"]) + _run_to_completion(engine, [req_a]) + + # --- Request B: 128 tokens (first 64 shared with A) → run to completion --- + # Engine finds: KV match = 2 blocks (from A), Mamba match = 2 + # (_find_mamba_divergence_block scans backward: block 1 has state → return 2). + # num_mamba_matched == num_kv_matched → no divergence, _kv_divergence_token = 0. + # last_aligned = 128. prefix_skip = 64, effective = 64 tokens. + # Mamba restored from block 1 (initial states passed to varlen kernel). + # After B's prefill (total_prefilled=128 == last_aligned=128): + # stores Mamba state for block 3 (last complete block). + # After B: Mamba at {1, 3}. KV blocks {0,1,2,3} all cached. + req_b = make_req(1, prompt_b, config["enable_prefix_caching"]) + _run_to_completion(engine, [req_b]) + + # --- Mamba state assertions (prefix caching config only) --- + if config["name"] == "chunked+prefix": + ctx = engine.context + alloc = ctx.block_allocator + + # Look up block IDs from B's hashes (B covers blocks 0-3). + # Blocks remain in KV hash map after release (LRU eviction policy). + b_hashes = req_b.precomputed_block_hashes + b_block_ids = [alloc.kv_hash_to_block_id.get(h) for h in b_hashes] + assert all(bid is not None for bid in b_block_ids), \ + "All of B's blocks should be registered in KV hash map" + + # After B completes: Mamba state at blocks 1 and 3 only. + # Block 0: never stored (only the last complete block at a boundary is stored). + # Block 1: stored after A's prefill. + # Block 2: not stored (B's prefill boundary was at block 3, not block 2). + # Block 3: stored after B's prefill. + assert ctx.has_mamba_state_for_block(b_block_ids[1]), \ + "Block 1 should have Mamba state (stored after A's prefill)" + assert ctx.has_mamba_state_for_block(b_block_ids[3]), \ + "Block 3 should have Mamba state (stored after B's prefill)" + assert not ctx.has_mamba_state_for_block(b_block_ids[0]), \ + "Block 0 should NOT have Mamba state" + assert not ctx.has_mamba_state_for_block(b_block_ids[2]), \ + "Block 2 should NOT have Mamba state" + + # --- Request C: 96 tokens (first 96 shared with B) → run to completion --- + # Engine finds: KV match = 3 blocks (0,1,2 from A/B). + # _find_mamba_divergence_block scans backward: + # block 2 (no Mamba) → block 1 (yes) → return 2. + # num_mamba_matched=2 < num_kv_matched=3 → divergence! + # _kv_divergence_token = 3*32 = 96, _mamba_last_aligned_token = 96. + # In add_request: KV match truncated to 2 blocks. + # prefix_skip = min(2*32, 96-1) = 64, effective = 32 tokens. + # Mamba restored from block 1 (initial states passed to varlen kernel). + # After C's prefill (total_prefilled=96 == _kv_divergence_token=96): + # stores Mamba state for block 2 (divergence boundary). + req_c = make_req(2, prompt_c, config["enable_prefix_caching"]) + results_c = _run_to_completion(engine, [req_c]) + + # --- Mamba state assertion: block 2 now has Mamba state --- + if config["name"] == "chunked+prefix": + ctx = engine.context + alloc = ctx.block_allocator + + c_hashes = req_c.precomputed_block_hashes + block_2_id = alloc.kv_hash_to_block_id.get(c_hashes[2]) + assert block_2_id is not None, "Block 2 should be in KV hash map" + assert ctx.has_mamba_state_for_block(block_2_id), \ + "Block 2 should now have Mamba state (stored at divergence boundary)" + + all_results[config["name"]] = results_c[2] + + # --- Cross-config comparison: C's output must match --- + assert all_results["chunked+prefix"] == all_results["chunked_only"], ( + f"Request C output mismatch:\n" + f" chunked+prefix: {all_results['chunked+prefix']}\n" + f" chunked_only: {all_results['chunked_only']}" + ) + + @pytest.mark.internal + @pytest.mark.skipif( + not is_fa_min_version("2.7.3"), reason="need latest flash attn for dynamic batching" + ) + @torch.no_grad() + def test_last_aligned_boundary_forces_chunk_and_stores_mamba_state(self): + """Request X (80 tokens, no shared prefix) gets forced-chunked at 64 tokens + (last block-aligned boundary). Mamba state is stored after the first chunk, + and the second chunk (16 tokens) continues with per-request Mamba state. + + Sequence: + X (80 tokens, no prefix) → + Chunk 1: 64 tokens → stores Mamba for block 1 + Chunk 2: 16 tokens (initial Mamba state restored, unified varlen path) + + Verified by comparing X's output between chunked+prefix and chunked_only configs. + """ + _skip_if_mamba_sequence_packing_not_available() + + # --- Parameters --- + seed = 42 + vocab_size = 100 + block_size = 32 + max_tokens = 256 + max_sequence_length = 256 + buffer_size_gb = 0.1 + num_tokens_to_generate = 4 + + torch.manual_seed(seed) + prompt_x = torch.randint(0, vocab_size - 1, (80,), device="cuda", dtype=torch.int64) + + def make_req(req_id, prompt, enable_pc): + return DynamicInferenceRequest( + request_id=req_id, + prompt_tokens=prompt.clone(), + sampling_params=SamplingParams( + num_tokens_to_generate=num_tokens_to_generate, + termination_id=-1, + ), + block_size_tokens=block_size, + enable_prefix_caching=enable_pc, + ) + + configs = [ + { + "name": "chunked+prefix", + "enable_chunked_prefill": True, + "enable_prefix_caching": True, + "prefix_caching_mamba_gb": 0.01, + "max_tokens": max_tokens, + "eviction_policy": PrefixCachingEvictionPolicy.LRU, + }, + { + "name": "chunked_only", + "enable_chunked_prefill": True, + "enable_prefix_caching": False, + "prefix_caching_mamba_gb": None, + "max_tokens": max_tokens, + }, + ] + + all_results = {} + for config in configs: + Utils.destroy_model_parallel() + Utils.initialize_model_parallel( + tensor_model_parallel_size=1, pipeline_model_parallel_size=1 + ) + + engine, _ = _build_engine( + enable_chunked_prefill=config["enable_chunked_prefill"], + enable_prefix_caching=config["enable_prefix_caching"], + prefix_caching_mamba_gb=config["prefix_caching_mamba_gb"], + block_size_tokens=block_size, + max_tokens=config["max_tokens"], + max_requests=8, + vocab_size=vocab_size, + max_sequence_length=max_sequence_length, + buffer_size_gb=buffer_size_gb, + seed=seed, + eviction_policy=config.get("eviction_policy", PrefixCachingEvictionPolicy.REF_ZERO), + ) + _install_deterministic_mock_forward(engine, vocab_size) + + # --- Request X: 80 tokens (no prefix) --- + # In chunked+prefix config: + # No prefix match. _mamba_num_matched_blocks = 0. + # _mamba_last_aligned_token = (80//32)*32 = 64. + # _get_mamba_chunk_limit returns 64 (last_aligned - finished = 64 - 0). + # mamba_forces_chunk = (80 > 64) → True. + # chunk_length = min(80, min(256, 64)) = 64. + # Chunk 1: 64 tokens → Mamba zero-init, processed from scratch. + # _store_mamba_states: total_prefilled=64 == last_aligned=64 → stores block 1. + # Chunk 2: 16 tokens, initial Mamba state restored (unified varlen path). + # + # In chunked_only config: + # No Mamba budget → max_mamba_cache_slots=0 → no mamba_limit → no forced chunk. + # X processes all 80 tokens in one shot. + req_x = make_req(0, prompt_x, config["enable_prefix_caching"]) + results_x = _run_to_completion(engine, [req_x]) + + all_results[config["name"]] = results_x[0] + + # --- Cross-config comparison: X's output must match --- + assert all_results["chunked+prefix"] == all_results["chunked_only"], ( + f"Request X output mismatch:\n" + f" chunked+prefix: {all_results['chunked+prefix']}\n" + f" chunked_only: {all_results['chunked_only']}" + ) + + +class TestMixedKernelRouting: + """Engine-level test for mixed restored-state and fresh prefill routing. + + In a single forward step, a continuing chunked prefill request with + restored Mamba state runs alongside two fresh prefill requests with + zero-initialized Mamba state, all through the unified varlen path. + + Scenario: + block_size=32, max_tokens=112, num_tokens_to_generate=4 + + 1. Request A (64 tokens) -> run to completion, Mamba state at block 1 + 2. Request B (160 tokens, first 64 shared with A) -> chunk 1 = 112 tokens + (64 prefix-skipped + 48 effective). B has 48 remaining, is continuing. + 3. Add C (32 tokens, fresh) and D (32 tokens, fresh) + 4. Schedule -> B continues (restored state, 48 tokens) + C (fresh, 32) + D (fresh, 32) + 5. Step -> forward pass with unified varlen kernel + 6. Continue until all complete + + Verified by comparing B, C, D outputs against individual baseline runs. + + NOTE: this test covers a strictly more complex scenario than + TestMultiplePrefillWithInitialStates (mixed restored + fresh initial + states vs all-restored). + """ + + def teardown_method(self, method): + Utils.destroy_model_parallel() + + @pytest.mark.internal + @pytest.mark.skipif( + not is_fa_min_version("2.7.3"), reason="need latest flash attn for dynamic batching" + ) + @torch.no_grad() + def test_restored_state_continuation_with_fresh_prefills(self): + _skip_if_mamba_sequence_packing_not_available() + + # --- Parameters --- + seed = 42 + vocab_size = 100 + block_size = 32 + max_tokens = 112 + max_requests = 8 + max_sequence_length = 512 + buffer_size_gb = 0.1 + num_tokens_to_generate = 4 + + # Deterministic prompt generation. + torch.manual_seed(seed) + shared_prefix = torch.randint( + 0, vocab_size - 1, (64,), device="cuda", dtype=torch.int64 + ) + suffix_b = torch.randint( + 0, vocab_size - 1, (96,), device="cuda", dtype=torch.int64 + ) + prompt_a = shared_prefix.clone() # 64 tokens (2 blocks) + prompt_b = torch.cat([shared_prefix, suffix_b]) # 160 tokens (5 blocks) + prompt_c = torch.randint( + 0, vocab_size - 1, (32,), device="cuda", dtype=torch.int64 + ) # 32 tokens, no shared prefix + prompt_d = torch.randint( + 0, vocab_size - 1, (32,), device="cuda", dtype=torch.int64 + ) # 32 tokens, no shared prefix + + engine_kwargs = dict( + enable_chunked_prefill=True, + enable_prefix_caching=True, + prefix_caching_mamba_gb=0.01, + block_size_tokens=block_size, + max_tokens=max_tokens, + max_requests=max_requests, + vocab_size=vocab_size, + max_sequence_length=max_sequence_length, + buffer_size_gb=buffer_size_gb, + seed=seed, + eviction_policy=PrefixCachingEvictionPolicy.LRU, + ) + + def make_req(req_id, prompt, enable_pc=True): + return DynamicInferenceRequest( + request_id=req_id, + prompt_tokens=prompt.clone(), + sampling_params=SamplingParams( + num_tokens_to_generate=num_tokens_to_generate, + termination_id=-1, + ), + block_size_tokens=block_size, + enable_prefix_caching=enable_pc, + ) + + # ================================================================= + # Simultaneous run: A first, then B (chunked), then B+C+D together + # ================================================================= + Utils.destroy_model_parallel() + Utils.initialize_model_parallel( + tensor_model_parallel_size=1, pipeline_model_parallel_size=1 + ) + engine_sim, _ = _build_engine(**engine_kwargs) + _install_deterministic_mock_forward(engine_sim, vocab_size) + + # Step 1: Run A (64 tokens) to completion. + # After A: Mamba state stored for block 1 (last_aligned=64), + # KV blocks 0,1 cached (ref_count=0, LRU). + results_a = _run_to_completion(engine_sim, [make_req(0, prompt_a)]) + assert 0 in results_a + + # Step 2: Add B (160 tokens, first 64 shared with A) and step once. + # Engine scheduling: + # KV match = 2 blocks (from A), Mamba match = 2 (block 1 has state). + # _mamba_last_aligned_token = 160, _kv_divergence_token = 0 (mamba==kv). + # mamba_limit = 160 (last_aligned - finished). remaining = 160. + # token_fully_can_be_added = (160 <= 112) → False. + # chunk_length = min(160, 112) = 112 (capped by max_tokens). + # min(112, 160) = 112 → no additional mamba cap. + # In add_request: + # finished_chunk_token_count=0 → is_chunked_prefill=False → allocates Mamba. + # prefix_skip = min(64, 111) = 64, effective = 48 tokens. + # Mamba restored from block 1 (initial states passed to varlen kernel). + # After step: B has 48 remaining tokens, is continuing chunked prefill. + engine_sim._add_request(make_req(1, prompt_b)) + step_result = engine_sim.step_modern() + assert len(step_result["finished_request_records"]) == 0, \ + "B should not be finished after first chunk" + + # Step 3: Add C and D (fresh requests, no shared prefix with anything). + # C: 32 tokens. D: 32 tokens. Both go to waiting queue behind B. + engine_sim._add_request(make_req(2, prompt_c)) + engine_sim._add_request(make_req(3, prompt_d)) + + # Steps 4-6: Schedule and step until all complete. + # Next schedule_chunked_prefill processes: + # B continues (restored state): 48 tokens. + # C fresh: 32 tokens (zero-initialized Mamba state). + # D fresh: 32 tokens (zero-initialized Mamba state). + # Total tokens: 48 + 32 + 32 = 112 = max_tokens. + # All three go through unified varlen path. + # Subsequent steps: decode until all requests generate num_tokens_to_generate. + results_sim = {} + while engine_sim.has_unfinished_requests(): + result = engine_sim.step_modern() + for record in result["finished_request_records"]: + finished = record.merge() + results_sim[finished.request_id] = list(finished.generated_tokens) + + assert 1 in results_sim, "B did not complete" + assert 2 in results_sim, "C did not complete" + assert 3 in results_sim, "D did not complete" + + # ================================================================= + # Individual baseline runs (chunked, no prefix caching) + # Each request runs alone in a fresh engine, producing baseline output. + # ================================================================= + baseline_kwargs = dict( + enable_chunked_prefill=True, + enable_prefix_caching=False, + prefix_caching_mamba_gb=None, + block_size_tokens=block_size, + max_tokens=max_tokens, + max_requests=max_requests, + vocab_size=vocab_size, + max_sequence_length=max_sequence_length, + buffer_size_gb=buffer_size_gb, + seed=seed, + ) + + # B alone (160 tokens, chunked at 112 + 48, no prefix skip) + Utils.destroy_model_parallel() + Utils.initialize_model_parallel( + tensor_model_parallel_size=1, pipeline_model_parallel_size=1 + ) + engine_b, _ = _build_engine(**baseline_kwargs) + _install_deterministic_mock_forward(engine_b, vocab_size) + results_b = _run_to_completion( + engine_b, [make_req(1, prompt_b, enable_pc=False)] + ) + + # C alone (32 tokens, single prefill) + Utils.destroy_model_parallel() + Utils.initialize_model_parallel( + tensor_model_parallel_size=1, pipeline_model_parallel_size=1 + ) + engine_c, _ = _build_engine(**baseline_kwargs) + _install_deterministic_mock_forward(engine_c, vocab_size) + results_c = _run_to_completion( + engine_c, [make_req(2, prompt_c, enable_pc=False)] + ) + + # D alone (32 tokens, single prefill) + Utils.destroy_model_parallel() + Utils.initialize_model_parallel( + tensor_model_parallel_size=1, pipeline_model_parallel_size=1 + ) + engine_d, _ = _build_engine(**baseline_kwargs) + _install_deterministic_mock_forward(engine_d, vocab_size) + results_d = _run_to_completion( + engine_d, [make_req(3, prompt_d, enable_pc=False)] + ) + + # ================================================================= + # Assertions: simultaneous outputs must match individual baselines + # ================================================================= + assert results_sim[1] == results_b[1], ( + f"Request B mismatch:\n" + f" simultaneous: {results_sim[1]}\n" + f" individual: {results_b[1]}" + ) + assert results_sim[2] == results_c[2], ( + f"Request C mismatch:\n" + f" simultaneous: {results_sim[2]}\n" + f" individual: {results_c[2]}" + ) + assert results_sim[3] == results_d[3], ( + f"Request D mismatch:\n" + f" simultaneous: {results_sim[3]}\n" + f" individual: {results_d[3]}" + ) + + # Verify non-trivial output + assert len(results_sim[1]) == num_tokens_to_generate + assert len(results_sim[2]) == num_tokens_to_generate + assert len(results_sim[3]) == num_tokens_to_generate + + +class TestMambaEvictionEdgeCases: + """Tests for Mamba eviction edge cases: all-active raises, mixed ref counts, restore-after-evict.""" + + def setup_method(self, method): + Utils.initialize_model_parallel( + tensor_model_parallel_size=1, pipeline_model_parallel_size=1 + ) + model_parallel_cuda_manual_seed(123) + + def teardown_method(self, method): + Utils.destroy_model_parallel() + + @pytest.mark.internal + def test_all_active_slots_raises_error(self): + """When all Mamba slots hold blocks with ref_count > 0, eviction raises RuntimeError.""" + ctx = _build_hybrid_context(prefix_caching_mamba_gb=0.001, eviction_policy=PrefixCachingEvictionPolicy.LRU) + max_slots = ctx.max_mamba_cache_slots + assert max_slots > 0 + + # Fill all slots with blocks whose ref_count > 0 (active) + for i in range(max_slots): + ctx.block_to_mamba_slot[i] = i + ctx.mamba_slot_to_block[i] = i + ctx.block_allocator.block_ref_counts[i] = 1 # active + ctx.block_allocator.block_timestamps[i] = i * 100 + ctx.mamba_cache_free_count = 0 + + # Allocating a new slot should raise because no evictable candidates + new_block_id = max_slots + 1 + with pytest.raises(RuntimeError, match="all slots in active use"): + ctx._allocate_mamba_cache_slot(new_block_id) + + @pytest.mark.internal + def test_mixed_ref_counts_evicts_only_inactive(self): + """With mixed ref_count=0 and ref_count=1, only the oldest ref_count=0 block is evicted.""" + ctx = _build_hybrid_context(prefix_caching_mamba_gb=0.001, eviction_policy=PrefixCachingEvictionPolicy.LRU) + max_slots = ctx.max_mamba_cache_slots + assert max_slots >= 3, f"Need at least 3 Mamba slots, got {max_slots}" + + # Fill all slots: some active (ref=1), some evictable (ref=0) + # slot 0 → block 0: ref=1 (active), timestamp=0 (oldest) + # slot 1 → block 1: ref=0 (evictable), timestamp=100 + # slot 2 → block 2: ref=0 (evictable), timestamp=50 (oldest evictable) + for i in range(max_slots): + ctx.block_to_mamba_slot[i] = i + ctx.mamba_slot_to_block[i] = i + if i == 0: + ctx.block_allocator.block_ref_counts[i] = 1 + ctx.block_allocator.block_timestamps[i] = 0 + elif i == 1: + ctx.block_allocator.block_ref_counts[i] = 0 + ctx.block_allocator.block_timestamps[i] = 100 + elif i == 2: + ctx.block_allocator.block_ref_counts[i] = 0 + ctx.block_allocator.block_timestamps[i] = 50 + else: + # Extra slots beyond 3: make them active so they don't interfere + ctx.block_allocator.block_ref_counts[i] = 1 + ctx.block_allocator.block_timestamps[i] = 200 + i + ctx.mamba_cache_free_count = 0 + + # Allocate for a new block → should evict block 2 (oldest evictable, timestamp=50) + new_block_id = max_slots + 1 + slot = ctx._allocate_mamba_cache_slot(new_block_id) + assert slot >= 0 + + # Block 0 (active) should be untouched + assert ctx.block_to_mamba_slot[0].item() == 0, "Active block 0 should be untouched" + assert ctx.mamba_slot_to_block[0].item() == 0 + + # Block 2 (oldest evictable) should be evicted + assert ctx.block_to_mamba_slot[2].item() == -1, "Block 2 should be evicted" + + # Block 1 (evictable but newer) should be untouched + assert ctx.block_to_mamba_slot[1].item() == 1, "Block 1 should be untouched" + + # New block should have the evicted slot + assert ctx.block_to_mamba_slot[new_block_id].item() == slot + + @pytest.mark.internal + def test_mamba_restore_after_eviction_cycle(self): + """After evicting and re-storing Mamba for a block, the new values are restored (not stale).""" + ctx = _build_hybrid_context(prefix_caching_mamba_gb=0.001) + block_size = ctx.block_size_tokens + max_slots = ctx.max_mamba_cache_slots + assert max_slots >= 1 + + # Add request A (2 blocks) + a_idx = ctx.total_request_count + prompt_a = torch.arange(block_size * 2, device=torch.cuda.current_device()) + req_a = _make_request(1, prompt_a, block_size) + ctx.add_request(req_a) + + # Write known values (conv=1.0, ssm=2.0) and store for block 0 + mamba_idx_a = ctx.mamba_metadata.request_to_mamba_state_idx[a_idx].item() + ctx.mamba_conv_states[:, mamba_idx_a] = 1.0 + ctx.mamba_ssm_states[:, mamba_idx_a] = 2.0 + block_0_id = ctx.request_to_kv_block_ids[a_idx][0].item() + ctx.store_mamba_state_for_block(block_0_id, a_idx) + assert ctx.has_mamba_state_for_block(block_0_id) + + # Evict block 0's Mamba state by invalidating + ctx.invalidate_mamba_state_for_block(block_0_id) + assert not ctx.has_mamba_state_for_block(block_0_id) + + # Re-store with new values (conv=3.0, ssm=4.0) + ctx.mamba_conv_states[:, mamba_idx_a] = 3.0 + ctx.mamba_ssm_states[:, mamba_idx_a] = 4.0 + ctx.store_mamba_state_for_block(block_0_id, a_idx) + assert ctx.has_mamba_state_for_block(block_0_id) + + # Clear request state, then restore from cache + ctx.mamba_conv_states[:, mamba_idx_a] = 0.0 + ctx.mamba_ssm_states[:, mamba_idx_a] = 0.0 + restored = ctx.restore_mamba_state_from_block(a_idx, block_0_id) + assert restored + + # Verify new values (3.0, 4.0) not stale (1.0, 2.0) + assert torch.allclose( + ctx.mamba_conv_states[:, mamba_idx_a], + torch.full_like(ctx.mamba_conv_states[:, mamba_idx_a], 3.0), + ), "Should restore new conv values (3.0), not stale (1.0)" + assert torch.allclose( + ctx.mamba_ssm_states[:, mamba_idx_a], + torch.full_like(ctx.mamba_ssm_states[:, mamba_idx_a], 4.0), + ), "Should restore new ssm values (4.0), not stale (2.0)" + + +class TestKvMambaEvictionInteraction: + """Tests for the interaction between KV eviction and Mamba invalidation.""" + + def setup_method(self, method): + Utils.initialize_model_parallel( + tensor_model_parallel_size=1, pipeline_model_parallel_size=1 + ) + model_parallel_cuda_manual_seed(123) + + def teardown_method(self, method): + Utils.destroy_model_parallel() + + @pytest.mark.internal + def test_kv_eviction_cascades_to_mamba_invalidation(self): + """evict_lru_blocks removes KV blocks AND invalidates their Mamba state.""" + ctx = _build_hybrid_context( + prefix_caching_mamba_gb=0.01, + buffer_size_gb=0.01, + rounder=1, + eviction_policy=PrefixCachingEvictionPolicy.LRU, + ) + block_size = ctx.block_size_tokens + alloc = ctx.block_allocator + + # Add request A (2 blocks) + a_idx = ctx.total_request_count + prompt_a = torch.arange(block_size * 2, device=torch.cuda.current_device()) + req_a = _make_request(1, prompt_a, block_size) + ctx.add_request(req_a) + + # Store Mamba state for block 0 and register mamba hash + block_0_id = ctx.request_to_kv_block_ids[a_idx][0].item() + block_1_id = ctx.request_to_kv_block_ids[a_idx][1].item() + block_0_hash = alloc.block_hashes[block_0_id].item() + ctx.store_mamba_state_for_block(block_0_id, a_idx) + alloc.register_mamba_block_hash(block_0_id, block_0_hash) + + assert ctx.has_mamba_state_for_block(block_0_id) + assert block_0_hash in alloc.kv_hash_to_block_id + assert block_0_hash in alloc.mamba_hash_to_block_id + free_before = ctx.mamba_cache_free_count + + # Release A (ref_count → 0) + ctx.release_memory_blocks_from_request_indexes(torch.tensor([a_idx])) + ctx.total_request_count = 0 + + # Make blocks old so they get evicted first + alloc.block_timestamps[block_0_id] = 0 + alloc.block_timestamps[block_1_id] = 1 + + # Evict A's KV blocks + result = alloc.evict_lru_blocks(2) + assert result, "Should evict 2 blocks" + + # Mamba state should be gone + assert not ctx.has_mamba_state_for_block(block_0_id), \ + "Mamba state should be invalidated after KV eviction" + # Mamba free count should have increased + assert ctx.mamba_cache_free_count == free_before + 1 + # Both hash maps should be clear + assert block_0_hash not in alloc.kv_hash_to_block_id + assert block_0_hash not in alloc.mamba_hash_to_block_id + + @pytest.mark.internal + def test_mamba_eviction_preserves_kv_cache(self): + """Mamba eviction removes Mamba state but leaves KV blocks intact.""" + # Use extremely small Mamba budget to get exactly 1 slot. + # Each slot needs ~41 KB (conv_states + ssm_states), so 0.00005 GB ≈ 53 KB → 1 slot. + ctx = _build_hybrid_context( + prefix_caching_mamba_gb=0.00005, + buffer_size_gb=0.01, + rounder=1, + eviction_policy=PrefixCachingEvictionPolicy.LRU, + ) + block_size = ctx.block_size_tokens + alloc = ctx.block_allocator + max_slots = ctx.max_mamba_cache_slots + assert max_slots >= 1, f"Need at least 1 Mamba slot, got {max_slots}" + + # Add request A + a_idx = ctx.total_request_count + prompt_a = torch.arange(block_size * 2, device=torch.cuda.current_device()) + req_a = _make_request(1, prompt_a, block_size) + ctx.add_request(req_a) + + block_a0_id = ctx.request_to_kv_block_ids[a_idx][0].item() + block_a0_hash = alloc.block_hashes[block_a0_id].item() + ctx.store_mamba_state_for_block(block_a0_id, a_idx) + alloc.register_mamba_block_hash(block_a0_id, block_a0_hash) + + assert ctx.has_mamba_state_for_block(block_a0_id) + + # Release A (ref_count → 0, blocks become evictable) + ctx.release_memory_blocks_from_request_indexes(torch.tensor([a_idx])) + + # Make A's block have old timestamp so it gets evicted first + alloc.block_timestamps[block_a0_id] = 0 + + # Fill any remaining Mamba slots with filler blocks to ensure cache is full + for i in range(alloc.total_count): + if ctx.mamba_cache_free_count == 0: + break + if ctx.block_to_mamba_slot[i].item() < 0 and i != block_a0_id: + alloc.block_ref_counts[i] = 0 + alloc.block_timestamps[i] = 1000 # newer than A + ctx._allocate_mamba_cache_slot(i) + + assert ctx.mamba_cache_free_count == 0, "Mamba cache should be full" + + # Add request B (different prompt) + b_idx = ctx.total_request_count + prompt_b = torch.arange(1000, 1000 + block_size * 2, device=torch.cuda.current_device()) + req_b = _make_request(2, prompt_b, block_size) + ctx.add_request(req_b) + + # Store Mamba for B's block → forces eviction of oldest (A's block) + block_b0_id = ctx.request_to_kv_block_ids[b_idx][0].item() + ctx.store_mamba_state_for_block(block_b0_id, b_idx) + + # After Mamba eviction: KV blocks should still be in kv_hash_to_block_id + assert block_a0_hash in alloc.kv_hash_to_block_id, \ + "A's KV block should still be cached after Mamba eviction" + + # A's Mamba state should be gone + assert not ctx.has_mamba_state_for_block(block_a0_id), \ + "A's Mamba state should be evicted" + assert block_a0_hash not in alloc.mamba_hash_to_block_id, \ + "A's hash should be removed from mamba map" + + @pytest.mark.internal + def test_combined_kv_and_mamba_pressure(self): + """After both Mamba eviction and KV deregistration, all cached state is gone.""" + ctx = _build_hybrid_context( + prefix_caching_mamba_gb=0.00005, # ~1 Mamba slot + buffer_size_gb=0.01, + rounder=1, + eviction_policy=PrefixCachingEvictionPolicy.LRU, + ) + block_size = ctx.block_size_tokens + alloc = ctx.block_allocator + + # Add request A (2 blocks) + a_idx = ctx.total_request_count + prompt_a = torch.arange(block_size * 2, device=torch.cuda.current_device()) + req_a = _make_request(1, prompt_a, block_size) + ctx.add_request(req_a) + + # Save block IDs before releasing (release clears request_to_kv_block_ids) + block_a0_id = ctx.request_to_kv_block_ids[a_idx][0].item() + block_a1_id = ctx.request_to_kv_block_ids[a_idx][1].item() + block_a0_hash = alloc.block_hashes[block_a0_id].item() + ctx.store_mamba_state_for_block(block_a0_id, a_idx) + alloc.register_mamba_block_hash(block_a0_id, block_a0_hash) + + assert ctx.has_mamba_state_for_block(block_a0_id) + assert block_a0_hash in alloc.kv_hash_to_block_id + assert block_a0_hash in alloc.mamba_hash_to_block_id + + # Release A (ref_count → 0) + ctx.release_memory_blocks_from_request_indexes(torch.tensor([a_idx])) + ctx.total_request_count = 0 + alloc.block_timestamps[block_a0_id] = 0 # oldest + alloc.block_timestamps[block_a1_id] = 1 + + # Now directly deregister A's KV blocks (simulates LRU eviction) + # This should cascade to Mamba invalidation via the callback + blocks_tensor = torch.tensor( + [block_a0_id, block_a1_id], device=torch.cuda.current_device(), dtype=torch.int32 + ) + alloc._deregister_blocks(blocks_tensor) + + # After KV eviction: both KV and Mamba state should be gone + assert block_a0_hash not in alloc.kv_hash_to_block_id, \ + "A's KV hash should be gone after eviction" + assert block_a0_hash not in alloc.mamba_hash_to_block_id, \ + "A's Mamba hash should be gone after eviction" + assert not ctx.has_mamba_state_for_block(block_a0_id), \ + "A's Mamba state should be gone after KV eviction" + + # Add request D with same prefix as A → should get no matches (all state evicted) + d_idx = ctx.total_request_count + req_d = _make_request(4, prompt_a.clone(), block_size) + req_d._mamba_num_matched_blocks = 0 + ctx.add_request(req_d) + + # D should have zero-init Mamba state (fresh allocation) + mamba_idx_d = ctx.mamba_metadata.request_to_mamba_state_idx[d_idx].item() + assert torch.all(ctx.mamba_conv_states[:, mamba_idx_d] == 0.0), \ + "D should get zero-init Mamba state (no cache hit)" + + +class TestEvictionEndToEnd: + """Engine-level tests verifying correct output after KV and/or Mamba eviction.""" + + def teardown_method(self, method): + Utils.destroy_model_parallel() + + @pytest.mark.internal + @pytest.mark.skipif( + not is_fa_min_version("2.7.3"), reason="need latest flash attn for dynamic batching" + ) + @torch.no_grad() + def test_mamba_eviction_forces_full_recompute(self): + """After Mamba eviction, a repeated prefix triggers full recompute and still produces correct output.""" + _skip_if_mamba_sequence_packing_not_available() + + seed = 42 + vocab_size = 100 + block_size = 32 + max_tokens = 256 + max_sequence_length = 256 + buffer_size_gb = 0.1 + num_tokens_to_generate = 4 + + torch.manual_seed(seed) + prompt_a = torch.randint(0, vocab_size - 1, (64,), device="cuda", dtype=torch.int64) + prompt_b = torch.randint(0, vocab_size - 1, (64,), device="cuda", dtype=torch.int64) + + def make_req(req_id, prompt, enable_pc): + return DynamicInferenceRequest( + request_id=req_id, + prompt_tokens=prompt.clone(), + sampling_params=SamplingParams( + num_tokens_to_generate=num_tokens_to_generate, + termination_id=-1, + ), + block_size_tokens=block_size, + enable_prefix_caching=enable_pc, + ) + + # Config with very small Mamba cache (1-2 slots) to force Mamba eviction + configs = [ + { + "name": "prefix+small_mamba", + "enable_chunked_prefill": True, + "enable_prefix_caching": True, + "prefix_caching_mamba_gb": 0.0001, # Very small: 1-2 Mamba slots + "max_tokens": max_tokens, + "eviction_policy": PrefixCachingEvictionPolicy.LRU, + }, + { + "name": "baseline", + "enable_chunked_prefill": True, + "enable_prefix_caching": False, + "prefix_caching_mamba_gb": None, + "max_tokens": max_tokens, + }, + ] + + all_results = {} + for config in configs: + Utils.destroy_model_parallel() + Utils.initialize_model_parallel( + tensor_model_parallel_size=1, pipeline_model_parallel_size=1 + ) + + engine, _ = _build_engine( + enable_chunked_prefill=config["enable_chunked_prefill"], + enable_prefix_caching=config["enable_prefix_caching"], + prefix_caching_mamba_gb=config["prefix_caching_mamba_gb"], + block_size_tokens=block_size, + max_tokens=config["max_tokens"], + max_requests=8, + vocab_size=vocab_size, + max_sequence_length=max_sequence_length, + buffer_size_gb=buffer_size_gb, + seed=seed, + eviction_policy=config.get("eviction_policy", PrefixCachingEvictionPolicy.REF_ZERO), + ) + _install_deterministic_mock_forward(engine, vocab_size) + + # A: 64 tokens → completes → Mamba stored + _run_to_completion(engine, [make_req(0, prompt_a, config["enable_prefix_caching"])]) + + # B: different 64 tokens → completes → may evict A's Mamba + _run_to_completion(engine, [make_req(1, prompt_b, config["enable_prefix_caching"])]) + + # C: same as A → KV match possible, but Mamba may be evicted → full recompute + results_c = _run_to_completion( + engine, [make_req(2, prompt_a, config["enable_prefix_caching"])] + ) + all_results[config["name"]] = results_c[2] + + assert all_results["prefix+small_mamba"] == all_results["baseline"], ( + f"Output mismatch after Mamba eviction:\n" + f" prefix+small_mamba: {all_results['prefix+small_mamba']}\n" + f" baseline: {all_results['baseline']}" + ) + + @pytest.mark.internal + @pytest.mark.skipif( + not is_fa_min_version("2.7.3"), reason="need latest flash attn for dynamic batching" + ) + @torch.no_grad() + def test_kv_mamba_eviction_correct_output(self): + """With a small buffer, many intervening requests force KV LRU eviction. + A repeated prefix gets full recompute and still produces correct output. + + Uses 65-token prompts (3 blocks each, non-block-aligned to avoid decode + pausing). With buffer_size_gb=0.01 (~33 blocks), 11 requests fill all + blocks, and the 12th request (C, same as A) triggers eviction of A's + blocks during prefill allocation. + """ + _skip_if_mamba_sequence_packing_not_available() + + seed = 42 + vocab_size = 100 + block_size = 32 + max_tokens = 256 + max_sequence_length = 256 + buffer_size_gb = 0.01 # ~33 blocks + num_tokens_to_generate = 4 + prompt_len = 65 # Non-block-aligned to avoid decode pause + + torch.manual_seed(seed) + prompt_a = torch.randint(0, vocab_size - 1, (prompt_len,), device="cuda", dtype=torch.int64) + # Generate 10 different filler prompts (A + 10 fillers = 11 requests × 3 blocks = 33) + filler_prompts = [ + torch.randint(0, vocab_size - 1, (prompt_len,), device="cuda", dtype=torch.int64) + for _ in range(10) + ] + + def make_req(req_id, prompt, enable_pc): + return DynamicInferenceRequest( + request_id=req_id, + prompt_tokens=prompt.clone(), + sampling_params=SamplingParams( + num_tokens_to_generate=num_tokens_to_generate, + termination_id=-1, + ), + block_size_tokens=block_size, + enable_prefix_caching=enable_pc, + ) + + configs = [ + { + "name": "prefix+small_buffer", + "enable_chunked_prefill": True, + "enable_prefix_caching": True, + "prefix_caching_mamba_gb": 0.0005, # 1 Mamba slot + "max_tokens": max_tokens, + "eviction_policy": PrefixCachingEvictionPolicy.LRU, + }, + { + "name": "baseline", + "enable_chunked_prefill": True, + "enable_prefix_caching": False, + "prefix_caching_mamba_gb": None, + "max_tokens": max_tokens, + }, + ] + + all_results = {} + for config in configs: + Utils.destroy_model_parallel() + Utils.initialize_model_parallel( + tensor_model_parallel_size=1, pipeline_model_parallel_size=1 + ) + + engine, _ = _build_engine( + enable_chunked_prefill=config["enable_chunked_prefill"], + enable_prefix_caching=config["enable_prefix_caching"], + prefix_caching_mamba_gb=config["prefix_caching_mamba_gb"], + block_size_tokens=block_size, + max_tokens=config["max_tokens"], + max_requests=8, + vocab_size=vocab_size, + max_sequence_length=max_sequence_length, + buffer_size_gb=buffer_size_gb, + seed=seed, + eviction_policy=config.get("eviction_policy", PrefixCachingEvictionPolicy.REF_ZERO), + ) + _install_deterministic_mock_forward(engine, vocab_size) + + # A → completes → KV blocks cached, Mamba stored + _run_to_completion(engine, [make_req(0, prompt_a, config["enable_prefix_caching"])]) + + # Run 10 filler requests to fill all blocks in the cache + for i, filler in enumerate(filler_prompts): + _run_to_completion( + engine, [make_req(i + 1, filler, config["enable_prefix_caching"])] + ) + + # C (same as A) → A's blocks evicted → full recompute + results_c = _run_to_completion( + engine, [make_req(12, prompt_a, config["enable_prefix_caching"])] + ) + all_results[config["name"]] = results_c[12] + + assert all_results["prefix+small_buffer"] == all_results["baseline"], ( + f"Output mismatch after KV+Mamba eviction:\n" + f" prefix+small_buffer: {all_results['prefix+small_buffer']}\n" + f" baseline: {all_results['baseline']}" + ) + + @pytest.mark.internal + @pytest.mark.skipif( + not is_fa_min_version("2.7.3"), reason="need latest flash attn for dynamic batching" + ) + @torch.no_grad() + def test_eviction_with_interleaved_shared_prefixes(self): + """Multiple requests sharing a prefix produce correct output even when Mamba is evicted between them.""" + _skip_if_mamba_sequence_packing_not_available() + + seed = 42 + vocab_size = 100 + block_size = 32 + max_tokens = 256 + max_sequence_length = 256 + buffer_size_gb = 0.1 + num_tokens_to_generate = 4 + + torch.manual_seed(seed) + shared_prefix = torch.randint(0, vocab_size - 1, (64,), device="cuda", dtype=torch.int64) + prompt_b = torch.randint(0, vocab_size - 1, (64,), device="cuda", dtype=torch.int64) + + def make_req(req_id, prompt, enable_pc): + return DynamicInferenceRequest( + request_id=req_id, + prompt_tokens=prompt.clone(), + sampling_params=SamplingParams( + num_tokens_to_generate=num_tokens_to_generate, + termination_id=-1, + ), + block_size_tokens=block_size, + enable_prefix_caching=enable_pc, + ) + + configs = [ + { + "name": "prefix+small_mamba", + "enable_chunked_prefill": True, + "enable_prefix_caching": True, + "prefix_caching_mamba_gb": 0.0001, # Very small: 1-2 Mamba slots + "max_tokens": max_tokens, + "eviction_policy": PrefixCachingEvictionPolicy.LRU, + }, + { + "name": "baseline", + "enable_chunked_prefill": True, + "enable_prefix_caching": False, + "prefix_caching_mamba_gb": None, + "max_tokens": max_tokens, + }, + ] + + all_results = {} + for config in configs: + Utils.destroy_model_parallel() + Utils.initialize_model_parallel( + tensor_model_parallel_size=1, pipeline_model_parallel_size=1 + ) + + engine, _ = _build_engine( + enable_chunked_prefill=config["enable_chunked_prefill"], + enable_prefix_caching=config["enable_prefix_caching"], + prefix_caching_mamba_gb=config["prefix_caching_mamba_gb"], + block_size_tokens=block_size, + max_tokens=config["max_tokens"], + max_requests=8, + vocab_size=vocab_size, + max_sequence_length=max_sequence_length, + buffer_size_gb=buffer_size_gb, + seed=seed, + eviction_policy=config.get("eviction_policy", PrefixCachingEvictionPolicy.REF_ZERO), + ) + _install_deterministic_mock_forward(engine, vocab_size) + + # A1: shared prefix → completes → Mamba stored + results_a1 = _run_to_completion( + engine, [make_req(0, shared_prefix, config["enable_prefix_caching"])] + ) + + # A2: same prefix → matches A1's cache → completes + results_a2 = _run_to_completion( + engine, [make_req(1, shared_prefix, config["enable_prefix_caching"])] + ) + + # B: different prefix → completes → may evict Mamba for shared prefix + _run_to_completion(engine, [make_req(2, prompt_b, config["enable_prefix_caching"])]) + + # A3: same prefix → KV may still be cached (LRU), but Mamba evicted → full recompute + results_a3 = _run_to_completion( + engine, [make_req(3, shared_prefix, config["enable_prefix_caching"])] + ) + + all_results[config["name"]] = { + "a1": results_a1[0], + "a2": results_a2[1], + "a3": results_a3[3], + } + + # All three should match within each config + for name in all_results: + assert all_results[name]["a1"] == all_results[name]["a2"], \ + f"[{name}] A1 and A2 outputs should match" + assert all_results[name]["a1"] == all_results[name]["a3"], \ + f"[{name}] A1 and A3 outputs should match" + + # Cross-config: all should match baseline + for req_label in ("a1", "a2", "a3"): + assert all_results["prefix+small_mamba"][req_label] == all_results["baseline"][req_label], ( + f"Request {req_label} mismatch between configs:\n" + f" prefix+small_mamba: {all_results['prefix+small_mamba'][req_label]}\n" + f" baseline: {all_results['baseline'][req_label]}" + ) + + +class TestMambaStressAndBudget: + """Stress tests for Mamba cache slot management.""" + + def setup_method(self, method): + Utils.initialize_model_parallel( + tensor_model_parallel_size=1, pipeline_model_parallel_size=1 + ) + model_parallel_cuda_manual_seed(123) + + def teardown_method(self, method): + Utils.destroy_model_parallel() + + @pytest.mark.internal + def test_rapid_allocation_eviction_cycle(self): + """100 rapid allocate/invalidate cycles leave the slot pool in a clean state.""" + ctx = _build_hybrid_context(prefix_caching_mamba_gb=0.01) + max_slots = ctx.max_mamba_cache_slots + assert max_slots > 0 + + initial_free = ctx.mamba_cache_free_count + assert initial_free == max_slots + + for i in range(100): + block_id = i % max_slots # Reuse block IDs + slot = ctx._allocate_mamba_cache_slot(block_id) + assert slot >= 0 + ctx.invalidate_mamba_state_for_block(block_id) + + # After all cycles: all slots should be free + assert ctx.mamba_cache_free_count == max_slots, \ + f"Expected {max_slots} free slots, got {ctx.mamba_cache_free_count}" + + # No dangling references in mamba_slot_to_block + for s in range(max_slots): + assert ctx.mamba_slot_to_block[s].item() == -1, \ + f"Slot {s} has dangling block reference" + + @pytest.mark.internal + def test_many_blocks_few_mamba_slots(self): + """Large KV buffer with small Mamba cache: only first N blocks get Mamba state.""" + ctx = _build_hybrid_context( + prefix_caching_mamba_gb=0.001, # Few Mamba slots + buffer_size_gb=0.01, # Many KV blocks + max_tokens=None, # No token limit + eviction_policy=PrefixCachingEvictionPolicy.LRU, + ) + max_slots = ctx.max_mamba_cache_slots + assert max_slots >= 1 + + block_size = ctx.block_size_tokens + + # Add a request with many blocks + num_blocks = max_slots + 3 + prompt = torch.arange( + block_size * num_blocks, device=torch.cuda.current_device() + ) + req = _make_request(1, prompt, block_size) + ctx.add_request(req) + + # Store Mamba for first max_slots blocks + for i in range(max_slots): + block_id = ctx.request_to_kv_block_ids[0][i].item() + ctx.store_mamba_state_for_block(block_id, 0) + assert ctx.has_mamba_state_for_block(block_id), \ + f"Block {i} should have Mamba state" + + # Cache should be full + assert ctx.mamba_cache_free_count == 0, "Mamba cache should be full" + + # Blocks beyond max_slots don't have state yet + for i in range(max_slots, min(num_blocks, ctx.request_to_kv_block_ids.shape[1])): + block_id = ctx.request_to_kv_block_ids[0][i].item() + if block_id >= 0: + assert not ctx.has_mamba_state_for_block(block_id), \ + f"Block {i} should NOT have Mamba state" + + # Release request so blocks become evictable + ctx.release_memory_blocks_from_request_indexes(torch.tensor([0])) + + # Set timestamps: block 0 is oldest + for i in range(max_slots): + block_id = ctx.request_to_kv_block_ids[0][i].item() + if block_id >= 0: + ctx.block_allocator.block_timestamps[block_id] = i * 100 + + # Store for one more block beyond max_slots → triggers LRU eviction + extra_block_id = ctx.request_to_kv_block_ids[0][max_slots].item() + if extra_block_id >= 0: + # Need to allocate a new request slot to store from + ctx.total_request_count = 0 + req2 = _make_request(2, prompt.clone(), block_size) + ctx.add_request(req2) + ctx.store_mamba_state_for_block(extra_block_id, ctx.total_request_count - 1) + + # The LRU victim (block 0, oldest timestamp) should have lost state + first_block_id = ctx.request_to_kv_block_ids[0][0].item() + assert not ctx.has_mamba_state_for_block(first_block_id), \ + "Block 0 (LRU victim) should have lost Mamba state" + + # New block should have state + assert ctx.has_mamba_state_for_block(extra_block_id), \ + "Newly stored block should have Mamba state" diff --git a/tests/unit_tests/ssm/ops/test_causal_conv1d_varlen.py b/tests/unit_tests/ssm/ops/test_causal_conv1d_varlen.py new file mode 100644 index 00000000000..6be012a3c29 --- /dev/null +++ b/tests/unit_tests/ssm/ops/test_causal_conv1d_varlen.py @@ -0,0 +1,183 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +"""Unit tests for the Triton varlen causal conv1d kernel. + +Tests correctness of `causal_conv1d_varlen_fn` against a reference implementation +that loops over requests calling `causal_conv1d_fn` with `initial_states`. +""" + +import pytest +import torch + +from megatron.core.ssm.ops.causal_conv1d_varlen import causal_conv1d_varlen_fn + +try: + from causal_conv1d import causal_conv1d_fn + + HAS_CAUSAL_CONV1D = True +except ImportError: + HAS_CAUSAL_CONV1D = False + + +def _reference_conv1d_varlen(x, weight, bias, cu_seqlens, initial_states, activation="silu"): + """Reference: per-request loop calling causal_conv1d_fn with initial_states.""" + num_requests = cu_seqlens.shape[0] - 1 + conv_dim = x.shape[1] + d_conv = weight.shape[1] + parts = [] + for r in range(num_requests): + start = cu_seqlens[r].item() + end = cu_seqlens[r + 1].item() + if end <= start: + continue + seq_len_r = end - start + if initial_states is not None: + init_r = initial_states[r : r + 1] # (1, conv_dim, d_conv-1) + # causal_conv1d_fn with initial_states requires channels-last layout + # for both x and initial_states: create as (1, L, C) then transpose + x_r = x[start:end].unsqueeze(0).transpose(1, 2) # channels-last (1, C, L) + init_r = init_r.permute(0, 2, 1).contiguous().transpose(1, 2) # channels-last + else: + init_r = None + x_r = x[start:end].T.unsqueeze(0).contiguous() # (1, conv_dim, seq_len) + out_r = causal_conv1d_fn( + x=x_r, + weight=weight, + bias=bias, + activation=activation, + initial_states=init_r, + ) + parts.append(out_r.squeeze(0).T.contiguous()) # (seq_len, conv_dim) + return torch.cat(parts, dim=0) if parts else torch.empty(0, conv_dim, device=x.device) + + +@pytest.mark.skipif(not HAS_CAUSAL_CONV1D, reason="causal_conv1d not installed") +class TestCausalConv1dVarlen: + """Test causal_conv1d_varlen_fn against per-request causal_conv1d_fn reference.""" + + @pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16]) + def test_single_request(self, dtype): + """Single request should match causal_conv1d_fn exactly.""" + torch.manual_seed(42) + conv_dim, d_conv, seq_len = 64, 4, 32 + device = "cuda" + + x = torch.randn(seq_len, conv_dim, dtype=dtype, device=device) + weight = torch.randn(conv_dim, d_conv, dtype=dtype, device=device) + bias = torch.randn(conv_dim, dtype=dtype, device=device) + cu_seqlens = torch.tensor([0, seq_len], dtype=torch.int32, device=device) + initial_states = torch.randn(1, conv_dim, d_conv - 1, dtype=dtype, device=device) + + out = causal_conv1d_varlen_fn(x, weight, bias, cu_seqlens, initial_states) + ref = _reference_conv1d_varlen(x, weight, bias, cu_seqlens, initial_states) + + atol = 1e-2 if dtype == torch.bfloat16 else 1e-5 + torch.testing.assert_close(out, ref, atol=atol, rtol=1e-2) + + @pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16]) + def test_multiple_requests_varying_lengths(self, dtype): + """Multiple requests with different sequence lengths.""" + torch.manual_seed(123) + conv_dim, d_conv = 128, 4 + seq_lens = [10, 25, 3, 50, 8] + device = "cuda" + + total_tokens = sum(seq_lens) + x = torch.randn(total_tokens, conv_dim, dtype=dtype, device=device) + weight = torch.randn(conv_dim, d_conv, dtype=dtype, device=device) + bias = torch.randn(conv_dim, dtype=dtype, device=device) + + cu_seqlens_list = [0] + for sl in seq_lens: + cu_seqlens_list.append(cu_seqlens_list[-1] + sl) + cu_seqlens = torch.tensor(cu_seqlens_list, dtype=torch.int32, device=device) + + num_requests = len(seq_lens) + initial_states = torch.randn( + num_requests, conv_dim, d_conv - 1, dtype=dtype, device=device + ) + + out = causal_conv1d_varlen_fn(x, weight, bias, cu_seqlens, initial_states) + ref = _reference_conv1d_varlen(x, weight, bias, cu_seqlens, initial_states) + + atol = 1e-2 if dtype == torch.bfloat16 else 1e-5 + torch.testing.assert_close(out, ref, atol=atol, rtol=1e-2) + + def test_seqlen_shorter_than_d_conv(self): + """Sequence shorter than d_conv should use initial_states for all taps.""" + torch.manual_seed(7) + conv_dim, d_conv = 32, 4 + seq_lens = [2, 1, 3] # All shorter than d_conv + device = "cuda" + dtype = torch.float32 + + total_tokens = sum(seq_lens) + x = torch.randn(total_tokens, conv_dim, dtype=dtype, device=device) + weight = torch.randn(conv_dim, d_conv, dtype=dtype, device=device) + bias = torch.randn(conv_dim, dtype=dtype, device=device) + + cu_seqlens_list = [0] + for sl in seq_lens: + cu_seqlens_list.append(cu_seqlens_list[-1] + sl) + cu_seqlens = torch.tensor(cu_seqlens_list, dtype=torch.int32, device=device) + + num_requests = len(seq_lens) + initial_states = torch.randn( + num_requests, conv_dim, d_conv - 1, dtype=dtype, device=device + ) + + out = causal_conv1d_varlen_fn(x, weight, bias, cu_seqlens, initial_states) + ref = _reference_conv1d_varlen(x, weight, bias, cu_seqlens, initial_states) + + torch.testing.assert_close(out, ref, atol=1e-5, rtol=1e-5) + + def test_zero_initial_states(self): + """Zero initial_states should produce same result as None initial_states.""" + torch.manual_seed(99) + conv_dim, d_conv = 64, 4 + seq_lens = [16, 24] + device = "cuda" + dtype = torch.float32 + + total_tokens = sum(seq_lens) + x = torch.randn(total_tokens, conv_dim, dtype=dtype, device=device) + weight = torch.randn(conv_dim, d_conv, dtype=dtype, device=device) + bias = torch.randn(conv_dim, dtype=dtype, device=device) + + cu_seqlens_list = [0] + for sl in seq_lens: + cu_seqlens_list.append(cu_seqlens_list[-1] + sl) + cu_seqlens = torch.tensor(cu_seqlens_list, dtype=torch.int32, device=device) + + num_requests = len(seq_lens) + zero_states = torch.zeros( + num_requests, conv_dim, d_conv - 1, dtype=dtype, device=device + ) + + out_zero = causal_conv1d_varlen_fn(x, weight, bias, cu_seqlens, zero_states) + out_none = causal_conv1d_varlen_fn(x, weight, bias, cu_seqlens, None) + + torch.testing.assert_close(out_zero, out_none, atol=1e-5, rtol=1e-5) + + def test_nonzero_vs_zero_initial_states_differ(self): + """Non-zero initial_states should produce different results from zero.""" + torch.manual_seed(55) + conv_dim, d_conv = 64, 4 + seq_len = 16 + device = "cuda" + dtype = torch.float32 + + x = torch.randn(seq_len, conv_dim, dtype=dtype, device=device) + weight = torch.randn(conv_dim, d_conv, dtype=dtype, device=device) + bias = torch.randn(conv_dim, dtype=dtype, device=device) + cu_seqlens = torch.tensor([0, seq_len], dtype=torch.int32, device=device) + + nonzero_states = torch.randn(1, conv_dim, d_conv - 1, dtype=dtype, device=device) + + out_nonzero = causal_conv1d_varlen_fn(x, weight, bias, cu_seqlens, nonzero_states) + out_none = causal_conv1d_varlen_fn(x, weight, bias, cu_seqlens, None) + + # First few tokens should differ (those that depend on initial state) + assert not torch.allclose(out_nonzero[:d_conv - 1], out_none[:d_conv - 1], atol=1e-5), ( + "Non-zero initial states should produce different outputs for early tokens" + ) diff --git a/tests/unit_tests/ssm/ops/test_ops_init.py b/tests/unit_tests/ssm/ops/test_ops_init.py new file mode 100644 index 00000000000..2a7c8b42a4a --- /dev/null +++ b/tests/unit_tests/ssm/ops/test_ops_init.py @@ -0,0 +1,27 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +"""Test that the megatron.core.ssm.ops package exports the public API.""" + +import unittest + +try: + from megatron.core.ssm import ops as ssm_ops + HAVE_SSD_OPS = True +except (ImportError, Exception): + HAVE_SSD_OPS = False + + +@unittest.skipIf(not HAVE_SSD_OPS, "SSD ops (Triton 3+) not available") +class TestOpsPackagePublicAPI(unittest.TestCase): + """Ensure the ops package exposes the documented public API.""" + + def test_all_exported(self): + self.assertIn("mamba_chunk_scan_combined_varlen", ssm_ops.__all__) + + def test_mamba_chunk_scan_combined_varlen_importable(self): + self.assertTrue(hasattr(ssm_ops, "mamba_chunk_scan_combined_varlen")) + self.assertTrue(callable(ssm_ops.mamba_chunk_scan_combined_varlen)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit_tests/ssm/ops/test_ssd_bmm.py b/tests/unit_tests/ssm/ops/test_ssd_bmm.py new file mode 100644 index 00000000000..cc15c758291 --- /dev/null +++ b/tests/unit_tests/ssm/ops/test_ssd_bmm.py @@ -0,0 +1,78 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +import unittest +import torch + +try: + from megatron.core.ssm.ops.ssd_bmm import _bmm_chunk_fwd + HAVE_SSD_OPS = True +except (ImportError, Exception): + HAVE_SSD_OPS = False + + +@unittest.skipIf(not HAVE_SSD_OPS, "SSD ops (Triton 3+) not available") +@unittest.skipIf(not torch.cuda.is_available(), "CUDA required for SSD ops") +class TestBmmChunkFwd(unittest.TestCase): + """Tests for _bmm_chunk_fwd (C^T @ B per chunk).""" + + def setUp(self): + torch.manual_seed(42) + self.device = torch.device("cuda") + self.chunk_size = 16 + self.seqlen = 32 + self.ngroups = 2 + self.dstate = 8 # K dimension + self.cu_chunk_seqlens = torch.tensor([0, 16, 32], dtype=torch.int32, device=self.device) + + def test_bmm_chunk_fwd_shape(self): + # a: (seqlen, ngroups, k), b: (seqlen, ngroups, k) -> out: (nchunks, ngroups, chunk_size, chunk_size) + a = torch.randn(self.seqlen, self.ngroups, self.dstate, device=self.device, dtype=torch.float32) + b = torch.randn(self.seqlen, self.ngroups, self.dstate, device=self.device, dtype=torch.float32) + + out = _bmm_chunk_fwd( + a, b, self.chunk_size, self.cu_chunk_seqlens, causal=True, output_dtype=torch.float32 + ) + + nchunks = self.cu_chunk_seqlens.shape[0] - 1 + self.assertEqual(out.shape, (nchunks, self.ngroups, self.chunk_size, self.chunk_size)) + self.assertFalse(torch.isnan(out).any()) + + def test_bmm_chunk_fwd_vs_torch_per_chunk(self): + """Compare first chunk with explicit C^T @ B for that chunk.""" + a = torch.randn(self.seqlen, self.ngroups, self.dstate, device=self.device, dtype=torch.float32) + b = torch.randn(self.seqlen, self.ngroups, self.dstate, device=self.device, dtype=torch.float32) + + out = _bmm_chunk_fwd( + a, b, self.chunk_size, self.cu_chunk_seqlens, causal=False, output_dtype=torch.float32 + ) + + # Chunk 0: rows 0:16 of a and b. out[0, g] = a[0:16, g] @ b[0:16, g].T + # Relaxed tolerances: Triton block-wise reduction order can differ from torch; + # atol is the main check (max abs diff was ~0.008 in practice). + for g in range(self.ngroups): + a_chunk = a[0:16, g, :].contiguous() # (16, dstate) + b_chunk = b[0:16, g, :].contiguous() # (16, dstate) + expected = torch.mm(a_chunk, b_chunk.T) # (16, 16) + torch.testing.assert_close(out[0, g], expected, rtol=1.0, atol=0.02) + + def test_bmm_chunk_fwd_causal_vs_non_causal_shape(self): + a = torch.randn(self.seqlen, self.ngroups, self.dstate, device=self.device, dtype=torch.float32) + b = torch.randn(self.seqlen, self.ngroups, self.dstate, device=self.device, dtype=torch.float32) + + out_causal = _bmm_chunk_fwd(a, b, self.chunk_size, self.cu_chunk_seqlens, causal=True) + out_noncausal = _bmm_chunk_fwd(a, b, self.chunk_size, self.cu_chunk_seqlens, causal=False) + + self.assertEqual(out_causal.shape, out_noncausal.shape) + # Causal: lower triangle is correct; upper can differ + for c in range(out_causal.shape[0]): + for g in range(self.ngroups): + for i in range(self.chunk_size): + for j in range(i + 1): + self.assertTrue( + torch.allclose(out_causal[c, g, i, j], out_noncausal[c, g, i, j]), + f"c={c} g={g} i={i} j={j}", + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit_tests/ssm/ops/test_ssd_chunk_scan.py b/tests/unit_tests/ssm/ops/test_ssd_chunk_scan.py new file mode 100644 index 00000000000..1c6d4ecfbc3 --- /dev/null +++ b/tests/unit_tests/ssm/ops/test_ssd_chunk_scan.py @@ -0,0 +1,98 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +import unittest +import torch + +try: + from megatron.core.ssm.ops.ssd_chunk_scan import _chunk_scan_fwd + HAVE_SSD_OPS = True +except (ImportError, Exception): + HAVE_SSD_OPS = False + + +@unittest.skipIf(not HAVE_SSD_OPS, "SSD ops (Triton 3+) not available") +@unittest.skipIf(not torch.cuda.is_available(), "CUDA required for SSD ops") +class TestChunkScanFwd(unittest.TestCase): + """Tests for _chunk_scan_fwd.""" + + def setUp(self): + torch.manual_seed(42) + self.device = torch.device("cuda") + self.seqlen = 32 + self.nheads = 4 + self.headdim = 16 + self.ngroups = 2 + self.dstate = 8 + self.chunk_size = 16 + self.nchunks = 2 + self.cu_chunk_seqlens = torch.tensor([0, 16, 32], dtype=torch.int32, device=self.device) + self.seq_idx = torch.tensor([0, 1], dtype=torch.int32, device=self.device) + + def test_chunk_scan_fwd_shape_and_inplace_out(self): + cb = torch.randn( + self.nchunks, self.ngroups, self.chunk_size, self.chunk_size, + device=self.device, dtype=torch.float32, + ) + x = torch.randn(self.seqlen, self.nheads, self.headdim, device=self.device, dtype=torch.float32) + dt = torch.randn(self.nheads, self.nchunks, self.chunk_size, device=self.device, dtype=torch.float32) + dA_cumsum = torch.randn(self.nheads, self.nchunks, self.chunk_size, device=self.device, dtype=torch.float32) + C = torch.randn(self.seqlen, self.ngroups, self.dstate, device=self.device, dtype=torch.float32) + states = torch.randn(self.nchunks, self.nheads, self.headdim, self.dstate, device=self.device, dtype=torch.float32) + out = torch.zeros(self.seqlen, self.nheads, self.headdim, device=self.device, dtype=torch.float32) + + _chunk_scan_fwd( + cb, x, dt, dA_cumsum, C, states, + self.cu_chunk_seqlens, out, self.seq_idx, + D=None, z=None, initial_states=None, + ) + + self.assertEqual(out.shape, (self.seqlen, self.nheads, self.headdim)) + self.assertFalse(torch.isnan(out).any()) + # Output should be non-zero (scan writes to out) + self.assertGreater(out.abs().max().item(), 0.0) + + def test_chunk_scan_fwd_with_D(self): + cb = torch.randn( + self.nchunks, self.ngroups, self.chunk_size, self.chunk_size, + device=self.device, dtype=torch.float32, + ) + x = torch.randn(self.seqlen, self.nheads, self.headdim, device=self.device, dtype=torch.float32) + dt = torch.randn(self.nheads, self.nchunks, self.chunk_size, device=self.device, dtype=torch.float32) + dA_cumsum = torch.randn(self.nheads, self.nchunks, self.chunk_size, device=self.device, dtype=torch.float32) + C = torch.randn(self.seqlen, self.ngroups, self.dstate, device=self.device, dtype=torch.float32) + states = torch.randn(self.nchunks, self.nheads, self.headdim, self.dstate, device=self.device, dtype=torch.float32) + out = torch.zeros(self.seqlen, self.nheads, self.headdim, device=self.device, dtype=torch.float32) + D = torch.ones(self.nheads, self.headdim, device=self.device, dtype=torch.float32) + + _chunk_scan_fwd( + cb, x, dt, dA_cumsum, C, states, + self.cu_chunk_seqlens, out, self.seq_idx, + D=D, z=None, initial_states=None, + ) + + self.assertFalse(torch.isnan(out).any()) + + def test_chunk_scan_fwd_with_z(self): + cb = torch.randn( + self.nchunks, self.ngroups, self.chunk_size, self.chunk_size, + device=self.device, dtype=torch.float32, + ) + x = torch.randn(self.seqlen, self.nheads, self.headdim, device=self.device, dtype=torch.float32) + z = torch.randn(self.seqlen, self.nheads, self.headdim, device=self.device, dtype=torch.float32) + dt = torch.randn(self.nheads, self.nchunks, self.chunk_size, device=self.device, dtype=torch.float32) + dA_cumsum = torch.randn(self.nheads, self.nchunks, self.chunk_size, device=self.device, dtype=torch.float32) + C = torch.randn(self.seqlen, self.ngroups, self.dstate, device=self.device, dtype=torch.float32) + states = torch.randn(self.nchunks, self.nheads, self.headdim, self.dstate, device=self.device, dtype=torch.float32) + out = torch.zeros(self.seqlen, self.nheads, self.headdim, device=self.device, dtype=torch.float32) + + _chunk_scan_fwd( + cb, x, dt, dA_cumsum, C, states, + self.cu_chunk_seqlens, out, self.seq_idx, + D=None, z=z, initial_states=None, + ) + + self.assertFalse(torch.isnan(out).any()) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit_tests/ssm/ops/test_ssd_chunk_state.py b/tests/unit_tests/ssm/ops/test_ssd_chunk_state.py new file mode 100644 index 00000000000..f72909d9084 --- /dev/null +++ b/tests/unit_tests/ssm/ops/test_ssd_chunk_state.py @@ -0,0 +1,164 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +import unittest +import torch + +try: + from megatron.core.ssm.ops.ssd_chunk_state import ( + _chunk_cumsum_fwd, + _chunk_state_fwd, + chunk_state_varlen, + ) + HAVE_SSD_OPS = True +except (ImportError, Exception): + HAVE_SSD_OPS = False + + +@unittest.skipIf(not HAVE_SSD_OPS, "SSD ops (Triton 3+) not available") +@unittest.skipIf(not torch.cuda.is_available(), "CUDA required for SSD ops") +class TestChunkCumsumFwd(unittest.TestCase): + """Tests for _chunk_cumsum_fwd.""" + + def setUp(self): + torch.manual_seed(42) + self.device = torch.device("cuda") + self.seqlen = 32 + self.nheads = 4 + self.chunk_size = 16 + self.cu_chunk_seqlens = torch.tensor([0, 16, 32], dtype=torch.int32, device=self.device) + + def test_chunk_cumsum_fwd_shape(self): + dt = torch.randn(self.seqlen, self.nheads, device=self.device, dtype=torch.float32) + A = torch.randn(self.nheads, device=self.device, dtype=torch.float32) + + dA_cumsum, dt_out = _chunk_cumsum_fwd( + dt, A, self.chunk_size, self.cu_chunk_seqlens + ) + + nchunks = self.cu_chunk_seqlens.shape[0] - 1 + self.assertEqual(dA_cumsum.shape, (self.nheads, nchunks, self.chunk_size)) + self.assertEqual(dt_out.shape, (self.nheads, nchunks, self.chunk_size)) + self.assertFalse(torch.isnan(dA_cumsum).any()) + self.assertFalse(torch.isnan(dt_out).any()) + + def test_chunk_cumsum_fwd_cumsum_per_chunk(self): + """dA_cumsum should be cumsum of dt * A along the chunk dimension.""" + dt = torch.randn(self.seqlen, self.nheads, device=self.device, dtype=torch.float32) + A = torch.randn(self.nheads, device=self.device, dtype=torch.float32) + + dA_cumsum, dt_out = _chunk_cumsum_fwd( + dt, A, self.chunk_size, self.cu_chunk_seqlens, + dt_bias=None, dt_softplus=False, dt_limit=(0.0, float("inf")), + ) + + nchunks = self.cu_chunk_seqlens.shape[0] - 1 + for c in range(nchunks): + start = self.cu_chunk_seqlens[c].item() + end = self.cu_chunk_seqlens[c + 1].item() + chunk_len = end - start + for h in range(self.nheads): + dA_chunk = (dt_out[h, c, :chunk_len] * A[h]).cpu() + expected_cumsum = torch.cumsum(dA_chunk, dim=0) + torch.testing.assert_close( + dA_cumsum[h, c, :chunk_len].cpu(), expected_cumsum, rtol=1e-4, atol=1e-4 + ) + + def test_chunk_cumsum_fwd_with_dt_bias(self): + dt = torch.randn(self.seqlen, self.nheads, device=self.device, dtype=torch.float32) + A = torch.randn(self.nheads, device=self.device, dtype=torch.float32) + dt_bias = torch.randn(self.nheads, device=self.device, dtype=torch.float32) + + dA_cumsum, dt_out = _chunk_cumsum_fwd( + dt, A, self.chunk_size, self.cu_chunk_seqlens, dt_bias=dt_bias + ) + + nchunks = self.cu_chunk_seqlens.shape[0] - 1 + self.assertEqual(dA_cumsum.shape, (self.nheads, nchunks, self.chunk_size)) + self.assertFalse(torch.isnan(dA_cumsum).any()) + + +@unittest.skipIf(not HAVE_SSD_OPS, "SSD ops (Triton 3+) not available") +@unittest.skipIf(not torch.cuda.is_available(), "CUDA required for SSD ops") +class TestChunkStateFwd(unittest.TestCase): + """Tests for _chunk_state_fwd.""" + + def setUp(self): + torch.manual_seed(42) + self.device = torch.device("cuda") + self.seqlen = 32 + self.nheads = 4 + self.headdim = 16 + self.ngroups = 2 + self.dstate = 8 + self.chunk_size = 16 + self.cu_chunk_seqlens = torch.tensor([0, 16, 32], dtype=torch.int32, device=self.device) + + def test_chunk_state_fwd_shape(self): + x = torch.randn(self.seqlen, self.nheads, self.headdim, device=self.device, dtype=torch.float32) + B = torch.randn(self.seqlen, self.ngroups, self.dstate, device=self.device, dtype=torch.float32) + dt = torch.randn(self.nheads, 2, self.chunk_size, device=self.device, dtype=torch.float32) + dA_cumsum = torch.randn(self.nheads, 2, self.chunk_size, device=self.device, dtype=torch.float32) + + states = _chunk_state_fwd(B, x, dt, dA_cumsum, self.cu_chunk_seqlens) + + nchunks = self.cu_chunk_seqlens.shape[0] - 1 + self.assertEqual(states.shape, (nchunks, self.nheads, self.headdim, self.dstate)) + self.assertFalse(torch.isnan(states).any()) + + +@unittest.skipIf(not HAVE_SSD_OPS, "SSD ops (Triton 3+) not available") +@unittest.skipIf(not torch.cuda.is_available(), "CUDA required for SSD ops") +class TestChunkStateVarlen(unittest.TestCase): + """Tests for chunk_state_varlen.""" + + def setUp(self): + torch.manual_seed(42) + self.device = torch.device("cuda") + self.seqlen = 32 + self.nheads = 4 + self.headdim = 16 + self.ngroups = 2 + self.dstate = 8 + self.chunk_size = 16 + self.batch = 2 + self.cu_seqlens = torch.tensor([0, 16, 32], dtype=torch.int32, device=self.device) + self.cu_chunk_seqlens = torch.tensor([0, 16, 32], dtype=torch.int32, device=self.device) + self.last_chunk_indices = torch.tensor([0, 1], dtype=torch.int64, device=self.device) + + def test_chunk_state_varlen_shape(self): + x = torch.randn(self.seqlen, self.nheads, self.headdim, device=self.device, dtype=torch.float32) + B = torch.randn(self.seqlen, self.ngroups, self.dstate, device=self.device, dtype=torch.float32) + dt = torch.randn(self.nheads, 2, self.chunk_size, device=self.device, dtype=torch.float32) + dA_cumsum = torch.randn(self.nheads, 2, self.chunk_size, device=self.device, dtype=torch.float32) + chunk_states = torch.randn(2, self.nheads, self.headdim, self.dstate, device=self.device, dtype=torch.float32) + + states = chunk_state_varlen( + B, x, dt, dA_cumsum, self.cu_seqlens, chunk_states, + last_chunk_indices=self.last_chunk_indices, + cu_chunk_seqlens=self.cu_chunk_seqlens, + ) + + self.assertEqual(states.shape, (self.batch, self.nheads, self.headdim, self.dstate)) + self.assertFalse(torch.isnan(states).any()) + + def test_chunk_state_varlen_with_initial_states(self): + x = torch.randn(self.seqlen, self.nheads, self.headdim, device=self.device, dtype=torch.float32) + B = torch.randn(self.seqlen, self.ngroups, self.dstate, device=self.device, dtype=torch.float32) + dt = torch.randn(self.nheads, 2, self.chunk_size, device=self.device, dtype=torch.float32) + dA_cumsum = torch.randn(self.nheads, 2, self.chunk_size, device=self.device, dtype=torch.float32) + chunk_states = torch.randn(2, self.nheads, self.headdim, self.dstate, device=self.device, dtype=torch.float32) + initial_states = torch.randn(self.batch, self.nheads, self.headdim, self.dstate, device=self.device, dtype=torch.float32) + + states = chunk_state_varlen( + B, x, dt, dA_cumsum, self.cu_seqlens, chunk_states, + initial_states=initial_states, + last_chunk_indices=self.last_chunk_indices, + cu_chunk_seqlens=self.cu_chunk_seqlens, + ) + + self.assertEqual(states.shape, (self.batch, self.nheads, self.headdim, self.dstate)) + self.assertFalse(torch.isnan(states).any()) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit_tests/ssm/ops/test_ssd_combined.py b/tests/unit_tests/ssm/ops/test_ssd_combined.py new file mode 100644 index 00000000000..52bceb94a91 --- /dev/null +++ b/tests/unit_tests/ssm/ops/test_ssd_combined.py @@ -0,0 +1,150 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +import unittest +import torch + +try: + from megatron.core.ssm.ops.ssd_combined import ( + is_int_pow_2, + mamba_chunk_scan_combined_varlen, + ) + HAVE_SSD_OPS = True +except (ImportError, Exception): + HAVE_SSD_OPS = False + + +@unittest.skipIf(not HAVE_SSD_OPS, "SSD ops (Triton 3+) not available") +@unittest.skipIf(not torch.cuda.is_available(), "CUDA required for SSD ops") +class TestIsIntPow2(unittest.TestCase): + """Tests for is_int_pow_2 utility.""" + + def test_powers_of_two(self): + for exp in range(12): + n = 2 ** exp + self.assertTrue(is_int_pow_2(n), f"2^{exp}={n} should be power of 2") + + def test_non_powers_of_two(self): + for n in [0, 3, 5, 6, 7, 9, 10, 12, 15, 18]: + self.assertFalse(is_int_pow_2(n), f"{n} should not be power of 2") + + def test_negative_and_float(self): + self.assertFalse(is_int_pow_2(-1)) + self.assertFalse(is_int_pow_2(-4)) + self.assertFalse(is_int_pow_2(2.0)) + self.assertFalse(is_int_pow_2(0)) + + +@unittest.skipIf(not HAVE_SSD_OPS, "SSD ops (Triton 3+) not available") +@unittest.skipIf(not torch.cuda.is_available(), "CUDA required for SSD ops") +class TestMambaChunkScanCombinedVarlen(unittest.TestCase): + """Tests for mamba_chunk_scan_combined_varlen.""" + + def setUp(self): + torch.manual_seed(42) + self.device = torch.device("cuda") + self.chunk_size = 16 + self.seqlen = 32 + self.nheads = 4 + self.headdim = 16 + self.ngroups = 2 + self.dstate = 8 + self.batch = 2 + # cu_seqlens: [0, 16, 32] -> two sequences of length 16 each + self.cu_seqlens = torch.tensor([0, 16, 32], dtype=torch.int32, device=self.device) + # 2 chunks of 16 each + self.cu_chunk_seqlens = torch.tensor([0, 16, 32], dtype=torch.int32, device=self.device) + # last chunk index per sequence: seq0 ends in chunk 0, seq1 ends in chunk 1 + self.last_chunk_indices = torch.tensor([0, 1], dtype=torch.int64, device=self.device) + # seq_idx: which sequence each chunk belongs to (nchunks,) + self.seq_idx = torch.tensor([0, 1], dtype=torch.int32, device=self.device) + + def test_mamba_chunk_scan_combined_varlen_shape_and_no_nan(self): + x = torch.randn(self.seqlen, self.nheads, self.headdim, device=self.device, dtype=torch.float32) + dt = torch.randn(self.seqlen, self.nheads, device=self.device, dtype=torch.float32) + A = torch.randn(self.nheads, device=self.device, dtype=torch.float32) + B = torch.randn(self.seqlen, self.ngroups, self.dstate, device=self.device, dtype=torch.float32) + C = torch.randn(self.seqlen, self.ngroups, self.dstate, device=self.device, dtype=torch.float32) + out = torch.empty(self.seqlen, self.nheads, self.headdim, device=self.device, dtype=torch.float32) + + varlen_states = mamba_chunk_scan_combined_varlen( + x=x, + dt=dt, + A=A, + B=B, + C=C, + chunk_size=self.chunk_size, + cu_seqlens=self.cu_seqlens, + cu_chunk_seqlens=self.cu_chunk_seqlens, + last_chunk_indices=self.last_chunk_indices, + seq_idx=self.seq_idx, + out=out, + ) + + self.assertEqual(varlen_states.shape, (self.batch, self.nheads, self.headdim, self.dstate)) + self.assertEqual(out.shape, (self.seqlen, self.nheads, self.headdim)) + self.assertFalse(torch.isnan(out).any(), "output should have no NaN") + self.assertFalse(torch.isnan(varlen_states).any(), "varlen_states should have no NaN") + + def test_mamba_chunk_scan_combined_varlen_with_D_and_dt_bias(self): + x = torch.randn(self.seqlen, self.nheads, self.headdim, device=self.device, dtype=torch.float32) + dt = torch.randn(self.seqlen, self.nheads, device=self.device, dtype=torch.float32) + A = torch.randn(self.nheads, device=self.device, dtype=torch.float32) + B = torch.randn(self.seqlen, self.ngroups, self.dstate, device=self.device, dtype=torch.float32) + C = torch.randn(self.seqlen, self.ngroups, self.dstate, device=self.device, dtype=torch.float32) + D = torch.ones(self.nheads, self.headdim, device=self.device, dtype=torch.float32) + dt_bias = torch.randn(self.nheads, device=self.device, dtype=torch.float32) + out = torch.empty(self.seqlen, self.nheads, self.headdim, device=self.device, dtype=torch.float32) + + varlen_states = mamba_chunk_scan_combined_varlen( + x=x, + dt=dt, + A=A, + B=B, + C=C, + chunk_size=self.chunk_size, + cu_seqlens=self.cu_seqlens, + cu_chunk_seqlens=self.cu_chunk_seqlens, + last_chunk_indices=self.last_chunk_indices, + seq_idx=self.seq_idx, + out=out, + D=D, + dt_bias=dt_bias, + ) + + self.assertEqual(varlen_states.shape, (self.batch, self.nheads, self.headdim, self.dstate)) + self.assertFalse(torch.isnan(out).any()) + + def test_mamba_chunk_scan_combined_varlen_single_sequence(self): + """Single sequence: cu_seqlens [0, 32], one sequence of 32.""" + cu_seqlens = torch.tensor([0, 32], dtype=torch.int32, device=self.device) + cu_chunk_seqlens = torch.tensor([0, 16, 32], dtype=torch.int32, device=self.device) + last_chunk_indices = torch.tensor([1], dtype=torch.int64, device=self.device) + seq_idx = torch.tensor([0, 0], dtype=torch.int32, device=self.device) + + x = torch.randn(32, self.nheads, self.headdim, device=self.device, dtype=torch.float32) + dt = torch.randn(32, self.nheads, device=self.device, dtype=torch.float32) + A = torch.randn(self.nheads, device=self.device, dtype=torch.float32) + B = torch.randn(32, self.ngroups, self.dstate, device=self.device, dtype=torch.float32) + C = torch.randn(32, self.ngroups, self.dstate, device=self.device, dtype=torch.float32) + out = torch.empty(32, self.nheads, self.headdim, device=self.device, dtype=torch.float32) + + varlen_states = mamba_chunk_scan_combined_varlen( + x=x, + dt=dt, + A=A, + B=B, + C=C, + chunk_size=self.chunk_size, + cu_seqlens=cu_seqlens, + cu_chunk_seqlens=cu_chunk_seqlens, + last_chunk_indices=last_chunk_indices, + seq_idx=seq_idx, + out=out, + ) + + self.assertEqual(varlen_states.shape, (1, self.nheads, self.headdim, self.dstate)) + self.assertFalse(torch.isnan(out).any()) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit_tests/ssm/ops/test_ssd_state_passing.py b/tests/unit_tests/ssm/ops/test_ssd_state_passing.py new file mode 100644 index 00000000000..e8dccdcbbcf --- /dev/null +++ b/tests/unit_tests/ssm/ops/test_ssd_state_passing.py @@ -0,0 +1,87 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +import unittest +import torch + +try: + from megatron.core.ssm.ops.ssd_state_passing import _state_passing_fwd + HAVE_SSD_OPS = True +except (ImportError, Exception): + HAVE_SSD_OPS = False + + +@unittest.skipIf(not HAVE_SSD_OPS, "SSD ops (Triton 3+) not available") +@unittest.skipIf(not torch.cuda.is_available(), "CUDA required for SSD ops") +class TestStatePassingFwd(unittest.TestCase): + """Tests for _state_passing_fwd: recurrence out = exp(dA_cs_last) * prev + new_states.""" + + def setUp(self): + torch.manual_seed(42) + self.device = torch.device("cuda") + self.nchunks = 4 + self.nheads = 2 + self.chunk_size = 16 + self.dim = self.chunk_size * 8 # headdim * dstate flattened + self.cu_chunk_seqlens = torch.tensor( + [0, 16, 32, 48, 64], dtype=torch.int32, device=self.device + ) + + def test_state_passing_fwd_shape(self): + states = torch.randn( + self.nchunks, self.nheads, self.dim, device=self.device, dtype=torch.float32 + ) + dA_cumsum = torch.randn( + self.nheads, self.nchunks, self.chunk_size, device=self.device, dtype=torch.float32 + ) + seq_idx = torch.zeros(self.nchunks, dtype=torch.int32, device=self.device) + + out = _state_passing_fwd( + states, dA_cumsum, self.cu_chunk_seqlens, seq_idx, initial_states=None + ) + + self.assertEqual(out.shape, (self.nchunks, self.nheads, self.dim)) + self.assertFalse(torch.isnan(out).any()) + + def test_state_passing_fwd_with_initial_states(self): + states = torch.randn( + self.nchunks, self.nheads, self.dim, device=self.device, dtype=torch.float32 + ) + dA_cumsum = torch.randn( + self.nheads, self.nchunks, self.chunk_size, device=self.device, dtype=torch.float32 + ) + seq_idx = torch.tensor([0, 0, 1, 1], dtype=torch.int32, device=self.device) + initial_states = torch.randn(2, self.nheads, self.dim, device=self.device, dtype=torch.float32) + + out = _state_passing_fwd( + states, + dA_cumsum, + self.cu_chunk_seqlens, + seq_idx, + initial_states=initial_states, + ) + + self.assertEqual(out.shape, (self.nchunks, self.nheads, self.dim)) + self.assertFalse(torch.isnan(out).any()) + + def test_state_passing_fwd_recurrence_single_head_single_dim(self): + """Sanity: single head, small dim, check recurrence manually for first elements.""" + dim = 4 + nchunks = 2 + nheads = 1 + chunk_size = 2 + cu_chunk_seqlens = torch.tensor([0, 2, 4], dtype=torch.int32, device=self.device) + seq_idx = torch.zeros(nchunks, dtype=torch.int32, device=self.device) + + states = torch.randn(nchunks, nheads, dim, device=self.device, dtype=torch.float32) + dA_cumsum = torch.randn(nheads, nchunks, chunk_size, device=self.device, dtype=torch.float32) + + out = _state_passing_fwd(states, dA_cumsum, cu_chunk_seqlens, seq_idx) + + # Chunk 0: out[0] = exp(dA_cumsum[0,-1]) * 0 + states[0] = states[0] (no initial state) + # So out[0] should equal states[0] + torch.testing.assert_close(out[0], states[0], rtol=1e-4, atol=1e-4) + self.assertEqual(out.shape, (nchunks, nheads, dim)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit_tests/ssm/ops/test_ssm_kernel.py b/tests/unit_tests/ssm/ops/test_ssm_kernel.py new file mode 100644 index 00000000000..646fbb8162c --- /dev/null +++ b/tests/unit_tests/ssm/ops/test_ssm_kernel.py @@ -0,0 +1,162 @@ +import unittest +from unittest.mock import MagicMock +import torch +import torch.nn as nn +import math + +# Assume the provided class is in mamba_mixer.py +from megatron.core.ssm.mamba_mixer import MambaMixer + +class MockContextParallel: + """ + Mocks the MambaContextParallel helper. + """ + def __init__(self, d_inner, ngroups, nheads, d_state, device): + self.d_inner_local_tpcp = d_inner + self.ngroups_local_tpcp = ngroups + self.nheads_local_tpcp = nheads + self.cp_size = 1 + + # Random weights for the mock + self.conv1d_weight = torch.randn(d_inner + 2 * ngroups * d_state, 1, 4, device=device) + self.conv1d_bias = torch.randn(d_inner + 2 * ngroups * d_state, device=device) + self.A_log = torch.randn(nheads, device=device) + self.D = torch.ones(nheads, device=device) + self.dt_bias = torch.randn(nheads, device=device) + + # Simple conv1d layer for the fallback path if needed + self.conv1d_layer = nn.Conv1d( + in_channels=self.conv1d_weight.shape[0], + out_channels=self.conv1d_weight.shape[0], + kernel_size=4, groups=self.conv1d_weight.shape[0], padding=3 + ).to(device) + + def get_A_log(self): return self.A_log + def get_D(self): return self.D + def get_dt_bias(self): return self.dt_bias + def get_conv1d_weight(self): return self.conv1d_weight + def get_conv1d_bias(self): return self.conv1d_bias + + def conv1d(self, x): + return self.conv1d_layer(x) + + def pre_conv_ssm(self, x): return x + def post_conv_ssm(self, x): return x + + +class TestMambaDynamicInference(unittest.TestCase): + + def setUp(self): + torch.manual_seed(42) + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + if self.device.type == 'cpu': + self.skipTest("Mamba Triton kernels require CUDA") + + # --- Configuration --- + self.d_model = 256 + self.d_state = 16 + self.headdim = 64 + self.d_conv = 4 + self.ngroups = 1 + self.d_inner = self.d_model * 2 # expand=2 + self.nheads = self.d_inner // self.headdim + + # Create the Mixer instance directly + self.mixer = MagicMock(spec=MambaMixer) + self.mixer.d_state = self.d_state + self.mixer.d_conv = self.d_conv + self.mixer.headdim = self.headdim + self.mixer.chunk_size = 256 + self.mixer.activation = "silu" + self.mixer.act = nn.SiLU() + self.mixer.D_has_hdim = False + self.mixer.rmsnorm = True + + # Mock the Context Parallel wrapper (used by ssm_prefill) + self.mixer.cp = MockContextParallel( + d_inner=self.d_inner, + ngroups=self.ngroups, + nheads=self.nheads, + d_state=self.d_state, + device=self.device + ) + + # --- Setup for ssm_decode --- + # ssm_decode accesses attributes directly from self, not self.cp + self.mixer.d_inner_local_tp = self.d_inner + self.mixer.ngroups_local_tp = self.ngroups + self.mixer.nheads_local_tp = self.nheads + + # Create real parameters for ssm_decode to access + conv_dim = self.d_inner + 2 * self.ngroups * self.d_state + self.mixer.conv1d = nn.Conv1d( + in_channels=conv_dim, out_channels=conv_dim, + kernel_size=self.d_conv, groups=conv_dim, padding=self.d_conv - 1, + bias=True, device=self.device + ) + self.mixer.dt_bias = nn.Parameter(torch.randn(self.nheads, device=self.device)) + self.mixer.A_log = nn.Parameter(torch.randn(self.nheads, device=self.device)) + self.mixer.D = nn.Parameter(torch.ones(self.nheads, device=self.device)) + + # Bind methods + self.mixer._ssm_prefill = MambaMixer._ssm_prefill.__get__(self.mixer, MambaMixer) + self.mixer._ssm_decode = MambaMixer._ssm_decode.__get__(self.mixer, MambaMixer) + + def test_ssm_prefill_padding_isolation(self): + """ + Tests that ssm_prefill only updates states for the real request + and outputs zeros for padding tokens. + """ + num_requests = 48 + real_seq_len = 6 + total_tokens = 63 + + # Inputs + dim_inputs = self.d_inner * 2 + 2 * self.ngroups * self.d_state + self.nheads + zxBCdt = torch.randn(total_tokens, 1, dim_inputs, device=self.device, dtype=torch.float32) + + # Metadata + seq_idx = torch.full((total_tokens,), -1, dtype=torch.int32, device=self.device) + seq_idx[:real_seq_len] = 0 + seq_idx = seq_idx.unsqueeze(0) + + cu_seqlens = torch.full((num_requests + 1,), real_seq_len, dtype=torch.int32, device=self.device) + cu_seqlens[0] = 0 + + batch_indices = torch.full((num_requests,), -1, dtype=torch.long, device=self.device) + batch_indices[0] = 0 + + # States + conv_dim = self.d_inner + 2 * self.ngroups * self.d_state + conv_state = torch.zeros(num_requests, conv_dim, self.d_conv, device=self.device) + ssm_state = torch.zeros(num_requests, self.nheads, self.headdim, self.d_state, device=self.device) + + # Run + self.mixer.norm = MagicMock(side_effect=lambda x, z: x * z) + output = self.mixer._ssm_prefill( + zxBCdt=zxBCdt, conv_state=conv_state, ssm_state=ssm_state, + seq_idx=seq_idx, cu_seqlens=cu_seqlens, + batch_indices=batch_indices, return_varlen_states=True + ) + + # Assertions + real_output = output[0:real_seq_len] + padding_output = output[real_seq_len:] + + self.assertTrue(torch.allclose(padding_output, torch.zeros_like(padding_output)), + "Output for padding tokens should be 0") + self.assertTrue(conv_state[0].abs().max() > 0, "Real request conv_state should be modified") + + # Verify isolation of padding states + remaining_conv_states = conv_state[1:num_requests] + remaining_ssm_states = ssm_state[1:num_requests] + + self.assertTrue(torch.allclose(remaining_conv_states, torch.zeros_like(remaining_conv_states)), + "Conv states for padding requests (indices 1 to N-1) should remain 0") + self.assertTrue(torch.allclose(remaining_ssm_states, torch.zeros_like(remaining_ssm_states)), + "SSM states for padding requests (indices 1 to N-1) should remain 0") + print("Prefill Test Passed!") + + +if __name__ == '__main__': + unittest.main(argv=['first-arg-is-ignored'], exit=False)