forked from OATML/EVE
-
Notifications
You must be signed in to change notification settings - Fork 41
Numba parallel weights computation + dataloader #5
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
loodvn
wants to merge
104
commits into
OATML-Markslab:master
Choose a base branch
from
loodvn:weights_2024
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
104 commits
Select commit
Hold shift + click to select a range
4e98cb2
Added some slurm scripts, small changes to VAE_model for file handling
fcb8d1b
defined instance variables in __init__ for clarity; extracted one_hot…
loodvn 9606f5a
change focus_seq_trimmed to a string instead of list of chars; added …
loodvn fb13194
error checking in compute_evol_indices
loodvn 09c8b8b
merged data_utils
loodvn c29fd56
added joint training, one-hot sequence functions
loodvn f0984cf
Merge remote-tracking branch 'origin/master' into master
loodvn ba75f74
moved optimizer.zero_grad() outside of if-else
loodvn bf3f1c5
had to comment out the alternating joint training for now, to switch …
loodvn 8d54d4b
joint training script improvements:
loodvn a4c7be9
Joint training: parameterize lm_loss_weight
loodvn ab1055d
adding EVcouplings versions, some data checks, and need to check that…
loodvn 57b7fb7
temp hehe
loodvn 9687b7c
committing all ideas for parallelising the weights calculation for re…
loodvn d65d435
checking on O2 now
loodvn 0597e02
adding mapping files
loodvn 636f88f
running as array
loodvn cc30589
changed logging dir
loodvn b5af044
removed old training flags
loodvn 87c30d5
some slurm script changes
loodvn b9ab6b1
updating conda bin and log output dir
loodvn 7bf655f
print equality
loodvn d8a0951
added directory exists checking
loodvn 7611fd8
running all proteins now
loodvn d235b9d
74 MSAs
loodvn c29ddae
oops was still debugging
loodvn 5993736
using new MSAs
loodvn 2b66b67
wrong MSA location
loodvn ad85d42
explicitly setting number of CPUs to use
loodvn 91e4ac0
also testing only 1 cpu
loodvn 8786baa
trying to add method as a parameter - flags are pretentious
loodvn 92a5d8b
changed reading in calc_weights option
loodvn 91efd61
bugfixes
loodvn 79b7030
bugfixes
loodvn 9a2ff26
another bug
loodvn a024bb4
using specific python from conda env because other jobs were lagging
loodvn c020bb8
skipping existing files because we'll get runtimes from logs anyway
loodvn c50aa89
using multiprocessing + numba now, is roughly as fast as parallel num…
loodvn 0badb28
going to try run all the weights and check
loodvn cebd2db
typo
loodvn c915e5a
moved all the tmp EVE vs EVCouplings checks out into calc_weights.py,…
loodvn 289e84a
going big, 40 cpus
loodvn 6429f69
moved all the tmp EVE vs EVCouplings checks out into calc_weights.py,…
loodvn a3be96c
going to calc all weights with 2h timeout for 8cpus (and later 4cpus)
loodvn 4b4ef45
moved all the weights calc into a different utils file
loodvn 6838046
added flag option to train_VAE.py to fail if weights not found (usefu…
loodvn 93804fe
separated the slurm calc_weights from local calc_weights.sh
loodvn 3a11b0c
modifying some other changes I made early on
loodvn 417865f
Merge branch 'OATML:master' into lood/speedup_weights
loodvn 0fe130c
cleaned up one_hot_3D
loodvn 99945e0
Merge branch 'OATML:master' into master
loodvn 9a5484b
check - editing readme
loodvn 3fc7a2e
Merge branch 'master' of https://github.com/loodvn/EVE
loodvn a125835
Committing all together, other branch corrupt
loodvn 4312346
temporarily removing data/* because git tree corrupt
loodvn 289c218
removed circular dependency in utils/weights
loodvn e902012
removing ADRB2, and old masters-project-specific files
loodvn b8ad7c4
Merge branch 'master' into lood/speedup_weights2
loodvn eccc14b
changed EVcouplings code to > insteead of >= theta
loodvn 7b1e5d1
Merge remote-tracking branch 'origin/lood/speedup_weights2' into lood…
loodvn fe0e95b
moved calc_weights to top level instead of nested in gen_alignment;
loodvn bf45aba
removed all the timing and speed comparison stuff from calc_weights.py
loodvn 95ca192
changed initial num_neighbors to 1 because num_neighbors = 0 didn't m…
loodvn 1316094
merged with Marks-OATML master
loodvn c658d1c
Merge remote-tracking branch 'marks/master'
loodvn e0f3754
don't need weights for scoring
loodvn 3efbda2
Merge branch 'lood/speedup_weights'
loodvn 16c8c7b
grabbed some nice files from deepseq_reproduce branch
loodvn c980752
move constants out,
loodvn a153a1e
adding disorder mapping file and script
loodvn b3f15c7
Merge branch 'lood/speedup_weights'
loodvn b0f3467
got mapping from deepseq_reproduce branch
loodvn a0c3167
adding disorder mapping file and script
loodvn 671669e
added threshold_focus_cols_frac_gaps to calc_weights too, also rerunn…
loodvn c8917f9
added overwrite_weights for my specific use case, just to be sure
loodvn eec4a42
Merge remote-tracking branch 'origin/master'
loodvn e60ee9d
added weight shape check, using threshold_focus_cols_frac_gaps = 1 si…
loodvn 2c7338b
rerunning with longer time limit
loodvn 15abb33
imported compute_evol_indices from deepseq_reproduce
loodvn 534b0db
kicking off scoring, had to recheckout compute_evol_indices from deep…
loodvn 536b207
turned DMS filename assertion into just a warning for now, need to fi…
loodvn 2cb5cd0
using updated MSA and DMS files (v7?), rerunning training/scoring acc…
loodvn 07d3984
added new disordered MSA using notebook in disorder_human project
loodvn aeade79
using new DMS and MSA mapping and new suffix
loodvn f761149
adpred scripts
loodvn af99627
allowed to pass in a MSA file directly to calc_weights instead of a m…
loodvn bb902f7
Merge branch 'lood/speedup_weights2'
loodvn f70712c
reformatted whitespace PEP8
loodvn 6043996
some more minor whitespace formatting
loodvn 875a625
syntax errors: added overwrite_weights to signature and fixed :: synt…
loodvn 098ce58
added overwrite_weights option to calc_weights.py
loodvn cb9aabd
added overwrite_weights option to calc_weights.py
loodvn 37375c4
Weights calc:
loodvn c8249f0
Training: Added some checks to input/output files
loodvn f9c291c
Tweaked progress bar; removed debugging statements
loodvn 3709d7d
Streaming one-hot-encodings is working well
loodvn e30f784
Using a --experimental_stream_data flag, a bit cleaner
loodvn 74238ea
Skipping synonymous mutants in the filtering, fixed tqdm bug
loodvn 455ffaf
Using protein_name in compute_evol_indices, added some logging
loodvn c801b52
Removed weights calculation comparison tests, cleaned up dataloader
loodvn 70f63e2
Computing one-hot encodings on the fly for evol_indices using dataloa…
loodvn e18c56f
Using dataloaders for train and validation, use multi-cpu weights by …
loodvn 3d48173
Added files back from upstream repo to match master before PR
loodvn d129b81
deleted internal scripts
loodvn File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,14 @@ | ||
| EVE/__pycache__/ | ||
| utils/__pycache__/ | ||
| results/VAE_parameters/* | ||
| !results/VAE_parameters/.gitkeep | ||
| !results/VAE_parameters/.gitkeep | ||
| logs/ | ||
| .idea/ | ||
| .ipynb_checkpoints/ | ||
| notebooks/ | ||
| results/*parameters?*/ | ||
| results/evol_indices/ | ||
| slurm/ | ||
| slurm_dan/ | ||
| # Reinclude examples | ||
| !data/mappings/example_mapping.csv | ||
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,7 +11,7 @@ EVE is a set of protein-specific models providing for any single amino acid muta | |
| The end to end process to compute EVE scores consists of three consecutive steps: | ||
| 1. Train the Bayesian VAE on a re-weighted multiple sequence alignment (MSA) for the protein of interest => train_VAE.py | ||
| 2. Compute the evolutionary indices for all single amino acid mutations => compute_evol_indices.py | ||
| 3. Train a GMM to cluster variants on the basis of the evol indices then output scores and uncertainties on the class assignments => train_GMM_and_compute_EVE_scores.py | ||
| 3. Train a GMM to cluster variants on the basis of the qevol indices then output scores and uncertainties on the class assignments => train_GMM_and_compute_EVE_scores.py | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. revert please |
||
| We also provide all EVE scores for all single amino acid mutations for thousands of proteins at the following address: http://evemodel.org/. | ||
|
|
||
| ## Example scripts | ||
|
|
@@ -47,6 +47,7 @@ The entire codebase is written in python. Package requirements are as follows: | |
| - tqdm | ||
| - matplotlib | ||
| - seaborn | ||
| - numba | ||
|
|
||
| The corresponding environment may be created via conda and the provided protein_env.yml file as follows: | ||
| ``` | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,123 @@ | ||
| # Basically train_VAE.py but just calculating the weights | ||
| import argparse | ||
| import os | ||
| import time | ||
|
|
||
| import numpy as np | ||
| import pandas as pd | ||
|
|
||
| from utils import data_utils | ||
|
|
||
|
|
||
| def create_argparser(): | ||
| parser = argparse.ArgumentParser(description='VAE') | ||
|
|
||
| # If we don't have a mapping file, just use a single MSA path | ||
| parser.add_argument("--MSA_filepath", type=str, help="Full path to MSA") | ||
|
|
||
| # If we have a mapping file with one MSA path per line | ||
| parser.add_argument('--MSA_data_folder', type=str, help='Folder where MSAs are stored', required=True) | ||
| parser.add_argument('--MSA_list', type=str, help='List of proteins and corresponding MSA file name', required=True) | ||
| parser.add_argument('--protein_index', type=int, help='Row index of protein in input mapping file', required=True) | ||
| parser.add_argument('--MSA_weights_location', type=str, | ||
| help='Location where weights for each sequence in the MSA will be stored', required=True) | ||
| parser.add_argument('--theta_reweighting', type=float, help='Parameters for MSA sequence re-weighting') | ||
| parser.add_argument("--num_cpus", type=int, help="Number of CPUs to use", default=1) | ||
| parser.add_argument("--skip_existing", help="Will quit gracefully if weights file already exists", action="store_true", default=False) | ||
| parser.add_argument("--overwrite", help="Will overwrite existing weights file", action="store_true", default=False) | ||
| parser.add_argument("--calc_method", choices=["evcouplings", "eve", "both", "identity"], help="Method to use for calculating weights. Note: Both produce the same results as we modified the evcouplings numba code to mirror the eve calculation", default="evcouplings") | ||
| parser.add_argument("--threshold_focus_cols_frac_gaps", type=float, | ||
| help="Maximum fraction of gaps allowed in focus columns - see data_utils.MSA_processing") | ||
| return parser | ||
|
|
||
|
|
||
| def main(args): | ||
| print("Arguments:", args) | ||
|
|
||
| weights_file = None | ||
|
|
||
| if args.MSA_filepath is not None: | ||
| assert os.path.isfile(args.MSA_filepath), f"MSA filepath {args.MSA_filepath} doesn't exist" | ||
| msa_location = args.MSA_filepath | ||
| else: | ||
| # Use mapping file | ||
| assert os.path.isfile(args.MSA_list), f"MSA file list {args.MSA_list} doesn't seem to exist" | ||
| mapping_file = pd.read_csv(args.MSA_list) | ||
| protein_name = mapping_file['protein_name'][args.protein_index] | ||
| msa_location = args.MSA_data_folder + os.sep + mapping_file['msa_location'][args.protein_index] | ||
| print("Protein name: " + str(protein_name)) | ||
| # If weights_file is in the df_mapping, use that instead | ||
| if "weight_file_name" in mapping_file.columns: | ||
| weights_file = args.MSA_weights_location + os.sep + mapping_file["weight_file_name"][args.protein_index] | ||
| print("Using weights filename from mapping file:", weights_file) | ||
|
|
||
| print("MSA file: " + str(msa_location)) | ||
|
|
||
| if args.theta_reweighting is not None: | ||
| theta = args.theta_reweighting | ||
| print(f"Using custom theta value {theta} instead of loading from mapping file.") | ||
| else: | ||
| try: | ||
| theta = float(mapping_file['theta'][args.protein_index]) | ||
| except KeyError as e: | ||
| # Overriding previous errors is bad, but we're being nice to the user | ||
| raise KeyError("Couldn't load theta from mapping file. " | ||
| "NOT using default value of theta=0.2; please specify theta manually. Specific line:", | ||
| mapping_file[args.protein_index], | ||
| "Previous error:", e) | ||
| assert not np.isnan(theta), "Theta is NaN, please provide a custom theta value" | ||
|
|
||
| print("Theta MSA re-weighting: " + str(theta)) | ||
|
|
||
| # Using data_kwargs so that if options aren't set, they'll be set to default values | ||
| data_kwargs = {} | ||
| if args.threshold_focus_cols_frac_gaps is not None: | ||
| print("Using custom threshold_focus_cols_frac_gaps: ", args.threshold_focus_cols_frac_gaps) | ||
| data_kwargs['threshold_focus_cols_frac_gaps'] = args.threshold_focus_cols_frac_gaps | ||
|
|
||
| if not os.path.isdir(args.MSA_weights_location): | ||
| # exist_ok=True: Otherwise we'll get some race conditions between concurrent jobs | ||
| os.makedirs(args.MSA_weights_location, exist_ok=True) | ||
| # print(f"{args.MSA_weights_location} is not a directory. " | ||
| # f"Being nice and creating it for you, but this might be a mistake.") | ||
| raise NotADirectoryError(f"{args.MSA_weights_location} is not a directory." | ||
| f"Could create it automatically, but at the moment raising an error.") | ||
| else: | ||
| print(f"MSA weights directory: {args.MSA_weights_location}") | ||
|
|
||
| if weights_file is None: | ||
| print("Weights filename not found - writing to new file") | ||
| weights_file = args.MSA_weights_location + os.sep + protein_name + '_theta_' + str(theta) + '.npy' | ||
|
|
||
| print(f"Writing to {weights_file}") | ||
| # First check that the weights file doesn't exist | ||
| if os.path.isfile(weights_file) and not args.overwrite: | ||
| if args.skip_existing: | ||
| print("Weights file already exists, skipping, since --skip_existing was specified") | ||
| exit(0) | ||
| else: | ||
| raise FileExistsError(f"File {weights_file} already exists. " | ||
| f"Please delete it if you want to re-calculate it. " | ||
| f"If you want to skip existing files, use --skip_existing.") | ||
|
|
||
| # The msa_data processing has a side effect of saving a weights file | ||
| _ = data_utils.MSA_processing( | ||
| MSA_location=msa_location, | ||
| theta=theta, | ||
| use_weights=True, | ||
| weights_location=weights_file, | ||
| num_cpus=args.num_cpus, | ||
| weights_calc_method=args.calc_method, | ||
| overwrite_weights=args.overwrite, | ||
| skip_one_hot_encodings=True, | ||
| **data_kwargs, | ||
| ) | ||
|
|
||
|
|
||
| if __name__ == '__main__': | ||
| start = time.perf_counter() | ||
| parser = create_argparser() | ||
| args = parser.parse_args() | ||
| main(args) | ||
| end = time.perf_counter() | ||
| print(f"calc_weights.py took {end-start:.2f} seconds in total.") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,20 +1,20 @@ | ||
| import os,sys | ||
| import json | ||
| import argparse | ||
| import os | ||
|
|
||
| import pandas as pd | ||
| import torch | ||
|
|
||
| from EVE import VAE_model | ||
| from utils import data_utils | ||
|
|
||
| if __name__=='__main__': | ||
| if __name__ == '__main__': | ||
|
|
||
| parser = argparse.ArgumentParser(description='Evol indices') | ||
| parser.add_argument('--MSA_data_folder', type=str, help='Folder where MSAs are stored') | ||
| parser.add_argument('--MSA_list', type=str, help='List of proteins and corresponding MSA file name') | ||
| parser.add_argument('--protein_index', type=int, help='Row index of protein in input mapping file') | ||
| parser.add_argument('--MSA_weights_location', type=str, help='Location where weights for each sequence in the MSA will be stored') | ||
| parser.add_argument('--theta_reweighting', type=float, help='Parameters for MSA sequence re-weighting') | ||
| # parser.add_argument('--MSA_weights_location', type=str, help='Location where weights for each sequence in the MSA will be stored') | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should these arguments be deprecated instead of removed entirely? |
||
| # parser.add_argument('--theta_reweighting', type=float, help='Parameters for MSA sequence re-weighting') | ||
| parser.add_argument('--VAE_checkpoint_location', type=str, help='Location where VAE model checkpoints will be stored') | ||
| parser.add_argument('--model_name_suffix', default='Jan1', type=str, help='model checkpoint name is the protein name followed by this suffix') | ||
| parser.add_argument('--model_parameters_location', type=str, help='Location of VAE model parameters') | ||
|
|
@@ -27,29 +27,33 @@ | |
| parser.add_argument('--batch_size', default=256, type=int, help='Batch size when computing evol indices') | ||
| args = parser.parse_args() | ||
|
|
||
| print("Arguments=", args) | ||
|
|
||
| mapping_file = pd.read_csv(args.MSA_list) | ||
| protein_name = mapping_file['protein_name'][args.protein_index] | ||
| msa_location = args.MSA_data_folder + os.sep + mapping_file['msa_location'][args.protein_index] | ||
| print("Protein name: "+str(protein_name)) | ||
| print("MSA file: "+str(msa_location)) | ||
|
|
||
| if args.theta_reweighting is not None: | ||
| theta = args.theta_reweighting | ||
| else: | ||
| try: | ||
| theta = float(mapping_file['theta'][args.protein_index]) | ||
| except: | ||
| theta = 0.2 | ||
| print("Theta MSA re-weighting: "+str(theta)) | ||
| # Theta reweighting not necessary for computing evol indices | ||
| # if args.theta_reweighting is not None: | ||
| # theta = args.theta_reweighting | ||
| # else: | ||
| # try: | ||
| # theta = float(mapping_file['theta'][args.protein_index]) | ||
| # except: | ||
| # print("Theta not found in mapping file. Using default value of 0.2") | ||
| # theta = 0.2 | ||
| # print("Theta MSA re-weighting: "+str(theta)) | ||
|
|
||
| data = data_utils.MSA_processing( | ||
| MSA_location=msa_location, | ||
| theta=theta, | ||
| use_weights=True, | ||
| weights_location=args.MSA_weights_location + os.sep + protein_name + '_theta_' + str(theta) + '.npy' | ||
| # theta=theta, | ||
| use_weights=False, | ||
| # weights_location=args.MSA_weights_location + os.sep + protein_name + '_theta_' + str(theta) + '.npy' | ||
| ) | ||
|
|
||
| if args.computation_mode=="all_singles": | ||
| if args.computation_mode == "all_singles": | ||
| data.save_all_singles(output_filename=args.all_singles_mutations_folder + os.sep + protein_name + "_all_singles.csv") | ||
| args.mutations_location = args.all_singles_mutations_folder + os.sep + protein_name + "_all_singles.csv" | ||
| else: | ||
|
|
@@ -58,25 +62,24 @@ | |
| model_name = protein_name + "_" + args.model_name_suffix | ||
| print("Model name: "+str(model_name)) | ||
|
|
||
| model_params = json.load(open(args.model_parameters_location)) | ||
| # model_params = json.load(open(args.model_parameters_location)) | ||
|
|
||
| checkpoint_name = str(args.VAE_checkpoint_location) + os.sep + model_name + "_final" | ||
| assert os.path.isdir(args.VAE_checkpoint_location), "Cannot find dir"+args.VAE_checkpoint_location | ||
| assert os.path.isfile(checkpoint_name), "Cannot find "+checkpoint_name+".\nOther options: "+str([f for f in os.listdir('.') if os.path.isfile(f)]) | ||
| checkpoint = torch.load(checkpoint_name) | ||
|
|
||
| model = VAE_model.VAE_model( | ||
| model_name=model_name, | ||
| data=data, | ||
| encoder_parameters=model_params["encoder_parameters"], | ||
| decoder_parameters=model_params["decoder_parameters"], | ||
| encoder_parameters=checkpoint["encoder_parameters"], | ||
| decoder_parameters=checkpoint["decoder_parameters"], | ||
| random_seed=42 | ||
| ) | ||
| model = model.to(model.device) | ||
|
|
||
| try: | ||
| checkpoint_name = str(args.VAE_checkpoint_location) + os.sep + model_name + "_final" | ||
| checkpoint = torch.load(checkpoint_name) | ||
| model.load_state_dict(checkpoint['model_state_dict']) | ||
| print("Initialized VAE with checkpoint '{}' ".format(checkpoint_name)) | ||
| except: | ||
| print("Unable to locate VAE model checkpoint") | ||
| sys.exit(0) | ||
| model.load_state_dict(checkpoint['model_state_dict']) | ||
| print("Initialized VAE with checkpoint '{}' ".format(checkpoint_name)) | ||
|
|
||
| list_valid_mutations, evol_indices, _, _ = model.compute_evol_indices(msa_data=data, | ||
| list_mutations_location=args.mutations_location, | ||
|
|
@@ -93,5 +96,6 @@ | |
| try: | ||
| keep_header = os.stat(evol_indices_output_filename).st_size == 0 | ||
| except: | ||
| keep_header=True | ||
| df.to_csv(path_or_buf=evol_indices_output_filename, index=False, mode='a', header=keep_header) | ||
| keep_header = True | ||
| df.to_csv(path_or_buf=evol_indices_output_filename, index=False, mode='a', header=keep_header) | ||
| print("Script completed successfully.") | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.