Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions configs/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
defaults:
- data: things_eeg2
- model: eeg_encoder
- training: contrastive
- evaluation: metrics
- wandb: default
- _self_

seed: 42
45 changes: 45 additions & 0 deletions configs/data/things_eeg2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
data_dir: ${oc.env:EEGVIX_DATA_DIR,eeg_dataset}
preprocessed_dir: ${data.data_dir}/preprocessed

n_subjects: 10
n_channels: 17
n_timepoints: 100
sampling_rate: 100 # Hz (downsampled from 1000)
time_start: -0.2 # seconds relative to stimulus onset
time_end: 0.8

n_train_images: 16540
n_test_images: 200
n_train_reps: 4
n_test_reps: 80
images_per_concept: 10

# Channels: occipital and parietal
channels:
- O1
- Oz
- O2
- PO7
- PO3
- POz
- PO4
- PO8
- P7
- P5
- P3
- P1
- Pz
- P2
- P4
- P6
- P8

# Data loading
batch_size: 256
num_workers: 4
average_repetitions: true
val_n_concepts: 150
val_random_state: 42

# Precomputed embeddings
clip_embeddings_dir: ${data.data_dir}/clip_embeddings
12 changes: 12 additions & 0 deletions configs/evaluation/metrics.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
retrieval:
top_k: [1, 5, 10, 50, 200]

zero_shot:
n_test_images: ${data.n_test_images}
n_test_reps: ${data.n_test_reps}

generation:
compute_fid: true
compute_ssim: true
compute_lpips: true
n_generated_per_condition: 5
22 changes: 22 additions & 0 deletions configs/experiment/debug.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# @package _global_

defaults:
- /data: things_eeg2
- /model: eeg_encoder
- /training: contrastive
- /evaluation: metrics
- /wandb: default

seed: 42

data:
batch_size: 16
num_workers: 0
n_subjects: 1

training:
max_epochs: 5
early_stopping_patience: 3

wandb:
mode: disabled
10 changes: 10 additions & 0 deletions configs/experiment/default.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# @package _global_

defaults:
- /data: things_eeg2
- /model: eeg_encoder
- /training: contrastive
- /evaluation: metrics
- /wandb: default

seed: 42
26 changes: 26 additions & 0 deletions configs/model/eeg_encoder.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Spatiotemporal Transformer EEG Encoder
eeg_encoder:
n_channels: ${data.n_channels}
n_timepoints: ${data.n_timepoints}
embed_dim: 512
num_temporal_conv_layers: 3
temporal_kernel_sizes: [7, 5, 3]
n_spatial_heads: 4
n_temporal_transformer_layers: 4
n_temporal_heads: 8
dropout: 0.1
use_frequency_branch: true
output_dim: 768 # CLIP ViT-L/14 embedding dimension

subject_embedding:
n_subjects: ${data.n_subjects}
embed_dim: ${model.eeg_encoder.embed_dim}

projection_head:
input_dim: ${model.eeg_encoder.output_dim}
hidden_dim: 2048
output_dim: 768

clip:
model_name: ViT-L-14
pretrained: openai
21 changes: 21 additions & 0 deletions configs/training/contrastive.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
learning_rate: 3e-4
weight_decay: 0.01
warmup_epochs: 10
max_epochs: 200
early_stopping_patience: 20

optimizer: adamw
scheduler: cosine_warmup

# InfoNCE loss
init_temperature: 0.07
learnable_temperature: true

# Trainer
precision: 16-mixed
gradient_clip_val: 1.0
accumulate_grad_batches: 1
check_val_every_n_epoch: 1

# Subject embedding warmup
subject_embedding_warmup_epochs: 5
18 changes: 18 additions & 0 deletions configs/training/diffusion_finetune.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
learning_rate: 1e-5
weight_decay: 0.01
max_epochs: 50

# Stable Diffusion
sd_model: stabilityai/stable-diffusion-2-1
use_ip_adapter: true
use_lora: true
lora_rank: 16
lora_alpha: 32

# Generation
num_inference_steps: 50
guidance_scale: 7.5
image_resolution: 512

