Skip to content

Repo for "Optimized feature gains explain and predict successes and failures of human selective listening" by Ian M. Griffith, R. Preston Hess, and Josh H. McDermott (in press, Nature Human Behavior)

License

Notifications You must be signed in to change notification settings

IanMGriff/auditory_attention

Repository files navigation

Auditory Attention

The code repository for "Optimized feature gains explain and predict successes and failures of human selective listening" by Ian M. Griffith, R. Preston Hess, and Josh H. McDermott (in press, Nature Human Behavior).


Figure 1


Data directories

To use the repository, first download the model checkpoint, participant data, model simulation results, and demo stimuli as archives corresponding to attn_cue_models, data, and demo_stimulifrom our OSF project site: https://doi.org/10.17605/OSF.IO/WJZVU.

  • attn_cue_models.zip holds a checkpoint for our best model (e.g., word_task_v10_main_feature_gain_config).
  • demo_stimuli.zip provides the example .wav files used in the quick-start code below (male/female cues, targets, and mixtures).
  • data.tar contains processed experiment tables: CSV/PKL tables with the aggregated model and human results used by the figure scripts for all experiments and model simulations.

Repository map

  • attn_cue_models/ – Pretrained checkpoint for the best feature-gain model analyzed in the paper.
  • config/ – YAML files describing model architectures, datasets, and hyperparameters.
  • corpus/ – Dataset/dataloader definitions for model training and behavioral simulations.
  • data/ – CSV/PKL tables with the aggregated model and human results used by the figure scripts.
  • demo_stimuli/ – Male/female cue-target .wav files plus mixtures so you can run the model.
  • notebooks/ – Jupyter notebooks for exploratory analysis and figure generation.
    • notebooks/Final_Figures/ contains both .ipynb and .py counterparts for every main and supplementary figure. Run python notebooks/Final_Figures/run_all_figure_gen.py to regenerate all figures and associated statistics.
  • src/ – PyTorch Lightning modules, cochlear front-end implementations, audio transforms, utilities, and every experiment entrypoint (e.g., src/eval_*.py). Run them with python -m src.<module_name>.
  • scripts/ – SLURM-ready job scripts that change into the repo root, export PYTHONPATH, and invoke the corresponding src/... modules. Use them as templates for MIT OpenMind or adapt them to your own scheduler.

Requirements

  1. Install dependencies
    • Python 3.11.5
    • PyTorch 2.1+
    • PyTorch Lightning 2.1+
    • Additional packages listed in requirements.txt (recommended: create a Conda environment and run pip install -r requirements.txt)
  2. Hardware expectations
    Model training used a DDP environment with 4×A100-80GB GPUs and 100 GB host RAM with 4 CPUs feeding each GPU. Models took roughly 7-10 days to converge (depending on architecture size).

Experiment mapping

  • Per-experiment simulations (call with python -m src.<module_name>)
    • src/eval_swc_mono_stim.py – Experiment 1 (main diotic conditions; distractor sex & language)
    • src/eval_swc_popham_2024.py – Experiment 2 (talker harmonicity)
    • src/eval_texture_backgrounds.py – Experiment 3 (Saddler & McDermott 2024 background textures)
    • src/eval_symmetric_distractors.py – Extended Data Figure 4 at all spatial configurations; Experiment 4 (Byrne et al. 2023)
    • src/eval_precedence.py – Experiment 5 (simulate Freyman et al. 1999)
    • src/eval_sim_array_threshold_experiment_v02.py – Experiment 6 (thresholds)
    • src/eval_sim_array_spotlight_experiment_v02.py – Experiment 7 (spotlight task)
    • src/eval_cue_duration.py – Experiment 1b (cue duration)
    • src/get_acts_for_tuning_and_selection_analysis.py – Activations for Figure 5 / Extended Data Figure 5
    • src/get_acts_for_tuning_anova_jsin.py – Activations for Extended Data Figure 7
    • src/unit_tuning_anova_parallel_jsin.py – ANOVAs for Extended Data Figure 7
  • Cluster execution
    Use the scripts in scripts/ (e.g., scripts/run_unit_tuning_anova_parallel.sh) as templates for your scheduler; they capture the exact resource settings we used on OpenMind.

