From 7a16b4d3ee2c059b3cd9cc3dbc70498faa44a7a5 Mon Sep 17 00:00:00 2001 From: Hollow Man Date: Tue, 3 Mar 2026 12:12:20 +0200 Subject: [PATCH 1/6] Enable DSA CP/absorbed/THD paths with TileLang fused ops This PR upgrades the DSA path end-to-end to support context parallel (allgather CP) with THD packing support, absorbed MLA integration, and fused TileLang kernels with safe fallbacks. Signed-off-by: Hollow Man --- .../experimental_attention_variant/dsa.py | 2006 +++++++++++++++-- .../ops/indexer.py | 80 + .../ops/sparse_mla.py | 48 + .../ops/tilelang_indexer_bwd.py | 168 ++ .../ops/tilelang_indexer_fwd.py | 132 ++ .../ops/tilelang_sparse_mla_bwd.py | 355 +++ .../ops/tilelang_sparse_mla_fwd.py | 222 ++ .../transformer/multi_latent_attention.py | 90 + .../core/transformer/transformer_config.py | 3 - .../transformer/test_attention_variant_dsa.py | 554 ++++- 10 files changed, 3498 insertions(+), 160 deletions(-) create mode 100644 megatron/core/transformer/experimental_attention_variant/ops/indexer.py create mode 100644 megatron/core/transformer/experimental_attention_variant/ops/sparse_mla.py create mode 100644 megatron/core/transformer/experimental_attention_variant/ops/tilelang_indexer_bwd.py create mode 100644 megatron/core/transformer/experimental_attention_variant/ops/tilelang_indexer_fwd.py create mode 100644 megatron/core/transformer/experimental_attention_variant/ops/tilelang_sparse_mla_bwd.py create mode 100644 megatron/core/transformer/experimental_attention_variant/ops/tilelang_sparse_mla_fwd.py diff --git a/megatron/core/transformer/experimental_attention_variant/dsa.py b/megatron/core/transformer/experimental_attention_variant/dsa.py index 3734db7043f..9ce12433725 100644 --- a/megatron/core/transformer/experimental_attention_variant/dsa.py +++ b/megatron/core/transformer/experimental_attention_variant/dsa.py @@ -1,7 +1,9 @@ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. import copy +import logging import math +from collections import OrderedDict from dataclasses import dataclass from typing import Optional, Tuple, Union @@ -21,11 +23,1212 @@ from megatron.core.transformer.spec_utils import ModuleSpec, build_module from megatron.core.transformer.transformer_config import TransformerConfig +logger = logging.getLogger(__name__) + try: from fast_hadamard_transform import hadamard_transform except ImportError: hadamard_transform = None +try: + from megatron.core.transformer.experimental_attention_variant.ops.indexer import ( + lighting_indexer, + ) +except Exception: + lighting_indexer = None + +try: + from megatron.core.transformer.experimental_attention_variant.ops.sparse_mla import SparseMLA +except Exception: + SparseMLA = None + +# Reusable no-grad scratch buffers keyed by (name, shape, dtype, device). +_DSA_SCRATCH_CACHE_MAX_ENTRIES = 128 +_DSA_SCRATCH_CACHE_MAX_BYTES = 512 * 1024 * 1024 +_DSA_SCRATCH_CACHE = OrderedDict() + + +def _scratch_cache_total_bytes() -> int: + """Return total bytes held by cached scratch tensors.""" + return sum(buf.numel() * buf.element_size() for buf in _DSA_SCRATCH_CACHE.values()) + + +def _evict_scratch_cache_if_needed() -> None: + """Bound scratch cache growth by LRU eviction.""" + while ( + len(_DSA_SCRATCH_CACHE) > _DSA_SCRATCH_CACHE_MAX_ENTRIES + or _scratch_cache_total_bytes() > _DSA_SCRATCH_CACHE_MAX_BYTES + ): + _DSA_SCRATCH_CACHE.popitem(last=False) + + +def _get_scratch_buffer( + name: str, shape: Tuple[int, ...], dtype: torch.dtype, device: torch.device +) -> torch.Tensor: + """Get a reusable scratch tensor for temporary no-grad workspaces.""" + key = (name, shape, dtype, device) + buf = _DSA_SCRATCH_CACHE.pop(key, None) + if buf is None: + buf = torch.empty(shape, dtype=dtype, device=device) + _DSA_SCRATCH_CACHE[key] = buf + _evict_scratch_cache_if_needed() + return buf + + +def _normalize_cp_comm_type(cp_comm_type: Optional[str]) -> str: + """Normalize CP communication type to a canonical lowercase form.""" + if cp_comm_type is None: + return "p2p" + return cp_comm_type.replace("_", "").lower() + + +def _get_cp_positions_from_layout( + sq: int, + skv: int, + cp_size: int, + cp_rank: int, + cp_comm_type: Optional[str], + device: torch.device, + cp_group: Optional[torch.distributed.ProcessGroup] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Infer query/key global token positions under CP layout. + + This helper currently supports allgather CP layout, where each rank owns a + contiguous query chunk and sees gathered keys in global order. + """ + if cp_size <= 1: + query_pos = torch.arange(sq, device=device, dtype=torch.int64) + key_pos = torch.arange(skv, device=device, dtype=torch.int64) + return query_pos, key_pos + + if _normalize_cp_comm_type(cp_comm_type) != "allgather": + raise NotImplementedError( + "DSAttention context parallelism currently supports cp_comm_type=allgather only." + ) + + # Avoid assuming uniform per-rank query lengths (cp_rank * sq). When available, + # gather local lengths to build the true global offset for this CP rank. + query_offset = cp_rank * sq + if ( + cp_group is not None + and torch.distributed.is_available() + and torch.distributed.is_initialized() + and cp_group.size() == cp_size + ): + local_len = torch.tensor([sq], device=device, dtype=torch.int64) + all_lens = [torch.empty_like(local_len) for _ in range(cp_size)] + torch.distributed.all_gather(all_lens, local_len, group=cp_group) + if cp_rank > 0: + query_offset = int(torch.stack(all_lens[:cp_rank]).sum().item()) + else: + query_offset = 0 + + query_pos = torch.arange(sq, device=device, dtype=torch.int64) + query_offset + key_pos = torch.arange(skv, device=device, dtype=torch.int64) + return query_pos, key_pos + + +def _build_causal_mask_from_positions( + query_pos: torch.Tensor, key_pos: torch.Tensor +) -> torch.Tensor: + """Build a causal mask from explicit query/key global positions.""" + assert query_pos.dtype in (torch.int32, torch.int64), "query_pos must be integer tensor" + assert key_pos.dtype in (torch.int32, torch.int64), "key_pos must be integer tensor" + assert query_pos.device == key_pos.device, "query_pos and key_pos must be on the same device" + + # mask[q, k] = -inf if key_pos[k] > query_pos[q], else 0. + invalid = key_pos.unsqueeze(0) > query_pos.unsqueeze(-1) + mask = torch.zeros( + (query_pos.numel(), key_pos.numel()), dtype=torch.float32, device=query_pos.device + ) + mask.masked_fill_(invalid, float("-inf")) + return mask + + +def _generate_varlen_mask_params(cu_seqlens: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Generate row-wise [start, end) key bounds for packed causal masking.""" + assert cu_seqlens.ndim == 1 and cu_seqlens.numel() >= 2, "invalid cu_seqlens" + cu_seqlens = cu_seqlens.to(dtype=torch.int64) + seq_len = int(cu_seqlens[-1].item()) + q_indices = torch.arange(seq_len, dtype=torch.int64, device=cu_seqlens.device) + seq_indices = torch.searchsorted(cu_seqlens, q_indices, right=True) - 1 + starts = cu_seqlens[seq_indices] + ends = q_indices + 1 + return starts, ends + + +def _build_valid_mask_from_starts_ends( + starts: torch.Tensor, ends: torch.Tensor, key_positions: torch.Tensor +) -> torch.Tensor: + """Build boolean validity mask [sq, sk] from row-wise [start, end) bounds.""" + assert starts.ndim == ends.ndim == 1, "starts/ends must be 1D" + assert starts.shape == ends.shape, "starts/ends shape mismatch" + assert key_positions.ndim == 1, "key_positions must be 1D" + assert starts.device == ends.device == key_positions.device, "device mismatch" + assert starts.dtype in (torch.int32, torch.int64), "starts must be int tensor" + assert ends.dtype in (torch.int32, torch.int64), "ends must be int tensor" + assert key_positions.dtype in (torch.int32, torch.int64), "key_positions must be int tensor" + key_positions = key_positions.to(dtype=torch.int64) + starts = starts.to(dtype=torch.int64) + ends = ends.to(dtype=torch.int64) + return (key_positions.unsqueeze(0) >= starts.unsqueeze(-1)) & ( + key_positions.unsqueeze(0) < ends.unsqueeze(-1) + ) + + +def _apply_starts_ends_mask_to_scores( + scores: torch.Tensor, starts: torch.Tensor, ends: torch.Tensor, key_positions: torch.Tensor +) -> torch.Tensor: + """Apply varlen starts/ends mask to score tensor. + + Supports scores with shape [b, sq, sk] or [b, np, sq, sk]. + """ + valid = _build_valid_mask_from_starts_ends(starts, ends, key_positions) + if scores.ndim == 3: + return scores.masked_fill(~valid.unsqueeze(0), float("-inf")) + if scores.ndim == 4: + return scores.masked_fill(~valid.unsqueeze(0).unsqueeze(0), float("-inf")) + raise ValueError(f"Unsupported scores ndim={scores.ndim}, expected 3 or 4.") + + +def _build_default_causal_mask(sq: int, sk: int, device: torch.device) -> torch.Tensor: + """Build standard upper-triangular additive causal mask.""" + return torch.triu( + torch.full((sq, sk), float("-inf"), dtype=torch.float32, device=device), diagonal=1 + ) + + +def _prepare_additive_mask( + mask: Optional[torch.Tensor], *, sq: int, sk: int, b: int, device: torch.device +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Validate/build additive mask and return useful broadcasted views. + + Returns: + score_mask: [sq, sk] or [b, sq, sk] + attn_score_mask: [1, 1, sq, sk] or [b, 1, sq, sk] + index_score_mask: [1, sq, sk] or [b, sq, sk] + valid_mask: [b, sq, sk] bool, True means finite (not masked) + """ + if mask is None: + score_mask = _build_default_causal_mask(sq, sk, device=device) + else: + assert mask.dtype == torch.float32, "mask dtype must be float32" + assert mask.device == device, "mask device mismatch" + assert mask.ndim in (2, 3), "mask must be 2D or 3D" + if mask.ndim == 2: + assert mask.shape == (sq, sk), "mask shape mismatch" + else: + assert mask.shape == (b, sq, sk), "mask shape mismatch" + score_mask = mask + + if score_mask.ndim == 2: + attn_score_mask = score_mask.view(1, 1, sq, sk) + index_score_mask = score_mask.unsqueeze(0) + valid_mask = torch.isfinite(score_mask).unsqueeze(0).expand(b, sq, sk) + else: + attn_score_mask = score_mask.view(b, 1, sq, sk) + index_score_mask = score_mask + valid_mask = torch.isfinite(score_mask) + return score_mask, attn_score_mask, index_score_mask, valid_mask + + +def _prepare_sparse_mask_context( + *, + mask: Optional[torch.Tensor], + varlen_starts: Optional[torch.Tensor], + varlen_ends: Optional[torch.Tensor], + key_positions: Optional[torch.Tensor], + sq: int, + sk: int, + b: int, + device: torch.device, +) -> Tuple[ + Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor] +]: + """Prepare shared sparse-mask context for unfused attention paths.""" + if mask is not None and varlen_starts is not None: + raise ValueError("mask and varlen_starts are mutually exclusive") + + if varlen_starts is not None: + if varlen_ends is None: + raise ValueError("varlen_ends is required when varlen_starts is provided") + varlen_starts_i64 = varlen_starts.to(device=device, dtype=torch.int64) + varlen_ends_i64 = varlen_ends.to(device=device, dtype=torch.int64) + if key_positions is None: + key_positions_i64 = torch.arange(sk, dtype=torch.int64, device=device) + else: + key_positions_i64 = key_positions.to(device=device, dtype=torch.int64) + return None, varlen_starts_i64, varlen_ends_i64, key_positions_i64 + + _, _, index_score_mask, _ = _prepare_additive_mask(mask, sq=sq, sk=sk, b=b, device=device) + return index_score_mask, None, None, None + + +def _apply_sparse_validity_to_index_mask( + index_mask: torch.Tensor, + *, + row_mask: Optional[torch.Tensor], + varlen_starts: Optional[torch.Tensor], + varlen_ends: Optional[torch.Tensor], + key_positions: Optional[torch.Tensor], +) -> torch.Tensor: + """Apply either varlen or additive mask validity constraints to index_mask.""" + if varlen_starts is not None: + if varlen_ends is None: + raise ValueError("varlen_ends is required when varlen_starts is provided") + if key_positions is None: + raise ValueError("key_positions is required when varlen_starts is provided") + valid_mask = _build_valid_mask_from_starts_ends( + varlen_starts, varlen_ends, key_positions + ).unsqueeze(0) + return index_mask.masked_fill(~valid_mask, float("-inf")) + + if row_mask is None: + raise ValueError("row_mask is required when varlen_starts is None") + return index_mask + row_mask + + +def _gather_sparse_topk_validity_and_bias( + *, + idx_topk: torch.Tensor, + valid_t: torch.Tensor, + bi: int, + s0: int, + s1: int, + row_mask: Optional[torch.Tensor], + varlen_starts: Optional[torch.Tensor], + varlen_ends: Optional[torch.Tensor], + key_positions: Optional[torch.Tensor], + dtype: torch.dtype, +) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """Gather top-k validity mask and optional additive bias for one [s_chunk, topk] block.""" + if varlen_starts is not None: + if varlen_ends is None: + raise ValueError("varlen_ends is required when varlen_starts is provided") + if key_positions is None: + raise ValueError("key_positions is required when varlen_starts is provided") + key_pos_sel = key_positions.index_select(0, idx_topk.reshape(-1)).view_as(idx_topk) + valid_varlen = (key_pos_sel >= varlen_starts[s0:s1].unsqueeze(-1)) & ( + key_pos_sel < varlen_ends[s0:s1].unsqueeze(-1) + ) + return valid_t & valid_varlen, None + + if row_mask is None: + raise ValueError("row_mask is required when varlen_starts is None") + mask_src = row_mask[0, s0:s1, :] if row_mask.size(0) == 1 else row_mask[bi, s0:s1, :] + mask_bias = mask_src.gather(-1, idx_topk).to(dtype=dtype) + return valid_t & torch.isfinite(mask_bias), mask_bias + + +def _scatter_topk_into_index_mask( + index_mask: torch.Tensor, topk_indices: torch.Tensor, *, seq_chunk_size: int = 256 +) -> None: + """Scatter top-k supports into index_mask using chunk-wise int64 casts.""" + b, sq, _ = index_mask.shape + assert topk_indices.ndim == 3, "topk_indices must be [b, sq, topk]" + assert topk_indices.shape[:2] == (b, sq), "topk_indices shape mismatch" + device = index_mask.device + seq_chunk_size = max(1, int(seq_chunk_size)) + + for s0 in range(0, sq, seq_chunk_size): + s1 = min(s0 + seq_chunk_size, sq) + idx_chunk = topk_indices[:, s0:s1] + if idx_chunk.dtype != torch.int64 or idx_chunk.device != device: + idx_chunk = idx_chunk.to(dtype=torch.int64, device=device) + if torch.any(idx_chunk < 0): + valid_topk = idx_chunk >= 0 + if valid_topk.any(): + b_idx, q_rel_idx, t_idx = torch.where(valid_topk) + q_idx = q_rel_idx + s0 + k_idx = idx_chunk[b_idx, q_rel_idx, t_idx] + index_mask[b_idx, q_idx, k_idx] = 0.0 + else: + index_mask[:, s0:s1].scatter_(-1, idx_chunk, 0.0) + + +def _extract_query_positions_from_position_ids( + position_ids: Optional[torch.Tensor], sq: int, device: torch.device +) -> Optional[torch.Tensor]: + """Extract per-rank query positions from position_ids if compatible.""" + if position_ids is None: + return None + if position_ids.ndim == 2: + if position_ids.size(0) > 1: + assert torch.equal( + position_ids[0], position_ids[-1] + ), "Allgather-CP DSA expects identical position_ids across batch" + query_pos = position_ids[0] + elif position_ids.ndim == 1: + query_pos = position_ids + else: + raise ValueError(f"position_ids should be 1D or 2D tensor, got {position_ids.ndim}D.") + + if query_pos.numel() != sq: + return None + return query_pos.to(device=device, dtype=torch.int64) + + +def _get_packed_qk_cu_seqlens( + packed_seq_params: PackedSeqParams, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Select packed cu_seqlens for query and key/value streams.""" + cu_seqlens_q = ( + packed_seq_params.cu_seqlens_q_padded + if packed_seq_params.cu_seqlens_q_padded is not None + else packed_seq_params.cu_seqlens_q + ) + cu_seqlens = ( + packed_seq_params.cu_seqlens_kv_padded + if packed_seq_params.cu_seqlens_kv_padded is not None + else packed_seq_params.cu_seqlens_kv + ) + cu_seqlens_kv = cu_seqlens + + if cu_seqlens_q is None and cu_seqlens_kv is None: + raise ValueError("Packed sequence parameters must provide cu_seqlens for DSA masking.") + if cu_seqlens_q is None: + cu_seqlens_q = cu_seqlens_kv + if cu_seqlens_kv is None: + cu_seqlens_kv = cu_seqlens_q + return cu_seqlens_q, cu_seqlens_kv + + +def _build_dsattention_forward_mask( + *, + sq: int, + skv: int, + b: int, + device: torch.device, + cp_size: int, + cp_rank: int, + cp_comm_type: str, + cp_group: Optional[torch.distributed.ProcessGroup], + attn_mask_type: Optional[AttnMaskType], + attention_mask: Optional[torch.Tensor], + position_ids: Optional[torch.Tensor], + packed_seq_params: Optional[PackedSeqParams], +) -> Tuple[Optional[torch.Tensor], Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]: + """Build DSAttention mask. + + Returns: + float_mask: Optional additive mask [sq, skv] or [b, sq, skv]. + varlen_params: Optional (starts, ends, key_positions), each int64 tensor. + """ + packed_thd = packed_seq_params is not None and packed_seq_params.qkv_format == "thd" + if attn_mask_type is not None: + assert attn_mask_type == AttnMaskType.causal, "Only causal mask is supported for now" + if packed_thd: + cu_seqlens_q, _ = _get_packed_qk_cu_seqlens(packed_seq_params) + cu_seqlens_q = cu_seqlens_q.to(device=device, dtype=torch.int64) + starts, ends = _generate_varlen_mask_params(cu_seqlens_q) + if cp_size > 1: + query_idx, key_idx = _get_cp_positions_from_layout( + sq=sq, + skv=skv, + cp_size=cp_size, + cp_rank=cp_rank, + cp_comm_type=cp_comm_type, + device=device, + cp_group=cp_group, + ) + else: + query_idx = torch.arange(sq, dtype=torch.int64, device=device) + key_idx = torch.arange(skv, dtype=torch.int64, device=device) + varlen_starts = starts.index_select(0, query_idx) + varlen_ends = ends.index_select(0, query_idx) + return None, (varlen_starts, varlen_ends, key_idx) + + if cp_size > 1: + query_pos = _extract_query_positions_from_position_ids(position_ids, sq, device) + if query_pos is None: + query_pos, key_pos = _get_cp_positions_from_layout( + sq=sq, + skv=skv, + cp_size=cp_size, + cp_rank=cp_rank, + cp_comm_type=cp_comm_type, + device=device, + cp_group=cp_group, + ) + else: + key_pos = torch.arange(skv, dtype=torch.int64, device=device) + return _build_causal_mask_from_positions(query_pos, key_pos), None + + return _build_default_causal_mask(sq, skv, device=device), None + + assert attention_mask is not None, "attention_mask is required when attn_mask_type is None" + assert attention_mask.shape == (b, 1, sq, skv), "attention_mask shape mismatch" + mask = attention_mask[:, 0, :, :] + float_mask = torch.zeros_like(mask, dtype=torch.float32).masked_fill(mask, float("-inf")) + return float_mask, None + + +def _build_fused_indexer_varlen_bounds( + *, + sq: int, + skv: int, + device: torch.device, + mask: Optional[torch.Tensor], + varlen_starts: Optional[torch.Tensor], + varlen_ends: Optional[torch.Tensor], + key_positions: Optional[torch.Tensor], +) -> Optional[Tuple[torch.Tensor, torch.Tensor]]: + """Build starts/ends bounds for tilelang fused indexer. + + Fused indexer expects row-wise contiguous valid key ranges [start, end). + """ + if varlen_starts is not None: + if varlen_ends is None: + raise ValueError("varlen_ends is required when varlen_starts is provided") + if key_positions is None: + key_positions = torch.arange(skv, dtype=torch.int64, device=device) + expected_key_pos = torch.arange(skv, dtype=torch.int64, device=device) + key_positions = key_positions.to(dtype=torch.int64, device=device) + if not torch.equal(key_positions, expected_key_pos): + return None + return ( + varlen_starts.to(dtype=torch.int32, device=device), + varlen_ends.to(dtype=torch.int32, device=device), + ) + + if mask is None: + # Standard local causal mask. + ends = torch.arange(1, sq + 1, dtype=torch.int64, device=device).clamp_max(skv) + starts = torch.zeros_like(ends) + return starts.to(dtype=torch.int32), ends.to(dtype=torch.int32) + + if mask.ndim == 3: + # Fused indexer uses one shared starts/ends schedule. For batched masks, only + # enable fused path when all batch masks are identical. + if mask.size(0) > 1: + ref_mask = mask[0] + for bi in range(1, mask.size(0)): + if not torch.equal(mask[bi], ref_mask): + return None + row_mask = mask[0] + else: + row_mask = mask + if row_mask.ndim != 2 or row_mask.shape != (sq, skv): + return None + + finite = torch.isfinite(row_mask) + ends = finite.sum(dim=-1, dtype=torch.int64) + key_ids = torch.arange(skv, dtype=torch.int64, device=device).unsqueeze(0) + expected = key_ids < ends.unsqueeze(-1) + if not torch.equal(finite, expected): + return None + + starts = torch.zeros_like(ends) + return starts.to(dtype=torch.int32), ends.to(dtype=torch.int32) + + +def _fused_qk_topk_lighting( + q: torch.Tensor, + k: torch.Tensor, + weights: torch.Tensor, + index_topk: int, + starts: torch.Tensor, + ends: torch.Tensor, + block_size: int, +) -> Optional[torch.Tensor]: + """Run fused tilelang indexer and return top-k indices [b, sq, topk].""" + if lighting_indexer is None: + return None + if q.ndim != 4 or k.ndim != 3 or weights.ndim != 3: + return None + + sq, b = q.size(0), q.size(1) + if k.size(1) != b or weights.size(1) != b: + return None + starts = starts.contiguous() + ends = ends.contiguous() + + topk_out = None + for bi in range(b): + index_q = q[:, bi].contiguous() + index_k = k[:, bi].contiguous() + index_w = weights[:, bi].float().contiguous() + for start in range(0, sq, block_size): + end = min(start + block_size, sq) + _, topk_indices = lighting_indexer( + index_q[start:end], + index_k, + index_w[start:end], + starts[start:end], + ends[start:end], + index_topk, + topk_indices=None, + ) + if topk_out is None: + topk_out = torch.empty( + (b, sq, topk_indices.size(-1)), + dtype=topk_indices.dtype, + device=topk_indices.device, + ) + topk_out[bi, start:end].copy_(topk_indices) + + if topk_out is None: + return None + return topk_out + + +def _compute_topk_target_chunk_sum( + *, + query_h: torch.Tensor, + key_shared: Optional[torch.Tensor], + key_per_head: Optional[torch.Tensor], + s0: int, + s1: int, + idx_seq: torch.Tensor, + valid_seq: torch.Tensor, + softmax_scale: float, + head_chunk_size: int, + topk_chunk_size: int, + sk: int, + hn: int, +) -> torch.Tensor: + """Compute unnormalized target probability mass on top-k support for one sequence chunk.""" + s_len = s1 - s0 + topk = idx_seq.size(-1) + device = query_h.device + np = query_h.size(0) + + attn_chunk_sum = _get_scratch_buffer("kl_attn_chunk_sum", (s_len, topk), torch.float32, device) + attn_chunk_sum.zero_() + + for h0 in range(0, np, head_chunk_size): + h1 = min(h0 + head_chunk_size, np) + h_chunk = h1 - h0 + q_chunk = query_h[h0:h1, s0:s1, :] # [h_chunk, s_len, hn] + + if key_shared is None: + key_chunk = key_per_head[h0:h1] # [h_chunk, sk, hn] + flat_keys = key_chunk.reshape(h_chunk * sk, hn) + head_offsets = ( + torch.arange(h_chunk, device=device, dtype=torch.int64).view(-1, 1, 1) * sk + ) + else: + flat_keys = None + head_offsets = None + + # Two-pass online softmax over top-k chunks (per head): + # 1) pass computes row-wise max and denominator; + # 2) pass recomputes chunk logits and accumulates probabilities. + m = _get_scratch_buffer("kl_m", (h_chunk, s_len), torch.float32, device) + l = _get_scratch_buffer("kl_l", (h_chunk, s_len), torch.float32, device) + m.fill_(float("-inf")) + l.zero_() + + # Pass 1: row max + denominator. + for t0 in range(0, topk, topk_chunk_size): + t1 = min(t0 + topk_chunk_size, topk) + tk = t1 - t0 + idx_topk = idx_seq[:, t0:t1] # [s_len, tk] + valid_topk_chunk = valid_seq[:, t0:t1] # [s_len, tk] + + if key_shared is not None: + key_sel = key_shared.index_select(0, idx_topk.reshape(-1)).view(s_len, tk, hn) + logits = ( + torch.einsum('hsd,skd->hsk', q_chunk.float(), key_sel.float()) * softmax_scale + ) + else: + flat_idx = idx_topk.unsqueeze(0) + head_offsets # [h_chunk, s_len, tk] + key_sel = flat_keys.index_select(0, flat_idx.reshape(-1)).view( + h_chunk, s_len, tk, hn + ) + logits = (q_chunk.float().unsqueeze(2) * key_sel.float()).sum( + dim=-1 + ) * softmax_scale + + logits = logits.masked_fill(~valid_topk_chunk.unsqueeze(0), float("-inf")) + + chunk_max = logits.max(dim=-1).values + m_new = torch.maximum(m, chunk_max) + alpha = torch.exp(m - m_new) + alpha = torch.nan_to_num(alpha, nan=0.0, posinf=0.0, neginf=0.0) + p_chunk = torch.exp(logits - m_new.unsqueeze(-1)) + p_chunk = torch.nan_to_num(p_chunk, nan=0.0, posinf=0.0, neginf=0.0) + l = l * alpha + p_chunk.sum(dim=-1) + m = m_new + + # Pass 2: probabilities accumulation per top-k chunk. + stable_m = torch.where(torch.isfinite(m), m, torch.zeros_like(m)) + inv_l = l.clamp_min(1e-10).reciprocal() + for t0 in range(0, topk, topk_chunk_size): + t1 = min(t0 + topk_chunk_size, topk) + tk = t1 - t0 + idx_topk = idx_seq[:, t0:t1] # [s_len, tk] + valid_topk_chunk = valid_seq[:, t0:t1] # [s_len, tk] + + if key_shared is not None: + key_sel = key_shared.index_select(0, idx_topk.reshape(-1)).view(s_len, tk, hn) + logits = ( + torch.einsum('hsd,skd->hsk', q_chunk.float(), key_sel.float()) * softmax_scale + ) + else: + flat_idx = idx_topk.unsqueeze(0) + head_offsets # [h_chunk, s_len, tk] + key_sel = flat_keys.index_select(0, flat_idx.reshape(-1)).view( + h_chunk, s_len, tk, hn + ) + logits = (q_chunk.float().unsqueeze(2) * key_sel.float()).sum( + dim=-1 + ) * softmax_scale + + logits = logits.masked_fill(~valid_topk_chunk.unsqueeze(0), float("-inf")) + probs = torch.exp(logits - stable_m.unsqueeze(-1)) * inv_l.unsqueeze(-1) + probs = torch.nan_to_num(probs, nan=0.0, posinf=0.0, neginf=0.0) + attn_chunk_sum[:, t0:t1] += probs.sum(dim=0) + + return attn_chunk_sum + + +def _compute_sparse_topk_kl_chunk( + target_chunk: torch.Tensor, index_logits_chunk: torch.Tensor, valid_seq: torch.Tensor +) -> torch.Tensor: + """Compute KL(target || index) sum for one [s_chunk, topk] chunk.""" + index_logits_chunk = index_logits_chunk.to(dtype=torch.float32, device=target_chunk.device) + index_logits_chunk = index_logits_chunk.masked_fill(~valid_seq, float("-inf")) + no_valid_rows = ~valid_seq.any(dim=-1, keepdim=True) + if no_valid_rows.any(): + index_logits_chunk = index_logits_chunk.masked_fill( + no_valid_rows.expand_as(index_logits_chunk), 0.0 + ) + index_scores_chunk = torch.nn.functional.softmax( + index_logits_chunk, dim=-1, dtype=torch.float32 + ) + kl_chunk = target_chunk * ( + torch.log(target_chunk + 1e-10) - torch.log(index_scores_chunk + 1e-10) + ) + return kl_chunk.sum() + + +def _normalize_topk_target_chunk(target_chunk: torch.Tensor) -> torch.Tensor: + """Normalize target probability mass over top-k support.""" + return target_chunk / target_chunk.sum(dim=-1, keepdim=True).clamp_min(1e-10) + + +def _stage_topk_target_chunk( + target_chunk: torch.Tensor, + *, + slot_prefix: str, + slot: int, + device: torch.device, + tp_group: torch.distributed.ProcessGroup, + tp_size: int, +) -> Tuple[torch.Tensor, Optional[torch.distributed.Work]]: + """Copy chunk into scratch slot and optionally launch async TP all-reduce.""" + target_chunk_work = _get_scratch_buffer( + f"{slot_prefix}_slot{slot}", tuple(target_chunk.shape), torch.float32, device + ) + target_chunk_work.copy_(target_chunk) + if tp_size > 1: + handle = torch.distributed.all_reduce(target_chunk_work, group=tp_group, async_op=True) + else: + handle = None + return target_chunk_work, handle + + +def _consume_pending_topk_kl_chunk( + *, + pending_handle: Optional[torch.distributed.Work], + pending_target_chunk: Optional[torch.Tensor], + pending_index_logits: Optional[torch.Tensor], + pending_valid_seq: Optional[torch.Tensor], + kl_sum: torch.Tensor, +) -> Tuple[ + torch.Tensor, + Optional[torch.distributed.Work], + Optional[torch.Tensor], + Optional[torch.Tensor], + Optional[torch.Tensor], +]: + """Finalize one pending chunk and accumulate KL.""" + if pending_target_chunk is None: + return kl_sum, pending_handle, pending_target_chunk, pending_index_logits, pending_valid_seq + if pending_handle is not None: + pending_handle.wait() + normalized_target = _normalize_topk_target_chunk(pending_target_chunk) + kl_sum = kl_sum + _compute_sparse_topk_kl_chunk( + target_chunk=normalized_target, + index_logits_chunk=pending_index_logits, + valid_seq=pending_valid_seq, + ) + return kl_sum, None, None, None, None + + +def _enqueue_topk_kl_chunk( + *, + target_chunk: torch.Tensor, + index_logits_chunk: torch.Tensor, + valid_seq: torch.Tensor, + slot_prefix: str, + chunk_id: int, + device: torch.device, + tp_group: torch.distributed.ProcessGroup, + tp_size: int, + pending_handle: Optional[torch.distributed.Work], + pending_target_chunk: Optional[torch.Tensor], + pending_index_logits: Optional[torch.Tensor], + pending_valid_seq: Optional[torch.Tensor], + kl_sum: torch.Tensor, +) -> Tuple[ + torch.Tensor, + int, + Optional[torch.distributed.Work], + Optional[torch.Tensor], + Optional[torch.Tensor], + Optional[torch.Tensor], +]: + """Stage a new KL chunk, consume previous pending chunk, and update pending state.""" + slot = chunk_id & 1 + target_chunk_work, current_handle = _stage_topk_target_chunk( + target_chunk, + slot_prefix=slot_prefix, + slot=slot, + device=device, + tp_group=tp_group, + tp_size=tp_size, + ) + kl_sum, _, _, _, _ = _consume_pending_topk_kl_chunk( + pending_handle=pending_handle, + pending_target_chunk=pending_target_chunk, + pending_index_logits=pending_index_logits, + pending_valid_seq=pending_valid_seq, + kl_sum=kl_sum, + ) + return (kl_sum, chunk_id + 1, current_handle, target_chunk_work, index_logits_chunk, valid_seq) + + +def _flush_pending_topk_kl_chunk( + *, + pending_handle: Optional[torch.distributed.Work], + pending_target_chunk: Optional[torch.Tensor], + pending_index_logits: Optional[torch.Tensor], + pending_valid_seq: Optional[torch.Tensor], + kl_sum: torch.Tensor, +) -> torch.Tensor: + """Consume the final pending KL chunk and return updated kl_sum.""" + kl_sum, _, _, _, _ = _consume_pending_topk_kl_chunk( + pending_handle=pending_handle, + pending_target_chunk=pending_target_chunk, + pending_index_logits=pending_index_logits, + pending_valid_seq=pending_valid_seq, + kl_sum=kl_sum, + ) + return kl_sum + + +def _fused_qk_topk_lighting_with_streaming_sparse_kl( + q: torch.Tensor, + k: torch.Tensor, + weights: torch.Tensor, + index_topk: int, + starts: torch.Tensor, + ends: torch.Tensor, + block_size: int, + query: torch.Tensor, + key: torch.Tensor, + softmax_scale: float, + loss_coeff: float, + pg_collection: ProcessGroupCollection, + seq_chunk_size: int = 32, + head_chunk_size: int = 4, + topk_chunk_size: int = 64, +) -> Optional[Tuple[torch.Tensor, torch.Tensor]]: + """Run fused tilelang indexer and stream top-k logits directly into sparse KL accumulation.""" + if lighting_indexer is None: + return None + if q.ndim != 4 or k.ndim != 3 or weights.ndim != 3: + return None + + query, _ = _ensure_sbhd(query, "query") + key, _ = _ensure_sbhd(key, "key") + sq, b = q.size(0), q.size(1) + sq_q, b_q, np, hn = query.size() + sk, b_k, nk, hk = key.size() + if k.size(1) != b or weights.size(1) != b: + return None + if sq_q != sq or b_q != b or b_k != b or hk != hn: + return None + if nk != 1 and nk != np: + return None + + starts = starts.contiguous() + ends = ends.contiguous() + + topk_out = None + kl_sum = torch.zeros((), dtype=torch.float32, device=q.device) + tp_size = pg_collection.tp.size() + pending_handle = None + pending_target_chunk = None + pending_index_logits = None + pending_valid_seq = None + chunk_id = 0 + for bi in range(b): + query_h = query[:, bi].permute(1, 0, 2).contiguous() # [np, sq, hn] + if nk == 1: + key_shared = key[:, bi, 0].contiguous() # [sk, hn] + key_per_head = None + else: + key_shared = None + key_per_head = key[:, bi].permute(1, 0, 2).contiguous() # [np, sk, hn] + + index_q = q[:, bi].contiguous() + index_k = k[:, bi].contiguous() + index_w = weights[:, bi].float().contiguous() + + for start in range(0, sq, block_size): + end = min(start + block_size, sq) + topk_scores, topk_indices = lighting_indexer( + index_q[start:end], + index_k, + index_w[start:end], + starts[start:end], + ends[start:end], + index_topk, + topk_indices=None, + ) + + if topk_out is None: + topk_out = torch.empty( + (b, sq, topk_indices.size(-1)), + dtype=topk_indices.dtype, + device=topk_indices.device, + ) + topk_out[bi, start:end].copy_(topk_indices) + + s_len = end - start + for rel_start in range(0, s_len, seq_chunk_size): + rel_end = min(rel_start + seq_chunk_size, s_len) + abs_start = start + rel_start + abs_end = start + rel_end + + idx_seq_raw = topk_indices[rel_start:rel_end].to( + dtype=torch.int64, device=query.device + ) + valid_seq = idx_seq_raw >= 0 + idx_seq = idx_seq_raw.clamp(min=0) + target_chunk = _compute_topk_target_chunk_sum( + query_h=query_h, + key_shared=key_shared, + key_per_head=key_per_head, + s0=abs_start, + s1=abs_end, + idx_seq=idx_seq, + valid_seq=valid_seq, + softmax_scale=softmax_scale, + head_chunk_size=head_chunk_size, + topk_chunk_size=topk_chunk_size, + sk=sk, + hn=hn, + ) + ( + kl_sum, + chunk_id, + pending_handle, + pending_target_chunk, + pending_index_logits, + pending_valid_seq, + ) = _enqueue_topk_kl_chunk( + target_chunk=target_chunk, + index_logits_chunk=topk_scores[rel_start:rel_end], + valid_seq=valid_seq, + slot_prefix="stream_kl_target", + chunk_id=chunk_id, + device=query.device, + tp_group=pg_collection.tp, + tp_size=tp_size, + pending_handle=pending_handle, + pending_target_chunk=pending_target_chunk, + pending_index_logits=pending_index_logits, + pending_valid_seq=pending_valid_seq, + kl_sum=kl_sum, + ) + kl_sum = _flush_pending_topk_kl_chunk( + pending_handle=pending_handle, + pending_target_chunk=pending_target_chunk, + pending_index_logits=pending_index_logits, + pending_valid_seq=pending_valid_seq, + kl_sum=kl_sum, + ) + + if topk_out is None: + return None + kl_div = kl_sum / (b * sq) + return topk_out, kl_div * loss_coeff + + +def _fused_sparse_mla_absorbed( + query: torch.Tensor, + key: torch.Tensor, + topk_indices: torch.Tensor, + softmax_scale: float, + v_channels: int, +) -> Optional[torch.Tensor]: + """Run fused SparseMLA kernel for absorbed-MLA path. + + Inputs are expected in SBHD with MQA key heads (kv_group=1): + query: [sq, b, np, d_total] + key: [skv, b, 1, d_total] + topk: [b, sq, topk] + + Returns: + output: [sq, b, np, v_channels], or None if unsupported / unavailable. + """ + if SparseMLA is None: + return None + + if query.ndim != 4 or key.ndim != 4 or topk_indices.ndim != 3: + return None + if key.size(2) != 1: + return None + if query.size(1) != key.size(1) or topk_indices.size(0) != query.size(1): + return None + if topk_indices.size(1) != query.size(0): + return None + if query.size(-1) != key.size(-1): + return None + if query.size(-1) != 576 or v_channels != 512: + # Current copied tilelang kernels are specialized for GLM5/DeepSeek V3.2 absorbed dims. + return None + + # Kernel requires topk to be block-aligned. + if topk_indices.size(-1) % 64 != 0: + return None + + batch_outputs = None + for bi in range(query.size(1)): + q_t = query[:, bi].contiguous() # [sq, np, d_total] + kv_t = key[:, bi].contiguous() # [skv, 1, d_total] + idx_t = topk_indices[bi].unsqueeze(1).to(torch.int32).contiguous() # [sq, 1, topk] + out, _ = SparseMLA.apply(q_t, kv_t, idx_t, softmax_scale) + if out.ndim != 3 or out.size(-1) != v_channels: + return None + if batch_outputs is None: + batch_outputs = torch.empty( + (out.size(0), query.size(1), out.size(1), out.size(2)), + dtype=out.dtype, + device=out.device, + ) + batch_outputs[:, bi].copy_(out) + + if batch_outputs is None: + return None + return batch_outputs.contiguous() + + +def _explain_absorbed_fused_skip( + query: torch.Tensor, key: torch.Tensor, topk_indices: torch.Tensor, v_channels: int +) -> str: + """Return first failing condition for fused absorbed SparseMLA path.""" + if SparseMLA is None: + return "SparseMLA kernel unavailable (import failed)" + if query.ndim != 4 or key.ndim != 4 or topk_indices.ndim != 3: + return "invalid tensor rank (expected query/key 4D and topk 3D)" + if key.size(2) != 1: + return f"key head count {key.size(2)} != 1 (MQA required)" + if query.size(1) != key.size(1) or topk_indices.size(0) != query.size(1): + return "batch shape mismatch among query/key/topk" + if topk_indices.size(1) != query.size(0): + return "topk seqlen mismatch with query seqlen" + if query.size(-1) != key.size(-1): + return "query/key hidden dim mismatch" + if query.size(-1) != 576 or v_channels != 512: + return ( + f"kernel specialized for d_total=576,v_channels=512 but got " + f"d_total={query.size(-1)}, v_channels={v_channels}" + ) + if topk_indices.size(-1) % 64 != 0: + return f"topk ({topk_indices.size(-1)}) is not block-aligned (must be multiple of 64)" + return "unknown runtime fallback (SparseMLA.apply returned/raised failure)" + + +def _build_sparse_attn_reason( + *, + sparse_attn_path: str, + absorbed_mla: bool, + query: torch.Tensor, + key: torch.Tensor, + topk_indices: torch.Tensor, + config: TransformerConfig, +) -> str: + """Build a concise reason string for sparse attention path selection.""" + if sparse_attn_path == "fused_sparse_mla_absorbed": + return "fused absorbed SparseMLA is active" + if sparse_attn_path == "fused_sparse_mla_absorbed_upv": + return "fused absorbed SparseMLA + up_v projection active" + if sparse_attn_path == "unfused_absorbed_upv": + return "absorbed QK path with up_v projection active; fused SparseMLA unavailable" + if not absorbed_mla: + return "absorbed=False; non-absorbed DSAttention currently uses unfused_dsa only" + return _explain_absorbed_fused_skip( + query=query, + key=key, + topk_indices=topk_indices, + v_channels=int(getattr(config, "kv_lora_rank", 0) or 0), + ) + + +def _unfused_absorbed_dsa_fn( + query: torch.Tensor, + key: torch.Tensor, + topk_indices: torch.Tensor, + softmax_scale: float, + v_channels: int, + mask: Optional[torch.Tensor] = None, + varlen_starts: Optional[torch.Tensor] = None, + varlen_ends: Optional[torch.Tensor] = None, + key_positions: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """Unfused absorbed-MLA attention: output stays [sq, b, np, v_channels].""" + sq, b, np, hn = query.size() + skv = key.size(0) + assert key.size(2) == 1, "Absorbed DSA expects MQA key head dimension = 1" + assert key.size(-1) >= v_channels, "key last dim must contain latent value channels" + row_mask, varlen_starts, varlen_ends, key_positions = _prepare_sparse_mask_context( + mask=mask, + varlen_starts=varlen_starts, + varlen_ends=varlen_ends, + key_positions=key_positions, + sq=sq, + sk=skv, + b=b, + device=query.device, + ) + + # [sq,b,np,hn] -> [b,np,sq,hn] + q = query.permute(1, 2, 0, 3) + # [skv,b,1,hn] -> [b,1,hn,skv] + k = key.permute(1, 2, 3, 0) + attention_scores = torch.matmul(q.float(), k.float()) * softmax_scale + + # Sparse + causal/varlen validity mask. + index_mask = torch.full((b, sq, skv), float("-inf"), device=attention_scores.device) + _scatter_topk_into_index_mask(index_mask, topk_indices, seq_chunk_size=256) + index_mask = _apply_sparse_validity_to_index_mask( + index_mask, + row_mask=row_mask, + varlen_starts=varlen_starts, + varlen_ends=varlen_ends, + key_positions=key_positions, + ) + + attention_scores += index_mask.unsqueeze(1) + attention_scores = torch.nn.functional.softmax(attention_scores, dim=-1, dtype=torch.float32) + + # Latent value is the first v_channels slice of absorbed key cache. + value = key[..., :v_channels].permute(1, 2, 0, 3) # [b,1,skv,v] + output = torch.matmul(attention_scores.to(value.dtype), value) # [b,np,sq,v] + return output.permute(2, 0, 1, 3).contiguous() + + +def _run_sparse_attention( + *, + absorbed_mla: bool, + query: torch.Tensor, + key: torch.Tensor, + value: Optional[torch.Tensor], + up_v_weight: Optional[torch.Tensor], + topk_indices: torch.Tensor, + softmax_scale: float, + config: TransformerConfig, + mask: Optional[torch.Tensor], + varlen_starts: Optional[torch.Tensor], + varlen_ends: Optional[torch.Tensor], + key_positions: Optional[torch.Tensor], +) -> Tuple[torch.Tensor, str]: + """Run sparse attention for absorbed and non-absorbed MLA paths.""" + if absorbed_mla: + latent_v_channels = int(getattr(config, "kv_lora_rank", 0) or 0) + if latent_v_channels <= 0: + raise RuntimeError( + "Invalid kv_lora_rank for absorbed-MLA DSAttention sparse attention." + ) + if value is not None: + raise RuntimeError( + "Absorbed DSAttention expects value=None (latent path). " + "Received absorbed layout with explicit value tensor." + ) + output = _fused_sparse_mla_absorbed( + query, key, topk_indices, softmax_scale, latent_v_channels + ) + if output is None: + output = _unfused_absorbed_dsa_fn( + query, + key, + topk_indices, + softmax_scale, + latent_v_channels, + mask=mask, + varlen_starts=varlen_starts, + varlen_ends=varlen_ends, + key_positions=key_positions, + ) + if up_v_weight is None: + return output, "unfused_absorbed" + output = torch.einsum("sbhc,hdc->sbhd", output, up_v_weight).contiguous() + output = output.view(output.size(0), output.size(1), -1) + return output, "unfused_absorbed_upv" + if up_v_weight is None: + return output, "fused_sparse_mla_absorbed" + output = torch.einsum("sbhc,hdc->sbhd", output, up_v_weight).contiguous() + output = output.view(output.size(0), output.size(1), -1) + return output, "fused_sparse_mla_absorbed_upv" + + return ( + unfused_dsa_fn( + query, + key, + value, + topk_indices, + softmax_scale, + mask=mask, + varlen_starts=varlen_starts, + varlen_ends=varlen_ends, + key_positions=key_positions, + ), + "unfused_dsa", + ) + + +def _ensure_sbhd(tensor: torch.Tensor, name: str) -> Tuple[torch.Tensor, bool]: + """Ensure tensor is [s, b, h, d], allowing packed [t, h, d] input.""" + if tensor.ndim == 4: + return tensor, False + if tensor.ndim == 3: + return tensor.unsqueeze(1), True + raise ValueError(f"{name} must be 3D ([t,h,d]) or 4D ([s,b,h,d]), got {tensor.ndim}D") + + +def _normalize_dsattention_output_rank(output: torch.Tensor, target_ndim: int) -> torch.Tensor: + """Normalize DSAttention output rank to match caller hidden-state rank.""" + if target_ndim not in (2, 3): + raise RuntimeError(f"DSAttention expected x.ndim in (2, 3), got {target_ndim}") + + if output.ndim == 4: + output = output.reshape(output.size(0), output.size(1), -1) + elif output.ndim not in (2, 3): + raise RuntimeError( + f"DSAttention produced unexpected output rank {output.ndim}; expected 2D/3D/4D." + ) + + if target_ndim == 3 and output.ndim == 2: + output = output.unsqueeze(1) + elif target_ndim == 2 and output.ndim == 3: + if output.size(1) != 1: + raise RuntimeError( + "DSAttention cannot squeeze non-singleton batch dim for packed output: " + f"shape={tuple(output.shape)}" + ) + output = output.squeeze(1) + + if output.ndim != target_ndim: + raise RuntimeError( + "DSAttention output rank mismatch after normalization: " + f"target_ndim={target_ndim}, output_shape={tuple(output.shape)}" + ) + return output + def rotate_activation(x: torch.Tensor) -> torch.Tensor: """Apply Hadamard rotation activation. @@ -167,6 +1370,10 @@ def compute_dsa_indexer_loss( loss_coeff: float, sparse_loss: bool, pg_collection: ProcessGroupCollection, + mask: Optional[torch.Tensor] = None, + varlen_starts: Optional[torch.Tensor] = None, + varlen_ends: Optional[torch.Tensor] = None, + key_positions: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Compute KL divergence loss between index_scores and true attention_scores. @@ -187,10 +1394,20 @@ def compute_dsa_indexer_loss( sparse_loss: bool, whether to use sparse indexer loss. If True, only the topk indices will be used to compute the loss. pg_collection: Process group collection, must have TP process group. + mask: Optional additive attention mask. Supports shape [sq, sk] or [b, sq, sk]. + Invalid positions should be -inf. + varlen_starts: Optional row-wise key start bounds [sq] for packed THD. + varlen_ends: Optional row-wise key end bounds [sq] for packed THD. + key_positions: Optional global key positions [sk] for packed THD. Returns: index_loss: KL divergence loss (scalar). """ + query, _ = _ensure_sbhd(query, "query") + key, _ = _ensure_sbhd(key, "key") + if mask is not None and varlen_starts is not None: + raise ValueError("mask and varlen_starts are mutually exclusive") + sq, b, np, hn = query.size() sk = key.size(0) @@ -203,18 +1420,29 @@ def compute_dsa_indexer_loss( # Reshape to [b, np, sq, sk] attention_scores = attention_scores.reshape(b, np, sq, sk) - # causal_mask [sq, sk] - causal_mask = torch.triu( - torch.full((sq, sk), float('-inf'), dtype=torch.float32, device=attention_scores.device), - diagonal=1, - ) + if varlen_starts is not None: + if varlen_ends is None: + raise ValueError("varlen_ends is required when varlen_starts is provided") + if key_positions is None: + key_positions = torch.arange(sk, dtype=torch.int64, device=attention_scores.device) + attention_scores = _apply_starts_ends_mask_to_scores( + attention_scores, varlen_starts, varlen_ends, key_positions + ) + index_scores = _apply_starts_ends_mask_to_scores( + index_scores, varlen_starts, varlen_ends, key_positions + ) + else: + _, attn_score_mask, _, _ = _prepare_additive_mask( + mask, sq=sq, sk=sk, b=b, device=attention_scores.device + ) + # [b, np, sq, sk] + [1/b, 1, sq, sk] -> [b, np, sq, sk] + attention_scores += attn_score_mask + # index_mask [b, sq, sk] index_mask = torch.full( - (b, sq, sk), float("-inf"), dtype=torch.float32, device=causal_mask.device + (b, sq, sk), float("-inf"), dtype=torch.float32, device=attention_scores.device ).scatter_(-1, topk_indices, 0) - # [b, np, sq, skv] + [1, 1, sq, skv] -> [b, np, sq, skv] - attention_scores += causal_mask.view(1, 1, sq, sk) if sparse_loss: # [b, np, sq, sk] + [b, 1, sq, sk] -> [b, np, sq, sk] attention_scores += index_mask.view(b, 1, sq, sk) @@ -252,6 +1480,111 @@ def compute_dsa_indexer_loss( return indexer_loss +def compute_dsa_indexer_loss_topk_sparse( + index_topk_scores: torch.Tensor, + topk_indices: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + softmax_scale: float, + loss_coeff: float, + pg_collection: ProcessGroupCollection, +) -> torch.Tensor: + """Compute sparse top-k KL loss using fused indexer top-k logits. + + This matches fused indexer semantics where indexer logits are + only materialized at selected top-k positions. + The implementation streams sequence chunks to reduce peak memory. + """ + query, _ = _ensure_sbhd(query, "query") + key, _ = _ensure_sbhd(key, "key") + + sq, b, np, hn = query.size() + sk, bk, nk, hk = key.size() + assert bk == b and hk == hn, "query/key shape mismatch" + assert index_topk_scores.shape[:2] == (b, sq), "index_topk_scores shape mismatch" + assert topk_indices.shape[:2] == (b, sq), "topk_indices shape mismatch" + + if nk != 1: + assert nk == np, "key head count must be 1 (MQA) or match query heads" + + # Compute KL in streaming chunks to avoid materializing full [b, sq, topk] + # and avoid full-size valid/safe top-k tensors. + seq_chunk_size = 32 + head_chunk_size = 4 + topk_chunk_size = 64 + kl_sum = torch.zeros((), dtype=torch.float32, device=query.device) + tp_size = pg_collection.tp.size() + pending_handle = None + pending_target_chunk = None + pending_index_logits = None + pending_valid_seq = None + chunk_id = 0 + + for bi in range(b): + query_h = query[:, bi].permute(1, 0, 2).contiguous() # [np, sq, hn] + if nk == 1: + key_shared = key[:, bi, 0].contiguous() # [sk, hn] + key_per_head = None + else: + key_shared = None + key_per_head = key[:, bi].permute(1, 0, 2).contiguous() # [np, sk, hn] + + for s0 in range(0, sq, seq_chunk_size): + s1 = min(s0 + seq_chunk_size, sq) + idx_seq_raw = topk_indices[bi, s0:s1] # [s_len, topk] + if idx_seq_raw.dtype != torch.int64 or idx_seq_raw.device != query.device: + idx_seq_raw = idx_seq_raw.to(dtype=torch.int64, device=query.device) + valid_seq = idx_seq_raw >= 0 + idx_seq = idx_seq_raw.clamp(min=0) + + target_chunk = _compute_topk_target_chunk_sum( + query_h=query_h, + key_shared=key_shared, + key_per_head=key_per_head, + s0=s0, + s1=s1, + idx_seq=idx_seq, + valid_seq=valid_seq, + softmax_scale=softmax_scale, + head_chunk_size=head_chunk_size, + topk_chunk_size=topk_chunk_size, + sk=sk, + hn=hn, + ) + ( + kl_sum, + chunk_id, + pending_handle, + pending_target_chunk, + pending_index_logits, + pending_valid_seq, + ) = _enqueue_topk_kl_chunk( + target_chunk=target_chunk, + index_logits_chunk=index_topk_scores[bi, s0:s1], + valid_seq=valid_seq, + slot_prefix="topk_sparse_kl_target", + chunk_id=chunk_id, + device=query.device, + tp_group=pg_collection.tp, + tp_size=tp_size, + pending_handle=pending_handle, + pending_target_chunk=pending_target_chunk, + pending_index_logits=pending_index_logits, + pending_valid_seq=pending_valid_seq, + kl_sum=kl_sum, + ) + kl_sum = _flush_pending_topk_kl_chunk( + pending_handle=pending_handle, + pending_target_chunk=pending_target_chunk, + pending_index_logits=pending_index_logits, + pending_valid_seq=pending_valid_seq, + kl_sum=kl_sum, + ) + + kl_div = kl_sum / (b * sq) + return kl_div * loss_coeff + + def _compute_index_scores(q: torch.Tensor, weights: torch.Tensor, k: torch.Tensor) -> torch.Tensor: """ Perform index score using BF16 precision. @@ -301,22 +1634,35 @@ def fused_qk_topk_naive( weights: torch.Tensor, index_topk: int, mask: Optional[torch.Tensor] = None, + varlen_starts: Optional[torch.Tensor] = None, + varlen_ends: Optional[torch.Tensor] = None, + key_positions: Optional[torch.Tensor] = None, ): """Naive implementation of QK Topk.""" - seqlen = q.size(0) + sk = k.size(0) # ========================================= # Compute index scores # ========================================= # [batch, seqlen, seqlen] index_scores = _compute_index_scores(q, weights, k) - if mask is not None: + if mask is not None and varlen_starts is not None: + raise ValueError("mask and varlen_starts are mutually exclusive") + if varlen_starts is not None: + if varlen_ends is None: + raise ValueError("varlen_ends is required when varlen_starts is provided") + if key_positions is None: + key_positions = torch.arange(sk, dtype=torch.int64, device=index_scores.device) + index_scores = _apply_starts_ends_mask_to_scores( + index_scores, varlen_starts, varlen_ends, key_positions + ) + elif mask is not None: assert mask.dtype == index_scores.dtype, "Mask dtype must match index scores dtype" index_scores = index_scores + mask # ========================================= # Select top-k indices # ========================================= - topk_k = min(index_topk, seqlen) + topk_k = min(index_topk, sk) # [batch, seqlen, index_topk] topk_indices = index_scores.topk(topk_k, dim=-1)[1] @@ -324,10 +1670,32 @@ def fused_qk_topk_naive( def fwd_fused_indexer_loss_naive( - q, weights, k, query, key, topk, softmax_scale, loss_coeff, mask, sparse_loss, pg_collection + q, + weights, + k, + query, + key, + topk, + softmax_scale, + loss_coeff, + mask, + sparse_loss, + pg_collection, + varlen_starts=None, + varlen_ends=None, + key_positions=None, ): """Naive implementation of forward pass for indexer loss.""" - index_scores, topk_indices = fused_qk_topk_naive(q, k, weights, topk, mask) + index_scores, topk_indices = fused_qk_topk_naive( + q, + k, + weights, + topk, + mask=mask, + varlen_starts=varlen_starts, + varlen_ends=varlen_ends, + key_positions=key_positions, + ) indexer_loss = compute_dsa_indexer_loss( index_scores, @@ -338,6 +1706,10 @@ def fwd_fused_indexer_loss_naive( loss_coeff, sparse_loss, pg_collection, + mask=mask, + varlen_starts=varlen_starts, + varlen_ends=varlen_ends, + key_positions=key_positions, ) return topk_indices, indexer_loss @@ -353,10 +1725,19 @@ def bwd_fused_indexer_loss_naive( softmax_scale, loss_coeff, sparse_loss, + mask, grad_loss, pg_collection, + varlen_starts=None, + varlen_ends=None, + key_positions=None, ): """Naive implementation of backward pass for indexer loss.""" + query, _ = _ensure_sbhd(query, "query") + key, _ = _ensure_sbhd(key, "key") + if mask is not None and varlen_starts is not None: + raise ValueError("mask and varlen_starts are mutually exclusive") + index_scores = _compute_index_scores(q, weights, k) # [B, Sq, Sk] sq, b, np, hn = query.size() @@ -374,24 +1755,36 @@ def bwd_fused_indexer_loss_naive( # Reshape to [b, np, sq, sk] attention_scores = attention_scores.reshape(b, np, sq, sk) - # causal_mask [sq, sk] - causal_mask = torch.triu( - torch.full((sq, sk), float('-inf'), dtype=torch.float32, device=attention_scores.device), - diagonal=1, - ) + if varlen_starts is not None: + if varlen_ends is None: + raise ValueError("varlen_ends is required when varlen_starts is provided") + if key_positions is None: + key_positions = torch.arange(sk, dtype=torch.int64, device=attention_scores.device) + attention_scores = _apply_starts_ends_mask_to_scores( + attention_scores, varlen_starts, varlen_ends, key_positions + ) + index_scores = _apply_starts_ends_mask_to_scores( + index_scores, varlen_starts, varlen_ends, key_positions + ) + base_valid_mask = ( + _build_valid_mask_from_starts_ends(varlen_starts, varlen_ends, key_positions) + .unsqueeze(0) + .expand(b, sq, sk) + ) + else: + _, attn_score_mask, index_score_mask, base_valid_mask = _prepare_additive_mask( + mask, sq=sq, sk=sk, b=b, device=attention_scores.device + ) + # [b, np, sq, sk] + [1/b, 1, sq, sk] -> [b, np, sq, sk] + attention_scores = attention_scores + attn_score_mask + # [b, sq, sk] + [1/b, sq, sk] -> [b, sq, sk] + index_scores = index_scores + index_score_mask + # index_mask [b, sq, sk] index_mask = torch.full( - (b, sq, sk), float("-inf"), dtype=torch.float32, device=causal_mask.device + (b, sq, sk), float("-inf"), dtype=torch.float32, device=attention_scores.device ).scatter_(-1, topk_indices, 0) - # Apply causal mask to both attention and index scores - # [b, np, sq, skv] + [1, 1, sq, skv] -> [b, np, sq, skv] - attention_scores = attention_scores + causal_mask.view(1, 1, sq, sk) - # [b, sq, sk] + [1, sq, sk] -> [b, sq, sk] - index_scores = index_scores + causal_mask.unsqueeze(0) - # Free causal_mask - no longer needed - del causal_mask - if sparse_loss: # [b, np, sq, sk] + [b, 1, sq, sk] -> [b, np, sq, sk] attention_scores = attention_scores + index_mask.view(b, 1, sq, sk) @@ -450,22 +1843,17 @@ def bwd_fused_indexer_loss_naive( # Free intermediate tensors del index_scores_softmax, grad_index_scores_softmax, sum_grad - # Zero out gradients for masked positions - # Create a mask for valid (non-masked) positions - # Causal mask: position (i, j) is valid if j <= i - causal_valid_mask = torch.tril( - torch.ones((sq, sk), device=q.device, dtype=torch.bool) - ) # [sq, sk] + # Zero out gradients for masked positions. if sparse_loss: - # Also apply index mask - only topk positions are valid + # Also apply index mask - only topk positions are valid. index_valid_mask = index_mask == 0 # [b, sq, sk] - del index_mask # Free index_mask immediately after use - valid_mask = causal_valid_mask.unsqueeze(0) & index_valid_mask # [b, sq, sk] + del index_mask + valid_mask = base_valid_mask & index_valid_mask # [b, sq, sk] del index_valid_mask else: - del index_mask # Free index_mask even if not used for sparse_loss - valid_mask = causal_valid_mask.unsqueeze(0).expand(b, sq, sk) # [b, sq, sk] - del causal_valid_mask + del index_mask + valid_mask = base_valid_mask # [b, sq, sk] + del base_valid_mask grad_index_scores_logits = grad_index_scores_logits * valid_mask.float() del valid_mask @@ -524,6 +1912,9 @@ def forward( mask, sparse_loss, pg_collection, + varlen_starts=None, + varlen_ends=None, + key_positions=None, ): """ Fused forward: index_scores never materialized in full. @@ -540,6 +1931,9 @@ def forward( mask, sparse_loss, pg_collection, + varlen_starts=varlen_starts, + varlen_ends=varlen_ends, + key_positions=key_positions, ) # Save for backward (recomputation strategy) @@ -547,7 +1941,11 @@ def forward( ctx.softmax_scale = softmax_scale ctx.loss_coeff = loss_coeff ctx.sparse_loss = sparse_loss + ctx.mask = mask ctx.pg_collection = pg_collection + ctx.varlen_starts = varlen_starts + ctx.varlen_ends = varlen_ends + ctx.key_positions = key_positions return topk_indices, loss @@ -568,12 +1966,32 @@ def backward(ctx, grad_topk_indices, grad_loss): ctx.softmax_scale, ctx.loss_coeff, ctx.sparse_loss, + ctx.mask, grad_loss, ctx.pg_collection, + varlen_starts=ctx.varlen_starts, + varlen_ends=ctx.varlen_ends, + key_positions=ctx.key_positions, ) # query and key are detached in forward, so return None for their gradients - return grad_q, grad_weights, grad_k, None, None, None, None, None, None, None, None + grads = [ + grad_q, + grad_weights, + grad_k, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ] + return tuple(grads[: len(ctx.needs_input_grad)]) class DSAIndexerLossAutoScaler(torch.autograd.Function): @@ -776,21 +2194,36 @@ def __init__( parallel_mode="duplicated", ) - def _apply_rope(self, x: torch.Tensor, rotary_pos_emb: torch.Tensor, mscale: float): + def _apply_rope( + self, + x: torch.Tensor, + rotary_pos_emb: torch.Tensor, + mscale: float, + cu_seqlens: Optional[torch.Tensor] = None, + ): """Apply RoPE to the input tensor.""" # x_nope [seqlen, batch, *, index_head_dim - qk_pos_emb_head_dim] # x_pe [seqlen, batch, *, qk_pos_emb_head_dim] x_nope, x_pe = torch.split( x, [self.index_head_dim - self.qk_pos_emb_head_dim, self.qk_pos_emb_head_dim], dim=-1 ) + squeezed_batch_dim = False + if cu_seqlens is not None and cu_seqlens.device != x_pe.device: + cu_seqlens = cu_seqlens.to(device=x_pe.device) + # THD RoPE path expects [t, h, d], while indexer tensors are [t, 1, h, d]. + if cu_seqlens is not None and x_pe.ndim == 4 and x_pe.size(1) == 1: + x_pe = x_pe.squeeze(1) + squeezed_batch_dim = True x_pe = apply_rotary_pos_emb( x_pe, rotary_pos_emb, config=self.config, - cu_seqlens=None, + cu_seqlens=cu_seqlens, mscale=mscale, cp_group=self.pg_collection.cp, ) + if squeezed_batch_dim: + x_pe = x_pe.unsqueeze(1) # [seqlen, batch, *, index_head_dim] x = torch.cat([x_nope, x_pe], dim=-1) return x @@ -799,6 +2232,8 @@ def forward_before_topk( self, x: torch.Tensor, qr: torch.Tensor, packed_seq_params: Optional[PackedSeqParams] = None ) -> Tuple[torch.Tensor, torch.Tensor]: """All computations before topk.""" + packed_seq = packed_seq_params is not None and packed_seq_params.qkv_format == "thd" + # ========================================= # Prepare RoPE params # ========================================= @@ -806,10 +2241,14 @@ def forward_before_topk( None, None, x, self.config, packed_seq_params ) if self.config.rope_type == "rope": - rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len, packed_seq=False) + rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len, packed_seq=packed_seq) mscale = 1.0 else: - rotary_pos_emb, mscale = self.rotary_pos_emb(rotary_seq_len, packed_seq=False) + rotary_pos_emb, mscale = self.rotary_pos_emb(rotary_seq_len, packed_seq=packed_seq) + if packed_seq: + cu_seqlens_q, cu_seqlens_kv = _get_packed_qk_cu_seqlens(packed_seq_params) + else: + cu_seqlens_q = cu_seqlens_kv = None # ========================================= # Gather inputs if sp is enabled @@ -831,17 +2270,17 @@ def forward_before_topk( # [seqlen, batch, index_n_heads * index_head_dim] # -> [seqlen, batch, index_n_heads, index_head_dim] q = q.reshape(seqlen, bsz, self.index_n_heads, self.index_head_dim) - q = self._apply_rope(q, rotary_pos_emb, mscale) + q = self._apply_rope(q, rotary_pos_emb, mscale, cu_seqlens=cu_seqlens_q) # ========================================= # k linear and apply rope to k # ========================================= # [seqlen, batch, hidden_size] -> [seqlen, batch, index_head_dim] k, _ = self.linear_wk(x) - k = self.k_norm(k) + k = self.k_norm(k.float()).to(dtype=k.dtype) # [seqlen, batch, index_head_dim] -> [seqlen, batch, 1, index_head_dim] k = k.reshape(seqlen, bsz, 1, self.index_head_dim) - k = self._apply_rope(k, rotary_pos_emb, mscale) + k = self._apply_rope(k, rotary_pos_emb, mscale, cu_seqlens=cu_seqlens_kv) # [seqlen, batch, 1, index_head_dim] -> [seqlen, batch, index_head_dim] k = k.reshape(seqlen, bsz, self.index_head_dim) @@ -875,15 +2314,14 @@ def forward_with_scores( Args: x: hidden states [seqlen, batch, hidden_size]. qr: Low-rank query tensor [seqlen, batch, q_lora_rank]. - mask: Attention mask [batch, seqlen, seqlen]. + mask: Optional additive attention mask [seqlen, seqlen] or + [batch, seqlen, seqlen]. packed_seq_params: Packed sequence parameters for variable length sequences. Returns: index_scores: Index scores [batch, seqlen, seqlen]. topk_indices: Top-k indices [batch, seqlen, index_topk]. """ - assert packed_seq_params is None, "Packed sequence is not supported for DSAttention" - # [seqlen, batch, index_n_heads * index_head_dim] # [seqlen, batch, index_head_dim] # [seqlen, batch, index_n_heads] @@ -917,56 +2355,155 @@ def forward( return topk_indices -def unfused_dsa_fn(query, key, value, topk_indices, softmax_scale): +def unfused_dsa_fn( + query, + key, + value, + topk_indices, + softmax_scale, + mask: Optional[torch.Tensor] = None, + varlen_starts: Optional[torch.Tensor] = None, + varlen_ends: Optional[torch.Tensor] = None, + key_positions: Optional[torch.Tensor] = None, +): """ Unfused sparse attention implementation. + + This path uses chunked sparse softmax accumulation over top-k selected keys + to avoid materializing full [b, np, sq, skv] attention score tensors. """ + if value is None: + raise NotImplementedError("DSAttention unfused path requires value tensor.") + + query, query_was_thd = _ensure_sbhd(query, "query") + key, _ = _ensure_sbhd(key, "key") + value, _ = _ensure_sbhd(value, "value") + sq, b, np, hn = query.size() skv = key.size(0) + nk = key.size(2) hnv = value.size(3) + nv = value.size(2) + + # [sq, b, np, hn] -> [b, np, sq, hn] + query_b = query.permute(1, 2, 0, 3).contiguous() + # [skv, b, nk, hn] -> [b, nk, skv, hn] + key_b = key.permute(1, 2, 0, 3).contiguous() + # [skv, b, nv, hnv] -> [b, nv, skv, hnv] + value_b = value.permute(1, 2, 0, 3).contiguous() + if nk == 1 and np > 1: + key_b = key_b.expand(b, np, skv, hn) + else: + assert nk == np, "key head count must be 1 (MQA) or match query heads" + if nv == 1 and np > 1: + value_b = value_b.expand(b, np, skv, hnv) + else: + assert nv == np, "value head count must be 1 (MQA) or match query heads" + + row_mask, varlen_starts, varlen_ends, key_positions = _prepare_sparse_mask_context( + mask=mask, + varlen_starts=varlen_starts, + varlen_ends=varlen_ends, + key_positions=key_positions, + sq=sq, + sk=skv, + b=b, + device=query.device, + ) - # =================================== - # Raw attention scores [b, np, sq, skv] - # =================================== - # [sq, b, np, hn] -> [b, np, sq, hn] -> [b * np, sq, hn] - query = query.permute(1, 2, 0, 3).reshape(b * np, sq, hn) - # [skv, b, np, hn] -> [b, np, hn, skv] -> [b * np, hn, skv] - key = key.permute(1, 2, 3, 0).reshape(b * np, hn, skv) - # Compute attention scores [b * np, sq, skv] - attention_scores = torch.bmm(query.float(), key.float()) * softmax_scale - # Reshape to [b, np, sq, skv] - attention_scores = attention_scores.reshape(b, np, sq, skv) + seq_chunk_size = 64 + head_chunk_size = 4 + topk_chunk_size = 256 + safe_k_max = max(0, skv - 1) + output = torch.empty((sq, b, np * hnv), dtype=value.dtype, device=query.device) + + for bi in range(b): + for h0 in range(0, np, head_chunk_size): + h1 = min(h0 + head_chunk_size, np) + h_chunk = h1 - h0 + out_h0 = h0 * hnv + out_h1 = h1 * hnv + k_chunk = key_b[bi, h0:h1, :, :].contiguous() # [h_chunk, skv, hn] + v_chunk = value_b[bi, h0:h1, :, :].contiguous() # [h_chunk, skv, hnv] + flat_k = k_chunk.reshape(h_chunk * skv, hn) + flat_v = v_chunk.reshape(h_chunk * skv, hnv) + head_offsets = ( + torch.arange(h_chunk, device=query.device, dtype=torch.int64).view(-1, 1, 1) * skv + ) - # =================================== - # Apply sparse mask from indexer - # =================================== - # index_mask [b, sq, skv] - index_mask = torch.full((b, sq, skv), float("-inf"), device=attention_scores.device) - index_mask.scatter_(-1, topk_indices, 0) - # causal_mask [sq, skv] - causal_mask = torch.triu( - torch.full((sq, skv), float('-inf'), dtype=torch.float32, device=index_mask.device), - diagonal=1, - ) - # [b, sq, skv] + [1, sq, skv] -> [b, sq, skv] - index_mask += causal_mask.view(1, sq, skv) - # [b, np, sq, skv] + [b, 1, sq, skv] -> [b, np, sq, skv] - attention_scores += index_mask.unsqueeze(1) - attention_scores = torch.nn.functional.softmax(attention_scores, dim=-1, dtype=torch.float32) + for s0 in range(0, sq, seq_chunk_size): + s1 = min(s0 + seq_chunk_size, sq) + s_len = s1 - s0 + idx_seq_raw = topk_indices[bi, s0:s1] # [s_len, topk] + if idx_seq_raw.dtype != torch.int64 or idx_seq_raw.device != query.device: + idx_seq_raw = idx_seq_raw.to(dtype=torch.int64, device=query.device) + valid_seq = idx_seq_raw >= 0 + idx_seq = idx_seq_raw.clamp(min=0, max=safe_k_max) + q_chunk = query_b[bi, h0:h1, s0:s1, :] # [h_chunk, s_len, hn] + + m = _get_scratch_buffer( + "unfused_dsa_m", (h_chunk, s_len), torch.float32, query.device + ) + l = _get_scratch_buffer( + "unfused_dsa_l", (h_chunk, s_len), torch.float32, query.device + ) + acc = _get_scratch_buffer( + "unfused_dsa_acc", (h_chunk, s_len, hnv), torch.float32, query.device + ) + m.fill_(float("-inf")) + l.zero_() + acc.zero_() + + for t0 in range(0, idx_seq.size(-1), topk_chunk_size): + t1 = min(t0 + topk_chunk_size, idx_seq.size(-1)) + idx_topk = idx_seq[:, t0:t1] # [s_len, tk] + valid_t = valid_seq[:, t0:t1] # [s_len, tk] + flat_idx = idx_topk.unsqueeze(0) + head_offsets # [h_chunk, s_len, tk] + k_sel = flat_k.index_select(0, flat_idx.reshape(-1)).view( + h_chunk, s_len, -1, hn + ) + v_sel = flat_v.index_select(0, flat_idx.reshape(-1)).view( + h_chunk, s_len, -1, hnv + ) + logits = (q_chunk.float().unsqueeze(2) * k_sel.float()).sum( + dim=-1 + ) * softmax_scale + + valid_2d, mask_bias = _gather_sparse_topk_validity_and_bias( + idx_topk=idx_topk, + valid_t=valid_t, + bi=bi, + s0=s0, + s1=s1, + row_mask=row_mask, + varlen_starts=varlen_starts, + varlen_ends=varlen_ends, + key_positions=key_positions, + dtype=torch.float32, + ) + if mask_bias is not None: + logits = logits + mask_bias.unsqueeze(0) + logits = logits.masked_fill( + ~valid_2d.unsqueeze(0).expand(h_chunk, -1, -1), float("-inf") + ) + m_new = torch.maximum(m, logits.max(dim=-1).values) + alpha = torch.exp(m - m_new) + alpha = torch.nan_to_num(alpha, nan=0.0, posinf=0.0, neginf=0.0) + p = torch.exp(logits - m_new.unsqueeze(-1)) + p = torch.nan_to_num(p, nan=0.0, posinf=0.0, neginf=0.0) + acc = acc * alpha.unsqueeze(-1) + torch.einsum( + "hst,hstd->hsd", p, v_sel.float() + ) + l = l * alpha + p.sum(dim=-1) + m = m_new + + out_chunk = (acc / l.clamp_min(1e-10).unsqueeze(-1)).to(dtype=value.dtype) + output[s0:s1, bi, out_h0:out_h1] = out_chunk.permute(1, 0, 2).reshape( + s_len, h_chunk * hnv + ) - # =================================== - # Output - # =================================== - # [skv, b, np, hnv] -> [b, np, skv, hnv] -> [b * np, skv, hnv] - value = value.permute(1, 2, 0, 3).reshape(b * np, skv, hnv) - # Reshape attention_scores: [b, np, sq, skv] -> [b * np, sq, skv] - attention_scores = attention_scores.reshape(b * np, sq, skv) - # Compute output: [b * np, sq, hnv] - output = torch.bmm(attention_scores.to(value.dtype), value) - # Reshape output: [b * np, sq, hnv] -> [b, np, sq, hnv] -> [sq, b, np, hnv] - output = output.reshape(b, np, sq, hnv).permute(2, 0, 1, 3).contiguous() - # Flatten: [sq, b, np, hnv] -> [sq, b, np * hnv] - output = output.reshape(sq, b, np * hnv) + if query_was_thd: + output = output.squeeze(1) return output @@ -1006,28 +2543,44 @@ def __init__( k_channels if k_channels is not None else config.kv_channels ) self.softmax_scale = softmax_scale + self.cp_comm_type = _normalize_cp_comm_type(cp_comm_type) + self._last_debug_path_msg = None + + def _debug_print_path(self, msg: str) -> None: + """Print DSAttention path transitions for debugging.""" + if msg == self._last_debug_path_msg: + return + self._last_debug_path_msg = msg + if not torch.distributed.is_available() or not torch.distributed.is_initialized(): + logger.info("[DSAttention][L%s] %s", self.layer_number, msg) + return + if torch.distributed.get_rank() == 0: + logger.info("[DSAttention][L%s] %s", self.layer_number, msg) def forward( self, query: torch.Tensor, key: torch.Tensor, - value: torch.Tensor, + value: Optional[torch.Tensor], attention_mask: torch.Tensor, x: torch.Tensor, qr: torch.Tensor, + position_ids: Optional[torch.Tensor] = None, attn_mask_type: AttnMaskType = None, attention_bias: torch.Tensor = None, packed_seq_params: PackedSeqParams = None, + up_v_weight: Optional[torch.Tensor] = None, ): """ Forward pass for Sparse Attention. Args: - query: Query tensor [sq, b, np, hn]. - key: Key tensor [skv, b, np, hn]. - value: Value tensor [skv, b, np, hnv]. + query: Query tensor [sq, b, np, hn] or packed [t, np, hn]. + key: Key tensor [skv, b, np, hn] or packed [t, np, hn]. + value: Value tensor [skv, b, np, hnv] or packed [t, np, hnv]. x: Original hidden states [sq, b, hidden_size]. qr: Low-rank query representation [sq, b, q_lora_rank]. + position_ids: Optional position ids [b, sq], used by allgather CP causal masking. attention_mask: Attention mask tensor [b, 1, sq, sk]. attn_mask_type: Type of attention mask. attention_bias: Optional attention bias. @@ -1036,84 +2589,231 @@ def forward( Returns: output: Output tensor [sq, b, hidden_size] """ - sq, b, np, hn = query.size() + query, _ = _ensure_sbhd(query, "query") + key, _ = _ensure_sbhd(key, "key") + if value is not None: + value, _ = _ensure_sbhd(value, "value") + if up_v_weight is not None: + assert up_v_weight.ndim == 3, "up_v_weight must be [heads, v_head_dim, kv_lora_rank]" + up_v_weight = up_v_weight.to(device=query.device, dtype=query.dtype).contiguous() + if value is not None: + raise RuntimeError( + "DSAttention received up_v_weight with explicit value tensor. " + "For absorbed DSA path, value must be None." + ) + + latent_v_channels = int(getattr(self.config, "kv_lora_rank", 0) or 0) + qk_pos_dim = int(getattr(self.config, "qk_pos_emb_head_dim", 0) or 0) + expected_absorbed_dim = latent_v_channels + qk_pos_dim + absorbed_layout = ( + latent_v_channels > 0 + and expected_absorbed_dim > 0 + and key.size(2) == 1 + and query.size(-1) == key.size(-1) == expected_absorbed_dim + ) + absorbed_mla = absorbed_layout + if value is None and not absorbed_mla: + raise RuntimeError( + "DSAttention received value=None but query/key are not in absorbed layout. " + f"query_hdim={query.size(-1)}, key_hdim={key.size(-1)}, key_heads={key.size(2)}, " + f"expected_absorbed_dim={expected_absorbed_dim}" + ) + if up_v_weight is not None and not absorbed_mla: + raise RuntimeError( + "DSAttention received up_v_weight but absorbed layout was not detected. " + f"query_hdim={query.size(-1)}, key_hdim={key.size(-1)}, key_heads={key.size(2)}, " + f"expected_absorbed_dim={expected_absorbed_dim}" + ) + if self.training and absorbed_mla and up_v_weight is None: + raise RuntimeError( + "Absorbed DSAttention training requires up_v_weight for latent-to-value projection." + ) + + sq, b, _, _ = query.size() + + cp_group = getattr(self.indexer.pg_collection, "cp", None) + cp_size = cp_group.size() if cp_group is not None else 1 + cp_rank = cp_group.rank() if cp_group is not None else 0 + + if cp_size > 1: + assert ( + self.cp_comm_type == "allgather" + ), "DSAttention context parallelism currently supports cp_comm_type=allgather only." + # For allgather CP, keys/values are expected in full-sequence order. + # Gather only if inputs are local-sequence tensors. + if key.size(0) == sq: + key = gather_from_sequence_parallel_region(key, group=cp_group) + if value is not None and value.size(0) == sq: + value = gather_from_sequence_parallel_region(value, group=cp_group) + skv = key.size(0) - hnv = value.size(3) # Detach x and qr to prevent gradients of indexer from flowing back to the main model. x = x.detach() qr = qr.detach() - # Get a FP32 mask with -inf for masked positions. - if attn_mask_type is not None: - assert attn_mask_type == AttnMaskType.causal, 'Only causal mask is supported for now' - # Generate upper triangular mask with -inf above diagonal, 0 elsewhere - # torch.triu with diagonal=1 creates upper triangular matrix (excluding main diagonal) - # float_mask [sq, skv] - float_mask = torch.triu( - torch.full((sq, skv), float('-inf'), dtype=torch.float32, device=x.device), - diagonal=1, - ) + indexer_loss_coeff = self.config.dsa_indexer_loss_coeff + use_indexer_loss = self.training and torch.is_grad_enabled() and indexer_loss_coeff > 0 + float_mask, varlen_params = _build_dsattention_forward_mask( + sq=sq, + skv=skv, + b=b, + device=x.device, + cp_size=cp_size, + cp_rank=cp_rank, + cp_comm_type=self.cp_comm_type, + cp_group=cp_group, + attn_mask_type=attn_mask_type, + attention_mask=attention_mask, + position_ids=position_ids, + packed_seq_params=packed_seq_params, + ) + if varlen_params is not None: + varlen_starts, varlen_ends, key_positions = varlen_params else: - assert attention_mask.shape == (b, 1, sq, skv), 'attention_mask shape mismatch' - # [b, 1, sq, skv] -> [b, sq, skv] - mask = attention_mask.squeeze() - # float_mask [b, sq, skv] - float_mask = torch.zeros_like(mask, dtype=torch.float32).masked_fill( - mask, float('-inf') - ) + varlen_starts = varlen_ends = key_positions = None + + # =================================== + # Prepare indexer inputs / top-k + # =================================== + q, k, weights = self.indexer.forward_before_topk(x, qr, packed_seq_params) + if cp_size > 1 and k.size(0) == sq: + k = gather_from_sequence_parallel_region(k, group=cp_group) + fused_bounds = _build_fused_indexer_varlen_bounds( + sq=sq, + skv=skv, + device=q.device, + mask=float_mask, + varlen_starts=varlen_starts, + varlen_ends=varlen_ends, + key_positions=key_positions, + ) - if self.training and torch.is_grad_enabled(): - # =================================== - # Prepare inputs for indexer loss - # =================================== - q, k, weights = self.indexer.forward_before_topk(x, qr, packed_seq_params) - indexer_loss_coeff = getattr(self.config, 'dsa_indexer_loss_coeff', 0.0) + topk_indices = None + indexer_path = "unknown" if use_indexer_loss else "naive_topk" + indexer_loss = None + if use_indexer_loss: # =================================== # Attach indexer topk and loss # =================================== - # Compute KL divergence loss between indexer scores and true attention scores - topk_indices, indexer_loss = FusedDSAIndexerLoss.apply( - q, - weights, - k, - query.detach(), - key.detach(), - self.softmax_scale, - self.indexer.index_topk, - indexer_loss_coeff, - float_mask, - getattr(self.config, "dsa_indexer_use_sparse_loss", False), - self.indexer.pg_collection, - ) - # Save indexer loss for logging + sparse_indexer_loss = self.config.dsa_indexer_use_sparse_loss + if sparse_indexer_loss and fused_bounds is not None: + starts_i32, ends_i32 = fused_bounds + block_size = int(getattr(self, "fused_indexer_block_size", 8192)) + fused_topk_with_loss = _fused_qk_topk_lighting_with_streaming_sparse_kl( + q, + k, + weights, + self.indexer.index_topk, + starts_i32, + ends_i32, + block_size=max(1, block_size), + query=query.detach(), + key=key.detach(), + softmax_scale=self.softmax_scale, + loss_coeff=indexer_loss_coeff, + pg_collection=self.indexer.pg_collection, + ) + if fused_topk_with_loss is not None: + topk_indices, indexer_loss = fused_topk_with_loss + indexer_path = "fused_topk_sparse_kl" + + if topk_indices is None or indexer_loss is None: + # Legacy dense path fallback. + key_for_loss = key.detach() + if absorbed_mla and key_for_loss.size(2) == 1 and query.size(2) > 1: + key_for_loss = key_for_loss.expand(-1, -1, query.size(2), -1).contiguous() + topk_indices, indexer_loss = FusedDSAIndexerLoss.apply( + q, + weights, + k, + query.detach(), + key_for_loss, + self.softmax_scale, + self.indexer.index_topk, + indexer_loss_coeff, + float_mask, + sparse_indexer_loss, + self.indexer.pg_collection, + varlen_starts, + varlen_ends, + key_positions, + ) + indexer_path = "dense_indexer_loss_fallback" + + # Save indexer loss for logging. if indexer_loss_coeff > 0: DSAIndexerLossLoggingHelper.save_loss_to_tracker( loss=indexer_loss, layer_number=self.layer_number, num_layers=self.config.num_layers, ) - - # =================================== - # Run sparse attention kernel - # =================================== - output = unfused_dsa_fn(query, key, value, topk_indices, self.softmax_scale) - - # Attach loss to output - output = DSAIndexerLossAutoScaler.apply(output, indexer_loss) - else: # =================================== - # Get index scores and top-k indices + # Get top-k indices # =================================== - _, topk_indices = self.indexer.forward_with_scores( - x, qr, mask=float_mask, packed_seq_params=packed_seq_params - ) + if fused_bounds is not None: + starts_i32, ends_i32 = fused_bounds + block_size = int(getattr(self, "fused_indexer_block_size", 8192)) + topk_indices = _fused_qk_topk_lighting( + q, + k, + weights, + self.indexer.index_topk, + starts_i32, + ends_i32, + block_size=max(1, block_size), + ) + if topk_indices is not None: + indexer_path = "fused_topk" + + if topk_indices is None: + _, topk_indices = fused_qk_topk_naive( + q, + k, + weights, + self.indexer.index_topk, + mask=float_mask, + varlen_starts=varlen_starts, + varlen_ends=varlen_ends, + key_positions=key_positions, + ) - # =================================== - # Run sparse attention kernel - # =================================== - output = unfused_dsa_fn(query, key, value, topk_indices, self.softmax_scale) + # =================================== + # Run sparse attention kernel + # =================================== + output, sparse_attn_path = _run_sparse_attention( + absorbed_mla=absorbed_mla, + query=query, + key=key, + value=value, + up_v_weight=up_v_weight, + topk_indices=topk_indices, + softmax_scale=self.softmax_scale, + config=self.config, + mask=float_mask, + varlen_starts=varlen_starts, + varlen_ends=varlen_ends, + key_positions=key_positions, + ) + sparse_attn_reason = _build_sparse_attn_reason( + sparse_attn_path=sparse_attn_path, + absorbed_mla=absorbed_mla, + query=query, + key=key, + topk_indices=topk_indices, + config=self.config, + ) + self._debug_print_path( + f"use_indexer_loss={use_indexer_loss}, indexer={indexer_path}, " + f"sparse_attn={sparse_attn_path}, cp_size={cp_size}, absorbed={absorbed_mla}, " + f"sparse_attn_reason={sparse_attn_reason}" + ) - return output + if use_indexer_loss: + if indexer_loss is None: + raise RuntimeError("Indexer loss path did not produce a valid loss tensor.") + output = DSAIndexerLossAutoScaler.apply(output, indexer_loss) + + return _normalize_dsattention_output_rank(output, x.ndim) diff --git a/megatron/core/transformer/experimental_attention_variant/ops/indexer.py b/megatron/core/transformer/experimental_attention_variant/ops/indexer.py new file mode 100644 index 00000000000..86504585132 --- /dev/null +++ b/megatron/core/transformer/experimental_attention_variant/ops/indexer.py @@ -0,0 +1,80 @@ +import torch + +from .tilelang_indexer_bwd import indexer_bwd_interface +from .tilelang_indexer_fwd import indexer_fwd_interface + + +def pytorch_extract_topk_scores(logits, topk_indices, dim=-1): + """Gather top-k logits and mask invalid (-1) entries with -inf.""" + valid_mask = topk_indices != -1 + safe_indices = topk_indices.clamp(min=0).to(torch.int64) + scores = torch.gather(logits, dim=dim, index=safe_indices) + scores = torch.where(valid_mask, scores, float("-inf")) + return scores + + +class IndexerFunction(torch.autograd.Function): + """Autograd wrapper for fused tilelang indexer forward/backward.""" + + @staticmethod + def forward( + ctx, + index_q: torch.Tensor, + index_k: torch.Tensor, + weights: torch.Tensor, + cu_seqlen_ks: torch.Tensor, + cu_seqlen_ke: torch.Tensor, + topk: int, + topk_indices: torch.Tensor | None = None, + ): + """Run fused indexer forward and optionally select top-k indices.""" + _, head_num, _ = index_q.shape + logits = indexer_fwd_interface( + index_q, index_k, weights, cu_seqlen_ks, cu_seqlen_ke, clean_logits=True + ) + if topk_indices is None: + index_score, topk_indices = torch.topk(logits, topk, dim=-1) + topk_indices = topk_indices.to(torch.int32) + topk_indices = topk_indices.masked_fill(index_score == -torch.inf, -1) + + index_score = pytorch_extract_topk_scores(logits, topk_indices) + + ctx.save_for_backward(index_q, index_k, weights, cu_seqlen_ks, cu_seqlen_ke, topk_indices) + ctx.topk = topk + ctx.head_num = head_num + return index_score, topk_indices + + @staticmethod + def backward(ctx, grad_scores, grad_indices): + """Propagate gradients through fused indexer outputs.""" + index_q, index_k, weights, cu_seqlen_ks, cu_seqlen_ke, topk_indices = ctx.saved_tensors + grad_q, grad_w, grad_k = indexer_bwd_interface( + index_q, weights, index_k, topk_indices, grad_scores + ) + return grad_q, grad_k, grad_w, None, None, None, None, None, None, None + + +def lighting_indexer( + index_q: torch.Tensor, + index_k: torch.Tensor, + weights: torch.Tensor, + cu_seqlen_ks: torch.Tensor, + cu_seqlen_ke: torch.Tensor, + topk: int, + topk_indices: torch.Tensor | None = None, +): + """Compute indexer top-k scores/indices via the custom autograd function.""" + return IndexerFunction.apply( + index_q, index_k, weights.squeeze(-1), cu_seqlen_ks, cu_seqlen_ke, topk, topk_indices + ) + + +def generate_varlen_mask_params(cu_seqlens): + """Generate inclusive-exclusive [start, end) bounds for each query row.""" + seq_len = cu_seqlens[-1].item() + q_indices = torch.arange(0, seq_len, device=cu_seqlens.device) + seq_indices = torch.searchsorted(cu_seqlens, q_indices, right=True) - 1 + starts = cu_seqlens[seq_indices] + ends = q_indices + 1 + assert torch.all((ends - starts) > 0) + return starts, ends diff --git a/megatron/core/transformer/experimental_attention_variant/ops/sparse_mla.py b/megatron/core/transformer/experimental_attention_variant/ops/sparse_mla.py new file mode 100644 index 00000000000..c2916ce031a --- /dev/null +++ b/megatron/core/transformer/experimental_attention_variant/ops/sparse_mla.py @@ -0,0 +1,48 @@ +import torch + +from .tilelang_sparse_mla_bwd import sparse_mla_bwd +from .tilelang_sparse_mla_fwd import sparse_mla_fwd_interface + + +class SparseMLA(torch.autograd.Function): + """Autograd wrapper around tilelang sparse-MLA forward/backward kernels.""" + + @staticmethod + def forward(ctx, q, kv, indices, scaling): + """ + Args: + q: Query tensor (seq_len, heads, dim_plus_tail_dim) + kv: Key-Value tensor (seq_len_kv, kv_group, dim_plus_tail_dim) + indices: Sparse indices tensor (seq_len, kv_group, topk) + + Returns: + out: Output tensor (seq_len, heads, dim) + """ + indices = indices.contiguous() + q, kv = q.contiguous(), kv.contiguous() + ctx.scaling = scaling + tl_out, tl_lse = sparse_mla_fwd_interface(q, kv, indices, sm_scale=scaling) + + # Save tensors for backward pass + ctx.save_for_backward(q, kv, indices, tl_out, tl_lse) + + return tl_out, tl_lse + + @staticmethod + def backward(ctx, grad_output, grad_lse): + """ + Args: + grad_output: Gradient of the loss with respect to output + + Returns: + Gradients for q, kv, and indices (None for indices) + """ + q, kv, indices, tl_out, tl_lse = ctx.saved_tensors + scaling = ctx.scaling + + tl_dq, tl_dkv = sparse_mla_bwd( + q, kv, tl_out, grad_output.contiguous(), indices, tl_lse, sm_scale=scaling + ) + + # Return gradients for each input (None for indices as it's not differentiable) + return tl_dq, tl_dkv, None, None diff --git a/megatron/core/transformer/experimental_attention_variant/ops/tilelang_indexer_bwd.py b/megatron/core/transformer/experimental_attention_variant/ops/tilelang_indexer_bwd.py new file mode 100644 index 00000000000..3189b9afe0a --- /dev/null +++ b/megatron/core/transformer/experimental_attention_variant/ops/tilelang_indexer_bwd.py @@ -0,0 +1,168 @@ +# ruff: noqa +# Adapted from: +# https://github.com/tile-ai/tilelang/blob/4956b5835fa554af6c03d4a6289cad44bf310869/ +# examples/dsa_sparse_finetune/indexer_bwd.py +import tilelang as tl +import tilelang.language as T +import torch + +BF16 = T.bfloat16 +FP32 = T.float32 +INT32 = T.int32 + +pass_configs = { + tl.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tl.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, +} + + +@tl.jit(pass_configs=pass_configs) +def tl_indexer_bwd_impl( + heads: int, dim: int, topk: int, block_I: int = 32, num_stages: int = 0, num_threads: int = 128 +): + """Build tilelang backward kernel for sparse indexer.""" + assert num_stages == 0 + assert topk == tl.math.next_power_of_2(topk) + assert topk % block_I == 0 + assert heads <= 64 and heads % 8 == 0 + seq_len = T.symbolic("seq_len") + q_seq_len = T.symbolic("q_seq_len") + + dtype: str = BF16 + accum_dtype: str = FP32 + index_q_shape = [q_seq_len, heads, dim] + weights_shape = [q_seq_len, heads] + index_k_shape = [seq_len, dim] + shape_p = [q_seq_len, topk] + topk_indices_shape = [q_seq_len, topk] + + pad_heads = heads + if heads < 16: + pad_heads = 16 + + @T.prim_func + def tl_indexer_bwd_kernel( + IndexQ: T.Tensor(index_q_shape, dtype), + IndexK: T.Tensor(index_k_shape, dtype), + Weights: T.Tensor(weights_shape, FP32), + TopkIndices: T.Tensor(topk_indices_shape, INT32), + OGrad: T.Tensor(shape_p, FP32), + dIndexQ: T.Tensor(index_q_shape, dtype), + dWeights: T.Tensor(weights_shape, FP32), + dIndexK: T.Tensor(index_k_shape, FP32), + ): + + with T.Kernel(q_seq_len, threads=num_threads) as (bx): + index_q_shared = T.alloc_shared([pad_heads, dim], dtype=FP32) + weights_shared = T.alloc_shared([pad_heads], dtype=FP32) + index_k_shared = T.alloc_shared([block_I, dim], dtype=FP32) + indices_shared = T.alloc_shared([block_I], dtype=INT32) + d_index_q_frag = T.alloc_fragment([pad_heads, dim], dtype=accum_dtype) + d_weights_frag = T.alloc_fragment([pad_heads], dtype=accum_dtype) + d_index_k_frag = T.alloc_fragment([block_I, dim], dtype=accum_dtype) + logits = T.alloc_fragment((block_I, pad_heads), dtype=accum_dtype) + _logits = T.alloc_shared((block_I, pad_heads), dtype=accum_dtype) + grad = T.alloc_shared([block_I], dtype=FP32) + + num_blocks = T.ceildiv(topk, block_I) + for i, j in T.Parallel(pad_heads, dim): + index_q_shared[i, j] = T.if_then_else(i < heads, IndexQ[bx, i, j], 0) + for i in T.Parallel(heads): + weights_shared[i] = Weights[bx, i] + + T.fill(d_index_q_frag, 0) + T.fill(d_weights_frag, 0) + + # for bi_i in T.Pipelined(num_blocks, num_stages=num_stages): + for bi_i in T.serial(num_blocks): + for i in T.Parallel(block_I): + if bi_i * block_I + i < topk: + indices_shared[i] = TopkIndices[bx, bi_i * block_I + i] + grad[i] = OGrad[bx, bi_i * block_I + i] + + T.sync_threads() + for i, j in T.Parallel(block_I, dim): + index_k_shared[i, j] = T.if_then_else( + indices_shared[i] > -1 and indices_shared[i] < seq_len, + IndexK[indices_shared[i], j], + 0, + ) + + T.sync_threads() + T.gemm( + index_k_shared, + index_q_shared, + logits, + transpose_A=False, + transpose_B=True, + clear_accum=True, + ) + for i, j in T.Parallel(block_I, heads): + logits[i, j] = T.max(logits[i, j], 0) + + d_weights_i = T.alloc_fragment((block_I, pad_heads), accum_dtype) + for i, j in T.Parallel(block_I, heads): + d_weights_i[i, j] = grad[i] * logits[i, j] + T.reduce_sum(d_weights_i, d_weights_frag, dim=0, clear=False) + + for i, j in T.Parallel(block_I, pad_heads): + _logits[i, j] = T.if_then_else( + logits[i, j] > 0 and j < heads, grad[i] * weights_shared[j], 0 + ) + T.sync_threads() + T.gemm( + _logits, + index_k_shared, + d_index_q_frag, + transpose_A=True, + transpose_B=False, + clear_accum=False, + ) + + T.gemm( + _logits, + index_q_shared, + d_index_k_frag, + transpose_A=False, + transpose_B=False, + clear_accum=True, + ) + + for i, j in T.Parallel(block_I, dim): + if indices_shared[i] > -1 and indices_shared[i] < seq_len: + T.atomic_add(dIndexK[indices_shared[i], j], d_index_k_frag[i, j]) + + T.copy(d_index_q_frag[:heads, :], dIndexQ[bx, :, :]) + T.copy(d_weights_frag[:heads], dWeights[bx, :]) + + return tl_indexer_bwd_kernel + + +def indexer_bwd_interface( + index_q: torch.Tensor, + weights: torch.Tensor, + index_k: torch.Tensor, + topk_indices: torch.Tensor, + grad_scores: torch.Tensor, +): + """Run indexer backward kernel and return gradients for q/w/k.""" + _, head_num, head_dim = index_q.shape + k_top = topk_indices.shape[1] + + grad_scores = grad_scores.contiguous() + grad_q = torch.empty_like(index_q) + grad_w = torch.empty_like(weights, dtype=torch.float32) + grad_k = torch.zeros_like(index_k, dtype=torch.float32) + + tl_indexer_bwd_impl(head_num, head_dim, k_top)( + index_q.contiguous(), + index_k.contiguous(), + weights.squeeze(-1).contiguous(), + topk_indices.contiguous(), + grad_scores, + grad_q, + grad_w.squeeze(-1), + grad_k, + ) + + return grad_q, grad_w, grad_k diff --git a/megatron/core/transformer/experimental_attention_variant/ops/tilelang_indexer_fwd.py b/megatron/core/transformer/experimental_attention_variant/ops/tilelang_indexer_fwd.py new file mode 100644 index 00000000000..0e1e4a59b48 --- /dev/null +++ b/megatron/core/transformer/experimental_attention_variant/ops/tilelang_indexer_fwd.py @@ -0,0 +1,132 @@ +# ruff: noqa +# Adapted from: +# https://github.com/tile-ai/tilelang/blob/4956b5835fa554af6c03d4a6289cad44bf310869/ +# examples/deepseek_v32/fp8_lighting_indexer.py +import tilelang +import torch +from tilelang import language as T + + +@tilelang.jit(pass_configs={tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True}) +def tl_indexer_fwd_impl(heads, index_dim, block_N=256, num_stages=3, threads=512, block_Q=None): + """Build tilelang forward kernel for sparse indexer logits.""" + if block_Q is None: + block_Q = 128 // heads + dtype = T.bfloat16 + accum_dtype = T.float32 + index_dtype = T.int32 + + seq_len = T.dynamic("seq_len") + seq_len_kv = T.dynamic("seq_len_kv") + + index_q_shape = [seq_len * heads, index_dim] + index_k_shape = [seq_len_kv, index_dim] + logits_shape = [seq_len, seq_len_kv] + + @T.prim_func + def tl_indexer_fwd_kernel( + IndexQ: T.Tensor(index_q_shape, dtype), # type: ignore + IndexK: T.Tensor(index_k_shape, dtype), # type: ignore + Logits: T.Tensor(logits_shape, accum_dtype), # type: ignore + Weights: T.Tensor([seq_len, heads], accum_dtype), # type: ignore + CuSeqLenKS: T.Tensor([seq_len], index_dtype), # type: ignore + CuSeqLenKE: T.Tensor([seq_len], index_dtype), # type: ignore + ): + with T.Kernel(T.ceildiv(seq_len, block_Q), threads=threads) as bx: + index_q_shared = T.alloc_shared([block_Q * heads, index_dim], dtype) + index_k_shared = T.alloc_shared([block_N, index_dim], dtype) + s = T.alloc_fragment([block_N, block_Q * heads], accum_dtype) + s_reshaped = T.reshape(s, (block_N, block_Q, heads)) + logits = T.alloc_fragment([block_N, block_Q], accum_dtype) + weights = T.alloc_fragment([block_Q, heads], accum_dtype) + + seq_len_i = bx * block_Q + + cu_k_s_min = T.alloc_var(index_dtype) + cu_k_e_max = T.alloc_var(index_dtype) + + cu_k_s_min = 2147483647 + cu_k_e_max = -2147483648 + + for bq_i in T.serial(block_Q): + cu_k_s_min = T.min(cu_k_s_min, T.min(CuSeqLenKS[seq_len_i + bq_i], seq_len_kv)) + for bq_i in T.serial(block_Q): + cu_k_e_max = T.max(cu_k_e_max, T.min(CuSeqLenKE[seq_len_i + bq_i], seq_len_kv)) + + T.copy(IndexQ[seq_len_i * heads, 0], index_q_shared) + T.copy(Weights[seq_len_i, 0], weights) + + for nbn_i in T.Pipelined( + T.ceildiv(cu_k_e_max - cu_k_s_min, block_N), num_stages=num_stages + ): + T.copy(IndexK[cu_k_s_min + nbn_i * block_N, 0], index_k_shared) + + T.gemm( + index_k_shared, + index_q_shared, + s, + transpose_B=True, + clear_accum=True, + policy=T.GemmWarpPolicy.FullCol, + ) + + for bn_i, bq_i, h_i in T.Parallel(block_N, block_Q, heads): + s_reshaped[bn_i, bq_i, h_i] = ( + T.max(s_reshaped[bn_i, bq_i, h_i], 0) * weights[bq_i, h_i] + ) + + T.reduce_sum(s_reshaped, logits, dim=-1, clear=True) + + for bq_i, bn_i in T.Parallel(block_Q, block_N): + Logits[seq_len_i + bq_i, cu_k_s_min + nbn_i * block_N + bn_i] = logits[ + bn_i, bq_i + ] + + return tl_indexer_fwd_kernel + + +@tilelang.jit +def clean_logits_(threads: int = 512, block_K: int = 4096): + """Build kernel that masks out invalid key ranges in logits.""" + seq_len = T.dynamic("seq_len") + seq_len_kv = T.dynamic("seq_len_kv") + + dtype = T.float + indices_dtype = T.int32 + + @T.prim_func + def clean_logits_kernel( + Logits: T.Tensor([seq_len, seq_len_kv], dtype), # type: ignore + CuSeqLenKS: T.Tensor([seq_len], indices_dtype), # type: ignore + CuSeqLenKE: T.Tensor([seq_len], indices_dtype), # type: ignore + ): + with T.Kernel(seq_len, threads=threads) as bx: + tx = T.thread_binding(0, threads, thread="threadIdx.x") + cu_k_s = CuSeqLenKS[bx] + cu_k_e = CuSeqLenKE[bx] + + for n_i in T.Pipelined(T.ceildiv(seq_len_kv, block_K)): + for k_i in T.serial(block_K // threads): + idx = n_i * block_K + k_i * threads + tx + if idx < cu_k_s or idx >= cu_k_e: + Logits[bx, idx] = -T.infinity(dtype) + + return clean_logits_kernel + + +def indexer_fwd_interface(q, kv, weights, cu_seqlen_ks, cu_seqlen_ke, clean_logits=True): + """Run indexer forward kernel and optionally clean logits by row bounds.""" + seq_len, heads, index_dim = q.shape + seq_len_kv = kv.shape[0] + + clean_logits_kernel = clean_logits_() + + tl_indexer_fwd_kernel = tl_indexer_fwd_impl(heads=heads, index_dim=index_dim) + + logits = torch.empty([seq_len, seq_len_kv], device=q.device, dtype=torch.float32) + tl_indexer_fwd_kernel( + q.view(seq_len * heads, index_dim), kv, logits, weights, cu_seqlen_ks, cu_seqlen_ke + ) + if clean_logits: + clean_logits_kernel(logits, cu_seqlen_ks, cu_seqlen_ke) + return logits diff --git a/megatron/core/transformer/experimental_attention_variant/ops/tilelang_sparse_mla_bwd.py b/megatron/core/transformer/experimental_attention_variant/ops/tilelang_sparse_mla_bwd.py new file mode 100644 index 00000000000..01e13ac1efe --- /dev/null +++ b/megatron/core/transformer/experimental_attention_variant/ops/tilelang_sparse_mla_bwd.py @@ -0,0 +1,355 @@ +# ruff: noqa +# Adapted from: +# https://github.com/tile-ai/tilelang/blob/4ff81c7d40803d269569e157e847623e84553f78/ +# examples/deepseek_v32/sparse_mla_bwd.py +import tilelang +import torch +from tilelang import language as T + + +@tilelang.jit(out_idx=[-1]) +def preprocess(B, S, H, D, block_ND=32, num_stages=5, dtype=T.bfloat16, accum_dtype=T.float32): + """Build preprocessing kernel that computes Delta = sum(O * dO) per row/head.""" + assert dtype == T.bfloat16 + assert accum_dtype == T.float32 + shape = [B, S, H, D] + + @T.prim_func + def preprocess_kernel( + O: T.Tensor(shape, dtype), + dO: T.Tensor(shape, dtype), + Delta: T.Tensor([B, S, H], accum_dtype), + ): + with T.Kernel(H, T.ceildiv(S, block_ND), B) as (bx, by, bz): + o = T.alloc_fragment([block_ND, block_ND], accum_dtype) + do = T.alloc_fragment([block_ND, block_ND], accum_dtype) + delta = T.alloc_fragment([block_ND], accum_dtype) + acc = T.alloc_fragment([block_ND, block_ND], accum_dtype) + T.clear(acc) + for k in T.Pipelined(T.ceildiv(D, block_ND), num_stages=num_stages): + T.copy( + O[ + bz, + by * block_ND : (by + 1) * block_ND, + bx, + k * block_ND : (k + 1) * block_ND, + ], + o, + ) + T.copy( + dO[ + bz, + by * block_ND : (by + 1) * block_ND, + bx, + k * block_ND : (k + 1) * block_ND, + ], + do, + ) + for i, j in T.Parallel(block_ND, block_ND): + acc[i, j] += o[i, j] * do[i, j] + T.reduce_sum(acc, delta, 1) + T.copy(delta, Delta[bz, by * block_ND : (by + 1) * block_ND, bx]) + + return preprocess_kernel + + +@tilelang.jit(out_idx=[-1]) +def postprocess( + B, S_kv, D, D_tail, kv_group=1, block_N=64, threads=128, dtype=T.bfloat16, accum_dtype=T.float32 +): + """Build postprocess kernel that casts/exports accumulated dKV.""" + assert dtype == T.bfloat16 + assert accum_dtype == T.float32 + dkv_shape = [B, S_kv, kv_group, D + D_tail] + + @T.prim_func + def postprocess_kernel( + dKV: T.Tensor(dkv_shape, accum_dtype), dKV_out: T.Tensor(dkv_shape, dtype) + ): + with T.Kernel(T.ceildiv(S_kv, block_N), kv_group, B, threads=threads) as (bx, by, bz): + T.copy( + dKV[bz, bx * block_N : (bx + 1) * block_N, by, :], + dKV_out[bz, bx * block_N : (bx + 1) * block_N, by, :], + ) + + return postprocess_kernel + + +@tilelang.jit( + out_idx=[-2], + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + tilelang.PassConfigKey.TL_ENABLE_AGGRESSIVE_SHARED_MEMORY_MERGE: True, + }, +) +def bwd( + B, + S, + S_kv, + H, + D, + D_tail, + topk, + kv_group=1, + sm_scale=None, + is_causal=True, + block_size=32, + num_stages=0, + threads=128, + indices_dtype=T.int32, + dtype=T.bfloat16, + accum_dtype=T.float32, +): + """Build sparse-MLA backward kernel.""" + assert is_causal == True, "non-casual is not supported now" + assert ( + topk % block_size == 0 + ), "otherwise will load some index=0 thus causing wrong kv to be loaded" + assert dtype == T.bfloat16 + assert accum_dtype == T.float32 + assert indices_dtype == T.int32 + + if sm_scale is None: + sm_scale = (D + D_tail) ** (-0.5) + sm_scale_mul_reciprocal_log2 = sm_scale * 1.44269504 # log2(e) + + H_kv = H // kv_group + q_shape = [B, S, H, D + D_tail] + k_shape = [B, S_kv, kv_group, D + D_tail] + o_shape = [B, S, H, D] + indices_shape = [B, S, kv_group, topk] + delta_shape = [B, S, H] + lse_shape = [B, S, H] + assert indices_dtype == T.int32 + assert dtype == T.bfloat16 + assert accum_dtype == T.float32 + + H = H_kv + padded_H = max(tilelang.math.next_power_of_2(H_kv), 16) + block_H = min(64, padded_H) + assert padded_H % block_H == 0 + NH = padded_H // block_H + BS = block_size + NS = tilelang.cdiv(topk, block_size) + + split_store = 2 + + @T.prim_func + def sparse_mla_bwd_kernel( + Q: T.Tensor(q_shape, dtype), + KV: T.Tensor(k_shape, dtype), + dO: T.Tensor(o_shape, dtype), + Indices: T.Tensor(indices_shape, indices_dtype), + Lse: T.Tensor(lse_shape, accum_dtype), + Delta: T.Tensor(delta_shape, accum_dtype), + dQ: T.Tensor(q_shape, dtype), + dKV: T.Tensor(k_shape, accum_dtype), + ): + with T.Kernel(S, B, kv_group * NH, threads=threads) as (s_i, by, bz): + Q_shared = T.alloc_shared([block_H, D], dtype) + Q_tail_shared = T.alloc_shared([block_H, D_tail], dtype) + KV_shared = T.alloc_shared([BS, D], dtype) + KV_tail_shared = T.alloc_shared([BS, D_tail], dtype) + dO_shared = T.alloc_shared([block_H, D], dtype) + mask = T.alloc_fragment([BS], "bool") + + P_shared_cast = T.alloc_shared([block_H, BS], dtype) + dP_shared_cast = T.alloc_shared([block_H, BS], dtype) + dQ_shared = T.alloc_shared([block_H, D], dtype) + dQ_tail_shared = T.alloc_shared([block_H, D_tail], dtype) + + acc_p = T.alloc_fragment([block_H, BS], accum_dtype) + acc_dp = T.alloc_fragment([block_H, BS], accum_dtype) + acc_dq = T.alloc_fragment([block_H, D], accum_dtype) + acc_dq_tail = T.alloc_fragment([block_H, D_tail], accum_dtype) + acc_dkv = T.alloc_fragment([BS, D], accum_dtype) + acc_dkv_tail = T.alloc_fragment([BS, D_tail], accum_dtype) + acc_dkv_shared = T.alloc_shared([BS // split_store, D], accum_dtype) + acc_dkv_tail_shared = T.alloc_shared([BS // split_store, D_tail], accum_dtype) + + # max_kv_i = s_i + + T.copy(Q[by, s_i, bz * block_H : (bz + 1) * block_H, :D], Q_shared) + T.copy(Q[by, s_i, bz * block_H : (bz + 1) * block_H, D:], Q_tail_shared) + T.copy(dO[by, s_i, bz * block_H : (bz + 1) * block_H, :D], dO_shared) + + T.clear(acc_dq) + T.clear(acc_dq_tail) + + # Process each block of indices + for i_i in T.Pipelined(NS, num_stages=num_stages): + # Check which indices are valid + for bi_i in T.Parallel(BS): + # Changed here for thd + mask[bi_i] = Indices[by, s_i, bz // NH, i_i * BS + bi_i] != -1 + + # Compute attention scores + for h_i, bi_i in T.Parallel(block_H, BS): + acc_p[h_i, bi_i] = T.if_then_else(mask[bi_i], 0, -T.infinity(acc_p.dtype)) + + # Load KV, V for this block of indices + for bi_i, d_i in T.Parallel(BS, D): + KV_shared[bi_i, d_i] = KV[ + by, Indices[by, s_i, bz // NH, i_i * BS + bi_i], bz // NH, d_i + ] + + T.gemm( + Q_shared, KV_shared, acc_p, transpose_B=True, policy=T.GemmWarpPolicy.FullCol + ) + + for bi_i, d_i in T.Parallel(BS, D_tail): + KV_tail_shared[bi_i, d_i] = KV[ + by, Indices[by, s_i, bz // NH, i_i * BS + bi_i], bz // NH, D + d_i + ] + T.gemm( + Q_tail_shared, + KV_tail_shared, + acc_p, + transpose_B=True, + policy=T.GemmWarpPolicy.FullCol, + ) + + for h_i, bi_i in T.Parallel(block_H, BS): + acc_p[h_i, bi_i] = T.exp2( + acc_p[h_i, bi_i] * sm_scale_mul_reciprocal_log2 + - Lse[by, s_i, bz * block_H + h_i] + ) + + T.copy(acc_p, P_shared_cast) + + T.gemm( + dO_shared, + KV_shared, + acc_dp, + transpose_B=True, + policy=T.GemmWarpPolicy.FullCol, + clear_accum=True, + ) + + for h_i, bi_i in T.Parallel(block_H, BS): + acc_dp[h_i, bi_i] = ( + acc_p[h_i, bi_i] + * (acc_dp[h_i, bi_i] - Delta[by, s_i, bz * block_H + h_i]) + * sm_scale + ) + + T.copy(acc_dp, dP_shared_cast) + T.gemm(dP_shared_cast, KV_shared, acc_dq, policy=T.GemmWarpPolicy.FullCol) + T.gemm(dP_shared_cast, KV_tail_shared, acc_dq_tail, policy=T.GemmWarpPolicy.FullCol) + + T.gemm( + dP_shared_cast, + Q_shared, + acc_dkv, + transpose_A=True, + policy=T.GemmWarpPolicy.FullCol, + clear_accum=True, + ) + T.gemm( + P_shared_cast, + dO_shared, + acc_dkv, + transpose_A=True, + policy=T.GemmWarpPolicy.FullCol, + ) + + T.clear(acc_dkv_tail) + T.gemm( + dP_shared_cast, + Q_tail_shared, + acc_dkv_tail, + transpose_A=True, + policy=T.GemmWarpPolicy.FullCol, + ) + + for s in range(split_store): + for bi_i, d_i in T.Parallel(BS, D): + if bi_i < BS // split_store: + acc_dkv_shared[bi_i, d_i] = acc_dkv[bi_i + s * (BS // split_store), d_i] + + for bi_i, d_i in T.Parallel(BS, D_tail): + if bi_i < BS // split_store: + acc_dkv_tail_shared[bi_i, d_i] = acc_dkv_tail[ + bi_i + s * (BS // split_store), d_i + ] + + for bi_i, d_i in T.Parallel(BS // split_store, D // 4): + T.atomic_addx4( + dKV[ + by, + Indices[ + by, s_i, bz // NH, i_i * BS + bi_i + s * (BS // split_store) + ], + bz // NH, + d_i * 4, + ], + acc_dkv_shared[bi_i, d_i * 4], + ) + + # Atomically update dKV, dKV_tail tensors + for bi_i, d_i in T.Parallel(BS // split_store, D_tail // 4): + T.atomic_addx4( + dKV[ + by, + Indices[ + by, s_i, bz // NH, i_i * BS + bi_i + s * (BS // split_store) + ], + bz // NH, + D + d_i * 4, + ], + acc_dkv_tail_shared[bi_i, d_i * 4], + ) + + # Store the accumulated dQ + T.copy(acc_dq, dQ_shared) + T.copy(acc_dq_tail, dQ_tail_shared) + + T.copy(dQ_shared, dQ[by, s_i, bz * block_H : (bz + 1) * block_H, :D]) + T.copy(dQ_tail_shared, dQ[by, s_i, bz * block_H : (bz + 1) * block_H, D:]) + + return sparse_mla_bwd_kernel + + +def sparse_mla_bwd( + q, kv, o, do, indices, lse, sm_scale=None, is_casual=True, return_kernel=False, delta=None +): + """Run sparse-MLA backward kernels and return (dq, dkv).""" + q = q.unsqueeze(0) + kv = kv.unsqueeze(0) + o = o.unsqueeze(0) + do = do.unsqueeze(0) + indices = indices.unsqueeze(0) + lse = lse.unsqueeze(0) + + assert q.is_contiguous() + assert kv.is_contiguous() + assert indices.is_contiguous() + assert lse.is_contiguous() + B, S, H, dim_plus_tail_dim = q.shape + _, S_kv, kv_group, _ = kv.shape + assert kv.shape[-1] == dim_plus_tail_dim + assert kv.shape[0] == B + # dim should be assigned + D = 512 + + D_tail = dim_plus_tail_dim - D + topk = indices.shape[-1] + assert indices.shape == (B, S, kv_group, topk) + assert lse.shape == (B, S, H) + + # Get kernels + preprocess_kernel = preprocess(B, S, H, D) + bwd_kernel = bwd(B, S, S_kv, H, D, D_tail, topk, kv_group, sm_scale, is_casual) + postprocess_kernel = postprocess(B, S_kv, D, D_tail, kv_group) + + if delta is None: + delta = preprocess_kernel(o, do) + dkv = torch.zeros_like(kv, dtype=torch.float32) + dq = bwd_kernel(q, kv, do, indices, lse, delta, dkv) + dkv = postprocess_kernel(dkv) + + dq = dq.squeeze(0) + dkv = dkv.squeeze(0) + + return dq, dkv diff --git a/megatron/core/transformer/experimental_attention_variant/ops/tilelang_sparse_mla_fwd.py b/megatron/core/transformer/experimental_attention_variant/ops/tilelang_sparse_mla_fwd.py new file mode 100644 index 00000000000..19d1a80963f --- /dev/null +++ b/megatron/core/transformer/experimental_attention_variant/ops/tilelang_sparse_mla_fwd.py @@ -0,0 +1,222 @@ +# ruff: noqa +# Adapted from: +# https://github.com/tile-ai/tilelang/blob/e666d2d3cc483829c57618c9ebf2e4f4ada0819d/ +# examples/deepseek_v32/sparse_mla_fwd.py +import tilelang +from tilelang import language as T + + +@tilelang.jit( + out_idx=[-2, -1], + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }, +) +def sparse_mla_fwd( + heads, + dim, + tail_dim, + topk, + kv_group=1, + sm_scale=None, + is_causal=True, + CP0=True, + block_I=64, + num_stages=2, + threads=256, +): + """Build sparse-MLA forward kernel.""" + assert dim == tilelang.math.next_power_of_2( + dim + ), f"haven't check padding correctness yet, dim={dim}" + assert tail_dim == tilelang.math.next_power_of_2( + tail_dim + ), f"haven't check padding correctness yet, dim={tail_dim}" + assert is_causal == True, "non-casual is not supported" + assert ( + topk % block_I == 0 + ), "otherwise will load some index=0 thus causing wrong kv to be loaded" + if sm_scale is None: + sm_scale = (1.0 / (dim + tail_dim)) ** 0.5 * 1.44269504 # log2(e) + else: + sm_scale = sm_scale * 1.44269504 # log2(e) + + batch = T.dynamic("batch") + seq_len = T.dynamic("seq_len") + seq_len_kv = T.dynamic("seq_len_kv") + + head_kv = heads // kv_group + q_shape = [batch, seq_len, heads, dim + tail_dim] + kv_shape = [batch, seq_len_kv, kv_group, dim + tail_dim] + o_shape = [batch, seq_len, heads, dim] + indices_shape = [batch, seq_len, kv_group, topk] + lse_shape = [batch, seq_len, heads] + indices_dtype = T.int32 + dtype = T.bfloat16 + accum_dtype = T.float32 + + G = kv_group + H = head_kv + padded_H = max(tilelang.math.next_power_of_2(head_kv), 16) + if padded_H != H: + assert kv_group == 1, ( + "here we solve the H padding automatically, otherwise handle Q/Output copy with " + "your own mask (for kv_group==1, g_i*padded_H:(g_i+1)*padded_H is handled)" + ) + BI = block_I + NI = tilelang.cdiv(topk, block_I) + D = dim + D_tail = tail_dim + + if head_kv > 64: + assert head_kv % 64 == 0, "head_kv should be a multiple of 64" + REPLICATE_H = head_kv // 64 + else: + REPLICATE_H = 1 + + H_per_block = padded_H if REPLICATE_H == 1 else 64 + + @T.prim_func + def main( + Q: T.Tensor(q_shape, dtype), # type: ignore + KV: T.Tensor(kv_shape, dtype), # type: ignore + Indices: T.Tensor(indices_shape, indices_dtype), # type: ignore + Output: T.Tensor(o_shape, dtype), # type: ignore + Lse: T.Tensor(lse_shape, accum_dtype), # type: ignore + ): + with T.Kernel(seq_len * REPLICATE_H, batch, kv_group, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([H_per_block, D], dtype) + Q_tail_shared = T.alloc_shared([H_per_block, D_tail], dtype) + KV_shared = T.alloc_shared([BI, D], dtype) + K_tail_shared = T.alloc_shared([BI, D_tail], dtype) + O_shared = T.alloc_shared([H_per_block, D], dtype) + Lse_shared = T.alloc_shared([H_per_block], accum_dtype) + mask = T.alloc_fragment([BI], "bool") + + acc_o = T.alloc_fragment([H_per_block, D], accum_dtype) + acc_s = T.alloc_fragment([H_per_block, BI], accum_dtype) + S_shared = T.alloc_shared([H_per_block, BI], dtype) + sumexp = T.alloc_fragment([H_per_block], accum_dtype) + sumexp_i = T.alloc_fragment([H_per_block], accum_dtype) + alpha = T.alloc_fragment([H_per_block], accum_dtype) + m_i = T.alloc_fragment([H_per_block], accum_dtype) + m_i_prev = T.alloc_fragment([H_per_block], accum_dtype) + + T.fill(acc_o, 0) + T.fill(sumexp, 0) + T.fill(m_i, -(2**30)) # avoid -inf - inf to cause nan + + b_i, g_i = by, bz + s_i = bx if REPLICATE_H == 1 else (bx // REPLICATE_H) + q_i = s_i + max_kv_i = q_i + + H0 = g_i * padded_H + (0 if REPLICATE_H == 1 else (bx % REPLICATE_H) * 64) + H1 = H0 + H_per_block + + T.copy(Q[b_i, s_i, H0:H1, :D], Q_shared) + T.copy(Q[b_i, s_i, H0:H1, D:], Q_tail_shared) + + for i_i in T.Pipelined(NI, num_stages=num_stages): + for bi_i in T.Parallel(BI): + # Changed here for thd + mask[bi_i] = Indices[b_i, s_i, g_i, i_i * BI + bi_i] != -1 + + for bi_i, d_i in T.Parallel(BI, D): + KV_shared[bi_i, d_i] = KV[ + b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i, d_i + ] + for bi_i, d_i in T.Parallel(BI, D_tail): + K_tail_shared[bi_i, d_i] = KV[ + b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i, D + d_i + ] + + for h_i, bi_i in T.Parallel(H_per_block, BI): + acc_s[h_i, bi_i] = T.if_then_else(mask[bi_i], 0, -T.infinity(acc_s.dtype)) + T.gemm( + Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow + ) + T.gemm( + Q_tail_shared, + K_tail_shared, + acc_s, + transpose_B=True, + policy=T.GemmWarpPolicy.FullRow, + ) + T.copy(m_i, m_i_prev) + T.reduce_max(acc_s, m_i, dim=1, clear=False) + for h_i in T.Parallel(H_per_block): + m_i[h_i] = T.max(m_i[h_i], m_i_prev[h_i]) + for h_i in T.Parallel(H_per_block): + alpha[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale) + for h_i, bi_i in T.Parallel(H_per_block, BI): + acc_s[h_i, bi_i] = T.exp2(acc_s[h_i, bi_i] * sm_scale - m_i[h_i] * sm_scale) + T.reduce_sum(acc_s, sumexp_i, dim=1) # is this a accumulate operator? + for h_i in T.Parallel(H_per_block): + sumexp[h_i] = sumexp[h_i] * alpha[h_i] + sumexp_i[h_i] + for h_i, d_i in T.Parallel(H_per_block, D): + acc_o[h_i, d_i] = acc_o[h_i, d_i] * alpha[h_i] + + T.copy(acc_s, S_shared) + T.gemm(S_shared, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + + # Rescale + for h_i, d_i in T.Parallel(H_per_block, D): + acc_o[h_i, d_i] /= sumexp[h_i] + for h_i in T.Parallel(H_per_block): + sumexp[h_i] = T.log2(sumexp[h_i]) + m_i[h_i] * sm_scale + + T.copy(acc_o, Output[b_i, s_i, H0:H1, :]) + T.copy(sumexp, Lse[b_i, s_i, H0:H1]) + + return main + + +def sparse_mla_fwd_interface( + q, + kv, + indices, + sm_scale=None, + return_p_sum: bool = False, + d_v=512, + block_I=64, + num_stages=2, + threads=256, +): + """Run sparse-MLA forward kernel and return (out, lse).""" + q = q.unsqueeze(0) + kv = kv.unsqueeze(0) + indices = indices.unsqueeze(0) + + is_casual = True + assert return_p_sum == False, "This kernel file is for fwd only" + assert q.is_contiguous() and kv.is_contiguous() and indices.is_contiguous() + batch, seq_len, heads, dim_plus_tail_dim = q.shape + _, seq_len_kv, kv_group, _ = kv.shape + + assert dim_plus_tail_dim == 576, "you should assign dim otherwise" + dim = d_v + + assert kv.shape[-1] == dim_plus_tail_dim + tail_dim = dim_plus_tail_dim - dim + assert kv.shape[0] == batch + _, _, _, topk = indices.shape + assert indices.shape == (batch, seq_len, kv_group, topk) + + kernel = sparse_mla_fwd( + heads, + dim, + tail_dim, + topk, + kv_group, + sm_scale, + is_casual, + block_I=block_I, + num_stages=num_stages, + threads=threads, + ) + out, lse = kernel(q, kv, indices) + out = out.squeeze(0) + lse = lse.squeeze(0) + return out, lse diff --git a/megatron/core/transformer/multi_latent_attention.py b/megatron/core/transformer/multi_latent_attention.py index a9cdc697cc8..9ccd41dec62 100644 --- a/megatron/core/transformer/multi_latent_attention.py +++ b/megatron/core/transformer/multi_latent_attention.py @@ -290,6 +290,11 @@ def forward( # query representation. extra_kwargs["x"] = hidden_states extra_kwargs["qr"] = q_compressed + extra_kwargs["position_ids"] = position_ids + query, key, value, up_v_weight = self.get_absorb_query_key_value_tensors( + query, key, kv_compressed + ) + extra_kwargs["up_v_weight"] = up_v_weight with off_interface( self.offload_core_attention and self.training, query, "core_attn" ) as query: @@ -360,6 +365,91 @@ def forward( return output, bias + def get_absorb_query_key_value_tensors( + self, query: torch.Tensor, key: torch.Tensor, kv_compressed: torch.Tensor + ): + """Build absorbed query/key/value tensors for DSA static path. + + Returns: + query_absorbed, key_absorbed, value_absorbed(None), up_v_weight + """ + if self.linear_kv_up_proj is None: + raise RuntimeError("DSA absorbed path requires linear_kv_up_proj, but it is missing.") + + linear_kv_up_proj = self.linear_kv_up_proj + if not hasattr(linear_kv_up_proj, "weight") and hasattr(linear_kv_up_proj, "to_wrap"): + linear_kv_up_proj = linear_kv_up_proj.to_wrap + + kv_up_weight = linear_kv_up_proj.weight.view( + self.num_attention_heads_per_partition, + self.config.qk_head_dim + self.config.v_head_dim, + self.config.kv_lora_rank, + ) + up_k_weight = kv_up_weight[:, : self.config.qk_head_dim, :].contiguous() + up_v_weight = kv_up_weight[:, self.config.qk_head_dim :, :].contiguous() + + def _align_kv_latent_seq_len(kv_latent: torch.Tensor, target_seqlen: int) -> torch.Tensor: + """Align kv_latent sequence length with absorbed key/query sequence length.""" + if kv_latent.size(0) == target_seqlen: + return kv_latent + if self.config.sequence_parallel and get_pg_size(self.tp_group) > 1: + kv_latent = gather_from_sequence_parallel_region(kv_latent, group=self.tp_group) + if kv_latent.size(0) != target_seqlen: + raise RuntimeError( + "DSA absorbed rewrite sequence mismatch after SP alignment: " + f"kv_latent_seqlen={kv_latent.size(0)}, target_seqlen={target_seqlen}. " + "Check sequence_parallel and q/kv gathering consistency." + ) + return kv_latent + + if query.ndim == 4 and key.ndim == 4: + # query: [s, b, h, qk+pos] -> [s, b, h, kv_lora+pos] + q_no_pe = query[..., : self.config.qk_head_dim] + q_pos = query[..., self.config.qk_head_dim :] + q_content = torch.einsum("sbhd,hdk->sbhk", q_no_pe, up_k_weight) + query = torch.cat([q_content, q_pos], dim=-1).contiguous() + + # key: [s, b, h, qk+pos] -> [s, b, 1, kv_lora+pos] + if kv_compressed.ndim == 2: + kv_latent = kv_compressed.unsqueeze(1) + elif kv_compressed.ndim == 3: + kv_latent = kv_compressed + else: + raise RuntimeError( + f"Unsupported kv_compressed ndim={kv_compressed.ndim} for DSA absorbed path." + ) + kv_latent = _align_kv_latent_seq_len(kv_latent, target_seqlen=key.size(0)) + k_pos = key[:, :, 0, self.config.qk_head_dim :].contiguous() + key = torch.cat([kv_latent, k_pos], dim=-1).unsqueeze(2).contiguous() + value = None + elif query.ndim == 3 and key.ndim == 3: + # Packed THD path: query [t, h, qk+pos], key [t, h, qk+pos]. + q_no_pe = query[..., : self.config.qk_head_dim] + q_pos = query[..., self.config.qk_head_dim :] + q_content = torch.einsum("thd,hdk->thk", q_no_pe, up_k_weight) + query = torch.cat([q_content, q_pos], dim=-1).contiguous() + + if kv_compressed.ndim == 3: + kv_latent = kv_compressed.squeeze(1) + elif kv_compressed.ndim == 2: + kv_latent = kv_compressed + else: + raise RuntimeError( + "Unsupported kv_compressed ndim=" + f"{kv_compressed.ndim} for packed DSA absorbed path." + ) + kv_latent = _align_kv_latent_seq_len(kv_latent, target_seqlen=key.size(0)) + k_pos = key[:, 0, self.config.qk_head_dim :].contiguous() + key = torch.cat([kv_latent, k_pos], dim=-1).unsqueeze(1).contiguous() + value = None + else: + raise RuntimeError( + f"Unsupported query/key ndim for DSA absorbed rewrite: " + f"query.ndim={query.ndim}, key.ndim={key.ndim}" + ) + + return query, key, value, up_v_weight + class MLASelfAttention(MultiLatentAttention): """MLA Self-attention layer class diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index 642af8415d3..af17d099a47 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -2158,9 +2158,6 @@ def __post_init__(self): assert not self.use_kitchen if self.experimental_attention_variant == "dsa": - assert ( - self.context_parallel_size == 1 - ), "Currently context parallelism is not supported by DSAttention!" assert not self.apply_rope_fusion, "RoPE fusion is not supported for DSAttention" if self.inference_fuse_tp_communication: diff --git a/tests/unit_tests/transformer/test_attention_variant_dsa.py b/tests/unit_tests/transformer/test_attention_variant_dsa.py index 96253a4ca10..a4fd5ae59a7 100644 --- a/tests/unit_tests/transformer/test_attention_variant_dsa.py +++ b/tests/unit_tests/transformer/test_attention_variant_dsa.py @@ -6,10 +6,9 @@ import torch import megatron.core.parallel_state as parallel_state -from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec +from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed -from megatron.core.transformer import TransformerConfig from megatron.core.transformer.enums import AttnMaskType from megatron.core.transformer.experimental_attention_variant.dsa import ( DSAIndexer, @@ -18,21 +17,30 @@ DSAttention, DSAttentionSubmodules, FusedDSAIndexerLoss, + _build_fused_indexer_varlen_bounds, + _unfused_absorbed_dsa_fn, + _fused_qk_topk_lighting, + _fused_qk_topk_lighting_with_streaming_sparse_kl, + _scatter_topk_into_index_mask, + _build_causal_mask_from_positions, + _generate_varlen_mask_params, _compute_index_scores, + _get_cp_positions_from_layout, compute_dsa_indexer_loss, + compute_dsa_indexer_loss_topk_sparse, fused_qk_topk_naive, rotate_activation, + unfused_dsa_fn, ) from megatron.core.transformer.transformer_config import MLATransformerConfig from tests.unit_tests.test_utilities import Utils try: - from fast_hadamard_transform import hadamard_transform as _hadamard_transform + from fast_hadamard_transform import hadamard_transform HAVE_HADAMARD = True except ImportError: HAVE_HADAMARD = False - _hadamard_transform = None def mock_hadamard_transform(x: torch.Tensor, scale: float = 1.0) -> torch.Tensor: @@ -43,6 +51,114 @@ def mock_hadamard_transform(x: torch.Tensor, scale: float = 1.0) -> torch.Tensor return x * scale +def _build_packed_causal_mask_for_test( + query_idx: torch.Tensor, key_idx: torch.Tensor, cu_seqlens: torch.Tensor +) -> torch.Tensor: + """Build packed-sequence causal mask for tests.""" + query_idx = query_idx.to(dtype=torch.int64) + key_idx = key_idx.to(dtype=torch.int64) + cu_seqlens = cu_seqlens.to(device=query_idx.device, dtype=torch.int64) + + boundaries = cu_seqlens[1:] + query_seq_id = torch.searchsorted(boundaries, query_idx, right=True) + key_seq_id = torch.searchsorted(boundaries, key_idx, right=True) + valid = (query_seq_id.unsqueeze(-1) == key_seq_id.unsqueeze(0)) & ( + key_idx.unsqueeze(0) <= query_idx.unsqueeze(-1) + ) + mask = torch.zeros( + (query_idx.numel(), key_idx.numel()), dtype=torch.float32, device=query_idx.device + ) + mask.masked_fill_(~valid, float("-inf")) + return mask + + +def _fake_lighting_indexer_for_test( + index_q: torch.Tensor, + index_k: torch.Tensor, + index_w: torch.Tensor, + starts: torch.Tensor, + ends: torch.Tensor, + index_topk: int, + topk_indices: torch.Tensor | None = None, +): + """Reference fake indexer for testing fused batched loop plumbing.""" + del topk_indices + + # [sq, h, d] @ [sk, d]^T -> [sq, h, sk] + logits = torch.einsum("qhd,kd->qhk", index_q.float(), index_k.float()) + logits = torch.relu(logits) * index_w.float().unsqueeze(-1) + logits = logits.sum(dim=1) # [sq, sk] + + key_pos = torch.arange(index_k.size(0), dtype=torch.int64, device=logits.device) + valid = (key_pos.unsqueeze(0) >= starts.to(torch.int64).unsqueeze(-1)) & ( + key_pos.unsqueeze(0) < ends.to(torch.int64).unsqueeze(-1) + ) + logits = logits.masked_fill(~valid, float("-inf")) + + topk_k = min(index_topk, logits.size(-1)) + topk_scores, topk_idx = torch.topk(logits, topk_k, dim=-1) + topk_idx = topk_idx.to(torch.int32) + topk_idx = topk_idx.masked_fill(topk_scores == float("-inf"), -1) + return topk_scores, topk_idx + + +def _fake_fused_scores_indices_for_test( + q: torch.Tensor, + k: torch.Tensor, + weights: torch.Tensor, + starts: torch.Tensor, + ends: torch.Tensor, + index_topk: int, + block_size: int, +) -> tuple[torch.Tensor, torch.Tensor]: + """Build fused-style batched/chunked top-k scores+indices with fake indexer.""" + sq, b = q.size(0), q.size(1) + scores_out, idx_out = None, None + for bi in range(b): + for s0 in range(0, sq, block_size): + s1 = min(s0 + block_size, sq) + scores_chunk, idx_chunk = _fake_lighting_indexer_for_test( + q[:, bi][s0:s1], + k[:, bi], + weights[:, bi][s0:s1], + starts[s0:s1], + ends[s0:s1], + index_topk, + ) + if scores_out is None: + scores_out = torch.empty( + (b, sq, scores_chunk.size(-1)), + dtype=scores_chunk.dtype, + device=scores_chunk.device, + ) + if idx_out is None: + idx_out = torch.empty( + (b, sq, idx_chunk.size(-1)), dtype=idx_chunk.dtype, device=idx_chunk.device + ) + scores_out[bi, s0:s1].copy_(scores_chunk) + idx_out[bi, s0:s1].copy_(idx_chunk) + assert scores_out is not None and idx_out is not None + return scores_out, idx_out + + +class _FakeTPGroup: + def size(self) -> int: + return 1 + + +class _FakeCPGroup: + def __init__(self, size: int): + self._size = size + + def size(self) -> int: + return self._size + + +class _FakePGCollection: + def __init__(self): + self.tp = _FakeTPGroup() + + @pytest.fixture(autouse=True) def patch_hadamard_if_needed(): """Automatically patch hadamard_transform in dsa module if not installed.""" @@ -56,6 +172,397 @@ def patch_hadamard_if_needed(): yield +class TestDSACPPositionHelpers: + """Test helper utilities used for DSAttention context-parallel masking.""" + + def test_allgather_layout_positions(self): + """Allgather CP layout should map to contiguous query and global key positions.""" + query_pos, key_pos = _get_cp_positions_from_layout( + sq=4, skv=8, cp_size=2, cp_rank=1, cp_comm_type="allgather", device=torch.device("cpu") + ) + assert query_pos.tolist() == [4, 5, 6, 7] + assert key_pos.tolist() == list(range(8)) + + def test_position_based_causal_mask(self): + """Position-based causal mask should mask keys with strictly larger positions.""" + query_pos = torch.tensor([0, 2], dtype=torch.int64) + key_pos = torch.tensor([0, 1, 2, 3], dtype=torch.int64) + mask = _build_causal_mask_from_positions(query_pos, key_pos) + expected = torch.tensor( + [[0.0, float("-inf"), float("-inf"), float("-inf")], [0.0, 0.0, 0.0, float("-inf")]], + dtype=torch.float32, + ) + torch.testing.assert_close(mask, expected, rtol=0, atol=0) + + def test_packed_position_based_causal_mask(self): + """Packed causal mask should block cross-sequence attention using cu_seqlens boundaries.""" + # Two packed sequences: [0,1,2] and [3,4] + cu_seqlens = torch.tensor([0, 3, 5], dtype=torch.int32) + query_idx = torch.tensor([1, 3, 4], dtype=torch.int64) + key_idx = torch.tensor([0, 1, 2, 3, 4], dtype=torch.int64) + + mask = _build_packed_causal_mask_for_test(query_idx, key_idx, cu_seqlens) + expected = torch.tensor( + [ + [0.0, 0.0, float("-inf"), float("-inf"), float("-inf")], + [float("-inf"), float("-inf"), float("-inf"), 0.0, float("-inf")], + [float("-inf"), float("-inf"), float("-inf"), 0.0, 0.0], + ], + dtype=torch.float32, + ) + torch.testing.assert_close(mask, expected, rtol=0, atol=0) + + def test_topk_uses_key_length(self): + """Top-k selection should be bounded by key length, not query length.""" + sq, skv, bsz, nheads, dim = 4, 7, 1, 2, 8 + topk = 6 + q = torch.randn(sq, bsz, nheads, dim, dtype=torch.float32) + k = torch.randn(skv, bsz, dim, dtype=torch.float32) + weights = torch.randn(sq, bsz, nheads, dtype=torch.float32) + + _, topk_indices = fused_qk_topk_naive(q, k, weights, topk, mask=None) + assert topk_indices.shape == (bsz, sq, topk) + + def test_cp_packed_varlen_end_to_end_matches_dense_mask(self): + """CP+THD multi-sequence varlen path should match dense packed mask end-to-end.""" + # Simulate cp_size=2 allgather layout with local query chunk and global keys. + cp_size, cp_rank = 2, 1 + sq, skv = 4, 8 + bsz, nheads, dim, vdim = 1, 2, 8, 6 + topk = 4 + softmax_scale = dim**-0.5 + + # Three packed sequences in global stream: [0,1,2], [3,4], [5,6,7] + cu_seqlens = torch.tensor([0, 3, 5, 8], dtype=torch.int32) + query_idx, key_idx = _get_cp_positions_from_layout( + sq=sq, + skv=skv, + cp_size=cp_size, + cp_rank=cp_rank, + cp_comm_type="allgather", + device=torch.device("cpu"), + ) + + # Build varlen starts/ends for local query rows. + starts_all, ends_all = _generate_varlen_mask_params(cu_seqlens.to(torch.int64)) + starts = starts_all.index_select(0, query_idx) + ends = ends_all.index_select(0, query_idx) + + q = torch.randn(sq, bsz, nheads, dim, dtype=torch.float32) + k_for_index = torch.randn(skv, bsz, dim, dtype=torch.float32) + weights = torch.randn(sq, bsz, nheads, dtype=torch.float32) + query = torch.randn(sq, bsz, nheads, dim, dtype=torch.float32) + key = torch.randn(skv, bsz, nheads, dim, dtype=torch.float32) + value = torch.randn(skv, bsz, nheads, vdim, dtype=torch.float32) + + dense_mask = _build_packed_causal_mask_for_test(query_idx, key_idx, cu_seqlens) + _, dense_idx = fused_qk_topk_naive(q, k_for_index, weights, topk, mask=dense_mask) + out_dense = unfused_dsa_fn(query, key, value, dense_idx, softmax_scale, mask=dense_mask) + + _, varlen_idx = fused_qk_topk_naive( + q, + k_for_index, + weights, + topk, + mask=None, + varlen_starts=starts, + varlen_ends=ends, + key_positions=key_idx, + ) + out_varlen = unfused_dsa_fn( + query, + key, + value, + varlen_idx, + softmax_scale, + mask=None, + varlen_starts=starts, + varlen_ends=ends, + key_positions=key_idx, + ) + + torch.testing.assert_close(out_varlen, out_dense, rtol=0, atol=0) + + def test_cp_packed_varlen_uneven_rank_lengths_matches_dense_mask(self, monkeypatch): + """CP+THD varlen path should match dense mask under uneven per-rank query lengths.""" + # Simulate cp_size=2, cp_rank=1, local query lengths [3, 5]. + cp_size, cp_rank = 2, 1 + local_lengths = [3, 5] + sq, skv = local_lengths[cp_rank], sum(local_lengths) + bsz, nheads, dim, vdim = 1, 2, 8, 6 + topk = 4 + softmax_scale = dim**-0.5 + + fake_cp_group = _FakeCPGroup(cp_size) + + monkeypatch.setattr(torch.distributed, "is_available", lambda: True) + monkeypatch.setattr(torch.distributed, "is_initialized", lambda: True) + + def _fake_all_gather(out, local_len, group=None): + del local_len, group + for i, tensor in enumerate(out): + tensor.copy_( + torch.tensor([local_lengths[i]], dtype=tensor.dtype, device=tensor.device) + ) + + monkeypatch.setattr(torch.distributed, "all_gather", _fake_all_gather) + + # Packed global stream has three sequences: [0,1], [2,3,4], [5,6,7] + cu_seqlens = torch.tensor([0, 2, 5, 8], dtype=torch.int32) + query_idx, key_idx = _get_cp_positions_from_layout( + sq=sq, + skv=skv, + cp_size=cp_size, + cp_rank=cp_rank, + cp_comm_type="allgather", + device=torch.device("cpu"), + cp_group=fake_cp_group, + ) + assert query_idx.tolist() == [3, 4, 5, 6, 7] + + starts_all, ends_all = _generate_varlen_mask_params(cu_seqlens.to(torch.int64)) + starts = starts_all.index_select(0, query_idx) + ends = ends_all.index_select(0, query_idx) + + q = torch.randn(sq, bsz, nheads, dim, dtype=torch.float32) + k_for_index = torch.randn(skv, bsz, dim, dtype=torch.float32) + weights = torch.randn(sq, bsz, nheads, dtype=torch.float32) + query = torch.randn(sq, bsz, nheads, dim, dtype=torch.float32) + key = torch.randn(skv, bsz, nheads, dim, dtype=torch.float32) + value = torch.randn(skv, bsz, nheads, vdim, dtype=torch.float32) + + dense_mask = _build_packed_causal_mask_for_test(query_idx, key_idx, cu_seqlens) + _, dense_idx = fused_qk_topk_naive(q, k_for_index, weights, topk, mask=dense_mask) + out_dense = unfused_dsa_fn(query, key, value, dense_idx, softmax_scale, mask=dense_mask) + + _, varlen_idx = fused_qk_topk_naive( + q, + k_for_index, + weights, + topk, + mask=None, + varlen_starts=starts, + varlen_ends=ends, + key_positions=key_idx, + ) + out_varlen = unfused_dsa_fn( + query, + key, + value, + varlen_idx, + softmax_scale, + mask=None, + varlen_starts=starts, + varlen_ends=ends, + key_positions=key_idx, + ) + + torch.testing.assert_close(out_varlen, out_dense, rtol=0, atol=0) + + def test_fused_topk_batched_loop_matches_reference(self): + """Fused batched/chunked top-k loop should match per-batch reference outputs.""" + sq, skv, bsz, heads, dim = 9, 13, 3, 4, 8 + topk = 5 + block_size = 4 + + q = torch.randn(sq, bsz, heads, dim, dtype=torch.float32) + k = torch.randn(skv, bsz, dim, dtype=torch.float32) + weights = torch.randn(sq, bsz, heads, dtype=torch.float32) + starts = torch.zeros(sq, dtype=torch.int32) + ends = torch.arange(1, sq + 1, dtype=torch.int32).clamp_max(skv) + + expected = [] + for bi in range(bsz): + _, ref_idx = _fake_lighting_indexer_for_test( + q[:, bi], k[:, bi], weights[:, bi], starts, ends, topk + ) + expected.append(ref_idx) + expected = torch.stack(expected, dim=0) + + with patch( + "megatron.core.transformer.experimental_attention_variant.dsa.lighting_indexer", + _fake_lighting_indexer_for_test, + ): + got = _fused_qk_topk_lighting( + q=q, + k=k, + weights=weights, + index_topk=topk, + starts=starts, + ends=ends, + block_size=block_size, + ) + + assert got is not None + assert got.shape == expected.shape + assert got.dtype == expected.dtype + assert torch.equal(got, expected) + + def test_fused_streaming_sparse_kl_matches_reference(self): + """Streaming fused sparse-KL path should match reference top-k sparse KL.""" + sq, skv, bsz, heads, dim = 10, 12, 2, 4, 8 + topk = 6 + block_size = 4 + softmax_scale = dim**-0.5 + + q = torch.randn(sq, bsz, heads, dim, dtype=torch.float32) + k = torch.randn(skv, bsz, dim, dtype=torch.float32) + weights = torch.randn(sq, bsz, heads, dtype=torch.float32) + # MQA key for target attention distribution. + query = torch.randn(sq, bsz, heads, dim, dtype=torch.float32) + key = torch.randn(skv, bsz, 1, dim, dtype=torch.float32) + starts = torch.zeros(sq, dtype=torch.int32) + ends = torch.arange(1, sq + 1, dtype=torch.int32).clamp_max(skv) + fake_pg = _FakePGCollection() + + ref_scores, ref_idx = _fake_fused_scores_indices_for_test( + q, k, weights, starts, ends, topk, block_size + ) + ref_loss = compute_dsa_indexer_loss_topk_sparse( + index_topk_scores=ref_scores, + topk_indices=ref_idx, + query=query, + key=key, + softmax_scale=softmax_scale, + loss_coeff=1.0, + pg_collection=fake_pg, + ) + + with patch( + "megatron.core.transformer.experimental_attention_variant.dsa.lighting_indexer", + _fake_lighting_indexer_for_test, + ): + fused_out = _fused_qk_topk_lighting_with_streaming_sparse_kl( + q=q, + k=k, + weights=weights, + index_topk=topk, + starts=starts, + ends=ends, + block_size=block_size, + query=query, + key=key, + softmax_scale=softmax_scale, + loss_coeff=1.0, + pg_collection=fake_pg, + ) + + assert fused_out is not None + got_idx, got_loss = fused_out + assert torch.equal(got_idx, ref_idx) + torch.testing.assert_close(got_loss, ref_loss, rtol=1e-5, atol=1e-5) + + def test_fused_bounds_disable_on_per_batch_mask_mismatch(self): + """Fused bounds should disable when batched masks are not identical.""" + sq, skv, bsz = 5, 7, 2 + base_mask = torch.triu( + torch.full((sq, skv), float("-inf"), dtype=torch.float32), diagonal=1 + ) + mask = base_mask.unsqueeze(0).expand(bsz, -1, -1).clone() + out = _build_fused_indexer_varlen_bounds( + sq=sq, + skv=skv, + device=mask.device, + mask=mask, + varlen_starts=None, + varlen_ends=None, + key_positions=None, + ) + assert out is not None + + # Change one batch mask so masks are no longer identical. + mask[1, 0, 0] = float("-inf") + out_mismatch = _build_fused_indexer_varlen_bounds( + sq=sq, + skv=skv, + device=mask.device, + mask=mask, + varlen_starts=None, + varlen_ends=None, + key_positions=None, + ) + assert out_mismatch is None + + def test_scatter_topk_chunked_matches_manual_with_negative_indices(self): + """Chunked top-k scatter should match manual behavior for -1 invalid indices.""" + b, sq, skv = 2, 4, 6 + topk_indices = torch.tensor( + [ + [[0, 2, -1], [1, -1, -1], [2, 4, 5], [3, -1, 0]], + [[5, 4, 1], [0, -1, 2], [3, -1, -1], [1, 2, 3]], + ], + dtype=torch.int32, + ) + got = torch.full((b, sq, skv), float("-inf"), dtype=torch.float32) + _scatter_topk_into_index_mask(got, topk_indices, seq_chunk_size=2) + + expected = torch.full((b, sq, skv), float("-inf"), dtype=torch.float32) + topk_i64 = topk_indices.to(torch.int64) + valid = topk_i64 >= 0 + b_idx, q_idx, t_idx = torch.where(valid) + k_idx = topk_i64[b_idx, q_idx, t_idx] + expected[b_idx, q_idx, k_idx] = 0.0 + + assert torch.equal(got, expected) + + +class TestDSAAbsorbedParityCPU: + """CPU parity tests for absorbed DSA rewrite.""" + + def test_absorbed_path_matches_non_absorbed_output(self): + """Absorbed attention + up_v projection should match non-absorbed attention output.""" + torch.manual_seed(1234) + + sq, skv, bsz, nheads = 6, 6, 1, 3 + qk_dim, qk_pos_dim = 5, 2 + kv_lora_rank, vdim = 4, 3 + softmax_scale = (qk_dim + qk_pos_dim) ** -0.5 + + # Build synthetic tensors consistent with the absorbed rewrite equations. + q_no_pe = torch.randn(sq, bsz, nheads, qk_dim, dtype=torch.float32) + q_pos = torch.randn(sq, bsz, nheads, qk_pos_dim, dtype=torch.float32) + kv_latent = torch.randn(skv, bsz, kv_lora_rank, dtype=torch.float32) + k_pos_shared = torch.randn(skv, bsz, 1, qk_pos_dim, dtype=torch.float32) + + up_k_weight = torch.randn(nheads, qk_dim, kv_lora_rank, dtype=torch.float32) + up_v_weight = torch.randn(nheads, vdim, kv_lora_rank, dtype=torch.float32) + + # Non-absorbed tensors. + query_non_abs = torch.cat([q_no_pe, q_pos], dim=-1).contiguous() + k_no_pe = torch.einsum("sbk,hqk->sbhq", kv_latent, up_k_weight) + key_non_abs = torch.cat([k_no_pe, k_pos_shared.expand(-1, -1, nheads, -1)], dim=-1) + value_non_abs = torch.einsum("sbk,hvk->sbhv", kv_latent, up_v_weight).contiguous() + + # Absorbed tensors. + q_content_abs = torch.einsum("sbhq,hqk->sbhk", q_no_pe, up_k_weight) + query_abs = torch.cat([q_content_abs, q_pos], dim=-1).contiguous() + key_abs = torch.cat([kv_latent.unsqueeze(2), k_pos_shared], dim=-1).contiguous() + + # Use full-key support and causal masking in both paths. + topk_indices = ( + torch.arange(skv, dtype=torch.int64).view(1, 1, skv).expand(bsz, sq, skv).contiguous() + ) + causal_mask = torch.triu( + torch.full((sq, skv), float("-inf"), dtype=torch.float32), diagonal=1 + ) + + out_non_abs = unfused_dsa_fn( + query_non_abs, key_non_abs, value_non_abs, topk_indices, softmax_scale, mask=causal_mask + ) + out_abs_latent = _unfused_absorbed_dsa_fn( + query_abs, + key_abs, + topk_indices, + softmax_scale, + v_channels=kv_lora_rank, + mask=causal_mask, + ) + out_abs = torch.einsum("sbhc,hdc->sbhd", out_abs_latent, up_v_weight).contiguous() + out_abs = out_abs.view(sq, bsz, -1) + + torch.testing.assert_close(out_abs, out_non_abs, rtol=1e-4, atol=1e-5) + + class TestRotateActivation: """Test rotate_activation function.""" @@ -381,6 +888,9 @@ def test_fused_indexer_loss_gradient_matches_autograd(self, seqlen_and_topk, spa mask, sparse_loss, self.pg_collection, + None, + None, + None, ) # Backward with manual implementation @@ -486,6 +996,9 @@ def test_fused_indexer_loss_gradient_tp_consistency( mask, sparse_loss, pg_collection_tp1, + None, + None, + None, ) loss_tp1.backward() @@ -534,6 +1047,9 @@ def test_fused_indexer_loss_gradient_tp_consistency( mask, sparse_loss, pg_collection_tpn, + None, + None, + None, ) loss_tpn.backward() @@ -695,6 +1211,36 @@ def test_dsa_indexer_forward_with_scores(self, seqlen): != torch.sort(topk_indices, dim=-1).values[:, :, :-1] ) + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + def test_dsa_indexer_forward_with_scores_packed_thd(self, seqlen): + """Test indexer forward_with_scores works with packed THD inputs.""" + batch_size = 1 + self.indexer.cuda() + + x = torch.randn(seqlen, batch_size, self.config.hidden_size, dtype=torch.bfloat16).cuda() + qr = torch.randn(seqlen, batch_size, self.config.q_lora_rank, dtype=torch.bfloat16).cuda() + + cu_seqlens = torch.tensor([0, seqlen], dtype=torch.int32, device=x.device) + packed_seq_params = PackedSeqParams( + qkv_format="thd", + cu_seqlens_q=cu_seqlens, + cu_seqlens_kv=cu_seqlens, + max_seqlen_q=seqlen, + max_seqlen_kv=seqlen, + ) + token_idx = torch.arange(seqlen, dtype=torch.int64, device=x.device) + mask = _build_packed_causal_mask_for_test(token_idx, token_idx, cu_seqlens) + + index_scores, topk_indices = self.indexer.forward_with_scores( + x, qr, mask=mask, packed_seq_params=packed_seq_params + ) + + assert index_scores.shape == (batch_size, seqlen, seqlen) + assert topk_indices.shape == (batch_size, seqlen, min(self.config.dsa_indexer_topk, seqlen)) + assert index_scores.dtype == torch.float32 + assert topk_indices.dtype == torch.long + assert torch.all((topk_indices >= 0) & (topk_indices < seqlen)) + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_dsa_indexer_with_mask(self, seqlen): """Test indexer with attention mask.""" From 1223eb55d64650e955f32fb68a15b865e24479b4 Mon Sep 17 00:00:00 2001 From: Hollow Man Date: Tue, 3 Mar 2026 14:41:39 +0200 Subject: [PATCH 2/6] Add cache for compiled kernel Signed-off-by: Hollow Man --- .../experimental_attention_variant/dsa.py | 10 +- .../ops/indexer.py | 2 +- .../ops/tilelang_indexer_bwd.py | 57 ++++++- .../ops/tilelang_indexer_fwd.py | 100 +++++++++--- .../ops/tilelang_sparse_mla_bwd.py | 142 +++++++++++++++++- .../ops/tilelang_sparse_mla_fwd.py | 130 ++++++++++++++-- 6 files changed, 402 insertions(+), 39 deletions(-) diff --git a/megatron/core/transformer/experimental_attention_variant/dsa.py b/megatron/core/transformer/experimental_attention_variant/dsa.py index 9ce12433725..e0d9b3d00f2 100644 --- a/megatron/core/transformer/experimental_attention_variant/dsa.py +++ b/megatron/core/transformer/experimental_attention_variant/dsa.py @@ -34,12 +34,18 @@ from megatron.core.transformer.experimental_attention_variant.ops.indexer import ( lighting_indexer, ) -except Exception: +except (ImportError, OSError): + logger.debug( + "Failed to import fused TileLang indexer; lighting_indexer path disabled.", exc_info=True + ) lighting_indexer = None try: from megatron.core.transformer.experimental_attention_variant.ops.sparse_mla import SparseMLA -except Exception: +except (ImportError, OSError): + logger.debug( + "Failed to import fused TileLang SparseMLA; SparseMLA path disabled.", exc_info=True + ) SparseMLA = None # Reusable no-grad scratch buffers keyed by (name, shape, dtype, device). diff --git a/megatron/core/transformer/experimental_attention_variant/ops/indexer.py b/megatron/core/transformer/experimental_attention_variant/ops/indexer.py index 86504585132..af23fe6efb5 100644 --- a/megatron/core/transformer/experimental_attention_variant/ops/indexer.py +++ b/megatron/core/transformer/experimental_attention_variant/ops/indexer.py @@ -51,7 +51,7 @@ def backward(ctx, grad_scores, grad_indices): grad_q, grad_w, grad_k = indexer_bwd_interface( index_q, weights, index_k, topk_indices, grad_scores ) - return grad_q, grad_k, grad_w, None, None, None, None, None, None, None + return grad_q, grad_k, grad_w, None, None, None, None def lighting_indexer( diff --git a/megatron/core/transformer/experimental_attention_variant/ops/tilelang_indexer_bwd.py b/megatron/core/transformer/experimental_attention_variant/ops/tilelang_indexer_bwd.py index 3189b9afe0a..c568b479e5a 100644 --- a/megatron/core/transformer/experimental_attention_variant/ops/tilelang_indexer_bwd.py +++ b/megatron/core/transformer/experimental_attention_variant/ops/tilelang_indexer_bwd.py @@ -2,6 +2,8 @@ # Adapted from: # https://github.com/tile-ai/tilelang/blob/4956b5835fa554af6c03d4a6289cad44bf310869/ # examples/dsa_sparse_finetune/indexer_bwd.py +from collections import OrderedDict + import tilelang as tl import tilelang.language as T import torch @@ -9,6 +11,8 @@ BF16 = T.bfloat16 FP32 = T.float32 INT32 = T.int32 +_TILELANG_KERNEL_CACHE_MAX = 64 +_tilelang_indexer_bwd_kernel_cache = OrderedDict() pass_configs = { tl.PassConfigKey.TL_DISABLE_TMA_LOWER: True, @@ -16,6 +20,36 @@ } +def _cache_put_lru(cache: OrderedDict, key, value): + cache[key] = value + cache.move_to_end(key) + while len(cache) > _TILELANG_KERNEL_CACHE_MAX: + cache.popitem(last=False) + + +def _next_power_of_two(x: int) -> int: + if x <= 1: + return 1 + return 1 << (x - 1).bit_length() + + +def _round_up(x: int, multiple: int) -> int: + return ((x + multiple - 1) // multiple) * multiple + + +def _canonical_topk(topk: int, block_i: int = 32) -> int: + return _round_up(_next_power_of_two(topk), block_i) + + +def _get_indexer_bwd_kernel(heads: int, dim: int, topk: int): + key = (heads, dim, topk) + kernel = _tilelang_indexer_bwd_kernel_cache.pop(key, None) + if kernel is None: + kernel = tl_indexer_bwd_impl(heads, dim, topk) + _cache_put_lru(_tilelang_indexer_bwd_kernel_cache, key, kernel) + return kernel + + @tl.jit(pass_configs=pass_configs) def tl_indexer_bwd_impl( heads: int, dim: int, topk: int, block_I: int = 32, num_stages: int = 0, num_threads: int = 128 @@ -147,14 +181,33 @@ def indexer_bwd_interface( ): """Run indexer backward kernel and return gradients for q/w/k.""" _, head_num, head_dim = index_q.shape - k_top = topk_indices.shape[1] + k_top = int(topk_indices.shape[1]) + assert k_top > 0, "topk must be positive" + padded_topk = _canonical_topk(k_top) + + if padded_topk != k_top: + padded_indices = torch.full( + (topk_indices.size(0), padded_topk), + -1, + dtype=topk_indices.dtype, + device=topk_indices.device, + ) + padded_indices[:, :k_top].copy_(topk_indices) + topk_indices = padded_indices + + padded_grad_scores = torch.zeros( + (grad_scores.size(0), padded_topk), dtype=grad_scores.dtype, device=grad_scores.device + ) + padded_grad_scores[:, :k_top].copy_(grad_scores) + grad_scores = padded_grad_scores grad_scores = grad_scores.contiguous() grad_q = torch.empty_like(index_q) grad_w = torch.empty_like(weights, dtype=torch.float32) grad_k = torch.zeros_like(index_k, dtype=torch.float32) - tl_indexer_bwd_impl(head_num, head_dim, k_top)( + bwd_kernel = _get_indexer_bwd_kernel(head_num, head_dim, padded_topk) + bwd_kernel( index_q.contiguous(), index_k.contiguous(), weights.squeeze(-1).contiguous(), diff --git a/megatron/core/transformer/experimental_attention_variant/ops/tilelang_indexer_fwd.py b/megatron/core/transformer/experimental_attention_variant/ops/tilelang_indexer_fwd.py index 0e1e4a59b48..cc1c72cca22 100644 --- a/megatron/core/transformer/experimental_attention_variant/ops/tilelang_indexer_fwd.py +++ b/megatron/core/transformer/experimental_attention_variant/ops/tilelang_indexer_fwd.py @@ -2,10 +2,49 @@ # Adapted from: # https://github.com/tile-ai/tilelang/blob/4956b5835fa554af6c03d4a6289cad44bf310869/ # examples/deepseek_v32/fp8_lighting_indexer.py +from collections import OrderedDict + import tilelang import torch from tilelang import language as T +_TILELANG_KERNEL_CACHE_MAX = 64 +_tilelang_indexer_fwd_kernel_cache = OrderedDict() +_tilelang_indexer_clean_logits_kernel_cache = OrderedDict() + + +def _cache_put_lru(cache: OrderedDict, key, value): + cache[key] = value + cache.move_to_end(key) + while len(cache) > _TILELANG_KERNEL_CACHE_MAX: + cache.popitem(last=False) + + +def _get_clean_logits_kernel(threads: int = 512, block_K: int = 4096): + key = (threads, block_K) + kernel = _tilelang_indexer_clean_logits_kernel_cache.pop(key, None) + if kernel is None: + kernel = clean_logits_(threads=threads, block_K=block_K) + _cache_put_lru(_tilelang_indexer_clean_logits_kernel_cache, key, kernel) + return kernel + + +def _get_indexer_fwd_kernel( + heads: int, index_dim: int, block_N: int = 256, num_stages: int = 3, threads: int = 512 +): + key = (heads, index_dim, block_N, num_stages, threads) + kernel = _tilelang_indexer_fwd_kernel_cache.pop(key, None) + if kernel is None: + kernel = tl_indexer_fwd_impl( + heads=heads, + index_dim=index_dim, + block_N=block_N, + num_stages=num_stages, + threads=threads, + ) + _cache_put_lru(_tilelang_indexer_fwd_kernel_cache, key, kernel) + return kernel + @tilelang.jit(pass_configs={tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True}) def tl_indexer_fwd_impl(heads, index_dim, block_N=256, num_stages=3, threads=512, block_Q=None): @@ -37,7 +76,7 @@ def tl_indexer_fwd_kernel( index_k_shared = T.alloc_shared([block_N, index_dim], dtype) s = T.alloc_fragment([block_N, block_Q * heads], accum_dtype) s_reshaped = T.reshape(s, (block_N, block_Q, heads)) - logits = T.alloc_fragment([block_N, block_Q], accum_dtype) + logits_shared = T.alloc_shared([block_N, block_Q], accum_dtype) weights = T.alloc_fragment([block_Q, heads], accum_dtype) seq_len_i = bx * block_Q @@ -49,17 +88,41 @@ def tl_indexer_fwd_kernel( cu_k_e_max = -2147483648 for bq_i in T.serial(block_Q): - cu_k_s_min = T.min(cu_k_s_min, T.min(CuSeqLenKS[seq_len_i + bq_i], seq_len_kv)) + q_idx = seq_len_i + bq_i + if q_idx < seq_len: + k_s = T.max(T.min(CuSeqLenKS[q_idx], seq_len_kv), 0) + cu_k_s_min = T.min(cu_k_s_min, k_s) for bq_i in T.serial(block_Q): - cu_k_e_max = T.max(cu_k_e_max, T.min(CuSeqLenKE[seq_len_i + bq_i], seq_len_kv)) - - T.copy(IndexQ[seq_len_i * heads, 0], index_q_shared) - T.copy(Weights[seq_len_i, 0], weights) + q_idx = seq_len_i + bq_i + if q_idx < seq_len: + k_e = T.max(T.min(CuSeqLenKE[q_idx], seq_len_kv), 0) + cu_k_e_max = T.max(cu_k_e_max, k_e) + + # Clamp bounds to [0, seq_len_kv] and normalize empty rows. + cu_k_s_min = T.max(cu_k_s_min, 0) + cu_k_s_min = T.min(cu_k_s_min, seq_len_kv) + cu_k_e_max = T.max(cu_k_e_max, 0) + cu_k_e_max = T.min(cu_k_e_max, seq_len_kv) + if cu_k_e_max < cu_k_s_min: + cu_k_e_max = cu_k_s_min + + for bq_i, h_i, d_i in T.Parallel(block_Q, heads, index_dim): + q_idx = seq_len_i + bq_i + index_q_shared[bq_i * heads + h_i, d_i] = T.if_then_else( + q_idx < seq_len, IndexQ[q_idx * heads + h_i, d_i], 0 + ) + for bq_i, h_i in T.Parallel(block_Q, heads): + q_idx = seq_len_i + bq_i + weights[bq_i, h_i] = T.if_then_else(q_idx < seq_len, Weights[q_idx, h_i], 0) for nbn_i in T.Pipelined( T.ceildiv(cu_k_e_max - cu_k_s_min, block_N), num_stages=num_stages ): - T.copy(IndexK[cu_k_s_min + nbn_i * block_N, 0], index_k_shared) + for bn_i, d_i in T.Parallel(block_N, index_dim): + k_idx = cu_k_s_min + nbn_i * block_N + bn_i + index_k_shared[bn_i, d_i] = T.if_then_else( + k_idx >= 0 and k_idx < cu_k_e_max, IndexK[k_idx, d_i], 0 + ) T.gemm( index_k_shared, @@ -75,12 +138,16 @@ def tl_indexer_fwd_kernel( T.max(s_reshaped[bn_i, bq_i, h_i], 0) * weights[bq_i, h_i] ) - T.reduce_sum(s_reshaped, logits, dim=-1, clear=True) + T.reduce_sum(s_reshaped, logits_shared, dim=-1, clear=True) - for bq_i, bn_i in T.Parallel(block_Q, block_N): - Logits[seq_len_i + bq_i, cu_k_s_min + nbn_i * block_N + bn_i] = logits[ - bn_i, bq_i - ] + # Keep this write deterministic to satisfy data-race verification. + for bq_i in T.serial(block_Q): + q_idx = seq_len_i + bq_i + if q_idx < seq_len: + for bn_i in T.serial(block_N): + k_idx = cu_k_s_min + nbn_i * block_N + bn_i + if k_idx >= 0 and k_idx < cu_k_e_max: + Logits[q_idx, k_idx] = logits_shared[bn_i, bq_i] return tl_indexer_fwd_kernel @@ -108,7 +175,7 @@ def clean_logits_kernel( for n_i in T.Pipelined(T.ceildiv(seq_len_kv, block_K)): for k_i in T.serial(block_K // threads): idx = n_i * block_K + k_i * threads + tx - if idx < cu_k_s or idx >= cu_k_e: + if idx < seq_len_kv and (idx < cu_k_s or idx >= cu_k_e): Logits[bx, idx] = -T.infinity(dtype) return clean_logits_kernel @@ -119,14 +186,13 @@ def indexer_fwd_interface(q, kv, weights, cu_seqlen_ks, cu_seqlen_ke, clean_logi seq_len, heads, index_dim = q.shape seq_len_kv = kv.shape[0] - clean_logits_kernel = clean_logits_() - - tl_indexer_fwd_kernel = tl_indexer_fwd_impl(heads=heads, index_dim=index_dim) - + tl_indexer_fwd_kernel = _get_indexer_fwd_kernel(heads=heads, index_dim=index_dim) logits = torch.empty([seq_len, seq_len_kv], device=q.device, dtype=torch.float32) tl_indexer_fwd_kernel( q.view(seq_len * heads, index_dim), kv, logits, weights, cu_seqlen_ks, cu_seqlen_ke ) + if clean_logits: + clean_logits_kernel = _get_clean_logits_kernel() clean_logits_kernel(logits, cu_seqlen_ks, cu_seqlen_ke) return logits diff --git a/megatron/core/transformer/experimental_attention_variant/ops/tilelang_sparse_mla_bwd.py b/megatron/core/transformer/experimental_attention_variant/ops/tilelang_sparse_mla_bwd.py index 01e13ac1efe..e350924d450 100644 --- a/megatron/core/transformer/experimental_attention_variant/ops/tilelang_sparse_mla_bwd.py +++ b/megatron/core/transformer/experimental_attention_variant/ops/tilelang_sparse_mla_bwd.py @@ -2,10 +2,85 @@ # Adapted from: # https://github.com/tile-ai/tilelang/blob/4ff81c7d40803d269569e157e847623e84553f78/ # examples/deepseek_v32/sparse_mla_bwd.py +import os +from collections import OrderedDict + import tilelang import torch from tilelang import language as T +_TILELANG_KERNEL_CACHE_MAX = 64 +_SPARSE_MLA_BWD_BLOCK_SIZE = 32 +_tilelang_sparse_mla_preprocess_kernel_cache = OrderedDict() +_tilelang_sparse_mla_bwd_kernel_cache = OrderedDict() +_tilelang_sparse_mla_postprocess_kernel_cache = OrderedDict() + + +def _env_int(name: str, default: int) -> int: + value = os.getenv(name) + if value is None: + return default + try: + parsed = int(value) + except ValueError: + return default + return parsed if parsed > 0 else default + + +def _ceil_div(x: int, y: int) -> int: + return (x + y - 1) // y + + +def _round_up(x: int, multiple: int) -> int: + if multiple <= 1: + return x + return _ceil_div(x, multiple) * multiple + + +def _cache_put_lru(cache: OrderedDict, key, value): + cache[key] = value + cache.move_to_end(key) + while len(cache) > _TILELANG_KERNEL_CACHE_MAX: + cache.popitem(last=False) + + +def _get_preprocess_kernel(B: int, S: int, H: int, D: int): + key = (B, S, H, D) + kernel = _tilelang_sparse_mla_preprocess_kernel_cache.pop(key, None) + if kernel is None: + kernel = preprocess(B, S, H, D) + _cache_put_lru(_tilelang_sparse_mla_preprocess_kernel_cache, key, kernel) + return kernel + + +def _get_bwd_kernel( + B: int, + S: int, + S_kv: int, + H: int, + D: int, + D_tail: int, + topk: int, + kv_group: int, + sm_scale, + is_causal: bool, +): + key = (B, S, S_kv, H, D, D_tail, topk, kv_group, sm_scale, is_causal) + kernel = _tilelang_sparse_mla_bwd_kernel_cache.pop(key, None) + if kernel is None: + kernel = bwd(B, S, S_kv, H, D, D_tail, topk, kv_group, sm_scale, is_causal) + _cache_put_lru(_tilelang_sparse_mla_bwd_kernel_cache, key, kernel) + return kernel + + +def _get_postprocess_kernel(B: int, S_kv: int, D: int, D_tail: int, kv_group: int): + key = (B, S_kv, D, D_tail, kv_group) + kernel = _tilelang_sparse_mla_postprocess_kernel_cache.pop(key, None) + if kernel is None: + kernel = postprocess(B, S_kv, D, D_tail, kv_group) + _cache_put_lru(_tilelang_sparse_mla_postprocess_kernel_cache, key, kernel) + return kernel + @tilelang.jit(out_idx=[-1]) def preprocess(B, S, H, D, block_ND=32, num_stages=5, dtype=T.bfloat16, accum_dtype=T.float32): @@ -315,6 +390,9 @@ def sparse_mla_bwd( q, kv, o, do, indices, lse, sm_scale=None, is_casual=True, return_kernel=False, delta=None ): """Run sparse-MLA backward kernels and return (dq, dkv).""" + seq_bucket = _env_int("MCORE_DSA_TILELANG_SEQ_BUCKET", 256) + topk_bucket = _env_int("MCORE_DSA_TILELANG_TOPK_BUCKET", _SPARSE_MLA_BWD_BLOCK_SIZE) + q = q.unsqueeze(0) kv = kv.unsqueeze(0) o = o.unsqueeze(0) @@ -330,18 +408,71 @@ def sparse_mla_bwd( _, S_kv, kv_group, _ = kv.shape assert kv.shape[-1] == dim_plus_tail_dim assert kv.shape[0] == B - # dim should be assigned + # This copied kernel currently assumes a fixed base value-channel dimension. D = 512 + assert ( + dim_plus_tail_dim >= D + ), f"Invalid dimensions: dim_plus_tail_dim={dim_plus_tail_dim} is smaller than base D={D}" D_tail = dim_plus_tail_dim - D topk = indices.shape[-1] assert indices.shape == (B, S, kv_group, topk) assert lse.shape == (B, S, H) + seq_bucketed = _round_up(S, seq_bucket) + seq_kv_bucketed = _round_up(S_kv, seq_bucket) + topk_bucketed = _round_up(_round_up(topk, topk_bucket), _SPARSE_MLA_BWD_BLOCK_SIZE) + + if seq_bucketed != S: + q_padded = torch.zeros( + (B, seq_bucketed, H, dim_plus_tail_dim), dtype=q.dtype, device=q.device + ) + q_padded[:, :S].copy_(q) + q = q_padded + + o_padded = torch.zeros((B, seq_bucketed, H, D), dtype=o.dtype, device=o.device) + o_padded[:, :S].copy_(o) + o = o_padded + + do_padded = torch.zeros((B, seq_bucketed, H, D), dtype=do.dtype, device=do.device) + do_padded[:, :S].copy_(do) + do = do_padded + + lse_padded = torch.zeros((B, seq_bucketed, H), dtype=lse.dtype, device=lse.device) + lse_padded[:, :S].copy_(lse) + lse = lse_padded + + if seq_kv_bucketed != S_kv: + kv_padded = torch.zeros( + (B, seq_kv_bucketed, kv_group, dim_plus_tail_dim), dtype=kv.dtype, device=kv.device + ) + kv_padded[:, :S_kv].copy_(kv) + kv = kv_padded + + if seq_bucketed != S or topk_bucketed != topk: + indices_padded = torch.full( + (B, seq_bucketed, kv_group, topk_bucketed), + -1, + dtype=indices.dtype, + device=indices.device, + ) + indices_padded[:, :S, :, :topk].copy_(indices) + indices = indices_padded + + if delta is not None: + if delta.ndim == 2: + delta = delta.unsqueeze(0) + if seq_bucketed != S: + delta_padded = torch.zeros((B, seq_bucketed, H), dtype=delta.dtype, device=delta.device) + delta_padded[:, :S].copy_(delta) + delta = delta_padded + # Get kernels - preprocess_kernel = preprocess(B, S, H, D) - bwd_kernel = bwd(B, S, S_kv, H, D, D_tail, topk, kv_group, sm_scale, is_casual) - postprocess_kernel = postprocess(B, S_kv, D, D_tail, kv_group) + preprocess_kernel = _get_preprocess_kernel(B, seq_bucketed, H, D) + bwd_kernel = _get_bwd_kernel( + B, seq_bucketed, seq_kv_bucketed, H, D, D_tail, topk_bucketed, kv_group, sm_scale, is_casual + ) + postprocess_kernel = _get_postprocess_kernel(B, seq_kv_bucketed, D, D_tail, kv_group) if delta is None: delta = preprocess_kernel(o, do) @@ -349,6 +480,9 @@ def sparse_mla_bwd( dq = bwd_kernel(q, kv, do, indices, lse, delta, dkv) dkv = postprocess_kernel(dkv) + dq = dq[:, :S].contiguous() + dkv = dkv[:, :S_kv].contiguous() + dq = dq.squeeze(0) dkv = dkv.squeeze(0) diff --git a/megatron/core/transformer/experimental_attention_variant/ops/tilelang_sparse_mla_fwd.py b/megatron/core/transformer/experimental_attention_variant/ops/tilelang_sparse_mla_fwd.py index 19d1a80963f..bfab43c2566 100644 --- a/megatron/core/transformer/experimental_attention_variant/ops/tilelang_sparse_mla_fwd.py +++ b/megatron/core/transformer/experimental_attention_variant/ops/tilelang_sparse_mla_fwd.py @@ -2,9 +2,75 @@ # Adapted from: # https://github.com/tile-ai/tilelang/blob/e666d2d3cc483829c57618c9ebf2e4f4ada0819d/ # examples/deepseek_v32/sparse_mla_fwd.py +import os +from collections import OrderedDict + import tilelang +import torch from tilelang import language as T +_TILELANG_KERNEL_CACHE_MAX = 64 +_tilelang_sparse_mla_fwd_kernel_cache = OrderedDict() + + +def _env_int(name: str, default: int) -> int: + value = os.getenv(name) + if value is None: + return default + try: + parsed = int(value) + except ValueError: + return default + return parsed if parsed > 0 else default + + +def _ceil_div(x: int, y: int) -> int: + return (x + y - 1) // y + + +def _round_up(x: int, multiple: int) -> int: + if multiple <= 1: + return x + return _ceil_div(x, multiple) * multiple + + +def _cache_put_lru(cache: OrderedDict, key, value): + cache[key] = value + cache.move_to_end(key) + while len(cache) > _TILELANG_KERNEL_CACHE_MAX: + cache.popitem(last=False) + + +def _get_sparse_mla_fwd_kernel( + heads: int, + dim: int, + tail_dim: int, + topk: int, + kv_group: int, + sm_scale, + is_causal: bool, + block_I: int, + num_stages: int, + threads: int, +): + key = (heads, dim, tail_dim, topk, kv_group, sm_scale, is_causal, block_I, num_stages, threads) + kernel = _tilelang_sparse_mla_fwd_kernel_cache.pop(key, None) + if kernel is None: + kernel = sparse_mla_fwd( + heads, + dim, + tail_dim, + topk, + kv_group, + sm_scale, + is_causal, + block_I=block_I, + num_stages=num_stages, + threads=threads, + ) + _cache_put_lru(_tilelang_sparse_mla_fwd_kernel_cache, key, kernel) + return kernel + @tilelang.jit( out_idx=[-2, -1], @@ -185,6 +251,9 @@ def sparse_mla_fwd_interface( threads=256, ): """Run sparse-MLA forward kernel and return (out, lse).""" + seq_bucket = _env_int("MCORE_DSA_TILELANG_SEQ_BUCKET", 256) + topk_bucket = _env_int("MCORE_DSA_TILELANG_TOPK_BUCKET", block_I) + q = q.unsqueeze(0) kv = kv.unsqueeze(0) indices = indices.unsqueeze(0) @@ -193,10 +262,15 @@ def sparse_mla_fwd_interface( assert return_p_sum == False, "This kernel file is for fwd only" assert q.is_contiguous() and kv.is_contiguous() and indices.is_contiguous() batch, seq_len, heads, dim_plus_tail_dim = q.shape - _, seq_len_kv, kv_group, _ = kv.shape - - assert dim_plus_tail_dim == 576, "you should assign dim otherwise" + _, seq_len_kv, kv_group, kv_dim = kv.shape + assert ( + kv_dim == dim_plus_tail_dim + ), "q and kv must have the same embedding dimension on the last axis" + assert ( + dim_plus_tail_dim == 576 + ), "TileLang sparse MLA fwd is currently specialized for dim_plus_tail_dim=576" dim = d_v + assert 0 < dim <= dim_plus_tail_dim, f"d_v must be in (0, {dim_plus_tail_dim}], but got {dim}" assert kv.shape[-1] == dim_plus_tail_dim tail_dim = dim_plus_tail_dim - dim @@ -204,19 +278,49 @@ def sparse_mla_fwd_interface( _, _, _, topk = indices.shape assert indices.shape == (batch, seq_len, kv_group, topk) - kernel = sparse_mla_fwd( - heads, - dim, - tail_dim, - topk, - kv_group, - sm_scale, - is_casual, + seq_len_bucketed = _round_up(seq_len, seq_bucket) + seq_len_kv_bucketed = _round_up(seq_len_kv, seq_bucket) + topk_bucketed = _round_up(_round_up(topk, topk_bucket), block_I) + + if seq_len_bucketed != seq_len: + q_padded = torch.zeros( + (batch, seq_len_bucketed, heads, dim_plus_tail_dim), dtype=q.dtype, device=q.device + ) + q_padded[:, :seq_len].copy_(q) + q = q_padded + + if seq_len_kv_bucketed != seq_len_kv: + kv_padded = torch.zeros( + (batch, seq_len_kv_bucketed, kv_group, dim_plus_tail_dim), + dtype=kv.dtype, + device=kv.device, + ) + kv_padded[:, :seq_len_kv].copy_(kv) + kv = kv_padded + + if seq_len_bucketed != seq_len or topk_bucketed != topk: + indices_padded = torch.full( + (batch, seq_len_bucketed, kv_group, topk_bucketed), + -1, + dtype=indices.dtype, + device=indices.device, + ) + indices_padded[:, :seq_len, :, :topk].copy_(indices) + indices = indices_padded + + kernel = _get_sparse_mla_fwd_kernel( + heads=heads, + dim=dim, + tail_dim=tail_dim, + topk=topk_bucketed, + kv_group=kv_group, + sm_scale=sm_scale, + is_causal=is_casual, block_I=block_I, num_stages=num_stages, threads=threads, ) out, lse = kernel(q, kv, indices) - out = out.squeeze(0) - lse = lse.squeeze(0) + out = out[:, :seq_len].contiguous().squeeze(0) + lse = lse[:, :seq_len].contiguous().squeeze(0) return out, lse From bf05dadf3e09559b01df06e49f6f3058bd3f37c4 Mon Sep 17 00:00:00 2001 From: Hollow Man Date: Tue, 3 Mar 2026 19:06:55 +0200 Subject: [PATCH 3/6] Increase chunk size for better peformance Signed-off-by: Hollow Man --- .../experimental_attention_variant/dsa.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/megatron/core/transformer/experimental_attention_variant/dsa.py b/megatron/core/transformer/experimental_attention_variant/dsa.py index e0d9b3d00f2..e01862d5b40 100644 --- a/megatron/core/transformer/experimental_attention_variant/dsa.py +++ b/megatron/core/transformer/experimental_attention_variant/dsa.py @@ -837,9 +837,9 @@ def _fused_qk_topk_lighting_with_streaming_sparse_kl( softmax_scale: float, loss_coeff: float, pg_collection: ProcessGroupCollection, - seq_chunk_size: int = 32, - head_chunk_size: int = 4, - topk_chunk_size: int = 64, + seq_chunk_size: int = 512, + head_chunk_size: int = 16, + topk_chunk_size: int = 1024, ) -> Optional[Tuple[torch.Tensor, torch.Tensor]]: """Run fused tilelang indexer and stream top-k logits directly into sparse KL accumulation.""" if lighting_indexer is None: @@ -1515,9 +1515,9 @@ def compute_dsa_indexer_loss_topk_sparse( # Compute KL in streaming chunks to avoid materializing full [b, sq, topk] # and avoid full-size valid/safe top-k tensors. - seq_chunk_size = 32 - head_chunk_size = 4 - topk_chunk_size = 64 + seq_chunk_size = 512 + head_chunk_size = 16 + topk_chunk_size = 1024 kl_sum = torch.zeros((), dtype=torch.float32, device=query.device) tp_size = pg_collection.tp.size() pending_handle = None @@ -2417,9 +2417,9 @@ def unfused_dsa_fn( device=query.device, ) - seq_chunk_size = 64 - head_chunk_size = 4 - topk_chunk_size = 256 + seq_chunk_size = 512 + head_chunk_size = 16 + topk_chunk_size = 1024 safe_k_max = max(0, skv - 1) output = torch.empty((sq, b, np * hnv), dtype=value.dtype, device=query.device) From 30203d0177c056a4d10123de9061dd5e30d0c262 Mon Sep 17 00:00:00 2001 From: Hollow Man Date: Tue, 3 Mar 2026 22:37:45 +0200 Subject: [PATCH 4/6] fix tilelang recompile under pp>1 Signed-off-by: Hollow Man --- .../experimental_attention_variant/dsa.py | 23 +++++---- .../ops/tilelang_indexer_fwd.py | 50 ++++++++++++------- 2 files changed, 46 insertions(+), 27 deletions(-) diff --git a/megatron/core/transformer/experimental_attention_variant/dsa.py b/megatron/core/transformer/experimental_attention_variant/dsa.py index e01862d5b40..1672c882ac4 100644 --- a/megatron/core/transformer/experimental_attention_variant/dsa.py +++ b/megatron/core/transformer/experimental_attention_variant/dsa.py @@ -2007,7 +2007,9 @@ class DSAIndexerLossAutoScaler(torch.autograd.Function): to train the indexer to predict attention scores without affecting the forward pass. """ - main_loss_backward_scale: torch.Tensor = None + # Keep scale as a Python float to avoid holding a long-lived CUDA tensor + # across iterations/streams. + main_loss_backward_scale: Optional[float] = None @staticmethod def forward(ctx, output: torch.Tensor, indexer_loss: torch.Tensor): @@ -2036,24 +2038,25 @@ def backward(ctx, grad_output: torch.Tensor): """ (indexer_loss,) = ctx.saved_tensors if DSAIndexerLossAutoScaler.main_loss_backward_scale is None: - DSAIndexerLossAutoScaler.main_loss_backward_scale = torch.tensor( - 1.0, device=indexer_loss.device - ) - indexer_loss_backward_scale = DSAIndexerLossAutoScaler.main_loss_backward_scale - scaled_indexer_loss_grad = torch.ones_like(indexer_loss) * indexer_loss_backward_scale + DSAIndexerLossAutoScaler.main_loss_backward_scale = 1.0 + indexer_loss_backward_scale = float(DSAIndexerLossAutoScaler.main_loss_backward_scale) + scaled_indexer_loss_grad = torch.full_like( + indexer_loss, fill_value=indexer_loss_backward_scale + ) return grad_output, scaled_indexer_loss_grad @staticmethod - def set_loss_scale(scale: torch.Tensor): + def set_loss_scale(scale: Union[torch.Tensor, float]): """Set the scale of the indexer loss. Args: scale: The scale value to set. """ - if DSAIndexerLossAutoScaler.main_loss_backward_scale is None: - DSAIndexerLossAutoScaler.main_loss_backward_scale = scale + if isinstance(scale, torch.Tensor): + scale_value = float(scale.detach().item()) else: - DSAIndexerLossAutoScaler.main_loss_backward_scale.copy_(scale) + scale_value = float(scale) + DSAIndexerLossAutoScaler.main_loss_backward_scale = scale_value @dataclass diff --git a/megatron/core/transformer/experimental_attention_variant/ops/tilelang_indexer_fwd.py b/megatron/core/transformer/experimental_attention_variant/ops/tilelang_indexer_fwd.py index cc1c72cca22..d7d6b10efab 100644 --- a/megatron/core/transformer/experimental_attention_variant/ops/tilelang_indexer_fwd.py +++ b/megatron/core/transformer/experimental_attention_variant/ops/tilelang_indexer_fwd.py @@ -2,15 +2,29 @@ # Adapted from: # https://github.com/tile-ai/tilelang/blob/4956b5835fa554af6c03d4a6289cad44bf310869/ # examples/deepseek_v32/fp8_lighting_indexer.py +import os +import threading from collections import OrderedDict import tilelang import torch from tilelang import language as T -_TILELANG_KERNEL_CACHE_MAX = 64 +def _env_int(name: str, default: int) -> int: + value = os.getenv(name) + if value is None: + return default + try: + parsed = int(value) + except ValueError: + return default + return parsed if parsed > 0 else default + + +_TILELANG_KERNEL_CACHE_MAX = _env_int("MCORE_DSA_TILELANG_KERNEL_CACHE_MAX", 512) _tilelang_indexer_fwd_kernel_cache = OrderedDict() _tilelang_indexer_clean_logits_kernel_cache = OrderedDict() +_tilelang_indexer_fwd_cache_lock = threading.Lock() def _cache_put_lru(cache: OrderedDict, key, value): @@ -22,28 +36,30 @@ def _cache_put_lru(cache: OrderedDict, key, value): def _get_clean_logits_kernel(threads: int = 512, block_K: int = 4096): key = (threads, block_K) - kernel = _tilelang_indexer_clean_logits_kernel_cache.pop(key, None) - if kernel is None: - kernel = clean_logits_(threads=threads, block_K=block_K) - _cache_put_lru(_tilelang_indexer_clean_logits_kernel_cache, key, kernel) - return kernel + with _tilelang_indexer_fwd_cache_lock: + kernel = _tilelang_indexer_clean_logits_kernel_cache.pop(key, None) + if kernel is None: + kernel = clean_logits_(threads=threads, block_K=block_K) + _cache_put_lru(_tilelang_indexer_clean_logits_kernel_cache, key, kernel) + return kernel def _get_indexer_fwd_kernel( heads: int, index_dim: int, block_N: int = 256, num_stages: int = 3, threads: int = 512 ): key = (heads, index_dim, block_N, num_stages, threads) - kernel = _tilelang_indexer_fwd_kernel_cache.pop(key, None) - if kernel is None: - kernel = tl_indexer_fwd_impl( - heads=heads, - index_dim=index_dim, - block_N=block_N, - num_stages=num_stages, - threads=threads, - ) - _cache_put_lru(_tilelang_indexer_fwd_kernel_cache, key, kernel) - return kernel + with _tilelang_indexer_fwd_cache_lock: + kernel = _tilelang_indexer_fwd_kernel_cache.pop(key, None) + if kernel is None: + kernel = tl_indexer_fwd_impl( + heads=heads, + index_dim=index_dim, + block_N=block_N, + num_stages=num_stages, + threads=threads, + ) + _cache_put_lru(_tilelang_indexer_fwd_kernel_cache, key, kernel) + return kernel @tilelang.jit(pass_configs={tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True}) From f133b53c27abda0f4a1e90b88f43f2aada568182 Mon Sep 17 00:00:00 2001 From: Hollow Man Date: Wed, 4 Mar 2026 21:59:52 +0200 Subject: [PATCH 5/6] threading lock for compiling Signed-off-by: Hollow Man --- .../ops/tilelang_indexer_bwd.py | 29 +++++++-- .../ops/tilelang_sparse_mla_bwd.py | 52 ++++++++++----- .../ops/tilelang_sparse_mla_fwd.py | 63 +++++++++++++------ 3 files changed, 103 insertions(+), 41 deletions(-) diff --git a/megatron/core/transformer/experimental_attention_variant/ops/tilelang_indexer_bwd.py b/megatron/core/transformer/experimental_attention_variant/ops/tilelang_indexer_bwd.py index c568b479e5a..aebe3d12012 100644 --- a/megatron/core/transformer/experimental_attention_variant/ops/tilelang_indexer_bwd.py +++ b/megatron/core/transformer/experimental_attention_variant/ops/tilelang_indexer_bwd.py @@ -2,6 +2,8 @@ # Adapted from: # https://github.com/tile-ai/tilelang/blob/4956b5835fa554af6c03d4a6289cad44bf310869/ # examples/dsa_sparse_finetune/indexer_bwd.py +import os +import threading from collections import OrderedDict import tilelang as tl @@ -11,8 +13,8 @@ BF16 = T.bfloat16 FP32 = T.float32 INT32 = T.int32 -_TILELANG_KERNEL_CACHE_MAX = 64 _tilelang_indexer_bwd_kernel_cache = OrderedDict() +_tilelang_indexer_bwd_cache_lock = threading.Lock() pass_configs = { tl.PassConfigKey.TL_DISABLE_TMA_LOWER: True, @@ -20,6 +22,20 @@ } +def _env_int(name: str, default: int) -> int: + value = os.getenv(name) + if value is None: + return default + try: + parsed = int(value) + except ValueError: + return default + return parsed if parsed > 0 else default + + +_TILELANG_KERNEL_CACHE_MAX = _env_int("MCORE_DSA_TILELANG_KERNEL_CACHE_MAX", 512) + + def _cache_put_lru(cache: OrderedDict, key, value): cache[key] = value cache.move_to_end(key) @@ -43,11 +59,12 @@ def _canonical_topk(topk: int, block_i: int = 32) -> int: def _get_indexer_bwd_kernel(heads: int, dim: int, topk: int): key = (heads, dim, topk) - kernel = _tilelang_indexer_bwd_kernel_cache.pop(key, None) - if kernel is None: - kernel = tl_indexer_bwd_impl(heads, dim, topk) - _cache_put_lru(_tilelang_indexer_bwd_kernel_cache, key, kernel) - return kernel + with _tilelang_indexer_bwd_cache_lock: + kernel = _tilelang_indexer_bwd_kernel_cache.pop(key, None) + if kernel is None: + kernel = tl_indexer_bwd_impl(heads, dim, topk) + _cache_put_lru(_tilelang_indexer_bwd_kernel_cache, key, kernel) + return kernel @tl.jit(pass_configs=pass_configs) diff --git a/megatron/core/transformer/experimental_attention_variant/ops/tilelang_sparse_mla_bwd.py b/megatron/core/transformer/experimental_attention_variant/ops/tilelang_sparse_mla_bwd.py index e350924d450..767bcf5ee32 100644 --- a/megatron/core/transformer/experimental_attention_variant/ops/tilelang_sparse_mla_bwd.py +++ b/megatron/core/transformer/experimental_attention_variant/ops/tilelang_sparse_mla_bwd.py @@ -3,17 +3,18 @@ # https://github.com/tile-ai/tilelang/blob/4ff81c7d40803d269569e157e847623e84553f78/ # examples/deepseek_v32/sparse_mla_bwd.py import os +import threading from collections import OrderedDict import tilelang import torch from tilelang import language as T -_TILELANG_KERNEL_CACHE_MAX = 64 _SPARSE_MLA_BWD_BLOCK_SIZE = 32 _tilelang_sparse_mla_preprocess_kernel_cache = OrderedDict() _tilelang_sparse_mla_bwd_kernel_cache = OrderedDict() _tilelang_sparse_mla_postprocess_kernel_cache = OrderedDict() +_tilelang_sparse_mla_bwd_cache_lock = threading.Lock() def _env_int(name: str, default: int) -> int: @@ -27,6 +28,9 @@ def _env_int(name: str, default: int) -> int: return parsed if parsed > 0 else default +_TILELANG_KERNEL_CACHE_MAX = _env_int("MCORE_DSA_TILELANG_KERNEL_CACHE_MAX", 512) + + def _ceil_div(x: int, y: int) -> int: return (x + y - 1) // y @@ -44,13 +48,25 @@ def _cache_put_lru(cache: OrderedDict, key, value): cache.popitem(last=False) +def _normalize_sm_scale(sm_scale): + if sm_scale is None: + return None + if isinstance(sm_scale, torch.Tensor): + sm_scale = float(sm_scale.detach().item()) + else: + sm_scale = float(sm_scale) + # Avoid tiny floating-point jitter creating cache-key churn. + return round(sm_scale, 12) + + def _get_preprocess_kernel(B: int, S: int, H: int, D: int): key = (B, S, H, D) - kernel = _tilelang_sparse_mla_preprocess_kernel_cache.pop(key, None) - if kernel is None: - kernel = preprocess(B, S, H, D) - _cache_put_lru(_tilelang_sparse_mla_preprocess_kernel_cache, key, kernel) - return kernel + with _tilelang_sparse_mla_bwd_cache_lock: + kernel = _tilelang_sparse_mla_preprocess_kernel_cache.pop(key, None) + if kernel is None: + kernel = preprocess(B, S, H, D) + _cache_put_lru(_tilelang_sparse_mla_preprocess_kernel_cache, key, kernel) + return kernel def _get_bwd_kernel( @@ -65,21 +81,23 @@ def _get_bwd_kernel( sm_scale, is_causal: bool, ): - key = (B, S, S_kv, H, D, D_tail, topk, kv_group, sm_scale, is_causal) - kernel = _tilelang_sparse_mla_bwd_kernel_cache.pop(key, None) - if kernel is None: - kernel = bwd(B, S, S_kv, H, D, D_tail, topk, kv_group, sm_scale, is_causal) - _cache_put_lru(_tilelang_sparse_mla_bwd_kernel_cache, key, kernel) - return kernel + key = (B, S, S_kv, H, D, D_tail, topk, kv_group, _normalize_sm_scale(sm_scale), is_causal) + with _tilelang_sparse_mla_bwd_cache_lock: + kernel = _tilelang_sparse_mla_bwd_kernel_cache.pop(key, None) + if kernel is None: + kernel = bwd(B, S, S_kv, H, D, D_tail, topk, kv_group, sm_scale, is_causal) + _cache_put_lru(_tilelang_sparse_mla_bwd_kernel_cache, key, kernel) + return kernel def _get_postprocess_kernel(B: int, S_kv: int, D: int, D_tail: int, kv_group: int): key = (B, S_kv, D, D_tail, kv_group) - kernel = _tilelang_sparse_mla_postprocess_kernel_cache.pop(key, None) - if kernel is None: - kernel = postprocess(B, S_kv, D, D_tail, kv_group) - _cache_put_lru(_tilelang_sparse_mla_postprocess_kernel_cache, key, kernel) - return kernel + with _tilelang_sparse_mla_bwd_cache_lock: + kernel = _tilelang_sparse_mla_postprocess_kernel_cache.pop(key, None) + if kernel is None: + kernel = postprocess(B, S_kv, D, D_tail, kv_group) + _cache_put_lru(_tilelang_sparse_mla_postprocess_kernel_cache, key, kernel) + return kernel @tilelang.jit(out_idx=[-1]) diff --git a/megatron/core/transformer/experimental_attention_variant/ops/tilelang_sparse_mla_fwd.py b/megatron/core/transformer/experimental_attention_variant/ops/tilelang_sparse_mla_fwd.py index bfab43c2566..cae3985c84b 100644 --- a/megatron/core/transformer/experimental_attention_variant/ops/tilelang_sparse_mla_fwd.py +++ b/megatron/core/transformer/experimental_attention_variant/ops/tilelang_sparse_mla_fwd.py @@ -3,14 +3,15 @@ # https://github.com/tile-ai/tilelang/blob/e666d2d3cc483829c57618c9ebf2e4f4ada0819d/ # examples/deepseek_v32/sparse_mla_fwd.py import os +import threading from collections import OrderedDict import tilelang import torch from tilelang import language as T -_TILELANG_KERNEL_CACHE_MAX = 64 _tilelang_sparse_mla_fwd_kernel_cache = OrderedDict() +_tilelang_sparse_mla_fwd_cache_lock = threading.Lock() def _env_int(name: str, default: int) -> int: @@ -24,6 +25,9 @@ def _env_int(name: str, default: int) -> int: return parsed if parsed > 0 else default +_TILELANG_KERNEL_CACHE_MAX = _env_int("MCORE_DSA_TILELANG_KERNEL_CACHE_MAX", 512) + + def _ceil_div(x: int, y: int) -> int: return (x + y - 1) // y @@ -41,6 +45,17 @@ def _cache_put_lru(cache: OrderedDict, key, value): cache.popitem(last=False) +def _normalize_sm_scale(sm_scale): + if sm_scale is None: + return None + if isinstance(sm_scale, torch.Tensor): + sm_scale = float(sm_scale.detach().item()) + else: + sm_scale = float(sm_scale) + # Avoid tiny floating-point jitter creating cache-key churn. + return round(sm_scale, 12) + + def _get_sparse_mla_fwd_kernel( heads: int, dim: int, @@ -53,23 +68,35 @@ def _get_sparse_mla_fwd_kernel( num_stages: int, threads: int, ): - key = (heads, dim, tail_dim, topk, kv_group, sm_scale, is_causal, block_I, num_stages, threads) - kernel = _tilelang_sparse_mla_fwd_kernel_cache.pop(key, None) - if kernel is None: - kernel = sparse_mla_fwd( - heads, - dim, - tail_dim, - topk, - kv_group, - sm_scale, - is_causal, - block_I=block_I, - num_stages=num_stages, - threads=threads, - ) - _cache_put_lru(_tilelang_sparse_mla_fwd_kernel_cache, key, kernel) - return kernel + key = ( + heads, + dim, + tail_dim, + topk, + kv_group, + _normalize_sm_scale(sm_scale), + is_causal, + block_I, + num_stages, + threads, + ) + with _tilelang_sparse_mla_fwd_cache_lock: + kernel = _tilelang_sparse_mla_fwd_kernel_cache.pop(key, None) + if kernel is None: + kernel = sparse_mla_fwd( + heads, + dim, + tail_dim, + topk, + kv_group, + sm_scale, + is_causal, + block_I=block_I, + num_stages=num_stages, + threads=threads, + ) + _cache_put_lru(_tilelang_sparse_mla_fwd_kernel_cache, key, kernel) + return kernel @tilelang.jit( From 9ad3f653e00628eae38d300aa9963e635f08e4ad Mon Sep 17 00:00:00 2001 From: Hollow Man Date: Thu, 5 Mar 2026 01:21:25 +0200 Subject: [PATCH 6/6] No recompile Signed-off-by: Hollow Man --- .../ops/tilelang_sparse_mla_bwd.py | 78 +++++++++---------- 1 file changed, 38 insertions(+), 40 deletions(-) diff --git a/megatron/core/transformer/experimental_attention_variant/ops/tilelang_sparse_mla_bwd.py b/megatron/core/transformer/experimental_attention_variant/ops/tilelang_sparse_mla_bwd.py index 767bcf5ee32..b547d723b0a 100644 --- a/megatron/core/transformer/experimental_attention_variant/ops/tilelang_sparse_mla_bwd.py +++ b/megatron/core/transformer/experimental_attention_variant/ops/tilelang_sparse_mla_bwd.py @@ -59,61 +59,54 @@ def _normalize_sm_scale(sm_scale): return round(sm_scale, 12) -def _get_preprocess_kernel(B: int, S: int, H: int, D: int): - key = (B, S, H, D) +def _get_preprocess_kernel(H: int, D: int): + key = (H, D) with _tilelang_sparse_mla_bwd_cache_lock: kernel = _tilelang_sparse_mla_preprocess_kernel_cache.pop(key, None) if kernel is None: - kernel = preprocess(B, S, H, D) + kernel = preprocess(H, D) _cache_put_lru(_tilelang_sparse_mla_preprocess_kernel_cache, key, kernel) return kernel def _get_bwd_kernel( - B: int, - S: int, - S_kv: int, - H: int, - D: int, - D_tail: int, - topk: int, - kv_group: int, - sm_scale, - is_causal: bool, + H: int, D: int, D_tail: int, topk: int, kv_group: int, sm_scale, is_causal: bool ): - key = (B, S, S_kv, H, D, D_tail, topk, kv_group, _normalize_sm_scale(sm_scale), is_causal) + key = (H, D, D_tail, topk, kv_group, _normalize_sm_scale(sm_scale), is_causal) with _tilelang_sparse_mla_bwd_cache_lock: kernel = _tilelang_sparse_mla_bwd_kernel_cache.pop(key, None) if kernel is None: - kernel = bwd(B, S, S_kv, H, D, D_tail, topk, kv_group, sm_scale, is_causal) + kernel = bwd(H, D, D_tail, topk, kv_group, sm_scale, is_causal) _cache_put_lru(_tilelang_sparse_mla_bwd_kernel_cache, key, kernel) return kernel -def _get_postprocess_kernel(B: int, S_kv: int, D: int, D_tail: int, kv_group: int): - key = (B, S_kv, D, D_tail, kv_group) +def _get_postprocess_kernel(D: int, D_tail: int, kv_group: int): + key = (D, D_tail, kv_group) with _tilelang_sparse_mla_bwd_cache_lock: kernel = _tilelang_sparse_mla_postprocess_kernel_cache.pop(key, None) if kernel is None: - kernel = postprocess(B, S_kv, D, D_tail, kv_group) + kernel = postprocess(D, D_tail, kv_group) _cache_put_lru(_tilelang_sparse_mla_postprocess_kernel_cache, key, kernel) return kernel @tilelang.jit(out_idx=[-1]) -def preprocess(B, S, H, D, block_ND=32, num_stages=5, dtype=T.bfloat16, accum_dtype=T.float32): +def preprocess(H, D, block_ND=32, num_stages=5, dtype=T.bfloat16, accum_dtype=T.float32): """Build preprocessing kernel that computes Delta = sum(O * dO) per row/head.""" assert dtype == T.bfloat16 assert accum_dtype == T.float32 - shape = [B, S, H, D] + batch = T.dynamic("batch") + seq_len = T.dynamic("seq_len") + shape = [batch, seq_len, H, D] @T.prim_func def preprocess_kernel( O: T.Tensor(shape, dtype), dO: T.Tensor(shape, dtype), - Delta: T.Tensor([B, S, H], accum_dtype), + Delta: T.Tensor([batch, seq_len, H], accum_dtype), ): - with T.Kernel(H, T.ceildiv(S, block_ND), B) as (bx, by, bz): + with T.Kernel(H, T.ceildiv(seq_len, block_ND), batch) as (bx, by, bz): o = T.alloc_fragment([block_ND, block_ND], accum_dtype) do = T.alloc_fragment([block_ND, block_ND], accum_dtype) delta = T.alloc_fragment([block_ND], accum_dtype) @@ -148,18 +141,24 @@ def preprocess_kernel( @tilelang.jit(out_idx=[-1]) def postprocess( - B, S_kv, D, D_tail, kv_group=1, block_N=64, threads=128, dtype=T.bfloat16, accum_dtype=T.float32 + D, D_tail, kv_group=1, block_N=64, threads=128, dtype=T.bfloat16, accum_dtype=T.float32 ): """Build postprocess kernel that casts/exports accumulated dKV.""" assert dtype == T.bfloat16 assert accum_dtype == T.float32 - dkv_shape = [B, S_kv, kv_group, D + D_tail] + batch = T.dynamic("batch") + seq_len_kv = T.dynamic("seq_len_kv") + dkv_shape = [batch, seq_len_kv, kv_group, D + D_tail] @T.prim_func def postprocess_kernel( dKV: T.Tensor(dkv_shape, accum_dtype), dKV_out: T.Tensor(dkv_shape, dtype) ): - with T.Kernel(T.ceildiv(S_kv, block_N), kv_group, B, threads=threads) as (bx, by, bz): + with T.Kernel(T.ceildiv(seq_len_kv, block_N), kv_group, batch, threads=threads) as ( + bx, + by, + bz, + ): T.copy( dKV[bz, bx * block_N : (bx + 1) * block_N, by, :], dKV_out[bz, bx * block_N : (bx + 1) * block_N, by, :], @@ -177,9 +176,6 @@ def postprocess_kernel( }, ) def bwd( - B, - S, - S_kv, H, D, D_tail, @@ -207,13 +203,17 @@ def bwd( sm_scale = (D + D_tail) ** (-0.5) sm_scale_mul_reciprocal_log2 = sm_scale * 1.44269504 # log2(e) + batch = T.dynamic("batch") + seq_len = T.dynamic("seq_len") + seq_len_kv = T.dynamic("seq_len_kv") + H_kv = H // kv_group - q_shape = [B, S, H, D + D_tail] - k_shape = [B, S_kv, kv_group, D + D_tail] - o_shape = [B, S, H, D] - indices_shape = [B, S, kv_group, topk] - delta_shape = [B, S, H] - lse_shape = [B, S, H] + q_shape = [batch, seq_len, H, D + D_tail] + k_shape = [batch, seq_len_kv, kv_group, D + D_tail] + o_shape = [batch, seq_len, H, D] + indices_shape = [batch, seq_len, kv_group, topk] + delta_shape = [batch, seq_len, H] + lse_shape = [batch, seq_len, H] assert indices_dtype == T.int32 assert dtype == T.bfloat16 assert accum_dtype == T.float32 @@ -239,7 +239,7 @@ def sparse_mla_bwd_kernel( dQ: T.Tensor(q_shape, dtype), dKV: T.Tensor(k_shape, accum_dtype), ): - with T.Kernel(S, B, kv_group * NH, threads=threads) as (s_i, by, bz): + with T.Kernel(seq_len, batch, kv_group * NH, threads=threads) as (s_i, by, bz): Q_shared = T.alloc_shared([block_H, D], dtype) Q_tail_shared = T.alloc_shared([block_H, D_tail], dtype) KV_shared = T.alloc_shared([BS, D], dtype) @@ -486,11 +486,9 @@ def sparse_mla_bwd( delta = delta_padded # Get kernels - preprocess_kernel = _get_preprocess_kernel(B, seq_bucketed, H, D) - bwd_kernel = _get_bwd_kernel( - B, seq_bucketed, seq_kv_bucketed, H, D, D_tail, topk_bucketed, kv_group, sm_scale, is_casual - ) - postprocess_kernel = _get_postprocess_kernel(B, seq_kv_bucketed, D, D_tail, kv_group) + preprocess_kernel = _get_preprocess_kernel(H, D) + bwd_kernel = _get_bwd_kernel(H, D, D_tail, topk_bucketed, kv_group, sm_scale, is_casual) + postprocess_kernel = _get_postprocess_kernel(D, D_tail, kv_group) if delta is None: delta = preprocess_kernel(o, do)