precision: 16-mixed
gradient_clip_val: 1.0
4 changes: 4 additions & 0 deletions configs/wandb/default.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
project: eegvix-v2
entity: null
tags: []
mode: online # online, offline, disabled
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
51 changes: 51 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
[build-system]
requires = ["setuptools>=68.0", "wheel"]
build-backend = "setuptools.build_meta"

[project]
name = "eegvix"
version = "2.0.0"
description = "EEG-to-Image generation via CLIP-aligned contrastive learning and diffusion models."
requires-python = ">=3.10"
license = {text = "MIT"}
authors = [{name = "Robert Vava", email = "vavarobert10@gmail.com"}]

dependencies = [
"torch>=2.1",
"torchvision>=0.16",
"lightning>=2.1",
"open-clip-torch>=2.24",
"diffusers>=0.25",
"transformers>=4.36",
"peft>=0.7",
"hydra-core>=1.3",
"omegaconf>=2.3",
"wandb>=0.16",
"numpy>=1.24",
"scipy>=1.11",
"scikit-learn>=1.3",
"pillow>=10.0",
"torchmetrics>=1.2",
"lpips>=0.1.4",
"tqdm>=4.66",
"einops>=0.7",
"matplotlib>=3.8",
]

[project.optional-dependencies]
dev = [
"pytest>=7.4",
"pytest-cov>=4.1",
"ruff>=0.1",
"mypy>=1.7",
]

[tool.setuptools.packages.find]
where = ["src"]

[tool.ruff]
line-length = 120
target-version = "py310"

[tool.pytest.ini_options]
testpaths = ["tests"]
3 changes: 3 additions & 0 deletions src/eegvix/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
"""EEGVIX: EEG-to-Image generation via CLIP-aligned contrastive learning and diffusion models."""

__version__ = "2.0.0"
4 changes: 4 additions & 0 deletions src/eegvix/data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from eegvix.data.dataset import ThingsEEG2Dataset
from eegvix.data.datamodule import ThingsEEG2DataModule

__all__ = ["ThingsEEG2Dataset", "ThingsEEG2DataModule"]
45 changes: 45 additions & 0 deletions src/eegvix/data/channel_info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
"""Electrode positions for the 17 occipital/parietal channels used in THINGS-EEG2.

2D positions are approximate projections onto a unit circle (top-down view of scalp).
Coordinates follow the standard 10-10 system layout used in the dataset.
"""

import torch

# Channel names in the order they appear in the preprocessed data
CHANNEL_NAMES: list[str] = [
"O1", "Oz", "O2",
"PO7", "PO3", "POz", "PO4", "PO8",
"P7", "P5", "P3", "P1", "Pz", "P2", "P4", "P6", "P8",
]

# Approximate 2D scalp positions (x, y) in normalized coordinates [-1, 1].
# x: left(-) to right(+), y: posterior(-) to anterior(+)
# These follow the standard 10-10 montage layout.
CHANNEL_POSITIONS_2D: dict[str, tuple[float, float]] = {
"O1": (-0.31, -0.95),
"Oz": ( 0.00, -1.00),
"O2": ( 0.31, -0.95),
"PO7": (-0.59, -0.81),
"PO3": (-0.31, -0.81),
"POz": ( 0.00, -0.81),
"PO4": ( 0.31, -0.81),
"PO8": ( 0.59, -0.81),
"P7": (-0.81, -0.59),
"P5": (-0.59, -0.59),
"P3": (-0.39, -0.59),
"P1": (-0.19, -0.59),
"Pz": ( 0.00, -0.59),
"P2": ( 0.19, -0.59),
"P4": ( 0.39, -0.59),
"P6": ( 0.59, -0.59),
"P8": ( 0.81, -0.59),
}

N_CHANNELS = len(CHANNEL_NAMES)


def get_channel_positions_tensor() -> torch.Tensor:
"""Return channel positions as a (17, 2) float tensor, ordered by CHANNEL_NAMES."""
positions = [CHANNEL_POSITIONS_2D[ch] for ch in CHANNEL_NAMES]
return torch.tensor(positions, dtype=torch.float32)
Loading