Skip to content

Conversation

@xiuhu17
Copy link

@xiuhu17 xiuhu17 commented Jan 7, 2026

No description provided.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @xiuhu17, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request fundamentally refactors the attention mechanism within the Megatron-LM framework by integrating highly optimized sparse Multi-Head Attention (MLA) kernels developed using Tilelang. The change aims to improve performance and efficiency, particularly for the DeepSeekV32 model, by leveraging custom kernel implementations for critical computational steps. The accompanying Docker environment updates ensure that users can easily set up and utilize this enhanced version.

Highlights

  • Tilelang Kernel Integration: New Tilelang-optimized kernels for sparse Multi-Head Attention (MLA) forward and backward passes have been introduced to enhance performance.
  • Attention Mechanism Refactor: The existing eager attention implementation within AttentionFuncionWithContextParallel has been replaced with the new Tilelang kernels, involving significant changes to function signatures and internal logic.
  • Docker Environment Update: The Docker setup instructions have been updated to reflect the new tilelang branch and image, streamlining the use of the Tilelang-enabled Megatron-LM.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for TileLang kernels for sparse attention, specifically for DeepSeek-V3.2 style Multi-Latent Attention. The changes are extensive, modifying the context-parallel attention implementation and the dynamic sparse attention (DSA) variant. New TileLang kernel files for forward and backward passes are added. While this is a significant performance-oriented change, I've found a critical issue in the integration of the new kernel. The value tensor is ignored in the attention computation, leading to incorrect results. There is also a potential issue in the backward pass regarding causality assumptions with context parallelism. The README is also updated to reflect new docker images and repository setup instructions.

Comment on lines 145 to 259
+ def forward(ctx, q, kv, v, indices, dim_v, K, attention_dropout, softmax_scale, pg):
'''Forward pass for the native attention function with context parallelism'''

# Assert einops exists
@@ -171,12 +171,17 @@ class AttentionFuncionWithContextParallel(torch.autograd.Function):
probs = []
@@ -164,72 +66,71 @@ class AttentionFuncionWithContextParallel(torch.autograd.Function):
cp_size = torch.distributed.get_world_size(pg)
comm = AllGatherComm(group=pg)
nheads = q.shape[2]
- nheads_k = k.shape[2]
- heads_k_stride = 1
- assert nheads % nheads_k == 0 and nheads_k % heads_k_stride == 0
+ kv_group = kv.shape[2]
+ heads_kv_stride = 1
+ assert nheads % kv_group == 0 and kv_group % heads_kv_stride == 0
outs = []
- probs = []
+ lses = []

# Initialize KV buffers
- kv_buffer = torch.empty(
kv_buffer = torch.empty(
- (2, k.shape[0] * cp_size, k.shape[1], heads_k_stride, k.shape[3]),
+ # seperate KV buffer for MLA
+ kv_buffer = [torch.empty(
+ (k.shape[0] * cp_size, k.shape[1], heads_k_stride, k.shape[3]),
dtype=k.dtype,
device=k.device,
- )
- kv_buffer_copy = torch.empty_like(kv_buffer)
+ ), torch.empty(
+ (v.shape[0] * cp_size, v.shape[1], heads_k_stride, v.shape[3]),
+ dtype=v.dtype,
+ device=v.device,
+ )]
+ kv_buffer_copy = [torch.empty_like(kv_buffer[0]), torch.empty_like(kv_buffer[1])]
- dtype=k.dtype,
- device=k.device,
+ (kv.shape[0] * cp_size, kv.shape[1], heads_kv_stride, kv.shape[3]),
+ dtype=kv.dtype,
+ device=kv.device,
)
kv_buffer_copy = torch.empty_like(kv_buffer)

# All-gather first chunk of KV buffers
k_0 = k[:, :, :heads_k_stride].contiguous()
@@ -186,7 +191,7 @@ class AttentionFuncionWithContextParallel(torch.autograd.Function):

# Prepare attention bias
attn_bias = to_zz_mask_attn_bias(
- k_0 = k[:, :, :heads_k_stride].contiguous()
- v_0 = v[:, :, :heads_k_stride].contiguous()
- comm.all_gather(kv_buffer_copy[0], k_0)
- comm.all_gather(kv_buffer_copy[1], v_0)
-
- # Prepare attention bias
- attn_bias = to_zz_mask_attn_bias(
- attention_mask, cp_size, nheads, nheads_k, heads_k_stride, q.device, q.dtype
+ attention_mask, cp_size, nheads, nheads_k, heads_k_stride, q.device, q.dtype, if_zz_mask
)
- )
-
- # Iterate over heads
- for i in range(0, nheads_k, heads_k_stride):
+ kv_0 = kv[:, :, :heads_kv_stride].contiguous()
+ comm.all_gather(kv_buffer_copy, kv_0)
+
+ # Prepare topk
+ zz_indices = indices.transpose(1, 2)
+
+ # Iterate over heads, sequential, i
+ for i in range(0, kv_group, heads_kv_stride):
# Wait for previous all-gather to complete
comm.wait()
kv_buffer, kv_buffer_copy = kv_buffer_copy, kv_buffer
# All-gather the next portion of KV buffers if not the last iteration
- if i < nheads_k - heads_k_stride:
- kvsl = i + heads_k_stride
- kvsr = kvsl + heads_k_stride
- send_k = k[:, :, kvsl:kvsr].contiguous()
- send_v = v[:, :, kvsl:kvsr].contiguous()
- comm.all_gather(kv_buffer_copy[0], send_k)
- comm.all_gather(kv_buffer_copy[1], send_v)
+ if i < kv_group - heads_kv_stride:
+ kvsl = i + heads_kv_stride
+ kvsr = kvsl + heads_kv_stride
+ send_kv = kv[:, :, kvsl:kvsr].contiguous()
+ comm.all_gather(kv_buffer_copy, send_kv)

