|
28 | 28 | from collections import Counter |
29 | 29 |
|
30 | 30 |
|
| 31 | +def select_top_edges(gene_inter_adata, adata, top_per_source=10, col_cluster='leiden_remap', min_reg_size=10, verbose=True, return_copy = False, tf_column=None): |
| 32 | + """ |
| 33 | + Selects top gene targets per source from a clustered gene interaction AnnData. |
| 34 | +
|
| 35 | + Parameters |
| 36 | + ---------- |
| 37 | + gene_inter_adata : AnnData |
| 38 | + Gene interaction AnnData with `var` containing 'source' and 'target'. |
| 39 | + adata : AnnData |
| 40 | + Expression AnnData for ranking genes. |
| 41 | + top_per_source : int, default=750 |
| 42 | + Number of top targets to select per source. |
| 43 | + col_cluster : str, default='spectral' |
| 44 | + Column in obs defining clusters.grn_adata3.var |
| 45 | +
|
| 46 | + Returns |
| 47 | + ------- |
| 48 | + gene_inter_adata_filtered : AnnData |
| 49 | + Filtered AnnData containing top edges. |
| 50 | + reglon_sizes : list of int |
| 51 | + Sizes of regulatory regions per source. |
| 52 | +
|
| 53 | + """ |
| 54 | + |
| 55 | + min_edge_support = 0.5 |
| 56 | + |
| 57 | + clusters = list(np.unique(gene_inter_adata.obs[col_cluster])) |
| 58 | + |
| 59 | + keep_edges_dict = {} |
| 60 | + |
| 61 | + |
| 62 | + for c in clusters: |
| 63 | + Keep_edges = [] |
| 64 | + if verbose: print(f"Selecting targets for cluster: {c}") |
| 65 | + |
| 66 | + cells_c = gene_inter_adata.obs[col_cluster] == c |
| 67 | + gene_inter_adata.var['edge_support_c'] = ( |
| 68 | + gene_inter_adata.layers['mask'][cells_c, :].mean(axis=0) |
| 69 | + ) |
| 70 | + |
| 71 | + if tf_column is not None: |
| 72 | + tfs = gene_inter_adata.var[gene_inter_adata.var[tf_column]]['source'].unique() |
| 73 | + source_list = set(gene_inter_adata.var["source"].unique()).intersection(set(tfs)) |
| 74 | + |
| 75 | + else: |
| 76 | + source_list = gene_inter_adata.var["source"].unique() |
| 77 | + |
| 78 | + for source in source_list: |
| 79 | + |
| 80 | + df_targets = gene_inter_adata.var[ |
| 81 | + (gene_inter_adata.var['source'] == source) & |
| 82 | + (gene_inter_adata.var['edge_support_c'] >= min_edge_support)] |
| 83 | + #print(gene_inter_adata[cells_c, gene_inter_adata[cells_c, df_targets.index].X.sum(axis = 0)].X.sum(axis = 0)) |
| 84 | + |
| 85 | + df_targets['sum_of_edge'] = gene_inter_adata[cells_c, df_targets.index].X.sum(axis = 0) |
| 86 | + df_targets = df_targets.sort_values('sum_of_edge', ascending=False).head(top_per_source) |
| 87 | + |
| 88 | + if len(df_targets) >= min_reg_size: |
| 89 | + Keep_edges.extend(f"{source}_{t}" for t in df_targets['target']) |
| 90 | + |
| 91 | + keep_edges_dict[c] = Keep_edges |
| 92 | + keep_edges_dict = process_cell_edges(keep_edges_dict) |
| 93 | + return keep_edges_dict |
| 94 | + |
| 95 | + |
| 96 | +def process_cell_edges(keep_edges): |
| 97 | + results = {'unique': {}, 'all': {}} |
| 98 | + all_cells = list(keep_edges.keys()) |
| 99 | + |
| 100 | + |
| 101 | + def get_source_summary(edge_set): |
| 102 | + # Handles (source, target) tuples OR strings with a separator like '->' |
| 103 | + sources = [] |
| 104 | + for e in edge_set: |
| 105 | + sources.append(e.split('_')[0]) |
| 106 | + |
| 107 | + source_dict = dict(Counter(sources)) |
| 108 | + sources = pd.DataFrame({'source' :source_dict.keys(), 'count': source_dict.values()}).sort_values('count', ascending=False) |
| 109 | + return sources |
| 110 | + |
| 111 | + # Calculate Uniques |
| 112 | + for cell in all_cells: |
| 113 | + others = set().union(*(set(keep_edges[c]) for c in all_cells if c != cell)) |
| 114 | + unique = set(keep_edges[cell]) - others |
| 115 | + |
| 116 | + df = pd.DataFrame( |
| 117 | + [e.split('_', 1) for e in unique], |
| 118 | + columns=['source', 'target'] |
| 119 | + ) |
| 120 | + df_all = pd.DataFrame( |
| 121 | + [e.split('_', 1) for e in set(keep_edges[cell])], |
| 122 | + columns=['source', 'target'] |
| 123 | + ) |
| 124 | + |
| 125 | + results['unique'][cell] = { |
| 126 | + 'edges': df, |
| 127 | + 'summary': get_source_summary(unique) |
| 128 | + } |
| 129 | + |
| 130 | + results['all'][cell] = { |
| 131 | + 'edges': df_all, |
| 132 | + 'summary': get_source_summary(set(keep_edges[cell])) |
| 133 | + } |
| 134 | + |
| 135 | + return results |
| 136 | + |
| 137 | + |
| 138 | +def compute_signatures_UCell_scores(selected_edges, adata, key='unique') -> pd.DataFrame: |
| 139 | + """ |
| 140 | + Filters gene signatures by cluster and computes UCell scores. |
| 141 | +
|
| 142 | + Parameters |
| 143 | + ---------- |
| 144 | + grn_adata : AnnData |
| 145 | + AnnData object containing GRN (gene regulatory network) information. |
| 146 | + adata : AnnData |
| 147 | + AnnData object containing gene expression counts in the 'counts' layer. |
| 148 | +
|
| 149 | + Returns |
| 150 | + ------- |
| 151 | + pd.DataFrame |
| 152 | + DataFrame with UCell scores merged with the 'spectral' cluster labels. |
| 153 | + """ |
| 154 | + |
| 155 | + all_signatures = {} |
| 156 | + for ct in selected_edges[key]: |
| 157 | + sign = selected_edges[key][ct]['edges'].groupby('source')['target'].apply(list).to_dict() |
| 158 | + sign = {f"{ct}_{k}": v for k, v in sign.items()} |
| 159 | + all_signatures = all_signatures | sign |
| 160 | + |
| 161 | + ucell.compute_ucell_scores(adata, signatures=all_signatures, n_jobs=1) |
| 162 | + data_ucell = adata.obs.filter(like='_UCell') |
| 163 | + data_ucell.columns = [x.replace('_UCell', '') for x in data_ucell.columns] |
| 164 | + |
| 165 | + return data_ucell |
| 166 | + |
| 167 | + |
| 168 | +def aggregate_edges(selected_edges, grn_adata, key='unique') -> pd.DataFrame: |
| 169 | + """ |
| 170 | + Filters gene signatures by cluster and computes UCell scores. |
| 171 | +
|
| 172 | + Parameters |
| 173 | + ---------- |
| 174 | + grn_adata : AnnData |
| 175 | + AnnData object containing GRN (gene regulatory network) information. |
| 176 | + adata : AnnData |
| 177 | + AnnData object containing gene expression counts in the 'counts' layer. |
| 178 | +
|
| 179 | + Returns |
| 180 | + ------- |
| 181 | + pd.DataFrame |
| 182 | + DataFrame with UCell scores merged with the 'spectral' cluster labels. |
| 183 | + """ |
| 184 | + |
| 185 | + regulons = {} |
| 186 | + for ct in selected_edges[key]: |
| 187 | + print(ct) |
| 188 | + sign = selected_edges[key][ct]['edges'].groupby('source').apply(lambda x: (x['source'] + "_" + x['target']).tolist()).to_dict() |
| 189 | + for g in sign: |
| 190 | + regulons[f'{ct}_{g}'] = grn_adata[:, sign[g]].X.sum(axis = 1) |
| 191 | + regulons = pd.DataFrame(regulons) |
| 192 | + return regulons |
| 193 | + |
| 194 | + |
| 195 | + |
| 196 | + |
31 | 197 | def filter_clusters_by_cell_count(grn_adata: ad.AnnData, metric_tag: float, top_fraction: float) -> Tuple[Optional[Dict[str, float]], ad.AnnData]: |
32 | 198 | """ |
33 | 199 | Filter features (genes/edges) based on cell count differences between two clusters, |
|
0 commit comments