1212import scanpy as sc
1313import seaborn as sns
1414from scipy .stats import pearsonr , ranksums
15+ #from scipy.stats import mannwhitneyu
1516from statsmodels .stats .multitest import multipletests
1617import networkx as nx
1718import requests
1819from pyvis .network import Network
1920import pyucell as ucell
2021
2122
23+ from netmap .downstream .clustering import process , spectral_clustering , downstream_recipe
24+ from netmap .downstream .edge_selection import add_top_edge_annotation_global
2225
2326
2427from itertools import combinations
2528from collections import Counter
2629
2730
2831def select_top_edges (gene_inter_adata , adata , top_per_source = 10 , col_cluster = 'leiden_remap' , min_reg_size = 10 , verbose = True , return_copy = False , tf_column = None ):
29- """
30- Selects top gene targets per source from a clustered gene interaction AnnData.
31-
32- Parameters
33- ----------
34- gene_inter_adata : AnnData
35- Gene interaction AnnData with `var` containing 'source' and 'target'.
36- adata : AnnData
37- Expression AnnData for ranking genes.
38- top_per_source : int, default=750
39- Number of top targets to select per source.
40- col_cluster : str, default='spectral'
41- Column in obs defining clusters.grn_adata3.var
42-
43- Returns
44- -------
45- gene_inter_adata_filtered : AnnData
46- Filtered AnnData containing top edges.
47- reglon_sizes : list of int
48- Sizes of regulatory regions per source.
49-
50- """
5132
5233 min_edge_support = 0.5
53-
5434 clusters = list (np .unique (gene_inter_adata .obs [col_cluster ]))
55-
5635 keep_edges_dict = {}
5736
58-
5937 for c in clusters :
60- Keep_edges = []
38+ Keep_edges = [] # Now will store tuples: (edge_name, sum_val)
6139 if verbose : print (f"Selecting targets for cluster: { c } " )
6240
6341 cells_c = gene_inter_adata .obs [col_cluster ] == c
@@ -67,66 +45,68 @@ def select_top_edges(gene_inter_adata, adata, top_per_source=10, col_cluster='le
6745
6846 if tf_column is not None :
6947 tfs = gene_inter_adata .var [gene_inter_adata .var [tf_column ]]['source' ].unique ()
70- source_list = set (gene_inter_adata .var ["source" ].unique ()).intersection (set (tfs ))
71-
48+ source_list = set (gene_inter_adata .var ["source" ].unique ()).intersection (set (tfs ))
7249 else :
7350 source_list = gene_inter_adata .var ["source" ].unique ()
7451
7552 for source in source_list :
76-
7753 df_targets = gene_inter_adata .var [
7854 (gene_inter_adata .var ['source' ] == source ) &
79- (gene_inter_adata .var ['edge_support_c' ] >= min_edge_support )]
80- #print(gene_inter_adata[cells_c, gene_inter_adata[cells_c, df_targets.index].X.sum(axis = 0)].X.sum(axis = 0))
55+ (gene_inter_adata .var ['edge_support_c' ] >= min_edge_support )].copy ()
8156
82- df_targets ['sum_of_edge' ] = gene_inter_adata [cells_c , df_targets .index ].X .sum (axis = 0 )
57+ # Calculate sum and sort
58+ df_targets ['sum_of_edge' ] = gene_inter_adata [cells_c , df_targets .index ].X .mean (axis = 0 )
8359 df_targets = df_targets .sort_values ('sum_of_edge' , ascending = False ).head (top_per_source )
8460
8561 if len (df_targets ) >= min_reg_size :
86- Keep_edges .extend (f"{ source } _{ t } " for t in df_targets ['target' ])
62+ for _ , row in df_targets .iterrows ():
63+ edge_str = f"{ source } _{ row ['target' ]} "
64+ Keep_edges .append ((edge_str , row ['sum_of_edge' ]))
8765
8866 keep_edges_dict [c ] = Keep_edges
89- keep_edges_dict = process_cell_edges ( keep_edges_dict )
90- return keep_edges_dict
67+
68+ return process_cell_edges ( keep_edges_dict )
9169
9270
93- def process_cell_edges (keep_edges ):
71+ def process_cell_edges (keep_edges_with_vals ):
9472 results = {'unique' : {}, 'all' : {}}
95- all_cells = list (keep_edges .keys ())
73+ all_cells = list (keep_edges_with_vals .keys ())
9674
97-
98- def get_source_summary (edge_set ):
99- # Handles (source, target) tuples OR strings with a separator like '->'
100- sources = []
101- for e in edge_set :
102- sources .append (e .split ('_' )[0 ])
103-
75+ def get_source_summary (edge_list ):
76+ # edge_list is list of (name, val)
77+ sources = [e [0 ].split ('_' )[0 ] for e in edge_list ]
10478 source_dict = dict (Counter (sources ))
105- sources = pd .DataFrame ({'source' : source_dict .keys (), 'count' : source_dict .values ()}).sort_values ('count' , ascending = False )
106- return sources
79+ sources_df = pd .DataFrame ({'source' : list ( source_dict .keys ()) , 'count' : list ( source_dict .values () )}).sort_values ('count' , ascending = False )
80+ return sources_df
10781
108- # Calculate Uniques
10982 for cell in all_cells :
110- others = set ().union (* (set (keep_edges [c ]) for c in all_cells if c != cell ))
111- unique = set (keep_edges [cell ]) - others
83+ # Convert to dict for easy lookup by edge name string
84+ current_edges_dict = {name : val for name , val in keep_edges_with_vals [cell ]}
85+
86+ # Calculate Uniques based on the edge name string
87+ others_names = set ()
88+ for c in all_cells :
89+ if c != cell :
90+ others_names .update ([e [0 ] for e in keep_edges_with_vals [c ]])
91+
92+ unique_names = set (current_edges_dict .keys ()) - others_names
11293
113- df = pd .DataFrame (
114- [e .split ('_' , 1 ) for e in unique ],
115- columns = ['source' , 'target' ]
116- )
117- df_all = pd .DataFrame (
118- [e .split ('_' , 1 ) for e in set (keep_edges [cell ])],
119- columns = ['source' , 'target' ]
120- )
94+ # Helper to build DF with sum column
95+ def build_df (name_set , lookup_dict ):
96+ data = []
97+ for name in name_set :
98+ source , target = name .split ('_' , 1 )
99+ data .append ([source , target , lookup_dict [name ]])
100+ return pd .DataFrame (data , columns = ['source' , 'target' , 'sum_of_edge' ])
121101
122102 results ['unique' ][cell ] = {
123- 'edges' : df ,
124- 'summary' : get_source_summary (unique )
103+ 'edges' : build_df ( unique_names , current_edges_dict ) ,
104+ 'summary' : get_source_summary ([( n , current_edges_dict [ n ]) for n in unique_names ] )
125105 }
126106
127107 results ['all' ][cell ] = {
128- 'edges' : df_all ,
129- 'summary' : get_source_summary (set ( keep_edges [cell ]) )
108+ 'edges' : build_df ( current_edges_dict . keys (), current_edges_dict ) ,
109+ 'summary' : get_source_summary (keep_edges_with_vals [cell ])
130110 }
131111
132112 return results
@@ -184,7 +164,134 @@ def aggregate_edges(selected_edges, grn_adata, key='unique') -> pd.DataFrame:
184164 print (ct )
185165 sign = selected_edges [key ][ct ]['edges' ].groupby ('source' ).apply (lambda x : (x ['source' ] + "_" + x ['target' ]).tolist ()).to_dict ()
186166 for g in sign :
187- regulons [f'{ ct } _{ g } ' ] = grn_adata [:, sign [g ]].X .sum (axis = 1 ) / len ( sign [ g ] )
167+ regulons [f'{ ct } _{ g } ' ] = grn_adata [:, sign [g ]].X .mean (axis = 1 )
188168 regulons = pd .DataFrame (regulons )
189169 return regulons
170+
171+
172+
173+ def make_cluster_regulon_dataframe (keep_edges ):
174+ all_regulons = []
175+ for un in keep_edges :
176+ for clu in keep_edges [un ] :
177+ df = keep_edges [un ][clu ]['edges' ]
178+ if df .shape [0 ]> 0 :
179+ df ['cluster' ] = un
180+ df ['set_type' ] = clu
181+ all_regulons .append (df )
182+ all_regulons = pd .concat (all_regulons )
183+ return all_regulons
184+
190185
186+
187+ def jaccard_similarity (set1 , set2 ):
188+ intersection = len (set1 .intersection (set2 ))
189+ union = len (set1 .union (set2 ))
190+ return intersection / union if union > 0 else 0
191+
192+
193+ def get_sourcewise_jaccard_regulons (all_signatures , keep_edges , n_top = 50 ):
194+ top_sources = {}
195+ top_counter = {}
196+ for ct in grn_adata3 .obs .leiden_remap .unique ():
197+ print (ct )
198+ try :
199+ bcrank = sc .get .rank_genes_groups_df (all_signatures , group = ct , key = 'wilcoxon' )
200+ bcrank [['celltype' , 'gene' ]] = bcrank ['names' ].str .rsplit ('_' , n = 1 , expand = True )
201+ topg = set (bcrank [0 :n_top ].gene )
202+
203+ for g in topg :
204+ re = keep_edges ['all' ][ct ]['edges' ]
205+ targets = list (re [re .source == g ].target )
206+ if g in top_sources :
207+ top_sources [g ][ct ] = targets
208+ top_counter [g ] += 1
209+ else :
210+ top_sources [g ] = {ct : targets }
211+ top_counter [g ] = 1
212+ except :
213+ pass
214+
215+ # Dictionary to store the final DataFrames
216+ gene_matrices = {}
217+
218+ for g , celltype_dict in top_sources .items ():
219+ # Only process if the gene appears in more than 1 celltype
220+ if len (celltype_dict ) < 2 :
221+ continue
222+
223+ results = []
224+ celltypes = sorted (celltype_dict .keys ())
225+
226+ for s1 in celltypes :
227+ set1 = set (celltype_dict [s1 ])
228+ for s2 in celltypes :
229+ set2 = set (celltype_dict [s2 ])
230+
231+ # Calculate Jaccard
232+ intersection = len (set1 .intersection (set2 ))
233+ union = len (set1 .union (set2 ))
234+ sim = intersection / union if union > 0 else 0
235+
236+ results .append ({
237+ 'ct1' : s1 ,
238+ 'ct2' : s2 ,
239+ 'jaccard' : sim
240+ })
241+
242+ # Convert to DataFrame and Pivot to Square Matrix
243+ df_long = pd .DataFrame (results )
244+ matrix = df_long .pivot (index = 'ct1' , columns = 'ct2' , values = 'jaccard' )
245+
246+ gene_matrices [g ] = matrix
247+
248+ return gene_matrices
249+
250+
251+ def make_global_target_similarity_plot (gene_matrices ):
252+ # 1. Stack all matrices and calculate the mean
253+ # We use .values to ensure we are averaging the numbers,
254+ # but keep the index/columns from one of the matrices.
255+ all_mats = list (gene_matrices .values ())
256+
257+ if len (all_mats ) > 0 :
258+ # Use reduce or concat to get the average
259+ global_matrix = pd .concat (all_mats ).groupby (level = 0 ).mean ()
260+ # Ensure the columns are in the same order as the index for a perfect square
261+ global_matrix = global_matrix [global_matrix .index ]
262+ else :
263+ print ("No matrices found to average." )
264+
265+ # Independent row and column clustering
266+ row_idx = hierarchy .leaves_list (hierarchy .linkage (pdist (global_matrix .fillna (0 )), method = 'ward' ))
267+ ordered_global = global_matrix .iloc [row_idx , row_idx ]
268+
269+ # 3. Plot
270+ fig , ax = plt .subplots (figsize = (7 , 8 ))
271+
272+ sns .heatmap (
273+ ordered_global ,
274+ mask = (ordered_global == 0 ),
275+ cmap = 'YlGnBu' ,
276+ square = True ,
277+ linewidths = .5 ,
278+ linecolor = '#eeeeee' ,
279+ ax = ax ,
280+ cbar_kws = {"shrink" : 0.2 , "orientation" : "horizontal" , "label" : "Mean Jaccard" },
281+ annot = False
282+ )
283+
284+ # Move Legend to lower left
285+ cbar = ax .collections [0 ].colorbar
286+ cbar .ax .set_position ([0.15 , 0.05 , 0.2 , 0.015 ])
287+
288+ # Formatting
289+ ax .tick_params (axis = 'both' , which = 'major' , pad = 0.5 , length = 0 )
290+ ax .set_xticklabels (ax .get_xticklabels (), rotation = 90 , ha = 'center' , fontsize = 9 )
291+ ax .set_yticklabels (ax .get_yticklabels (), rotation = 0 , fontsize = 9 )
292+ ax .set_xlabel ("" )
293+ ax .set_ylabel ("" )
294+ ax .set_title ('Global Target Similarity (Average across all Source Genes)' , pad = 25 , fontweight = 'bold' )
295+
296+ plt .subplots_adjust (bottom = 0.25 , left = 0.25 )
297+ plt .show ()
0 commit comments