-
Notifications
You must be signed in to change notification settings - Fork 34
Add mamba2 model #103
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add mamba2 model #103
Conversation
|
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). View this failed invocation of the CLA check for more information. For the most up to date status, view the checks section at the bottom of the pull request. |
00c7589 to
f661341
Compare
f661341 to
ce45e8e
Compare
bonsai/models/mamba2/modeling.py
Outdated
| def segsum(x: jnp.ndarray) -> jnp.ndarray: | ||
| """Stable segment sum calculation. Input: (..., T) -> Output: (..., T, T).""" | ||
| T = x.shape[-1] | ||
| x_rep = repeat(x, "... d -> ... d e", e=T) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we use jax.numpy.tile to remove dependence on the einops library?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1, although einops is useful at times we want Bonsai to be dependency-free and in pure jax and flax.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed in 90737bf
bonsai/models/mamba2/modeling.py
Outdated
|
|
||
| # Chunk everything | ||
| def chunk_tensor(t): | ||
| return rearrange(t, "b (c l) ... -> b c l ...", l=chunk_size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we use .reshape to remove dependence on the einops library? Could do the following
def chunk_tensor(t)
b, cl, *remaining = t.shape
return t.reshape(b, cl // chunk_size, chunk_size, *remaining)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed in 90737bf
| # limitations under the License. | ||
|
|
||
| """Parameter utilities for Mamba2 models.""" | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add a function to this file which loads mamba2 from pre-trained weights?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Implemented in 472243a
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additionally please see 27f596c as we have now confirmed we can load all pre-trained models of mamba2 class architecture
|
|
||
| loss, _grads = nnx.value_and_grad(loss_fn)(model, input_ids, labels) | ||
| self.assertFalse(jnp.isnan(loss)) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add a test to compare the outputs of the bonsai implementation compared to a pretrained model?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Implemented in d5ea00e
jenriver
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the Mamba2 implementation! It's super cool to have linear attention in our collection!
| def count_parameters(model: nnx.Module) -> int: | ||
| """Count the total number of trainable parameters in a model. | ||
|
|
||
| Args: | ||
| model: NNX module to count parameters for. | ||
|
|
||
| Returns: | ||
| Total number of parameters. | ||
| """ | ||
| _____graphdef, state = nnx.split(model) | ||
| params = state.filter(nnx.Param) | ||
| return sum(p.size for p in jax.tree.leaves(params)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we remove this? It's not critical path for code and quality test guarantees parameter is successfully loaded.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed in aa2cbe4
bonsai/models/mamba2/modeling.py
Outdated
| def segsum(x: jnp.ndarray) -> jnp.ndarray: | ||
| """Stable segment sum calculation. Input: (..., T) -> Output: (..., T, T).""" | ||
| T = x.shape[-1] | ||
| x_rep = repeat(x, "... d -> ... d e", e=T) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Current use of repeat will create a huge physical copy leading to potential OOM's. Instead, please use broadcast to something like this:
x_cumsum = jnp.cumsum(x, axis=-1)
x_segsum = x_cumsum[..., :, None] - x_cumsum[..., None, :]
mask = jnp.tril(jnp.ones((T, T), dtype=bool), k=0)
x_segsum = jnp.where(mask, x_segsum, -jnp.inf)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch! Fixed in 494c166
| @@ -0,0 +1,353 @@ | |||
| # Copyright 2025 The JAX Authors. | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add a quality test to check your implementation output vs. original reference model?
Currently it only checks the shapes/dtypes and doesn't check the actual output, so we don't know whether the quality of this model is correct.
ex: https://github.com/jax-ml/bonsai/blob/main/bonsai/models/vit/tests/test_outputs_vit.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Implemented in d5ea00e
| @@ -0,0 +1,135 @@ | |||
| # Copyright 2025 The JAX Authors. | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice work. Some quick notes:
-
Can we add some example questions and outputs to the logs? It helps verify that the generation "looks right" at a glance. (example)
-
Let's remove the error tolerance check from this file; we should handle quality/numerical parity in test_outputs.py instead to keep this as a pure smoke test.
-
Also, let's remove most of the prin statements to focus on only the model input / output.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! See refactor under ec4611c
|
|
||
| # Tolerances account for Triton vs JAX numerical differences across 24 layers | ||
| self.assertLess(max_diff, 1e-1, f"Max diff {max_diff:.2e} exceeds tolerance") | ||
| self.assertLess(mean_diff, 1e-3, f"Mean diff {mean_diff:.2e} exceeds tolerance") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Cosmo -- could you compare the rtol instead? Usually atol is dependent on model and gives less signal.
A good rule of thumb is 1e-5 for float32 and 1e-3 for bfloat16, although these can vary based on rng seed and model.
(i.e. via np.testing.assert_allclose or torch.testing.assert_close)
https://github.com/jax-ml/bonsai/blob/main/bonsai/models/vit/tests/test_outputs_vit.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Jen good to know this!
I've now fixed this in 706b62d
I updated the golden parity checks to use np.testing.assert_allclose with rtol as the primary signal (fp32 = 1e-5, bf16 = 1e-3), per your comment and the ViT tests.
I also pinned jax_default_matmul_precision to "highest" to match the golden generator, that removed extra backend drift and made the rtol thresholds stable.
706b62d to
d0185ae
Compare
| last_hidden = outputs["last_hidden_state"][:, -1, :] | ||
| out = self.output_proj(last_hidden) | ||
| return out.reshape(x.shape[0], self.forecast_horizon, self.output_dim) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could improve inference performance with caching.
jenriver
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hello Cosmo, thanks for the implementation! This looks great esp with the proper golden logits test. :)
I left the KV Cache as a TODO item in #99 . With this, hope we can see the full benefits in performance from linear attention!
Resolves #99
Adds Mamba2 model implementation in JAX/Flax NNX.
Implements the State Space Duality (SSD) algorithm from "Transformers are SSMs" (Dao & Gu, ICML 2024).
Models Added
Mamba2Model- Base backboneMamba2ForCausalLM- Causal language modelingMamba2Forecaster- Time series forecastingReference
Checklist
run_model.pyfor model usage,test_outputs.pyand/ormodel_validation_colab.ipynbfor quality).