Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ Using `blaxbird` one can
- distribute data and model weights over multiple processes or GPUs,
- define hooks that are periodically called during training.

In addition, `blaxbird` offers high-quality implementation of common neural network modules and algorithms, such as:
In addition, `blaxbird` offers high-quality implementations of common neural network modules and algorithms, such as:

- MLP, Diffusion Transformer,
- Flow Matching and Denoising Score Matching (EDM schedules) with Euler and Heun samplers,
- Consistency Distillation/Matching.
- MLPs, DiTs, UNets,
- Flow Matching and Denoising Score Matching (EDM schedules) models with Euler and Heun samplers,
- Consistency Distillation/Matching models.

## Example

Expand Down
2 changes: 1 addition & 1 deletion blaxbird/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""blaxbird: A high-level API for building and training Flax NNX models."""

__version__ = "0.1.0"
__version__ = "0.1.1"

from blaxbird._src.checkpointer import get_default_checkpointer
from blaxbird._src.trainer import train_fn
Expand Down
51 changes: 51 additions & 0 deletions blaxbird/_src/experimental/consistency_distillation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import numpy as np
from flax import nnx
from jax import numpy as jnp
from jax import random as jr

from blaxbird._src.experimental import samplers
from blaxbird._src.experimental.parameterizations import RFMConfig


def _forward_process(inputs, times, noise):
new_shape = (-1,) + tuple(np.ones(inputs.ndim - 1, dtype=np.int32).tolist())
times = times.reshape(new_shape)
inputs_t = times * inputs + (1.0 - times) * noise
return inputs_t


def rfm(config: RFMConfig = RFMConfig()):
"""Construct rectified flow matching functions.

Args:
config: a FlowMatchingConfig object

Returns:
returns a tuple consisting of train_step, val_step and sampling functions
"""
parameterization = config.parameterization

def _loss_fn(model, rng_key, batch):
inputs = batch["inputs"]
time_key, rng_key = jr.split(rng_key)
times = jr.uniform(time_key, shape=(inputs.shape[0],))
times = (
times * (parameterization.t_max - parameterization.t_eps)
+ parameterization.t_eps
)
noise_key, rng_key = jr.split(rng_key)
noise = jr.normal(noise_key, inputs.shape)
inputs_t = _forward_process(inputs, times, noise)
vt = model(inputs=inputs_t, times=times, context=batch.get("context"))
ut = inputs - noise
loss = jnp.mean(jnp.square(ut - vt))
return loss

def train_step(model, rng_key, batch, **kwargs):
return nnx.value_and_grad(_loss_fn)(model, rng_key, batch)

def val_step(model, rng_key, batch, **kwargs):
return _loss_fn(model, rng_key, batch)

sampler = getattr(samplers, config.sampler + "_sample_fn")(config)
return train_step, val_step, sampler
66 changes: 1 addition & 65 deletions blaxbird/_src/experimental/edm.py
Original file line number Diff line number Diff line change
@@ -1,74 +1,10 @@
import dataclasses

import numpy as np
from flax import nnx
from jax import numpy as jnp
from jax import random as jr

from blaxbird._src.experimental import samplers


@dataclasses.dataclass
class EDMParameterization:
n_sampling_steps: int = 25
sigma_min: float = 0.002
sigma_max: float = 80.0
rho: float = 7.0
sigma_data: float = 0.5
P_mean: float = -1.2
P_std: float = 1.2
S_churn: float = 40
S_min: float = 0.05
S_max: float = 50
S_noise: float = 1.003

def sigma(self, eps):
return jnp.exp(eps * self.P_std + self.P_mean)

def loss_weight(self, sigma):
return (jnp.square(sigma) + jnp.square(self.sigma_data)) / jnp.square(
sigma * self.sigma_data
)

def skip_scaling(self, sigma):
return self.sigma_data**2 / (sigma**2 + self.sigma_data**2)

def out_scaling(self, sigma):
return sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5

def in_scaling(self, sigma):
return 1 / (sigma**2 + self.sigma_data**2) ** 0.5

def noise_conditioning(self, sigma):
return 0.25 * jnp.log(sigma)

def sampling_sigmas(self, num_steps):
rho_inv = 1 / self.rho
step_idxs = jnp.arange(num_steps, dtype=jnp.float32)
sigmas = (
self.sigma_max**rho_inv
+ step_idxs
/ (num_steps - 1)
* (self.sigma_min**rho_inv - self.sigma_max**rho_inv)
) ** self.rho
return jnp.concatenate([sigmas, jnp.zeros_like(sigmas[:1])])

def sigma_hat(self, sigma, num_steps):
gamma = (
jnp.minimum(self.S_churn / num_steps, 2**0.5 - 1)
if self.S_min <= sigma <= self.S_max
else 0
)
return sigma + gamma * sigma


@dataclasses.dataclass
class EDMConfig:
n_sampling_steps: int = 25
sampler: str = "heun"
parameterization: EDMParameterization = dataclasses.field(
default_factory=EDMParameterization
)
from blaxbird._src.experimental.parameterizations import EDMConfig


def edm(config: EDMConfig):
Expand Down
1 change: 0 additions & 1 deletion blaxbird/_src/experimental/nn/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@


class MLP(nnx.Module):
# ruff: noqa: PLR0913, ANN204, ANN101
def __init__(
self,
in_features: int,
Expand Down
Loading