Skip to content

Commit 23cfe1b

Browse files
add further documentation
1 parent d176d97 commit 23cfe1b

21 files changed

Lines changed: 148 additions & 21609 deletions

src/netmap/grn/inferrence.py

Lines changed: 40 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,36 @@ def attribution_one_model(
346346

347347
def inferrence_model_wise(models, data_train_full_tensor, gene_names, xai_method, n_models = [10, 25, 50], background_type = 'zeros'):
348348

349+
"""
350+
The main inferrence function to compute the entire GRN model wise. Computes all
351+
attributions for all targets, aggregates them on the fly and creates an anndata.AnnData
352+
object with the edge names in the var slot.
353+
354+
Parameters
355+
----------
356+
models : list[torch.Model]
357+
List of trained autoencoder models
358+
359+
data_train_full_tensor: torch.tensor
360+
input data tensor
361+
362+
gene_names: np.array
363+
Gene names indicating the order of the genes in the torch tensort
364+
365+
xai_method: str
366+
Method to be used [GradientShap, Deconvolution, GuidedBackprop]
367+
368+
n_models: list [int]
369+
returns aggregates of the attributions at these levels.
370+
371+
background_type: str
372+
Bacground to compute the LRP values against. One of ['zeros', 'randomize', 'data']
349373
374+
Returns
375+
-------
376+
grn_adata : anndata.AnnData
377+
A complete, aggregated GRN object
378+
"""
350379

351380
tms = []
352381

@@ -361,13 +390,9 @@ def inferrence_model_wise(models, data_train_full_tensor, gene_names, xai_method
361390
explainer, xai_type = _get_explainer(trained_model, xai_method)
362391
tms.append(explainer)
363392

364-
365-
thresholds = [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
366-
367393
attributions = {}
368394
attribution_collector = None
369395
keynames = []
370-
top_egde_collector = {}
371396

372397

373398
for m in range(len(tms)):
@@ -380,19 +405,6 @@ def inferrence_model_wise(models, data_train_full_tensor, gene_names, xai_method
380405
background_type = background_type)
381406

382407

383-
# grn_adata_eph = attribution_to_anndata(current_attribution, var=cou)
384-
# b = np.argsort(grn_adata_eph.X, axis=1)
385-
# grn_adata_eph.layers['sorted'] = b
386-
# grn_adata_eph = edge_selection.add_top_edge_annotation_global(grn_adata=grn_adata_eph, top_edges = thresholds, key_name=f'agg_{m}')
387-
# df_subset = grn_adata_eph.var.iloc[:, 2:]
388-
# integral_results = df_subset.apply(
389-
# lambda row: np.sum(integrate.cumulative_trapezoid(row, thresholds )),
390-
# axis=1,
391-
# )
392-
# integral_results = integral_results/1000
393-
# top_egde_collector[f'agg_{m}'] = integral_results
394-
395-
396408
if attribution_collector is not None:
397409
# add current attribution to the collector
398410
attribution_collector = aggregate_attributions([attribution_collector, current_attribution], strategy='sum')
@@ -402,11 +414,13 @@ def inferrence_model_wise(models, data_train_full_tensor, gene_names, xai_method
402414
attribution_collector = current_attribution
403415

404416

405-
406-
if (m+1) in n_models:
407-
# dont reset, just save the correct matrix
408-
attributions[f'aggregated_{(m+1)}'] = attribution_collector/(m+1)
409-
keynames.append(f'aggregated_{(m+1)}')
417+
try:
418+
if (m+1) in n_models:
419+
# dont reset, just save the correct matrix
420+
attributions[f'aggregated_{(m+1)}'] = attribution_collector/(m+1)
421+
keynames.append(f'aggregated_{(m+1)}')
422+
except:
423+
pass
410424

411425

412426
# top_egde_collector = pd.DataFrame(top_egde_collector)
@@ -418,7 +432,8 @@ def inferrence_model_wise(models, data_train_full_tensor, gene_names, xai_method
418432

419433
grn_adata = attribution_to_anndata(attributions[keynames[0]], var=cou)
420434

421-
for k in keynames[1:len(keynames)]:
422-
# add remaining versions as masks
423-
grn_adata.layers[k] = attributions[k]
435+
if len(keynames)>0:
436+
for k in keynames[1:len(keynames)]:
437+
# add remaining versions as masks
438+
grn_adata.layers[k] = attributions[k]
424439
return grn_adata

src/netmap/masking/external.py

Lines changed: 44 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def _create_edge_mask_from_GRN(grn_df, gene_list, name_grn='external_grn'):
2020
names: numpy.ndarray: edge name vector (GeneA_GeneB)
2121
2222
"""
23+
2324
# Create a mapping from gene names to their matrix indices for efficient look-up.
2425
gene_to_index = {gene: i for i, gene in enumerate(gene_list)}
2526
num_genes = len(gene_list)
@@ -55,22 +56,28 @@ def _create_edge_mask_from_GRN(grn_df, gene_list, name_grn='external_grn'):
5556

5657

5758
def _get_all_genes_in_grn_object(grnad):
59+
"""
60+
Helper function to get all genes if not available
61+
62+
Args:
63+
grnad (anndata.Anndata) An anndata object containing a var object with the columns source
64+
and target
65+
"""
5866
all_sources = np.unique(grnad.var.source)
5967
all_targets = np.unique(grnad.var.target)
6068
all_genes = np.unique(np.concatenate([all_sources, all_targets]))
6169
return all_genes
6270

6371

6472
def add_external_grn(grn_ad, external_grn, name_grn = 'external_grn'):
65-
66-
"""
67-
Adds three columns to a anndate GRN object.
68-
is_target
69-
is_source
70-
is_egde
71-
7273
"""
74+
Add annotation columns for a reference GRN
7375
76+
Args:
77+
grn_ad
78+
:param external_grn: pd.DataFrame containing a source column and a target column
79+
:param name_grn:
80+
"""
7481
all_my_genes = _get_all_genes_in_grn_object(grn_ad)
7582
edge_mask = _create_edge_mask_from_GRN(external_grn, all_my_genes, name_grn = name_grn)
7683
grn_ad.var = grn_ad.var.merge(edge_mask, left_index=True, right_index=True)
@@ -81,6 +88,15 @@ def add_external_grn(grn_ad, external_grn, name_grn = 'external_grn'):
8188

8289

8390
def get_genome_annotation_from_gtf(gtf_df):
91+
""" Add genome information from a pandas data frame to the object.
92+
Returns the gene features from a gtf file.
93+
94+
Args:
95+
gtf_df (pd.DataFrame): Genome information
96+
97+
Returns:
98+
pd.DataFrame with genome infformation
99+
"""
84100
genes = gtf_df.filter(feature="gene")
85101
genes = pd.DataFrame(genes)
86102
genes.columns = gtf_df.columns
@@ -92,17 +108,36 @@ def get_genome_annotation_from_gtf(gtf_df):
92108
return genes
93109

94110

95-
def preprocess_bed_file(bed_file, gtf_df):
111+
def preprocess_bed_file(bed_file):
112+
""" Read the bed file as a tab separated csv file and obtain all TFs
113+
that are related to a gene from the object.
114+
115+
Args:
116+
bed_file (str): path containing the bed file
117+
118+
Returns:
119+
pd.DataFrame: Dataframe relating the TFs to the genes
120+
"""
96121
## ALL cis regulatory motifs
97122
crm_df = pd.read_csv(bed_file, sep="\t", header=None)
98123
crm_df.columns = ['chr', 'start', 'end', 'TF_list','TF_number', 'strand', 'number1', 'number2', 'large_number']
99-
crm_by_chr = {chr_: df for chr_, df in crm_df.groupby("chr")}
100124
crm_df['TF_list_list'] = crm_df['TF_list'].str.split(",")
101125
return crm_df
102126

103127

104128

105129
def get_regulators(crm_df, genes, window):
130+
""" Obtain the regulators of a target gene by searching
131+
in a window up and down from the TSS
132+
133+
Args:
134+
crm_df (_type_): _description_
135+
genes (_type_): _description_
136+
window (_type_): _description_
137+
138+
Returns:
139+
_type_: _description_
140+
"""
106141
gene_to_tfs = defaultdict(set)
107142

108143
crm_by_chr = {chr_: df for chr_, df in crm_df.groupby("chr")}

src/netmap/masking/internal.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,16 @@ def dict_to_dataframe(mask_dict, column_order_list):
148148

149149

150150
def add_neighbourhood_expression_mask(adata, grn_adata):
151+
""" Create a mask indicating whether the edge is likely actually
152+
expressed or not.
153+
154+
Args:
155+
adata (_type_): _description_
156+
grn_adata (_type_): _description_
157+
158+
Returns:
159+
_type_: _description_
160+
"""
151161
counts = pd.DataFrame(adata.X)
152162
counts.columns =adata.var.index
153163
ne = get_neighborhood_expression(adata, required_neighbours=5)

src/netmap/model/train_model.py

Lines changed: 51 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,27 @@
22
from netmap.model.nbautoencoder import NegativeBinomialAutoencoder
33
from netmap.model.zinbautoencoder import ZINBAutoencoder
44

5+
import torch
6+
from torch.utils.data import DataLoader, TensorDataset
7+
from tqdm import tqdm
8+
59

10+
def create_model_zoo(data_tensor, n_models = 10, n_epochs = 10000, model_type = 'ZINBAutoencoder', dropout_rate = 0.02, latent_dim=8, hidden_dim=[128]):
11+
""" Creates a set of Autoencoders of the data using the speicified architecture. The architecture of the encoder can be specified using
12+
the `hidden_dim` parameter, the decoder architecture is mirrored. Early stopping is used by default.
613
7-
def create_model_zoo(data_tensor, n_models = 4, n_epochs = 500, model_type = 'ZINBAutoencoder', dropout_rate = 0.02, latent_dim=8, hidden_dim=[128]):
14+
Args:
15+
data_tensor (torch.tensor): The raw gene expression data
16+
n_models (int, optional): The number of models to compute. Defaults to 10.
17+
n_epochs (int, optional): Maximum number of epochs, if early stopping is not triggered. Defaults to 10000. Use
18+
model_type (str, optional): Model type, one of [ZINBAutoencoder, NegativeBinomialAutoencoder] Defaults to 'ZINBAutoencoder'.
19+
dropout_rate (float, optional): Dropout rate used during training. Defaults to 0.02.
20+
latent_dim (int, optional): Number of neurons in the latent dimension. Defaults to 8.
21+
hidden_dim (list, optional): Architecture specification, list of ints. Defaults to [128].
22+
23+
Returns:
24+
Model )list): The list of trained models.
25+
"""
826
model_zoo = []
927
counter = 0
1028
failures = 0
@@ -23,7 +41,7 @@ def create_model_zoo(data_tensor, n_models = 4, n_epochs = 500, model_type = 'ZI
2341

2442
optimizer2 = torch.optim.Adam(trained_model2.parameters(), lr=1e-4)
2543

26-
trained_model2 = train_autoencoder_early_stopping(
44+
trained_model2 = _train_autoencoder_early_stopping(
2745
trained_model2,
2846
data_train2.cuda(),
2947
data_test2.cuda(),
@@ -40,13 +58,24 @@ def create_model_zoo(data_tensor, n_models = 4, n_epochs = 500, model_type = 'ZI
4058
return model_zoo
4159

4260

43-
def train_autoencoder(
61+
def _train_autoencoder(
4462
model,
4563
data_train,
4664
optimizer,
4765
batch_size=32, # Minibatch size
4866
num_epochs=100,
4967
):
68+
"""Legacy version of the training loop without early stopping
69+
70+
Args:
71+
model (_type_): The model to be trained
72+
data_train (_type_): Trianing data
73+
optimizer (_type_): optimizer to be used
74+
batch_size (int, optional): Batch size. Defaults to 32.
75+
76+
Returns:
77+
Model: trained model
78+
"""
5079
# Prepare DataLoader for training
5180
train_dataset = TensorDataset(data_train)
5281
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
@@ -73,26 +102,36 @@ def train_autoencoder(
73102

74103
return model
75104

76-
import torch
77-
from torch.utils.data import DataLoader, TensorDataset
78-
79105

80-
from tqdm import tqdm
81-
from torch.utils.data import TensorDataset, DataLoader
82-
from tqdm import tqdm
83-
import torch
84106

85-
def train_autoencoder_early_stopping(
107+
def _train_autoencoder_early_stopping(
86108
model,
87109
data_train,
88110
data_val,
89111
optimizer,
90112
batch_size=32,
91-
num_epochs=100,
113+
num_epochs=10000,
92114
patience=10,
93115
min_delta=0.001,
94116
validation_freq = 10,
95117
):
118+
"""Training loop for the autoencoders.
119+
120+
Args:
121+
model (_type_): An instance of an autoencoder model
122+
data_train (_type_): Training data split
123+
data_val (_type_): Validation data split used for early stopping
124+
optimizer (_type_): Optimizer used
125+
batch_size (int, optional): Minibatch size. Defaults to 32.
126+
num_epochs (int, optional): Number of epochs. Defaults to 10000.
127+
patience (int, optional): Number of epochs with delta loss smaller
128+
than min delta before early stopping is triggered. Defaults to 10.
129+
min_delta (float, optional): Loss delta for early stopping. Defaults to 0.001.
130+
validation_freq (int, optional): Number of epochs before validation is run. Defaults to 10.
131+
132+
Returns:
133+
Model: Trained model with the parametrization of the best loss.
134+
"""
96135
# Prepare DataLoaders
97136
train_dataset = TensorDataset(data_train)
98137
val_dataset = TensorDataset(data_val)

src/netmap/utils/data_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@ def attribution_to_anndata(attribution_list, var = None, obs = None)-> anndata.A
88
Transform attribution data frame into an anndata object
99
1010
Args:
11-
attribution_list: (sparse) Data frame of attribution values (one column per edge)
11+
attribution_list: (sparse) Data frame of attribution values (one column per edge)
1212
1313
returns:
14-
Anndata object with attribution values in X.
14+
anndata.Anndata: Anndata object with attribution values in X.
1515
"""
1616
print('Creating anndata')
1717
adata = anndata.AnnData(attribution_list)

src/netmap/utils/netmap_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ class NetmapConfig:
99
input_data: str = "data.h5ad"
1010
layer: str = 'X'
1111
output_directory: str = "netmap"
12-
transcription_factors: str = "/data_nfs/datasets/SCENIC_DB/tf_lists/allTFs_hg38.txt"
12+
transcription_factors: str = ""
1313
tf_only: bool = True
1414
penalize_error: bool = True
1515
adata_filename: str = "grn_lrp.h5ad"

src/old/edge_selection.py

Lines changed: 0 additions & 47 deletions
This file was deleted.

0 commit comments

Comments
 (0)