Skip to content

Latest commit

ย 

History

History
482 lines (378 loc) ยท 23.5 KB

File metadata and controls

482 lines (378 loc) ยท 23.5 KB

GDN (GatedDeltaNet) ๅœจ vLLM ไธญ็š„่ฎก็ฎ—ๆต็จ‹

1. ๆ•ดไฝ“ๆžถๆž„ๆฆ‚่งˆ

GDN๏ผˆGated Delta Net๏ผ‰ๆ˜ฏ Qwen3-Next ๆจกๅž‹ไธญ็š„ไธ€็ง็บฟๆ€งๆณจๆ„ๅŠ›ๆœบๅˆถ๏ผŒไธŽไผ ็ปŸ็š„ full attention ไบคๆ›ฟไฝฟ็”จใ€‚ๆจกๅž‹็š„ๆฏไธ€ๅฑ‚็”ฑ config.layer_types ๅ†ณๅฎšๆ˜ฏ "linear_attention"๏ผˆGDN๏ผ‰่ฟ˜ๆ˜ฏ "full_attention"๏ผˆๆ ‡ๅ‡† Attention๏ผ‰ใ€‚

ๆ ธๅฟƒไปฃ็ ไฝไบŽ๏ผš

  • GDN ๆจกๅž‹ๅฑ‚๏ผšvllm/model_executor/models/qwen3_next.py โ†’ Qwen3NextGatedDeltaNet
  • GDN Attention Metadata๏ผšvllm/v1/attention/backends/gdn_attn.py โ†’ GDNAttentionMetadataBuilder
  • ้€’ๅฝ’ๆ ธๅฟƒ Kernel๏ผšvllm/model_executor/layers/fla/ops/fused_recurrent.py
  • ๅˆ†ๅ—ๆ ธๅฟƒ Kernel๏ผšvllm/model_executor/layers/fla/ops/chunk.py

ๅ…ณ้”ฎ็‚น๏ผšQSA๏ผˆQuery-Side Aggregation๏ผ‰ๅชๅœจ full_attention ๅฑ‚ไธญไฝฟ็”จ๏ผŒไธๅฝฑๅ“ GDN ๅฑ‚็š„่ฎก็ฎ—ใ€‚

# vllm/model_executor/models/qwen3_next.py : Qwen3NextDecoderLayer.__init__
if self.layer_type == "linear_attention":
    topk_indices_buffer = None  # only use qsa in full-attn
    self.linear_attn = Qwen3NextGatedDeltaNet(...)
elif self.layer_type == "full_attention":
    self.self_attn = Qwen3NextAttention(..., topk_indices_buffer=topk_indices_buffer)

2. GDN ๅฑ‚็š„ไธ‰้˜ถๆฎต่ฎก็ฎ—ๆต็จ‹

GDN ็š„ forward ๅˆ†ไธบไธ‰ไธช้˜ถๆฎตใ€‚

้˜ถๆฎต 1๏ผš่พ“ๅ…ฅๆŠ•ๅฝฑ (Input Projection)

ไปŽ hidden_states ้€š่ฟ‡ไธคไธชๆŠ•ๅฝฑ็Ÿฉ้˜ตๅพ—ๅˆฐ 6 ไธชๅผ ้‡๏ผš

# vllm/model_executor/models/qwen3_next.py : Qwen3NextGatedDeltaNet.forward
projected_states_qkvz, _ = self.in_proj_qkvz(hidden_states)  # โ†’ q, k, v, z
projected_states_ba, _ = self.in_proj_ba(hidden_states)        # โ†’ b, a

ๆœ‰ไธค็ง split ๆ–นๅผ๏ผš

  • Fused Triton kernel๏ผˆfused_qkvzba_split_reshape_cat๏ผ‰๏ผšๅœจๆปก่ถณๆกไปถๆ—ถไฝฟ็”จ่žๅˆ kernel
  • PyTorch ๅŽŸ็”Ÿ๏ผˆfix_query_key_value_ordering๏ผ‰๏ผšfallback ่ทฏๅพ„

่พ“ๅ‡บ็š„ 6 ไธชๅผ ้‡็š„ไฝœ็”จ๏ผš

ๅผ ้‡ ๅฝข็Šถ ็”จ้€”
q [tokens, num_k_heads/tp, head_k_dim] ๆŸฅ่ฏขๅ‘้‡
k [tokens, num_k_heads/tp, head_k_dim] ้”ฎๅ‘้‡
v [tokens, num_v_heads/tp, head_v_dim] ๅ€ผๅ‘้‡
z [tokens, num_v_heads/tp, head_v_dim] ่พ“ๅ‡บ้—จๆŽง
b [tokens, num_v_heads/tp] ่ฎก็ฎ— beta๏ผˆๆ›ดๆ–ฐ้—จๆŽง๏ผ‰
a [tokens, num_v_heads/tp] ่ฎก็ฎ— g๏ผˆ่กฐๅ‡ๅ› ๅญ๏ผ‰

