This repository provides the official PyTorch implementation for our paper: "Explainable Multi-modality Learning for Eye Disease Diagnosis with Missing Data"
Ensure your environment satisfies the following dependencies:
Python == 3.8.5
PyTorch == 1.8.1
TorchVision == 0.9.1
NumPy == 1.20.2
OpenCV (cv2) == 4.5.1Install dependencies via pip:
pip install torch==1.8.1 torchvision==0.9.1 numpy==1.20.2 opencv-python==4.5.1Before running training or inference, configure parameters in the conf.py file. This includes:
- Model settings
- Modality control
- Missing data simulation options
- Training hyperparameters
To start training (using 2 GPUs), simply run:
bash run.shThis internally calls:
python -m torch.distributed.launch --nproc_per_node=2 --master_port=21676 --use_env main.py ...All training options (e.g., margin mode, prototype settings, seed) can be modified directly in run.sh.
To run inference on test data:
python local_test.pyMake sure that the paths and inference parameters are properly set in conf.py.
This project partially builds upon the Deformable ProtoPNet framework. We thank the authors for their valuable open-source contributions, which helped inspire the development of our model.