Skip to content

Commit 833ab4f

Browse files
committed
add new feature
1 parent 3537859 commit 833ab4f

File tree

2 files changed

+331
-0
lines changed

2 files changed

+331
-0
lines changed
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": null,
6+
"id": "initial_id",
7+
"metadata": {
8+
"collapsed": true
9+
},
10+
"outputs": [],
11+
"source": [
12+
"from pyXenium.io.xenium_gene_protein_loader import load_xenium_gene_protein\n",
13+
"from pyXenium.analysis import ProteinMicroEnv\n",
14+
"from pyXenium.vis.fig_microenv import build_microenv_figure\n",
15+
"\n",
16+
"# 1) 载入数据\n",
17+
"adata = load_xenium_gene_protein(\n",
18+
" base_path=\"~/proj/data/Xenium_V1_Human_Kidney_FFPE_Protein\",\n",
19+
" clusters_relpath=\"analysis/clustering/gene_expression_graphclust/clusters.csv\",\n",
20+
" cluster_column_name=\"cluster\"\n",
21+
")\n",
22+
"\n",
23+
"# 2) 运行微环境分析(如果你已运行过,可直接跳到第3步)\n",
24+
"anl = ProteinMicroEnv(\n",
25+
" adata=adata,\n",
26+
" protein_obsm=\"protein\",\n",
27+
" protein_norm_obsm=\"protein_norm\",\n",
28+
" cluster_key=\"cluster\",\n",
29+
" spatial_obsm=\"spatial\", # 若没有 'spatial',会回退 x_centroid/y_centroid\n",
30+
" obs_xy=(\"x_centroid\", \"y_centroid\"),\n",
31+
" random_state=0\n",
32+
")\n",
33+
"\n",
34+
"cluster_id = \"3\" # 替换为你的目标簇\n",
35+
"protein = \"Ki-67\" # 替换为你的目标蛋白;比如 EPCAM/KRT8/KRT18/CD68/PDCD1 等\n",
36+
"res = anl.analyze(\n",
37+
" cluster_id=cluster_id,\n",
38+
" protein=protein,\n",
39+
" group_key=\"cluster\", # 若已有 obs['cell_type'],优先用 'cell_type'\n",
40+
" radius=None, # None 自动;也可给 40~60.0\n",
41+
" permutations=999,\n",
42+
" save_dir=\"./microenv_out\"\n",
43+
")\n",
44+
"\n",
45+
"# 3) 生成论文图板(PDF/PNG/SVG)\n",
46+
"outbase = build_microenv_figure(\n",
47+
" adata=adata,\n",
48+
" res=res,\n",
49+
" cluster_id=cluster_id,\n",
50+
" protein=protein,\n",
51+
" group_key=\"cluster\", # 仅用于标题说明\n",
52+
" spatial_obsm=\"spatial\",\n",
53+
" obs_xy=(\"x_centroid\", \"y_centroid\"),\n",
54+
" outdir=\"./figures\",\n",
55+
" basename=None, # 用默认命名 Fig_microenv_cluster{cluster_id}_{protein}\n",
56+
" figsize_inches=(7.0, 5.0),\n",
57+
" scatter_s=0.4, # 点的像素大小(大样本建议 0.3~0.5)\n",
58+
" scale_bar_um=100.0 # 如果坐标单位不是 μm,可改为 None 或正确单位长度\n",
59+
")\n",
60+
"\n",
61+
"print(\"Saved:\", outbase + \".pdf\")\n"
62+
]
63+
}
64+
],
65+
"metadata": {
66+
"kernelspec": {
67+
"display_name": "Python 3",
68+
"language": "python",
69+
"name": "python3"
70+
},
71+
"language_info": {
72+
"codemirror_mode": {
73+
"name": "ipython",
74+
"version": 2
75+
},
76+
"file_extension": ".py",
77+
"mimetype": "text/x-python",
78+
"name": "python",
79+
"nbconvert_exporter": "python",
80+
"pygments_lexer": "ipython2",
81+
"version": "2.7.6"
82+
}
83+
},
84+
"nbformat": 4,
85+
"nbformat_minor": 5
86+
}

src/pyXenium/vis/fig_microenv.py