ๆœ€็ปˆ q, k, v ่ขซๆ‹ผๆŽฅไธบ mixed_qkv = cat(q, k, v)ใ€‚

้˜ถๆฎต 2๏ผšๆ ธๅฟƒๆณจๆ„ๅŠ›่ฎก็ฎ— (Core Attention)

ๆ ธๅฟƒ่ฎก็ฎ—้€š่ฟ‡ _forward_core ๅฎŒๆˆ๏ผŒๅˆ†ไธบไธ‰ไธชๅญๆญฅ้ชคใ€‚

2.1 ๅ› ๆžœๅท็งฏ (Causal Conv1d)

ๅฏน mixed_qkv ่ฟ›่กŒ 1D ๅ› ๆžœๅท็งฏ๏ผŒๆ›ดๆ–ฐ conv_state ็ผ“ๅญ˜ใ€‚ๆ นๆฎๆ˜ฏๅฆๆœ‰ spec decode๏ผŒๅˆ†ไธบ spec ้ƒจๅˆ†ๅ’Œ non-spec ้ƒจๅˆ†ๅˆ†ๅˆซๅค„็†๏ผš

# vllm/model_executor/models/qwen3_next.py : _forward_core

# Spec decode ้ƒจๅˆ†๏ผšๅ•ๆญฅๆ›ดๆ–ฐ
if spec_sequence_masks is not None:
    mixed_qkv_spec = causal_conv1d_update(
        mixed_qkv_spec, conv_state, conv_weights, ...,
        conv_state_indices=spec_state_indices_tensor[:, 0],
        num_accepted_tokens=num_accepted_tokens,
        query_start_loc=spec_query_start_loc,
    )

# Non-spec Prefill ้ƒจๅˆ†๏ผšๆ•ดๅบๅˆ—ๅท็งฏ
if attn_metadata.num_prefills > 0:
    mixed_qkv_non_spec = causal_conv1d_fn(
        mixed_qkv_non_spec_T, conv_weights, ...,
        conv_states=conv_state,
        has_initial_state=has_initial_state,
        cache_indices=non_spec_state_indices_tensor,
        query_start_loc=non_spec_query_start_loc,
    ).transpose(0, 1)

# Non-spec Decode ้ƒจๅˆ†๏ผšๅ•ๆญฅๆ›ดๆ–ฐ
elif attn_metadata.num_decodes > 0:
    mixed_qkv_non_spec = causal_conv1d_update(
        mixed_qkv_non_spec, conv_state, conv_weights, ...,
        conv_state_indices=non_spec_state_indices_tensor,
    )

2.2 ้—จๆŽง่ฎก็ฎ— (Gating)

้€š่ฟ‡ fused triton kernel ่ฎก็ฎ—่กฐๅ‡ๅ› ๅญ g ๅ’Œๆ›ดๆ–ฐ้—จๆŽง beta๏ผš

# vllm/model_executor/models/qwen3_next.py
g, beta = fused_gdn_gating(self.A_log, a.contiguous(), b, self.dt_bias)

ๆ•ฐๅญฆๅ…ฌๅผ๏ผš

g = -exp(A_log) * softplus(a + dt_bias)
beta = sigmoid(b)

ๅ…ถไธญ softplus(x) = (1/ฮฒ) * log(1 + exp(ฮฒ*x))๏ผŒๅฝ“ ฮฒ*x > threshold ๆ—ถ้€€ๅŒ–ไธบ xใ€‚

Triton kernel ๅฎž็Žฐ๏ผš

# fused_gdn_gating_kernel
x = a.float() + dt_bias.float()
softplus_x = where(beta * x <= threshold, (1/beta) * log(1 + exp(beta * x)), x)
g = -exp(A_log.float()) * softplus_x
beta_output = sigmoid(b.float())

2.3 ้€’ๅฝ’ๆณจๆ„ๅŠ› (Recurrent Attention)

่ฟ™ๆ˜ฏ GDN ็š„ๆ ธๅฟƒๆ•ฐๅญฆ่ฎก็ฎ—๏ผŒๆ นๆฎ prefill/decode ไฝฟ็”จไธๅŒ็š„็ฎ—ๆณ•ใ€‚

Decode ่ทฏๅพ„ไฝฟ็”จ fused_recurrent_gated_delta_rule๏ผˆ้€ๆญฅ้€’ๆŽจ๏ผŒๅคๆ‚ๅบฆ O(1)/step๏ผ‰๏ผš

# Triton kernel: fused_recurrent_gated_delta_rule_fwd_kernel
# ๆ–‡ไปถ: vllm/model_executor/layers/fla/ops/fused_recurrent.py

