One-step edge-to-image generation using the MeanFlow framework in VAE latent space.
Figure 1: 1-NFE Generation Results on AFHQ v2. Rows show different samples.
Left: Input Edge, Middle: Generated (1-step), Right: Ground Truth.
This project extends the MeanFlow framework from class-conditional generation to spatial conditioning (Edge-to-Image), enabling high-quality image synthesis from edge maps with a single forward pass (1-NFE).
- One-Step Generation: Generate images from edge maps using a single network evaluation
-
Multi-scale Edge Encoder
- Stem: 3× stride-2 conv (256→128→64→32)
- Multi-scale outputs: 32×32, 16×16, 8×8
- Zero-init output layers for training stability
-
Latent U-Net
- 3 resolution levels (32→16→8)
- ResBlocks with ChannelLayerNorm (JVP stability)
- Self-attention at 8×8 bottleneck
- Dual time embedding (t, t-r)
-
MeanFlow Training
- JVP-based target computation
- Adaptive loss weighting
- Training-time CFG
-
One-Step Generation: Generate images from edge maps using a single network evaluation
The model demonstrates strong structural alignment with input edges while hallucinating realistic fur textures in a single step.
High-fidelity samples across categories (Cat / Dog / Wild) generated in 1-NFE.
MeanFlow-Edge2Image/
├── src/
│ ├── config.py # Configuration settings
│ ├── train.py # Training script
│ ├── inference.py # Inference & evaluation
│ ├── models/
│ │ ├── unet.py # U-Net with multi-scale edge encoder
│ │ └── vae.py # VAE wrapper for SD-VAE
│ ├── utils/
│ │ ├── meanflow.py # MeanFlow loss & sampling
│ │ └── training.py # Training utilities (EMA, checkpoints)
│ └── datasets/
│ └── latent_dataset.py
├── tools/
│ ├── preprocess.py # Dataset preprocessing
│ └── compute_fid_stats.py # FID statistics computation
├── data/ # Dataset directory
├── checkpoints/ # Training outputs
└── fid_stats/ # FID reference statistics
cd <directory-name>It is highly recommended to use a virtual environment (e.g., venv or conda) to manage dependencies. This project requires Python >= 3.10.
# Using conda
conda create -n meanflow python=3.10
conda activate meanflowPyTorch installation depends on your system's CUDA version. It is intentionally excluded from requirements.txt to ensure a correct installation. Please visit the official PyTorch website to find the appropriate command for your setup.
Example for CUDA 12.4:
conda install mkl==2023.1.0 pytorch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 pytorch-cuda=12.4 -c pytorch -c nvidiaOnce PyTorch is installed, install the remaining packages (including pytorch-fid) using the provided requirements.txt file.
pip install -r requirements.txt# Download AFHQ v2 (Animal Faces HQ)
# Option 1: From official source
wget https://github.com/clovaai/stargan-v2/raw/master/download.sh
bash download.sh afhq-v2-dataset
# Expected structure:
# data/afhq_v2_raw/
# ├── train/
# │ ├── cat/
# │ ├── dog/
# │ └── wild/
# └── test/
# ├── cat/
# ├── dog/
# └── wild/The preprocessing script performs:
- VAE Encoding: Encode images to latent space (256×256 → 32×32×4)
- PiDiNet Edge Extraction: Extract semantic edges with Safe Mode quantization
python -m tools.preprocess \
--input data/afhq_v2_raw \
--output data/afhq_v2_processed \
--image_size 256 \
--batch_size 32 \
--num_workers 8Options:
| Argument | Default | Description |
|---|---|---|
--input |
Required | Path to raw AFHQ dataset |
--output |
Required | Path to save processed data |
--image_size |
256 | Target image size |
--batch_size |
32 | Processing batch size |
--num_workers |
8 | Number of CPU workers |
--skip_latent |
False | Skip VAE encoding (edge only) |
--save_pngs |
False | Save visualization PNGs |
Output Structure:
data/afhq_v2_processed/
├── train/
│ ├── cat/
│ │ ├── flickr_cat_000001_latent.pt # {latent: [4,32,32], edge: [1,256,256]}
│ │ └── ...
│ ├── dog/
│ └── wild/
└── val/
└── ...
python -m src.train --name my_experimentpython -m src.train --resume checkpoints/my_experiment/models/latest.ptpython -m src.train --name exp1 --data_dir /path/to/processed_dataKey hyperparameters can be modified in src/config.py:
CONFIG = {
"training": {
"total_epochs": 480,
"batch_size": 64,
"accum_steps": 4, # Effective batch size: 64 × 4 = 256
"lr": 1e-4,
"warmup_epochs": 10,
"ema_decay": 0.999,
# MeanFlow specific
"boundary_ratio": 0.25, # Ratio of r=t samples
"time_dist": "logit_normal",
"time_mean": -0.4,
"time_std": 1.0,
"adaptive_weight_p": 1.0,
# CFG
"cond_drop_prob": 0.1, # Condition dropout probability
"cfg_omega": 2.0, # Guidance scale
},
"model": {
"base_channels": 128,
"channel_mults": (1, 2, 4),
"num_res_blocks": 2,
"attention_levels": [2], # Attention at 8×8 resolution
"dropout": 0.1,
},
}checkpoints/my_experiment/
├── config.json # Saved configuration
├── logs/
│ ├── training.log # Training log
│ └── metrics.csv # Epoch-wise metrics
├── images/
│ ├── epoch_0010_ema_s1.png # Validation samples (1-step)
│ ├── epoch_0010_ema_s2.png # Validation samples (2-step)
│ └── ...
└── models/
├── best.pt # Best checkpoint (lowest MSE)
├── latest.pt # Latest checkpoint
└── epoch_0030.pt # Periodic checkpoints
Key metrics to monitor in metrics.csv:
| Metric | Description | Ideal |
|---|---|---|
raw_mse |
Mean squared error | ↓ Lower is better |
u_pred_rms |
RMS of predicted velocity | Should increase |
target_rms |
RMS of target velocity | Reference value |
u_pred_rms / target_rms |
Velocity ratio | → 1.0 |
Unlike many diffusion models requiring industrial-grade GPUs, this project is optimized for consumer hardware.
| Resource | Specification | Usage Stats |
|---|---|---|
| GPU | NVIDIA RTX 4070 | 12GB VRAM |
| Training VRAM | ~11 GB | Batch size 64 (AMP enabled) |
| Inference Speed | < 0.1s / image | 1-NFE @ 256x256 |
| Training Time | ~8 hours* | For 300 epochs |
*Estimated time. Thanks to Latent Space training and JVP optimization, MeanFlow converges efficiently on a single card.
Generate samples with edge | prediction | ground truth grid:
python -m src.inference \
--ckpt checkpoints/my_experiment/models/best.pt \
--split val \
--num_samples 20 \
--steps 1Compare 1-step, 2-step, and 4-step generation:
python -m src.inference \
--ckpt checkpoints/my_experiment/models/best.pt \
--compare_stepsOutput: [Edge | GT | 1-step | 2-step | 4-step] grid
Generate all validation images for FID calculation:
python -m src.inference \
--ckpt checkpoints/my_experiment/models/best.pt \
--fid \
--steps 1Options:
| Argument | Default | Description |
|---|---|---|
--ckpt |
Required | Path to checkpoint |
--split |
val | Dataset split |
--category |
all | Category (cat/dog/wild/all) |
--num_samples |
20 | Samples per category |
--steps |
1 | Number of sampling steps |
--compare_steps |
False | Compare 1/2/4 steps |
--fid |
False | FID generation mode |
First, compute FID statistics for the real images:
python -m tools.compute_fid_stats \
--input data/afhq_v2_raw \
--split test \
--categories all \
--image_size 256Output: fid_stats/afhq_all_val_stats.npz
python -m src.inference \
--ckpt checkpoints/my_experiment/models/best.pt \
--fid \
--steps 1Output: checkpoints/my_experiment/results/fid_samples_s1/
python -m pytorch_fid \
checkpoints/my_experiment/results/fid_samples_s1 \
fid_stats/afhq_all_val_stats.npz \
--device cudaWe evaluated the generation quality on the AFHQ v2 validation set. Surprisingly, while more steps improve FID, they degrade perceptual quality (see below).
| Steps (NFE) | FID (↓) | Perceptual Quality |
|---|---|---|
| 1 | 44.06 | Clean, smooth textures (Best for Visuals) |
| 2 | 31.59 | Sharper, occasional high-contrast artifacts |
| 4 | 28.64 | Oversaturated ("Burn" artifacts) |
While MeanFlow theoretically supports multi-step integration, we observed a "Velocity Undershoot" phenomenon where the learned vector field is conservative.
- 1-Step: Produces natural, painting-like results.
- 4-Steps: Forcing multi-step integration on this specific vector field leads to accumulated errors, resulting in high-contrast "burn" artifacts.
Conclusion: The model is effectively a specialized One-Step Solver.
- MeanFlow - Original paper
- ControlNet - Zero-convolution and edge preprocessing insights
- Stable Diffusion VAE - Pre-trained VAE
- AFHQ - Dataset



