Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 10 additions & 21 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -121,32 +121,21 @@ diffumon sample --checkpoint-path checkpoints/fashion_mnist_100epochs.pth --num-
diffumon sample --checkpoint-path checkpoints/pokemon_11k_800epochs_32dim.pth --num-samples 32 --output-dir samples/pokemon_11k_800epochs_32dim
```

### Generate samples with DDIM
### Hydra-powered sampling (DDPM or DDIM)

Use the deterministic DDIM sampler to cut down sampling steps:
Use the Hydra entrypoint to configure samplers without adding more CLI flags:

```bash
diffumon sample \
--checkpoint-path checkpoints/fashion_mnist_100epochs.pth \
--num-samples 16 \
--sampler ddim \
--num-inference-steps 50 \
--output-dir samples/fashion_mnist_ddim_50
python -m diffumon.sample_app \
checkpoint_path=checkpoints/fashion_mnist_100epochs.pth \
output_dir=samples/fashion_mnist_ddim \
num_samples=16 \
sampler.sampler_type=ddim \
sampler.num_inference_steps=50 \
sampler.eta=0.0
```

Add a bit of stochasticity (non‑zero eta) if you want more diverse outputs:

```bash
diffumon sample \
--checkpoint-path checkpoints/fashion_mnist_100epochs.pth \
--num-samples 16 \
--sampler ddim \
--num-inference-steps 50 \
--ddim-eta 0.2 \
--output-dir samples/fashion_mnist_ddim_eta02
```

Omitting `--num-inference-steps` runs DDIM across the full training schedule.
Hydra makes overrides easy, e.g. add stochasticity with `sampler.eta=0.2` or save intermediate steps with `sampler.save_every_k_time_steps=50`.

## Useful resources

Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ dependencies = [
"einops>=0.8.0",
"py7zr>=0.22.0",
"gdown>=5.2.0",
"hydra-core>=1.3.2",
]

dependency-groups.dev = [
Expand Down Expand Up @@ -51,6 +52,7 @@ packages = ["src/diffumon"]

[project.scripts]
diffumon = 'diffumon.cli:main'
diffumon-sample = 'diffumon.sample_app:main'

[tool.isort]
profile = "black"
26 changes: 1 addition & 25 deletions src/diffumon/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
download_pokemon_sprites_11k,
)
from diffumon.data.transforms import forward_transform
from diffumon.diffusion.sampler import SamplerType, p_sampler_to_images
from diffumon.diffusion.sampler import p_sampler_to_images
from diffumon.models.unet import Unet
from diffumon.trainers.training_loop import train_noise_predictor
from diffumon.utils import get_device, load_unet_checkpoint
Expand Down Expand Up @@ -280,33 +280,12 @@ class CustomURLopener(urllib.request.FancyURLopener):
@click.option(
"--seed", default=1999, type=int, help="Random seed for generating samples"
)
@click.option(
"--sampler",
default="ddpm",
type=click.Choice([s.value for s in SamplerType]),
help="Sampling strategy to use for denoising.",
)
@click.option(
"--ddim-eta",
default=0.0,
type=float,
help="Amount of stochasticity for DDIM sampling (0.0 = deterministic). Ignored for DDPM.",
)
@click.option(
"--num-inference-steps",
default=None,
type=int,
help="Number of inference steps for DDIM sampling. Defaults to the full schedule when omitted.",
)
def sample(
num_samples: int,
output_dir: str,
checkpoint_path: str,
device: str | None,
seed: int,
sampler: str,
ddim_eta: float,
num_inference_steps: int | None,
) -> None:
# Code for sampling diffumon

Expand All @@ -327,9 +306,6 @@ def sample(
chw_dims=chw_dims,
seed=seed,
output_dir=output_dir,
sampler_type=SamplerType(sampler),
eta=ddim_eta,
num_inference_steps=num_inference_steps,
device=device,
)

Expand Down
22 changes: 22 additions & 0 deletions src/diffumon/config/sample_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from dataclasses import dataclass, field

from diffumon.diffusion.sampler import SamplerType


@dataclass
class SamplerConfig:
sampler_type: SamplerType = SamplerType.DDPM
eta: float = 0.0
num_inference_steps: int | None = None
save_every_k_time_steps: int = -1


@dataclass
class SampleConfig:
checkpoint_path: str = "checkpoints/last_diffumon_checkpoint.pth"
output_dir: str = "samples"
num_samples: int = 32
seed: int = 1999
device: str | None = None
chw_dims_override: list[int] | None = None
sampler: SamplerConfig = field(default_factory=SamplerConfig)
4 changes: 2 additions & 2 deletions src/diffumon/diffusion/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,11 +199,11 @@ def p_sampler_to_images(
if device is None:
device = get_device()

sampler = create_sampler(
sampler = _resolve_sampler(
sampler_type=sampler_type, eta=eta, num_inference_steps=num_inference_steps
)

sample_batches: list[Tensor] = sampler.sample(
sample_batches: list[Tensor] = sampler(
model=model,
ns=ns,
chw_dims=chw_dims,
Expand Down
45 changes: 45 additions & 0 deletions src/diffumon/sample_app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from pathlib import Path

import hydra
from hydra.core.config_store import ConfigStore
from omegaconf import OmegaConf
import torch

from diffumon.config.sample_config import SampleConfig
from diffumon.diffusion.sampler import SamplerType, p_sampler_to_images
from diffumon.utils import get_device, load_unet_checkpoint

cs = ConfigStore.instance()
cs.store(name="sample_config", node=SampleConfig)


@hydra.main(config_name="sample_config", version_base=None)
def main(cfg: SampleConfig) -> None:
print("Sampling with config:")
print(OmegaConf.to_yaml(cfg))

device = torch.device(cfg.device) if cfg.device else get_device()
model, noise_schedule, _, chw_dims = load_unet_checkpoint(
cfg.checkpoint_path, device=device
)

output_dir = Path(cfg.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)

p_sampler_to_images(
model=model,
ns=noise_schedule,
num_samples=cfg.num_samples,
chw_dims=cfg.chw_dims_override or chw_dims,
seed=cfg.seed,
output_dir=output_dir,
sampler_type=cfg.sampler.sampler_type,
eta=cfg.sampler.eta,
num_inference_steps=cfg.sampler.num_inference_steps,
save_every_k_time_steps=cfg.sampler.save_every_k_time_steps,
device=device,
)


if __name__ == "__main__":
main()