|
3 | 3 | import seaborn as sns |
4 | 4 | import matplotlib.pyplot as plt |
5 | 5 | import pandas as pd |
| 6 | +import numpy as np |
| 7 | +import scipy.cluster.hierarchy as sch |
| 8 | +from scipy.spatial.distance import pdist |
6 | 9 |
|
7 | | -def plot_auc_heatmap(summary: pd.DataFrame, figsize=(10,8)): |
| 10 | +def plot_auc_heatmap(summary: pd.DataFrame, |
| 11 | + figsize=(10, 8), |
| 12 | + metric="euclidean", |
| 13 | + method="average", |
| 14 | + min_shared=2): |
| 15 | + """ |
| 16 | + 对 summary pivot 出的 AUC 矩阵做热图 + 聚类(clustermap), |
| 17 | + 在距离计算阶段忽略 NaN(即只用两个行/列共有的那部分非 NaN 值计算距离); |
| 18 | + 在热图阶段则对 NaN 用 mask 使其不可着色,但保留行/列显示顺序。 |
| 19 | +
|
| 20 | + 参数 |
| 21 | + ---- |
| 22 | + summary : pd.DataFrame |
| 23 | + 包含至少三列 “cluster”, “protein”, “test_auc”,用于 pivot。 |
| 24 | + figsize : tuple (宽, 高) |
| 25 | + 图像尺寸传给 seaborn.clustermap。 |
| 26 | + metric : str |
| 27 | + 用于 pdist 的距离度量(默认 “euclidean”)。 |
| 28 | + method : str |
| 29 | + 用于 sch.linkage 的聚类方法(默认 “average”)。 |
| 30 | + min_shared : int |
| 31 | + 当两行/列用来计算距离时,要求它们“共同非 NaN 的特征数量” ≥ min_shared; |
| 32 | + 若共有特征太少,则认为距离为缺失(会在后续被替换为一个较大距离)。 |
| 33 | +
|
| 34 | + 返回 |
| 35 | + ---- |
| 36 | + sns.ClusterGrid 对象(clustermap 画出的结果)。 |
| 37 | + """ |
| 38 | + # 构造矩阵 |
8 | 39 | mat = summary.pivot(index="cluster", columns="protein", values="test_auc") |
9 | 40 | mat = mat.apply(pd.to_numeric, errors="coerce") |
10 | | - g = sns.clustermap(mat, cmap="viridis", linewidths=.3, figsize=figsize) |
11 | | - g.ax_heatmap.set_xlabel("Protein"); g.ax_heatmap.set_ylabel("Cluster") |
| 41 | + mat = mat.replace([np.inf, -np.inf], np.nan) |
| 42 | + |
| 43 | + # 内部辅助函数:用来给 pdist 提供“忽略 NaN 的距离函数” |
| 44 | + def _pairwise_dist(u, v): |
| 45 | + # u, v 是一维 numpy 数组 |
| 46 | + mask = (~np.isnan(u)) & (~np.isnan(v)) |
| 47 | + if mask.sum() < min_shared: |
| 48 | + return np.nan |
| 49 | + uu = u[mask] |
| 50 | + vv = v[mask] |
| 51 | + # 注意:这里用了 pdist 但其实只计算两个向量之间距离 |
| 52 | + # 用 numpy 或其它实现也可以 |
| 53 | + return pdist(np.vstack([uu, vv]), metric=metric)[0] |
| 54 | + |
| 55 | + # 计算行 linkage |
| 56 | + # pdist 会把上三角所有行对的距离打包成“压缩距离向量” |
| 57 | + row_dist = pdist(mat.values, metric=_pairwise_dist) |
| 58 | + # 对 row_dist 中的 NaN 距离赋一个较大数(最大有限值 * 1.1) |
| 59 | + finite = row_dist[np.isfinite(row_dist)] |
| 60 | + if finite.size > 0: |
| 61 | + maxd = np.nanmax(finite) |
| 62 | + row_dist = np.where(np.isfinite(row_dist), row_dist, maxd * 1.1) |
| 63 | + row_linkage = sch.linkage(row_dist, method=method) |
| 64 | + else: |
| 65 | + row_linkage = None |
| 66 | + |
| 67 | + # 计算列 linkage(对转置矩阵做同样操作) |
| 68 | + col_dist = pdist(mat.values.T, metric=_pairwise_dist) |
| 69 | + finite2 = col_dist[np.isfinite(col_dist)] |
| 70 | + if finite2.size > 0: |
| 71 | + maxd2 = np.nanmax(finite2) |
| 72 | + col_dist = np.where(np.isfinite(col_dist), col_dist, maxd2 * 1.1) |
| 73 | + col_linkage = sch.linkage(col_dist, method=method) |
| 74 | + else: |
| 75 | + col_linkage = None |
| 76 | + |
| 77 | + # mask:在热图可视化阶段遮蔽 NaN |
| 78 | + mask = mat.isna() |
| 79 | + |
| 80 | + # 调用 clustermap,传入计算好的 linkage,保留行 / 列显示顺序 |
| 81 | + g = sns.clustermap(mat, |
| 82 | + row_linkage=row_linkage, |
| 83 | + col_linkage=col_linkage, |
| 84 | + mask=mask, |
| 85 | + figsize=figsize, |
| 86 | + cmap="viridis", |
| 87 | + linewidths=.3) |
| 88 | + |
| 89 | + g.ax_heatmap.set_xlabel("Protein") |
| 90 | + g.ax_heatmap.set_ylabel("Cluster") |
12 | 91 | return g |
13 | 92 |
|
14 | 93 | def plot_topk_per_cluster(summary: pd.DataFrame, k=5, metric="test_auc"): |
|
0 commit comments