Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 44 additions & 38 deletions docker/deepseekv32/megatron.patch
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
diff --git a/megatron/core/transformer/dot_product_attention_context_parallel.py b/megatron/core/transformer/dot_product_attention_context_parallel.py
index 89659a1d7..1def27c69 100644
index 89659a1d7..c69859a04 100644
--- a/megatron/core/transformer/dot_product_attention_context_parallel.py
+++ b/megatron/core/transformer/dot_product_attention_context_parallel.py
@@ -3,9 +3,12 @@
@@ -3,107 +3,12 @@
# Some of this code was adopted from https://github.com/zhuzilin/ring-flash-attention/
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
Expand All @@ -11,14 +11,15 @@ index 89659a1d7..1def27c69 100644
import torch
+import torch.distributed as dist
from torch.nn import functional as F
+from .tilelang_kernel import sparse_mla_bwd, sparse_mla_fwd_interface

try:
import einops
@@ -15,96 +18,6 @@ except ImportError:
HAVE_EINOPS = False


-
-try:
- import einops
-
- HAVE_EINOPS = True
-except ImportError:
- HAVE_EINOPS = False
-
-
-@torch.no_grad
-def eager_attn_fwd(q, k, v, attn_bias, sinks, scale, dropout):
- """Forward pass for eager attention"""
Expand Down Expand Up @@ -108,11 +109,11 @@ index 89659a1d7..1def27c69 100644
- grad_q = einops.rearrange(grad__q, 'b h s d -> b s h d')
- return grad_q, grad_k, grad_v, grad_sinks
-
-
+from .tilelang_kernel import sparse_mla_bwd, sparse_mla_fwd_interface

class AllGatherComm:
"""All gather communication with async operations"""

@@ -131,212 +44,146 @@ class AllGatherComm:
@@ -131,212 +36,145 @@ class AllGatherComm:
handle.wait()
self.handles = []

Expand Down Expand Up @@ -146,9 +147,9 @@ index 89659a1d7..1def27c69 100644
'''Forward pass for the native attention function with context parallelism'''

