diff --git a/bonsai/models/mamba2/README.md b/bonsai/models/mamba2/README.md new file mode 100644 index 0000000..5831b4e --- /dev/null +++ b/bonsai/models/mamba2/README.md @@ -0,0 +1,56 @@ +# Mamba2 in JAX + +This directory contains a pure JAX implementation of the [Mamba2](https://arxiv.org/abs/2405.21060) model, using the Flax NNX API. + +## Model Configuration Support Status + +| Model Name | Config Support Status | +| :--- | :--- | +| [Mamba2ForCausalLM](https://arxiv.org/abs/2405.21060) | **✅ Supported** | +| [Mamba2Forecaster](https://arxiv.org/abs/2405.21060) | **✅ Supported** | + +## Pretrained Weights Support + +| Model | HuggingFace ID | Params | Status | +| :--- | :--- | :--- | :--- | +| Mamba2-130M | `state-spaces/mamba2-130m` | 130M | ✅ Verified | +| Mamba2-370M | `state-spaces/mamba2-370m` | 370M | ✅ Verified | +| Mamba2-780M | `state-spaces/mamba2-780m` | 780M | ✅ Verified | +| Mamba2-1.3B | `state-spaces/mamba2-1.3b` | 1.3B | ✅ Verified | +| Mamba2-2.7B | `state-spaces/mamba2-2.7b` | 2.7B | ✅ Verified | + +### Loading Pretrained Weights +```python +from bonsai.models.mamba2 import modeling + +# Load from HuggingFace Hub +model = modeling.Mamba2ForCausalLM.from_pretrained("state-spaces/mamba2-130m") +``` + +### Running this model + +Run Mamba2 model inference in action: + +```bash +python bonsai/models/mamba2/tests/run_model.py +``` + +### Hardware Validation Status + +| Hardware | Status | +| :--- | :--- | +| CPU | ✅ Runs | +| GPU (NVIDIA) | ✅ Runs | +| TPU v5e | ✅ Runs | + +## How to contribute to this model + +We welcome contributions! You can contribute to this model via the following: +* Add a model config variant from the above `🟡 Not started` to `class ModelConfig` in [modeling.py](modeling.py). Make sure your code is runnable on at least one hardware before creating a PR. +* Got some hardware? Run [run_model.py](tests/run_model.py) the existing configs above on hardwares marked `❔ Needs check`. Mark as `✅ Runs` or `⛔️ Not supported`. + +## References + +* **Paper**: [Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality](https://arxiv.org/abs/2405.21060) (Dao & Gu, ICML 2024) +* **Reference PyTorch Implementation**: [state-spaces/mamba](https://github.com/state-spaces/mamba) +* **Original JAX Port**: [CosmoNaught/mamba2-jax](https://github.com/CosmoNaught/mamba2-jax) \ No newline at end of file diff --git a/bonsai/models/mamba2/modeling.py b/bonsai/models/mamba2/modeling.py new file mode 100644 index 0000000..af2e697 --- /dev/null +++ b/bonsai/models/mamba2/modeling.py @@ -0,0 +1,518 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Mamba2 JAX/Flax NNX Implementation. + +A pure JAX/Flax implementation of the Mamba2 architecture using the State Space Duality (SSD) mechanism. +Reference: "Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality" +Paper: https://arxiv.org/abs/2405.21060 +""" + +from __future__ import annotations + +import dataclasses +from typing import Literal + +import jax +import jax.numpy as jnp +import optax +from flax import nnx + +# Configuration + + +@dataclasses.dataclass(frozen=True) +class Mamba2Config: + """Configuration for Mamba2 models.""" + + vocab_size: int = 50280 + pad_token_id: int = 0 + bos_token_id: int = 0 + eos_token_id: int = 0 + + hidden_size: int = 768 + state_size: int = 128 + head_dim: int = 64 + chunk_size: int = 256 + expand: int = 2 + conv_kernel: int = 4 + num_hidden_layers: int = 24 + layer_norm_epsilon: float = 1e-5 + + use_bias: bool = False + use_conv_bias: bool = True + hidden_act: Literal["silu", "gelu", "relu", "tanh"] = "silu" + + emb_initializer_range: float = 0.02 + A_initializer_range: tuple[float, float] = (1.0, 16.0) + + time_step_min: float = 0.001 + time_step_max: float = 0.1 + time_step_floor: float = 1e-4 + time_step_limit: tuple[float, float] = (0.0, float("inf")) + + residual_in_fp32: bool = True + tie_word_embeddings: bool = True + + @property + def intermediate_size(self) -> int: + return int(self.expand * self.hidden_size) + + @property + def num_heads(self) -> int: + return self.intermediate_size // self.head_dim + + @classmethod + def tiny(cls): + """Tiny configuration for testing.""" + return cls(vocab_size=1000, hidden_size=64, state_size=16, head_dim=16, chunk_size=32, num_hidden_layers=2) + + +# SSD Core Algorithm + + +def _pad_seq_dim(x: jnp.ndarray, pad_size: int) -> jnp.ndarray: + """Pad zeros at the end of the sequence dimension (axis=1).""" + if pad_size == 0: + return x + pad_width = [(0, 0)] * x.ndim + pad_width[1] = (0, pad_size) + return jnp.pad(x, pad_width, mode="constant", constant_values=0.0) + + +def segsum(x: jnp.ndarray) -> jnp.ndarray: + """Stable segment sum calculation. Input: (..., T) -> Output: (..., T, T).""" + T = x.shape[-1] + x_cumsum = jnp.cumsum(x, axis=-1) + x_segsum = x_cumsum[..., :, None] - x_cumsum[..., None, :] + mask = jnp.tril(jnp.ones((T, T), dtype=bool), k=0) + x_segsum = jnp.where(mask, x_segsum, -jnp.inf) + return x_segsum + + +def ssd_forward( + x: jnp.ndarray, # (B, L, H, P) + dt: jnp.ndarray, # (B, L, H) + A: jnp.ndarray, # (H,) + B_mat: jnp.ndarray, # (B, L, H, N) + C_mat: jnp.ndarray, # (B, L, H, N) + chunk_size: int, + D: jnp.ndarray, # (H,) + dt_bias: jnp.ndarray, # (H,) + dt_min: float, + dt_max: float, + initial_states: jnp.ndarray | None = None, + return_final_states: bool = False, +) -> tuple[jnp.ndarray, jnp.ndarray | None]: + """SSD (State Space Duality) forward pass with chunked computation. + + Args: + x: Input tensor (batch_size, seq_len, num_heads, head_dim) + dt: Time deltas (batch_size, seq_len, num_heads) + A: State transition scalar per head (num_heads,) + B_mat: Input-to-state matrix (batch_size, seq_len, num_heads, state_size) + C_mat: State-to-output matrix (batch_size, seq_len, num_heads, state_size) + chunk_size: Size of chunks for efficient computation + D: Skip connection weights (num_heads,) + dt_bias: Bias for time deltas (num_heads,) + dt_min: Minimum time delta after clamping + dt_max: Maximum time delta after clamping + initial_states: Optional initial SSM states (batch, 1, heads, head_dim, state_size) + return_final_states: Whether to return final SSM states + + Returns: + y: Output tensor (batch_size, seq_len, num_heads, head_dim) + final_state: Optional final states (batch_size, num_heads, head_dim, state_size) + """ + _B_size, seq_len, num_heads, _head_dim = x.shape + pad_size = (chunk_size - seq_len % chunk_size) % chunk_size + + # Apply dt bias with softplus and clamp + dt = jax.nn.softplus(dt + dt_bias) + dt = jnp.clip(dt, dt_min, dt_max) + + # Pad tensors along sequence dimension + x_padded = _pad_seq_dim(x, pad_size) + dt_padded = _pad_seq_dim(dt, pad_size) + B_padded = _pad_seq_dim(B_mat, pad_size) + C_padded = _pad_seq_dim(C_mat, pad_size) + + # D residual connection + D_residual = D.reshape(1, 1, num_heads, 1) * x_padded + + # Discretize x and A + x_disc = x_padded * dt_padded[..., None] + A_disc = A.astype(x_disc.dtype) * dt_padded + + # Chunk everything + def chunk_tensor(t): + b, cl, *remaining = t.shape + return t.reshape(b, cl // chunk_size, chunk_size, *remaining) + + x_blk = chunk_tensor(x_disc) + A_blk = chunk_tensor(A_disc) + B_blk = chunk_tensor(B_padded) + C_blk = chunk_tensor(C_padded) + + # A cumsum over intra-chunk time dimension + A_blk2 = jnp.transpose(A_blk, (0, 3, 1, 2)) + A_cumsum = jnp.cumsum(A_blk2, axis=-1) + + # 1. Intra-chunk (diagonal blocks) + L_mat = jnp.exp(segsum(A_blk2)) + Y_diag = jnp.einsum("bclhn,bcshn,bhcls,bcshp->bclhp", C_blk, B_blk, L_mat, x_blk) + + # 2. States within each chunk + decay_states = jnp.exp(A_cumsum[..., -1:] - A_cumsum) + states = jnp.einsum("bclhn,bhcl,bclhp->bchpn", B_blk, decay_states, x_blk) + + # 3. Inter-chunk recurrence + if initial_states is None: + initial_states = jnp.zeros_like(states[:, :1, ...]) + states = jnp.concatenate([initial_states, states], axis=1) + + A_end = A_cumsum[..., -1] + A_end_padded = jnp.pad(A_end, ((0, 0), (0, 0), (1, 0))) + decay_chunk = jnp.exp(segsum(A_end_padded)) + new_states = jnp.einsum("bhzc,bchpn->bzhpn", decay_chunk, states) + states, final_state = new_states[:, :-1, ...], new_states[:, -1, ...] + + # 4. Convert states -> outputs + state_decay_out = jnp.exp(A_cumsum) + Y_off = jnp.einsum("bclhn,bchpn,bhcl->bclhp", C_blk, states, state_decay_out) + + y = Y_diag + Y_off + b, c, l, h, p = y.shape + y = y.reshape(b, c * l, h, p) + y = y + D_residual + + # Remove padding + if pad_size > 0: + y = y[:, :seq_len, :, :] + + return (y, final_state) if return_final_states else (y, None) + + +# Model Components +ACT2FN = {"silu": nnx.silu, "gelu": nnx.gelu, "relu": nnx.relu, "tanh": jnp.tanh} + + +class RMSNorm(nnx.Module): + """RMSNorm with optional residual gating.""" + + def __init__(self, hidden_size: int, eps: float = 1e-6, gate_residual: bool = False, *, rngs: nnx.Rngs): + self.hidden_size = hidden_size + self.eps = eps + self.gate_residual = gate_residual + self.weight = nnx.Param(jnp.ones((hidden_size,))) + + @jax.named_scope("rms_norm") + def __call__(self, hidden_states: jnp.ndarray, residual: jnp.ndarray | None = None) -> jnp.ndarray: + x = hidden_states.astype(jnp.float32) + if residual is not None and self.gate_residual: + x = x * nnx.silu(residual.astype(jnp.float32)) + variance = jnp.mean(x**2, axis=-1, keepdims=True) + x = x * jax.lax.rsqrt(variance + self.eps) * self.weight[:] + return x.astype(hidden_states.dtype) + + +class DepthwiseConv1d(nnx.Module): + """Depthwise causal 1D convolution. Expects (batch, seq_len, channels).""" + + def __init__(self, features: int, kernel_size: int, use_bias: bool = True, *, rngs: nnx.Rngs): + self.features = features + self.kernel_size = kernel_size + self.conv = nnx.Conv( + in_features=features, + out_features=features, + kernel_size=(kernel_size,), + padding=((kernel_size - 1, 0),), + feature_group_count=features, + use_bias=use_bias, + rngs=rngs, + ) + + @jax.named_scope("depthwise_conv1d") + def __call__(self, x: jnp.ndarray) -> jnp.ndarray: + return self.conv(x) + + +class Mamba2Mixer(nnx.Module): + """Mamba2 mixer block using the SSD algorithm.""" + + def __init__(self, cfg: Mamba2Config, layer_idx: int, *, rngs: nnx.Rngs): + self.cfg = cfg + self.layer_idx = layer_idx + self.hidden_size = cfg.hidden_size + self.ssm_state_size = cfg.state_size + self.intermediate_size = cfg.intermediate_size + self.head_dim = cfg.head_dim + self.num_heads = cfg.num_heads + self.chunk_size = cfg.chunk_size + self.dt_min, self.dt_max = cfg.time_step_limit + self.act = ACT2FN[cfg.hidden_act] + + # Input projection + proj_size = 2 * (self.intermediate_size + self.ssm_state_size) + self.num_heads + self.in_proj = nnx.Linear(cfg.hidden_size, proj_size, use_bias=cfg.use_bias, rngs=rngs) + + # Depthwise conv + conv1d_dim = self.intermediate_size + 2 * self.ssm_state_size + self.conv1d = DepthwiseConv1d(conv1d_dim, cfg.conv_kernel, use_bias=cfg.use_conv_bias, rngs=rngs) + + # SSM parameters + key = rngs.params() + low, high = cfg.time_step_min, cfg.time_step_max + floor = cfg.time_step_floor + dt_init = jnp.exp(jax.random.uniform(key, (cfg.num_heads,)) * (jnp.log(high) - jnp.log(low)) + jnp.log(low)) + dt_init = jnp.maximum(dt_init, floor) + self.dt_bias = nnx.Param(dt_init + jnp.log(-jnp.expm1(-dt_init))) # inverse softplus + + key = rngs.params() + A_low, A_high = cfg.A_initializer_range + A_init = jax.random.uniform(key, (cfg.num_heads,), minval=A_low, maxval=A_high) + self.A_log = nnx.Param(jnp.log(A_init)) + + self.D = nnx.Param(jnp.ones((cfg.num_heads,))) + + # Internal norm and output projection + self.norm = RMSNorm(self.intermediate_size, eps=1e-5, gate_residual=True, rngs=rngs) + self.out_proj = nnx.Linear(self.intermediate_size, cfg.hidden_size, use_bias=cfg.use_bias, rngs=rngs) + + @jax.named_scope("mamba2_mixer") + def __call__( + self, hidden_states: jnp.ndarray, initial_state: jnp.ndarray | None = None, return_final_state: bool = False + ) -> tuple[jnp.ndarray, jnp.ndarray | None]: + B_size, L, _ = hidden_states.shape + + # 1) Parallel projection + zxbcdt = self.in_proj(hidden_states) + d_mlp = (zxbcdt.shape[-1] - 2 * self.intermediate_size - 2 * self.ssm_state_size - self.num_heads) // 2 + + z0, x0, z, xBC, dt = jnp.split( + zxbcdt, + [ + d_mlp, + 2 * d_mlp, + 2 * d_mlp + self.intermediate_size, + 2 * d_mlp + self.intermediate_size + self.intermediate_size + 2 * self.ssm_state_size, + ], + axis=-1, + ) + + # 2) Depthwise causal convolution + xBC = self.act(self.conv1d(xBC)) + x, B_t, C_t = jnp.split(xBC, [self.intermediate_size, self.intermediate_size + self.ssm_state_size], axis=-1) + + # 3) SSD forward + init_state = initial_state[:, None, ...] if initial_state is not None else None + A = -jnp.exp(self.A_log[:].astype(jnp.float32)) + + B_exp = jnp.broadcast_to(jnp.expand_dims(B_t, 2), (B_size, L, self.num_heads, self.ssm_state_size)) + C_exp = jnp.broadcast_to(jnp.expand_dims(C_t, 2), (B_size, L, self.num_heads, self.ssm_state_size)) + + y, final_state = ssd_forward( + x=x.reshape(B_size, L, -1, self.head_dim), + dt=dt, + A=A, + B_mat=B_exp, + C_mat=C_exp, + chunk_size=self.chunk_size, + D=self.D[:], + dt_bias=self.dt_bias[:], + dt_min=self.dt_min, + dt_max=self.dt_max, + initial_states=init_state, + return_final_states=return_final_state, + ) + y = y.reshape(B_size, L, -1) + + # 4) Residual gate normalization + y = self.norm(y, residual=z) + if d_mlp > 0: + y = jnp.concatenate([self.act(z0) * x0, y], axis=-1) + + # 5) Output projection + return self.out_proj(y), final_state + + +class Mamba2Block(nnx.Module): + """Single Mamba2 block with pre-norm and residual connection.""" + + def __init__(self, cfg: Mamba2Config, layer_idx: int, *, rngs: nnx.Rngs): + self.cfg = cfg + self.residual_in_fp32 = cfg.residual_in_fp32 + self.norm = RMSNorm(cfg.hidden_size, eps=cfg.layer_norm_epsilon, rngs=rngs) + self.mixer = Mamba2Mixer(cfg, layer_idx=layer_idx, rngs=rngs) + + def __call__( + self, hidden_states: jnp.ndarray, initial_state: jnp.ndarray | None = None, return_final_state: bool = False + ) -> tuple[jnp.ndarray, jnp.ndarray | None]: + residual = hidden_states + hs = self.norm(hidden_states.astype(jnp.float32)) + if self.residual_in_fp32: + residual = residual.astype(jnp.float32) + hs_out, last_state = self.mixer(hs, initial_state=initial_state, return_final_state=return_final_state) + return residual + hs_out, last_state + + +class Mamba2Model(nnx.Module): + """Mamba2 backbone model (no task-specific head).""" + + def __init__(self, cfg: Mamba2Config, *, rngs: nnx.Rngs): + self.cfg = cfg + self.embedder = nnx.Embed(num_embeddings=cfg.vocab_size, features=cfg.hidden_size, rngs=rngs) + self.layers = nnx.List([Mamba2Block(cfg, layer_idx=i, rngs=rngs) for i in range(cfg.num_hidden_layers)]) + self.final_norm = RMSNorm(cfg.hidden_size, eps=cfg.layer_norm_epsilon, rngs=rngs) + + @jax.named_scope("mamba2_backbone") + def __call__( + self, + input_ids: jnp.ndarray | None = None, + inputs_embeds: jnp.ndarray | None = None, + initial_states: list[jnp.ndarray] | None = None, + output_hidden_states: bool = False, + output_last_ssm_states: bool = False, + ) -> dict[str, jnp.ndarray | list[jnp.ndarray] | None]: + if (input_ids is None) == (inputs_embeds is None): + raise ValueError("Specify exactly one of input_ids or inputs_embeds") + + hidden_states = self.embedder(input_ids) if inputs_embeds is None else inputs_embeds + + if initial_states is None: + initial_states = [None] * self.cfg.num_hidden_layers + elif len(initial_states) != self.cfg.num_hidden_layers: + raise ValueError("initial_states length must equal num_hidden_layers") + + all_hidden_states = [] if output_hidden_states else None + all_last_states = [] if output_last_ssm_states else None + + for layer, init_state in zip(self.layers, initial_states): + hidden_states, last_state = layer( + hidden_states, initial_state=init_state, return_final_state=output_last_ssm_states + ) + if output_hidden_states: + all_hidden_states.append(hidden_states) + if output_last_ssm_states: + all_last_states.append(last_state) + + hidden_states = self.final_norm(hidden_states) + if output_hidden_states: + all_hidden_states.append(hidden_states) + + return { + "last_hidden_state": hidden_states, + "hidden_states": all_hidden_states, + "last_ssm_states": all_last_states, + } + + +class Mamba2ForCausalLM(nnx.Module): + """Mamba2 model with causal language modeling head.""" + + def __init__(self, cfg: Mamba2Config, *, rngs: nnx.Rngs): + self.cfg = cfg + self.backbone = Mamba2Model(cfg, rngs=rngs) + if not cfg.tie_word_embeddings: + self.lm_head = nnx.Linear(cfg.hidden_size, cfg.vocab_size, use_bias=False, rngs=rngs) + else: + self.lm_head = None + + @jax.named_scope("mamba2_causal_lm") + def __call__(self, input_ids: jnp.ndarray, labels: jnp.ndarray | None = None) -> dict[str, jnp.ndarray | None]: + backbone_outputs = self.backbone(input_ids=input_ids) + hidden_states = backbone_outputs["last_hidden_state"] + + if self.cfg.tie_word_embeddings: + logits = hidden_states @ self.backbone.embedder.embedding[:].T + else: + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + shift_logits = logits[:, :-1, :].reshape(-1, logits.shape[-1]) + shift_labels = labels[:, 1:].reshape(-1) + loss = optax.softmax_cross_entropy_with_integer_labels(shift_logits, shift_labels).mean() + + return {"logits": logits, "loss": loss} + + @classmethod + def from_pretrained( + cls, + model_id_or_path: str, + *, + cfg: Mamba2Config | None = None, + dtype: jnp.dtype = jnp.float32, + seed: int = 0, + revision: str = "main", + ) -> "Mamba2ForCausalLM": + # Local import to avoid hard dependency cycles + from bonsai.models.mamba2 import params as mamba2_params + + # If cfg is None, params.create_model_from_huggingface already infers it. + if "/" in model_id_or_path and not model_id_or_path.startswith((".", "/")): + return mamba2_params.create_model_from_huggingface( + model_id_or_path, cfg=cfg, dtype=dtype, seed=seed, revision=revision + ) + + return mamba2_params.create_model_from_torch_checkpoint(model_id_or_path, cfg=cfg, dtype=dtype, seed=seed) + + +class Mamba2Forecaster(nnx.Module): + """Mamba2-based time series forecaster.""" + + def __init__( + self, + input_dim: int, + d_model: int = 768, + n_layers: int = 4, + output_dim: int = 1, + forecast_horizon: int = 24, + d_state: int = 128, + headdim: int = 64, + d_conv: int = 4, + chunk_size: int = 256, + *, + rngs: nnx.Rngs, + ): + self.forecast_horizon = forecast_horizon + self.output_dim = output_dim + + self.input_proj = nnx.Linear(input_dim, d_model, rngs=rngs) + cfg = Mamba2Config( + vocab_size=1, + hidden_size=d_model, + state_size=d_state, + head_dim=headdim, + conv_kernel=d_conv, + chunk_size=chunk_size, + num_hidden_layers=n_layers, + ) + self.mamba2 = Mamba2Model(cfg, rngs=rngs) + self.output_proj = nnx.Linear(d_model, output_dim * forecast_horizon, rngs=rngs) + + @jax.named_scope("mamba2_forecaster") + def __call__(self, x: jnp.ndarray) -> jnp.ndarray: + """Forward pass. Input: (batch, seq_len, input_dim) -> Output: (batch, forecast_horizon, output_dim).""" + x_proj = self.input_proj(x) + outputs = self.mamba2(input_ids=None, inputs_embeds=x_proj) + last_hidden = outputs["last_hidden_state"][:, -1, :] + out = self.output_proj(last_hidden) + return out.reshape(x.shape[0], self.forecast_horizon, self.output_dim) + + +@jax.jit +def forward(model: Mamba2ForCausalLM, input_ids: jnp.ndarray, labels: jnp.ndarray | None = None): + """JIT-compiled forward pass for Mamba2ForCausalLM.""" + return model(input_ids, labels) diff --git a/bonsai/models/mamba2/params.py b/bonsai/models/mamba2/params.py new file mode 100644 index 0000000..8ec7134 --- /dev/null +++ b/bonsai/models/mamba2/params.py @@ -0,0 +1,406 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Parameter utilities for Mamba2 models.""" + +import re +from collections.abc import Mapping +from typing import Any + +import jax +import jax.numpy as jnp +from flax import nnx + +from bonsai.models.mamba2 import modeling + + +def create_random_model(cfg: modeling.Mamba2Config, seed: int = 0) -> modeling.Mamba2ForCausalLM: + """Create a randomly initialized Mamba2ForCausalLM. + + Args: + cfg: Mamba2Config for the model. + seed: Random seed for initialization. + + Returns: + Randomly initialized Mamba2ForCausalLM. + """ + return modeling.Mamba2ForCausalLM(cfg, rngs=nnx.Rngs(seed)) + + +def create_random_forecaster( + input_dim: int, + d_model: int = 768, + n_layers: int = 4, + output_dim: int = 1, + forecast_horizon: int = 24, + seed: int = 0, + **kwargs, +) -> modeling.Mamba2Forecaster: + """Create a randomly initialized Mamba2Forecaster. + + Args: + input_dim: Number of input features per timestep. + d_model: Hidden dimension of the model. + n_layers: Number of Mamba2 layers. + output_dim: Number of output features per timestep. + forecast_horizon: Number of future timesteps to predict. + seed: Random seed for initialization. + **kwargs: Additional arguments passed to Mamba2Forecaster. + + Returns: + Randomly initialized Mamba2Forecaster. + """ + return modeling.Mamba2Forecaster( + input_dim=input_dim, + d_model=d_model, + n_layers=n_layers, + output_dim=output_dim, + forecast_horizon=forecast_horizon, + rngs=nnx.Rngs(seed), + **kwargs, + ) + + +def count_parameters(model: nnx.Module) -> int: + """Count the total number of trainable parameters in a model. + + Args: + model: NNX module to count parameters for. + + Returns: + Total number of parameters. + """ + _graphdef, state = nnx.split(model) + params = state.filter(nnx.Param) + return sum(p.size for p in jax.tree.leaves(params)) + + +def _get_key_mapping() -> list[tuple[re.Pattern, str, str]]: + """Get mapping from PyTorch state-spaces/mamba2 keys to JAX parameter paths. + + Based on the official state-spaces/mamba repository checkpoint format. + Keys follow the pattern: backbone.layers.{idx}.mixer.{param} + + Returns list of (pattern, replacement, transform_type) tuples. + Transform types: LINEAR, CONV1D, EMBED, SCALE, NONE + """ + return [ + # Embedding - state-spaces uses "embedding" (singular) + (re.compile(r"^backbone\.embedding\.weight$"), "backbone.embedder.embedding", "EMBED"), + # Layer pre-norm (RMSNorm before mixer) + (re.compile(r"^backbone\.layers\.(\d+)\.norm\.weight$"), r"backbone.layers.\1.norm.weight", "SCALE"), + # Mixer input projection + ( + re.compile(r"^backbone\.layers\.(\d+)\.mixer\.in_proj\.weight$"), + r"backbone.layers.\1.mixer.in_proj.kernel", + "LINEAR", + ), + # Mixer conv1d (DepthwiseConv1d wraps nnx.Conv as self.conv) + ( + re.compile(r"^backbone\.layers\.(\d+)\.mixer\.conv1d\.weight$"), + r"backbone.layers.\1.mixer.conv1d.conv.kernel", + "CONV1D", + ), + ( + re.compile(r"^backbone\.layers\.(\d+)\.mixer\.conv1d\.bias$"), + r"backbone.layers.\1.mixer.conv1d.conv.bias", + "NONE", + ), + # SSM parameters (A_log, D, dt_bias) + (re.compile(r"^backbone\.layers\.(\d+)\.mixer\.A_log$"), r"backbone.layers.\1.mixer.A_log", "NONE"), + (re.compile(r"^backbone\.layers\.(\d+)\.mixer\.D$"), r"backbone.layers.\1.mixer.D", "NONE"), + (re.compile(r"^backbone\.layers\.(\d+)\.mixer\.dt_bias$"), r"backbone.layers.\1.mixer.dt_bias", "NONE"), + # Mixer internal norm (RMSNorm with residual gate) + ( + re.compile(r"^backbone\.layers\.(\d+)\.mixer\.norm\.weight$"), + r"backbone.layers.\1.mixer.norm.weight", + "SCALE", + ), + # Mixer output projection + ( + re.compile(r"^backbone\.layers\.(\d+)\.mixer\.out_proj\.weight$"), + r"backbone.layers.\1.mixer.out_proj.kernel", + "LINEAR", + ), + ( + re.compile(r"^backbone\.layers\.(\d+)\.mixer\.out_proj\.bias$"), + r"backbone.layers.\1.mixer.out_proj.bias", + "NONE", + ), + # Final norm + (re.compile(r"^backbone\.norm_f\.weight$"), "backbone.final_norm.weight", "SCALE"), + # LM head (may be tied to embeddings) + (re.compile(r"^lm_head\.weight$"), "lm_head.kernel", "LINEAR"), + ] + + +def _transform_tensor(tensor: jnp.ndarray, transform_type: str) -> jnp.ndarray: + """Apply transformation to convert PyTorch tensor to JAX format.""" + if transform_type == "LINEAR": + return tensor.T + elif transform_type == "CONV1D": + # PyTorch conv1d: (out_channels, in_channels/groups, kernel_size) + # JAX conv: (kernel_size, in_channels/groups, out_channels) + return jnp.transpose(tensor, (2, 1, 0)) + elif transform_type == "EMBED": + return tensor + elif transform_type == "SCALE": + return tensor + elif transform_type == "NONE": + return tensor + else: + raise ValueError(f"Unknown transform type: {transform_type}") + + +def _set_nested_attr(obj: Any, path: str, value: Any) -> None: + """Set a nested attribute on an object using dot-separated path.""" + parts = path.split(".") + for part in parts[:-1]: + if obj is None: + raise AttributeError(f"Encountered None while traversing path '{path}' at '{part}'") + if part.isdigit(): + obj = obj[int(part)] + else: + obj = getattr(obj, part) + + final_part = parts[-1] + if obj is None: + raise AttributeError(f"Encountered None while setting '{path}'") + + if final_part.isdigit(): + obj[int(final_part)] = value + return + + if not hasattr(obj, final_part): + raise AttributeError(f"Object of type {type(obj).__name__} has no attribute '{final_part}' (path='{path}')") + + attr = getattr(obj, final_part) + if isinstance(attr, nnx.Param): + attr[...] = value + else: + setattr(obj, final_part, value) + + +def load_pytorch_weights( + model: modeling.Mamba2ForCausalLM, + state_dict: Mapping[str, Any], + dtype: jnp.dtype = jnp.float32, + strict: bool = False, +) -> tuple[modeling.Mamba2ForCausalLM, list[str], list[str]]: + key_mapping = _get_key_mapping() + loaded_keys: list[str] = [] + skipped_keys: list[str] = [] + + tie = getattr(model.cfg, "tie_word_embeddings", False) + embedding_loaded = False + + for pt_key, pt_tensor in state_dict.items(): + matched_rule = False + + for pattern, replacement, transform_type in key_mapping: + if not pattern.match(pt_key): + continue + + matched_rule = True + jax_path = pattern.sub(replacement, pt_key) + + # Track embedding load + if pt_key == "backbone.embedding.weight": + embedding_loaded = True + + # Special-case tied head: + # - never overwrite embedding with transposed lm_head + if pt_key == "lm_head.weight" and tie and getattr(model, "lm_head", None) is None: + if embedding_loaded: + # Redundant in tied models (and HF mamba2 checkpoints contain both). + loaded_keys.append(f"{pt_key} (skipped: tied embeddings)") + else: + # Fallback: if embedding.weight is absent, use lm_head.weight *without transpose* + tensor = jnp.array(pt_tensor, dtype=dtype) # NO _transform_tensor here + try: + _set_nested_attr(model, "backbone.embedder.embedding", tensor) + loaded_keys.append(f"{pt_key} (used as embedding; no transpose)") + embedding_loaded = True + except (AttributeError, IndexError, KeyError, TypeError) as e: + if strict: + raise ValueError(f"Failed to set tied embedding from {pt_key}: {e}") from e + skipped_keys.append(f"{pt_key} (tied embedding set failed: {e})") + break + + # Normal path + tensor = jnp.array(pt_tensor, dtype=dtype) + tensor = _transform_tensor(tensor, transform_type) + + try: + _set_nested_attr(model, jax_path, tensor) + loaded_keys.append(pt_key) + except (AttributeError, IndexError, KeyError, TypeError) as e: + if strict: + raise ValueError(f"Failed to set {jax_path} from {pt_key}: {e}") from e + skipped_keys.append(f"{pt_key} (error: {e})") + + break # only first matching rule + + if not matched_rule: + skipped_keys.append(pt_key) + + if strict and skipped_keys: + raise ValueError(f"Unexpected/unloaded keys in state_dict (first 20): {skipped_keys[:20]}") + + return model, loaded_keys, skipped_keys + + +def create_model_from_torch_checkpoint( + checkpoint_path: str, + cfg: modeling.Mamba2Config | None = None, + dtype: jnp.dtype = jnp.float32, + seed: int = 0, +) -> modeling.Mamba2ForCausalLM: + """Create model from PyTorch checkpoint file. + + Args: + checkpoint_path: Path to .pt/.bin checkpoint or directory with model files. + cfg: Model config. Required for now. + dtype: Target dtype. + seed: Random seed for initialization (weights will be overwritten). + + Returns: + Mamba2ForCausalLM with loaded weights. + """ + import os + + if os.path.isdir(checkpoint_path): + safetensors_path = os.path.join(checkpoint_path, "model.safetensors") + pytorch_path = os.path.join(checkpoint_path, "pytorch_model.bin") + if os.path.exists(safetensors_path): + checkpoint_path = safetensors_path + elif os.path.exists(pytorch_path): + checkpoint_path = pytorch_path + else: + raise FileNotFoundError(f"No checkpoint found in {checkpoint_path}") + + if checkpoint_path.endswith(".safetensors"): + try: + from safetensors import safe_open + except ImportError as e: + raise ImportError("safetensors required: pip install safetensors") from e + + state_dict = {} + with safe_open(checkpoint_path, framework="numpy") as f: + for key in f.keys(): + state_dict[key] = f.get_tensor(key) + else: + try: + import torch + except ImportError as e: + raise ImportError("torch required: pip install torch") from e + + checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=True) + state_dict = {k: v.numpy() for k, v in checkpoint.items()} + + if cfg is None: + raise ValueError("Config inference not yet implemented. Please provide cfg.") + + model = create_random_model(cfg, seed=seed) + model, loaded, skipped = load_pytorch_weights(model, state_dict, dtype=dtype) + + print(f"Loaded {len(loaded)} parameters") + if skipped: + print(f"Skipped {len(skipped)} keys: {skipped[:5]}...") + + return model + + +def create_model_from_huggingface( + model_id: str, + cfg: modeling.Mamba2Config | None = None, + dtype: jnp.dtype = jnp.float32, + seed: int = 0, + revision: str = "main", +) -> modeling.Mamba2ForCausalLM: + """Create model from HuggingFace Hub. + + Args: + model_id: HuggingFace model ID (e.g., "state-spaces/mamba2-130m"). + cfg: Model config. If None, will try to infer from config.json. + dtype: Target dtype. + seed: Random seed for initialization. + revision: Git revision to use. + + Returns: + Mamba2ForCausalLM with loaded weights. + + Example: + >>> cfg = modeling.Mamba2Config( + ... vocab_size=50280, hidden_size=768, + ... state_size=128, num_hidden_layers=24, head_dim=64 + ... ) + >>> model = create_model_from_huggingface("state-spaces/mamba2-130m", cfg=cfg) + """ + try: + from huggingface_hub import hf_hub_download + except ImportError as e: + raise ImportError("huggingface_hub required: pip install huggingface_hub") from e + + # Try safetensors first, fall back to pytorch_model.bin + try: + checkpoint_path = hf_hub_download(model_id, "model.safetensors", revision=revision) + except Exception: + try: + checkpoint_path = hf_hub_download(model_id, "pytorch_model.bin", revision=revision) + except Exception as e: + raise FileNotFoundError(f"Could not find model.safetensors or pytorch_model.bin in {model_id}") from e + + if cfg is None: + import json + + config_path = hf_hub_download(model_id, "config.json", revision=revision) + with open(config_path) as f: + hf_config = json.load(f) + + cfg = modeling.Mamba2Config( + vocab_size=hf_config.get("vocab_size", 50280), + hidden_size=hf_config.get("d_model", hf_config.get("hidden_size", 768)), + state_size=hf_config.get("d_state", hf_config.get("state_size", 128)), + num_hidden_layers=hf_config.get("n_layer", hf_config.get("num_hidden_layers", 24)), + expand=hf_config.get("expand", 2), + conv_kernel=hf_config.get("d_conv", hf_config.get("conv_kernel", 4)), + head_dim=hf_config.get("headdim", hf_config.get("head_dim", 64)), + ) + + return create_model_from_torch_checkpoint(checkpoint_path, cfg=cfg, dtype=dtype, seed=seed) + + +def print_checkpoint_keys(checkpoint_path: str) -> None: + """Print keys from a checkpoint file for debugging. + + Args: + checkpoint_path: Path to checkpoint file. + """ + if checkpoint_path.endswith(".safetensors"): + from safetensors import safe_open + + with safe_open(checkpoint_path, framework="numpy") as f: + print(f"Keys in {checkpoint_path}:") + for key in sorted(f.keys()): + shape = f.get_tensor(key).shape + print(f" {key}: {shape}") + else: + import torch + + state_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=True) + print(f"Keys in {checkpoint_path}:") + for key in sorted(state_dict.keys()): + shape = tuple(state_dict[key].shape) + print(f" {key}: {shape}") diff --git a/bonsai/models/mamba2/tests/__init__.py b/bonsai/models/mamba2/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/bonsai/models/mamba2/tests/artifacts/generate_golden_logits.py b/bonsai/models/mamba2/tests/artifacts/generate_golden_logits.py new file mode 100644 index 0000000..d8b2ac2 --- /dev/null +++ b/bonsai/models/mamba2/tests/artifacts/generate_golden_logits.py @@ -0,0 +1,57 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import numpy as np +import torch + +torch.backends.cuda.matmul.allow_tf32 = False +torch.backends.cudnn.allow_tf32 = False + +from mamba_ssm import MambaLMHeadModel + +OUTPUT_DIR = os.path.join(os.path.dirname(__file__), "artifacts") + + +def main(): + model = MambaLMHeadModel.from_pretrained( + "state-spaces/mamba2-130m", + device="cuda", + dtype=torch.float32, + ).eval() + + for mod in model.modules(): + if hasattr(mod, "use_mem_eff_path"): + mod.use_mem_eff_path = False + + input_ids = np.array([[1, 2, 3, 4, 5, 6, 7, 8]], dtype=np.int32) + torch_input = torch.tensor(input_ids, device="cuda") + + with torch.no_grad(): + hidden = model.backbone(torch_input).cpu().numpy() + logits = model.lm_head(model.backbone(torch_input)).cpu().numpy() + + os.makedirs(OUTPUT_DIR, exist_ok=True) + np.savez_compressed( + os.path.join(OUTPUT_DIR, "golden_mamba2_130m.npz"), + input_ids=input_ids, + last_hidden_state=hidden.astype(np.float32), + logits_slice=logits[:, :, :256].astype(np.float32), + ) + print(f"Saved to {OUTPUT_DIR}/golden_mamba2_130m.npz") + + +if __name__ == "__main__": + main() diff --git a/bonsai/models/mamba2/tests/artifacts/golden_mamba2_130m.npz b/bonsai/models/mamba2/tests/artifacts/golden_mamba2_130m.npz new file mode 100644 index 0000000..30b802c Binary files /dev/null and b/bonsai/models/mamba2/tests/artifacts/golden_mamba2_130m.npz differ diff --git a/bonsai/models/mamba2/tests/run_model.py b/bonsai/models/mamba2/tests/run_model.py new file mode 100644 index 0000000..a4242f2 --- /dev/null +++ b/bonsai/models/mamba2/tests/run_model.py @@ -0,0 +1,153 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import jax +import jax.numpy as jnp +from jax.sharding import PartitionSpec as P +from transformers import AutoTokenizer + +from bonsai.models.mamba2 import modeling + + +@jax.jit +def _decode_step( + model: modeling.Mamba2ForCausalLM, + tokens: jnp.ndarray, + cur: jnp.ndarray, +) -> jnp.ndarray: + out = model(tokens, labels=None) + prev_logits = jax.lax.dynamic_index_in_dim(out["logits"], cur - jnp.int32(1), axis=1, keepdims=False) + next_tok = jnp.argmax(prev_logits, axis=-1) + tokens = tokens.at[:, cur].set(next_tok) + return tokens + + +def _greedy_generate( + model: modeling.Mamba2ForCausalLM, + prompt_ids_1d: jnp.ndarray, + *, + max_new_tokens: int, + pad_id: int, + buffer_len: int, +) -> jnp.ndarray: + prompt_len = int(prompt_ids_1d.shape[0]) + total_len = prompt_len + max_new_tokens + if buffer_len < total_len: + raise ValueError(f"buffer_len ({buffer_len}) must be >= prompt_len+max_new_tokens ({total_len})") + + # Fixed shape buffer: (batch=1, buffer_len) + tokens = jnp.full((1, buffer_len), int(pad_id), dtype=jnp.int32) + tokens = tokens.at[0, :prompt_len].set(jnp.asarray(prompt_ids_1d, dtype=jnp.int32)) + + # Warmup compile once for this (1, buffer_len) shape. + tokens = _decode_step(model, tokens, jnp.asarray(prompt_len, dtype=jnp.int32)) + jax.block_until_ready(tokens) + + # Now generate the remaining tokens. + for t in range(1, max_new_tokens + 1): + cur = jnp.asarray(prompt_len + t, dtype=jnp.int32) # dynamic scalar => no per-token recompiles + tokens = _decode_step(model, tokens, cur) + + out = tokens[:, :total_len] + jax.block_until_ready(out) + return out + + +def _to_host_1d(x: jnp.ndarray) -> jnp.ndarray: + try: + return x.get(out_sharding=P(None)) + except Exception: + return jax.device_get(x) + + +def run_model(*, max_new_tokens: int = 32) -> None: + """Run the Mamba2 generation smoke test.""" + query = [ + "Why is the sky blue instead of any other color like purple?", + "What is the capital city of England?", + ] + + cfg = modeling.Mamba2Config( + vocab_size=50288, + hidden_size=768, + state_size=128, + num_hidden_layers=24, + head_dim=64, + expand=2, + conv_kernel=4, + ) + + model = modeling.Mamba2ForCausalLM.from_pretrained("state-spaces/mamba2-130m", cfg=cfg) + tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b") + + # Determine padding id robustly. + pad_id = tokenizer.pad_token_id + if pad_id is None: + pad_id = tokenizer.eos_token_id + if pad_id is None: + pad_id = cfg.pad_token_id + + # Share a single static buffer length across all prompts to reuse compilation. + prompt_ids_list = [jnp.asarray(tokenizer.encode(q), dtype=jnp.int32) for q in query] + max_prompt_len = max(int(x.shape[0]) for x in prompt_ids_list) + buffer_len = max_prompt_len + max_new_tokens + + for q, prompt_ids in zip(query, prompt_ids_list): + tokens_2d = _greedy_generate( + model, + prompt_ids, + max_new_tokens=max_new_tokens, + pad_id=int(pad_id), + buffer_len=buffer_len, + ) + + host_ids = _to_host_1d(tokens_2d.at[0].get()) + + generated_ids_only = host_ids[len(prompt_ids) :] + + text = tokenizer.decode(generated_ids_only.tolist(), skip_special_tokens=True) + + print(f"User:\n {q}") + print(f"Answer:\n {text.strip()}\n\n") + + +def run_forecaster() -> None: + """Run a tiny Mamba2Forecaster smoke test (shape-only).""" + from bonsai.models.mamba2 import params + + model = params.create_random_forecaster( + input_dim=10, + d_model=64, + n_layers=2, + output_dim=1, + forecast_horizon=24, + seed=42, + ) + + x = jax.random.normal(jax.random.PRNGKey(0), (4, 100, 10)) + y = model(x) + jax.block_until_ready(y) + + print(f"Forecaster input shape: {tuple(x.shape)}") + print(f"Forecaster output shape: {tuple(y.shape)}") + + +if __name__ == "__main__": + run_model() + run_forecaster() + + +__all__ = ["run_forecaster", "run_model"] diff --git a/bonsai/models/mamba2/tests/test_outputs_mamba_2.py b/bonsai/models/mamba2/tests/test_outputs_mamba_2.py new file mode 100644 index 0000000..5597e7b --- /dev/null +++ b/bonsai/models/mamba2/tests/test_outputs_mamba_2.py @@ -0,0 +1,413 @@ +# Copyright 2025 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +os.environ["JAX_DEFAULT_MATMUL_PRECISION"] = "highest" +import jax + +jax.config.update("jax_default_matmul_precision", "highest") + +import jax +import jax.numpy as jnp +import numpy as np +from absl.testing import absltest +from flax import nnx + +from bonsai.models.mamba2 import modeling, params + + +class TestMamba2Config(absltest.TestCase): + """Tests for Mamba2Config.""" + + def test_default_config(self): + """Test default config values.""" + cfg = modeling.Mamba2Config() + self.assertEqual(cfg.vocab_size, 50280) + self.assertEqual(cfg.hidden_size, 768) + self.assertEqual(cfg.num_hidden_layers, 24) + + def test_intermediate_size(self): + """Test intermediate_size property.""" + cfg = modeling.Mamba2Config(hidden_size=512, expand=2) + self.assertEqual(cfg.intermediate_size, 1024) + + def test_num_heads(self): + """Test num_heads property.""" + cfg = modeling.Mamba2Config(hidden_size=512, expand=2, head_dim=64) + # intermediate_size = 1024, head_dim = 64 -> num_heads = 16 + self.assertEqual(cfg.num_heads, 16) + + def test_predefined_configs(self): + """Test predefined configuration methods.""" + cfg_tiny = modeling.Mamba2Config.tiny() + self.assertEqual(cfg_tiny.hidden_size, 64) + self.assertEqual(cfg_tiny.num_hidden_layers, 2) + + +class TestRMSNorm(absltest.TestCase): + """Tests for RMSNorm layer.""" + + def test_output_shape(self): + """Test RMSNorm output shape.""" + norm = modeling.RMSNorm(hidden_size=64, rngs=nnx.Rngs(0)) + x = jnp.ones((2, 16, 64)) + out = norm(x) + self.assertEqual(out.shape, x.shape) + + def test_output_dtype(self): + """Test RMSNorm preserves dtype.""" + norm = modeling.RMSNorm(hidden_size=64, rngs=nnx.Rngs(0)) + x = jnp.ones((2, 16, 64), dtype=jnp.float16) + out = norm(x) + self.assertEqual(out.dtype, jnp.float16) + + def test_with_residual_gate(self): + """Test RMSNorm with residual gating.""" + norm = modeling.RMSNorm(hidden_size=64, gate_residual=True, rngs=nnx.Rngs(0)) + x = jnp.ones((2, 16, 64)) + residual = jnp.ones((2, 16, 64)) * 0.5 + out = norm(x, residual=residual) + self.assertEqual(out.shape, x.shape) + + +class TestSegsum(absltest.TestCase): + """Tests for segsum function.""" + + def test_output_shape(self): + """Test segsum output shape.""" + x = jnp.ones((2, 4, 8)) + out = modeling.segsum(x) + self.assertEqual(out.shape, (2, 4, 8, 8)) + + def test_lower_triangular(self): + """Test that segsum produces lower-triangular + -inf structure.""" + x = jnp.ones((4,)) + out = modeling.segsum(x) + # Upper triangle should be -inf + self.assertTrue(jnp.isinf(out[0, 1])) + self.assertTrue(jnp.isinf(out[0, 2])) + self.assertTrue(jnp.isinf(out[0, 3])) + + +class TestSSDForward(absltest.TestCase): + """Tests for SSD forward function.""" + + def test_output_shape(self): + """Test SSD forward output shape.""" + batch_size, seq_len, num_heads, head_dim, state_size = 2, 32, 4, 16, 8 + x = jnp.ones((batch_size, seq_len, num_heads, head_dim)) + dt = jnp.ones((batch_size, seq_len, num_heads)) * 0.1 + A = -jnp.ones((num_heads,)) + B_mat = jnp.ones((batch_size, seq_len, num_heads, state_size)) + C_mat = jnp.ones((batch_size, seq_len, num_heads, state_size)) + D = jnp.ones((num_heads,)) + dt_bias = jnp.zeros((num_heads,)) + + y, _ = modeling.ssd_forward( + x, dt, A, B_mat, C_mat, chunk_size=16, D=D, dt_bias=dt_bias, dt_min=0.001, dt_max=0.1 + ) + self.assertEqual(y.shape, x.shape) + + def test_with_initial_states(self): + """Test SSD forward with initial states.""" + batch_size, seq_len, num_heads, head_dim, state_size = 2, 32, 4, 16, 8 + x = jnp.ones((batch_size, seq_len, num_heads, head_dim)) + dt = jnp.ones((batch_size, seq_len, num_heads)) * 0.1 + A = -jnp.ones((num_heads,)) + B_mat = jnp.ones((batch_size, seq_len, num_heads, state_size)) + C_mat = jnp.ones((batch_size, seq_len, num_heads, state_size)) + D = jnp.ones((num_heads,)) + dt_bias = jnp.zeros((num_heads,)) + initial_states = jnp.zeros((batch_size, 1, num_heads, head_dim, state_size)) + + y, final_state = modeling.ssd_forward( + x, + dt, + A, + B_mat, + C_mat, + chunk_size=16, + D=D, + dt_bias=dt_bias, + dt_min=0.001, + dt_max=0.1, + initial_states=initial_states, + return_final_states=True, + ) + self.assertEqual(y.shape, x.shape) + self.assertEqual(final_state.shape, (batch_size, num_heads, head_dim, state_size)) + + +class TestMamba2Model(absltest.TestCase): + """Tests for Mamba2Model.""" + + def setUp(self): + super().setUp() + self.cfg = modeling.Mamba2Config.tiny() + self.model = modeling.Mamba2Model(self.cfg, rngs=nnx.Rngs(42)) + + def test_output_shape(self): + """Test Mamba2Model output shape.""" + batch_size, seq_len = 2, 32 + input_ids = jnp.ones((batch_size, seq_len), dtype=jnp.int32) + outputs = self.model(input_ids=input_ids) + + self.assertEqual(outputs["last_hidden_state"].shape, (batch_size, seq_len, self.cfg.hidden_size)) + self.assertIsNone(outputs["hidden_states"]) + self.assertIsNone(outputs["last_ssm_states"]) + + def test_output_hidden_states(self): + """Test output_hidden_states flag.""" + batch_size, seq_len = 2, 32 + input_ids = jnp.ones((batch_size, seq_len), dtype=jnp.int32) + outputs = self.model(input_ids=input_ids, output_hidden_states=True) + + self.assertIsNotNone(outputs["hidden_states"]) + # num_layers + 1 (final norm output) + self.assertLen(outputs["hidden_states"], self.cfg.num_hidden_layers + 1) + + def test_inputs_embeds(self): + """Test using inputs_embeds instead of input_ids.""" + batch_size, seq_len = 2, 32 + inputs_embeds = jnp.ones((batch_size, seq_len, self.cfg.hidden_size)) + outputs = self.model(inputs_embeds=inputs_embeds) + + self.assertEqual(outputs["last_hidden_state"].shape, (batch_size, seq_len, self.cfg.hidden_size)) + + def test_no_nans(self): + """Test that outputs don't contain NaNs.""" + input_ids = jnp.ones((2, 32), dtype=jnp.int32) + outputs = self.model(input_ids=input_ids) + self.assertFalse(jnp.any(jnp.isnan(outputs["last_hidden_state"]))) + + def test_invalid_inputs(self): + """Test that providing both input_ids and inputs_embeds raises error.""" + input_ids = jnp.ones((2, 32), dtype=jnp.int32) + inputs_embeds = jnp.ones((2, 32, self.cfg.hidden_size)) + with self.assertRaises(ValueError): + self.model(input_ids=input_ids, inputs_embeds=inputs_embeds) + + +class TestMamba2ForCausalLM(absltest.TestCase): + """Tests for Mamba2ForCausalLM.""" + + def setUp(self): + super().setUp() + self.cfg = modeling.Mamba2Config.tiny() + self.model = modeling.Mamba2ForCausalLM(self.cfg, rngs=nnx.Rngs(42)) + + def test_output_shape(self): + """Test Mamba2ForCausalLM logits shape.""" + batch_size, seq_len = 2, 32 + input_ids = jnp.ones((batch_size, seq_len), dtype=jnp.int32) + outputs = self.model(input_ids=input_ids) + + self.assertEqual(outputs["logits"].shape, (batch_size, seq_len, self.cfg.vocab_size)) + self.assertIsNone(outputs["loss"]) + + def test_loss_computation(self): + """Test loss computation with labels.""" + batch_size, seq_len = 2, 32 + input_ids = jnp.ones((batch_size, seq_len), dtype=jnp.int32) + labels = jnp.ones((batch_size, seq_len), dtype=jnp.int32) + outputs = self.model(input_ids=input_ids, labels=labels) + + self.assertIsNotNone(outputs["loss"]) + self.assertEqual(outputs["loss"].shape, ()) # scalar + self.assertFalse(jnp.isnan(outputs["loss"])) + + def test_no_nans_in_logits(self): + """Test that logits don't contain NaNs.""" + input_ids = jnp.ones((2, 32), dtype=jnp.int32) + outputs = self.model(input_ids=input_ids) + self.assertFalse(jnp.any(jnp.isnan(outputs["logits"]))) + + +class TestMamba2Forecaster(absltest.TestCase): + """Tests for Mamba2Forecaster.""" + + def test_output_shape(self): + """Test Mamba2Forecaster output shape.""" + model = modeling.Mamba2Forecaster(input_dim=10, forecast_horizon=24, output_dim=1, n_layers=2, rngs=nnx.Rngs(0)) + x = jnp.ones((2, 100, 10)) + out = model(x) + self.assertEqual(out.shape, (2, 24, 1)) + + def test_multi_output(self): + """Test Mamba2Forecaster with multiple output dimensions.""" + model = modeling.Mamba2Forecaster(input_dim=10, forecast_horizon=12, output_dim=3, n_layers=2, rngs=nnx.Rngs(0)) + x = jnp.ones((4, 50, 10)) + out = model(x) + self.assertEqual(out.shape, (4, 12, 3)) + + def test_no_nans(self): + """Test that forecaster outputs don't contain NaNs.""" + model = modeling.Mamba2Forecaster(input_dim=5, forecast_horizon=10, n_layers=2, rngs=nnx.Rngs(0)) + x = jnp.ones((2, 32, 5)) + out = model(x) + self.assertFalse(jnp.any(jnp.isnan(out))) + + +class TestParameters(absltest.TestCase): + """Tests for parameter utilities.""" + + def test_create_random_model(self): + """Test random model creation.""" + cfg = modeling.Mamba2Config.tiny() + model = params.create_random_model(cfg, seed=42) + self.assertIsInstance(model, modeling.Mamba2ForCausalLM) + + # Test forward pass + input_ids = jnp.ones((2, 16), dtype=jnp.int32) + outputs = model(input_ids) + self.assertEqual(outputs["logits"].shape, (2, 16, cfg.vocab_size)) + + def test_create_random_forecaster(self): + """Test random forecaster creation.""" + model = params.create_random_forecaster(input_dim=10, forecast_horizon=24, seed=42) + self.assertIsInstance(model, modeling.Mamba2Forecaster) + + # Test forward pass + x = jnp.ones((2, 50, 10)) + out = model(x) + self.assertEqual(out.shape, (2, 24, 1)) + + +class TestJIT(absltest.TestCase): + """Tests for JIT compilation.""" + + def setUp(self): + super().setUp() + self.cfg = modeling.Mamba2Config.tiny() + + def test_jit_backbone(self): + """Test that backbone can be JIT compiled.""" + model = modeling.Mamba2Model(self.cfg, rngs=nnx.Rngs(42)) + + @jax.jit + def forward_fn(model, x): + return model(input_ids=x) + + input_ids = jnp.ones((2, 32), dtype=jnp.int32) + outputs = forward_fn(model, input_ids) + self.assertEqual(outputs["last_hidden_state"].shape, (2, 32, 64)) + + def test_jit_causal_lm(self): + """Test that CausalLM can be JIT compiled.""" + model = modeling.Mamba2ForCausalLM(self.cfg, rngs=nnx.Rngs(42)) + + input_ids = jnp.ones((2, 32), dtype=jnp.int32) + labels = jnp.ones((2, 32), dtype=jnp.int32) + outputs = modeling.forward(model, input_ids, labels) + self.assertIsNotNone(outputs["loss"]) + + +class TestGradients(absltest.TestCase): + """Tests for gradient computation.""" + + def setUp(self): + super().setUp() + self.cfg = modeling.Mamba2Config.tiny() + + def test_gradients_exist(self): + """Test that gradients can be computed.""" + model = modeling.Mamba2ForCausalLM(self.cfg, rngs=nnx.Rngs(42)) + + def loss_fn(model, x, labels): + outputs = model(input_ids=x, labels=labels) + return outputs["loss"] + + input_ids = jnp.ones((2, 16), dtype=jnp.int32) + labels = jnp.ones((2, 16), dtype=jnp.int32) + + loss, _grads = nnx.value_and_grad(loss_fn)(model, input_ids, labels) + self.assertIsNotNone(_grads) + self.assertTrue(jnp.isfinite(loss)) + + def test_no_nan_gradients(self): + """Test that gradients don't contain NaNs.""" + model = modeling.Mamba2ForCausalLM(self.cfg, rngs=nnx.Rngs(42)) + + def loss_fn(model, x, labels): + outputs = model(input_ids=x, labels=labels) + return outputs["loss"] + + input_ids = jnp.ones((2, 16), dtype=jnp.int32) + labels = jnp.ones((2, 16), dtype=jnp.int32) + + loss, _grads = nnx.value_and_grad(loss_fn)(model, input_ids, labels) + self.assertFalse(jnp.isnan(loss)) + + +class TestGoldenParity(absltest.TestCase): + """Tests for parity with mamba_ssm reference outputs.""" + + @classmethod + def setUpClass(cls): + super().setUpClass() + artifacts_dir = os.path.join(os.path.dirname(__file__), "artifacts") + cls.golden = np.load(os.path.join(artifacts_dir, "golden_mamba2_130m.npz")) + + def test_hidden_state_parity(self): + """Test last_hidden_state matches mamba_ssm reference within numerical tolerance.""" + cfg = modeling.Mamba2Config( + vocab_size=50288, + hidden_size=768, + state_size=128, + num_hidden_layers=24, + head_dim=64, + expand=2, + conv_kernel=4, + ) + model = modeling.Mamba2ForCausalLM.from_pretrained("state-spaces/mamba2-130m", cfg=cfg) + + input_ids = jnp.array(self.golden["input_ids"], dtype=jnp.int32) + outputs = model.backbone(input_ids=input_ids) + bonsai_hidden = np.array(outputs["last_hidden_state"]) + golden_hidden = self.golden["last_hidden_state"] + + # fp32=1e-5, bf16=1e-3 (see ViT parity tests). + # atol is an output-level floor to avoid near-zero blowups + rtol = 1e-5 if bonsai_hidden.dtype == np.float32 else 1e-3 + atol = 1e-1 + np.testing.assert_allclose(bonsai_hidden, golden_hidden, rtol=rtol, atol=atol) + + def test_logits_parity(self): + """Test logits match mamba_ssm reference within numerical tolerance.""" + cfg = modeling.Mamba2Config( + vocab_size=50288, + hidden_size=768, + state_size=128, + num_hidden_layers=24, + head_dim=64, + expand=2, + conv_kernel=4, + ) + model = modeling.Mamba2ForCausalLM.from_pretrained("state-spaces/mamba2-130m", cfg=cfg) + + input_ids = jnp.array(self.golden["input_ids"], dtype=jnp.int32) + outputs = model(input_ids=input_ids) + bonsai_logits = np.array(outputs["logits"])[:, :, :256] + golden_logits = self.golden["logits_slice"] + + # fp32=1e-5, bf16=1e-3 (see ViT parity tests). + # atol is an output-level floor to avoid near-zero blowups + rtol = 1e-5 if bonsai_logits.dtype == np.float32 else 1e-3 + atol = 2e-1 + np.testing.assert_allclose(bonsai_logits, golden_logits, rtol=rtol, atol=atol) + + +if __name__ == "__main__": + absltest.main()