From 56d45fce4c7be94742f6e1ac0764ae46255d5666 Mon Sep 17 00:00:00 2001 From: Simon Dirmeier Date: Thu, 21 Aug 2025 11:20:40 +0200 Subject: [PATCH 1/2] Increment version --- .pre-commit-config.yaml | 4 +- examples/cifar10_flow_matching/main.py | 58 +++++++++++++++++--------- examples/mnist_classification/main.py | 6 +-- pyproject.toml | 3 +- 4 files changed, 44 insertions(+), 27 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 930dc2e..76bceb2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -10,8 +10,6 @@ repos: - id: check-yaml - id: debug-statements - id: end-of-file-fixer - - id: no-commit-to-branch - args: [--branch, main] - id: requirements-txt-fixer - id: trailing-whitespace @@ -20,7 +18,7 @@ repos: hooks: - id: mypy args: ["--ignore-missing-imports"] - files: "(fll)" + files: "(blaxbird)" - repo: https://github.com/astral-sh/ruff-pre-commit rev: v0.3.0 diff --git a/examples/cifar10_flow_matching/main.py b/examples/cifar10_flow_matching/main.py index 08946ae..7a94859 100644 --- a/examples/cifar10_flow_matching/main.py +++ b/examples/cifar10_flow_matching/main.py @@ -1,6 +1,7 @@ import argparse import os +import dataloader import jax import matplotlib.pyplot as plt import numpy as np @@ -12,7 +13,6 @@ from jax.experimental import mesh_utils import blaxbird -import dataloader from blaxbird import get_default_checkpointer, train_fn from blaxbird.experimental import rfm @@ -38,9 +38,9 @@ def visualize_hook(sample_fn, val_iter, hook_every_n_steps, log_to_wandb): def convert_batch_to_image_grid(image_batch): reshaped = ( - image_batch.reshape(n_row, n_col, *img_size) - .transpose([0, 2, 1, 3, 4]) - .reshape(n_row * img_size[0], n_col * img_size[1], img_size[2]) + image_batch.reshape(n_row, n_col, *img_size) + .transpose([0, 2, 1, 3, 4]) + .reshape(n_row * img_size[0], n_col * img_size[1], img_size[2]) ) return (reshaped + 1.0) / 2.0 @@ -48,9 +48,9 @@ def plot(images): fig = plt.figure(figsize=(16, 6)) ax = fig.add_subplot(1, 1, 1) ax.imshow( - images, - interpolation="nearest", - cmap="gray", + images, + interpolation="nearest", + cmap="gray", ) plt.axis("off") plt.tight_layout() @@ -61,26 +61,31 @@ def fn(step, *, model, **kwargs): return all_samples = [] for i, batch in enumerate(val_iter): - samples = sample_fn(model, jr.fold_in(jr.key(step), i), sample_shape=batch["inputs"].shape) - all_samples.append(samples) - if len(all_samples) * all_samples[0].shape[0] >= n_row * n_col: - break - all_samples = np.concatenate(all_samples, axis=0)[:(n_row * n_col)] + samples = sample_fn( + model, jr.fold_in(jr.key(step), i), sample_shape=batch["inputs"].shape + ) + all_samples.append(samples) + if len(all_samples) * all_samples[0].shape[0] >= n_row * n_col: + break + all_samples = np.concatenate(all_samples, axis=0)[: (n_row * n_col)] all_samples = convert_batch_to_image_grid(all_samples) fig = plot(all_samples) if jax.process_index() == 0 and log_to_wandb: - wandb.log({"images": wandb.Image(fig)}, step=step) + wandb.log({"images": wandb.Image(fig)}, step=step) return fn -def get_hooks(sample_fn, val_itr, hook_every_n_steps, log_to_wandb ): +def get_hooks(sample_fn, val_itr, hook_every_n_steps, log_to_wandb): return [visualize_hook(sample_fn, val_itr, hook_every_n_steps, log_to_wandb)] def get_train_and_val_itrs(rng_key, outfolder): return dataloader.data_loaders( - rng_key, outfolder, split=["train[:90%]", "train[90%:]"], shuffle=[True, False], + rng_key, + outfolder, + split=["train[:90%]", "train[90%:]"], + shuffle=[True, False], ) @@ -92,14 +97,19 @@ def run(n_steps, eval_every_n_steps, n_eval_batches, dit_type, log_to_wandb): jr.key(0), os.path.join(outfolder, "data") ) - model = getattr(blaxbird.experimental, dit_type)(image_size=(32, 32, 3), rngs=nnx.rnglib.Rngs(jr.key(1))) + model = getattr(blaxbird.experimental, dit_type)( + image_size=(32, 32, 3), rngs=nnx.rnglib.Rngs(jr.key(1)) + ) train_step, val_step, sample_fn = rfm() optimizer = get_optimizer(model) save_fn, _, restore_last_fn = get_default_checkpointer( - os.path.join(outfolder, "checkpoints"), save_every_n_steps=eval_every_n_steps + os.path.join(outfolder, "checkpoints"), + save_every_n_steps=eval_every_n_steps, ) - hooks = get_hooks(sample_fn, val_itr, eval_every_n_steps, log_to_wandb) + [save_fn] + hooks = get_hooks(sample_fn, val_itr, eval_every_n_steps, log_to_wandb) + [ + save_fn + ] model_sharding, data_sharding = get_sharding() model, optimizer = restore_last_fn(model, optimizer) @@ -121,7 +131,15 @@ def run(n_steps, eval_every_n_steps, n_eval_batches, dit_type, log_to_wandb): parser.add_argument("--n-steps", type=int, default=1_000) parser.add_argument("--eval-every-n-steps", type=int, default=50) parser.add_argument("--n-eval-batches", type=int, default=10) - parser.add_argument("--dit", type=str, choices=["SmallDiT", "BaseDiT"], default="SmallDiT") + parser.add_argument( + "--dit", type=str, choices=["SmallDiT", "BaseDiT"], default="SmallDiT" + ) parser.add_argument("--log-to-wandb", action="store_true") args = parser.parse_args() - run(args.n_steps, args.eval_every_n_steps, args.n_eval_batches, args.dit, args.log_to_wandb) + run( + args.n_steps, + args.eval_every_n_steps, + args.n_eval_batches, + args.dit, + args.log_to_wandb, + ) diff --git a/examples/mnist_classification/main.py b/examples/mnist_classification/main.py index 16db26a..b0df4c5 100644 --- a/examples/mnist_classification/main.py +++ b/examples/mnist_classification/main.py @@ -8,11 +8,10 @@ from flax import nnx from jax import random as jr from jax.experimental import mesh_utils +from model import CNN, train_step, val_step from blaxbird import get_default_checkpointer, train_fn -from model import CNN, train_step, val_step - def get_optimizer(model, lr=1e-4): tx = optax.adamw(lr) @@ -83,7 +82,8 @@ def run(n_steps, eval_every_n_steps, n_eval_batches): optimizer = get_optimizer(model) save_fn, _, restore_last_fn = get_default_checkpointer( - os.path.join(outfolder, "checkpoints"), save_every_n_steps=eval_every_n_steps + os.path.join(outfolder, "checkpoints"), + save_every_n_steps=eval_every_n_steps, ) hooks = get_hooks(val_itr, eval_every_n_steps) + [save_fn] diff --git a/pyproject.toml b/pyproject.toml index 29e0bea..122aca8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -68,7 +68,6 @@ addopts = "-v --doctest-modules --cov=./blaxbird --cov-report=xml" [tool.ruff] indent-width = 2 line-length = 80 -exclude = ["*_test.py", "docs/**", "examples/**"] [tool.ruff.lint] select = ["D", "E", "F", "W", "I001"] @@ -76,6 +75,8 @@ extend-select = [ "UP", "I", "PL", "S" ] ignore = ["S101", "ANN101", "PLR2044", "PLR0913"] +exclude = ["*_test.py", "docs/**", "examples/**"] + [tool.ruff.lint.pydocstyle] convention= 'google' From b67d95bf9c650ef54590857898e91efd8b316098 Mon Sep 17 00:00:00 2001 From: Simon Dirmeier Date: Thu, 21 Aug 2025 11:26:47 +0200 Subject: [PATCH 2/2] fix --- blaxbird/_src/experimental/nn/dit.py | 2 +- blaxbird/_src/experimental/nn/mlp.py | 2 +- blaxbird/_src/experimental/samplers.py | 2 +- blaxbird/_src/trainer.py | 11 +++++++---- 4 files changed, 10 insertions(+), 7 deletions(-) diff --git a/blaxbird/_src/experimental/nn/dit.py b/blaxbird/_src/experimental/nn/dit.py index 01c1824..3c11dbf 100644 --- a/blaxbird/_src/experimental/nn/dit.py +++ b/blaxbird/_src/experimental/nn/dit.py @@ -124,7 +124,7 @@ def __call__(self, inputs: jax.Array, context: jax.Array) -> jax.Array: class DiT(nnx.Module): def __init__( # noqa: PLR0913 self, - image_size: int, + image_size: tuple[int, int, int], n_hidden_channels: int, patch_size: int, n_layers: int, diff --git a/blaxbird/_src/experimental/nn/mlp.py b/blaxbird/_src/experimental/nn/mlp.py index cd18af5..f7bec36 100644 --- a/blaxbird/_src/experimental/nn/mlp.py +++ b/blaxbird/_src/experimental/nn/mlp.py @@ -14,7 +14,7 @@ def __init__( kernel_init: nnx.initializers.Initializer = nnx.initializers.lecun_normal(), bias_init: nnx.initializers.Initializer = nnx.initializers.zeros_init(), use_bias: bool = True, - dropout_rate: float = None, + dropout_rate: float | None = None, activation: Callable[[jax.Array], jax.Array] = jax.nn.silu, activate_last: bool = False, rngs: nnx.rnglib.Rngs, diff --git a/blaxbird/_src/experimental/samplers.py b/blaxbird/_src/experimental/samplers.py index 61246d1..1a72535 100644 --- a/blaxbird/_src/experimental/samplers.py +++ b/blaxbird/_src/experimental/samplers.py @@ -69,7 +69,7 @@ def heun_sampler_fn(config: EDMConfig): params = config.parameterization # ruff: noqa: ANN001, ANN202, ANN003 - def _denoise(model, rng_key, inputs, sigma, context, params): + def _denoise(model, rng_key, inputs, sigma, context): new_shape = (-1,) + tuple(np.ones(inputs.ndim - 1, dtype=np.int32).tolist()) inputs_t = inputs * params.in_scaling(sigma).reshape(new_shape) noise_cond = params.noise_conditioning(sigma) diff --git a/blaxbird/_src/trainer.py b/blaxbird/_src/trainer.py index ec2771f..06a9aee 100644 --- a/blaxbird/_src/trainer.py +++ b/blaxbird/_src/trainer.py @@ -80,8 +80,9 @@ def train( step_fn, eval_fn = _step_and_val_fns(fns) # get model and replicate state = nnx.state((model, optimizer)) - state = jax.device_put(state, shardings[0]) - nnx.update((model, optimizer), state) + if shardings is not None: + state = jax.device_put(state, shardings[0]) + nnx.update((model, optimizer), state) # metrics metrics = nnx.MultiMetric(loss=nnx.metrics.Average("loss")) metrics_history = {} @@ -89,7 +90,8 @@ def train( step_key, rng_key = jr.split(rng_key) for step, batch in zip(range(1, n_steps + 1), train_itr): train_key, val_key = jr.split(jr.fold_in(step_key, step)) - batch = jax.device_put(batch, shardings[1]) + if shardings is not None: + batch = jax.device_put(batch, shardings[1]) # do a gradient step step_fn( model=model, @@ -107,7 +109,8 @@ def train( metrics_history[f"train/{metric}"] = float(value) # do evaluation loop for val_idx, batch in zip(range(n_eval_batches), val_itr): - batch = jax.device_put(batch, shardings[1]) + if shardings is not None: + batch = jax.device_put(batch, shardings[1]) eval_fn( model=model, rng_key=jr.fold_in(val_key, val_idx),