这个项目实现了Flow Matching (FM) 技术,用于图像生成任务。项目包含了DiT (Diffusion Transformer) 模型的实现,并在MNIST数据集上进行了训练和推理。
train.py: 训练脚本,使用MNIST数据集训练模型dit.py: DiT (Diffusion Transformer) 模型的实现flowmatching.py: Flow Matching 技术的实现infer.ipynb: 推理脚本,用于生成新图像coupling.ipynb: 耦合数据集训练示例 (MNIST到FashionMNIST)output.png: 训练结果示例图像traj.png: 图像生成轨迹示例
项目实现了DiT模型,包含以下主要组件:
DiTBlock: DiT的基本构建块,包含注意力机制和MLPFinalLayer: 最终输出层PatchEmbd: 图像块嵌入层- 位置编码和时间编码函数
Flow Matching技术通过学习从噪声到数据的向量场来生成图像。主要功能包括:
get_train_tuple: 生成训练数据对sample_ode: 通过ODE求解器采样生成图像
python train.py --epochs 30 --batch_size 16 --lr 2e-4训练脚本将使用MNIST数据集训练DiT模型,并将检查点保存在./checkpoints/目录中。
使用infer.ipynb Jupyter笔记本进行推理,加载训练好的模型并生成新图像。
- PyTorch
- torchvision
- timm
- tqdm
- tensorboard