- # Assert einops exists
if not HAVE_EINOPS:
raise ImportError("einops is required by the attention CP but cannot be imported.")
- if not HAVE_EINOPS:
- raise ImportError("einops is required by the attention CP but cannot be imported.")
-
- # Initialize communication group and constants
cp_size = 1
if pg is not None:
Expand All @@ -164,8 +165,6 @@ index 89659a1d7..1def27c69 100644
- # Initialize KV buffers
- kv_buffer = torch.empty(
- (2, k.shape[0] * cp_size, k.shape[1], heads_k_stride, k.shape[3]),
+ s, b, heads, dim = q.shape
+ skv, _, kv_groups, _ = k.shape
+
+ k_buffer = torch.empty(
+ (k.shape[0] * cp_size, k.shape[1], 1, k.shape[3]),
Expand Down Expand Up @@ -232,9 +231,9 @@ index 89659a1d7..1def27c69 100644
+ k_i = k_buffer
+
+ s_, b_, h_, d_ = q_i.shape
+ q_i = einops.rearrange(q_i, 's b h d -> b s h d').flatten().view(b_, s_, h_, d_)
+ q_i = q_i.transpose(0, 1).flatten().view(b_, s_, h_, d_)
+ s_, b_, h_, d_ = k_i.shape
+ k_i = einops.rearrange(k_i, 's b h d -> b s h d').flatten().view(b_, s_, h_, d_)
+ k_i = k_i.transpose(0, 1).flatten().view(b_, s_, h_, d_)
Comment on lines 233 to +236
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The .flatten().view(...) pattern is used here to make the tensor contiguous after a transpose operation. While functionally correct, using .contiguous() is more explicit and readable. This pattern appears multiple times in this file (e.g., for q, k_i, dout, out in the backward pass). Consider replacing it for better code clarity.

        s_, b_, h_, d_ = q_i.shape
        q_i = q_i.transpose(0, 1).contiguous()
        s_, b_, h_, d_ = k_i.shape
        k_i = k_i.transpose(0, 1).contiguous()

+ zz_indices_i = zz_indices
+ b_, s_, g_, topk_ = zz_indices_i.shape
+ zz_indices_i = zz_indices_i.flatten().view(b_, s_, g_, topk_)
Expand All @@ -245,7 +244,8 @@ index 89659a1d7..1def27c69 100644
+ out_i, lse_i = sparse_mla_fwd_interface(q_i.contiguous(), k_i, zz_indices_i, zz_masks_i, dim_v, sm_scale = softmax_scale)
+
+ # out: [B, seq_len_shard, h, dim] -> [seq_len, B, h, dim]
+ out_i = einops.rearrange(out_i, 'b s h d -> s b h d')
+ b_, s_, h_, d_ = out_i.shape
+ out_i = out_i.transpose(0, 1).flatten().view(s_, b_, h_, d_).contiguous()
Comment on lines +247 to +248
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The .flatten().view(...) pattern already returns a contiguous tensor, so the final .contiguous() call is redundant. For improved readability and to avoid the unnecessary operation, you can replace the entire chain with .transpose(0, 1).contiguous().

        b_, s_, h_, d_ = out_i.shape
        out_i = out_i.transpose(0, 1).contiguous()

+
+ # outs: [[B, seq_len_shard, nheads // kv_group, dim], ...., [B, seq_len_shard, nheads // kv_group, dim]], repeat kv_group // heads_kv_stride times
+ # lses: [[B, seq_len_shard, heads_kv_stride], ...., [B, seq_len_shard, heads_kv_stride]], repeat kv_group // heads_kv_stride times
Expand Down Expand Up @@ -353,13 +353,13 @@ index 89659a1d7..1def27c69 100644
+ dk_list = []
+
+ s_, b_, h_, d_ = q.shape
+ q = einops.rearrange(q, 's b h d -> b s h d').flatten().view(b_, s_, h_, d_)
+ q = q.transpose(0, 1).flatten().view(b_, s_, h_, d_)
+ s_, b_, h_, d_ = k_i.shape
+ k_i = einops.rearrange(k_i, 's b h d -> b s h d').flatten().view(b_, s_, h_, d_)
+ k_i = k_i.transpose(0, 1).flatten().view(b_, s_, h_, d_)
+ s_, b_, h_, d_ = dout.shape
+ dout = einops.rearrange(dout, 's b h d -> b s h d').flatten().view(b_, s_, h_, d_)
+ dout = dout.transpose(0, 1).flatten().view(b_, s_, h_, d_)
+ s_, b_, h_, d_ = out.shape
+ out = einops.rearrange(out, 's b h d -> b s h d').flatten().view(b_, s_, h_, d_)
+ out = out.transpose(0, 1).flatten().view(b_, s_, h_, d_)
+ b_, s_, h_ = lse.shape
+ lse = lse.flatten().view(b_, s_, h_)
+ zz_indices_i = zz_indices
Expand All @@ -379,12 +379,16 @@ index 89659a1d7..1def27c69 100644
+
+ # TODO: needs casual = True, may not be compatible with zz
+ dq_i, _dk_i = sparse_mla_bwd(q_i, k_i, out_i, dout_i, zz_indices_i, zz_masks_i, lse_i, dim_v, sm_scale = softmax_scale)
+
+ b_, s_, h_, d_ = dq_i.shape
+ dq_i = dq_i.transpose(0, 1).flatten().view(s_, b_, h_, d_).contiguous()
+ b_, s_, h_, d_ = _dk_i.shape
+ _dk_i = _dk_i.transpose(0, 1).flatten().view(s_, b_, h_, d_).contiguous()
Comment on lines +383 to +386
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to the forward pass, the .flatten().view(...) pattern makes the tensor contiguous, so the trailing .contiguous() call is redundant for both dq_i and _dk_i. Consider simplifying this to .transpose(0, 1).contiguous() for better clarity and consistency.

            b_, s_, h_, d_ = dq_i.shape
            dq_i = dq_i.transpose(0, 1).contiguous()
            b_, s_, h_, d_ = _dk_i.shape
            _dk_i = _dk_i.transpose(0, 1).contiguous()


- # Rearrange gradients to (s, b, h, d)
dq_i = einops.rearrange(dq_i, 'b s h d -> s b h d')
_dk_i = einops.rearrange(_dk_i, 'b s h d -> s b h d')
- dq_i = einops.rearrange(dq_i, 'b s h d -> s b h d')
- _dk_i = einops.rearrange(_dk_i, 'b s h d -> s b h d')
- _dv_i = einops.rearrange(_dv_i, 'b s h d -> s b h d')
+
if pg is None:
dk_i = _dk_i
- dv_i = _dv_i
Expand All @@ -393,13 +397,14 @@ index 89659a1d7..1def27c69 100644
dk_i = torch.zeros(
(k_i.shape[1] // cp_size, k_i.shape[0], k_i.shape[2], k_i.shape[3]),
device=k_i.device,
dtype=k_i.dtype,
)
- dtype=k_i.dtype,
- )
- dv_i = torch.zeros(
- (v_i.shape[1] // cp_size, v_i.shape[0], v_i.shape[2], v_i.shape[3]),
- device=v_i.device,
- dtype=v_i.dtype,
- )
+ dtype=torch.float32,
)
torch.distributed.reduce_scatter_tensor(dk_i, _dk_i, group=pg)
- torch.distributed.reduce_scatter_tensor(dv_i, _dv_i, group=pg)

Expand All @@ -416,9 +421,11 @@ index 89659a1d7..1def27c69 100644
- dv = torch.cat(dv, dim=2)
- return dq, dk, dv, None, None, None, None
+ dq = torch.cat(dq_list, dim=2)
+ dk = sum(dk_list)
+ dk_ = torch.cat(dk_list, dim=2)
+ dk = torch.sum(dk_, dim=2, keepdim=True).to(torch.bfloat16)
+
+ return dq, dk, None, None, None, None, None, None
\ No newline at end of file
diff --git a/megatron/core/transformer/experimental_attention_variant/dsa.py b/megatron/core/transformer/experimental_attention_variant/dsa.py
index fc994490b..b23d2e9a8 100644
--- a/megatron/core/transformer/experimental_attention_variant/dsa.py
Expand Down Expand Up @@ -998,10 +1005,10 @@ index 000000000..d8f2425f0
+]
diff --git a/megatron/core/transformer/tilelang_kernel/sparse_mla_bwd.py b/megatron/core/transformer/tilelang_kernel/sparse_mla_bwd.py
new file mode 100644
index 000000000..b8ea416dd
index 000000000..83a259efa
--- /dev/null
+++ b/megatron/core/transformer/tilelang_kernel/sparse_mla_bwd.py
@@ -0,0 +1,274 @@
@@ -0,0 +1,272 @@
+# ruff: noqa
+import tilelang
+from tilelang import language as T
Expand Down Expand Up @@ -1267,22 +1274,20 @@ index 000000000..b8ea416dd
+ # Get kernels
+ preprocess_kernel = preprocess(B, S, H, D)
+ bwd_kernel = bwd(B, S, S_kv, H, D, D_tail, topk, kv_group, sm_scale, is_casual)
+ postprocess_kernel = postprocess(B, S_kv, D, D_tail, kv_group)
+
+ if delta is None:
+ delta = preprocess_kernel(o, do)
+ dkv = torch.zeros_like(kv, dtype=torch.float32)
+ dq = bwd_kernel(q, kv, do, indices, masks, lse, delta, dkv)
+ dkv = postprocess_kernel(dkv)
+
+ return dq, dkv
\ No newline at end of file
diff --git a/megatron/core/transformer/tilelang_kernel/sparse_mla_fwd.py b/megatron/core/transformer/tilelang_kernel/sparse_mla_fwd.py
new file mode 100644
index 000000000..d338a2fa6
index 000000000..e247038de
--- /dev/null
+++ b/megatron/core/transformer/tilelang_kernel/sparse_mla_fwd.py
@@ -0,0 +1,190 @@
@@ -0,0 +1,191 @@
+# ruff: noqa
+import torch
+import tilelang
Expand Down Expand Up @@ -1403,6 +1408,7 @@ index 000000000..d338a2fa6
+ for i_i in T.Pipelined(NI, num_stages=num_stages):
+ for bi_i in T.Parallel(BI):
+ mask[bi_i] = Masks[b_i, s_i, g_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i]]
+
+ for bi_i, d_i in T.Parallel(BI, D):
+ KV_shared[bi_i, d_i] = KV[b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i, d_i]
+ for bi_i, d_i in T.Parallel(BI, D_tail):
Expand Down