Skip to content

Commit cd31fe0

Browse files
committed
fix
1 parent 990192c commit cd31fe0

File tree

1 file changed

+82
-3
lines changed

1 file changed

+82
-3
lines changed

src/pyXenium/analysis/plotting.py

Lines changed: 82 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,91 @@
33
import seaborn as sns
44
import matplotlib.pyplot as plt
55
import pandas as pd
6+
import numpy as np
7+
import scipy.cluster.hierarchy as sch
8+
from scipy.spatial.distance import pdist
69

7-
def plot_auc_heatmap(summary: pd.DataFrame, figsize=(10,8)):
10+
def plot_auc_heatmap(summary: pd.DataFrame,
11+
figsize=(10, 8),
12+
metric="euclidean",
13+
method="average",
14+
min_shared=2):
15+
"""
16+
对 summary pivot 出的 AUC 矩阵做热图 + 聚类(clustermap),
17+
在距离计算阶段忽略 NaN(即只用两个行/列共有的那部分非 NaN 值计算距离);
18+
在热图阶段则对 NaN 用 mask 使其不可着色,但保留行/列显示顺序。
19+
20+
参数
21+
----
22+
summary : pd.DataFrame
23+
包含至少三列 “cluster”, “protein”, “test_auc”,用于 pivot。
24+
figsize : tuple (宽, 高)
25+
图像尺寸传给 seaborn.clustermap。
26+
metric : str
27+
用于 pdist 的距离度量(默认 “euclidean”)。
28+
method : str
29+
用于 sch.linkage 的聚类方法(默认 “average”)。
30+
min_shared : int
31+
当两行/列用来计算距离时,要求它们“共同非 NaN 的特征数量” ≥ min_shared;
32+
若共有特征太少,则认为距离为缺失(会在后续被替换为一个较大距离)。
33+
34+
返回
35+
----
36+
sns.ClusterGrid 对象(clustermap 画出的结果)。
37+
"""
38+
# 构造矩阵
839
mat = summary.pivot(index="cluster", columns="protein", values="test_auc")
940
mat = mat.apply(pd.to_numeric, errors="coerce")
10-
g = sns.clustermap(mat, cmap="viridis", linewidths=.3, figsize=figsize)
11-
g.ax_heatmap.set_xlabel("Protein"); g.ax_heatmap.set_ylabel("Cluster")
41+
mat = mat.replace([np.inf, -np.inf], np.nan)
42+
43+
# 内部辅助函数:用来给 pdist 提供“忽略 NaN 的距离函数”
44+
def _pairwise_dist(u, v):
45+
# u, v 是一维 numpy 数组
46+
mask = (~np.isnan(u)) & (~np.isnan(v))
47+
if mask.sum() < min_shared:
48+
return np.nan
49+
uu = u[mask]
50+
vv = v[mask]
51+
# 注意:这里用了 pdist 但其实只计算两个向量之间距离
52+
# 用 numpy 或其它实现也可以
53+
return pdist(np.vstack([uu, vv]), metric=metric)[0]
54+
55+
# 计算行 linkage
56+
# pdist 会把上三角所有行对的距离打包成“压缩距离向量”
57+
row_dist = pdist(mat.values, metric=_pairwise_dist)
58+
# 对 row_dist 中的 NaN 距离赋一个较大数(最大有限值 * 1.1)
59+
finite = row_dist[np.isfinite(row_dist)]
60+
if finite.size > 0:
61+
maxd = np.nanmax(finite)
62+
row_dist = np.where(np.isfinite(row_dist), row_dist, maxd * 1.1)
63+
row_linkage = sch.linkage(row_dist, method=method)
64+
else:
65+
row_linkage = None
66+
67+
# 计算列 linkage(对转置矩阵做同样操作)
68+
col_dist = pdist(mat.values.T, metric=_pairwise_dist)
69+
finite2 = col_dist[np.isfinite(col_dist)]
70+
if finite2.size > 0:
71+
maxd2 = np.nanmax(finite2)
72+
col_dist = np.where(np.isfinite(col_dist), col_dist, maxd2 * 1.1)
73+
col_linkage = sch.linkage(col_dist, method=method)
74+
else:
75+
col_linkage = None
76+
77+
# mask:在热图可视化阶段遮蔽 NaN
78+
mask = mat.isna()
79+
80+
# 调用 clustermap,传入计算好的 linkage,保留行 / 列显示顺序
81+
g = sns.clustermap(mat,
82+
row_linkage=row_linkage,
83+
col_linkage=col_linkage,
84+
mask=mask,
85+
figsize=figsize,
86+
cmap="viridis",
87+
linewidths=.3)
88+
89+
g.ax_heatmap.set_xlabel("Protein")
90+
g.ax_heatmap.set_ylabel("Cluster")
1291
return g
1392

1493
def plot_topk_per_cluster(summary: pd.DataFrame, k=5, metric="test_auc"):

0 commit comments

Comments
 (0)