PGSFormer is a transformer-based model for alcohol use disorder prediction from polygenic scores and covariates. In this implementation, each PGS value is treated as a token, projected into a latent embedding space, and passed through a stack of transformer encoder layers. The encoded token sequence is aggregated by attention pooling to produce a global PGS representation.
To preserve direct signal from the original PGS vector, PGSFormer includes a residual shortcut branch. The raw PGS input is mapped by a linear layer and added to the pooled transformer representation through a learnable positive scaling gate. An auxiliary head on the pooled PGS representation produces an auxiliary PGS score, while the final AUD prediction is made after concatenating the PGS representation with non-EEG covariates.
This repository contains the final no-EEG configuration used for the main PGSFormer experiment:
- Inputs: 11 PGS features plus
SEXandCONTROL - Backbone: transformer encoder with attention pooling
- Residual shortcut: enabled
- Auxiliary PGS head: enabled
- Reconstruction decoder: disabled
d_model=64n_layers=1n_heads=4dropout=0.25lr=1e-4weight_decay=1e-5batch_size=512epochs=200patience=50
The main training objective is focal loss on the AUD classifier output. The final model also includes an auxiliary binary cross-entropy loss on the auxiliary PGS head and a residual binary cross-entropy loss on the shortcut branch.
train_pgsformer_no_eeg.py: main training scriptmodel/pgsformer.py: PGSFormer model definitiondata/coga.py: dataset loaderutils/config.py: fold-aware path expansionutils/misc.py: random seed helperrun_train.sh: five-fold launcher for the final configuration
The training script expects fold-specific data under a root directory:
DATA_ROOT/fold_0/learning_set.csvDATA_ROOT/fold_0/validation_set.csvDATA_ROOT/fold_1/learning_set.csvDATA_ROOT/fold_1/validation_set.csv...
Within learning_set.csv, rows labeled subset=train are used for training and rows labeled subset=test are used for validation. The external evaluation split is read from validation_set.csv.
Set DATA_ROOT in run_train.sh, then run:
bash run_train.sh