GDN๏ผGated Delta Net๏ผๆฏ Qwen3-Next ๆจกๅไธญ็ไธ็ง็บฟๆงๆณจๆๅๆบๅถ๏ผไธไผ ็ป็ full attention ไบคๆฟไฝฟ็จใๆจกๅ็ๆฏไธๅฑ็ฑ config.layer_types ๅณๅฎๆฏ "linear_attention"๏ผGDN๏ผ่ฟๆฏ "full_attention"๏ผๆ ๅ Attention๏ผใ
ๆ ธๅฟไปฃ็ ไฝไบ๏ผ
- GDN ๆจกๅๅฑ๏ผ
vllm/model_executor/models/qwen3_next.pyโQwen3NextGatedDeltaNet - GDN Attention Metadata๏ผ
vllm/v1/attention/backends/gdn_attn.pyโGDNAttentionMetadataBuilder - ้ๅฝๆ ธๅฟ Kernel๏ผ
vllm/model_executor/layers/fla/ops/fused_recurrent.py - ๅๅๆ ธๅฟ Kernel๏ผ
vllm/model_executor/layers/fla/ops/chunk.py
ๅ ณ้ฎ็น๏ผQSA๏ผQuery-Side Aggregation๏ผๅชๅจ full_attention ๅฑไธญไฝฟ็จ๏ผไธๅฝฑๅ GDN ๅฑ็่ฎก็ฎใ
# vllm/model_executor/models/qwen3_next.py : Qwen3NextDecoderLayer.__init__
if self.layer_type == "linear_attention":
topk_indices_buffer = None # only use qsa in full-attn
self.linear_attn = Qwen3NextGatedDeltaNet(...)
elif self.layer_type == "full_attention":
self.self_attn = Qwen3NextAttention(..., topk_indices_buffer=topk_indices_buffer)GDN ็ forward ๅไธบไธไธช้ถๆฎตใ
ไป hidden_states ้่ฟไธคไธชๆๅฝฑ็ฉ้ตๅพๅฐ 6 ไธชๅผ ้๏ผ
# vllm/model_executor/models/qwen3_next.py : Qwen3NextGatedDeltaNet.forward
projected_states_qkvz, _ = self.in_proj_qkvz(hidden_states) # โ q, k, v, z
projected_states_ba, _ = self.in_proj_ba(hidden_states) # โ b, aๆไธค็ง split ๆนๅผ๏ผ
- Fused Triton kernel๏ผ
fused_qkvzba_split_reshape_cat๏ผ๏ผๅจๆปก่ถณๆกไปถๆถไฝฟ็จ่ๅ kernel - PyTorch ๅ็๏ผ
fix_query_key_value_ordering๏ผ๏ผfallback ่ทฏๅพ
่พๅบ็ 6 ไธชๅผ ้็ไฝ็จ๏ผ
| ๅผ ้ | ๅฝข็ถ | ็จ้ |
|---|---|---|
| q | [tokens, num_k_heads/tp, head_k_dim] |
ๆฅ่ฏขๅ้ |
| k | [tokens, num_k_heads/tp, head_k_dim] |
้ฎๅ้ |
| v | [tokens, num_v_heads/tp, head_v_dim] |
ๅผๅ้ |
| z | [tokens, num_v_heads/tp, head_v_dim] |
่พๅบ้จๆง |
| b | [tokens, num_v_heads/tp] |
่ฎก็ฎ beta๏ผๆดๆฐ้จๆง๏ผ |
| a | [tokens, num_v_heads/tp] |
่ฎก็ฎ g๏ผ่กฐๅๅ ๅญ๏ผ |
ๆ็ป q, k, v ่ขซๆผๆฅไธบ mixed_qkv = cat(q, k, v)ใ
ๆ ธๅฟ่ฎก็ฎ้่ฟ _forward_core ๅฎๆ๏ผๅไธบไธไธชๅญๆญฅ้ชคใ
ๅฏน mixed_qkv ่ฟ่ก 1D ๅ ๆๅท็งฏ๏ผๆดๆฐ conv_state ็ผๅญใๆ นๆฎๆฏๅฆๆ spec decode๏ผๅไธบ spec ้จๅๅ non-spec ้จๅๅๅซๅค็๏ผ
# vllm/model_executor/models/qwen3_next.py : _forward_core
# Spec decode ้จๅ๏ผๅๆญฅๆดๆฐ
if spec_sequence_masks is not None:
mixed_qkv_spec = causal_conv1d_update(
mixed_qkv_spec, conv_state, conv_weights, ...,
conv_state_indices=spec_state_indices_tensor[:, 0],
num_accepted_tokens=num_accepted_tokens,
query_start_loc=spec_query_start_loc,
)
# Non-spec Prefill ้จๅ๏ผๆดๅบๅๅท็งฏ
if attn_metadata.num_prefills > 0:
mixed_qkv_non_spec = causal_conv1d_fn(
mixed_qkv_non_spec_T, conv_weights, ...,
conv_states=conv_state,
has_initial_state=has_initial_state,
cache_indices=non_spec_state_indices_tensor,
query_start_loc=non_spec_query_start_loc,
).transpose(0, 1)
# Non-spec Decode ้จๅ๏ผๅๆญฅๆดๆฐ
elif attn_metadata.num_decodes > 0:
mixed_qkv_non_spec = causal_conv1d_update(
mixed_qkv_non_spec, conv_state, conv_weights, ...,
conv_state_indices=non_spec_state_indices_tensor,
)้่ฟ fused triton kernel ่ฎก็ฎ่กฐๅๅ ๅญ g ๅๆดๆฐ้จๆง beta๏ผ
# vllm/model_executor/models/qwen3_next.py
g, beta = fused_gdn_gating(self.A_log, a.contiguous(), b, self.dt_bias)ๆฐๅญฆๅ ฌๅผ๏ผ
g = -exp(A_log) * softplus(a + dt_bias)
beta = sigmoid(b)
ๅ
ถไธญ softplus(x) = (1/ฮฒ) * log(1 + exp(ฮฒ*x))๏ผๅฝ ฮฒ*x > threshold ๆถ้ๅไธบ xใ
Triton kernel ๅฎ็ฐ๏ผ
# fused_gdn_gating_kernel
x = a.float() + dt_bias.float()
softplus_x = where(beta * x <= threshold, (1/beta) * log(1 + exp(beta * x)), x)
g = -exp(A_log.float()) * softplus_x
beta_output = sigmoid(b.float())่ฟๆฏ GDN ็ๆ ธๅฟๆฐๅญฆ่ฎก็ฎ๏ผๆ นๆฎ prefill/decode ไฝฟ็จไธๅ็็ฎๆณใ
Decode ่ทฏๅพไฝฟ็จ fused_recurrent_gated_delta_rule๏ผ้ๆญฅ้ๆจ๏ผๅคๆๅบฆ O(1)/step๏ผ๏ผ
# Triton kernel: fused_recurrent_gated_delta_rule_fwd_kernel
# ๆไปถ: vllm/model_executor/layers/fla/ops/fused_recurrent.py
for i_t in range(0, T):
b_q = load(p_q) # query
b_k = load(p_k) # key
b_v = load(p_v) # value
# L2 normalization
b_q = b_q / sqrt(sum(b_q * b_q) + 1e-6)
b_k = b_k / sqrt(sum(b_k * b_k) + 1e-6)
b_q = b_q * scale
# State decay
b_h *= exp(b_g) # h = h * exp(g)
# Delta rule
b_v -= sum(b_h * b_k[:, None], 0) # v' = v - h^T @ k
b_v *= b_beta # v' = beta * v'
# State update
b_h += b_k[:, None] * b_v[None, :] # h = h + k โ v'
# Output
b_o = sum(b_h * b_q[:, None], 0) # o = h^T @ qPrefill ่ทฏๅพไฝฟ็จ chunk_gated_delta_rule๏ผๅๅๅนถ่ก๏ผๆด้ซๆๅฐๅค็้ฟๅบๅ๏ผ๏ผ
# vllm/model_executor/models/qwen3_next.py : _forward_core
if attn_metadata.num_prefills > 0:
initial_state = ssm_state[non_spec_state_indices_tensor].contiguous()
initial_state[~has_initial_state, ...] = 0
core_attn_out_non_spec, last_recurrent_state = chunk_gated_delta_rule(
q=query_non_spec, k=key_non_spec, v=value_non_spec,
g=g_non_spec, beta=beta_non_spec,
initial_state=initial_state,
output_final_state=True,
cu_seqlens=non_spec_query_start_loc,
head_first=False,
use_qk_l2norm_in_kernel=True,
)
# ๅฐๆ็ป็ถๆๅๅ cache
ssm_state[non_spec_state_indices_tensor] = last_recurrent_state.to(ssm_state.dtype)ๆฐๅญฆๅ ฌๅผๆป็ป๏ผๆฏไธช time step t๏ผ๏ผ
ๅ ถไธญ๏ผ
-
$h_t \in \mathbb{R}^{d_k \times d_v}$ ไธบ้ๅฝ็ถๆ็ฉ้ต -
$e^{g_t}$ ไธบๆๆฐ่กฐๅๅ ๅญ๏ผๆงๅถๅๅฒไฟกๆฏ็้ๅฟ -
$\sigma(b_t)$ ไธบ sigmoid ้จๆง๏ผๆงๅถๆฐไฟกๆฏ็ๅๅ ฅๅผบๅบฆ -
$v_t - h_{t-1}^T \hat{k}_t$ ไธบ"delta rule"๏ผๅชๅๅ ฅไธๅฝๅ็ถๆๅทฎๅผ็้จๅ
# vllm/model_executor/models/qwen3_next.py : Qwen3NextGatedDeltaNet.forward
core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1])
z = z.reshape(-1, z.shape[-1])
core_attn_out = self.norm(core_attn_out, z) # RMSNormGated: silu(z) * norm(out)
core_attn_out = rearrange(core_attn_out, "... h d -> ... (h d)")
output[:num_tokens], _ = self.out_proj(core_attn_out) # ็บฟๆงๆๅฝฑๅ hidden_size่พๅบ็ป่ฟ RMSNormGated๏ผไฝฟ็จ z ไฝไธบ gate๏ผsilu ๆฟๆดป๏ผๅ้่ฟ out_proj ๆ ๅฐๅ hidden_sizeใ
| ้ถๆฎต | Prefill | Decode |
|---|---|---|
| Conv1d | causal_conv1d_fn๏ผๆดๅบๅๅท็งฏ๏ผ |
causal_conv1d_update๏ผๅๆญฅๆดๆฐ๏ผ |
| ้ๅฝๆณจๆๅ | chunk_gated_delta_rule๏ผๅๅๅนถ่ก๏ผ |
fused_recurrent_gated_delta_rule๏ผ้ๆญฅ้ๆจ๏ผ |
| State ๅๅงๅ | ไป ssm_state ๅ ่ฝฝ๏ผ่ฅๆ ๅๅฒๅ็ฝฎ้ถ | ็ดๆฅไฝฟ็จ ssm_state ไธญ็ๅทฒๆ็ถๆ |
| State ๅๅ | ่ฎก็ฎๅฎๆๅๆดไฝๅๅ ssm_state | ๅๅฐๆดๆฐ ssm_state๏ผinplace๏ผ |
| ๅคๆๅบฆ | O(n) ๆป่ฎก๏ผๅๅๅนถ่กๅ ้ | O(d_k ร d_v) per head per step |
GDNAttentionMetadataBuilder๏ผไฝไบ vllm/v1/attention/backends/gdn_attn.py๏ผ่ด่ดฃไธบๆฏๆฌก forward ๆๅปบ metadataใ
@dataclass
class GDNAttentionMetadata:
num_prefills: int
num_prefill_tokens: int
num_decodes: int
num_decode_tokens: int
num_spec_decodes: int
num_spec_decode_tokens: int
num_actual_tokens: int
has_initial_state: torch.Tensor | None # prefill ๆถๆ ่ฎฐๆฏๅฆๆๅๅฒ็ถๆ
spec_query_start_loc: torch.Tensor | None # spec decode ็ query ไฝ็ฝฎ
non_spec_query_start_loc: torch.Tensor | None # non-spec ็ query ไฝ็ฝฎ
spec_state_indices_tensor: torch.Tensor | None # spec decode ็ state ็ดขๅผ
non_spec_state_indices_tensor: torch.Tensor | None # non-spec ็ state ็ดขๅผ
spec_sequence_masks: torch.Tensor | None # ๆ ่ฎฐๅชไบๅบๅๆฏ spec decode
spec_token_indx: torch.Tensor | None # spec token ็ๅ
จๅฑ็ดขๅผ
non_spec_token_indx: torch.Tensor | None # non-spec token ็ๅ
จๅฑ็ดขๅผ
num_accepted_tokens: torch.Tensor | None # ๆฏไธชๅบๅๆฅๅ็ token ๆฐ
retrieve_parent_token: torch.Tensor | None # tree attention ็็ถ่็น็ดขๅผ- ๅคๆญๆฏๅฆๆ spec decode๏ผ้่ฟ
num_decode_draft_tokens_cpuๅคๆญ - ๆ spec decode ๆถ๏ผ็ดๆฅๆ prefill/decode ๅๅ๏ผstate ็ดขๅผๅ
block_table[:, 0] - ๆ spec decode ๆถ๏ผ
- ๅฐ batch ๅไธบ spec ้จๅ๏ผๅ
num_spec_decodesไธช๏ผๅ non-spec ้จๅ - spec ้จๅ็ state ็ดขๅผๅ
block_table[:, :num_spec+1]๏ผๅคๆญฅ state๏ผ - ้่ฟ
spec_token_indx/non_spec_token_indx็ดขๅผๆฅๅ็ฆปๅๅๅนถ token
- ๅฐ batch ๅไธบ spec ้จๅ๏ผๅ
- CUDAGraph ๆฏๆ๏ผๅฏน decode-only ๅบๆฏ่ฟ่ก padding ไปฅ้้ CUDAGraph ๆ่ท
QSA ๅชๅฝฑๅ full_attention ๅฑ๏ผไธๅฝฑๅ GDN ๅฑใๅฝ config ไธญ index_topk > 0 ๆถๅฏ็จใ
hidden_states โ QKV Proj โ Q/K Norm โ RoPE โ Standard Flash Attention๏ผๅ
จ KV๏ผโ [Gate ร] O Proj
ๆ QSA ๆถ๏ผๅจๆ ๅ attention ไนๅๅขๅ ไบไธไธช Indexer ๆจกๅ็จไบ็จ็ๅ๏ผ
hidden_states โ QKV Proj โ Q/K Norm โ RoPE
โ
โโโ Indexer๏ผ็จ็็ดขๅผ่ฎก็ฎ๏ผโโโ
โ โ
โ hidden โ QKW Proj โ
โ Q/K LayerNorm โ
โ RoPE โ
โ Write K to indexer cache โ
โ Score = Q @ K_cache * W โ
โ TopK indices โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โ
Sparse Flash Attention๏ผไป
TopK ไธช token ็ KV๏ผ
โ
[Gate ร] O Proj โ output
Indexer ็ฑป๏ผvllm/model_executor/models/qwen3_next.py๏ผๆ ธๅฟ็ปไปถ๏ผ
class Indexer(nn.Module):
def __init__(self, ...):
self.index_topk = config.index_topk # TopK ๆฐ้
self.index_n_heads = config.index_n_heads # indexer query heads ๆฐ้
self.index_kv_heads = config.index_kv_heads # indexer kv heads ๆฐ้
self.index_head_dim = config.index_head_dim # indexer head ็ปดๅบฆ
# QKW ่ๅๆๅฝฑ๏ผreplicated๏ผไธๅ TP๏ผ
self.index_qkw_proj = ReplicatedLinear(hidden_size, q_dim + k_dim + w_dim)
# Q/K LayerNorm๏ผ็จณๅฎไธๅๅฑ็ score scale๏ผ
self.index_q_layernorm = RMSNorm(index_head_dim)
self.index_k_layernorm = RMSNorm(index_head_dim)
# ็ฌ็ซ็ K cache
self.k_cache = Qwen3NextIndexerCache(...)
self.score_scale = index_head_dim ** 0.5Indexer ็ forward ่ฎก็ฎ่ฟ็จ๏ผ
- QKW ๆๅฝฑ๏ผ
qkw = index_qkw_proj(hidden_states)โ ๆๅไธบ q, k, w - LayerNorm๏ผๅฏน q, k ๅๅซๅ RMSNorm
- RoPE๏ผๆฝๅ ๆ่ฝฌไฝ็ฝฎ็ผ็ ๏ผไฝฟ็จ
partial_rotary_factor=0.5๏ผ - ๅๅ ฅ K Cache๏ผๅฐๅฝๅ k ๅๅ ฅ indexer ็ฌ็ซ็ block KV cache
- ่ฎก็ฎ็ธๅ
ณๆงๅๆฐ๏ผ
- bf16 ๆจกๅผ๏ผ
score = (Q @ K_cache) * W / score_scale - fp8 ๆจกๅผ๏ผๅ ๅฏน Q ๅ per-token-group FP8 ้ๅ๏ผๅ่ฎก็ฎ
- bf16 ๆจกๅผ๏ผ
- TopK ้ๆฉ๏ผๅฏนๆฏไธช token ้ๅบ TopK ไธชๆ็ธๅ ณ็ๅๅฒ token ็ดขๅผ
- ่พๅบ๏ผ
topk_indices_bufferไพๅ็ปญ sparse attention ไฝฟ็จ
Prefill ๆถ๏ผflash_attn_sparse_prefill๏ผ๏ผ
- ไป K cache ไธญๆ block ๆถ้ๅฎๆด็ K
- ่ฎก็ฎ
logits = Q @ K_cache * W - ้ TopK ๅพๅฐ็จ็็ดขๅผ
- ็จ็จ็็ดขๅผไป KV cache ไธญๆๅๅฏนๅบ็ KใV๏ผๅ sparse attention
- ๆฏๆ Context Parallel (CP) ๅ ้้ฟๅบๅ
Decode ๆถ๏ผ
- ไฝฟ็จ paged MQA logits ่ฎก็ฎๅๆฐ
- TopK ้ๆฉๅ๏ผๆๅ็จ็ KV
- ๆฏๆ pai-fa3 ๆ triton ไธค็ง sparse attention backend
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โ Qwen3-Next Decoder Layer โ
โ โ
โ input: hidden_states โ
โ โ โ
โ โโโโโโโโโโโโโโโ โ
โ โ InputLayerNormโ โ
โ โโโโโโโโฌโโโโโโโ โ
โ โ โ
โ โโโโโโโดโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ
โ โ layer_type? โ โ
โ โ โ โ
โ โผ โผ โ
โ "linear_attention" (GDN) "full_attention" โ
โ โ โ โ
โ โ โโโโโโโโโโโโโโโโโโโโโโโโโ โโโโโโโโโโดโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ
โ โ โ 1. in_proj_qkvz(h) โ โ qkv_proj(h) โ q, k, v [, gate] โ โ
โ โ โ in_proj_ba(h) โ โ q_norm, k_norm โ โ
โ โ โ โ q,k,v,z,b,a โ โ rotary_emb(pos, q, k) โ โ
โ โ โโโโโโโโโฌโโโโโโโโโโโโโโโโ โโโโโโโโโโฌโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ
โ โ โ โ โ
โ โ โโโโโโโโโดโโโโโโโโโโโโโโโโ โโโโโโโโโโดโโโโโโโโโโ โ
โ โ โ 2. causal_conv1d โ โ QSA enabled? โ โ
โ โ โ (prefill: conv1d_fnโ โ โ โ
โ โ โ decode: conv1d_updโ โ YES NO โ โ
โ โ โโโโโโโโโฌโโโโโโโโโโโโโโโโ โ โ โ โ โ
โ โ โ โ โผ โ โ โ
โ โ โโโโโโโโโดโโโโโโโโโโโโโโโโ โ Indexer: โ โ โ
โ โ โ 3. fused_gdn_gating โ โ QKW proj โ โ โ
โ โ โ g = -exp(A) * โ โ LayerNorm โ โ โ
โ โ โ softplus(a+bias) โ โ RoPE โ โ โ
โ โ โ ฮฒ = sigmoid(b) โ โ Write K$ โ โ โ
โ โ โโโโโโโโโฌโโโโโโโโโโโโโโโโ โ Score calc โ โ โ
โ โ โ โ TopK select โ โ โ
โ โ โโโโโโโโโดโโโโโโโโโโโโโโโโ โ โ โ โ โ
โ โ โ 4. Recurrent Attn โ โโโโโโคโโโโโโโโโคโโโโโโ โ
โ โ โ (prefill: chunk_gdr) โ โผ โผ โ
โ โ โ (decode: fused_rec) โ โโโโโโโโโโโโโโโโโโโโโโ โ
โ โ โ โ โ Flash Attention โ โ
โ โ โ For each step t: โ โ (sparse if QSA) โ โ
โ โ โ qฬ=q/โqโโ * scale โ โ (dense if no QSA)โ โ
โ โ โ kฬ=k/โkโโ โ โโโโโโโโโโฌโโโโโโโโโโโโ โ
โ โ โ h *= exp(g) โ โ โ
โ โ โ v' = ฮฒ(v - hแตkฬ) โ โโโโโโโโโโดโโโโโโโโโโโโ โ
โ โ โ h += kฬ โ v' โ โ [gate * ] o_proj โ โ
โ โ โ o = hแตqฬ โ โโโโโโโโโโฌโโโโโโโโโโโโ โ
โ โ โโโโโโโโโฌโโโโโโโโโโโโโโโโ โ โ
โ โ โโโโโโโโโดโโโโโโโโโโโโโโโโ โ โ
โ โ โ 5. norm(out, z) + โ โ โ
โ โ โ out_proj โ โ โ
โ โ โโโโโโโโโฌโโโโโโโโโโโโโโโโ โ โ
โ โ โ โ โ
โ โโโโโโโโโโโดโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ โ
โ โ โ
โ attention_output โ
โ โ โ
โ + residual + [layer_scale] โ
โ โ โ
โ PostAttentionLayerNorm โ
โ โ โ
โ MLP (Dense or MoE) โ
โ โ โ
โ + residual + [layer_scale] โ
โ โ โ
โ output โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
GDN ไฝฟ็จ Mamba ้ฃๆ ผ็ state cache๏ผๆฏไธชๅบๅ็ปดๆคไธคไธช็ถๆ๏ผ
| ็ถๆ | ๅฝข็ถ | ่ฏดๆ |
|---|---|---|
| conv_state | [num_slots, conv_dim, conv_kernel_size] |
ๅ ๆๅท็งฏ็ๆปๅจ็ชๅฃ็ถๆ |
| ssm_state | [num_slots, num_v_heads, head_k_dim, head_v_dim] |
้ๅฝๆณจๆๅ็็ถๆ็ฉ้ต h |
ๅ
ถไธญ conv_dim = key_dim * 2 + value_dimใ
Spec decode ๆถ๏ผๆฏไธชๅบๅ้่ฆ num_spec + 1 ไธช slot ๆฅๅญๅจๅคๆญฅ็ถๆใ
ๆ ๅ็ block KV cache๏ผ[2, num_blocks, block_size, num_kv_heads, head_dim]
QSA ๆถ้ขๅคๆ Indexer K cache๏ผ[num_blocks, block_size, index_kv_heads, index_head_dim]
| ็ปดๅบฆ | GDN (linear_attention) | Full Attention (ๆ QSA) | Full Attention (ๆ QSA) |
|---|---|---|---|
| KV Cache | conv_state + ssm_state | block KV cache | block KV cache + indexer K cache |
| Prefill ็ฎๆณ | chunk_gated_delta_rule | Flash Attention | Indexer + Sparse Flash Attention |
| Decode ็ฎๆณ | fused_recurrent | Flash Attention | Indexer + Sparse Flash Attention |
| Decode ๅคๆๅบฆ | O(d_k ร d_v) per head | O(seq_len ร d) | O(TopK ร d) |
| ไฝ็ฝฎ็ผ็ | ๆ RoPE | RoPE | RoPE + Indexer RoPE |
| QSA ๅฝฑๅ | ๆ | - | Indexer ้ TopK๏ผ็จ็ attention |
| ็ถๆๆฐๅญฆ | h = exp(g)ยทh + kโ[ฮฒ(v-hแตk)] | softmax(QKแต/โd)V | softmax(QK_topk^T/โd)V_topk |
| ็ฏๅขๅ้ | ไฝ็จ |
|---|---|
VLLM_GDN_USE_BLADNN |
ๆฏๅฆไฝฟ็จ bladnn ๅ ้ GDN ้ๅฝ่ฎก็ฎ |
VLLM_QSA_USE_FP8_INDEXER |
Indexer ๆฏๅฆไฝฟ็จ FP8 ้ๅ |
VLLM_QSA_PREFILL_USE_TL_INDEXER |
Prefill ๆถ Indexer ๆฏๅฆไฝฟ็จ tilelang kernel |
VLLM_QSA_DECODE_USE_TL_INDEXER |
Decode ๆถ Indexer ๆฏๅฆไฝฟ็จ tilelang kernel |
VLLM_QSA_PREFILL_ATTN_BACKEND |
QSA prefill ็ attention backend๏ผtriton / pai-fa3๏ผ |
VLLM_QSA_DECODE_ATTN_BACKEND |
QSA decode ็ attention backend๏ผpai-fa3 ็ญ๏ผ |
VLLM_QSA_PREFILL_USE_CP |
QSA prefill ๆฏๅฆไฝฟ็จ Context Parallel |
VLLM_DSA_USE_CONTEXT_PARALLEL_THRESHOLD |
ๅฏ็จ CP ็ token ๆฐ้ๅผ |
VLLM_DSA_USE_DENSE_PREFILL_THRESHOLD |
ไฝฟ็จ dense prefill ็้ๅผ |
| ๆไปถ | ไธป่ฆๅ ๅฎน |
|---|---|
vllm/model_executor/models/qwen3_next.py |
GDN ๅฑ (Qwen3NextGatedDeltaNet)ใAttention ๅฑ (Qwen3NextAttention)ใIndexerใ่ๅ gating kernel |
vllm/v1/attention/backends/gdn_attn.py |
GDN Attention Metadata ๆๅปบๅจ |
vllm/model_executor/layers/fla/ops/fused_recurrent.py |
้ๅฝ GDN Triton kernel๏ผdecode ่ทฏๅพ๏ผ |
vllm/model_executor/layers/fla/ops/chunk.py |
ๅๅ GDN kernel๏ผprefill ่ทฏๅพ๏ผ |
vllm/v1/attention/backends/flash_attn.py |
Flash Attention + QSA ็จ็ attention ๅฎ็ฐ |
vllm/v1/attention/backends/flash_attn_qsautils.py |
QSA ่พ ๅฉๅทฅๅ ทๅฝๆฐ |
vllm/model_executor/models/qsa_indexer_utils.py |
QSA Indexer ็ MQA logits ่ฎก็ฎๅทฅๅ ท |
vllm/model_executor/layers/mamba/ops/causal_conv1d.py |
ๅ ๆๅท็งฏๅฎ็ฐ |
vllm/v1/worker/gpu_model_runner.py |
GDN metadata ไธ model runner ็้ๆ |