1 Archimedes/Athena RC 2 valeo.ai 3 National Technical University of Athens
4 University of Crete 5 IACM-Forth
TL;DR: We propose EQ-VAE, a straightforward regularization objective that promotes equivariance in the latent space of pretrained autoencoders under scaling and rotation. This leads to a more structured latent distribution, which accelerates generative model training and improves performance.
If you just want to use EQ-VAE to speedup 🚀 the training on your diffusion model you can use our HuggingFace checkpoints 🤗. We provide two models eq-vae and eq-vae-ema.
| Model | Basemodel | Dataset | Epochs | rFID | PSNR | LPIPS | SSIM |
|---|---|---|---|---|---|---|---|
| eq-vae | SD-VAE | OpenImages | 5 | 0.82 | 25.95 | 0.141 | 0.72 |
| eq-vae-ema | SD-VAE | Imagenet | 44 | 0.55 | 26.15 | 0.133 | 0.72 |
from diffusers import AutoencoderKL
eqvae = AutoencoderKL.from_pretrained("zelaki/eq-vae")If you are looking for the weights in the original LDM format you can find them here: eq-vae-ldm, eq-vae-ema-ldm
conda env create -f environment.yml
conda activate eqvaeWe provide a training script to finetune SD-VAE with EQ-VAE regularization. For detailed guide go to train_eqvae.
To evaluate the reconstruction of EQ-VAE, calculate rFID, LPIPS, SSIM and PSNR on a validation set (we use Imagenet Validation in our paper) with the following:
torchrun --nproc_per_node=8 eval.py \
--data_path /path/to/imagenet/validation \
--output_path results \
--ckpt_path /path/to/your/ckptTo train a DiT model with EQ-VAE on ImageNet:
- First extract the latent representations:
torchrun --nnodes=1 --nproc_per_node=8 train_gen/extract_features.py \
--data-path /path/to/imagenet/train \
--features-path /path/to/latents \
--vae-ckpt /path/to/eqvae.ckpt \
--vae-config configs/eqvae_config.yaml - Then train DiT on the precomputed latents:
accelerate launch --mixed_precision fp16 train_gen/train.py \
--model DiT-XL/2 \
--feature-path /path/to/latents \
--results-dir results- Evaluate generation as follows:
torchrun --nnodes=1 --nproc_per_node=8 sample_ddp.py \
--model DiT-XL/2 \
--num-fid-samples 50000 \
--ckpt /path/to/dit.cpt \
--sample-dir samples \
--vae-ckpt /path/to/eqvae.ckpt \
--vae-config configs/eqvae_config.yaml \
--ddpm True \
--cfg-scale 1.0This script generates a folder of 50k samples as well as a .npz file and directly used with ADM's TensorFlow evaluation suite to compute gFID.
This code is mainly built upon LDM and fastDiT.
@inproceedings{
kouzelis2025eqvae,
title={{EQ}-{VAE}: Equivariance Regularized Latent Space for Improved Generative Image Modeling},
author={Theodoros Kouzelis and Ioannis Kakogeorgiou and Spyros Gidaris and Nikos Komodakis},
booktitle={Forty-second International Conference on Machine Learning},
year={2025},
url={https://openreview.net/forum?id=UWhW5YYLo6}
}``
