From 3c45c7eea8f8f9ed1b7e462852ffdeb026bf7c0b Mon Sep 17 00:00:00 2001 From: Onur Yilmaz Date: Sun, 8 Feb 2026 13:31:38 -0500 Subject: [PATCH 01/11] Get varlen code from vLLM and integrate in MLM Signed-off-by: Onur Yilmaz --- megatron/core/ssm/mamba_mixer.py | 195 +++++- megatron/core/ssm/ops/__init__.py | 6 + megatron/core/ssm/ops/ssd_bmm.py | 207 ++++++ megatron/core/ssm/ops/ssd_chunk_scan.py | 453 +++++++++++++ megatron/core/ssm/ops/ssd_chunk_state.py | 718 +++++++++++++++++++++ megatron/core/ssm/ops/ssd_combined.py | 241 +++++++ megatron/core/ssm/ops/ssd_state_passing.py | 153 +++++ 7 files changed, 1952 insertions(+), 21 deletions(-) create mode 100644 megatron/core/ssm/ops/__init__.py create mode 100644 megatron/core/ssm/ops/ssd_bmm.py create mode 100644 megatron/core/ssm/ops/ssd_chunk_scan.py create mode 100644 megatron/core/ssm/ops/ssd_chunk_state.py create mode 100644 megatron/core/ssm/ops/ssd_combined.py create mode 100644 megatron/core/ssm/ops/ssd_state_passing.py diff --git a/megatron/core/ssm/mamba_mixer.py b/megatron/core/ssm/mamba_mixer.py index cc71cdc32f6..0fc6f36b389 100644 --- a/megatron/core/ssm/mamba_mixer.py +++ b/megatron/core/ssm/mamba_mixer.py @@ -65,6 +65,19 @@ HAVE_MAMBA_SSM = True except ImportError: + mamba_chunk_scan_combined = None + mamba_split_conv1d_scan_combined = None + HAVE_MAMBA_SSM = False + +try: + from megatron.core.ssm.ops.ssd_combined import mamba_chunk_scan_combined_varlen + + HAVE_SSM_OPS_VARLEN = True +except ImportError: + mamba_chunk_scan_combined_varlen = None + HAVE_SSM_OPS_VARLEN = False + +if not HAVE_MAMBA_SSM: from unittest.mock import MagicMock RMSNormGated = MagicMock() @@ -847,27 +860,167 @@ def _ssm_prefill( # Note that both `seq_idx` and `cu_seqlens` must be passed in # for variable length generation. # See https://github.com/state-spaces/mamba/blob/e0761ece1db07e0949dd88b4f4cd440420a19fd9/tests/test_generation.py#L97 # pylint: disable=line-too-long - y = mamba_chunk_scan_combined( - x, - dt, - A, - B, - C, - self.chunk_size, - D=( - rearrange(self.cp.get_D().float(), "(h p) -> h p", p=self.headdim) - if self.D_has_hdim - else self.cp.get_D() - ), - z=z if not self.rmsnorm else None, - dt_bias=self.cp.get_dt_bias().float(), - dt_softplus=True, - return_final_states=ssm_state is not None, - seq_idx=seq_idx, - cu_seqlens=cu_seqlens, - return_varlen_states=return_varlen_states, - initial_states=initial_ssm_state, - ) + + # Use local varlen kernels only for batch > 1; for batch==1 use mamba_ssm which + # is designed for that layout and avoids packing/alignment issues. + if ( + cu_seqlens is not None + and x.shape[0] > 1 + and HAVE_SSM_OPS_VARLEN + and mamba_chunk_scan_combined_varlen is not None + ): + # Variable-length path using local Triton kernels (megatron.core.ssm.ops) + batch, max_seqlen = x.shape[0], x.shape[1] + total_tokens = cu_seqlens[-1].item() + chunk_size = self.chunk_size + device = x.device + + if total_tokens > 0: + # Build chunk boundaries so no chunk spans two sequences (fixes junk output + # when multiple short sequences share a chunk). Merge fixed-size boundaries + # with sequence boundaries from cu_seqlens. + boundaries_set = {0, total_tokens} + for s in range(1, batch): + boundaries_set.add(cu_seqlens[s].item()) + for pos in range(0, total_tokens, chunk_size): + boundaries_set.add(min(pos, total_tokens)) + boundaries = sorted(boundaries_set) + cu_chunk_seqlens = torch.tensor( + boundaries, device=device, dtype=cu_seqlens.dtype + ) + nchunks = len(boundaries) - 1 + + seq_lengths = cu_seqlens[1:] - cu_seqlens[:-1] + # Chunk index that contains the last token of each sequence + last_token_pos = (cu_seqlens[1:] - 1).clamp(min=0) + last_chunk_indices = ( + torch.searchsorted( + cu_chunk_seqlens, last_token_pos.to(cu_chunk_seqlens.dtype), right=False + ) + - 1 + ) + last_chunk_indices = last_chunk_indices.clamp(0, nchunks - 1).to( + device=device, dtype=torch.int64 + ) + + # Chunk-level seq_idx: which sequence each chunk belongs to + chunk_starts = cu_chunk_seqlens[:-1].to(cu_seqlens.dtype) + seq_idx_chunk = ( + (torch.searchsorted(cu_seqlens, chunk_starts, right=False) - 1) + .clamp(0, batch - 1) + .to(device=device, dtype=torch.int32) + ) + + # Pack tensors to (total_tokens, ...); use .item() for safe Python int slicing + x_packed = torch.cat( + [ + x[b, : seq_lengths[b].item(), :, :].contiguous() + for b in range(batch) + ], + dim=0, + ).contiguous() + dt_packed = torch.cat( + [dt[b, : seq_lengths[b].item(), :].contiguous() for b in range(batch)], + dim=0, + ).contiguous() + B_packed = torch.cat( + [ + B[b, : seq_lengths[b].item(), :, :].contiguous() + for b in range(batch) + ], + dim=0, + ).contiguous() + C_packed = torch.cat( + [ + C[b, : seq_lengths[b].item(), :, :].contiguous() + for b in range(batch) + ], + dim=0, + ).contiguous() + z_packed = ( + torch.cat( + [ + z[b, : seq_lengths[b].item(), :, :].contiguous() + for b in range(batch) + ], + dim=0, + ).contiguous() + if not self.rmsnorm + else None + ) + + out_packed = torch.empty_like(x_packed) + D_val = ( + rearrange(self.cp.get_D().float(), "(h p) -> h p", p=self.headdim) + if self.D_has_hdim + else self.cp.get_D() + ) + dt_bias_val = self.cp.get_dt_bias().float() + + varlen_states = mamba_chunk_scan_combined_varlen( + x_packed, + dt_packed, + A, + B_packed, + C_packed, + chunk_size, + cu_seqlens=cu_seqlens, + cu_chunk_seqlens=cu_chunk_seqlens, + last_chunk_indices=last_chunk_indices, + seq_idx=seq_idx_chunk, + out=out_packed, + D=D_val, + z=z_packed, + dt_bias=dt_bias_val, + initial_states=initial_ssm_state, + dt_softplus=True, + return_intermediate_states=False, + ) + + # Unpack output to (batch, max_seqlen, nheads, headdim) + y_unpacked = x.new_zeros(batch, max_seqlen, x.shape[2], x.shape[3]) + for b in range(batch): + length_b = seq_lengths[b].item() + if length_b > 0: + y_unpacked[b, :length_b, :, :] = out_packed[ + cu_seqlens[b] : cu_seqlens[b + 1] + ] + else: + # Zero tokens: no chunks, return zeros without calling kernel + y_unpacked = x.new_zeros(batch, max_seqlen, x.shape[2], x.shape[3]) + varlen_states = x.new_zeros( + batch, x.shape[2], x.shape[3], B.shape[-1], + device=device, dtype=x.dtype, + ) + + if ssm_state is not None and return_varlen_states: + y = (y_unpacked, None, varlen_states) + elif ssm_state is not None: + y = (y_unpacked, None) + else: + y = y_unpacked + else: + y = mamba_chunk_scan_combined( + x, + dt, + A, + B, + C, + self.chunk_size, + D=( + rearrange(self.cp.get_D().float(), "(h p) -> h p", p=self.headdim) + if self.D_has_hdim + else self.cp.get_D() + ), + z=z if not self.rmsnorm else None, + dt_bias=self.cp.get_dt_bias().float(), + dt_softplus=True, + return_final_states=ssm_state is not None, + seq_idx=seq_idx, + cu_seqlens=cu_seqlens, + return_varlen_states=return_varlen_states, + initial_states=initial_ssm_state, + ) if ssm_state is not None: if return_varlen_states: diff --git a/megatron/core/ssm/ops/__init__.py b/megatron/core/ssm/ops/__init__.py new file mode 100644 index 00000000000..03b4a09a529 --- /dev/null +++ b/megatron/core/ssm/ops/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# Triton kernels for Mamba SSM (adapted from vLLM / state-spaces/mamba). + +from .ssd_combined import mamba_chunk_scan_combined_varlen + +__all__ = ["mamba_chunk_scan_combined_varlen"] diff --git a/megatron/core/ssm/ops/ssd_bmm.py b/megatron/core/ssm/ops/ssd_bmm.py new file mode 100644 index 00000000000..57731ba5f98 --- /dev/null +++ b/megatron/core/ssm/ops/ssd_bmm.py @@ -0,0 +1,207 @@ +# Copyright (c) 2024, Tri Dao, Albert Gu. +# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_bmm.py +# Adapted from vLLM project (Apache-2.0). + +import torch +import triton +import triton.language as tl + + +@triton.autotune( + configs=[ + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64}, + num_stages=3, + num_warps=8, + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, + num_stages=5, + num_warps=2, + ), + triton.Config( + {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, + num_stages=5, + num_warps=2, + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, + num_stages=4, + num_warps=2, + ), + ], + key=["chunk_size", "K", "IS_CAUSAL"], +) +@triton.jit +def _bmm_chunk_fwd_kernel( + # Pointers to matrices + a_ptr, + b_ptr, + out_ptr, + cu_chunk_seqlens_ptr, + # Matrix dimensions + seqlen, + chunk_size: tl.constexpr, + K: tl.constexpr, + ngroups: tl.constexpr, + stride_a_seqlen: tl.int64, + stride_a_head: tl.int64, + stride_ak: tl.constexpr, + stride_b_seqlen: tl.int64, + stride_b_head: tl.int64, + stride_bk: tl.constexpr, + stride_out_chunk: tl.int64, + stride_out_head: tl.int64, + stride_outm: tl.int64, + stride_outn: tl.constexpr, + # Meta-parameters + IS_CAUSAL: tl.constexpr, + dot_dtype: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +): + pid_ch = tl.program_id(axis=1).to(tl.int64) + pid_c = pid_ch // ngroups + pid_h = pid_ch - pid_c * ngroups + num_pid_n = tl.cdiv(chunk_size, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + if IS_CAUSAL: + if pid_n * BLOCK_SIZE_N >= (pid_m + 1) * BLOCK_SIZE_M: + return + + chunk_seqlen_start = tl.load(cu_chunk_seqlens_ptr + pid_c) + chunk_seqlen_end = tl.load(cu_chunk_seqlens_ptr + pid_c + 1) + + a_ptr += chunk_seqlen_start * stride_a_seqlen + pid_h * stride_a_head + b_ptr += chunk_seqlen_start * stride_b_seqlen + pid_h * stride_b_head + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_m[:, None] * stride_a_seqlen + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_b_seqlen) + chunk_size_limit = chunk_seqlen_end - chunk_seqlen_start + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # compute a * b.T + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load( + a_ptrs, + mask=(offs_m[:, None] < chunk_size_limit) + & (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0, + ).to(dot_dtype) + b = tl.load( + b_ptrs, + mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K) + & (offs_n[None, :] < chunk_size_limit), + other=0.0, + ).to(dot_dtype) + acc += tl.dot(a, b) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + out = acc.to(out_ptr.dtype.element_ty) + out_ptr += pid_c * stride_out_chunk + pid_h * stride_out_head + out_ptrs = out_ptr + (stride_outm * offs_m[:, None] + offs_n[None, :] * stride_outn) + tl.store( + out_ptrs, + out, + mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size), + ) + + +def _bmm_chunk_fwd(a, b, chunk_size, cu_chunk_seqlens, causal=False, output_dtype=None): + """ + Argument: + a: (seqlen, ngroups, k) + b: (seqlen, ngroups, k) + chunk_size: int + cu_chunk_seq_lens: (nchunks+1,) + causal: if True, then out[i, j] for i > j will be arbitrary, only out[i, j] for i <= j are + guaranteed to be correct. + Return: + out: (nchunks, ngroups, chunk_size, chunk_size) + """ + seqlen, ngroups, k = a.shape + assert b.shape == a.shape + if a.stride(-1) != 1 and a.stride(0) != 1: + a = a.contiguous() + if b.stride(-1) != 1 and b.stride(0) != 1: + b = b.contiguous() + + nchunks = len(cu_chunk_seqlens) - 1 + # Allocates output. + out_dtype = a.dtype if output_dtype is None else output_dtype + out = torch.empty( + (nchunks, ngroups, chunk_size, chunk_size), device=a.device, dtype=out_dtype + ) + dot_dtype = ( + tl.bfloat16 + if a.dtype == torch.bfloat16 or b.dtype == torch.bfloat16 + else ( + tl.float16 + if a.dtype == torch.float16 or b.dtype == torch.float16 + else tl.float32 + ) + ) + grid = lambda META: ( + triton.cdiv(chunk_size, META["BLOCK_SIZE_M"]) + * triton.cdiv(chunk_size, META["BLOCK_SIZE_N"]), + nchunks * ngroups, + ) + with torch.cuda.device(a.device.index): + _bmm_chunk_fwd_kernel[grid]( + a_ptr=a, + b_ptr=b, + out_ptr=out, + cu_chunk_seqlens_ptr=cu_chunk_seqlens, + seqlen=seqlen, + chunk_size=chunk_size, + K=k, + ngroups=ngroups, + stride_a_seqlen=a.stride(0), + stride_a_head=a.stride(1), + stride_ak=a.stride(2), + stride_b_seqlen=b.stride(0), + stride_b_head=b.stride(1), + stride_bk=b.stride(2), + stride_out_chunk=out.stride(0), + stride_out_head=out.stride(1), + stride_outm=out.stride(-2), + stride_outn=out.stride(-1), + IS_CAUSAL=causal, + dot_dtype=dot_dtype, + ) + return out diff --git a/megatron/core/ssm/ops/ssd_chunk_scan.py b/megatron/core/ssm/ops/ssd_chunk_scan.py new file mode 100644 index 00000000000..a1715935c97 --- /dev/null +++ b/megatron/core/ssm/ops/ssd_chunk_scan.py @@ -0,0 +1,453 @@ +# Copyright (c) 2024, Tri Dao, Albert Gu. +# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_chunk_scan.py +# Adapted from vLLM project (Apache-2.0). + +from packaging import version + +import triton +import triton.language as tl + +TRITON_22 = version.parse(triton.__version__) >= version.parse("2.2.0") + + +@triton.autotune( + configs=[ + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64}, + num_stages=3, + num_warps=8, + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, + num_stages=5, + num_warps=2, + ), + triton.Config( + {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, + num_stages=5, + num_warps=2, + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, + num_stages=4, + num_warps=2, + ), + ], + key=["chunk_size", "hdim", "dstate", "IS_CAUSAL"], +) +@triton.jit +def _chunk_scan_fwd_kernel( + # Pointers to matrices + cb_ptr, + x_ptr, + z_ptr, + out_ptr, + dt_ptr, + dA_cumsum_ptr, + seq_idx_ptr, + C_ptr, + states_ptr, + D_ptr, + initstates_ptr, + cu_chunk_seqlens_ptr, + # Matrix dimensions + chunk_size: tl.constexpr, + hdim: tl.constexpr, + dstate: tl.constexpr, + seqlen, + nheads_ngroups_ratio: tl.constexpr, + # Strides + stride_cb_chunk: tl.int64, + stride_cb_head: tl.int64, + stride_cb_csize_m: tl.int64, + stride_cb_csize_k: tl.constexpr, + stride_x_seqlen: tl.int64, + stride_x_head: tl.int64, + stride_x_hdim: tl.constexpr, + stride_z_seqlen: tl.int64, + stride_z_head: tl.int64, + stride_z_hdim: tl.constexpr, + stride_out_seqlen: tl.int64, + stride_out_head: tl.int64, + stride_out_hdim: tl.constexpr, + stride_dt_chunk: tl.int64, + stride_dt_head: tl.int64, + stride_dt_csize: tl.constexpr, + stride_dA_cs_chunk: tl.int64, + stride_dA_cs_head: tl.int64, + stride_dA_cs_csize: tl.constexpr, + stride_seq_idx_chunk: tl.constexpr, + stride_C_seqlen: tl.int64, + stride_C_head: tl.int64, + stride_C_dstate: tl.constexpr, + stride_states_chunk: tl.int64, + stride_states_head: tl.int64, + stride_states_hdim: tl.int64, + stride_states_dstate: tl.constexpr, + stride_init_states_batch: tl.int64, + stride_init_states_head: tl.int64, + stride_init_states_hdim: tl.int64, + stride_init_states_dstate: tl.constexpr, + stride_D_head: tl.constexpr, + # Meta-parameters + IS_CAUSAL: tl.constexpr, + HAS_D: tl.constexpr, + D_HAS_HDIM: tl.constexpr, + HAS_Z: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_DSTATE: tl.constexpr, + IS_TRITON_22: tl.constexpr, + HAS_INITSTATES: tl.constexpr, +): + pid_c = tl.program_id(axis=1).to(tl.int64) + pid_h = tl.program_id(axis=2) + num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + cb_ptr += pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head + chunk_seqlen_start = tl.load(cu_chunk_seqlens_ptr + pid_c) + chunk_seqlen_end = tl.load(cu_chunk_seqlens_ptr + pid_c + 1) + x_ptr += chunk_seqlen_start * stride_x_seqlen + pid_h * stride_x_head + dt_ptr += pid_c * stride_dt_chunk + pid_h * stride_dt_head + dA_cumsum_ptr += pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head + C_ptr += ( + chunk_seqlen_start * stride_C_seqlen + + (pid_h // nheads_ngroups_ratio) * stride_C_head + ) + + # M-block offsets and prev states + # - logic in next block may override these if there is an active offset + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + + seq_idx_ptr += pid_c * stride_seq_idx_chunk + seq_idx = tl.load(seq_idx_ptr) + seq_idx_prev = tl.load( + seq_idx_ptr - stride_seq_idx_chunk, mask=pid_c >= 1, other=-1 + ) + + if HAS_INITSTATES and (seq_idx != seq_idx_prev): + prev_states_ptr = ( + initstates_ptr + + seq_idx * stride_init_states_batch + + pid_h * stride_init_states_head + ) + prev_states_hdim = stride_init_states_hdim + prev_states_dstate = stride_init_states_dstate + else: + prev_states_ptr = ( + states_ptr + (pid_c - 1) * stride_states_chunk + pid_h * stride_states_head + ) + prev_states_hdim = stride_states_hdim + prev_states_dstate = stride_states_dstate + + chunk_size_limit = chunk_seqlen_end - chunk_seqlen_start + + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + dA_cs_m = tl.load( + dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0 + ).to(tl.float32) + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + offs_out_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_out_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + # Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128 + offs_k_dstate = tl.arange( + 0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K + ) + C_ptrs = C_ptr + ( + offs_m[:, None] * stride_C_seqlen + offs_k_dstate[None, :] * stride_C_dstate + ) + + scale_m = tl.exp(dA_cs_m) + if BLOCK_SIZE_DSTATE <= 128: + C = tl.load( + C_ptrs, + mask=(offs_m[:, None] < chunk_size_limit) + & (offs_k_dstate[None, :] < dstate), + other=0.0, + ) + + if not HAS_INITSTATES and (seq_idx != seq_idx_prev): + # if no init states AND starting a new sequence, we need zeros + prev_states = tl.zeros( + (BLOCK_SIZE_DSTATE, BLOCK_SIZE_N), dtype=C_ptr.dtype.element_ty + ) + else: + # otherwise read the previous state + prev_states_ptrs = ( + prev_states_ptr + + offs_n[None, :] * prev_states_hdim + + offs_k_dstate[:, None] * prev_states_dstate + ) + prev_states = tl.load( + prev_states_ptrs, + mask=(offs_k_dstate[:, None] < dstate) & (offs_n[None, :] < hdim), + other=0.0, + ) + prev_states = prev_states.to(C_ptr.dtype.element_ty) + + acc = tl.dot(C, prev_states) * scale_m[:, None] + + else: + prev_states_ptrs = ( + prev_states_ptr + + offs_n[None, :] * prev_states_hdim + + offs_k_dstate[:, None] * prev_states_dstate + ) + for k in range(0, dstate, BLOCK_SIZE_K): + C = tl.load( + C_ptrs, + mask=(offs_m[:, None] < chunk_size_limit) + & (offs_k_dstate[None, :] < dstate - k), + other=0.0, + ) + if not HAS_INITSTATES and (seq_idx != seq_idx_prev): + prev_states = tl.zeros( + (BLOCK_SIZE_K, BLOCK_SIZE_N), dtype=C_ptr.dtype.element_ty + ) + else: + prev_states = tl.load( + prev_states_ptrs, + mask=(offs_k_dstate[:, None] < dstate - k) + & (offs_n[None, :] < hdim), + other=0.0, + ) + prev_states = prev_states.to(C_ptr.dtype.element_ty) + acc += tl.dot(C, prev_states) + C_ptrs += BLOCK_SIZE_K + prev_states_ptrs += BLOCK_SIZE_K + acc *= scale_m[:, None] + + offs_k = tl.arange(0, BLOCK_SIZE_K) + cb_ptrs = cb_ptr + ( + offs_m[:, None] * stride_cb_csize_m + offs_k[None, :] * stride_cb_csize_k + ) + x_ptrs = x_ptr + ( + offs_k[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim + ) + dt_ptrs = dt_ptr + offs_k * stride_dt_csize + dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize + K_MAX = ( + chunk_size_limit + if not IS_CAUSAL + else min((pid_m + 1) * BLOCK_SIZE_M, chunk_size_limit) + ) + for k in range(0, K_MAX, BLOCK_SIZE_K): + cb = tl.load( + cb_ptrs, + mask=(offs_m[:, None] < chunk_size) & (offs_k[None, :] < chunk_size - k), + other=0.0, + ).to(tl.float32) + dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size - k, other=0.0).to( + tl.float32 + ) + # If there's seq_idx, we already set cb[i, j] = 0 for seq_idx[i] != seq_idx[j]. + # So we don't need masking wrt seq_idx here. + cb *= tl.exp(dA_cs_m[:, None] - dA_cs_k[None, :]) + dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(tl.float32) + cb *= dt_k + if IS_CAUSAL: + mask = offs_m[:, None] >= k + offs_k[None, :] + cb = tl.where(mask, cb, 0.0) + cb = cb.to(x_ptr.dtype.element_ty) + x = tl.load( + x_ptrs, + mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < hdim), + other=0.0, + ) + acc += tl.dot(cb, x) + cb_ptrs += BLOCK_SIZE_K * stride_cb_csize_k + x_ptrs += BLOCK_SIZE_K * stride_x_seqlen + dt_ptrs += BLOCK_SIZE_K * stride_dt_csize + dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize + + offs_out_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_out_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + if HAS_D: + if D_HAS_HDIM: + D = tl.load( + D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0 + ).to(tl.float32) + else: + D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32) + x_residual = tl.load( + x_ptr + + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim), + mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), + other=0.0, + ).to(tl.float32) + acc += x_residual * D + + if HAS_Z: + z_ptr += chunk_seqlen_start * stride_z_seqlen + pid_h * stride_z_head + z_ptrs = z_ptr + ( + stride_z_seqlen * offs_out_m[:, None] + stride_z_hdim * offs_out_n[None, :] + ) + z = tl.load( + z_ptrs, + mask=(offs_out_m[:, None] < chunk_size_limit) + & (offs_out_n[None, :] < hdim), + other=0.0, + ).to(tl.float32) + acc *= z * tl.sigmoid(z) + + out_ptr += chunk_seqlen_start * stride_out_seqlen + pid_h * stride_out_head + out_ptrs = out_ptr + ( + stride_out_seqlen * offs_out_m[:, None] + offs_out_n[None, :] * stride_out_hdim + ) + tl.store( + out_ptrs, + acc, + mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim), + ) + + +def _chunk_scan_fwd( + cb, + x, + dt, + dA_cumsum, + C, + states, + cu_chunk_seqlens, + out, + seq_idx, + D=None, + z=None, + initial_states=None, +): + assert seq_idx is not None, "this implementation requires seq_idx" + + seqlen, nheads, headdim = x.shape + _, nchunks, chunk_size = dt.shape + _, ngroups, dstate = C.shape + assert nheads % ngroups == 0 + assert C.shape == (seqlen, ngroups, dstate) + assert cb.shape == (nchunks, ngroups, chunk_size, chunk_size) + if D is not None: + assert D.shape == (nheads, headdim) or D.shape == (nheads,) + if z is not None: + assert z.shape == x.shape + assert dt.shape == (nheads, nchunks, chunk_size) + assert dA_cumsum.shape == (nheads, nchunks, chunk_size) + assert states.shape == (nchunks, nheads, headdim, dstate) + assert seq_idx.shape == (nchunks,) + + grid = lambda META: ( + triton.cdiv(chunk_size, META["BLOCK_SIZE_M"]) + * triton.cdiv(headdim, META["BLOCK_SIZE_N"]), + nchunks, + nheads, + ) + + z_strides = (z.stride(0), z.stride(1), z.stride(2)) if z is not None else (0, 0, 0) + initial_states_strides = ( + ( + initial_states.stride(0), + initial_states.stride(1), + initial_states.stride(2), + initial_states.stride(3), + ) + if initial_states is not None + else (0, 0, 0, 0) + ) + + _chunk_scan_fwd_kernel[grid]( + cb_ptr=cb, + x_ptr=x, + z_ptr=z, + out_ptr=out, + dt_ptr=dt, + dA_cumsum_ptr=dA_cumsum, + seq_idx_ptr=seq_idx, + C_ptr=C, + states_ptr=states, + D_ptr=D, + initstates_ptr=initial_states, + cu_chunk_seqlens_ptr=cu_chunk_seqlens, + chunk_size=chunk_size, + hdim=headdim, + dstate=dstate, + seqlen=seqlen, + nheads_ngroups_ratio=nheads // ngroups, + stride_cb_chunk=cb.stride(0), + stride_cb_head=cb.stride(1), + stride_cb_csize_m=cb.stride(2), + stride_cb_csize_k=cb.stride(3), + stride_x_seqlen=x.stride(0), + stride_x_head=x.stride(1), + stride_x_hdim=x.stride(2), + stride_z_seqlen=z_strides[0], + stride_z_head=z_strides[1], + stride_z_hdim=z_strides[2], + stride_out_seqlen=out.stride(0), + stride_out_head=out.stride(1), + stride_out_hdim=out.stride(2), + stride_dt_chunk=dt.stride(1), + stride_dt_head=dt.stride(0), + stride_dt_csize=dt.stride(2), + stride_dA_cs_chunk=dA_cumsum.stride(1), + stride_dA_cs_head=dA_cumsum.stride(0), + stride_dA_cs_csize=dA_cumsum.stride(2), + stride_seq_idx_chunk=seq_idx.stride(0), + stride_C_seqlen=C.stride(0), + stride_C_head=C.stride(1), + stride_C_dstate=C.stride(2), + stride_states_chunk=states.stride(0), + stride_states_head=states.stride(1), + stride_states_hdim=states.stride(2), + stride_states_dstate=states.stride(3), + stride_init_states_batch=initial_states_strides[0], + stride_init_states_head=initial_states_strides[1], + stride_init_states_hdim=initial_states_strides[2], + stride_init_states_dstate=initial_states_strides[3], + stride_D_head=D.stride(0) if D is not None else 0, + IS_CAUSAL=True, + HAS_D=D is not None, + D_HAS_HDIM=D.dim() == 2 if D is not None else True, + HAS_Z=z is not None, + BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16), + IS_TRITON_22=TRITON_22, + HAS_INITSTATES=initial_states is not None, + ) + return diff --git a/megatron/core/ssm/ops/ssd_chunk_state.py b/megatron/core/ssm/ops/ssd_chunk_state.py new file mode 100644 index 00000000000..9e2fdaf867b --- /dev/null +++ b/megatron/core/ssm/ops/ssd_chunk_state.py @@ -0,0 +1,718 @@ +# Copyright (c) 2024, Tri Dao, Albert Gu. +# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_chunk_state.py +# Adapted from vLLM project (Apache-2.0). + +import torch +import triton +import triton.language as tl +from packaging import version + +try: + TRITON3 = version.parse(triton.__version__) >= version.parse("3.0.0") +except: + raise ImportError("Triton version 3.0.0 or higher is required") + +if TRITON3: + @triton.jit + def softplus(dt): + dt = tl.where(dt <= 20.0, tl.math.log(tl.math.exp(dt) + 1), dt) + return dt +else: + @triton.jit + def softplus(dt): + dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt) + return dt + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE_H": 2}), + triton.Config({"BLOCK_SIZE_H": 4}), + triton.Config({"BLOCK_SIZE_H": 8}), + triton.Config({"BLOCK_SIZE_H": 16}), + triton.Config({"BLOCK_SIZE_H": 32}), + triton.Config({"BLOCK_SIZE_H": 64}), + ], + key=["chunk_size", "nheads"], +) +@triton.jit +def _chunk_cumsum_fwd_kernel( + # Pointers to matrices + dt_ptr, + A_ptr, + dt_bias_ptr, + dt_out_ptr, + dA_cumsum_ptr, + cu_chunk_seqlens_ptr, + # Matrix dimension + seqlen, + nheads: tl.constexpr, + chunk_size: tl.constexpr, + dt_min: tl.constexpr, + dt_max: tl.constexpr, + # Strides + stride_dt_seqlen: tl.int64, + stride_dt_head: tl.constexpr, + stride_A_head: tl.constexpr, + stride_dt_bias_head: tl.constexpr, + stride_dt_out_head: tl.int64, + stride_dt_out_chunk: tl.int64, + stride_dt_out_csize: tl.constexpr, + stride_dA_cs_head: tl.int64, + stride_dA_cs_chunk: tl.int64, + stride_dA_cs_csize: tl.constexpr, + # Meta-parameters + DT_SOFTPLUS: tl.constexpr, + HAS_DT_BIAS: tl.constexpr, + BLOCK_SIZE_H: tl.constexpr, + BLOCK_SIZE_CHUNK: tl.constexpr, +): + # if dt is long, may cause problems, so use 64 bit + # https://github.com/triton-lang/triton/issues/1058 + pid_c = tl.program_id(axis=0).to(tl.int64) + pid_h = tl.program_id(axis=1) + + chunk_seqlen_start = tl.load(cu_chunk_seqlens_ptr + pid_c) + chunk_seqlen_end = tl.load(cu_chunk_seqlens_ptr + pid_c + 1) + + dt_ptr += chunk_seqlen_start * stride_dt_seqlen + dt_out_ptr += pid_c * stride_dt_out_chunk + dA_cumsum_ptr += pid_c * stride_dA_cs_chunk + + offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H) + offs_c = tl.arange(0, BLOCK_SIZE_CHUNK) + dt_ptrs = dt_ptr + ( + offs_h[:, None] * stride_dt_head + offs_c[None, :] * stride_dt_seqlen + ) + A_ptrs = A_ptr + offs_h * stride_A_head + dt_out_ptrs = dt_out_ptr + ( + offs_h[:, None] * stride_dt_out_head + offs_c[None, :] * stride_dt_out_csize + ) + dA_cs_ptrs = dA_cumsum_ptr + ( + offs_h[:, None] * stride_dA_cs_head + offs_c[None, :] * stride_dA_cs_csize + ) + chunk_size_limit = chunk_seqlen_end - chunk_seqlen_start + + dt = tl.load( + dt_ptrs, + mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), + other=0.0, + ).to(tl.float32) + if HAS_DT_BIAS: + dt_bias = tl.load( + dt_bias_ptr + offs_h * stride_dt_bias_head, mask=offs_h < nheads, other=0.0 + ).to(tl.float32) + dt += dt_bias[:, None] + if DT_SOFTPLUS: + dt = tl.where(dt <= 20.0, softplus(dt), dt) + + dt = tl.clamp(dt, dt_min, dt_max) + dt = tl.where( + (offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt, 0.0 + ) + tl.store( + dt_out_ptrs, + dt, + mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size), + ) + A = tl.load(A_ptrs, mask=offs_h < nheads, other=0.0).to(tl.float32) + dA = dt * A[:, None] + dA_cs = tl.cumsum(dA, axis=1) + tl.store( + dA_cs_ptrs, + dA_cs, + mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size), + ) + + +@triton.autotune( + configs=[ + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64}, + num_stages=3, + num_warps=8, + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, + num_stages=5, + num_warps=2, + ), + triton.Config( + {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, + num_stages=5, + num_warps=2, + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, + num_stages=4, + num_warps=2, + ), + ], + key=["hdim", "dstate", "chunk_size"], +) +@triton.jit +def _chunk_state_fwd_kernel( + # Pointers to matrices + x_ptr, + b_ptr, + states_ptr, + dt_ptr, + dA_cumsum_ptr, + cu_chunk_seqlens_ptr, + # Matrix dimensions + hdim: tl.constexpr, + dstate: tl.constexpr, + chunk_size: tl.constexpr, + seqlen, + nheads_ngroups_ratio: tl.constexpr, + # Strides + stride_x_seqlen: tl.int64, + stride_x_head: tl.int64, + stride_x_hdim: tl.constexpr, + stride_b_seqlen: tl.int64, + stride_b_head: tl.int64, + stride_b_dstate: tl.constexpr, + stride_states_chunk: tl.int64, + stride_states_head: tl.int64, + stride_states_hdim: tl.int64, + stride_states_dstate: tl.constexpr, + stride_dt_head: tl.int64, + stride_dt_chunk: tl.int64, + stride_dt_csize: tl.constexpr, + stride_dA_cs_head: tl.int64, + stride_dA_cs_chunk: tl.int64, + stride_dA_cs_csize: tl.constexpr, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +): + pid_c = tl.program_id(axis=1).to(tl.int64) + pid_h = tl.program_id(axis=2) + num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + chunk_seqlen_start = tl.load(cu_chunk_seqlens_ptr + pid_c) + chunk_seqlen_end = tl.load(cu_chunk_seqlens_ptr + pid_c + 1) + b_ptr += ( + chunk_seqlen_start * stride_b_seqlen + + (pid_h // nheads_ngroups_ratio) * stride_b_head + ) + x_ptr += chunk_seqlen_start * stride_x_seqlen + pid_h * stride_x_head + dt_ptr += pid_c * stride_dt_chunk + pid_h * stride_dt_head + dA_cumsum_ptr += pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + x_ptrs = x_ptr + ( + offs_m[:, None] * stride_x_hdim + offs_k[None, :] * stride_x_seqlen + ) + b_ptrs = b_ptr + ( + offs_n[None, :] * stride_b_dstate + offs_k[:, None] * stride_b_seqlen + ) + dt_ptrs = dt_ptr + offs_k * stride_dt_csize + dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to( + tl.float32 + ) + dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize + + chunk_size_limit = chunk_seqlen_end - chunk_seqlen_start + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, chunk_size_limit, BLOCK_SIZE_K): + x = tl.load( + x_ptrs, + mask=(offs_m[:, None] < hdim) & (offs_k[None, :] < chunk_size_limit - k), + other=0.0, + ) + b = tl.load( + b_ptrs, + mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < dstate), + other=0.0, + ).to(tl.float32) + dA_cs_k = tl.load( + dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0 + ).to(tl.float32) + dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to( + tl.float32 + ) + scale = tl.exp(dA_cs_last - dA_cs_k) * dt_k + b *= scale[:, None] + b = b.to(x_ptr.dtype.element_ty) + acc += tl.dot(x, b) + + x_ptrs += BLOCK_SIZE_K * stride_x_seqlen + b_ptrs += BLOCK_SIZE_K * stride_b_seqlen + dt_ptrs += BLOCK_SIZE_K * stride_dt_csize + dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize + + states = acc.to(states_ptr.dtype.element_ty) + + states_ptr += pid_c * stride_states_chunk + pid_h * stride_states_head + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + states_ptrs = states_ptr + ( + offs_m[:, None] * stride_states_hdim + offs_n[None, :] * stride_states_dstate + ) + c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate) + tl.store(states_ptrs, states, mask=c_mask) + + +def _chunk_cumsum_fwd( + dt, + A, + chunk_size, + cu_chunk_seqlens, + dt_bias=None, + dt_softplus=False, + dt_limit=(0.0, float("inf")), +): + seqlen, nheads = dt.shape + assert A.shape == (nheads,) + if dt_bias is not None: + assert dt_bias.shape == (nheads,) + nchunks = cu_chunk_seqlens.shape[0] - 1 + dt_out = torch.empty( + nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32 + ) + dA_cumsum = torch.empty( + nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32 + ) + grid_chunk_cs = lambda META: (nchunks, triton.cdiv(nheads, META["BLOCK_SIZE_H"])) + with torch.cuda.device(dt.device.index): + _chunk_cumsum_fwd_kernel[grid_chunk_cs]( + dt_ptr=dt, + A_ptr=A, + dt_bias_ptr=dt_bias, + dt_out_ptr=dt_out, + dA_cumsum_ptr=dA_cumsum, + cu_chunk_seqlens_ptr=cu_chunk_seqlens, + seqlen=seqlen, + nheads=nheads, + chunk_size=chunk_size, + dt_min=dt_limit[0], + dt_max=dt_limit[1], + stride_dt_seqlen=dt.stride(0), + stride_dt_head=dt.stride(1), + stride_A_head=A.stride(0), + stride_dt_bias_head=dt_bias.stride(0) if dt_bias is not None else 0, + stride_dt_out_head=dt_out.stride(0), + stride_dt_out_chunk=dt_out.stride(1), + stride_dt_out_csize=dt_out.stride(2), + stride_dA_cs_head=dA_cumsum.stride(0), + stride_dA_cs_chunk=dA_cumsum.stride(1), + stride_dA_cs_csize=dA_cumsum.stride(2), + DT_SOFTPLUS=dt_softplus, + HAS_DT_BIAS=dt_bias is not None, + BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size), + ) + return dA_cumsum, dt_out + + +def _chunk_state_fwd( + B, x, dt, dA_cumsum, cu_chunk_seqlens, states=None, states_in_fp32=True +): + seqlen, nheads, headdim = x.shape + _, nchunks, chunk_size = dt.shape + _, ngroups, dstate = B.shape + assert nheads % ngroups == 0 + assert B.shape == (seqlen, ngroups, dstate) + assert dt.shape == (nheads, nchunks, chunk_size) + assert dA_cumsum.shape == dt.shape + + if states is not None: + assert states.shape == (nchunks, nheads, headdim, dstate) + else: + states_dtype = torch.float32 if states_in_fp32 else B.dtype + states = torch.empty( + (nchunks, nheads, headdim, dstate), device=x.device, dtype=states_dtype + ) + + grid = lambda META: ( + triton.cdiv(headdim, META["BLOCK_SIZE_M"]) + * triton.cdiv(dstate, META["BLOCK_SIZE_N"]), + nchunks, + nheads, + ) + with torch.cuda.device(x.device.index): + _chunk_state_fwd_kernel[grid]( + x_ptr=x, + b_ptr=B, + states_ptr=states, + dt_ptr=dt, + dA_cumsum_ptr=dA_cumsum, + cu_chunk_seqlens_ptr=cu_chunk_seqlens, + hdim=headdim, + dstate=dstate, + chunk_size=chunk_size, + seqlen=seqlen, + nheads_ngroups_ratio=nheads // ngroups, + stride_x_seqlen=x.stride(0), + stride_x_head=x.stride(1), + stride_x_hdim=x.stride(2), + stride_b_seqlen=B.stride(0), + stride_b_head=B.stride(1), + stride_b_dstate=B.stride(2), + stride_states_chunk=states.stride(0), + stride_states_head=states.stride(1), + stride_states_hdim=states.stride(2), + stride_states_dstate=states.stride(3), + stride_dt_head=dt.stride(0), + stride_dt_chunk=dt.stride(1), + stride_dt_csize=dt.stride(2), + stride_dA_cs_head=dA_cumsum.stride(0), + stride_dA_cs_chunk=dA_cumsum.stride(1), + stride_dA_cs_csize=dA_cumsum.stride(2), + ) + return states + + +@triton.autotune( + configs=[ + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64}, + num_stages=3, + num_warps=8, + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, + num_stages=5, + num_warps=2, + ), + triton.Config( + {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, + num_stages=5, + num_warps=2, + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, + num_stages=4, + num_warps=2, + ), + ], + key=["hdim", "dstate", "chunk_size"], +) +@triton.jit +def _chunk_state_varlen_kernel( + x_ptr, + b_ptr, + dt_ptr, + dA_cumsum_ptr, + chunk_states_ptr, + cu_seqlens_ptr, + last_chunk_indices_ptr, + cu_chunk_seqlens_ptr, + states_ptr, + initstates_ptr, + hdim: tl.constexpr, + dstate: tl.constexpr, + chunk_size: tl.constexpr, + nheads_ngroups_ratio: tl.constexpr, + stride_x_seqlen: tl.int64, + stride_x_head: tl.int64, + stride_x_hdim: tl.constexpr, + stride_b_seqlen: tl.int64, + stride_b_head: tl.int64, + stride_b_dstate: tl.constexpr, + stride_dt_head: tl.int64, + stride_dt_chunk: tl.int64, + stride_dt_csize: tl.constexpr, + stride_dA_cs_head: tl.int64, + stride_dA_cs_chunk: tl.int64, + stride_dA_cs_csize: tl.constexpr, + stride_chunk_states_chunk: tl.int64, + stride_chunk_states_head: tl.int64, + stride_chunk_states_hdim: tl.int64, + stride_chunk_states_dstate: tl.constexpr, + stride_states_batch: tl.int64, + stride_states_head: tl.int64, + stride_states_hdim: tl.int64, + stride_states_dstate: tl.constexpr, + stride_init_states_batch: tl.int64, + stride_init_states_head: tl.int64, + stride_init_states_hdim: tl.int64, + stride_init_states_dstate: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + HAS_INITSTATES: tl.constexpr, + USE_LAST_CHUNK_INDICES: tl.constexpr, +): + pid_b = tl.program_id(axis=1) + pid_h = tl.program_id(axis=2) + num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + end_idx = tl.load(cu_seqlens_ptr + pid_b + 1) + start_idx = tl.load(cu_seqlens_ptr + pid_b) + if USE_LAST_CHUNK_INDICES: + pid_c = tl.load(last_chunk_indices_ptr + pid_b).to(tl.int64) + chunk_start = tl.load(cu_chunk_seqlens_ptr + pid_c) + chunk_size_limit = tl.load(cu_chunk_seqlens_ptr + pid_c + 1) - chunk_start + else: + pid_c = (end_idx - 1) // chunk_size + chunk_start = pid_c * chunk_size + chunk_size_limit = end_idx - chunk_start + b_ptr += ( + chunk_start * stride_b_seqlen + + (pid_h // nheads_ngroups_ratio) * stride_b_head + ) + x_ptr += chunk_start * stride_x_seqlen + pid_h * stride_x_head + dt_ptr += pid_c * stride_dt_chunk + pid_h * stride_dt_head + dA_cumsum_ptr += pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head + chunk_states_ptr += ( + pid_c * stride_chunk_states_chunk + pid_h * stride_chunk_states_head + ) + + if HAS_INITSTATES: + initstates_ptr += pid_h * stride_init_states_head + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + x_ptrs = x_ptr + ( + offs_m[:, None] * stride_x_hdim + offs_k[None, :] * stride_x_seqlen + ) + b_ptrs = b_ptr + ( + offs_n[None, :] * stride_b_dstate + offs_k[:, None] * stride_b_seqlen + ) + dt_ptrs = dt_ptr + offs_k * stride_dt_csize + dA_cs_last = tl.load( + dA_cumsum_ptr + (end_idx - 1 - chunk_start) * stride_dA_cs_csize + ).to(tl.float32) + dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize + + start_idx_cur = tl.maximum(start_idx - chunk_start, 0) + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, chunk_size_limit, BLOCK_SIZE_K): + x = tl.load( + x_ptrs, + mask=(offs_m[:, None] < hdim) + & (offs_k[None, :] < chunk_size_limit - k) + & (offs_k[None, :] >= start_idx_cur - k), + other=0.0, + ) + b = tl.load( + b_ptrs, + mask=(offs_k[:, None] < chunk_size_limit - k) + & (offs_n[None, :] < dstate) + & (offs_k[:, None] >= start_idx_cur - k), + other=0.0, + ).to(tl.float32) + dA_cs_k = tl.load( + dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0 + ).to(tl.float32) + dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to( + tl.float32 + ) + scale = tl.where( + (offs_k >= start_idx_cur - k) & (offs_k < chunk_size_limit - k), + tl.exp(dA_cs_last - dA_cs_k) * dt_k, + 0.0, + ) + b *= scale[:, None] + b = b.to(x_ptr.dtype.element_ty) + acc += tl.dot(x, b) + x_ptrs += BLOCK_SIZE_K * stride_x_seqlen + b_ptrs += BLOCK_SIZE_K * stride_b_seqlen + dt_ptrs += BLOCK_SIZE_K * stride_dt_csize + dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize + + if (start_idx < chunk_start) or (HAS_INITSTATES): + dA_cs_boundary = 0.0 + if not HAS_INITSTATES: + past_states_ptrs = chunk_states_ptr + ( + offs_m[:, None] * stride_chunk_states_hdim + + offs_n[None, :] * stride_chunk_states_dstate + ) + else: + if start_idx < chunk_start: + past_states_ptrs = chunk_states_ptr + ( + offs_m[:, None] * stride_chunk_states_hdim + + offs_n[None, :] * stride_chunk_states_dstate + ) + else: + past_states_ptrs = initstates_ptr + ( + pid_b * stride_init_states_batch + + offs_m[:, None] * stride_init_states_hdim + + offs_n[None, :] * stride_init_states_dstate + ) + if start_idx > chunk_start: + dA_cs_boundary = tl.load( + dA_cumsum_ptr + + (start_idx - chunk_start - 1) * stride_dA_cs_csize + ).to(tl.float32) + + past_states = tl.load( + past_states_ptrs, + mask=(offs_m[:, None] < hdim) & (offs_n[None, :] < dstate), + other=0.0, + ).to(tl.float32) + scale = tl.exp(dA_cs_last - dA_cs_boundary) + acc += past_states * scale + + states = acc.to(states_ptr.dtype.element_ty) + states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + states_ptrs = states_ptr + ( + offs_m[:, None] * stride_states_hdim + offs_n[None, :] * stride_states_dstate + ) + c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate) + tl.store(states_ptrs, states, mask=c_mask) + + +def chunk_state_varlen( + B, + x, + dt, + dA_cumsum, + cu_seqlens, + chunk_states, + initial_states=None, + last_chunk_indices=None, + cu_chunk_seqlens=None, +): + """Compute per-sequence final SSM state from chunk states (correct when sequences share chunks).""" + total_seqlen, nheads, headdim = x.shape + _, nchunks, chunk_size = dt.shape + _, ngroups, dstate = B.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + assert nheads % ngroups == 0 + assert B.shape == (total_seqlen, ngroups, dstate) + assert dt.shape == (nheads, nchunks, chunk_size) + assert dA_cumsum.shape == dt.shape + assert chunk_states.shape == (nchunks, nheads, headdim, dstate) + if initial_states is not None: + assert initial_states.shape == (batch, nheads, headdim, dstate) + use_last_chunk = ( + last_chunk_indices is not None and cu_chunk_seqlens is not None + ) + if use_last_chunk: + last_chunk_indices = last_chunk_indices.contiguous().to(x.device) + cu_chunk_seqlens = cu_chunk_seqlens.contiguous().to(x.device) + else: + last_chunk_indices = torch.zeros(1, dtype=torch.int64, device=x.device) + cu_chunk_seqlens = cu_seqlens + + states = torch.empty( + batch, + nheads, + headdim, + dstate, + dtype=chunk_states.dtype, + device=chunk_states.device, + ) + initial_states_strides = ( + ( + initial_states.stride(0), + initial_states.stride(1), + initial_states.stride(2), + initial_states.stride(3), + ) + if initial_states is not None + else (0, 0, 0, 0) + ) + grid = lambda META: ( + triton.cdiv(headdim, META["BLOCK_SIZE_M"]) + * triton.cdiv(dstate, META["BLOCK_SIZE_N"]), + batch, + nheads, + ) + with torch.cuda.device(x.device.index): + _chunk_state_varlen_kernel[grid]( + x_ptr=x, + b_ptr=B, + dt_ptr=dt, + dA_cumsum_ptr=dA_cumsum, + chunk_states_ptr=chunk_states, + cu_seqlens_ptr=cu_seqlens, + last_chunk_indices_ptr=last_chunk_indices, + cu_chunk_seqlens_ptr=cu_chunk_seqlens, + states_ptr=states, + initstates_ptr=initial_states, + hdim=headdim, + dstate=dstate, + chunk_size=chunk_size, + nheads_ngroups_ratio=nheads // ngroups, + stride_x_seqlen=x.stride(0), + stride_x_head=x.stride(1), + stride_x_hdim=x.stride(2), + stride_b_seqlen=B.stride(0), + stride_b_head=B.stride(1), + stride_b_dstate=B.stride(2), + stride_dt_head=dt.stride(0), + stride_dt_chunk=dt.stride(1), + stride_dt_csize=dt.stride(2), + stride_dA_cs_head=dA_cumsum.stride(0), + stride_dA_cs_chunk=dA_cumsum.stride(1), + stride_dA_cs_csize=dA_cumsum.stride(2), + stride_chunk_states_chunk=chunk_states.stride(0), + stride_chunk_states_head=chunk_states.stride(1), + stride_chunk_states_hdim=chunk_states.stride(2), + stride_chunk_states_dstate=chunk_states.stride(3), + stride_states_batch=states.stride(0), + stride_states_head=states.stride(1), + stride_states_hdim=states.stride(2), + stride_states_dstate=states.stride(3), + stride_init_states_batch=initial_states_strides[0], + stride_init_states_head=initial_states_strides[1], + stride_init_states_hdim=initial_states_strides[2], + stride_init_states_dstate=initial_states_strides[3], + HAS_INITSTATES=initial_states is not None, + USE_LAST_CHUNK_INDICES=use_last_chunk, + ) + return states diff --git a/megatron/core/ssm/ops/ssd_combined.py b/megatron/core/ssm/ops/ssd_combined.py new file mode 100644 index 00000000000..c6a8a363a5c --- /dev/null +++ b/megatron/core/ssm/ops/ssd_combined.py @@ -0,0 +1,241 @@ +# Copyright (c) 2024, Tri Dao, Albert Gu. +# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_combined.py +# Adapted from vLLM project (Apache-2.0). + +import torch +from einops import rearrange +from packaging import version + +import triton + +from .ssd_bmm import _bmm_chunk_fwd +from .ssd_chunk_scan import _chunk_scan_fwd +from .ssd_chunk_state import ( + _chunk_cumsum_fwd, + _chunk_state_fwd, + chunk_state_varlen, +) +from .ssd_state_passing import _state_passing_fwd + +TRITON_22 = version.parse(triton.__version__) >= version.parse("2.2.0") + + +def is_int_pow_2(n): + return isinstance(n, int) and n > 0 and (n & (n - 1)) == 0 + + +def _mamba_chunk_scan_combined_fwd( + x, + dt, + A, + B, + C, + chunk_size, + out, + D=None, + z=None, + dt_bias=None, + initial_states=None, + return_intermediate_states=False, + seq_idx=None, + cu_seqlens=None, + cu_chunk_seqlens=None, + last_chunk_indices=None, + dt_softplus=False, + dt_limit=(0.0, float("inf")), + state_dtype=None, +): + assert is_int_pow_2(chunk_size), "chunk_size must be integer power of 2" + seqlen, nheads, headdim = x.shape + _, ngroups, dstate = B.shape + assert nheads % ngroups == 0 + assert B.shape == (seqlen, ngroups, dstate) + assert dt.shape == (seqlen, nheads) + assert A.shape == (nheads,) + assert C.shape == B.shape + if z is not None: + assert z.shape == x.shape + if D is not None: + assert D.shape == (nheads, headdim) or D.shape == (nheads,) + if seq_idx is not None: + assert seq_idx.shape == (cu_chunk_seqlens.shape[0] - 1,) + if B.stride(-1) != 1: + B = B.contiguous() + if C.stride(-1) != 1: + C = C.contiguous() + if ( + x.stride(-1) != 1 and x.stride(0) != 1 + ): # Either M or K dimension should be contiguous + x = x.contiguous() + if ( + z is not None and z.stride(-1) != 1 and z.stride(0) != 1 + ): # Either M or K dimension should be contiguous + z = z.contiguous() + if D is not None and D.stride(-1) != 1: + D = D.contiguous() + assert cu_seqlens is not None, "Assuming varlen input - must supply cu_seqlens" + + if initial_states is not None: + assert initial_states.shape == (len(cu_seqlens) - 1, nheads, headdim, dstate) + + # This function executes 5 sub-functions for computing mamba + # - a good resource is the blog https://goombalab.github.io/blog/2024/mamba2-part3-algorithm/ + # which has a minimal implementation to understand the below operations + # - as explained by the blog, mamba is a special case of causal attention + # - the idea is to chunk the attention matrix and compute each + # submatrix separately using different optimizations. + # - see the blog and paper for a visualization of the submatrices + # which we refer to in the comments below + + # 1. Compute chunked cumsum of A * dt + # - here dt may go through a softplus activation + dA_cumsum, dt = _chunk_cumsum_fwd( + dt, + A, + chunk_size, + cu_chunk_seqlens, + dt_bias=dt_bias, + dt_softplus=dt_softplus, + dt_limit=dt_limit, + ) + + # 2. Compute the state for each intra-chunk + # (right term of low-rank factorization of off-diagonal blocks; B terms) + states = _chunk_state_fwd( + B, x, dt, dA_cumsum, cu_chunk_seqlens, states_in_fp32=True + ) + + # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries + # (middle term of factorization of off-diag blocks; A terms) + # - for handling chunked prefill, this requires i) initial_states and + # ii) seq_idx to be all specified. + # - When a new seq_idx is detected, we will stop passing the prev_state + # and switch accordingly to the init_state corresponding to the new seq_idx. + states = _state_passing_fwd( + rearrange(states, "... p n -> ... (p n)"), + dA_cumsum, # (nheads, nchunks, chunk_size) + cu_chunk_seqlens, + initial_states=rearrange(initial_states, "... p n -> ... (p n)") + if initial_states is not None + else None, # (batch, nheads, headdim*dstate) + seq_idx=seq_idx, + out_dtype=state_dtype if state_dtype is not None else C.dtype, + ) + states = rearrange(states, "... (p n) -> ... p n", n=dstate) + + # 4. Compute batched matrix multiply for C_j^T B_i terms + CB = _bmm_chunk_fwd(C, B, chunk_size, cu_chunk_seqlens, output_dtype=torch.float32) + + # 5. Scan and compute the diagonal blocks, taking into + # account past causal states. + # - if initial states are provided, then states information will be + # augmented with initial_states. + # - to do this properly, we need to account for example changes in + # the continuous batch, therefore we introduce pseudo chunks, which is + # a chunk that is split up each time an example changes. + # - in each (pseudo) chunk, we detect if the previous (pseudo) chunk had + # a seq_idx change, in which case we take states information from + # init_states. + _chunk_scan_fwd( + CB, + x, + dt, + dA_cumsum, + C, + states, + cu_chunk_seqlens, + out, # in-place update + seq_idx, + D=D, + z=z, + initial_states=initial_states, + ) + + if return_intermediate_states: + return states + else: + # Per-sequence final state at exact last token (correct when sequences share chunks) + return chunk_state_varlen( + B, + x, + dt, + dA_cumsum, + cu_seqlens, + states, + initial_states=initial_states, + last_chunk_indices=last_chunk_indices, + cu_chunk_seqlens=cu_chunk_seqlens, + ) + + +def mamba_chunk_scan_combined_varlen( + x, + dt, + A, + B, + C, + chunk_size, + cu_seqlens, + cu_chunk_seqlens, + last_chunk_indices, + seq_idx, + out, + D=None, + z=None, + dt_bias=None, + initial_states=None, + dt_softplus=False, + dt_limit=(0.0, float("inf")), + return_intermediate_states=False, + state_dtype=None, +): + """ + Argument: + x: (seqlen, nheads, headdim) + dt: (seqlen, nheads) + A: (nheads) + B: (seqlen, ngroups, dstate) + C: (seqlen, ngroups, dstate) + chunk_size: int + cu_seqlens: (batch + 1,) + cu_chunk_seqlens: (nchunks + 1,) + last_chunk_indices: (batch,) + seq_idx: (nchunks,) + out: (seqlen, nheads, headdim) preallocated output tensor + D: (nheads, headdim) or (nheads,) + z: (seqlen, nheads, headdim) + dt_bias: (nheads,) + initial_states: (batch, nheads, headdim, dstate) + dt_softplus: Whether to apply softplus to dt + out: (seqlen, nheads, headdim) preallocated output tensor + state_dtype: The data type of the ssm state + Return: + varlen_states: (batch, nheads, headdim, dstate) + """ + + assert cu_seqlens is not None, "cu_seqlens must be provided assuming varlen input" + assert seq_idx is not None + + varlen_states = _mamba_chunk_scan_combined_fwd( + x, + dt, + A, + B, + C, + chunk_size, + out, + D=D, + z=z, + dt_bias=dt_bias, + initial_states=initial_states, + return_intermediate_states=return_intermediate_states, + seq_idx=seq_idx, + cu_seqlens=cu_seqlens, + cu_chunk_seqlens=cu_chunk_seqlens, + last_chunk_indices=last_chunk_indices, + dt_softplus=dt_softplus, + dt_limit=dt_limit, + state_dtype=state_dtype, + ) + + return varlen_states diff --git a/megatron/core/ssm/ops/ssd_state_passing.py b/megatron/core/ssm/ops/ssd_state_passing.py new file mode 100644 index 00000000000..a121a860be4 --- /dev/null +++ b/megatron/core/ssm/ops/ssd_state_passing.py @@ -0,0 +1,153 @@ +# Copyright (c) 2024, Tri Dao, Albert Gu. +# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_state_passing.py +# Adapted from vLLM project (Apache-2.0). + +import torch +import triton +import triton.language as tl + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE": 64}), + triton.Config({"BLOCK_SIZE": 128}), + triton.Config({"BLOCK_SIZE": 256}), + triton.Config({"BLOCK_SIZE": 512}), + triton.Config({"BLOCK_SIZE": 1024}), + triton.Config({"BLOCK_SIZE": 2048}), + ], + key=["dim"], +) +@triton.jit +def _state_passing_fwd_kernel( + # Pointers to matrices + states_ptr, + out_ptr, + dA_cs_ptr, + initstates_ptr, + seq_idx_ptr, + cu_chunk_seqlens_ptr, + # Matrix dimensions + dim: tl.constexpr, + nchunks, + seqlen, + chunk_size: tl.constexpr, + # Strides + stride_states_chunk: tl.int64, + stride_states_head: tl.int64, + stride_states_dim: tl.constexpr, + stride_out_chunk: tl.int64, + stride_out_head: tl.int64, + stride_out_dim: tl.constexpr, + stride_dA_cs_head: tl.int64, + stride_dA_cs_chunk: tl.int64, + stride_dA_cs_csize: tl.constexpr, + stride_initstates_batch: tl.int64, + stride_initstates_head: tl.int64, + stride_initstates_dim: tl.constexpr, + stride_seq_idx_chunk: tl.constexpr, + # Meta-parameters + HAS_INITSTATES: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + pid_h = tl.program_id(axis=1) + pid_m = tl.program_id(axis=0) + + states_ptr += pid_h * stride_states_head + dA_cs_ptr += pid_h * stride_dA_cs_head + (chunk_size - 1) * stride_dA_cs_csize + out_ptr += pid_h * stride_out_head + + offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + states_ptrs = states_ptr + offs_m * stride_states_dim + out_ptrs = out_ptr + offs_m * stride_out_dim + + if HAS_INITSTATES: + initstates_ptrs = ( + initstates_ptr + + pid_h * stride_initstates_head + + offs_m * stride_initstates_dim + ) + + states = tl.load(initstates_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + else: + states = tl.zeros((BLOCK_SIZE,), dtype=tl.float32) + + prev_seq_idx = 0 + for c in range(nchunks): + new_states = tl.load(states_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + dA_cs = tl.load(dA_cs_ptr).to(tl.float32) + seq_idx = tl.load(seq_idx_ptr + c * stride_seq_idx_chunk) + # we have started a new sequence + if prev_seq_idx != seq_idx: + if HAS_INITSTATES: + initstates_ptrs = ( + initstates_ptr + + seq_idx * stride_initstates_batch + + pid_h * stride_initstates_head + + offs_m * stride_initstates_dim + ) + states = tl.load(initstates_ptrs, mask=offs_m < dim, other=0.0).to( + tl.float32 + ) + else: + states = tl.zeros((BLOCK_SIZE,), dtype=tl.float32) + + prev_seq_idx = seq_idx + states = tl.exp(dA_cs) * states + new_states + tl.store(out_ptrs, states, mask=offs_m < dim) + + states_ptrs += stride_states_chunk + dA_cs_ptr += stride_dA_cs_chunk + out_ptrs += stride_out_chunk + + +def _state_passing_fwd( + states, + dA_cumsum, + cu_chunk_seqlens, + seq_idx, + initial_states=None, + out_dtype=None, +): + nchunks, nheads, dim = states.shape + chunk_size = dA_cumsum.shape[-1] + assert dA_cumsum.shape == (nheads, nchunks, chunk_size) + seqlen = seq_idx.shape[-1] + out_dtype = states.dtype if out_dtype is None else out_dtype + out = torch.empty((nchunks, nheads, dim), device=states.device, dtype=out_dtype) + + initial_states_strides = ( + (initial_states.stride(0), initial_states.stride(1), initial_states.stride(2)) + if initial_states is not None + else (0, 0, 0) + ) + + grid = lambda META: (triton.cdiv(dim, META["BLOCK_SIZE"]), nheads) + with torch.cuda.device(states.device.index): + _state_passing_fwd_kernel[grid]( + states_ptr=states, + out_ptr=out, + dA_cs_ptr=dA_cumsum, + initstates_ptr=initial_states, + seq_idx_ptr=seq_idx, + cu_chunk_seqlens_ptr=cu_chunk_seqlens, + dim=dim, + nchunks=nchunks, + seqlen=seqlen if seq_idx is not None else 0, + chunk_size=chunk_size if seq_idx is not None else 0, + stride_states_chunk=states.stride(0), + stride_states_head=states.stride(1), + stride_states_dim=states.stride(2), + stride_out_chunk=out.stride(0), + stride_out_head=out.stride(1), + stride_out_dim=out.stride(2), + stride_dA_cs_head=dA_cumsum.stride(0), + stride_dA_cs_chunk=dA_cumsum.stride(1), + stride_dA_cs_csize=dA_cumsum.stride(2), + stride_initstates_batch=initial_states_strides[0], + stride_initstates_head=initial_states_strides[1], + stride_initstates_dim=initial_states_strides[2], + stride_seq_idx_chunk=seq_idx.stride(0), + HAS_INITSTATES=initial_states is not None, + ) + return out From 13746c90e74d2a965112382c7d8a60d40772ccd1 Mon Sep 17 00:00:00 2001 From: Onur Yilmaz Date: Sun, 8 Feb 2026 15:51:47 -0500 Subject: [PATCH 02/11] Add init states Signed-off-by: Onur Yilmaz --- megatron/core/ssm/mamba_mixer.py | 208 +++++++++++++++---------------- 1 file changed, 101 insertions(+), 107 deletions(-) diff --git a/megatron/core/ssm/mamba_mixer.py b/megatron/core/ssm/mamba_mixer.py index 0fc6f36b389..04b8f676f93 100644 --- a/megatron/core/ssm/mamba_mixer.py +++ b/megatron/core/ssm/mamba_mixer.py @@ -875,123 +875,117 @@ def _ssm_prefill( chunk_size = self.chunk_size device = x.device - if total_tokens > 0: - # Build chunk boundaries so no chunk spans two sequences (fixes junk output - # when multiple short sequences share a chunk). Merge fixed-size boundaries - # with sequence boundaries from cu_seqlens. - boundaries_set = {0, total_tokens} - for s in range(1, batch): - boundaries_set.add(cu_seqlens[s].item()) - for pos in range(0, total_tokens, chunk_size): - boundaries_set.add(min(pos, total_tokens)) - boundaries = sorted(boundaries_set) - cu_chunk_seqlens = torch.tensor( - boundaries, device=device, dtype=cu_seqlens.dtype - ) - nchunks = len(boundaries) - 1 - - seq_lengths = cu_seqlens[1:] - cu_seqlens[:-1] - # Chunk index that contains the last token of each sequence - last_token_pos = (cu_seqlens[1:] - 1).clamp(min=0) - last_chunk_indices = ( - torch.searchsorted( - cu_chunk_seqlens, last_token_pos.to(cu_chunk_seqlens.dtype), right=False - ) - - 1 - ) - last_chunk_indices = last_chunk_indices.clamp(0, nchunks - 1).to( - device=device, dtype=torch.int64 - ) + initial_ssm_state = ssm_state[batch_indices] - # Chunk-level seq_idx: which sequence each chunk belongs to - chunk_starts = cu_chunk_seqlens[:-1].to(cu_seqlens.dtype) - seq_idx_chunk = ( - (torch.searchsorted(cu_seqlens, chunk_starts, right=False) - 1) - .clamp(0, batch - 1) - .to(device=device, dtype=torch.int32) + # Build chunk boundaries so no chunk spans two sequences (fixes junk output + # when multiple short sequences share a chunk). Merge fixed-size boundaries + # with sequence boundaries from cu_seqlens. + boundaries_set = {0, total_tokens} + for s in range(1, batch): + boundaries_set.add(cu_seqlens[s].item()) + for pos in range(0, total_tokens, chunk_size): + boundaries_set.add(min(pos, total_tokens)) + boundaries = sorted(boundaries_set) + cu_chunk_seqlens = torch.tensor( + boundaries, device=device, dtype=cu_seqlens.dtype + ) + nchunks = len(boundaries) - 1 + + seq_lengths = cu_seqlens[1:] - cu_seqlens[:-1] + # Chunk index that contains the last token of each sequence + last_token_pos = (cu_seqlens[1:] - 1).clamp(min=0) + last_chunk_indices = ( + torch.searchsorted( + cu_chunk_seqlens, last_token_pos.to(cu_chunk_seqlens.dtype), right=False ) + - 1 + ) + last_chunk_indices = last_chunk_indices.clamp(0, nchunks - 1).to( + device=device, dtype=torch.int64 + ) - # Pack tensors to (total_tokens, ...); use .item() for safe Python int slicing - x_packed = torch.cat( - [ - x[b, : seq_lengths[b].item(), :, :].contiguous() - for b in range(batch) - ], - dim=0, - ).contiguous() - dt_packed = torch.cat( - [dt[b, : seq_lengths[b].item(), :].contiguous() for b in range(batch)], - dim=0, - ).contiguous() - B_packed = torch.cat( - [ - B[b, : seq_lengths[b].item(), :, :].contiguous() - for b in range(batch) - ], - dim=0, - ).contiguous() - C_packed = torch.cat( + # Chunk-level seq_idx: which sequence each chunk belongs to + chunk_starts = cu_chunk_seqlens[:-1].to(cu_seqlens.dtype) + seq_idx_chunk = ( + (torch.searchsorted(cu_seqlens, chunk_starts, right=False) - 1) + .clamp(0, batch - 1) + .to(device=device, dtype=torch.int32) + ) + + # Pack tensors to (total_tokens, ...); use .item() for safe Python int slicing + x_packed = torch.cat( + [ + x[b, : seq_lengths[b].item(), :, :].contiguous() + for b in range(batch) + ], + dim=0, + ).contiguous() + dt_packed = torch.cat( + [dt[b, : seq_lengths[b].item(), :].contiguous() for b in range(batch)], + dim=0, + ).contiguous() + B_packed = torch.cat( + [ + B[b, : seq_lengths[b].item(), :, :].contiguous() + for b in range(batch) + ], + dim=0, + ).contiguous() + C_packed = torch.cat( + [ + C[b, : seq_lengths[b].item(), :, :].contiguous() + for b in range(batch) + ], + dim=0, + ).contiguous() + z_packed = ( + torch.cat( [ - C[b, : seq_lengths[b].item(), :, :].contiguous() + z[b, : seq_lengths[b].item(), :, :].contiguous() for b in range(batch) ], dim=0, ).contiguous() - z_packed = ( - torch.cat( - [ - z[b, : seq_lengths[b].item(), :, :].contiguous() - for b in range(batch) - ], - dim=0, - ).contiguous() - if not self.rmsnorm - else None - ) + if not self.rmsnorm + else None + ) - out_packed = torch.empty_like(x_packed) - D_val = ( - rearrange(self.cp.get_D().float(), "(h p) -> h p", p=self.headdim) - if self.D_has_hdim - else self.cp.get_D() - ) - dt_bias_val = self.cp.get_dt_bias().float() - - varlen_states = mamba_chunk_scan_combined_varlen( - x_packed, - dt_packed, - A, - B_packed, - C_packed, - chunk_size, - cu_seqlens=cu_seqlens, - cu_chunk_seqlens=cu_chunk_seqlens, - last_chunk_indices=last_chunk_indices, - seq_idx=seq_idx_chunk, - out=out_packed, - D=D_val, - z=z_packed, - dt_bias=dt_bias_val, - initial_states=initial_ssm_state, - dt_softplus=True, - return_intermediate_states=False, - ) + out_packed = torch.empty_like(x_packed) + D_val = ( + rearrange(self.cp.get_D().float(), "(h p) -> h p", p=self.headdim) + if self.D_has_hdim + else self.cp.get_D() + ) + dt_bias_val = self.cp.get_dt_bias().float() - # Unpack output to (batch, max_seqlen, nheads, headdim) - y_unpacked = x.new_zeros(batch, max_seqlen, x.shape[2], x.shape[3]) - for b in range(batch): - length_b = seq_lengths[b].item() - if length_b > 0: - y_unpacked[b, :length_b, :, :] = out_packed[ - cu_seqlens[b] : cu_seqlens[b + 1] - ] - else: - # Zero tokens: no chunks, return zeros without calling kernel - y_unpacked = x.new_zeros(batch, max_seqlen, x.shape[2], x.shape[3]) - varlen_states = x.new_zeros( - batch, x.shape[2], x.shape[3], B.shape[-1], - device=device, dtype=x.dtype, - ) + varlen_states = mamba_chunk_scan_combined_varlen( + x_packed, + dt_packed, + A, + B_packed, + C_packed, + chunk_size, + cu_seqlens=cu_seqlens, + cu_chunk_seqlens=cu_chunk_seqlens, + last_chunk_indices=last_chunk_indices, + seq_idx=seq_idx_chunk, + out=out_packed, + D=D_val, + z=z_packed, + dt_bias=dt_bias_val, + initial_states=initial_ssm_state, + dt_softplus=True, + return_intermediate_states=False, + ) + + # Unpack output to (batch, max_seqlen, nheads, headdim) + y_unpacked = x.new_zeros(batch, max_seqlen, x.shape[2], x.shape[3]) + for b in range(batch): + length_b = seq_lengths[b].item() + if length_b > 0: + y_unpacked[b, :length_b, :, :] = out_packed[ + cu_seqlens[b] : cu_seqlens[b + 1] + ] if ssm_state is not None and return_varlen_states: y = (y_unpacked, None, varlen_states) From adc4107648c93f84df30ef98d578c78f1ab13709 Mon Sep 17 00:00:00 2001 From: Onur Yilmaz Date: Sun, 8 Feb 2026 16:03:54 -0500 Subject: [PATCH 03/11] Working except small prompts Signed-off-by: Onur Yilmaz --- megatron/core/ssm/mamba_mixer.py | 209 ++++++++++++++++--------------- 1 file changed, 108 insertions(+), 101 deletions(-) diff --git a/megatron/core/ssm/mamba_mixer.py b/megatron/core/ssm/mamba_mixer.py index 04b8f676f93..6c73eaa09f2 100644 --- a/megatron/core/ssm/mamba_mixer.py +++ b/megatron/core/ssm/mamba_mixer.py @@ -864,8 +864,7 @@ def _ssm_prefill( # Use local varlen kernels only for batch > 1; for batch==1 use mamba_ssm which # is designed for that layout and avoids packing/alignment issues. if ( - cu_seqlens is not None - and x.shape[0] > 1 + cu_seqlens is not None and HAVE_SSM_OPS_VARLEN and mamba_chunk_scan_combined_varlen is not None ): @@ -877,115 +876,123 @@ def _ssm_prefill( initial_ssm_state = ssm_state[batch_indices] - # Build chunk boundaries so no chunk spans two sequences (fixes junk output - # when multiple short sequences share a chunk). Merge fixed-size boundaries - # with sequence boundaries from cu_seqlens. - boundaries_set = {0, total_tokens} - for s in range(1, batch): - boundaries_set.add(cu_seqlens[s].item()) - for pos in range(0, total_tokens, chunk_size): - boundaries_set.add(min(pos, total_tokens)) - boundaries = sorted(boundaries_set) - cu_chunk_seqlens = torch.tensor( - boundaries, device=device, dtype=cu_seqlens.dtype - ) - nchunks = len(boundaries) - 1 - - seq_lengths = cu_seqlens[1:] - cu_seqlens[:-1] - # Chunk index that contains the last token of each sequence - last_token_pos = (cu_seqlens[1:] - 1).clamp(min=0) - last_chunk_indices = ( - torch.searchsorted( - cu_chunk_seqlens, last_token_pos.to(cu_chunk_seqlens.dtype), right=False + if total_tokens > 0: + # Build chunk boundaries so no chunk spans two sequences (fixes junk output + # when multiple short sequences share a chunk). Merge fixed-size boundaries + # with sequence boundaries from cu_seqlens. + boundaries_set = {0, total_tokens} + for s in range(1, batch): + boundaries_set.add(cu_seqlens[s].item()) + for pos in range(0, total_tokens, chunk_size): + boundaries_set.add(min(pos, total_tokens)) + boundaries = sorted(boundaries_set) + cu_chunk_seqlens = torch.tensor( + boundaries, device=device, dtype=cu_seqlens.dtype + ) + nchunks = len(boundaries) - 1 + + seq_lengths = cu_seqlens[1:] - cu_seqlens[:-1] + # Chunk index that contains the last token of each sequence + last_token_pos = (cu_seqlens[1:] - 1).clamp(min=0) + last_chunk_indices = ( + torch.searchsorted( + cu_chunk_seqlens, last_token_pos.to(cu_chunk_seqlens.dtype), right=False + ) + - 1 + ) + last_chunk_indices = last_chunk_indices.clamp(0, nchunks - 1).to( + device=device, dtype=torch.int64 ) - - 1 - ) - last_chunk_indices = last_chunk_indices.clamp(0, nchunks - 1).to( - device=device, dtype=torch.int64 - ) - # Chunk-level seq_idx: which sequence each chunk belongs to - chunk_starts = cu_chunk_seqlens[:-1].to(cu_seqlens.dtype) - seq_idx_chunk = ( - (torch.searchsorted(cu_seqlens, chunk_starts, right=False) - 1) - .clamp(0, batch - 1) - .to(device=device, dtype=torch.int32) - ) + # Chunk-level seq_idx: which sequence each chunk belongs to + chunk_starts = cu_chunk_seqlens[:-1].to(cu_seqlens.dtype) + seq_idx_chunk = ( + (torch.searchsorted(cu_seqlens, chunk_starts, right=False) - 1) + .clamp(0, batch - 1) + .to(device=device, dtype=torch.int32) + ) - # Pack tensors to (total_tokens, ...); use .item() for safe Python int slicing - x_packed = torch.cat( - [ - x[b, : seq_lengths[b].item(), :, :].contiguous() - for b in range(batch) - ], - dim=0, - ).contiguous() - dt_packed = torch.cat( - [dt[b, : seq_lengths[b].item(), :].contiguous() for b in range(batch)], - dim=0, - ).contiguous() - B_packed = torch.cat( - [ - B[b, : seq_lengths[b].item(), :, :].contiguous() - for b in range(batch) - ], - dim=0, - ).contiguous() - C_packed = torch.cat( - [ - C[b, : seq_lengths[b].item(), :, :].contiguous() - for b in range(batch) - ], - dim=0, - ).contiguous() - z_packed = ( - torch.cat( + # Pack tensors to (total_tokens, ...); use .item() for safe Python int slicing + x_packed = torch.cat( [ - z[b, : seq_lengths[b].item(), :, :].contiguous() + x[b, : seq_lengths[b].item(), :, :].contiguous() for b in range(batch) ], dim=0, ).contiguous() - if not self.rmsnorm - else None - ) - - out_packed = torch.empty_like(x_packed) - D_val = ( - rearrange(self.cp.get_D().float(), "(h p) -> h p", p=self.headdim) - if self.D_has_hdim - else self.cp.get_D() - ) - dt_bias_val = self.cp.get_dt_bias().float() + dt_packed = torch.cat( + [dt[b, : seq_lengths[b].item(), :].contiguous() for b in range(batch)], + dim=0, + ).contiguous() + B_packed = torch.cat( + [ + B[b, : seq_lengths[b].item(), :, :].contiguous() + for b in range(batch) + ], + dim=0, + ).contiguous() + C_packed = torch.cat( + [ + C[b, : seq_lengths[b].item(), :, :].contiguous() + for b in range(batch) + ], + dim=0, + ).contiguous() + z_packed = ( + torch.cat( + [ + z[b, : seq_lengths[b].item(), :, :].contiguous() + for b in range(batch) + ], + dim=0, + ).contiguous() + if not self.rmsnorm + else None + ) - varlen_states = mamba_chunk_scan_combined_varlen( - x_packed, - dt_packed, - A, - B_packed, - C_packed, - chunk_size, - cu_seqlens=cu_seqlens, - cu_chunk_seqlens=cu_chunk_seqlens, - last_chunk_indices=last_chunk_indices, - seq_idx=seq_idx_chunk, - out=out_packed, - D=D_val, - z=z_packed, - dt_bias=dt_bias_val, - initial_states=initial_ssm_state, - dt_softplus=True, - return_intermediate_states=False, - ) + out_packed = torch.empty_like(x_packed) + D_val = ( + rearrange(self.cp.get_D().float(), "(h p) -> h p", p=self.headdim) + if self.D_has_hdim + else self.cp.get_D() + ) + dt_bias_val = self.cp.get_dt_bias().float() + + varlen_states = mamba_chunk_scan_combined_varlen( + x_packed, + dt_packed, + A, + B_packed, + C_packed, + chunk_size, + cu_seqlens=cu_seqlens, + cu_chunk_seqlens=cu_chunk_seqlens, + last_chunk_indices=last_chunk_indices, + seq_idx=seq_idx_chunk, + out=out_packed, + D=D_val, + z=z_packed, + dt_bias=dt_bias_val, + initial_states=initial_ssm_state, + dt_softplus=True, + return_intermediate_states=False, + ) - # Unpack output to (batch, max_seqlen, nheads, headdim) - y_unpacked = x.new_zeros(batch, max_seqlen, x.shape[2], x.shape[3]) - for b in range(batch): - length_b = seq_lengths[b].item() - if length_b > 0: - y_unpacked[b, :length_b, :, :] = out_packed[ - cu_seqlens[b] : cu_seqlens[b + 1] - ] + # Unpack output to (batch, max_seqlen, nheads, headdim) + y_unpacked = x.new_zeros(batch, max_seqlen, x.shape[2], x.shape[3]) + for b in range(batch): + length_b = seq_lengths[b].item() + if length_b > 0: + y_unpacked[b, :length_b, :, :] = out_packed[ + cu_seqlens[b] : cu_seqlens[b + 1] + ] + else: + # Zero tokens: skip kernel to avoid illegal memory access with empty inputs + y_unpacked = x.new_zeros(batch, max_seqlen, x.shape[2], x.shape[3]) + varlen_states = x.new_zeros( + batch, x.shape[2], x.shape[3], B.shape[-1], + device=device, dtype=x.dtype, + ) if ssm_state is not None and return_varlen_states: y = (y_unpacked, None, varlen_states) From aa056cf0016052afcedb5fb3e43353ea151ecdc5 Mon Sep 17 00:00:00 2001 From: Onur Yilmaz Date: Sun, 8 Feb 2026 17:29:34 -0500 Subject: [PATCH 04/11] Working with 2 prompts Signed-off-by: Onur Yilmaz --- megatron/core/ssm/mamba_mixer.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/megatron/core/ssm/mamba_mixer.py b/megatron/core/ssm/mamba_mixer.py index 6c73eaa09f2..104ac9c53fe 100644 --- a/megatron/core/ssm/mamba_mixer.py +++ b/megatron/core/ssm/mamba_mixer.py @@ -863,17 +863,24 @@ def _ssm_prefill( # Use local varlen kernels only for batch > 1; for batch==1 use mamba_ssm which # is designed for that layout and avoids packing/alignment issues. + + #print("") + #print("------------ cu_seqlens shape ------------", cu_seqlens.shape if cu_seqlens is not None else None) + #print("------------ cu_seqlens ------------", cu_seqlens) + # print("------------ seq_idx shape ------------", seq_idx.shape if seq_idx is not None else None) + #print("------------ seq_idx ------------", seq_idx) + if ( - cu_seqlens is not None - and HAVE_SSM_OPS_VARLEN - and mamba_chunk_scan_combined_varlen is not None + cu_seqlens is not None ): + print("******** Using variable-length path ********") # Variable-length path using local Triton kernels (megatron.core.ssm.ops) batch, max_seqlen = x.shape[0], x.shape[1] total_tokens = cu_seqlens[-1].item() chunk_size = self.chunk_size device = x.device + #initial_ssm_state = None initial_ssm_state = ssm_state[batch_indices] if total_tokens > 0: @@ -911,6 +918,8 @@ def _ssm_prefill( .clamp(0, batch - 1) .to(device=device, dtype=torch.int32) ) + #print("------------ seq_idx_chunk shape ------------", seq_idx_chunk.shape if seq_idx_chunk is not None else None) + #print("------------ seq_idx_chunk ------------", seq_idx_chunk) # Pack tensors to (total_tokens, ...); use .item() for safe Python int slicing x_packed = torch.cat( @@ -1001,6 +1010,7 @@ def _ssm_prefill( else: y = y_unpacked else: + print("********************************* Using chunked path *********************************") y = mamba_chunk_scan_combined( x, dt, From e35ec71f2c0956da100feffcc113e3b062c5e0eb Mon Sep 17 00:00:00 2001 From: Onur Yilmaz Date: Mon, 9 Feb 2026 01:01:51 -0500 Subject: [PATCH 05/11] First working version Signed-off-by: Onur Yilmaz --- megatron/core/ssm/mamba_mixer.py | 236 +++++++++++-------------------- 1 file changed, 80 insertions(+), 156 deletions(-) diff --git a/megatron/core/ssm/mamba_mixer.py b/megatron/core/ssm/mamba_mixer.py index 104ac9c53fe..b4b87c14c0d 100644 --- a/megatron/core/ssm/mamba_mixer.py +++ b/megatron/core/ssm/mamba_mixer.py @@ -11,6 +11,8 @@ from dataclasses import dataclass, replace from typing import List, Optional, Tuple, Union +from yaml.error import YAMLError + import torch import torch.nn as nn import torch.nn.functional as F @@ -857,160 +859,82 @@ def _ssm_prefill( else: initial_ssm_state = None - # Note that both `seq_idx` and `cu_seqlens` must be passed in - # for variable length generation. - # See https://github.com/state-spaces/mamba/blob/e0761ece1db07e0949dd88b4f4cd440420a19fd9/tests/test_generation.py#L97 # pylint: disable=line-too-long - - # Use local varlen kernels only for batch > 1; for batch==1 use mamba_ssm which - # is designed for that layout and avoids packing/alignment issues. - - #print("") - #print("------------ cu_seqlens shape ------------", cu_seqlens.shape if cu_seqlens is not None else None) - #print("------------ cu_seqlens ------------", cu_seqlens) - # print("------------ seq_idx shape ------------", seq_idx.shape if seq_idx is not None else None) - #print("------------ seq_idx ------------", seq_idx) - if ( cu_seqlens is not None ): - print("******** Using variable-length path ********") - # Variable-length path using local Triton kernels (megatron.core.ssm.ops) - batch, max_seqlen = x.shape[0], x.shape[1] - total_tokens = cu_seqlens[-1].item() - chunk_size = self.chunk_size - device = x.device - - #initial_ssm_state = None + # Variable-length path: sequences are concatenated in one row (x.shape[0] == 1). + # Batch size = number of sequences from cu_seqlens; max_seqlen = max sequence length. + x = x.squeeze(0) + dt = dt.squeeze(0) + A = A.squeeze(0) + B = B.squeeze(0) + C = C.squeeze(0) + z = z.squeeze(0) + y = torch.empty_like(x) + initial_ssm_state = ssm_state[batch_indices] - if total_tokens > 0: - # Build chunk boundaries so no chunk spans two sequences (fixes junk output - # when multiple short sequences share a chunk). Merge fixed-size boundaries - # with sequence boundaries from cu_seqlens. - boundaries_set = {0, total_tokens} - for s in range(1, batch): - boundaries_set.add(cu_seqlens[s].item()) - for pos in range(0, total_tokens, chunk_size): - boundaries_set.add(min(pos, total_tokens)) - boundaries = sorted(boundaries_set) - cu_chunk_seqlens = torch.tensor( - boundaries, device=device, dtype=cu_seqlens.dtype - ) - nchunks = len(boundaries) - 1 - - seq_lengths = cu_seqlens[1:] - cu_seqlens[:-1] - # Chunk index that contains the last token of each sequence - last_token_pos = (cu_seqlens[1:] - 1).clamp(min=0) - last_chunk_indices = ( - torch.searchsorted( - cu_chunk_seqlens, last_token_pos.to(cu_chunk_seqlens.dtype), right=False + # Calculate cu_chunk_seqlens using seq_idx and cu_seqlens + cu_chunk_seqlens = None + if seq_idx is not None and cu_seqlens is not None: + # seq_idx: shape (1, total_tokens), where seq_idx[0, i] maps token position i to the packed sequence index + # cu_seqlens: shape (num_sequences + 1), e.g. [0, 5, 7, 11, 16], where cu_seqlens[-1] == total_tokens + + # The number of sequences = cu_seqlens.numel() - 1 = N + # The output cu_chunk_seqlens should be a cumulative sum of the number of tokens in each sequence, + # i.e. same as cu_seqlens + cu_chunk_seqlens = cu_seqlens + + # However, double check if seq_idx indicates any extra grouping, or if an extra entry is required: + # If seq_idx.max() + 1 > cu_seqlens.numel() - 1, then extra sequences + n_seq_from_seq_idx = int(seq_idx.max().item() + 1) + n_seq_from_cu = cu_seqlens.numel() - 1 + if n_seq_from_seq_idx > n_seq_from_cu: + # Need to extend cu_seqlens to include the rest of tokens counted in seq_idx + # This can happen if the last part is treated as an extra seq + cu_chunk_seqlens = torch.cat( + [cu_seqlens, cu_seqlens.new_tensor([seq_idx.shape[1]])] ) - - 1 - ) - last_chunk_indices = last_chunk_indices.clamp(0, nchunks - 1).to( - device=device, dtype=torch.int64 - ) - - # Chunk-level seq_idx: which sequence each chunk belongs to - chunk_starts = cu_chunk_seqlens[:-1].to(cu_seqlens.dtype) - seq_idx_chunk = ( - (torch.searchsorted(cu_seqlens, chunk_starts, right=False) - 1) - .clamp(0, batch - 1) - .to(device=device, dtype=torch.int32) - ) - #print("------------ seq_idx_chunk shape ------------", seq_idx_chunk.shape if seq_idx_chunk is not None else None) - #print("------------ seq_idx_chunk ------------", seq_idx_chunk) - - # Pack tensors to (total_tokens, ...); use .item() for safe Python int slicing - x_packed = torch.cat( - [ - x[b, : seq_lengths[b].item(), :, :].contiguous() - for b in range(batch) - ], - dim=0, - ).contiguous() - dt_packed = torch.cat( - [dt[b, : seq_lengths[b].item(), :].contiguous() for b in range(batch)], - dim=0, - ).contiguous() - B_packed = torch.cat( - [ - B[b, : seq_lengths[b].item(), :, :].contiguous() - for b in range(batch) - ], - dim=0, - ).contiguous() - C_packed = torch.cat( - [ - C[b, : seq_lengths[b].item(), :, :].contiguous() - for b in range(batch) - ], - dim=0, - ).contiguous() - z_packed = ( - torch.cat( - [ - z[b, : seq_lengths[b].item(), :, :].contiguous() - for b in range(batch) - ], - dim=0, - ).contiguous() - if not self.rmsnorm - else None - ) - out_packed = torch.empty_like(x_packed) - D_val = ( + # Kernel expects seq_idx of shape (nchunks,) — one sequence index per chunk. + # We have seq_idx of shape (1, total_tokens); take seq index at start of each chunk. + seq_idx_for_varlen = None + if seq_idx is not None and cu_chunk_seqlens is not None: + chunk_starts = cu_chunk_seqlens[:-1] + seq_idx_for_varlen = seq_idx[0, chunk_starts].contiguous() + + ssm_varlen_states = mamba_chunk_scan_combined_varlen( + x=x, + dt=dt, + A=A, + B=B, + C=C, + chunk_size=self.chunk_size, + cu_seqlens=cu_seqlens, + cu_chunk_seqlens=cu_chunk_seqlens, + last_chunk_indices=None, + seq_idx=seq_idx_for_varlen, + out=y, + D=( rearrange(self.cp.get_D().float(), "(h p) -> h p", p=self.headdim) if self.D_has_hdim else self.cp.get_D() - ) - dt_bias_val = self.cp.get_dt_bias().float() - - varlen_states = mamba_chunk_scan_combined_varlen( - x_packed, - dt_packed, - A, - B_packed, - C_packed, - chunk_size, - cu_seqlens=cu_seqlens, - cu_chunk_seqlens=cu_chunk_seqlens, - last_chunk_indices=last_chunk_indices, - seq_idx=seq_idx_chunk, - out=out_packed, - D=D_val, - z=z_packed, - dt_bias=dt_bias_val, - initial_states=initial_ssm_state, - dt_softplus=True, - return_intermediate_states=False, - ) + ), + z=z if not self.rmsnorm else None, + dt_bias=self.cp.get_dt_bias().float(), + initial_states=initial_ssm_state, + return_intermediate_states=False, + dt_softplus=True, + dt_limit=(0.0, float("inf")), + state_dtype=ssm_state.dtype, + ) - # Unpack output to (batch, max_seqlen, nheads, headdim) - y_unpacked = x.new_zeros(batch, max_seqlen, x.shape[2], x.shape[3]) - for b in range(batch): - length_b = seq_lengths[b].item() - if length_b > 0: - y_unpacked[b, :length_b, :, :] = out_packed[ - cu_seqlens[b] : cu_seqlens[b + 1] - ] - else: - # Zero tokens: skip kernel to avoid illegal memory access with empty inputs - y_unpacked = x.new_zeros(batch, max_seqlen, x.shape[2], x.shape[3]) - varlen_states = x.new_zeros( - batch, x.shape[2], x.shape[3], B.shape[-1], - device=device, dtype=x.dtype, - ) + y = y.unsqueeze(0) + z = z.unsqueeze(0) - if ssm_state is not None and return_varlen_states: - y = (y_unpacked, None, varlen_states) - elif ssm_state is not None: - y = (y_unpacked, None) - else: - y = y_unpacked + tensor_masked_update(ssm_state, batch_indices, ssm_varlen_states) + else: - print("********************************* Using chunked path *********************************") y = mamba_chunk_scan_combined( x, dt, @@ -1033,23 +957,23 @@ def _ssm_prefill( initial_states=initial_ssm_state, ) - if ssm_state is not None: - if return_varlen_states: - assert batch_indices is not None + if ssm_state is not None: + if return_varlen_states: + assert batch_indices is not None - y, _, ssm_varlen_states = y + y, _, ssm_varlen_states = y - # This has to be varlen_states, NOT last_state - # See reference implementation: - # https://github.com/state-spaces/mamba/blob/e0761ece1db07e0949dd88b4f4cd440420a19fd9/mamba_ssm/modules/mamba2.py#L267 # pylint: disable=line-too-long - tensor_masked_update(ssm_state, batch_indices, ssm_varlen_states) - elif is_chunked_prefill: - assert batch_indices is not None - y, last_state = y - tensor_masked_update(ssm_state, batch_indices, last_state) - else: - y, last_state = y - ssm_state.copy_(last_state) + # This has to be varlen_states, NOT last_state + # See reference implementation: + # https://github.com/state-spaces/mamba/blob/e0761ece1db07e0949dd88b4f4cd440420a19fd9/mamba_ssm/modules/mamba2.py#L267 # pylint: disable=line-too-long + tensor_masked_update(ssm_state, batch_indices, ssm_varlen_states) + elif is_chunked_prefill: + assert batch_indices is not None + y, last_state = y + tensor_masked_update(ssm_state, batch_indices, last_state) + else: + y, last_state = y + ssm_state.copy_(last_state) y = rearrange(y, "b l h p -> l b (h p)").contiguous() y = self.cp.post_conv_ssm(y) From 59a813bc1dae2c35aa57e7c141dbf18d3569c1df Mon Sep 17 00:00:00 2001 From: Onur Yilmaz Date: Mon, 9 Feb 2026 12:24:22 -0500 Subject: [PATCH 06/11] Add cuda graph check Signed-off-by: Onur Yilmaz --- megatron/core/ssm/mamba_mixer.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/megatron/core/ssm/mamba_mixer.py b/megatron/core/ssm/mamba_mixer.py index b4b87c14c0d..bb553526f94 100644 --- a/megatron/core/ssm/mamba_mixer.py +++ b/megatron/core/ssm/mamba_mixer.py @@ -886,15 +886,17 @@ def _ssm_prefill( cu_chunk_seqlens = cu_seqlens # However, double check if seq_idx indicates any extra grouping, or if an extra entry is required: - # If seq_idx.max() + 1 > cu_seqlens.numel() - 1, then extra sequences - n_seq_from_seq_idx = int(seq_idx.max().item() + 1) - n_seq_from_cu = cu_seqlens.numel() - 1 - if n_seq_from_seq_idx > n_seq_from_cu: - # Need to extend cu_seqlens to include the rest of tokens counted in seq_idx - # This can happen if the last part is treated as an extra seq - cu_chunk_seqlens = torch.cat( - [cu_seqlens, cu_seqlens.new_tensor([seq_idx.shape[1]])] - ) + # If seq_idx.max() + 1 > cu_seqlens.numel() - 1, then extra sequences. + # Skip this during CUDA graph capture: .item() would sync the stream and is not permitted. + if not torch.cuda.is_current_stream_capturing(): + n_seq_from_seq_idx = int(seq_idx.max().item() + 1) + n_seq_from_cu = cu_seqlens.numel() - 1 + if n_seq_from_seq_idx > n_seq_from_cu: + # Need to extend cu_seqlens to include the rest of tokens counted in seq_idx + # This can happen if the last part is treated as an extra seq + cu_chunk_seqlens = torch.cat( + [cu_seqlens, cu_seqlens.new_tensor([seq_idx.shape[1]])] + ) # Kernel expects seq_idx of shape (nchunks,) — one sequence index per chunk. # We have seq_idx of shape (1, total_tokens); take seq index at start of each chunk. From 6d2ad5d1cab6932afee1e143adc47767f5f5b27a Mon Sep 17 00:00:00 2001 From: Onur Yilmaz Date: Mon, 9 Feb 2026 14:03:30 -0500 Subject: [PATCH 07/11] Init state with varlen works Signed-off-by: Onur Yilmaz --- megatron/core/ssm/mamba_mixer.py | 27 ++++++++++----------------- 1 file changed, 10 insertions(+), 17 deletions(-) diff --git a/megatron/core/ssm/mamba_mixer.py b/megatron/core/ssm/mamba_mixer.py index bb553526f94..9658243fc8c 100644 --- a/megatron/core/ssm/mamba_mixer.py +++ b/megatron/core/ssm/mamba_mixer.py @@ -854,10 +854,7 @@ def _ssm_prefill( self.cp.cp_size == 1 or self.rmsnorm ), "Context parallel not supported for use_mem_eff_path==False and rmsnorm==False" - if is_chunked_prefill: - initial_ssm_state = ssm_state[batch_indices] - else: - initial_ssm_state = None + initial_ssm_state = ssm_state[batch_indices] if ( cu_seqlens is not None @@ -872,8 +869,6 @@ def _ssm_prefill( z = z.squeeze(0) y = torch.empty_like(x) - initial_ssm_state = ssm_state[batch_indices] - # Calculate cu_chunk_seqlens using seq_idx and cu_seqlens cu_chunk_seqlens = None if seq_idx is not None and cu_seqlens is not None: @@ -886,17 +881,15 @@ def _ssm_prefill( cu_chunk_seqlens = cu_seqlens # However, double check if seq_idx indicates any extra grouping, or if an extra entry is required: - # If seq_idx.max() + 1 > cu_seqlens.numel() - 1, then extra sequences. - # Skip this during CUDA graph capture: .item() would sync the stream and is not permitted. - if not torch.cuda.is_current_stream_capturing(): - n_seq_from_seq_idx = int(seq_idx.max().item() + 1) - n_seq_from_cu = cu_seqlens.numel() - 1 - if n_seq_from_seq_idx > n_seq_from_cu: - # Need to extend cu_seqlens to include the rest of tokens counted in seq_idx - # This can happen if the last part is treated as an extra seq - cu_chunk_seqlens = torch.cat( - [cu_seqlens, cu_seqlens.new_tensor([seq_idx.shape[1]])] - ) + # If seq_idx.max() + 1 > cu_seqlens.numel() - 1, then extra sequences + n_seq_from_seq_idx = int(seq_idx.max().item() + 1) + n_seq_from_cu = cu_seqlens.numel() - 1 + if n_seq_from_seq_idx > n_seq_from_cu: + # Need to extend cu_seqlens to include the rest of tokens counted in seq_idx + # This can happen if the last part is treated as an extra seq + cu_chunk_seqlens = torch.cat( + [cu_seqlens, cu_seqlens.new_tensor([seq_idx.shape[1]])] + ) # Kernel expects seq_idx of shape (nchunks,) — one sequence index per chunk. # We have seq_idx of shape (1, total_tokens); take seq index at start of each chunk. From 9f2210aaf110e7db530d193f0307665010db8c14 Mon Sep 17 00:00:00 2001 From: Onur Yilmaz Date: Mon, 16 Feb 2026 12:51:29 -0500 Subject: [PATCH 08/11] Adding unit tests Signed-off-by: Onur Yilmaz --- tests/unit_tests/ssm/ops/test_ops_init.py | 27 +++ tests/unit_tests/ssm/ops/test_ssd_bmm.py | 78 +++++++++ .../unit_tests/ssm/ops/test_ssd_chunk_scan.py | 98 +++++++++++ .../ssm/ops/test_ssd_chunk_state.py | 164 ++++++++++++++++++ tests/unit_tests/ssm/ops/test_ssd_combined.py | 150 ++++++++++++++++ .../ssm/ops/test_ssd_state_passing.py | 87 ++++++++++ tests/unit_tests/ssm/ops/test_ssm_kernel.py | 162 +++++++++++++++++ 7 files changed, 766 insertions(+) create mode 100644 tests/unit_tests/ssm/ops/test_ops_init.py create mode 100644 tests/unit_tests/ssm/ops/test_ssd_bmm.py create mode 100644 tests/unit_tests/ssm/ops/test_ssd_chunk_scan.py create mode 100644 tests/unit_tests/ssm/ops/test_ssd_chunk_state.py create mode 100644 tests/unit_tests/ssm/ops/test_ssd_combined.py create mode 100644 tests/unit_tests/ssm/ops/test_ssd_state_passing.py create mode 100644 tests/unit_tests/ssm/ops/test_ssm_kernel.py diff --git a/tests/unit_tests/ssm/ops/test_ops_init.py b/tests/unit_tests/ssm/ops/test_ops_init.py new file mode 100644 index 00000000000..2a7c8b42a4a --- /dev/null +++ b/tests/unit_tests/ssm/ops/test_ops_init.py @@ -0,0 +1,27 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +"""Test that the megatron.core.ssm.ops package exports the public API.""" + +import unittest + +try: + from megatron.core.ssm import ops as ssm_ops + HAVE_SSD_OPS = True +except (ImportError, Exception): + HAVE_SSD_OPS = False + + +@unittest.skipIf(not HAVE_SSD_OPS, "SSD ops (Triton 3+) not available") +class TestOpsPackagePublicAPI(unittest.TestCase): + """Ensure the ops package exposes the documented public API.""" + + def test_all_exported(self): + self.assertIn("mamba_chunk_scan_combined_varlen", ssm_ops.__all__) + + def test_mamba_chunk_scan_combined_varlen_importable(self): + self.assertTrue(hasattr(ssm_ops, "mamba_chunk_scan_combined_varlen")) + self.assertTrue(callable(ssm_ops.mamba_chunk_scan_combined_varlen)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit_tests/ssm/ops/test_ssd_bmm.py b/tests/unit_tests/ssm/ops/test_ssd_bmm.py new file mode 100644 index 00000000000..cc15c758291 --- /dev/null +++ b/tests/unit_tests/ssm/ops/test_ssd_bmm.py @@ -0,0 +1,78 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +import unittest +import torch + +try: + from megatron.core.ssm.ops.ssd_bmm import _bmm_chunk_fwd + HAVE_SSD_OPS = True +except (ImportError, Exception): + HAVE_SSD_OPS = False + + +@unittest.skipIf(not HAVE_SSD_OPS, "SSD ops (Triton 3+) not available") +@unittest.skipIf(not torch.cuda.is_available(), "CUDA required for SSD ops") +class TestBmmChunkFwd(unittest.TestCase): + """Tests for _bmm_chunk_fwd (C^T @ B per chunk).""" + + def setUp(self): + torch.manual_seed(42) + self.device = torch.device("cuda") + self.chunk_size = 16 + self.seqlen = 32 + self.ngroups = 2 + self.dstate = 8 # K dimension + self.cu_chunk_seqlens = torch.tensor([0, 16, 32], dtype=torch.int32, device=self.device) + + def test_bmm_chunk_fwd_shape(self): + # a: (seqlen, ngroups, k), b: (seqlen, ngroups, k) -> out: (nchunks, ngroups, chunk_size, chunk_size) + a = torch.randn(self.seqlen, self.ngroups, self.dstate, device=self.device, dtype=torch.float32) + b = torch.randn(self.seqlen, self.ngroups, self.dstate, device=self.device, dtype=torch.float32) + + out = _bmm_chunk_fwd( + a, b, self.chunk_size, self.cu_chunk_seqlens, causal=True, output_dtype=torch.float32 + ) + + nchunks = self.cu_chunk_seqlens.shape[0] - 1 + self.assertEqual(out.shape, (nchunks, self.ngroups, self.chunk_size, self.chunk_size)) + self.assertFalse(torch.isnan(out).any()) + + def test_bmm_chunk_fwd_vs_torch_per_chunk(self): + """Compare first chunk with explicit C^T @ B for that chunk.""" + a = torch.randn(self.seqlen, self.ngroups, self.dstate, device=self.device, dtype=torch.float32) + b = torch.randn(self.seqlen, self.ngroups, self.dstate, device=self.device, dtype=torch.float32) + + out = _bmm_chunk_fwd( + a, b, self.chunk_size, self.cu_chunk_seqlens, causal=False, output_dtype=torch.float32 + ) + + # Chunk 0: rows 0:16 of a and b. out[0, g] = a[0:16, g] @ b[0:16, g].T + # Relaxed tolerances: Triton block-wise reduction order can differ from torch; + # atol is the main check (max abs diff was ~0.008 in practice). + for g in range(self.ngroups): + a_chunk = a[0:16, g, :].contiguous() # (16, dstate) + b_chunk = b[0:16, g, :].contiguous() # (16, dstate) + expected = torch.mm(a_chunk, b_chunk.T) # (16, 16) + torch.testing.assert_close(out[0, g], expected, rtol=1.0, atol=0.02) + + def test_bmm_chunk_fwd_causal_vs_non_causal_shape(self): + a = torch.randn(self.seqlen, self.ngroups, self.dstate, device=self.device, dtype=torch.float32) + b = torch.randn(self.seqlen, self.ngroups, self.dstate, device=self.device, dtype=torch.float32) + + out_causal = _bmm_chunk_fwd(a, b, self.chunk_size, self.cu_chunk_seqlens, causal=True) + out_noncausal = _bmm_chunk_fwd(a, b, self.chunk_size, self.cu_chunk_seqlens, causal=False) + + self.assertEqual(out_causal.shape, out_noncausal.shape) + # Causal: lower triangle is correct; upper can differ + for c in range(out_causal.shape[0]): + for g in range(self.ngroups): + for i in range(self.chunk_size): + for j in range(i + 1): + self.assertTrue( + torch.allclose(out_causal[c, g, i, j], out_noncausal[c, g, i, j]), + f"c={c} g={g} i={i} j={j}", + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit_tests/ssm/ops/test_ssd_chunk_scan.py b/tests/unit_tests/ssm/ops/test_ssd_chunk_scan.py new file mode 100644 index 00000000000..1c6d4ecfbc3 --- /dev/null +++ b/tests/unit_tests/ssm/ops/test_ssd_chunk_scan.py @@ -0,0 +1,98 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +import unittest +import torch + +try: + from megatron.core.ssm.ops.ssd_chunk_scan import _chunk_scan_fwd + HAVE_SSD_OPS = True +except (ImportError, Exception): + HAVE_SSD_OPS = False + + +@unittest.skipIf(not HAVE_SSD_OPS, "SSD ops (Triton 3+) not available") +@unittest.skipIf(not torch.cuda.is_available(), "CUDA required for SSD ops") +class TestChunkScanFwd(unittest.TestCase): + """Tests for _chunk_scan_fwd.""" + + def setUp(self): + torch.manual_seed(42) + self.device = torch.device("cuda") + self.seqlen = 32 + self.nheads = 4 + self.headdim = 16 + self.ngroups = 2 + self.dstate = 8 + self.chunk_size = 16 + self.nchunks = 2 + self.cu_chunk_seqlens = torch.tensor([0, 16, 32], dtype=torch.int32, device=self.device) + self.seq_idx = torch.tensor([0, 1], dtype=torch.int32, device=self.device) + + def test_chunk_scan_fwd_shape_and_inplace_out(self): + cb = torch.randn( + self.nchunks, self.ngroups, self.chunk_size, self.chunk_size, + device=self.device, dtype=torch.float32, + ) + x = torch.randn(self.seqlen, self.nheads, self.headdim, device=self.device, dtype=torch.float32) + dt = torch.randn(self.nheads, self.nchunks, self.chunk_size, device=self.device, dtype=torch.float32) + dA_cumsum = torch.randn(self.nheads, self.nchunks, self.chunk_size, device=self.device, dtype=torch.float32) + C = torch.randn(self.seqlen, self.ngroups, self.dstate, device=self.device, dtype=torch.float32) + states = torch.randn(self.nchunks, self.nheads, self.headdim, self.dstate, device=self.device, dtype=torch.float32) + out = torch.zeros(self.seqlen, self.nheads, self.headdim, device=self.device, dtype=torch.float32) + + _chunk_scan_fwd( + cb, x, dt, dA_cumsum, C, states, + self.cu_chunk_seqlens, out, self.seq_idx, + D=None, z=None, initial_states=None, + ) + + self.assertEqual(out.shape, (self.seqlen, self.nheads, self.headdim)) + self.assertFalse(torch.isnan(out).any()) + # Output should be non-zero (scan writes to out) + self.assertGreater(out.abs().max().item(), 0.0) + + def test_chunk_scan_fwd_with_D(self): + cb = torch.randn( + self.nchunks, self.ngroups, self.chunk_size, self.chunk_size, + device=self.device, dtype=torch.float32, + ) + x = torch.randn(self.seqlen, self.nheads, self.headdim, device=self.device, dtype=torch.float32) + dt = torch.randn(self.nheads, self.nchunks, self.chunk_size, device=self.device, dtype=torch.float32) + dA_cumsum = torch.randn(self.nheads, self.nchunks, self.chunk_size, device=self.device, dtype=torch.float32) + C = torch.randn(self.seqlen, self.ngroups, self.dstate, device=self.device, dtype=torch.float32) + states = torch.randn(self.nchunks, self.nheads, self.headdim, self.dstate, device=self.device, dtype=torch.float32) + out = torch.zeros(self.seqlen, self.nheads, self.headdim, device=self.device, dtype=torch.float32) + D = torch.ones(self.nheads, self.headdim, device=self.device, dtype=torch.float32) + + _chunk_scan_fwd( + cb, x, dt, dA_cumsum, C, states, + self.cu_chunk_seqlens, out, self.seq_idx, + D=D, z=None, initial_states=None, + ) + + self.assertFalse(torch.isnan(out).any()) + + def test_chunk_scan_fwd_with_z(self): + cb = torch.randn( + self.nchunks, self.ngroups, self.chunk_size, self.chunk_size, + device=self.device, dtype=torch.float32, + ) + x = torch.randn(self.seqlen, self.nheads, self.headdim, device=self.device, dtype=torch.float32) + z = torch.randn(self.seqlen, self.nheads, self.headdim, device=self.device, dtype=torch.float32) + dt = torch.randn(self.nheads, self.nchunks, self.chunk_size, device=self.device, dtype=torch.float32) + dA_cumsum = torch.randn(self.nheads, self.nchunks, self.chunk_size, device=self.device, dtype=torch.float32) + C = torch.randn(self.seqlen, self.ngroups, self.dstate, device=self.device, dtype=torch.float32) + states = torch.randn(self.nchunks, self.nheads, self.headdim, self.dstate, device=self.device, dtype=torch.float32) + out = torch.zeros(self.seqlen, self.nheads, self.headdim, device=self.device, dtype=torch.float32) + + _chunk_scan_fwd( + cb, x, dt, dA_cumsum, C, states, + self.cu_chunk_seqlens, out, self.seq_idx, + D=None, z=z, initial_states=None, + ) + + self.assertFalse(torch.isnan(out).any()) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit_tests/ssm/ops/test_ssd_chunk_state.py b/tests/unit_tests/ssm/ops/test_ssd_chunk_state.py new file mode 100644 index 00000000000..f72909d9084 --- /dev/null +++ b/tests/unit_tests/ssm/ops/test_ssd_chunk_state.py @@ -0,0 +1,164 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +import unittest +import torch + +try: + from megatron.core.ssm.ops.ssd_chunk_state import ( + _chunk_cumsum_fwd, + _chunk_state_fwd, + chunk_state_varlen, + ) + HAVE_SSD_OPS = True +except (ImportError, Exception): + HAVE_SSD_OPS = False + + +@unittest.skipIf(not HAVE_SSD_OPS, "SSD ops (Triton 3+) not available") +@unittest.skipIf(not torch.cuda.is_available(), "CUDA required for SSD ops") +class TestChunkCumsumFwd(unittest.TestCase): + """Tests for _chunk_cumsum_fwd.""" + + def setUp(self): + torch.manual_seed(42) + self.device = torch.device("cuda") + self.seqlen = 32 + self.nheads = 4 + self.chunk_size = 16 + self.cu_chunk_seqlens = torch.tensor([0, 16, 32], dtype=torch.int32, device=self.device) + + def test_chunk_cumsum_fwd_shape(self): + dt = torch.randn(self.seqlen, self.nheads, device=self.device, dtype=torch.float32) + A = torch.randn(self.nheads, device=self.device, dtype=torch.float32) + + dA_cumsum, dt_out = _chunk_cumsum_fwd( + dt, A, self.chunk_size, self.cu_chunk_seqlens + ) + + nchunks = self.cu_chunk_seqlens.shape[0] - 1 + self.assertEqual(dA_cumsum.shape, (self.nheads, nchunks, self.chunk_size)) + self.assertEqual(dt_out.shape, (self.nheads, nchunks, self.chunk_size)) + self.assertFalse(torch.isnan(dA_cumsum).any()) + self.assertFalse(torch.isnan(dt_out).any()) + + def test_chunk_cumsum_fwd_cumsum_per_chunk(self): + """dA_cumsum should be cumsum of dt * A along the chunk dimension.""" + dt = torch.randn(self.seqlen, self.nheads, device=self.device, dtype=torch.float32) + A = torch.randn(self.nheads, device=self.device, dtype=torch.float32) + + dA_cumsum, dt_out = _chunk_cumsum_fwd( + dt, A, self.chunk_size, self.cu_chunk_seqlens, + dt_bias=None, dt_softplus=False, dt_limit=(0.0, float("inf")), + ) + + nchunks = self.cu_chunk_seqlens.shape[0] - 1 + for c in range(nchunks): + start = self.cu_chunk_seqlens[c].item() + end = self.cu_chunk_seqlens[c + 1].item() + chunk_len = end - start + for h in range(self.nheads): + dA_chunk = (dt_out[h, c, :chunk_len] * A[h]).cpu() + expected_cumsum = torch.cumsum(dA_chunk, dim=0) + torch.testing.assert_close( + dA_cumsum[h, c, :chunk_len].cpu(), expected_cumsum, rtol=1e-4, atol=1e-4 + ) + + def test_chunk_cumsum_fwd_with_dt_bias(self): + dt = torch.randn(self.seqlen, self.nheads, device=self.device, dtype=torch.float32) + A = torch.randn(self.nheads, device=self.device, dtype=torch.float32) + dt_bias = torch.randn(self.nheads, device=self.device, dtype=torch.float32) + + dA_cumsum, dt_out = _chunk_cumsum_fwd( + dt, A, self.chunk_size, self.cu_chunk_seqlens, dt_bias=dt_bias + ) + + nchunks = self.cu_chunk_seqlens.shape[0] - 1 + self.assertEqual(dA_cumsum.shape, (self.nheads, nchunks, self.chunk_size)) + self.assertFalse(torch.isnan(dA_cumsum).any()) + + +@unittest.skipIf(not HAVE_SSD_OPS, "SSD ops (Triton 3+) not available") +@unittest.skipIf(not torch.cuda.is_available(), "CUDA required for SSD ops") +class TestChunkStateFwd(unittest.TestCase): + """Tests for _chunk_state_fwd.""" + + def setUp(self): + torch.manual_seed(42) + self.device = torch.device("cuda") + self.seqlen = 32 + self.nheads = 4 + self.headdim = 16 + self.ngroups = 2 + self.dstate = 8 + self.chunk_size = 16 + self.cu_chunk_seqlens = torch.tensor([0, 16, 32], dtype=torch.int32, device=self.device) + + def test_chunk_state_fwd_shape(self): + x = torch.randn(self.seqlen, self.nheads, self.headdim, device=self.device, dtype=torch.float32) + B = torch.randn(self.seqlen, self.ngroups, self.dstate, device=self.device, dtype=torch.float32) + dt = torch.randn(self.nheads, 2, self.chunk_size, device=self.device, dtype=torch.float32) + dA_cumsum = torch.randn(self.nheads, 2, self.chunk_size, device=self.device, dtype=torch.float32) + + states = _chunk_state_fwd(B, x, dt, dA_cumsum, self.cu_chunk_seqlens) + + nchunks = self.cu_chunk_seqlens.shape[0] - 1 + self.assertEqual(states.shape, (nchunks, self.nheads, self.headdim, self.dstate)) + self.assertFalse(torch.isnan(states).any()) + + +@unittest.skipIf(not HAVE_SSD_OPS, "SSD ops (Triton 3+) not available") +@unittest.skipIf(not torch.cuda.is_available(), "CUDA required for SSD ops") +class TestChunkStateVarlen(unittest.TestCase): + """Tests for chunk_state_varlen.""" + + def setUp(self): + torch.manual_seed(42) + self.device = torch.device("cuda") + self.seqlen = 32 + self.nheads = 4 + self.headdim = 16 + self.ngroups = 2 + self.dstate = 8 + self.chunk_size = 16 + self.batch = 2 + self.cu_seqlens = torch.tensor([0, 16, 32], dtype=torch.int32, device=self.device) + self.cu_chunk_seqlens = torch.tensor([0, 16, 32], dtype=torch.int32, device=self.device) + self.last_chunk_indices = torch.tensor([0, 1], dtype=torch.int64, device=self.device) + + def test_chunk_state_varlen_shape(self): + x = torch.randn(self.seqlen, self.nheads, self.headdim, device=self.device, dtype=torch.float32) + B = torch.randn(self.seqlen, self.ngroups, self.dstate, device=self.device, dtype=torch.float32) + dt = torch.randn(self.nheads, 2, self.chunk_size, device=self.device, dtype=torch.float32) + dA_cumsum = torch.randn(self.nheads, 2, self.chunk_size, device=self.device, dtype=torch.float32) + chunk_states = torch.randn(2, self.nheads, self.headdim, self.dstate, device=self.device, dtype=torch.float32) + + states = chunk_state_varlen( + B, x, dt, dA_cumsum, self.cu_seqlens, chunk_states, + last_chunk_indices=self.last_chunk_indices, + cu_chunk_seqlens=self.cu_chunk_seqlens, + ) + + self.assertEqual(states.shape, (self.batch, self.nheads, self.headdim, self.dstate)) + self.assertFalse(torch.isnan(states).any()) + + def test_chunk_state_varlen_with_initial_states(self): + x = torch.randn(self.seqlen, self.nheads, self.headdim, device=self.device, dtype=torch.float32) + B = torch.randn(self.seqlen, self.ngroups, self.dstate, device=self.device, dtype=torch.float32) + dt = torch.randn(self.nheads, 2, self.chunk_size, device=self.device, dtype=torch.float32) + dA_cumsum = torch.randn(self.nheads, 2, self.chunk_size, device=self.device, dtype=torch.float32) + chunk_states = torch.randn(2, self.nheads, self.headdim, self.dstate, device=self.device, dtype=torch.float32) + initial_states = torch.randn(self.batch, self.nheads, self.headdim, self.dstate, device=self.device, dtype=torch.float32) + + states = chunk_state_varlen( + B, x, dt, dA_cumsum, self.cu_seqlens, chunk_states, + initial_states=initial_states, + last_chunk_indices=self.last_chunk_indices, + cu_chunk_seqlens=self.cu_chunk_seqlens, + ) + + self.assertEqual(states.shape, (self.batch, self.nheads, self.headdim, self.dstate)) + self.assertFalse(torch.isnan(states).any()) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit_tests/ssm/ops/test_ssd_combined.py b/tests/unit_tests/ssm/ops/test_ssd_combined.py new file mode 100644 index 00000000000..52bceb94a91 --- /dev/null +++ b/tests/unit_tests/ssm/ops/test_ssd_combined.py @@ -0,0 +1,150 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +import unittest +import torch + +try: + from megatron.core.ssm.ops.ssd_combined import ( + is_int_pow_2, + mamba_chunk_scan_combined_varlen, + ) + HAVE_SSD_OPS = True +except (ImportError, Exception): + HAVE_SSD_OPS = False + + +@unittest.skipIf(not HAVE_SSD_OPS, "SSD ops (Triton 3+) not available") +@unittest.skipIf(not torch.cuda.is_available(), "CUDA required for SSD ops") +class TestIsIntPow2(unittest.TestCase): + """Tests for is_int_pow_2 utility.""" + + def test_powers_of_two(self): + for exp in range(12): + n = 2 ** exp + self.assertTrue(is_int_pow_2(n), f"2^{exp}={n} should be power of 2") + + def test_non_powers_of_two(self): + for n in [0, 3, 5, 6, 7, 9, 10, 12, 15, 18]: + self.assertFalse(is_int_pow_2(n), f"{n} should not be power of 2") + + def test_negative_and_float(self): + self.assertFalse(is_int_pow_2(-1)) + self.assertFalse(is_int_pow_2(-4)) + self.assertFalse(is_int_pow_2(2.0)) + self.assertFalse(is_int_pow_2(0)) + + +@unittest.skipIf(not HAVE_SSD_OPS, "SSD ops (Triton 3+) not available") +@unittest.skipIf(not torch.cuda.is_available(), "CUDA required for SSD ops") +class TestMambaChunkScanCombinedVarlen(unittest.TestCase): + """Tests for mamba_chunk_scan_combined_varlen.""" + + def setUp(self): + torch.manual_seed(42) + self.device = torch.device("cuda") + self.chunk_size = 16 + self.seqlen = 32 + self.nheads = 4 + self.headdim = 16 + self.ngroups = 2 + self.dstate = 8 + self.batch = 2 + # cu_seqlens: [0, 16, 32] -> two sequences of length 16 each + self.cu_seqlens = torch.tensor([0, 16, 32], dtype=torch.int32, device=self.device) + # 2 chunks of 16 each + self.cu_chunk_seqlens = torch.tensor([0, 16, 32], dtype=torch.int32, device=self.device) + # last chunk index per sequence: seq0 ends in chunk 0, seq1 ends in chunk 1 + self.last_chunk_indices = torch.tensor([0, 1], dtype=torch.int64, device=self.device) + # seq_idx: which sequence each chunk belongs to (nchunks,) + self.seq_idx = torch.tensor([0, 1], dtype=torch.int32, device=self.device) + + def test_mamba_chunk_scan_combined_varlen_shape_and_no_nan(self): + x = torch.randn(self.seqlen, self.nheads, self.headdim, device=self.device, dtype=torch.float32) + dt = torch.randn(self.seqlen, self.nheads, device=self.device, dtype=torch.float32) + A = torch.randn(self.nheads, device=self.device, dtype=torch.float32) + B = torch.randn(self.seqlen, self.ngroups, self.dstate, device=self.device, dtype=torch.float32) + C = torch.randn(self.seqlen, self.ngroups, self.dstate, device=self.device, dtype=torch.float32) + out = torch.empty(self.seqlen, self.nheads, self.headdim, device=self.device, dtype=torch.float32) + + varlen_states = mamba_chunk_scan_combined_varlen( + x=x, + dt=dt, + A=A, + B=B, + C=C, + chunk_size=self.chunk_size, + cu_seqlens=self.cu_seqlens, + cu_chunk_seqlens=self.cu_chunk_seqlens, + last_chunk_indices=self.last_chunk_indices, + seq_idx=self.seq_idx, + out=out, + ) + + self.assertEqual(varlen_states.shape, (self.batch, self.nheads, self.headdim, self.dstate)) + self.assertEqual(out.shape, (self.seqlen, self.nheads, self.headdim)) + self.assertFalse(torch.isnan(out).any(), "output should have no NaN") + self.assertFalse(torch.isnan(varlen_states).any(), "varlen_states should have no NaN") + + def test_mamba_chunk_scan_combined_varlen_with_D_and_dt_bias(self): + x = torch.randn(self.seqlen, self.nheads, self.headdim, device=self.device, dtype=torch.float32) + dt = torch.randn(self.seqlen, self.nheads, device=self.device, dtype=torch.float32) + A = torch.randn(self.nheads, device=self.device, dtype=torch.float32) + B = torch.randn(self.seqlen, self.ngroups, self.dstate, device=self.device, dtype=torch.float32) + C = torch.randn(self.seqlen, self.ngroups, self.dstate, device=self.device, dtype=torch.float32) + D = torch.ones(self.nheads, self.headdim, device=self.device, dtype=torch.float32) + dt_bias = torch.randn(self.nheads, device=self.device, dtype=torch.float32) + out = torch.empty(self.seqlen, self.nheads, self.headdim, device=self.device, dtype=torch.float32) + + varlen_states = mamba_chunk_scan_combined_varlen( + x=x, + dt=dt, + A=A, + B=B, + C=C, + chunk_size=self.chunk_size, + cu_seqlens=self.cu_seqlens, + cu_chunk_seqlens=self.cu_chunk_seqlens, + last_chunk_indices=self.last_chunk_indices, + seq_idx=self.seq_idx, + out=out, + D=D, + dt_bias=dt_bias, + ) + + self.assertEqual(varlen_states.shape, (self.batch, self.nheads, self.headdim, self.dstate)) + self.assertFalse(torch.isnan(out).any()) + + def test_mamba_chunk_scan_combined_varlen_single_sequence(self): + """Single sequence: cu_seqlens [0, 32], one sequence of 32.""" + cu_seqlens = torch.tensor([0, 32], dtype=torch.int32, device=self.device) + cu_chunk_seqlens = torch.tensor([0, 16, 32], dtype=torch.int32, device=self.device) + last_chunk_indices = torch.tensor([1], dtype=torch.int64, device=self.device) + seq_idx = torch.tensor([0, 0], dtype=torch.int32, device=self.device) + + x = torch.randn(32, self.nheads, self.headdim, device=self.device, dtype=torch.float32) + dt = torch.randn(32, self.nheads, device=self.device, dtype=torch.float32) + A = torch.randn(self.nheads, device=self.device, dtype=torch.float32) + B = torch.randn(32, self.ngroups, self.dstate, device=self.device, dtype=torch.float32) + C = torch.randn(32, self.ngroups, self.dstate, device=self.device, dtype=torch.float32) + out = torch.empty(32, self.nheads, self.headdim, device=self.device, dtype=torch.float32) + + varlen_states = mamba_chunk_scan_combined_varlen( + x=x, + dt=dt, + A=A, + B=B, + C=C, + chunk_size=self.chunk_size, + cu_seqlens=cu_seqlens, + cu_chunk_seqlens=cu_chunk_seqlens, + last_chunk_indices=last_chunk_indices, + seq_idx=seq_idx, + out=out, + ) + + self.assertEqual(varlen_states.shape, (1, self.nheads, self.headdim, self.dstate)) + self.assertFalse(torch.isnan(out).any()) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit_tests/ssm/ops/test_ssd_state_passing.py b/tests/unit_tests/ssm/ops/test_ssd_state_passing.py new file mode 100644 index 00000000000..e8dccdcbbcf --- /dev/null +++ b/tests/unit_tests/ssm/ops/test_ssd_state_passing.py @@ -0,0 +1,87 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +import unittest +import torch + +try: + from megatron.core.ssm.ops.ssd_state_passing import _state_passing_fwd + HAVE_SSD_OPS = True +except (ImportError, Exception): + HAVE_SSD_OPS = False + + +@unittest.skipIf(not HAVE_SSD_OPS, "SSD ops (Triton 3+) not available") +@unittest.skipIf(not torch.cuda.is_available(), "CUDA required for SSD ops") +class TestStatePassingFwd(unittest.TestCase): + """Tests for _state_passing_fwd: recurrence out = exp(dA_cs_last) * prev + new_states.""" + + def setUp(self): + torch.manual_seed(42) + self.device = torch.device("cuda") + self.nchunks = 4 + self.nheads = 2 + self.chunk_size = 16 + self.dim = self.chunk_size * 8 # headdim * dstate flattened + self.cu_chunk_seqlens = torch.tensor( + [0, 16, 32, 48, 64], dtype=torch.int32, device=self.device + ) + + def test_state_passing_fwd_shape(self): + states = torch.randn( + self.nchunks, self.nheads, self.dim, device=self.device, dtype=torch.float32 + ) + dA_cumsum = torch.randn( + self.nheads, self.nchunks, self.chunk_size, device=self.device, dtype=torch.float32 + ) + seq_idx = torch.zeros(self.nchunks, dtype=torch.int32, device=self.device) + + out = _state_passing_fwd( + states, dA_cumsum, self.cu_chunk_seqlens, seq_idx, initial_states=None + ) + + self.assertEqual(out.shape, (self.nchunks, self.nheads, self.dim)) + self.assertFalse(torch.isnan(out).any()) + + def test_state_passing_fwd_with_initial_states(self): + states = torch.randn( + self.nchunks, self.nheads, self.dim, device=self.device, dtype=torch.float32 + ) + dA_cumsum = torch.randn( + self.nheads, self.nchunks, self.chunk_size, device=self.device, dtype=torch.float32 + ) + seq_idx = torch.tensor([0, 0, 1, 1], dtype=torch.int32, device=self.device) + initial_states = torch.randn(2, self.nheads, self.dim, device=self.device, dtype=torch.float32) + + out = _state_passing_fwd( + states, + dA_cumsum, + self.cu_chunk_seqlens, + seq_idx, + initial_states=initial_states, + ) + + self.assertEqual(out.shape, (self.nchunks, self.nheads, self.dim)) + self.assertFalse(torch.isnan(out).any()) + + def test_state_passing_fwd_recurrence_single_head_single_dim(self): + """Sanity: single head, small dim, check recurrence manually for first elements.""" + dim = 4 + nchunks = 2 + nheads = 1 + chunk_size = 2 + cu_chunk_seqlens = torch.tensor([0, 2, 4], dtype=torch.int32, device=self.device) + seq_idx = torch.zeros(nchunks, dtype=torch.int32, device=self.device) + + states = torch.randn(nchunks, nheads, dim, device=self.device, dtype=torch.float32) + dA_cumsum = torch.randn(nheads, nchunks, chunk_size, device=self.device, dtype=torch.float32) + + out = _state_passing_fwd(states, dA_cumsum, cu_chunk_seqlens, seq_idx) + + # Chunk 0: out[0] = exp(dA_cumsum[0,-1]) * 0 + states[0] = states[0] (no initial state) + # So out[0] should equal states[0] + torch.testing.assert_close(out[0], states[0], rtol=1e-4, atol=1e-4) + self.assertEqual(out.shape, (nchunks, nheads, dim)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit_tests/ssm/ops/test_ssm_kernel.py b/tests/unit_tests/ssm/ops/test_ssm_kernel.py new file mode 100644 index 00000000000..646fbb8162c --- /dev/null +++ b/tests/unit_tests/ssm/ops/test_ssm_kernel.py @@ -0,0 +1,162 @@ +import unittest +from unittest.mock import MagicMock +import torch +import torch.nn as nn +import math + +# Assume the provided class is in mamba_mixer.py +from megatron.core.ssm.mamba_mixer import MambaMixer + +class MockContextParallel: + """ + Mocks the MambaContextParallel helper. + """ + def __init__(self, d_inner, ngroups, nheads, d_state, device): + self.d_inner_local_tpcp = d_inner + self.ngroups_local_tpcp = ngroups + self.nheads_local_tpcp = nheads + self.cp_size = 1 + + # Random weights for the mock + self.conv1d_weight = torch.randn(d_inner + 2 * ngroups * d_state, 1, 4, device=device) + self.conv1d_bias = torch.randn(d_inner + 2 * ngroups * d_state, device=device) + self.A_log = torch.randn(nheads, device=device) + self.D = torch.ones(nheads, device=device) + self.dt_bias = torch.randn(nheads, device=device) + + # Simple conv1d layer for the fallback path if needed + self.conv1d_layer = nn.Conv1d( + in_channels=self.conv1d_weight.shape[0], + out_channels=self.conv1d_weight.shape[0], + kernel_size=4, groups=self.conv1d_weight.shape[0], padding=3 + ).to(device) + + def get_A_log(self): return self.A_log + def get_D(self): return self.D + def get_dt_bias(self): return self.dt_bias + def get_conv1d_weight(self): return self.conv1d_weight + def get_conv1d_bias(self): return self.conv1d_bias + + def conv1d(self, x): + return self.conv1d_layer(x) + + def pre_conv_ssm(self, x): return x + def post_conv_ssm(self, x): return x + + +class TestMambaDynamicInference(unittest.TestCase): + + def setUp(self): + torch.manual_seed(42) + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + if self.device.type == 'cpu': + self.skipTest("Mamba Triton kernels require CUDA") + + # --- Configuration --- + self.d_model = 256 + self.d_state = 16 + self.headdim = 64 + self.d_conv = 4 + self.ngroups = 1 + self.d_inner = self.d_model * 2 # expand=2 + self.nheads = self.d_inner // self.headdim + + # Create the Mixer instance directly + self.mixer = MagicMock(spec=MambaMixer) + self.mixer.d_state = self.d_state + self.mixer.d_conv = self.d_conv + self.mixer.headdim = self.headdim + self.mixer.chunk_size = 256 + self.mixer.activation = "silu" + self.mixer.act = nn.SiLU() + self.mixer.D_has_hdim = False + self.mixer.rmsnorm = True + + # Mock the Context Parallel wrapper (used by ssm_prefill) + self.mixer.cp = MockContextParallel( + d_inner=self.d_inner, + ngroups=self.ngroups, + nheads=self.nheads, + d_state=self.d_state, + device=self.device + ) + + # --- Setup for ssm_decode --- + # ssm_decode accesses attributes directly from self, not self.cp + self.mixer.d_inner_local_tp = self.d_inner + self.mixer.ngroups_local_tp = self.ngroups + self.mixer.nheads_local_tp = self.nheads + + # Create real parameters for ssm_decode to access + conv_dim = self.d_inner + 2 * self.ngroups * self.d_state + self.mixer.conv1d = nn.Conv1d( + in_channels=conv_dim, out_channels=conv_dim, + kernel_size=self.d_conv, groups=conv_dim, padding=self.d_conv - 1, + bias=True, device=self.device + ) + self.mixer.dt_bias = nn.Parameter(torch.randn(self.nheads, device=self.device)) + self.mixer.A_log = nn.Parameter(torch.randn(self.nheads, device=self.device)) + self.mixer.D = nn.Parameter(torch.ones(self.nheads, device=self.device)) + + # Bind methods + self.mixer._ssm_prefill = MambaMixer._ssm_prefill.__get__(self.mixer, MambaMixer) + self.mixer._ssm_decode = MambaMixer._ssm_decode.__get__(self.mixer, MambaMixer) + + def test_ssm_prefill_padding_isolation(self): + """ + Tests that ssm_prefill only updates states for the real request + and outputs zeros for padding tokens. + """ + num_requests = 48 + real_seq_len = 6 + total_tokens = 63 + + # Inputs + dim_inputs = self.d_inner * 2 + 2 * self.ngroups * self.d_state + self.nheads + zxBCdt = torch.randn(total_tokens, 1, dim_inputs, device=self.device, dtype=torch.float32) + + # Metadata + seq_idx = torch.full((total_tokens,), -1, dtype=torch.int32, device=self.device) + seq_idx[:real_seq_len] = 0 + seq_idx = seq_idx.unsqueeze(0) + + cu_seqlens = torch.full((num_requests + 1,), real_seq_len, dtype=torch.int32, device=self.device) + cu_seqlens[0] = 0 + + batch_indices = torch.full((num_requests,), -1, dtype=torch.long, device=self.device) + batch_indices[0] = 0 + + # States + conv_dim = self.d_inner + 2 * self.ngroups * self.d_state + conv_state = torch.zeros(num_requests, conv_dim, self.d_conv, device=self.device) + ssm_state = torch.zeros(num_requests, self.nheads, self.headdim, self.d_state, device=self.device) + + # Run + self.mixer.norm = MagicMock(side_effect=lambda x, z: x * z) + output = self.mixer._ssm_prefill( + zxBCdt=zxBCdt, conv_state=conv_state, ssm_state=ssm_state, + seq_idx=seq_idx, cu_seqlens=cu_seqlens, + batch_indices=batch_indices, return_varlen_states=True + ) + + # Assertions + real_output = output[0:real_seq_len] + padding_output = output[real_seq_len:] + + self.assertTrue(torch.allclose(padding_output, torch.zeros_like(padding_output)), + "Output for padding tokens should be 0") + self.assertTrue(conv_state[0].abs().max() > 0, "Real request conv_state should be modified") + + # Verify isolation of padding states + remaining_conv_states = conv_state[1:num_requests] + remaining_ssm_states = ssm_state[1:num_requests] + + self.assertTrue(torch.allclose(remaining_conv_states, torch.zeros_like(remaining_conv_states)), + "Conv states for padding requests (indices 1 to N-1) should remain 0") + self.assertTrue(torch.allclose(remaining_ssm_states, torch.zeros_like(remaining_ssm_states)), + "SSM states for padding requests (indices 1 to N-1) should remain 0") + print("Prefill Test Passed!") + + +if __name__ == '__main__': + unittest.main(argv=['first-arg-is-ignored'], exit=False) From 196c3b3a1d7c936b159797a3220ff395995be3c7 Mon Sep 17 00:00:00 2001 From: Lawrence McAfee Date: Thu, 5 Mar 2026 05:38:58 -0800 Subject: [PATCH 09/11] Fix SSM kernel bugs and unify Mamba prefill through varlen path Bug fixes: 1. conv_state save-before-read: extract initial conv states BEFORE causal_conv1d_varlen_states + tensor_masked_update overwrites the conv_state buffer. Previously, initial_conv_states was read AFTER the buffer was updated, so restored requests would see their own newly-computed final states instead of the pre-existing initial states, corrupting the convolution output. 2. cu_chunk_seqlens OOB: the SSM Triton kernels allocate per-chunk output arrays of size chunk_size (128). Passing cu_seqlens directly as cu_chunk_seqlens caused out-of-bounds memory access when any sequence exceeded chunk_size tokens. Fix: subdivide each sequence into chunks of at most self.chunk_size, producing correct cu_chunk_seqlens boundaries. 3. zxBCdt padding mismatch: after conv1d, the per-request loop rebuilt xBC with only real tokens while dt and z retained padded token count. This caused a shape assertion failure in the SSM kernel. Fix: strip padded tokens from zxBCdt before _ssm_prefill, then pad the output back to the original padded size for downstream residual add. 4. Per-request conv1d with initial_states: causal_conv1d_fn cannot accept both seq_idx and initial_states simultaneously. The old code passed seq_idx to handle multiple sequences but this zeroes state at sequence boundaries instead of using the cached initial states. Fix: loop over requests, calling causal_conv1d_fn per-request with initial_states and channels-last layout. Improvements: - Unify all Mamba prefill (including chunked) through single varlen SSM kernel call, removing separate chunked-prefill routing and the _batch_indices_chunked_prefill / _device_chunked_prefill metadata - Simplify _dynamic_inference to flat decode + prefill structure - Add _dynamic_inference_prefill helper that strips CUDA-graph padding from metadata and data tensors before calling _ssm_prefill - Remove deprecated constructor parameters (use_mem_eff_path, d_state, headdim, ngroups) and their warnings - Add assertion format string in ssd_combined.py for easier debugging Co-Authored-By: Claude Opus 4.6 --- .../attention_context/mamba_metadata.py | 123 +++--- megatron/core/ssm/mamba_mixer.py | 394 +++++++----------- megatron/core/ssm/ops/ssd_combined.py | 2 +- 3 files changed, 215 insertions(+), 304 deletions(-) diff --git a/megatron/core/inference/contexts/attention_context/mamba_metadata.py b/megatron/core/inference/contexts/attention_context/mamba_metadata.py index 6cf45aeb9e1..03b19bc0570 100644 --- a/megatron/core/inference/contexts/attention_context/mamba_metadata.py +++ b/megatron/core/inference/contexts/attention_context/mamba_metadata.py @@ -60,11 +60,6 @@ def __init__(self, max_requests: int, max_tokens: int): (self.max_requests,), -1, dtype=torch.int32, device=self.device ) - # Map from the active chunked prefill request to its slot in the static Mamba state buffer - self._batch_indices_chunked_prefill_buffer = torch.full( - (1,), -1, dtype=torch.int32, device=self.device - ) - # Map from token id to request id for active prefill requests self._seq_idx_buffer = torch.full( (1, self.max_tokens), -1, dtype=torch.int32, device=self.device @@ -80,20 +75,14 @@ def __init__(self, max_requests: int, max_tokens: int): (2,), dtype=torch.int32, device=self.device ) - # Tuple of ( - # total prefill sequence length excluding chunked prefill, - # chunked prefill sequence length - # ) - self._device_chunked_prefill_buffer = torch.zeros( - (2,), dtype=torch.int32, device=self.device - ) - # Allocator for Mamba state slots self.mamba_state_free_slots = torch.arange( self.max_requests, dtype=torch.int32, device=torch.cuda.current_device() ) self.mamba_state_free_slot_count = self.max_requests + self.reset_varlen_metadata() + def reset(self) -> None: """ Resets all Mamba states and frees all allocated slots. @@ -112,11 +101,9 @@ def reset_varlen_metadata(self) -> None: """Resets varlen metadata.""" self.batch_indices_decode = None self.batch_indices_prefill = None - self.batch_indices_chunked_prefill = None self.cu_seqlens = None self.seq_idx = None self.device_decode_prefill = None - self.device_chunked_prefill = None def update( self, @@ -133,20 +120,17 @@ def update( Args: active_mamba_indices (Tensor): Tensor containing the Mamba slot indices for active requests. - num_active_requests (int): The number of active requests. + token_to_request_idx (Tensor): Map from token index to request index. + cu_seqlens (Tensor): Cumulative sequence lengths. + batch_dimensions (InferenceBatchDimensions): Dimensions of the current batch. + padded_batch_dimensions (InferenceBatchDimensions): Dimensions of the padded batch. """ real_decode_count = batch_dimensions.decode_req_count real_prefill_count = batch_dimensions.prefill_req_count - real_token_count = batch_dimensions.token_count - has_explicit_chunked_prefill_req = batch_dimensions.has_explicit_chunked_prefill_req padded_decode_count = padded_batch_dimensions.decode_req_count padded_prefill_count = padded_batch_dimensions.prefill_req_count padded_token_count = padded_batch_dimensions.token_count - assert ( - has_explicit_chunked_prefill_req - == padded_batch_dimensions.has_explicit_chunked_prefill_req - ) if padded_decode_count > 0: # Update decode indices @@ -157,83 +141,61 @@ def update( self._batch_indices_decode_buffer[real_decode_count:padded_decode_count] = -1 self.batch_indices_decode = self._batch_indices_decode_buffer[:padded_decode_count] - # Determine if we have a chunked prefill request and adjust counts for regular prefill - regular_prefill_count = real_prefill_count - if has_explicit_chunked_prefill_req: - # The last prefill request is the chunked one - regular_prefill_count -= 1 - chunked_req_idx = real_decode_count + regular_prefill_count - - # Update chunked prefill indices - self._batch_indices_chunked_prefill_buffer[0] = active_mamba_indices[chunked_req_idx] - self.batch_indices_chunked_prefill = self._batch_indices_chunked_prefill_buffer - else: - self.batch_indices_chunked_prefill = None - if padded_prefill_count > 0: - # Update prefill indices (excluding chunked prefill from regular prefill buffer) - if regular_prefill_count > 0: - self._batch_indices_prefill_buffer[:regular_prefill_count].copy_( - active_mamba_indices[ - real_decode_count : real_decode_count + regular_prefill_count - ] + # Update prefill indices (all prefill requests go through varlen) + if real_prefill_count > 0: + prefill_start_idx = real_decode_count + self._batch_indices_prefill_buffer[:real_prefill_count].copy_( + active_mamba_indices[prefill_start_idx : prefill_start_idx + real_prefill_count] ) - if padded_prefill_count > regular_prefill_count: - self._batch_indices_prefill_buffer[regular_prefill_count:padded_prefill_count] = -1 + if padded_prefill_count > real_prefill_count: + self._batch_indices_prefill_buffer[ + real_prefill_count:padded_prefill_count + ] = -1 self.batch_indices_prefill = self._batch_indices_prefill_buffer[:padded_prefill_count] - # Update seq_idx - end_regular_prefill_token_idx = cu_seqlens[real_decode_count + regular_prefill_count] + # Update seq_idx for all prefill requests + prefill_start_req_idx = real_decode_count + end_prefill_req_idx = real_decode_count + real_prefill_count + + start_prefill_token_idx = cu_seqlens[prefill_start_req_idx] + end_prefill_token_idx = cu_seqlens[end_prefill_req_idx] - # The length of tokens belonging to regular prefill requests (excluding decode tokens) - seq_len = end_regular_prefill_token_idx - real_decode_count + seq_len = end_prefill_token_idx - start_prefill_token_idx if seq_len > 0: + # Normalize request IDs to 0-based relative to prefill requests self._seq_idx_buffer[:, :seq_len].copy_( - token_to_request_idx[real_decode_count:end_regular_prefill_token_idx] - - real_decode_count + token_to_request_idx[start_prefill_token_idx:end_prefill_token_idx] + - prefill_start_req_idx ) if padded_token_count > seq_len: self._seq_idx_buffer[:, seq_len:padded_token_count] = -1 self.seq_idx = self._seq_idx_buffer[:, :padded_token_count] - # Update cu_seqlens + # Update cu_seqlens for all prefill requests self._cu_seqlens_buffer[0] = 0 - if regular_prefill_count > 0: - self._cu_seqlens_buffer[1 : regular_prefill_count + 1].copy_( - cu_seqlens[ - real_decode_count + 1 : real_decode_count + regular_prefill_count + 1 - ] - - real_decode_count + if real_prefill_count > 0: + self._cu_seqlens_buffer[1 : real_prefill_count + 1].copy_( + cu_seqlens[prefill_start_req_idx + 1 : end_prefill_req_idx + 1] + - cu_seqlens[prefill_start_req_idx] ) # Pad the rest with the last value (effectively length 0 segments) - last_val = self._cu_seqlens_buffer[regular_prefill_count] - self._cu_seqlens_buffer[regular_prefill_count + 1 : padded_prefill_count + 1].fill_( + last_val = self._cu_seqlens_buffer[real_prefill_count] + self._cu_seqlens_buffer[real_prefill_count + 1 : padded_prefill_count + 1].fill_( last_val ) self.cu_seqlens = self._cu_seqlens_buffer[: padded_prefill_count + 1] if padded_decode_count > 0 and padded_prefill_count > 0: self._device_decode_prefill_buffer[0] = real_decode_count - self._device_decode_prefill_buffer[1] = regular_prefill_count + self._device_decode_prefill_buffer[1] = real_prefill_count self.device_decode_prefill = self._device_decode_prefill_buffer - # If using chunked prefill for this batch, store the number of regular prefill tokens - # and the number of tokens in the chunked prefill request - if has_explicit_chunked_prefill_req: - chunked_prefill_token_count = ( - cu_seqlens[real_decode_count + real_prefill_count] - - cu_seqlens[real_decode_count + real_prefill_count - 1] - ) - assert self.cu_seqlens is not None - self._device_chunked_prefill_buffer[0] = self.cu_seqlens[regular_prefill_count] - self._device_chunked_prefill_buffer[1] = chunked_prefill_token_count - self.device_chunked_prefill = self._device_chunked_prefill_buffer - def allocate_slot(self) -> Optional[int]: """ Allocates a new slot for a request in the Mamba state buffers. @@ -251,6 +213,25 @@ def allocate_slot(self) -> Optional[int]: return mamba_idx + def batch_allocate_slots(self, num_slots: int) -> Optional[torch.Tensor]: + """ + Allocates new slots for the given number of requests in the Mamba state buffers. + + Returns: + torch.Tensor: The indices of the allocated slots. + Returns None if not enough slots are available. + """ + if self.mamba_state_free_slot_count < num_slots: + return None + + # Get free slots + self.mamba_state_free_slot_count -= num_slots + mamba_idx = self.mamba_state_free_slots[ + self.mamba_state_free_slot_count : self.mamba_state_free_slot_count + num_slots + ] + + return mamba_idx + def free_slots(self, request_indices: torch.Tensor) -> None: """ Frees the Mamba state slots associated with the given request indices. diff --git a/megatron/core/ssm/mamba_mixer.py b/megatron/core/ssm/mamba_mixer.py index 9658243fc8c..94a7c7c3a3d 100644 --- a/megatron/core/ssm/mamba_mixer.py +++ b/megatron/core/ssm/mamba_mixer.py @@ -7,12 +7,9 @@ import logging import math -import warnings from dataclasses import dataclass, replace from typing import List, Optional, Tuple, Union -from yaml.error import YAMLError - import torch import torch.nn as nn import torch.nn.functional as F @@ -178,10 +175,6 @@ def __init__( # Fused kernel and sharding options chunk_size=128, layer_number=None, - use_mem_eff_path=None, - d_state=None, - headdim=None, - ngroups=None, pg_collection: ProcessGroupCollection = None, pp_layer_offset: int = 0, ): @@ -209,33 +202,6 @@ def __init__( self.cached_batch_size = None assert pg_collection is not None, "pg_collection must be provided for MambaMixer" self.pg_collection = pg_collection - - # Check for deprecated arguments and raise warnings - if use_mem_eff_path is not None: - warnings.warn( - "The 'use_mem_eff_path' argument is deprecated and will be removed in the future. " - "Please use the value from the TransformerConfig object instead.", - DeprecationWarning, - ) - if d_state is not None: - warnings.warn( - "The 'd_state' argument is deprecated and will be removed in the future. " - "Please use the value from the TransformerConfig object instead.", - DeprecationWarning, - ) - if headdim is not None: - warnings.warn( - "The 'headdim' argument is deprecated and will be removed in the future. " - "Please use the value from the TransformerConfig object instead.", - DeprecationWarning, - ) - if ngroups is not None: - warnings.warn( - "The 'ngroups' argument is deprecated and will be removed in the future. " - "Please use the value from the TransformerConfig object instead.", - DeprecationWarning, - ) - self.use_mem_eff_path = self.config.use_mamba_mem_eff_path self.d_state = self.config.mamba_state_dim self.headdim = self.config.mamba_head_dim @@ -466,8 +432,7 @@ def forward( def _dynamic_inference(self, hidden_states: torch.Tensor, context: DynamicInferenceContext): """ Executes dynamic inference by separating decode and prefill requests and - running them independently. Also runs the chunked prefill request independently - if it exists. + running them independently. """ sequence_packing_available, reason_for_no_sequence_packing = ( _check_mamba_sequence_packing_support(for_inference_not_training=True) @@ -477,123 +442,48 @@ def _dynamic_inference(self, hidden_states: torch.Tensor, context: DynamicInfere conv_state, ssm_state = context.mamba_states_cache(self.layer_number - self.pp_layer_offset) padded_dims = context.padded_batch_dimensions - token_count = padded_dims.token_count decode_req_count = padded_dims.decode_req_count prefill_req_count = padded_dims.prefill_req_count - has_explicit_chunked_prefill_req = padded_dims.has_explicit_chunked_prefill_req # Input projection zxBCdt, _ = self.in_proj(hidden_states) - if decode_req_count > 0 and prefill_req_count == 0: - # Decode-only - y = self._ssm_decode( - zxBCdt.transpose(0, 1), - conv_state, - ssm_state, - context.mamba_metadata.batch_indices_decode, - ).transpose(0, 1) - elif decode_req_count == 0 and (prefill_req_count > 0 or has_explicit_chunked_prefill_req): - if prefill_req_count > 0: - # Prefill only (regular prefill requests) - y_prefill = self._ssm_prefill( - zxBCdt, - conv_state=conv_state, - ssm_state=ssm_state, - seq_idx=context.mamba_metadata.seq_idx, - cu_seqlens=context.mamba_metadata.cu_seqlens, - return_varlen_states=True, - batch_indices=context.mamba_metadata.batch_indices_prefill, - ) - if has_explicit_chunked_prefill_req: - # Prefill only (chunked prefill request) - zxBCdt_chunked_prefill = torch.empty_like(zxBCdt) - tensor_get_slice_after( - zxBCdt, - zxBCdt_chunked_prefill, - context.mamba_metadata.device_chunked_prefill, - check_bounds=False, - ) - y_chunked_prefill = self._ssm_prefill( - zxBCdt_chunked_prefill[: context.mamba_metadata.device_chunked_prefill[1]], - conv_state=conv_state, - ssm_state=ssm_state, - batch_indices=context.mamba_metadata.batch_indices_chunked_prefill, - is_chunked_prefill=True, - ) - if prefill_req_count > 0 and has_explicit_chunked_prefill_req: - # Merge regular prefill and chunked prefill parts - tensor_merge( - y_prefill, y_chunked_prefill, context.mamba_metadata.device_chunked_prefill - ) - y = y_prefill - elif prefill_req_count > 0: - # Prefill-only without chunked prefill - y = y_prefill - else: - # Prefill-only with only chunked prefill - y = y_chunked_prefill - else: - # Mix of decode and prefill - zxBCdt_prefill = torch.empty_like(zxBCdt) - tensor_get_slice_after( - zxBCdt, - zxBCdt_prefill, - context.mamba_metadata.device_decode_prefill, - check_bounds=False, - ) - # Decode requests + y_decode = None + y_prefill = None + + # Decode + if decode_req_count > 0: + # For mixed batch, the decode tokens are at the start of zxBCdt + zxBCdt_decode = zxBCdt[:decode_req_count] if prefill_req_count > 0 else zxBCdt + y_decode = self._ssm_decode( - zxBCdt[:decode_req_count].transpose(0, 1), + zxBCdt_decode.transpose(0, 1), conv_state, ssm_state, context.mamba_metadata.batch_indices_decode, ).transpose(0, 1) - y_prefill, y_chunked_prefill = None, None - if prefill_req_count > 0: - # Regular prefill requests - y_prefill = self._ssm_prefill( - zxBCdt_prefill, - conv_state=conv_state, - ssm_state=ssm_state, - seq_idx=context.mamba_metadata.seq_idx, - cu_seqlens=context.mamba_metadata.cu_seqlens, - return_varlen_states=True, - batch_indices=context.mamba_metadata.batch_indices_prefill, - ) - if has_explicit_chunked_prefill_req: - # Chunked prefill request - zxBCdt_chunked_prefill = torch.empty_like(zxBCdt_prefill) + + # Prefill + if prefill_req_count > 0: + if decode_req_count > 0: + # If mixed, slice the prefill portion out of zxBCdt + zxBCdt_prefill = torch.empty_like(zxBCdt) tensor_get_slice_after( + zxBCdt, zxBCdt_prefill, - zxBCdt_chunked_prefill, - context.mamba_metadata.device_chunked_prefill, + context.mamba_metadata.device_decode_prefill, check_bounds=False, ) - y_chunked_prefill = self._ssm_prefill( - zxBCdt_chunked_prefill[: context.mamba_metadata.device_chunked_prefill[1]], - conv_state=conv_state, - ssm_state=ssm_state, - batch_indices=context.mamba_metadata.batch_indices_chunked_prefill, - is_chunked_prefill=True, - ) - if prefill_req_count > 0 and has_explicit_chunked_prefill_req: - # Merge regular prefill and chunked prefill parts - assert y_prefill is not None - assert y_chunked_prefill is not None - tensor_merge( - y_prefill, y_chunked_prefill, context.mamba_metadata.device_chunked_prefill - ) - elif has_explicit_chunked_prefill_req: - # Chunked prefill only - assert y_prefill is None - assert y_chunked_prefill is not None - y_prefill = y_chunked_prefill else: - # Regular prefill only; y_prefill is already set, nothing more to be done - assert y_prefill is not None - # Merge decode and prefill parts + zxBCdt_prefill = zxBCdt + + y_prefill = self._dynamic_inference_prefill( + zxBCdt_prefill, context, conv_state, ssm_state + ) + + # Merge decode and prefill results if necessary + if y_decode is not None and y_prefill is not None: y = torch.empty( [token_count, 1, y_prefill.shape[-1]], dtype=y_prefill.dtype, @@ -602,12 +492,71 @@ def _dynamic_inference(self, hidden_states: torch.Tensor, context: DynamicInfere tensor_merge( y_decode, y_prefill, context.mamba_metadata.device_decode_prefill, output_tensor=y ) + elif y_decode is not None: + y = y_decode + elif y_prefill is not None: + y = y_prefill + else: + raise RuntimeError("Dynamic inference called with 0 decode and 0 prefill requests") # Output projection out, out_bias = self.out_proj(y) return out, out_bias + def _dynamic_inference_prefill( + self, + zxBCdt: torch.Tensor, + context: DynamicInferenceContext, + conv_state: torch.Tensor, + ssm_state: torch.Tensor, + ) -> torch.Tensor: + """Helper to run dynamic inference prefill. + + All prefill requests (including chunked prefill) are processed together + through the unified varlen path. + """ + metadata = context.mamba_metadata + real_prefill_count = context.batch_dimensions.prefill_req_count + if real_prefill_count <= 0: + return None + + # Strip CUDA-graph padding from metadata tensors. Padded + # entries have -1 batch_indices and zero-length cu_seqlens segments, + # which cause out-of-bounds indexing in the varlen SSM kernel. + # Also strip padded tokens from zxBCdt to keep all downstream tensor + # shapes (z, xBC, dt) consistent. Pad output back afterward for + # residual add compatibility. + cu_seqlens = metadata.cu_seqlens[: real_prefill_count + 1] + batch_indices = metadata.batch_indices_prefill[:real_prefill_count] + real_token_count = cu_seqlens[-1].item() + seq_idx = ( + metadata.seq_idx[:, :real_token_count] + if metadata.seq_idx is not None + else None + ) + + padded_token_count = zxBCdt.shape[0] + zxBCdt = zxBCdt[:real_token_count] + + y_prefill = self._ssm_prefill( + zxBCdt, + conv_state=conv_state, + ssm_state=ssm_state, + seq_idx=seq_idx, + cu_seqlens=cu_seqlens, + batch_indices=batch_indices, + ) + + # Pad output back to padded token count. The caller (residual add, + # tensor_merge) expects the output to match the padded input shape. + if y_prefill.shape[0] < padded_token_count: + y_prefill = F.pad( + y_prefill, (0, 0, 0, 0, 0, padded_token_count - y_prefill.shape[0]) + ) + + return y_prefill + def _decode( self, hidden_states, conv_state, ssm_state, batch_indices: Optional[torch.Tensor] = None ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -735,9 +684,7 @@ def _ssm_prefill( ssm_state: Optional[torch.Tensor], seq_idx: Optional[torch.Tensor] = None, cu_seqlens: Optional[torch.Tensor] = None, - return_varlen_states: bool = False, batch_indices: Optional[torch.Tensor] = None, - is_chunked_prefill: bool = False, ) -> torch.Tensor: """ Performs SSM computation for inference prefill step. @@ -749,18 +696,13 @@ def _ssm_prefill( ssm_state: The selective scan state tensor for inference. seq_idx: A map from token index to request index for variable-length sequences. cu_seqlens: Cumulative sequence lengths for variable-length sequences. - return_varlen_states: Whether to return variable-length states from the SSM kernel. batch_indices: A map from batch id to position in the Mamba state tensors for dynamic inference. - is_chunked_prefill: Whether the request is a chunked prefill request. Returns: The output tensor of shape (l, b, d). """ is_dynamic_batching = seq_idx is not None - assert not ( - is_dynamic_batching and is_chunked_prefill - ), "Cannot use chunked prefill with dynamic batching" # transpose: l b pd --> b l pd zxBCdt = rearrange(zxBCdt, "l b d -> b l d").contiguous() @@ -779,31 +721,53 @@ def _ssm_prefill( ) # Compute short convolution - initial_conv_state = None if conv_state is not None and is_dynamic_batching: - # xBC should have shape (b l d) for causal_conv1d_varlen_states assert batch_indices is not None + + # Extract initial conv states BEFORE saving new ones. + # causal_conv1d_varlen_states computes the final conv state from the + # input sequence and tensor_masked_update writes it into the conv_state + # buffer. If we read initial_conv_states after this write, restored + # requests see their own newly-computed states instead of the cached + # initial states from a previous request, corrupting the conv output. + initial_conv_states = conv_state[batch_indices, :, 1:] + + # Save final conv states from the input sequence conv_varlen_states = causal_conv1d_varlen_states( xBC.squeeze(0), cu_seqlens, state_len=conv_state.shape[-1] ) tensor_masked_update(conv_state, batch_indices, conv_varlen_states) - # Maintain channels-last memory layout to use seq_idx for causal_conv1d_fn - # See https://github.com/Dao-AILab/causal-conv1d/blob/69e6dadc28b169a4c49cb86b586f64ee90242c70/csrc/causal_conv1d.cpp#L174 # pylint: disable=line-too-long - xBC = xBC.transpose(1, 2) - elif is_chunked_prefill: - # Maintain channels-last memory layout to use initial_states for causal_conv1d_fn - # See https://github.com/Dao-AILab/causal-conv1d/blob/69e6dadc28b169a4c49cb86b586f64ee90242c70/csrc/causal_conv1d.cpp#L200 # pylint: disable=line-too-long - assert batch_indices is not None - initial_conv_state = ( - conv_state[batch_indices, :, 1:].permute(0, 2, 1).contiguous().transpose(1, 2) - ) - xBC = xBC.transpose(1, 2) - tensor_masked_update( - conv_state, batch_indices, F.pad(xBC, (self.d_conv - xBC.shape[-1], 0)) - ) + conv_weight = rearrange(self.cp.get_conv1d_weight(), "d 1 w -> d w") + conv_bias = self.cp.get_conv1d_bias() + + # causal_conv1d_fn cannot accept both seq_idx and + # initial_states simultaneously. Using seq_idx with packed sequences + # zeroes the conv state at sequence boundaries instead of using the + # cached initial states. We must loop over requests individually, + # passing initial_states per-request with channels-last layout. + num_requests = cu_seqlens.shape[0] - 1 + xBC_parts = [] + for r in range(num_requests): + start = cu_seqlens[r].item() + end = cu_seqlens[r + 1].item() + if end <= start: + continue + # xBC is (1, total_tokens, conv_dim); slice gives channels-last via transpose + xBC_r = xBC[:, start:end, :].transpose(1, 2) # channels-last (1, C, L) + init_r = initial_conv_states[r : r + 1] # (1, conv_dim, d_conv-1) + init_r = init_r.permute(0, 2, 1).contiguous().transpose(1, 2) # channels-last + xBC_r = causal_conv1d_fn( + x=xBC_r, + weight=conv_weight, + bias=conv_bias, + activation=self.activation, + initial_states=init_r, + ) + xBC_parts.append(xBC_r.transpose(1, 2).contiguous()) # (1, L, C) + xBC = torch.cat(xBC_parts, dim=1) # (1, total_tokens, conv_dim) else: - # transpose: b l pd --> b pd l + # Non-dynamic-batching path (static batching / training fallback) xBC = rearrange(xBC, "b l d -> b d l").contiguous() if conv_state is not None: # If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv @@ -812,22 +776,19 @@ def _ssm_prefill( F.pad(xBC, (self.d_conv - xBC.shape[-1], 0)) ) # Update state (B D W) - seqlen = xBC.size(2) - if causal_conv1d_fn is None: - xBC = self.act(self.cp.conv1d(xBC)[..., :seqlen]) - else: - assert self.activation in ["silu", "swish"] - xBC = causal_conv1d_fn( - x=xBC, - weight=rearrange(self.cp.get_conv1d_weight(), "d 1 w -> d w"), - bias=self.cp.get_conv1d_bias(), - activation=self.activation, - seq_idx=seq_idx, - initial_states=initial_conv_state, - ) - - # transpose b pd l --> b l pd - xBC = rearrange(xBC, "b d l -> b l d").contiguous() + seqlen = xBC.size(2) + if causal_conv1d_fn is None: + xBC = self.act(self.cp.conv1d(xBC)[..., :seqlen]) + else: + assert self.activation in ["silu", "swish"] + xBC = causal_conv1d_fn( + x=xBC, + weight=rearrange(self.cp.get_conv1d_weight(), "d 1 w -> d w"), + bias=self.cp.get_conv1d_bias(), + activation=self.activation, + seq_idx=seq_idx, + ) + xBC = rearrange(xBC, "b d l -> b l d").contiguous() x, B, C = torch.split( xBC, @@ -839,28 +800,20 @@ def _ssm_prefill( dim=-1, ) - # TODO Vijay: fuse most of the transposes with the GEMMS x = rearrange(x, "b l (h p) -> b l h p", p=self.headdim).contiguous() dt = dt.contiguous() B = rearrange(B, "b l (g n) -> b l g n", n=self.d_state).contiguous() C = rearrange(C, "b l (g n) -> b l g n", n=self.d_state).contiguous() z = rearrange(z, "b l (h p) -> b l h p", p=self.headdim).contiguous() - # If `rmsnorm == False`, then the norm inside `mamba_chunk_scan_combined` will be used. - # In this case, if `cp_size > 1` then that norm could be performed on less heads than if - # `cp_size == 1` (groups of heads can be sharded across CP ranks), which would be - # mathematically incorrect, and potentially arithmetically unstable. assert ( self.cp.cp_size == 1 or self.rmsnorm ), "Context parallel not supported for use_mem_eff_path==False and rmsnorm==False" - initial_ssm_state = ssm_state[batch_indices] + if is_dynamic_batching: + # Unified varlen SSM path: all prefill requests through single kernel call + initial_ssm_state = ssm_state[batch_indices] - if ( - cu_seqlens is not None - ): - # Variable-length path: sequences are concatenated in one row (x.shape[0] == 1). - # Batch size = number of sequences from cu_seqlens; max_seqlen = max sequence length. x = x.squeeze(0) dt = dt.squeeze(0) A = A.squeeze(0) @@ -869,32 +822,25 @@ def _ssm_prefill( z = z.squeeze(0) y = torch.empty_like(x) - # Calculate cu_chunk_seqlens using seq_idx and cu_seqlens - cu_chunk_seqlens = None - if seq_idx is not None and cu_seqlens is not None: - # seq_idx: shape (1, total_tokens), where seq_idx[0, i] maps token position i to the packed sequence index - # cu_seqlens: shape (num_sequences + 1), e.g. [0, 5, 7, 11, 16], where cu_seqlens[-1] == total_tokens - - # The number of sequences = cu_seqlens.numel() - 1 = N - # The output cu_chunk_seqlens should be a cumulative sum of the number of tokens in each sequence, - # i.e. same as cu_seqlens - cu_chunk_seqlens = cu_seqlens - - # However, double check if seq_idx indicates any extra grouping, or if an extra entry is required: - # If seq_idx.max() + 1 > cu_seqlens.numel() - 1, then extra sequences - n_seq_from_seq_idx = int(seq_idx.max().item() + 1) - n_seq_from_cu = cu_seqlens.numel() - 1 - if n_seq_from_seq_idx > n_seq_from_cu: - # Need to extend cu_seqlens to include the rest of tokens counted in seq_idx - # This can happen if the last part is treated as an extra seq - cu_chunk_seqlens = torch.cat( - [cu_seqlens, cu_seqlens.new_tensor([seq_idx.shape[1]])] - ) - - # Kernel expects seq_idx of shape (nchunks,) — one sequence index per chunk. - # We have seq_idx of shape (1, total_tokens); take seq index at start of each chunk. + # Build cu_chunk_seqlens by subdividing each sequence into + # chunks of at most self.chunk_size tokens. The SSM Triton kernels + # allocate per-chunk output arrays of size chunk_size, so passing + # cu_seqlens directly would cause out-of-bounds memory access when + # any sequence exceeds chunk_size (default 128) tokens. + chunk_boundaries = [0] + num_seqs = cu_seqlens.numel() - 1 + for i in range(num_seqs): + start = cu_seqlens[i].item() + end = cu_seqlens[i + 1].item() + pos = start + self.chunk_size + while pos < end: + chunk_boundaries.append(pos) + pos += self.chunk_size + chunk_boundaries.append(end) + cu_chunk_seqlens = cu_seqlens.new_tensor(chunk_boundaries) + seq_idx_for_varlen = None - if seq_idx is not None and cu_chunk_seqlens is not None: + if seq_idx is not None: chunk_starts = cu_chunk_seqlens[:-1] seq_idx_for_varlen = seq_idx[0, chunk_starts].contiguous() @@ -928,8 +874,9 @@ def _ssm_prefill( z = z.unsqueeze(0) tensor_masked_update(ssm_state, batch_indices, ssm_varlen_states) - else: + # Non-dynamic-batching path (static batching) + initial_ssm_state = None y = mamba_chunk_scan_combined( x, dt, @@ -946,29 +893,12 @@ def _ssm_prefill( dt_bias=self.cp.get_dt_bias().float(), dt_softplus=True, return_final_states=ssm_state is not None, - seq_idx=seq_idx, - cu_seqlens=cu_seqlens, - return_varlen_states=return_varlen_states, initial_states=initial_ssm_state, ) if ssm_state is not None: - if return_varlen_states: - assert batch_indices is not None - - y, _, ssm_varlen_states = y - - # This has to be varlen_states, NOT last_state - # See reference implementation: - # https://github.com/state-spaces/mamba/blob/e0761ece1db07e0949dd88b4f4cd440420a19fd9/mamba_ssm/modules/mamba2.py#L267 # pylint: disable=line-too-long - tensor_masked_update(ssm_state, batch_indices, ssm_varlen_states) - elif is_chunked_prefill: - assert batch_indices is not None - y, last_state = y - tensor_masked_update(ssm_state, batch_indices, last_state) - else: - y, last_state = y - ssm_state.copy_(last_state) + y, last_state = y + ssm_state.copy_(last_state) y = rearrange(y, "b l h p -> l b (h p)").contiguous() y = self.cp.post_conv_ssm(y) diff --git a/megatron/core/ssm/ops/ssd_combined.py b/megatron/core/ssm/ops/ssd_combined.py index c6a8a363a5c..b7918fedf77 100644 --- a/megatron/core/ssm/ops/ssd_combined.py +++ b/megatron/core/ssm/ops/ssd_combined.py @@ -49,7 +49,7 @@ def _mamba_chunk_scan_combined_fwd( seqlen, nheads, headdim = x.shape _, ngroups, dstate = B.shape assert nheads % ngroups == 0 - assert B.shape == (seqlen, ngroups, dstate) + assert B.shape == (seqlen, ngroups, dstate), f"B.shape={B.shape} != ({seqlen}, {ngroups}, {dstate})" assert dt.shape == (seqlen, nheads) assert A.shape == (nheads,) assert C.shape == B.shape From 0678931983c06faf13c57270954d8faee003ce45 Mon Sep 17 00:00:00 2001 From: Lawrence McAfee Date: Thu, 5 Mar 2026 11:05:05 -0800 Subject: [PATCH 10/11] Eliminate chunk_state_varlen kernel call with chunk-aligned state extraction With chunk-aligned sequences (one sequence per chunk boundary), the final SSM state for each sequence is simply states[last_chunk_indices], making the separate chunk_state_varlen Triton kernel unnecessary. Construct last_chunk_indices in mamba_mixer.py alongside cu_chunk_seqlens and remove the cu_seqlens parameter from the varlen API since it was only needed by chunk_state_varlen. Co-Authored-By: Claude Opus 4.6 --- megatron/core/ssm/mamba_mixer.py | 9 +++++-- megatron/core/ssm/ops/ssd_combined.py | 27 +++++-------------- tests/unit_tests/ssm/ops/test_ssd_combined.py | 8 +----- 3 files changed, 15 insertions(+), 29 deletions(-) diff --git a/megatron/core/ssm/mamba_mixer.py b/megatron/core/ssm/mamba_mixer.py index 94a7c7c3a3d..541cd3f1fb2 100644 --- a/megatron/core/ssm/mamba_mixer.py +++ b/megatron/core/ssm/mamba_mixer.py @@ -827,7 +827,11 @@ def _ssm_prefill( # allocate per-chunk output arrays of size chunk_size, so passing # cu_seqlens directly would cause out-of-bounds memory access when # any sequence exceeds chunk_size (default 128) tokens. + # Also build last_chunk_indices: the index of the last chunk for + # each sequence, used to extract final SSM states directly from + # the states tensor without a separate kernel call. chunk_boundaries = [0] + last_chunk_indices = [] num_seqs = cu_seqlens.numel() - 1 for i in range(num_seqs): start = cu_seqlens[i].item() @@ -837,7 +841,9 @@ def _ssm_prefill( chunk_boundaries.append(pos) pos += self.chunk_size chunk_boundaries.append(end) + last_chunk_indices.append(len(chunk_boundaries) - 2) cu_chunk_seqlens = cu_seqlens.new_tensor(chunk_boundaries) + last_chunk_indices = cu_seqlens.new_tensor(last_chunk_indices) seq_idx_for_varlen = None if seq_idx is not None: @@ -851,9 +857,8 @@ def _ssm_prefill( B=B, C=C, chunk_size=self.chunk_size, - cu_seqlens=cu_seqlens, cu_chunk_seqlens=cu_chunk_seqlens, - last_chunk_indices=None, + last_chunk_indices=last_chunk_indices, seq_idx=seq_idx_for_varlen, out=y, D=( diff --git a/megatron/core/ssm/ops/ssd_combined.py b/megatron/core/ssm/ops/ssd_combined.py index b7918fedf77..b599702d2c2 100644 --- a/megatron/core/ssm/ops/ssd_combined.py +++ b/megatron/core/ssm/ops/ssd_combined.py @@ -13,7 +13,6 @@ from .ssd_chunk_state import ( _chunk_cumsum_fwd, _chunk_state_fwd, - chunk_state_varlen, ) from .ssd_state_passing import _state_passing_fwd @@ -38,7 +37,6 @@ def _mamba_chunk_scan_combined_fwd( initial_states=None, return_intermediate_states=False, seq_idx=None, - cu_seqlens=None, cu_chunk_seqlens=None, last_chunk_indices=None, dt_softplus=False, @@ -73,10 +71,12 @@ def _mamba_chunk_scan_combined_fwd( z = z.contiguous() if D is not None and D.stride(-1) != 1: D = D.contiguous() - assert cu_seqlens is not None, "Assuming varlen input - must supply cu_seqlens" + assert cu_chunk_seqlens is not None, "Assuming varlen input - must supply cu_chunk_seqlens" + assert last_chunk_indices is not None, "last_chunk_indices must be provided" if initial_states is not None: - assert initial_states.shape == (len(cu_seqlens) - 1, nheads, headdim, dstate) + num_seqs = last_chunk_indices.shape[0] + assert initial_states.shape == (num_seqs, nheads, headdim, dstate) # This function executes 5 sub-functions for computing mamba # - a good resource is the blog https://goombalab.github.io/blog/2024/mamba2-part3-algorithm/ @@ -154,18 +154,9 @@ def _mamba_chunk_scan_combined_fwd( if return_intermediate_states: return states else: - # Per-sequence final state at exact last token (correct when sequences share chunks) - return chunk_state_varlen( - B, - x, - dt, - dA_cumsum, - cu_seqlens, - states, - initial_states=initial_states, - last_chunk_indices=last_chunk_indices, - cu_chunk_seqlens=cu_chunk_seqlens, - ) + # With chunk-aligned sequences (one sequence per chunk), the final + # state for each sequence is simply the state at its last chunk. + return states[last_chunk_indices] def mamba_chunk_scan_combined_varlen( @@ -175,7 +166,6 @@ def mamba_chunk_scan_combined_varlen( B, C, chunk_size, - cu_seqlens, cu_chunk_seqlens, last_chunk_indices, seq_idx, @@ -197,7 +187,6 @@ def mamba_chunk_scan_combined_varlen( B: (seqlen, ngroups, dstate) C: (seqlen, ngroups, dstate) chunk_size: int - cu_seqlens: (batch + 1,) cu_chunk_seqlens: (nchunks + 1,) last_chunk_indices: (batch,) seq_idx: (nchunks,) @@ -213,7 +202,6 @@ def mamba_chunk_scan_combined_varlen( varlen_states: (batch, nheads, headdim, dstate) """ - assert cu_seqlens is not None, "cu_seqlens must be provided assuming varlen input" assert seq_idx is not None varlen_states = _mamba_chunk_scan_combined_fwd( @@ -230,7 +218,6 @@ def mamba_chunk_scan_combined_varlen( initial_states=initial_states, return_intermediate_states=return_intermediate_states, seq_idx=seq_idx, - cu_seqlens=cu_seqlens, cu_chunk_seqlens=cu_chunk_seqlens, last_chunk_indices=last_chunk_indices, dt_softplus=dt_softplus, diff --git a/tests/unit_tests/ssm/ops/test_ssd_combined.py b/tests/unit_tests/ssm/ops/test_ssd_combined.py index 52bceb94a91..f8289bc83f9 100644 --- a/tests/unit_tests/ssm/ops/test_ssd_combined.py +++ b/tests/unit_tests/ssm/ops/test_ssd_combined.py @@ -49,8 +49,6 @@ def setUp(self): self.ngroups = 2 self.dstate = 8 self.batch = 2 - # cu_seqlens: [0, 16, 32] -> two sequences of length 16 each - self.cu_seqlens = torch.tensor([0, 16, 32], dtype=torch.int32, device=self.device) # 2 chunks of 16 each self.cu_chunk_seqlens = torch.tensor([0, 16, 32], dtype=torch.int32, device=self.device) # last chunk index per sequence: seq0 ends in chunk 0, seq1 ends in chunk 1 @@ -73,7 +71,6 @@ def test_mamba_chunk_scan_combined_varlen_shape_and_no_nan(self): B=B, C=C, chunk_size=self.chunk_size, - cu_seqlens=self.cu_seqlens, cu_chunk_seqlens=self.cu_chunk_seqlens, last_chunk_indices=self.last_chunk_indices, seq_idx=self.seq_idx, @@ -102,7 +99,6 @@ def test_mamba_chunk_scan_combined_varlen_with_D_and_dt_bias(self): B=B, C=C, chunk_size=self.chunk_size, - cu_seqlens=self.cu_seqlens, cu_chunk_seqlens=self.cu_chunk_seqlens, last_chunk_indices=self.last_chunk_indices, seq_idx=self.seq_idx, @@ -115,8 +111,7 @@ def test_mamba_chunk_scan_combined_varlen_with_D_and_dt_bias(self): self.assertFalse(torch.isnan(out).any()) def test_mamba_chunk_scan_combined_varlen_single_sequence(self): - """Single sequence: cu_seqlens [0, 32], one sequence of 32.""" - cu_seqlens = torch.tensor([0, 32], dtype=torch.int32, device=self.device) + """Single sequence of 32 tokens, split into 2 chunks of 16.""" cu_chunk_seqlens = torch.tensor([0, 16, 32], dtype=torch.int32, device=self.device) last_chunk_indices = torch.tensor([1], dtype=torch.int64, device=self.device) seq_idx = torch.tensor([0, 0], dtype=torch.int32, device=self.device) @@ -135,7 +130,6 @@ def test_mamba_chunk_scan_combined_varlen_single_sequence(self): B=B, C=C, chunk_size=self.chunk_size, - cu_seqlens=cu_seqlens, cu_chunk_seqlens=cu_chunk_seqlens, last_chunk_indices=last_chunk_indices, seq_idx=seq_idx, From 89449df21f66c9e9732bab9ff9724c498b0c07db Mon Sep 17 00:00:00 2001 From: Lawrence McAfee Date: Thu, 5 Mar 2026 12:08:20 -0800 Subject: [PATCH 11/11] Add intermediate SSM and conv state extraction at chunk boundaries For prefix caching of Mamba layers, extract SSM and conv states at block-aligned chunk boundaries during varlen prefill. Since block_size_tokens % chunk_size == 0, every block boundary falls on a chunk boundary, making intermediate SSM state extraction pure indexing with no extra computation. Conv states are sliced from the pre-conv input tensor. Co-Authored-By: Claude Opus 4.6 --- megatron/core/ssm/mamba_mixer.py | 88 ++++++++- megatron/core/ssm/ops/ssd_combined.py | 19 +- tests/unit_tests/ssm/ops/test_ssd_combined.py | 168 ++++++++++++++++++ 3 files changed, 265 insertions(+), 10 deletions(-) diff --git a/megatron/core/ssm/mamba_mixer.py b/megatron/core/ssm/mamba_mixer.py index 541cd3f1fb2..9b69a7c0615 100644 --- a/megatron/core/ssm/mamba_mixer.py +++ b/megatron/core/ssm/mamba_mixer.py @@ -510,7 +510,8 @@ def _dynamic_inference_prefill( context: DynamicInferenceContext, conv_state: torch.Tensor, ssm_state: torch.Tensor, - ) -> torch.Tensor: + intermediate_token_offsets: Optional[List[List[int]]] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, List]]: """Helper to run dynamic inference prefill. All prefill requests (including chunked prefill) are processed together @@ -539,15 +540,22 @@ def _dynamic_inference_prefill( padded_token_count = zxBCdt.shape[0] zxBCdt = zxBCdt[:real_token_count] - y_prefill = self._ssm_prefill( + result = self._ssm_prefill( zxBCdt, conv_state=conv_state, ssm_state=ssm_state, seq_idx=seq_idx, cu_seqlens=cu_seqlens, batch_indices=batch_indices, + intermediate_token_offsets=intermediate_token_offsets, ) + if intermediate_token_offsets is not None: + y_prefill, intermediate_states = result + else: + y_prefill = result + intermediate_states = None + # Pad output back to padded token count. The caller (residual add, # tensor_merge) expects the output to match the padded input shape. if y_prefill.shape[0] < padded_token_count: @@ -555,6 +563,8 @@ def _dynamic_inference_prefill( y_prefill, (0, 0, 0, 0, 0, padded_token_count - y_prefill.shape[0]) ) + if intermediate_states is not None: + return y_prefill, intermediate_states return y_prefill def _decode( @@ -685,7 +695,8 @@ def _ssm_prefill( seq_idx: Optional[torch.Tensor] = None, cu_seqlens: Optional[torch.Tensor] = None, batch_indices: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + intermediate_token_offsets: Optional[List[List[int]]] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, List]]: """ Performs SSM computation for inference prefill step. @@ -698,9 +709,14 @@ def _ssm_prefill( cu_seqlens: Cumulative sequence lengths for variable-length sequences. batch_indices: A map from batch id to position in the Mamba state tensors for dynamic inference. + intermediate_token_offsets: Per-request list of token offsets (relative to + sequence start) at which to extract intermediate SSM and conv states. + Offsets must be multiples of chunk_size. Returns: - The output tensor of shape (l, b, d). + If intermediate_token_offsets is None: output tensor of shape (l, b, d). + If provided: (output, intermediate_states_per_request) where each entry is + (ssm_states, conv_states) or None. """ is_dynamic_batching = seq_idx is not None @@ -747,6 +763,7 @@ def _ssm_prefill( # cached initial states. We must loop over requests individually, # passing initial_states per-request with channels-last layout. num_requests = cu_seqlens.shape[0] - 1 + xBC_pre_conv = xBC if intermediate_token_offsets is not None else None xBC_parts = [] for r in range(num_requests): start = cu_seqlens[r].item() @@ -832,25 +849,55 @@ def _ssm_prefill( # the states tensor without a separate kernel call. chunk_boundaries = [0] last_chunk_indices = [] + intermediate_chunk_indices_list = [] + per_request_intermediate_counts = [] num_seqs = cu_seqlens.numel() - 1 for i in range(num_seqs): start = cu_seqlens[i].item() end = cu_seqlens[i + 1].item() + first_chunk_idx = len(chunk_boundaries) - 1 pos = start + self.chunk_size while pos < end: chunk_boundaries.append(pos) pos += self.chunk_size chunk_boundaries.append(end) last_chunk_indices.append(len(chunk_boundaries) - 2) + + # Build intermediate chunk indices for this sequence + if intermediate_token_offsets is not None: + seq_len = end - start + offsets = intermediate_token_offsets[i] + count = 0 + for offset in offsets: + assert offset > 0 and offset <= seq_len, ( + f"intermediate offset {offset} out of range for " + f"sequence {i} with length {seq_len}" + ) + assert offset % self.chunk_size == 0, ( + f"intermediate offset {offset} is not a multiple " + f"of chunk_size {self.chunk_size}" + ) + chunk_idx = first_chunk_idx + (offset // self.chunk_size) - 1 + intermediate_chunk_indices_list.append(chunk_idx) + count += 1 + per_request_intermediate_counts.append(count) + cu_chunk_seqlens = cu_seqlens.new_tensor(chunk_boundaries) last_chunk_indices = cu_seqlens.new_tensor(last_chunk_indices) + if intermediate_token_offsets is not None and intermediate_chunk_indices_list: + intermediate_chunk_indices = cu_seqlens.new_tensor( + intermediate_chunk_indices_list, dtype=torch.int64 + ) + else: + intermediate_chunk_indices = None + seq_idx_for_varlen = None if seq_idx is not None: chunk_starts = cu_chunk_seqlens[:-1] seq_idx_for_varlen = seq_idx[0, chunk_starts].contiguous() - ssm_varlen_states = mamba_chunk_scan_combined_varlen( + ssm_varlen_result = mamba_chunk_scan_combined_varlen( x=x, dt=dt, A=A, @@ -870,15 +917,44 @@ def _ssm_prefill( dt_bias=self.cp.get_dt_bias().float(), initial_states=initial_ssm_state, return_intermediate_states=False, + intermediate_chunk_indices=intermediate_chunk_indices, dt_softplus=True, dt_limit=(0.0, float("inf")), state_dtype=ssm_state.dtype, ) + if intermediate_chunk_indices is not None: + ssm_varlen_states, intermediate_ssm_states = ssm_varlen_result + else: + ssm_varlen_states = ssm_varlen_result + y = y.unsqueeze(0) z = z.unsqueeze(0) tensor_masked_update(ssm_state, batch_indices, ssm_varlen_states) + + # Assemble per-request intermediate states (SSM + conv) + if intermediate_chunk_indices is not None: + conv_dim = xBC_pre_conv.shape[-1] + intermediate_states_per_request = [] + ssm_offset = 0 + for i in range(num_seqs): + count = per_request_intermediate_counts[i] + if count == 0: + intermediate_states_per_request.append(None) + else: + req_ssm = intermediate_ssm_states[ssm_offset : ssm_offset + count] + # Extract conv states: last d_conv tokens of pre-conv xBC at each offset + req_conv_list = [] + for offset in intermediate_token_offsets[i]: + abs_pos = cu_seqlens[i].item() + offset + conv_state_at_offset = xBC_pre_conv[ + 0, abs_pos - self.d_conv : abs_pos, : + ].t() + req_conv_list.append(conv_state_at_offset) + req_conv = torch.stack(req_conv_list) + intermediate_states_per_request.append((req_ssm, req_conv)) + ssm_offset += count else: # Non-dynamic-batching path (static batching) initial_ssm_state = None @@ -913,6 +989,8 @@ def _ssm_prefill( z = self.cp.post_conv_ssm(z) y = self.norm(y, z) + if intermediate_token_offsets is not None and is_dynamic_batching: + return y, intermediate_states_per_request return y def _ssm_decode( diff --git a/megatron/core/ssm/ops/ssd_combined.py b/megatron/core/ssm/ops/ssd_combined.py index b599702d2c2..ce8664e433f 100644 --- a/megatron/core/ssm/ops/ssd_combined.py +++ b/megatron/core/ssm/ops/ssd_combined.py @@ -39,6 +39,7 @@ def _mamba_chunk_scan_combined_fwd( seq_idx=None, cu_chunk_seqlens=None, last_chunk_indices=None, + intermediate_chunk_indices=None, dt_softplus=False, dt_limit=(0.0, float("inf")), state_dtype=None, @@ -153,10 +154,13 @@ def _mamba_chunk_scan_combined_fwd( if return_intermediate_states: return states + + final_states = states[last_chunk_indices] + if intermediate_chunk_indices is not None: + intermediate_states = states[intermediate_chunk_indices] + return final_states, intermediate_states else: - # With chunk-aligned sequences (one sequence per chunk), the final - # state for each sequence is simply the state at its last chunk. - return states[last_chunk_indices] + return final_states def mamba_chunk_scan_combined_varlen( @@ -177,6 +181,7 @@ def mamba_chunk_scan_combined_varlen( dt_softplus=False, dt_limit=(0.0, float("inf")), return_intermediate_states=False, + intermediate_chunk_indices=None, state_dtype=None, ): """ @@ -196,10 +201,13 @@ def mamba_chunk_scan_combined_varlen( dt_bias: (nheads,) initial_states: (batch, nheads, headdim, dstate) dt_softplus: Whether to apply softplus to dt - out: (seqlen, nheads, headdim) preallocated output tensor + intermediate_chunk_indices: (N,) optional int64 tensor of chunk indices at which to + extract intermediate SSM states. When provided, returns (final_states, + intermediate_states) instead of just final_states. state_dtype: The data type of the ssm state Return: - varlen_states: (batch, nheads, headdim, dstate) + varlen_states: (batch, nheads, headdim, dstate), or + (varlen_states, intermediate_states) if intermediate_chunk_indices is provided """ assert seq_idx is not None @@ -220,6 +228,7 @@ def mamba_chunk_scan_combined_varlen( seq_idx=seq_idx, cu_chunk_seqlens=cu_chunk_seqlens, last_chunk_indices=last_chunk_indices, + intermediate_chunk_indices=intermediate_chunk_indices, dt_softplus=dt_softplus, dt_limit=dt_limit, state_dtype=state_dtype, diff --git a/tests/unit_tests/ssm/ops/test_ssd_combined.py b/tests/unit_tests/ssm/ops/test_ssd_combined.py index f8289bc83f9..ad1b60541e9 100644 --- a/tests/unit_tests/ssm/ops/test_ssd_combined.py +++ b/tests/unit_tests/ssm/ops/test_ssd_combined.py @@ -140,5 +140,173 @@ def test_mamba_chunk_scan_combined_varlen_single_sequence(self): self.assertFalse(torch.isnan(out).any()) +@unittest.skipIf(not HAVE_SSD_OPS, "SSD ops (Triton 3+) not available") +@unittest.skipIf(not torch.cuda.is_available(), "CUDA required for SSD ops") +class TestIntermediateStateExtraction(unittest.TestCase): + """Tests for intermediate_chunk_indices parameter.""" + + def setUp(self): + torch.manual_seed(42) + self.device = torch.device("cuda") + self.chunk_size = 16 + self.nheads = 4 + self.headdim = 16 + self.ngroups = 2 + self.dstate = 8 + + def _make_inputs(self, seqlen): + """Create random inputs for a single sequence of given length.""" + x = torch.randn(seqlen, self.nheads, self.headdim, device=self.device, dtype=torch.float32) + dt = torch.randn(seqlen, self.nheads, device=self.device, dtype=torch.float32) + A = torch.randn(self.nheads, device=self.device, dtype=torch.float32) + B = torch.randn(seqlen, self.ngroups, self.dstate, device=self.device, dtype=torch.float32) + C = torch.randn(seqlen, self.ngroups, self.dstate, device=self.device, dtype=torch.float32) + out = torch.empty(seqlen, self.nheads, self.headdim, device=self.device, dtype=torch.float32) + return x, dt, A, B, C, out + + def test_intermediate_states_shape_and_no_nan(self): + """1 sequence, 4 chunks. Request intermediates at chunks [0, 1, 2].""" + seqlen = 64 # 4 chunks of 16 + nchunks = seqlen // self.chunk_size + x, dt, A, B, C, out = self._make_inputs(seqlen) + cu_chunk_seqlens = torch.arange(0, seqlen + 1, self.chunk_size, dtype=torch.int32, device=self.device) + last_chunk_indices = torch.tensor([nchunks - 1], dtype=torch.int64, device=self.device) + seq_idx = torch.zeros(nchunks, dtype=torch.int32, device=self.device) + intermediate_chunk_indices = torch.tensor([0, 1, 2], dtype=torch.int64, device=self.device) + + result = mamba_chunk_scan_combined_varlen( + x=x, dt=dt, A=A, B=B, C=C, + chunk_size=self.chunk_size, + cu_chunk_seqlens=cu_chunk_seqlens, + last_chunk_indices=last_chunk_indices, + seq_idx=seq_idx, + out=out, + intermediate_chunk_indices=intermediate_chunk_indices, + ) + + self.assertIsInstance(result, tuple) + final_states, intermediate_states = result + self.assertEqual(final_states.shape, (1, self.nheads, self.headdim, self.dstate)) + self.assertEqual(intermediate_states.shape, (3, self.nheads, self.headdim, self.dstate)) + self.assertFalse(torch.isnan(final_states).any()) + self.assertFalse(torch.isnan(intermediate_states).any()) + + def test_intermediate_states_match_full_states(self): + """Intermediate states should match corresponding entries from full states.""" + seqlen = 64 # 4 chunks + nchunks = seqlen // self.chunk_size + x, dt, A, B, C, out = self._make_inputs(seqlen) + cu_chunk_seqlens = torch.arange(0, seqlen + 1, self.chunk_size, dtype=torch.int32, device=self.device) + last_chunk_indices = torch.tensor([nchunks - 1], dtype=torch.int64, device=self.device) + seq_idx = torch.zeros(nchunks, dtype=torch.int32, device=self.device) + + # Run with return_intermediate_states=True to get all states + out1 = torch.empty_like(out) + all_states = mamba_chunk_scan_combined_varlen( + x=x, dt=dt, A=A, B=B, C=C, + chunk_size=self.chunk_size, + cu_chunk_seqlens=cu_chunk_seqlens, + last_chunk_indices=last_chunk_indices, + seq_idx=seq_idx, + out=out1, + return_intermediate_states=True, + ) + + # Run with intermediate_chunk_indices + indices = [0, 1, 2] + intermediate_chunk_indices = torch.tensor(indices, dtype=torch.int64, device=self.device) + out2 = torch.empty_like(out) + final_states, intermediate_states = mamba_chunk_scan_combined_varlen( + x=x, dt=dt, A=A, B=B, C=C, + chunk_size=self.chunk_size, + cu_chunk_seqlens=cu_chunk_seqlens, + last_chunk_indices=last_chunk_indices, + seq_idx=seq_idx, + out=out2, + intermediate_chunk_indices=intermediate_chunk_indices, + ) + + # Intermediate states should match the corresponding all_states entries + for i, chunk_idx in enumerate(indices): + torch.testing.assert_close( + intermediate_states[i], all_states[chunk_idx], + msg=f"intermediate state at index {i} (chunk {chunk_idx}) does not match", + ) + + # Final state should match last chunk + torch.testing.assert_close(final_states[0], all_states[nchunks - 1]) + + def test_intermediate_states_multi_sequence(self): + """2 packed sequences, verify intermediate extraction across sequence boundaries.""" + seq1_len = 32 # 2 chunks + seq2_len = 48 # 3 chunks + total_len = seq1_len + seq2_len + x, dt, A, B, C, out = self._make_inputs(total_len) + + # cu_chunk_seqlens: seq1 has chunks at [0, 16, 32], seq2 at [32, 48, 64, 80] + boundaries = list(range(0, seq1_len + 1, self.chunk_size)) + \ + list(range(seq1_len + self.chunk_size, total_len + 1, self.chunk_size)) + cu_chunk_seqlens = torch.tensor(boundaries, dtype=torch.int32, device=self.device) + nchunks = len(boundaries) - 1 # 5 chunks total + # Last chunk for seq1 is chunk 1, for seq2 is chunk 4 + last_chunk_indices = torch.tensor([1, 4], dtype=torch.int64, device=self.device) + # seq_idx: [0, 0, 1, 1, 1] + seq_idx = torch.tensor([0, 0, 1, 1, 1], dtype=torch.int32, device=self.device) + + # Request chunk 0 from seq1 and chunks 2, 3 from seq2 + intermediate_chunk_indices = torch.tensor([0, 2, 3], dtype=torch.int64, device=self.device) + + # Also get full states for comparison + out_full = torch.empty_like(out) + all_states = mamba_chunk_scan_combined_varlen( + x=x, dt=dt, A=A, B=B, C=C, + chunk_size=self.chunk_size, + cu_chunk_seqlens=cu_chunk_seqlens, + last_chunk_indices=last_chunk_indices, + seq_idx=seq_idx, + out=out_full, + return_intermediate_states=True, + ) + + out2 = torch.empty_like(out) + final_states, intermediate_states = mamba_chunk_scan_combined_varlen( + x=x, dt=dt, A=A, B=B, C=C, + chunk_size=self.chunk_size, + cu_chunk_seqlens=cu_chunk_seqlens, + last_chunk_indices=last_chunk_indices, + seq_idx=seq_idx, + out=out2, + intermediate_chunk_indices=intermediate_chunk_indices, + ) + + self.assertEqual(final_states.shape, (2, self.nheads, self.headdim, self.dstate)) + self.assertEqual(intermediate_states.shape, (3, self.nheads, self.headdim, self.dstate)) + + # Verify intermediate states match full states + for i, chunk_idx in enumerate([0, 2, 3]): + torch.testing.assert_close(intermediate_states[i], all_states[chunk_idx]) + + def test_no_intermediate_returns_tensor(self): + """Without intermediate_chunk_indices, result should be a plain tensor.""" + seqlen = 32 + nchunks = seqlen // self.chunk_size + x, dt, A, B, C, out = self._make_inputs(seqlen) + cu_chunk_seqlens = torch.arange(0, seqlen + 1, self.chunk_size, dtype=torch.int32, device=self.device) + last_chunk_indices = torch.tensor([nchunks - 1], dtype=torch.int64, device=self.device) + seq_idx = torch.zeros(nchunks, dtype=torch.int32, device=self.device) + + result = mamba_chunk_scan_combined_varlen( + x=x, dt=dt, A=A, B=B, C=C, + chunk_size=self.chunk_size, + cu_chunk_seqlens=cu_chunk_seqlens, + last_chunk_indices=last_chunk_indices, + seq_idx=seq_idx, + out=out, + ) + + self.assertIsInstance(result, torch.Tensor) + self.assertEqual(result.shape, (1, self.nheads, self.headdim, self.dstate)) + + if __name__ == "__main__": unittest.main()