Skip to content

Commit 53bc280

Browse files
move regulon to new file
1 parent 0121b4e commit 53bc280

1 file changed

Lines changed: 190 additions & 0 deletions

File tree

src/netmap/downstream/regulon.py

Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
import logging
2+
import warnings
3+
from functools import reduce
4+
from itertools import chain, product, combinations
5+
from typing import List, Optional, Tuple, Union, Dict
6+
import json
7+
import os
8+
import anndata as ad
9+
import matplotlib.pyplot as plt
10+
import numpy as np
11+
import pandas as pd
12+
import scanpy as sc
13+
import seaborn as sns
14+
from scipy.stats import pearsonr, ranksums
15+
from statsmodels.stats.multitest import multipletests
16+
import networkx as nx
17+
import requests
18+
from pyvis.network import Network
19+
import pyucell as ucell
20+
21+
22+
23+
24+
from itertools import combinations
25+
from collections import Counter
26+
27+
28+
def select_top_edges(gene_inter_adata, adata, top_per_source=10, col_cluster='leiden_remap', min_reg_size=10, verbose=True, return_copy = False, tf_column=None):
29+
"""
30+
Selects top gene targets per source from a clustered gene interaction AnnData.
31+
32+
Parameters
33+
----------
34+
gene_inter_adata : AnnData
35+
Gene interaction AnnData with `var` containing 'source' and 'target'.
36+
adata : AnnData
37+
Expression AnnData for ranking genes.
38+
top_per_source : int, default=750
39+
Number of top targets to select per source.
40+
col_cluster : str, default='spectral'
41+
Column in obs defining clusters.grn_adata3.var
42+
43+
Returns
44+
-------
45+
gene_inter_adata_filtered : AnnData
46+
Filtered AnnData containing top edges.
47+
reglon_sizes : list of int
48+
Sizes of regulatory regions per source.
49+
50+
"""
51+
52+
min_edge_support = 0.5
53+
54+
clusters = list(np.unique(gene_inter_adata.obs[col_cluster]))
55+
56+
keep_edges_dict = {}
57+
58+
59+
for c in clusters:
60+
Keep_edges = []
61+
if verbose: print(f"Selecting targets for cluster: {c}")
62+
63+
cells_c = gene_inter_adata.obs[col_cluster] == c
64+
gene_inter_adata.var['edge_support_c'] = (
65+
gene_inter_adata.layers['mask'][cells_c, :].mean(axis=0)
66+
)
67+
68+
if tf_column is not None:
69+
tfs = gene_inter_adata.var[gene_inter_adata.var[tf_column]]['source'].unique()
70+
source_list = set(gene_inter_adata.var["source"].unique()).intersection(set(tfs))
71+
72+
else:
73+
source_list = gene_inter_adata.var["source"].unique()
74+
75+
for source in source_list:
76+
77+
df_targets = gene_inter_adata.var[
78+
(gene_inter_adata.var['source'] == source) &
79+
(gene_inter_adata.var['edge_support_c'] >= min_edge_support)]
80+
#print(gene_inter_adata[cells_c, gene_inter_adata[cells_c, df_targets.index].X.sum(axis = 0)].X.sum(axis = 0))
81+
82+
df_targets['sum_of_edge'] = gene_inter_adata[cells_c, df_targets.index].X.sum(axis = 0)
83+
df_targets = df_targets.sort_values('sum_of_edge', ascending=False).head(top_per_source)
84+
85+
if len(df_targets) >= min_reg_size:
86+
Keep_edges.extend(f"{source}_{t}" for t in df_targets['target'])
87+
88+
keep_edges_dict[c] = Keep_edges
89+
keep_edges_dict = process_cell_edges(keep_edges_dict)
90+
return keep_edges_dict
91+
92+
93+
def process_cell_edges(keep_edges):
94+
results = {'unique': {}, 'all': {}}
95+
all_cells = list(keep_edges.keys())
96+
97+
98+
def get_source_summary(edge_set):
99+
# Handles (source, target) tuples OR strings with a separator like '->'
100+
sources = []
101+
for e in edge_set:
102+
sources.append(e.split('_')[0])
103+
104+
source_dict = dict(Counter(sources))
105+
sources = pd.DataFrame({'source' :source_dict.keys(), 'count': source_dict.values()}).sort_values('count', ascending=False)
106+
return sources
107+
108+
# Calculate Uniques
109+
for cell in all_cells:
110+
others = set().union(*(set(keep_edges[c]) for c in all_cells if c != cell))
111+
unique = set(keep_edges[cell]) - others
112+
113+
df = pd.DataFrame(
114+
[e.split('_', 1) for e in unique],
115+
columns=['source', 'target']
116+
)
117+
df_all = pd.DataFrame(
118+
[e.split('_', 1) for e in set(keep_edges[cell])],
119+
columns=['source', 'target']
120+
)
121+
122+
results['unique'][cell] = {
123+
'edges': df,
124+
'summary': get_source_summary(unique)
125+
}
126+
127+
results['all'][cell] = {
128+
'edges': df_all,
129+
'summary': get_source_summary(set(keep_edges[cell]))
130+
}
131+
132+
return results
133+
134+
135+
def compute_signatures_UCell_scores(selected_edges, adata, key='unique') -> pd.DataFrame:
136+
"""
137+
Filters gene signatures by cluster and computes UCell scores.
138+
139+
Parameters
140+
----------
141+
grn_adata : AnnData
142+
AnnData object containing GRN (gene regulatory network) information.
143+
adata : AnnData
144+
AnnData object containing gene expression counts in the 'counts' layer.
145+
146+
Returns
147+
-------
148+
pd.DataFrame
149+
DataFrame with UCell scores merged with the 'spectral' cluster labels.
150+
"""
151+
152+
all_signatures = {}
153+
for ct in selected_edges[key]:
154+
sign = selected_edges[key][ct]['edges'].groupby('source')['target'].apply(list).to_dict()
155+
sign = {f"{ct}_{k}": v for k, v in sign.items()}
156+
all_signatures = all_signatures | sign
157+
158+
ucell.compute_ucell_scores(adata, signatures=all_signatures, n_jobs=1)
159+
data_ucell = adata.obs.filter(like='_UCell')
160+
data_ucell.columns = [x.replace('_UCell', '') for x in data_ucell.columns]
161+
162+
return data_ucell
163+
164+
165+
def aggregate_edges(selected_edges, grn_adata, key='unique') -> pd.DataFrame:
166+
"""
167+
Filters gene signatures by cluster and computes UCell scores.
168+
169+
Parameters
170+
----------
171+
grn_adata : AnnData
172+
AnnData object containing GRN (gene regulatory network) information.
173+
adata : AnnData
174+
AnnData object containing gene expression counts in the 'counts' layer.
175+
176+
Returns
177+
-------
178+
pd.DataFrame
179+
DataFrame with UCell scores merged with the 'spectral' cluster labels.
180+
"""
181+
182+
regulons = {}
183+
for ct in selected_edges[key]:
184+
print(ct)
185+
sign = selected_edges[key][ct]['edges'].groupby('source').apply(lambda x: (x['source'] + "_" + x['target']).tolist()).to_dict()
186+
for g in sign:
187+
regulons[f'{ct}_{g}'] = grn_adata[:, sign[g]].X.sum(axis = 1)
188+
regulons = pd.DataFrame(regulons)
189+
return regulons
190+

0 commit comments

Comments
 (0)