-
Notifications
You must be signed in to change notification settings - Fork 81
feat: Tilelang Kernel support for CP #410
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Summary of ChangesHello @xiuhu17, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request fundamentally refactors the attention mechanism within the Megatron-LM framework by integrating highly optimized sparse Multi-Head Attention (MLA) kernels developed using Tilelang. The change aims to improve performance and efficiency, particularly for the DeepSeekV32 model, by leveraging custom kernel implementations for critical computational steps. The accompanying Docker environment updates ensure that users can easily set up and utilize this enhanced version. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces support for TileLang kernels for sparse attention, specifically for DeepSeek-V3.2 style Multi-Latent Attention. The changes are extensive, modifying the context-parallel attention implementation and the dynamic sparse attention (DSA) variant. New TileLang kernel files for forward and backward passes are added. While this is a significant performance-oriented change, I've found a critical issue in the integration of the new kernel. The value tensor is ignored in the attention computation, leading to incorrect results. There is also a potential issue in the backward pass regarding causality assumptions with context parallelism. The README is also updated to reflect new docker images and repository setup instructions.
docker/deepseekv32/megatron.patch
Outdated
| + def forward(ctx, q, kv, v, indices, dim_v, K, attention_dropout, softmax_scale, pg): | ||
| '''Forward pass for the native attention function with context parallelism''' | ||
|
|
||
| # Assert einops exists | ||
| @@ -171,12 +171,17 @@ class AttentionFuncionWithContextParallel(torch.autograd.Function): | ||
| probs = [] | ||
| @@ -164,72 +66,71 @@ class AttentionFuncionWithContextParallel(torch.autograd.Function): | ||
| cp_size = torch.distributed.get_world_size(pg) | ||
| comm = AllGatherComm(group=pg) | ||
| nheads = q.shape[2] | ||
| - nheads_k = k.shape[2] | ||
| - heads_k_stride = 1 | ||
| - assert nheads % nheads_k == 0 and nheads_k % heads_k_stride == 0 | ||
| + kv_group = kv.shape[2] | ||
| + heads_kv_stride = 1 | ||
| + assert nheads % kv_group == 0 and kv_group % heads_kv_stride == 0 | ||
| outs = [] | ||
| - probs = [] | ||
| + lses = [] | ||
|
|
||
| # Initialize KV buffers | ||
| - kv_buffer = torch.empty( | ||
| kv_buffer = torch.empty( | ||
| - (2, k.shape[0] * cp_size, k.shape[1], heads_k_stride, k.shape[3]), | ||
| + # seperate KV buffer for MLA | ||
| + kv_buffer = [torch.empty( | ||
| + (k.shape[0] * cp_size, k.shape[1], heads_k_stride, k.shape[3]), | ||
| dtype=k.dtype, | ||
| device=k.device, | ||
| - ) | ||
| - kv_buffer_copy = torch.empty_like(kv_buffer) | ||
| + ), torch.empty( | ||
| + (v.shape[0] * cp_size, v.shape[1], heads_k_stride, v.shape[3]), | ||
| + dtype=v.dtype, | ||
| + device=v.device, | ||
| + )] | ||
| + kv_buffer_copy = [torch.empty_like(kv_buffer[0]), torch.empty_like(kv_buffer[1])] | ||
| - dtype=k.dtype, | ||
| - device=k.device, | ||
| + (kv.shape[0] * cp_size, kv.shape[1], heads_kv_stride, kv.shape[3]), | ||
| + dtype=kv.dtype, | ||
| + device=kv.device, | ||
| ) | ||
| kv_buffer_copy = torch.empty_like(kv_buffer) | ||
|
|
||
| # All-gather first chunk of KV buffers | ||
| k_0 = k[:, :, :heads_k_stride].contiguous() | ||
| @@ -186,7 +191,7 @@ class AttentionFuncionWithContextParallel(torch.autograd.Function): | ||
|
|
||
| # Prepare attention bias | ||
| attn_bias = to_zz_mask_attn_bias( | ||
| - k_0 = k[:, :, :heads_k_stride].contiguous() | ||
| - v_0 = v[:, :, :heads_k_stride].contiguous() | ||
| - comm.all_gather(kv_buffer_copy[0], k_0) | ||
| - comm.all_gather(kv_buffer_copy[1], v_0) | ||
| - | ||
| - # Prepare attention bias | ||
| - attn_bias = to_zz_mask_attn_bias( | ||
| - attention_mask, cp_size, nheads, nheads_k, heads_k_stride, q.device, q.dtype | ||
| + attention_mask, cp_size, nheads, nheads_k, heads_k_stride, q.device, q.dtype, if_zz_mask | ||
| ) | ||
| - ) | ||
| - | ||
| - # Iterate over heads | ||
| - for i in range(0, nheads_k, heads_k_stride): | ||
| + kv_0 = kv[:, :, :heads_kv_stride].contiguous() | ||
| + comm.all_gather(kv_buffer_copy, kv_0) | ||
| + | ||
| + # Prepare topk | ||
| + zz_indices = indices.transpose(1, 2) | ||
| + | ||
| + # Iterate over heads, sequential, i | ||
| + for i in range(0, kv_group, heads_kv_stride): | ||
| # Wait for previous all-gather to complete | ||
| comm.wait() | ||
| kv_buffer, kv_buffer_copy = kv_buffer_copy, kv_buffer | ||
| # All-gather the next portion of KV buffers if not the last iteration | ||
| - if i < nheads_k - heads_k_stride: | ||
| - kvsl = i + heads_k_stride | ||
| - kvsr = kvsl + heads_k_stride | ||
| - send_k = k[:, :, kvsl:kvsr].contiguous() | ||
| - send_v = v[:, :, kvsl:kvsr].contiguous() | ||
| - comm.all_gather(kv_buffer_copy[0], send_k) | ||
| - comm.all_gather(kv_buffer_copy[1], send_v) | ||
| + if i < kv_group - heads_kv_stride: | ||
| + kvsl = i + heads_kv_stride | ||
| + kvsr = kvsl + heads_kv_stride | ||
| + send_kv = kv[:, :, kvsl:kvsr].contiguous() | ||
| + comm.all_gather(kv_buffer_copy, send_kv) | ||
|
|
||
| # Iterate over heads | ||
| @@ -226,6 +231,7 @@ class AttentionFuncionWithContextParallel(torch.autograd.Function): | ||
| # Prepare query, key, value for attention | ||
| - q_i = q[:, :, i * nheads // nheads_k : (i + heads_k_stride) * nheads // nheads_k] | ||
| - k_i = kv_buffer[0] | ||
| - v_i = kv_buffer[1] | ||
| + q_i = q[:, :, i * nheads // kv_group : (i + heads_kv_stride) * nheads // kv_group] | ||
| + kv_i = kv_buffer | ||
|
|
||
| # Rearrange query, key, value to (b, s, h, d) | ||
| q_i = einops.rearrange(q_i, 's b h d -> b s h d') | ||
| - k_i = einops.rearrange(k_i, 's b h d -> b s h d') | ||
| - v_i = einops.rearrange(v_i, 's b h d -> b s h d') | ||
| + s_, b_, h_, d_ = kv_i.shape | ||
| + kv_i = einops.rearrange(kv_i, 's b h d -> b s h d').flatten().view(b_, s_, h_, d_) | ||
| + zz_indices_i = zz_indices[:, :, i:(i+heads_kv_stride)] | ||
| + b_, s_, g_, topk_ = zz_indices_i.shape | ||
| + zz_indices_i = zz_indices_i.flatten().view(b_, s_, g_, topk_) | ||
|
|
||
| # Forward pass | ||
| - out_i, probs_i = eager_attn_fwd( | ||
| - q_i, k_i, v_i, attn_bias, None, softmax_scale, attention_dropout | ||
| - ) | ||
| - outs.append(out_i) | ||
| - probs.append(probs_i) | ||
| + out_i, lse_i = sparse_mla_fwd_interface(q_i.contiguous(), kv_i, zz_indices_i, dim_v, sm_scale = softmax_scale) | ||
| + | ||
| + outs.append(out_i.contiguous()) | ||
| + lses.append(lse_i.contiguous()) | ||
|
|
||
| - # Concatenate outputs and rearrange to (s, b, h, d) | ||
| + # out: [B, seq_len_shard, h, dim] -> [seq_len, B, h, dim] | ||
| out = torch.cat(outs, dim=2) | ||
| out = einops.rearrange(out, 'b s h d -> s b h d') | ||
|
|
||
| # Save contexts for backward pass | ||
| ctx.save_for_backward(q, k, v, attention_mask, *outs, *probs) | ||
| + ctx.if_zz_mask = if_zz_mask | ||
| - ctx.save_for_backward(q, k, v, attention_mask, *outs, *probs) | ||
| + # outs: [[B, seq_len_shard, nheads // kv_group, dim], ...., [B, seq_len_shard, nheads // kv_group, dim]], repeat kv_group // heads_kv_stride times | ||
| + # lses: [[B, seq_len_shard, heads_kv_stride], ...., [B, seq_len_shard, heads_kv_stride]], repeat kv_group // heads_kv_stride times | ||
| + ctx.save_for_backward(q, kv, indices, *outs, *lses) | ||
| + ctx.K = K | ||
| ctx.dropout = attention_dropout | ||
| ctx.scale = softmax_scale | ||
| ctx.heads_k_stride = heads_k_stride # TODO make it configurable | ||
| @@ -252,12 +258,16 @@ class AttentionFuncionWithContextParallel(torch.autograd.Function): | ||
| comm = AllGatherComm(group=pg) | ||
| - ctx.scale = softmax_scale | ||
| - ctx.heads_k_stride = heads_k_stride # TODO make it configurable | ||
| + ctx.softmax_scale = softmax_scale | ||
| + ctx.heads_kv_stride = heads_kv_stride # TODO make it configurable | ||
| ctx.pg = pg | ||
| + ctx.dim_v = dim_v | ||
|
|
||
| return out |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The v parameter, which corresponds to the value tensor, is completely ignored within the forward pass of AttentionFuncionWithContextParallel. The sparse_mla_fwd_interface is called with kv_i, which is derived from the kv input tensor (the key tensor from dsa.py). The TileLang kernel then incorrectly uses a slice of the key tensor as the value tensor during attention computation.
This is a critical correctness bug that will lead to wrong attention outputs.
Furthermore, the backward pass computes the gradient for v (dv) as a slice of the gradient for kv (dkv). Since key and value are separate tensors, this will result in incorrect gradients for the value tensor, and the value tensor's weights will not be updated correctly.
The v tensor needs to be correctly passed through the all-gather communication and used in the attention computation. This likely requires changes to how the kv_buffer is handled and potentially how the sparse_mla_fwd_interface is called, if it can't be modified to accept k and v separately.
docker/deepseekv32/megatron.patch
Outdated
| + # TODO: needs casual = True, may not be compatible with zz | ||
| + dq_i, _dkv_i = sparse_mla_bwd(q_i.contiguous(), kv_i, outs[i], dout_i.contiguous(), zz_indices_i, dim_v, lses[i], softmax_scale, True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The is_casual parameter to sparse_mla_bwd is hardcoded to True. A TODO comment notes that this might be incompatible with the zz (zigzag) masking used for context parallelism. The zz_indices_i are derived from indices calculated using a zz-aware causal mask. However, the sparse_mla_bwd kernel asserts is_causal == True and its implementation might not be aware of the zz key/value reordering, potentially leading to incorrect gradient calculations under context parallelism. This should be investigated to ensure correctness.
No description provided.