Skip to content

Commit 14f4190

Browse files
authored
Merge pull request #1 from hutaobo/codex/add-joint-rna-and-protein-analysis-function
Add RNA/protein joint analysis pipeline
2 parents 422d8fd + 016f0ed commit 14f4190

3 files changed

Lines changed: 374 additions & 1 deletion

File tree

README.md

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ Xenium runs that include both RNA and protein measurements.
88
> - [Partial loading (incomplete exports)](#partial-loading-incomplete-exports)
99
> - [RNA + Protein loader](#rna--protein-loader)
1010
> - [Gene–protein correlation](#gene–protein-correlation)
11+
> - [RNA/protein joint analysis](#rnaprotein-joint-analysis)
1112
1213
## Features
1314

@@ -106,6 +107,31 @@ summary = protein_gene_correlation(
106107
print(summary)
107108
```
108109

110+
### RNA/protein joint analysis
111+
112+
Cluster cells using RNA expression, then explain within-cluster protein
113+
heterogeneity by training neural network classifiers on the RNA latent space.
114+
115+
```python
116+
from pyXenium.analysis import rna_protein_cluster_analysis
117+
118+
summary, models = rna_protein_cluster_analysis(
119+
adata,
120+
n_clusters=12,
121+
n_pcs=30,
122+
min_cells_per_cluster=100,
123+
min_cells_per_group=30,
124+
hidden_layer_sizes=(128, 64),
125+
)
126+
127+
# Inspect metrics for the first few cluster × protein combinations
128+
print(summary.head())
129+
130+
# Retrieve the fitted model for a specific cluster and protein
131+
podocin_model = models["cluster_3"]["Podocin"]
132+
print(podocin_model.test_accuracy)
133+
```
134+
109135
## Data format expectations
110136

111137
- **Cell-feature matrix (MEX)** under `cell_feature_matrix/`:

src/pyXenium/analysis/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from .protein_gene_correlation import protein_gene_correlation
2+
from .rna_protein_joint import ProteinModelResult, rna_protein_cluster_analysis
23

34
__all__ = [
4-
*globals().get("__all__", []),
55
"protein_gene_correlation",
6+
"rna_protein_cluster_analysis",
7+
"ProteinModelResult",
68
]
Lines changed: 345 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,345 @@
1+
"""Utilities for joint RNA + protein analysis on Xenium AnnData objects.
2+
3+
This module provides a high-level pipeline that clusters cells using the RNA
4+
expression matrix and, for each protein marker, trains a small neural network
5+
to explain high-vs-low protein states within every RNA-defined cluster.
6+
7+
The function is intentionally self-contained and only relies on scikit-learn
8+
primitives so that it can operate on large Xenium datasets without requiring
9+
extra dependencies (e.g. Scanpy). It works directly on the data structures
10+
produced by :func:`pyXenium.io.xenium_gene_protein_loader.load_xenium_gene_protein`.
11+
"""
12+
13+
from __future__ import annotations
14+
15+
from dataclasses import dataclass
16+
from typing import Dict, List, Optional, Tuple
17+
18+
import numpy as np
19+
import pandas as pd
20+
from anndata import AnnData
21+
from scipy import sparse
22+
from sklearn.cluster import KMeans
23+
from sklearn.decomposition import TruncatedSVD
24+
from sklearn.metrics import accuracy_score, roc_auc_score
25+
from sklearn.model_selection import train_test_split
26+
from sklearn.neural_network import MLPClassifier
27+
from sklearn.preprocessing import StandardScaler
28+
29+
30+
@dataclass
31+
class ProteinModelResult:
32+
"""Container for the protein classification model of one cluster."""
33+
34+
protein: str
35+
cluster: str
36+
threshold: float
37+
n_cells: int
38+
n_high: int
39+
n_low: int
40+
train_accuracy: float
41+
test_accuracy: float
42+
test_auc: float
43+
model: MLPClassifier
44+
scaler: StandardScaler
45+
46+
47+
def _get_rna_matrix(adata: AnnData):
48+
"""Return the raw RNA matrix from ``adata`` as CSR sparse matrix."""
49+
50+
X = adata.layers["rna"] if "rna" in adata.layers else adata.X
51+
return X.tocsr() if sparse.issparse(X) else sparse.csr_matrix(np.asarray(X))
52+
53+
54+
def _normalize_log1p(matrix: sparse.csr_matrix, target_sum: float = 1e4) -> sparse.csr_matrix:
55+
"""Library-size normalisation followed by log1p for sparse matrices."""
56+
57+
matrix = matrix.astype(np.float32)
58+
cell_sums = np.array(matrix.sum(axis=1)).ravel()
59+
cell_sums[cell_sums == 0] = 1.0
60+
inv = sparse.diags((target_sum / cell_sums).astype(np.float32))
61+
norm = inv @ matrix
62+
norm.data = np.log1p(norm.data)
63+
return norm
64+
65+
66+
def _fit_pcs(
67+
matrix: sparse.csr_matrix,
68+
n_components: int,
69+
random_state: Optional[int],
70+
) -> np.ndarray:
71+
"""Fit TruncatedSVD on the RNA matrix and return dense principal components."""
72+
73+
n_features = matrix.shape[1]
74+
if n_features <= 1:
75+
raise ValueError("RNA matrix must contain at least two genes for SVD.")
76+
77+
n_components = max(1, min(n_components, n_features - 1))
78+
svd = TruncatedSVD(n_components=n_components, random_state=random_state)
79+
pcs = svd.fit_transform(matrix)
80+
return pcs.astype(np.float32)
81+
82+
83+
def _resolve_protein_frame(adata: AnnData) -> pd.DataFrame:
84+
"""Return the protein matrix stored in ``adata.obsm['protein']`` as DataFrame."""
85+
86+
if "protein" not in adata.obsm:
87+
raise KeyError("AnnData is missing 'protein' modality in adata.obsm['protein'].")
88+
89+
protein_df = adata.obsm["protein"]
90+
if isinstance(protein_df, pd.DataFrame):
91+
return protein_df
92+
93+
return pd.DataFrame(np.asarray(protein_df), index=adata.obs_names)
94+
95+
96+
def rna_protein_cluster_analysis(
97+
adata: AnnData,
98+
*,
99+
n_clusters: int = 12,
100+
n_pcs: int = 30,
101+
cluster_key: str = "rna_cluster",
102+
random_state: Optional[int] = 0,
103+
target_sum: float = 1e4,
104+
min_cells_per_cluster: int = 50,
105+
min_cells_per_group: int = 20,
106+
protein_split_method: str = "median",
107+
protein_quantile: float = 0.75,
108+
test_size: float = 0.2,
109+
hidden_layer_sizes: Tuple[int, ...] = (64, 32),
110+
max_iter: int = 200,
111+
early_stopping: bool = True,
112+
) -> Tuple[pd.DataFrame, Dict[str, Dict[str, ProteinModelResult]]]:
113+
"""Joint RNA/protein analysis for Xenium AnnData objects.
114+
115+
The pipeline performs three consecutive steps:
116+
117+
1. **RNA preprocessing** – library-size normalisation (counts per ``target_sum``)
118+
followed by ``log1p``. A :class:`~sklearn.decomposition.TruncatedSVD`
119+
is fitted to obtain ``n_pcs`` latent dimensions.
120+
2. **Clustering** – :class:`~sklearn.cluster.KMeans` is applied on the latent
121+
representation to create ``n_clusters`` RNA-driven cell groups. Cluster
122+
assignments are stored in ``adata.obs[cluster_key]`` and the latent space
123+
in ``adata.obsm['X_rna_pca']``.
124+
3. **Protein explanation** – for every cluster and every protein marker, the
125+
cells are divided into "high" vs. "low" groups (median split by default).
126+
A small neural network (:class:`~sklearn.neural_network.MLPClassifier`)
127+
is trained to predict the binary labels from the RNA latent features. The
128+
training/test accuracies and optional ROC-AUC are reported.
129+
130+
Parameters
131+
----------
132+
adata:
133+
AnnData object returned by
134+
:func:`pyXenium.io.xenium_gene_protein_loader.load_xenium_gene_protein`.
135+
Requires ``adata.layers['rna']`` (or ``adata.X``) and
136+
``adata.obsm['protein']``.
137+
n_clusters:
138+
Number of RNA clusters to compute with KMeans.
139+
n_pcs:
140+
Number of latent components extracted with TruncatedSVD. The value is
141+
automatically capped at ``n_genes - 1``.
142+
cluster_key:
143+
Column name added to ``adata.obs`` that stores cluster labels.
144+
random_state:
145+
Seed for the SVD, KMeans and neural networks. Use ``None`` for random
146+
initialisation.
147+
target_sum:
148+
Target library size after normalisation (Counts Per ``target_sum``).
149+
min_cells_per_cluster:
150+
Clusters with fewer cells are skipped entirely.
151+
min_cells_per_group:
152+
Minimum number of cells required in both "high" and "low" protein
153+
groups to train a neural network.
154+
protein_split_method:
155+
Either ``"median"`` (default) for a median split or ``"quantile"`` to
156+
keep only the top ``protein_quantile`` and bottom ``1 - protein_quantile``
157+
fractions of cells (discarding the middle portion).
158+
protein_quantile:
159+
Quantile used when ``protein_split_method='quantile'``.
160+
test_size:
161+
Fraction of the cluster reserved for the test split when training the
162+
neural network.
163+
hidden_layer_sizes:
164+
Hidden-layer configuration passed to :class:`MLPClassifier`.
165+
max_iter:
166+
Maximum number of training iterations for the neural network.
167+
early_stopping:
168+
Whether to use early stopping in :class:`MLPClassifier`.
169+
170+
Returns
171+
-------
172+
summary:
173+
:class:`pandas.DataFrame` summarising the trained models. Columns are
174+
``['cluster', 'protein', 'threshold', 'n_cells', 'n_high', 'n_low',
175+
'train_accuracy', 'test_accuracy', 'test_auc']``.
176+
models:
177+
Nested dictionary ``{cluster -> {protein -> ProteinModelResult}}``
178+
containing the fitted neural networks and scalers for downstream use.
179+
180+
Examples
181+
--------
182+
>>> from pyXenium.analysis import rna_protein_cluster_analysis
183+
>>> summary, models = rna_protein_cluster_analysis(adata, n_clusters=8)
184+
>>> summary.head()
185+
cluster protein threshold n_cells ... test_accuracy test_auc
186+
0 cluster_0 EPCAM (µm) 0.563100 512 ... 0.84 0.91
187+
1 cluster_0 Podocin (µm^2) 0.118775 512 ... 0.79 0.87
188+
"""
189+
190+
if adata.n_obs == 0:
191+
raise ValueError("AnnData contains no cells (n_obs == 0).")
192+
193+
protein_df = _resolve_protein_frame(adata)
194+
if protein_df.shape[1] == 0:
195+
raise ValueError("AnnData.obsm['protein'] is empty – nothing to analyse.")
196+
197+
rna_csr = _get_rna_matrix(adata)
198+
if rna_csr.shape[1] < 2:
199+
raise ValueError("RNA modality must have at least two genes for clustering.")
200+
201+
log_norm = _normalize_log1p(rna_csr, target_sum=target_sum)
202+
pcs = _fit_pcs(log_norm, n_components=n_pcs, random_state=random_state)
203+
adata.obsm["X_rna_pca"] = pcs
204+
205+
kmeans = KMeans(n_clusters=n_clusters, random_state=random_state, n_init=10)
206+
cluster_labels = kmeans.fit_predict(pcs)
207+
cluster_names = np.array([f"cluster_{i}" for i in cluster_labels])
208+
adata.obs[cluster_key] = cluster_names
209+
210+
results: List[Dict[str, float]] = []
211+
models: Dict[str, Dict[str, ProteinModelResult]] = {}
212+
213+
unique_clusters = pd.Index(np.unique(cluster_names))
214+
for cluster in unique_clusters:
215+
idx = np.where(cluster_names == cluster)[0]
216+
if idx.size < min_cells_per_cluster:
217+
continue
218+
219+
cluster_pcs = pcs[idx]
220+
cluster_protein = protein_df.iloc[idx]
221+
models.setdefault(cluster, {})
222+
223+
for protein in cluster_protein.columns:
224+
values_all = cluster_protein[protein].to_numpy(dtype=np.float32)
225+
finite_mask = np.isfinite(values_all)
226+
if finite_mask.sum() < min_cells_per_group * 2:
227+
continue
228+
229+
values = values_all[finite_mask]
230+
X_cluster = cluster_pcs[finite_mask]
231+
232+
if protein_split_method == "median":
233+
threshold = float(np.median(values))
234+
labels = (values >= threshold).astype(int)
235+
elif protein_split_method == "quantile":
236+
q = float(protein_quantile)
237+
if not 0.5 < q < 1.0:
238+
raise ValueError("protein_quantile must be between 0.5 and 1.0 (exclusive).")
239+
high_thr = np.quantile(values, q)
240+
low_mask = values <= np.quantile(values, 1.0 - q)
241+
high_mask = values >= high_thr
242+
selected_mask = high_mask | low_mask
243+
if selected_mask.sum() < min_cells_per_group * 2:
244+
continue
245+
values = values[selected_mask]
246+
X_cluster = X_cluster[selected_mask]
247+
labels = high_mask[selected_mask].astype(int)
248+
threshold = float(high_thr)
249+
else:
250+
raise ValueError("protein_split_method must be 'median' or 'quantile'.")
251+
252+
n_selected = labels.size
253+
n_high = int(labels.sum())
254+
n_low = int(n_selected - n_high)
255+
256+
if n_high < min_cells_per_group or n_low < min_cells_per_group:
257+
continue
258+
259+
try:
260+
X_train, X_test, y_train, y_test = train_test_split(
261+
X_cluster,
262+
labels,
263+
test_size=test_size,
264+
random_state=random_state,
265+
stratify=labels,
266+
)
267+
except ValueError:
268+
# Not enough samples to stratify.
269+
continue
270+
271+
scaler = StandardScaler()
272+
X_train_scaled = scaler.fit_transform(X_train)
273+
X_test_scaled = scaler.transform(X_test)
274+
275+
clf = MLPClassifier(
276+
hidden_layer_sizes=hidden_layer_sizes,
277+
random_state=random_state,
278+
max_iter=max_iter,
279+
early_stopping=early_stopping,
280+
)
281+
282+
try:
283+
clf.fit(X_train_scaled, y_train)
284+
except Exception:
285+
continue
286+
287+
train_acc = float(accuracy_score(y_train, clf.predict(X_train_scaled)))
288+
test_pred = clf.predict(X_test_scaled)
289+
test_acc = float(accuracy_score(y_test, test_pred))
290+
291+
if hasattr(clf, "predict_proba") and len(np.unique(y_test)) == 2:
292+
probs = clf.predict_proba(X_test_scaled)[:, 1]
293+
try:
294+
test_auc = float(roc_auc_score(y_test, probs))
295+
except ValueError:
296+
test_auc = float("nan")
297+
else:
298+
test_auc = float("nan")
299+
300+
result = ProteinModelResult(
301+
protein=protein,
302+
cluster=cluster,
303+
threshold=threshold,
304+
n_cells=n_selected,
305+
n_high=n_high,
306+
n_low=n_low,
307+
train_accuracy=train_acc,
308+
test_accuracy=test_acc,
309+
test_auc=test_auc,
310+
model=clf,
311+
scaler=scaler,
312+
)
313+
314+
models[cluster][protein] = result
315+
results.append(
316+
{
317+
"cluster": cluster,
318+
"protein": protein,
319+
"threshold": threshold,
320+
"n_cells": n_selected,
321+
"n_high": n_high,
322+
"n_low": n_low,
323+
"train_accuracy": train_acc,
324+
"test_accuracy": test_acc,
325+
"test_auc": test_auc,
326+
}
327+
)
328+
329+
summary = pd.DataFrame(results, columns=[
330+
"cluster",
331+
"protein",
332+
"threshold",
333+
"n_cells",
334+
"n_high",
335+
"n_low",
336+
"train_accuracy",
337+
"test_accuracy",
338+
"test_auc",
339+
])
340+
341+
return summary, models
342+
343+
344+
__all__ = ["rna_protein_cluster_analysis", "ProteinModelResult"]
345+

0 commit comments

Comments
 (0)