for i_t in range(0, T):
    b_q = load(p_q)    # query
    b_k = load(p_k)    # key
    b_v = load(p_v)    # value

    # L2 normalization
    b_q = b_q / sqrt(sum(b_q * b_q) + 1e-6)
    b_k = b_k / sqrt(sum(b_k * b_k) + 1e-6)
    b_q = b_q * scale

    # State decay
    b_h *= exp(b_g)                       # h = h * exp(g)

    # Delta rule
    b_v -= sum(b_h * b_k[:, None], 0)    # v' = v - h^T @ k
    b_v *= b_beta                         # v' = beta * v'

    # State update
    b_h += b_k[:, None] * b_v[None, :]   # h = h + k โŠ— v'

    # Output
    b_o = sum(b_h * b_q[:, None], 0)     # o = h^T @ q

Prefill ่ทฏๅพ„ไฝฟ็”จ chunk_gated_delta_rule๏ผˆๅˆ†ๅ—ๅนถ่กŒ๏ผŒๆ›ด้ซ˜ๆ•ˆๅœฐๅค„็†้•ฟๅบๅˆ—๏ผ‰๏ผš

# vllm/model_executor/models/qwen3_next.py : _forward_core
if attn_metadata.num_prefills > 0:
    initial_state = ssm_state[non_spec_state_indices_tensor].contiguous()
    initial_state[~has_initial_state, ...] = 0
    core_attn_out_non_spec, last_recurrent_state = chunk_gated_delta_rule(
        q=query_non_spec, k=key_non_spec, v=value_non_spec,
        g=g_non_spec, beta=beta_non_spec,
        initial_state=initial_state,
        output_final_state=True,
        cu_seqlens=non_spec_query_start_loc,
        head_first=False,
        use_qk_l2norm_in_kernel=True,
    )
    # ๅฐ†ๆœ€็ปˆ็Šถๆ€ๅ†™ๅ›ž cache
    ssm_state[non_spec_state_indices_tensor] = last_recurrent_state.to(ssm_state.dtype)

ๆ•ฐๅญฆๅ…ฌๅผๆ€ป็ป“๏ผˆๆฏไธช time step t๏ผ‰๏ผš

$$\hat{q}_t = \frac{q_t}{|q_t|_2} \cdot \text{scale}, \quad \hat{k}_t = \frac{k_t}{|k_t|_2}$$

$$h_t = e^{g_t} \cdot h_{t-1} + \hat{k}_t \otimes \left[\sigma(b_t) \cdot (v_t - h_{t-1}^T \hat{k}_t)\right]$$

$$o_t = h_t^T \hat{q}_t$$

ๅ…ถไธญ๏ผš

  • $h_t \in \mathbb{R}^{d_k \times d_v}$ ไธบ้€’ๅฝ’็Šถๆ€็Ÿฉ้˜ต
  • $e^{g_t}$ ไธบๆŒ‡ๆ•ฐ่กฐๅ‡ๅ› ๅญ๏ผŒๆŽงๅˆถๅކๅฒไฟกๆฏ็š„้—ๅฟ˜
  • $\sigma(b_t)$ ไธบ sigmoid ้—จๆŽง๏ผŒๆŽงๅˆถๆ–ฐไฟกๆฏ็š„ๅ†™ๅ…ฅๅผบๅบฆ
  • $v_t - h_{t-1}^T \hat{k}_t$ ไธบ"delta rule"๏ผŒๅชๅ†™ๅ…ฅไธŽๅฝ“ๅ‰็Šถๆ€ๅทฎๅผ‚็š„้ƒจๅˆ†

้˜ถๆฎต 3๏ผš่พ“ๅ‡บๆŠ•ๅฝฑ (Output Projection)

# vllm/model_executor/models/qwen3_next.py : Qwen3NextGatedDeltaNet.forward
core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1])
z = z.reshape(-1, z.shape[-1])
core_attn_out = self.norm(core_attn_out, z)     # RMSNormGated: silu(z) * norm(out)
core_attn_out = rearrange(core_attn_out, "... h d -> ... (h d)")
output[:num_tokens], _ = self.out_proj(core_attn_out)  # ็บฟๆ€งๆŠ•ๅฝฑๅ›ž hidden_size

่พ“ๅ‡บ็ป่ฟ‡ RMSNormGated๏ผˆไฝฟ็”จ z ไฝœไธบ gate๏ผŒsilu ๆฟ€ๆดป๏ผ‰ๅŽ้€š่ฟ‡ out_proj ๆ˜ ๅฐ„ๅ›ž hidden_sizeใ€‚


3. Prefill vs Decode ่ทฏๅพ„ๅฏนๆฏ”

