-
Notifications
You must be signed in to change notification settings - Fork 81
fix bugs with non contiguous tensors while doing communication #497
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,8 +1,8 @@ | ||
| diff --git a/megatron/core/transformer/dot_product_attention_context_parallel.py b/megatron/core/transformer/dot_product_attention_context_parallel.py | ||
| index 89659a1d7..1def27c69 100644 | ||
| index 89659a1d7..c69859a04 100644 | ||
| --- a/megatron/core/transformer/dot_product_attention_context_parallel.py | ||
| +++ b/megatron/core/transformer/dot_product_attention_context_parallel.py | ||
| @@ -3,9 +3,12 @@ | ||
| @@ -3,107 +3,12 @@ | ||
| # Some of this code was adopted from https://github.com/zhuzilin/ring-flash-attention/ | ||
| # This source code is licensed under the MIT license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
@@ -11,14 +11,15 @@ index 89659a1d7..1def27c69 100644 | |
| import torch | ||
| +import torch.distributed as dist | ||
| from torch.nn import functional as F | ||
| +from .tilelang_kernel import sparse_mla_bwd, sparse_mla_fwd_interface | ||
|
|
||
| try: | ||
| import einops | ||
| @@ -15,96 +18,6 @@ except ImportError: | ||
| HAVE_EINOPS = False | ||
|
|
||
|
|
||
| - | ||
| -try: | ||
| - import einops | ||
| - | ||
| - HAVE_EINOPS = True | ||
| -except ImportError: | ||
| - HAVE_EINOPS = False | ||
| - | ||
| - | ||
| -@torch.no_grad | ||
| -def eager_attn_fwd(q, k, v, attn_bias, sinks, scale, dropout): | ||
| - """Forward pass for eager attention""" | ||
|
|
@@ -108,11 +109,11 @@ index 89659a1d7..1def27c69 100644 | |
| - grad_q = einops.rearrange(grad__q, 'b h s d -> b s h d') | ||
| - return grad_q, grad_k, grad_v, grad_sinks | ||
| - | ||
| - | ||
| +from .tilelang_kernel import sparse_mla_bwd, sparse_mla_fwd_interface | ||
|
|
||
| class AllGatherComm: | ||
| """All gather communication with async operations""" | ||
|
|
||
| @@ -131,212 +44,146 @@ class AllGatherComm: | ||
| @@ -131,212 +36,145 @@ class AllGatherComm: | ||
| handle.wait() | ||
| self.handles = [] | ||
|
|
||
|
|
@@ -146,9 +147,9 @@ index 89659a1d7..1def27c69 100644 | |
| '''Forward pass for the native attention function with context parallelism''' | ||
|
|
||
| - # Assert einops exists | ||
| if not HAVE_EINOPS: | ||
| raise ImportError("einops is required by the attention CP but cannot be imported.") | ||
| - if not HAVE_EINOPS: | ||
| - raise ImportError("einops is required by the attention CP but cannot be imported.") | ||
| - | ||
| - # Initialize communication group and constants | ||
| cp_size = 1 | ||
| if pg is not None: | ||
|
|
@@ -164,8 +165,6 @@ index 89659a1d7..1def27c69 100644 | |
| - # Initialize KV buffers | ||
| - kv_buffer = torch.empty( | ||
| - (2, k.shape[0] * cp_size, k.shape[1], heads_k_stride, k.shape[3]), | ||
| + s, b, heads, dim = q.shape | ||
| + skv, _, kv_groups, _ = k.shape | ||
| + | ||
| + k_buffer = torch.empty( | ||
| + (k.shape[0] * cp_size, k.shape[1], 1, k.shape[3]), | ||
|
|
@@ -232,9 +231,9 @@ index 89659a1d7..1def27c69 100644 | |
| + k_i = k_buffer | ||
| + | ||
| + s_, b_, h_, d_ = q_i.shape | ||
| + q_i = einops.rearrange(q_i, 's b h d -> b s h d').flatten().view(b_, s_, h_, d_) | ||
| + q_i = q_i.transpose(0, 1).flatten().view(b_, s_, h_, d_) | ||
| + s_, b_, h_, d_ = k_i.shape | ||
| + k_i = einops.rearrange(k_i, 's b h d -> b s h d').flatten().view(b_, s_, h_, d_) | ||
| + k_i = k_i.transpose(0, 1).flatten().view(b_, s_, h_, d_) | ||
| + zz_indices_i = zz_indices | ||
| + b_, s_, g_, topk_ = zz_indices_i.shape | ||
| + zz_indices_i = zz_indices_i.flatten().view(b_, s_, g_, topk_) | ||
|
|
@@ -245,7 +244,8 @@ index 89659a1d7..1def27c69 100644 | |
| + out_i, lse_i = sparse_mla_fwd_interface(q_i.contiguous(), k_i, zz_indices_i, zz_masks_i, dim_v, sm_scale = softmax_scale) | ||
| + | ||
| + # out: [B, seq_len_shard, h, dim] -> [seq_len, B, h, dim] | ||
| + out_i = einops.rearrange(out_i, 'b s h d -> s b h d') | ||
| + b_, s_, h_, d_ = out_i.shape | ||
| + out_i = out_i.transpose(0, 1).flatten().view(s_, b_, h_, d_).contiguous() | ||
|
Comment on lines
+247
to
+248
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
| + | ||
| + # outs: [[B, seq_len_shard, nheads // kv_group, dim], ...., [B, seq_len_shard, nheads // kv_group, dim]], repeat kv_group // heads_kv_stride times | ||
| + # lses: [[B, seq_len_shard, heads_kv_stride], ...., [B, seq_len_shard, heads_kv_stride]], repeat kv_group // heads_kv_stride times | ||
|
|
@@ -353,13 +353,13 @@ index 89659a1d7..1def27c69 100644 | |
| + dk_list = [] | ||
| + | ||
| + s_, b_, h_, d_ = q.shape | ||
| + q = einops.rearrange(q, 's b h d -> b s h d').flatten().view(b_, s_, h_, d_) | ||
| + q = q.transpose(0, 1).flatten().view(b_, s_, h_, d_) | ||
| + s_, b_, h_, d_ = k_i.shape | ||
| + k_i = einops.rearrange(k_i, 's b h d -> b s h d').flatten().view(b_, s_, h_, d_) | ||
| + k_i = k_i.transpose(0, 1).flatten().view(b_, s_, h_, d_) | ||
| + s_, b_, h_, d_ = dout.shape | ||
| + dout = einops.rearrange(dout, 's b h d -> b s h d').flatten().view(b_, s_, h_, d_) | ||
| + dout = dout.transpose(0, 1).flatten().view(b_, s_, h_, d_) | ||
| + s_, b_, h_, d_ = out.shape | ||
| + out = einops.rearrange(out, 's b h d -> b s h d').flatten().view(b_, s_, h_, d_) | ||
| + out = out.transpose(0, 1).flatten().view(b_, s_, h_, d_) | ||
| + b_, s_, h_ = lse.shape | ||
| + lse = lse.flatten().view(b_, s_, h_) | ||
| + zz_indices_i = zz_indices | ||
|
|
@@ -379,12 +379,16 @@ index 89659a1d7..1def27c69 100644 | |
| + | ||
| + # TODO: needs casual = True, may not be compatible with zz | ||
| + dq_i, _dk_i = sparse_mla_bwd(q_i, k_i, out_i, dout_i, zz_indices_i, zz_masks_i, lse_i, dim_v, sm_scale = softmax_scale) | ||
| + | ||
| + b_, s_, h_, d_ = dq_i.shape | ||
| + dq_i = dq_i.transpose(0, 1).flatten().view(s_, b_, h_, d_).contiguous() | ||
| + b_, s_, h_, d_ = _dk_i.shape | ||
| + _dk_i = _dk_i.transpose(0, 1).flatten().view(s_, b_, h_, d_).contiguous() | ||
|
Comment on lines
+383
to
+386
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar to the forward pass, the |
||
|
|
||
| - # Rearrange gradients to (s, b, h, d) | ||
| dq_i = einops.rearrange(dq_i, 'b s h d -> s b h d') | ||
| _dk_i = einops.rearrange(_dk_i, 'b s h d -> s b h d') | ||
| - dq_i = einops.rearrange(dq_i, 'b s h d -> s b h d') | ||
| - _dk_i = einops.rearrange(_dk_i, 'b s h d -> s b h d') | ||
| - _dv_i = einops.rearrange(_dv_i, 'b s h d -> s b h d') | ||
| + | ||
| if pg is None: | ||
| dk_i = _dk_i | ||
| - dv_i = _dv_i | ||
|
|
@@ -393,13 +397,14 @@ index 89659a1d7..1def27c69 100644 | |
| dk_i = torch.zeros( | ||
| (k_i.shape[1] // cp_size, k_i.shape[0], k_i.shape[2], k_i.shape[3]), | ||
| device=k_i.device, | ||
| dtype=k_i.dtype, | ||
| ) | ||
| - dtype=k_i.dtype, | ||
| - ) | ||
| - dv_i = torch.zeros( | ||
| - (v_i.shape[1] // cp_size, v_i.shape[0], v_i.shape[2], v_i.shape[3]), | ||
| - device=v_i.device, | ||
| - dtype=v_i.dtype, | ||
| - ) | ||
| + dtype=torch.float32, | ||
| ) | ||
| torch.distributed.reduce_scatter_tensor(dk_i, _dk_i, group=pg) | ||
| - torch.distributed.reduce_scatter_tensor(dv_i, _dv_i, group=pg) | ||
|
|
||
|
|
@@ -416,9 +421,11 @@ index 89659a1d7..1def27c69 100644 | |
| - dv = torch.cat(dv, dim=2) | ||
| - return dq, dk, dv, None, None, None, None | ||
| + dq = torch.cat(dq_list, dim=2) | ||
| + dk = sum(dk_list) | ||
| + dk_ = torch.cat(dk_list, dim=2) | ||
| + dk = torch.sum(dk_, dim=2, keepdim=True).to(torch.bfloat16) | ||
| + | ||
| + return dq, dk, None, None, None, None, None, None | ||
| \ No newline at end of file | ||
| diff --git a/megatron/core/transformer/experimental_attention_variant/dsa.py b/megatron/core/transformer/experimental_attention_variant/dsa.py | ||
| index fc994490b..b23d2e9a8 100644 | ||
| --- a/megatron/core/transformer/experimental_attention_variant/dsa.py | ||
|
|
@@ -998,10 +1005,10 @@ index 000000000..d8f2425f0 | |
| +] | ||
| diff --git a/megatron/core/transformer/tilelang_kernel/sparse_mla_bwd.py b/megatron/core/transformer/tilelang_kernel/sparse_mla_bwd.py | ||
| new file mode 100644 | ||
| index 000000000..b8ea416dd | ||
| index 000000000..83a259efa | ||
| --- /dev/null | ||
| +++ b/megatron/core/transformer/tilelang_kernel/sparse_mla_bwd.py | ||
| @@ -0,0 +1,274 @@ | ||
| @@ -0,0 +1,272 @@ | ||
| +# ruff: noqa | ||
| +import tilelang | ||
| +from tilelang import language as T | ||
|
|
@@ -1267,22 +1274,20 @@ index 000000000..b8ea416dd | |
| + # Get kernels | ||
| + preprocess_kernel = preprocess(B, S, H, D) | ||
| + bwd_kernel = bwd(B, S, S_kv, H, D, D_tail, topk, kv_group, sm_scale, is_casual) | ||
| + postprocess_kernel = postprocess(B, S_kv, D, D_tail, kv_group) | ||
| + | ||
| + if delta is None: | ||
| + delta = preprocess_kernel(o, do) | ||
| + dkv = torch.zeros_like(kv, dtype=torch.float32) | ||
| + dq = bwd_kernel(q, kv, do, indices, masks, lse, delta, dkv) | ||
| + dkv = postprocess_kernel(dkv) | ||
| + | ||
| + return dq, dkv | ||
| \ No newline at end of file | ||
| diff --git a/megatron/core/transformer/tilelang_kernel/sparse_mla_fwd.py b/megatron/core/transformer/tilelang_kernel/sparse_mla_fwd.py | ||
| new file mode 100644 | ||
| index 000000000..d338a2fa6 | ||
| index 000000000..e247038de | ||
| --- /dev/null | ||
| +++ b/megatron/core/transformer/tilelang_kernel/sparse_mla_fwd.py | ||
| @@ -0,0 +1,190 @@ | ||
| @@ -0,0 +1,191 @@ | ||
| +# ruff: noqa | ||
| +import torch | ||
| +import tilelang | ||
|
|
@@ -1403,6 +1408,7 @@ index 000000000..d338a2fa6 | |
| + for i_i in T.Pipelined(NI, num_stages=num_stages): | ||
| + for bi_i in T.Parallel(BI): | ||
| + mask[bi_i] = Masks[b_i, s_i, g_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i]] | ||
| + | ||
| + for bi_i, d_i in T.Parallel(BI, D): | ||
| + KV_shared[bi_i, d_i] = KV[b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i, d_i] | ||
| + for bi_i, d_i in T.Parallel(BI, D_tail): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
.flatten().view(...)pattern is used here to make the tensor contiguous after atransposeoperation. While functionally correct, using.contiguous()is more explicit and readable. This pattern appears multiple times in this file (e.g., forq,k_i,dout,outin thebackwardpass). Consider replacing it for better code clarity.