Skip to content

Commit 7cc7504

Browse files
committed
a
1 parent 8a23696 commit 7cc7504

File tree

1 file changed

+76
-61
lines changed

1 file changed

+76
-61
lines changed

src/pyXenium/io/partial_xenium_loader.py

Lines changed: 76 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
"""
2-
pyXenium.io.partial_xenium_loader_v2
3-
------------------------------------
2+
pyXenium.io.partial_xenium_loader
3+
---------------------------------
44
5-
Purpose
6-
=======
75
Load an AnnData object when you *don't* have a full Xenium `out/` folder, by
86
stitching together any subset of:
97
@@ -12,16 +10,13 @@
1210
- `cells.zarr` or `cells.zarr.zip` (cell centroids / spatial coords)
1311
- `transcripts.zarr` or `transcripts.zarr.zip` (per-gene transcript locations)
1412
15-
New capabilities (2025-09-22)
16-
-----------------------------
17-
- **Common base path** support: pass `base_dir` (local folder) or `base_url` (HTTP/HTTPS)
18-
plus filename knobs: `analysis_name`, `cells_name`, `transcripts_name`.
19-
- **Remote loading**: if a URL is provided for zarr (or via `base_url+name`), the file
20-
is downloaded to a temporary location transparently before reading (so `ZipStore`
21-
can seek). For MEX, you can pass a `mex_dir` URL (base) together with
22-
`mex_matrix_name/mex_features_name/mex_barcodes_name`.
13+
Updates (2025-09-24)
14+
--------------------
15+
- FIX: include root-level "cell_id" in search keys for all zarr readers.
16+
- FEAT: support `(N,2)` numeric cell_id by taking first column and normalizing to "cell_{num}".
17+
- Keep all previous APIs and behaviors.
2318
24-
Author: Taobo Hu (pyXenium project) — 2025-09-22
19+
Author: Taobo Hu (pyXenium project)
2520
"""
2621
from __future__ import annotations
2722

@@ -70,11 +65,9 @@ def _is_url(p: Optional[str | os.PathLike]) -> bool:
7065

7166

7267
def _fetch_to_temp(src: str, suffix: Optional[str] = None) -> Path:
73-
"""Download a URL to a temporary file and return its Path.
74-
Ensures a local, seekable file for ZipStore and mmread.
75-
"""
68+
"""Download a URL to a temporary file and return its Path."""
7669
logger.info(f"Downloading: {src}")
77-
r = requests.get(src, stream=True, timeout=60)
70+
r = requests.get(src, stream=True, timeout=120)
7871
r.raise_for_status()
7972
tmpdir = Path(tempfile.mkdtemp(prefix="pyxenium_"))
8073
name = Path(urllib.parse.urlparse(src).path).name or ("tmp" + (suffix or ""))
@@ -110,7 +103,7 @@ class PartialInputs:
110103

111104
def _load_mex(mex_dir: Path | str) -> AnnData:
112105
logger.info(f"Reading MEX from {mex_dir}")
113-
base = Path(mex_dir) if not _is_url(mex_dir) else Path(mex_dir) # URL case is pre-downloaded
106+
base = Path(mex_dir) if not _is_url(mex_dir) else Path(mex_dir) # URL case pre-downloaded
114107

