|
| 1 | +"""Utilities for joint RNA + protein analysis on Xenium AnnData objects. |
| 2 | +
|
| 3 | +This module provides a high-level pipeline that clusters cells using the RNA |
| 4 | +expression matrix and, for each protein marker, trains a small neural network |
| 5 | +to explain high-vs-low protein states within every RNA-defined cluster. |
| 6 | +
|
| 7 | +The function is intentionally self-contained and only relies on scikit-learn |
| 8 | +primitives so that it can operate on large Xenium datasets without requiring |
| 9 | +extra dependencies (e.g. Scanpy). It works directly on the data structures |
| 10 | +produced by :func:`pyXenium.io.xenium_gene_protein_loader.load_xenium_gene_protein`. |
| 11 | +""" |
| 12 | + |
| 13 | +from __future__ import annotations |
| 14 | + |
| 15 | +from dataclasses import dataclass |
| 16 | +from typing import Dict, List, Optional, Tuple |
| 17 | + |
| 18 | +import numpy as np |
| 19 | +import pandas as pd |
| 20 | +from anndata import AnnData |
| 21 | +from scipy import sparse |
| 22 | +from sklearn.cluster import KMeans |
| 23 | +from sklearn.decomposition import TruncatedSVD |
| 24 | +from sklearn.metrics import accuracy_score, roc_auc_score |
| 25 | +from sklearn.model_selection import train_test_split |
| 26 | +from sklearn.neural_network import MLPClassifier |
| 27 | +from sklearn.preprocessing import StandardScaler |
| 28 | + |
| 29 | + |
| 30 | +@dataclass |
| 31 | +class ProteinModelResult: |
| 32 | + """Container for the protein classification model of one cluster.""" |
| 33 | + |
| 34 | + protein: str |
| 35 | + cluster: str |
| 36 | + threshold: float |
| 37 | + n_cells: int |
| 38 | + n_high: int |
| 39 | + n_low: int |
| 40 | + train_accuracy: float |
| 41 | + test_accuracy: float |
| 42 | + test_auc: float |
| 43 | + model: MLPClassifier |
| 44 | + scaler: StandardScaler |
| 45 | + |
| 46 | + |
| 47 | +def _get_rna_matrix(adata: AnnData): |
| 48 | + """Return the raw RNA matrix from ``adata`` as CSR sparse matrix.""" |
| 49 | + |
| 50 | + X = adata.layers["rna"] if "rna" in adata.layers else adata.X |
| 51 | + return X.tocsr() if sparse.issparse(X) else sparse.csr_matrix(np.asarray(X)) |
| 52 | + |
| 53 | + |
| 54 | +def _normalize_log1p(matrix: sparse.csr_matrix, target_sum: float = 1e4) -> sparse.csr_matrix: |
| 55 | + """Library-size normalisation followed by log1p for sparse matrices.""" |
| 56 | + |
| 57 | + matrix = matrix.astype(np.float32) |
| 58 | + cell_sums = np.array(matrix.sum(axis=1)).ravel() |
| 59 | + cell_sums[cell_sums == 0] = 1.0 |
| 60 | + inv = sparse.diags((target_sum / cell_sums).astype(np.float32)) |
| 61 | + norm = inv @ matrix |
| 62 | + norm.data = np.log1p(norm.data) |
| 63 | + return norm |
| 64 | + |
| 65 | + |
| 66 | +def _fit_pcs( |
| 67 | + matrix: sparse.csr_matrix, |
| 68 | + n_components: int, |
| 69 | + random_state: Optional[int], |
| 70 | +) -> np.ndarray: |
| 71 | + """Fit TruncatedSVD on the RNA matrix and return dense principal components.""" |
| 72 | + |
| 73 | + n_features = matrix.shape[1] |
| 74 | + if n_features <= 1: |
| 75 | + raise ValueError("RNA matrix must contain at least two genes for SVD.") |
| 76 | + |
| 77 | + n_components = max(1, min(n_components, n_features - 1)) |
| 78 | + svd = TruncatedSVD(n_components=n_components, random_state=random_state) |
| 79 | + pcs = svd.fit_transform(matrix) |
| 80 | + return pcs.astype(np.float32) |
| 81 | + |
| 82 | + |
| 83 | +def _resolve_protein_frame(adata: AnnData) -> pd.DataFrame: |
| 84 | + """Return the protein matrix stored in ``adata.obsm['protein']`` as DataFrame.""" |
| 85 | + |
| 86 | + if "protein" not in adata.obsm: |
| 87 | + raise KeyError("AnnData is missing 'protein' modality in adata.obsm['protein'].") |
| 88 | + |
| 89 | + protein_df = adata.obsm["protein"] |
| 90 | + if isinstance(protein_df, pd.DataFrame): |
| 91 | + return protein_df |
| 92 | + |
| 93 | + return pd.DataFrame(np.asarray(protein_df), index=adata.obs_names) |
| 94 | + |
| 95 | + |
| 96 | +def rna_protein_cluster_analysis( |
| 97 | + adata: AnnData, |
| 98 | + *, |
| 99 | + n_clusters: int = 12, |
| 100 | + n_pcs: int = 30, |
| 101 | + cluster_key: str = "rna_cluster", |
| 102 | + random_state: Optional[int] = 0, |
| 103 | + target_sum: float = 1e4, |
| 104 | + min_cells_per_cluster: int = 50, |
| 105 | + min_cells_per_group: int = 20, |
| 106 | + protein_split_method: str = "median", |
| 107 | + protein_quantile: float = 0.75, |
| 108 | + test_size: float = 0.2, |
| 109 | + hidden_layer_sizes: Tuple[int, ...] = (64, 32), |
| 110 | + max_iter: int = 200, |
| 111 | + early_stopping: bool = True, |
| 112 | +) -> Tuple[pd.DataFrame, Dict[str, Dict[str, ProteinModelResult]]]: |
| 113 | + """Joint RNA/protein analysis for Xenium AnnData objects. |
| 114 | +
|
| 115 | + The pipeline performs three consecutive steps: |
| 116 | +
|
| 117 | + 1. **RNA preprocessing** – library-size normalisation (counts per ``target_sum``) |
| 118 | + followed by ``log1p``. A :class:`~sklearn.decomposition.TruncatedSVD` |
| 119 | + is fitted to obtain ``n_pcs`` latent dimensions. |
| 120 | + 2. **Clustering** – :class:`~sklearn.cluster.KMeans` is applied on the latent |
| 121 | + representation to create ``n_clusters`` RNA-driven cell groups. Cluster |
| 122 | + assignments are stored in ``adata.obs[cluster_key]`` and the latent space |
| 123 | + in ``adata.obsm['X_rna_pca']``. |
| 124 | + 3. **Protein explanation** – for every cluster and every protein marker, the |
| 125 | + cells are divided into "high" vs. "low" groups (median split by default). |
| 126 | + A small neural network (:class:`~sklearn.neural_network.MLPClassifier`) |
| 127 | + is trained to predict the binary labels from the RNA latent features. The |
| 128 | + training/test accuracies and optional ROC-AUC are reported. |
| 129 | +
|
| 130 | + Parameters |
| 131 | + ---------- |
| 132 | + adata: |
| 133 | + AnnData object returned by |
| 134 | + :func:`pyXenium.io.xenium_gene_protein_loader.load_xenium_gene_protein`. |
| 135 | + Requires ``adata.layers['rna']`` (or ``adata.X``) and |
| 136 | + ``adata.obsm['protein']``. |
| 137 | + n_clusters: |
| 138 | + Number of RNA clusters to compute with KMeans. |
| 139 | + n_pcs: |
| 140 | + Number of latent components extracted with TruncatedSVD. The value is |
| 141 | + automatically capped at ``n_genes - 1``. |
| 142 | + cluster_key: |
| 143 | + Column name added to ``adata.obs`` that stores cluster labels. |
| 144 | + random_state: |
| 145 | + Seed for the SVD, KMeans and neural networks. Use ``None`` for random |
| 146 | + initialisation. |
| 147 | + target_sum: |
| 148 | + Target library size after normalisation (Counts Per ``target_sum``). |
| 149 | + min_cells_per_cluster: |
| 150 | + Clusters with fewer cells are skipped entirely. |
| 151 | + min_cells_per_group: |
| 152 | + Minimum number of cells required in both "high" and "low" protein |
| 153 | + groups to train a neural network. |
| 154 | + protein_split_method: |
| 155 | + Either ``"median"`` (default) for a median split or ``"quantile"`` to |
| 156 | + keep only the top ``protein_quantile`` and bottom ``1 - protein_quantile`` |
| 157 | + fractions of cells (discarding the middle portion). |
| 158 | + protein_quantile: |
| 159 | + Quantile used when ``protein_split_method='quantile'``. |
| 160 | + test_size: |
| 161 | + Fraction of the cluster reserved for the test split when training the |
| 162 | + neural network. |
| 163 | + hidden_layer_sizes: |
| 164 | + Hidden-layer configuration passed to :class:`MLPClassifier`. |
| 165 | + max_iter: |
| 166 | + Maximum number of training iterations for the neural network. |
| 167 | + early_stopping: |
| 168 | + Whether to use early stopping in :class:`MLPClassifier`. |
| 169 | +
|
| 170 | + Returns |
| 171 | + ------- |
| 172 | + summary: |
| 173 | + :class:`pandas.DataFrame` summarising the trained models. Columns are |
| 174 | + ``['cluster', 'protein', 'threshold', 'n_cells', 'n_high', 'n_low', |
| 175 | + 'train_accuracy', 'test_accuracy', 'test_auc']``. |
| 176 | + models: |
| 177 | + Nested dictionary ``{cluster -> {protein -> ProteinModelResult}}`` |
| 178 | + containing the fitted neural networks and scalers for downstream use. |
| 179 | +
|
| 180 | + Examples |
| 181 | + -------- |
| 182 | + >>> from pyXenium.analysis import rna_protein_cluster_analysis |
| 183 | + >>> summary, models = rna_protein_cluster_analysis(adata, n_clusters=8) |
| 184 | + >>> summary.head() |
| 185 | + cluster protein threshold n_cells ... test_accuracy test_auc |
| 186 | + 0 cluster_0 EPCAM (µm) 0.563100 512 ... 0.84 0.91 |
| 187 | + 1 cluster_0 Podocin (µm^2) 0.118775 512 ... 0.79 0.87 |
| 188 | + """ |
| 189 | + |
| 190 | + if adata.n_obs == 0: |
| 191 | + raise ValueError("AnnData contains no cells (n_obs == 0).") |
| 192 | + |
| 193 | + protein_df = _resolve_protein_frame(adata) |
| 194 | + if protein_df.shape[1] == 0: |
| 195 | + raise ValueError("AnnData.obsm['protein'] is empty – nothing to analyse.") |
| 196 | + |
| 197 | + rna_csr = _get_rna_matrix(adata) |
| 198 | + if rna_csr.shape[1] < 2: |
| 199 | + raise ValueError("RNA modality must have at least two genes for clustering.") |
| 200 | + |
| 201 | + log_norm = _normalize_log1p(rna_csr, target_sum=target_sum) |
| 202 | + pcs = _fit_pcs(log_norm, n_components=n_pcs, random_state=random_state) |
| 203 | + adata.obsm["X_rna_pca"] = pcs |
| 204 | + |
| 205 | + kmeans = KMeans(n_clusters=n_clusters, random_state=random_state, n_init=10) |
| 206 | + cluster_labels = kmeans.fit_predict(pcs) |
| 207 | + cluster_names = np.array([f"cluster_{i}" for i in cluster_labels]) |
| 208 | + adata.obs[cluster_key] = cluster_names |
| 209 | + |
| 210 | + results: List[Dict[str, float]] = [] |
| 211 | + models: Dict[str, Dict[str, ProteinModelResult]] = {} |
| 212 | + |
| 213 | + unique_clusters = pd.Index(np.unique(cluster_names)) |
| 214 | + for cluster in unique_clusters: |
| 215 | + idx = np.where(cluster_names == cluster)[0] |
| 216 | + if idx.size < min_cells_per_cluster: |
| 217 | + continue |
| 218 | + |
| 219 | + cluster_pcs = pcs[idx] |
| 220 | + cluster_protein = protein_df.iloc[idx] |
| 221 | + models.setdefault(cluster, {}) |
| 222 | + |
| 223 | + for protein in cluster_protein.columns: |
| 224 | + values_all = cluster_protein[protein].to_numpy(dtype=np.float32) |
| 225 | + finite_mask = np.isfinite(values_all) |
| 226 | + if finite_mask.sum() < min_cells_per_group * 2: |
| 227 | + continue |
| 228 | + |
| 229 | + values = values_all[finite_mask] |
| 230 | + X_cluster = cluster_pcs[finite_mask] |
| 231 | + |
| 232 | + if protein_split_method == "median": |
| 233 | + threshold = float(np.median(values)) |
| 234 | + labels = (values >= threshold).astype(int) |
| 235 | + elif protein_split_method == "quantile": |
| 236 | + q = float(protein_quantile) |
| 237 | + if not 0.5 < q < 1.0: |
| 238 | + raise ValueError("protein_quantile must be between 0.5 and 1.0 (exclusive).") |
| 239 | + high_thr = np.quantile(values, q) |
| 240 | + low_mask = values <= np.quantile(values, 1.0 - q) |
| 241 | + high_mask = values >= high_thr |
| 242 | + selected_mask = high_mask | low_mask |
| 243 | + if selected_mask.sum() < min_cells_per_group * 2: |
| 244 | + continue |
| 245 | + values = values[selected_mask] |
| 246 | + X_cluster = X_cluster[selected_mask] |
| 247 | + labels = high_mask[selected_mask].astype(int) |
| 248 | + threshold = float(high_thr) |
| 249 | + else: |
| 250 | + raise ValueError("protein_split_method must be 'median' or 'quantile'.") |
| 251 | + |
| 252 | + n_selected = labels.size |
| 253 | + n_high = int(labels.sum()) |
| 254 | + n_low = int(n_selected - n_high) |
| 255 | + |
| 256 | + if n_high < min_cells_per_group or n_low < min_cells_per_group: |
| 257 | + continue |
| 258 | + |
| 259 | + try: |
| 260 | + X_train, X_test, y_train, y_test = train_test_split( |
| 261 | + X_cluster, |
| 262 | + labels, |
| 263 | + test_size=test_size, |
| 264 | + random_state=random_state, |
| 265 | + stratify=labels, |
| 266 | + ) |
| 267 | + except ValueError: |
| 268 | + # Not enough samples to stratify. |
| 269 | + continue |
| 270 | + |
| 271 | + scaler = StandardScaler() |
| 272 | + X_train_scaled = scaler.fit_transform(X_train) |
| 273 | + X_test_scaled = scaler.transform(X_test) |
| 274 | + |
| 275 | + clf = MLPClassifier( |
| 276 | + hidden_layer_sizes=hidden_layer_sizes, |
| 277 | + random_state=random_state, |
| 278 | + max_iter=max_iter, |
| 279 | + early_stopping=early_stopping, |
| 280 | + ) |
| 281 | + |
| 282 | + try: |
| 283 | + clf.fit(X_train_scaled, y_train) |
| 284 | + except Exception: |
| 285 | + continue |
| 286 | + |
| 287 | + train_acc = float(accuracy_score(y_train, clf.predict(X_train_scaled))) |
| 288 | + test_pred = clf.predict(X_test_scaled) |
| 289 | + test_acc = float(accuracy_score(y_test, test_pred)) |
| 290 | + |
| 291 | + if hasattr(clf, "predict_proba") and len(np.unique(y_test)) == 2: |
| 292 | + probs = clf.predict_proba(X_test_scaled)[:, 1] |
| 293 | + try: |
| 294 | + test_auc = float(roc_auc_score(y_test, probs)) |
| 295 | + except ValueError: |
| 296 | + test_auc = float("nan") |
| 297 | + else: |
| 298 | + test_auc = float("nan") |
| 299 | + |
| 300 | + result = ProteinModelResult( |
| 301 | + protein=protein, |
| 302 | + cluster=cluster, |
| 303 | + threshold=threshold, |
| 304 | + n_cells=n_selected, |
| 305 | + n_high=n_high, |
| 306 | + n_low=n_low, |
| 307 | + train_accuracy=train_acc, |
| 308 | + test_accuracy=test_acc, |
| 309 | + test_auc=test_auc, |
| 310 | + model=clf, |
| 311 | + scaler=scaler, |
| 312 | + ) |
| 313 | + |
| 314 | + models[cluster][protein] = result |
| 315 | + results.append( |
| 316 | + { |
| 317 | + "cluster": cluster, |
| 318 | + "protein": protein, |
| 319 | + "threshold": threshold, |
| 320 | + "n_cells": n_selected, |
| 321 | + "n_high": n_high, |
| 322 | + "n_low": n_low, |
| 323 | + "train_accuracy": train_acc, |
| 324 | + "test_accuracy": test_acc, |
| 325 | + "test_auc": test_auc, |
| 326 | + } |
| 327 | + ) |
| 328 | + |
| 329 | + summary = pd.DataFrame(results, columns=[ |
| 330 | + "cluster", |
| 331 | + "protein", |
| 332 | + "threshold", |
| 333 | + "n_cells", |
| 334 | + "n_high", |
| 335 | + "n_low", |
| 336 | + "train_accuracy", |
| 337 | + "test_accuracy", |
| 338 | + "test_auc", |
| 339 | + ]) |
| 340 | + |
| 341 | + return summary, models |
| 342 | + |
| 343 | + |
| 344 | +__all__ = ["rna_protein_cluster_analysis", "ProteinModelResult"] |
| 345 | + |
0 commit comments