Skip to content

Commit e5a9129

Browse files
committed
add new feature
1 parent 14f4190 commit e5a9129

File tree

3 files changed

+202
-0
lines changed

3 files changed

+202
-0
lines changed
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# pyXenium/analysis/differential.py
2+
3+
import pandas as pd
4+
import numpy as np
5+
from scipy.stats import ttest_ind, spearmanr
6+
from statsmodels.stats import multitest
7+
8+
def get_rna_expr_df(adata, layer_key="rna"):
9+
expr = adata.layers.get(layer_key)
10+
if expr is None:
11+
raise KeyError(f"adata.layers does not have {layer_key}")
12+
try:
13+
arr = expr.toarray()
14+
except:
15+
arr = expr
16+
return pd.DataFrame(arr, index=adata.obs.index, columns=adata.var.index)
17+
18+
def analyze_one_score(adata, rna_expr, cluster, protein, cluster_key="rna_cluster", score_prefix="score", min_cells=3):
19+
score_col = f"{score_prefix}:{cluster}:{protein}"
20+
if score_col not in adata.obs.columns:
21+
return None
22+
23+
mask = (adata.obs[cluster_key] == cluster)
24+
cells = adata.obs.index[mask]
25+
if len(cells) < min_cells:
26+
return None
27+
28+
s = adata.obs.loc[cells, score_col].astype(float)
29+
median = s.median()
30+
high = s[s >= median].index
31+
low = s[s < median].index
32+
if len(high) < min_cells or len(low) < min_cells:
33+
return None
34+
35+
# 差异表达
36+
de_res = []
37+
for gene in rna_expr.columns:
38+
gh = rna_expr.loc[high, gene].dropna()
39+
gl = rna_expr.loc[low, gene].dropna()
40+
if len(gh) < 3 or len(gl) < 3:
41+
continue
42+
t, p = ttest_ind(gh, gl, equal_var=False)
43+
de_res.append((gene, t, p, gh.mean() - gl.mean()))
44+
de_df = pd.DataFrame(de_res, columns=["gene","tstat","pval","mean_diff"])
45+
if not de_df.empty:
46+
de_df["adj_pval"] = multitest.multipletests(de_df["pval"], method="fdr_bh")[1]
47+
48+
# 相关性
49+
corr_res = []
50+
for gene in rna_expr.columns:
51+
x = rna_expr.loc[cells, gene].fillna(0).values
52+
y = s.values
53+
r, p = spearmanr(x, y)
54+
corr_res.append((gene, r, p))
55+
corr_df = pd.DataFrame(corr_res, columns=["gene","spearman_r","pval"])
56+
if not corr_df.empty:
57+
corr_df["adj_pval"] = multitest.multipletests(corr_df["pval"], method="fdr_bh")[1]
58+
59+
return {
60+
"cluster": cluster,
61+
"protein": protein,
62+
"n_cells": len(cells),
63+
"de": de_df,
64+
"corr": corr_df
65+
}
66+
67+
def run_all_clusters_proteins(adata, rna_expr, cluster_label, protein_names, score_prefix="score", min_cells=3):
68+
results = []
69+
for cl in adata.obs[cluster_label].unique():
70+
for p in protein_names:
71+
rec = analyze_one_score(
72+
adata, rna_expr, cl, p,
73+
cluster_key=cluster_label, score_prefix=score_prefix, min_cells=min_cells
74+
)
75+
if rec is not None:
76+
results.append(rec)
77+
return results
78+
79+
def summarize_results(results):
80+
# 把结果字典列表拆成两个 DataFrame
81+
de_list = []
82+
corr_list = []
83+
for rec in results:
84+
c = rec["cluster"]; p = rec["protein"]
85+
df_de = rec["de"].copy()
86+
df_de["cluster"] = c; df_de["protein"] = p
87+
de_list.append(df_de)
88+
df_corr = rec["corr"].copy()
89+
df_corr["cluster"] = c; df_corr["protein"] = p
90+
corr_list.append(df_corr)
91+
all_de = pd.concat(de_list, ignore_index=True) if de_list else pd.DataFrame()
92+
all_corr = pd.concat(corr_list, ignore_index=True) if corr_list else pd.DataFrame()
93+
return all_de, all_corr

