Code repository for paper HEALNet: Multimodal Fusion for Heterogeneous Biomedical Data
[pdf | Installation | Experimental Data | Getting Started | Cite]
- Preserve modality-specific signal: HEALNet learns modality-specific weights for each modality and projects it into a shared embedding. Positional encodings capture spatial signal for the specified number of spatial dimensions.
- Learn cross-modal interactions: By passing a latent embedding through the fusion layers (Figure B), we 1) iteratively encode each modality into the share embedding which 2) consequently becomes the context for the next modality. As such, this latent becomes a "learned query" that is updated in each layer pass.
- Handling missing modalities: The model's iterative architecture allows skipping missing modalities for individual samples at train or inference time without adding much noise. This allows to train on all data without being restricted to the intersection of available modalities.
- Model inspection: The model can be inspected through the modality-specific attention weights.
- 8/12/2024: Camera-ready release (v0.1.0) available!
- 25/09/2024: HEALNet has been accepted to NeurIPS 2024. Reach out to chat in Vancouver!
First, locally install HEALNet using pip.
git clone git@github.com:konst-int-i/healnet.git
cd healnet
conda create --name healnet python=3.9We provide two sets of dependencies for installation:
- Lightweight: access to
healnet.models - All: access to entire experimental pipeline
We recommend the lightweight installation if you only want to use the healnet.models to build on top of HEALNet in a different pipeline.
pip install -e .The full experiments require some further dependencies which can be installed using
pip install -e .[all]Note that you require the .[all] installation to run the tutorial.
You can test the installation by running the pytests
pytest -v healnet/tests/from healnet import HealNet
from healnet.etl import MMDataset
import torch
import einops
# synthetic data example
n = 100 # number of samples
b = 4 # batch size
img_c = 3 # image channels
tab_c = 1 # tabular channels
tab_d = 2000 # tabular features
# 2D dims
h = 224 # image height
w = 224 # image width
# 3d dim
d = 12
tab_tensor = torch.rand(size=(n, tab_c, tab_d))
img_tensor_2d = torch.rand(size=(n, h, w, img_c)) # h w c
img_tensor_3d = torch.rand(size=(n, d, h, w, img_c)) # d h w c
dataset = MMDataset([tab_tensor, img_tensor_2d, img_tensor_3d])
[tab_sample, img_sample_2d, img_sample_3d] = dataset[0]
# batch dim for illustration purposes
tab_sample = einops.repeat(tab_sample, 'c d -> b c d', b=1) # spatial axis: None (pass as 1)
img_sample_2d = einops.repeat(img_sample_2d, 'h w c -> b h w c', b=1) # spatial axes: h w
img_sample_3d = einops.repeat(img_sample_3d, 'd h w c -> b d h w c', b=1) # spatial axes: d h w
tensors = [tab_sample, img_sample_2d, img_sample_3d]
model = HealNet(
n_modalities=3,
channel_dims=[2000, 3, 3], # (2000, 3, 3) number of channels/tokens per modality
num_spatial_axes=[1, 2, 3], # (1, 2, 3) number of spatial axes (will be positionally encoded to preserve spatial information)
out_dims = 4
)
# example forward pass
logits = model(tensors)Please view our Getting Started Notebook for a more detailed example.
If you want to reproduce the results in the paper instead of using HEALNet as a standalone module, you need to install a few more dependencies.
Install or update the conda/mamba environment using and then activate. For a faster installation, we recommend using mamba.
conda env update -f environment.yml
On Mac or Linux, you can install the below dependencies using the command line
invoke install --system <system>for both linux and mac.
This will auto-install the requirements below (OpenSlide and GDC client). Please follow detailed instructions below if our pre-written installation fails.
Note that for openslide-python to work, you need to install openslide separately on your system.
See here for instructions.
To download the WSI data, you need to install the gdc-client for your respective platform
We are using git-lfs to store the pre-processed mutation, CNV, and gene expression data.
sudo apt-get install git-lfs
git lfs install
git lfs pullThis will pull the data into data/tcga/omic and data/tcga/omic_xena.
From the root of the repository, run
- Specify the path to the gdc-client executable in
main.yml(this will likely be the repository root if you installed the dependencies usinginvoke install). - Run
invoke download --dataset <dataset> --config_path <config>, e.g., invoke download --dataset brca
If you are unsure about which arguments are available, you can always run invoke download --help.
The script downloads the data using the given manifest files in data/tcga/gdc_manifests/full and save it in the data folder under tcga/wsi/<dataset> taking the following structure:
tcga/wsi/<dataset>/
├── slide_1.svs
├── slide_2.svs
└── ...
If a data manifest file is not available for a given cancer site, you can select the files and download the manifest using the NIH Genomic Data Commons Data Portal. You can filter the .svs tissue and diagnostics slide files
To ensure comparability with baselines, want to have the option to run the model in the WSI patches and extracted features using the CLAM package.
To extract he patches, run
invoke preprocess --dataset <dataset> --config <config> --level <level>Which will extract to the following structure
tcga/wsi/<dataset>_preprocessed/
├── masks
├── slide_1.png
├── slide_2.png
└── ...
├── patches
├── slide_1.h5
├── slide_2.h5
└── ...
├── stitches
├── slide_1.png
├── slide_2.png
└── ...
└── process_list_autogen.csv
Note that the slide.h5 files contain the coordinates of the patches that are to be read in via OpenSlide (x, y coordinates).
On first run of the pipeline, the script will add an additional folder called patch_features which contains the ResNet50 extracted features after patch normalisation as a 2048-dimensional tensor (using PyTorch serialisation).
├── patch_features
├── slide_1.pt
├── slide_2.pt
└── ...
This repo contains the manifests and scripts to easily download the following 8 cancer sites from The Cancer Genome Atlas. You can use the GDC Data Access Tool and use the same scripts if you require additional data.
- BLCA: Urothelial Bladder Carcinoma
- BRCA: Breast Invasive Carcinoma
- UCEC: Uterine Corpus Endometrial Carcinoma
- KIRP: cevical Kidney Renal Papillary Cell Carcinoma
- LUAD: Lung Adenocarcinoma
- LUSC: Lung Squamous Cell Carcinoma
- PAAD: Pancreatic adenocarcinoma
- HNSC: Head and Neck Squamous Cell Carcinoma
To be added
Given the configuration in config.yml, you can launch a single run using. Note that all below commands assume that you are in the repository root.
python3 healnet/main.pyTo prevent import errors, you may have to add your local path to the PYTHONPATH
export PYTHONPATH=<path_to_repository>:$PYTHONPATHYou can view the available command line arguments using
python3 healnet/main.py --helppython3 healnet/main.py --mode run_planYou can launch a hyperparameter search by passing the --hyperparameter_sweep argument.
python3 healnet/main.py --hyperparameter_sweepNote that the sweep parameters are specified in the config/sweep.yaml file. If a parameter is not specified as part of the parameter sweep, the program will default to whatever is configured in config/main_gpu.yml
