SpatialGT is a graph transformer model for spatial transcriptomics data analysis. It leverages spatial context through neighbor-aware attention mechanisms to enable:
- πΊοΈ Spatial Context Learning: Pre-train on large-scale spatial transcriptomics data
- 𧬠Gene Expression Reconstruction: Predict masked gene expression from spatial context
- π¬ Perturbation Simulation: Simulate transcriptomic responses to virtual perturbations
git clone https://github.com/ai4nucleome/SpatialGT.git
cd SpatialGT
# Create and activate environment
conda env create -f env/environment.yml
conda activate spatialgtgit clone https://github.com/ai4nucleome/SpatialGT.git
cd SpatialGT
# Install PyTorch with CUDA 11.8
pip install torch==2.2.2+cu118 torchvision==0.17.2+cu118 torchaudio==2.2.2+cu118 \
-f https://download.pytorch.org/whl/torch_stable.html
# Install dependencies
pip install -r requirements.txtSee env/INSTALL.md for detailed installation instructions.
import torch
from pretrain.model_spatialpt import SpatialNeighborTransformer
from pretrain.Config import Config
# Load configuration
config = Config()
# Initialize model
model = SpatialNeighborTransformer(config)
# Load pretrained weights
checkpoint = torch.load("path/to/checkpoint.pt")
model.load_state_dict(checkpoint)
# Your spatial transcriptomics data
# ...SpatialGT/
βββ pretrain/ # Pretraining module
β βββ Config.py # Configuration
β βββ model_spatialpt.py # Model architecture
β βββ spatial_databank.py # Data loading utilities
β βββ run_pretrain.py # Training script
β βββ run.sh # Launch script
β
βββ finetune/ # Finetuning module
β βββ Config.py # Finetuning configuration
β βββ finetune.py # Finetuning script
β βββ finetune.sh # Launch script
β
βββ reconstruction/ # Expression reconstruction
β βββ spatialgt_reconstruction.py
β βββ knn_reconstruction.py # KNN baseline
β βββ sedr_reconstruction.py # SEDR baseline
β βββ run_reconstruction.sh
β
βββ perturbation/ # Perturbation simulation
β βββ mouse_stroke/ # Mouse stroke case study
β βββ human_colitis/ # Human colitis case study
β
βββ gene_embedding/ # Pretrained gene embeddings (download from HuggingFace)
β βββ vocab.json
β βββ id_to_gene.json
β βββ pretrained_gene_embeddings.pt
β
βββ baseline/ # Baseline methods
β βββ SEDR/ # SEDR implementation
β
βββ env/ # Environment setup
β βββ environment.yml
β βββ INSTALL.md
β
βββ requirements.txt
Your spatial transcriptomics data should be in AnnData format (.h5ad) with:
adata.X: Gene expression matrix (cells Γ genes)adata.obsm['spatial']: Spatial coordinatesadata.var_names: Gene symbols
# Preprocess your data for training
python pretrain/preprocess.py \
--dataset_list /path/to/datalist.txt \
--cache_dir /path/to/cache \
--n_neighbors 8The datalist.txt should contain paths to your H5AD files, one per line.
cd pretrain
# Single GPU
python run_pretrain.py \
--dataset_list /path/to/datalist.txt \
--output_dir /path/to/output
# Multi-GPU (distributed)
bash run.shcd finetune
# Finetune on your dataset
bash finetune.sh \
--base_ckpt /path/to/pretrained/checkpoint \
--cache_dir /path/to/your/data \
--output_dir /path/to/outputKey parameters:
--unfreeze_last_n: Number of transformer layers to unfreeze (default: 8, all layers)--num_epochs: Training epochs (default: 100)--learning_rate: Learning rate (default: 1e-4)
cd reconstruction
# SpatialGT reconstruction (10 steps)
bash run_reconstruction.sh --method spatialgt --n_spots 100
# SEDR baseline (1 step)
bash run_reconstruction.sh --method sedr --n_spots 100
# KNN baseline (10 steps)
bash run_reconstruction.sh --method knn --n_spots 100cd perturbation/mouse_stroke
# Run perturbation with ICA region
bash run_perturbation.sh output_name \
--perturb_mode random \
--n_spots 80 \
--steps 10cd perturbation/human_colitis
# Run perturbation on activated MNPs
python colitis_spatialgt_perturb_eval.py \
--sample HS5_UC_R_0 \
--perturb_target MNP_activated \
--steps 10We provide pretrained and finetuned model checkpoints on Hugging Face:
| Model | Description | Download |
|---|---|---|
| SpatialGT-Pretrained | Pretrained on spatial transcriptomics atlas | π€ Hugging Face |
| SpatialGT-MouseStroke-Sham | Finetuned on mouse stroke Sham (control) | π€ Hugging Face |
| SpatialGT-MouseStroke-PT | Finetuned on mouse stroke PT (stroke) | π€ Hugging Face |
| SpatialGT-GeneEmbedding | Pretrained gene embeddings | π€ Hugging Face |
# Using huggingface-cli
huggingface-cli download Bgoood/SpatialGT-Pretrained --local-dir model/pretrain_ckpt
huggingface-cli download Bgoood/SpatialGT-MouseStroke-Sham --local-dir model/sham_1_ft
huggingface-cli download Bgoood/SpatialGT-MouseStroke-PT --local-dir model/pt_ft
# Download pretrained gene embeddings
huggingface-cli download Bgoood/SpatialGT-GeneEmbedding --local-dir gene_embeddingOr using Python:
from huggingface_hub import snapshot_download
# Download pretrained model
snapshot_download(repo_id="Bgoood/SpatialGT-Pretrained", local_dir="model/pretrain_ckpt")This project is licensed under the MIT License - see the LICENSE file for details.
For questions and issues, please open a GitHub issue or contact yxu662@connect.hkust-gz.edu.cn.
- Scanpy for single-cell analysis tools
- Hugging Face Transformers for transformer implementations
- SEDR for the baseline method
