Goal: compare different architectures on generating images from the MNIST dataset using flow matching
architectures to compare:
- UNet model (from the course) -
unet.py - Diffusion Transformer -
dit.py - Mamba -
mamba.py
training run and example generations can be seen in experiment.ipynb
- upon visual inspection, the mamba generated samples are noticably worse than the other models, they have more artifacts and the digits are less clear
- DiT and UNet seem pretty comparable in generation quality
- interesting to note that the DiT (16.8M params) has 14x more parameters than the UNet (1.2M params) but training speed is only 5% slower
- interesting that UNet performance is very close to DiT despite using 14x less parameters
avg loss of final 500 steps:
- UNet: 131.15
- DiT: 123.94
- Mamba: 154.31
the avg loss values are pretty consistent when aggregating over final 1000, 500, 100, and 50 steps
code organization:
common.pycontains shared abstract classesSampleableandConditionalVectorFieldgaussian_probability_path.pycontains code forGaussianConditionalProbabilityPathwhich is used to add noise to mnist images during trainingsimulator_utils.pycontains the definitions forODE(ordinary differential equation) and the ODE simulatormnist_sampler.pycontains the implementation of the MNIST dataloader / samplerCFGTrainer.pycontains code for theCFGTrainerclass
much of the scaffolding/utility code was based on assignments from MIT's Introduction to Flow Matching and Diffusion Models course
- credit to Peter E. Holderrieth and Ezra Erives
- try a harder dataset than MNIST to see if perf gap widens between DiT and UNet
- try training the model to do ε-prediction (diffusion) rather than velocity prediction (flow matching)
