This repository contains the official PyTorch implementation of our cross-modal transformer framework for weakly supervised prostate cancer grading from H&E whole slide images (WSIs). The model aligns slide-level visual representations with grade-specific textual descriptions generated by a frozen LLM and performs classification via cosine similarity in a shared latent space.
- Patch feature input (e.g., CLAM / ResNet features)
- Optional IRM attention scores for top-M patch refinement
- Frozen LLM prompt token embeddings (precomputed)
- Cross-attention transformer (text queries, visual keys/values)
- Convex aggregation for slide and text embeddings
- Cosine similarity + temperature scaling for grading
Expected per-slide feature files:
patches_dir/<slide_id>.pt containing:
features: Tensor[N, d_patch]attn(optional): Tensor[N]
Labels file:
labels.csv with columns: slide_id,label (label in {0,1,2,3}).
- Precompute prompt token embeddings and save as a
.ptdict:{class_id: Tensor[L, d_text_in]} - Train:
python scripts/train.py --cfg configs/default.yaml --fold 0 --prompt_pt /llm/prompts.pt