-
Notifications
You must be signed in to change notification settings - Fork 19
[feat][plugin] make ATOM mla attention works for vllm #265
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
fef62be
b9a0acf
07fd497
d45cd8a
5e633ed
007c829
4d9971f
45f7c4e
5c460b9
f7341d2
5d696d7
9e8486a
f57e9ff
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -35,6 +35,11 @@ | |
| batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant as _aiter_triton_fp8_bmm, | ||
| ) | ||
|
|
||
|
|
||
| from atom.plugin import is_plugin_mode | ||
|
|
||
| from atom.plugin.attention_mla import MLAAttentionImplDecoratorForPluginMode | ||
|
|
||
| # torch.set_printoptions(threshold=10_000) | ||
|
|
||
| logger = logging.getLogger("atom") | ||
|
|
@@ -92,11 +97,12 @@ def dynamic_per_batched_tensor_quant( | |
| return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal() | ||
|
|
||
|
|
||
| @MLAAttentionImplDecoratorForPluginMode | ||
| class MLAAttention(nn.Module): | ||
| def __init__( | ||
| self, | ||
| num_heads: int, | ||
| head_dim: int, | ||
| head_size: int, | ||
| scale: float, | ||
| num_kv_heads: int, | ||
| kv_cache_dtype: str, | ||
|
|
@@ -107,7 +113,7 @@ def __init__( | |
| ) -> None: | ||
| super().__init__() | ||
| self.num_heads = num_heads | ||
| self.head_dim = head_dim | ||
| self.head_dim = head_size | ||
| self.scale = float(scale) | ||
| self.num_kv_heads = num_kv_heads | ||
| self.kv_cache_dtype = kv_cache_dtype if kv_cache_dtype == "fp8" else "auto" | ||
|
|
@@ -134,7 +140,7 @@ def __init__( | |
| ) | ||
| self.layer_num = layer_num | ||
|
|
||
| def process_weights_after_loading(self): | ||
| def process_weights_after_loading(self, act_dtype: Optional[torch.dtype] = None): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @zejunchen-zejun we need add this arg?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For vLLM side, it's calling path like https://github.com/vllm-project/vllm/blob/1892993bc18e243e2c05841314c5e9c06a80c70d/vllm/attention/layer.py#L675, it needs such a arg. |
||
| if is_rocm_aiter_fp4bmm_enabled(): | ||
| kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj) | ||
| self.W_K, self.W_K_scale, W_V, self.W_V_scale = quark_post_load_weights( | ||
|
|
@@ -146,7 +152,7 @@ def process_weights_after_loading(self): | |
| self.W_K_scale = self.W_K_scale.transpose(-2, -1).contiguous() | ||
| self.W_V = self.W_V.transpose(-2, -1).contiguous() | ||
| self.W_V_scale = self.W_V_scale.transpose(-2, -1).contiguous() | ||
| else: # is_rocm_aiter_fp8bmm_enabled(): | ||
| else: # is_rocm_aiter_fp8bmm_enabled() | ||
| kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T | ||
| assert kv_b_proj_weight.shape == ( | ||
| self.kv_lora_rank, | ||
|
|
@@ -175,7 +181,7 @@ def process_weights_after_loading(self): | |
| W_V, dtype=dtypes.fp8 | ||
| ) | ||
|
|
||
| def _v_up_proj_and_o_proj(self, x): | ||
| def _v_up_proj(self, x): | ||
| # Convert from (B, N, L) to (N, B, L) | ||
| x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) | ||
| # Multiply (N, B, L) x (N, L, V) -> (N, B, V), Convert from (N, B, V) to (B, N, V) | ||
|
|
@@ -207,7 +213,7 @@ def _v_up_proj_and_o_proj(self, x): | |
| ) | ||
| # Convert from (B, N, V) to (B, N * V) | ||
| x = x.reshape(-1, self.num_heads * self.v_head_dim) | ||
| return self.o_proj(x) | ||
| return x | ||
|
|
||
| def _q_proj_and_k_up_proj(self, x, x_scale=None): | ||
| q_nope, q_pe = ( | ||
|
|
@@ -413,7 +419,7 @@ def _forward_prefill_mha( | |
| causal=True, | ||
| ) | ||
|
|
||
| return self.o_proj(output.flatten(start_dim=-2)) | ||
| return output.flatten(start_dim=-2) | ||
|
|
||
| def _forward_prefill_mla( | ||
| self, | ||
|
|
@@ -480,7 +486,7 @@ def _forward_prefill_mla( | |
| None, | ||
| ) | ||
|
|
||
| return self._v_up_proj_and_o_proj(o) | ||
| return self._v_up_proj(o) | ||
|
|
||
| def _forward_decode( | ||
| self, | ||
|
|
@@ -555,16 +561,15 @@ def _forward_decode( | |
| kv_scale=self._k_scale, | ||
| ) | ||
|
|
||
| return self._v_up_proj_and_o_proj(o) | ||
| return self._v_up_proj(o) | ||
|
|
||
| def forward( | ||
| def forward_impl_server_mode( | ||
| self, | ||
| q: torch.Tensor, # query in unified attn | ||
| q: torch.Tensor, | ||
| k_nope: torch.Tensor, | ||
| k_rope: torch.Tensor, | ||
| positions: torch.Tensor, | ||
| q_scale: Optional[torch.Tensor], | ||
| qkv: Optional[torch.Tensor], | ||
| positions: torch.Tensor = None, | ||
| q_scale: Optional[torch.Tensor] = None, | ||
| ) -> torch.Tensor: | ||
| # kv_cache = self.kv_cache | ||
| forward_context: ForwardContext = get_forward_context() | ||
|
|
@@ -577,7 +582,7 @@ def forward( | |
| if forward_context.context.is_dummy_run: | ||
| # dummy run: skip real attention and return | ||
| output_shape = list(q.shape) | ||
| output_shape[-1] = 7168 | ||
| output_shape[-1] = self.num_heads * self.v_head_dim | ||
XiaobingSuper marked this conversation as resolved.
Show resolved
Hide resolved
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. self.num_heads * self.v_head_dim looks like not eaquals to 7168 for deepseek
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. plugin path is also in our repo... then why we have to move o_proj out of attn
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This change is for the fallback path, i.e., for plugin mode, but use vllm attn backend, because we use vllm MLAAttention class(here is self.attn), the forward path doesn't has o_proj, see https://github.com/vllm-project/vllm/blob/v0.15.1/vllm/attention/layer.py#L640, this is only for attention compute.
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @valarLip The plugin mode using vllm atten backend will be like(setting ATOM_DISABLE_VLLM_PLUGIN_ATTENTION=1): https://github.com/zejunchen-zejun/ATOM/blob/zejun/plugin_for_atom_1223/recipes/vLLM-ATOM-OOT-Plugin-Backend.md#launching-server-of-vllm-with-atom-oot-plugin-platform |
||
| atom_config = get_current_atom_config() | ||
| output_dtype = atom_config.torch_dtype | ||
| output = torch.empty(output_shape, dtype=output_dtype, device=q.device) | ||
|
|
@@ -647,6 +652,41 @@ def forward( | |
|
|
||
| return output | ||
|
|
||
| def forward( | ||
| self, | ||
| layer: torch.nn.Module, | ||
| query: torch.Tensor, # query in unified attn | ||
| k_nope: torch.Tensor, | ||
| k_rope: torch.Tensor, | ||
| kv_cache: torch.Tensor = None, | ||
| attn_metadata=None, | ||
| positions: torch.Tensor = None, | ||
| q_scale: Optional[torch.Tensor] = None, | ||
| output: torch.Tensor = None, | ||
| **kwargs, | ||
| ) -> torch.Tensor: | ||
| if is_plugin_mode(): | ||
| # forward impl method are added by the decorator | ||
| # MLAAttentionImplDecoratorForPluginMode | ||
| return self.forward_impl_plugin_mode( | ||
| layer=layer, | ||
| q=query, | ||
| k_c_normed=k_nope, | ||
| k_pe=k_rope, | ||
| kv_cache=kv_cache, | ||
| attn_metadata=attn_metadata, | ||
| output=output, | ||
| ) | ||
| else: | ||
| # only for server mode, keep the original method | ||
| return self.forward_impl_server_mode( | ||
| q=query, | ||
| k_nope=k_nope, | ||
| k_rope=k_rope, | ||
| positions=positions, | ||
| q_scale=q_scale, | ||
| ) | ||
|
|
||
|
|
||
| @triton.jit | ||
| def _convert_req_index_to_global_index_kernel( | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,7 +3,6 @@ | |
|
|
||
| # from flash_attn import flash_attn_with_kvcache | ||
| from typing import Optional | ||
|
|
||
| import torch | ||
| from torch import nn | ||
|
|
||
|
|
@@ -60,15 +59,15 @@ def __init__( | |
| **kwargs, | ||
| ) | ||
|
|
||
| self.use_mla = use_mla | ||
| # for plugin mode | ||
| if is_vllm(): | ||
| self.use_mla = use_mla | ||
| self.rotary_emb = rotary_emb | ||
| self.rotary_emb = mla_modules.rotary_emb if use_mla else rotary_emb | ||
|
|
||
| try: | ||
| from vllm.attention.layer import Attention, AttentionType | ||
| from vllm.attention.layer import Attention, MLAAttention, AttentionType | ||
| except ImportError: | ||
| from vllm.model_executor.layers.attention import Attention | ||
| from vllm.model_executor.layers.attention import Attention, MLAAttention | ||
| from vllm.v1.attention.backend import AttentionType | ||
XiaobingSuper marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| atom_config = get_current_atom_config() | ||
|
|
@@ -88,28 +87,68 @@ def __init__( | |
| extra_impl_args["rotary_emb"] = rotary_emb | ||
| extra_impl_args["q_norm"] = q_norm | ||
| extra_impl_args["k_norm"] = k_norm | ||
|
|
||
| self.attn = Attention( | ||
| num_heads=num_heads, | ||
| head_size=head_dim, | ||
| scale=scale, | ||
| num_kv_heads=num_kv_heads, | ||
| alibi_slopes=alibi_slopes, | ||
| cache_config=cache_config, | ||
| quant_config=quant_config, | ||
| logits_soft_cap=None, | ||
| per_layer_sliding_window=per_layer_sliding_window, | ||
| prefix=f"{prefix}", | ||
| attn_type=AttentionType.DECODER, | ||
| kv_sharing_target_layer_name=None, | ||
| **extra_impl_args, | ||
| ) | ||
| if use_mla: | ||
| extra_impl_args["layer_num"] = layer_num | ||
| extra_impl_args["mla_modules"] = mla_modules | ||
|
|
||
| if use_mla: | ||
| assert ( | ||
| mla_modules.indexer is None | ||
| ), "MLAAttention is not supported for sparse mode" | ||
| self.num_heads = num_heads | ||
| self.v_head_dim = mla_modules.v_head_dim | ||
| self.qk_head_dim = mla_modules.qk_head_dim | ||
| self.qk_nope_head_dim = mla_modules.qk_nope_head_dim | ||
| self.q_proj = mla_modules.q_proj | ||
| self.o_proj = mla_modules.o_proj | ||
|
|
||
| self.attn = MLAAttention( | ||
| num_heads=num_heads, | ||
| scale=scale, | ||
| qk_nope_head_dim=mla_modules.qk_nope_head_dim, | ||
| qk_rope_head_dim=mla_modules.qk_rope_head_dim, | ||
| v_head_dim=mla_modules.v_head_dim, | ||
| q_lora_rank=mla_modules.q_lora_rank, | ||
| kv_lora_rank=mla_modules.kv_lora_rank, | ||
| cache_config=cache_config, | ||
| quant_config=quant_config, | ||
| prefix=f"{prefix}.attn", | ||
| kv_b_proj=mla_modules.kv_b_proj, | ||
| use_sparse=False, | ||
| indexer=mla_modules.indexer, | ||
| **extra_impl_args, | ||
| ) | ||
| else: | ||
| self.attn = Attention( | ||
| num_heads=num_heads, | ||
| head_size=head_dim, | ||
| scale=scale, | ||
| num_kv_heads=num_kv_heads, | ||
| alibi_slopes=alibi_slopes, | ||
| cache_config=cache_config, | ||
| quant_config=quant_config, | ||
| logits_soft_cap=None, | ||
| per_layer_sliding_window=per_layer_sliding_window, | ||
| prefix=f"{prefix}", | ||
| attn_type=AttentionType.DECODER, | ||
| kv_sharing_target_layer_name=None, | ||
| **extra_impl_args, | ||
| ) | ||
|
|
||
| compilation_config = atom_config.compilation_config | ||
| self.layer_name = prefix | ||
| if self.layer_name in compilation_config.static_forward_context: | ||
| raise ValueError("Duplicate layer: {}".format(self.layer_name)) | ||
| compilation_config.static_forward_context[self.layer_name] = self | ||
|
|
||
| if self.use_mla: | ||
| if "positions" not in compilation_config.static_forward_context: | ||
| max_num_tokens = ( | ||
| atom_config.plugin_config.vllm_scheduler_config.max_num_batched_tokens | ||
| ) | ||
| compilation_config.static_forward_context["positions"] = ( | ||
| torch.zeros(max_num_tokens, dtype=torch.int64, device="cuda") | ||
| ) | ||
| return | ||
|
|
||
| self.num_heads = num_heads | ||
|
|
@@ -122,7 +161,6 @@ def __init__( | |
| self.k_scale = self.v_scale = None | ||
| self.layer_num = layer_num | ||
| self.mla_modules = mla_modules | ||
| self.use_mla = use_mla | ||
| self.base_attention = None | ||
| self.kv_cache = torch.tensor([]) | ||
| self.indexer = mla_modules.indexer if mla_modules is not None else None | ||
|
|
@@ -136,7 +174,7 @@ def __init__( | |
| use_mla=self.use_mla, | ||
| ) | ||
| impl_cls = self.attn_backend.get_impl_cls() | ||
| self.impl = impl_cls( | ||
| impl_args = dict( | ||
| num_heads=num_heads, | ||
| head_dim=head_dim, | ||
XiaobingSuper marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| scale=scale, | ||
|
|
@@ -153,7 +191,8 @@ def __init__( | |
| k_norm=k_norm, | ||
| **kwargs, | ||
| ) | ||
|
|
||
| impl_args["head_size" if self.use_mla else "head_dim"] = head_dim | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is also comes from vllm? i would like we always use head_dim
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, vllm use head_size, see https://github.com/vllm-project/vllm/blob/v0.15.1/vllm/attention/layer.py#L579. Before this PR 6c40248, it also use head_size. |
||
| self.impl = impl_cls(**impl_args) | ||
| compilation_config = atom_config.compilation_config | ||
| default_name = f"MLA_{layer_num}" if self.use_mla else f"MHA_{layer_num}" | ||
| self.layer_name = prefix if prefix is not None else default_name | ||
|
|
||


Uh oh!
There was an error while loading. Please reload this page.