The advent of foundation models has ushered in a transformative era in computational pathology, enabling the extraction of rich, transferable image features for a broad range of downstream pathology tasks. However, site-specific signatures and demographic biases persist in these features, leading to short-cut learning and unfair predictions, ultimately compromising model generalizability and fairness across diverse clinical sites and demographic groups.
This repository implements FLEX, a novel framework that enhances cross-domain generalization and demographic fairness of pathology foundation models, thus facilitating accurate diagnosis across diverse pathology tasks. FLEX employs a task-specific information bottleneck, informed by visual and textual domain knowledge, to promote:
- Generalizability across clinical settings
- Fairness across demographic groups
- Adaptability to specific pathology tasks
- Cross-domain generalization: Significantly improves diagnostic performance on data from unseen sites
- Demographic fairness: Reduces performance gaps between demographic groups
- Versatility: Compatible with various vision-language models
- Scalability: Adaptable to varying training data sizes
- Seamless integration: Works with multiple instance learning frameworks
-
Clone the repository:
git clone https://github.com/HKU-MedAI/FLEX cd FLEX -
Create and activate a virtual environment, and install the dependencies:
conda env create -f environment.yml conda activate flex
Prepare your data in the following structure:
Dataset/
├── TCGA-BRCA/
│ ├── features/
│ │ ├── ...
│ ├── tcga-brca_label.csv
│ ├── tcga-brca_label_her2.csv
│ └── ...
├── TCGA-NSCLC/
└── ...
Organize visual prompts in the following structure:
prompts/
├── BRCA/
│ ├── 0/
│ │ ├── image1.png
│ │ └── ...
│ └── 1/
│ ├── image1.png
│ └── ...
├── BRCA_HER2/
└── ...
Generate site-preserved Monte Carlo Cross-Validation (SP-MCCV) splits for your dataset:
python generate_sitepreserved_splits.py
python generate_sp_mccv_splits.pyDue to the large size of WSIs, patch-level features must be extracted first. We recommend using established pipelines like CLAM or TRIDENT. This is a computationally intensive step. Extracted features (e.g., in .h5 format) should be placed in the features/ subdirectory for each dataset.
- Visual Prompts: As described in our paper, visual prompts are representative patches for each class. We provide the visual prompts used in our experiments in the
prompts/directory. For custom tasks, you will need to generate your own. - Textual Prompts: Textual concepts are defined within the code/configuration files. These are crucial for guiding the information bottleneck. Please refer to
config.py(or similar file) to see how task-specific prompts like "invasive ductal carcinoma" are defined.
Train the FLEX model and evaluate the performance:
bash ./scripts/train_flex.sh--task: Task name (e.g., BRCA, NSCLC, STAD_LAUREN)--data_root_dir: Path to the data directory--split_suffix: Split suffix (e.g., sitepre5_fold3)--exp_code: Experiment code for logging and saving results--model_type: Model type (default: flex)--base_mil: Base MIL framework (default:abmil)--slide_align: Whether to align in slide level (default: 1)--w_infonce: Weight for InfoNCE loss (default: 14)--w_kl: Weight for KL loss (default: 14)--len_prompt: Number of learnable textual prompt tokens
FLEX has been evaluated on 16 clinically relevant tasks and demonstrates:
- Improved performance on unseen clinical sites
- Reduced performance gap between seen and unseen sites
- Enhanced fairness across demographic groups
For detailed results, refer to our paper.
This project is licensed under the Apache-2.0 license.
This project was built on the top of amazing works, including CLAM, CONCH, QuiltNet, PathGen-CLIP, and PreservedSiteCV. We thank the authors for their great works.
