From b730250c1d140bf161ca5ca7d6d968189aa09d4c Mon Sep 17 00:00:00 2001 From: Simon Dirmeier Date: Tue, 19 Aug 2025 10:13:44 +0200 Subject: [PATCH 1/5] Add some more documentation --- README.md | 216 +++++++++++++++++++++++-- blaxbird/__init__.py | 2 +- blaxbird/_src/checkpointer.py | 10 +- blaxbird/_src/trainer.py | 7 +- examples/mnist_classification/main.py | 15 +- examples/mnist_classification/model.py | 2 +- 6 files changed, 219 insertions(+), 33 deletions(-) diff --git a/README.md b/README.md index 7847c24..53fbd10 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,22 @@ # blaxbird [blækbɜːd] +[![ci](https://github.com/dirmeier/blaxbird/actions/workflows/ci.yaml/badge.svg)](https://github.com/dirmeier/blaxbird/actions/workflows/ci.yaml) +[![version](https://img.shields.io/pypi/v/blaxbird.svg?colorB=black&style=flat)](https://pypi.org/project/blaxbird/) + +> A high-level API to build and train NNX models + ## About -A high-level API to build and train NNX models. +`Blaxbird` [blækbɜːd] is a high-level API to easily build NNX models and train them on CPU or GPU. +Using `blaxbird` one can +- concisely define models and loss functions without the usual JAX/Flax verbosity, +- easily define checkpointers that save the best and most current network weights, +- distribute data and model weights over multiple processes or GPUs, +- define hooks that are periodically called during training. + +## Example -Define the module +To use `blaxbird`, one only needs to define a model, a loss function, and train and validation step functions: ```python import optax from flax import nnx @@ -18,34 +30,212 @@ def loss_fn(model, images, labels): logits=logits, labels=labels ).mean() - -def train_step(model, rng_key, batch, **kwargs): +def train_step(model, rng_key, batch): return nnx.value_and_grad(loss_fn)(model, batch["image"], batch["label"]) - -def val_step(model, rng_key, batch, **kwargs): +def val_step(model, rng_key, batch): return loss_fn(model, batch["image"], batch["label"]) ``` -Define the trainer +You can then define construct (and use) a training function like this: + ```python -from jax import random as jr +import optax from flax import nnx +from jax import random as jr from blaxbird import train_fn model = CNN(rngs=nnx.rnglib.Rngs(jr.key(1))) -optimizer = get_optimizer(model) +optimizer = nnx.Optimizer(model, optax.adam(1e-4)) train = train_fn( - fns=(train_step, val_step), - n_steps=n_steps, - n_eval_frequency=n_eval_frequency, - n_eval_batches=n_eval_batches, + fns=(train_step, val_step), + n_steps=100, + eval_every_n_steps=10, + n_eval_batches=10 ) train(jr.key(2), model, optimizer, train_itr, val_itr) ``` +See a self-contained example in [examples/mnist_classification](examples/mnist_classification). + +## Usage + +`train_fn` is a higher order function with the following signature: + +```python +def train_fn( + *, + fns: tuple[Callable, Callable], + shardings: Optional[tuple[jax.NamedSharding, jax.NamedSharding]] = None, + n_steps: int, + eval_every_n_steps: int, + n_eval_batches: int, + log_to_wandb: bool = False, + hooks: Iterable[Callable] = (), +) -> Callable: + ... +``` + +We briefly explain the more ambiguous argument types below. + +### `fns` + +`fns` is a required argument consistenf of tuple of two functions, a step function and a validation function. +In the simplest case they look like this: + +```python +def train_step(model, rng_key, batch): + return nnx.value_and_grad(loss_fn)(model, batch["image"], batch["label"]) + +def val_step(model, rng_key, batch): + return loss_fn(model, batch["image"], batch["label"]) +``` + +Both `train_step` and `val_step` have the same arguments and argument types: +- `model` specifies a `nnx.Module`, i.e., a neural network like the CNN shown above. +- `rng_key` is a `jax.random.key` in case you need to generate random numbers. +- `batch` is a sample from a data loader (to be specified later). + +The loss function that is called by both computes a *scalar* loss value. B +While `train_step` returns has to return the loss and gradients, `val_step` only needs +to return the loss. + +### `shardings` + +To specify how data and model weights are distributed over devices and processes, +`blaxbird` uses JAX' [sharding](https://docs.jax.dev/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) functionality. + +`shardings` is again specified by a tuple, one for the model sharding, the other for the data sharding. +An example is shown below, where we only distributed the data over `num_devices` devices. +You can, if you don't want to distribute anything, just set the argument to `None` or not specify it. + +```python +def get_sharding(): + num_devices = jax.local_device_count() + mesh = jax.sharding.Mesh( + mesh_utils.create_device_mesh((num_devices,)), ("data",) + ) + model_sharding = jax.NamedSharding(mesh, jax.sharding.PartitionSpec()) + data_sharding = jax.NamedSharding(mesh, jax.sharding.PartitionSpec("data")) + return model_sharding, data_sharding +``` + +### `hooks` + +`hooks` is a list of callables which are periodically called during training. +Each hook has to have the following signature: + +```python +def hook_fn(step, *, model, **kwargs) -> None: + ... +``` + +It takes an integer `step` specifying the current training iteration and the model itself. +For instance, if you want to track custom metrics during validation, you could create a hook like this: + +```python +def hook_fn(metrics, val_iter, hook_every_n_steps): + def fn(step, *, model, **kwargs): + if step % hook_every_n_steps != 0: + return + for batch in val_iter: + logits = model(batch["image"]) + loss = optax.softmax_cross_entropy_with_integer_labels( + logits=logits, labels=batch["label"] + ).mean() + metrics.update(loss=loss, logits=logits, labels=batch["label"]) + if jax.process_index() == 0: + curr_metrics = ", ".join( + [f"{k}: {v}" for k, v in metrics.compute().items()] + ) + logging.info(f"metrics at step {step}: {curr_metrics}") + metrics.reset() + return fn + +metrics = nnx.MultiMetric( + accuracy=nnx.metrics.Accuracy(), + loss=nnx.metrics.Average("loss"), +) +hook = hook_fn(metrics, val_iter, eval_every_n_steps) +``` + +This creates a hook function `hook` that after `eval_every_n_steps` steps iterates over the validation set +computes accuracy and loss, and then logs everything. + +To provide multiple hooks to the train function, just concatenate them in a list. + +#### A checkpointing `hook` + +We provide a convenient hook for checkpointing which can be constructed using +`get_default_checkpointer`. The checkpointer saves both the last `k` checkpoints with the lowest +validation loss and the last training checkpoint. + +The signature of the hook is: + +```python +def get_default_checkpointer( + outfolder: str, + *, + save_every_n_steps: int, + max_to_keep: int = 5, +) -> tuple[Callable, Callable, Callable] +``` + +Its arguments are: +- `outfolder`: a folder specifying where to store the checkpoints. +- `save_every_n_steps`: after how many training steps to store a checkpoint. +- `max_to_keep`: the number of checkpoints to keep before starting to remove old checkpoints (to not clog the device). + +For instance, you would construct the checkpointing function then like this: + +```python +from blaxbird import get_default_checkpointer + +hook_save, *_ = get_default_checkpointer( + os.path.join(outfolder, "checkpoints"), save_every_n_steps=100 +) +``` + +### Restoring a run + +You can also use `get_default_checkpointer` to restart the run where you left off. +`get_default_checkpointer` in fact returns three functions, one for saving checkpoints and two for restoring +checkpoints: + +```python +from blaxbird import get_default_checkpointer + +save, restore_best, restore_last = get_default_checkpointer( + os.path.join(outfolder, "checkpoints"), save_every_n_steps=100 +) +``` + +You can then do either of: + +```python +model = CNN(rngs=nnx.rnglib.Rngs(jr.key(1))) +optimizer = nnx.Optimizer(model, optax.adam(1e-4)) + +model, optimizer = restore_best(model, optimizer) +model, optimizer = restore_last(model, optimizer) +``` + +## Installation + +To install the package from PyPI, call: + +```bash +pip install blaxbird +``` + +To install the latest GitHub , just call the following on the command line: + +```bash +pip install git+https://github.com/dirmeier/blaxbird@ +``` + ## Author Simon Dirmeier simd@mailbox.org diff --git a/blaxbird/__init__.py b/blaxbird/__init__.py index 2c60007..1fbf835 100644 --- a/blaxbird/__init__.py +++ b/blaxbird/__init__.py @@ -1,4 +1,4 @@ -"""fll: A high-level API for building and training Flax NNX models.""" +"""blaxbird: A high-level API for building and training Flax NNX models.""" __version__ = "0.0.1" diff --git a/blaxbird/_src/checkpointer.py b/blaxbird/_src/checkpointer.py index 5925310..d8aaccc 100644 --- a/blaxbird/_src/checkpointer.py +++ b/blaxbird/_src/checkpointer.py @@ -9,10 +9,8 @@ def get_default_checkpointer( outfolder: str, *, - save_every_n_steps: int = 1, + save_every_n_steps: int, max_to_keep: int = 5, - best_fn: Callable = lambda x: x["val/loss"], - best_mode: str = "min", ) -> tuple[Callable, Callable, Callable]: """Construct functions for checkpointing functionality. @@ -20,17 +18,13 @@ def get_default_checkpointer( outfolder: a path specifying where checkpoints are stored save_every_n_steps: how often to store checkpoints max_to_keep: number of checkpoints to store before they get deleted - best_fn: function that maintains checkpoints using a specific criterion for - quality - best_mode: use `min`, e.g., if your criterion is a loss function. - Use 'max' if the criterion is an ELBO or something. Returns: returns function to saev and restore checkpoints """ checkpointer = ocp.PyTreeCheckpointer() options = ocp.CheckpointManagerOptions( - max_to_keep=max_to_keep, create=True, best_mode=best_mode, best_fn=best_fn + max_to_keep=max_to_keep, create=True, best_mode="min", best_fn=lambda x: x["val/loss"] ) checkpoint_manager = ocp.CheckpointManager( os.path.join(outfolder, "best"), diff --git a/blaxbird/_src/trainer.py b/blaxbird/_src/trainer.py index 1f2b844..dc36b55 100644 --- a/blaxbird/_src/trainer.py +++ b/blaxbird/_src/trainer.py @@ -1,4 +1,5 @@ from collections.abc import Callable, Iterable +from typing import Optional import grain.python as grain import jax @@ -32,7 +33,7 @@ def _eval_step(model, rng_key, metrics, batch, **kwargs): def train_fn( *, fns: tuple[Callable, Callable], - shardings: tuple[jax.NamedSharding, jax.NamedSharding], + shardings: Optional[tuple[jax.NamedSharding, jax.NamedSharding]] = None, n_steps: int, eval_every_n_steps: int, n_eval_batches: int, @@ -70,8 +71,8 @@ def train( rng_key: a jax.random.key object model: a NNX model optimizer: a nnx.Optimizer object - train_itr: a data laoder - val_itr: a data laoder + train_itr: a data loader + val_itr: a data loader """ # get train and val fns step_fn, eval_fn = _step_and_val_fns(fns) diff --git a/examples/mnist_classification/main.py b/examples/mnist_classification/main.py index fb40e9e..09d3db6 100644 --- a/examples/mnist_classification/main.py +++ b/examples/mnist_classification/main.py @@ -7,10 +7,11 @@ from flax import nnx from jax import random as jr from jax.experimental import mesh_utils -from model import CNN, train_step, val_step from blaxbird import get_default_checkpointer, train_fn +from model import CNN, train_step, val_step + def get_optimizer(model, lr=1e-4): tx = optax.adamw(lr) @@ -28,10 +29,10 @@ def get_sharding(): return model_sharding, data_sharding -def visualize_hook(val_iter, n_eval_frequency): - def hook_fn(metrics, val_iter, n_eval_frequency): +def visualize_hook(val_iter, eval_every_n_steps): + def hook_fn(metrics, val_iter, eval_every_n_steps): def fn(step, *, model, **kwargs): - if step % n_eval_frequency != 0: + if step % eval_every_n_steps != 0: return batch = next(iter(val_iter)) logits = model(batch["image"]) @@ -52,11 +53,11 @@ def fn(step, *, model, **kwargs): accuracy=nnx.metrics.Accuracy(), loss=nnx.metrics.Average("loss"), ) - return hook_fn(metrics, val_iter, n_eval_frequency) + return hook_fn(metrics, val_iter, eval_every_n_steps) -def get_hooks(val_itr, n_eval_frequency): - return [visualize_hook(val_itr, n_eval_frequency)] +def get_hooks(val_itr, eval_every_n_steps): + return [visualize_hook(val_itr, eval_every_n_steps)] def get_train_and_val_itrs(rng_key, outfolder): diff --git a/examples/mnist_classification/model.py b/examples/mnist_classification/model.py index ad1e0ab..a9ad54c 100644 --- a/examples/mnist_classification/model.py +++ b/examples/mnist_classification/model.py @@ -5,7 +5,7 @@ class CNN(nnx.Module): - """Copy==pasted from the Flax documentation.""" + """Copy-pasted from the Flax documentation.""" def __init__(self, *, rngs: nnx.Rngs): self.conv1 = nnx.Conv(1, 32, kernel_size=(3, 3), rngs=rngs) From 905a405da3bbb297a78821a818001d422f8a4b0a Mon Sep 17 00:00:00 2001 From: Simon Dirmeier Date: Tue, 19 Aug 2025 10:14:26 +0200 Subject: [PATCH 2/5] Add some more documentation --- README.md | 20 ++++++++++---------- blaxbird/_src/checkpointer.py | 5 ++++- blaxbird/_src/trainer.py | 3 +-- 3 files changed, 15 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index 53fbd10..6a2680f 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@ Using `blaxbird` one can - distribute data and model weights over multiple processes or GPUs, - define hooks that are periodically called during training. -## Example +## Example To use `blaxbird`, one only needs to define a model, a loss function, and train and validation step functions: ```python @@ -50,8 +50,8 @@ model = CNN(rngs=nnx.rnglib.Rngs(jr.key(1))) optimizer = nnx.Optimizer(model, optax.adam(1e-4)) train = train_fn( - fns=(train_step, val_step), - n_steps=100, + fns=(train_step, val_step), + n_steps=100, eval_every_n_steps=10, n_eval_batches=10 ) @@ -109,7 +109,7 @@ To specify how data and model weights are distributed over devices and processes `shardings` is again specified by a tuple, one for the model sharding, the other for the data sharding. An example is shown below, where we only distributed the data over `num_devices` devices. -You can, if you don't want to distribute anything, just set the argument to `None` or not specify it. +You can, if you don't want to distribute anything, just set the argument to `None` or not specify it. ```python def get_sharding(): @@ -140,7 +140,7 @@ def hook_fn(metrics, val_iter, hook_every_n_steps): def fn(step, *, model, **kwargs): if step % hook_every_n_steps != 0: return - for batch in val_iter: + for batch in val_iter: logits = model(batch["image"]) loss = optax.softmax_cross_entropy_with_integer_labels( logits=logits, labels=batch["label"] @@ -159,17 +159,17 @@ metrics = nnx.MultiMetric( loss=nnx.metrics.Average("loss"), ) hook = hook_fn(metrics, val_iter, eval_every_n_steps) -``` +``` This creates a hook function `hook` that after `eval_every_n_steps` steps iterates over the validation set -computes accuracy and loss, and then logs everything. +computes accuracy and loss, and then logs everything. To provide multiple hooks to the train function, just concatenate them in a list. #### A checkpointing `hook` We provide a convenient hook for checkpointing which can be constructed using -`get_default_checkpointer`. The checkpointer saves both the last `k` checkpoints with the lowest +`get_default_checkpointer`. The checkpointer saves both the last `k` checkpoints with the lowest validation loss and the last training checkpoint. The signature of the hook is: @@ -179,13 +179,13 @@ def get_default_checkpointer( outfolder: str, *, save_every_n_steps: int, - max_to_keep: int = 5, + max_to_keep: int = 5, ) -> tuple[Callable, Callable, Callable] ``` Its arguments are: - `outfolder`: a folder specifying where to store the checkpoints. -- `save_every_n_steps`: after how many training steps to store a checkpoint. +- `save_every_n_steps`: after how many training steps to store a checkpoint. - `max_to_keep`: the number of checkpoints to keep before starting to remove old checkpoints (to not clog the device). For instance, you would construct the checkpointing function then like this: diff --git a/blaxbird/_src/checkpointer.py b/blaxbird/_src/checkpointer.py index d8aaccc..8314089 100644 --- a/blaxbird/_src/checkpointer.py +++ b/blaxbird/_src/checkpointer.py @@ -24,7 +24,10 @@ def get_default_checkpointer( """ checkpointer = ocp.PyTreeCheckpointer() options = ocp.CheckpointManagerOptions( - max_to_keep=max_to_keep, create=True, best_mode="min", best_fn=lambda x: x["val/loss"] + max_to_keep=max_to_keep, + create=True, + best_mode="min", + best_fn=lambda x: x["val/loss"], ) checkpoint_manager = ocp.CheckpointManager( os.path.join(outfolder, "best"), diff --git a/blaxbird/_src/trainer.py b/blaxbird/_src/trainer.py index dc36b55..483ac64 100644 --- a/blaxbird/_src/trainer.py +++ b/blaxbird/_src/trainer.py @@ -1,5 +1,4 @@ from collections.abc import Callable, Iterable -from typing import Optional import grain.python as grain import jax @@ -33,7 +32,7 @@ def _eval_step(model, rng_key, metrics, batch, **kwargs): def train_fn( *, fns: tuple[Callable, Callable], - shardings: Optional[tuple[jax.NamedSharding, jax.NamedSharding]] = None, + shardings: tuple[jax.NamedSharding, jax.NamedSharding] | None = None, n_steps: int, eval_every_n_steps: int, n_eval_batches: int, From 67e49fe936839a11f42769db638f20746ab96743 Mon Sep 17 00:00:00 2001 From: Simon Dirmeier Date: Tue, 19 Aug 2025 10:15:43 +0200 Subject: [PATCH 3/5] Add some more documentation --- .github/workflows/release.yaml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index 3c918b9..d1b137c 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -10,11 +10,11 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [3.9] + python-version: [3.11] steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 + uses: actions/setup-python@v3 with: python-version: ${{ matrix.python-version }} - name: Install pypa/build From 628aef0b190a83d8c9e9f6e74bb71f170a25f658 Mon Sep 17 00:00:00 2001 From: Simon Dirmeier Date: Tue, 19 Aug 2025 10:19:50 +0200 Subject: [PATCH 4/5] Add some more documentation --- README.md | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/README.md b/README.md index 6a2680f..2a908ef 100644 --- a/README.md +++ b/README.md @@ -222,6 +222,26 @@ model, optimizer = restore_best(model, optimizer) model, optimizer = restore_last(model, optimizer) ``` +### Doing training + +After having defined train functions, hooks and shardings, you can train your model like this: + +```python +train = train_fn( + fns=(train_step, val_step), + n_steps=n_steps, + eval_every_n_steps=eval_every_n_steps, + n_eval_batches=n_eval_batches, + shardings=(model_sharding, data_sharding), + hooks=hooks, + log_to_wandb=False, +) +train(jr.key(1), model, optimizer, train_itr, val_itr) +``` + +An self-contained example that also explains how the data loaders should look like can be found +in [examples/mnist_classification](examples/mnist_classification). + ## Installation To install the package from PyPI, call: From 288cb9f26f3836e57aea2a4defd61ed2e10fb94d Mon Sep 17 00:00:00 2001 From: Simon Dirmeier Date: Tue, 19 Aug 2025 10:38:20 +0200 Subject: [PATCH 5/5] Add some more documentation --- .github/workflows/ci.yaml | 27 --------------------------- 1 file changed, 27 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 505e493..81ffc4e 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -36,30 +36,3 @@ jobs: - name: Run lints run: | make lints - - tests: - runs-on: ubuntu-latest - needs: - - lints - strategy: - matrix: - python-version: [ 3.11, 3.12 ] - steps: - - uses: actions/checkout@v3 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v3 - with: - python-version: ${{ matrix.python-version }} - - uses: astral-sh/setup-uv@v5 - with: - version: "latest" - - name: Install dependencies - run: | - uv sync --dev - - name: Run tests - run: | - make tests - - name: Upload coverage reports to Codecov - uses: codecov/codecov-action@v3 - env: - CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}