From fd9f1a95af68f7c47a2774833183f132da1ae3a2 Mon Sep 17 00:00:00 2001 From: robertvava Date: Mon, 23 Mar 2026 00:34:45 +0200 Subject: [PATCH] V2 - new models, architecture, config --- configs/config.yaml | 9 + configs/data/things_eeg2.yaml | 45 +++ configs/evaluation/metrics.yaml | 12 + configs/experiment/debug.yaml | 22 ++ configs/experiment/default.yaml | 10 + configs/model/eeg_encoder.yaml | 26 ++ configs/training/contrastive.yaml | 21 ++ configs/training/diffusion_finetune.yaml | 18 + configs/wandb/default.yaml | 4 + {src => legacy}/config.py | 0 {src => legacy}/data/images/.gitkeep | 0 .../dataloading_utils/main_load.py | 0 .../dataloading_utils/misc_load.py | 0 {src => legacy}/main.py | 0 {src => legacy}/misc_utils.py | 0 {src => legacy}/models/VAE/vae.py | 0 .../models/alignment/alignment_model.py | 0 {src => legacy}/models/autoencoders/eeg_ae.py | 0 {src => legacy}/models/autoencoders/img_ae.py | 0 {src => legacy}/models/commons/Decoder.py | 0 {src => legacy}/models/commons/Encoder.py | 0 {src => legacy}/models/diff/diffusion.py | 0 .../models/joint_model/joint_model.py | 0 {src => legacy}/models/no_gen/logreg.py | 0 .../processing/eeg/pre/__init__.py | 0 {src => legacy}/processing/eeg/pre/fft.py | 0 {src => legacy}/processing/eeg/pre/gaf.py | 0 {src => legacy}/processing/img/post/denorm.py | 0 {src => legacy}/processing/img/pre/transf.py | 0 {src => legacy}/run_pipeline.py | 0 {src => legacy}/setup.py | 0 .../best_aligned_eeg_encoder128.pt | Bin .../best_aligned_eeg_encoder256.pt | Bin .../best_aligned_eeg_encoder32.pt | Bin .../best_aligned_eeg_encoder512.pt | Bin .../best_aligned_eeg_encoder64.pt | Bin .../best_aligned_image_encoder128.pt | Bin .../best_aligned_image_encoder256.pt | Bin .../best_aligned_image_encoder32.pt | Bin .../best_aligned_image_encoder512.pt | Bin .../best_aligned_image_encoder64.pt | Bin {src => legacy}/trainers/alignment_trainer.py | 0 {src => legacy}/trainers/eeg_ae_trainer.py | 0 {src => legacy}/trainers/img_ae_trainer.py | 0 {src => legacy}/trainers/joint_trainer.py | 0 {src => legacy}/trainers/reg_trainer.py | 0 pyproject.toml | 51 +++ src/eegvix/__init__.py | 3 + src/eegvix/data/__init__.py | 4 + src/eegvix/data/channel_info.py | 45 +++ src/eegvix/data/datamodule.py | 130 +++++++ src/eegvix/data/dataset.py | 182 +++++++++ src/eegvix/data/transforms.py | 137 +++++++ src/eegvix/evaluation/__init__.py | 10 + src/eegvix/evaluation/generation_metrics.py | 73 ++++ src/eegvix/evaluation/retrieval.py | 125 ++++++ src/eegvix/evaluation/rsa.py | 53 +++ src/eegvix/generation/__init__.py | 3 + src/eegvix/generation/pipeline.py | 138 +++++++ src/eegvix/losses/__init__.py | 3 + src/eegvix/losses/contrastive.py | 71 ++++ src/eegvix/models/__init__.py | 6 + src/eegvix/models/clip_wrapper.py | 81 ++++ src/eegvix/models/diffusion_wrapper.py | 126 +++++++ src/eegvix/models/eeg_encoder.py | 357 ++++++++++++++++++ src/eegvix/models/projection_head.py | 32 ++ src/eegvix/models/subject_embedding.py | 27 ++ src/eegvix/training/__init__.py | 3 + src/eegvix/training/callbacks.py | 108 ++++++ src/eegvix/training/contrastive_module.py | 164 ++++++++ src/eegvix/training/diffusion_module.py | 119 ++++++ src/eegvix/utils/__init__.py | 3 + src/eegvix/utils/io.py | 13 + src/eegvix/utils/seed.py | 13 + src/models/VAE/__pycache__/vae.cpython-37.pyc | Bin 3982 -> 0 bytes .../no_gen/__pycache__/logreg.cpython-310.pyc | Bin 1199 -> 0 bytes .../no_gen/__pycache__/logreg.cpython-37.pyc | Bin 1294 -> 0 bytes .../pre/__pycache__/transf.cpython-310.pyc | Bin 477 -> 0 bytes .../img/pre/__pycache__/transf.cpython-37.pyc | Bin 465 -> 0 bytes src/scripts/evaluate.py | 114 ++++++ src/scripts/generate.py | 72 ++++ src/scripts/precompute_clip.py | 56 +++ src/scripts/train_contrastive.py | 107 ++++++ src/scripts/train_diffusion.py | 64 ++++ tests/__init__.py | 0 tests/conftest.py | 54 +++ tests/test_losses.py | 52 +++ tests/test_models.py | 91 +++++ tests/test_retrieval.py | 42 +++ 89 files changed, 2899 insertions(+) create mode 100644 configs/config.yaml create mode 100644 configs/data/things_eeg2.yaml create mode 100644 configs/evaluation/metrics.yaml create mode 100644 configs/experiment/debug.yaml create mode 100644 configs/experiment/default.yaml create mode 100644 configs/model/eeg_encoder.yaml create mode 100644 configs/training/contrastive.yaml create mode 100644 configs/training/diffusion_finetune.yaml create mode 100644 configs/wandb/default.yaml rename {src => legacy}/config.py (100%) rename {src => legacy}/data/images/.gitkeep (100%) rename {src => legacy}/dataloading_utils/main_load.py (100%) rename {src => legacy}/dataloading_utils/misc_load.py (100%) rename {src => legacy}/main.py (100%) rename {src => legacy}/misc_utils.py (100%) rename {src => legacy}/models/VAE/vae.py (100%) rename {src => legacy}/models/alignment/alignment_model.py (100%) rename {src => legacy}/models/autoencoders/eeg_ae.py (100%) rename {src => legacy}/models/autoencoders/img_ae.py (100%) rename {src => legacy}/models/commons/Decoder.py (100%) rename {src => legacy}/models/commons/Encoder.py (100%) rename {src => legacy}/models/diff/diffusion.py (100%) rename {src => legacy}/models/joint_model/joint_model.py (100%) rename {src => legacy}/models/no_gen/logreg.py (100%) rename {src => legacy}/processing/eeg/pre/__init__.py (100%) rename {src => legacy}/processing/eeg/pre/fft.py (100%) rename {src => legacy}/processing/eeg/pre/gaf.py (100%) rename {src => legacy}/processing/img/post/denorm.py (100%) rename {src => legacy}/processing/img/pre/transf.py (100%) rename {src => legacy}/run_pipeline.py (100%) rename {src => legacy}/setup.py (100%) rename {src => legacy}/trained_models/best_aligned_eeg_encoder128.pt (100%) rename {src => legacy}/trained_models/best_aligned_eeg_encoder256.pt (100%) rename {src => legacy}/trained_models/best_aligned_eeg_encoder32.pt (100%) rename {src => legacy}/trained_models/best_aligned_eeg_encoder512.pt (100%) rename {src => legacy}/trained_models/best_aligned_eeg_encoder64.pt (100%) rename {src => legacy}/trained_models/best_aligned_image_encoder128.pt (100%) rename {src => legacy}/trained_models/best_aligned_image_encoder256.pt (100%) rename {src => legacy}/trained_models/best_aligned_image_encoder32.pt (100%) rename {src => legacy}/trained_models/best_aligned_image_encoder512.pt (100%) rename {src => legacy}/trained_models/best_aligned_image_encoder64.pt (100%) rename {src => legacy}/trainers/alignment_trainer.py (100%) rename {src => legacy}/trainers/eeg_ae_trainer.py (100%) rename {src => legacy}/trainers/img_ae_trainer.py (100%) rename {src => legacy}/trainers/joint_trainer.py (100%) rename {src => legacy}/trainers/reg_trainer.py (100%) create mode 100644 pyproject.toml create mode 100644 src/eegvix/__init__.py create mode 100644 src/eegvix/data/__init__.py create mode 100644 src/eegvix/data/channel_info.py create mode 100644 src/eegvix/data/datamodule.py create mode 100644 src/eegvix/data/dataset.py create mode 100644 src/eegvix/data/transforms.py create mode 100644 src/eegvix/evaluation/__init__.py create mode 100644 src/eegvix/evaluation/generation_metrics.py create mode 100644 src/eegvix/evaluation/retrieval.py create mode 100644 src/eegvix/evaluation/rsa.py create mode 100644 src/eegvix/generation/__init__.py create mode 100644 src/eegvix/generation/pipeline.py create mode 100644 src/eegvix/losses/__init__.py create mode 100644 src/eegvix/losses/contrastive.py create mode 100644 src/eegvix/models/__init__.py create mode 100644 src/eegvix/models/clip_wrapper.py create mode 100644 src/eegvix/models/diffusion_wrapper.py create mode 100644 src/eegvix/models/eeg_encoder.py create mode 100644 src/eegvix/models/projection_head.py create mode 100644 src/eegvix/models/subject_embedding.py create mode 100644 src/eegvix/training/__init__.py create mode 100644 src/eegvix/training/callbacks.py create mode 100644 src/eegvix/training/contrastive_module.py create mode 100644 src/eegvix/training/diffusion_module.py create mode 100644 src/eegvix/utils/__init__.py create mode 100644 src/eegvix/utils/io.py create mode 100644 src/eegvix/utils/seed.py delete mode 100644 src/models/VAE/__pycache__/vae.cpython-37.pyc delete mode 100644 src/models/no_gen/__pycache__/logreg.cpython-310.pyc delete mode 100644 src/models/no_gen/__pycache__/logreg.cpython-37.pyc delete mode 100644 src/processing/img/pre/__pycache__/transf.cpython-310.pyc delete mode 100644 src/processing/img/pre/__pycache__/transf.cpython-37.pyc create mode 100644 src/scripts/evaluate.py create mode 100644 src/scripts/generate.py create mode 100644 src/scripts/precompute_clip.py create mode 100644 src/scripts/train_contrastive.py create mode 100644 src/scripts/train_diffusion.py create mode 100644 tests/__init__.py create mode 100644 tests/conftest.py create mode 100644 tests/test_losses.py create mode 100644 tests/test_models.py create mode 100644 tests/test_retrieval.py diff --git a/configs/config.yaml b/configs/config.yaml new file mode 100644 index 0000000..e1c6d57 --- /dev/null +++ b/configs/config.yaml @@ -0,0 +1,9 @@ +defaults: + - data: things_eeg2 + - model: eeg_encoder + - training: contrastive + - evaluation: metrics + - wandb: default + - _self_ + +seed: 42 diff --git a/configs/data/things_eeg2.yaml b/configs/data/things_eeg2.yaml new file mode 100644 index 0000000..9772a71 --- /dev/null +++ b/configs/data/things_eeg2.yaml @@ -0,0 +1,45 @@ +data_dir: ${oc.env:EEGVIX_DATA_DIR,eeg_dataset} +preprocessed_dir: ${data.data_dir}/preprocessed + +n_subjects: 10 +n_channels: 17 +n_timepoints: 100 +sampling_rate: 100 # Hz (downsampled from 1000) +time_start: -0.2 # seconds relative to stimulus onset +time_end: 0.8 + +n_train_images: 16540 +n_test_images: 200 +n_train_reps: 4 +n_test_reps: 80 +images_per_concept: 10 + +# Channels: occipital and parietal +channels: + - O1 + - Oz + - O2 + - PO7 + - PO3 + - POz + - PO4 + - PO8 + - P7 + - P5 + - P3 + - P1 + - Pz + - P2 + - P4 + - P6 + - P8 + +# Data loading +batch_size: 256 +num_workers: 4 +average_repetitions: true +val_n_concepts: 150 +val_random_state: 42 + +# Precomputed embeddings +clip_embeddings_dir: ${data.data_dir}/clip_embeddings diff --git a/configs/evaluation/metrics.yaml b/configs/evaluation/metrics.yaml new file mode 100644 index 0000000..d8094f0 --- /dev/null +++ b/configs/evaluation/metrics.yaml @@ -0,0 +1,12 @@ +retrieval: + top_k: [1, 5, 10, 50, 200] + +zero_shot: + n_test_images: ${data.n_test_images} + n_test_reps: ${data.n_test_reps} + +generation: + compute_fid: true + compute_ssim: true + compute_lpips: true + n_generated_per_condition: 5 diff --git a/configs/experiment/debug.yaml b/configs/experiment/debug.yaml new file mode 100644 index 0000000..4d57810 --- /dev/null +++ b/configs/experiment/debug.yaml @@ -0,0 +1,22 @@ +# @package _global_ + +defaults: + - /data: things_eeg2 + - /model: eeg_encoder + - /training: contrastive + - /evaluation: metrics + - /wandb: default + +seed: 42 + +data: + batch_size: 16 + num_workers: 0 + n_subjects: 1 + +training: + max_epochs: 5 + early_stopping_patience: 3 + +wandb: + mode: disabled diff --git a/configs/experiment/default.yaml b/configs/experiment/default.yaml new file mode 100644 index 0000000..d4440d4 --- /dev/null +++ b/configs/experiment/default.yaml @@ -0,0 +1,10 @@ +# @package _global_ + +defaults: + - /data: things_eeg2 + - /model: eeg_encoder + - /training: contrastive + - /evaluation: metrics + - /wandb: default + +seed: 42 diff --git a/configs/model/eeg_encoder.yaml b/configs/model/eeg_encoder.yaml new file mode 100644 index 0000000..c72b469 --- /dev/null +++ b/configs/model/eeg_encoder.yaml @@ -0,0 +1,26 @@ +# Spatiotemporal Transformer EEG Encoder +eeg_encoder: + n_channels: ${data.n_channels} + n_timepoints: ${data.n_timepoints} + embed_dim: 512 + num_temporal_conv_layers: 3 + temporal_kernel_sizes: [7, 5, 3] + n_spatial_heads: 4 + n_temporal_transformer_layers: 4 + n_temporal_heads: 8 + dropout: 0.1 + use_frequency_branch: true + output_dim: 768 # CLIP ViT-L/14 embedding dimension + +subject_embedding: + n_subjects: ${data.n_subjects} + embed_dim: ${model.eeg_encoder.embed_dim} + +projection_head: + input_dim: ${model.eeg_encoder.output_dim} + hidden_dim: 2048 + output_dim: 768 + +clip: + model_name: ViT-L-14 + pretrained: openai diff --git a/configs/training/contrastive.yaml b/configs/training/contrastive.yaml new file mode 100644 index 0000000..0d72f66 --- /dev/null +++ b/configs/training/contrastive.yaml @@ -0,0 +1,21 @@ +learning_rate: 3e-4 +weight_decay: 0.01 +warmup_epochs: 10 +max_epochs: 200 +early_stopping_patience: 20 + +optimizer: adamw +scheduler: cosine_warmup + +# InfoNCE loss +init_temperature: 0.07 +learnable_temperature: true + +# Trainer +precision: 16-mixed +gradient_clip_val: 1.0 +accumulate_grad_batches: 1 +check_val_every_n_epoch: 1 + +# Subject embedding warmup +subject_embedding_warmup_epochs: 5 diff --git a/configs/training/diffusion_finetune.yaml b/configs/training/diffusion_finetune.yaml new file mode 100644 index 0000000..e4b06cb --- /dev/null +++ b/configs/training/diffusion_finetune.yaml @@ -0,0 +1,18 @@ +learning_rate: 1e-5 +weight_decay: 0.01 +max_epochs: 50 + +# Stable Diffusion +sd_model: stabilityai/stable-diffusion-2-1 +use_ip_adapter: true +use_lora: true +lora_rank: 16 +lora_alpha: 32 + +# Generation +num_inference_steps: 50 +guidance_scale: 7.5 +image_resolution: 512 + +precision: 16-mixed +gradient_clip_val: 1.0 diff --git a/configs/wandb/default.yaml b/configs/wandb/default.yaml new file mode 100644 index 0000000..5383fcb --- /dev/null +++ b/configs/wandb/default.yaml @@ -0,0 +1,4 @@ +project: eegvix-v2 +entity: null +tags: [] +mode: online # online, offline, disabled diff --git a/src/config.py b/legacy/config.py similarity index 100% rename from src/config.py rename to legacy/config.py diff --git a/src/data/images/.gitkeep b/legacy/data/images/.gitkeep similarity index 100% rename from src/data/images/.gitkeep rename to legacy/data/images/.gitkeep diff --git a/src/dataloading_utils/main_load.py b/legacy/dataloading_utils/main_load.py similarity index 100% rename from src/dataloading_utils/main_load.py rename to legacy/dataloading_utils/main_load.py diff --git a/src/dataloading_utils/misc_load.py b/legacy/dataloading_utils/misc_load.py similarity index 100% rename from src/dataloading_utils/misc_load.py rename to legacy/dataloading_utils/misc_load.py diff --git a/src/main.py b/legacy/main.py similarity index 100% rename from src/main.py rename to legacy/main.py diff --git a/src/misc_utils.py b/legacy/misc_utils.py similarity index 100% rename from src/misc_utils.py rename to legacy/misc_utils.py diff --git a/src/models/VAE/vae.py b/legacy/models/VAE/vae.py similarity index 100% rename from src/models/VAE/vae.py rename to legacy/models/VAE/vae.py diff --git a/src/models/alignment/alignment_model.py b/legacy/models/alignment/alignment_model.py similarity index 100% rename from src/models/alignment/alignment_model.py rename to legacy/models/alignment/alignment_model.py diff --git a/src/models/autoencoders/eeg_ae.py b/legacy/models/autoencoders/eeg_ae.py similarity index 100% rename from src/models/autoencoders/eeg_ae.py rename to legacy/models/autoencoders/eeg_ae.py diff --git a/src/models/autoencoders/img_ae.py b/legacy/models/autoencoders/img_ae.py similarity index 100% rename from src/models/autoencoders/img_ae.py rename to legacy/models/autoencoders/img_ae.py diff --git a/src/models/commons/Decoder.py b/legacy/models/commons/Decoder.py similarity index 100% rename from src/models/commons/Decoder.py rename to legacy/models/commons/Decoder.py diff --git a/src/models/commons/Encoder.py b/legacy/models/commons/Encoder.py similarity index 100% rename from src/models/commons/Encoder.py rename to legacy/models/commons/Encoder.py diff --git a/src/models/diff/diffusion.py b/legacy/models/diff/diffusion.py similarity index 100% rename from src/models/diff/diffusion.py rename to legacy/models/diff/diffusion.py diff --git a/src/models/joint_model/joint_model.py b/legacy/models/joint_model/joint_model.py similarity index 100% rename from src/models/joint_model/joint_model.py rename to legacy/models/joint_model/joint_model.py diff --git a/src/models/no_gen/logreg.py b/legacy/models/no_gen/logreg.py similarity index 100% rename from src/models/no_gen/logreg.py rename to legacy/models/no_gen/logreg.py diff --git a/src/processing/eeg/pre/__init__.py b/legacy/processing/eeg/pre/__init__.py similarity index 100% rename from src/processing/eeg/pre/__init__.py rename to legacy/processing/eeg/pre/__init__.py diff --git a/src/processing/eeg/pre/fft.py b/legacy/processing/eeg/pre/fft.py similarity index 100% rename from src/processing/eeg/pre/fft.py rename to legacy/processing/eeg/pre/fft.py diff --git a/src/processing/eeg/pre/gaf.py b/legacy/processing/eeg/pre/gaf.py similarity index 100% rename from src/processing/eeg/pre/gaf.py rename to legacy/processing/eeg/pre/gaf.py diff --git a/src/processing/img/post/denorm.py b/legacy/processing/img/post/denorm.py similarity index 100% rename from src/processing/img/post/denorm.py rename to legacy/processing/img/post/denorm.py diff --git a/src/processing/img/pre/transf.py b/legacy/processing/img/pre/transf.py similarity index 100% rename from src/processing/img/pre/transf.py rename to legacy/processing/img/pre/transf.py diff --git a/src/run_pipeline.py b/legacy/run_pipeline.py similarity index 100% rename from src/run_pipeline.py rename to legacy/run_pipeline.py diff --git a/src/setup.py b/legacy/setup.py similarity index 100% rename from src/setup.py rename to legacy/setup.py diff --git a/src/trained_models/best_aligned_eeg_encoder128.pt b/legacy/trained_models/best_aligned_eeg_encoder128.pt similarity index 100% rename from src/trained_models/best_aligned_eeg_encoder128.pt rename to legacy/trained_models/best_aligned_eeg_encoder128.pt diff --git a/src/trained_models/best_aligned_eeg_encoder256.pt b/legacy/trained_models/best_aligned_eeg_encoder256.pt similarity index 100% rename from src/trained_models/best_aligned_eeg_encoder256.pt rename to legacy/trained_models/best_aligned_eeg_encoder256.pt diff --git a/src/trained_models/best_aligned_eeg_encoder32.pt b/legacy/trained_models/best_aligned_eeg_encoder32.pt similarity index 100% rename from src/trained_models/best_aligned_eeg_encoder32.pt rename to legacy/trained_models/best_aligned_eeg_encoder32.pt diff --git a/src/trained_models/best_aligned_eeg_encoder512.pt b/legacy/trained_models/best_aligned_eeg_encoder512.pt similarity index 100% rename from src/trained_models/best_aligned_eeg_encoder512.pt rename to legacy/trained_models/best_aligned_eeg_encoder512.pt diff --git a/src/trained_models/best_aligned_eeg_encoder64.pt b/legacy/trained_models/best_aligned_eeg_encoder64.pt similarity index 100% rename from src/trained_models/best_aligned_eeg_encoder64.pt rename to legacy/trained_models/best_aligned_eeg_encoder64.pt diff --git a/src/trained_models/best_aligned_image_encoder128.pt b/legacy/trained_models/best_aligned_image_encoder128.pt similarity index 100% rename from src/trained_models/best_aligned_image_encoder128.pt rename to legacy/trained_models/best_aligned_image_encoder128.pt diff --git a/src/trained_models/best_aligned_image_encoder256.pt b/legacy/trained_models/best_aligned_image_encoder256.pt similarity index 100% rename from src/trained_models/best_aligned_image_encoder256.pt rename to legacy/trained_models/best_aligned_image_encoder256.pt diff --git a/src/trained_models/best_aligned_image_encoder32.pt b/legacy/trained_models/best_aligned_image_encoder32.pt similarity index 100% rename from src/trained_models/best_aligned_image_encoder32.pt rename to legacy/trained_models/best_aligned_image_encoder32.pt diff --git a/src/trained_models/best_aligned_image_encoder512.pt b/legacy/trained_models/best_aligned_image_encoder512.pt similarity index 100% rename from src/trained_models/best_aligned_image_encoder512.pt rename to legacy/trained_models/best_aligned_image_encoder512.pt diff --git a/src/trained_models/best_aligned_image_encoder64.pt b/legacy/trained_models/best_aligned_image_encoder64.pt similarity index 100% rename from src/trained_models/best_aligned_image_encoder64.pt rename to legacy/trained_models/best_aligned_image_encoder64.pt diff --git a/src/trainers/alignment_trainer.py b/legacy/trainers/alignment_trainer.py similarity index 100% rename from src/trainers/alignment_trainer.py rename to legacy/trainers/alignment_trainer.py diff --git a/src/trainers/eeg_ae_trainer.py b/legacy/trainers/eeg_ae_trainer.py similarity index 100% rename from src/trainers/eeg_ae_trainer.py rename to legacy/trainers/eeg_ae_trainer.py diff --git a/src/trainers/img_ae_trainer.py b/legacy/trainers/img_ae_trainer.py similarity index 100% rename from src/trainers/img_ae_trainer.py rename to legacy/trainers/img_ae_trainer.py diff --git a/src/trainers/joint_trainer.py b/legacy/trainers/joint_trainer.py similarity index 100% rename from src/trainers/joint_trainer.py rename to legacy/trainers/joint_trainer.py diff --git a/src/trainers/reg_trainer.py b/legacy/trainers/reg_trainer.py similarity index 100% rename from src/trainers/reg_trainer.py rename to legacy/trainers/reg_trainer.py diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..65ee24a --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,51 @@ +[build-system] +requires = ["setuptools>=68.0", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "eegvix" +version = "2.0.0" +description = "EEG-to-Image generation via CLIP-aligned contrastive learning and diffusion models." +requires-python = ">=3.10" +license = {text = "MIT"} +authors = [{name = "Robert Vava", email = "vavarobert10@gmail.com"}] + +dependencies = [ + "torch>=2.1", + "torchvision>=0.16", + "lightning>=2.1", + "open-clip-torch>=2.24", + "diffusers>=0.25", + "transformers>=4.36", + "peft>=0.7", + "hydra-core>=1.3", + "omegaconf>=2.3", + "wandb>=0.16", + "numpy>=1.24", + "scipy>=1.11", + "scikit-learn>=1.3", + "pillow>=10.0", + "torchmetrics>=1.2", + "lpips>=0.1.4", + "tqdm>=4.66", + "einops>=0.7", + "matplotlib>=3.8", +] + +[project.optional-dependencies] +dev = [ + "pytest>=7.4", + "pytest-cov>=4.1", + "ruff>=0.1", + "mypy>=1.7", +] + +[tool.setuptools.packages.find] +where = ["src"] + +[tool.ruff] +line-length = 120 +target-version = "py310" + +[tool.pytest.ini_options] +testpaths = ["tests"] diff --git a/src/eegvix/__init__.py b/src/eegvix/__init__.py new file mode 100644 index 0000000..e646052 --- /dev/null +++ b/src/eegvix/__init__.py @@ -0,0 +1,3 @@ +"""EEGVIX: EEG-to-Image generation via CLIP-aligned contrastive learning and diffusion models.""" + +__version__ = "2.0.0" diff --git a/src/eegvix/data/__init__.py b/src/eegvix/data/__init__.py new file mode 100644 index 0000000..eef3966 --- /dev/null +++ b/src/eegvix/data/__init__.py @@ -0,0 +1,4 @@ +from eegvix.data.dataset import ThingsEEG2Dataset +from eegvix.data.datamodule import ThingsEEG2DataModule + +__all__ = ["ThingsEEG2Dataset", "ThingsEEG2DataModule"] diff --git a/src/eegvix/data/channel_info.py b/src/eegvix/data/channel_info.py new file mode 100644 index 0000000..dbddd49 --- /dev/null +++ b/src/eegvix/data/channel_info.py @@ -0,0 +1,45 @@ +"""Electrode positions for the 17 occipital/parietal channels used in THINGS-EEG2. + +2D positions are approximate projections onto a unit circle (top-down view of scalp). +Coordinates follow the standard 10-10 system layout used in the dataset. +""" + +import torch + +# Channel names in the order they appear in the preprocessed data +CHANNEL_NAMES: list[str] = [ + "O1", "Oz", "O2", + "PO7", "PO3", "POz", "PO4", "PO8", + "P7", "P5", "P3", "P1", "Pz", "P2", "P4", "P6", "P8", +] + +# Approximate 2D scalp positions (x, y) in normalized coordinates [-1, 1]. +# x: left(-) to right(+), y: posterior(-) to anterior(+) +# These follow the standard 10-10 montage layout. +CHANNEL_POSITIONS_2D: dict[str, tuple[float, float]] = { + "O1": (-0.31, -0.95), + "Oz": ( 0.00, -1.00), + "O2": ( 0.31, -0.95), + "PO7": (-0.59, -0.81), + "PO3": (-0.31, -0.81), + "POz": ( 0.00, -0.81), + "PO4": ( 0.31, -0.81), + "PO8": ( 0.59, -0.81), + "P7": (-0.81, -0.59), + "P5": (-0.59, -0.59), + "P3": (-0.39, -0.59), + "P1": (-0.19, -0.59), + "Pz": ( 0.00, -0.59), + "P2": ( 0.19, -0.59), + "P4": ( 0.39, -0.59), + "P6": ( 0.59, -0.59), + "P8": ( 0.81, -0.59), +} + +N_CHANNELS = len(CHANNEL_NAMES) + + +def get_channel_positions_tensor() -> torch.Tensor: + """Return channel positions as a (17, 2) float tensor, ordered by CHANNEL_NAMES.""" + positions = [CHANNEL_POSITIONS_2D[ch] for ch in CHANNEL_NAMES] + return torch.tensor(positions, dtype=torch.float32) diff --git a/src/eegvix/data/datamodule.py b/src/eegvix/data/datamodule.py new file mode 100644 index 0000000..39c4269 --- /dev/null +++ b/src/eegvix/data/datamodule.py @@ -0,0 +1,130 @@ +"""PyTorch Lightning DataModule for THINGS-EEG2.""" + +from collections.abc import Callable +from pathlib import Path + +import numpy as np +import lightning as L +from sklearn.utils import resample +from torch.utils.data import DataLoader + +from eegvix.data.dataset import ThingsEEG2Dataset + + +class ThingsEEG2DataModule(L.LightningDataModule): + """Lightning DataModule that handles train/val/test splits for THINGS-EEG2. + + Uses concept-based stratified splitting: entire concepts (10 images each) + are assigned to validation, preserving semantic structure. + """ + + def __init__( + self, + data_dir: str | Path, + batch_size: int = 256, + num_workers: int = 4, + subjects: list[int] | None = None, + average_repetitions: bool = True, + val_n_concepts: int = 150, + val_random_state: int = 42, + clip_embeddings_path: str | Path | None = None, + eeg_transform: Callable | None = None, + augment_train: bool = True, + ): + super().__init__() + self.save_hyperparameters(ignore=["eeg_transform"]) + self.data_dir = Path(data_dir) + self.batch_size = batch_size + self.num_workers = num_workers + self.subjects = subjects + self.average_repetitions = average_repetitions + self.val_n_concepts = val_n_concepts + self.val_random_state = val_random_state + self.clip_embeddings_path = clip_embeddings_path + self.eeg_transform = eeg_transform + self.augment_train = augment_train + + def _get_val_indices(self) -> np.ndarray: + """Create concept-based validation split (same logic as original codebase).""" + n_concepts = 1654 + images_per_concept = 10 + n_conditions = n_concepts * images_per_concept # 16540 + + val_concepts = np.sort( + resample( + np.arange(n_concepts), + replace=False, + n_samples=self.val_n_concepts, + random_state=self.val_random_state, + ) + ) + + idx_val = np.zeros(n_conditions, dtype=bool) + for c in val_concepts: + idx_val[c * images_per_concept : (c + 1) * images_per_concept] = True + + return idx_val + + def setup(self, stage: str | None = None) -> None: + val_indices = self._get_val_indices() + + if stage in ("fit", None): + self.train_dataset = ThingsEEG2Dataset( + data_dir=self.data_dir, + subjects=self.subjects, + split="train", + indices=val_indices, + average_repetitions=self.average_repetitions, + clip_embeddings_path=self.clip_embeddings_path, + eeg_transform=self.eeg_transform, + augment=self.augment_train, + ) + self.val_dataset = ThingsEEG2Dataset( + data_dir=self.data_dir, + subjects=self.subjects, + split="val", + indices=val_indices, + average_repetitions=self.average_repetitions, + clip_embeddings_path=self.clip_embeddings_path, + eeg_transform=self.eeg_transform, + augment=False, + ) + + if stage in ("test", None): + self.test_dataset = ThingsEEG2Dataset( + data_dir=self.data_dir, + subjects=self.subjects, + split="test", + average_repetitions=self.average_repetitions, + clip_embeddings_path=self.clip_embeddings_path, + eeg_transform=self.eeg_transform, + augment=False, + ) + + def train_dataloader(self) -> DataLoader: + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + shuffle=True, + num_workers=self.num_workers, + pin_memory=True, + drop_last=True, + ) + + def val_dataloader(self) -> DataLoader: + return DataLoader( + self.val_dataset, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers, + pin_memory=True, + ) + + def test_dataloader(self) -> DataLoader: + return DataLoader( + self.test_dataset, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers, + pin_memory=True, + ) diff --git a/src/eegvix/data/dataset.py b/src/eegvix/data/dataset.py new file mode 100644 index 0000000..fdc6786 --- /dev/null +++ b/src/eegvix/data/dataset.py @@ -0,0 +1,182 @@ +"""THINGS-EEG2 dataset for paired EEG-image loading across multiple subjects.""" + +from collections.abc import Callable +from pathlib import Path + +import numpy as np +import torch +from PIL import Image +from torch.utils.data import Dataset +from torchvision import transforms + + +class ThingsEEG2Dataset(Dataset): + """PyTorch Dataset for THINGS-EEG2 paired EEG-image data. + + Each sample returns: + eeg: (17, 100) float tensor — EEG signal + clip_embedding: (768,) float tensor — precomputed CLIP image embedding (if available) + image: (3, 224, 224) float tensor — CLIP-preprocessed image (if clip_embeddings not precomputed) + subject_id: int — subject index (0-9) + image_id: int — unique image condition index + """ + + def __init__( + self, + data_dir: str | Path, + subjects: list[int] | None = None, + split: str = "train", + indices: np.ndarray | None = None, + average_repetitions: bool = True, + clip_embeddings_path: str | Path | None = None, + image_transform: transforms.Compose | None = None, + eeg_transform: Callable | None = None, + augment: bool = False, + ): + """ + Args: + data_dir: Root directory containing preprocessed/ and images/ subdirectories. + subjects: List of subject indices (0-9) to load. None = all 10. + split: "train", "val", or "test". + indices: For train/val, boolean mask or integer indices into the 16540 training conditions. + average_repetitions: If True, average across EEG repetitions for cleaner signal. + clip_embeddings_path: Path to precomputed CLIP embeddings (.pt file). + image_transform: Torchvision transforms for images (used if no precomputed embeddings). + eeg_transform: Callable transform for EEG data. + augment: Whether to apply data augmentation to EEG. + """ + self.data_dir = Path(data_dir) + self.split = split + self.average_repetitions = average_repetitions + self.eeg_transform = eeg_transform + self.augment = augment + + if subjects is None: + subjects = list(range(10)) + self.subjects = subjects + + # Default CLIP-compatible image transform + if image_transform is None: + self.image_transform = transforms.Compose([ + transforms.Resize((224, 224)), + transforms.ToTensor(), + transforms.Normalize( + mean=[0.48145466, 0.4578275, 0.40821073], # CLIP normalization + std=[0.26862954, 0.26130258, 0.27577711], + ), + ]) + else: + self.image_transform = image_transform + + # Load precomputed CLIP embeddings if available + self.clip_embeddings = None + if clip_embeddings_path is not None and Path(clip_embeddings_path).exists(): + self.clip_embeddings = torch.load(clip_embeddings_path, weights_only=True) + + # Load EEG data for all subjects + self._load_data(indices) + + def _load_data(self, indices: np.ndarray | None) -> None: + """Load EEG data and build image path mappings.""" + self.eeg_data: list[torch.Tensor] = [] + self.subject_ids: list[int] = [] + self.image_ids: list[int] = [] + + for subj_idx in self.subjects: + subj_dir = self.data_dir / "preprocessed" / f"sub-{subj_idx + 1:02d}" + + if self.split == "test": + raw = np.load( + subj_dir / "preprocessed_eeg_test.npy", allow_pickle=True + ).item() + eeg = raw["preprocessed_eeg_data"] # (200, 80, 17, 100) + + if self.average_repetitions: + eeg = np.mean(eeg, axis=1) # (200, 17, 100) + for img_idx in range(eeg.shape[0]): + self.eeg_data.append(torch.from_numpy(eeg[img_idx].astype(np.float32))) + self.subject_ids.append(subj_idx) + self.image_ids.append(img_idx) + else: + for img_idx in range(eeg.shape[0]): + for rep in range(eeg.shape[1]): + self.eeg_data.append(torch.from_numpy(eeg[img_idx, rep].astype(np.float32))) + self.subject_ids.append(subj_idx) + self.image_ids.append(img_idx) + else: + raw = np.load( + subj_dir / "preprocessed_eeg_training.npy", allow_pickle=True + ).item() + eeg = raw["preprocessed_eeg_data"] # (16540, 4, 17, 100) + + if self.average_repetitions: + eeg = np.mean(eeg, axis=1) # (16540, 17, 100) + + # Apply train/val split + if indices is not None: + if self.split == "val": + selected = indices + else: # train + selected = ~indices if indices.dtype == bool else np.setdiff1d(np.arange(eeg.shape[0]), indices) + else: + selected = np.arange(eeg.shape[0] if self.average_repetitions else eeg.shape[0]) + + if self.average_repetitions: + for img_idx in (selected if selected.dtype != bool else np.where(selected)[0]): + self.eeg_data.append(torch.from_numpy(eeg[img_idx].astype(np.float32))) + self.subject_ids.append(subj_idx) + self.image_ids.append(int(img_idx)) + else: + for img_idx in (selected if selected.dtype != bool else np.where(selected)[0]): + for rep in range(eeg.shape[1]): + self.eeg_data.append(torch.from_numpy(eeg[img_idx, rep].astype(np.float32))) + self.subject_ids.append(subj_idx) + self.image_ids.append(int(img_idx)) + + # Build image paths + if self.split == "test": + self.image_dir = self.data_dir / "images" / "test_images" + else: + self.image_dir = self.data_dir / "images" / "training_images" + + self._image_paths = self._build_image_paths() + + def _build_image_paths(self) -> list[Path]: + """Collect and sort image paths from the image directory.""" + paths = sorted(self.image_dir.rglob("*.jpg")) + if not paths: + paths = sorted(self.image_dir.rglob("*.JPEG")) + if not paths: + paths = sorted(self.image_dir.rglob("*.png")) + return paths + + def __len__(self) -> int: + return len(self.eeg_data) + + def __getitem__(self, idx: int) -> dict[str, torch.Tensor | int]: + eeg = self.eeg_data[idx] + subject_id = self.subject_ids[idx] + image_id = self.image_ids[idx] + + if self.eeg_transform is not None: + eeg = self.eeg_transform(eeg) + + sample = { + "eeg": eeg, + "subject_id": subject_id, + "image_id": image_id, + } + + # Prefer precomputed CLIP embeddings; fall back to loading image + if self.clip_embeddings is not None: + key = f"{self.split}_{image_id}" + if key in self.clip_embeddings: + sample["clip_embedding"] = self.clip_embeddings[key] + elif image_id < len(self.clip_embeddings.get("embeddings", [])): + sample["clip_embedding"] = self.clip_embeddings["embeddings"][image_id] + else: + if image_id < len(self._image_paths): + img = Image.open(self._image_paths[image_id]).convert("RGB") + sample["image"] = self.image_transform(img) + + return sample diff --git a/src/eegvix/data/transforms.py b/src/eegvix/data/transforms.py new file mode 100644 index 0000000..088f50d --- /dev/null +++ b/src/eegvix/data/transforms.py @@ -0,0 +1,137 @@ +"""EEG preprocessing transforms and image transforms for CLIP compatibility.""" + +import numpy as np +import torch + + +class BaselineCorrection: + """Subtract mean of pre-stimulus interval from each channel. + + The dataset has 100 timepoints at 100Hz spanning -200ms to 800ms. + Pre-stimulus interval: timepoints 0-19 (indices for -200ms to 0ms). + """ + + def __init__(self, n_pre_stimulus: int = 20): + self.n_pre_stimulus = n_pre_stimulus + + def __call__(self, eeg: torch.Tensor) -> torch.Tensor: + # eeg shape: (n_channels, n_timepoints) + baseline = eeg[:, :self.n_pre_stimulus].mean(dim=1, keepdim=True) + return eeg - baseline + + +class ChannelWiseZScore: + """Z-score normalization per channel using precomputed statistics. + + Expects statistics computed across the training set for each channel. + """ + + def __init__(self, mean: torch.Tensor, std: torch.Tensor): + # mean, std shape: (n_channels, 1) or (n_channels,) + self.mean = mean.view(-1, 1) if mean.dim() == 1 else mean + self.std = std.view(-1, 1) if std.dim() == 1 else std + + def __call__(self, eeg: torch.Tensor) -> torch.Tensor: + return (eeg - self.mean.to(eeg.device)) / (self.std.to(eeg.device) + 1e-8) + + +class RobustScaler: + """Scale EEG using median and IQR per channel. More robust to artifacts than z-score.""" + + def __init__(self, median: torch.Tensor, iqr: torch.Tensor): + self.median = median.view(-1, 1) + self.iqr = iqr.view(-1, 1) + + def __call__(self, eeg: torch.Tensor) -> torch.Tensor: + return (eeg - self.median.to(eeg.device)) / (self.iqr.to(eeg.device) + 1e-8) + + +class TemporalJitter: + """Data augmentation: randomly shift the EEG signal by a few timepoints.""" + + def __init__(self, max_shift: int = 3): + self.max_shift = max_shift + + def __call__(self, eeg: torch.Tensor) -> torch.Tensor: + shift = torch.randint(-self.max_shift, self.max_shift + 1, (1,)).item() + if shift == 0: + return eeg + if shift > 0: + return torch.cat([eeg[:, shift:], torch.zeros_like(eeg[:, :shift])], dim=1) + return torch.cat([torch.zeros_like(eeg[:, :abs(shift)]), eeg[:, :shift]], dim=1) + + +class GaussianNoise: + """Data augmentation: add Gaussian noise to EEG.""" + + def __init__(self, std: float = 0.01): + self.std = std + + def __call__(self, eeg: torch.Tensor) -> torch.Tensor: + return eeg + torch.randn_like(eeg) * self.std + + +class ChannelDropout: + """Data augmentation: randomly zero out entire channels.""" + + def __init__(self, p: float = 0.1): + self.p = p + + def __call__(self, eeg: torch.Tensor) -> torch.Tensor: + mask = torch.bernoulli(torch.full((eeg.size(0), 1), 1.0 - self.p)) + return eeg * mask + + +class BandpowerFeatures: + """Extract power in standard EEG frequency bands via FFT. + + Bands: delta (1-4Hz), theta (4-8Hz), alpha (8-13Hz), beta (13-30Hz), gamma (30-50Hz). + At 100Hz sampling rate with 100 timepoints, frequency resolution is 1Hz. + """ + + BANDS: dict[str, tuple[float, float]] = { + "delta": (1.0, 4.0), + "theta": (4.0, 8.0), + "alpha": (8.0, 13.0), + "beta": (13.0, 30.0), + "gamma": (30.0, 50.0), + } + + def __init__(self, sampling_rate: int = 100, n_timepoints: int = 100): + self.sampling_rate = sampling_rate + self.n_timepoints = n_timepoints + self.freqs = np.fft.rfftfreq(n_timepoints, d=1.0 / sampling_rate) + + def __call__(self, eeg: torch.Tensor) -> torch.Tensor: + """Extract bandpower features. + + Args: + eeg: (n_channels, n_timepoints) + + Returns: + bandpowers: (n_channels, n_bands) where n_bands=5 + """ + fft_vals = torch.fft.rfft(eeg, dim=-1) + power = (fft_vals.real ** 2 + fft_vals.imag ** 2) + + bandpowers = [] + for low, high in self.BANDS.values(): + mask = torch.tensor( + (self.freqs >= low) & (self.freqs < high), dtype=torch.float32 + ) + band_power = (power * mask.to(eeg.device)).mean(dim=-1) + bandpowers.append(band_power) + + return torch.stack(bandpowers, dim=-1) # (n_channels, 5) + + +class ComposeEEGTransforms: + """Compose multiple EEG transforms sequentially.""" + + def __init__(self, transforms: list): + self.transforms = transforms + + def __call__(self, eeg: torch.Tensor) -> torch.Tensor: + for t in self.transforms: + eeg = t(eeg) + return eeg diff --git a/src/eegvix/evaluation/__init__.py b/src/eegvix/evaluation/__init__.py new file mode 100644 index 0000000..88bf71f --- /dev/null +++ b/src/eegvix/evaluation/__init__.py @@ -0,0 +1,10 @@ +from eegvix.evaluation.retrieval import top_k_accuracy, zero_shot_identification +from eegvix.evaluation.generation_metrics import compute_fid, compute_ssim, compute_lpips + +__all__ = [ + "top_k_accuracy", + "zero_shot_identification", + "compute_fid", + "compute_ssim", + "compute_lpips", +] diff --git a/src/eegvix/evaluation/generation_metrics.py b/src/eegvix/evaluation/generation_metrics.py new file mode 100644 index 0000000..53f0f6f --- /dev/null +++ b/src/eegvix/evaluation/generation_metrics.py @@ -0,0 +1,73 @@ +"""Generation quality metrics: FID, SSIM, LPIPS.""" + +import torch +from torchmetrics.image.fid import FrechetInceptionDistance +from torchmetrics.image import StructuralSimilarityIndexMeasure + + +def compute_fid( + real_images: torch.Tensor, + generated_images: torch.Tensor, + feature_dim: int = 2048, +) -> float: + """Compute Frechet Inception Distance between real and generated images. + + Args: + real_images: (n, 3, h, w) real images in [0, 255] uint8 + generated_images: (n, 3, h, w) generated images in [0, 255] uint8 + feature_dim: InceptionV3 feature dimension + + Returns: + FID score (lower is better) + """ + fid = FrechetInceptionDistance(feature=feature_dim, normalize=True) + fid.update(real_images, real=True) + fid.update(generated_images, real=False) + return fid.compute().item() + + +def compute_ssim( + real_images: torch.Tensor, + generated_images: torch.Tensor, +) -> float: + """Compute Structural Similarity Index between paired images. + + Args: + real_images: (n, 3, h, w) in [0, 1] + generated_images: (n, 3, h, w) in [0, 1] + + Returns: + Mean SSIM (higher is better) + """ + ssim = StructuralSimilarityIndexMeasure(data_range=1.0) + return ssim(generated_images, real_images).item() + + +def compute_lpips( + real_images: torch.Tensor, + generated_images: torch.Tensor, + net: str = "alex", +) -> float: + """Compute Learned Perceptual Image Patch Similarity. + + Args: + real_images: (n, 3, h, w) in [0, 1] + generated_images: (n, 3, h, w) in [0, 1] + net: Backbone network ("alex", "vgg", "squeeze") + + Returns: + Mean LPIPS distance (lower is better) + """ + import lpips + + loss_fn = lpips.LPIPS(net=net) + loss_fn.eval() + + # LPIPS expects [-1, 1] range + real_scaled = real_images * 2 - 1 + gen_scaled = generated_images * 2 - 1 + + with torch.no_grad(): + distances = loss_fn(gen_scaled, real_scaled) + + return distances.mean().item() diff --git a/src/eegvix/evaluation/retrieval.py b/src/eegvix/evaluation/retrieval.py new file mode 100644 index 0000000..56e50ce --- /dev/null +++ b/src/eegvix/evaluation/retrieval.py @@ -0,0 +1,125 @@ +"""Retrieval accuracy and zero-shot identification metrics.""" + +import torch +import numpy as np + + +def top_k_accuracy( + eeg_embeds: torch.Tensor, + image_embeds: torch.Tensor, + k_values: list[int] | None = None, +) -> dict[str, float]: + """Compute top-k retrieval accuracy: for each EEG, check if the correct image is in top-k. + + Args: + eeg_embeds: (n, d) L2-normalized EEG embeddings + image_embeds: (n, d) L2-normalized image embeddings (same ordering as eeg_embeds) + k_values: List of k values to evaluate + + Returns: + Dict mapping "top_k" to accuracy for each k + """ + if k_values is None: + k_values = [1, 5, 10, 50, 200] + + similarity = eeg_embeds @ image_embeds.T # (n, n) + labels = torch.arange(similarity.size(0), device=similarity.device) + + results = {} + for k in k_values: + if k > similarity.size(1): + continue + top_k_preds = similarity.topk(k, dim=1).indices + correct = (top_k_preds == labels.unsqueeze(1)).any(dim=1).float().mean() + results[f"top_{k}"] = correct.item() + + return results + + +def zero_shot_identification( + bio_test_embeds: torch.Tensor, + syn_test_embeds: torch.Tensor, + distractor_embeds: torch.Tensor | None = None, +) -> dict[str, float]: + """Zero-shot identification following the THINGS-EEG2 paper protocol. + + For each biological test EEG embedding, check if the correlation with its + matching synthetic embedding is higher than with all other candidates. + + Args: + bio_test_embeds: (200, d) biological test EEG embeddings (averaged across 80 reps) + syn_test_embeds: (200, d) synthetic test embeddings (from CLIP image encoder) + distractor_embeds: (n_distractors, d) optional distractor embeddings + + Returns: + Dict with identification accuracy and per-condition results + """ + n_test = bio_test_embeds.size(0) + + # Build candidate pool + if distractor_embeds is not None: + candidates = torch.cat([syn_test_embeds, distractor_embeds], dim=0) + else: + candidates = syn_test_embeds + + # Correlation matrix + similarity = bio_test_embeds @ candidates.T # (200, n_candidates) + + # For each test condition, check if the correct candidate has the highest similarity + correct = 0 + for i in range(n_test): + # The correct match is at index i in candidates (first 200 are syn_test) + if similarity[i].argmax().item() == i: + correct += 1 + + accuracy = correct / n_test + + return { + "accuracy": accuracy, + "n_correct": correct, + "n_total": n_test, + "n_candidates": candidates.size(0), + } + + +def zero_shot_with_varying_distractors( + bio_test_embeds: torch.Tensor, + syn_test_embeds: torch.Tensor, + distractor_embeds: torch.Tensor, + set_sizes: list[int] | None = None, + n_iterations: int = 100, +) -> dict[int, float]: + """Run zero-shot identification with varying numbers of distractors. + + Replicates the paper's protocol of gradually increasing candidate set size. + + Args: + bio_test_embeds: (200, d) biological test embeddings + syn_test_embeds: (200, d) synthetic test embeddings + distractor_embeds: (n, d) pool of distractor embeddings + set_sizes: List of distractor set sizes to evaluate + n_iterations: Number of random iterations per set size + + Returns: + Dict mapping set_size to mean accuracy + """ + if set_sizes is None: + set_sizes = list(range(0, min(distractor_embeds.size(0), 150001), 1000)) + + results = {} + n_distractors = distractor_embeds.size(0) + + for size in set_sizes: + accuracies = [] + for _ in range(n_iterations): + if size == 0: + result = zero_shot_identification(bio_test_embeds, syn_test_embeds) + else: + idx = torch.randperm(n_distractors)[:size] + distractors = distractor_embeds[idx] + result = zero_shot_identification(bio_test_embeds, syn_test_embeds, distractors) + accuracies.append(result["accuracy"]) + + results[size] = float(np.mean(accuracies)) + + return results diff --git a/src/eegvix/evaluation/rsa.py b/src/eegvix/evaluation/rsa.py new file mode 100644 index 0000000..29b328c --- /dev/null +++ b/src/eegvix/evaluation/rsa.py @@ -0,0 +1,53 @@ +"""Representational Similarity Analysis (RSA) for comparing EEG and image representations.""" + +import torch +import numpy as np +from scipy.stats import spearmanr + + +def compute_rdm(embeddings: torch.Tensor) -> torch.Tensor: + """Compute Representational Dissimilarity Matrix (1 - cosine similarity). + + Args: + embeddings: (n, d) L2-normalized embeddings + + Returns: + (n, n) dissimilarity matrix + """ + similarity = embeddings @ embeddings.T + return 1.0 - similarity + + +def rsa_correlation( + eeg_embeds: torch.Tensor, + image_embeds: torch.Tensor, +) -> dict[str, float]: + """Compute RSA between EEG and image representational spaces. + + Compares the geometry of the two embedding spaces by correlating their + representational dissimilarity matrices (RDMs). + + Args: + eeg_embeds: (n, d) L2-normalized EEG embeddings + image_embeds: (n, d) L2-normalized image embeddings + + Returns: + Dict with Pearson and Spearman correlations between RDMs + """ + eeg_rdm = compute_rdm(eeg_embeds).cpu().numpy() + img_rdm = compute_rdm(image_embeds).cpu().numpy() + + # Extract upper triangle (excluding diagonal) + n = eeg_rdm.shape[0] + triu_idx = np.triu_indices(n, k=1) + eeg_upper = eeg_rdm[triu_idx] + img_upper = img_rdm[triu_idx] + + pearson_r = float(np.corrcoef(eeg_upper, img_upper)[0, 1]) + spearman_r, spearman_p = spearmanr(eeg_upper, img_upper) + + return { + "pearson_r": pearson_r, + "spearman_r": float(spearman_r), + "spearman_p": float(spearman_p), + } diff --git a/src/eegvix/generation/__init__.py b/src/eegvix/generation/__init__.py new file mode 100644 index 0000000..4d4df95 --- /dev/null +++ b/src/eegvix/generation/__init__.py @@ -0,0 +1,3 @@ +from eegvix.generation.pipeline import EEGToImagePipeline + +__all__ = ["EEGToImagePipeline"] diff --git a/src/eegvix/generation/pipeline.py b/src/eegvix/generation/pipeline.py new file mode 100644 index 0000000..3d2a6d2 --- /dev/null +++ b/src/eegvix/generation/pipeline.py @@ -0,0 +1,138 @@ +"""End-to-end EEG-to-image generation pipeline. + +Usage: + pipeline = EEGToImagePipeline.from_pretrained("path/to/checkpoint") + images = pipeline.generate(eeg_tensor, subject_id=0) +""" + +from pathlib import Path + +import torch +from PIL import Image + +from eegvix.models.eeg_encoder import EEGEncoder +from eegvix.models.projection_head import ProjectionHead +from eegvix.models.diffusion_wrapper import DiffusionWrapper + + +class EEGToImagePipeline: + """Complete pipeline: raw EEG → CLIP embedding → Stable Diffusion → image.""" + + def __init__( + self, + eeg_encoder: EEGEncoder, + projection_head: ProjectionHead, + diffusion: DiffusionWrapper, + device: str = "cuda", + ): + self.device = device + self.eeg_encoder = eeg_encoder.to(device).eval() + self.projection_head = projection_head.to(device).eval() + self.diffusion = diffusion + + @classmethod + def from_pretrained( + cls, + contrastive_checkpoint: str | Path, + sd_model: str = "stabilityai/stable-diffusion-2-1", + use_ip_adapter: bool = True, + lora_path: str | Path | None = None, + device: str = "cuda", + ) -> "EEGToImagePipeline": + """Load a pretrained pipeline from a contrastive training checkpoint. + + Args: + contrastive_checkpoint: Path to the Lightning checkpoint from contrastive training. + sd_model: Stable Diffusion model identifier. + use_ip_adapter: Whether to use IP-Adapter for conditioning. + lora_path: Optional path to LoRA weights for the diffusion model. + device: Device to load models on. + """ + checkpoint = torch.load(contrastive_checkpoint, map_location="cpu", weights_only=False) + state = checkpoint.get("state_dict", checkpoint) + + # Extract EEG encoder state + encoder_state = { + k.replace("eeg_encoder.", ""): v + for k, v in state.items() + if k.startswith("eeg_encoder.") + } + proj_state = { + k.replace("projection_head.", ""): v + for k, v in state.items() + if k.startswith("projection_head.") + } + + eeg_encoder = EEGEncoder() + eeg_encoder.load_state_dict(encoder_state) + + projection_head = ProjectionHead() + projection_head.load_state_dict(proj_state) + + diffusion = DiffusionWrapper( + sd_model=sd_model, + use_ip_adapter=use_ip_adapter, + use_lora=lora_path is not None, + device=device, + ) + if lora_path is not None: + diffusion.load_lora(lora_path) + + return cls(eeg_encoder, projection_head, diffusion, device) + + @torch.no_grad() + def encode_eeg( + self, + eeg: torch.Tensor, + subject_id: int | torch.Tensor = 0, + ) -> torch.Tensor: + """Encode raw EEG to a CLIP-space embedding. + + Args: + eeg: (batch, 17, 100) or (17, 100) raw EEG signal + subject_id: Subject index or tensor of indices + + Returns: + (batch, 768) L2-normalized CLIP-space embeddings + """ + if eeg.dim() == 2: + eeg = eeg.unsqueeze(0) + eeg = eeg.to(self.device) + + if isinstance(subject_id, int): + subject_ids = torch.full((eeg.size(0),), subject_id, device=self.device, dtype=torch.long) + else: + subject_ids = subject_id.to(self.device) + + features = self.eeg_encoder(eeg, subject_ids) + return self.projection_head(features) + + @torch.no_grad() + def generate( + self, + eeg: torch.Tensor, + subject_id: int | torch.Tensor = 0, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + num_images: int = 1, + ) -> list[Image.Image]: + """Generate images from raw EEG signals. + + Args: + eeg: (batch, 17, 100) or (17, 100) raw EEG signal + subject_id: Subject index + num_inference_steps: Diffusion denoising steps + guidance_scale: Classifier-free guidance scale + num_images: Number of images to generate per EEG sample + + Returns: + List of PIL images + """ + clip_embedding = self.encode_eeg(eeg, subject_id) + + return self.diffusion.generate( + clip_embedding=clip_embedding, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + num_images_per_prompt=num_images, + ) diff --git a/src/eegvix/losses/__init__.py b/src/eegvix/losses/__init__.py new file mode 100644 index 0000000..5e9204b --- /dev/null +++ b/src/eegvix/losses/__init__.py @@ -0,0 +1,3 @@ +from eegvix.losses.contrastive import InfoNCELoss + +__all__ = ["InfoNCELoss"] diff --git a/src/eegvix/losses/contrastive.py b/src/eegvix/losses/contrastive.py new file mode 100644 index 0000000..50d573f --- /dev/null +++ b/src/eegvix/losses/contrastive.py @@ -0,0 +1,71 @@ +"""InfoNCE / CLIP-style symmetric contrastive loss. + +The loss aligns EEG embeddings with their corresponding CLIP image embeddings +in a shared space. Within each batch, matching EEG-image pairs are positives; +all other combinations are negatives. + +L = 0.5 * (CE(eeg_to_img_logits, labels) + CE(img_to_eeg_logits, labels)) + +where logits[i,j] = (eeg_embed[i] . img_embed[j]) / temperature +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class InfoNCELoss(nn.Module): + """Symmetric InfoNCE contrastive loss with learnable temperature. + + Equivalent to the CLIP loss: matches EEG embeddings to image embeddings + bidirectionally within a batch. + """ + + def __init__(self, init_temperature: float = 0.07, learnable: bool = True): + super().__init__() + # log_temperature is learned; temperature = exp(log_temperature) + log_temp = torch.tensor([-torch.tensor(init_temperature).log()]) + if learnable: + self.log_temperature = nn.Parameter(log_temp) + else: + self.register_buffer("log_temperature", log_temp) + + @property + def temperature(self) -> torch.Tensor: + # Clamp to avoid numerical instability + return self.log_temperature.exp().clamp(min=1e-4, max=100.0) + + def forward( + self, + eeg_embeds: torch.Tensor, + image_embeds: torch.Tensor, + ) -> dict[str, torch.Tensor]: + """Compute symmetric contrastive loss. + + Args: + eeg_embeds: (batch, embed_dim) L2-normalized EEG embeddings + image_embeds: (batch, embed_dim) L2-normalized image embeddings + + Returns: + dict with keys: "loss", "eeg_to_img_acc", "img_to_eeg_acc", "temperature" + """ + # Cosine similarity scaled by temperature + logits = (eeg_embeds @ image_embeds.T) / self.temperature # (batch, batch) + labels = torch.arange(logits.size(0), device=logits.device) + + # Symmetric cross-entropy + loss_eeg_to_img = F.cross_entropy(logits, labels) + loss_img_to_eeg = F.cross_entropy(logits.T, labels) + loss = (loss_eeg_to_img + loss_img_to_eeg) / 2 + + # Accuracy metrics (for logging) + with torch.no_grad(): + eeg_to_img_acc = (logits.argmax(dim=1) == labels).float().mean() + img_to_eeg_acc = (logits.T.argmax(dim=1) == labels).float().mean() + + return { + "loss": loss, + "eeg_to_img_acc": eeg_to_img_acc, + "img_to_eeg_acc": img_to_eeg_acc, + "temperature": self.temperature.detach(), + } diff --git a/src/eegvix/models/__init__.py b/src/eegvix/models/__init__.py new file mode 100644 index 0000000..4e44867 --- /dev/null +++ b/src/eegvix/models/__init__.py @@ -0,0 +1,6 @@ +from eegvix.models.eeg_encoder import EEGEncoder +from eegvix.models.clip_wrapper import CLIPImageEncoder +from eegvix.models.projection_head import ProjectionHead +from eegvix.models.subject_embedding import SubjectEmbedding + +__all__ = ["EEGEncoder", "CLIPImageEncoder", "ProjectionHead", "SubjectEmbedding"] diff --git a/src/eegvix/models/clip_wrapper.py b/src/eegvix/models/clip_wrapper.py new file mode 100644 index 0000000..d105c08 --- /dev/null +++ b/src/eegvix/models/clip_wrapper.py @@ -0,0 +1,81 @@ +"""Frozen CLIP ViT-L/14 image encoder wrapper. + +Provides a simple interface to extract CLIP image embeddings, +either online during training or as a precomputation step. +""" + +from pathlib import Path + +import torch +import torch.nn as nn +import open_clip + + +class CLIPImageEncoder(nn.Module): + """Wraps OpenCLIP's image encoder with all parameters frozen. + + The forward pass returns L2-normalized image embeddings + in the same space that the EEG encoder is trained to match. + """ + + def __init__(self, model_name: str = "ViT-L-14", pretrained: str = "openai"): + super().__init__() + model, _, self.preprocess = open_clip.create_model_and_transforms( + model_name, pretrained=pretrained + ) + self.visual = model.visual + self.embed_dim = model.visual.output_dim + + # Freeze all parameters + for param in self.visual.parameters(): + param.requires_grad = False + + @torch.no_grad() + def forward(self, images: torch.Tensor) -> torch.Tensor: + """Extract CLIP image embeddings. + + Args: + images: (batch, 3, 224, 224) preprocessed images + + Returns: + (batch, embed_dim) L2-normalized image embeddings + """ + features = self.visual(images) + return nn.functional.normalize(features, dim=-1) + + @torch.no_grad() + def precompute_embeddings( + self, + image_paths: list[Path], + batch_size: int = 64, + device: str = "cuda", + ) -> torch.Tensor: + """Precompute CLIP embeddings for a list of image files. + + Args: + image_paths: List of paths to images. + batch_size: Processing batch size. + device: Device to run inference on. + + Returns: + (n_images, embed_dim) tensor of embeddings. + """ + from PIL import Image + from tqdm import tqdm + + self.to(device) + self.eval() + + all_embeddings = [] + for i in tqdm(range(0, len(image_paths), batch_size), desc="Precomputing CLIP embeddings"): + batch_paths = image_paths[i : i + batch_size] + batch_images = [] + for p in batch_paths: + img = Image.open(p).convert("RGB") + batch_images.append(self.preprocess(img)) + + batch_tensor = torch.stack(batch_images).to(device) + embeddings = self.forward(batch_tensor) + all_embeddings.append(embeddings.cpu()) + + return torch.cat(all_embeddings, dim=0) diff --git a/src/eegvix/models/diffusion_wrapper.py b/src/eegvix/models/diffusion_wrapper.py new file mode 100644 index 0000000..fab332b --- /dev/null +++ b/src/eegvix/models/diffusion_wrapper.py @@ -0,0 +1,126 @@ +"""Stable Diffusion wrapper with IP-Adapter and optional LoRA for EEG-conditioned generation. + +IP-Adapter injects image embeddings (or in our case, EEG-derived CLIP embeddings) into +the cross-attention layers of the UNet, enabling image-conditioned generation without +modifying the diffusion model's architecture. +""" + +from pathlib import Path + +import torch +import torch.nn as nn +from diffusers import StableDiffusionPipeline, DDIMScheduler + + +class DiffusionWrapper(nn.Module): + """Wraps Stable Diffusion for CLIP-embedding-conditioned image generation. + + Supports: + - IP-Adapter conditioning (primary approach) + - Optional LoRA fine-tuning of cross-attention layers + """ + + def __init__( + self, + sd_model: str = "stabilityai/stable-diffusion-2-1", + use_ip_adapter: bool = True, + use_lora: bool = False, + lora_rank: int = 16, + lora_alpha: int = 32, + device: str = "cuda", + ): + super().__init__() + self.device = device + self.use_ip_adapter = use_ip_adapter + self.use_lora = use_lora + + # Load Stable Diffusion pipeline + self.pipe = StableDiffusionPipeline.from_pretrained( + sd_model, + torch_dtype=torch.float16, + safety_checker=None, + ) + self.pipe.scheduler = DDIMScheduler.from_config(self.pipe.scheduler.config) + self.pipe = self.pipe.to(device) + + # Load IP-Adapter if requested + if use_ip_adapter: + self.pipe.load_ip_adapter( + "h94/IP-Adapter", + subfolder="models", + weight_name="ip-adapter_sd15.bin", + ) + self.pipe.set_ip_adapter_scale(0.8) + + # Apply LoRA if requested + if use_lora: + from peft import LoraConfig, get_peft_model + + lora_config = LoraConfig( + r=lora_rank, + lora_alpha=lora_alpha, + target_modules=["to_k", "to_v", "to_q", "to_out.0"], + lora_dropout=0.05, + ) + self.pipe.unet = get_peft_model(self.pipe.unet, lora_config) + + @torch.no_grad() + def generate( + self, + clip_embedding: torch.Tensor, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + num_images_per_prompt: int = 1, + height: int = 512, + width: int = 512, + ) -> list: + """Generate images conditioned on CLIP embeddings (from EEG encoder). + + Args: + clip_embedding: (batch, 768) CLIP-space embeddings from the EEG encoder. + num_inference_steps: Number of diffusion denoising steps. + guidance_scale: Classifier-free guidance scale. + num_images_per_prompt: Number of images to generate per embedding. + height: Output image height. + width: Output image width. + + Returns: + List of PIL images. + """ + if self.use_ip_adapter: + # IP-Adapter expects image embeddings + # We pass EEG-derived CLIP embeddings as if they were image embeddings + output = self.pipe( + prompt="", + ip_adapter_image_embeds=[clip_embedding.to(self.device, dtype=torch.float16)], + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + num_images_per_prompt=num_images_per_prompt, + height=height, + width=width, + ) + else: + # Direct text-encoder replacement: project CLIP embedding into text encoder space + # This is a simpler but less effective approach + prompt_embeds = clip_embedding.unsqueeze(1).expand(-1, 77, -1) + prompt_embeds = prompt_embeds.to(self.device, dtype=torch.float16) + + output = self.pipe( + prompt_embeds=prompt_embeds, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + num_images_per_prompt=num_images_per_prompt, + height=height, + width=width, + ) + + return output.images + + def save_lora(self, path: str | Path) -> None: + if self.use_lora: + self.pipe.unet.save_pretrained(path) + + def load_lora(self, path: str | Path) -> None: + if self.use_lora: + from peft import PeftModel + self.pipe.unet = PeftModel.from_pretrained(self.pipe.unet, path) diff --git a/src/eegvix/models/eeg_encoder.py b/src/eegvix/models/eeg_encoder.py new file mode 100644 index 0000000..94f9836 --- /dev/null +++ b/src/eegvix/models/eeg_encoder.py @@ -0,0 +1,357 @@ +"""Spatiotemporal Transformer EEG Encoder. + +Architecture: + 1. Temporal convolution stem: 1D convolutions across time per channel with residual + connections, reducing 100 timepoints to ~25 temporal tokens. + 2. Spatial attention: Multi-head attention across 17 channels at each temporal position, + using learned positional embeddings from 2D electrode coordinates. + 3. Temporal transformer: Stack of transformer encoder layers over the temporal dimension, + with a prepended learnable CLS token. + 4. Optional frequency branch: Parallel FFT-based pathway that extracts spectral features + and concatenates them before the final projection. + 5. Subject conditioning: Per-subject embedding added to CLS token. + 6. Final linear projection to output_dim (768 for CLIP ViT-L/14 compatibility). +""" + +import math + +import torch +import torch.nn as nn +from einops import rearrange + +from eegvix.data.channel_info import get_channel_positions_tensor, N_CHANNELS +from eegvix.models.subject_embedding import SubjectEmbedding + + +class TemporalConvStem(nn.Module): + """Multi-layer 1D temporal convolution with residual connections. + + Processes each channel's time series, reducing temporal resolution + while increasing feature dimensionality. + """ + + def __init__( + self, + in_channels: int = 1, + embed_dim: int = 512, + kernel_sizes: list[int] | None = None, + dropout: float = 0.1, + ): + super().__init__() + if kernel_sizes is None: + kernel_sizes = [7, 5, 3] + + layers = [] + current_channels = in_channels + dims = self._compute_layer_dims(in_channels, embed_dim, len(kernel_sizes)) + + for i, (k, out_dim) in enumerate(zip(kernel_sizes, dims)): + layers.append(TemporalConvBlock(current_channels, out_dim, k, dropout)) + current_channels = out_dim + + self.layers = nn.ModuleList(layers) + self.out_dim = dims[-1] + + @staticmethod + def _compute_layer_dims(in_dim: int, out_dim: int, n_layers: int) -> list[int]: + """Linearly interpolate channel dimensions across layers.""" + if n_layers == 1: + return [out_dim] + dims = [] + for i in range(n_layers): + d = in_dim + (out_dim - in_dim) * (i + 1) / n_layers + dims.append(int(d)) + dims[-1] = out_dim # Ensure exact output dim + return dims + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x: (batch * n_channels, 1, n_timepoints) + + Returns: + (batch * n_channels, embed_dim, reduced_timepoints) + """ + for layer in self.layers: + x = layer(x) + return x + + +class TemporalConvBlock(nn.Module): + """Single temporal conv block with residual connection.""" + + def __init__(self, in_channels: int, out_channels: int, kernel_size: int, dropout: float = 0.1): + super().__init__() + padding = kernel_size // 2 + self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride=2, padding=padding) + self.norm = nn.LayerNorm(out_channels) + self.act = nn.GELU() + self.dropout = nn.Dropout(dropout) + + # Residual projection if dimensions change + self.residual = ( + nn.Conv1d(in_channels, out_channels, 1, stride=2) + if in_channels != out_channels + else nn.AvgPool1d(kernel_size=2, stride=2) + if in_channels == out_channels + else nn.Identity() + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + residual = self.residual(x) + x = self.conv(x) + # LayerNorm expects (batch, time, channels), so transpose + x = self.norm(x.transpose(1, 2)).transpose(1, 2) + x = self.act(x) + x = self.dropout(x) + # Ensure matching temporal dimensions for residual + min_t = min(x.size(2), residual.size(2)) + return x[:, :, :min_t] + residual[:, :, :min_t] + + +class SpatialChannelAttention(nn.Module): + """Multi-head attention across EEG channels, informed by electrode topology. + + Uses 2D scalp positions of electrodes as positional embeddings, learned through + a small MLP that maps (x, y) coordinates to the embedding dimension. + """ + + def __init__(self, embed_dim: int = 512, n_heads: int = 4, dropout: float = 0.1): + super().__init__() + self.attention = nn.MultiheadAttention( + embed_dim=embed_dim, num_heads=n_heads, dropout=dropout, batch_first=True + ) + self.norm = nn.LayerNorm(embed_dim) + + # Learn spatial positional embeddings from 2D electrode coordinates + self.pos_encoder = nn.Sequential( + nn.Linear(2, embed_dim // 2), + nn.GELU(), + nn.Linear(embed_dim // 2, embed_dim), + ) + + # Register electrode positions as buffer (not a parameter) + self.register_buffer("channel_positions", get_channel_positions_tensor()) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x: (batch, n_channels, embed_dim) + + Returns: + (batch, n_channels, embed_dim) + """ + # Add spatial positional encoding + pos_embed = self.pos_encoder(self.channel_positions) # (17, embed_dim) + x_pos = x + pos_embed.unsqueeze(0) + + # Self-attention across channels + attended, _ = self.attention(x_pos, x_pos, x_pos) + return self.norm(x + attended) + + +class FrequencyBranch(nn.Module): + """Parallel branch that extracts frequency-domain features from raw EEG. + + Computes FFT and learns features from the power spectrum of each channel. + """ + + def __init__(self, n_channels: int = 17, n_timepoints: int = 100, output_dim: int = 128): + super().__init__() + n_freq_bins = n_timepoints // 2 + 1 # 51 for 100 timepoints + + self.freq_encoder = nn.Sequential( + nn.Linear(n_freq_bins, 256), + nn.GELU(), + nn.LayerNorm(256), + nn.Linear(256, output_dim), + nn.GELU(), + ) + self.channel_pool = nn.Sequential( + nn.Linear(n_channels * output_dim, output_dim * 2), + nn.GELU(), + nn.LayerNorm(output_dim * 2), + nn.Linear(output_dim * 2, output_dim), + ) + + def forward(self, eeg: torch.Tensor) -> torch.Tensor: + """ + Args: + eeg: (batch, n_channels, n_timepoints) raw EEG signal + + Returns: + (batch, output_dim) frequency features + """ + fft = torch.fft.rfft(eeg, dim=-1) + power = fft.real ** 2 + fft.imag ** 2 # (batch, 17, 51) + + # Per-channel frequency encoding + freq_features = self.freq_encoder(power) # (batch, 17, output_dim) + + # Pool across channels + pooled = freq_features.reshape(freq_features.size(0), -1) # (batch, 17 * output_dim) + return self.channel_pool(pooled) # (batch, output_dim) + + +class EEGEncoder(nn.Module): + """Full spatiotemporal transformer encoder for EEG signals. + + Combines temporal convolutions, spatial attention, temporal transformer, + optional frequency features, and per-subject embeddings to produce + CLIP-compatible embeddings. + """ + + def __init__( + self, + n_channels: int = 17, + n_timepoints: int = 100, + embed_dim: int = 512, + num_temporal_conv_layers: int = 3, + temporal_kernel_sizes: list[int] | None = None, + n_spatial_heads: int = 4, + n_temporal_transformer_layers: int = 4, + n_temporal_heads: int = 8, + dropout: float = 0.1, + use_frequency_branch: bool = True, + output_dim: int = 768, + n_subjects: int = 10, + ): + super().__init__() + self.n_channels = n_channels + self.n_timepoints = n_timepoints + self.embed_dim = embed_dim + self.use_frequency_branch = use_frequency_branch + + # 1. Temporal convolution stem (per channel) + if temporal_kernel_sizes is None: + temporal_kernel_sizes = [7, 5, 3][:num_temporal_conv_layers] + self.temporal_stem = TemporalConvStem( + in_channels=1, embed_dim=embed_dim, kernel_sizes=temporal_kernel_sizes, dropout=dropout + ) + + # 2. Spatial attention across channels + self.spatial_attention = SpatialChannelAttention( + embed_dim=embed_dim, n_heads=n_spatial_heads, dropout=dropout + ) + + # 3. Temporal transformer with CLS token + self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim) * 0.02) + self.temporal_pos_embed = None # Dynamically created based on sequence length + + encoder_layer = nn.TransformerEncoderLayer( + d_model=embed_dim, + nhead=n_temporal_heads, + dim_feedforward=embed_dim * 4, + dropout=dropout, + activation="gelu", + batch_first=True, + norm_first=True, + ) + self.temporal_transformer = nn.TransformerEncoder( + encoder_layer, num_layers=n_temporal_transformer_layers + ) + + # 4. Frequency branch (optional) + freq_dim = 128 if use_frequency_branch else 0 + if use_frequency_branch: + self.frequency_branch = FrequencyBranch( + n_channels=n_channels, n_timepoints=n_timepoints, output_dim=freq_dim + ) + + # 5. Subject embedding + self.subject_embedding = SubjectEmbedding(n_subjects=n_subjects, embed_dim=embed_dim) + + # 6. Final projection + self.norm = nn.LayerNorm(embed_dim + freq_dim) + self.projection = nn.Linear(embed_dim + freq_dim, output_dim) + + self._init_weights() + + def _init_weights(self) -> None: + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.LayerNorm): + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) + + def _get_temporal_pos_embed(self, seq_len: int, device: torch.device) -> torch.Tensor: + """Sinusoidal positional embeddings for the temporal transformer.""" + position = torch.arange(seq_len, device=device).unsqueeze(1).float() + div_term = torch.exp( + torch.arange(0, self.embed_dim, 2, device=device).float() + * (-math.log(10000.0) / self.embed_dim) + ) + pe = torch.zeros(1, seq_len, self.embed_dim, device=device) + pe[0, :, 0::2] = torch.sin(position * div_term) + pe[0, :, 1::2] = torch.cos(position * div_term) + return pe + + def forward( + self, + eeg: torch.Tensor, + subject_ids: torch.Tensor, + enable_subject_embed: bool = True, + ) -> torch.Tensor: + """ + Args: + eeg: (batch, 17, 100) raw EEG signal + subject_ids: (batch,) integer subject indices + enable_subject_embed: Set False during warmup to disable subject conditioning + + Returns: + (batch, output_dim) EEG embedding in CLIP-compatible space + """ + batch_size = eeg.size(0) + + # --- Temporal convolution stem (per channel) --- + # Reshape to process each channel independently + x = rearrange(eeg, "b c t -> (b c) 1 t") # (batch*17, 1, 100) + x = self.temporal_stem(x) # (batch*17, embed_dim, ~12) + reduced_time = x.size(2) + + # Reshape back: (batch, 17, embed_dim, reduced_time) + x = rearrange(x, "(b c) d t -> b c t d", b=batch_size, c=self.n_channels) + # x is now (batch, 17, reduced_time, embed_dim) + + # --- Spatial attention at each temporal position --- + # Process each time step: attention across 17 channels + spatial_out = [] + for t in range(reduced_time): + channel_features = x[:, :, t, :] # (batch, 17, embed_dim) + attended = self.spatial_attention(channel_features) # (batch, 17, embed_dim) + # Pool across channels -> (batch, embed_dim) + pooled = attended.mean(dim=1) + spatial_out.append(pooled) + + # Stack temporal tokens: (batch, reduced_time, embed_dim) + temporal_tokens = torch.stack(spatial_out, dim=1) + + # --- Prepend CLS token with subject embedding --- + cls = self.cls_token.expand(batch_size, -1, -1) # (batch, 1, embed_dim) + if enable_subject_embed: + subj_embed = self.subject_embedding(subject_ids) # (batch, embed_dim) + cls = cls + subj_embed.unsqueeze(1) + + tokens = torch.cat([cls, temporal_tokens], dim=1) # (batch, 1+reduced_time, embed_dim) + + # Add sinusoidal positional embeddings + pos_embed = self._get_temporal_pos_embed(tokens.size(1), tokens.device) + tokens = tokens + pos_embed + + # --- Temporal transformer --- + transformer_out = self.temporal_transformer(tokens) # (batch, 1+reduced_time, embed_dim) + cls_out = transformer_out[:, 0, :] # (batch, embed_dim) — CLS token output + + # --- Frequency branch (optional) --- + if self.use_frequency_branch: + freq_features = self.frequency_branch(eeg) # (batch, 128) + cls_out = torch.cat([cls_out, freq_features], dim=-1) # (batch, embed_dim + 128) + + # --- Final projection --- + cls_out = self.norm(cls_out) + output = self.projection(cls_out) # (batch, output_dim) + + return output diff --git a/src/eegvix/models/projection_head.py b/src/eegvix/models/projection_head.py new file mode 100644 index 0000000..2077c8f --- /dev/null +++ b/src/eegvix/models/projection_head.py @@ -0,0 +1,32 @@ +"""Projection head for contrastive learning. + +Maps encoder output to a normalized embedding space where the contrastive loss operates. +Standard practice in CLIP/SimCLR: the projection head separates the representation space +(used for downstream tasks) from the contrastive loss space. +""" + +import torch +import torch.nn as nn + + +class ProjectionHead(nn.Module): + def __init__(self, input_dim: int = 768, hidden_dim: int = 2048, output_dim: int = 768): + super().__init__() + self.net = nn.Sequential( + nn.Linear(input_dim, hidden_dim), + nn.LayerNorm(hidden_dim), + nn.GELU(), + nn.Linear(hidden_dim, output_dim), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Project and L2-normalize. + + Args: + x: (batch, input_dim) + + Returns: + (batch, output_dim) L2-normalized embeddings + """ + projected = self.net(x) + return nn.functional.normalize(projected, dim=-1) diff --git a/src/eegvix/models/subject_embedding.py b/src/eegvix/models/subject_embedding.py new file mode 100644 index 0000000..e3f27a6 --- /dev/null +++ b/src/eegvix/models/subject_embedding.py @@ -0,0 +1,27 @@ +"""Per-subject learnable embedding for multi-subject EEG encoding.""" + +import torch +import torch.nn as nn + + +class SubjectEmbedding(nn.Module): + """Learnable embedding vector per subject, added to the EEG encoder's CLS token. + + During a warmup phase early in training, embeddings can be zeroed out + to let the encoder learn subject-invariant features first. + """ + + def __init__(self, n_subjects: int = 10, embed_dim: int = 512): + super().__init__() + self.embedding = nn.Embedding(n_subjects, embed_dim) + nn.init.normal_(self.embedding.weight, std=0.02) + + def forward(self, subject_ids: torch.Tensor) -> torch.Tensor: + """ + Args: + subject_ids: (batch,) integer tensor with subject indices + + Returns: + (batch, embed_dim) subject embedding vectors + """ + return self.embedding(subject_ids) diff --git a/src/eegvix/training/__init__.py b/src/eegvix/training/__init__.py new file mode 100644 index 0000000..b8ee11e --- /dev/null +++ b/src/eegvix/training/__init__.py @@ -0,0 +1,3 @@ +from eegvix.training.contrastive_module import ContrastiveAlignmentModule + +__all__ = ["ContrastiveAlignmentModule"] diff --git a/src/eegvix/training/callbacks.py b/src/eegvix/training/callbacks.py new file mode 100644 index 0000000..791c554 --- /dev/null +++ b/src/eegvix/training/callbacks.py @@ -0,0 +1,108 @@ +"""Custom Lightning callbacks for training visualization and monitoring.""" + +import torch +import lightning as L +import wandb +import numpy as np + + +class EmbeddingVisualizationCallback(L.Callback): + """Log UMAP/t-SNE projections of EEG and image embeddings to wandb.""" + + def __init__(self, every_n_epochs: int = 10, max_samples: int = 500): + self.every_n_epochs = every_n_epochs + self.max_samples = max_samples + + def on_validation_epoch_end(self, trainer: L.Trainer, pl_module: L.LightningModule) -> None: + if trainer.current_epoch % self.every_n_epochs != 0: + return + + try: + from sklearn.manifold import TSNE + except ImportError: + return + + eeg_embeds = [] + img_embeds = [] + image_ids = [] + + with torch.no_grad(): + for batch in trainer.val_dataloaders: + eeg = batch["eeg"].to(pl_module.device) + subject_ids = batch["subject_id"].to(pl_module.device) + + eeg_embed = pl_module(eeg, subject_ids) + eeg_embeds.append(eeg_embed.cpu()) + + if "clip_embedding" in batch: + img_embeds.append(batch["clip_embedding"]) + image_ids.extend(batch["image_id"].tolist()) + + if len(eeg_embeds) * eeg.size(0) >= self.max_samples: + break + + eeg_all = torch.cat(eeg_embeds, dim=0)[:self.max_samples].numpy() + if img_embeds: + img_all = torch.cat(img_embeds, dim=0)[:self.max_samples].numpy() + combined = np.concatenate([eeg_all, img_all], axis=0) + labels = ["eeg"] * len(eeg_all) + ["image"] * len(img_all) + else: + combined = eeg_all + labels = ["eeg"] * len(eeg_all) + + tsne = TSNE(n_components=2, random_state=42, perplexity=min(30, len(combined) - 1)) + projected = tsne.fit_transform(combined) + + table = wandb.Table(columns=["x", "y", "type"]) + for (x, y), label in zip(projected, labels): + table.add_data(float(x), float(y), label) + + wandb.log({ + "embedding_space": wandb.plot.scatter( + table, "x", "y", groupKeys="type", + title=f"Embedding Space (epoch {trainer.current_epoch})" + ) + }) + + +class RetrievalAccuracyCallback(L.Callback): + """Compute top-k retrieval accuracy on the validation set.""" + + def __init__(self, top_k: list[int] | None = None, every_n_epochs: int = 5): + self.top_k = top_k or [1, 5, 10] + self.every_n_epochs = every_n_epochs + + def on_validation_epoch_end(self, trainer: L.Trainer, pl_module: L.LightningModule) -> None: + if trainer.current_epoch % self.every_n_epochs != 0: + return + + eeg_embeds = [] + img_embeds = [] + + with torch.no_grad(): + for batch in trainer.val_dataloaders: + eeg = batch["eeg"].to(pl_module.device) + subject_ids = batch["subject_id"].to(pl_module.device) + + eeg_embed = pl_module(eeg, subject_ids) + eeg_embeds.append(eeg_embed.cpu()) + + if "clip_embedding" in batch: + img_embeds.append(batch["clip_embedding"]) + + if not img_embeds: + return + + eeg_all = torch.cat(eeg_embeds, dim=0) + img_all = torch.cat(img_embeds, dim=0) + + # Cosine similarity matrix + similarity = eeg_all @ img_all.T + labels = torch.arange(similarity.size(0)) + + for k in self.top_k: + if k > similarity.size(1): + continue + top_k_preds = similarity.topk(k, dim=1).indices + correct = (top_k_preds == labels.unsqueeze(1)).any(dim=1).float().mean() + pl_module.log(f"val/top{k}_acc", correct, sync_dist=True) diff --git a/src/eegvix/training/contrastive_module.py b/src/eegvix/training/contrastive_module.py new file mode 100644 index 0000000..b4ffe72 --- /dev/null +++ b/src/eegvix/training/contrastive_module.py @@ -0,0 +1,164 @@ +"""PyTorch Lightning module for contrastive EEG-CLIP alignment training.""" + +import torch +import lightning as L +from torch.optim import AdamW +from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, LinearLR, SequentialLR + +from eegvix.models.eeg_encoder import EEGEncoder +from eegvix.models.clip_wrapper import CLIPImageEncoder +from eegvix.models.projection_head import ProjectionHead +from eegvix.losses.contrastive import InfoNCELoss + + +class ContrastiveAlignmentModule(L.LightningModule): + """Trains the EEG encoder to produce CLIP-aligned embeddings via InfoNCE loss. + + The CLIP image encoder is frozen. Only the EEG encoder and projection head + are trained. Optionally uses precomputed CLIP embeddings to save memory. + """ + + def __init__( + self, + # EEG encoder config + n_channels: int = 17, + n_timepoints: int = 100, + embed_dim: int = 512, + num_temporal_conv_layers: int = 3, + temporal_kernel_sizes: list[int] | None = None, + n_spatial_heads: int = 4, + n_temporal_transformer_layers: int = 4, + n_temporal_heads: int = 8, + dropout: float = 0.1, + use_frequency_branch: bool = True, + output_dim: int = 768, + n_subjects: int = 10, + # Projection head config + proj_hidden_dim: int = 2048, + # Loss config + init_temperature: float = 0.07, + learnable_temperature: bool = True, + # Training config + learning_rate: float = 3e-4, + weight_decay: float = 0.01, + warmup_epochs: int = 10, + max_epochs: int = 200, + subject_embedding_warmup_epochs: int = 5, + # CLIP config + clip_model_name: str = "ViT-L-14", + clip_pretrained: str = "openai", + use_precomputed_clip: bool = True, + ): + super().__init__() + self.save_hyperparameters() + + # EEG encoder (trainable) + self.eeg_encoder = EEGEncoder( + n_channels=n_channels, + n_timepoints=n_timepoints, + embed_dim=embed_dim, + num_temporal_conv_layers=num_temporal_conv_layers, + temporal_kernel_sizes=temporal_kernel_sizes, + n_spatial_heads=n_spatial_heads, + n_temporal_transformer_layers=n_temporal_transformer_layers, + n_temporal_heads=n_temporal_heads, + dropout=dropout, + use_frequency_branch=use_frequency_branch, + output_dim=output_dim, + n_subjects=n_subjects, + ) + + # Projection head (trainable) + self.projection_head = ProjectionHead( + input_dim=output_dim, + hidden_dim=proj_hidden_dim, + output_dim=output_dim, + ) + + # CLIP image encoder (frozen, only needed if not using precomputed) + self.use_precomputed_clip = use_precomputed_clip + if not use_precomputed_clip: + self.clip_encoder = CLIPImageEncoder( + model_name=clip_model_name, pretrained=clip_pretrained + ) + + # Loss + self.criterion = InfoNCELoss( + init_temperature=init_temperature, + learnable=learnable_temperature, + ) + + # Config + self.learning_rate = learning_rate + self.weight_decay = weight_decay + self.warmup_epochs = warmup_epochs + self.max_epochs = max_epochs + self.subject_embedding_warmup_epochs = subject_embedding_warmup_epochs + + def forward(self, eeg: torch.Tensor, subject_ids: torch.Tensor) -> torch.Tensor: + """Encode EEG and project to CLIP space. + + Returns L2-normalized embeddings suitable for contrastive loss or retrieval. + """ + enable_subject = self.current_epoch >= self.subject_embedding_warmup_epochs + eeg_features = self.eeg_encoder(eeg, subject_ids, enable_subject_embed=enable_subject) + return self.projection_head(eeg_features) + + def _shared_step(self, batch: dict, stage: str) -> torch.Tensor: + eeg = batch["eeg"] + subject_ids = batch["subject_id"] + + # Get EEG embeddings + eeg_embeds = self.forward(eeg, subject_ids) + + # Get image embeddings + if self.use_precomputed_clip and "clip_embedding" in batch: + image_embeds = batch["clip_embedding"] + image_embeds = torch.nn.functional.normalize(image_embeds, dim=-1) + else: + image_embeds = self.clip_encoder(batch["image"]) + + # Contrastive loss + loss_dict = self.criterion(eeg_embeds, image_embeds) + + # Log metrics + self.log(f"{stage}/loss", loss_dict["loss"], prog_bar=True, sync_dist=True) + self.log(f"{stage}/eeg_to_img_acc", loss_dict["eeg_to_img_acc"], sync_dist=True) + self.log(f"{stage}/img_to_eeg_acc", loss_dict["img_to_eeg_acc"], sync_dist=True) + self.log(f"{stage}/temperature", loss_dict["temperature"], sync_dist=True) + + return loss_dict["loss"] + + def training_step(self, batch: dict, batch_idx: int) -> torch.Tensor: + return self._shared_step(batch, "train") + + def validation_step(self, batch: dict, batch_idx: int) -> torch.Tensor: + return self._shared_step(batch, "val") + + def test_step(self, batch: dict, batch_idx: int) -> torch.Tensor: + return self._shared_step(batch, "test") + + def configure_optimizers(self): + # Separate parameter groups: encoder, projection head, temperature + param_groups = [ + {"params": self.eeg_encoder.parameters(), "lr": self.learning_rate}, + {"params": self.projection_head.parameters(), "lr": self.learning_rate}, + {"params": self.criterion.parameters(), "lr": self.learning_rate * 10}, + ] + + optimizer = AdamW(param_groups, weight_decay=self.weight_decay) + + # Linear warmup + cosine decay + warmup_scheduler = LinearLR( + optimizer, start_factor=0.01, total_iters=self.warmup_epochs + ) + cosine_scheduler = CosineAnnealingWarmRestarts( + optimizer, T_0=self.max_epochs - self.warmup_epochs, T_mult=1 + ) + scheduler = SequentialLR( + optimizer, + schedulers=[warmup_scheduler, cosine_scheduler], + milestones=[self.warmup_epochs], + ) + + return {"optimizer": optimizer, "lr_scheduler": {"scheduler": scheduler, "interval": "epoch"}} diff --git a/src/eegvix/training/diffusion_module.py b/src/eegvix/training/diffusion_module.py new file mode 100644 index 0000000..f9286c0 --- /dev/null +++ b/src/eegvix/training/diffusion_module.py @@ -0,0 +1,119 @@ +"""PyTorch Lightning module for optional LoRA fine-tuning of Stable Diffusion. + +This module fine-tunes the cross-attention layers of the UNet using LoRA, +conditioned on EEG-derived CLIP embeddings paired with their ground truth images. +""" + +import torch +import lightning as L +from torch.optim import AdamW +from diffusers import DDPMScheduler + +from eegvix.models.eeg_encoder import EEGEncoder +from eegvix.models.projection_head import ProjectionHead + + +class DiffusionFineTuneModule(L.LightningModule): + """Fine-tunes Stable Diffusion's UNet cross-attention via LoRA. + + Loads a pretrained EEG encoder (from contrastive training) and uses it + to produce conditioning embeddings for the diffusion model. + """ + + def __init__( + self, + eeg_encoder_checkpoint: str, + sd_model: str = "stabilityai/stable-diffusion-2-1", + lora_rank: int = 16, + lora_alpha: int = 32, + learning_rate: float = 1e-5, + weight_decay: float = 0.01, + ): + super().__init__() + self.save_hyperparameters() + self.learning_rate = learning_rate + self.weight_decay = weight_decay + + # Load frozen EEG encoder from contrastive training + self.eeg_encoder = EEGEncoder() + self.projection_head = ProjectionHead() + checkpoint = torch.load(eeg_encoder_checkpoint, map_location="cpu", weights_only=False) + if "state_dict" in checkpoint: + state = checkpoint["state_dict"] + encoder_state = {k.replace("eeg_encoder.", ""): v for k, v in state.items() if k.startswith("eeg_encoder.")} + proj_state = {k.replace("projection_head.", ""): v for k, v in state.items() if k.startswith("projection_head.")} + self.eeg_encoder.load_state_dict(encoder_state) + self.projection_head.load_state_dict(proj_state) + + for param in self.eeg_encoder.parameters(): + param.requires_grad = False + for param in self.projection_head.parameters(): + param.requires_grad = False + self.eeg_encoder.eval() + self.projection_head.eval() + + # Noise scheduler for training + self.noise_scheduler = DDPMScheduler.from_pretrained(sd_model, subfolder="scheduler") + + # The actual UNet + LoRA setup happens in setup() to defer heavy model loading + self.sd_model = sd_model + self.lora_rank = lora_rank + self.lora_alpha = lora_alpha + self._unet = None + self._vae = None + + def setup(self, stage: str | None = None) -> None: + if self._unet is not None: + return + + from diffusers import AutoencoderKL, UNet2DConditionModel + from peft import LoraConfig, get_peft_model + + self._vae = AutoencoderKL.from_pretrained(self.sd_model, subfolder="vae") + self._vae.requires_grad_(False) + self._vae.eval() + + self._unet = UNet2DConditionModel.from_pretrained(self.sd_model, subfolder="unet") + lora_config = LoraConfig( + r=self.lora_rank, + lora_alpha=self.lora_alpha, + target_modules=["to_k", "to_v", "to_q", "to_out.0"], + lora_dropout=0.05, + ) + self._unet = get_peft_model(self._unet, lora_config) + + def training_step(self, batch: dict, batch_idx: int) -> torch.Tensor: + images = batch["image"] # (batch, 3, 512, 512) + eeg = batch["eeg"] + subject_ids = batch["subject_id"] + + # Get EEG-derived CLIP embeddings (frozen) + with torch.no_grad(): + eeg_features = self.eeg_encoder(eeg, subject_ids) + eeg_embeds = self.projection_head(eeg_features) + + # Encode images to latent space + with torch.no_grad(): + latents = self._vae.encode(images).latent_dist.sample() + latents = latents * self._vae.config.scaling_factor + + # Add noise + noise = torch.randn_like(latents) + timesteps = torch.randint(0, self.noise_scheduler.config.num_train_timesteps, (latents.size(0),), device=self.device) + noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps) + + # Predict noise conditioned on EEG embeddings + # Reshape embeddings to match expected cross-attention input: (batch, seq_len, dim) + encoder_hidden_states = eeg_embeds.unsqueeze(1).expand(-1, 77, -1) + + noise_pred = self._unet(noisy_latents, timesteps, encoder_hidden_states).sample + loss = torch.nn.functional.mse_loss(noise_pred, noise) + + self.log("train/diffusion_loss", loss, prog_bar=True) + return loss + + def configure_optimizers(self): + # Only optimize LoRA parameters + trainable_params = [p for p in self._unet.parameters() if p.requires_grad] + optimizer = AdamW(trainable_params, lr=self.learning_rate, weight_decay=self.weight_decay) + return optimizer diff --git a/src/eegvix/utils/__init__.py b/src/eegvix/utils/__init__.py new file mode 100644 index 0000000..4acac57 --- /dev/null +++ b/src/eegvix/utils/__init__.py @@ -0,0 +1,3 @@ +from eegvix.utils.seed import seed_everything + +__all__ = ["seed_everything"] diff --git a/src/eegvix/utils/io.py b/src/eegvix/utils/io.py new file mode 100644 index 0000000..c5fa7fe --- /dev/null +++ b/src/eegvix/utils/io.py @@ -0,0 +1,13 @@ +from pathlib import Path + +import torch + + +def save_checkpoint(state: dict, path: str | Path) -> None: + path = Path(path) + path.parent.mkdir(parents=True, exist_ok=True) + torch.save(state, path) + + +def load_checkpoint(path: str | Path, map_location: str = "cpu") -> dict: + return torch.load(path, map_location=map_location, weights_only=False) diff --git a/src/eegvix/utils/seed.py b/src/eegvix/utils/seed.py new file mode 100644 index 0000000..6376bd6 --- /dev/null +++ b/src/eegvix/utils/seed.py @@ -0,0 +1,13 @@ +import random + +import numpy as np +import torch +import lightning as L + + +def seed_everything(seed: int = 42) -> None: + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + L.seed_everything(seed, workers=True) diff --git a/src/models/VAE/__pycache__/vae.cpython-37.pyc b/src/models/VAE/__pycache__/vae.cpython-37.pyc deleted file mode 100644 index 98b0d97c19f6416654c21a1d123b9b38c7ba2347..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 3982 zcmbVPTW{n@6}HQ_?oQ`IHgkc6Wms6C0VUeuB5sUkZg+2$ zI}@6Ygp`r?Y5qg<$P40?r~M!GvK0DlUfrSS3zuV2xI?e+yUt@L^-}+tllqxB z49>iM2u+ZN(1f)nf+k8^(6nkz3z~M?fu>Vy+Nt}#(_6a47oFXnt2=v>a!}=?GARbs zHHTLZ)VDCmoW7eneesdA=caBdp0)a(^``GzANs*Zj_>qC%pK|IQ8^f;nH-3E5?&*m zoK#IADyMR1&Y3tD7mn*x-l5NFIVFC<^I+!0LFJu?skne$92pCdhOmg{7A<2D8jCi@ zowPM`F9htDuH1EJ&iOLGlR9v(y?gl?cyM{2f(~7}5Is*X9c8l2idf~-OoytHd79~H zoTO=9?qB|qmRv{fd=};_I#82wCM9RktvJrhyozHj%2J09N9FO`sSXCb>ZJ!sH8|KE z$su&!lk#o*;4a5^b-0t4St50?H;9K59pZ9~+hsf9(=6XVs8mmAUuDIf?i5Ltl~tVP zL*%X<$AcnKDvtl;{QKV4!DyIm$#S4vbCM5qiQtX+Rl}NE@76pvTXl2KiL{0 zKZV-*^}CO@j+1P2{Cl~Ag@icO8ieDnh?{QL<++02bgfS5X0y_HZV6)M_90Hyhp3!$ z=?}wtuEZ_ zp8v+e+~3`Y1LoNO!^WQ<|84s09l~gsC8hRMmG-(i0tzNYrgSG)d8w+T9At78@5q}J zj9R}9Kq5-Sq)&|Ns05OeBG#b=r)~r0e7B0V|0GY8UOh_kA!;VBEqkI4`tYSVOH>l8 z$x#F!FbA}4;I}S-xEoEBuU10DVk?tjhGOf)v0VYi-8lYmk`(rd0kixn<}{x`HQ7QL zWJM9j3om-E{5*}AbW_Oz28tvn{4akMXcTgp2H&G#k{0mLVsD^VNp&B{q*UXP z%D`_?inW6EW=NIkmWFd~0}a!)I(DDJJ8YmVwVF8u z7zn+^vI>IaoeFdW%RLSl2VI3faLkLAKKMse5%`D9V{BfS#{~bvKlNrFo1A4L6W+_a z_`3We1!H{qFNj`8exHgrDZWf`hvG{VhOK^xkq)Yn931FK<@>`?p2{CVIj5+pWI%Lq z5+5Yh0<#vZ#P&CF*6h?l)sUaw1UoemtJ$d+O5@+RVJ*KyVe-fc$rKj9>9^}jE&u|0*tsN2o1{J-Ckr7 z`T(!lm2oB=Wp)cGB~gtt3}of9`uUL#)$+iR+&V{`vL6czN2NW)Ay?;^9D}L&R4LBo2HO{i*f+m z+0MBJfi-Y7K2{UVkOwqnmI89%NiJXhlwunKm}8qY05%JdKjH0tigzjIm|LKkK~SV5 z8*VtgM5h}RN7tg*eDTwuOlT1O4Y3EUKvFlrIX{|wDwgea#Gd(f$9v}XgL=;kF6M5# z9~rW5Ar@^N!ReimQu9p#HwT=Yr@902GPv#D_c=iw7*Kn=XwVb0!IN1!F}rg_)DZE{ zFu{U|_3GnfqEw!gyE{Als_yRBinuHa6g1P&)b&-ieGgAeaWmf2wc70kzy+T3;FV3{ zAx`xT2n%omI~6nc7U6 zM^7dFL60EZEn2Q(L@yr|DA?mn?xVCbCSDmAz3Gj{_bV1x*k!%|#)AtmNQV0uV|$5a zfCX4A`!9I+3s~#+sLYftNlPCV*4N-zzb4uJ@Ylg64I3w=IEVFAOkZz2y~b`pCgH&= zPZMx%TH&iR^H6u(5@Qbue}!QYRGyH(goS(`q9F{Dkgc+}I5L%YHI|YU!&j_d=oyV) zO|`YJt2W)mHm}MDAenl)++Ab7bq*kL`}|9?Q-&M|{uiiuE_v*Vfnn3o6D$aR2}S diff --git a/src/models/no_gen/__pycache__/logreg.cpython-310.pyc b/src/models/no_gen/__pycache__/logreg.cpython-310.pyc deleted file mode 100644 index c03711130cb081537b362cb3a036b5a33bb98a31..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 1199 zcmZuwy>AmS6t{gJxm*%jDgg!236?PQ#EKBAXeE|5LakVCv5weHuf5!*_9gvLkSeuO zr;2|eH+ExXf&arRV#tm_VuANAX-O+=`Mu}gdw%x&aMoz}1jhRJui~Rk$T!@~4#3C? zEd30KAc7_&oN(G@A*19C5lmQTL|BsTTBlZM3%W&`_89C*tBrJaSksMu3W6jd6(nTm zWSa^q*qIkv!V>lw3GH)Ylh6TAj+xt(9VNAnvvea9GHKHpk0>`_p=138NKSG(B&Te| z3PMS4jqKdob(Hd}H)NbIlDP9xx60a4c^i&rF!fmw1cOi=L3K9%zj1e|1Q`?g&l?m7Ju{5sE zRV<{b^dceRv@>RC#8d^C^6=CNkocqygn(`gsUY#E&kQawj_&uQQh3C8oX2ULb8c9g z8h=af_hp*LQDWTHEFC--#@>|c?@aA=l(%b*Sc z1#7X^fJHf!K9I6Ah!2Ai@5WC@YVcBT#@ zC7!@W<2Ux{0`=)UwdwJL+fhwzn=Y0-jk*#_?sE>=>?bJKIp6O`$wX5oQ=wnem{zJS zlZ5k`Cp-eJ!grL7gkfk*kVooteyPOZ>pFR++B?Q8*>akyIWRFUd}Fw^G@9{PhF3=B Py_~%373S&M?S1PnT{R~1 diff --git a/src/models/no_gen/__pycache__/logreg.cpython-37.pyc b/src/models/no_gen/__pycache__/logreg.cpython-37.pyc deleted file mode 100644 index ea9408480a45f2614762c8846a81e742d43e73ab..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 1294 zcmZux%}?7f6t|r;P1907gkYQYEp1YEpn_ee2?5eBfF{r`)g&vcO$nq)vD3m(Cn15v ziSw?_aevEp-No0Pdf6Ya%l2McT0nE;pP!%Kd$!;EIcrl>9)b4b=Xde7K*%qgTsIqp zRp|OTfFOcKq($qLl6OQfVOu)H+P!twhLZunOJ{27nF65_Ai?{tb|ll=jFu z>$8kd(gT@WsogK6_JJeveP9pRW#CE>tVykeORUf$S?30L6B?s~cZ z@hT6!`V~AI! z%)+e|-hYq-1?GgU-a7hTT)!m|nhrgQ+Q zBmS%@=uTTIh1SO9JdDGXbK^ldlBknHMGT8$WwI#NB0K>>auJ!Iam6 zw6VXPs1~S7TfxcuB#A&QG|*}|tQcoAjAfvVwbxiiSTV)ju-QsNfwR{>Gq#q|o|=YZ z_|%j*Z$yFCoPQ&KHfsAxOV(8KNvgCPbc5Q4)Q4%(u5E-`gD${h)}(B9!;@MI38ZUr z!kaR#MF|kAxzzrwst_63sb2vQR;3QJ=^S;aOXsYK1bSMzMS`z8QIn<#OpL~)!GJp42I+s?T)lOp&o1?_|P-F@HVy?gJSsMQ<PW~)b+t5W6(9|8wF;1+TMTYcoaxidv|@26JHUYE z0A~g_GYrM8mTOEtAxD}3uF0L3yLihPd;B?lx>+J`9=~c_h9BRr9t4*=bYr|X@_2u; zvTi>NJw6&1=D8)Qr^)9UB8r7#d3m3y;GE@lTeMkE37J<~(77Lic#ov5pVaWROO8dv zh!lrRCaItLq{-AtB4W}Elmg990?{Lk@iaIiu@qgVRM6ujh&WJ2rdwHy`}xA(f{GkH zkZ@~HzPq$L$(cHh?E6TMOEV$ DmR^7V diff --git a/src/processing/img/pre/__pycache__/transf.cpython-37.pyc b/src/processing/img/pre/__pycache__/transf.cpython-37.pyc deleted file mode 100644 index 455782d203f67975a2cef35adf2f020228d9edc2..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 465 zcmYLGy-Ncz6i@E@aXqbb5nP2^8l3$&v0VhIpq;v%O^k+I(4s>Mn0Z_rt2 zXE7ihM5McuF1}o;K6t;p{9fK?cItH(f$V)e(FR87-CX8kfw2v$8UPeg9HAKZFhuJW7Rj*Tugx$sA>Xa=oW)o6c5d54-Hc?32BMjv`j0s z%I%>AiXpkv@vP|=#7P@~yDxX?raSTZbN+a{jtvn)73TS zVdV4is1V;POEZbCwnUr?#dPI>sqlhnrz^TFQ9|lk2f7X-An%cMOgRJJyW~W~j7V|B zWR?ecK-x^5W+El+P$}SoEEEZ0jOXDwNu}sBrNV@hFy?S&WVVrJI?zjh%PC^aG|JRK wRtY<}f~DC~$Fuc9CS?C44;5rC`Yt^GP0g$tyw=V}Ea+Lpw&3HCGBYk(zYs=#^8f$< diff --git a/src/scripts/evaluate.py b/src/scripts/evaluate.py new file mode 100644 index 0000000..ef6458e --- /dev/null +++ b/src/scripts/evaluate.py @@ -0,0 +1,114 @@ +"""Run the full evaluation suite on a trained contrastive model. + +Usage: + python scripts/evaluate.py --checkpoint path/to/checkpoint.ckpt --data_dir eeg_dataset +""" + +import argparse +from pathlib import Path + +import torch +import numpy as np +from tqdm import tqdm + +from eegvix.models.eeg_encoder import EEGEncoder +from eegvix.models.projection_head import ProjectionHead +from eegvix.models.clip_wrapper import CLIPImageEncoder +from eegvix.evaluation.retrieval import top_k_accuracy, zero_shot_identification +from eegvix.evaluation.rsa import rsa_correlation + + +def load_model(checkpoint_path: str, device: str = "cuda"): + checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False) + state = checkpoint.get("state_dict", checkpoint) + + encoder_state = {k.replace("eeg_encoder.", ""): v for k, v in state.items() if k.startswith("eeg_encoder.")} + proj_state = {k.replace("projection_head.", ""): v for k, v in state.items() if k.startswith("projection_head.")} + + eeg_encoder = EEGEncoder() + eeg_encoder.load_state_dict(encoder_state) + eeg_encoder.to(device).eval() + + projection_head = ProjectionHead() + projection_head.load_state_dict(proj_state) + projection_head.to(device).eval() + + return eeg_encoder, projection_head + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--checkpoint", type=str, required=True) + parser.add_argument("--data_dir", type=str, default="eeg_dataset") + parser.add_argument("--clip_embeddings", type=str, default=None) + parser.add_argument("--device", type=str, default="cuda") + parser.add_argument("--subjects", type=int, nargs="+", default=None) + args = parser.parse_args() + + device = args.device + data_dir = Path(args.data_dir) + subjects = args.subjects or list(range(10)) + + eeg_encoder, projection_head = load_model(args.checkpoint, device) + + # Load CLIP image embeddings + if args.clip_embeddings: + clip_data = torch.load(args.clip_embeddings, weights_only=True) + test_image_embeds = clip_data["embeddings"] + else: + print("Computing CLIP embeddings for test images...") + clip = CLIPImageEncoder() + image_dir = data_dir / "images" / "test_images" + image_paths = sorted(image_dir.rglob("*.jpg")) + if not image_paths: + image_paths = sorted(image_dir.rglob("*.JPEG")) + test_image_embeds = clip.precompute_embeddings(image_paths, device=device) + + print(f"\nEvaluating on {len(subjects)} subjects...") + print("=" * 60) + + all_results = {} + + for subj_idx in subjects: + subj_dir = data_dir / "preprocessed" / f"sub-{subj_idx + 1:02d}" + raw = np.load(subj_dir / "preprocessed_eeg_test.npy", allow_pickle=True).item() + eeg_data = raw["preprocessed_eeg_data"] # (200, 80, 17, 100) + eeg_averaged = np.mean(eeg_data, axis=1) # (200, 17, 100) + + # Encode EEG + eeg_tensor = torch.from_numpy(eeg_averaged.astype(np.float32)).to(device) + subject_ids = torch.full((eeg_tensor.size(0),), subj_idx, device=device, dtype=torch.long) + + with torch.no_grad(): + features = eeg_encoder(eeg_tensor, subject_ids) + eeg_embeds = projection_head(features) + + img_embeds = test_image_embeds.to(device) + + # Top-k retrieval + retrieval = top_k_accuracy(eeg_embeds, img_embeds) + + # Zero-shot identification (200-way) + zs = zero_shot_identification(eeg_embeds, img_embeds) + + # RSA + rsa = rsa_correlation(eeg_embeds, img_embeds) + + results = {**retrieval, "zero_shot_acc": zs["accuracy"], **rsa} + all_results[f"sub-{subj_idx + 1:02d}"] = results + + print(f"\nSubject {subj_idx + 1:02d}:") + for k, v in results.items(): + print(f" {k}: {v:.4f}") + + # Average across subjects + print("\n" + "=" * 60) + print("Average across subjects:") + avg_keys = list(next(iter(all_results.values())).keys()) + for k in avg_keys: + values = [r[k] for r in all_results.values()] + print(f" {k}: {np.mean(values):.4f} ± {np.std(values):.4f}") + + +if __name__ == "__main__": + main() diff --git a/src/scripts/generate.py b/src/scripts/generate.py new file mode 100644 index 0000000..b76f89a --- /dev/null +++ b/src/scripts/generate.py @@ -0,0 +1,72 @@ +"""Generate images from EEG signals using the full pipeline. + +Usage: + python scripts/generate.py --checkpoint path/to/contrastive_checkpoint.ckpt --data_dir eeg_dataset +""" + +import argparse +from pathlib import Path + +import torch +import numpy as np +from tqdm import tqdm + +from eegvix.generation.pipeline import EEGToImagePipeline + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--checkpoint", type=str, required=True, help="Contrastive training checkpoint") + parser.add_argument("--data_dir", type=str, default="eeg_dataset") + parser.add_argument("--output_dir", type=str, default="generated_images") + parser.add_argument("--subject", type=int, default=0) + parser.add_argument("--split", type=str, default="test", choices=["train", "test"]) + parser.add_argument("--n_images", type=int, default=5, help="Images per EEG condition") + parser.add_argument("--steps", type=int, default=50) + parser.add_argument("--guidance_scale", type=float, default=7.5) + parser.add_argument("--sd_model", type=str, default="stabilityai/stable-diffusion-2-1") + parser.add_argument("--lora_path", type=str, default=None) + parser.add_argument("--device", type=str, default="cuda") + args = parser.parse_args() + + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + pipeline = EEGToImagePipeline.from_pretrained( + contrastive_checkpoint=args.checkpoint, + sd_model=args.sd_model, + lora_path=args.lora_path, + device=args.device, + ) + + # Load test EEG data + data_dir = Path(args.data_dir) + subj_dir = data_dir / "preprocessed" / f"sub-{args.subject + 1:02d}" + + if args.split == "test": + raw = np.load(subj_dir / "preprocessed_eeg_test.npy", allow_pickle=True).item() + else: + raw = np.load(subj_dir / "preprocessed_eeg_training.npy", allow_pickle=True).item() + + eeg_data = raw["preprocessed_eeg_data"] + # Average across repetitions for cleaner signal + eeg_data = np.mean(eeg_data, axis=1) + + print(f"Generating images for {eeg_data.shape[0]} conditions...") + + for i in tqdm(range(eeg_data.shape[0])): + eeg_tensor = torch.from_numpy(eeg_data[i].astype(np.float32)) + images = pipeline.generate( + eeg_tensor, + subject_id=args.subject, + num_inference_steps=args.steps, + guidance_scale=args.guidance_scale, + num_images=args.n_images, + ) + + for j, img in enumerate(images): + img.save(output_dir / f"condition_{i:05d}_sample_{j}.png") + + +if __name__ == "__main__": + main() diff --git a/src/scripts/precompute_clip.py b/src/scripts/precompute_clip.py new file mode 100644 index 0000000..8e3bfd1 --- /dev/null +++ b/src/scripts/precompute_clip.py @@ -0,0 +1,56 @@ +"""Precompute CLIP image embeddings for all training and test images. + +Saves embeddings as .pt files to avoid running CLIP during contrastive training. + +Usage: + python scripts/precompute_clip.py --data_dir eeg_dataset --output_dir eeg_dataset/clip_embeddings +""" + +import argparse +from pathlib import Path + +import torch + +from eegvix.models.clip_wrapper import CLIPImageEncoder + + +def main(): + parser = argparse.ArgumentParser(description="Precompute CLIP image embeddings") + parser.add_argument("--data_dir", type=str, default="eeg_dataset") + parser.add_argument("--output_dir", type=str, default="eeg_dataset/clip_embeddings") + parser.add_argument("--model_name", type=str, default="ViT-L-14") + parser.add_argument("--pretrained", type=str, default="openai") + parser.add_argument("--batch_size", type=int, default=64) + parser.add_argument("--device", type=str, default="cuda") + args = parser.parse_args() + + data_dir = Path(args.data_dir) + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + clip = CLIPImageEncoder(model_name=args.model_name, pretrained=args.pretrained) + + for split, subdir in [("train", "training_images"), ("test", "test_images")]: + image_dir = data_dir / "images" / subdir + image_paths = sorted(image_dir.rglob("*.jpg")) + if not image_paths: + image_paths = sorted(image_dir.rglob("*.JPEG")) + if not image_paths: + image_paths = sorted(image_dir.rglob("*.png")) + + if not image_paths: + print(f"No images found in {image_dir}") + continue + + print(f"Computing {split} embeddings for {len(image_paths)} images...") + embeddings = clip.precompute_embeddings( + image_paths, batch_size=args.batch_size, device=args.device + ) + + output_path = output_dir / f"{split}_clip_embeddings.pt" + torch.save({"embeddings": embeddings, "paths": [str(p) for p in image_paths]}, output_path) + print(f"Saved {split} embeddings to {output_path} — shape: {embeddings.shape}") + + +if __name__ == "__main__": + main() diff --git a/src/scripts/train_contrastive.py b/src/scripts/train_contrastive.py new file mode 100644 index 0000000..4ba8d6f --- /dev/null +++ b/src/scripts/train_contrastive.py @@ -0,0 +1,107 @@ +"""Train the EEG encoder via contrastive alignment with CLIP. + +Usage: + python scripts/train_contrastive.py + python scripts/train_contrastive.py +experiment=debug + python scripts/train_contrastive.py training.max_epochs=100 data.batch_size=128 +""" + +import hydra +from omegaconf import DictConfig +import lightning as L +from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor +from lightning.pytorch.loggers import WandbLogger + +from eegvix.utils.seed import seed_everything +from eegvix.data.datamodule import ThingsEEG2DataModule +from eegvix.training.contrastive_module import ContrastiveAlignmentModule +from eegvix.training.callbacks import EmbeddingVisualizationCallback, RetrievalAccuracyCallback + + +@hydra.main(config_path="../../configs", config_name="config", version_base="1.3") +def main(cfg: DictConfig) -> None: + seed_everything(cfg.seed) + + # Data + datamodule = ThingsEEG2DataModule( + data_dir=cfg.data.data_dir, + batch_size=cfg.data.batch_size, + num_workers=cfg.data.num_workers, + average_repetitions=cfg.data.average_repetitions, + val_n_concepts=cfg.data.val_n_concepts, + val_random_state=cfg.data.val_random_state, + clip_embeddings_path=f"{cfg.data.clip_embeddings_dir}/train_clip_embeddings.pt", + ) + + # Model + model = ContrastiveAlignmentModule( + n_channels=cfg.data.n_channels, + n_timepoints=cfg.data.n_timepoints, + embed_dim=cfg.model.eeg_encoder.embed_dim, + num_temporal_conv_layers=cfg.model.eeg_encoder.num_temporal_conv_layers, + temporal_kernel_sizes=list(cfg.model.eeg_encoder.temporal_kernel_sizes), + n_spatial_heads=cfg.model.eeg_encoder.n_spatial_heads, + n_temporal_transformer_layers=cfg.model.eeg_encoder.n_temporal_transformer_layers, + n_temporal_heads=cfg.model.eeg_encoder.n_temporal_heads, + dropout=cfg.model.eeg_encoder.dropout, + use_frequency_branch=cfg.model.eeg_encoder.use_frequency_branch, + output_dim=cfg.model.eeg_encoder.output_dim, + n_subjects=cfg.data.n_subjects, + proj_hidden_dim=cfg.model.projection_head.hidden_dim, + init_temperature=cfg.training.init_temperature, + learnable_temperature=cfg.training.learnable_temperature, + learning_rate=cfg.training.learning_rate, + weight_decay=cfg.training.weight_decay, + warmup_epochs=cfg.training.warmup_epochs, + max_epochs=cfg.training.max_epochs, + subject_embedding_warmup_epochs=cfg.training.subject_embedding_warmup_epochs, + clip_model_name=cfg.model.clip.model_name, + clip_pretrained=cfg.model.clip.pretrained, + use_precomputed_clip=True, + ) + + # Callbacks + callbacks = [ + ModelCheckpoint( + monitor="val/loss", + mode="min", + save_top_k=3, + filename="epoch={epoch}-val_loss={val/loss:.4f}", + auto_insert_metric_name=False, + ), + EarlyStopping( + monitor="val/loss", + patience=cfg.training.early_stopping_patience, + mode="min", + ), + LearningRateMonitor(logging_interval="epoch"), + RetrievalAccuracyCallback(top_k=[1, 5, 10], every_n_epochs=5), + EmbeddingVisualizationCallback(every_n_epochs=10), + ] + + # Logger + logger = WandbLogger( + project=cfg.wandb.project, + entity=cfg.wandb.entity, + tags=cfg.wandb.tags, + mode=cfg.wandb.mode, + ) + + # Trainer + trainer = L.Trainer( + max_epochs=cfg.training.max_epochs, + precision=cfg.training.precision, + gradient_clip_val=cfg.training.gradient_clip_val, + accumulate_grad_batches=cfg.training.accumulate_grad_batches, + check_val_every_n_epoch=cfg.training.check_val_every_n_epoch, + callbacks=callbacks, + logger=logger, + deterministic=True, + ) + + trainer.fit(model, datamodule) + trainer.test(model, datamodule) + + +if __name__ == "__main__": + main() diff --git a/src/scripts/train_diffusion.py b/src/scripts/train_diffusion.py new file mode 100644 index 0000000..7251749 --- /dev/null +++ b/src/scripts/train_diffusion.py @@ -0,0 +1,64 @@ +"""Fine-tune Stable Diffusion with LoRA for EEG-conditioned generation. + +Usage: + python scripts/train_diffusion.py --checkpoint path/to/contrastive_checkpoint.ckpt +""" + +import argparse + +import lightning as L +from lightning.pytorch.callbacks import ModelCheckpoint +from lightning.pytorch.loggers import WandbLogger + +from eegvix.utils.seed import seed_everything +from eegvix.data.datamodule import ThingsEEG2DataModule +from eegvix.training.diffusion_module import DiffusionFineTuneModule + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--checkpoint", type=str, required=True, help="Contrastive training checkpoint") + parser.add_argument("--data_dir", type=str, default="eeg_dataset") + parser.add_argument("--sd_model", type=str, default="stabilityai/stable-diffusion-2-1") + parser.add_argument("--max_epochs", type=int, default=50) + parser.add_argument("--batch_size", type=int, default=4) + parser.add_argument("--lr", type=float, default=1e-5) + parser.add_argument("--lora_rank", type=int, default=16) + parser.add_argument("--seed", type=int, default=42) + args = parser.parse_args() + + seed_everything(args.seed) + + datamodule = ThingsEEG2DataModule( + data_dir=args.data_dir, + batch_size=args.batch_size, + num_workers=4, + average_repetitions=True, + ) + + model = DiffusionFineTuneModule( + eeg_encoder_checkpoint=args.checkpoint, + sd_model=args.sd_model, + lora_rank=args.lora_rank, + learning_rate=args.lr, + ) + + callbacks = [ + ModelCheckpoint(monitor="train/diffusion_loss", mode="min", save_top_k=3), + ] + + logger = WandbLogger(project="eegvix-v2-diffusion") + + trainer = L.Trainer( + max_epochs=args.max_epochs, + precision="16-mixed", + gradient_clip_val=1.0, + callbacks=callbacks, + logger=logger, + ) + + trainer.fit(model, datamodule) + + +if __name__ == "__main__": + main() diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..dcf0fa2 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,54 @@ +"""Shared test fixtures for EEGVIX tests.""" + +import pytest +import torch + + +@pytest.fixture +def batch_size(): + return 4 + + +@pytest.fixture +def n_channels(): + return 17 + + +@pytest.fixture +def n_timepoints(): + return 100 + + +@pytest.fixture +def embed_dim(): + return 512 + + +@pytest.fixture +def output_dim(): + return 768 + + +@pytest.fixture +def dummy_eeg(batch_size, n_channels, n_timepoints): + """Random EEG tensor simulating a batch.""" + return torch.randn(batch_size, n_channels, n_timepoints) + + +@pytest.fixture +def dummy_subject_ids(batch_size): + """Random subject IDs in [0, 9].""" + return torch.randint(0, 10, (batch_size,)) + + +@pytest.fixture +def dummy_images(batch_size): + """Random image tensor (CLIP-preprocessed size).""" + return torch.randn(batch_size, 3, 224, 224) + + +@pytest.fixture +def dummy_embeddings(batch_size, output_dim): + """Random L2-normalized embeddings.""" + x = torch.randn(batch_size, output_dim) + return torch.nn.functional.normalize(x, dim=-1) diff --git a/tests/test_losses.py b/tests/test_losses.py new file mode 100644 index 0000000..5ec8249 --- /dev/null +++ b/tests/test_losses.py @@ -0,0 +1,52 @@ +"""Tests for contrastive loss.""" + +import pytest +import torch + +from eegvix.losses.contrastive import InfoNCELoss + + +class TestInfoNCELoss: + def test_perfect_alignment(self): + """When EEG and image embeddings are identical, accuracy should be 1.0.""" + loss_fn = InfoNCELoss(init_temperature=0.07, learnable=False) + embeds = torch.nn.functional.normalize(torch.randn(8, 768), dim=-1) + result = loss_fn(embeds, embeds) + # Loss won't be near zero because cross-entropy with cosine ~1/temp is still nonzero, + # but accuracy must be perfect since the diagonal is the max + assert result["eeg_to_img_acc"].item() == 1.0 + assert result["img_to_eeg_acc"].item() == 1.0 + + def test_random_embeddings(self): + """With random embeddings, accuracy should be near chance (1/batch_size).""" + loss_fn = InfoNCELoss(init_temperature=0.07, learnable=False) + eeg = torch.nn.functional.normalize(torch.randn(64, 768), dim=-1) + img = torch.nn.functional.normalize(torch.randn(64, 768), dim=-1) + result = loss_fn(eeg, img) + # Loss should be high (close to log(64) ≈ 4.16) + assert result["loss"].item() > 3.0 + + def test_gradient_flows_to_temperature(self): + loss_fn = InfoNCELoss(init_temperature=0.07, learnable=True) + eeg = torch.nn.functional.normalize(torch.randn(8, 768), dim=-1) + img = torch.nn.functional.normalize(torch.randn(8, 768), dim=-1) + result = loss_fn(eeg, img) + result["loss"].backward() + assert loss_fn.log_temperature.grad is not None + + def test_symmetric(self): + """Loss should be symmetric: L(a, b) == L(b, a).""" + loss_fn = InfoNCELoss(init_temperature=0.07, learnable=False) + eeg = torch.nn.functional.normalize(torch.randn(8, 768), dim=-1) + img = torch.nn.functional.normalize(torch.randn(8, 768), dim=-1) + r1 = loss_fn(eeg, img) + r2 = loss_fn(img, eeg) + assert torch.allclose(r1["loss"], r2["loss"], atol=1e-5) + + def test_batch_size_1(self): + """Should not crash with batch size 1.""" + loss_fn = InfoNCELoss(init_temperature=0.07, learnable=False) + eeg = torch.nn.functional.normalize(torch.randn(1, 768), dim=-1) + img = torch.nn.functional.normalize(torch.randn(1, 768), dim=-1) + result = loss_fn(eeg, img) + assert not torch.isnan(result["loss"]) diff --git a/tests/test_models.py b/tests/test_models.py new file mode 100644 index 0000000..a1cb09e --- /dev/null +++ b/tests/test_models.py @@ -0,0 +1,91 @@ +"""Tests for EEG encoder, projection head, and subject embedding.""" + +import pytest +import torch + +from eegvix.models.eeg_encoder import EEGEncoder, TemporalConvStem, SpatialChannelAttention, FrequencyBranch +from eegvix.models.projection_head import ProjectionHead +from eegvix.models.subject_embedding import SubjectEmbedding + + +class TestTemporalConvStem: + def test_output_shape(self): + stem = TemporalConvStem(in_channels=1, embed_dim=256, kernel_sizes=[7, 5]) + x = torch.randn(17, 1, 100) # 17 channels, 1 input feature, 100 timepoints + out = stem(x) + assert out.shape[0] == 17 + assert out.shape[1] == 256 + assert out.shape[2] > 0 # Reduced temporal dimension + assert out.shape[2] < 100 + + +class TestSpatialChannelAttention: + def test_output_shape(self): + attn = SpatialChannelAttention(embed_dim=256, n_heads=4) + x = torch.randn(2, 17, 256) # batch=2, 17 channels, embed_dim=256 + out = attn(x) + assert out.shape == (2, 17, 256) + + +class TestFrequencyBranch: + def test_output_shape(self): + branch = FrequencyBranch(n_channels=17, n_timepoints=100, output_dim=128) + eeg = torch.randn(2, 17, 100) + out = branch(eeg) + assert out.shape == (2, 128) + + +class TestEEGEncoder: + def test_output_shape(self, dummy_eeg, dummy_subject_ids, output_dim): + encoder = EEGEncoder(output_dim=output_dim, embed_dim=128, n_temporal_transformer_layers=2) + out = encoder(dummy_eeg, dummy_subject_ids) + assert out.shape == (dummy_eeg.size(0), output_dim) + + def test_gradient_flow(self, dummy_eeg, dummy_subject_ids): + encoder = EEGEncoder(embed_dim=128, n_temporal_transformer_layers=2) + out = encoder(dummy_eeg, dummy_subject_ids) + loss = out.sum() + loss.backward() + # Check that gradients exist for at least some parameters + has_grad = any(p.grad is not None and p.grad.abs().sum() > 0 for p in encoder.parameters()) + assert has_grad + + def test_subject_embed_disable(self, dummy_eeg, dummy_subject_ids): + encoder = EEGEncoder(embed_dim=128, n_temporal_transformer_layers=2) + out_with = encoder(dummy_eeg, dummy_subject_ids, enable_subject_embed=True) + out_without = encoder(dummy_eeg, dummy_subject_ids, enable_subject_embed=False) + # Outputs should differ when subject embedding is toggled + assert not torch.allclose(out_with, out_without, atol=1e-4) + + def test_without_frequency_branch(self, dummy_eeg, dummy_subject_ids): + encoder = EEGEncoder(embed_dim=128, n_temporal_transformer_layers=2, use_frequency_branch=False) + out = encoder(dummy_eeg, dummy_subject_ids) + assert out.shape == (dummy_eeg.size(0), 768) + + +class TestProjectionHead: + def test_output_normalized(self, dummy_embeddings): + head = ProjectionHead(input_dim=768, hidden_dim=1024, output_dim=768) + out = head(dummy_embeddings) + norms = out.norm(dim=-1) + assert torch.allclose(norms, torch.ones_like(norms), atol=1e-5) + + def test_output_shape(self): + head = ProjectionHead(input_dim=512, hidden_dim=1024, output_dim=768) + x = torch.randn(4, 512) + out = head(x) + assert out.shape == (4, 768) + + +class TestSubjectEmbedding: + def test_output_shape(self): + emb = SubjectEmbedding(n_subjects=10, embed_dim=256) + ids = torch.tensor([0, 3, 7, 9]) + out = emb(ids) + assert out.shape == (4, 256) + + def test_different_subjects_different_embeddings(self): + emb = SubjectEmbedding(n_subjects=10, embed_dim=256) + e0 = emb(torch.tensor([0])) + e1 = emb(torch.tensor([1])) + assert not torch.allclose(e0, e1) diff --git a/tests/test_retrieval.py b/tests/test_retrieval.py new file mode 100644 index 0000000..8d91795 --- /dev/null +++ b/tests/test_retrieval.py @@ -0,0 +1,42 @@ +"""Tests for evaluation metrics.""" + +import torch +from eegvix.evaluation.retrieval import top_k_accuracy, zero_shot_identification + + +class TestTopKAccuracy: + def test_perfect_match(self): + """When embeddings are identical, top-1 accuracy should be 1.0.""" + embeds = torch.nn.functional.normalize(torch.randn(20, 768), dim=-1) + results = top_k_accuracy(embeds, embeds, k_values=[1, 5]) + assert results["top_1"] == 1.0 + assert results["top_5"] == 1.0 + + def test_random_chance(self): + """With random embeddings, top-1 accuracy should be near 1/n.""" + torch.manual_seed(42) + n = 200 + eeg = torch.nn.functional.normalize(torch.randn(n, 768), dim=-1) + img = torch.nn.functional.normalize(torch.randn(n, 768), dim=-1) + results = top_k_accuracy(eeg, img, k_values=[1]) + # Should be near chance (0.5%) but with some variance + assert results["top_1"] < 0.1 + + +class TestZeroShotIdentification: + def test_perfect_identification(self): + """With identical embeddings, all conditions should be correctly identified.""" + embeds = torch.nn.functional.normalize(torch.randn(50, 768), dim=-1) + result = zero_shot_identification(embeds, embeds) + assert result["accuracy"] == 1.0 + assert result["n_correct"] == 50 + + def test_with_distractors(self): + """Identification should still work with well-separated embeddings + distractors.""" + torch.manual_seed(42) + embeds = torch.nn.functional.normalize(torch.randn(10, 768), dim=-1) + distractors = torch.nn.functional.normalize(torch.randn(100, 768), dim=-1) + result = zero_shot_identification(embeds, embeds, distractors) + # With identical bio/syn and random distractors, most should be correct + assert result["accuracy"] >= 0.5 + assert result["n_candidates"] == 110