Skip to content
/ FLEX Public

Knowledge-Guided Adaptation of Pathology Foundation Models Improves Cross-domain Generalization and Demographic Fairness

License

Notifications You must be signed in to change notification settings

HKU-MedAI/FLEX

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

12 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

header

Overview

The advent of foundation models has ushered in a transformative era in computational pathology, enabling the extraction of rich, transferable image features for a broad range of downstream pathology tasks. However, site-specific signatures and demographic biases persist in these features, leading to short-cut learning and unfair predictions, ultimately compromising model generalizability and fairness across diverse clinical sites and demographic groups.

This repository implements FLEX, a novel framework that enhances cross-domain generalization and demographic fairness of pathology foundation models, thus facilitating accurate diagnosis across diverse pathology tasks. FLEX employs a task-specific information bottleneck, informed by visual and textual domain knowledge, to promote:

  • Generalizability across clinical settings
  • Fairness across demographic groups
  • Adaptability to specific pathology tasks

FLEX Framework

Features

  • Cross-domain generalization: Significantly improves diagnostic performance on data from unseen sites
  • Demographic fairness: Reduces performance gaps between demographic groups
  • Versatility: Compatible with various vision-language models
  • Scalability: Adaptable to varying training data sizes
  • Seamless integration: Works with multiple instance learning frameworks

Installation

Setup

  1. Clone the repository:

    git clone https://github.com/HKU-MedAI/FLEX
    cd FLEX
  2. Create and activate a virtual environment, and install the dependencies:

    conda env create -f environment.yml
    conda activate flex

Instructions for Use

Data Preparation

Prepare your data in the following structure:

Dataset/
├── TCGA-BRCA/
│   ├── features/
│   │   ├── ...
│   ├── tcga-brca_label.csv
│   ├── tcga-brca_label_her2.csv
│   └── ...
├── TCGA-NSCLC/
└── ...

Visual Prompts

Organize visual prompts in the following structure:

prompts/
├── BRCA/
│   ├── 0/
│   │   ├── image1.png
│   │   └── ...
│   └── 1/
│       ├── image1.png
│       └── ...
├── BRCA_HER2/
└── ...

Running on Your Data

Generate Data Splits

Generate site-preserved Monte Carlo Cross-Validation (SP-MCCV) splits for your dataset:

python generate_sitepreserved_splits.py
python generate_sp_mccv_splits.py

Extract Features

Due to the large size of WSIs, patch-level features must be extracted first. We recommend using established pipelines like CLAM or TRIDENT. This is a computationally intensive step. Extracted features (e.g., in .h5 format) should be placed in the features/ subdirectory for each dataset.

Prepare Visual and Textual Concepts

  1. Visual Prompts: As described in our paper, visual prompts are representative patches for each class. We provide the visual prompts used in our experiments in the prompts/ directory. For custom tasks, you will need to generate your own.
  2. Textual Prompts: Textual concepts are defined within the code/configuration files. These are crucial for guiding the information bottleneck. Please refer to config.py (or similar file) to see how task-specific prompts like "invasive ductal carcinoma" are defined.

Train the FLEX model

Train the FLEX model and evaluate the performance:

bash ./scripts/train_flex.sh

Key Parameters

  • --task: Task name (e.g., BRCA, NSCLC, STAD_LAUREN)
  • --data_root_dir: Path to the data directory
  • --split_suffix: Split suffix (e.g., sitepre5_fold3)
  • --exp_code: Experiment code for logging and saving results
  • --model_type: Model type (default: flex)
  • --base_mil: Base MIL framework (default:abmil)
  • --slide_align: Whether to align in slide level (default: 1)
  • --w_infonce: Weight for InfoNCE loss (default: 14)
  • --w_kl: Weight for KL loss (default: 14)
  • --len_prompt: Number of learnable textual prompt tokens

Evaluation Results

FLEX has been evaluated on 16 clinically relevant tasks and demonstrates:

  • Improved performance on unseen clinical sites
  • Reduced performance gap between seen and unseen sites
  • Enhanced fairness across demographic groups

For detailed results, refer to our paper.

License

This project is licensed under the Apache-2.0 license.

Acknowledgments

This project was built on the top of amazing works, including CLAM, CONCH, QuiltNet, PathGen-CLIP, and PreservedSiteCV. We thank the authors for their great works.

About

Knowledge-Guided Adaptation of Pathology Foundation Models Improves Cross-domain Generalization and Demographic Fairness

Resources

License

Stars

Watchers

Forks

Packages

No packages published