|
| 1 | +import logging |
| 2 | +import warnings |
| 3 | +from functools import reduce |
| 4 | +from itertools import chain, product, combinations |
| 5 | +from typing import List, Optional, Tuple, Union, Dict |
| 6 | +import json |
| 7 | +import os |
| 8 | +import anndata as ad |
| 9 | +import matplotlib.pyplot as plt |
| 10 | +import numpy as np |
| 11 | +import pandas as pd |
| 12 | +import scanpy as sc |
| 13 | +import seaborn as sns |
| 14 | +from scipy.stats import pearsonr, ranksums |
| 15 | +from statsmodels.stats.multitest import multipletests |
| 16 | +import networkx as nx |
| 17 | +import requests |
| 18 | +from pyvis.network import Network |
| 19 | +import pyucell as ucell |
| 20 | + |
| 21 | + |
| 22 | + |
| 23 | + |
| 24 | +from itertools import combinations |
| 25 | +from collections import Counter |
| 26 | + |
| 27 | + |
| 28 | +def select_top_edges(gene_inter_adata, adata, top_per_source=10, col_cluster='leiden_remap', min_reg_size=10, verbose=True, return_copy = False, tf_column=None): |
| 29 | + """ |
| 30 | + Selects top gene targets per source from a clustered gene interaction AnnData. |
| 31 | +
|
| 32 | + Parameters |
| 33 | + ---------- |
| 34 | + gene_inter_adata : AnnData |
| 35 | + Gene interaction AnnData with `var` containing 'source' and 'target'. |
| 36 | + adata : AnnData |
| 37 | + Expression AnnData for ranking genes. |
| 38 | + top_per_source : int, default=750 |
| 39 | + Number of top targets to select per source. |
| 40 | + col_cluster : str, default='spectral' |
| 41 | + Column in obs defining clusters.grn_adata3.var |
| 42 | +
|
| 43 | + Returns |
| 44 | + ------- |
| 45 | + gene_inter_adata_filtered : AnnData |
| 46 | + Filtered AnnData containing top edges. |
| 47 | + reglon_sizes : list of int |
| 48 | + Sizes of regulatory regions per source. |
| 49 | +
|
| 50 | + """ |
| 51 | + |
| 52 | + min_edge_support = 0.5 |
| 53 | + |
| 54 | + clusters = list(np.unique(gene_inter_adata.obs[col_cluster])) |
| 55 | + |
| 56 | + keep_edges_dict = {} |
| 57 | + |
| 58 | + |
| 59 | + for c in clusters: |
| 60 | + Keep_edges = [] |
| 61 | + if verbose: print(f"Selecting targets for cluster: {c}") |
| 62 | + |
| 63 | + cells_c = gene_inter_adata.obs[col_cluster] == c |
| 64 | + gene_inter_adata.var['edge_support_c'] = ( |
| 65 | + gene_inter_adata.layers['mask'][cells_c, :].mean(axis=0) |
| 66 | + ) |
| 67 | + |
| 68 | + if tf_column is not None: |
| 69 | + tfs = gene_inter_adata.var[gene_inter_adata.var[tf_column]]['source'].unique() |
| 70 | + source_list = set(gene_inter_adata.var["source"].unique()).intersection(set(tfs)) |
| 71 | + |
| 72 | + else: |
| 73 | + source_list = gene_inter_adata.var["source"].unique() |
| 74 | + |
| 75 | + for source in source_list: |
| 76 | + |
| 77 | + df_targets = gene_inter_adata.var[ |
| 78 | + (gene_inter_adata.var['source'] == source) & |
| 79 | + (gene_inter_adata.var['edge_support_c'] >= min_edge_support)] |
| 80 | + #print(gene_inter_adata[cells_c, gene_inter_adata[cells_c, df_targets.index].X.sum(axis = 0)].X.sum(axis = 0)) |
| 81 | + |
| 82 | + df_targets['sum_of_edge'] = gene_inter_adata[cells_c, df_targets.index].X.sum(axis = 0) |
| 83 | + df_targets = df_targets.sort_values('sum_of_edge', ascending=False).head(top_per_source) |
| 84 | + |
| 85 | + if len(df_targets) >= min_reg_size: |
| 86 | + Keep_edges.extend(f"{source}_{t}" for t in df_targets['target']) |
| 87 | + |
| 88 | + keep_edges_dict[c] = Keep_edges |
| 89 | + keep_edges_dict = process_cell_edges(keep_edges_dict) |
| 90 | + return keep_edges_dict |
| 91 | + |
| 92 | + |
| 93 | +def process_cell_edges(keep_edges): |
| 94 | + results = {'unique': {}, 'all': {}} |
| 95 | + all_cells = list(keep_edges.keys()) |
| 96 | + |
| 97 | + |
| 98 | + def get_source_summary(edge_set): |
| 99 | + # Handles (source, target) tuples OR strings with a separator like '->' |
| 100 | + sources = [] |
| 101 | + for e in edge_set: |
| 102 | + sources.append(e.split('_')[0]) |
| 103 | + |
| 104 | + source_dict = dict(Counter(sources)) |
| 105 | + sources = pd.DataFrame({'source' :source_dict.keys(), 'count': source_dict.values()}).sort_values('count', ascending=False) |
| 106 | + return sources |
| 107 | + |
| 108 | + # Calculate Uniques |
| 109 | + for cell in all_cells: |
| 110 | + others = set().union(*(set(keep_edges[c]) for c in all_cells if c != cell)) |
| 111 | + unique = set(keep_edges[cell]) - others |
| 112 | + |
| 113 | + df = pd.DataFrame( |
| 114 | + [e.split('_', 1) for e in unique], |
| 115 | + columns=['source', 'target'] |
| 116 | + ) |
| 117 | + df_all = pd.DataFrame( |
| 118 | + [e.split('_', 1) for e in set(keep_edges[cell])], |
| 119 | + columns=['source', 'target'] |
| 120 | + ) |
| 121 | + |
| 122 | + results['unique'][cell] = { |
| 123 | + 'edges': df, |
| 124 | + 'summary': get_source_summary(unique) |
| 125 | + } |
| 126 | + |
| 127 | + results['all'][cell] = { |
| 128 | + 'edges': df_all, |
| 129 | + 'summary': get_source_summary(set(keep_edges[cell])) |
| 130 | + } |
| 131 | + |
| 132 | + return results |
| 133 | + |
| 134 | + |
| 135 | +def compute_signatures_UCell_scores(selected_edges, adata, key='unique') -> pd.DataFrame: |
| 136 | + """ |
| 137 | + Filters gene signatures by cluster and computes UCell scores. |
| 138 | +
|
| 139 | + Parameters |
| 140 | + ---------- |
| 141 | + grn_adata : AnnData |
| 142 | + AnnData object containing GRN (gene regulatory network) information. |
| 143 | + adata : AnnData |
| 144 | + AnnData object containing gene expression counts in the 'counts' layer. |
| 145 | +
|
| 146 | + Returns |
| 147 | + ------- |
| 148 | + pd.DataFrame |
| 149 | + DataFrame with UCell scores merged with the 'spectral' cluster labels. |
| 150 | + """ |
| 151 | + |
| 152 | + all_signatures = {} |
| 153 | + for ct in selected_edges[key]: |
| 154 | + sign = selected_edges[key][ct]['edges'].groupby('source')['target'].apply(list).to_dict() |
| 155 | + sign = {f"{ct}_{k}": v for k, v in sign.items()} |
| 156 | + all_signatures = all_signatures | sign |
| 157 | + |
| 158 | + ucell.compute_ucell_scores(adata, signatures=all_signatures, n_jobs=1) |
| 159 | + data_ucell = adata.obs.filter(like='_UCell') |
| 160 | + data_ucell.columns = [x.replace('_UCell', '') for x in data_ucell.columns] |
| 161 | + |
| 162 | + return data_ucell |
| 163 | + |
| 164 | + |
| 165 | +def aggregate_edges(selected_edges, grn_adata, key='unique') -> pd.DataFrame: |
| 166 | + """ |
| 167 | + Filters gene signatures by cluster and computes UCell scores. |
| 168 | +
|
| 169 | + Parameters |
| 170 | + ---------- |
| 171 | + grn_adata : AnnData |
| 172 | + AnnData object containing GRN (gene regulatory network) information. |
| 173 | + adata : AnnData |
| 174 | + AnnData object containing gene expression counts in the 'counts' layer. |
| 175 | +
|
| 176 | + Returns |
| 177 | + ------- |
| 178 | + pd.DataFrame |
| 179 | + DataFrame with UCell scores merged with the 'spectral' cluster labels. |
| 180 | + """ |
| 181 | + |
| 182 | + regulons = {} |
| 183 | + for ct in selected_edges[key]: |
| 184 | + print(ct) |
| 185 | + sign = selected_edges[key][ct]['edges'].groupby('source').apply(lambda x: (x['source'] + "_" + x['target']).tolist()).to_dict() |
| 186 | + for g in sign: |
| 187 | + regulons[f'{ct}_{g}'] = grn_adata[:, sign[g]].X.sum(axis = 1) |
| 188 | + regulons = pd.DataFrame(regulons) |
| 189 | + return regulons |
| 190 | + |
0 commit comments