Skip to content
Open
26 changes: 26 additions & 0 deletions bonsai/models/gemma3/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Gemma3 in JAX

This directory contains a pure JAX implementation of the [Gemma3 model](https://deepmind.google/models/gemma/gemma-3/), using the [Flax NNX](https://flax.readthedocs.io/en/v0.8.3/experimental/nnx/index.html) API.

Note that you need an access token to download the model weights. In order to run the scripts, make sure to save an environment variable `HF_TOKEN` with your huggingface access token.


## Model Configuration Support Status


### Running this model


```sh
python3 -m bonsai.models.gemma3.tests.run_model
```


## How to contribute to this model

### Remaining Tasks

1. Implement with batching. Need this for FSDP.
2. Optimize based on the profiling.
3. Clean up code (variable names, etc.). Simplify unused configs (marked these with TODO) or use them.
4. Update to include other model sizes and optimize parameter loading.
Loading