Skip to content

SpatialGT: A Graph Transformer for Spatial Transcriptomics Analysis

License

Notifications You must be signed in to change notification settings

ai4nucleome/SpatialGT

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 

History

5 Commits
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

SpatialGT: Spatial Graph Transformer for Spatial Transcriptomics

Python 3.8+ PyTorch 2.0+ License: MIT

Overview

SpatialGT is a graph transformer model for spatial transcriptomics data analysis. It leverages spatial context through neighbor-aware attention mechanisms to enable:

  • πŸ—ΊοΈ Spatial Context Learning: Pre-train on large-scale spatial transcriptomics data
  • 🧬 Gene Expression Reconstruction: Predict masked gene expression from spatial context
  • πŸ”¬ Perturbation Simulation: Simulate transcriptomic responses to virtual perturbations

Table of Contents

Installation

Option 1: Conda (Recommended)

git clone https://github.com/ai4nucleome/SpatialGT.git
cd SpatialGT

# Create and activate environment
conda env create -f env/environment.yml
conda activate spatialgt

Option 2: Pip

git clone https://github.com/ai4nucleome/SpatialGT.git
cd SpatialGT

# Install PyTorch with CUDA 11.8
pip install torch==2.2.2+cu118 torchvision==0.17.2+cu118 torchaudio==2.2.2+cu118 \
    -f https://download.pytorch.org/whl/torch_stable.html

# Install dependencies
pip install -r requirements.txt

See env/INSTALL.md for detailed installation instructions.

Quick Start

import torch
from pretrain.model_spatialpt import SpatialNeighborTransformer
from pretrain.Config import Config

# Load configuration
config = Config()

# Initialize model
model = SpatialNeighborTransformer(config)

# Load pretrained weights
checkpoint = torch.load("path/to/checkpoint.pt")
model.load_state_dict(checkpoint)

# Your spatial transcriptomics data
# ...

Repository Structure

SpatialGT/
β”œβ”€β”€ pretrain/                 # Pretraining module
β”‚   β”œβ”€β”€ Config.py            # Configuration
β”‚   β”œβ”€β”€ model_spatialpt.py   # Model architecture
β”‚   β”œβ”€β”€ spatial_databank.py  # Data loading utilities
β”‚   β”œβ”€β”€ run_pretrain.py      # Training script
β”‚   └── run.sh               # Launch script
β”‚
β”œβ”€β”€ finetune/                 # Finetuning module
β”‚   β”œβ”€β”€ Config.py            # Finetuning configuration
β”‚   β”œβ”€β”€ finetune.py          # Finetuning script
β”‚   └── finetune.sh          # Launch script
β”‚
β”œβ”€β”€ reconstruction/           # Expression reconstruction
β”‚   β”œβ”€β”€ spatialgt_reconstruction.py
β”‚   β”œβ”€β”€ knn_reconstruction.py      # KNN baseline
β”‚   β”œβ”€β”€ sedr_reconstruction.py     # SEDR baseline
β”‚   └── run_reconstruction.sh
β”‚
β”œβ”€β”€ perturbation/             # Perturbation simulation
β”‚   β”œβ”€β”€ mouse_stroke/        # Mouse stroke case study
β”‚   └── human_colitis/       # Human colitis case study
β”‚
β”œβ”€β”€ gene_embedding/           # Pretrained gene embeddings (download from HuggingFace)
β”‚   β”œβ”€β”€ vocab.json
β”‚   β”œβ”€β”€ id_to_gene.json
β”‚   └── pretrained_gene_embeddings.pt
β”‚
β”œβ”€β”€ baseline/                 # Baseline methods
β”‚   └── SEDR/                # SEDR implementation
β”‚
β”œβ”€β”€ env/                      # Environment setup
β”‚   β”œβ”€β”€ environment.yml
β”‚   └── INSTALL.md
β”‚
└── requirements.txt

Data Preparation

1. Prepare H5AD Files

Your spatial transcriptomics data should be in AnnData format (.h5ad) with:

  • adata.X: Gene expression matrix (cells Γ— genes)
  • adata.obsm['spatial']: Spatial coordinates
  • adata.var_names: Gene symbols

2. Preprocess Data

# Preprocess your data for training
python pretrain/preprocess.py \
    --dataset_list /path/to/datalist.txt \
    --cache_dir /path/to/cache \
    --n_neighbors 8

The datalist.txt should contain paths to your H5AD files, one per line.

Usage

Pretraining

cd pretrain

# Single GPU
python run_pretrain.py \
    --dataset_list /path/to/datalist.txt \
    --output_dir /path/to/output

# Multi-GPU (distributed)
bash run.sh

Finetuning

cd finetune

# Finetune on your dataset
bash finetune.sh \
    --base_ckpt /path/to/pretrained/checkpoint \
    --cache_dir /path/to/your/data \
    --output_dir /path/to/output

Key parameters:

  • --unfreeze_last_n: Number of transformer layers to unfreeze (default: 8, all layers)
  • --num_epochs: Training epochs (default: 100)
  • --learning_rate: Learning rate (default: 1e-4)

Reconstruction

cd reconstruction

# SpatialGT reconstruction (10 steps)
bash run_reconstruction.sh --method spatialgt --n_spots 100

# SEDR baseline (1 step)
bash run_reconstruction.sh --method sedr --n_spots 100

# KNN baseline (10 steps)
bash run_reconstruction.sh --method knn --n_spots 100

Perturbation Simulation

Mouse Stroke Case

cd perturbation/mouse_stroke

# Run perturbation with ICA region
bash run_perturbation.sh output_name \
    --perturb_mode random \
    --n_spots 80 \
    --steps 10

Human Colitis Case

cd perturbation/human_colitis

# Run perturbation on activated MNPs
python colitis_spatialgt_perturb_eval.py \
    --sample HS5_UC_R_0 \
    --perturb_target MNP_activated \
    --steps 10

Pretrained Models

We provide pretrained and finetuned model checkpoints on Hugging Face:

Model Description Download
SpatialGT-Pretrained Pretrained on spatial transcriptomics atlas πŸ€— Hugging Face
SpatialGT-MouseStroke-Sham Finetuned on mouse stroke Sham (control) πŸ€— Hugging Face
SpatialGT-MouseStroke-PT Finetuned on mouse stroke PT (stroke) πŸ€— Hugging Face
SpatialGT-GeneEmbedding Pretrained gene embeddings πŸ€— Hugging Face

Download Models

# Using huggingface-cli
huggingface-cli download Bgoood/SpatialGT-Pretrained --local-dir model/pretrain_ckpt
huggingface-cli download Bgoood/SpatialGT-MouseStroke-Sham --local-dir model/sham_1_ft
huggingface-cli download Bgoood/SpatialGT-MouseStroke-PT --local-dir model/pt_ft

# Download pretrained gene embeddings
huggingface-cli download Bgoood/SpatialGT-GeneEmbedding --local-dir gene_embedding

Or using Python:

from huggingface_hub import snapshot_download

# Download pretrained model
snapshot_download(repo_id="Bgoood/SpatialGT-Pretrained", local_dir="model/pretrain_ckpt")

License

This project is licensed under the MIT License - see the LICENSE file for details.

Contact

For questions and issues, please open a GitHub issue or contact yxu662@connect.hkust-gz.edu.cn.

Acknowledgments

About

SpatialGT: A Graph Transformer for Spatial Transcriptomics Analysis

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published