This repository contains the official PyTorch implementation of BaRISTA (Brain Scale Informed Spatiotemporal Representation of Human Intracranial Neural Activity).
We recommend setting up a virtual environment to manage dependencies.
# 1. Create and activate a virtual environment
python -m venv barista_venv
source barista_venv/bin/activate
# 2. Install the package in editable mode
python -m pip install -e .-
Download the data from the Brain Treebank website. You will also need the
clean_laplacian.jsonfile from the PopT codebase. -
Update the
dataset_dirconfig inbarista/braintreebank.yamlto point to the raw data path.
The data directory structure should have the following structure:
Click to expand full directory tree
braintreebank_data
|__corrupted_elec.json
|__clean_laplacian.json
|__all_subject_data
| |__ sub_1_trial000.h5
| |__ sub_1_trial001.h5
| |__ sub_1_trial002.h5
| |__ sub_2_trial000.h5
| |
| ...
|
|__ electrode_labels
| |__ sub_1
| | |__ electrode_labels.json
| |__ sub_2
| | |__ electrode_labels.json
| ...
|
|__ localization
| |__ elec_coords_full.csv
| |__ sub_1
| | |__ depth-wm.csv
| |__ sub_2
| | |__ depth-wm.csv
| ...
|
|__ subject_metadata
| |__ sub_1_trial000_metadata.json
| |__ sub_1_trial001_metadata.json
| |__ sub_1_trial002_metadata.json
| |__ sub_2_trial000_metadata.json
| |
| ...
|
|__ subject_timings
| |__ sub_1_trial000_timings.csv
| |__ sub_1_trial001_timings.csv
| |__ sub_1_trial002_timings.csv
| |__ sub_2_trial000_timings.csv
| |
| ...
|
|__ transcripts
| |__ ant-man
| | |__ features.csv
| |__ aquaman
| | |__ features.csv
| ......
You must segment the data before training. The required arguments depend on the experiment:
| Experiment Type | force_nonoverlap |
experiment options |
|---|---|---|
| 1. Random splits, non-overlapping neural segments (Main Analysis in the paper) | True |
sentence_onset, speech_vs_nonspeech |
| 2. Chronological splits, increased labels (Appendix K in the paper) | False |
sentence_onset_time, speech_vs_nonspeech_time, volume, optical_flow |
To generate the random splits with non-overlapping neural segments, as used for the main analysis (Section 4), you will need to run the following:
python barista/prepare_segments.py \
--config barista/config/braintreebank.yaml \
--experiment <sentence_onset|speech_vs_nonspeech>
⚠️ Ensureforce_nonoverlapinbarista/config/braintreebank.yamlis set toTruefor this experiment. Incorrect settings will produce invalid splits.
This setting should only be used with the sentence_onset and speech_vs_nonspeech experiments.
We can also generate chronological splitting (splitting sessions based on time rather than random shuffling). This approach enables us to increase the number of labeled segments for finetuning by allowing overlap between segments within the same split, while preventing information leakage (i.e., no overlapping neural segments) between train and test splits. To generate the chronological splits used for the evaluation in Appendix K, there are two steps to follow.
First, you will need to segment the data using the following command:
python barista/prepare_segments.py \
--config barista/config/braintreebank.yaml \
--experiment <sentence_onset_time|speech_vs_nonspeech_time|volume|optical_flow>
⚠️ Ensureforce_nonoverlapinbarista/config/braintreebank.yamlis set toFalsefor this experiment. Incorrect settings will produce invalid splits.
This setting should only be used with the sentence_onset_time, speech_vs_nonspeech_time, volume, and optical_flow experiments.
Second, you will need to generate the 5 chronological folds to use during evaluation. To create these different folds, we use the data/generate_chronological_folds.ipynb notebook. This notebook automatically will generate 5 different train/valid/test splits across time, while ensuring that all generated splits have both positive and negative labels present. To use the notebook, take the following steps:
-
Open
generate_chronological_folds.ipynb -
Update the
_METADATA_FNAMESvariable with the metadata hash string produced from the previous step. -
Run the notebook to generate the 5 train/valid/test fold pickle files.
The notebook will output a pickle file in the same directory as the specified metadata file and it will be dynamically loaded during train/eval time to ensure the right chronological split fold is used.
To finetune the model,
-
Set update
finetune_sessionsfield inbarista/config/braintreebank.yamlto the desired finetuning session. -
Use the following command to run finetuning:
python barista/train.pyIt is important to ensure the braintreebank.yaml fields match precisely with the config used during segmentation generation, including the experiment field. Otherwise, the metadata hash string will not match and the experiment will fail. For the chronological folds, the experiment will also fail if the pickle file outlined in the second step of Generating chronological splits with increased label data hasn't been generated.
Pretrained models are available under pretrained_models/. Set the checkpoint_path in barista/config/train.yaml to the specific pretrained model path. e.g. checkpoint_path: pretrained_models/parcels_chans.ckpt.
⚠️ You also need to set thetokenizer.spatial_groupinginbarista/config/model.yamlaccordingly for each of the models.
| Checkpoint Name | tokenizer.spatial_grouping |
|---|---|
chans_chans.ckpt |
coords |
parcels_chans.ckpt |
destrieux |
lobes_chans.ckpt |
lobes |
Alternatively, you can pass these as extra argument to train command:
Example finetuning command for Parcel level model
python barista/train.py \
--override \
tokenizer.spatial_grouping="destrieux" \
checkpoint_path="pretrained_models/parcels_chans.ckpt"You can also use the scripts under barista/utility_scripts to run the model for a specific setting across different finetuning seeds.
The run outputs are saved in the results directory specified in the script and can be easily aggregated using aggregate_runs.py across different subjects, models, and folds.
Example usage for random splits
./barista/utility_scripts/run_finetune_random_splits.sh \
--spe destrieux \
--checkpoint "pretrained_models/parcels_chans.ckpt" \
--session HOLDSUBJ_1_HS1_1 \
--gpu 0 \
--exp sentence_onsetExample usage for chronological fold
./barista/utility_scripts/run_finetune_folds.sh \
--spe destrieux \
--checkpoint "pretrained_models/parcels_chans.ckpt" \
--session HOLDSUBJ_1_HS1_1 \
--gpu 0 \
--fold 0 \
--exp sentence_onset_timeYou can use utility_scripts/aggregate_runs.py to get the average results as a markdown table:
python barista/utility_scripts/aggregate_runs.py \
--results_dir <results|results_folds>Citation
@inproceedings{
oganesian2025barista,
title={BaRISTA: Brain Scale Informed Spatiotemporal Representation of Human Intracranial Neural Activity},
author={Oganesian, Lucine L. and Hashemi, Saba and Shanechi, Maryam M.},
booktitle={Advances in Neural Information Processing Systems},
year={2025},
url={https://openreview.net/pdf?id=LDjBDk3Czb}
}
Copyright (c) 2025 University of Southern California
See full notice in LICENSE.md
Lucine L. Oganesian, Saba Hashemi, and Maryam M. Shanechi
Shanechi Lab, University of Southern California
