Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions vllm/model_executor/layers/mamba/mamba_mixer2.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,7 +595,7 @@
num_accepted_tokens = attn_metadata.num_accepted_tokens
spec_sequence_masks = attn_metadata.spec_sequence_masks
spec_query_start_loc = attn_metadata.spec_query_start_loc
non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc

Check failure on line 598 in vllm/model_executor/layers/mamba/mamba_mixer2.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (F841)

vllm/model_executor/layers/mamba/mamba_mixer2.py:598:13: F841 Local variable `non_spec_query_start_loc` is assigned to but never used
num_spec_decodes = attn_metadata.num_spec_decodes
# token count (non-spec only!)
num_decodes = attn_metadata.num_decode_tokens
Expand Down Expand Up @@ -679,14 +679,18 @@
# for prefill and decode
block_idx_last_computed_token_d, block_idx_last_computed_token_p = (
torch.split(
attn_metadata.block_idx_last_computed_token,
attn_metadata.block_idx_last_computed_token[
: num_decodes + num_prefills
],
[num_decodes, num_prefills],
dim=0,
)
)
block_idx_last_scheduled_token_d, block_idx_last_scheduled_token_p = (
torch.split(
attn_metadata.block_idx_last_scheduled_token,
attn_metadata.block_idx_last_scheduled_token[
: num_decodes + num_prefills
],
[num_decodes, num_prefills],
dim=0,
)
Expand Down
5 changes: 0 additions & 5 deletions vllm/model_executor/models/nemotron_h_mtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,6 @@ class NemotronHMTP(nn.Module, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
self.vllm_config = vllm_config
self.config = config
self.quant_config = vllm_config.quant_config
Expand All @@ -305,10 +304,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
vllm_config.parallel_config.eplb_config.num_redundant_experts
)

assert not cache_config.enable_prefix_caching, (
"NemotronHMTP currently does not support prefix caching"
)

# MTP predictor
self.model = NemotronHMultiTokenPredictor(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
Expand Down
41 changes: 31 additions & 10 deletions vllm/v1/attention/backends/mamba2_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,17 +375,27 @@ def build(
device=common_attn_metadata.query_start_loc.device,
)

spec_state_indices_tensor = common_attn_metadata.block_table_tensor[
:, : self.num_spec + 1
]
if self.vllm_config.cache_config.enable_prefix_caching:
num_cacheable_blocks = num_computed_tokens // mamba_block_size
block_indices = num_cacheable_blocks.unsqueeze(1) + torch.arange(
self.num_spec + 1, device=self.device
).unsqueeze(0)
batch_indices = torch.arange(
common_attn_metadata.block_table_tensor.size(0),
device=self.device,
).unsqueeze(1)
spec_state_indices_tensor = common_attn_metadata.block_table_tensor[
batch_indices, block_indices
]
else:
spec_state_indices_tensor = common_attn_metadata.block_table_tensor[
:, : self.num_spec + 1
]
non_spec_state_indices_tensor = None
spec_query_start_loc = common_attn_metadata.query_start_loc
non_spec_query_start_loc = None
else:
if self.vllm_config.cache_config.enable_prefix_caching:
block_idx_first_scheduled_token = block_idx_first_scheduled_token[
~spec_sequence_masks
]
block_idx_last_scheduled_token = block_idx_last_scheduled_token[
~spec_sequence_masks
]
Expand All @@ -401,14 +411,24 @@ def build(
non_spec_token_indx = index[:num_non_spec_tokens]
spec_token_indx = index[num_non_spec_tokens:]

spec_state_indices_tensor = common_attn_metadata.block_table_tensor[
spec_sequence_masks, : self.num_spec + 1
]
if self.vllm_config.cache_config.enable_prefix_caching:
num_cacheable_blocks = num_computed_tokens // mamba_block_size
block_indices = num_cacheable_blocks[spec_sequence_masks].unsqueeze(
1
) + torch.arange(self.num_spec + 1, device=self.device).unsqueeze(0)
batch_indices = torch.arange(
spec_sequence_masks.sum().item(), device=self.device
).unsqueeze(1)
spec_state_indices_tensor = common_attn_metadata.block_table_tensor[
spec_sequence_masks
][batch_indices, block_indices]
non_spec_state_indices_tensor = (
common_attn_metadata.block_table_tensor[~spec_sequence_masks]
)
else:
spec_state_indices_tensor = common_attn_metadata.block_table_tensor[
spec_sequence_masks, : self.num_spec + 1
]
non_spec_state_indices_tensor = (
common_attn_metadata.block_table_tensor[~spec_sequence_masks, 0]
)
Expand Down Expand Up @@ -450,7 +470,8 @@ def build(
common_attn_metadata.query_start_loc.device
)

# Subtract ALL decode tokens (spec + non-spec) to get prefill-only coordinates
# Subtract ALL decode tokens (spec + non-spec)
# to get prefill-only coordinates
total_decode_tokens = num_decode_tokens + num_spec_decode_tokens
query_start_loc_p = (
common_attn_metadata.query_start_loc[-num_prefills - 1 :]
Expand Down
6 changes: 6 additions & 0 deletions vllm/v1/core/single_type_kv_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,6 +645,12 @@ def find_longest_cache_hit(
computed.append(cached)
break # we just need the last match - early stopping

# TODO - do we need to pop the last block if use_eagle is True?
# With hybrid models, it seems unnecessary,
# since it is done in the full attention manager.
# Popping again means the last 2 blocks are popped overall,
# and the performance is hit hard.

return computed_blocks

def get_num_common_prefix_blocks(self, running_request_id: str) -> int:
Expand Down
Loading