diff --git a/docker/deepseekv32/megatron.patch b/docker/deepseekv32/megatron.patch index ac7a1be3c..886b73e90 100644 --- a/docker/deepseekv32/megatron.patch +++ b/docker/deepseekv32/megatron.patch @@ -1,8 +1,8 @@ diff --git a/megatron/core/transformer/dot_product_attention_context_parallel.py b/megatron/core/transformer/dot_product_attention_context_parallel.py -index 89659a1d7..1def27c69 100644 +index 89659a1d7..c69859a04 100644 --- a/megatron/core/transformer/dot_product_attention_context_parallel.py +++ b/megatron/core/transformer/dot_product_attention_context_parallel.py -@@ -3,9 +3,12 @@ +@@ -3,107 +3,12 @@ # Some of this code was adopted from https://github.com/zhuzilin/ring-flash-attention/ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. @@ -11,14 +11,15 @@ index 89659a1d7..1def27c69 100644 import torch +import torch.distributed as dist from torch.nn import functional as F -+from .tilelang_kernel import sparse_mla_bwd, sparse_mla_fwd_interface - - try: - import einops -@@ -15,96 +18,6 @@ except ImportError: - HAVE_EINOPS = False - - +- +-try: +- import einops +- +- HAVE_EINOPS = True +-except ImportError: +- HAVE_EINOPS = False +- +- -@torch.no_grad -def eager_attn_fwd(q, k, v, attn_bias, sinks, scale, dropout): - """Forward pass for eager attention""" @@ -108,11 +109,11 @@ index 89659a1d7..1def27c69 100644 - grad_q = einops.rearrange(grad__q, 'b h s d -> b s h d') - return grad_q, grad_k, grad_v, grad_sinks - -- ++from .tilelang_kernel import sparse_mla_bwd, sparse_mla_fwd_interface + class AllGatherComm: """All gather communication with async operations""" - -@@ -131,212 +44,146 @@ class AllGatherComm: +@@ -131,212 +36,145 @@ class AllGatherComm: handle.wait() self.handles = [] @@ -146,9 +147,9 @@ index 89659a1d7..1def27c69 100644 '''Forward pass for the native attention function with context parallelism''' - # Assert einops exists - if not HAVE_EINOPS: - raise ImportError("einops is required by the attention CP but cannot be imported.") - +- if not HAVE_EINOPS: +- raise ImportError("einops is required by the attention CP but cannot be imported.") +- - # Initialize communication group and constants cp_size = 1 if pg is not None: @@ -164,8 +165,6 @@ index 89659a1d7..1def27c69 100644 - # Initialize KV buffers - kv_buffer = torch.empty( - (2, k.shape[0] * cp_size, k.shape[1], heads_k_stride, k.shape[3]), -+ s, b, heads, dim = q.shape -+ skv, _, kv_groups, _ = k.shape + + k_buffer = torch.empty( + (k.shape[0] * cp_size, k.shape[1], 1, k.shape[3]), @@ -232,9 +231,9 @@ index 89659a1d7..1def27c69 100644 + k_i = k_buffer + + s_, b_, h_, d_ = q_i.shape -+ q_i = einops.rearrange(q_i, 's b h d -> b s h d').flatten().view(b_, s_, h_, d_) ++ q_i = q_i.transpose(0, 1).flatten().view(b_, s_, h_, d_) + s_, b_, h_, d_ = k_i.shape -+ k_i = einops.rearrange(k_i, 's b h d -> b s h d').flatten().view(b_, s_, h_, d_) ++ k_i = k_i.transpose(0, 1).flatten().view(b_, s_, h_, d_) + zz_indices_i = zz_indices + b_, s_, g_, topk_ = zz_indices_i.shape + zz_indices_i = zz_indices_i.flatten().view(b_, s_, g_, topk_) @@ -245,7 +244,8 @@ index 89659a1d7..1def27c69 100644 + out_i, lse_i = sparse_mla_fwd_interface(q_i.contiguous(), k_i, zz_indices_i, zz_masks_i, dim_v, sm_scale = softmax_scale) + + # out: [B, seq_len_shard, h, dim] -> [seq_len, B, h, dim] -+ out_i = einops.rearrange(out_i, 'b s h d -> s b h d') ++ b_, s_, h_, d_ = out_i.shape ++ out_i = out_i.transpose(0, 1).flatten().view(s_, b_, h_, d_).contiguous() + + # outs: [[B, seq_len_shard, nheads // kv_group, dim], ...., [B, seq_len_shard, nheads // kv_group, dim]], repeat kv_group // heads_kv_stride times + # lses: [[B, seq_len_shard, heads_kv_stride], ...., [B, seq_len_shard, heads_kv_stride]], repeat kv_group // heads_kv_stride times @@ -353,13 +353,13 @@ index 89659a1d7..1def27c69 100644 + dk_list = [] + + s_, b_, h_, d_ = q.shape -+ q = einops.rearrange(q, 's b h d -> b s h d').flatten().view(b_, s_, h_, d_) ++ q = q.transpose(0, 1).flatten().view(b_, s_, h_, d_) + s_, b_, h_, d_ = k_i.shape -+ k_i = einops.rearrange(k_i, 's b h d -> b s h d').flatten().view(b_, s_, h_, d_) ++ k_i = k_i.transpose(0, 1).flatten().view(b_, s_, h_, d_) + s_, b_, h_, d_ = dout.shape -+ dout = einops.rearrange(dout, 's b h d -> b s h d').flatten().view(b_, s_, h_, d_) ++ dout = dout.transpose(0, 1).flatten().view(b_, s_, h_, d_) + s_, b_, h_, d_ = out.shape -+ out = einops.rearrange(out, 's b h d -> b s h d').flatten().view(b_, s_, h_, d_) ++ out = out.transpose(0, 1).flatten().view(b_, s_, h_, d_) + b_, s_, h_ = lse.shape + lse = lse.flatten().view(b_, s_, h_) + zz_indices_i = zz_indices @@ -379,12 +379,16 @@ index 89659a1d7..1def27c69 100644 + + # TODO: needs casual = True, may not be compatible with zz + dq_i, _dk_i = sparse_mla_bwd(q_i, k_i, out_i, dout_i, zz_indices_i, zz_masks_i, lse_i, dim_v, sm_scale = softmax_scale) ++ ++ b_, s_, h_, d_ = dq_i.shape ++ dq_i = dq_i.transpose(0, 1).flatten().view(s_, b_, h_, d_).contiguous() ++ b_, s_, h_, d_ = _dk_i.shape ++ _dk_i = _dk_i.transpose(0, 1).flatten().view(s_, b_, h_, d_).contiguous() - # Rearrange gradients to (s, b, h, d) - dq_i = einops.rearrange(dq_i, 'b s h d -> s b h d') - _dk_i = einops.rearrange(_dk_i, 'b s h d -> s b h d') +- dq_i = einops.rearrange(dq_i, 'b s h d -> s b h d') +- _dk_i = einops.rearrange(_dk_i, 'b s h d -> s b h d') - _dv_i = einops.rearrange(_dv_i, 'b s h d -> s b h d') -+ if pg is None: dk_i = _dk_i - dv_i = _dv_i @@ -393,13 +397,14 @@ index 89659a1d7..1def27c69 100644 dk_i = torch.zeros( (k_i.shape[1] // cp_size, k_i.shape[0], k_i.shape[2], k_i.shape[3]), device=k_i.device, - dtype=k_i.dtype, - ) +- dtype=k_i.dtype, +- ) - dv_i = torch.zeros( - (v_i.shape[1] // cp_size, v_i.shape[0], v_i.shape[2], v_i.shape[3]), - device=v_i.device, - dtype=v_i.dtype, -- ) ++ dtype=torch.float32, + ) torch.distributed.reduce_scatter_tensor(dk_i, _dk_i, group=pg) - torch.distributed.reduce_scatter_tensor(dv_i, _dv_i, group=pg) @@ -416,9 +421,11 @@ index 89659a1d7..1def27c69 100644 - dv = torch.cat(dv, dim=2) - return dq, dk, dv, None, None, None, None + dq = torch.cat(dq_list, dim=2) -+ dk = sum(dk_list) ++ dk_ = torch.cat(dk_list, dim=2) ++ dk = torch.sum(dk_, dim=2, keepdim=True).to(torch.bfloat16) + + return dq, dk, None, None, None, None, None, None +\ No newline at end of file diff --git a/megatron/core/transformer/experimental_attention_variant/dsa.py b/megatron/core/transformer/experimental_attention_variant/dsa.py index fc994490b..b23d2e9a8 100644 --- a/megatron/core/transformer/experimental_attention_variant/dsa.py @@ -998,10 +1005,10 @@ index 000000000..d8f2425f0 +] diff --git a/megatron/core/transformer/tilelang_kernel/sparse_mla_bwd.py b/megatron/core/transformer/tilelang_kernel/sparse_mla_bwd.py new file mode 100644 -index 000000000..b8ea416dd +index 000000000..83a259efa --- /dev/null +++ b/megatron/core/transformer/tilelang_kernel/sparse_mla_bwd.py -@@ -0,0 +1,274 @@ +@@ -0,0 +1,272 @@ +# ruff: noqa +import tilelang +from tilelang import language as T @@ -1267,22 +1274,20 @@ index 000000000..b8ea416dd + # Get kernels + preprocess_kernel = preprocess(B, S, H, D) + bwd_kernel = bwd(B, S, S_kv, H, D, D_tail, topk, kv_group, sm_scale, is_casual) -+ postprocess_kernel = postprocess(B, S_kv, D, D_tail, kv_group) + + if delta is None: + delta = preprocess_kernel(o, do) + dkv = torch.zeros_like(kv, dtype=torch.float32) + dq = bwd_kernel(q, kv, do, indices, masks, lse, delta, dkv) -+ dkv = postprocess_kernel(dkv) + + return dq, dkv \ No newline at end of file diff --git a/megatron/core/transformer/tilelang_kernel/sparse_mla_fwd.py b/megatron/core/transformer/tilelang_kernel/sparse_mla_fwd.py new file mode 100644 -index 000000000..d338a2fa6 +index 000000000..e247038de --- /dev/null +++ b/megatron/core/transformer/tilelang_kernel/sparse_mla_fwd.py -@@ -0,0 +1,190 @@ +@@ -0,0 +1,191 @@ +# ruff: noqa +import torch +import tilelang @@ -1403,6 +1408,7 @@ index 000000000..d338a2fa6 + for i_i in T.Pipelined(NI, num_stages=num_stages): + for bi_i in T.Parallel(BI): + mask[bi_i] = Masks[b_i, s_i, g_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i]] ++ + for bi_i, d_i in T.Parallel(BI, D): + KV_shared[bi_i, d_i] = KV[b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i, d_i] + for bi_i, d_i in T.Parallel(BI, D_tail):