้˜ถๆฎต Prefill Decode
Conv1d causal_conv1d_fn๏ผˆๆ•ดๅบๅˆ—ๅท็งฏ๏ผ‰ causal_conv1d_update๏ผˆๅ•ๆญฅๆ›ดๆ–ฐ๏ผ‰
้€’ๅฝ’ๆณจๆ„ๅŠ› chunk_gated_delta_rule๏ผˆๅˆ†ๅ—ๅนถ่กŒ๏ผ‰ fused_recurrent_gated_delta_rule๏ผˆ้€ๆญฅ้€’ๆŽจ๏ผ‰
State ๅˆๅง‹ๅŒ– ไปŽ ssm_state ๅŠ ่ฝฝ๏ผŒ่‹ฅๆ— ๅކๅฒๅˆ™็ฝฎ้›ถ ็›ดๆŽฅไฝฟ็”จ ssm_state ไธญ็š„ๅทฒๆœ‰็Šถๆ€
State ๅ†™ๅ›ž ่ฎก็ฎ—ๅฎŒๆˆๅŽๆ•ดไฝ“ๅ†™ๅ›ž ssm_state ๅŽŸๅœฐๆ›ดๆ–ฐ ssm_state๏ผˆinplace๏ผ‰
ๅคๆ‚ๅบฆ O(n) ๆ€ป่ฎก๏ผŒๅˆ†ๅ—ๅนถ่กŒๅŠ ้€Ÿ O(d_k ร— d_v) per head per step

4. GDN Attention Metadata ๆž„ๅปบ

GDNAttentionMetadataBuilder๏ผˆไฝไบŽ vllm/v1/attention/backends/gdn_attn.py๏ผ‰่ดŸ่ดฃไธบๆฏๆฌก forward ๆž„ๅปบ metadataใ€‚

ๆ ธๅฟƒๅญ—ๆฎต

@dataclass
class GDNAttentionMetadata:
    num_prefills: int
    num_prefill_tokens: int
    num_decodes: int
    num_decode_tokens: int
    num_spec_decodes: int
    num_spec_decode_tokens: int
    num_actual_tokens: int

    has_initial_state: torch.Tensor | None          # prefill ๆ—ถๆ ‡่ฎฐๆ˜ฏๅฆๆœ‰ๅކๅฒ็Šถๆ€
    spec_query_start_loc: torch.Tensor | None       # spec decode ็š„ query ไฝ็ฝฎ
    non_spec_query_start_loc: torch.Tensor | None   # non-spec ็š„ query ไฝ็ฝฎ
    spec_state_indices_tensor: torch.Tensor | None  # spec decode ็š„ state ็ดขๅผ•
    non_spec_state_indices_tensor: torch.Tensor | None  # non-spec ็š„ state ็ดขๅผ•
    spec_sequence_masks: torch.Tensor | None        # ๆ ‡่ฎฐๅ“ชไบ›ๅบๅˆ—ๆ˜ฏ spec decode
    spec_token_indx: torch.Tensor | None            # spec token ็š„ๅ…จๅฑ€็ดขๅผ•
    non_spec_token_indx: torch.Tensor | None        # non-spec token ็š„ๅ…จๅฑ€็ดขๅผ•
    num_accepted_tokens: torch.Tensor | None        # ๆฏไธชๅบๅˆ—ๆŽฅๅ—็š„ token ๆ•ฐ
    retrieve_parent_token: torch.Tensor | None      # tree attention ็š„็ˆถ่Š‚็‚น็ดขๅผ•

ๆž„ๅปบ้€ป่พ‘๏ผˆbuild ๆ–นๆณ•๏ผ‰

  1. ๅˆคๆ–ญๆ˜ฏๅฆๆœ‰ spec decode๏ผš้€š่ฟ‡ num_decode_draft_tokens_cpu ๅˆคๆ–ญ
  2. ๆ—  spec decode ๆ—ถ๏ผš็›ดๆŽฅๆŒ‰ prefill/decode ๅˆ‡ๅˆ†๏ผŒstate ็ดขๅผ•ๅ– block_table[:, 0]
  3. ๆœ‰ spec decode ๆ—ถ๏ผš
    • ๅฐ† batch ๅˆ†ไธบ spec ้ƒจๅˆ†๏ผˆๅ‰ num_spec_decodes ไธช๏ผ‰ๅ’Œ non-spec ้ƒจๅˆ†
    • spec ้ƒจๅˆ†็š„ state ็ดขๅผ•ๅ– block_table[:, :num_spec+1]๏ผˆๅคšๆญฅ state๏ผ‰
    • ้€š่ฟ‡ spec_token_indx / non_spec_token_indx ็ดขๅผ•ๆฅๅˆ†็ฆปๅ’Œๅˆๅนถ token
  4. CUDAGraph ๆ”ฏๆŒ๏ผšๅฏน decode-only ๅœบๆ™ฏ่ฟ›่กŒ padding ไปฅ้€‚้… CUDAGraph ๆ•่Žท

