Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ jobs:
- precommit
strategy:
matrix:
python-version: [ 3.11, 3.12 ]
python-version: [ 3.11 ]
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
Expand Down
19 changes: 13 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,19 @@
## About

`Blaxbird` [blækbɜːd] is a high-level API to easily build NNX models and train them on CPU or GPU.

Using `blaxbird` one can
- concisely define models and loss functions without the usual JAX/Flax verbosity,
- easily define checkpointers that save the best and most current network weights,
- distribute data and model weights over multiple processes or GPUs,
- define hooks that are periodically called during training.

In addition, `blaxbird` offers high-quality implementation of common neural network modules and algorithms, such as:

- MLP, Diffusion Transformer,
- Flow Matching and Denoising Score Matching (EDM schedules) with Euler and Heun samplers,
- Consistency Distillation/Matching.

## Example

To use `blaxbird`, one only needs to define a model, a loss function, and train and validation step functions:
Expand Down Expand Up @@ -58,7 +65,7 @@ train = train_fn(
train(jr.key(2), model, optimizer, train_itr, val_itr)
```

See a self-contained example in [examples/mnist_classification](examples/mnist_classification).
See the entire self-contained example in [examples/mnist_classification](examples/mnist_classification).

## Usage

Expand Down Expand Up @@ -158,7 +165,7 @@ metrics = nnx.MultiMetric(
accuracy=nnx.metrics.Accuracy(),
loss=nnx.metrics.Average("loss"),
)
hook = hook_fn(metrics, val_iter, eval_every_n_steps)
hook = hook_fn(metrics, val_iter, hook_every_n_steps)
```

This creates a hook function `hook` that after `eval_every_n_steps` steps iterates over the validation set
Expand Down Expand Up @@ -194,7 +201,7 @@ For instance, you would construct the checkpointing function then like this:
from blaxbird import get_default_checkpointer

hook_save, *_ = get_default_checkpointer(
os.path.join(outfolder, "checkpoints"), save_every_n_steps=100
"checkpoints", save_every_n_steps=100
)
```

Expand All @@ -208,7 +215,7 @@ checkpoints:
from blaxbird import get_default_checkpointer

save, restore_best, restore_last = get_default_checkpointer(
os.path.join(outfolder, "checkpoints"), save_every_n_steps=100
"checkpoints", save_every_n_steps=100
)
```

Expand Down Expand Up @@ -239,8 +246,8 @@ train = train_fn(
train(jr.key(1), model, optimizer, train_itr, val_itr)
```

An self-contained example that also explains how the data loaders should look like can be found
in [examples/mnist_classification](examples/mnist_classification).
Self-contained examples that also explain how the data loaders should look like can be found
in [examples](examples).

## Installation

Expand Down
2 changes: 1 addition & 1 deletion blaxbird/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""blaxbird: A high-level API for building and training Flax NNX models."""

__version__ = "0.0.1"
__version__ = "0.0.2"

from blaxbird._src.checkpointer import get_default_checkpointer
from blaxbird._src.trainer import train_fn
Expand Down
28 changes: 15 additions & 13 deletions blaxbird/_src/checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def save_fn(
)
checkpoint_manager.wait_until_finished()
except Exception as e:
logging.error(f"could not save checkpoint because of: {e}")
logging.error(f"could not save better checkpoint because of: {e}")
logging.error("resuming nonetheless")
try:
logging.info("saving last checkpoint")
Expand All @@ -87,15 +87,19 @@ def restore_best_fn(
"""
graph_def, state = nnx.split(model)
opt_def, opt_state = nnx.split(optimizer.opt_state)
restored = checkpoint_manager.restore(
checkpoint_manager.best_step(),
args=ocp.args.Composite(
state=ocp.args.StandardRestore(nnx.eval_shape(lambda: state)),
opt_state=ocp.args.StandardRestore(nnx.eval_shape(lambda: opt_state)),
),
)
model = nnx.merge(graph_def, restored["state"])
optimizer.opt_state = nnx.merge(opt_def, restored["opt_state"])
try:
logging.info("trying to restore best checkpoint")
restored = checkpoint_manager.restore(
checkpoint_manager.best_step(),
args=ocp.args.Composite(
state=ocp.args.StandardRestore(nnx.eval_shape(lambda: state)),
opt_state=ocp.args.StandardRestore(nnx.eval_shape(lambda: opt_state)),
),
)
model = nnx.merge(graph_def, restored["state"])
optimizer.opt_state = nnx.merge(opt_def, restored["opt_state"])
except FileNotFoundError:
logging.warning("could not find checkpoint. resuming with blank state")
return model, optimizer

def restore_last_fn(
Expand All @@ -119,9 +123,7 @@ def restore_last_fn(
model = nnx.merge(graphdef, restored[0])
optimizer.opt_state = nnx.merge(optdef, restored[1])
except FileNotFoundError:
logging.warning(
"could not find last checkpoint. resuming with blank state"
)
logging.warning("could not find checkpoint. resuming with blank state")
return model, optimizer

return save_fn, restore_best_fn, restore_last_fn
Empty file.
130 changes: 130 additions & 0 deletions blaxbird/_src/experimental/edm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import dataclasses

import numpy as np
from flax import nnx
from jax import numpy as jnp
from jax import random as jr

from blaxbird._src.experimental import samplers


@dataclasses.dataclass
class EDMParameterization:
n_sampling_steps: int = 25
sigma_min: float = 0.002
sigma_max: float = 80.0
rho: float = 7.0
sigma_data: float = 0.5
P_mean: float = -1.2
P_std: float = 1.2
S_churn: float = 40
S_min: float = 0.05
S_max: float = 50
S_noise: float = 1.003

def sigma(self, eps):
return jnp.exp(eps * self.P_std + self.P_mean)

def loss_weight(self, sigma):
return (jnp.square(sigma) + jnp.square(self.sigma_data)) / jnp.square(
sigma * self.sigma_data
)

def skip_scaling(self, sigma):
return self.sigma_data**2 / (sigma**2 + self.sigma_data**2)

def out_scaling(self, sigma):
return sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5

def in_scaling(self, sigma):
return 1 / (sigma**2 + self.sigma_data**2) ** 0.5

def noise_conditioning(self, sigma):
return 0.25 * jnp.log(sigma)

def sampling_sigmas(self, num_steps):
rho_inv = 1 / self.rho
step_idxs = jnp.arange(num_steps, dtype=jnp.float32)
sigmas = (
self.sigma_max**rho_inv
+ step_idxs
/ (num_steps - 1)
* (self.sigma_min**rho_inv - self.sigma_max**rho_inv)
) ** self.rho
return jnp.concatenate([sigmas, jnp.zeros_like(sigmas[:1])])

def sigma_hat(self, sigma, num_steps):
gamma = (
jnp.minimum(self.S_churn / num_steps, 2**0.5 - 1)
if self.S_min <= sigma <= self.S_max
else 0
)
return sigma + gamma * sigma


@dataclasses.dataclass
class EDMConfig:
n_sampling_steps: int = 25
sampler: str = "heun"
parameterization: EDMParameterization = dataclasses.field(
default_factory=EDMParameterization
)


def edm(config: EDMConfig):
"""Construct denoising score-matching functions.

Uses the EDM parameterization.

Args:
config: a EDMConfig object

Returns:
returns a tuple consisting of train_step, val_step and sampling functions
"""
parameterization = config.parameterization

def denoise(model, rng_key, inputs, sigma, context):
new_shape = (-1,) + tuple(np.ones(inputs.ndim - 1, dtype=np.int32).tolist())
inputs_t = inputs * parameterization.in_scaling(sigma).reshape(new_shape)
noise_cond = parameterization.noise_conditioning(sigma)
outputs = model(
inputs=inputs_t,
context=context,
times=noise_cond,
)
skip = inputs * parameterization.skip_scaling(sigma).reshape(new_shape)
outputs = outputs * parameterization.out_scaling(sigma).reshape(new_shape)
outputs = skip + outputs
return outputs

def loss_fn(model, rng_key, batch):
inputs = batch["inputs"]
new_shape = (-1,) + tuple(np.ones(inputs.ndim - 1, dtype=np.int32).tolist())

epsilon_key, noise_key, rng_key = jr.split(rng_key, 3)
epsilon = jr.normal(epsilon_key, (inputs.shape[0],))
sigma = parameterization.sigma(epsilon)

noise = jr.normal(noise_key, inputs.shape) * sigma.reshape(new_shape)
denoise_key, rng_key = jr.split(rng_key)
target_hat = denoise(
model,
denoise_key,
inputs=inputs + noise,
sigma=sigma,
context=batch.get("context"),
)

loss = jnp.square(inputs - target_hat)
loss = parameterization.loss_weight(sigma).reshape(new_shape) * loss
return loss.mean()

def train_step(model, rng_key, batch, **kwargs):
return nnx.value_and_grad(loss_fn)(model, rng_key, batch)

def val_step(model, rng_key, batch, **kwargs):
return loss_fn(model, rng_key, batch)

sampler = getattr(samplers, config.sampler + "sample_fn")(config)
return train_step, val_step, sampler
Empty file.
Loading