diff --git a/EVE/GMM_model.py b/EVE/GMM_model.py new file mode 100644 index 0000000..26253cf --- /dev/null +++ b/EVE/GMM_model.py @@ -0,0 +1,82 @@ +from sklearn import mixture +import numpy as np +import os +import tqdm +import logging + +class GMM_model: + """ + Class for the global-local GMM model trained on single-point mutations for a wild-type protein sequence + """ + def __init__(self, params) -> None: + + # store parameters + self.protein_GMM_weight = params['protein_GMM_weight'] + self.params = dict( + n_components=2, + covariance_type='full', + max_iter=1000, + n_init=30, + tol=1e-4 + ) + + def fit_single(self, X_train): + model = mixture.GaussianMixture(**self.params) + model.fit(X_train) + #The pathogenic cluster is the cluster with higher mean value + pathogenic_cluster_index = np.argmax(np.array(model.means_).flatten()) + return model, pathogenic_cluster_index + + def fit(self, X_train, proteins_train): + # set up to train + self.models = {} + self.indices = {} + + # train global model + gmm, index = self.fit_single(X_train,'main') + self.models['main'] = gmm + self.indices['main'] = index + + # train local models + if self.protein_GMM_weight > 0.0: + proteins_list = list(set(proteins_train)) + for protein in tqdm.tqdm(proteins_list, "Training all protein GMMs"): + X_train_protein = X_train[proteins_train == protein] + gmm, index = self.fit_single(X_train_protein,protein) + self.models[protein] = gmm + self.indices[protein] = index + + return self.models, self.indices + + def predict(self, X_pred, protein): + model = self.models[protein] + cluster_index = self.indices[protein] + scores = model.predict_proba(X_pred)[:,cluster_index] + classes = (scores > 0.5).astype(int) + return scores, classes + + def predict_weighted(self, X_pred, protein): + scores_protein = self.predict(X_pred, protein) + scores_main = self.predict(X_pred, 'main') + + scores_weighted = scores_main * (1 - self.protein_GMM_weight) + \ + scores_protein * self.protein_GMM_weight + classes_weighted = (scores_weighted > 0.5).astype(int) + return scores_weighted, classes_weighted + + def get_fitted_params(self): + fitted_params = ( + self.models.keys(), + self.models.values(), + self.indices.values() + ) + output = {} + for protein, model, index in zip(*fitted_params): + output[protein] = { + 'index': index, + 'means': model.means_, + 'covar': model.covariances_, + 'weights': model.weights_ + } + return output + diff --git a/EVE/VAE_decoder.py b/EVE/VAE_decoder.py index 7bfa57a..87277e0 100644 --- a/EVE/VAE_decoder.py +++ b/EVE/VAE_decoder.py @@ -1,6 +1,15 @@ import torch import torch.nn as nn -import torch.nn.functional as F +import torch.nn.functional as functional + +HIDDEN_LAYER_NONLINEARITIES = { + 'relu': nn.ReLU(), + 'tanh': nn.Tanh(), + 'sigmoid': nn.Sigmoid(), + 'elu': nn.ELU(), + 'linear': nn.Identity(), +} + class VAE_Bayesian_MLP_decoder(nn.Module): """ @@ -40,138 +49,185 @@ def __init__(self, params): self.mu_bias_init = 0.1 self.logvar_init = -10.0 self.logit_scale_p = 0.001 - - self.hidden_layers_mean=nn.ModuleDict() - self.hidden_layers_log_var=nn.ModuleDict() - for layer_index in range(len(self.hidden_layers_sizes)): - if layer_index==0: - self.hidden_layers_mean[str(layer_index)] = nn.Linear(self.z_dim, self.hidden_layers_sizes[layer_index]) - self.hidden_layers_log_var[str(layer_index)] = nn.Linear(self.z_dim, self.hidden_layers_sizes[layer_index]) - nn.init.constant_(self.hidden_layers_mean[str(layer_index)].bias, self.mu_bias_init) - nn.init.constant_(self.hidden_layers_log_var[str(layer_index)].weight, self.logvar_init) - nn.init.constant_(self.hidden_layers_log_var[str(layer_index)].bias, self.logvar_init) - else: - self.hidden_layers_mean[str(layer_index)] = nn.Linear(self.hidden_layers_sizes[layer_index-1],self.hidden_layers_sizes[layer_index]) - self.hidden_layers_log_var[str(layer_index)] = nn.Linear(self.hidden_layers_sizes[layer_index-1],self.hidden_layers_sizes[layer_index]) - nn.init.constant_(self.hidden_layers_mean[str(layer_index)].bias, self.mu_bias_init) - nn.init.constant_(self.hidden_layers_log_var[str(layer_index)].weight, self.logvar_init) - nn.init.constant_(self.hidden_layers_log_var[str(layer_index)].bias, self.logvar_init) - - if params['first_hidden_nonlinearity'] == 'relu': - self.first_hidden_nonlinearity = nn.ReLU() - elif params['first_hidden_nonlinearity'] == 'tanh': - self.first_hidden_nonlinearity = nn.Tanh() - elif params['first_hidden_nonlinearity'] == 'sigmoid': - self.first_hidden_nonlinearity = nn.Sigmoid() - elif params['first_hidden_nonlinearity'] == 'elu': - self.first_hidden_nonlinearity = nn.ELU() - elif params['first_hidden_nonlinearity'] == 'linear': - self.first_hidden_nonlinearity = nn.Identity() - - if params['last_hidden_nonlinearity'] == 'relu': - self.last_hidden_nonlinearity = nn.ReLU() - elif params['last_hidden_nonlinearity'] == 'tanh': - self.last_hidden_nonlinearity = nn.Tanh() - elif params['last_hidden_nonlinearity'] == 'sigmoid': - self.last_hidden_nonlinearity = nn.Sigmoid() - elif params['last_hidden_nonlinearity'] == 'elu': - self.last_hidden_nonlinearity = nn.ELU() - elif params['last_hidden_nonlinearity'] == 'linear': - self.last_hidden_nonlinearity = nn.Identity() - - if self.dropout_proba > 0.0: - self.dropout_layer = nn.Dropout(p=self.dropout_proba) - if self.convolve_output: - self.output_convolution_mean = nn.Conv1d(in_channels=self.convolution_depth,out_channels=self.alphabet_size,kernel_size=1,stride=1,bias=False) - self.output_convolution_log_var = nn.Conv1d(in_channels=self.convolution_depth,out_channels=self.alphabet_size,kernel_size=1,stride=1,bias=False) - nn.init.constant_(self.output_convolution_log_var.weight, self.logvar_init) + self.fcnn_output_size = self.seq_len * self.hidden_layers_sizes[-1] self.channel_size = self.convolution_depth else: self.channel_size = self.alphabet_size - if self.include_sparsity: - self.sparsity_weight_mean = nn.Parameter(torch.zeros(int(self.hidden_layers_sizes[-1]/self.num_tiles_sparsity), self.seq_len)) - self.sparsity_weight_log_var = nn.Parameter(torch.ones(int(self.hidden_layers_sizes[-1]/self.num_tiles_sparsity), self.seq_len)) - nn.init.constant_(self.sparsity_weight_log_var, self.logvar_init) + # Set hidden layer nonlinearities for first and last layers + self.first_hidden_nonlinearity = HIDDEN_LAYER_NONLINEARITIES[params['first_hidden_nonlinearity']] + self.last_hidden_nonlinearity = HIDDEN_LAYER_NONLINEARITIES[params['last_hidden_nonlinearity']] - self.last_hidden_layer_weight_mean = nn.Parameter(torch.zeros(self.channel_size * self.seq_len,self.hidden_layers_sizes[-1])) - self.last_hidden_layer_weight_log_var = nn.Parameter(torch.zeros(self.channel_size * self.seq_len,self.hidden_layers_sizes[-1])) - nn.init.xavier_normal_(self.last_hidden_layer_weight_mean) #Glorot initialization - nn.init.constant_(self.last_hidden_layer_weight_log_var, self.logvar_init) + # set dropout + if self.dropout_proba > 0.0: + self.use_dropout=True + self.dropout_layer = nn.Dropout(p=self.dropout_proba) + else: + self.use_dropout=False - self.last_hidden_layer_bias_mean = nn.Parameter(torch.zeros(self.alphabet_size * self.seq_len)) - self.last_hidden_layer_bias_log_var = nn.Parameter(torch.zeros(self.alphabet_size * self.seq_len)) + self.initialise_params() + + + def initialise_params(self): + # Set mean and variance for hidden layer distributions + self.hidden_layers_mean=nn.ModuleDict() + self.hidden_layers_log_var=nn.ModuleDict() + input_size = self.z_dim + for i, output_size in enumerate(self.hidden_layers_sizes): + self.hidden_layers_mean[str(i)] = nn.Linear(input_size, output_size) + self.hidden_layers_log_var[str(i)] = nn.Linear(input_size, output_size) + input_size = output_size + nn.init.constant_(self.hidden_layers_mean[str(i)].bias, self.mu_bias_init) + nn.init.constant_(self.hidden_layers_log_var[str(i)].weight, self.logvar_init) + nn.init.constant_(self.hidden_layers_log_var[str(i)].bias, self.logvar_init) + + # set mean and variance for last hidden weight and bias + output_size = self.channel_size * self.seq_len + self.last_hidden_layer_weight_mean = nn.Parameter(torch.zeros(output_size, input_size)) + self.last_hidden_layer_weight_log_var = nn.Parameter(torch.zeros(output_size, input_size)) + nn.init.xavier_normal_(self.last_hidden_layer_weight_mean) #Glorot initialization + nn.init.constant_(self.last_hidden_layer_weight_log_var, self.logvar_init) + + output_size = self.alphabet_size * self.seq_len + self.last_hidden_layer_bias_mean = nn.Parameter(torch.zeros(output_size)) + self.last_hidden_layer_bias_log_var = nn.Parameter(torch.zeros(output_size)) nn.init.constant_(self.last_hidden_layer_bias_mean, self.mu_bias_init) nn.init.constant_(self.last_hidden_layer_bias_log_var, self.logvar_init) - + + # Set mean and variance for 1-stride convolutions if they are used + if self.convolve_output: + self.output_convolution_mean = nn.Conv1d( + in_channels=self.convolution_depth, + out_channels=self.alphabet_size, + kernel_size=1,stride=1,bias=False) + self.output_convolution_log_var = nn.Conv1d( + in_channels=self.convolution_depth, + out_channels=self.alphabet_size, + kernel_size=1,stride=1,bias=False) + nn.init.constant_( + self.output_convolution_log_var.weight, + self.logvar_init + ) + + # Set mean and variance for sparsity prior + if self.include_sparsity: + sparsity_size = int(self.hidden_layers_sizes[-1]/self.num_tiles_sparsity) + self.sparsity_weight_mean = nn.Parameter(torch.zeros(sparsity_size, self.seq_len)) + self.sparsity_weight_log_var = nn.Parameter(torch.ones(sparsity_size, self.seq_len)) + nn.init.constant_(self.hidden_layers_log_var_sparsity, self.logvar_init) + if self.include_temperature_scaler: self.temperature_scaler_mean = nn.Parameter(torch.ones(1)) self.temperature_scaler_log_var = nn.Parameter(torch.ones(1) * self.logvar_init) - + + def sampler(self, mean, log_var): """ - Samples a latent vector via reparametrization trick + Samples a parameter from normal distribution via reparametrization trick """ eps = torch.randn_like(mean).to(self.device) z = torch.exp(0.5*log_var) * eps + mean return z - def forward(self, z): - batch_size = z.shape[0] - if self.dropout_proba > 0.0: - x = self.dropout_layer(z) - else: - x = z + def sample_weight_and_bias(self, mean, log_var): + weight = self.sampler(mean.weight, log_var.weight) + bias = self.sampler(mean.bias, log_var.bias) + return weight, bias - for layer_index in range(len(self.hidden_layers_sizes)-1): - layer_i_weight = self.sampler(self.hidden_layers_mean[str(layer_index)].weight, self.hidden_layers_log_var[str(layer_index)].weight) - layer_i_bias = self.sampler(self.hidden_layers_mean[str(layer_index)].bias, self.hidden_layers_log_var[str(layer_index)].bias) - x = self.first_hidden_nonlinearity(F.linear(x, weight=layer_i_weight, bias=layer_i_bias)) - if self.dropout_proba > 0.0: - x = self.dropout_layer(x) + def sample_custom(self, custom): + mean = self.hidden_layers_mean[custom] + log_var = self.hidden_layers_log_var[custom] + return self.sampler(mean, log_var) - last_index = len(self.hidden_layers_sizes)-1 - last_layer_weight = self.sampler(self.hidden_layers_mean[str(last_index)].weight, self.hidden_layers_log_var[str(last_index)].weight) - last_layer_bias = self.sampler(self.hidden_layers_mean[str(last_index)].bias, self.hidden_layers_log_var[str(last_index)].bias) - x = self.last_hidden_nonlinearity(F.linear(x, weight=last_layer_weight, bias=last_layer_bias)) - if self.dropout_proba > 0.0: + def apply_dropout(self, x): + if self.use_dropout: x = self.dropout_layer(x) + return x - W_out = self.sampler(self.last_hidden_layer_weight_mean, self.last_hidden_layer_weight_log_var) - b_out = self.sampler(self.last_hidden_layer_bias_mean, self.last_hidden_layer_bias_log_var) - + def forward(self, z): + """Decode latent vector into one-hot encoded sequence""" + + # Take input + x = self.apply_dropout(z) + # Sample parameters for each hidden layer and apply + for i in range(len(self.hidden_layers_sizes)): + mean = self.hidden_layers_mean[str(i)] + log_var = self.hidden_layers_log_var[str(i)] + weight, bias = self.sample_weight_and_bias(mean, log_var) + x = functional.linear(x, weight=weight, bias=bias) + if i < len(self.hidden_layers_sizes): + x = self.first_hidden_nonlinearity(x) + else: + x = self.last_hidden_nonlinearity(x) + x = self.apply_dropout(x) + + # Sample weight and bias for last layer + W_out = self.sampler( + self.last_hidden_layer_weight_mean, + self.last_hidden_layer_weight_log_var + ) + b_out = self.sampler( + self.last_hidden_layer_bias_mean, + self.last_hidden_layer_bias_log_var + ) + + # optionally, perform convolutions with stride 1 on this layer if self.convolve_output: - output_convolution_weight = self.sampler(self.output_convolution_mean.weight, self.output_convolution_log_var.weight) - W_out = torch.mm(W_out.view(self.seq_len * self.hidden_layers_sizes[-1], self.channel_size), - output_convolution_weight.view(self.channel_size,self.alphabet_size)) #product of size (H * seq_len, alphabet) - + output_convolution_weight = self.sampler( + self.output_convolution_mean.weight, + self.output_convolution_log_var.weight + ) + output_convolution_weight = output_convolution_weight.view( + self.channel_size,self.alphabet_size + ) + #product of size (H * seq_len, alphabet) + W_out = W_out.view(self.fcnn_output_size, self.channel_size) + W_out = torch.mm(W_out, output_convolution_weight) + W_out = W_out.view(self.seq_len * self.alphabet_size, self.hidden_layers_sizes[-1]) + + # optionally, place a sparsity prior on the last hidden layer if self.include_sparsity: - sparsity_weights = self.sampler(self.sparsity_weight_mean,self.sparsity_weight_log_var) + # sparsity weights are shrunk towards either 0 or 1 by sigmoid + sparsity_weights = self.sampler( + self.sparsity_weight_mean, + self.sparsity_weight_log_var + ) sparsity_tiled = sparsity_weights.repeat(self.num_tiles_sparsity,1) sparsity_tiled = nn.Sigmoid()(sparsity_tiled).unsqueeze(2) - - W_out = W_out.view(self.hidden_layers_sizes[-1], self.seq_len, self.alphabet_size) * sparsity_tiled + # Scale output layer by sparsity vector + W_out = W_out.view(self.hidden_layers_sizes[-1], self.seq_len, self.alphabet_size) + W_out = W_out * sparsity_tiled + W_out = W_out.view(self.seq_len * self.alphabet_size, self.hidden_layers_sizes[-1]) - W_out = W_out.view(self.seq_len * self.alphabet_size, self.hidden_layers_sizes[-1]) - - x = F.linear(x, weight=W_out, bias=b_out) + # apply last layer + print(x.shape) + print(W_out.shape) + print(b_out.shape) + x = functional.linear(x, weight=W_out, bias=b_out) + # optionally, apply temperature scaling to final output if self.include_temperature_scaler: - temperature_scaler = self.sampler(self.temperature_scaler_mean,self.temperature_scaler_log_var) - x = torch.log(1.0+torch.exp(temperature_scaler)) * x - + temperature_scaler = self.sampler( + self.temperature_scaler_mean, + self.temperature_scaler_log_var + ) + x = x * torch.log(1.0+torch.exp(temperature_scaler)) + + # reshape output to shape (batch_size, seq_len, alphabet) + self.output_dim = () + batch_size = z.shape[0] x = x.view(batch_size, self.seq_len, self.alphabet_size) - x_recon_log = F.log_softmax(x, dim=-1) #of shape (batch_size, seq_len, alphabet) - + + # return reconstruction loss + x_recon_log = functional.log_softmax(x, dim=-1) return x_recon_log class VAE_Standard_MLP_decoder(nn.Module): """ Standard MLP decoder class for the VAE model. """ - def __init__(self, seq_len, alphabet_size, hidden_layers_sizes, z_dim, first_hidden_nonlinearity, last_hidden_nonlinearity, dropout_proba, - convolve_output, convolution_depth, include_temperature_scaler, include_sparsity, num_tiles_sparsity): + def __init__(self, params): """ Required input parameters: - seq_len: (Int) Sequence length of sequence alignment @@ -213,29 +269,12 @@ def __init__(self, seq_len, alphabet_size, hidden_layers_sizes, z_dim, first_hid self.hidden_layers[str(layer_index)] = nn.Linear(self.hidden_layers_sizes[layer_index-1],self.hidden_layers_sizes[layer_index]) nn.init.constant_(self.hidden_layers[str(layer_index)].bias, self.mu_bias_init) - if params['first_hidden_nonlinearity'] == 'relu': - self.first_hidden_nonlinearity = nn.ReLU() - elif params['first_hidden_nonlinearity'] == 'tanh': - self.first_hidden_nonlinearity = nn.Tanh() - elif params['first_hidden_nonlinearity'] == 'sigmoid': - self.first_hidden_nonlinearity = nn.Sigmoid() - elif params['first_hidden_nonlinearity'] == 'elu': - self.first_hidden_nonlinearity = nn.ELU() - elif params['first_hidden_nonlinearity'] == 'linear': - self.first_hidden_nonlinearity = nn.Identity() - - if params['last_hidden_nonlinearity'] == 'relu': - self.last_hidden_nonlinearity = nn.ReLU() - elif params['last_hidden_nonlinearity'] == 'tanh': - self.last_hidden_nonlinearity = nn.Tanh() - elif params['last_hidden_nonlinearity'] == 'sigmoid': - self.last_hidden_nonlinearity = nn.Sigmoid() - elif params['last_hidden_nonlinearity'] == 'elu': - self.last_hidden_nonlinearity = nn.ELU() - elif params['last_hidden_nonlinearity'] == 'linear': - self.last_hidden_nonlinearity = nn.Identity() + # Set hidden layer nonlinearities for first and last layers + self.first_hidden_nonlinearity = HIDDEN_LAYER_NONLINEARITIES[params['first_hidden_nonlinearity']] + self.last_hidden_nonlinearity = HIDDEN_LAYER_NONLINEARITIES[params['last_hidden_nonlinearity']] if self.dropout_proba > 0.0: + self.use_dropout = True self.dropout_layer = nn.Dropout(p=self.dropout_proba) if self.convolve_output: diff --git a/data/weights/PTEN_HUMAN_theta_0.2.npy b/data/weights/PTEN_HUMAN_theta_0.2.npy new file mode 100644 index 0000000..1904805 Binary files /dev/null and b/data/weights/PTEN_HUMAN_theta_0.2.npy differ diff --git a/examples/Step2_compute_evol_indices_all_singles.sh b/examples/Step2_predict_ELBO.sh similarity index 97% rename from examples/Step2_compute_evol_indices_all_singles.sh rename to examples/Step2_predict_ELBO.sh index c5b95f2..2f34311 100644 --- a/examples/Step2_compute_evol_indices_all_singles.sh +++ b/examples/Step2_predict_ELBO.sh @@ -13,7 +13,7 @@ export output_evol_indices_location='./results/evol_indices' export num_samples_compute_evol_indices=20000 export batch_size=2048 -python compute_evol_indices.py \ +python predict_ELBO.py \ --MSA_data_folder ${MSA_data_folder} \ --MSA_list ${MSA_list} \ --protein_index ${protein_index} \ diff --git a/examples/Step3_train_GMM_and_compute_EVE_scores_all_singles.sh b/examples/Step3_predict_GMM_score.sh similarity index 73% rename from examples/Step3_train_GMM_and_compute_EVE_scores_all_singles.sh rename to examples/Step3_predict_GMM_score.sh index a68b4f4..7bccad9 100644 --- a/examples/Step3_train_GMM_and_compute_EVE_scores_all_singles.sh +++ b/examples/Step3_predict_GMM_score.sh @@ -3,15 +3,12 @@ export input_evol_indices_filename_suffix='_20000_samples' export protein_list='./data/mappings/example_mapping.csv' export output_eve_scores_location='./results/EVE_scores' export output_eve_scores_filename_suffix='Jan1_PTEN_example' - export GMM_parameter_location='./results/GMM_parameters/Default_GMM_parameters' export GMM_parameter_filename_suffix='default' export protein_GMM_weight=0.3 -export plot_location='./results' -export labels_file_location='./data/labels/PTEN_ClinVar_labels.csv' export default_uncertainty_threshold_file_location='./utils/default_uncertainty_threshold.json' -python train_GMM_and_compute_EVE_scores.py \ +python predict_GMM_score.py \ --input_evol_indices_location ${input_evol_indices_location} \ --input_evol_indices_filename_suffix ${input_evol_indices_filename_suffix} \ --protein_list ${protein_list} \ @@ -22,11 +19,6 @@ python train_GMM_and_compute_EVE_scores.py \ --GMM_parameter_filename_suffix ${GMM_parameter_filename_suffix} \ --compute_EVE_scores \ --protein_GMM_weight ${protein_GMM_weight} \ - --plot_histograms \ - --plot_scores_vs_labels \ - --plot_location ${plot_location} \ - --labels_file_location ${labels_file_location} \ - --default_uncertainty_threshold_file_location ${default_uncertainty_threshold_file_location} \ + --compute_uncertainty_thresholds \ --verbose - \ No newline at end of file diff --git a/examples/full_pipeline_PTEN.sh b/examples/full_pipeline_PTEN.sh new file mode 100644 index 0000000..1818160 --- /dev/null +++ b/examples/full_pipeline_PTEN.sh @@ -0,0 +1,73 @@ +export MSA_data_folder='./data/MSA' +export MSA_list='./data/mappings/example_mapping.csv' +export MSA_weights_location='./data/weights' +export VAE_checkpoint_location='./results/VAE_parameters' +export model_name_suffix='Jan1_PTEN_example' +export model_parameters_location='./EVE/default_model_params.json' +export training_logs_location='./logs/' +export protein_index=0 + +python train_VAE.py \ + --MSA_data_folder ${MSA_data_folder} \ + --MSA_list ${MSA_list} \ + --protein_index ${protein_index} \ + --MSA_weights_location ${MSA_weights_location} \ + --VAE_checkpoint_location ${VAE_checkpoint_location} \ + --model_name_suffix ${model_name_suffix} \ + --model_parameters_location ${model_parameters_location} \ + --training_logs_location ${training_logs_location} \ + --verbose + +export computation_mode='all_singles' +export all_singles_mutations_folder='./data/mutations' +export evol_indices_location='./results/evol_indices' +export num_samples_compute_evol_indices=20000 +export batch_size=2048 + +python predict_ELBO.py \ + --MSA_data_folder ${MSA_data_folder} \ + --MSA_list ${MSA_list} \ + --protein_index ${protein_index} \ + --MSA_weights_location ${MSA_weights_location} \ + --VAE_checkpoint_location ${VAE_checkpoint_location} \ + --model_name_suffix ${model_name_suffix} \ + --model_parameters_location ${model_parameters_location} \ + --computation_mode ${computation_mode} \ + --all_singles_mutations_folder ${all_singles_mutations_folder} \ + --output_evol_indices_location ${output_evol_indices_location} \ + --num_samples_compute_evol_indices ${num_samples_compute_evol_indices} \ + --batch_size ${batch_size} + --verbose + +export evol_indices_filename_suffix='_20000_samples' +export protein_list='./data/mappings/example_mapping.csv' +export eve_scores_location='./results/EVE_scores' +export eve_scores_filename_suffix='Jan1_PTEN_example' +export GMM_parameter_location='./results/GMM_parameters/Default_GMM_parameters' +export GMM_parameter_filename_suffix='default' +export protein_GMM_weight=0.3 +export default_uncertainty_threshold_file_location='./utils/default_uncertainty_threshold.json' + +python predict_GMM_score.py \ + --input_evol_indices_location ${evol_indices_location} \ + --input_evol_indices_filename_suffix ${evol_indices_filename_suffix} \ + --protein_list ${protein_list} \ + --output_eve_scores_location ${eve_scores_location} \ + --output_eve_scores_filename_suffix ${eve_scores_filename_suffix} \ + --GMM_parameter_location ${GMM_parameter_location} \ + --GMM_parameter_filename_suffix ${GMM_parameter_filename_suffix} \ + --protein_GMM_weight ${protein_GMM_weight} \ + --compute_uncertainty_thresholds \ + --verbose + +export plot_location='./results' +export labels_file_location='./data/labels/PTEN_ClinVar_labels.csv' + +python plot_scores_and_labels.py \ + --input_eve_scores_location ${eve_scores_location} \ + --input_eve_scores_filename_suffix ${eve_scores_filename_suffix} \ + --labels_file_location ${labels_file_location} \ + --plot_location ${plot_location} \ + --plot_histograms \ + --plot_scores_vs_labels \ + --verbose \ No newline at end of file diff --git a/logs/PTEN_HUMAN_Jan1_PTEN_example_losses.csv b/logs/PTEN_HUMAN_Jan1_PTEN_example_losses.csv new file mode 100644 index 0000000..fc2975f --- /dev/null +++ b/logs/PTEN_HUMAN_Jan1_PTEN_example_losses.csv @@ -0,0 +1,33 @@ +Number of sequences in alignment file: 1179 +Neff: 131.5475617794571 +Alignment sequence length: 387 +Number of sequences in alignment file: 1179 +Neff: 131.5475617794571 +Alignment sequence length: 387 +Number of sequences in alignment file: 1179 +Neff: 131.5475617794571 +Alignment sequence length: 387 +Number of sequences in alignment file: 1179 +Neff: 131.5475617794571 +Alignment sequence length: 387 +Number of sequences in alignment file: 1179 +Neff: 131.5475617794571 +Alignment sequence length: 387 +Number of sequences in alignment file: 1179 +Neff: 131.5475617794571 +Alignment sequence length: 387 +Number of sequences in alignment file: 1179 +Neff: 131.5475617794571 +Alignment sequence length: 387 +Number of sequences in alignment file: 1179 +Neff: 131.5475617794571 +Alignment sequence length: 387 +Number of sequences in alignment file: 1179 +Neff: 131.5475617794571 +Alignment sequence length: 387 +Number of sequences in alignment file: 1179 +Neff: 131.5475617794571 +Alignment sequence length: 387 +Number of sequences in alignment file: 1179 +Neff: 131.5475617794571 +Alignment sequence length: 387 diff --git a/plot_scores_and_labels.py b/plot_scores_and_labels.py new file mode 100644 index 0000000..5b724e0 --- /dev/null +++ b/plot_scores_and_labels.py @@ -0,0 +1,90 @@ +import argparse +import os +import tqdm +import pickle +import pandas as pd +from utils import plot_helpers + + +def main(args): + model_location = args.gmm_model_location + with open(model_location, 'rb') as fid: + gmm_model = pickle.load(fid) + scores_location = args.input_eve_scores_location + os.sep + \ + 'all_EVE_scores_'+ \ + args.output_eve_scores_filename_suffix+'.csv' + all_scores = pd.read_csv( + scores_location, + index=False + ) + protein_list = list(all_scores.protein_name.unique()) + + if args.plot_elbo_histograms: + # Plot fit of mixture model to predicted scores + histograms_location = \ + args.plot_location+os.sep+\ + 'plots_histograms'+os.sep+\ + args.output_eve_scores_filename_suffix + if not os.path.exists(histograms_location): + os.makedirs(histograms_location) + plot_helpers.plot_histograms( + all_scores, + gmm_model, + histograms_location, + protein_list + ) + + if args.plot_scores_vs_labels: + labels_dataset = pd.read_csv( + args.labels_file_location, + low_memory=False, + usecols=['protein_name','mutations','ClinVar_labels'] + ) + labels_dataset = labels_dataset[labels_dataset.ClinVar_labels.isin([0,1])] + all_scores_labelled = pd.merge( + all_scores, + labels_dataset, + how='inner', + on=['protein_name','mutations'] + ) + labelled_scores_location = args.input_eve_scores_location + os.sep + \ + 'all_EVE_scores_labelled_'+ \ + args.output_eve_scores_filename_suffix+'.csv' + all_scores_labelled.to_csv( + labelled_scores_location, index=False + ) + + # Plot scores against clinical labels + scores_vs_labels_plot_location = \ + args.plot_location+os.sep+\ + 'plots_scores_vs_labels'+os.sep+\ + args.output_eve_scores_filename_suffix + if not os.path.exists(scores_vs_labels_plot_location): + os.makedirs(scores_vs_labels_plot_location) + for protein in tqdm.tqdm(protein_list,"Plot scores Vs labels"): + protein_scores = all_scores_labelled[ + all_scores_labelled.protein_name==protein + ] + output_suffix = args.output_eve_scores_filename_suffix+'_'+protein + plot_helpers.plot_scores_vs_labels( + score_df=protein_scores, + plot_location=scores_vs_labels_plot_location, + output_eve_scores_filename_suffix=output_suffix, + mutation_name='mutations', + score_name="EVE_scores", + label_name='ClinVar_labels' + ) + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Plot EVE scores against standard labels') + parser.add_argument('--input_eve_scores_location', type=str, help='Folder where all EVE scores are stored') + parser.add_argument('--input_eve_scores_filename_suffix', default='', type=str, help='(Optional) Suffix to be added to output filename') + parser.add_argument('--labels_file_location', default=None, type=str, help='File with ground truth labels for all proteins of interest (e.g., ClinVar)') + parser.add_argument('--plot_location', default=None, type=str, help='Location of the different plots') + parser.add_argument('--plot_elbo_histograms', default=False, action='store_true', help='Plots all evol indices histograms with GMM fits') + parser.add_argument('--plot_scores_vs_labels', default=False, action='store_true', help='Plots EVE scores Vs labels at each protein position') + parser.add_argument('--verbose', action='store_true', help='Print detailed information during run') + args = parser.parse_args() + + + main(args) \ No newline at end of file diff --git a/compute_evol_indices.py b/predict_ELBO.py similarity index 63% rename from compute_evol_indices.py rename to predict_ELBO.py index 9e87034..2e61793 100644 --- a/compute_evol_indices.py +++ b/predict_ELBO.py @@ -7,32 +7,11 @@ from EVE import VAE_model from utils import data_utils -if __name__=='__main__': - - parser = argparse.ArgumentParser(description='Evol indices') - parser.add_argument('--MSA_data_folder', type=str, help='Folder where MSAs are stored') - parser.add_argument('--MSA_list', type=str, help='List of proteins and corresponding MSA file name') - parser.add_argument('--protein_index', type=int, help='Row index of protein in input mapping file') - parser.add_argument('--MSA_weights_location', type=str, help='Location where weights for each sequence in the MSA will be stored') - parser.add_argument('--theta_reweighting', type=float, help='Parameters for MSA sequence re-weighting') - parser.add_argument('--VAE_checkpoint_location', type=str, help='Location where VAE model checkpoints will be stored') - parser.add_argument('--model_name_suffix', default='Jan1', type=str, help='model checkpoint name is the protein name followed by this suffix') - parser.add_argument('--model_parameters_location', type=str, help='Location of VAE model parameters') - parser.add_argument('--computation_mode', type=str, help='Computes evol indices for all single AA mutations or for a passed in list of mutations (singles or multiples) [all_singles,input_mutations_list]') - parser.add_argument('--all_singles_mutations_folder', type=str, help='Location for the list of generated single AA mutations') - parser.add_argument('--mutations_location', type=str, help='Location of all mutations to compute the evol indices for') - parser.add_argument('--output_evol_indices_location', type=str, help='Output location of computed evol indices') - parser.add_argument('--output_evol_indices_filename_suffix', default='', type=str, help='(Optional) Suffix to be added to output filename') - parser.add_argument('--num_samples_compute_evol_indices', type=int, help='Num of samples to approximate delta elbo when computing evol indices') - parser.add_argument('--batch_size', default=256, type=int, help='Batch size when computing evol indices') - args = parser.parse_args() - +def main(args): + # Load protein name mapping_file = pd.read_csv(args.MSA_list) protein_name = mapping_file['protein_name'][args.protein_index] - msa_location = args.MSA_data_folder + os.sep + mapping_file['msa_location'][args.protein_index] print("Protein name: "+str(protein_name)) - print("MSA file: "+str(msa_location)) - if args.theta_reweighting is not None: theta = args.theta_reweighting else: @@ -42,56 +21,102 @@ theta = 0.2 print("Theta MSA re-weighting: "+str(theta)) + # Load and preprocess MSA + msa_location = \ + args.MSA_data_folder + os.sep + \ + mapping_file['msa_location'][args.protein_index] + print("MSA file: "+str(msa_location)) + weights_location = \ + args.MSA_weights_location + os.sep + \ + protein_name + '_theta_' + str(theta) + '.npy' data = data_utils.MSA_processing( MSA_location=msa_location, + weights_location=weights_location, theta=theta, - use_weights=True, - weights_location=args.MSA_weights_location + os.sep + protein_name + '_theta_' + str(theta) + '.npy' + use_weights=True ) - - if args.computation_mode=="all_singles": - data.save_all_singles(output_filename=args.all_singles_mutations_folder + os.sep + protein_name + "_all_singles.csv") - args.mutations_location = args.all_singles_mutations_folder + os.sep + protein_name + "_all_singles.csv" - else: - args.mutations_location = args.mutations_location + os.sep + protein_name + ".csv" - + + # Load model model_name = protein_name + "_" + args.model_name_suffix print("Model name: "+str(model_name)) - model_params = json.load(open(args.model_parameters_location)) - model = VAE_model.VAE_model( - model_name=model_name, - data=data, - encoder_parameters=model_params["encoder_parameters"], - decoder_parameters=model_params["decoder_parameters"], - random_seed=42 + model_name=model_name, + data=data, + encoder_parameters=model_params["encoder_parameters"], + decoder_parameters=model_params["decoder_parameters"], + random_seed=42 ) model = model.to(model.device) - try: - checkpoint_name = str(args.VAE_checkpoint_location) + os.sep + model_name + "_final" + checkpoint_name = \ + str(args.VAE_checkpoint_location) + os.sep + \ + model_name + "_final" checkpoint = torch.load(checkpoint_name) model.load_state_dict(checkpoint['model_state_dict']) print("Initialized VAE with checkpoint '{}' ".format(checkpoint_name)) except: print("Unable to locate VAE model checkpoint") sys.exit(0) - - list_valid_mutations, evol_indices, _, _ = model.compute_evol_indices(msa_data=data, - list_mutations_location=args.mutations_location, - num_samples=args.num_samples_compute_evol_indices, - batch_size=args.batch_size) - df = {} - df['protein_name'] = protein_name - df['mutations'] = list_valid_mutations - df['evol_indices'] = evol_indices - df = pd.DataFrame(df) - - evol_indices_output_filename = args.output_evol_indices_location+os.sep+protein_name+'_'+str(args.num_samples_compute_evol_indices)+'_samples'+args.output_evol_indices_filename_suffix+'.csv' + + # Load mutations location + if args.computation_mode=="all_singles": + mutations_location = \ + args.all_singles_mutations_folder + os.sep + \ + protein_name + "_all_singles.csv" + data.save_all_singles(output_filename = mutations_location) + else: + args.mutations_location = \ + args.mutations_location + os.sep + \ + protein_name + ".csv" + + # Run inference + evol_indices = model.compute_evol_indices( + msa_data=data, + list_mutations_location=args.mutations_location, + num_samples=args.num_samples_compute_evol_indices, + batch_size=args.batch_size) + list_valid_mutations, evol_indices, mean_elbo, std_elbo = evol_indices + + # Format as dataframe and write to file + df = pd.DataFrame(dict( + protein_name = protein_name, + mutations = list_valid_mutations, + evol_indices = evol_indices + )) + evol_indices_output_filename = \ + args.output_evol_indices_location + os.sep + \ + protein_name + '_' + \ + str(args.num_samples_compute_evol_indices) + '_samples' + \ + args.output_evol_indices_filename_suffix + '.csv' try: keep_header = os.stat(evol_indices_output_filename).st_size == 0 except: keep_header=True - df.to_csv(path_or_buf=evol_indices_output_filename, index=False, mode='a', header=keep_header) \ No newline at end of file + df.to_csv( + path_or_buf=evol_indices_output_filename, + index=False, mode='a', header=keep_header + ) + + +if __name__=='__main__': + parser = argparse.ArgumentParser(description='Evol indices') + parser.add_argument('--MSA_data_folder', type=str, help='Folder where MSAs are stored') + parser.add_argument('--MSA_list', type=str, help='List of proteins and corresponding MSA file name') + parser.add_argument('--protein_index', type=int, help='Row index of protein in input mapping file') + parser.add_argument('--MSA_weights_location', type=str, help='Location where weights for each sequence in the MSA will be stored') + parser.add_argument('--theta_reweighting', type=float, help='Parameters for MSA sequence re-weighting') + parser.add_argument('--VAE_checkpoint_location', type=str, help='Location where VAE model checkpoints will be stored') + parser.add_argument('--model_name_suffix', default='Jan1', type=str, help='model checkpoint name is the protein name followed by this suffix') + parser.add_argument('--model_parameters_location', type=str, help='Location of VAE model parameters') + parser.add_argument('--computation_mode', type=str, help='Computes evol indices for all single AA mutations or for a passed in list of mutations (singles or multiples) [all_singles,input_mutations_list]') + parser.add_argument('--all_singles_mutations_folder', type=str, help='Location for the list of generated single AA mutations') + parser.add_argument('--mutations_location', type=str, help='Location of all mutations to compute the evol indices for') + parser.add_argument('--output_evol_indices_location', type=str, help='Output location of computed evol indices') + parser.add_argument('--output_evol_indices_filename_suffix', default='', type=str, help='(Optional) Suffix to be added to output filename') + parser.add_argument('--num_samples_compute_evol_indices', type=int, help='Num of samples to approximate delta elbo when computing evol indices') + parser.add_argument('--batch_size', default=256, type=int, help='Batch size when computing evol indices') + args = parser.parse_args() + + main(args) \ No newline at end of file diff --git a/predict_GMM_score.py b/predict_GMM_score.py new file mode 100644 index 0000000..0437059 --- /dev/null +++ b/predict_GMM_score.py @@ -0,0 +1,166 @@ +import os +import numpy as np +import pandas as pd +import argparse +import pickle +import tqdm +import json +from utils import performance_helpers as ph, plot_helpers +from EVE import GMM_model + + +def preprocess_evol_indices(all_evol_indices, protein_name=None, verbose=False): + evol_indices = all_evol_indices.drop_duplicates() + X_train = evol_indices['evol_indices'].values.reshape(-1, 1) + proteins_train = list(evol_indices['protein_name'].values) + if verbose: + print("Training data size: "+str(len(X_train))) + print("Number of distinct proteins in protein_list: "+str(len(np.unique(all_evol_indices['protein_name'])))) + return X_train, proteins_train + + +def predict_scores_GMM(gmm_model, all_evol_indices, + recompute_uncertainty_threshold = True, uncertainty_threshold_location = None, + verbose = False): + + all_preds = all_evol_indices.copy() + + if gmm_model.protein_GMM_weight > 0.0: + all_preds['scores'] = np.nan + all_preds['classes'] = "" + protein_list = list(gmm_model.models.keys()).drop('main') + for protein in tqdm.tqdm(protein_list,"Scoring all protein mutations"): + preds_protein = all_preds[all_preds.protein_name==protein].copy() + X_pred_protein = preds_protein['evol_indices'].values.reshape(-1, 1) + scores, classes = gmm_model.predict_weighted(gmm_model, X_pred_protein) + preds_protein['scores'] = scores + preds_protein['classes'] = classes + all_preds.loc[all_preds.protein_name==protein, :] = preds_protein + + else: + X_pred = all_preds['evol_indices'].values.reshape(-1, 1) + scores, classes = gmm_model.predict(X_pred, 'main') + all_preds['scores'] = scores + all_preds['classes'] = classes + + if verbose: + scores_stats = all_preds['scores'].describe() + print("Score stats: \n", scores_stats) + len_before_drop_na = len(all_preds) + len_after_drop_na = len(all_preds['scores'].dropna()) + print("Dropped mutations due to missing EVE scores: "+str(len_after_drop_na-len_before_drop_na)) + + return all_preds + + +def filter_uncertainties(all_scores, n_quantiles, default_uc_location, recompute = False): + # Compute uncertainty from mixture model + y_pred = all_scores['scores'] + uncertainty = ph.predictive_entropy_binary_classifier(y_pred) + all_scores['uncertainty'] = uncertainty + + # Get quantiles for uncertainty + if not recompute: + with open(default_uc_location,'r') as fid: + uc_quantiles = json.load(fid) + uc_quantiles = ph.get_uncertainty_thresholds(uncertainty, n_quantiles) + if verbose: + print('Quantiles', uc_quantiles) + + # Assign classes at each quantile + for i, quantile in enumerate(quantiles): + level = f'class_ucq_{i}' + all_scores[level] = all_scores['class'] * (uncertainty < quantile) + if verbose: + print("Stats classification by uncertainty for quantile #:"+str(quantile)) + print(all_scores[level].value_counts(normalize=True)) + return all_scores + + +def main(args): + # Load evolutionary indices from files + mapping_file = pd.read_csv(args.protein_list,low_memory=False) + protein_list = np.unique(mapping_file['protein_name']) + list_variables_to_keep=['protein_name','mutations','evol_indices'] + all_evol_indices = [] + for protein in protein_list: + evol_indices_location = \ + args.input_evol_indices_location + os.sep + \ + protein + args.input_evol_indices_filename_suffix + '.csv' + if os.path.exists(evol_indices_location): + evol_indices = pd.read_csv( + evol_indices_location, + low_memory=False, ignore_index=True, + usecols=list_variables_to_keep) + all_evol_indices.append(evol_indices) + all_evol_indices = pd.concat(all_evol_indices) + + if args.load_GMM_models: + # Load GMM models from file + gmm_model_location = \ + args.GMM_parameter_location+os.sep+\ + 'GMM_model_dictionary_'+ \ + args.GMM_parameter_filename_suffix + with open(gmm_model_location, "rb" ) as fid: + gmm = pickle.load(fid) + + else: + # train GMMs on mutation evolutionary indices + gmm_params = { + 'protein_GMM_weight':args.protein_GMM_weight + } + gmm_model = GMM_model.GMM_model(gmm_params) + gmm_model.fit( + all_evol_indices, + protein_list + ) + + # Write GMM models to pickle file + gmm_model_location = \ + args.GMM_parameter_location+os.sep+\ + 'GMM_model_dictionary_'+ \ + args.output_eve_scores_filename_suffix + with open(gmm_model_location, "wb" ) as fid: + pickle.dump(gmm_model, fid) + + # Compute EVE classification scores for all mutations and write to file + all_scores = predict_scores_GMM( + gmm_model, + all_evol_indices, + protein_list, + args.recompute_uncertainty_threshold + ) + all_scores.to_csv( + args.output_eve_scores_location+os.sep+ \ + 'all_EVE_scores_'+ \ + args.output_eve_scores_filename_suffix+'.csv', + index=False + ) + + +if __name__=='__main__': + parser = argparse.ArgumentParser(description='GMM fit and EVE scores computation') + parser.add_argument('--input_evol_indices_location', type=str, help='Folder where all individual files with evolutionary indices are stored') + parser.add_argument('--input_evol_indices_filename_suffix', type=str, default='', help='Suffix that was added when generating the evol indices files') + parser.add_argument('--protein_list', type=str, help='List of proteins to be included (one per row)') + parser.add_argument('--output_eve_scores_location', type=str, help='Folder where all EVE scores are stored') + parser.add_argument('--output_eve_scores_filename_suffix', default='', type=str, help='(Optional) Suffix to be added to output filename') + + parser.add_argument('--load_GMM_models', default=False, action='store_true', help='If True, load GMM model parameters. If False, train GMMs from evol indices files') + parser.add_argument('--GMM_parameter_location', default=None, type=str, help='Folder where GMM objects are stored if loading / to be stored if we are re-training') + parser.add_argument('--GMM_parameter_filename_suffix', default=None, type=str, help='Suffix of GMMs model files to load') + parser.add_argument('--protein_GMM_weight', default=0.3, type=float, help='Value of global-local GMM mixing parameter') + + parser.add_argument('--compute_EVE_scores', default=False, action='store_true', help='Computes EVE scores and uncertainty metrics for all input protein mutations') + parser.add_argument('--recompute_uncertainty_threshold', default=False, action='store_true', help='Recompute uncertainty thresholds based on all evol indices in file. Otherwise loads default threhold.') + parser.add_argument('--default_uncertainty_threshold_file_location', default='./utils/default_uncertainty_threshold.json', type=str, help='Location of default uncertainty threholds.') + + parser.add_argument('--plot_histograms', default=False, action='store_true', help='Plots all evol indices histograms with GMM fits') + parser.add_argument('--plot_scores_vs_labels', default=False, action='store_true', help='Plots EVE scores Vs labels at each protein position') + parser.add_argument('--labels_file_location', default=None, type=str, help='File with ground truth labels for all proteins of interest (e.g., ClinVar)') + parser.add_argument('--plot_location', default=None, type=str, help='Location of the different plots') + parser.add_argument('--verbose', action='store_true', help='Print detailed information during run') + args = parser.parse_args() + + main(args) + \ No newline at end of file diff --git a/train_GMM_and_compute_EVE_scores.py b/train_GMM_and_compute_EVE_scores.py deleted file mode 100644 index acee314..0000000 --- a/train_GMM_and_compute_EVE_scores.py +++ /dev/null @@ -1,192 +0,0 @@ -import os -import numpy as np -import pandas as pd -import argparse -import pickle -import tqdm -import json -from sklearn import mixture, linear_model, svm, gaussian_process - -from utils import performance_helpers as ph, plot_helpers - -if __name__=='__main__': - parser = argparse.ArgumentParser(description='GMM fit and EVE scores computation') - parser.add_argument('--input_evol_indices_location', type=str, help='Folder where all individual files with evolutionary indices are stored') - parser.add_argument('--input_evol_indices_filename_suffix', type=str, default='', help='Suffix that was added when generating the evol indices files') - parser.add_argument('--protein_list', type=str, help='List of proteins to be included (one per row)') - parser.add_argument('--output_eve_scores_location', type=str, help='Folder where all EVE scores are stored') - parser.add_argument('--output_eve_scores_filename_suffix', default='', type=str, help='(Optional) Suffix to be added to output filename') - - parser.add_argument('--load_GMM_models', default=False, action='store_true', help='If True, load GMM model parameters. If False, train GMMs from evol indices files') - parser.add_argument('--GMM_parameter_location', default=None, type=str, help='Folder where GMM objects are stored if loading / to be stored if we are re-training') - parser.add_argument('--GMM_parameter_filename_suffix', default=None, type=str, help='Suffix of GMMs model files to load') - parser.add_argument('--protein_GMM_weight', default=0.3, type=float, help='Value of global-local GMM mixing parameter') - - parser.add_argument('--compute_EVE_scores', default=False, action='store_true', help='Computes EVE scores and uncertainty metrics for all input protein mutations') - parser.add_argument('--recompute_uncertainty_threshold', default=False, action='store_true', help='Recompute uncertainty thresholds based on all evol indices in file. Otherwise loads default threhold.') - parser.add_argument('--default_uncertainty_threshold_file_location', default='./utils/default_uncertainty_threshold.json', type=str, help='Location of default uncertainty threholds.') - - parser.add_argument('--plot_histograms', default=False, action='store_true', help='Plots all evol indices histograms with GMM fits') - parser.add_argument('--plot_scores_vs_labels', default=False, action='store_true', help='Plots EVE scores Vs labels at each protein position') - parser.add_argument('--labels_file_location', default=None, type=str, help='File with ground truth labels for all proteins of interest (e.g., ClinVar)') - parser.add_argument('--plot_location', default=None, type=str, help='Location of the different plots') - parser.add_argument('--verbose', action='store_true', help='Print detailed information during run') - args = parser.parse_args() - - mapping_file = pd.read_csv(args.protein_list,low_memory=False) - protein_list = np.unique(mapping_file['protein_name']) - list_variables_to_keep=['protein_name','mutations','evol_indices'] - all_evol_indices = pd.concat([pd.read_csv(args.input_evol_indices_location+os.sep+protein+args.input_evol_indices_filename_suffix+'.csv',low_memory=False)[list_variables_to_keep] \ - for protein in protein_list if os.path.exists(args.input_evol_indices_location+os.sep+protein+args.input_evol_indices_filename_suffix+'.csv')], ignore_index=True) - - all_evol_indices = all_evol_indices.drop_duplicates() - X_train = np.array(all_evol_indices['evol_indices']).reshape(-1, 1) - if args.verbose: - print("Training data size: "+str(len(X_train))) - print("Number of distinct proteins in protein_list: "+str(len(np.unique(all_evol_indices['protein_name'])))) - - if args.load_GMM_models: - dict_models = pickle.load( open( args.GMM_parameter_location+os.sep+'GMM_model_dictionary_'+args.GMM_parameter_filename_suffix, "rb" ) ) - dict_pathogenic_cluster_index = pickle.load( open( args.GMM_parameter_location+os.sep+'GMM_pathogenic_cluster_index_dictionary_'+args.GMM_parameter_filename_suffix, "rb" ) ) - else: - dict_models = {} - dict_pathogenic_cluster_index = {} - if not os.path.exists(args.GMM_parameter_location+os.sep+args.output_eve_scores_filename_suffix): - os.makedirs(args.GMM_parameter_location+os.sep+args.output_eve_scores_filename_suffix) - GMM_stats_log_location=args.GMM_parameter_location+os.sep+args.output_eve_scores_filename_suffix+os.sep+'GMM_stats_'+args.output_eve_scores_filename_suffix+'.csv' - with open(GMM_stats_log_location, "a") as logs: - logs.write("protein_name,weight_pathogenic,mean_pathogenic,mean_benign,std_dev_pathogenic,std_dev_benign\n") - - main_GMM = mixture.GaussianMixture(n_components=2, covariance_type='full',max_iter=1000,n_init=30,tol=1e-4) - main_GMM.fit(X_train) - - dict_models['main'] = main_GMM - pathogenic_cluster_index = np.argmax(np.array(main_GMM.means_).flatten()) #The pathogenic cluster is the cluster with higher mean value - dict_pathogenic_cluster_index['main'] = pathogenic_cluster_index - if args.verbose: - inferred_params = main_GMM.get_params() - print("Index of mixture component with highest mean: "+str(pathogenic_cluster_index)) - print("Model parameters: "+str(inferred_params)) - print("Mixture component weights: "+str(main_GMM.weights_)) - print("Mixture component means: "+str(main_GMM.means_)) - print("Cluster component cov: "+str(main_GMM.covariances_)) - with open(GMM_stats_log_location, "a") as logs: - logs.write(",".join(str(x) for x in [ - 'main', np.array(main_GMM.weights_).flatten()[dict_pathogenic_cluster_index['main']], np.array(main_GMM.means_).flatten()[dict_pathogenic_cluster_index['main']], - np.array(main_GMM.means_).flatten()[1 - dict_pathogenic_cluster_index['main']], np.sqrt(np.array(main_GMM.covariances_).flatten()[dict_pathogenic_cluster_index['main']]), - np.sqrt(np.array(main_GMM.covariances_).flatten()[1 - dict_pathogenic_cluster_index['main']]) - ])+"\n") - - if args.protein_GMM_weight > 0.0: - for protein in tqdm.tqdm(protein_list, "Training all protein GMMs"): - X_train_protein = np.array(all_evol_indices['evol_indices'][all_evol_indices.protein_name==protein]).reshape(-1, 1) - if len(X_train_protein) > 0: #We have evol indices computed for protein on file - protein_GMM = mixture.GaussianMixture(n_components=2,covariance_type='full',max_iter=1000,tol=1e-4,weights_init=main_GMM.weights_,means_init=main_GMM.means_,precisions_init=main_GMM.precisions_) - protein_GMM.fit(X_train_protein) - dict_models[protein] = protein_GMM - dict_pathogenic_cluster_index[protein] = np.argmax(np.array(protein_GMM.means_).flatten()) - with open(GMM_stats_log_location, "a") as logs: - logs.write(",".join(str(x) for x in [ - protein, np.array(protein_GMM.weights_).flatten()[dict_pathogenic_cluster_index[protein]], np.array(protein_GMM.means_).flatten()[dict_pathogenic_cluster_index[protein]], - np.array(protein_GMM.means_).flatten()[1 - dict_pathogenic_cluster_index[protein]], np.sqrt(np.array(protein_GMM.covariances_).flatten()[dict_pathogenic_cluster_index[protein]]), - np.sqrt(np.array(protein_GMM.covariances_).flatten()[1 - dict_pathogenic_cluster_index[protein]]) - ])+"\n") - else: - if args.verbose: - print("No evol indices for the protein: "+str(protein)+". Skipping.") - - pickle.dump(dict_models, open(args.GMM_parameter_location+os.sep+args.output_eve_scores_filename_suffix+os.sep+'GMM_model_dictionary_'+args.output_eve_scores_filename_suffix, 'wb')) - pickle.dump(dict_pathogenic_cluster_index, open(args.GMM_parameter_location+os.sep+args.output_eve_scores_filename_suffix+os.sep+'GMM_pathogenic_cluster_index_dictionary_'+args.output_eve_scores_filename_suffix, 'wb')) - - if args.plot_histograms: - if not os.path.exists(args.plot_location+os.sep+'plots_histograms'+os.sep+args.output_eve_scores_filename_suffix): - os.makedirs(args.plot_location+os.sep+'plots_histograms'+os.sep+args.output_eve_scores_filename_suffix) - plot_helpers.plot_histograms(all_evol_indices, dict_models, dict_pathogenic_cluster_index, args.protein_GMM_weight, args.plot_location+os.sep+'plots_histograms'+os.sep+args.output_eve_scores_filename_suffix, args.output_eve_scores_filename_suffix, protein_list) - - if args.compute_EVE_scores: - if args.protein_GMM_weight > 0.0: - all_scores = all_evol_indices.copy() - all_scores['EVE_scores'] = np.nan - all_scores['EVE_classes_100_pct_retained'] = "" - for protein in tqdm.tqdm(protein_list,"Scoring all protein mutations"): - try: - test_data_protein = all_scores[all_scores.protein_name==protein] - X_test_protein = np.array(test_data_protein['evol_indices']).reshape(-1, 1) - mutation_scores_protein = ph.compute_weighted_score_two_GMMs(X_pred=X_test_protein, - main_model = dict_models['main'], - protein_model=dict_models[protein], - cluster_index_main = dict_pathogenic_cluster_index['main'], - cluster_index_protein = dict_pathogenic_cluster_index[protein], - protein_weight = args.protein_GMM_weight) - gmm_class_protein = ph.compute_weighted_class_two_GMMs(X_pred=X_test_protein, - main_model = dict_models['main'], - protein_model=dict_models[protein], - cluster_index_main = dict_pathogenic_cluster_index['main'], - cluster_index_protein = dict_pathogenic_cluster_index[protein], - protein_weight = args.protein_GMM_weight) - gmm_class_label_protein = pd.Series(gmm_class_protein).map(lambda x: 'Pathogenic' if x == 1 else 'Benign') - - all_scores.loc[all_scores.protein_name==protein, 'EVE_scores'] = np.array(mutation_scores_protein) - all_scores.loc[all_scores.protein_name==protein, 'EVE_classes_100_pct_retained'] = np.array(gmm_class_label_protein) - except: - print("Issues with protein: "+str(protein)+". Skipping.") - else: - all_scores = all_evol_indices.copy() - mutation_scores = dict_models['main'].predict_proba(np.array(all_scores['evol_indices']).reshape(-1, 1)) - all_scores['EVE_scores'] = mutation_scores[:,dict_pathogenic_cluster_index['main']] - gmm_class = dict_models['main'].predict(np.array(all_scores['evol_indices']).reshape(-1, 1)) - all_scores['EVE_classes_100_pct_retained'] = np.array(pd.Series(gmm_class).map(lambda x: 'Pathogenic' if x == dict_pathogenic_cluster_index['main'] else 'Benign')) - - len_before_drop_na = len(all_scores) - all_scores = all_scores.dropna(subset=['EVE_scores']) - len_after_drop_na = len(all_scores) - - if args.verbose: - scores_stats = ph.compute_stats(all_scores['EVE_scores']) - print("Score stats: "+str(scores_stats)) - print("Dropped mutations due to missing EVE scores: "+str(len_after_drop_na-len_before_drop_na)) - all_scores['uncertainty'] = ph.predictive_entropy_binary_classifier(all_scores['EVE_scores']) - - if args.recompute_uncertainty_threshold: - uncertainty_cutoffs_deciles, _, _ = ph.compute_uncertainty_deciles(all_scores) - uncertainty_cutoffs_quartiles, _, _ = ph.compute_uncertainty_quartiles(all_scores) - if args.verbose: - print("Uncertainty cutoffs deciles: "+str(uncertainty_cutoffs_deciles)) - print("Uncertainty cutoffs quartiles: "+str(uncertainty_cutoffs_quartiles)) - else: - uncertainty_thresholds = json.load(open(args.default_uncertainty_threshold_file_location)) - uncertainty_cutoffs_deciles = uncertainty_thresholds["deciles"] - uncertainty_cutoffs_quartiles = uncertainty_thresholds["quartiles"] - - for decile in range(1,10): - classification_name = 'EVE_classes_'+str((decile)*10)+"_pct_retained" - all_scores[classification_name] = all_scores['EVE_classes_100_pct_retained'] - all_scores.loc[all_scores['uncertainty'] > uncertainty_cutoffs_deciles[str(decile)], classification_name] = 'Uncertain' - if args.verbose: - print("Stats classification by uncertainty for decile #:"+str(decile)) - print(all_scores[classification_name].value_counts(normalize=True)) - if args.verbose: - print("Stats classification by uncertainty for decile #:"+str(10)) - print(all_scores['EVE_classes_100_pct_retained'].value_counts(normalize=True)) - - for quartile in [1,3]: - classification_name = 'EVE_classes_'+str((quartile)*25)+"_pct_retained" - all_scores[classification_name] = all_scores['EVE_classes_100_pct_retained'] - all_scores.loc[all_scores['uncertainty'] > uncertainty_cutoffs_quartiles[str(quartile)], classification_name] = 'Uncertain' - if args.verbose: - print("Stats classification by uncertainty for quartile #:"+str(quartile)) - print(all_scores[classification_name].value_counts(normalize=True)) - - all_scores.to_csv(args.output_eve_scores_location+os.sep+'all_EVE_scores_'+args.output_eve_scores_filename_suffix+'.csv', index=False) - - if args.plot_scores_vs_labels: - labels_dataset=pd.read_csv(args.labels_file_location,low_memory=False) - all_scores_mutations_with_labels = pd.merge(all_scores, labels_dataset[['protein_name','mutations','ClinVar_labels']], how='inner', on=['protein_name','mutations']) - all_PB_scores = all_scores_mutations_with_labels[all_scores_mutations_with_labels.ClinVar_labels!=0.5].copy() - if not os.path.exists(args.plot_location+os.sep+'plots_scores_vs_labels'+os.sep+args.output_eve_scores_filename_suffix): - os.makedirs(args.plot_location+os.sep+'plots_scores_vs_labels'+os.sep+args.output_eve_scores_filename_suffix) - for protein in tqdm.tqdm(protein_list,"Plot scores Vs labels"): - plot_helpers.plot_scores_vs_labels(score_df=all_PB_scores[all_PB_scores.protein_name==protein], - plot_location=args.plot_location+os.sep+'plots_scores_vs_labels'+os.sep+args.output_eve_scores_filename_suffix, - output_eve_scores_filename_suffix=args.output_eve_scores_filename_suffix+'_'+protein, - mutation_name='mutations', score_name="EVE_scores", label_name='ClinVar_labels') \ No newline at end of file