Skip to content

Commit 94abdf8

Browse files
fix layer parameter issue and introduce chunked sorting
1 parent 01902a6 commit 94abdf8

1 file changed

Lines changed: 58 additions & 3 deletions

File tree

src/netmap/downstream/edge_selection.py

Lines changed: 58 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,52 @@
22
import pandas as pd
33
import numpy as np
44
from collections import Counter
5+
import numpy as np
6+
from scipy.sparse import issparse
7+
8+
def chunked_argsort(adata, layer_name='sorted', chunk_size=500, dtype=None):
9+
"""
10+
Computes np.argsort on adata.X in chunks to save memory.
11+
12+
Parameters:
13+
-----------
14+
adata : AnnData
15+
The AnnData object to process.
16+
layer_name : str
17+
The name of the layer where results will be stored.
18+
chunk_size : int
19+
Number of rows (cells) to process per iteration.
20+
dtype : np.dtype
21+
The integer type for the output. If None, it will automatically
22+
choose uint16 or uint32 based on the number of genes.
23+
"""
24+
n_obs, n_vars = adata.shape
25+
26+
# 1. Automatically determine the smallest safe integer type
27+
if dtype is None:
28+
if n_vars < 65535:
29+
dtype = np.uint16
30+
else:
31+
dtype = np.uint32
32+
33+
# 2. Pre-allocate the layer
34+
adata.layers[layer_name] = np.empty((n_obs, n_vars), dtype=dtype)
35+
36+
# 3. Loop through chunks
37+
for i in range(0, n_obs, chunk_size):
38+
end = min(i + chunk_size, n_obs)
39+
40+
# Pull chunk and densify only if necessary
41+
chunk = adata.X[i:end]
42+
if issparse(chunk):
43+
chunk = chunk.toarray()
44+
45+
# Perform sort and assign
46+
adata.layers[layer_name][i:end] = np.argsort(chunk, axis=1)
47+
48+
print(f"Successfully created layer '{layer_name}' using {dtype}.")
49+
50+
551

652
def _get_top_edges_global(grn_adata, top_edges: float):
753
"""
@@ -21,8 +67,17 @@ def _get_top_edges_global(grn_adata, top_edges: float):
2167
final_df : pd.DataFrame
2268
Processed Anndata object with the counted edges
2369
"""
24-
70+
if not 'sorted' in grn_adata.layers:
71+
try:
72+
chunked_argsort(grn_adata)
73+
except np._core._exceptions._ArrayMemoryError:
74+
print(f"You ran into an issue sorting the array. Please manually sort
75+
the array using chunked_argsort and reduce the chunk size (current default chunk
76+
size: 500)")
77+
2578
b = grn_adata.layers['sorted']
79+
80+
2681
# Calculate partition indices for all top_edges values
2782
top_edges_data_list = [int(np.round(grn_adata.shape[1] * t)) for t in top_edges]
2883
partition_indices = [grn_adata.shape[1]]+[grn_adata.shape[1] - n for n in top_edges_data_list]
@@ -151,7 +206,7 @@ def add_top_edge_annotation_cluster(grn_adata, top_edges = [0.1], nan_fill = 0,
151206

152207
for clu in grn_adata.obs[cluster_var].unique():
153208
grn_adata_sub = grn_adata[grn_adata.obs[cluster_var] == clu]
154-
top_edges_per_cell = _get_top_edges_global(grn_adata_sub, top_edges, layer='X')
209+
top_edges_per_cell = _get_top_edges_global(grn_adata_sub, top_edges)
155210

156211
for te in top_edges:
157212
if f'cell_count_{te}_{clu}' in var.columns:
@@ -198,7 +253,7 @@ def add_top_edge_annotation_global(grn_adata, top_edges = [0.1], nan_fill = 0, k
198253
else:
199254
var = var.reset_index()
200255

201-
top_edges_per_cell = _get_top_edges_global(grn_adata, top_edges, layer='X')
256+
top_edges_per_cell = _get_top_edges_global(grn_adata, top_edges)
202257
for te in top_edges:
203258
if f'{key_name}_cell_count_{te}' in var.columns:
204259
continue

0 commit comments

Comments
 (0)