Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ Xenium runs that include both RNA and protein measurements.
> - [Partial loading (incomplete exports)](#partial-loading-incomplete-exports)
> - [RNA + Protein loader](#rna--protein-loader)
> - [Gene–protein correlation](#gene–protein-correlation)
> - [RNA/protein joint analysis](#rnaprotein-joint-analysis)

## Features

Expand Down Expand Up @@ -106,6 +107,31 @@ summary = protein_gene_correlation(
print(summary)
```

### RNA/protein joint analysis

Cluster cells using RNA expression, then explain within-cluster protein
heterogeneity by training neural network classifiers on the RNA latent space.

```python
from pyXenium.analysis import rna_protein_cluster_analysis

summary, models = rna_protein_cluster_analysis(
adata,
n_clusters=12,
n_pcs=30,
min_cells_per_cluster=100,
min_cells_per_group=30,
hidden_layer_sizes=(128, 64),
)

# Inspect metrics for the first few cluster × protein combinations
print(summary.head())

# Retrieve the fitted model for a specific cluster and protein
podocin_model = models["cluster_3"]["Podocin"]
print(podocin_model.test_accuracy)
```

## Data format expectations

- **Cell-feature matrix (MEX)** under `cell_feature_matrix/`:
Expand Down
4 changes: 3 additions & 1 deletion src/pyXenium/analysis/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from .protein_gene_correlation import protein_gene_correlation
from .rna_protein_joint import ProteinModelResult, rna_protein_cluster_analysis

__all__ = [
*globals().get("__all__", []),
"protein_gene_correlation",
"rna_protein_cluster_analysis",
"ProteinModelResult",
]
345 changes: 345 additions & 0 deletions src/pyXenium/analysis/rna_protein_joint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,345 @@
"""Utilities for joint RNA + protein analysis on Xenium AnnData objects.

This module provides a high-level pipeline that clusters cells using the RNA
expression matrix and, for each protein marker, trains a small neural network
to explain high-vs-low protein states within every RNA-defined cluster.

The function is intentionally self-contained and only relies on scikit-learn
primitives so that it can operate on large Xenium datasets without requiring
extra dependencies (e.g. Scanpy). It works directly on the data structures
produced by :func:`pyXenium.io.xenium_gene_protein_loader.load_xenium_gene_protein`.
"""

from __future__ import annotations

from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple

import numpy as np
import pandas as pd
from anndata import AnnData
from scipy import sparse
from sklearn.cluster import KMeans
from sklearn.decomposition import TruncatedSVD
from sklearn.metrics import accuracy_score, roc_auc_score
from sklearn.model_selection import train_test_split
from sklearn.neural_network import MLPClassifier
from sklearn.preprocessing import StandardScaler


@dataclass
class ProteinModelResult:
"""Container for the protein classification model of one cluster."""

protein: str
cluster: str
threshold: float
n_cells: int
n_high: int
n_low: int
train_accuracy: float
test_accuracy: float
test_auc: float
model: MLPClassifier
scaler: StandardScaler


def _get_rna_matrix(adata: AnnData):
"""Return the raw RNA matrix from ``adata`` as CSR sparse matrix."""

X = adata.layers["rna"] if "rna" in adata.layers else adata.X
return X.tocsr() if sparse.issparse(X) else sparse.csr_matrix(np.asarray(X))


def _normalize_log1p(matrix: sparse.csr_matrix, target_sum: float = 1e4) -> sparse.csr_matrix:
"""Library-size normalisation followed by log1p for sparse matrices."""

matrix = matrix.astype(np.float32)
cell_sums = np.array(matrix.sum(axis=1)).ravel()
cell_sums[cell_sums == 0] = 1.0
inv = sparse.diags((target_sum / cell_sums).astype(np.float32))
norm = inv @ matrix
norm.data = np.log1p(norm.data)
return norm


def _fit_pcs(
matrix: sparse.csr_matrix,
n_components: int,
random_state: Optional[int],
) -> np.ndarray:
"""Fit TruncatedSVD on the RNA matrix and return dense principal components."""

n_features = matrix.shape[1]
if n_features <= 1:
raise ValueError("RNA matrix must contain at least two genes for SVD.")

n_components = max(1, min(n_components, n_features - 1))
svd = TruncatedSVD(n_components=n_components, random_state=random_state)
pcs = svd.fit_transform(matrix)
return pcs.astype(np.float32)


def _resolve_protein_frame(adata: AnnData) -> pd.DataFrame:
"""Return the protein matrix stored in ``adata.obsm['protein']`` as DataFrame."""

if "protein" not in adata.obsm:
raise KeyError("AnnData is missing 'protein' modality in adata.obsm['protein'].")
Comment on lines +86 to +87
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议: 如果缺少 'protein',会引发 KeyError,但没有提供备用方案或指导。

考虑更新错误消息,包含添加或计算 'protein' 模态的说明,以便用户更容易解决此问题。

Suggested change
if "protein" not in adata.obsm:
raise KeyError("AnnData is missing 'protein' modality in adata.obsm['protein'].")
if "protein" not in adata.obsm:
raise KeyError(
"AnnData is missing 'protein' modality in adata.obsm['protein'].\n"
"To resolve this, ensure that the protein data is loaded or computed and assigned to adata.obsm['protein'].\n"
"For example, use pyXenium.io.xenium_gene_protein_loader.load_xenium_gene_protein or another appropriate method to add the protein modality."
)
Original comment in English

suggestion: KeyError is raised if 'protein' is missing, but no fallback or guidance is provided.

Consider updating the error message to include instructions for adding or computing the 'protein' modality, making it easier for users to address the issue.

Suggested change
if "protein" not in adata.obsm:
raise KeyError("AnnData is missing 'protein' modality in adata.obsm['protein'].")
if "protein" not in adata.obsm:
raise KeyError(
"AnnData is missing 'protein' modality in adata.obsm['protein'].\n"
"To resolve this, ensure that the protein data is loaded or computed and assigned to adata.obsm['protein'].\n"
"For example, use pyXenium.io.xenium_gene_protein_loader.load_xenium_gene_protein or another appropriate method to add the protein modality."
)


protein_df = adata.obsm["protein"]
if isinstance(protein_df, pd.DataFrame):
return protein_df

return pd.DataFrame(np.asarray(protein_df), index=adata.obs_names)


def rna_protein_cluster_analysis(
adata: AnnData,
*,
n_clusters: int = 12,
n_pcs: int = 30,
cluster_key: str = "rna_cluster",
random_state: Optional[int] = 0,
target_sum: float = 1e4,
min_cells_per_cluster: int = 50,
min_cells_per_group: int = 20,
protein_split_method: str = "median",
protein_quantile: float = 0.75,
test_size: float = 0.2,
hidden_layer_sizes: Tuple[int, ...] = (64, 32),
max_iter: int = 200,
early_stopping: bool = True,
) -> Tuple[pd.DataFrame, Dict[str, Dict[str, ProteinModelResult]]]:
"""Joint RNA/protein analysis for Xenium AnnData objects.

The pipeline performs three consecutive steps:

1. **RNA preprocessing** – library-size normalisation (counts per ``target_sum``)
followed by ``log1p``. A :class:`~sklearn.decomposition.TruncatedSVD`
is fitted to obtain ``n_pcs`` latent dimensions.
2. **Clustering** – :class:`~sklearn.cluster.KMeans` is applied on the latent
representation to create ``n_clusters`` RNA-driven cell groups. Cluster
assignments are stored in ``adata.obs[cluster_key]`` and the latent space
in ``adata.obsm['X_rna_pca']``.
3. **Protein explanation** – for every cluster and every protein marker, the
cells are divided into "high" vs. "low" groups (median split by default).
A small neural network (:class:`~sklearn.neural_network.MLPClassifier`)
is trained to predict the binary labels from the RNA latent features. The
training/test accuracies and optional ROC-AUC are reported.

Parameters
----------
adata:
AnnData object returned by
:func:`pyXenium.io.xenium_gene_protein_loader.load_xenium_gene_protein`.
Requires ``adata.layers['rna']`` (or ``adata.X``) and
``adata.obsm['protein']``.
n_clusters:
Number of RNA clusters to compute with KMeans.
n_pcs:
Number of latent components extracted with TruncatedSVD. The value is
automatically capped at ``n_genes - 1``.
cluster_key:
Column name added to ``adata.obs`` that stores cluster labels.
random_state:
Seed for the SVD, KMeans and neural networks. Use ``None`` for random
initialisation.
target_sum:
Target library size after normalisation (Counts Per ``target_sum``).
min_cells_per_cluster:
Clusters with fewer cells are skipped entirely.
min_cells_per_group:
Minimum number of cells required in both "high" and "low" protein
groups to train a neural network.
protein_split_method:
Either ``"median"`` (default) for a median split or ``"quantile"`` to
keep only the top ``protein_quantile`` and bottom ``1 - protein_quantile``
fractions of cells (discarding the middle portion).
protein_quantile:
Quantile used when ``protein_split_method='quantile'``.
test_size:
Fraction of the cluster reserved for the test split when training the
neural network.
hidden_layer_sizes:
Hidden-layer configuration passed to :class:`MLPClassifier`.
max_iter:
Maximum number of training iterations for the neural network.
early_stopping:
Whether to use early stopping in :class:`MLPClassifier`.

Returns
-------
summary:
:class:`pandas.DataFrame` summarising the trained models. Columns are
``['cluster', 'protein', 'threshold', 'n_cells', 'n_high', 'n_low',
'train_accuracy', 'test_accuracy', 'test_auc']``.
models:
Nested dictionary ``{cluster -> {protein -> ProteinModelResult}}``
containing the fitted neural networks and scalers for downstream use.

Examples
--------
>>> from pyXenium.analysis import rna_protein_cluster_analysis
>>> summary, models = rna_protein_cluster_analysis(adata, n_clusters=8)
>>> summary.head()
cluster protein threshold n_cells ... test_accuracy test_auc
0 cluster_0 EPCAM (µm) 0.563100 512 ... 0.84 0.91
1 cluster_0 Podocin (µm^2) 0.118775 512 ... 0.79 0.87
"""

if adata.n_obs == 0:
raise ValueError("AnnData contains no cells (n_obs == 0).")

protein_df = _resolve_protein_frame(adata)
if protein_df.shape[1] == 0:
raise ValueError("AnnData.obsm['protein'] is empty – nothing to analyse.")

rna_csr = _get_rna_matrix(adata)
if rna_csr.shape[1] < 2:
raise ValueError("RNA modality must have at least two genes for clustering.")

log_norm = _normalize_log1p(rna_csr, target_sum=target_sum)
pcs = _fit_pcs(log_norm, n_components=n_pcs, random_state=random_state)
adata.obsm["X_rna_pca"] = pcs

kmeans = KMeans(n_clusters=n_clusters, random_state=random_state, n_init=10)
cluster_labels = kmeans.fit_predict(pcs)
cluster_names = np.array([f"cluster_{i}" for i in cluster_labels])
adata.obs[cluster_key] = cluster_names

results: List[Dict[str, float]] = []
models: Dict[str, Dict[str, ProteinModelResult]] = {}

unique_clusters = pd.Index(np.unique(cluster_names))
for cluster in unique_clusters:
idx = np.where(cluster_names == cluster)[0]
if idx.size < min_cells_per_cluster:
continue

cluster_pcs = pcs[idx]
cluster_protein = protein_df.iloc[idx]
models.setdefault(cluster, {})

for protein in cluster_protein.columns:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

问题 (复杂性): 考虑将蛋白质数据准备和分类器拟合逻辑提取到辅助函数中,以简化主循环并降低循环复杂度。

这里有两个你可以从内循环中提取出来的小型辅助函数,以简化嵌套、降低循环复杂度并保持所有功能:

def _prepare_protein_data(
    values: np.ndarray,
    X_cluster: np.ndarray,
    method: str,
    quantile: float,
    min_cells: int,
) -> Tuple[Optional[float], np.ndarray, np.ndarray]:
    """Return threshold, filtered X and labels or (None,_,_) if not enough cells."""
    finite = np.isfinite(values)
    if finite.sum() < min_cells * 2:
        return None, None, None

    vals = values[finite]
    Xc  = X_cluster[finite]

    if method == "median":
        thr    = float(np.median(vals))
        labels = (vals >= thr).astype(int)
    elif method == "quantile":
        q = float(quantile)
        if not 0.5 < q < 1.0:
            raise ValueError("protein_quantile must be between 0.5 and 1.0")
        high_thr = np.quantile(vals, q)
        low_mask  = vals <= np.quantile(vals, 1 - q)
        high_mask = vals >= high_thr
        sel       = low_mask | high_mask
        if sel.sum() < min_cells * 2:
            return None, None, None
        vals   = vals[sel]
        Xc     = Xc[sel]
        labels = high_mask[sel].astype(int)
        thr    = float(high_thr)
    else:
        raise ValueError("protein_split_method must be 'median' or 'quantile'")

    if labels.sum() < min_cells or (len(labels) - labels.sum()) < min_cells:
        return None, None, None

    return thr, Xc, labels
def _fit_protein_classifier(
    X: np.ndarray,
    y: np.ndarray,
    test_size: float,
    random_state: Optional[int],
    hidden_layer_sizes: Tuple[int, ...],
    max_iter: int,
    early_stopping: bool,
) -> Optional[ProteinModelResult]:
    """Train/test split + scaler + MLPClassifier, return result or None on error."""
    try:
        X_train, X_test, y_train, y_test = train_test_split(
            X, y, test_size=test_size, stratify=y, random_state=random_state
        )
    except ValueError:
        return None

    scaler = StandardScaler()
    Xtr = scaler.fit_transform(X_train)
    Xte = scaler.transform(X_test)

    clf = MLPClassifier(
        hidden_layer_sizes=hidden_layer_sizes,
        random_state=random_state,
        max_iter=max_iter,
        early_stopping=early_stopping,
    )
    try:
        clf.fit(Xtr, y_train)
    except Exception:
        return None

    train_acc = accuracy_score(y_train, clf.predict(Xtr))
    y_pred    = clf.predict(Xte)
    test_acc  = accuracy_score(y_test, y_pred)
    if hasattr(clf, "predict_proba") and len(np.unique(y_test)) == 2:
        try:
            auc = roc_auc_score(y_test, clf.predict_proba(Xte)[:, 1])
        except ValueError:
            auc = float("nan")
    else:
        auc = float("nan")

    return ProteinModelResult(
        protein="",
        cluster="",
        threshold=0.0,
        n_cells=len(y),
        n_high=int(y.sum()),
        n_low=int(len(y) - y.sum()),
        train_accuracy=float(train_acc),
        test_accuracy=float(test_acc),
        test_auc=float(auc),
        model=clf,
        scaler=scaler,
    )

然后在主循环中将大块代码替换为:

for protein in cluster_protein.columns:
    vals = cluster_protein[protein].to_numpy(dtype=np.float32)
    thr, X_sel, labels = _prepare_protein_data(
        vals, cluster_pcs, protein_split_method, protein_quantile, min_cells_per_group
    )
    if thr is None:
        continue

    result = _fit_protein_classifier(
        X_sel, labels, test_size, random_state,
        hidden_layer_sizes, max_iter, early_stopping,
    )
    if result is None:
        continue

    result.protein = protein
    result.cluster = cluster
    result.threshold = thr

    models[cluster][protein] = result
    results.append({
        "cluster": cluster,
        "protein": protein,
        "threshold": thr,
        "n_cells": result.n_cells,
        "n_high": result.n_high,
        "n_low": result.n_low,
        "train_accuracy": result.train_accuracy,
        "test_accuracy": result.test_accuracy,
        "test_auc": result.test_auc,
    })

这将大约 60 行嵌套逻辑提取到两个专注的辅助函数中,并简化了你的主循环。

Original comment in English

issue (complexity): Consider extracting the protein data preparation and classifier fitting logic into helper functions to flatten the main loop and reduce cyclomatic complexity.

Here are two small helpers you can extract from the inner loops to flatten that nesting, reduce cyclomatic complexity, and keep all functionality:

def _prepare_protein_data(
    values: np.ndarray,
    X_cluster: np.ndarray,
    method: str,
    quantile: float,
    min_cells: int,
) -> Tuple[Optional[float], np.ndarray, np.ndarray]:
    """Return threshold, filtered X and labels or (None,_,_) if not enough cells."""
    finite = np.isfinite(values)
    if finite.sum() < min_cells * 2:
        return None, None, None

    vals = values[finite]
    Xc  = X_cluster[finite]

    if method == "median":
        thr    = float(np.median(vals))
        labels = (vals >= thr).astype(int)
    elif method == "quantile":
        q = float(quantile)
        if not 0.5 < q < 1.0:
            raise ValueError("protein_quantile must be between 0.5 and 1.0")
        high_thr = np.quantile(vals, q)
        low_mask  = vals <= np.quantile(vals, 1 - q)
        high_mask = vals >= high_thr
        sel       = low_mask | high_mask
        if sel.sum() < min_cells * 2:
            return None, None, None
        vals   = vals[sel]
        Xc     = Xc[sel]
        labels = high_mask[sel].astype(int)
        thr    = float(high_thr)
    else:
        raise ValueError("protein_split_method must be 'median' or 'quantile'")

    if labels.sum() < min_cells or (len(labels) - labels.sum()) < min_cells:
        return None, None, None

    return thr, Xc, labels
def _fit_protein_classifier(
    X: np.ndarray,
    y: np.ndarray,
    test_size: float,
    random_state: Optional[int],
    hidden_layer_sizes: Tuple[int, ...],
    max_iter: int,
    early_stopping: bool,
) -> Optional[ProteinModelResult]:
    """Train/test split + scaler + MLPClassifier, return result or None on error."""
    try:
        X_train, X_test, y_train, y_test = train_test_split(
            X, y, test_size=test_size, stratify=y, random_state=random_state
        )
    except ValueError:
        return None

    scaler = StandardScaler()
    Xtr = scaler.fit_transform(X_train)
    Xte = scaler.transform(X_test)

    clf = MLPClassifier(
        hidden_layer_sizes=hidden_layer_sizes,
        random_state=random_state,
        max_iter=max_iter,
        early_stopping=early_stopping,
    )
    try:
        clf.fit(Xtr, y_train)
    except Exception:
        return None

    train_acc = accuracy_score(y_train, clf.predict(Xtr))
    y_pred    = clf.predict(Xte)
    test_acc  = accuracy_score(y_test, y_pred)
    if hasattr(clf, "predict_proba") and len(np.unique(y_test)) == 2:
        try:
            auc = roc_auc_score(y_test, clf.predict_proba(Xte)[:, 1])
        except ValueError:
            auc = float("nan")
    else:
        auc = float("nan")

    return ProteinModelResult(
        protein="",
        cluster="",
        threshold=0.0,
        n_cells=len(y),
        n_high=int(y.sum()),
        n_low=int(len(y) - y.sum()),
        train_accuracy=float(train_acc),
        test_accuracy=float(test_acc),
        test_auc=float(auc),
        model=clf,
        scaler=scaler,
    )

Then in your main loop replace the big block with:

for protein in cluster_protein.columns:
    vals = cluster_protein[protein].to_numpy(dtype=np.float32)
    thr, X_sel, labels = _prepare_protein_data(
        vals, cluster_pcs, protein_split_method, protein_quantile, min_cells_per_group
    )
    if thr is None:
        continue

    result = _fit_protein_classifier(
        X_sel, labels, test_size, random_state,
        hidden_layer_sizes, max_iter, early_stopping,
    )
    if result is None:
        continue

    result.protein = protein
    result.cluster = cluster
    result.threshold = thr

    models[cluster][protein] = result
    results.append({
        "cluster": cluster,
        "protein": protein,
        "threshold": thr,
        "n_cells": result.n_cells,
        "n_high": result.n_high,
        "n_low": result.n_low,
        "train_accuracy": result.train_accuracy,
        "test_accuracy": result.test_accuracy,
        "test_auc": result.test_auc,
    })

This pulls out ~60 lines of nested logic into two focused helpers and flattens your main loop.

values_all = cluster_protein[protein].to_numpy(dtype=np.float32)
finite_mask = np.isfinite(values_all)
if finite_mask.sum() < min_cells_per_group * 2:
continue

values = values_all[finite_mask]
X_cluster = cluster_pcs[finite_mask]

if protein_split_method == "median":
threshold = float(np.median(values))
labels = (values >= threshold).astype(int)
elif protein_split_method == "quantile":
q = float(protein_quantile)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议 (代码质量): 我们发现了这些问题:

Suggested change
q = float(protein_quantile)
q = protein_quantile


解释

此函数的质量得分低于 25% 的质量阈值。
此得分是方法长度、认知复杂度和工作内存的组合。

你如何解决这个问题?

重构此函数以使其更短、更具可读性可能值得。

  • 通过将功能片段提取到自己的函数中来减少函数长度。这是你能做的最重要的事情——理想情况下,一个函数应该少于 10 行。
  • 减少嵌套,也许可以通过引入守卫子句来提前返回。
  • 确保变量的范围紧密,以便使用相关概念的代码在函数中坐在一起,而不是分散开来。
Original comment in English

suggestion (code-quality): We've found these issues:

Suggested change
q = float(protein_quantile)
q = protein_quantile


Explanation

The quality score for this function is below the quality threshold of 25%.
This score is a combination of the method length, cognitive complexity and working memory.

How can you solve this?

It might be worth refactoring this function to make it shorter and more readable.

  • Reduce the function length by extracting pieces of functionality out into
    their own functions. This is the most important thing you can do - ideally a
    function should be less than 10 lines.
  • Reduce nesting, perhaps by introducing guard clauses to return early.
  • Ensure that variables are tightly scoped, so that code using related concepts
    sits together within the function rather than being scattered.

if not 0.5 < q < 1.0:
raise ValueError("protein_quantile must be between 0.5 and 1.0 (exclusive).")
high_thr = np.quantile(values, q)
low_mask = values <= np.quantile(values, 1.0 - q)
high_mask = values >= high_thr
selected_mask = high_mask | low_mask
if selected_mask.sum() < min_cells_per_group * 2:
continue
values = values[selected_mask]
X_cluster = X_cluster[selected_mask]
labels = high_mask[selected_mask].astype(int)
threshold = float(high_thr)
else:
raise ValueError("protein_split_method must be 'median' or 'quantile'.")

n_selected = labels.size
n_high = int(labels.sum())
n_low = int(n_selected - n_high)

if n_high < min_cells_per_group or n_low < min_cells_per_group:
continue

try:
X_train, X_test, y_train, y_test = train_test_split(
X_cluster,
labels,
test_size=test_size,
random_state=random_state,
stratify=labels,
)
except ValueError:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议 (bug_risk):train_test_split 中遇到 ValueError 时静默继续可能会在不通知的情况下跳过集群。

当由于分层样本不足而跳过集群时,添加日志记录以提高可追溯性。

Suggested change
except ValueError:
except ValueError:
import logging
logging.warning(
"Skipping cluster due to insufficient samples for stratification (cluster size: %d, labels: %s)",
len(X_cluster), set(labels)
)
Original comment in English

suggestion (bug_risk): Silently continuing on ValueError in train_test_split may skip clusters without notification.

Add logging when clusters are skipped due to insufficient samples for stratification to improve traceability.

Suggested change
except ValueError:
except ValueError:
import logging
logging.warning(
"Skipping cluster due to insufficient samples for stratification (cluster size: %d, labels: %s)",
len(X_cluster), set(labels)
)

# Not enough samples to stratify.
continue

scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)

clf = MLPClassifier(
hidden_layer_sizes=hidden_layer_sizes,
random_state=random_state,
max_iter=max_iter,
early_stopping=early_stopping,
)

try:
clf.fit(X_train_scaled, y_train)
Comment on lines +282 to +283
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议 (bug_risk): 在模型拟合期间捕获所有异常可能会掩盖潜在问题。

只捕获特定的异常,例如 ValueError 或 ConvergenceWarning,以确保在开发过程中不会遗漏关键错误。

Suggested change
try:
clf.fit(X_train_scaled, y_train)
from sklearn.exceptions import ConvergenceWarning
import warnings
try:
with warnings.catch_warnings():
warnings.filterwarnings("error", category=ConvergenceWarning)
clf.fit(X_train_scaled, y_train)
except (ValueError, ConvergenceWarning):
continue
Original comment in English

suggestion (bug_risk): Catching all exceptions during model fitting may obscure underlying issues.

Catch only specific exceptions like ValueError or ConvergenceWarning to ensure critical errors are not missed during development.

Suggested change
try:
clf.fit(X_train_scaled, y_train)
from sklearn.exceptions import ConvergenceWarning
import warnings
try:
with warnings.catch_warnings():
warnings.filterwarnings("error", category=ConvergenceWarning)
clf.fit(X_train_scaled, y_train)
except (ValueError, ConvergenceWarning):
continue

except Exception:
continue

train_acc = float(accuracy_score(y_train, clf.predict(X_train_scaled)))
test_pred = clf.predict(X_test_scaled)
test_acc = float(accuracy_score(y_test, test_pred))

if hasattr(clf, "predict_proba") and len(np.unique(y_test)) == 2:
probs = clf.predict_proba(X_test_scaled)[:, 1]
try:
test_auc = float(roc_auc_score(y_test, probs))
except ValueError:
test_auc = float("nan")
else:
test_auc = float("nan")

result = ProteinModelResult(
protein=protein,
cluster=cluster,
threshold=threshold,
n_cells=n_selected,
n_high=n_high,
n_low=n_low,
train_accuracy=train_acc,
test_accuracy=test_acc,
test_auc=test_auc,
model=clf,
scaler=scaler,
)

models[cluster][protein] = result
results.append(
{
"cluster": cluster,
"protein": protein,
"threshold": threshold,
"n_cells": n_selected,
"n_high": n_high,
"n_low": n_low,
"train_accuracy": train_acc,
"test_accuracy": test_acc,
"test_auc": test_auc,
}
)

summary = pd.DataFrame(results, columns=[
"cluster",
"protein",
"threshold",
"n_cells",
"n_high",
"n_low",
"train_accuracy",
"test_accuracy",
"test_auc",
])

return summary, models


__all__ = ["rna_protein_cluster_analysis", "ProteinModelResult"]

Loading