# Iterate over heads
@@ -226,6 +231,7 @@ class AttentionFuncionWithContextParallel(torch.autograd.Function):
# Prepare query, key, value for attention
- q_i = q[:, :, i * nheads // nheads_k : (i + heads_k_stride) * nheads // nheads_k]
- k_i = kv_buffer[0]
- v_i = kv_buffer[1]
+ q_i = q[:, :, i * nheads // kv_group : (i + heads_kv_stride) * nheads // kv_group]
+ kv_i = kv_buffer

# Rearrange query, key, value to (b, s, h, d)
q_i = einops.rearrange(q_i, 's b h d -> b s h d')
- k_i = einops.rearrange(k_i, 's b h d -> b s h d')
- v_i = einops.rearrange(v_i, 's b h d -> b s h d')
+ s_, b_, h_, d_ = kv_i.shape
+ kv_i = einops.rearrange(kv_i, 's b h d -> b s h d').flatten().view(b_, s_, h_, d_)
+ zz_indices_i = zz_indices[:, :, i:(i+heads_kv_stride)]
+ b_, s_, g_, topk_ = zz_indices_i.shape
+ zz_indices_i = zz_indices_i.flatten().view(b_, s_, g_, topk_)

# Forward pass
- out_i, probs_i = eager_attn_fwd(
- q_i, k_i, v_i, attn_bias, None, softmax_scale, attention_dropout
- )
- outs.append(out_i)
- probs.append(probs_i)
+ out_i, lse_i = sparse_mla_fwd_interface(q_i.contiguous(), kv_i, zz_indices_i, dim_v, sm_scale = softmax_scale)
+
+ outs.append(out_i.contiguous())
+ lses.append(lse_i.contiguous())

- # Concatenate outputs and rearrange to (s, b, h, d)
+ # out: [B, seq_len_shard, h, dim] -> [seq_len, B, h, dim]
out = torch.cat(outs, dim=2)
out = einops.rearrange(out, 'b s h d -> s b h d')

# Save contexts for backward pass
ctx.save_for_backward(q, k, v, attention_mask, *outs, *probs)
+ ctx.if_zz_mask = if_zz_mask
- ctx.save_for_backward(q, k, v, attention_mask, *outs, *probs)
+ # outs: [[B, seq_len_shard, nheads // kv_group, dim], ...., [B, seq_len_shard, nheads // kv_group, dim]], repeat kv_group // heads_kv_stride times
+ # lses: [[B, seq_len_shard, heads_kv_stride], ...., [B, seq_len_shard, heads_kv_stride]], repeat kv_group // heads_kv_stride times
+ ctx.save_for_backward(q, kv, indices, *outs, *lses)
+ ctx.K = K
ctx.dropout = attention_dropout
ctx.scale = softmax_scale
ctx.heads_k_stride = heads_k_stride # TODO make it configurable
@@ -252,12 +258,16 @@ class AttentionFuncionWithContextParallel(torch.autograd.Function):
comm = AllGatherComm(group=pg)
- ctx.scale = softmax_scale
- ctx.heads_k_stride = heads_k_stride # TODO make it configurable
+ ctx.softmax_scale = softmax_scale
+ ctx.heads_kv_stride = heads_kv_stride # TODO make it configurable
ctx.pg = pg
+ ctx.dim_v = dim_v

return out
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The v parameter, which corresponds to the value tensor, is completely ignored within the forward pass of AttentionFuncionWithContextParallel. The sparse_mla_fwd_interface is called with kv_i, which is derived from the kv input tensor (the key tensor from dsa.py). The TileLang kernel then incorrectly uses a slice of the key tensor as the value tensor during attention computation.

This is a critical correctness bug that will lead to wrong attention outputs.

Furthermore, the backward pass computes the gradient for v (dv) as a slice of the gradient for kv (dkv). Since key and value are separate tensors, this will result in incorrect gradients for the value tensor, and the value tensor's weights will not be updated correctly.

The v tensor needs to be correctly passed through the all-gather communication and used in the attention computation. This likely requires changes to how the kv_buffer is handled and potentially how the sparse_mla_fwd_interface is called, if it can't be modified to accept k and v separately.

Comment on lines 366 to 367
+ # TODO: needs casual = True, may not be compatible with zz
+ dq_i, _dkv_i = sparse_mla_bwd(q_i.contiguous(), kv_i, outs[i], dout_i.contiguous(), zz_indices_i, dim_v, lses[i], softmax_scale, True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The is_casual parameter to sparse_mla_bwd is hardcoded to True. A TODO comment notes that this might be incompatible with the zz (zigzag) masking used for context parallelism. The zz_indices_i are derived from indices calculated using a zz-aware causal mask. However, the sparse_mla_bwd kernel asserts is_causal == True and its implementation might not be aware of the zz key/value reordering, potentially leading to incorrect gradient calculations under context parallelism. This should be investigated to ensure correctness.

@xiuhu17 xiuhu17 changed the title Tilelang Kernel support feat: Tilelang Kernel support for CP Jan 7, 2026
@xiuhu17 xiuhu17 closed this Jan 24, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant