This the repository for "Riemannian Metric Learning for Alignment of Spatial Multiomics." a technique which:
- Performs Riemannian metric learning across spatial modalities (multimoics, transcriptomics, and so on) using the Riemannian pull-back metric.
- Infers Riemannian (geodesic) distances.
- Aligns Riemannian distances with Gromov-Wasserstein Optimal Transport.
In the section below, we detail the usage of MGW which complements the simple demo notebooks:
- [demo_mgw_y7.ipynb](demo_mgw_y7.ipynb)
- [mouse_align.ipynb](mouse_align.ipynb)mgw/
mgw.py— main solver/class for MGWgeometry.py— metric-tensor, geodesic distance, k-NN graph, APSP utilitiesmodels.py— neural field models (MLP)metric.py— evaluation metrics (e.g. migration, AMI, cosine similarity(plotting.py— visualization utilitiesutils.py— miscellaneous helpers, barycentric projection
validation/
dopamine.py— validation utilities for dopamine experiments (AUROC, AUPRC)run_methods.py— code for running other methods (moscot Translation, SCOT, SCOTv2, PASTE2 FGW spatial, POT FGW spatial only)
demos/
demo_mgw_y7.ipynb— demo notebook for running MGW on the Y_7 ccRCC slice (Hu '24)mouse_align.ipynb- demo notebook for aligning Spatiotemporal Transcriptomics with MGW on E9.5-10.5 mouse embryo timepoint pair (Chen '22)riemannian_mouse_geodesics.ipynb— demo visualization of the geodesics in the Riemannian pull-back metric of E9.5-10.5 mouse embryo (Chen '22)
experiments/
- Reproducible experimental notebooks on Stereo-Seq Mouse Embryo, Visium-Xenium alignment of colorectal cancer, MALDI-MSI metabolomics and Visium transcriptomics alignment of human striatum, AFADESI-MSI and Visium alignment of renal cancer.
Load two AnnData objects such as spatial transcriptomics (st) and spatial metabolomics (msi) after appropriate filtering.
import anndata as ad
st = ad.read_h5ad(ST_PATH)
msi = ad.read_h5ad(MSI_PATH)Call mgw.mgw_preprocess on two AnnDatas.
You can run PCA (will default to pre-computed PCA if already done) with PCA_comp components, and an additional CCA step for multimodal data. Set use_cca_feeler=True for this CCA step, which involves basic/coarse feeler alignment (spatial_only: bool = True to do a spatial-only feeler, feature_only = True to do a feature-only feeler, or if both False a basic spatial-feature feeler). This subsets feature dimensions which are correlated across modalities, and you can specify the number of final CCA dimensions with CCA_comp.
To run on the raw st.X and msi.X as-is without processing, set use_cca_feeler=False, use_pca_X/Z=False, and log1p_X/Z=False. We do not assume common/joint features in multimodal data generally and do independent internal PCA steps. For unimodal (e.g. transcriptomics-transcriptomics) we recommend an external joint PCA: see, e.g. experiments/mgw_mouse_embryo.ipynb for an example of this pre-processing.
import mgw.mgw as mgw
pre = mgw.mgw_preprocess(
st, msi,
PCA_comp=PCA_componet,
CCA_comp=CCA_componet,
use_cca_feeler=True,
use_pca_X=True,
use_pca_Z=False, #False if the features from second modality are intensities which doesn't make sense to run pca on
log1p_X=True,
log1p_Z=False, #False if the features from second modality are not counts which doesn't make sense to run log1p on
verbose=True
)Next, we run mgw.mgw_align_core on the data pre to both infer the neural fields, learn metric tensors, and align the result with Gromov-Wasserstein.
PHI_ARC = (128,256,256,128)
KNN_K= 12
DEFAULT_GW_PARAMS = dict(verbose=True, inner_maxit=3000, outer_maxit=3000, inner_tol=1e-7, outer_tol=1e-7, epsilon=1e-4)
DEFAULT_LR = 1e-3
DEFAULT_EPS = 1e-2
DEFAULT_ITER = 20_000
EXP_PATH = "your_path"
EXP_TAG = "your_tag"
out = mgw.mgw_align_core(
pre,
widths=PHI_ARC,
lr=DEFAULT_LR,
niter=DEFAULT_ITER,
knn_k=KNN_K,
geodesic_eps=DEFAULT_EPS,
save_dir=EXP_PATH,
tag=EXP_TAG,
verbose=True,
plot_net=True, # zoom in to visually check if the two modalities shown similar pattern
use_cca = True, #for multi-modal, we recommend setting to TRUE
gw_params = DEFAULT_GW_PARAMS
)Here, the key parameters are
PHI_ARC: Layers of the MLPKNN_K: Resolution of the K nearest neighbor graph used for Riemannian geodesicsDEFAULT_EPS: Epsilon for stability of Jacobian. Generally not an issue, and smaller yields more faithful Riemannian geodesics (e.g. 1e-5 for mouse embryo).DEFAULT_GW_PARAMS: Default parameters for the optimal transport solver of ott jaxDEFAULT_LR: Learning-rate for the network.DEFAULT_ITER: Number of training iterations for the network.save_dir: Where to save outputstag: Tag for generated files.
We have a number of variables which can be accessed from out.
P: MGW coupling/alignmentxs: Spatial coordinates 1 (normalized)xs2: Spatial coordinates 2 (normalized)phi: Neural field mapping into modality 1psi: Neural field mapping into modality 2G_M/G_N: Pull-back metric tensor field evaluated at the coordinatesC_M/C_N: MGW Riemannian distance matrices
As an example, let us return the alignment and barycentrically project across modalities.
P = out["P"]
from mgw.evaulation import bary_proj
adata_sm2st = bary_proj(st, msi, P)
adata_st2sm = bary_proj(msi, st, P.T)P represents the MGW alignment, adata_sm2st is the metabolomics to transcriptomics projection (added to st as metabolite annotation), and adata_st2sm is the transcriptomics to metabolomics projection (added to msi as transcriptomics annotation).