115108
candidates = {
116109
"matrix": ["matrix.mtx", "matrix.mtx.gz"],
@@ -123,6 +116,7 @@ def _find(names: Sequence[str], base: Path) -> Optional[Path]:
123116
p = base / name
124117
if p.exists():
125118
return p
119+
# also scan one-level subfolders
126120
for child in base.iterdir():
127121
if child.is_dir():
128122
for name in names:
@@ -134,7 +128,6 @@ def _find(names: Sequence[str], base: Path) -> Optional[Path]:
134128
mtx_p = _find(candidates["matrix"], base)
135129
feat_p = _find(candidates["features"], base)
136130
barc_p = _find(candidates["barcodes"], base)
137-
138131
if not (mtx_p and feat_p and barc_p):
139132
missing = [k for k, v in {"matrix": mtx_p, "features": feat_p, "barcodes": barc_p}.items() if v is None]
140133
raise FileNotFoundError(f"MEX missing files: {', '.join(missing)} under {base}")
@@ -186,19 +179,16 @@ def _open_zarr(path: Path | str):
186179
path = _fetch_to_temp(pstr)
187180
pstr = str(path)
188181

189-
# 打开 *.zarr.zip:只读 ZipStore + 只读 group
182+
# *.zarr.zip
190183
if pstr.endswith(".zip"):
191184
store = ZipStore(pstr, mode="r")
192-
# Zarr v3 优先
193185
if hasattr(zarr, "open_group"):
194186
return zarr.open_group(store=store, mode="r")
195-
# 兼容旧版 Zarr v2
196187
if hasattr(zarr, "open"):
197188
return zarr.open(store, mode="r")
198-
# 最后兜底(极旧环境)
199189
return zarr.group(store=store, mode="r")
200190

201-
# 打开目录/文件路径:优先 v3 API,其次 v2,再兜底
191+
# directory *.zarr
202192
if hasattr(zarr, "open_group"):
203193
return zarr.open_group(pstr, mode="r")
204194
if hasattr(zarr, "open"):
@@ -219,6 +209,28 @@ def _first_available(root, candidates: Sequence[str]) -> Optional[np.ndarray]:
219209
return None
220210

221211

212+
def _normalize_cell_ids(arr: np.ndarray) -> Tuple[pd.Index, Optional[np.ndarray]]:
213+
"""
214+
Normalize a cell_id array to string index.
215+
Returns:
216+
- obs index (strings)
217+
- numeric_ids (np.ndarray) if we originated from numeric ids, else None
218+
"""
219+
if arr.dtype.kind in ("i", "u"):
220+
# numeric: accept both 1D and 2D
221+
if arr.ndim == 2:
222+
nums = arr[:, 0]
223+
else:
224+
nums = arr
225+
obs_names = pd.Index([f"cell_{int(x)}" for x in nums], name="cell_id")
226+
return obs_names, nums.astype(np.int64)
227+
else:
228+
# bytes/str/object: convert to str
229+
s = pd.Series(arr)
230+
obs_names = pd.Index(s.astype(str), name="cell_id")
231+
return obs_names, None
232+
233+
222234
# -----------------------------
223235
# Attach spatial from cells.zarr
224236
# -----------------------------
@@ -252,21 +264,21 @@ def _attach_spatial(adata: AnnData, cells_zarr: Path | str) -> None:
252264
],
253265
)
254266
z = _first_available(root, ["cells/centroids/z", "centroids/z", "cells/centroid_z", "z"]) # optional
255-
cell_ids = _first_available(root, ["cells/cell_id", "cells/ids", "cell_ids", "ids", "barcodes"]) # optional
256-
if cell_ids is not None:
257-
if cell_ids.dtype.kind in ("i", "u"):
258-
cell_ids = np.char.add("cell_", cell_ids.astype(str))
259-
else:
260-
cell_ids = cell_ids.astype(str)
267+
268+
# **FIX**: include root-level "cell_id"
269+
cell_ids_raw = _first_available(
270+
root, ["cell_id", "cells/cell_id", "cells/ids", "cell_ids", "ids", "barcodes"]
271+
)
261272

262273
if x is None or y is None:
263274
logger.warning("Could not locate centroid x/y in cells.zarr; skipping spatial attach.")
264275
return
265276

266277
coords = np.column_stack([x, y]) if z is None else np.column_stack([x, y, z])
267278