src/pyXenium/analysis/plotting.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# pyXenium/analysis/plotting.py
2+
3+
import seaborn as sns
4+
import matplotlib.pyplot as plt
5+
import pandas as pd
6+
7+
def plot_auc_heatmap(summary: pd.DataFrame, figsize=(10,8)):
8+
mat = summary.pivot(index="cluster", columns="protein", values="test_auc")
9+
mat = mat.apply(pd.to_numeric, errors="coerce")
10+
g = sns.clustermap(mat, cmap="viridis", linewidths=.3, figsize=figsize)
11+
g.ax_heatmap.set_xlabel("Protein"); g.ax_heatmap.set_ylabel("Cluster")
12+
return g
13+
14+
def plot_topk_per_cluster(summary: pd.DataFrame, k=5, metric="test_auc"):
15+
topk = (summary.sort_values(["cluster", metric], ascending=[True, False])
16+
.groupby("cluster").head(k))
17+
fig, ax = plt.subplots(figsize=(max(10, k * 1.2), 6))
18+
labels = []
19+
vals = []
20+
for cl, sub in topk.groupby("cluster"):
21+
for _, r in sub.iterrows():
22+
labels.append(f"{cl}:{r['protein']}")
23+
vals.append(r[metric])
24+
ax.bar(labels, vals)
25+
ax.set_ylabel(metric)
26+
ax.set_xticklabels(labels, rotation=90)
27+
plt.tight_layout()
28+
return fig
29+
30+
def plot_DE_volcano(de_df: pd.DataFrame, title="DE Volcano",
31+
logfc_col="mean_diff", pval_col="pval", adj_col="adj_pval",
32+
fdr_thresh=0.05):
33+
df = de_df.copy()
34+
df["-log10p"] = -np.log10(df[pval_col])
35+
plt.figure(figsize=(6,5))
36+
sns.scatterplot(data=df, x=logfc_col, y="-log10p",
37+
hue=df[adj_col] < fdr_thresh,
38+
palette={True: "red", False: "gray"}, legend=False)
39+
plt.axhline(-np.log10(0.05), ls="--", color="black")
40+
plt.title(title)
41+
plt.xlabel("Mean difference (High vs Low)")
42+
plt.ylabel("-log10(p)")
43+
plt.tight_layout()
44+
plt.show()
45+
46+
def plot_model_diagnostics(adata, models, cluster, protein, feature_key="X_rna_pca"):
47+
from sklearn.metrics import RocCurveDisplay, PrecisionRecallDisplay
48+
from sklearn.calibration import calibration_curve
49+
50+
res = models[cluster][protein]
51+
clf, scaler = res.model, res.scaler
52+
thr = getattr(res, "threshold", None)
53+
54+
mask = (adata.obs["rna_cluster"] == cluster)
55+
X = scaler.transform(adata.obsm[feature_key][mask, :])
56+
# y 真值需要你自己定义:可能 adata.obs[f"protein:{protein}"] ≥ thr
57+
y = (adata.obs.loc[mask, f"protein:{protein}"] >= thr).astype(int).to_numpy()
58+
y_prob = clf.predict_proba(X)[:, 1]
59+
60+
RocCurveDisplay.from_predictions(y, y_prob)
61+
plt.title(f"ROC — {cluster}:{protein}")
62+
PrecisionRecallDisplay.from_predictions(y, y_prob)
63+
plt.title(f"PR — {cluster}:{protein}")
64+
prob_true, prob_pred = calibration_curve(y, y_prob, n_bins=10, strategy="quantile")
65+
plt.figure()
66+
plt.plot(prob_pred, prob_true, marker="o")
67+
plt.plot([0,1],[0,1], "--")
68+
plt.xlabel("Predicted prob"); plt.ylabel("Empirical freq")
69+
plt.title(f"Calibration — {cluster}:{protein}")
70+
plt.tight_layout()
71+
plt.show()

src/pyXenium/analysis/scoring.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# pyXenium/analysis/scoring.py
2+
3+
import numpy as np
4+
import pandas as pd
5+
6+
def write_model_scores(adata, models, feature_key="X_rna_pca", cluster_key="rna_cluster"):
7+
"""
8+
对 models 中每个 (cluster, protein) 模型,在对应簇的细胞上预测概率,
9+
并把结果写入 adata.obs 作为 score:cluster:protein 列。
10+
"""
11+
# 为避免 obs 碎片化,建议先收集所有新列数据,再一次性 assign
12+
new_cols = {}
13+
for cluster, protodict in models.items():
14+
mask = (adata.obs[cluster_key] == cluster)
15+
if mask.sum() == 0:
16+
continue
17+
X_all = adata.obsm.get(feature_key)
18+
if X_all is None:
19+
raise KeyError(f"Feature key {feature_key} not in adata.obsm.")
20+
X_sub = X_all[mask, :]
21+
idx = adata.obs.index[mask]
22+
23+
for protein, res in protodict.items():
24+
clf = res.model
25+
scaler = res.scaler
26+
X_scaled = scaler.transform(X_sub)
27+
y_prob = clf.predict_proba(X_scaled)[:, 1]
28+
col_name = f"score:{cluster}:{protein}"
29+
# 创建一个全体 NaN 列,然后填入子集
30+
col_ser = pd.Series(np.nan, index=adata.obs.index)
31+
col_ser.loc[idx] = y_prob
32+
new_cols[col_name] = col_ser
33+
34+
# 批量添加到 adata.obs
35+
for col_name, col_ser in new_cols.items():
36+
adata.obs[col_name] = col_ser
37+
38+
return adata

0 commit comments

Comments
 (0)