BrainDiffusion: Reconstructing Visual Semantics from Non-Invasive Neural Activity Readings
The files that have been heavily modified / newly created are:
code/sc_mbm/mae_for_eeg.py
code/eeg_ldm.py
code/dc_ldm/ldm_for_eeg.py
code/dc_ldm/models/diffusion/ddpm.py
code/cluster_analysis.py
code/gen_eval_eeg.py
code/eval_generations.py
In order to replicate findings, please make sure all data has been downloaded from their respective links and organized in the mentioned file structure format at the bottom of the page.
- EEG Waves : Please download and place them in
/datasetsand/pretrainsfolders in the project root dir - Image Pairs : Please download ImageNet subset of shown images
- Stable Diffusion 1.5 Checkpoint : Please download SD 1.5 checkpoint and place in
/pretrains/models
Once all of these dependencies have been downloaded and loaded into their correct paths, follow the steps for reproducing the results.
Create and activate conda environment named dreamdiffusion from the environment.yaml
conda env create -f environment.yaml
conda activate dreamdiffusionIn this stage, the cross-attention heads and pre-trained EEG encoder will be jointly optimized with EEG-image pairs.
python code/eeg_ldm.py --dataset EEG --batch_size 10 --num_epoch 100 --lr 1e-5Optionally you can also provide a checkpoint file to resume from:
python code/eeg_ldm.py --dataset EEG --batch_size 10 --num_epoch 100 --lr 1e-5 --checkpoint_path [CHECKPOINT_PATH]After fine-tuning, EEG Encoder's embedding space can be visually plotted with t-SNE dimensional reduction.
python code/cluster_analysis.py --checkpoint_path [CHECKPOINT_PATH] -o [OUTPUT_PATH]Plot loss curves.
python code/visualize_loss.py --loss_path [LOSS_PATH]Generate images based on the held out test EEG dataset.
python code/gen_eval_eeg.py --dataset EEG --model_path [MODEL_PATH]Run evaluation metrics on the generated images from the test set. ViT ImageNet-1K classifier used for Top1-Acc and Top3-Acc scores.
python code/eval_generations.py --results_path [RESULTS_PATH]Project Directory Structure:
/pretrains
β£ π models
β β π config.yaml
β β π v1-5-pruned.ckpt
β£ π generation
β β π checkpoint_best.pth
β£ π eeg_pretain
β β π checkpoint.pth (pre-trained EEG encoder)
/datasets
β£ π imageNet_images (subset of Imagenet)
β π imagenet_label_map.csv
β π block_splits_by_image_all.pth
β π block_splits_by_image_single.pth
β π eeg_5_95_std.pth
/code
β£ π sc_mbm
β β π mae_for_eeg.py
β β π trainer.py
β β π utils.py
β£ π dc_ldm
β β π ldm_for_eeg.py
β β π utils.py
β β£ π models
β β β (adopted from LDM)
β β£ π modules
β β β (adopted from LDM)
β π stageA1_eeg_pretrain.py (main script for EEG pre-training)
β π eeg_ldm.py (main script for fine-tuning stable diffusion)
β π gen_eval_eeg.py (main script for generating images)
β π eval_generations.py (main script for evaluating generated images)
β π cluster_analysis.py (functions for embedding alignment analysis)
β π visualize_loss.py (functions for visualizing losses)
β π dataset.py (functions for loading datasets)
β π eval_metrics.py (functions for evaluation metrics)
β π config.py (configurations for the main scripts)
This code is built upon the publicly available code DreamDiffusion. Thanks to these authors for making their excellent work and codes publicly available.