From 2dce4f35f41b497a479141b7205acd1a97954f83 Mon Sep 17 00:00:00 2001 From: Andrew Boardman Date: Tue, 13 Sep 2022 08:30:59 +0100 Subject: [PATCH 1/3] Fix standard max-likelihood decoder --- EVE/VAE_decoder.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/EVE/VAE_decoder.py b/EVE/VAE_decoder.py index 7bfa57a..18acb2d 100644 --- a/EVE/VAE_decoder.py +++ b/EVE/VAE_decoder.py @@ -170,8 +170,7 @@ class VAE_Standard_MLP_decoder(nn.Module): """ Standard MLP decoder class for the VAE model. """ - def __init__(self, seq_len, alphabet_size, hidden_layers_sizes, z_dim, first_hidden_nonlinearity, last_hidden_nonlinearity, dropout_proba, - convolve_output, convolution_depth, include_temperature_scaler, include_sparsity, num_tiles_sparsity): + def __init__(self, params): """ Required input parameters: - seq_len: (Int) Sequence length of sequence alignment From 23bbeac1412aaacf55b675352c6152632743e2b2 Mon Sep 17 00:00:00 2001 From: Andrew Boardman Date: Tue, 13 Sep 2022 14:39:49 +0100 Subject: [PATCH 2/3] refactor and create global-local GMM model class --- EVE/GMM_model.py | 95 ++++++++++ examples/full_pipeline_PTEN.sh | 73 ++++++++ plot_scores_and_labels.py | 90 ++++++++++ compute_evol_indices.py => predict_ELBO.py | 150 +++++++++------- predict_GMM_score.py | 182 +++++++++++++++++++ train_GMM_and_compute_EVE_scores.py | 192 --------------------- 6 files changed, 527 insertions(+), 255 deletions(-) create mode 100644 EVE/GMM_model.py create mode 100644 examples/full_pipeline_PTEN.sh create mode 100644 plot_scores_and_labels.py rename compute_evol_indices.py => predict_ELBO.py (63%) create mode 100644 predict_GMM_score.py delete mode 100644 train_GMM_and_compute_EVE_scores.py diff --git a/EVE/GMM_model.py b/EVE/GMM_model.py new file mode 100644 index 0000000..c9f6457 --- /dev/null +++ b/EVE/GMM_model.py @@ -0,0 +1,95 @@ +from sklearn import mixture +import numpy as np +import os +import tqdm + +class GMM_model: + """ + Class for the global-local GMM model trained on single-point mutations for a wild-type protein sequence + """ + def __init__(self, params) -> None: + # Set up logging + self.log_location = params['log_location'] + if not os.path.exists(os.path.basename(self.log_location)): + os.makedirs(os.path.basename(self.log_location)) + with open(self.log_location, "a") as logs: + logs.write("protein_name,weight_pathogenic,mean_pathogenic,mean_benign,std_dev_pathogenic,std_dev_benign\n") + + # store parameters + self.protein_GMM_weight = params['protein_GMM_weight'] + + def fit_single(self, X_train, protein_name=None, verbose = False): + model = mixture.GaussianMixture( + n_components=2, + covariance_type='full', + max_iter=1000, + n_init=30, + tol=1e-4 + ) + model.fit(X_train) + #The pathogenic cluster is the cluster with higher mean value + pathogenic_cluster_index = np.argmax(np.array(self.model.means_).flatten()) + + with open(self.log_location, "a") as logs: + logs.write(",".join(str(x) for x in [ + protein_name, + np.array(self.model.weights_).flatten()[self.pathogenic_cluster_index], + np.array(self.model.means_).flatten()[self.pathogenic_cluster_index], + np.array(self.model.means_).flatten()[1 - self.pathogenic_cluster_index], + np.sqrt(np.array(self.model.covariances_).flatten()[self.pathogenic_cluster_index]), + np.sqrt(np.array(self.model.covariances_).flatten()[1 - self.pathogenic_cluster_index]) + ]) + +"\n" + ) + + if verbose: + inferred_params = self.model.get_params() + print("Index of mixture component with highest mean: "+str(self.pathogenic_cluster_index)) + print("Model parameters: "+str(inferred_params)) + print("Mixture component weights: "+str(self.model.weights_)) + print("Mixture component means: "+str(self.model.means_)) + print("Cluster component cov: "+str(self.model.covariances_)) + + return model, pathogenic_cluster_index + + def fit(self, X_train, proteins_train, verbose = True): + # set up to train + self.models = {} + self.indices = {} + + # train global model + gmm, index = self.fit_single(X_train,'main',verbose=verbose) + self.models['main'] = gmm + self.indices['main'] = index + + # train local models + if self.protein_GMM_weight > 0.0: + proteins_list = list(set(proteins_train)) + for protein in tqdm.tqdm(proteins_list, "Training all protein GMMs"): + X_train_protein = X_train[proteins_train == protein] + gmm, index = self.fit_single(X_train_protein,protein,verbose=verbose) + self.models[protein] = gmm + self.indices[protein] = index + + return self.models, self.indices + + def predict(self, X_pred, protein): + model = self.models[protein] + cluster_index = self.indices[protein] + scores = model.predict_proba(X_pred)[:,cluster_index] + classes = (scores > 0.5).astype(int) + return scores, classes + + def predict_weighted(self, X_pred, protein): + scores_protein = self.predict(X_pred, protein) + scores_main = self.predict(X_pred, 'main') + + scores_weighted = scores_main * (1 - self.protein_GMM_weight) + \ + scores_protein * self.protein_GMM_weight + classes_weighted = (scores_weighted > 0.5).astype(int) + return scores_weighted, classes_weighted + + + + +X_train, groups_train = preprocess_indices(all_evol_indices) \ No newline at end of file diff --git a/examples/full_pipeline_PTEN.sh b/examples/full_pipeline_PTEN.sh new file mode 100644 index 0000000..1818160 --- /dev/null +++ b/examples/full_pipeline_PTEN.sh @@ -0,0 +1,73 @@ +export MSA_data_folder='./data/MSA' +export MSA_list='./data/mappings/example_mapping.csv' +export MSA_weights_location='./data/weights' +export VAE_checkpoint_location='./results/VAE_parameters' +export model_name_suffix='Jan1_PTEN_example' +export model_parameters_location='./EVE/default_model_params.json' +export training_logs_location='./logs/' +export protein_index=0 + +python train_VAE.py \ + --MSA_data_folder ${MSA_data_folder} \ + --MSA_list ${MSA_list} \ + --protein_index ${protein_index} \ + --MSA_weights_location ${MSA_weights_location} \ + --VAE_checkpoint_location ${VAE_checkpoint_location} \ + --model_name_suffix ${model_name_suffix} \ + --model_parameters_location ${model_parameters_location} \ + --training_logs_location ${training_logs_location} \ + --verbose + +export computation_mode='all_singles' +export all_singles_mutations_folder='./data/mutations' +export evol_indices_location='./results/evol_indices' +export num_samples_compute_evol_indices=20000 +export batch_size=2048 + +python predict_ELBO.py \ + --MSA_data_folder ${MSA_data_folder} \ + --MSA_list ${MSA_list} \ + --protein_index ${protein_index} \ + --MSA_weights_location ${MSA_weights_location} \ + --VAE_checkpoint_location ${VAE_checkpoint_location} \ + --model_name_suffix ${model_name_suffix} \ + --model_parameters_location ${model_parameters_location} \ + --computation_mode ${computation_mode} \ + --all_singles_mutations_folder ${all_singles_mutations_folder} \ + --output_evol_indices_location ${output_evol_indices_location} \ + --num_samples_compute_evol_indices ${num_samples_compute_evol_indices} \ + --batch_size ${batch_size} + --verbose + +export evol_indices_filename_suffix='_20000_samples' +export protein_list='./data/mappings/example_mapping.csv' +export eve_scores_location='./results/EVE_scores' +export eve_scores_filename_suffix='Jan1_PTEN_example' +export GMM_parameter_location='./results/GMM_parameters/Default_GMM_parameters' +export GMM_parameter_filename_suffix='default' +export protein_GMM_weight=0.3 +export default_uncertainty_threshold_file_location='./utils/default_uncertainty_threshold.json' + +python predict_GMM_score.py \ + --input_evol_indices_location ${evol_indices_location} \ + --input_evol_indices_filename_suffix ${evol_indices_filename_suffix} \ + --protein_list ${protein_list} \ + --output_eve_scores_location ${eve_scores_location} \ + --output_eve_scores_filename_suffix ${eve_scores_filename_suffix} \ + --GMM_parameter_location ${GMM_parameter_location} \ + --GMM_parameter_filename_suffix ${GMM_parameter_filename_suffix} \ + --protein_GMM_weight ${protein_GMM_weight} \ + --compute_uncertainty_thresholds \ + --verbose + +export plot_location='./results' +export labels_file_location='./data/labels/PTEN_ClinVar_labels.csv' + +python plot_scores_and_labels.py \ + --input_eve_scores_location ${eve_scores_location} \ + --input_eve_scores_filename_suffix ${eve_scores_filename_suffix} \ + --labels_file_location ${labels_file_location} \ + --plot_location ${plot_location} \ + --plot_histograms \ + --plot_scores_vs_labels \ + --verbose \ No newline at end of file diff --git a/plot_scores_and_labels.py b/plot_scores_and_labels.py new file mode 100644 index 0000000..5b724e0 --- /dev/null +++ b/plot_scores_and_labels.py @@ -0,0 +1,90 @@ +import argparse +import os +import tqdm +import pickle +import pandas as pd +from utils import plot_helpers + + +def main(args): + model_location = args.gmm_model_location + with open(model_location, 'rb') as fid: + gmm_model = pickle.load(fid) + scores_location = args.input_eve_scores_location + os.sep + \ + 'all_EVE_scores_'+ \ + args.output_eve_scores_filename_suffix+'.csv' + all_scores = pd.read_csv( + scores_location, + index=False + ) + protein_list = list(all_scores.protein_name.unique()) + + if args.plot_elbo_histograms: + # Plot fit of mixture model to predicted scores + histograms_location = \ + args.plot_location+os.sep+\ + 'plots_histograms'+os.sep+\ + args.output_eve_scores_filename_suffix + if not os.path.exists(histograms_location): + os.makedirs(histograms_location) + plot_helpers.plot_histograms( + all_scores, + gmm_model, + histograms_location, + protein_list + ) + + if args.plot_scores_vs_labels: + labels_dataset = pd.read_csv( + args.labels_file_location, + low_memory=False, + usecols=['protein_name','mutations','ClinVar_labels'] + ) + labels_dataset = labels_dataset[labels_dataset.ClinVar_labels.isin([0,1])] + all_scores_labelled = pd.merge( + all_scores, + labels_dataset, + how='inner', + on=['protein_name','mutations'] + ) + labelled_scores_location = args.input_eve_scores_location + os.sep + \ + 'all_EVE_scores_labelled_'+ \ + args.output_eve_scores_filename_suffix+'.csv' + all_scores_labelled.to_csv( + labelled_scores_location, index=False + ) + + # Plot scores against clinical labels + scores_vs_labels_plot_location = \ + args.plot_location+os.sep+\ + 'plots_scores_vs_labels'+os.sep+\ + args.output_eve_scores_filename_suffix + if not os.path.exists(scores_vs_labels_plot_location): + os.makedirs(scores_vs_labels_plot_location) + for protein in tqdm.tqdm(protein_list,"Plot scores Vs labels"): + protein_scores = all_scores_labelled[ + all_scores_labelled.protein_name==protein + ] + output_suffix = args.output_eve_scores_filename_suffix+'_'+protein + plot_helpers.plot_scores_vs_labels( + score_df=protein_scores, + plot_location=scores_vs_labels_plot_location, + output_eve_scores_filename_suffix=output_suffix, + mutation_name='mutations', + score_name="EVE_scores", + label_name='ClinVar_labels' + ) + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Plot EVE scores against standard labels') + parser.add_argument('--input_eve_scores_location', type=str, help='Folder where all EVE scores are stored') + parser.add_argument('--input_eve_scores_filename_suffix', default='', type=str, help='(Optional) Suffix to be added to output filename') + parser.add_argument('--labels_file_location', default=None, type=str, help='File with ground truth labels for all proteins of interest (e.g., ClinVar)') + parser.add_argument('--plot_location', default=None, type=str, help='Location of the different plots') + parser.add_argument('--plot_elbo_histograms', default=False, action='store_true', help='Plots all evol indices histograms with GMM fits') + parser.add_argument('--plot_scores_vs_labels', default=False, action='store_true', help='Plots EVE scores Vs labels at each protein position') + parser.add_argument('--verbose', action='store_true', help='Print detailed information during run') + args = parser.parse_args() + + + main(args) \ No newline at end of file diff --git a/compute_evol_indices.py b/predict_ELBO.py similarity index 63% rename from compute_evol_indices.py rename to predict_ELBO.py index 9e87034..f8f1388 100644 --- a/compute_evol_indices.py +++ b/predict_ELBO.py @@ -7,32 +7,43 @@ from EVE import VAE_model from utils import data_utils -if __name__=='__main__': - - parser = argparse.ArgumentParser(description='Evol indices') - parser.add_argument('--MSA_data_folder', type=str, help='Folder where MSAs are stored') - parser.add_argument('--MSA_list', type=str, help='List of proteins and corresponding MSA file name') - parser.add_argument('--protein_index', type=int, help='Row index of protein in input mapping file') - parser.add_argument('--MSA_weights_location', type=str, help='Location where weights for each sequence in the MSA will be stored') - parser.add_argument('--theta_reweighting', type=float, help='Parameters for MSA sequence re-weighting') - parser.add_argument('--VAE_checkpoint_location', type=str, help='Location where VAE model checkpoints will be stored') - parser.add_argument('--model_name_suffix', default='Jan1', type=str, help='model checkpoint name is the protein name followed by this suffix') - parser.add_argument('--model_parameters_location', type=str, help='Location of VAE model parameters') - parser.add_argument('--computation_mode', type=str, help='Computes evol indices for all single AA mutations or for a passed in list of mutations (singles or multiples) [all_singles,input_mutations_list]') - parser.add_argument('--all_singles_mutations_folder', type=str, help='Location for the list of generated single AA mutations') - parser.add_argument('--mutations_location', type=str, help='Location of all mutations to compute the evol indices for') - parser.add_argument('--output_evol_indices_location', type=str, help='Output location of computed evol indices') - parser.add_argument('--output_evol_indices_filename_suffix', default='', type=str, help='(Optional) Suffix to be added to output filename') - parser.add_argument('--num_samples_compute_evol_indices', type=int, help='Num of samples to approximate delta elbo when computing evol indices') - parser.add_argument('--batch_size', default=256, type=int, help='Batch size when computing evol indices') - args = parser.parse_args() - +def main(args): + # Load protein name mapping_file = pd.read_csv(args.MSA_list) protein_name = mapping_file['protein_name'][args.protein_index] - msa_location = args.MSA_data_folder + os.sep + mapping_file['msa_location'][args.protein_index] print("Protein name: "+str(protein_name)) - print("MSA file: "+str(msa_location)) + # Load model + model_name = protein_name + "_" + args.model_name_suffix + print("Model name: "+str(model_name)) + model_params = json.load(open(args.model_parameters_location)) + model = VAE_model.VAE_model( + model_name=model_name, + data=data, + encoder_parameters=model_params["encoder_parameters"], + decoder_parameters=model_params["decoder_parameters"], + random_seed=42 + ) + model = model.to(model.device) + try: + checkpoint_name = \ + str(args.VAE_checkpoint_location) + os.sep + \ + model_name + "_final" + checkpoint = torch.load(checkpoint_name) + model.load_state_dict(checkpoint['model_state_dict']) + print("Initialized VAE with checkpoint '{}' ".format(checkpoint_name)) + except: + print("Unable to locate VAE model checkpoint") + sys.exit(0) + + # Load and preprocess MSA + msa_location = \ + args.MSA_data_folder + os.sep + \ + mapping_file['msa_location'][args.protein_index] + print("MSA file: "+str(msa_location)) + weights_location = \ + args.MSA_weights_location + os.sep + \ + protein_name + '_theta_' + str(theta) + '.npy' if args.theta_reweighting is not None: theta = args.theta_reweighting else: @@ -41,57 +52,70 @@ except: theta = 0.2 print("Theta MSA re-weighting: "+str(theta)) - data = data_utils.MSA_processing( MSA_location=msa_location, + weights_location=weights_location, theta=theta, - use_weights=True, - weights_location=args.MSA_weights_location + os.sep + protein_name + '_theta_' + str(theta) + '.npy' + use_weights=True ) + # Load mutations location if args.computation_mode=="all_singles": - data.save_all_singles(output_filename=args.all_singles_mutations_folder + os.sep + protein_name + "_all_singles.csv") - args.mutations_location = args.all_singles_mutations_folder + os.sep + protein_name + "_all_singles.csv" + mutations_location = \ + args.all_singles_mutations_folder + os.sep + \ + protein_name + "_all_singles.csv" + data.save_all_singles(output_filename = mutations_location) else: - args.mutations_location = args.mutations_location + os.sep + protein_name + ".csv" - - model_name = protein_name + "_" + args.model_name_suffix - print("Model name: "+str(model_name)) - - model_params = json.load(open(args.model_parameters_location)) + args.mutations_location = \ + args.mutations_location + os.sep + \ + protein_name + ".csv" - model = VAE_model.VAE_model( - model_name=model_name, - data=data, - encoder_parameters=model_params["encoder_parameters"], - decoder_parameters=model_params["decoder_parameters"], - random_seed=42 - ) - model = model.to(model.device) + # Run inference + evol_indices = model.compute_evol_indices( + msa_data=data, + list_mutations_location=args.mutations_location, + num_samples=args.num_samples_compute_evol_indices, + batch_size=args.batch_size) + list_valid_mutations, evol_indices, mean_elbo, std_elbo = evol_indices - try: - checkpoint_name = str(args.VAE_checkpoint_location) + os.sep + model_name + "_final" - checkpoint = torch.load(checkpoint_name) - model.load_state_dict(checkpoint['model_state_dict']) - print("Initialized VAE with checkpoint '{}' ".format(checkpoint_name)) - except: - print("Unable to locate VAE model checkpoint") - sys.exit(0) - - list_valid_mutations, evol_indices, _, _ = model.compute_evol_indices(msa_data=data, - list_mutations_location=args.mutations_location, - num_samples=args.num_samples_compute_evol_indices, - batch_size=args.batch_size) - - df = {} - df['protein_name'] = protein_name - df['mutations'] = list_valid_mutations - df['evol_indices'] = evol_indices - df = pd.DataFrame(df) - - evol_indices_output_filename = args.output_evol_indices_location+os.sep+protein_name+'_'+str(args.num_samples_compute_evol_indices)+'_samples'+args.output_evol_indices_filename_suffix+'.csv' + # Format as dataframe and write to file + df = pd.DataFrame(dict( + protein_name = protein_name, + mutations = list_valid_mutations, + evol_indices = evol_indices + )) + evol_indices_output_filename = \ + args.output_evol_indices_location + os.sep + \ + protein_name + '_' + \ + str(args.num_samples_compute_evol_indices) + '_samples' + \ + args.output_evol_indices_filename_suffix + '.csv' try: keep_header = os.stat(evol_indices_output_filename).st_size == 0 except: keep_header=True - df.to_csv(path_or_buf=evol_indices_output_filename, index=False, mode='a', header=keep_header) \ No newline at end of file + df.to_csv( + path_or_buf=evol_indices_output_filename, + index=False, mode='a', header=keep_header + ) + + +if __name__=='__main__': + parser = argparse.ArgumentParser(description='Evol indices') + parser.add_argument('--MSA_data_folder', type=str, help='Folder where MSAs are stored') + parser.add_argument('--MSA_list', type=str, help='List of proteins and corresponding MSA file name') + parser.add_argument('--protein_index', type=int, help='Row index of protein in input mapping file') + parser.add_argument('--MSA_weights_location', type=str, help='Location where weights for each sequence in the MSA will be stored') + parser.add_argument('--theta_reweighting', type=float, help='Parameters for MSA sequence re-weighting') + parser.add_argument('--VAE_checkpoint_location', type=str, help='Location where VAE model checkpoints will be stored') + parser.add_argument('--model_name_suffix', default='Jan1', type=str, help='model checkpoint name is the protein name followed by this suffix') + parser.add_argument('--model_parameters_location', type=str, help='Location of VAE model parameters') + parser.add_argument('--computation_mode', type=str, help='Computes evol indices for all single AA mutations or for a passed in list of mutations (singles or multiples) [all_singles,input_mutations_list]') + parser.add_argument('--all_singles_mutations_folder', type=str, help='Location for the list of generated single AA mutations') + parser.add_argument('--mutations_location', type=str, help='Location of all mutations to compute the evol indices for') + parser.add_argument('--output_evol_indices_location', type=str, help='Output location of computed evol indices') + parser.add_argument('--output_evol_indices_filename_suffix', default='', type=str, help='(Optional) Suffix to be added to output filename') + parser.add_argument('--num_samples_compute_evol_indices', type=int, help='Num of samples to approximate delta elbo when computing evol indices') + parser.add_argument('--batch_size', default=256, type=int, help='Batch size when computing evol indices') + args = parser.parse_args() + + main(args) \ No newline at end of file diff --git a/predict_GMM_score.py b/predict_GMM_score.py new file mode 100644 index 0000000..50d0df8 --- /dev/null +++ b/predict_GMM_score.py @@ -0,0 +1,182 @@ +from cProfile import label +from email.policy import default +import os +from statistics import quantiles +from tabnanny import verbose +import numpy as np +import pandas as pd +import argparse +import pickle +import tqdm +import json +from utils import performance_helpers as ph, plot_helpers +from EVE import GMM_model + + +def preprocess_evol_indices(all_evol_indices, protein_name=None, verbose=False): + evol_indices = all_evol_indices.drop_duplicates() + X_train = evol_indices['evol_indices'].values.reshape(-1, 1) + proteins_train = list(evol_indices['protein_name'].values) + if verbose: + print("Training data size: "+str(len(X_train))) + print("Number of distinct proteins in protein_list: "+str(len(np.unique(all_evol_indices['protein_name'])))) + return X_train, proteins_train + + +def predict_scores_GMM(gmm_model, all_evol_indices, + recompute_uncertainty_threshold = True, uncertainty_threshold_location = None, + verbose = False): + + all_preds = all_evol_indices.copy() + + if gmm_model.protein_GMM_weight > 0.0: + all_preds['scores'] = np.nan + all_preds['classes'] = "" + protein_list = list(gmm_model.models.keys()).drop('main') + for protein in tqdm.tqdm(protein_list,"Scoring all protein mutations"): + preds_protein = all_preds[all_preds.protein_name==protein].copy() + X_pred_protein = preds_protein['evol_indices'].values.reshape(-1, 1) + scores, classes = gmm_model.predict_weighted(gmm_model, X_pred_protein) + preds_protein['scores'] = scores + preds_protein['classes'] = classes + all_preds.loc[all_preds.protein_name==protein, :] = preds_protein + + else: + X_pred = all_preds['evol_indices'].values.reshape(-1, 1) + scores, classes = gmm_model.predict(X_pred, 'main') + all_preds['scores'] = scores + all_preds['classes'] = classes + + if verbose: + scores_stats = all_preds['scores'].describe() + print("Score stats: \n", scores_stats) + len_before_drop_na = len(all_preds) + len_after_drop_na = len(all_preds['scores'].dropna()) + print("Dropped mutations due to missing EVE scores: "+str(len_after_drop_na-len_before_drop_na)) + + return all_preds + + +def filter_uncertainties(all_scores, n_quantiles, default_uc_location, recompute = False): + # Compute uncertainty from mixture model + y_pred = all_scores['scores'] + uncertainty = ph.predictive_entropy_binary_classifier(y_pred) + all_scores['uncertainty'] = uncertainty + + # Get quantiles for uncertainty + if not recompute: + with open(default_uc_location,'r') as fid: + uc_quantiles = json.load(fid) + uc_quantiles = ph.get_uncertainty_thresholds(uncertainty, n_quantiles) + if verbose: + print('Quantiles', uc_quantiles) + + # Assign classes at each quantile + for i, quantile in enumerate(quantiles): + level = f'class_ucq_{i}' + all_scores[level] = all_scores['class'] * (uncertainty < quantile) + if verbose: + print("Stats classification by uncertainty for quantile #:"+str(decile)) + print(all_scores[level].value_counts(normalize=True)) + return all_scores + + +def main(args): + # Load evolutionary indices from files + mapping_file = pd.read_csv(args.protein_list,low_memory=False) + protein_list = np.unique(mapping_file['protein_name']) + list_variables_to_keep=['protein_name','mutations','evol_indices'] + all_evol_indices = [] + for protein in protein_list: + evol_indices_location = \ + args.input_evol_indices_location + os.sep + \ + protein + args.input_evol_indices_filename_suffix + '.csv' + if os.path.exists(evol_indices_location): + evol_indices = pd.read_csv( + evol_indices_location, + low_memory=False, ignore_index=True, + usecols=list_variables_to_keep) + all_evol_indices.append(evol_indices) + all_evol_indices = pd.concat(all_evol_indices) + + if args.load_GMM_models: + # Load GMM models from file + gmm_model_location = \ + args.GMM_parameter_location+os.sep+\ + 'GMM_model_dictionary_'+ \ + args.GMM_parameter_filename_suffix + with open(gmm_model_location, "rb" ) as fid: + dict_models = pickle.load(fid) + # Load GMM indices from file + gmm_index_location = args.GMM_parameter_location+os.sep+\ + 'GMM_pathogenic_cluster_index_dictionary_'+\ + args.GMM_parameter_filename_suffix + with open(gmm_index_location, "rb" ) as fid: + dict_pathogenic_cluster_index = pickle.load(fid) + + else: + # train GMMs on mutation evolutionary indices + gmm_log_location = \ + args.GMM_parameter_location+os.sep+\ + args.output_eve_scores_filename_suffix+os.sep+\ + 'GMM_stats_'+args.output_eve_scores_filename_suffix+'.csv' + gmm_params = { + 'log_location':gmm_log_location, + 'protein_GMM_weight':args.protein_GMM_weight + } + gmm = GMM_model.GMM_model(gmm_params) + gmm.fit( + all_evol_indices, + protein_list, + verbose=args.verbose + ) + + # Write GMM models to file + gmm_model_location = \ + args.GMM_parameter_location+os.sep+\ + 'GMM_model_dictionary_'+ \ + args.output_eve_scores_filename_suffix + with open(gmm_model_location, "wb" ) as fid: + pickle.dump(gmm, fid) + + # Compute EVE classification scores for all mutations and write to file + all_scores = predict_scores_GMM( + all_evol_indices, + dict_models, + dict_pathogenic_cluster_index, + protein_list + ) + all_scores.to_csv( + args.output_eve_scores_location+os.sep+ \ + 'all_EVE_scores_'+ \ + args.output_eve_scores_filename_suffix+'.csv', + index=False + ) + + +if __name__=='__main__': + parser = argparse.ArgumentParser(description='GMM fit and EVE scores computation') + parser.add_argument('--input_evol_indices_location', type=str, help='Folder where all individual files with evolutionary indices are stored') + parser.add_argument('--input_evol_indices_filename_suffix', type=str, default='', help='Suffix that was added when generating the evol indices files') + parser.add_argument('--protein_list', type=str, help='List of proteins to be included (one per row)') + parser.add_argument('--output_eve_scores_location', type=str, help='Folder where all EVE scores are stored') + parser.add_argument('--output_eve_scores_filename_suffix', default='', type=str, help='(Optional) Suffix to be added to output filename') + + parser.add_argument('--load_GMM_models', default=False, action='store_true', help='If True, load GMM model parameters. If False, train GMMs from evol indices files') + parser.add_argument('--GMM_parameter_location', default=None, type=str, help='Folder where GMM objects are stored if loading / to be stored if we are re-training') + parser.add_argument('--GMM_parameter_filename_suffix', default=None, type=str, help='Suffix of GMMs model files to load') + parser.add_argument('--protein_GMM_weight', default=0.3, type=float, help='Value of global-local GMM mixing parameter') + + parser.add_argument('--compute_EVE_scores', default=False, action='store_true', help='Computes EVE scores and uncertainty metrics for all input protein mutations') + parser.add_argument('--recompute_uncertainty_threshold', default=False, action='store_true', help='Recompute uncertainty thresholds based on all evol indices in file. Otherwise loads default threhold.') + parser.add_argument('--default_uncertainty_threshold_file_location', default='./utils/default_uncertainty_threshold.json', type=str, help='Location of default uncertainty threholds.') + + parser.add_argument('--plot_histograms', default=False, action='store_true', help='Plots all evol indices histograms with GMM fits') + parser.add_argument('--plot_scores_vs_labels', default=False, action='store_true', help='Plots EVE scores Vs labels at each protein position') + parser.add_argument('--labels_file_location', default=None, type=str, help='File with ground truth labels for all proteins of interest (e.g., ClinVar)') + parser.add_argument('--plot_location', default=None, type=str, help='Location of the different plots') + parser.add_argument('--verbose', action='store_true', help='Print detailed information during run') + args = parser.parse_args() + + main(args) + \ No newline at end of file diff --git a/train_GMM_and_compute_EVE_scores.py b/train_GMM_and_compute_EVE_scores.py deleted file mode 100644 index acee314..0000000 --- a/train_GMM_and_compute_EVE_scores.py +++ /dev/null @@ -1,192 +0,0 @@ -import os -import numpy as np -import pandas as pd -import argparse -import pickle -import tqdm -import json -from sklearn import mixture, linear_model, svm, gaussian_process - -from utils import performance_helpers as ph, plot_helpers - -if __name__=='__main__': - parser = argparse.ArgumentParser(description='GMM fit and EVE scores computation') - parser.add_argument('--input_evol_indices_location', type=str, help='Folder where all individual files with evolutionary indices are stored') - parser.add_argument('--input_evol_indices_filename_suffix', type=str, default='', help='Suffix that was added when generating the evol indices files') - parser.add_argument('--protein_list', type=str, help='List of proteins to be included (one per row)') - parser.add_argument('--output_eve_scores_location', type=str, help='Folder where all EVE scores are stored') - parser.add_argument('--output_eve_scores_filename_suffix', default='', type=str, help='(Optional) Suffix to be added to output filename') - - parser.add_argument('--load_GMM_models', default=False, action='store_true', help='If True, load GMM model parameters. If False, train GMMs from evol indices files') - parser.add_argument('--GMM_parameter_location', default=None, type=str, help='Folder where GMM objects are stored if loading / to be stored if we are re-training') - parser.add_argument('--GMM_parameter_filename_suffix', default=None, type=str, help='Suffix of GMMs model files to load') - parser.add_argument('--protein_GMM_weight', default=0.3, type=float, help='Value of global-local GMM mixing parameter') - - parser.add_argument('--compute_EVE_scores', default=False, action='store_true', help='Computes EVE scores and uncertainty metrics for all input protein mutations') - parser.add_argument('--recompute_uncertainty_threshold', default=False, action='store_true', help='Recompute uncertainty thresholds based on all evol indices in file. Otherwise loads default threhold.') - parser.add_argument('--default_uncertainty_threshold_file_location', default='./utils/default_uncertainty_threshold.json', type=str, help='Location of default uncertainty threholds.') - - parser.add_argument('--plot_histograms', default=False, action='store_true', help='Plots all evol indices histograms with GMM fits') - parser.add_argument('--plot_scores_vs_labels', default=False, action='store_true', help='Plots EVE scores Vs labels at each protein position') - parser.add_argument('--labels_file_location', default=None, type=str, help='File with ground truth labels for all proteins of interest (e.g., ClinVar)') - parser.add_argument('--plot_location', default=None, type=str, help='Location of the different plots') - parser.add_argument('--verbose', action='store_true', help='Print detailed information during run') - args = parser.parse_args() - - mapping_file = pd.read_csv(args.protein_list,low_memory=False) - protein_list = np.unique(mapping_file['protein_name']) - list_variables_to_keep=['protein_name','mutations','evol_indices'] - all_evol_indices = pd.concat([pd.read_csv(args.input_evol_indices_location+os.sep+protein+args.input_evol_indices_filename_suffix+'.csv',low_memory=False)[list_variables_to_keep] \ - for protein in protein_list if os.path.exists(args.input_evol_indices_location+os.sep+protein+args.input_evol_indices_filename_suffix+'.csv')], ignore_index=True) - - all_evol_indices = all_evol_indices.drop_duplicates() - X_train = np.array(all_evol_indices['evol_indices']).reshape(-1, 1) - if args.verbose: - print("Training data size: "+str(len(X_train))) - print("Number of distinct proteins in protein_list: "+str(len(np.unique(all_evol_indices['protein_name'])))) - - if args.load_GMM_models: - dict_models = pickle.load( open( args.GMM_parameter_location+os.sep+'GMM_model_dictionary_'+args.GMM_parameter_filename_suffix, "rb" ) ) - dict_pathogenic_cluster_index = pickle.load( open( args.GMM_parameter_location+os.sep+'GMM_pathogenic_cluster_index_dictionary_'+args.GMM_parameter_filename_suffix, "rb" ) ) - else: - dict_models = {} - dict_pathogenic_cluster_index = {} - if not os.path.exists(args.GMM_parameter_location+os.sep+args.output_eve_scores_filename_suffix): - os.makedirs(args.GMM_parameter_location+os.sep+args.output_eve_scores_filename_suffix) - GMM_stats_log_location=args.GMM_parameter_location+os.sep+args.output_eve_scores_filename_suffix+os.sep+'GMM_stats_'+args.output_eve_scores_filename_suffix+'.csv' - with open(GMM_stats_log_location, "a") as logs: - logs.write("protein_name,weight_pathogenic,mean_pathogenic,mean_benign,std_dev_pathogenic,std_dev_benign\n") - - main_GMM = mixture.GaussianMixture(n_components=2, covariance_type='full',max_iter=1000,n_init=30,tol=1e-4) - main_GMM.fit(X_train) - - dict_models['main'] = main_GMM - pathogenic_cluster_index = np.argmax(np.array(main_GMM.means_).flatten()) #The pathogenic cluster is the cluster with higher mean value - dict_pathogenic_cluster_index['main'] = pathogenic_cluster_index - if args.verbose: - inferred_params = main_GMM.get_params() - print("Index of mixture component with highest mean: "+str(pathogenic_cluster_index)) - print("Model parameters: "+str(inferred_params)) - print("Mixture component weights: "+str(main_GMM.weights_)) - print("Mixture component means: "+str(main_GMM.means_)) - print("Cluster component cov: "+str(main_GMM.covariances_)) - with open(GMM_stats_log_location, "a") as logs: - logs.write(",".join(str(x) for x in [ - 'main', np.array(main_GMM.weights_).flatten()[dict_pathogenic_cluster_index['main']], np.array(main_GMM.means_).flatten()[dict_pathogenic_cluster_index['main']], - np.array(main_GMM.means_).flatten()[1 - dict_pathogenic_cluster_index['main']], np.sqrt(np.array(main_GMM.covariances_).flatten()[dict_pathogenic_cluster_index['main']]), - np.sqrt(np.array(main_GMM.covariances_).flatten()[1 - dict_pathogenic_cluster_index['main']]) - ])+"\n") - - if args.protein_GMM_weight > 0.0: - for protein in tqdm.tqdm(protein_list, "Training all protein GMMs"): - X_train_protein = np.array(all_evol_indices['evol_indices'][all_evol_indices.protein_name==protein]).reshape(-1, 1) - if len(X_train_protein) > 0: #We have evol indices computed for protein on file - protein_GMM = mixture.GaussianMixture(n_components=2,covariance_type='full',max_iter=1000,tol=1e-4,weights_init=main_GMM.weights_,means_init=main_GMM.means_,precisions_init=main_GMM.precisions_) - protein_GMM.fit(X_train_protein) - dict_models[protein] = protein_GMM - dict_pathogenic_cluster_index[protein] = np.argmax(np.array(protein_GMM.means_).flatten()) - with open(GMM_stats_log_location, "a") as logs: - logs.write(",".join(str(x) for x in [ - protein, np.array(protein_GMM.weights_).flatten()[dict_pathogenic_cluster_index[protein]], np.array(protein_GMM.means_).flatten()[dict_pathogenic_cluster_index[protein]], - np.array(protein_GMM.means_).flatten()[1 - dict_pathogenic_cluster_index[protein]], np.sqrt(np.array(protein_GMM.covariances_).flatten()[dict_pathogenic_cluster_index[protein]]), - np.sqrt(np.array(protein_GMM.covariances_).flatten()[1 - dict_pathogenic_cluster_index[protein]]) - ])+"\n") - else: - if args.verbose: - print("No evol indices for the protein: "+str(protein)+". Skipping.") - - pickle.dump(dict_models, open(args.GMM_parameter_location+os.sep+args.output_eve_scores_filename_suffix+os.sep+'GMM_model_dictionary_'+args.output_eve_scores_filename_suffix, 'wb')) - pickle.dump(dict_pathogenic_cluster_index, open(args.GMM_parameter_location+os.sep+args.output_eve_scores_filename_suffix+os.sep+'GMM_pathogenic_cluster_index_dictionary_'+args.output_eve_scores_filename_suffix, 'wb')) - - if args.plot_histograms: - if not os.path.exists(args.plot_location+os.sep+'plots_histograms'+os.sep+args.output_eve_scores_filename_suffix): - os.makedirs(args.plot_location+os.sep+'plots_histograms'+os.sep+args.output_eve_scores_filename_suffix) - plot_helpers.plot_histograms(all_evol_indices, dict_models, dict_pathogenic_cluster_index, args.protein_GMM_weight, args.plot_location+os.sep+'plots_histograms'+os.sep+args.output_eve_scores_filename_suffix, args.output_eve_scores_filename_suffix, protein_list) - - if args.compute_EVE_scores: - if args.protein_GMM_weight > 0.0: - all_scores = all_evol_indices.copy() - all_scores['EVE_scores'] = np.nan - all_scores['EVE_classes_100_pct_retained'] = "" - for protein in tqdm.tqdm(protein_list,"Scoring all protein mutations"): - try: - test_data_protein = all_scores[all_scores.protein_name==protein] - X_test_protein = np.array(test_data_protein['evol_indices']).reshape(-1, 1) - mutation_scores_protein = ph.compute_weighted_score_two_GMMs(X_pred=X_test_protein, - main_model = dict_models['main'], - protein_model=dict_models[protein], - cluster_index_main = dict_pathogenic_cluster_index['main'], - cluster_index_protein = dict_pathogenic_cluster_index[protein], - protein_weight = args.protein_GMM_weight) - gmm_class_protein = ph.compute_weighted_class_two_GMMs(X_pred=X_test_protein, - main_model = dict_models['main'], - protein_model=dict_models[protein], - cluster_index_main = dict_pathogenic_cluster_index['main'], - cluster_index_protein = dict_pathogenic_cluster_index[protein], - protein_weight = args.protein_GMM_weight) - gmm_class_label_protein = pd.Series(gmm_class_protein).map(lambda x: 'Pathogenic' if x == 1 else 'Benign') - - all_scores.loc[all_scores.protein_name==protein, 'EVE_scores'] = np.array(mutation_scores_protein) - all_scores.loc[all_scores.protein_name==protein, 'EVE_classes_100_pct_retained'] = np.array(gmm_class_label_protein) - except: - print("Issues with protein: "+str(protein)+". Skipping.") - else: - all_scores = all_evol_indices.copy() - mutation_scores = dict_models['main'].predict_proba(np.array(all_scores['evol_indices']).reshape(-1, 1)) - all_scores['EVE_scores'] = mutation_scores[:,dict_pathogenic_cluster_index['main']] - gmm_class = dict_models['main'].predict(np.array(all_scores['evol_indices']).reshape(-1, 1)) - all_scores['EVE_classes_100_pct_retained'] = np.array(pd.Series(gmm_class).map(lambda x: 'Pathogenic' if x == dict_pathogenic_cluster_index['main'] else 'Benign')) - - len_before_drop_na = len(all_scores) - all_scores = all_scores.dropna(subset=['EVE_scores']) - len_after_drop_na = len(all_scores) - - if args.verbose: - scores_stats = ph.compute_stats(all_scores['EVE_scores']) - print("Score stats: "+str(scores_stats)) - print("Dropped mutations due to missing EVE scores: "+str(len_after_drop_na-len_before_drop_na)) - all_scores['uncertainty'] = ph.predictive_entropy_binary_classifier(all_scores['EVE_scores']) - - if args.recompute_uncertainty_threshold: - uncertainty_cutoffs_deciles, _, _ = ph.compute_uncertainty_deciles(all_scores) - uncertainty_cutoffs_quartiles, _, _ = ph.compute_uncertainty_quartiles(all_scores) - if args.verbose: - print("Uncertainty cutoffs deciles: "+str(uncertainty_cutoffs_deciles)) - print("Uncertainty cutoffs quartiles: "+str(uncertainty_cutoffs_quartiles)) - else: - uncertainty_thresholds = json.load(open(args.default_uncertainty_threshold_file_location)) - uncertainty_cutoffs_deciles = uncertainty_thresholds["deciles"] - uncertainty_cutoffs_quartiles = uncertainty_thresholds["quartiles"] - - for decile in range(1,10): - classification_name = 'EVE_classes_'+str((decile)*10)+"_pct_retained" - all_scores[classification_name] = all_scores['EVE_classes_100_pct_retained'] - all_scores.loc[all_scores['uncertainty'] > uncertainty_cutoffs_deciles[str(decile)], classification_name] = 'Uncertain' - if args.verbose: - print("Stats classification by uncertainty for decile #:"+str(decile)) - print(all_scores[classification_name].value_counts(normalize=True)) - if args.verbose: - print("Stats classification by uncertainty for decile #:"+str(10)) - print(all_scores['EVE_classes_100_pct_retained'].value_counts(normalize=True)) - - for quartile in [1,3]: - classification_name = 'EVE_classes_'+str((quartile)*25)+"_pct_retained" - all_scores[classification_name] = all_scores['EVE_classes_100_pct_retained'] - all_scores.loc[all_scores['uncertainty'] > uncertainty_cutoffs_quartiles[str(quartile)], classification_name] = 'Uncertain' - if args.verbose: - print("Stats classification by uncertainty for quartile #:"+str(quartile)) - print(all_scores[classification_name].value_counts(normalize=True)) - - all_scores.to_csv(args.output_eve_scores_location+os.sep+'all_EVE_scores_'+args.output_eve_scores_filename_suffix+'.csv', index=False) - - if args.plot_scores_vs_labels: - labels_dataset=pd.read_csv(args.labels_file_location,low_memory=False) - all_scores_mutations_with_labels = pd.merge(all_scores, labels_dataset[['protein_name','mutations','ClinVar_labels']], how='inner', on=['protein_name','mutations']) - all_PB_scores = all_scores_mutations_with_labels[all_scores_mutations_with_labels.ClinVar_labels!=0.5].copy() - if not os.path.exists(args.plot_location+os.sep+'plots_scores_vs_labels'+os.sep+args.output_eve_scores_filename_suffix): - os.makedirs(args.plot_location+os.sep+'plots_scores_vs_labels'+os.sep+args.output_eve_scores_filename_suffix) - for protein in tqdm.tqdm(protein_list,"Plot scores Vs labels"): - plot_helpers.plot_scores_vs_labels(score_df=all_PB_scores[all_PB_scores.protein_name==protein], - plot_location=args.plot_location+os.sep+'plots_scores_vs_labels'+os.sep+args.output_eve_scores_filename_suffix, - output_eve_scores_filename_suffix=args.output_eve_scores_filename_suffix+'_'+protein, - mutation_name='mutations', score_name="EVE_scores", label_name='ClinVar_labels') \ No newline at end of file From 528dcd7dcbd5f8c5e8105b1f5baa4ee4d8aa5a3a Mon Sep 17 00:00:00 2001 From: Andrew Boardman Date: Mon, 26 Sep 2022 14:11:49 +0100 Subject: [PATCH 3/3] Refactored and made changes --- EVE/GMM_model.py | 67 ++--- EVE/VAE_decoder.py | 270 ++++++++++-------- data/weights/PTEN_HUMAN_theta_0.2.npy | Bin 0 -> 9560 bytes ...s_all_singles.sh => Step2_predict_ELBO.sh} | 2 +- ..._singles.sh => Step3_predict_GMM_score.sh} | 12 +- logs/PTEN_HUMAN_Jan1_PTEN_example_losses.csv | 33 +++ predict_ELBO.py | 47 +-- predict_GMM_score.py | 36 +-- 8 files changed, 252 insertions(+), 215 deletions(-) create mode 100644 data/weights/PTEN_HUMAN_theta_0.2.npy rename examples/{Step2_compute_evol_indices_all_singles.sh => Step2_predict_ELBO.sh} (97%) rename examples/{Step3_train_GMM_and_compute_EVE_scores_all_singles.sh => Step3_predict_GMM_score.sh} (73%) create mode 100644 logs/PTEN_HUMAN_Jan1_PTEN_example_losses.csv diff --git a/EVE/GMM_model.py b/EVE/GMM_model.py index c9f6457..26253cf 100644 --- a/EVE/GMM_model.py +++ b/EVE/GMM_model.py @@ -2,63 +2,38 @@ import numpy as np import os import tqdm +import logging class GMM_model: """ Class for the global-local GMM model trained on single-point mutations for a wild-type protein sequence """ def __init__(self, params) -> None: - # Set up logging - self.log_location = params['log_location'] - if not os.path.exists(os.path.basename(self.log_location)): - os.makedirs(os.path.basename(self.log_location)) - with open(self.log_location, "a") as logs: - logs.write("protein_name,weight_pathogenic,mean_pathogenic,mean_benign,std_dev_pathogenic,std_dev_benign\n") - + # store parameters self.protein_GMM_weight = params['protein_GMM_weight'] - - def fit_single(self, X_train, protein_name=None, verbose = False): - model = mixture.GaussianMixture( + self.params = dict( n_components=2, covariance_type='full', max_iter=1000, n_init=30, tol=1e-4 - ) + ) + + def fit_single(self, X_train): + model = mixture.GaussianMixture(**self.params) model.fit(X_train) #The pathogenic cluster is the cluster with higher mean value - pathogenic_cluster_index = np.argmax(np.array(self.model.means_).flatten()) - - with open(self.log_location, "a") as logs: - logs.write(",".join(str(x) for x in [ - protein_name, - np.array(self.model.weights_).flatten()[self.pathogenic_cluster_index], - np.array(self.model.means_).flatten()[self.pathogenic_cluster_index], - np.array(self.model.means_).flatten()[1 - self.pathogenic_cluster_index], - np.sqrt(np.array(self.model.covariances_).flatten()[self.pathogenic_cluster_index]), - np.sqrt(np.array(self.model.covariances_).flatten()[1 - self.pathogenic_cluster_index]) - ]) - +"\n" - ) - - if verbose: - inferred_params = self.model.get_params() - print("Index of mixture component with highest mean: "+str(self.pathogenic_cluster_index)) - print("Model parameters: "+str(inferred_params)) - print("Mixture component weights: "+str(self.model.weights_)) - print("Mixture component means: "+str(self.model.means_)) - print("Cluster component cov: "+str(self.model.covariances_)) - + pathogenic_cluster_index = np.argmax(np.array(model.means_).flatten()) return model, pathogenic_cluster_index - def fit(self, X_train, proteins_train, verbose = True): + def fit(self, X_train, proteins_train): # set up to train self.models = {} self.indices = {} # train global model - gmm, index = self.fit_single(X_train,'main',verbose=verbose) + gmm, index = self.fit_single(X_train,'main') self.models['main'] = gmm self.indices['main'] = index @@ -67,7 +42,7 @@ def fit(self, X_train, proteins_train, verbose = True): proteins_list = list(set(proteins_train)) for protein in tqdm.tqdm(proteins_list, "Training all protein GMMs"): X_train_protein = X_train[proteins_train == protein] - gmm, index = self.fit_single(X_train_protein,protein,verbose=verbose) + gmm, index = self.fit_single(X_train_protein,protein) self.models[protein] = gmm self.indices[protein] = index @@ -89,7 +64,19 @@ def predict_weighted(self, X_pred, protein): classes_weighted = (scores_weighted > 0.5).astype(int) return scores_weighted, classes_weighted - - - -X_train, groups_train = preprocess_indices(all_evol_indices) \ No newline at end of file + def get_fitted_params(self): + fitted_params = ( + self.models.keys(), + self.models.values(), + self.indices.values() + ) + output = {} + for protein, model, index in zip(*fitted_params): + output[protein] = { + 'index': index, + 'means': model.means_, + 'covar': model.covariances_, + 'weights': model.weights_ + } + return output + diff --git a/EVE/VAE_decoder.py b/EVE/VAE_decoder.py index 18acb2d..87277e0 100644 --- a/EVE/VAE_decoder.py +++ b/EVE/VAE_decoder.py @@ -1,6 +1,15 @@ import torch import torch.nn as nn -import torch.nn.functional as F +import torch.nn.functional as functional + +HIDDEN_LAYER_NONLINEARITIES = { + 'relu': nn.ReLU(), + 'tanh': nn.Tanh(), + 'sigmoid': nn.Sigmoid(), + 'elu': nn.ELU(), + 'linear': nn.Identity(), +} + class VAE_Bayesian_MLP_decoder(nn.Module): """ @@ -40,130 +49,178 @@ def __init__(self, params): self.mu_bias_init = 0.1 self.logvar_init = -10.0 self.logit_scale_p = 0.001 - - self.hidden_layers_mean=nn.ModuleDict() - self.hidden_layers_log_var=nn.ModuleDict() - for layer_index in range(len(self.hidden_layers_sizes)): - if layer_index==0: - self.hidden_layers_mean[str(layer_index)] = nn.Linear(self.z_dim, self.hidden_layers_sizes[layer_index]) - self.hidden_layers_log_var[str(layer_index)] = nn.Linear(self.z_dim, self.hidden_layers_sizes[layer_index]) - nn.init.constant_(self.hidden_layers_mean[str(layer_index)].bias, self.mu_bias_init) - nn.init.constant_(self.hidden_layers_log_var[str(layer_index)].weight, self.logvar_init) - nn.init.constant_(self.hidden_layers_log_var[str(layer_index)].bias, self.logvar_init) - else: - self.hidden_layers_mean[str(layer_index)] = nn.Linear(self.hidden_layers_sizes[layer_index-1],self.hidden_layers_sizes[layer_index]) - self.hidden_layers_log_var[str(layer_index)] = nn.Linear(self.hidden_layers_sizes[layer_index-1],self.hidden_layers_sizes[layer_index]) - nn.init.constant_(self.hidden_layers_mean[str(layer_index)].bias, self.mu_bias_init) - nn.init.constant_(self.hidden_layers_log_var[str(layer_index)].weight, self.logvar_init) - nn.init.constant_(self.hidden_layers_log_var[str(layer_index)].bias, self.logvar_init) - - if params['first_hidden_nonlinearity'] == 'relu': - self.first_hidden_nonlinearity = nn.ReLU() - elif params['first_hidden_nonlinearity'] == 'tanh': - self.first_hidden_nonlinearity = nn.Tanh() - elif params['first_hidden_nonlinearity'] == 'sigmoid': - self.first_hidden_nonlinearity = nn.Sigmoid() - elif params['first_hidden_nonlinearity'] == 'elu': - self.first_hidden_nonlinearity = nn.ELU() - elif params['first_hidden_nonlinearity'] == 'linear': - self.first_hidden_nonlinearity = nn.Identity() - - if params['last_hidden_nonlinearity'] == 'relu': - self.last_hidden_nonlinearity = nn.ReLU() - elif params['last_hidden_nonlinearity'] == 'tanh': - self.last_hidden_nonlinearity = nn.Tanh() - elif params['last_hidden_nonlinearity'] == 'sigmoid': - self.last_hidden_nonlinearity = nn.Sigmoid() - elif params['last_hidden_nonlinearity'] == 'elu': - self.last_hidden_nonlinearity = nn.ELU() - elif params['last_hidden_nonlinearity'] == 'linear': - self.last_hidden_nonlinearity = nn.Identity() - - if self.dropout_proba > 0.0: - self.dropout_layer = nn.Dropout(p=self.dropout_proba) - if self.convolve_output: - self.output_convolution_mean = nn.Conv1d(in_channels=self.convolution_depth,out_channels=self.alphabet_size,kernel_size=1,stride=1,bias=False) - self.output_convolution_log_var = nn.Conv1d(in_channels=self.convolution_depth,out_channels=self.alphabet_size,kernel_size=1,stride=1,bias=False) - nn.init.constant_(self.output_convolution_log_var.weight, self.logvar_init) + self.fcnn_output_size = self.seq_len * self.hidden_layers_sizes[-1] self.channel_size = self.convolution_depth else: self.channel_size = self.alphabet_size - if self.include_sparsity: - self.sparsity_weight_mean = nn.Parameter(torch.zeros(int(self.hidden_layers_sizes[-1]/self.num_tiles_sparsity), self.seq_len)) - self.sparsity_weight_log_var = nn.Parameter(torch.ones(int(self.hidden_layers_sizes[-1]/self.num_tiles_sparsity), self.seq_len)) - nn.init.constant_(self.sparsity_weight_log_var, self.logvar_init) + # Set hidden layer nonlinearities for first and last layers + self.first_hidden_nonlinearity = HIDDEN_LAYER_NONLINEARITIES[params['first_hidden_nonlinearity']] + self.last_hidden_nonlinearity = HIDDEN_LAYER_NONLINEARITIES[params['last_hidden_nonlinearity']] - self.last_hidden_layer_weight_mean = nn.Parameter(torch.zeros(self.channel_size * self.seq_len,self.hidden_layers_sizes[-1])) - self.last_hidden_layer_weight_log_var = nn.Parameter(torch.zeros(self.channel_size * self.seq_len,self.hidden_layers_sizes[-1])) - nn.init.xavier_normal_(self.last_hidden_layer_weight_mean) #Glorot initialization - nn.init.constant_(self.last_hidden_layer_weight_log_var, self.logvar_init) + # set dropout + if self.dropout_proba > 0.0: + self.use_dropout=True + self.dropout_layer = nn.Dropout(p=self.dropout_proba) + else: + self.use_dropout=False - self.last_hidden_layer_bias_mean = nn.Parameter(torch.zeros(self.alphabet_size * self.seq_len)) - self.last_hidden_layer_bias_log_var = nn.Parameter(torch.zeros(self.alphabet_size * self.seq_len)) + self.initialise_params() + + + def initialise_params(self): + # Set mean and variance for hidden layer distributions + self.hidden_layers_mean=nn.ModuleDict() + self.hidden_layers_log_var=nn.ModuleDict() + input_size = self.z_dim + for i, output_size in enumerate(self.hidden_layers_sizes): + self.hidden_layers_mean[str(i)] = nn.Linear(input_size, output_size) + self.hidden_layers_log_var[str(i)] = nn.Linear(input_size, output_size) + input_size = output_size + nn.init.constant_(self.hidden_layers_mean[str(i)].bias, self.mu_bias_init) + nn.init.constant_(self.hidden_layers_log_var[str(i)].weight, self.logvar_init) + nn.init.constant_(self.hidden_layers_log_var[str(i)].bias, self.logvar_init) + + # set mean and variance for last hidden weight and bias + output_size = self.channel_size * self.seq_len + self.last_hidden_layer_weight_mean = nn.Parameter(torch.zeros(output_size, input_size)) + self.last_hidden_layer_weight_log_var = nn.Parameter(torch.zeros(output_size, input_size)) + nn.init.xavier_normal_(self.last_hidden_layer_weight_mean) #Glorot initialization + nn.init.constant_(self.last_hidden_layer_weight_log_var, self.logvar_init) + + output_size = self.alphabet_size * self.seq_len + self.last_hidden_layer_bias_mean = nn.Parameter(torch.zeros(output_size)) + self.last_hidden_layer_bias_log_var = nn.Parameter(torch.zeros(output_size)) nn.init.constant_(self.last_hidden_layer_bias_mean, self.mu_bias_init) nn.init.constant_(self.last_hidden_layer_bias_log_var, self.logvar_init) - + + # Set mean and variance for 1-stride convolutions if they are used + if self.convolve_output: + self.output_convolution_mean = nn.Conv1d( + in_channels=self.convolution_depth, + out_channels=self.alphabet_size, + kernel_size=1,stride=1,bias=False) + self.output_convolution_log_var = nn.Conv1d( + in_channels=self.convolution_depth, + out_channels=self.alphabet_size, + kernel_size=1,stride=1,bias=False) + nn.init.constant_( + self.output_convolution_log_var.weight, + self.logvar_init + ) + + # Set mean and variance for sparsity prior + if self.include_sparsity: + sparsity_size = int(self.hidden_layers_sizes[-1]/self.num_tiles_sparsity) + self.sparsity_weight_mean = nn.Parameter(torch.zeros(sparsity_size, self.seq_len)) + self.sparsity_weight_log_var = nn.Parameter(torch.ones(sparsity_size, self.seq_len)) + nn.init.constant_(self.hidden_layers_log_var_sparsity, self.logvar_init) + if self.include_temperature_scaler: self.temperature_scaler_mean = nn.Parameter(torch.ones(1)) self.temperature_scaler_log_var = nn.Parameter(torch.ones(1) * self.logvar_init) - + + def sampler(self, mean, log_var): """ - Samples a latent vector via reparametrization trick + Samples a parameter from normal distribution via reparametrization trick """ eps = torch.randn_like(mean).to(self.device) z = torch.exp(0.5*log_var) * eps + mean return z - def forward(self, z): - batch_size = z.shape[0] - if self.dropout_proba > 0.0: - x = self.dropout_layer(z) - else: - x = z + def sample_weight_and_bias(self, mean, log_var): + weight = self.sampler(mean.weight, log_var.weight) + bias = self.sampler(mean.bias, log_var.bias) + return weight, bias - for layer_index in range(len(self.hidden_layers_sizes)-1): - layer_i_weight = self.sampler(self.hidden_layers_mean[str(layer_index)].weight, self.hidden_layers_log_var[str(layer_index)].weight) - layer_i_bias = self.sampler(self.hidden_layers_mean[str(layer_index)].bias, self.hidden_layers_log_var[str(layer_index)].bias) - x = self.first_hidden_nonlinearity(F.linear(x, weight=layer_i_weight, bias=layer_i_bias)) - if self.dropout_proba > 0.0: - x = self.dropout_layer(x) + def sample_custom(self, custom): + mean = self.hidden_layers_mean[custom] + log_var = self.hidden_layers_log_var[custom] + return self.sampler(mean, log_var) - last_index = len(self.hidden_layers_sizes)-1 - last_layer_weight = self.sampler(self.hidden_layers_mean[str(last_index)].weight, self.hidden_layers_log_var[str(last_index)].weight) - last_layer_bias = self.sampler(self.hidden_layers_mean[str(last_index)].bias, self.hidden_layers_log_var[str(last_index)].bias) - x = self.last_hidden_nonlinearity(F.linear(x, weight=last_layer_weight, bias=last_layer_bias)) - if self.dropout_proba > 0.0: + def apply_dropout(self, x): + if self.use_dropout: x = self.dropout_layer(x) + return x - W_out = self.sampler(self.last_hidden_layer_weight_mean, self.last_hidden_layer_weight_log_var) - b_out = self.sampler(self.last_hidden_layer_bias_mean, self.last_hidden_layer_bias_log_var) - + def forward(self, z): + """Decode latent vector into one-hot encoded sequence""" + + # Take input + x = self.apply_dropout(z) + # Sample parameters for each hidden layer and apply + for i in range(len(self.hidden_layers_sizes)): + mean = self.hidden_layers_mean[str(i)] + log_var = self.hidden_layers_log_var[str(i)] + weight, bias = self.sample_weight_and_bias(mean, log_var) + x = functional.linear(x, weight=weight, bias=bias) + if i < len(self.hidden_layers_sizes): + x = self.first_hidden_nonlinearity(x) + else: + x = self.last_hidden_nonlinearity(x) + x = self.apply_dropout(x) + + # Sample weight and bias for last layer + W_out = self.sampler( + self.last_hidden_layer_weight_mean, + self.last_hidden_layer_weight_log_var + ) + b_out = self.sampler( + self.last_hidden_layer_bias_mean, + self.last_hidden_layer_bias_log_var + ) + + # optionally, perform convolutions with stride 1 on this layer if self.convolve_output: - output_convolution_weight = self.sampler(self.output_convolution_mean.weight, self.output_convolution_log_var.weight) - W_out = torch.mm(W_out.view(self.seq_len * self.hidden_layers_sizes[-1], self.channel_size), - output_convolution_weight.view(self.channel_size,self.alphabet_size)) #product of size (H * seq_len, alphabet) - + output_convolution_weight = self.sampler( + self.output_convolution_mean.weight, + self.output_convolution_log_var.weight + ) + output_convolution_weight = output_convolution_weight.view( + self.channel_size,self.alphabet_size + ) + #product of size (H * seq_len, alphabet) + W_out = W_out.view(self.fcnn_output_size, self.channel_size) + W_out = torch.mm(W_out, output_convolution_weight) + W_out = W_out.view(self.seq_len * self.alphabet_size, self.hidden_layers_sizes[-1]) + + # optionally, place a sparsity prior on the last hidden layer if self.include_sparsity: - sparsity_weights = self.sampler(self.sparsity_weight_mean,self.sparsity_weight_log_var) + # sparsity weights are shrunk towards either 0 or 1 by sigmoid + sparsity_weights = self.sampler( + self.sparsity_weight_mean, + self.sparsity_weight_log_var + ) sparsity_tiled = sparsity_weights.repeat(self.num_tiles_sparsity,1) sparsity_tiled = nn.Sigmoid()(sparsity_tiled).unsqueeze(2) - - W_out = W_out.view(self.hidden_layers_sizes[-1], self.seq_len, self.alphabet_size) * sparsity_tiled - - W_out = W_out.view(self.seq_len * self.alphabet_size, self.hidden_layers_sizes[-1]) + # Scale output layer by sparsity vector + W_out = W_out.view(self.hidden_layers_sizes[-1], self.seq_len, self.alphabet_size) + W_out = W_out * sparsity_tiled + W_out = W_out.view(self.seq_len * self.alphabet_size, self.hidden_layers_sizes[-1]) - x = F.linear(x, weight=W_out, bias=b_out) + # apply last layer + print(x.shape) + print(W_out.shape) + print(b_out.shape) + x = functional.linear(x, weight=W_out, bias=b_out) + # optionally, apply temperature scaling to final output if self.include_temperature_scaler: - temperature_scaler = self.sampler(self.temperature_scaler_mean,self.temperature_scaler_log_var) - x = torch.log(1.0+torch.exp(temperature_scaler)) * x - + temperature_scaler = self.sampler( + self.temperature_scaler_mean, + self.temperature_scaler_log_var + ) + x = x * torch.log(1.0+torch.exp(temperature_scaler)) + + # reshape output to shape (batch_size, seq_len, alphabet) + self.output_dim = () + batch_size = z.shape[0] x = x.view(batch_size, self.seq_len, self.alphabet_size) - x_recon_log = F.log_softmax(x, dim=-1) #of shape (batch_size, seq_len, alphabet) - + + # return reconstruction loss + x_recon_log = functional.log_softmax(x, dim=-1) return x_recon_log class VAE_Standard_MLP_decoder(nn.Module): @@ -212,29 +269,12 @@ def __init__(self, params): self.hidden_layers[str(layer_index)] = nn.Linear(self.hidden_layers_sizes[layer_index-1],self.hidden_layers_sizes[layer_index]) nn.init.constant_(self.hidden_layers[str(layer_index)].bias, self.mu_bias_init) - if params['first_hidden_nonlinearity'] == 'relu': - self.first_hidden_nonlinearity = nn.ReLU() - elif params['first_hidden_nonlinearity'] == 'tanh': - self.first_hidden_nonlinearity = nn.Tanh() - elif params['first_hidden_nonlinearity'] == 'sigmoid': - self.first_hidden_nonlinearity = nn.Sigmoid() - elif params['first_hidden_nonlinearity'] == 'elu': - self.first_hidden_nonlinearity = nn.ELU() - elif params['first_hidden_nonlinearity'] == 'linear': - self.first_hidden_nonlinearity = nn.Identity() - - if params['last_hidden_nonlinearity'] == 'relu': - self.last_hidden_nonlinearity = nn.ReLU() - elif params['last_hidden_nonlinearity'] == 'tanh': - self.last_hidden_nonlinearity = nn.Tanh() - elif params['last_hidden_nonlinearity'] == 'sigmoid': - self.last_hidden_nonlinearity = nn.Sigmoid() - elif params['last_hidden_nonlinearity'] == 'elu': - self.last_hidden_nonlinearity = nn.ELU() - elif params['last_hidden_nonlinearity'] == 'linear': - self.last_hidden_nonlinearity = nn.Identity() + # Set hidden layer nonlinearities for first and last layers + self.first_hidden_nonlinearity = HIDDEN_LAYER_NONLINEARITIES[params['first_hidden_nonlinearity']] + self.last_hidden_nonlinearity = HIDDEN_LAYER_NONLINEARITIES[params['last_hidden_nonlinearity']] if self.dropout_proba > 0.0: + self.use_dropout = True self.dropout_layer = nn.Dropout(p=self.dropout_proba) if self.convolve_output: diff --git a/data/weights/PTEN_HUMAN_theta_0.2.npy b/data/weights/PTEN_HUMAN_theta_0.2.npy new file mode 100644 index 0000000000000000000000000000000000000000..19048055aa6de463fa6e16bb384bfa0fc38ae0d3 GIT binary patch literal 9560 zcmb`M2~d>h8OK?pq9(Cu3?fA#HHtJ$qgAVd?JL%JMT|#Gwe0aNf;C1l0>+cd5=Fr) zDAh^gfmf7hychzeXqHY;iD^YEi4~_^jmKDt*%%L0s@?GJ@BL;!vO}8bx-dL;a+Mo}ZT&V{KGF;?D0qbJD+t zTSe>Jdk1XJvWh>HxOex;(`opx#b2%2mTVR2mra#1ztinq8nb>!Xu8ga#tt<4*G(VI z&)jPjZ(kh}bb7N@KtJ*Ll1Bp%TLpC>Z`R>sx8_pDH)p=J2=w4%7epU~@2oWy1%9UI z%X+g_FVr2oW))prCePO@h#SOTmxhm3V26HCpB;ZMDs-K*s`!v+9X;eexSmsVxXvnq zZ*=N7Hd?PI>(m2%`bylyXRJ4VkwXu;fD5LDbw9BR<^X@t9m5BGawgcO=7YMh?hmPx zzj~7gNIb~H2U~?Nbe9L()Vz={d7;M~fJVF2nYk`*({;$5ZZ?&F`}?N%!?xH2a@fHi z^MO8c#Al2HeuOLB^H29fU8xf|Yw*ZT$wO>v{^%clsa^5zlJ0$N0-Cvl#x8LZKl4UB z@X>=$J)wyUIr>lD*u@?+^3<1g;$j`dAM()XK^yB!e8dC7H};A8(ogK7$9nFfE?sY~ zu?XTN4%X2pextvjhL0w&j~w*qi2K`$#@ht*2_G7J=;4q0us%bwCe4%oFE1_KBByMh_aKZtlQ((;}x$z(*bo zcx0*%blb$fSjUH1c{Tx`c>>|r56db!+Fug`u2lye-ER}r6~r%g$(wcN4nF4?>(oD@ z&JkYc@WfM7uqSA;P0b5(=!YiMZ;Ot!sB?)tsn@i&Tk?lh*#vcEE{t;cMV`6md_tdn z3{9VzpKDQb>XUkC${%vrCvFfu*69O$;-qfaL!Z777kOh3xy%4lv*xWd!953`__2e1 zBi*WbGsmC^P0a`T$breBu}+=n zSD((#7w7k}h@j?ip06g_jJN%5VX94pTEmO;dTQc9`*|CRr)ffabKiS48#Tecdhyi5 zEu*g|3hE7?bBekf_dWf^Zbee?XCp>vf_|`%;4=rrPrb+k|6um))Pft6G<6>`$Lw3| zW7p_+k}ItKXBBppH~#Q{@!zFUn-^#Tnte*Yhy$9sQa{#-b8EMv`(NGG>x*6bL!W2G zEzjQR^vq%HFH)c1rz7#fZ|8m(TjF}`U(=)ytKhNEmQ2O@3H5~UYb>@iFOSi zyLu0i&$^+jcSY1_>OE}wW&N0yNp^93W4GEXC7RfG(uVj4E9;@K&u|#kFohsQu7RosQYm_mv^1 zHT8Z)zjsf|)rnp0>U^i(oU_Ecsr{v>xJ`-b9k4iPaI0q)YvN+^?4sQ(G{HLoHlSyR|_79m#u*_dIehw>AwBf6_#{Gi7!06Bg0=Oq}QI`91h*zo?j(EJDi75ccDa z?(dzNqTYwhCw4e5hNpks`?+c8SyX+9FY2qIgGT<@E+(8c-HKmgQRgLg_#FbmHzM|k zpLm-oqfqZ zho%k(Q)g%HFzyoW-yQvUh&kn^S*&?>gy;Jg;c_}23v6o^$iZ(4 z!<;71*R19ZeeB7+-OB!g-z-c1&nF2&?%%y)Lc^z8#U;-*hqq2{(PQ71vh16=ev!H& zXQbidk9goe{_bNrubf>Wz|ie+o^P%IW&T;-I%SEzk&en!H*DPt+VOvIXnV%;W3fG7r6|ljv6GF*NlsB6{%g3k@PS_4zG_ z7vEVUK>5yjOK=}RBmd6Z-+d~DKWx|U_p9z5NfhiK-aE(#%-gfde3`nqz4KyL=N0SR zgS>~0b%K_0In}*Hz99KROL>=icS8H3TfKj!-VSg5n?~%QCu#cEw7#rUVUzykIkHR1 zp-26UaS#uPK1hF>Lf&f$ACD{ed|#5PJ9$d1^2Sx65I^*T_c`&vM_;~|W~ls-V;y_Y ztj`O+wk@o;S-l(OI{Ujyy{ph;-Pb?+T-RfsedSnpdG|x1dS|MBG&u0gTf?W?doS!V z)J3Dc5IgCTc-X$4jS~;cdEY9AMyr$<%tu&E~PK!hyyv0KCzFfn~|<=l!q~H z;`)Bn_l^Af>}sz);xZz6z;~-U8~3yHr`v}{4w^b}5BSDe;f;s7FfSFVj>uDYXkV;Q z|8FkkT zZ4an>kh&UC+AHxRSLwCuZ=E@%J|${DNk50Ye?KOUO0|w(`e39bf4`Eaj^vLX`AUEK pJ6OI~GQ4w8rtHu!^qR7RoWx>fhj`G(E`0g>I^VCnw5Olj{{?nNxIq8_ literal 0 HcmV?d00001 diff --git a/examples/Step2_compute_evol_indices_all_singles.sh b/examples/Step2_predict_ELBO.sh similarity index 97% rename from examples/Step2_compute_evol_indices_all_singles.sh rename to examples/Step2_predict_ELBO.sh index c5b95f2..2f34311 100644 --- a/examples/Step2_compute_evol_indices_all_singles.sh +++ b/examples/Step2_predict_ELBO.sh @@ -13,7 +13,7 @@ export output_evol_indices_location='./results/evol_indices' export num_samples_compute_evol_indices=20000 export batch_size=2048 -python compute_evol_indices.py \ +python predict_ELBO.py \ --MSA_data_folder ${MSA_data_folder} \ --MSA_list ${MSA_list} \ --protein_index ${protein_index} \ diff --git a/examples/Step3_train_GMM_and_compute_EVE_scores_all_singles.sh b/examples/Step3_predict_GMM_score.sh similarity index 73% rename from examples/Step3_train_GMM_and_compute_EVE_scores_all_singles.sh rename to examples/Step3_predict_GMM_score.sh index a68b4f4..7bccad9 100644 --- a/examples/Step3_train_GMM_and_compute_EVE_scores_all_singles.sh +++ b/examples/Step3_predict_GMM_score.sh @@ -3,15 +3,12 @@ export input_evol_indices_filename_suffix='_20000_samples' export protein_list='./data/mappings/example_mapping.csv' export output_eve_scores_location='./results/EVE_scores' export output_eve_scores_filename_suffix='Jan1_PTEN_example' - export GMM_parameter_location='./results/GMM_parameters/Default_GMM_parameters' export GMM_parameter_filename_suffix='default' export protein_GMM_weight=0.3 -export plot_location='./results' -export labels_file_location='./data/labels/PTEN_ClinVar_labels.csv' export default_uncertainty_threshold_file_location='./utils/default_uncertainty_threshold.json' -python train_GMM_and_compute_EVE_scores.py \ +python predict_GMM_score.py \ --input_evol_indices_location ${input_evol_indices_location} \ --input_evol_indices_filename_suffix ${input_evol_indices_filename_suffix} \ --protein_list ${protein_list} \ @@ -22,11 +19,6 @@ python train_GMM_and_compute_EVE_scores.py \ --GMM_parameter_filename_suffix ${GMM_parameter_filename_suffix} \ --compute_EVE_scores \ --protein_GMM_weight ${protein_GMM_weight} \ - --plot_histograms \ - --plot_scores_vs_labels \ - --plot_location ${plot_location} \ - --labels_file_location ${labels_file_location} \ - --default_uncertainty_threshold_file_location ${default_uncertainty_threshold_file_location} \ + --compute_uncertainty_thresholds \ --verbose - \ No newline at end of file diff --git a/logs/PTEN_HUMAN_Jan1_PTEN_example_losses.csv b/logs/PTEN_HUMAN_Jan1_PTEN_example_losses.csv new file mode 100644 index 0000000..fc2975f --- /dev/null +++ b/logs/PTEN_HUMAN_Jan1_PTEN_example_losses.csv @@ -0,0 +1,33 @@ +Number of sequences in alignment file: 1179 +Neff: 131.5475617794571 +Alignment sequence length: 387 +Number of sequences in alignment file: 1179 +Neff: 131.5475617794571 +Alignment sequence length: 387 +Number of sequences in alignment file: 1179 +Neff: 131.5475617794571 +Alignment sequence length: 387 +Number of sequences in alignment file: 1179 +Neff: 131.5475617794571 +Alignment sequence length: 387 +Number of sequences in alignment file: 1179 +Neff: 131.5475617794571 +Alignment sequence length: 387 +Number of sequences in alignment file: 1179 +Neff: 131.5475617794571 +Alignment sequence length: 387 +Number of sequences in alignment file: 1179 +Neff: 131.5475617794571 +Alignment sequence length: 387 +Number of sequences in alignment file: 1179 +Neff: 131.5475617794571 +Alignment sequence length: 387 +Number of sequences in alignment file: 1179 +Neff: 131.5475617794571 +Alignment sequence length: 387 +Number of sequences in alignment file: 1179 +Neff: 131.5475617794571 +Alignment sequence length: 387 +Number of sequences in alignment file: 1179 +Neff: 131.5475617794571 +Alignment sequence length: 387 diff --git a/predict_ELBO.py b/predict_ELBO.py index f8f1388..2e61793 100644 --- a/predict_ELBO.py +++ b/predict_ELBO.py @@ -12,6 +12,29 @@ def main(args): mapping_file = pd.read_csv(args.MSA_list) protein_name = mapping_file['protein_name'][args.protein_index] print("Protein name: "+str(protein_name)) + if args.theta_reweighting is not None: + theta = args.theta_reweighting + else: + try: + theta = float(mapping_file['theta'][args.protein_index]) + except: + theta = 0.2 + print("Theta MSA re-weighting: "+str(theta)) + + # Load and preprocess MSA + msa_location = \ + args.MSA_data_folder + os.sep + \ + mapping_file['msa_location'][args.protein_index] + print("MSA file: "+str(msa_location)) + weights_location = \ + args.MSA_weights_location + os.sep + \ + protein_name + '_theta_' + str(theta) + '.npy' + data = data_utils.MSA_processing( + MSA_location=msa_location, + weights_location=weights_location, + theta=theta, + use_weights=True + ) # Load model model_name = protein_name + "_" + args.model_name_suffix @@ -36,29 +59,7 @@ def main(args): print("Unable to locate VAE model checkpoint") sys.exit(0) - # Load and preprocess MSA - msa_location = \ - args.MSA_data_folder + os.sep + \ - mapping_file['msa_location'][args.protein_index] - print("MSA file: "+str(msa_location)) - weights_location = \ - args.MSA_weights_location + os.sep + \ - protein_name + '_theta_' + str(theta) + '.npy' - if args.theta_reweighting is not None: - theta = args.theta_reweighting - else: - try: - theta = float(mapping_file['theta'][args.protein_index]) - except: - theta = 0.2 - print("Theta MSA re-weighting: "+str(theta)) - data = data_utils.MSA_processing( - MSA_location=msa_location, - weights_location=weights_location, - theta=theta, - use_weights=True - ) - + # Load mutations location if args.computation_mode=="all_singles": mutations_location = \ diff --git a/predict_GMM_score.py b/predict_GMM_score.py index 50d0df8..0437059 100644 --- a/predict_GMM_score.py +++ b/predict_GMM_score.py @@ -1,8 +1,4 @@ -from cProfile import label -from email.policy import default import os -from statistics import quantiles -from tabnanny import verbose import numpy as np import pandas as pd import argparse @@ -76,7 +72,7 @@ def filter_uncertainties(all_scores, n_quantiles, default_uc_location, recompute level = f'class_ucq_{i}' all_scores[level] = all_scores['class'] * (uncertainty < quantile) if verbose: - print("Stats classification by uncertainty for quantile #:"+str(decile)) + print("Stats classification by uncertainty for quantile #:"+str(quantile)) print(all_scores[level].value_counts(normalize=True)) return all_scores @@ -106,45 +102,33 @@ def main(args): 'GMM_model_dictionary_'+ \ args.GMM_parameter_filename_suffix with open(gmm_model_location, "rb" ) as fid: - dict_models = pickle.load(fid) - # Load GMM indices from file - gmm_index_location = args.GMM_parameter_location+os.sep+\ - 'GMM_pathogenic_cluster_index_dictionary_'+\ - args.GMM_parameter_filename_suffix - with open(gmm_index_location, "rb" ) as fid: - dict_pathogenic_cluster_index = pickle.load(fid) + gmm = pickle.load(fid) else: # train GMMs on mutation evolutionary indices - gmm_log_location = \ - args.GMM_parameter_location+os.sep+\ - args.output_eve_scores_filename_suffix+os.sep+\ - 'GMM_stats_'+args.output_eve_scores_filename_suffix+'.csv' gmm_params = { - 'log_location':gmm_log_location, 'protein_GMM_weight':args.protein_GMM_weight } - gmm = GMM_model.GMM_model(gmm_params) - gmm.fit( + gmm_model = GMM_model.GMM_model(gmm_params) + gmm_model.fit( all_evol_indices, - protein_list, - verbose=args.verbose + protein_list ) - # Write GMM models to file + # Write GMM models to pickle file gmm_model_location = \ args.GMM_parameter_location+os.sep+\ 'GMM_model_dictionary_'+ \ args.output_eve_scores_filename_suffix with open(gmm_model_location, "wb" ) as fid: - pickle.dump(gmm, fid) + pickle.dump(gmm_model, fid) # Compute EVE classification scores for all mutations and write to file all_scores = predict_scores_GMM( + gmm_model, all_evol_indices, - dict_models, - dict_pathogenic_cluster_index, - protein_list + protein_list, + args.recompute_uncertainty_threshold ) all_scores.to_csv( args.output_eve_scores_location+os.sep+ \