From 6a9f875aac05a846a958ebafe28b232d3e77becb Mon Sep 17 00:00:00 2001 From: Larry Du Date: Sat, 22 Nov 2025 10:51:02 -0800 Subject: [PATCH 1/2] Hydra --- pyproject.toml | 1 + src/diffumon/config/sample_config.py | 22 ++++++++++++++++++++++ src/diffumon/diffusion/sampler.py | 4 ++-- 3 files changed, 25 insertions(+), 2 deletions(-) create mode 100644 src/diffumon/config/sample_config.py diff --git a/pyproject.toml b/pyproject.toml index 88f3a15..ef6bcd5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,6 +19,7 @@ dependencies = [ "einops>=0.8.0", "py7zr>=0.22.0", "gdown>=5.2.0", + "hydra-core>=1.3.2", ] dependency-groups.dev = [ diff --git a/src/diffumon/config/sample_config.py b/src/diffumon/config/sample_config.py new file mode 100644 index 0000000..6a28955 --- /dev/null +++ b/src/diffumon/config/sample_config.py @@ -0,0 +1,22 @@ +from dataclasses import dataclass, field + +from diffumon.diffusion.sampler import SamplerType + + +@dataclass +class SamplerConfig: + sampler_type: SamplerType = SamplerType.DDPM + eta: float = 0.0 + num_inference_steps: int | None = None + save_every_k_time_steps: int = -1 + + +@dataclass +class SampleConfig: + checkpoint_path: str = "checkpoints/last_diffumon_checkpoint.pth" + output_dir: str = "samples" + num_samples: int = 32 + seed: int = 1999 + device: str | None = None + chw_dims_override: list[int] | None = None + sampler: SamplerConfig = field(default_factory=SamplerConfig) diff --git a/src/diffumon/diffusion/sampler.py b/src/diffumon/diffusion/sampler.py index a15fb63..3d0f8a6 100644 --- a/src/diffumon/diffusion/sampler.py +++ b/src/diffumon/diffusion/sampler.py @@ -199,11 +199,11 @@ def p_sampler_to_images( if device is None: device = get_device() - sampler = create_sampler( + sampler = _resolve_sampler( sampler_type=sampler_type, eta=eta, num_inference_steps=num_inference_steps ) - sample_batches: list[Tensor] = sampler.sample( + sample_batches: list[Tensor] = sampler( model=model, ns=ns, chw_dims=chw_dims, From dd29571cc41368abc3990185c6a030592f7a4d20 Mon Sep 17 00:00:00 2001 From: Larry Du Date: Sat, 22 Nov 2025 13:01:36 -0800 Subject: [PATCH 2/2] Add hydra. --- README.md | 31 +++++++++----------------- pyproject.toml | 1 + src/diffumon/cli.py | 26 +--------------------- src/diffumon/sample_app.py | 45 ++++++++++++++++++++++++++++++++++++++ 4 files changed, 57 insertions(+), 46 deletions(-) create mode 100644 src/diffumon/sample_app.py diff --git a/README.md b/README.md index 9b4d581..ed34357 100644 --- a/README.md +++ b/README.md @@ -121,32 +121,21 @@ diffumon sample --checkpoint-path checkpoints/fashion_mnist_100epochs.pth --num- diffumon sample --checkpoint-path checkpoints/pokemon_11k_800epochs_32dim.pth --num-samples 32 --output-dir samples/pokemon_11k_800epochs_32dim ``` -### Generate samples with DDIM +### Hydra-powered sampling (DDPM or DDIM) -Use the deterministic DDIM sampler to cut down sampling steps: +Use the Hydra entrypoint to configure samplers without adding more CLI flags: ```bash -diffumon sample \ - --checkpoint-path checkpoints/fashion_mnist_100epochs.pth \ - --num-samples 16 \ - --sampler ddim \ - --num-inference-steps 50 \ - --output-dir samples/fashion_mnist_ddim_50 +python -m diffumon.sample_app \ + checkpoint_path=checkpoints/fashion_mnist_100epochs.pth \ + output_dir=samples/fashion_mnist_ddim \ + num_samples=16 \ + sampler.sampler_type=ddim \ + sampler.num_inference_steps=50 \ + sampler.eta=0.0 ``` -Add a bit of stochasticity (non‑zero eta) if you want more diverse outputs: - -```bash -diffumon sample \ - --checkpoint-path checkpoints/fashion_mnist_100epochs.pth \ - --num-samples 16 \ - --sampler ddim \ - --num-inference-steps 50 \ - --ddim-eta 0.2 \ - --output-dir samples/fashion_mnist_ddim_eta02 -``` - -Omitting `--num-inference-steps` runs DDIM across the full training schedule. +Hydra makes overrides easy, e.g. add stochasticity with `sampler.eta=0.2` or save intermediate steps with `sampler.save_every_k_time_steps=50`. ## Useful resources diff --git a/pyproject.toml b/pyproject.toml index ef6bcd5..be2369b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,6 +52,7 @@ packages = ["src/diffumon"] [project.scripts] diffumon = 'diffumon.cli:main' +diffumon-sample = 'diffumon.sample_app:main' [tool.isort] profile = "black" diff --git a/src/diffumon/cli.py b/src/diffumon/cli.py index 2891daa..d55c440 100644 --- a/src/diffumon/cli.py +++ b/src/diffumon/cli.py @@ -13,7 +13,7 @@ download_pokemon_sprites_11k, ) from diffumon.data.transforms import forward_transform -from diffumon.diffusion.sampler import SamplerType, p_sampler_to_images +from diffumon.diffusion.sampler import p_sampler_to_images from diffumon.models.unet import Unet from diffumon.trainers.training_loop import train_noise_predictor from diffumon.utils import get_device, load_unet_checkpoint @@ -280,33 +280,12 @@ class CustomURLopener(urllib.request.FancyURLopener): @click.option( "--seed", default=1999, type=int, help="Random seed for generating samples" ) -@click.option( - "--sampler", - default="ddpm", - type=click.Choice([s.value for s in SamplerType]), - help="Sampling strategy to use for denoising.", -) -@click.option( - "--ddim-eta", - default=0.0, - type=float, - help="Amount of stochasticity for DDIM sampling (0.0 = deterministic). Ignored for DDPM.", -) -@click.option( - "--num-inference-steps", - default=None, - type=int, - help="Number of inference steps for DDIM sampling. Defaults to the full schedule when omitted.", -) def sample( num_samples: int, output_dir: str, checkpoint_path: str, device: str | None, seed: int, - sampler: str, - ddim_eta: float, - num_inference_steps: int | None, ) -> None: # Code for sampling diffumon @@ -327,9 +306,6 @@ def sample( chw_dims=chw_dims, seed=seed, output_dir=output_dir, - sampler_type=SamplerType(sampler), - eta=ddim_eta, - num_inference_steps=num_inference_steps, device=device, ) diff --git a/src/diffumon/sample_app.py b/src/diffumon/sample_app.py new file mode 100644 index 0000000..609a5d3 --- /dev/null +++ b/src/diffumon/sample_app.py @@ -0,0 +1,45 @@ +from pathlib import Path + +import hydra +from hydra.core.config_store import ConfigStore +from omegaconf import OmegaConf +import torch + +from diffumon.config.sample_config import SampleConfig +from diffumon.diffusion.sampler import SamplerType, p_sampler_to_images +from diffumon.utils import get_device, load_unet_checkpoint + +cs = ConfigStore.instance() +cs.store(name="sample_config", node=SampleConfig) + + +@hydra.main(config_name="sample_config", version_base=None) +def main(cfg: SampleConfig) -> None: + print("Sampling with config:") + print(OmegaConf.to_yaml(cfg)) + + device = torch.device(cfg.device) if cfg.device else get_device() + model, noise_schedule, _, chw_dims = load_unet_checkpoint( + cfg.checkpoint_path, device=device + ) + + output_dir = Path(cfg.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + p_sampler_to_images( + model=model, + ns=noise_schedule, + num_samples=cfg.num_samples, + chw_dims=cfg.chw_dims_override or chw_dims, + seed=cfg.seed, + output_dir=output_dir, + sampler_type=cfg.sampler.sampler_type, + eta=cfg.sampler.eta, + num_inference_steps=cfg.sampler.num_inference_steps, + save_every_k_time_steps=cfg.sampler.save_every_k_time_steps, + device=device, + ) + + +if __name__ == "__main__": + main()