268-
if cell_ids is not None:
269-
df = pd.DataFrame(coords, index=pd.Index(cell_ids, name="cell_id"))
279+
if cell_ids_raw is not None:
280+
idx, _ = _normalize_cell_ids(cell_ids_raw)
281+
df = pd.DataFrame(coords, index=idx)
270282
try:
271283
coords_df = df.reindex(adata.obs.index)
272284
except Exception:
@@ -305,12 +317,9 @@ def _attach_clusters(adata: AnnData, analysis_zarr: Path | str, cluster_key: str
305317
)
306318
names = _first_available(root, ["clusters/names", "clustering/names", "names", "cluster_names"])
307319
ids = _first_available(root, ["clusters/ids", "clustering/ids", "ids", "cluster_ids"])
308-
cell_ids = _first_available(root, ["cells/cell_id", "cell_ids", "barcodes", "cells/ids"]) # optional
309-
if cell_ids is not None:
310-
if cell_ids.dtype.kind in ("i", "u"):
311-
cell_ids = np.char.add("cell_", cell_ids.astype(str))
312-
else:
313-
cell_ids = cell_ids.astype(str)
320+
321+
# **FIX**: include root-level "cell_id"
322+
cell_ids_raw = _first_available(root, ["cell_id", "cells/cell_id", "cell_ids", "barcodes", "cells/ids"])
314323

