-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Enable DSA CP/absorbed/THD paths with TileLang fused ops #3674
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
HollowMan6
wants to merge
6
commits into
NVIDIA:main
Choose a base branch
from
HollowMan6:dsa_cp_thd
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+3,950
−170
Draft
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
7a16b4d
Enable DSA CP/absorbed/THD paths with TileLang fused ops
HollowMan6 1223eb5
Add cache for compiled kernel
HollowMan6 bf05dad
Increase chunk size for better peformance
HollowMan6 30203d0
fix tilelang recompile under pp>1
HollowMan6 f133b53
threading lock for compiling
HollowMan6 9ad3f65
No recompile
HollowMan6 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
2,035 changes: 1,872 additions & 163 deletions
2,035
megatron/core/transformer/experimental_attention_variant/dsa.py
Large diffs are not rendered by default.
Oops, something went wrong.
80 changes: 80 additions & 0 deletions
80
megatron/core/transformer/experimental_attention_variant/ops/indexer.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,80 @@ | ||
| import torch | ||
|
|
||
| from .tilelang_indexer_bwd import indexer_bwd_interface | ||
| from .tilelang_indexer_fwd import indexer_fwd_interface | ||
|
|
||
|
|
||
| def pytorch_extract_topk_scores(logits, topk_indices, dim=-1): | ||
| """Gather top-k logits and mask invalid (-1) entries with -inf.""" | ||
| valid_mask = topk_indices != -1 | ||
| safe_indices = topk_indices.clamp(min=0).to(torch.int64) | ||
| scores = torch.gather(logits, dim=dim, index=safe_indices) | ||
| scores = torch.where(valid_mask, scores, float("-inf")) | ||
| return scores | ||
|
|
||
|
|
||
| class IndexerFunction(torch.autograd.Function): | ||
| """Autograd wrapper for fused tilelang indexer forward/backward.""" | ||
|
|
||
| @staticmethod | ||
| def forward( | ||
| ctx, | ||
| index_q: torch.Tensor, | ||
| index_k: torch.Tensor, | ||
| weights: torch.Tensor, | ||
| cu_seqlen_ks: torch.Tensor, | ||
| cu_seqlen_ke: torch.Tensor, | ||
| topk: int, | ||
| topk_indices: torch.Tensor | None = None, | ||
| ): | ||
| """Run fused indexer forward and optionally select top-k indices.""" | ||
| _, head_num, _ = index_q.shape | ||
| logits = indexer_fwd_interface( | ||
| index_q, index_k, weights, cu_seqlen_ks, cu_seqlen_ke, clean_logits=True | ||
| ) | ||
| if topk_indices is None: | ||
| index_score, topk_indices = torch.topk(logits, topk, dim=-1) | ||
| topk_indices = topk_indices.to(torch.int32) | ||
| topk_indices = topk_indices.masked_fill(index_score == -torch.inf, -1) | ||
|
|
||
| index_score = pytorch_extract_topk_scores(logits, topk_indices) | ||
|
|
||
| ctx.save_for_backward(index_q, index_k, weights, cu_seqlen_ks, cu_seqlen_ke, topk_indices) | ||
| ctx.topk = topk | ||
| ctx.head_num = head_num | ||
| return index_score, topk_indices | ||
|
|
||
| @staticmethod | ||
| def backward(ctx, grad_scores, grad_indices): | ||
| """Propagate gradients through fused indexer outputs.""" | ||
| index_q, index_k, weights, cu_seqlen_ks, cu_seqlen_ke, topk_indices = ctx.saved_tensors | ||
| grad_q, grad_w, grad_k = indexer_bwd_interface( | ||
| index_q, weights, index_k, topk_indices, grad_scores | ||
| ) | ||
| return grad_q, grad_k, grad_w, None, None, None, None | ||
|
|
||
|
|
||
| def lighting_indexer( | ||
| index_q: torch.Tensor, | ||
| index_k: torch.Tensor, | ||
| weights: torch.Tensor, | ||
| cu_seqlen_ks: torch.Tensor, | ||
| cu_seqlen_ke: torch.Tensor, | ||
| topk: int, | ||
| topk_indices: torch.Tensor | None = None, | ||
| ): | ||
| """Compute indexer top-k scores/indices via the custom autograd function.""" | ||
| return IndexerFunction.apply( | ||
| index_q, index_k, weights.squeeze(-1), cu_seqlen_ks, cu_seqlen_ke, topk, topk_indices | ||
| ) | ||
|
|
||
|
|
||
| def generate_varlen_mask_params(cu_seqlens): | ||
| """Generate inclusive-exclusive [start, end) bounds for each query row.""" | ||
| seq_len = cu_seqlens[-1].item() | ||
| q_indices = torch.arange(0, seq_len, device=cu_seqlens.device) | ||
| seq_indices = torch.searchsorted(cu_seqlens, q_indices, right=True) - 1 | ||
| starts = cu_seqlens[seq_indices] | ||
| ends = q_indices + 1 | ||
| assert torch.all((ends - starts) > 0) | ||
| return starts, ends |
48 changes: 48 additions & 0 deletions
48
megatron/core/transformer/experimental_attention_variant/ops/sparse_mla.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,48 @@ | ||
| import torch | ||
|
|
||
| from .tilelang_sparse_mla_bwd import sparse_mla_bwd | ||
| from .tilelang_sparse_mla_fwd import sparse_mla_fwd_interface | ||
|
|
||
|
|
||
| class SparseMLA(torch.autograd.Function): | ||
| """Autograd wrapper around tilelang sparse-MLA forward/backward kernels.""" | ||
|
|
||
| @staticmethod | ||
| def forward(ctx, q, kv, indices, scaling): | ||
| """ | ||
| Args: | ||
| q: Query tensor (seq_len, heads, dim_plus_tail_dim) | ||
| kv: Key-Value tensor (seq_len_kv, kv_group, dim_plus_tail_dim) | ||
| indices: Sparse indices tensor (seq_len, kv_group, topk) | ||
|
|
||
| Returns: | ||
| out: Output tensor (seq_len, heads, dim) | ||
| """ | ||
| indices = indices.contiguous() | ||
| q, kv = q.contiguous(), kv.contiguous() | ||
| ctx.scaling = scaling | ||
| tl_out, tl_lse = sparse_mla_fwd_interface(q, kv, indices, sm_scale=scaling) | ||
|
|
||
| # Save tensors for backward pass | ||
| ctx.save_for_backward(q, kv, indices, tl_out, tl_lse) | ||
|
|
||
| return tl_out, tl_lse | ||
|
|
||
| @staticmethod | ||
| def backward(ctx, grad_output, grad_lse): | ||
| """ | ||
| Args: | ||
| grad_output: Gradient of the loss with respect to output | ||
|
|
||
| Returns: | ||
| Gradients for q, kv, and indices (None for indices) | ||
| """ | ||
| q, kv, indices, tl_out, tl_lse = ctx.saved_tensors | ||
| scaling = ctx.scaling | ||
|
|
||
| tl_dq, tl_dkv = sparse_mla_bwd( | ||
| q, kv, tl_out, grad_output.contiguous(), indices, tl_lse, sm_scale=scaling | ||
| ) | ||
|
|
||
| # Return gradients for each input (None for indices as it's not differentiable) | ||
| return tl_dq, tl_dkv, None, None |
238 changes: 238 additions & 0 deletions
238
megatron/core/transformer/experimental_attention_variant/ops/tilelang_indexer_bwd.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,238 @@ | ||
| # ruff: noqa | ||
| # Adapted from: | ||
| # https://github.com/tile-ai/tilelang/blob/4956b5835fa554af6c03d4a6289cad44bf310869/ | ||
| # examples/dsa_sparse_finetune/indexer_bwd.py | ||
| import os | ||
| import threading | ||
| from collections import OrderedDict | ||
|
|
||
| import tilelang as tl | ||
| import tilelang.language as T | ||
| import torch | ||
|
|
||
| BF16 = T.bfloat16 | ||
| FP32 = T.float32 | ||
| INT32 = T.int32 | ||
| _tilelang_indexer_bwd_kernel_cache = OrderedDict() | ||
| _tilelang_indexer_bwd_cache_lock = threading.Lock() | ||
|
|
||
| pass_configs = { | ||
| tl.PassConfigKey.TL_DISABLE_TMA_LOWER: True, | ||
| tl.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, | ||
| } | ||
|
|
||
|
|
||
| def _env_int(name: str, default: int) -> int: | ||
| value = os.getenv(name) | ||
| if value is None: | ||
| return default | ||
| try: | ||
| parsed = int(value) | ||
| except ValueError: | ||
| return default | ||
| return parsed if parsed > 0 else default | ||
|
|
||
|
|
||
| _TILELANG_KERNEL_CACHE_MAX = _env_int("MCORE_DSA_TILELANG_KERNEL_CACHE_MAX", 512) | ||
|
|
||
|
|
||
| def _cache_put_lru(cache: OrderedDict, key, value): | ||
| cache[key] = value | ||
| cache.move_to_end(key) | ||
| while len(cache) > _TILELANG_KERNEL_CACHE_MAX: | ||
| cache.popitem(last=False) | ||
|
|
||
|
|
||
| def _next_power_of_two(x: int) -> int: | ||
| if x <= 1: | ||
| return 1 | ||
| return 1 << (x - 1).bit_length() | ||
|
|
||
|
|
||
| def _round_up(x: int, multiple: int) -> int: | ||
| return ((x + multiple - 1) // multiple) * multiple | ||
|
|
||
|
|
||
| def _canonical_topk(topk: int, block_i: int = 32) -> int: | ||
| return _round_up(_next_power_of_two(topk), block_i) | ||
|
|
||
|
|
||
| def _get_indexer_bwd_kernel(heads: int, dim: int, topk: int): | ||
| key = (heads, dim, topk) | ||
| with _tilelang_indexer_bwd_cache_lock: | ||
| kernel = _tilelang_indexer_bwd_kernel_cache.pop(key, None) | ||
| if kernel is None: | ||
| kernel = tl_indexer_bwd_impl(heads, dim, topk) | ||
| _cache_put_lru(_tilelang_indexer_bwd_kernel_cache, key, kernel) | ||
| return kernel | ||
|
|
||
|
|
||
| @tl.jit(pass_configs=pass_configs) | ||
| def tl_indexer_bwd_impl( | ||
| heads: int, dim: int, topk: int, block_I: int = 32, num_stages: int = 0, num_threads: int = 128 | ||
| ): | ||
| """Build tilelang backward kernel for sparse indexer.""" | ||
| assert num_stages == 0 | ||
| assert topk == tl.math.next_power_of_2(topk) | ||
| assert topk % block_I == 0 | ||
| assert heads <= 64 and heads % 8 == 0 | ||
| seq_len = T.symbolic("seq_len") | ||
| q_seq_len = T.symbolic("q_seq_len") | ||
|
|
||
| dtype: str = BF16 | ||
| accum_dtype: str = FP32 | ||
| index_q_shape = [q_seq_len, heads, dim] | ||
| weights_shape = [q_seq_len, heads] | ||
| index_k_shape = [seq_len, dim] | ||
| shape_p = [q_seq_len, topk] | ||
| topk_indices_shape = [q_seq_len, topk] | ||
|
|
||
| pad_heads = heads | ||
| if heads < 16: | ||
| pad_heads = 16 | ||
|
|
||
| @T.prim_func | ||
| def tl_indexer_bwd_kernel( | ||
| IndexQ: T.Tensor(index_q_shape, dtype), | ||
| IndexK: T.Tensor(index_k_shape, dtype), | ||
| Weights: T.Tensor(weights_shape, FP32), | ||
| TopkIndices: T.Tensor(topk_indices_shape, INT32), | ||
| OGrad: T.Tensor(shape_p, FP32), | ||
| dIndexQ: T.Tensor(index_q_shape, dtype), | ||
| dWeights: T.Tensor(weights_shape, FP32), | ||
| dIndexK: T.Tensor(index_k_shape, FP32), | ||
| ): | ||
|
|
||
| with T.Kernel(q_seq_len, threads=num_threads) as (bx): | ||
| index_q_shared = T.alloc_shared([pad_heads, dim], dtype=FP32) | ||
| weights_shared = T.alloc_shared([pad_heads], dtype=FP32) | ||
| index_k_shared = T.alloc_shared([block_I, dim], dtype=FP32) | ||
| indices_shared = T.alloc_shared([block_I], dtype=INT32) | ||
| d_index_q_frag = T.alloc_fragment([pad_heads, dim], dtype=accum_dtype) | ||
| d_weights_frag = T.alloc_fragment([pad_heads], dtype=accum_dtype) | ||
| d_index_k_frag = T.alloc_fragment([block_I, dim], dtype=accum_dtype) | ||
| logits = T.alloc_fragment((block_I, pad_heads), dtype=accum_dtype) | ||
| _logits = T.alloc_shared((block_I, pad_heads), dtype=accum_dtype) | ||
| grad = T.alloc_shared([block_I], dtype=FP32) | ||
|
|
||
| num_blocks = T.ceildiv(topk, block_I) | ||
| for i, j in T.Parallel(pad_heads, dim): | ||
| index_q_shared[i, j] = T.if_then_else(i < heads, IndexQ[bx, i, j], 0) | ||
| for i in T.Parallel(heads): | ||
| weights_shared[i] = Weights[bx, i] | ||
|
|
||
| T.fill(d_index_q_frag, 0) | ||
| T.fill(d_weights_frag, 0) | ||
|
|
||
| # for bi_i in T.Pipelined(num_blocks, num_stages=num_stages): | ||
| for bi_i in T.serial(num_blocks): | ||
| for i in T.Parallel(block_I): | ||
| if bi_i * block_I + i < topk: | ||
| indices_shared[i] = TopkIndices[bx, bi_i * block_I + i] | ||
| grad[i] = OGrad[bx, bi_i * block_I + i] | ||
|
|
||
| T.sync_threads() | ||
| for i, j in T.Parallel(block_I, dim): | ||
| index_k_shared[i, j] = T.if_then_else( | ||
| indices_shared[i] > -1 and indices_shared[i] < seq_len, | ||
| IndexK[indices_shared[i], j], | ||
| 0, | ||
| ) | ||
|
|
||
| T.sync_threads() | ||
| T.gemm( | ||
| index_k_shared, | ||
| index_q_shared, | ||
| logits, | ||
| transpose_A=False, | ||
| transpose_B=True, | ||
| clear_accum=True, | ||
| ) | ||
| for i, j in T.Parallel(block_I, heads): | ||
| logits[i, j] = T.max(logits[i, j], 0) | ||
|
|
||
| d_weights_i = T.alloc_fragment((block_I, pad_heads), accum_dtype) | ||
| for i, j in T.Parallel(block_I, heads): | ||
| d_weights_i[i, j] = grad[i] * logits[i, j] | ||
| T.reduce_sum(d_weights_i, d_weights_frag, dim=0, clear=False) | ||
|
|
||
| for i, j in T.Parallel(block_I, pad_heads): | ||
| _logits[i, j] = T.if_then_else( | ||
| logits[i, j] > 0 and j < heads, grad[i] * weights_shared[j], 0 | ||
| ) | ||
| T.sync_threads() | ||
| T.gemm( | ||
| _logits, | ||
| index_k_shared, | ||
| d_index_q_frag, | ||
| transpose_A=True, | ||
| transpose_B=False, | ||
| clear_accum=False, | ||
| ) | ||
|
|
||
| T.gemm( | ||
| _logits, | ||
| index_q_shared, | ||
| d_index_k_frag, | ||
| transpose_A=False, | ||
| transpose_B=False, | ||
| clear_accum=True, | ||
| ) | ||
|
|
||
| for i, j in T.Parallel(block_I, dim): | ||
| if indices_shared[i] > -1 and indices_shared[i] < seq_len: | ||
| T.atomic_add(dIndexK[indices_shared[i], j], d_index_k_frag[i, j]) | ||
|
|
||
| T.copy(d_index_q_frag[:heads, :], dIndexQ[bx, :, :]) | ||
| T.copy(d_weights_frag[:heads], dWeights[bx, :]) | ||
|
|
||
| return tl_indexer_bwd_kernel | ||
|
|
||
|
|
||
| def indexer_bwd_interface( | ||
| index_q: torch.Tensor, | ||
| weights: torch.Tensor, | ||
| index_k: torch.Tensor, | ||
| topk_indices: torch.Tensor, | ||
| grad_scores: torch.Tensor, | ||
| ): | ||
| """Run indexer backward kernel and return gradients for q/w/k.""" | ||
| _, head_num, head_dim = index_q.shape | ||
| k_top = int(topk_indices.shape[1]) | ||
| assert k_top > 0, "topk must be positive" | ||
| padded_topk = _canonical_topk(k_top) | ||
|
|
||
| if padded_topk != k_top: | ||
| padded_indices = torch.full( | ||
| (topk_indices.size(0), padded_topk), | ||
| -1, | ||
| dtype=topk_indices.dtype, | ||
| device=topk_indices.device, | ||
| ) | ||
| padded_indices[:, :k_top].copy_(topk_indices) | ||
| topk_indices = padded_indices | ||
|
|
||
| padded_grad_scores = torch.zeros( | ||
| (grad_scores.size(0), padded_topk), dtype=grad_scores.dtype, device=grad_scores.device | ||
| ) | ||
| padded_grad_scores[:, :k_top].copy_(grad_scores) | ||
| grad_scores = padded_grad_scores | ||
|
|
||
| grad_scores = grad_scores.contiguous() | ||
| grad_q = torch.empty_like(index_q) | ||
| grad_w = torch.empty_like(weights, dtype=torch.float32) | ||
| grad_k = torch.zeros_like(index_k, dtype=torch.float32) | ||
|
|
||
| bwd_kernel = _get_indexer_bwd_kernel(head_num, head_dim, padded_topk) | ||
| bwd_kernel( | ||
| index_q.contiguous(), | ||
| index_k.contiguous(), | ||
| weights.squeeze(-1).contiguous(), | ||
| topk_indices.contiguous(), | ||
| grad_scores, | ||
| grad_q, | ||
| grad_w.squeeze(-1), | ||
| grad_k, | ||
| ) | ||
|
|
||
| return grad_q, grad_w, grad_k | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.