Skip to content
70 changes: 55 additions & 15 deletions atom/model_ops/attention_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@
batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant as _aiter_triton_fp8_bmm,
)


from atom.plugin import is_plugin_mode

from atom.plugin.attention_mla import MLAAttentionImplDecoratorForPluginMode

# torch.set_printoptions(threshold=10_000)

logger = logging.getLogger("atom")
Expand Down Expand Up @@ -92,11 +97,12 @@ def dynamic_per_batched_tensor_quant(
return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()


@MLAAttentionImplDecoratorForPluginMode
class MLAAttention(nn.Module):
def __init__(
self,
num_heads: int,
head_dim: int,
head_size: int,
scale: float,
num_kv_heads: int,
kv_cache_dtype: str,
Expand All @@ -107,7 +113,7 @@ def __init__(
) -> None:
super().__init__()
self.num_heads = num_heads
self.head_dim = head_dim
self.head_dim = head_size
self.scale = float(scale)
self.num_kv_heads = num_kv_heads
self.kv_cache_dtype = kv_cache_dtype if kv_cache_dtype == "fp8" else "auto"
Expand All @@ -134,7 +140,7 @@ def __init__(
)
self.layer_num = layer_num

def process_weights_after_loading(self):
def process_weights_after_loading(self, act_dtype: Optional[torch.dtype] = None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zejunchen-zejun we need add this arg?

Copy link
Author

@XiaobingSuper XiaobingSuper Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if is_rocm_aiter_fp4bmm_enabled():
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj)
self.W_K, self.W_K_scale, W_V, self.W_V_scale = quark_post_load_weights(
Expand All @@ -146,7 +152,7 @@ def process_weights_after_loading(self):
self.W_K_scale = self.W_K_scale.transpose(-2, -1).contiguous()
self.W_V = self.W_V.transpose(-2, -1).contiguous()
self.W_V_scale = self.W_V_scale.transpose(-2, -1).contiguous()
else: # is_rocm_aiter_fp8bmm_enabled():
else: # is_rocm_aiter_fp8bmm_enabled()
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
assert kv_b_proj_weight.shape == (
self.kv_lora_rank,
Expand Down Expand Up @@ -175,7 +181,7 @@ def process_weights_after_loading(self):
W_V, dtype=dtypes.fp8
)

def _v_up_proj_and_o_proj(self, x):
def _v_up_proj(self, x):
# Convert from (B, N, L) to (N, B, L)
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
# Multiply (N, B, L) x (N, L, V) -> (N, B, V), Convert from (N, B, V) to (B, N, V)
Expand Down Expand Up @@ -207,7 +213,7 @@ def _v_up_proj_and_o_proj(self, x):
)
# Convert from (B, N, V) to (B, N * V)
x = x.reshape(-1, self.num_heads * self.v_head_dim)
return self.o_proj(x)
return x

def _q_proj_and_k_up_proj(self, x, x_scale=None):
q_nope, q_pe = (
Expand Down Expand Up @@ -413,7 +419,7 @@ def _forward_prefill_mha(
causal=True,
)

return self.o_proj(output.flatten(start_dim=-2))
return output.flatten(start_dim=-2)

def _forward_prefill_mla(
self,
Expand Down Expand Up @@ -480,7 +486,7 @@ def _forward_prefill_mla(
None,
)

return self._v_up_proj_and_o_proj(o)
return self._v_up_proj(o)

def _forward_decode(
self,
Expand Down Expand Up @@ -555,16 +561,15 @@ def _forward_decode(
kv_scale=self._k_scale,
)

return self._v_up_proj_and_o_proj(o)
return self._v_up_proj(o)

def forward(
def forward_impl_server_mode(
self,
q: torch.Tensor, # query in unified attn
q: torch.Tensor,
k_nope: torch.Tensor,
k_rope: torch.Tensor,
positions: torch.Tensor,
q_scale: Optional[torch.Tensor],
qkv: Optional[torch.Tensor],
positions: torch.Tensor = None,
q_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# kv_cache = self.kv_cache
forward_context: ForwardContext = get_forward_context()
Expand All @@ -577,7 +582,7 @@ def forward(
if forward_context.context.is_dummy_run:
# dummy run: skip real attention and return
output_shape = list(q.shape)
output_shape[-1] = 7168
output_shape[-1] = self.num_heads * self.v_head_dim
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.num_heads * self.v_head_dim looks like not eaquals to 7168 for deepseek

Copy link
Author

@XiaobingSuper XiaobingSuper Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is o_proj's input, see atom path:
image and plugin path:
image

The reason is that vllm do a_proj outside of the attention backend.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

plugin path is also in our repo... then why we have to move o_proj out of attn

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is for the fallback path, i.e., for plugin mode, but use vllm attn backend, because we use vllm MLAAttention class(here is self.attn), the forward path doesn't has o_proj, see https://github.com/vllm-project/vllm/blob/v0.15.1/vllm/attention/layer.py#L640, this is only for attention compute.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

atom_config = get_current_atom_config()
output_dtype = atom_config.torch_dtype
output = torch.empty(output_shape, dtype=output_dtype, device=q.device)
Expand Down Expand Up @@ -647,6 +652,41 @@ def forward(

return output

def forward(
self,
layer: torch.nn.Module,
query: torch.Tensor, # query in unified attn
k_nope: torch.Tensor,
k_rope: torch.Tensor,
kv_cache: torch.Tensor = None,
attn_metadata=None,
positions: torch.Tensor = None,
q_scale: Optional[torch.Tensor] = None,
output: torch.Tensor = None,
**kwargs,
) -> torch.Tensor:
if is_plugin_mode():
# forward impl method are added by the decorator
# MLAAttentionImplDecoratorForPluginMode
return self.forward_impl_plugin_mode(
layer=layer,
q=query,
k_c_normed=k_nope,
k_pe=k_rope,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
output=output,
)
else:
# only for server mode, keep the original method
return self.forward_impl_server_mode(
q=query,
k_nope=k_nope,
k_rope=k_rope,
positions=positions,
q_scale=q_scale,
)


@triton.jit
def _convert_req_index_to_global_index_kernel(
Expand Down
14 changes: 11 additions & 3 deletions atom/model_ops/attentions/aiter_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,22 @@

from .backends import AttentionBackend, CommonAttentionBuilder

from atom.plugin.prepare import is_plugin_mode
from atom.plugin.attention import AiterMLAAttentionMetadataBuilderDecoratorForPluginMode
from atom.plugin.attention import AiterBackendDecoratorForPluginMode

logger = logging.getLogger("atom")


def cdiv(a, b):
return (a + b - 1) // b


@AiterBackendDecoratorForPluginMode
class AiterMLABackend(AttentionBackend):
@staticmethod
def get_name() -> str:
return "ROCM_AITER_MLA"
return "ROCM_AITER_MLA" if not is_plugin_mode() else "CUSTOM"

@staticmethod
def get_builder_cls() -> Type["AiterMLAMetadataBuilder"]:
Expand All @@ -40,11 +45,14 @@ def get_impl_cls() -> Type["MLAAttention"]:
return MLAAttention


@AiterMLAAttentionMetadataBuilderDecoratorForPluginMode(
default_base_class=CommonAttentionBuilder
)
class AiterMLAMetadataBuilder(CommonAttentionBuilder):

def __init__(self, model_runner):
self.block_size = 1
super().__init__(model_runner)
CommonAttentionBuilder.__init__(self, model_runner)
config = model_runner.config
hf_config = config.hf_config
self.num_attention_heads = (
Expand Down Expand Up @@ -190,7 +198,7 @@ def prepare_mtp_decode(self, bs: int, max_seqlen_q: int, max_seqlen_k: int):
return self.set_mla_persistent_worker_buffers(bs, max_seqlen_q)

def prepare_prefill(self, batch: ScheduledBatch):
attn_metadata, positions = super().prepare_prefill(batch)
attn_metadata, positions = CommonAttentionBuilder.prepare_prefill(self, batch)
bs = batch.total_seqs_num_prefill
sum_scheduled_tokens = batch.total_tokens_num_prefill
var = self.model_runner.forward_vars
Expand Down
14 changes: 11 additions & 3 deletions atom/model_ops/base_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,10 +191,10 @@ def fake_(
qkv: torch.Tensor,
) -> torch.Tensor:
output_shape = list(q.shape)
if use_mla:
output_shape[-1] = 7168
# If we fusion rmsnorm and quant, the input dtype is fp8, but actually we use bf16 for output.
atom_config = get_current_atom_config()
if use_mla:
output_shape[-1] = atom_config.hf_config.hidden_size
output_dtype = atom_config.torch_dtype
output = torch.zeros(output_shape, dtype=output_dtype, device=q.device)

Expand All @@ -218,7 +218,15 @@ def unified_attention_with_output_base(
atom_config = get_current_atom_config()
self = atom_config.compilation_config.static_forward_context[layer_name]
if use_mla:
return self.impl.forward(q, k, v, positions, q_scale, qkv)
output = self.impl.forward(
layer=self,
query=q,
k_nope=k,
k_rope=v,
positions=positions,
q_scale=q_scale,
)
return self.impl.o_proj(output)
else:
return self.impl.forward(
layer=self,
Expand Down
87 changes: 63 additions & 24 deletions atom/model_ops/paged_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

# from flash_attn import flash_attn_with_kvcache
from typing import Optional

import torch
from torch import nn

Expand Down Expand Up @@ -60,15 +59,15 @@ def __init__(
**kwargs,
)

self.use_mla = use_mla
# for plugin mode
if is_vllm():
self.use_mla = use_mla
self.rotary_emb = rotary_emb
self.rotary_emb = mla_modules.rotary_emb if use_mla else rotary_emb

try:
from vllm.attention.layer import Attention, AttentionType
from vllm.attention.layer import Attention, MLAAttention, AttentionType
except ImportError:
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.attention import Attention, MLAAttention
from vllm.v1.attention.backend import AttentionType

atom_config = get_current_atom_config()
Expand All @@ -88,28 +87,68 @@ def __init__(
extra_impl_args["rotary_emb"] = rotary_emb
extra_impl_args["q_norm"] = q_norm
extra_impl_args["k_norm"] = k_norm

self.attn = Attention(
num_heads=num_heads,
head_size=head_dim,
scale=scale,
num_kv_heads=num_kv_heads,
alibi_slopes=alibi_slopes,
cache_config=cache_config,
quant_config=quant_config,
logits_soft_cap=None,
per_layer_sliding_window=per_layer_sliding_window,
prefix=f"{prefix}",
attn_type=AttentionType.DECODER,
kv_sharing_target_layer_name=None,
**extra_impl_args,
)
if use_mla:
extra_impl_args["layer_num"] = layer_num
extra_impl_args["mla_modules"] = mla_modules

if use_mla:
assert (
mla_modules.indexer is None
), "MLAAttention is not supported for sparse mode"
self.num_heads = num_heads
self.v_head_dim = mla_modules.v_head_dim
self.qk_head_dim = mla_modules.qk_head_dim
self.qk_nope_head_dim = mla_modules.qk_nope_head_dim
self.q_proj = mla_modules.q_proj
self.o_proj = mla_modules.o_proj

self.attn = MLAAttention(
num_heads=num_heads,
scale=scale,
qk_nope_head_dim=mla_modules.qk_nope_head_dim,
qk_rope_head_dim=mla_modules.qk_rope_head_dim,
v_head_dim=mla_modules.v_head_dim,
q_lora_rank=mla_modules.q_lora_rank,
kv_lora_rank=mla_modules.kv_lora_rank,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
kv_b_proj=mla_modules.kv_b_proj,
use_sparse=False,
indexer=mla_modules.indexer,
**extra_impl_args,
)
else:
self.attn = Attention(
num_heads=num_heads,
head_size=head_dim,
scale=scale,
num_kv_heads=num_kv_heads,
alibi_slopes=alibi_slopes,
cache_config=cache_config,
quant_config=quant_config,
logits_soft_cap=None,
per_layer_sliding_window=per_layer_sliding_window,
prefix=f"{prefix}",
attn_type=AttentionType.DECODER,
kv_sharing_target_layer_name=None,
**extra_impl_args,
)

compilation_config = atom_config.compilation_config
self.layer_name = prefix
if self.layer_name in compilation_config.static_forward_context:
raise ValueError("Duplicate layer: {}".format(self.layer_name))
compilation_config.static_forward_context[self.layer_name] = self

if self.use_mla:
if "positions" not in compilation_config.static_forward_context:
max_num_tokens = (
atom_config.plugin_config.vllm_scheduler_config.max_num_batched_tokens
)
compilation_config.static_forward_context["positions"] = (
torch.zeros(max_num_tokens, dtype=torch.int64, device="cuda")
)
return

self.num_heads = num_heads
Expand All @@ -122,7 +161,6 @@ def __init__(
self.k_scale = self.v_scale = None
self.layer_num = layer_num
self.mla_modules = mla_modules
self.use_mla = use_mla
self.base_attention = None
self.kv_cache = torch.tensor([])
self.indexer = mla_modules.indexer if mla_modules is not None else None
Expand All @@ -136,7 +174,7 @@ def __init__(
use_mla=self.use_mla,
)
impl_cls = self.attn_backend.get_impl_cls()
self.impl = impl_cls(
impl_args = dict(
num_heads=num_heads,
head_dim=head_dim,
scale=scale,
Expand All @@ -153,7 +191,8 @@ def __init__(
k_norm=k_norm,
**kwargs,
)

impl_args["head_size" if self.use_mla else "head_dim"] = head_dim
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is also comes from vllm? i would like we always use head_dim

Copy link
Author

@XiaobingSuper XiaobingSuper Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, vllm use head_size, see https://github.com/vllm-project/vllm/blob/v0.15.1/vllm/attention/layer.py#L579. Before this PR 6c40248, it also use head_size.

self.impl = impl_cls(**impl_args)
compilation_config = atom_config.compilation_config
default_name = f"MLA_{layer_num}" if self.use_mla else f"MHA_{layer_num}"
self.layer_name = prefix if prefix is not None else default_name
Expand Down
Loading
Loading