Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@ repos:
- id: check-yaml
- id: debug-statements
- id: end-of-file-fixer
- id: no-commit-to-branch
args: [--branch, main]
- id: requirements-txt-fixer
- id: trailing-whitespace

Expand All @@ -20,7 +18,7 @@ repos:
hooks:
- id: mypy
args: ["--ignore-missing-imports"]
files: "(fll)"
files: "(blaxbird)"

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.3.0
Expand Down
2 changes: 1 addition & 1 deletion blaxbird/_src/experimental/nn/dit.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def __call__(self, inputs: jax.Array, context: jax.Array) -> jax.Array:
class DiT(nnx.Module):
def __init__( # noqa: PLR0913
self,
image_size: int,
image_size: tuple[int, int, int],
n_hidden_channels: int,
patch_size: int,
n_layers: int,
Expand Down
2 changes: 1 addition & 1 deletion blaxbird/_src/experimental/nn/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def __init__(
kernel_init: nnx.initializers.Initializer = nnx.initializers.lecun_normal(),
bias_init: nnx.initializers.Initializer = nnx.initializers.zeros_init(),
use_bias: bool = True,
dropout_rate: float = None,
dropout_rate: float | None = None,
activation: Callable[[jax.Array], jax.Array] = jax.nn.silu,
activate_last: bool = False,
rngs: nnx.rnglib.Rngs,
Expand Down
2 changes: 1 addition & 1 deletion blaxbird/_src/experimental/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def heun_sampler_fn(config: EDMConfig):
params = config.parameterization

# ruff: noqa: ANN001, ANN202, ANN003
def _denoise(model, rng_key, inputs, sigma, context, params):
def _denoise(model, rng_key, inputs, sigma, context):
new_shape = (-1,) + tuple(np.ones(inputs.ndim - 1, dtype=np.int32).tolist())
inputs_t = inputs * params.in_scaling(sigma).reshape(new_shape)
noise_cond = params.noise_conditioning(sigma)
Expand Down
11 changes: 7 additions & 4 deletions blaxbird/_src/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,16 +80,18 @@ def train(
step_fn, eval_fn = _step_and_val_fns(fns)
# get model and replicate
state = nnx.state((model, optimizer))
state = jax.device_put(state, shardings[0])
nnx.update((model, optimizer), state)
if shardings is not None:
state = jax.device_put(state, shardings[0])
nnx.update((model, optimizer), state)
# metrics
metrics = nnx.MultiMetric(loss=nnx.metrics.Average("loss"))
metrics_history = {}
# run training
step_key, rng_key = jr.split(rng_key)
for step, batch in zip(range(1, n_steps + 1), train_itr):
train_key, val_key = jr.split(jr.fold_in(step_key, step))
batch = jax.device_put(batch, shardings[1])
if shardings is not None:
batch = jax.device_put(batch, shardings[1])
# do a gradient step
step_fn(
model=model,
Expand All @@ -107,7 +109,8 @@ def train(
metrics_history[f"train/{metric}"] = float(value)
# do evaluation loop
for val_idx, batch in zip(range(n_eval_batches), val_itr):
batch = jax.device_put(batch, shardings[1])
if shardings is not None:
batch = jax.device_put(batch, shardings[1])
eval_fn(
model=model,
rng_key=jr.fold_in(val_key, val_idx),
Expand Down
58 changes: 38 additions & 20 deletions examples/cifar10_flow_matching/main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import argparse
import os

import dataloader
import jax
import matplotlib.pyplot as plt
import numpy as np
Expand All @@ -12,7 +13,6 @@
from jax.experimental import mesh_utils

import blaxbird
import dataloader
from blaxbird import get_default_checkpointer, train_fn
from blaxbird.experimental import rfm

Expand All @@ -38,19 +38,19 @@ def visualize_hook(sample_fn, val_iter, hook_every_n_steps, log_to_wandb):

def convert_batch_to_image_grid(image_batch):
reshaped = (
image_batch.reshape(n_row, n_col, *img_size)
.transpose([0, 2, 1, 3, 4])
.reshape(n_row * img_size[0], n_col * img_size[1], img_size[2])
image_batch.reshape(n_row, n_col, *img_size)
.transpose([0, 2, 1, 3, 4])
.reshape(n_row * img_size[0], n_col * img_size[1], img_size[2])
)
return (reshaped + 1.0) / 2.0

def plot(images):
fig = plt.figure(figsize=(16, 6))
ax = fig.add_subplot(1, 1, 1)
ax.imshow(
images,
interpolation="nearest",
cmap="gray",
images,
interpolation="nearest",
cmap="gray",
)
plt.axis("off")
plt.tight_layout()
Expand All @@ -61,26 +61,31 @@ def fn(step, *, model, **kwargs):
return
all_samples = []
for i, batch in enumerate(val_iter):
samples = sample_fn(model, jr.fold_in(jr.key(step), i), sample_shape=batch["inputs"].shape)
all_samples.append(samples)
if len(all_samples) * all_samples[0].shape[0] >= n_row * n_col:
break
all_samples = np.concatenate(all_samples, axis=0)[:(n_row * n_col)]
samples = sample_fn(
model, jr.fold_in(jr.key(step), i), sample_shape=batch["inputs"].shape
)
all_samples.append(samples)
if len(all_samples) * all_samples[0].shape[0] >= n_row * n_col:
break
all_samples = np.concatenate(all_samples, axis=0)[: (n_row * n_col)]
all_samples = convert_batch_to_image_grid(all_samples)
fig = plot(all_samples)
if jax.process_index() == 0 and log_to_wandb:
wandb.log({"images": wandb.Image(fig)}, step=step)
wandb.log({"images": wandb.Image(fig)}, step=step)

return fn


def get_hooks(sample_fn, val_itr, hook_every_n_steps, log_to_wandb ):
def get_hooks(sample_fn, val_itr, hook_every_n_steps, log_to_wandb):
return [visualize_hook(sample_fn, val_itr, hook_every_n_steps, log_to_wandb)]


def get_train_and_val_itrs(rng_key, outfolder):
return dataloader.data_loaders(
rng_key, outfolder, split=["train[:90%]", "train[90%:]"], shuffle=[True, False],
rng_key,
outfolder,
split=["train[:90%]", "train[90%:]"],
shuffle=[True, False],
)


Expand All @@ -92,14 +97,19 @@ def run(n_steps, eval_every_n_steps, n_eval_batches, dit_type, log_to_wandb):
jr.key(0), os.path.join(outfolder, "data")
)

model = getattr(blaxbird.experimental, dit_type)(image_size=(32, 32, 3), rngs=nnx.rnglib.Rngs(jr.key(1)))
model = getattr(blaxbird.experimental, dit_type)(
image_size=(32, 32, 3), rngs=nnx.rnglib.Rngs(jr.key(1))
)
train_step, val_step, sample_fn = rfm()
optimizer = get_optimizer(model)

save_fn, _, restore_last_fn = get_default_checkpointer(
os.path.join(outfolder, "checkpoints"), save_every_n_steps=eval_every_n_steps
os.path.join(outfolder, "checkpoints"),
save_every_n_steps=eval_every_n_steps,
)
hooks = get_hooks(sample_fn, val_itr, eval_every_n_steps, log_to_wandb) + [save_fn]
hooks = get_hooks(sample_fn, val_itr, eval_every_n_steps, log_to_wandb) + [
save_fn
]

model_sharding, data_sharding = get_sharding()
model, optimizer = restore_last_fn(model, optimizer)
Expand All @@ -121,7 +131,15 @@ def run(n_steps, eval_every_n_steps, n_eval_batches, dit_type, log_to_wandb):
parser.add_argument("--n-steps", type=int, default=1_000)
parser.add_argument("--eval-every-n-steps", type=int, default=50)
parser.add_argument("--n-eval-batches", type=int, default=10)
parser.add_argument("--dit", type=str, choices=["SmallDiT", "BaseDiT"], default="SmallDiT")
parser.add_argument(
"--dit", type=str, choices=["SmallDiT", "BaseDiT"], default="SmallDiT"
)
parser.add_argument("--log-to-wandb", action="store_true")
args = parser.parse_args()
run(args.n_steps, args.eval_every_n_steps, args.n_eval_batches, args.dit, args.log_to_wandb)
run(
args.n_steps,
args.eval_every_n_steps,
args.n_eval_batches,
args.dit,
args.log_to_wandb,
)
6 changes: 3 additions & 3 deletions examples/mnist_classification/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,10 @@
from flax import nnx
from jax import random as jr
from jax.experimental import mesh_utils
from model import CNN, train_step, val_step

from blaxbird import get_default_checkpointer, train_fn

from model import CNN, train_step, val_step


def get_optimizer(model, lr=1e-4):
tx = optax.adamw(lr)
Expand Down Expand Up @@ -83,7 +82,8 @@ def run(n_steps, eval_every_n_steps, n_eval_batches):
optimizer = get_optimizer(model)

save_fn, _, restore_last_fn = get_default_checkpointer(
os.path.join(outfolder, "checkpoints"), save_every_n_steps=eval_every_n_steps
os.path.join(outfolder, "checkpoints"),
save_every_n_steps=eval_every_n_steps,
)
hooks = get_hooks(val_itr, eval_every_n_steps) + [save_fn]

Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,15 @@ addopts = "-v --doctest-modules --cov=./blaxbird --cov-report=xml"
[tool.ruff]
indent-width = 2
line-length = 80
exclude = ["*_test.py", "docs/**", "examples/**"]

[tool.ruff.lint]
select = ["D", "E", "F", "W", "I001"]
extend-select = [
"UP", "I", "PL", "S"
]
ignore = ["S101", "ANN101", "PLR2044", "PLR0913"]
exclude = ["*_test.py", "docs/**", "examples/**"]


[tool.ruff.lint.pydocstyle]
convention= 'google'
Expand Down