5. QSA๏ผˆQuery-Side Aggregation๏ผ‰ๅฏน Full Attention ๅฑ‚็š„ๅฝฑๅ“

QSA ๅชๅฝฑๅ“ full_attention ๅฑ‚๏ผŒไธๅฝฑๅ“ GDN ๅฑ‚ใ€‚ๅฝ“ config ไธญ index_topk > 0 ๆ—ถๅฏ็”จใ€‚

5.1 ๆ—  QSA ็š„ Full Attention ๆต็จ‹

hidden_states โ†’ QKV Proj โ†’ Q/K Norm โ†’ RoPE โ†’ Standard Flash Attention๏ผˆๅ…จ KV๏ผ‰โ†’ [Gate ร—] O Proj

5.2 ๆœ‰ QSA ็š„ Full Attention ๆต็จ‹

ๆœ‰ QSA ๆ—ถ๏ผŒๅœจๆ ‡ๅ‡† attention ไน‹ๅ‰ๅขžๅŠ ไบ†ไธ€ไธช Indexer ๆจกๅ—็”จไบŽ็จ€็–ๅŒ–๏ผš

hidden_states โ†’ QKV Proj โ†’ Q/K Norm โ†’ RoPE
                                         โ†“
                              โ”Œโ”€โ”€ Indexer๏ผˆ็จ€็–็ดขๅผ•่ฎก็ฎ—๏ผ‰โ”€โ”€โ”
                              โ”‚                             โ”‚
                              โ”‚  hidden โ†’ QKW Proj          โ”‚
                              โ”‚  Q/K LayerNorm              โ”‚
                              โ”‚  RoPE                       โ”‚
                              โ”‚  Write K to indexer cache   โ”‚
                              โ”‚  Score = Q @ K_cache * W    โ”‚
                              โ”‚  TopK indices               โ”‚
                              โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜
                                         โ†“
                              Sparse Flash Attention๏ผˆไป… TopK ไธช token ็š„ KV๏ผ‰
                                         โ†“
                              [Gate ร—] O Proj โ†’ output

5.3 Indexer ็ป„ไปถ่ฏฆ่งฃ

Indexer ็ฑป๏ผˆvllm/model_executor/models/qwen3_next.py๏ผ‰ๆ ธๅฟƒ็ป„ไปถ๏ผš

class Indexer(nn.Module):
    def __init__(self, ...):
        self.index_topk = config.index_topk           # TopK ๆ•ฐ้‡
        self.index_n_heads = config.index_n_heads     # indexer query heads ๆ•ฐ้‡
        self.index_kv_heads = config.index_kv_heads   # indexer kv heads ๆ•ฐ้‡
        self.index_head_dim = config.index_head_dim   # indexer head ็ปดๅบฆ

        # QKW ่”ๅˆๆŠ•ๅฝฑ๏ผˆreplicated๏ผŒไธๅš TP๏ผ‰
        self.index_qkw_proj = ReplicatedLinear(hidden_size, q_dim + k_dim + w_dim)

        # Q/K LayerNorm๏ผˆ็จณๅฎšไธๅŒๅฑ‚็š„ score scale๏ผ‰
        self.index_q_layernorm = RMSNorm(index_head_dim)
        self.index_k_layernorm = RMSNorm(index_head_dim)

        # ็‹ฌ็ซ‹็š„ K cache
        self.k_cache = Qwen3NextIndexerCache(...)

        self.score_scale = index_head_dim ** 0.5

Indexer ็š„ forward ่ฎก็ฎ—่ฟ‡็จ‹๏ผš

  1. QKW ๆŠ•ๅฝฑ๏ผšqkw = index_qkw_proj(hidden_states) โ†’ ๆ‹†ๅˆ†ไธบ q, k, w
  2. LayerNorm๏ผšๅฏน q, k ๅˆ†ๅˆซๅš RMSNorm
  3. RoPE๏ผšๆ–ฝๅŠ ๆ—‹่ฝฌไฝ็ฝฎ็ผ–็ ๏ผˆไฝฟ็”จ partial_rotary_factor=0.5๏ผ‰
  4. ๅ†™ๅ…ฅ K Cache๏ผšๅฐ†ๅฝ“ๅ‰ k ๅ†™ๅ…ฅ indexer ็‹ฌ็ซ‹็š„ block KV cache
  5. ่ฎก็ฎ—็›ธๅ…ณๆ€งๅˆ†ๆ•ฐ๏ผš
    • bf16 ๆจกๅผ๏ผšscore = (Q @ K_cache) * W / score_scale
    • fp8 ๆจกๅผ๏ผšๅ…ˆๅฏน Q ๅš per-token-group FP8 ้‡ๅŒ–๏ผŒๅ†่ฎก็ฎ—
  6. TopK ้€‰ๆ‹ฉ๏ผšๅฏนๆฏไธช token ้€‰ๅ‡บ TopK ไธชๆœ€็›ธๅ…ณ็š„ๅކๅฒ token ็ดขๅผ•
  7. ่พ“ๅ‡บ๏ผštopk_indices_buffer ไพ›ๅŽ็ปญ sparse attention ไฝฟ็”จ

