Skip to content

Commit 0c048fb

Browse files
remove unnecessary imports from downstream, add convenience function for window filtering
1 parent ee93617 commit 0c048fb

2 files changed

Lines changed: 20 additions & 57 deletions

File tree

src/netmap/downstream/final_downstream.py

Lines changed: 12 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,66 +1,24 @@
1-
# Standard library imports
2-
import gc
3-
import json
41
import logging
5-
import os
6-
import sys
7-
import time
82
import warnings
93
from functools import reduce
10-
from itertools import chain, product
11-
from typing import Tuple, List, Dict, Optional
12-
from anndata import AnnData
13-
from itertools import combinations
14-
15-
16-
# Third-party imports
4+
from itertools import chain, product, combinations
5+
from typing import List, Optional, Tuple, Union
6+
import json
7+
import os
178
import anndata as ad
189
import matplotlib.pyplot as plt
1910
import numpy as np
2011
import pandas as pd
2112
import scanpy as sc
22-
import scipy.sparse as scs
23-
from anndata import AnnData
2413
import seaborn as sns
25-
import scipy.sparse as sp
26-
import anndata
2714
from scipy.stats import pearsonr, ranksums
28-
from scipy.optimize import linear_sum_assignment
29-
from sklearn.cluster import KMeans, SpectralClustering
30-
from sklearn.metrics.cluster import contingency_matrix
31-
from statsmodels.stats.multitest import multipletests
32-
import pandas as pd
33-
from scipy.stats import mannwhitneyu
15+
#from scipy.stats import mannwhitneyu
3416
from statsmodels.stats.multitest import multipletests
35-
from anndata import AnnData
36-
import importlib
37-
import json
38-
import os
39-
from turtle import shape
40-
from typing import List, Optional, Tuple, Union
41-
4217
import networkx as nx
43-
import pandas as pd
4418
import requests
45-
from pandas import DataFrame
4619
from pyvis.network import Network
47-
from anndata import AnnData
48-
import importlib
49-
import json
50-
import os
51-
from typing import List, Optional, Tuple, Union
52-
53-
import networkx as nx
54-
import pandas as pd
55-
from pandas import DataFrame
5620
import pyucell as ucell
57-
from pyvis.network import Network
58-
59-
60-
#import omnipath as op
6121

62-
# Miscellaneous
63-
warnings.filterwarnings("ignore")
6422

6523
from netmap.downstream.clustering import process, spectral_clustering, downstream_recipe
6624
from netmap.downstream.edge_selection import add_top_edge_annotation_global
@@ -69,7 +27,7 @@
6927

7028

7129

72-
def filter_clusters_by_cell_count(grn_adata: AnnData, metric_tag: float, top_fraction: float) -> Tuple[Optional[Dict[str, float]], AnnData]:
30+
def filter_clusters_by_cell_count(grn_adata: ad.AnnData, metric_tag: float, top_fraction: float) -> Tuple[Optional[Dict[str, float]], ad.AnnData]:
7331
"""
7432
Filter features (genes/edges) based on cell count differences between two clusters,
7533
selecting the top fraction of features for each cluster.
@@ -193,7 +151,7 @@ def get_top_targets(gene_inter_adata, adata, top_per_source=750, col_cluster='sp
193151
return gene_inter_adata_filtered, reglon_sizes
194152

195153

196-
def filter_signatures_by_Ucell(grn_adata, adata, ncores: int = 100) -> pd.DataFrame:
154+
def filter_signatures_by_Ucell(grn_adata, adata) -> pd.DataFrame:
197155
"""
198156
Filters gene signatures by cluster and computes UCell scores.
199157
@@ -203,8 +161,6 @@ def filter_signatures_by_Ucell(grn_adata, adata, ncores: int = 100) -> pd.DataFr
203161
AnnData object containing GRN (gene regulatory network) information.
204162
adata : AnnData
205163
AnnData object containing gene expression counts in the 'counts' layer.
206-
ncores : int, optional
207-
Number of cores to use for parallel computation, by default 100.
208164
209165
Returns
210166
-------
@@ -219,7 +175,7 @@ def filter_signatures_by_Ucell(grn_adata, adata, ncores: int = 100) -> pd.DataFr
219175