315324
if label_arr is None:
316325
logger.warning("No cluster labels found in analysis.zarr; skipping.")
@@ -321,8 +330,9 @@ def _attach_clusters(adata: AnnData, analysis_zarr: Path | str, cluster_key: str
321330
mapping = {str(i): str(n) for i, n in zip(ids, names)}
322331
labels = np.array([mapping.get(str(x), str(x)) for x in label_arr], dtype=object)
323332

324-
if cell_ids is not None and len(cell_ids) == len(labels):
325-
s = pd.Series(labels, index=pd.Index(cell_ids, name="cell_id"))
333+
if cell_ids_raw is not None and len(cell_ids_raw) == len(labels):
334+
idx, _ = _normalize_cell_ids(cell_ids_raw)
335+
s = pd.Series(labels, index=idx)
326336
try:
327337
adata.obs[cluster_key] = s.reindex(adata.obs.index).astype("category")
328338
except Exception:
@@ -347,7 +357,10 @@ def _counts_from_transcripts(transcripts_zarr: Path | str, cell_id_index: pd.Ind
347357
root = _open_zarr(transcripts_zarr)
348358

349359
gene = _first_available(root, ["transcripts/gene", "transcripts/genes", "gene", "genes"]) # num or str
350-
cell = _first_available(root, ["transcripts/cell_id", "transcripts/cells", "cell_id", "cell"]) # str or num
360+
# **FIX**: include root-level "cell_id"
361+
cell = _first_available(
362+
root, ["transcripts/cell_id", "transcripts/cells", "cell_id", "cell", "cells/cell_id"]
363+
)
351364

352365
if gene is None or cell is None:
353366
raise KeyError("Could not locate transcript gene/cell arrays in transcripts.zarr")
@@ -366,6 +379,13 @@ def _counts_from_transcripts(transcripts_zarr: Path | str, cell_id_index: pd.Ind
366379
df = pd.DataFrame({"gene": pd.Categorical(gene), "cell": pd.Categorical(cell)})
367380
df = df[df["cell"].isin(cell_id_index)]
368381

382+
if df.empty:
383+
# 允许空上游,返回空矩阵但维度对齐
384+
n_cells = len(cell_id_index)
385+
X = sparse.csr_matrix((0, n_cells), dtype=np.float32)
386+
gene_index = pd.Index([], name="feature_id")
387+
return X, gene_index
388+
369389
gi = df["gene"].cat.codes.to_numpy()
370390
ci = df["cell"].cat.codes.to_numpy()
371391
data = np.ones_like(gi, dtype=np.int32)
@@ -408,12 +428,10 @@ def load_anndata_from_partial(
408428
) -> AnnData:
409429
"""Create an AnnData from any combination of partial Xenium artifacts.
410430
411-
You can pass explicit paths, or a common ``base_dir`` / ``base_url`` plus the
412-
default filenames. Explicit paths take precedence over base resolution.
431+
Explicit paths take precedence over base resolution with (base_dir|base_url)+name.
413432
"""
414433
# Resolve MEX
415434
if mex_dir is not None and _is_url(mex_dir):
416-
# Download the triplet next to each other and point to that temp dir
417435
url_base = str(mex_dir).rstrip("/")
418436
m_p = _fetch_to_temp(url_base + "/" + mex_matrix_name)
419437
_ = _fetch_to_temp(url_base + "/" + mex_features_name)
@@ -448,43 +466,40 @@ def _resolve_zarr(explicit: Optional[os.PathLike | str], name: str) -> Optional[
448466
cells_p = _resolve_zarr(cells_zarr, cells_name)
449467
transcripts_p = _resolve_zarr(transcripts_zarr, transcripts_name)
450468

451-
# Start with counts
469+
# Start with counts or at least obs index
452470
if mex_dir_p is not None:
453471
adata = _load_mex(mex_dir_p)
454472
else:
455-
# create obs index from any available source
473+
# Discover cell ids from any available zarr
456474
cell_ids: Optional[pd.Index] = None
475+
probe_keys = ["cell_id", "cells/cell_id", "cell_ids", "barcodes", "cells/ids", "cell"]
457476
for p in (cells_p, analysis_p, transcripts_p):
458477
if p is None:
459478
continue
460479
try:
461480
root = _open_zarr(p)
462481
except Exception:
463482
continue
464-
arr = _first_available(root, ["cells/cell_id", "cell_ids", "barcodes", "cells/ids", "cell"]) # type: ignore
483+
arr = _first_available(root, probe_keys) # include root "cell_id"
465484
if arr is not None:
466-
if arr.dtype.kind in ("i", "u"):
467-
if arr.ndim == 2:
468-
ids_numeric = arr[:, 0]
469-
else:
470-
ids_numeric = arr
471-
cell_ids = pd.Index([f"cell_{int(x)}" for x in ids_numeric], name="cell_id")
472-
else:
473-
cell_ids = pd.Index(pd.Series(arr).astype(str), name="cell_id")
485+
idx, _ = _normalize_cell_ids(arr)
486+
cell_ids = idx
474487
break
488+
489+
# As a fallback, deduce from transcripts unique cells
475490
if cell_ids is None and transcripts_p is not None:
476491
root = _open_zarr(transcripts_p)
477-
c = _first_available(root, ["transcripts/cell_id", "transcripts/cells", "cell_id", "cell"]) # type: ignore
492+
c = _first_available(root, ["transcripts/cell_id", "transcripts/cells", "cell_id", "cell", "cells/cell_id"])
478493
if c is not None:
479494
if c.dtype.kind in ("i", "u"):
480-
unique_nums = pd.unique(pd.Series(c))
481-
cell_ids_list = [f"cell_{int(x)}" for x in unique_nums]
482-
cell_ids = pd.Index(cell_ids_list, name="cell_id")
495+
cell_ids = pd.Index([f"cell_{int(x)}" for x in pd.unique(pd.Series(c))], name="cell_id")
483496
else:
484497
cell_ids = pd.Index(pd.unique(pd.Series(c).astype(str)), name="cell_id")
498+
485499
if cell_ids is None:
486500
raise ValueError("Could not determine cell IDs; provide MEX or any zarr with cell ids.")
487501

502+
# Build counts from transcripts if asked & available
488503
if build_counts_if_missing and transcripts_p is not None:
489504
X, gene_index = _counts_from_transcripts(transcripts_p, cell_ids)
490505
obs = pd.DataFrame(index=cell_ids)
@@ -501,7 +516,7 @@ def _resolve_zarr(explicit: Optional[os.PathLike | str], name: str) -> Optional[
501516
if analysis_p is not None:
502517
try:
503518
_attach_clusters(adata, analysis_p, cluster_key=cluster_key)
504-
if not keep_unassigned:
519+
if not keep_unassigned and cluster_key in adata.obs:
505520
bad = {"-1", "NA", "None", "Unassigned", "unassigned"}
506521
mask = ~adata.obs[cluster_key].astype(str).isin(bad)
507522
dropped = int((~mask).sum())
@@ -575,7 +590,7 @@ def cli_import_partial(
575590
adata.write_h5ad(output_h5ad)
576591
typer.echo(f"Wrote {output_h5ad} (n_cells={adata.n_obs}, n_genes={adata.n_vars})")
577592

578-
except Exception: # Typer not installed
593+
except Exception:
579594
app = None # type: ignore
580595

581596

0 commit comments

Comments
 (0)