5.4 Sparse Attention ็š„ Prefill ๅ’Œ Decode

Prefill ๆ—ถ๏ผˆflash_attn_sparse_prefill๏ผ‰๏ผš

  • ไปŽ K cache ไธญๆŒ‰ block ๆ”ถ้›†ๅฎŒๆ•ด็š„ K
  • ่ฎก็ฎ— logits = Q @ K_cache * W
  • ้€‰ TopK ๅพ—ๅˆฐ็จ€็–็ดขๅผ•
  • ็”จ็จ€็–็ดขๅผ•ไปŽ KV cache ไธญๆๅ–ๅฏนๅบ”็š„ Kใ€V๏ผŒๅš sparse attention
  • ๆ”ฏๆŒ Context Parallel (CP) ๅŠ ้€Ÿ้•ฟๅบๅˆ—

Decode ๆ—ถ๏ผš

  • ไฝฟ็”จ paged MQA logits ่ฎก็ฎ—ๅˆ†ๆ•ฐ
  • TopK ้€‰ๆ‹ฉๅŽ๏ผŒๆๅ–็จ€็– KV
  • ๆ”ฏๆŒ pai-fa3 ๆˆ– triton ไธค็ง sparse attention backend

6. ๅฎŒๆ•ดๆต็จ‹ๅ›พ

โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
โ”‚                        Qwen3-Next Decoder Layer                             โ”‚
โ”‚                                                                             โ”‚
โ”‚  input: hidden_states                                                       โ”‚
โ”‚         โ†“                                                                   โ”‚
โ”‚  โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”                                                            โ”‚
โ”‚  โ”‚ InputLayerNormโ”‚                                                           โ”‚
โ”‚  โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”˜                                                            โ”‚
โ”‚         โ”‚                                                                   โ”‚
โ”‚   โ”Œโ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”                                โ”‚
โ”‚   โ”‚         layer_type?                      โ”‚                               โ”‚
โ”‚   โ”‚                                          โ”‚                               โ”‚
โ”‚   โ–ผ                                          โ–ผ                               โ”‚
โ”‚ "linear_attention" (GDN)          "full_attention"                          โ”‚
โ”‚   โ”‚                                          โ”‚                               โ”‚
โ”‚   โ”‚ โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”      โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”    โ”‚
โ”‚   โ”‚ โ”‚ 1. in_proj_qkvz(h)   โ”‚      โ”‚  qkv_proj(h) โ†’ q, k, v [, gate] โ”‚    โ”‚
โ”‚   โ”‚ โ”‚    in_proj_ba(h)      โ”‚      โ”‚  q_norm, k_norm                  โ”‚    โ”‚
โ”‚   โ”‚ โ”‚    โ†’ q,k,v,z,b,a     โ”‚      โ”‚  rotary_emb(pos, q, k)          โ”‚    โ”‚
โ”‚   โ”‚ โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜      โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜    โ”‚
โ”‚   โ”‚         โ”‚                                โ”‚                               โ”‚
โ”‚   โ”‚ โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”      โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”                    โ”‚
โ”‚   โ”‚ โ”‚ 2. causal_conv1d      โ”‚      โ”‚  QSA enabled?     โ”‚                    โ”‚
โ”‚   โ”‚ โ”‚    (prefill: conv1d_fnโ”‚      โ”‚                   โ”‚                    โ”‚
โ”‚   โ”‚ โ”‚     decode: conv1d_updโ”‚      โ”‚   YES      NO     โ”‚                    โ”‚
โ”‚   โ”‚ โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜      โ”‚    โ”‚        โ”‚     โ”‚                    โ”‚
โ”‚   โ”‚         โ”‚                      โ”‚    โ–ผ        โ”‚     โ”‚                    โ”‚
โ”‚   โ”‚ โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”      โ”‚ Indexer:    โ”‚     โ”‚                    โ”‚
โ”‚   โ”‚ โ”‚ 3. fused_gdn_gating   โ”‚      โ”‚ QKW proj    โ”‚     โ”‚                    โ”‚
โ”‚   โ”‚ โ”‚  g = -exp(A) *        โ”‚      โ”‚ LayerNorm   โ”‚     โ”‚                    โ”‚
โ”‚   โ”‚ โ”‚     softplus(a+bias)  โ”‚      โ”‚ RoPE        โ”‚     โ”‚                    โ”‚
โ”‚   โ”‚ โ”‚  ฮฒ = sigmoid(b)       โ”‚      โ”‚ Write K$    โ”‚     โ”‚                    โ”‚
โ”‚   โ”‚ โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜      โ”‚ Score calc  โ”‚     โ”‚                    โ”‚
โ”‚   โ”‚         โ”‚                      โ”‚ TopK select โ”‚     โ”‚                    โ”‚
โ”‚   โ”‚ โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”      โ”‚    โ”‚        โ”‚     โ”‚                    โ”‚
โ”‚   โ”‚ โ”‚ 4. Recurrent Attn     โ”‚      โ””โ”€โ”€โ”€โ”€โ”คโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”คโ”€โ”€โ”€โ”€โ”€โ”˜                    โ”‚
โ”‚   โ”‚ โ”‚ (prefill: chunk_gdr)  โ”‚           โ–ผ        โ–ผ                          โ”‚
โ”‚   โ”‚ โ”‚ (decode: fused_rec)   โ”‚      โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”                   โ”‚
โ”‚   โ”‚ โ”‚                       โ”‚      โ”‚  Flash Attention    โ”‚                   โ”‚
โ”‚   โ”‚ โ”‚ For each step t:      โ”‚      โ”‚  (sparse if QSA)   โ”‚                   โ”‚
โ”‚   โ”‚ โ”‚  qฬ‚=q/โ€–qโ€–โ‚‚ * scale   โ”‚      โ”‚  (dense  if no QSA)โ”‚                   โ”‚
โ”‚   โ”‚ โ”‚  kฬ‚=k/โ€–kโ€–โ‚‚           โ”‚      โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜                   โ”‚
โ”‚   โ”‚ โ”‚  h *= exp(g)          โ”‚               โ”‚                               โ”‚
โ”‚   โ”‚ โ”‚  v' = ฮฒ(v - hแต€kฬ‚)    โ”‚      โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”                   โ”‚
โ”‚   โ”‚ โ”‚  h += kฬ‚ โŠ— v'         โ”‚      โ”‚ [gate * ] o_proj   โ”‚                   โ”‚
โ”‚   โ”‚ โ”‚  o = hแต€qฬ‚             โ”‚      โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜                   โ”‚
โ”‚   โ”‚ โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜               โ”‚                               โ”‚
โ”‚   โ”‚ โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”               โ”‚                               โ”‚
โ”‚   โ”‚ โ”‚ 5. norm(out, z) +     โ”‚               โ”‚                               โ”‚
โ”‚   โ”‚ โ”‚    out_proj           โ”‚               โ”‚                               โ”‚
โ”‚   โ”‚ โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜               โ”‚                               โ”‚
โ”‚   โ”‚         โ”‚                               โ”‚                               โ”‚
โ”‚   โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜                               โ”‚
โ”‚             โ”‚                                                               โ”‚
โ”‚         attention_output                                                    โ”‚
โ”‚             โ”‚                                                               โ”‚
โ”‚    + residual + [layer_scale]                                               โ”‚
โ”‚             โ”‚                                                               โ”‚
โ”‚    PostAttentionLayerNorm                                                   โ”‚
โ”‚             โ”‚                                                               โ”‚
โ”‚    MLP (Dense or MoE)                                                       โ”‚
โ”‚             โ”‚                                                               โ”‚
โ”‚    + residual + [layer_scale]                                               โ”‚
โ”‚             โ”‚                                                               โ”‚
โ”‚         output                                                              โ”‚
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜

7. KV Cache ็ป“ๆž„ๅฏนๆฏ”

GDN ๅฑ‚็š„ Cache

GDN ไฝฟ็”จ Mamba ้ฃŽๆ ผ็š„ state cache๏ผŒๆฏไธชๅบๅˆ—็ปดๆŠคไธคไธช็Šถๆ€๏ผš

็Šถๆ€ ๅฝข็Šถ ่ฏดๆ˜Ž
conv_state [num_slots, conv_dim, conv_kernel_size] ๅ› ๆžœๅท็งฏ็š„ๆป‘ๅŠจ็ช—ๅฃ็Šถๆ€
ssm_state [num_slots, num_v_heads, head_k_dim, head_v_dim] ้€’ๅฝ’ๆณจๆ„ๅŠ›็š„็Šถๆ€็Ÿฉ้˜ต h

ๅ…ถไธญ conv_dim = key_dim * 2 + value_dimใ€‚

Spec decode ๆ—ถ๏ผŒๆฏไธชๅบๅˆ—้œ€่ฆ num_spec + 1 ไธช slot ๆฅๅญ˜ๅ‚จๅคšๆญฅ็Šถๆ€ใ€‚

Full Attention ๅฑ‚็š„ Cache

ๆ ‡ๅ‡†็š„ block KV cache๏ผš[2, num_blocks, block_size, num_kv_heads, head_dim]

QSA ๆ—ถ้ขๅค–ๆœ‰ Indexer K cache๏ผš[num_blocks, block_size, index_kv_heads, index_head_dim]


8. ๅ…ณ้”ฎๅฏนๆฏ”ๆ€ป็ป“

็ปดๅบฆ GDN (linear_attention) Full Attention (ๆ—  QSA) Full Attention (ๆœ‰ QSA)
KV Cache conv_state + ssm_state block KV cache block KV cache + indexer K cache
Prefill ็ฎ—ๆณ• chunk_gated_delta_rule Flash Attention Indexer + Sparse Flash Attention
Decode ็ฎ—ๆณ• fused_recurrent Flash Attention Indexer + Sparse Flash Attention
Decode ๅคๆ‚ๅบฆ O(d_k ร— d_v) per head O(seq_len ร— d) O(TopK ร— d)
ไฝ็ฝฎ็ผ–็  ๆ—  RoPE RoPE RoPE + Indexer RoPE
QSA ๅฝฑๅ“ ๆ—  - Indexer ้€‰ TopK๏ผŒ็จ€็– attention
็Šถๆ€ๆ•ฐๅญฆ h = exp(g)ยทh + kโŠ—[ฮฒ(v-hแต€k)] softmax(QKแต€/โˆšd)V softmax(QK_topk^T/โˆšd)V_topk

9. ็›ธๅ…ณ็Žฏๅขƒๅ˜้‡

็Žฏๅขƒๅ˜้‡ ไฝœ็”จ
VLLM_GDN_USE_BLADNN ๆ˜ฏๅฆไฝฟ็”จ bladnn ๅŠ ้€Ÿ GDN ้€’ๅฝ’่ฎก็ฎ—
VLLM_QSA_USE_FP8_INDEXER Indexer ๆ˜ฏๅฆไฝฟ็”จ FP8 ้‡ๅŒ–
VLLM_QSA_PREFILL_USE_TL_INDEXER Prefill ๆ—ถ Indexer ๆ˜ฏๅฆไฝฟ็”จ tilelang kernel
VLLM_QSA_DECODE_USE_TL_INDEXER Decode ๆ—ถ Indexer ๆ˜ฏๅฆไฝฟ็”จ tilelang kernel
VLLM_QSA_PREFILL_ATTN_BACKEND QSA prefill ็š„ attention backend๏ผˆtriton / pai-fa3๏ผ‰
VLLM_QSA_DECODE_ATTN_BACKEND QSA decode ็š„ attention backend๏ผˆpai-fa3 ็ญ‰๏ผ‰
VLLM_QSA_PREFILL_USE_CP QSA prefill ๆ˜ฏๅฆไฝฟ็”จ Context Parallel
VLLM_DSA_USE_CONTEXT_PARALLEL_THRESHOLD ๅฏ็”จ CP ็š„ token ๆ•ฐ้˜ˆๅ€ผ
VLLM_DSA_USE_DENSE_PREFILL_THRESHOLD ไฝฟ็”จ dense prefill ็š„้˜ˆๅ€ผ

10. ๆบ็ ๆ–‡ไปถ็ดขๅผ•

ๆ–‡ไปถ ไธป่ฆๅ†…ๅฎน
vllm/model_executor/models/qwen3_next.py GDN ๅฑ‚ (Qwen3NextGatedDeltaNet)ใ€Attention ๅฑ‚ (Qwen3NextAttention)ใ€Indexerใ€่žๅˆ gating kernel
vllm/v1/attention/backends/gdn_attn.py GDN Attention Metadata ๆž„ๅปบๅ™จ
vllm/model_executor/layers/fla/ops/fused_recurrent.py ้€’ๅฝ’ GDN Triton kernel๏ผˆdecode ่ทฏๅพ„๏ผ‰
vllm/model_executor/layers/fla/ops/chunk.py ๅˆ†ๅ— GDN kernel๏ผˆprefill ่ทฏๅพ„๏ผ‰
vllm/v1/attention/backends/flash_attn.py Flash Attention + QSA ็จ€็– attention ๅฎž็Žฐ
vllm/v1/attention/backends/flash_attn_qsautils.py QSA ่พ…ๅŠฉๅทฅๅ…ทๅ‡ฝๆ•ฐ
vllm/model_executor/models/qsa_indexer_utils.py QSA Indexer ็š„ MQA logits ่ฎก็ฎ—ๅทฅๅ…ท
vllm/model_executor/layers/mamba/ops/causal_conv1d.py ๅ› ๆžœๅท็งฏๅฎž็Žฐ
vllm/v1/worker/gpu_model_runner.py GDN metadata ไธŽ model runner ็š„้›†ๆˆ