diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index a4efec7948e5..e39ceabd3988 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -679,14 +679,18 @@ def conv_ssm_forward( # for prefill and decode block_idx_last_computed_token_d, block_idx_last_computed_token_p = ( torch.split( - attn_metadata.block_idx_last_computed_token, + attn_metadata.block_idx_last_computed_token[ + : num_decodes + num_prefills + ], [num_decodes, num_prefills], dim=0, ) ) block_idx_last_scheduled_token_d, block_idx_last_scheduled_token_p = ( torch.split( - attn_metadata.block_idx_last_scheduled_token, + attn_metadata.block_idx_last_scheduled_token[ + : num_decodes + num_prefills + ], [num_decodes, num_prefills], dim=0, ) diff --git a/vllm/model_executor/models/nemotron_h_mtp.py b/vllm/model_executor/models/nemotron_h_mtp.py index ce927f8e2c7f..f454ba433442 100644 --- a/vllm/model_executor/models/nemotron_h_mtp.py +++ b/vllm/model_executor/models/nemotron_h_mtp.py @@ -290,7 +290,6 @@ class NemotronHMTP(nn.Module, SupportsPP): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config self.vllm_config = vllm_config self.config = config self.quant_config = vllm_config.quant_config @@ -305,10 +304,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): vllm_config.parallel_config.eplb_config.num_redundant_experts ) - assert not cache_config.enable_prefix_caching, ( - "NemotronHMTP currently does not support prefix caching" - ) - # MTP predictor self.model = NemotronHMultiTokenPredictor( vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") diff --git a/vllm/v1/attention/backends/mamba2_attn.py b/vllm/v1/attention/backends/mamba2_attn.py index 4957df7ad66a..f451b8395744 100644 --- a/vllm/v1/attention/backends/mamba2_attn.py +++ b/vllm/v1/attention/backends/mamba2_attn.py @@ -375,17 +375,27 @@ def build( device=common_attn_metadata.query_start_loc.device, ) - spec_state_indices_tensor = common_attn_metadata.block_table_tensor[ - :, : self.num_spec + 1 - ] + if self.vllm_config.cache_config.enable_prefix_caching: + num_cacheable_blocks = num_computed_tokens // mamba_block_size + block_indices = num_cacheable_blocks.unsqueeze(1) + torch.arange( + self.num_spec + 1, device=self.device + ).unsqueeze(0) + batch_indices = torch.arange( + common_attn_metadata.block_table_tensor.size(0), + device=self.device, + ).unsqueeze(1) + spec_state_indices_tensor = common_attn_metadata.block_table_tensor[ + batch_indices, block_indices + ] + else: + spec_state_indices_tensor = common_attn_metadata.block_table_tensor[ + :, : self.num_spec + 1 + ] non_spec_state_indices_tensor = None spec_query_start_loc = common_attn_metadata.query_start_loc non_spec_query_start_loc = None else: if self.vllm_config.cache_config.enable_prefix_caching: - block_idx_first_scheduled_token = block_idx_first_scheduled_token[ - ~spec_sequence_masks - ] block_idx_last_scheduled_token = block_idx_last_scheduled_token[ ~spec_sequence_masks ] @@ -401,14 +411,24 @@ def build( non_spec_token_indx = index[:num_non_spec_tokens] spec_token_indx = index[num_non_spec_tokens:] - spec_state_indices_tensor = common_attn_metadata.block_table_tensor[ - spec_sequence_masks, : self.num_spec + 1 - ] if self.vllm_config.cache_config.enable_prefix_caching: + num_cacheable_blocks = num_computed_tokens // mamba_block_size + block_indices = num_cacheable_blocks[spec_sequence_masks].unsqueeze( + 1 + ) + torch.arange(self.num_spec + 1, device=self.device).unsqueeze(0) + batch_indices = torch.arange( + spec_sequence_masks.sum().item(), device=self.device + ).unsqueeze(1) + spec_state_indices_tensor = common_attn_metadata.block_table_tensor[ + spec_sequence_masks + ][batch_indices, block_indices] non_spec_state_indices_tensor = ( common_attn_metadata.block_table_tensor[~spec_sequence_masks] ) else: + spec_state_indices_tensor = common_attn_metadata.block_table_tensor[ + spec_sequence_masks, : self.num_spec + 1 + ] non_spec_state_indices_tensor = ( common_attn_metadata.block_table_tensor[~spec_sequence_masks, 0] ) @@ -450,7 +470,8 @@ def build( common_attn_metadata.query_start_loc.device ) - # Subtract ALL decode tokens (spec + non-spec) to get prefill-only coordinates + # Subtract ALL decode tokens (spec + non-spec) + # to get prefill-only coordinates total_decode_tokens = num_decode_tokens + num_spec_decode_tokens query_start_loc_p = ( common_attn_metadata.query_start_loc[-num_prefills - 1 :] diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index d90ec550f766..c60f377bd52f 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -645,6 +645,12 @@ def find_longest_cache_hit( computed.append(cached) break # we just need the last match - early stopping + # TODO - do we need to pop the last block if use_eagle is True? + # With hybrid models, it seems unnecessary, + # since it is done in the full attention manager. + # Popping again means the last 2 blocks are popped overall, + # and the performance is hit hard. + return computed_blocks def get_num_common_prefix_blocks(self, running_request_id: str) -> int: