Skip to content

Commit c4e9868

Browse files
add regulon, update downstream
1 parent ba719ac commit c4e9868

2 files changed

Lines changed: 171 additions & 64 deletions

File tree

src/netmap/downstream/final_downstream.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,8 +176,8 @@ def make_cluster_regulon_dataframe(keep_edges):
176176
for clu in keep_edges[un] :
177177
df = keep_edges[un][clu]['edges']
178178
if df.shape[0]>0:
179-
df['cluster'] = clu
180-
df['set_type'] = un
179+
df['cluster'] = un
180+
df['set_type'] = clu
181181
all_regulons.append(df)
182182
all_regulons = pd.concat(all_regulons)
183183
return all_regulons

src/netmap/downstream/regulon.py

Lines changed: 169 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -12,52 +12,30 @@
1212
import scanpy as sc
1313
import seaborn as sns
1414
from scipy.stats import pearsonr, ranksums
15+
#from scipy.stats import mannwhitneyu
1516
from statsmodels.stats.multitest import multipletests
1617
import networkx as nx
1718
import requests
1819
from pyvis.network import Network
1920
import pyucell as ucell
2021

2122

23+
from netmap.downstream.clustering import process, spectral_clustering, downstream_recipe
24+
from netmap.downstream.edge_selection import add_top_edge_annotation_global
2225

2326

2427
from itertools import combinations
2528
from collections import Counter
2629

2730

2831
def select_top_edges(gene_inter_adata, adata, top_per_source=10, col_cluster='leiden_remap', min_reg_size=10, verbose=True, return_copy = False, tf_column=None):
29-
"""
30-
Selects top gene targets per source from a clustered gene interaction AnnData.
31-
32-
Parameters
33-
----------
34-
gene_inter_adata : AnnData
35-
Gene interaction AnnData with `var` containing 'source' and 'target'.
36-
adata : AnnData
37-
Expression AnnData for ranking genes.
38-
top_per_source : int, default=750
39-
Number of top targets to select per source.
40-
col_cluster : str, default='spectral'
41-
Column in obs defining clusters.grn_adata3.var
42-
43-
Returns
44-
-------
45-
gene_inter_adata_filtered : AnnData
46-
Filtered AnnData containing top edges.
47-
reglon_sizes : list of int
48-
Sizes of regulatory regions per source.
49-
50-
"""
5132

5233
min_edge_support = 0.5
53-
5434
clusters = list(np.unique(gene_inter_adata.obs[col_cluster]))
55-
5635
keep_edges_dict = {}
5736

58-
5937
for c in clusters:
60-
Keep_edges = []
38+
Keep_edges = [] # Now will store tuples: (edge_name, sum_val)
6139
if verbose: print(f"Selecting targets for cluster: {c}")
6240

6341
cells_c = gene_inter_adata.obs[col_cluster] == c
@@ -67,66 +45,68 @@ def select_top_edges(gene_inter_adata, adata, top_per_source=10, col_cluster='le
6745

6846
if tf_column is not None:
6947
tfs = gene_inter_adata.var[gene_inter_adata.var[tf_column]]['source'].unique()
70-
source_list = set(gene_inter_adata.var["source"].unique()).intersection(set(tfs))
71-
48+
source_list = set(gene_inter_adata.var["source"].unique()).intersection(set(tfs))
7249
else:
7350
source_list = gene_inter_adata.var["source"].unique()
7451

7552
for source in source_list:
76-
7753
df_targets = gene_inter_adata.var[
7854
(gene_inter_adata.var['source'] == source) &
79-
(gene_inter_adata.var['edge_support_c'] >= min_edge_support)]
80-
#print(gene_inter_adata[cells_c, gene_inter_adata[cells_c, df_targets.index].X.sum(axis = 0)].X.sum(axis = 0))
55+
(gene_inter_adata.var['edge_support_c'] >= min_edge_support)].copy()
8156

82-
df_targets['sum_of_edge'] = gene_inter_adata[cells_c, df_targets.index].X.sum(axis = 0)
57+
# Calculate sum and sort
58+
df_targets['sum_of_edge'] = gene_inter_adata[cells_c, df_targets.index].X.mean(axis=0)
8359
df_targets = df_targets.sort_values('sum_of_edge', ascending=False).head(top_per_source)
8460

8561
if len(df_targets) >= min_reg_size:
86-
Keep_edges.extend(f"{source}_{t}" for t in df_targets['target'])
62+
for _, row in df_targets.iterrows():
63+
edge_str = f"{source}_{row['target']}"
64+
Keep_edges.append((edge_str, row['sum_of_edge']))
8765

8866
keep_edges_dict[c] = Keep_edges
89-
keep_edges_dict = process_cell_edges(keep_edges_dict)
90-
return keep_edges_dict
67+
68+
return process_cell_edges(keep_edges_dict)
9169

9270

93-
def process_cell_edges(keep_edges):
71+
def process_cell_edges(keep_edges_with_vals):
9472
results = {'unique': {}, 'all': {}}
95-
all_cells = list(keep_edges.keys())
73+
all_cells = list(keep_edges_with_vals.keys())
9674

97-
98-
def get_source_summary(edge_set):
99-
# Handles (source, target) tuples OR strings with a separator like '->'
100-
sources = []
101-
for e in edge_set:
102-
sources.append(e.split('_')[0])
103-
75+
def get_source_summary(edge_list):
76+
# edge_list is list of (name, val)
77+
sources = [e[0].split('_')[0] for e in edge_list]
10478
source_dict = dict(Counter(sources))
105-
sources = pd.DataFrame({'source' :source_dict.keys(), 'count': source_dict.values()}).sort_values('count', ascending=False)
106-
return sources
79+
sources_df = pd.DataFrame({'source': list(source_dict.keys()), 'count': list(source_dict.values())}).sort_values('count', ascending=False)
80+
return sources_df
10781

108-
# Calculate Uniques
10982
for cell in all_cells:
110-
others = set().union(*(set(keep_edges[c]) for c in all_cells if c != cell))
111-
unique = set(keep_edges[cell]) - others
83+
# Convert to dict for easy lookup by edge name string
84+
current_edges_dict = {name: val for name, val in keep_edges_with_vals[cell]}
85+
86+
# Calculate Uniques based on the edge name string
87+
others_names = set()
88+
for c in all_cells:
89+
if c != cell:
90+
others_names.update([e[0] for e in keep_edges_with_vals[c]])
91+
92+
unique_names = set(current_edges_dict.keys()) - others_names
11293

113-
df = pd.DataFrame(
114-
[e.split('_', 1) for e in unique],
115-
columns=['source', 'target']
116-
)
117-
df_all = pd.DataFrame(
118-
[e.split('_', 1) for e in set(keep_edges[cell])],
119-
columns=['source', 'target']
120-
)
94+
# Helper to build DF with sum column
95+
def build_df(name_set, lookup_dict):
96+
data = []
97+
for name in name_set:
98+
source, target = name.split('_', 1)
99+
data.append([source, target, lookup_dict[name]])
100+
return pd.DataFrame(data, columns=['source', 'target', 'sum_of_edge'])
121101

122102
results['unique'][cell] = {
123-
'edges': df,
124-
'summary': get_source_summary(unique)
103+
'edges': build_df(unique_names, current_edges_dict),
104+
'summary': get_source_summary([(n, current_edges_dict[n]) for n in unique_names])
125105
}
126106

127107
results['all'][cell] = {
128-
'edges': df_all,
129-
'summary': get_source_summary(set(keep_edges[cell]))
108+
'edges': build_df(current_edges_dict.keys(), current_edges_dict),
109+
'summary': get_source_summary(keep_edges_with_vals[cell])
130110
}
131111

132112
return results
@@ -184,7 +164,134 @@ def aggregate_edges(selected_edges, grn_adata, key='unique') -> pd.DataFrame:
184164
print(ct)
185165
sign = selected_edges[key][ct]['edges'].groupby('source').apply(lambda x: (x['source'] + "_" + x['target']).tolist()).to_dict()
186166
for g in sign:
187-
regulons[f'{ct}_{g}'] = grn_adata[:, sign[g]].X.sum(axis = 1)/len(sign[g])
167+
regulons[f'{ct}_{g}'] = grn_adata[:, sign[g]].X.mean(axis = 1)
188168
regulons = pd.DataFrame(regulons)
189169
return regulons
170+
171+
172+
173+
def make_cluster_regulon_dataframe(keep_edges):
174+
all_regulons = []
175+
for un in keep_edges:
176+
for clu in keep_edges[un] :
177+
df = keep_edges[un][clu]['edges']
178+
if df.shape[0]>0:
179+
df['cluster'] = un
180+
df['set_type'] = clu
181+
all_regulons.append(df)
182+
all_regulons = pd.concat(all_regulons)
183+
return all_regulons
184+
190185

