Skip to content

Conversation

@CosmoNaught
Copy link
Contributor

Resolves #99

Adds Mamba2 model implementation in JAX/Flax NNX.

Implements the State Space Duality (SSD) algorithm from "Transformers are SSMs" (Dao & Gu, ICML 2024).

Models Added

  • Mamba2Model - Base backbone
  • Mamba2ForCausalLM - Causal language modeling
  • Mamba2Forecaster - Time series forecasting

Reference

Checklist

  • I have read the Contribution Guidelines and used pre-commit hooks to format this commit.
  • I have added all the necessary unit tests for my change. (run_model.py for model usage, test_outputs.py and/or model_validation_colab.ipynb for quality).
  • (If using an LLM) I have carefully reviewed and removed all superfluous comments or unneeded, commented-out code. Only necessary and functional code remains.
  • I have signed the Contributor License Agreement (CLA).

@google-cla
Copy link

google-cla bot commented Dec 11, 2025

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@CosmoNaught CosmoNaught force-pushed the add-mamba2-model branch 2 times, most recently from 00c7589 to f661341 Compare December 11, 2025 16:13
def segsum(x: jnp.ndarray) -> jnp.ndarray:
"""Stable segment sum calculation. Input: (..., T) -> Output: (..., T, T)."""
T = x.shape[-1]
x_rep = repeat(x, "... d -> ... d e", e=T)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use jax.numpy.tile to remove dependence on the einops library?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, although einops is useful at times we want Bonsai to be dependency-free and in pure jax and flax.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 90737bf


# Chunk everything
def chunk_tensor(t):
return rearrange(t, "b (c l) ... -> b c l ...", l=chunk_size)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use .reshape to remove dependence on the einops library? Could do the following

def chunk_tensor(t)
  b, cl, *remaining = t.shape
  return t.reshape(b, cl // chunk_size, chunk_size, *remaining)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 90737bf

# limitations under the License.

"""Parameter utilities for Mamba2 models."""

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a function to this file which loads mamba2 from pre-trained weights?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Implemented in 472243a

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additionally please see 27f596c as we have now confirmed we can load all pre-trained models of mamba2 class architecture


loss, _grads = nnx.value_and_grad(loss_fn)(model, input_ids, labels)
self.assertFalse(jnp.isnan(loss))

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a test to compare the outputs of the bonsai implementation compared to a pretrained model?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Implemented in d5ea00e

Copy link
Member

@jenriver jenriver left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the Mamba2 implementation! It's super cool to have linear attention in our collection!

Comment on lines 70 to 81
def count_parameters(model: nnx.Module) -> int:
"""Count the total number of trainable parameters in a model.

Args:
model: NNX module to count parameters for.

Returns:
Total number of parameters.
"""
_____graphdef, state = nnx.split(model)
params = state.filter(nnx.Param)
return sum(p.size for p in jax.tree.leaves(params))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we remove this? It's not critical path for code and quality test guarantees parameter is successfully loaded.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed in aa2cbe4

def segsum(x: jnp.ndarray) -> jnp.ndarray:
"""Stable segment sum calculation. Input: (..., T) -> Output: (..., T, T)."""
T = x.shape[-1]
x_rep = repeat(x, "... d -> ... d e", e=T)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Current use of repeat will create a huge physical copy leading to potential OOM's. Instead, please use broadcast to something like this:

x_cumsum = jnp.cumsum(x, axis=-1)
x_segsum = x_cumsum[..., :, None] - x_cumsum[..., None, :]
mask = jnp.tril(jnp.ones((T, T), dtype=bool), k=0)
 x_segsum = jnp.where(mask, x_segsum, -jnp.inf)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! Fixed in 494c166

@@ -0,0 +1,353 @@
# Copyright 2025 The JAX Authors.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a quality test to check your implementation output vs. original reference model?
Currently it only checks the shapes/dtypes and doesn't check the actual output, so we don't know whether the quality of this model is correct.

ex: https://github.com/jax-ml/bonsai/blob/main/bonsai/models/vit/tests/test_outputs_vit.py

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Implemented in d5ea00e

@@ -0,0 +1,135 @@
# Copyright 2025 The JAX Authors.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work. Some quick notes:

  • Can we add some example questions and outputs to the logs? It helps verify that the generation "looks right" at a glance. (example)

  • Let's remove the error tolerance check from this file; we should handle quality/numerical parity in test_outputs.py instead to keep this as a pure smoke test.

  • Also, let's remove most of the prin statements to focus on only the model input / output.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! See refactor under ec4611c


# Tolerances account for Triton vs JAX numerical differences across 24 layers
self.assertLess(max_diff, 1e-1, f"Max diff {max_diff:.2e} exceeds tolerance")
self.assertLess(mean_diff, 1e-3, f"Mean diff {mean_diff:.2e} exceeds tolerance")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Cosmo -- could you compare the rtol instead? Usually atol is dependent on model and gives less signal.

A good rule of thumb is 1e-5 for float32 and 1e-3 for bfloat16, although these can vary based on rng seed and model.

(i.e. via np.testing.assert_allclose or torch.testing.assert_close)
https://github.com/jax-ml/bonsai/blob/main/bonsai/models/vit/tests/test_outputs_vit.py

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Jen good to know this!

I've now fixed this in 706b62d

I updated the golden parity checks to use np.testing.assert_allclose with rtol as the primary signal (fp32 = 1e-5, bf16 = 1e-3), per your comment and the ViT tests.

I also pinned jax_default_matmul_precision to "highest" to match the golden generator, that removed extra backend drift and made the rtol thresholds stable.

last_hidden = outputs["last_hidden_state"][:, -1, :]
out = self.output_proj(last_hidden)
return out.reshape(x.shape[0], self.forecast_horizon, self.output_dim)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could improve inference performance with caching.

Copy link
Member

@jenriver jenriver left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello Cosmo, thanks for the implementation! This looks great esp with the proper golden logits test. :)

I left the KV Cache as a TODO item in #99 . With this, hope we can see the full benefits in performance from linear attention!

@jenriver jenriver merged commit a907b75 into jax-ml:main Dec 22, 2025
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Model Request: Mamba2

3 participants