Quick-start: load a checkpoint and run the demo stimuli

import yaml
import pickle 
from pathlib import Path
from src.spatial_attn_lightning import BinauralAttentionModule 
import src.audio_transforms as at
import soundfile as sf 

config_path = "config/binaural_attn/word_task_v10_main_feature_gain_config.yaml"
config = yaml.load(open(config_path, 'r'), Loader=yaml.FullLoader)

# set checkpoint path
ckpt_path  =  'attn_cue_models/word_task_v10_main_feature_gain_config/checkpoints/epoch=1-step=24679-v1.ckpt'

# load model from checkpoint and freeze with .eval()
model = BinauralAttentionModule.load_from_checkpoint(checkpoint_path=ckpt_path, config=config, strict=False).eval()

# send to gpu
model = model.cuda()

# get cochleagram 
coch_gram = model.coch_gram.cuda()

# define audio transforms
SNR = 0 # signal-to-noise ratio in dB for CombineWithRandomDBSNR. Setting low and high to same value sets snr to that value
audio_transforms = at.AudioCompose([
                        at.AudioToTensor(),
                        at.CombineWithRandomDBSNR(low_snr=SNR, high_snr=SNR), 
                        at.RMSNormalizeForegroundAndBackground(rms_level=0.02),
                        at.DuplicateChannel(),
                        at.UnsqueezeAudio(dim=0),
                        ])

# Load word dictionary 
with open("./cv_800_word_label_to_int_dict.pkl", "rb") as f:
    word_to_ix_dict = pickle.load(f) 

# Map for class ix to word labels
class_ix_to_word = {v: k for k, v in word_to_ix_dict.items()}

# Load audio demo stimuli
outdir = Path("demo_stimuli")

female_cue, _ = sf.read(outdir / "female_cue.wav")
male_cue, _ = sf.read(outdir / "male_cue.wav")

female_target, _ = sf.read(outdir / "female_target_above.wav")
male_target, _ = sf.read(outdir / "male_target_about.wav" )

# use demo labels 
female_target_word = 'above'
male_target_word = 'about'

# transform audio
mixture, _ = audio_transforms(female_target, male_target) # will combine first and second signal at specified dB SNR 
female_cue, _ = audio_transforms(female_cue, None) # can pass None if not processing distractor 
male_cue, _ = audio_transforms(male_cue, None)

# get cochleagrams 
female_cue_cgram, male_cue_cgram = coch_gram(female_cue.cuda().float(), male_cue.cuda().float())
mixture_cgram, _ = coch_gram(mixture.cuda().float(), None)

# get model prediction when cueing male talker
model_logits = model(male_cue_cgram, mixture_cgram)
male_word_pred = model_logits.softmax(-1).argmax(dim=1).item()
print(f"Male cue -> True word: {male_target_word}. Predicted word: {class_ix_to_word[male_word_pred]}")
# should print "True word: about. Predicted word: about"

# get model predictions when cueing female talker in same mixture
model_logits = model(female_cue_cgram, mixture_cgram)
female_word_pred = model_logits.softmax(-1).argmax(dim=1).item()
print(f"Female cue -> True word: {female_target_word}. Predicted word: {class_ix_to_word[female_word_pred]}")
# should print "True word: above. Predicted word: above"

This example relies entirely on tracked assets (config/, attn_cue_models/, demo_stimuli/, cv_800_word_label_to_int_dict.pkl). After confirming it runs end-to-end, you can swap in your own stimuli, adjust the audio transforms, or fine-tune the models with different configs. For deeper dives, inspect the notebooks in notebooks/Final_Figures/ or the evaluation scripts listed earlier.

About

Repo for "Optimized feature gains explain and predict successes and failures of human selective listening" by Ian M. Griffith, R. Preston Hess, and Josh H. McDermott (in press, Nature Human Behavior)

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published