-
Notifications
You must be signed in to change notification settings - Fork 0
Add RNA/protein joint analysis pipeline #1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,8 @@ | ||
| from .protein_gene_correlation import protein_gene_correlation | ||
| from .rna_protein_joint import ProteinModelResult, rna_protein_cluster_analysis | ||
|
|
||
| __all__ = [ | ||
| *globals().get("__all__", []), | ||
| "protein_gene_correlation", | ||
| "rna_protein_cluster_analysis", | ||
| "ProteinModelResult", | ||
| ] |
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,345 @@ | ||||||||||||||||||||||||||||||||||||||||||||||
| """Utilities for joint RNA + protein analysis on Xenium AnnData objects. | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| This module provides a high-level pipeline that clusters cells using the RNA | ||||||||||||||||||||||||||||||||||||||||||||||
| expression matrix and, for each protein marker, trains a small neural network | ||||||||||||||||||||||||||||||||||||||||||||||
| to explain high-vs-low protein states within every RNA-defined cluster. | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| The function is intentionally self-contained and only relies on scikit-learn | ||||||||||||||||||||||||||||||||||||||||||||||
| primitives so that it can operate on large Xenium datasets without requiring | ||||||||||||||||||||||||||||||||||||||||||||||
| extra dependencies (e.g. Scanpy). It works directly on the data structures | ||||||||||||||||||||||||||||||||||||||||||||||
| produced by :func:`pyXenium.io.xenium_gene_protein_loader.load_xenium_gene_protein`. | ||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| from __future__ import annotations | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| from dataclasses import dataclass | ||||||||||||||||||||||||||||||||||||||||||||||
| from typing import Dict, List, Optional, Tuple | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| import numpy as np | ||||||||||||||||||||||||||||||||||||||||||||||
| import pandas as pd | ||||||||||||||||||||||||||||||||||||||||||||||
| from anndata import AnnData | ||||||||||||||||||||||||||||||||||||||||||||||
| from scipy import sparse | ||||||||||||||||||||||||||||||||||||||||||||||
| from sklearn.cluster import KMeans | ||||||||||||||||||||||||||||||||||||||||||||||
| from sklearn.decomposition import TruncatedSVD | ||||||||||||||||||||||||||||||||||||||||||||||
| from sklearn.metrics import accuracy_score, roc_auc_score | ||||||||||||||||||||||||||||||||||||||||||||||
| from sklearn.model_selection import train_test_split | ||||||||||||||||||||||||||||||||||||||||||||||
| from sklearn.neural_network import MLPClassifier | ||||||||||||||||||||||||||||||||||||||||||||||
| from sklearn.preprocessing import StandardScaler | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| @dataclass | ||||||||||||||||||||||||||||||||||||||||||||||
| class ProteinModelResult: | ||||||||||||||||||||||||||||||||||||||||||||||
| """Container for the protein classification model of one cluster.""" | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| protein: str | ||||||||||||||||||||||||||||||||||||||||||||||
| cluster: str | ||||||||||||||||||||||||||||||||||||||||||||||
| threshold: float | ||||||||||||||||||||||||||||||||||||||||||||||
| n_cells: int | ||||||||||||||||||||||||||||||||||||||||||||||
| n_high: int | ||||||||||||||||||||||||||||||||||||||||||||||
| n_low: int | ||||||||||||||||||||||||||||||||||||||||||||||
| train_accuracy: float | ||||||||||||||||||||||||||||||||||||||||||||||
| test_accuracy: float | ||||||||||||||||||||||||||||||||||||||||||||||
| test_auc: float | ||||||||||||||||||||||||||||||||||||||||||||||
| model: MLPClassifier | ||||||||||||||||||||||||||||||||||||||||||||||
| scaler: StandardScaler | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| def _get_rna_matrix(adata: AnnData): | ||||||||||||||||||||||||||||||||||||||||||||||
| """Return the raw RNA matrix from ``adata`` as CSR sparse matrix.""" | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| X = adata.layers["rna"] if "rna" in adata.layers else adata.X | ||||||||||||||||||||||||||||||||||||||||||||||
| return X.tocsr() if sparse.issparse(X) else sparse.csr_matrix(np.asarray(X)) | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| def _normalize_log1p(matrix: sparse.csr_matrix, target_sum: float = 1e4) -> sparse.csr_matrix: | ||||||||||||||||||||||||||||||||||||||||||||||
| """Library-size normalisation followed by log1p for sparse matrices.""" | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| matrix = matrix.astype(np.float32) | ||||||||||||||||||||||||||||||||||||||||||||||
| cell_sums = np.array(matrix.sum(axis=1)).ravel() | ||||||||||||||||||||||||||||||||||||||||||||||
| cell_sums[cell_sums == 0] = 1.0 | ||||||||||||||||||||||||||||||||||||||||||||||
| inv = sparse.diags((target_sum / cell_sums).astype(np.float32)) | ||||||||||||||||||||||||||||||||||||||||||||||
| norm = inv @ matrix | ||||||||||||||||||||||||||||||||||||||||||||||
| norm.data = np.log1p(norm.data) | ||||||||||||||||||||||||||||||||||||||||||||||
| return norm | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| def _fit_pcs( | ||||||||||||||||||||||||||||||||||||||||||||||
| matrix: sparse.csr_matrix, | ||||||||||||||||||||||||||||||||||||||||||||||
| n_components: int, | ||||||||||||||||||||||||||||||||||||||||||||||
| random_state: Optional[int], | ||||||||||||||||||||||||||||||||||||||||||||||
| ) -> np.ndarray: | ||||||||||||||||||||||||||||||||||||||||||||||
| """Fit TruncatedSVD on the RNA matrix and return dense principal components.""" | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| n_features = matrix.shape[1] | ||||||||||||||||||||||||||||||||||||||||||||||
| if n_features <= 1: | ||||||||||||||||||||||||||||||||||||||||||||||
| raise ValueError("RNA matrix must contain at least two genes for SVD.") | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| n_components = max(1, min(n_components, n_features - 1)) | ||||||||||||||||||||||||||||||||||||||||||||||
| svd = TruncatedSVD(n_components=n_components, random_state=random_state) | ||||||||||||||||||||||||||||||||||||||||||||||
| pcs = svd.fit_transform(matrix) | ||||||||||||||||||||||||||||||||||||||||||||||
| return pcs.astype(np.float32) | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| def _resolve_protein_frame(adata: AnnData) -> pd.DataFrame: | ||||||||||||||||||||||||||||||||||||||||||||||
| """Return the protein matrix stored in ``adata.obsm['protein']`` as DataFrame.""" | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| if "protein" not in adata.obsm: | ||||||||||||||||||||||||||||||||||||||||||||||
| raise KeyError("AnnData is missing 'protein' modality in adata.obsm['protein'].") | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| protein_df = adata.obsm["protein"] | ||||||||||||||||||||||||||||||||||||||||||||||
| if isinstance(protein_df, pd.DataFrame): | ||||||||||||||||||||||||||||||||||||||||||||||
| return protein_df | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| return pd.DataFrame(np.asarray(protein_df), index=adata.obs_names) | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| def rna_protein_cluster_analysis( | ||||||||||||||||||||||||||||||||||||||||||||||
| adata: AnnData, | ||||||||||||||||||||||||||||||||||||||||||||||
| *, | ||||||||||||||||||||||||||||||||||||||||||||||
| n_clusters: int = 12, | ||||||||||||||||||||||||||||||||||||||||||||||
| n_pcs: int = 30, | ||||||||||||||||||||||||||||||||||||||||||||||
| cluster_key: str = "rna_cluster", | ||||||||||||||||||||||||||||||||||||||||||||||
| random_state: Optional[int] = 0, | ||||||||||||||||||||||||||||||||||||||||||||||
| target_sum: float = 1e4, | ||||||||||||||||||||||||||||||||||||||||||||||
| min_cells_per_cluster: int = 50, | ||||||||||||||||||||||||||||||||||||||||||||||
| min_cells_per_group: int = 20, | ||||||||||||||||||||||||||||||||||||||||||||||
| protein_split_method: str = "median", | ||||||||||||||||||||||||||||||||||||||||||||||
| protein_quantile: float = 0.75, | ||||||||||||||||||||||||||||||||||||||||||||||
| test_size: float = 0.2, | ||||||||||||||||||||||||||||||||||||||||||||||
| hidden_layer_sizes: Tuple[int, ...] = (64, 32), | ||||||||||||||||||||||||||||||||||||||||||||||
| max_iter: int = 200, | ||||||||||||||||||||||||||||||||||||||||||||||
| early_stopping: bool = True, | ||||||||||||||||||||||||||||||||||||||||||||||
| ) -> Tuple[pd.DataFrame, Dict[str, Dict[str, ProteinModelResult]]]: | ||||||||||||||||||||||||||||||||||||||||||||||
| """Joint RNA/protein analysis for Xenium AnnData objects. | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| The pipeline performs three consecutive steps: | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| 1. **RNA preprocessing** – library-size normalisation (counts per ``target_sum``) | ||||||||||||||||||||||||||||||||||||||||||||||
| followed by ``log1p``. A :class:`~sklearn.decomposition.TruncatedSVD` | ||||||||||||||||||||||||||||||||||||||||||||||
| is fitted to obtain ``n_pcs`` latent dimensions. | ||||||||||||||||||||||||||||||||||||||||||||||
| 2. **Clustering** – :class:`~sklearn.cluster.KMeans` is applied on the latent | ||||||||||||||||||||||||||||||||||||||||||||||
| representation to create ``n_clusters`` RNA-driven cell groups. Cluster | ||||||||||||||||||||||||||||||||||||||||||||||
| assignments are stored in ``adata.obs[cluster_key]`` and the latent space | ||||||||||||||||||||||||||||||||||||||||||||||
| in ``adata.obsm['X_rna_pca']``. | ||||||||||||||||||||||||||||||||||||||||||||||
| 3. **Protein explanation** – for every cluster and every protein marker, the | ||||||||||||||||||||||||||||||||||||||||||||||
| cells are divided into "high" vs. "low" groups (median split by default). | ||||||||||||||||||||||||||||||||||||||||||||||
| A small neural network (:class:`~sklearn.neural_network.MLPClassifier`) | ||||||||||||||||||||||||||||||||||||||||||||||
| is trained to predict the binary labels from the RNA latent features. The | ||||||||||||||||||||||||||||||||||||||||||||||
| training/test accuracies and optional ROC-AUC are reported. | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| Parameters | ||||||||||||||||||||||||||||||||||||||||||||||
| ---------- | ||||||||||||||||||||||||||||||||||||||||||||||
| adata: | ||||||||||||||||||||||||||||||||||||||||||||||
| AnnData object returned by | ||||||||||||||||||||||||||||||||||||||||||||||
| :func:`pyXenium.io.xenium_gene_protein_loader.load_xenium_gene_protein`. | ||||||||||||||||||||||||||||||||||||||||||||||
| Requires ``adata.layers['rna']`` (or ``adata.X``) and | ||||||||||||||||||||||||||||||||||||||||||||||
| ``adata.obsm['protein']``. | ||||||||||||||||||||||||||||||||||||||||||||||
| n_clusters: | ||||||||||||||||||||||||||||||||||||||||||||||
| Number of RNA clusters to compute with KMeans. | ||||||||||||||||||||||||||||||||||||||||||||||
| n_pcs: | ||||||||||||||||||||||||||||||||||||||||||||||
| Number of latent components extracted with TruncatedSVD. The value is | ||||||||||||||||||||||||||||||||||||||||||||||
| automatically capped at ``n_genes - 1``. | ||||||||||||||||||||||||||||||||||||||||||||||
| cluster_key: | ||||||||||||||||||||||||||||||||||||||||||||||
| Column name added to ``adata.obs`` that stores cluster labels. | ||||||||||||||||||||||||||||||||||||||||||||||
| random_state: | ||||||||||||||||||||||||||||||||||||||||||||||
| Seed for the SVD, KMeans and neural networks. Use ``None`` for random | ||||||||||||||||||||||||||||||||||||||||||||||
| initialisation. | ||||||||||||||||||||||||||||||||||||||||||||||
| target_sum: | ||||||||||||||||||||||||||||||||||||||||||||||
| Target library size after normalisation (Counts Per ``target_sum``). | ||||||||||||||||||||||||||||||||||||||||||||||
| min_cells_per_cluster: | ||||||||||||||||||||||||||||||||||||||||||||||
| Clusters with fewer cells are skipped entirely. | ||||||||||||||||||||||||||||||||||||||||||||||
| min_cells_per_group: | ||||||||||||||||||||||||||||||||||||||||||||||
| Minimum number of cells required in both "high" and "low" protein | ||||||||||||||||||||||||||||||||||||||||||||||
| groups to train a neural network. | ||||||||||||||||||||||||||||||||||||||||||||||
| protein_split_method: | ||||||||||||||||||||||||||||||||||||||||||||||
| Either ``"median"`` (default) for a median split or ``"quantile"`` to | ||||||||||||||||||||||||||||||||||||||||||||||
| keep only the top ``protein_quantile`` and bottom ``1 - protein_quantile`` | ||||||||||||||||||||||||||||||||||||||||||||||
| fractions of cells (discarding the middle portion). | ||||||||||||||||||||||||||||||||||||||||||||||
| protein_quantile: | ||||||||||||||||||||||||||||||||||||||||||||||
| Quantile used when ``protein_split_method='quantile'``. | ||||||||||||||||||||||||||||||||||||||||||||||
| test_size: | ||||||||||||||||||||||||||||||||||||||||||||||
| Fraction of the cluster reserved for the test split when training the | ||||||||||||||||||||||||||||||||||||||||||||||
| neural network. | ||||||||||||||||||||||||||||||||||||||||||||||
| hidden_layer_sizes: | ||||||||||||||||||||||||||||||||||||||||||||||
| Hidden-layer configuration passed to :class:`MLPClassifier`. | ||||||||||||||||||||||||||||||||||||||||||||||
| max_iter: | ||||||||||||||||||||||||||||||||||||||||||||||
| Maximum number of training iterations for the neural network. | ||||||||||||||||||||||||||||||||||||||||||||||
| early_stopping: | ||||||||||||||||||||||||||||||||||||||||||||||
| Whether to use early stopping in :class:`MLPClassifier`. | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| Returns | ||||||||||||||||||||||||||||||||||||||||||||||
| ------- | ||||||||||||||||||||||||||||||||||||||||||||||
| summary: | ||||||||||||||||||||||||||||||||||||||||||||||
| :class:`pandas.DataFrame` summarising the trained models. Columns are | ||||||||||||||||||||||||||||||||||||||||||||||
| ``['cluster', 'protein', 'threshold', 'n_cells', 'n_high', 'n_low', | ||||||||||||||||||||||||||||||||||||||||||||||
| 'train_accuracy', 'test_accuracy', 'test_auc']``. | ||||||||||||||||||||||||||||||||||||||||||||||
| models: | ||||||||||||||||||||||||||||||||||||||||||||||
| Nested dictionary ``{cluster -> {protein -> ProteinModelResult}}`` | ||||||||||||||||||||||||||||||||||||||||||||||
| containing the fitted neural networks and scalers for downstream use. | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| Examples | ||||||||||||||||||||||||||||||||||||||||||||||
| -------- | ||||||||||||||||||||||||||||||||||||||||||||||
| >>> from pyXenium.analysis import rna_protein_cluster_analysis | ||||||||||||||||||||||||||||||||||||||||||||||
| >>> summary, models = rna_protein_cluster_analysis(adata, n_clusters=8) | ||||||||||||||||||||||||||||||||||||||||||||||
| >>> summary.head() | ||||||||||||||||||||||||||||||||||||||||||||||
| cluster protein threshold n_cells ... test_accuracy test_auc | ||||||||||||||||||||||||||||||||||||||||||||||
| 0 cluster_0 EPCAM (µm) 0.563100 512 ... 0.84 0.91 | ||||||||||||||||||||||||||||||||||||||||||||||
| 1 cluster_0 Podocin (µm^2) 0.118775 512 ... 0.79 0.87 | ||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| if adata.n_obs == 0: | ||||||||||||||||||||||||||||||||||||||||||||||
| raise ValueError("AnnData contains no cells (n_obs == 0).") | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| protein_df = _resolve_protein_frame(adata) | ||||||||||||||||||||||||||||||||||||||||||||||
| if protein_df.shape[1] == 0: | ||||||||||||||||||||||||||||||||||||||||||||||
| raise ValueError("AnnData.obsm['protein'] is empty – nothing to analyse.") | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| rna_csr = _get_rna_matrix(adata) | ||||||||||||||||||||||||||||||||||||||||||||||
| if rna_csr.shape[1] < 2: | ||||||||||||||||||||||||||||||||||||||||||||||
| raise ValueError("RNA modality must have at least two genes for clustering.") | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| log_norm = _normalize_log1p(rna_csr, target_sum=target_sum) | ||||||||||||||||||||||||||||||||||||||||||||||
| pcs = _fit_pcs(log_norm, n_components=n_pcs, random_state=random_state) | ||||||||||||||||||||||||||||||||||||||||||||||
| adata.obsm["X_rna_pca"] = pcs | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| kmeans = KMeans(n_clusters=n_clusters, random_state=random_state, n_init=10) | ||||||||||||||||||||||||||||||||||||||||||||||
| cluster_labels = kmeans.fit_predict(pcs) | ||||||||||||||||||||||||||||||||||||||||||||||
| cluster_names = np.array([f"cluster_{i}" for i in cluster_labels]) | ||||||||||||||||||||||||||||||||||||||||||||||
| adata.obs[cluster_key] = cluster_names | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| results: List[Dict[str, float]] = [] | ||||||||||||||||||||||||||||||||||||||||||||||
| models: Dict[str, Dict[str, ProteinModelResult]] = {} | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| unique_clusters = pd.Index(np.unique(cluster_names)) | ||||||||||||||||||||||||||||||||||||||||||||||
| for cluster in unique_clusters: | ||||||||||||||||||||||||||||||||||||||||||||||
| idx = np.where(cluster_names == cluster)[0] | ||||||||||||||||||||||||||||||||||||||||||||||
| if idx.size < min_cells_per_cluster: | ||||||||||||||||||||||||||||||||||||||||||||||
| continue | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| cluster_pcs = pcs[idx] | ||||||||||||||||||||||||||||||||||||||||||||||
| cluster_protein = protein_df.iloc[idx] | ||||||||||||||||||||||||||||||||||||||||||||||
| models.setdefault(cluster, {}) | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| for protein in cluster_protein.columns: | ||||||||||||||||||||||||||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 问题 (复杂性): 考虑将蛋白质数据准备和分类器拟合逻辑提取到辅助函数中,以简化主循环并降低循环复杂度。 这里有两个你可以从内循环中提取出来的小型辅助函数,以简化嵌套、降低循环复杂度并保持所有功能: def _prepare_protein_data(
values: np.ndarray,
X_cluster: np.ndarray,
method: str,
quantile: float,
min_cells: int,
) -> Tuple[Optional[float], np.ndarray, np.ndarray]:
"""Return threshold, filtered X and labels or (None,_,_) if not enough cells."""
finite = np.isfinite(values)
if finite.sum() < min_cells * 2:
return None, None, None
vals = values[finite]
Xc = X_cluster[finite]
if method == "median":
thr = float(np.median(vals))
labels = (vals >= thr).astype(int)
elif method == "quantile":
q = float(quantile)
if not 0.5 < q < 1.0:
raise ValueError("protein_quantile must be between 0.5 and 1.0")
high_thr = np.quantile(vals, q)
low_mask = vals <= np.quantile(vals, 1 - q)
high_mask = vals >= high_thr
sel = low_mask | high_mask
if sel.sum() < min_cells * 2:
return None, None, None
vals = vals[sel]
Xc = Xc[sel]
labels = high_mask[sel].astype(int)
thr = float(high_thr)
else:
raise ValueError("protein_split_method must be 'median' or 'quantile'")
if labels.sum() < min_cells or (len(labels) - labels.sum()) < min_cells:
return None, None, None
return thr, Xc, labelsdef _fit_protein_classifier(
X: np.ndarray,
y: np.ndarray,
test_size: float,
random_state: Optional[int],
hidden_layer_sizes: Tuple[int, ...],
max_iter: int,
early_stopping: bool,
) -> Optional[ProteinModelResult]:
"""Train/test split + scaler + MLPClassifier, return result or None on error."""
try:
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=test_size, stratify=y, random_state=random_state
)
except ValueError:
return None
scaler = StandardScaler()
Xtr = scaler.fit_transform(X_train)
Xte = scaler.transform(X_test)
clf = MLPClassifier(
hidden_layer_sizes=hidden_layer_sizes,
random_state=random_state,
max_iter=max_iter,
early_stopping=early_stopping,
)
try:
clf.fit(Xtr, y_train)
except Exception:
return None
train_acc = accuracy_score(y_train, clf.predict(Xtr))
y_pred = clf.predict(Xte)
test_acc = accuracy_score(y_test, y_pred)
if hasattr(clf, "predict_proba") and len(np.unique(y_test)) == 2:
try:
auc = roc_auc_score(y_test, clf.predict_proba(Xte)[:, 1])
except ValueError:
auc = float("nan")
else:
auc = float("nan")
return ProteinModelResult(
protein="",
cluster="",
threshold=0.0,
n_cells=len(y),
n_high=int(y.sum()),
n_low=int(len(y) - y.sum()),
train_accuracy=float(train_acc),
test_accuracy=float(test_acc),
test_auc=float(auc),
model=clf,
scaler=scaler,
)然后在主循环中将大块代码替换为: for protein in cluster_protein.columns:
vals = cluster_protein[protein].to_numpy(dtype=np.float32)
thr, X_sel, labels = _prepare_protein_data(
vals, cluster_pcs, protein_split_method, protein_quantile, min_cells_per_group
)
if thr is None:
continue
result = _fit_protein_classifier(
X_sel, labels, test_size, random_state,
hidden_layer_sizes, max_iter, early_stopping,
)
if result is None:
continue
result.protein = protein
result.cluster = cluster
result.threshold = thr
models[cluster][protein] = result
results.append({
"cluster": cluster,
"protein": protein,
"threshold": thr,
"n_cells": result.n_cells,
"n_high": result.n_high,
"n_low": result.n_low,
"train_accuracy": result.train_accuracy,
"test_accuracy": result.test_accuracy,
"test_auc": result.test_auc,
})这将大约 60 行嵌套逻辑提取到两个专注的辅助函数中,并简化了你的主循环。 Original comment in Englishissue (complexity): Consider extracting the protein data preparation and classifier fitting logic into helper functions to flatten the main loop and reduce cyclomatic complexity. Here are two small helpers you can extract from the inner loops to flatten that nesting, reduce cyclomatic complexity, and keep all functionality: def _prepare_protein_data(
values: np.ndarray,
X_cluster: np.ndarray,
method: str,
quantile: float,
min_cells: int,
) -> Tuple[Optional[float], np.ndarray, np.ndarray]:
"""Return threshold, filtered X and labels or (None,_,_) if not enough cells."""
finite = np.isfinite(values)
if finite.sum() < min_cells * 2:
return None, None, None
vals = values[finite]
Xc = X_cluster[finite]
if method == "median":
thr = float(np.median(vals))
labels = (vals >= thr).astype(int)
elif method == "quantile":
q = float(quantile)
if not 0.5 < q < 1.0:
raise ValueError("protein_quantile must be between 0.5 and 1.0")
high_thr = np.quantile(vals, q)
low_mask = vals <= np.quantile(vals, 1 - q)
high_mask = vals >= high_thr
sel = low_mask | high_mask
if sel.sum() < min_cells * 2:
return None, None, None
vals = vals[sel]
Xc = Xc[sel]
labels = high_mask[sel].astype(int)
thr = float(high_thr)
else:
raise ValueError("protein_split_method must be 'median' or 'quantile'")
if labels.sum() < min_cells or (len(labels) - labels.sum()) < min_cells:
return None, None, None
return thr, Xc, labelsdef _fit_protein_classifier(
X: np.ndarray,
y: np.ndarray,
test_size: float,
random_state: Optional[int],
hidden_layer_sizes: Tuple[int, ...],
max_iter: int,
early_stopping: bool,
) -> Optional[ProteinModelResult]:
"""Train/test split + scaler + MLPClassifier, return result or None on error."""
try:
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=test_size, stratify=y, random_state=random_state
)
except ValueError:
return None
scaler = StandardScaler()
Xtr = scaler.fit_transform(X_train)
Xte = scaler.transform(X_test)
clf = MLPClassifier(
hidden_layer_sizes=hidden_layer_sizes,
random_state=random_state,
max_iter=max_iter,
early_stopping=early_stopping,
)
try:
clf.fit(Xtr, y_train)
except Exception:
return None
train_acc = accuracy_score(y_train, clf.predict(Xtr))
y_pred = clf.predict(Xte)
test_acc = accuracy_score(y_test, y_pred)
if hasattr(clf, "predict_proba") and len(np.unique(y_test)) == 2:
try:
auc = roc_auc_score(y_test, clf.predict_proba(Xte)[:, 1])
except ValueError:
auc = float("nan")
else:
auc = float("nan")
return ProteinModelResult(
protein="",
cluster="",
threshold=0.0,
n_cells=len(y),
n_high=int(y.sum()),
n_low=int(len(y) - y.sum()),
train_accuracy=float(train_acc),
test_accuracy=float(test_acc),
test_auc=float(auc),
model=clf,
scaler=scaler,
)Then in your main loop replace the big block with: for protein in cluster_protein.columns:
vals = cluster_protein[protein].to_numpy(dtype=np.float32)
thr, X_sel, labels = _prepare_protein_data(
vals, cluster_pcs, protein_split_method, protein_quantile, min_cells_per_group
)
if thr is None:
continue
result = _fit_protein_classifier(
X_sel, labels, test_size, random_state,
hidden_layer_sizes, max_iter, early_stopping,
)
if result is None:
continue
result.protein = protein
result.cluster = cluster
result.threshold = thr
models[cluster][protein] = result
results.append({
"cluster": cluster,
"protein": protein,
"threshold": thr,
"n_cells": result.n_cells,
"n_high": result.n_high,
"n_low": result.n_low,
"train_accuracy": result.train_accuracy,
"test_accuracy": result.test_accuracy,
"test_auc": result.test_auc,
})This pulls out ~60 lines of nested logic into two focused helpers and flattens your main loop. |
||||||||||||||||||||||||||||||||||||||||||||||
| values_all = cluster_protein[protein].to_numpy(dtype=np.float32) | ||||||||||||||||||||||||||||||||||||||||||||||
| finite_mask = np.isfinite(values_all) | ||||||||||||||||||||||||||||||||||||||||||||||
| if finite_mask.sum() < min_cells_per_group * 2: | ||||||||||||||||||||||||||||||||||||||||||||||
| continue | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| values = values_all[finite_mask] | ||||||||||||||||||||||||||||||||||||||||||||||
| X_cluster = cluster_pcs[finite_mask] | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| if protein_split_method == "median": | ||||||||||||||||||||||||||||||||||||||||||||||
| threshold = float(np.median(values)) | ||||||||||||||||||||||||||||||||||||||||||||||
| labels = (values >= threshold).astype(int) | ||||||||||||||||||||||||||||||||||||||||||||||
| elif protein_split_method == "quantile": | ||||||||||||||||||||||||||||||||||||||||||||||
| q = float(protein_quantile) | ||||||||||||||||||||||||||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 建议 (代码质量): 我们发现了这些问题:
Suggested change
解释此函数的质量得分低于 25% 的质量阈值。 你如何解决这个问题? 重构此函数以使其更短、更具可读性可能值得。
Original comment in Englishsuggestion (code-quality): We've found these issues:
Suggested change
ExplanationThe quality score for this function is below the quality threshold of 25%. How can you solve this? It might be worth refactoring this function to make it shorter and more readable.
|
||||||||||||||||||||||||||||||||||||||||||||||
| if not 0.5 < q < 1.0: | ||||||||||||||||||||||||||||||||||||||||||||||
| raise ValueError("protein_quantile must be between 0.5 and 1.0 (exclusive).") | ||||||||||||||||||||||||||||||||||||||||||||||
| high_thr = np.quantile(values, q) | ||||||||||||||||||||||||||||||||||||||||||||||
| low_mask = values <= np.quantile(values, 1.0 - q) | ||||||||||||||||||||||||||||||||||||||||||||||
| high_mask = values >= high_thr | ||||||||||||||||||||||||||||||||||||||||||||||
| selected_mask = high_mask | low_mask | ||||||||||||||||||||||||||||||||||||||||||||||
| if selected_mask.sum() < min_cells_per_group * 2: | ||||||||||||||||||||||||||||||||||||||||||||||
| continue | ||||||||||||||||||||||||||||||||||||||||||||||
| values = values[selected_mask] | ||||||||||||||||||||||||||||||||||||||||||||||
| X_cluster = X_cluster[selected_mask] | ||||||||||||||||||||||||||||||||||||||||||||||
| labels = high_mask[selected_mask].astype(int) | ||||||||||||||||||||||||||||||||||||||||||||||
| threshold = float(high_thr) | ||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||
| raise ValueError("protein_split_method must be 'median' or 'quantile'.") | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| n_selected = labels.size | ||||||||||||||||||||||||||||||||||||||||||||||
| n_high = int(labels.sum()) | ||||||||||||||||||||||||||||||||||||||||||||||
| n_low = int(n_selected - n_high) | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| if n_high < min_cells_per_group or n_low < min_cells_per_group: | ||||||||||||||||||||||||||||||||||||||||||||||
| continue | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||||||||||||
| X_train, X_test, y_train, y_test = train_test_split( | ||||||||||||||||||||||||||||||||||||||||||||||
| X_cluster, | ||||||||||||||||||||||||||||||||||||||||||||||
| labels, | ||||||||||||||||||||||||||||||||||||||||||||||
| test_size=test_size, | ||||||||||||||||||||||||||||||||||||||||||||||
| random_state=random_state, | ||||||||||||||||||||||||||||||||||||||||||||||
| stratify=labels, | ||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||
| except ValueError: | ||||||||||||||||||||||||||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 建议 (bug_risk): 在 当由于分层样本不足而跳过集群时,添加日志记录以提高可追溯性。
Suggested change
Original comment in Englishsuggestion (bug_risk): Silently continuing on ValueError in train_test_split may skip clusters without notification. Add logging when clusters are skipped due to insufficient samples for stratification to improve traceability.
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||
| # Not enough samples to stratify. | ||||||||||||||||||||||||||||||||||||||||||||||
| continue | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| scaler = StandardScaler() | ||||||||||||||||||||||||||||||||||||||||||||||
| X_train_scaled = scaler.fit_transform(X_train) | ||||||||||||||||||||||||||||||||||||||||||||||
| X_test_scaled = scaler.transform(X_test) | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| clf = MLPClassifier( | ||||||||||||||||||||||||||||||||||||||||||||||
| hidden_layer_sizes=hidden_layer_sizes, | ||||||||||||||||||||||||||||||||||||||||||||||
| random_state=random_state, | ||||||||||||||||||||||||||||||||||||||||||||||
| max_iter=max_iter, | ||||||||||||||||||||||||||||||||||||||||||||||
| early_stopping=early_stopping, | ||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||||||||||||
| clf.fit(X_train_scaled, y_train) | ||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+282
to
+283
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 建议 (bug_risk): 在模型拟合期间捕获所有异常可能会掩盖潜在问题。 只捕获特定的异常,例如 ValueError 或 ConvergenceWarning,以确保在开发过程中不会遗漏关键错误。
Suggested change
Original comment in Englishsuggestion (bug_risk): Catching all exceptions during model fitting may obscure underlying issues. Catch only specific exceptions like ValueError or ConvergenceWarning to ensure critical errors are not missed during development.
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||
| except Exception: | ||||||||||||||||||||||||||||||||||||||||||||||
| continue | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| train_acc = float(accuracy_score(y_train, clf.predict(X_train_scaled))) | ||||||||||||||||||||||||||||||||||||||||||||||
| test_pred = clf.predict(X_test_scaled) | ||||||||||||||||||||||||||||||||||||||||||||||
| test_acc = float(accuracy_score(y_test, test_pred)) | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| if hasattr(clf, "predict_proba") and len(np.unique(y_test)) == 2: | ||||||||||||||||||||||||||||||||||||||||||||||
| probs = clf.predict_proba(X_test_scaled)[:, 1] | ||||||||||||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||||||||||||
| test_auc = float(roc_auc_score(y_test, probs)) | ||||||||||||||||||||||||||||||||||||||||||||||
| except ValueError: | ||||||||||||||||||||||||||||||||||||||||||||||
| test_auc = float("nan") | ||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||
| test_auc = float("nan") | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| result = ProteinModelResult( | ||||||||||||||||||||||||||||||||||||||||||||||
| protein=protein, | ||||||||||||||||||||||||||||||||||||||||||||||
| cluster=cluster, | ||||||||||||||||||||||||||||||||||||||||||||||
| threshold=threshold, | ||||||||||||||||||||||||||||||||||||||||||||||
| n_cells=n_selected, | ||||||||||||||||||||||||||||||||||||||||||||||
| n_high=n_high, | ||||||||||||||||||||||||||||||||||||||||||||||
| n_low=n_low, | ||||||||||||||||||||||||||||||||||||||||||||||
| train_accuracy=train_acc, | ||||||||||||||||||||||||||||||||||||||||||||||
| test_accuracy=test_acc, | ||||||||||||||||||||||||||||||||||||||||||||||
| test_auc=test_auc, | ||||||||||||||||||||||||||||||||||||||||||||||
| model=clf, | ||||||||||||||||||||||||||||||||||||||||||||||
| scaler=scaler, | ||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| models[cluster][protein] = result | ||||||||||||||||||||||||||||||||||||||||||||||
| results.append( | ||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||
| "cluster": cluster, | ||||||||||||||||||||||||||||||||||||||||||||||
| "protein": protein, | ||||||||||||||||||||||||||||||||||||||||||||||
| "threshold": threshold, | ||||||||||||||||||||||||||||||||||||||||||||||
| "n_cells": n_selected, | ||||||||||||||||||||||||||||||||||||||||||||||
| "n_high": n_high, | ||||||||||||||||||||||||||||||||||||||||||||||
| "n_low": n_low, | ||||||||||||||||||||||||||||||||||||||||||||||
| "train_accuracy": train_acc, | ||||||||||||||||||||||||||||||||||||||||||||||
| "test_accuracy": test_acc, | ||||||||||||||||||||||||||||||||||||||||||||||
| "test_auc": test_auc, | ||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| summary = pd.DataFrame(results, columns=[ | ||||||||||||||||||||||||||||||||||||||||||||||
| "cluster", | ||||||||||||||||||||||||||||||||||||||||||||||
| "protein", | ||||||||||||||||||||||||||||||||||||||||||||||
| "threshold", | ||||||||||||||||||||||||||||||||||||||||||||||
| "n_cells", | ||||||||||||||||||||||||||||||||||||||||||||||
| "n_high", | ||||||||||||||||||||||||||||||||||||||||||||||
| "n_low", | ||||||||||||||||||||||||||||||||||||||||||||||
| "train_accuracy", | ||||||||||||||||||||||||||||||||||||||||||||||
| "test_accuracy", | ||||||||||||||||||||||||||||||||||||||||||||||
| "test_auc", | ||||||||||||||||||||||||||||||||||||||||||||||
| ]) | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| return summary, models | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| __all__ = ["rna_protein_cluster_analysis", "ProteinModelResult"] | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
建议: 如果缺少 'protein',会引发 KeyError,但没有提供备用方案或指导。
考虑更新错误消息,包含添加或计算 'protein' 模态的说明,以便用户更容易解决此问题。
Original comment in English
suggestion: KeyError is raised if 'protein' is missing, but no fallback or guidance is provided.
Consider updating the error message to include instructions for adding or computing the 'protein' modality, making it easier for users to address the issue.