Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 0 additions & 27 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,30 +36,3 @@ jobs:
- name: Run lints
run: |
make lints

tests:
runs-on: ubuntu-latest
needs:
- lints
strategy:
matrix:
python-version: [ 3.11, 3.12 ]
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python-version }}
- uses: astral-sh/setup-uv@v5
with:
version: "latest"
- name: Install dependencies
run: |
uv sync --dev
- name: Run tests
run: |
make tests
- name: Upload coverage reports to Codecov
uses: codecov/codecov-action@v3
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
6 changes: 3 additions & 3 deletions .github/workflows/release.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.9]
python-version: [3.11]
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python-version }}
- name: Install pypa/build
Expand Down
232 changes: 221 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,22 @@
# blaxbird [blækbɜːd]

[![ci](https://github.com/dirmeier/blaxbird/actions/workflows/ci.yaml/badge.svg)](https://github.com/dirmeier/blaxbird/actions/workflows/ci.yaml)
[![version](https://img.shields.io/pypi/v/blaxbird.svg?colorB=black&style=flat)](https://pypi.org/project/blaxbird/)

> A high-level API to build and train NNX models

## About

A high-level API to build and train NNX models.
`Blaxbird` [blækbɜːd] is a high-level API to easily build NNX models and train them on CPU or GPU.
Using `blaxbird` one can
- concisely define models and loss functions without the usual JAX/Flax verbosity,
- easily define checkpointers that save the best and most current network weights,
- distribute data and model weights over multiple processes or GPUs,
- define hooks that are periodically called during training.

## Example

Define the module
To use `blaxbird`, one only needs to define a model, a loss function, and train and validation step functions:
```python
import optax
from flax import nnx
Expand All @@ -18,32 +30,230 @@ def loss_fn(model, images, labels):
logits=logits, labels=labels
).mean()


def train_step(model, rng_key, batch, **kwargs):
def train_step(model, rng_key, batch):
return nnx.value_and_grad(loss_fn)(model, batch["image"], batch["label"])


def val_step(model, rng_key, batch, **kwargs):
def val_step(model, rng_key, batch):
return loss_fn(model, batch["image"], batch["label"])
```

Define the trainer
You can then define construct (and use) a training function like this:

```python
from jax import random as jr
import optax
from flax import nnx
from jax import random as jr

from blaxbird import train_fn

model = CNN(rngs=nnx.rnglib.Rngs(jr.key(1)))
optimizer = get_optimizer(model)
optimizer = nnx.Optimizer(model, optax.adam(1e-4))

train = train_fn(
fns=(train_step, val_step),
n_steps=100,
eval_every_n_steps=10,
n_eval_batches=10
)
train(jr.key(2), model, optimizer, train_itr, val_itr)
```

See a self-contained example in [examples/mnist_classification](examples/mnist_classification).

## Usage

`train_fn` is a higher order function with the following signature:

```python
def train_fn(
*,
fns: tuple[Callable, Callable],
shardings: Optional[tuple[jax.NamedSharding, jax.NamedSharding]] = None,
n_steps: int,
eval_every_n_steps: int,
n_eval_batches: int,
log_to_wandb: bool = False,
hooks: Iterable[Callable] = (),
) -> Callable:
...
```

We briefly explain the more ambiguous argument types below.

### `fns`

`fns` is a required argument consistenf of tuple of two functions, a step function and a validation function.
In the simplest case they look like this:

```python
def train_step(model, rng_key, batch):
return nnx.value_and_grad(loss_fn)(model, batch["image"], batch["label"])

def val_step(model, rng_key, batch):
return loss_fn(model, batch["image"], batch["label"])
```

Both `train_step` and `val_step` have the same arguments and argument types:
- `model` specifies a `nnx.Module`, i.e., a neural network like the CNN shown above.
- `rng_key` is a `jax.random.key` in case you need to generate random numbers.
- `batch` is a sample from a data loader (to be specified later).

The loss function that is called by both computes a *scalar* loss value. B
While `train_step` returns has to return the loss and gradients, `val_step` only needs
to return the loss.

### `shardings`

To specify how data and model weights are distributed over devices and processes,
`blaxbird` uses JAX' [sharding](https://docs.jax.dev/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html) functionality.

`shardings` is again specified by a tuple, one for the model sharding, the other for the data sharding.
An example is shown below, where we only distributed the data over `num_devices` devices.
You can, if you don't want to distribute anything, just set the argument to `None` or not specify it.

```python
def get_sharding():
num_devices = jax.local_device_count()
mesh = jax.sharding.Mesh(
mesh_utils.create_device_mesh((num_devices,)), ("data",)
)
model_sharding = jax.NamedSharding(mesh, jax.sharding.PartitionSpec())
data_sharding = jax.NamedSharding(mesh, jax.sharding.PartitionSpec("data"))
return model_sharding, data_sharding
```

### `hooks`

`hooks` is a list of callables which are periodically called during training.
Each hook has to have the following signature:

```python
def hook_fn(step, *, model, **kwargs) -> None:
...
```

It takes an integer `step` specifying the current training iteration and the model itself.
For instance, if you want to track custom metrics during validation, you could create a hook like this:

```python
def hook_fn(metrics, val_iter, hook_every_n_steps):
def fn(step, *, model, **kwargs):
if step % hook_every_n_steps != 0:
return
for batch in val_iter:
logits = model(batch["image"])
loss = optax.softmax_cross_entropy_with_integer_labels(
logits=logits, labels=batch["label"]
).mean()
metrics.update(loss=loss, logits=logits, labels=batch["label"])
if jax.process_index() == 0:
curr_metrics = ", ".join(
[f"{k}: {v}" for k, v in metrics.compute().items()]
)
logging.info(f"metrics at step {step}: {curr_metrics}")
metrics.reset()
return fn

metrics = nnx.MultiMetric(
accuracy=nnx.metrics.Accuracy(),
loss=nnx.metrics.Average("loss"),
)
hook = hook_fn(metrics, val_iter, eval_every_n_steps)
```

This creates a hook function `hook` that after `eval_every_n_steps` steps iterates over the validation set
computes accuracy and loss, and then logs everything.

To provide multiple hooks to the train function, just concatenate them in a list.

#### A checkpointing `hook`

We provide a convenient hook for checkpointing which can be constructed using
`get_default_checkpointer`. The checkpointer saves both the last `k` checkpoints with the lowest
validation loss and the last training checkpoint.

The signature of the hook is:

```python
def get_default_checkpointer(
outfolder: str,
*,
save_every_n_steps: int,
max_to_keep: int = 5,
) -> tuple[Callable, Callable, Callable]
```

Its arguments are:
- `outfolder`: a folder specifying where to store the checkpoints.
- `save_every_n_steps`: after how many training steps to store a checkpoint.
- `max_to_keep`: the number of checkpoints to keep before starting to remove old checkpoints (to not clog the device).

For instance, you would construct the checkpointing function then like this:

```python
from blaxbird import get_default_checkpointer

hook_save, *_ = get_default_checkpointer(
os.path.join(outfolder, "checkpoints"), save_every_n_steps=100
)
```

### Restoring a run

You can also use `get_default_checkpointer` to restart the run where you left off.
`get_default_checkpointer` in fact returns three functions, one for saving checkpoints and two for restoring
checkpoints:

```python
from blaxbird import get_default_checkpointer

save, restore_best, restore_last = get_default_checkpointer(
os.path.join(outfolder, "checkpoints"), save_every_n_steps=100
)
```

You can then do either of:

```python
model = CNN(rngs=nnx.rnglib.Rngs(jr.key(1)))
optimizer = nnx.Optimizer(model, optax.adam(1e-4))

model, optimizer = restore_best(model, optimizer)
model, optimizer = restore_last(model, optimizer)
```

### Doing training

After having defined train functions, hooks and shardings, you can train your model like this:

```python
train = train_fn(
fns=(train_step, val_step),
n_steps=n_steps,
n_eval_frequency=n_eval_frequency,
eval_every_n_steps=eval_every_n_steps,
n_eval_batches=n_eval_batches,
shardings=(model_sharding, data_sharding),
hooks=hooks,
log_to_wandb=False,
)
train(jr.key(2), model, optimizer, train_itr, val_itr)
train(jr.key(1), model, optimizer, train_itr, val_itr)
```

An self-contained example that also explains how the data loaders should look like can be found
in [examples/mnist_classification](examples/mnist_classification).

## Installation

To install the package from PyPI, call:

```bash
pip install blaxbird
```

To install the latest GitHub <RELEASE>, just call the following on the command line:

```bash
pip install git+https://github.com/dirmeier/blaxbird@<RELEASE>
```

## Author
Expand Down
2 changes: 1 addition & 1 deletion blaxbird/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""fll: A high-level API for building and training Flax NNX models."""
"""blaxbird: A high-level API for building and training Flax NNX models."""

__version__ = "0.0.1"

Expand Down
13 changes: 5 additions & 8 deletions blaxbird/_src/checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,28 +9,25 @@
def get_default_checkpointer(
outfolder: str,
*,
save_every_n_steps: int = 1,
save_every_n_steps: int,
max_to_keep: int = 5,
best_fn: Callable = lambda x: x["val/loss"],
best_mode: str = "min",
) -> tuple[Callable, Callable, Callable]:
"""Construct functions for checkpointing functionality.

Args:
outfolder: a path specifying where checkpoints are stored
save_every_n_steps: how often to store checkpoints
max_to_keep: number of checkpoints to store before they get deleted
best_fn: function that maintains checkpoints using a specific criterion for
quality
best_mode: use `min`, e.g., if your criterion is a loss function.
Use 'max' if the criterion is an ELBO or something.

Returns:
returns function to saev and restore checkpoints
"""
checkpointer = ocp.PyTreeCheckpointer()
options = ocp.CheckpointManagerOptions(
max_to_keep=max_to_keep, create=True, best_mode=best_mode, best_fn=best_fn
max_to_keep=max_to_keep,
create=True,
best_mode="min",
best_fn=lambda x: x["val/loss"],
)
checkpoint_manager = ocp.CheckpointManager(
os.path.join(outfolder, "best"),
Expand Down
6 changes: 3 additions & 3 deletions blaxbird/_src/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def _eval_step(model, rng_key, metrics, batch, **kwargs):
def train_fn(
*,
fns: tuple[Callable, Callable],
shardings: tuple[jax.NamedSharding, jax.NamedSharding],
shardings: tuple[jax.NamedSharding, jax.NamedSharding] | None = None,
n_steps: int,
eval_every_n_steps: int,
n_eval_batches: int,
Expand Down Expand Up @@ -70,8 +70,8 @@ def train(
rng_key: a jax.random.key object
model: a NNX model
optimizer: a nnx.Optimizer object
train_itr: a data laoder
val_itr: a data laoder
train_itr: a data loader
val_itr: a data loader
"""
# get train and val fns
step_fn, eval_fn = _step_and_val_fns(fns)
Expand Down
Loading