220176

221177

222-
def filter_grn_by_top_signatures(data_ucell: pd.DataFrame, grn_adata: AnnData, keep_top_ranked: int = 100, filter_by: str = "z_score", cluster_col = 'spectral') -> Tuple[Optional[AnnData], List[str]]:
178+
def filter_grn_by_top_signatures(data_ucell: pd.DataFrame, grn_adata: ad.AnnData, keep_top_ranked: int = 100, filter_by: str = "z_score", cluster_col = 'spectral') -> Tuple[Optional[ad.AnnData], List[str]]:
223179
"""
224180
Filters a GRN (Gene Regulatory Network) AnnData object to keep only the edges
225181
corresponding to the top-ranked signatures per cluster based on UCell scores.
@@ -395,7 +351,7 @@ def plot_shared_targets_heatmap(grn_adata, genes=None, figsize=(6, 6), cmap='RdB
395351

396352

397353

398-
def compute_edge_overlaps_simple(grn_adata: AnnData, net_list: List[Tuple[str, pd.DataFrame]]) -> Dict[str, float]:
354+
def compute_edge_overlaps_simple(grn_adata: ad.AnnData, net_list: List[Tuple[str, pd.DataFrame]]) -> Dict[str, float]:
399355
"""
400356
Compute the percentage of overlapping edges between a GRN and multiple reference networks.
401357
@@ -431,7 +387,7 @@ def compute_edge_overlaps_simple(grn_adata: AnnData, net_list: List[Tuple[str, p
431387

432388

433389

434-
def filter_clusters_by_cell_count(grn_adata: AnnData, metric_tag: float, top_fraction: float) -> Tuple[Optional[Dict[str, float]], AnnData]:
390+
def filter_clusters_by_cell_count(grn_adata: ad.AnnData, metric_tag: float, top_fraction: float) -> Tuple[Optional[Dict[str, float]], ad.AnnData]:
435391
"""
436392
Filter features (genes/edges) based on cell count differences between two clusters,
437393
selecting the top fraction of features for each cluster.
@@ -477,7 +433,7 @@ def filter_clusters_by_cell_count(grn_adata: AnnData, metric_tag: float, top_fra
477433

478434

479435

480-
def create_regulon_activity_adata(grn_adata: AnnData, data_ucell: pd.DataFrame, top_sources_list: List[str]) -> AnnData:
436+
def create_regulon_activity_adata(grn_adata: ad.AnnData, data_ucell: pd.DataFrame, top_sources_list: List[str]) -> ad.AnnData:
481437
"""
482438
Creates an AnnData object with regulon activity scores based on top GRN sources.
483439
@@ -520,7 +476,7 @@ def create_regulon_activity_adata(grn_adata: AnnData, data_ucell: pd.DataFrame,
520476

521477

522478

523-
def plot_reg(grn_adata: AnnData, regulon: List, name="network", layout: Optional[str]="hierarchical", out_path="network_plots/"):
479+
def plot_reg(grn_adata: ad.AnnData, regulon: List, name="network", layout: Optional[str]="hierarchical", out_path="network_plots/"):
524480

525481
# Make all genes uppercase
526482
#df = df.applymap(lambda s: s.upper() if type(s) == str else s)

src/netmap/masking/external.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,4 +168,11 @@ def get_regulators(crm_df, genes, window):
168168
results['edge'] = results['TFs'] + '_' + results['gene']
169169

170170
results['regulator'] = True
171-
return results
171+
return results
172+
173+
def add_genome_information_to_anndata(grn_adata, tf_to_gene_df, window_name = ''):
174+
grn_adata.var = grn_adata.var.reset_index().merge(tf_to_gene_df.loc[:, ['edge', 'regulator']], left_on='edge_key', right_on='edge', how='left').set_index('edge_key')
175+
grn_adata.var.regulator = grn_adata.var.regulator.apply(lambda x: False if pd.isna(x) else True)
176+
grn_adata.var = grn_adata.var.drop_columns(columns = ['edge'])
177+
grn_adata.var = grn_adata.var.rename(columns= {'regulator': f'regulator{window_name}'})
178+
return grn_adata

0 commit comments

Comments
 (0)