186+
187+
def jaccard_similarity(set1, set2):
188+
intersection = len(set1.intersection(set2))
189+
union = len(set1.union(set2))
190+
return intersection / union if union > 0 else 0
191+
192+
193+
def get_sourcewise_jaccard_regulons(all_signatures, keep_edges, n_top = 50):
194+
top_sources = {}
195+
top_counter = {}
196+
for ct in grn_adata3.obs.leiden_remap.unique():
197+
print(ct)
198+
try:
199+
bcrank = sc.get.rank_genes_groups_df(all_signatures, group= ct, key='wilcoxon')
200+
bcrank[['celltype', 'gene']] = bcrank['names'].str.rsplit('_', n=1, expand=True)
201+
topg = set(bcrank[0:n_top].gene)
202+
203+
for g in topg:
204+
re = keep_edges['all'][ct]['edges']
205+
targets = list(re[re.source == g].target)
206+
if g in top_sources:
207+
top_sources[g][ct] = targets
208+
top_counter[g] +=1
209+
else:
210+
top_sources[g] = {ct: targets}
211+
top_counter[g] = 1
212+
except:
213+
pass
214+
215+
# Dictionary to store the final DataFrames
216+
gene_matrices = {}
217+
218+
for g, celltype_dict in top_sources.items():
219+
# Only process if the gene appears in more than 1 celltype
220+
if len(celltype_dict) < 2:
221+
continue
222+
223+
results = []
224+
celltypes = sorted(celltype_dict.keys())
225+
226+
for s1 in celltypes:
227+
set1 = set(celltype_dict[s1])
228+
for s2 in celltypes:
229+
set2 = set(celltype_dict[s2])
230+
231+
# Calculate Jaccard
232+
intersection = len(set1.intersection(set2))
233+
union = len(set1.union(set2))
234+
sim = intersection / union if union > 0 else 0
235+
236+
results.append({
237+
'ct1': s1,
238+
'ct2': s2,
239+
'jaccard': sim
240+
})
241+
242+
# Convert to DataFrame and Pivot to Square Matrix
243+
df_long = pd.DataFrame(results)
244+
matrix = df_long.pivot(index='ct1', columns='ct2', values='jaccard')
245+
246+
gene_matrices[g] = matrix
247+
248+
return gene_matrices
249+
250+
251+
def make_global_target_similarity_plot(gene_matrices):
252+
# 1. Stack all matrices and calculate the mean
253+
# We use .values to ensure we are averaging the numbers,
254+
# but keep the index/columns from one of the matrices.
255+
all_mats = list(gene_matrices.values())
256+
257+
if len(all_mats) > 0:
258+
# Use reduce or concat to get the average
259+
global_matrix = pd.concat(all_mats).groupby(level=0).mean()
260+
# Ensure the columns are in the same order as the index for a perfect square
261+
global_matrix = global_matrix[global_matrix.index]
262+
else:
263+
print("No matrices found to average.")
264+
265+
# Independent row and column clustering
266+
row_idx = hierarchy.leaves_list(hierarchy.linkage(pdist(global_matrix.fillna(0)), method='ward'))
267+
ordered_global = global_matrix.iloc[row_idx, row_idx]
268+
269+
# 3. Plot
270+
fig, ax = plt.subplots(figsize=(7, 8))
271+
272+
sns.heatmap(
273+
ordered_global,
274+
mask=(ordered_global == 0),
275+
cmap='YlGnBu',
276+
square=True,
277+
linewidths=.5,
278+
linecolor='#eeeeee',
279+
ax=ax,
280+
cbar_kws={"shrink": 0.2, "orientation": "horizontal", "label": "Mean Jaccard"},
281+
annot=False
282+
)
283+
284+
# Move Legend to lower left
285+
cbar = ax.collections[0].colorbar
286+
cbar.ax.set_position([0.15, 0.05, 0.2, 0.015])
287+
288+
# Formatting
289+
ax.tick_params(axis='both', which='major', pad=0.5, length=0)
290+
ax.set_xticklabels(ax.get_xticklabels(), rotation=90, ha='center', fontsize=9)
291+
ax.set_yticklabels(ax.get_yticklabels(), rotation=0, fontsize=9)
292+
ax.set_xlabel("")
293+
ax.set_ylabel("")
294+
ax.set_title('Global Target Similarity (Average across all Source Genes)', pad=25, fontweight='bold')
295+
296+
plt.subplots_adjust(bottom=0.25, left=0.25)
297+
plt.show()

0 commit comments

Comments
 (0)