Lines changed: 245 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,245 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
pyXenium/vis/fig_microenv.py
4+
5+
Build a publication-ready multi-panel figure from ProteinMicroEnv analysis:
6+
A: spatial categorical (protein status), B: spatial numeric (protein level),
7+
C: neighbor enrichment bars, D: microenvironment predictability (coef + AUC),
8+
E: RNA DE volcano (within cluster), F: protein distribution (hist/KDE).
9+
"""
10+
11+
from __future__ import annotations
12+
import os
13+
from typing import Optional, Tuple, Dict
14+
15+
import numpy as np
16+
import pandas as pd
17+
import matplotlib as mpl
18+
import matplotlib.pyplot as plt
19+
from matplotlib.gridspec import GridSpec
20+
21+
import scanpy as sc
22+
from anndata import AnnData
23+
24+
# ---------------------- Global style (journal-friendly) ----------------------
25+
26+
def set_paper_rc(font_family: str = "Arial",
27+
base_size: float = 8.0,
28+
line_width: float = 0.8) -> None:
29+
"""A minimal, journal-friendly rcParams setup."""
30+
mpl.rcParams.update({
31+
"font.family": "sans-serif",
32+
"font.sans-serif": [font_family],
33+
"font.size": base_size,
34+
"axes.titlesize": base_size,
35+
"axes.labelsize": base_size,
36+
"xtick.labelsize": base_size - 0.5,
37+
"ytick.labelsize": base_size - 0.5,
38+
"axes.linewidth": line_width,
39+
"grid.linewidth": 0.5,
40+
"legend.frameon": False,
41+
"pdf.fonttype": 42, # TrueType (editable in AI)
42+
"ps.fonttype": 42,
43+
"savefig.dpi": 600,
44+
"figure.dpi": 150,
45+
})
46+
47+
# ---------------------- Helpers ----------------------
48+
49+
def _get_coords(adata: AnnData,
50+
prefer_obsm: str = "spatial",
51+
obs_xy: Tuple[str, str] = ("x_centroid", "y_centroid")) -> np.ndarray:
52+
if prefer_obsm in adata.obsm.keys():
53+
arr = np.asarray(adata.obsm[prefer_obsm])
54+
return arr[:, :2]
55+
return adata.obs.loc[:, [obs_xy[0], obs_xy[1]]].to_numpy()
56+
57+
def _draw_scale_bar(ax, coords: np.ndarray, length_um: float = 100.0, pad_ratio: float = 0.04) -> None:
58+
"""Add a simple horizontal scale bar (assumes coords in μm)."""
59+
xmin, ymin = coords[:,0].min(), coords[:,1].min()
60+
xmax, ymax = coords[:,0].max(), coords[:,1].max()
61+
L = length_um
62+
pad = pad_ratio * (xmax - xmin)
63+
x0 = xmin + pad
64+
y0 = ymin + pad
65+
ax.plot([x0, x0 + L], [y0, y0], lw=1.2, color="black")
66+
ax.text(x0 + L/2, y0 + 0.8*pad, f"{int(L)} μm", ha="center", va="bottom")
67+
68+
def _rasterized_scatter(ax, x, y, c, title: str = "", rasterized: bool = True,
69+
vmin=None, vmax=None, cmap="viridis", s=1.0, alpha=0.9):
70+
sca = ax.scatter(x, y, c=c, s=s, alpha=alpha, cmap=cmap,
71+
rasterized=rasterized, vmin=vmin, vmax=vmax)
72+
ax.set_aspect("equal", adjustable="box"); ax.set_title(title)
73+
ax.set_xticks([]); ax.set_yticks([])
74+
cb = plt.colorbar(sca, ax=ax, fraction=0.046, pad=0.02)
75+
return sca, cb
76+
77+
def _categorical_scatter(ax, coords: np.ndarray, labels: pd.Series,
78+
title: str = "", s=1.0, alpha=0.9):
79+
"""Plot categorical labels (incl. NaN->grey)."""
80+
cat = labels.astype("category")
81+
codes = cat.cat.codes.to_numpy() # NaN -> -1
82+
mask = codes != -1
83+
# Discrete colormap for categories
84+
base = plt.get_cmap("tab20", max(len(cat.cat.categories), 1))
85+
if mask.any():
86+
sca = ax.scatter(coords[mask,0], coords[mask,1], c=codes[mask],
87+
s=s, alpha=alpha, cmap=base, rasterized=True)
88+
cb = plt.colorbar(sca, ax=ax, fraction=0.046, pad=0.02)
89+
cb.set_ticks(np.arange(len(cat.cat.categories)))
90+
cb.set_ticklabels(list(cat.cat.categories))
91+
if (~mask).any():
92+
ax.scatter(coords[~mask,0], coords[~mask,1], c="lightgrey", s=s, alpha=alpha, rasterized=True)
93+
ax.set_aspect("equal", adjustable="box"); ax.set_title(title)
94+
ax.set_xticks([]); ax.set_yticks([])
95+
96+
def _barh_with_ci(ax, df: pd.DataFrame, value_col: str, label_col: str,
97+
top_k: int = 10, title: str = "", invert: bool = True):
98+
df2 = df.sort_values(value_col, ascending=False).head(top_k)
99+
ax.barh(df2[label_col], df2[value_col])
100+
if invert: ax.invert_yaxis()
101+
ax.set_title(title)
102+
103+
def _volcano(ax, de: pd.DataFrame, title: str = "", max_points: int = 20000):
104+
"""Generic volcano; expects columns: 'logfoldchanges', 'pvals_adj', 'group'."""
105+
df = de.copy()
106+
# pick one direction (protein_high vs rest)
107+
if "group" in df.columns:
108+
g = sorted(df["group"].unique())
109+
# prefer the group named 'protein_high'
110+
grp = "protein_high" if "protein_high" in g else g[0]
111+
df = df[df["group"] == grp].copy()
112+
df["neglog10q"] = -np.log10(np.clip(df["pvals_adj"].astype(float), 1e-300, 1.0))
113+
df = df.replace([np.inf, -np.inf], np.nan).dropna(subset=["logfoldchanges", "neglog10q"])
114+
# downsample for plotting speed
115+
if len(df) > max_points:
116+
df = df.sample(max_points, random_state=0)
117+
ax.scatter(df["logfoldchanges"], df["neglog10q"], s=6, alpha=0.7, rasterized=True)
118+
ax.set_xlabel("log2 fold change")
119+
ax.set_ylabel("-log10(q)")
120+
ax.set_title(title)
121+
122+
def _hist(ax, x: np.ndarray, bins: int = 40, title: str = ""):
123+
ax.hist(x, bins=bins)
124+
ax.set_title(title)
125+
126+
# ---------------------- Main figure builder ----------------------
127+
128+
def build_microenv_figure(adata: AnnData,
129+
res: Dict,
130+
cluster_id: str,
131+
protein: str,
132+
group_key: str = "cluster",
133+
spatial_obsm: str = "spatial",
134+
obs_xy: Tuple[str,str] = ("x_centroid","y_centroid"),
135+
outdir: str = "./figures",
136+
basename: Optional[str] = None,
137+
figsize_inches: Tuple[float,float] = (7.0, 5.0),
138+
scatter_s: float = 0.5,
139+
scale_bar_um: Optional[float] = 100.0) -> str:
140+
"""
141+
Assemble a 2x3 multi-panel board: A-F. Returns the saved base path (without extension).
142+
"""
143+
os.makedirs(outdir, exist_ok=True)
144+
base = basename or f"Fig_microenv_cluster{cluster_id}_{protein}"
145+
146+
# Panels need:
147+
# - coords
148+
coords = _get_coords(adata, spatial_obsm, obs_xy)
149+
# - status column, protein numeric, enrichment table, coef table, DE table, MI
150+
status_col = res["status_col"]
151+
enrich = res["neighbor_enrichment"]
152+
coef = res["predict_coef"]
153+
de = res["de"]
154+
mi = res["moransI"]
155+
auc = res["predict_auc"]
156+
157+
# Prepare values for panel B
158+
prot_key = "protein_norm" if "protein_norm" in adata.obsm_keys() else "protein"
159+
prot_vals = adata.obsm[prot_key][protein].to_numpy()
160+
161+
# Prepare mask for target cluster (避免绘全图过慢可选子采样)
162+
mask = adata.obs["cluster"].astype(str) == str(cluster_id)
163+
coords_c = coords[mask]
164+
status_c = adata.obs.loc[mask, status_col]
165+
166+
# rc
167+
set_paper_rc()
168+
169+
# Figure & GridSpec
170+
fig = plt.figure(figsize=figsize_inches, constrained_layout=True)
171+
gs = GridSpec(2, 3, figure=fig)
172+
173+
# ------- A: spatial categorical (status in cluster) -------
174+
axA = fig.add_subplot(gs[0,0])
175+
_categorical_scatter(axA, coords_c, status_c, title=f"A {protein} high/low (cluster {cluster_id})", s=scatter_s)
176+
if scale_bar_um is not None:
177+
_draw_scale_bar(axA, coords_c, length_um=scale_bar_um)
178+
179+
# ------- B: spatial numeric (protein level, cluster only) -------
180+
axB = fig.add_subplot(gs[0,1])
181+
prot_c = prot_vals[mask]
182+
vmin, vmax = np.nanpercentile(prot_c, [2, 98])
183+
sca, _ = _rasterized_scatter(axB, coords_c[:,0], coords_c[:,1], prot_c,
184+
title=f"B {protein} level", vmin=vmin, vmax=vmax, s=scatter_s)
185+
if scale_bar_um is not None:
186+
_draw_scale_bar(axB, coords_c, length_um=scale_bar_um)
187+
188+
# ------- C: neighbor enrichment (bars) -------
189+
axC = fig.add_subplot(gs[0,2])
190+
if isinstance(enrich, pd.DataFrame) and not enrich.empty:
191+
# 仅显示显著或top10
192+
dfC = enrich.copy()
193+
dfC["label"] = dfC["neighbor_type"].astype(str)
194+
_barh_with_ci(axC, dfC, value_col="delta_frac_high_minus_low", label_col="label",
195+
top_k=10, title="C Neighbor enrichment (Δfrac High-Low)")
196+
axC.set_xlabel("Δ fraction")
197+
else:
198+
axC.text(0.5, 0.5, "No enrichment", ha="center", va="center")
199+
axC.axis("off")
200+
201+
# ------- D: microenvironment predictability (coef + AUC) -------
202+
axD = fig.add_subplot(gs[1,0])
203+
if isinstance(coef, pd.DataFrame) and not coef.empty:
204+
dfD = coef.copy()
205+
dfD["label"] = dfD["feature"].str.replace("nbr_frac:", "", regex=False)
206+
dfD = dfD.sort_values("coef", ascending=True).tail(12)
207+
axD.barh(dfD["label"], dfD["coef"])
208+
axD.set_title("D Microenvironment coefficients")
209+
axD.set_xlabel("logistic coef")
210+
# annotate AUC
211+
axD.text(0.98, 0.05, f"AUC={auc:.3f}" if np.isfinite(auc) else "AUC=N/A",
212+
ha="right", va="bottom", transform=axD.transAxes)
213+
else:
214+
axD.text(0.5, 0.5, "No model", ha="center", va="center")
215+
axD.axis("off")
216+
217+
# ------- E: volcano (DE within cluster) -------
218+
axE = fig.add_subplot(gs[1,1])
219+
if isinstance(de, pd.DataFrame) and not de.empty:
220+
_volcano(axE, de, title="E RNA DE: protein-high vs low")
221+
else:
222+
axE.text(0.5, 0.5, "No DE", ha="center", va="center")
223+
axE.axis("off")
224+
225+
# ------- F: protein distribution (hist) -------
226+
axF = fig.add_subplot(gs[1,2])
227+
axF.hist([prot_c[status_c == "protein_low"], prot_c[status_c == "protein_high"]],
228+
bins=40, label=["low", "high"], alpha=0.7)
229+
axF.set_title(f"F {protein} distribution")
230+
axF.set_xlabel(f"{protein} (normalized)"); axF.set_ylabel("cells")
231+
axF.legend(frameon=False)
232+
233+
# Suptitle with Moran's I
234+
if isinstance(mi, dict) and "I" in mi:
235+
fig.suptitle(f"Protein microenvironment (cluster {cluster_id}, {protein}) | Moran's I={mi['I']:.3f}, p={mi['p_value']:.2g}",
236+
y=1.02, fontsize=8)
237+
238+
# save
239+
outbase = os.path.join(outdir, base)
240+
fig.savefig(outbase + ".pdf", bbox_inches="tight")
241+
fig.savefig(outbase + ".png", bbox_inches="tight", dpi=600)
242+
# 可选 svg
243+
fig.savefig(outbase + ".svg", bbox_inches="tight")
244+
plt.close(fig)
245+
return outbase

0 commit comments

Comments
 (0)