diff --git a/.gitignore b/.gitignore index c4ef2bb..5655558 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,14 @@ EVE/__pycache__/ utils/__pycache__/ results/VAE_parameters/* -!results/VAE_parameters/.gitkeep \ No newline at end of file +!results/VAE_parameters/.gitkeep +logs/ +.idea/ +.ipynb_checkpoints/ +notebooks/ +results/*parameters?*/ +results/evol_indices/ +slurm/ +slurm_dan/ +# Reinclude examples +!data/mappings/example_mapping.csv diff --git a/EVE/VAE_model.py b/EVE/VAE_model.py index a637a8a..56ad337 100644 --- a/EVE/VAE_model.py +++ b/EVE/VAE_model.py @@ -1,8 +1,11 @@ +import datetime import os +import sys + import numpy as np import pandas as pd import time -import tqdm +from tqdm import tqdm from scipy.special import erfinv from sklearn.model_selection import train_test_split @@ -12,60 +15,67 @@ import torch.optim as optim import torch.backends.cudnn as cudnn +from utils.data_utils import one_hot_3D, get_training_dataloader, get_one_hot_dataloader from . import VAE_encoder, VAE_decoder + class VAE_model(nn.Module): """ Class for the VAE model with estimation of weights distribution parameters via Mean-Field VI. """ + def __init__(self, - model_name, - data, - encoder_parameters, - decoder_parameters, - random_seed - ): - + model_name, + data, + encoder_parameters, + decoder_parameters, + random_seed, + seq_len=None, + alphabet_size=None, + Neff=None, + ): + super().__init__() - + self.model_name = model_name self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.dtype = torch.float32 self.random_seed = random_seed torch.manual_seed(random_seed) - - self.seq_len = data.seq_len - self.alphabet_size = data.alphabet_size - self.Neff = data.Neff - self.encoder_parameters=encoder_parameters - self.decoder_parameters=decoder_parameters + self.seq_len = seq_len if seq_len is not None else data.seq_len + self.alphabet_size = alphabet_size if alphabet_size is not None else data.alphabet_size + self.Neff = Neff if Neff is not None else data.Neff + + self.encoder_parameters = encoder_parameters + self.decoder_parameters = decoder_parameters encoder_parameters['seq_len'] = self.seq_len encoder_parameters['alphabet_size'] = self.alphabet_size decoder_parameters['seq_len'] = self.seq_len decoder_parameters['alphabet_size'] = self.alphabet_size - + self.encoder = VAE_encoder.VAE_MLP_encoder(params=encoder_parameters) if decoder_parameters['bayesian_decoder']: self.decoder = VAE_decoder.VAE_Bayesian_MLP_decoder(params=decoder_parameters) else: self.decoder = VAE_decoder.VAE_Standard_MLP_decoder(params=decoder_parameters) self.logit_sparsity_p = decoder_parameters['logit_sparsity_p'] - + def sample_latent(self, mu, log_var): """ Samples a latent vector via reparametrization trick """ eps = torch.randn_like(mu).to(self.device) - z = torch.exp(0.5*log_var) * eps + mu + z = torch.exp(0.5 * log_var) * eps + mu return z def KLD_diag_gaussians(self, mu, logvar, p_mu, p_logvar): """ KL divergence between diagonal gaussian with prior diagonal gaussian. """ - KLD = 0.5 * (p_logvar - logvar) + 0.5 * (torch.exp(logvar) + torch.pow(mu-p_mu,2)) / (torch.exp(p_logvar)+1e-20) - 0.5 + KLD = 0.5 * (p_logvar - logvar) + 0.5 * (torch.exp(logvar) + torch.pow(mu - p_mu, 2)) / ( + torch.exp(p_logvar) + 1e-20) - 0.5 return torch.sum(KLD) @@ -74,7 +84,7 @@ def annealing_factor(self, annealing_warm_up, training_step): Annealing schedule of KL to focus on reconstruction error in early stages of training """ if training_step < annealing_warm_up: - return training_step/annealing_warm_up + return training_step / annealing_warm_up else: return 1 @@ -83,57 +93,60 @@ def KLD_global_parameters(self): KL divergence between the variational distributions and the priors (for the decoder weights). """ KLD_decoder_params = 0.0 - zero_tensor = torch.tensor(0.0).to(self.device) - + zero_tensor = torch.tensor(0.0).to(self.device) + for layer_index in range(len(self.decoder.hidden_layers_sizes)): - for param_type in ['weight','bias']: + for param_type in ['weight', 'bias']: KLD_decoder_params += self.KLD_diag_gaussians( - self.decoder.state_dict(keep_vars=True)['hidden_layers_mean.'+str(layer_index)+'.'+param_type].flatten(), - self.decoder.state_dict(keep_vars=True)['hidden_layers_log_var.'+str(layer_index)+'.'+param_type].flatten(), - zero_tensor, - zero_tensor - ) - - for param_type in ['weight','bias']: - KLD_decoder_params += self.KLD_diag_gaussians( - self.decoder.state_dict(keep_vars=True)['last_hidden_layer_'+param_type+'_mean'].flatten(), - self.decoder.state_dict(keep_vars=True)['last_hidden_layer_'+param_type+'_log_var'].flatten(), - zero_tensor, - zero_tensor + self.decoder.state_dict(keep_vars=True)[ + 'hidden_layers_mean.' + str(layer_index) + '.' + param_type].flatten(), + self.decoder.state_dict(keep_vars=True)[ + 'hidden_layers_log_var.' + str(layer_index) + '.' + param_type].flatten(), + zero_tensor, + zero_tensor ) + for param_type in ['weight', 'bias']: + KLD_decoder_params += self.KLD_diag_gaussians( + self.decoder.state_dict(keep_vars=True)['last_hidden_layer_' + param_type + '_mean'].flatten(), + self.decoder.state_dict(keep_vars=True)['last_hidden_layer_' + param_type + '_log_var'].flatten(), + zero_tensor, + zero_tensor + ) + if self.decoder.include_sparsity: self.logit_scale_sigma = 4.0 - self.logit_scale_mu = 2.0**0.5 * self.logit_scale_sigma * erfinv(2.0 * self.logit_sparsity_p - 1.0) + self.logit_scale_mu = 2.0 ** 0.5 * self.logit_scale_sigma * erfinv(2.0 * self.logit_sparsity_p - 1.0) - sparsity_mu = torch.tensor(self.logit_scale_mu).to(self.device) - sparsity_log_var = torch.log(torch.tensor(self.logit_scale_sigma**2)).to(self.device) + sparsity_mu = torch.tensor(self.logit_scale_mu).to(self.device) + sparsity_log_var = torch.log(torch.tensor(self.logit_scale_sigma ** 2)).to(self.device) KLD_decoder_params += self.KLD_diag_gaussians( - self.decoder.state_dict(keep_vars=True)['sparsity_weight_mean'].flatten(), - self.decoder.state_dict(keep_vars=True)['sparsity_weight_log_var'].flatten(), - sparsity_mu, - sparsity_log_var + self.decoder.state_dict(keep_vars=True)['sparsity_weight_mean'].flatten(), + self.decoder.state_dict(keep_vars=True)['sparsity_weight_log_var'].flatten(), + sparsity_mu, + sparsity_log_var ) - + if self.decoder.convolve_output: for param_type in ['weight']: KLD_decoder_params += self.KLD_diag_gaussians( - self.decoder.state_dict(keep_vars=True)['output_convolution_mean.'+param_type].flatten(), - self.decoder.state_dict(keep_vars=True)['output_convolution_log_var.'+param_type].flatten(), - zero_tensor, - zero_tensor + self.decoder.state_dict(keep_vars=True)['output_convolution_mean.' + param_type].flatten(), + self.decoder.state_dict(keep_vars=True)['output_convolution_log_var.' + param_type].flatten(), + zero_tensor, + zero_tensor ) if self.decoder.include_temperature_scaler: KLD_decoder_params += self.KLD_diag_gaussians( - self.decoder.state_dict(keep_vars=True)['temperature_scaler_mean'].flatten(), - self.decoder.state_dict(keep_vars=True)['temperature_scaler_log_var'].flatten(), - zero_tensor, - zero_tensor - ) + self.decoder.state_dict(keep_vars=True)['temperature_scaler_mean'].flatten(), + self.decoder.state_dict(keep_vars=True)['temperature_scaler_log_var'].flatten(), + zero_tensor, + zero_tensor + ) return KLD_decoder_params - def loss_function(self, x_recon_log, x, mu, log_var, kl_latent_scale, kl_global_params_scale, annealing_warm_up, training_step, Neff): + def loss_function(self, x_recon_log, x, mu, log_var, kl_latent_scale, kl_global_params_scale, annealing_warm_up, + training_step, Neff): """ Returns mean of negative ELBO, reconstruction loss and KL divergence across batch x. """ @@ -143,10 +156,11 @@ def loss_function(self, x_recon_log, x, mu, log_var, kl_latent_scale, kl_global_ KLD_decoder_params_normalized = self.KLD_global_parameters() / Neff else: KLD_decoder_params_normalized = 0.0 - warm_up_scale = self.annealing_factor(annealing_warm_up,training_step) - neg_ELBO = BCE + warm_up_scale * (kl_latent_scale * KLD_latent + kl_global_params_scale * KLD_decoder_params_normalized) + warm_up_scale = self.annealing_factor(annealing_warm_up, training_step) + neg_ELBO = BCE + warm_up_scale * ( + kl_latent_scale * KLD_latent + kl_global_params_scale * KLD_decoder_params_normalized) return neg_ELBO, BCE, KLD_latent, KLD_decoder_params_normalized - + def all_likelihood_components(self, x): """ Returns tensors of ELBO, reconstruction loss and KL divergence for each point in batch x. @@ -155,188 +169,296 @@ def all_likelihood_components(self, x): z = self.sample_latent(mu, log_var) recon_x_log = self.decoder(z) - recon_x_log = recon_x_log.view(-1,self.alphabet_size*self.seq_len) - x = x.view(-1,self.alphabet_size*self.seq_len) - - BCE_batch_tensor = torch.sum(F.binary_cross_entropy_with_logits(recon_x_log, x, reduction='none'),dim=1) - KLD_batch_tensor = (-0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp(),dim=1)) - + recon_x_log = recon_x_log.view(-1, self.alphabet_size * self.seq_len) + x = x.view(-1, self.alphabet_size * self.seq_len) + + BCE_batch_tensor = torch.sum(F.binary_cross_entropy_with_logits(recon_x_log, x, reduction='none'), dim=1) + KLD_batch_tensor = (-0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp(), dim=1)) + ELBO_batch_tensor = -(BCE_batch_tensor + KLD_batch_tensor) return ELBO_batch_tensor, BCE_batch_tensor, KLD_batch_tensor - def train_model(self, data, training_parameters): + def all_likelihood_components_z(self, x, mu, log_var): + """Skip the encoder part and directly sample z""" + # Need to run mu, log_var = self.encoder(x) first + z = self.sample_latent(mu, log_var) + recon_x_log = self.decoder(z) + + recon_x_log = recon_x_log.view(-1, self.alphabet_size * self.seq_len) + x = x.view(-1, self.alphabet_size * self.seq_len) + + BCE_batch_tensor = torch.sum(F.binary_cross_entropy_with_logits(recon_x_log, x, reduction='none'), dim=1) + KLD_batch_tensor = (-0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp(), dim=1)) + + ELBO_batch_tensor = -(BCE_batch_tensor + KLD_batch_tensor) + + return ELBO_batch_tensor, BCE_batch_tensor, KLD_batch_tensor + + def train_model(self, data, training_parameters, use_dataloader=False): """ Training procedure for the VAE model. If use_validation_set is True then: - we split the alignment data in train/val sets. - we train up to num_training_steps steps but store the version of the model with lowest loss on validation set across training - If not, then we train the model for num_training_steps and save the model at the end of training + If not, then we train the model for num_training_steps and save the model at the end of training. + + use_dataloader: Whether to stream in the one-hot encodings via a dataloader. + If False, loads in the entire one-hot encoding matrix into memory and iterates over it. """ if torch.cuda.is_available(): cudnn.benchmark = True self.train() - + if training_parameters['log_training_info']: - filename = training_parameters['training_logs_location']+os.sep+self.model_name+"_losses.csv" + filename = training_parameters['training_logs_location'] + os.sep + self.model_name + "_losses.csv" with open(filename, "a") as logs: - logs.write("Number of sequences in alignment file:\t"+str(data.num_sequences)+"\n") - logs.write("Neff:\t"+str(self.Neff)+"\n") - logs.write("Alignment sequence length:\t"+str(data.seq_len)+"\n") + logs.write("Number of sequences in alignment file:\t" + str(data.num_sequences) + "\n") + logs.write("Neff:\t" + str(self.Neff) + "\n") + logs.write("Alignment sequence length:\t" + str(data.seq_len) + "\n") + + optimizer = optim.Adam(self.parameters(), lr=training_parameters['learning_rate'], + weight_decay=training_parameters['l2_regularization']) - optimizer = optim.Adam(self.parameters(), lr=training_parameters['learning_rate'], weight_decay = training_parameters['l2_regularization']) - if training_parameters['use_lr_scheduler']: - scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=training_parameters['lr_scheduler_step_size'], gamma=training_parameters['lr_scheduler_gamma']) + scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=training_parameters['lr_scheduler_step_size'], + gamma=training_parameters['lr_scheduler_gamma']) + list_sequences = list(data.seq_name_to_sequence.values()) if training_parameters['use_validation_set']: - x_train, x_val, weights_train, weights_val = train_test_split(data.one_hot_encoding, data.weights, test_size=training_parameters['validation_set_pct'], random_state=self.random_seed) + if use_dataloader: + seqs_train, seqs_val, weights_train, weights_val = train_test_split(list_sequences, + data.weights, + test_size=training_parameters['validation_set_pct'], + random_state=self.random_seed) + # Validation still just passes in the whole one-hot encoding in one go + x_val = one_hot_3D(seqs_val, alphabet=data.alphabet, seq_length=data.seq_len) + assert len(seqs_train) == weights_train.shape[0] # One weight per sequence + else: + x_train, x_val, weights_train, weights_val = train_test_split(data.one_hot_encoding, data.weights, + test_size=training_parameters['validation_set_pct'], + random_state=self.random_seed) + assert x_train.shape[0] == weights_train.shape[0] # One weight per sequence best_val_loss = float('inf') - best_model_step_index=0 + best_model_step_index = 0 else: - x_train = data.one_hot_encoding + seqs_train = list_sequences weights_train = data.weights best_val_loss = None best_model_step_index = training_parameters['num_training_steps'] - - batch_order = np.arange(x_train.shape[0]) + seq_sample_probs = weights_train / np.sum(weights_train) + + # Keep old behaviour for comparison + if use_dataloader: + # Stream one-hot encodings + train_dataloader = get_training_dataloader(sequences=seqs_train, weights=weights_train, alphabet=data.alphabet, seq_len=data.seq_len, batch_size=training_parameters['batch_size'], num_training_steps=training_parameters['num_training_steps']) + else: + batch_order = np.arange(x_train.shape[0]) + assert batch_order.shape == seq_sample_probs.shape, f"batch_order and seq_sample_probs must have the same shape. batch_order.shape={batch_order.shape}, seq_sample_probs.shape={seq_sample_probs.shape}" + def get_mock_dataloader(): + while True: + # Sample a batch according to sequence weight + batch_index = np.random.choice(batch_order, training_parameters['batch_size'], p=seq_sample_probs).tolist() + batch = x_train[batch_index] + yield batch + train_dataloader = get_mock_dataloader() self.Neff_training = np.sum(weights_train) - N_training = x_train.shape[0] - + start = time.time() train_loss = 0 - - for training_step in tqdm.tqdm(range(1,training_parameters['num_training_steps']+1), desc="Training model"): + for training_step, batch in enumerate(tqdm(train_dataloader, desc="Training model", total=training_parameters['num_training_steps'], mininterval=5)): - batch_index = np.random.choice(batch_order, training_parameters['batch_size'], p=seq_sample_probs).tolist() - x = torch.tensor(x_train[batch_index], dtype=self.dtype).to(self.device) + # For the dataloader, we may have to manually end training at + if training_step >= training_parameters['num_training_steps']: + break + x = batch.to(self.device, dtype=self.dtype) + optimizer.zero_grad() mu, log_var = self.encoder(x) z = self.sample_latent(mu, log_var) recon_x_log = self.decoder(z) - - neg_ELBO, BCE, KLD_latent, KLD_decoder_params_normalized = self.loss_function(recon_x_log, x, mu, log_var, training_parameters['kl_latent_scale'], training_parameters['kl_global_params_scale'], training_parameters['annealing_warm_up'], training_step, self.Neff_training) - + + neg_ELBO, BCE, KLD_latent, KLD_decoder_params_normalized = self.loss_function( + recon_x_log, x, mu, log_var, + training_parameters['kl_latent_scale'], + training_parameters['kl_global_params_scale'], + training_parameters['annealing_warm_up'], + training_step, + self.Neff_training) + neg_ELBO.backward() optimizer.step() - + if training_parameters['use_lr_scheduler']: scheduler.step() - + if training_step % training_parameters['log_training_freq'] == 0: - progress = "|Train : Update {0}. Negative ELBO : {1:.3f}, BCE: {2:.3f}, KLD_latent: {3:.3f}, KLD_decoder_params_norm: {4:.3f}, Time: {5:.2f} |".format(training_step, neg_ELBO, BCE, KLD_latent, KLD_decoder_params_normalized, time.time() - start) + progress = "|Train : Update {0}. Negative ELBO : {1:.3f}, BCE: {2:.3f}, KLD_latent: {3:.3f}, KLD_decoder_params_norm: {4:.3f}, Time: {5:.2f} |".format( + training_step, neg_ELBO, BCE, KLD_latent, KLD_decoder_params_normalized, time.time() - start) print(progress) if training_parameters['log_training_info']: - with open(filename, "a") as logs: - logs.write(progress+"\n") - - if training_step % training_parameters['save_model_params_freq']==0: - self.save(model_checkpoint=training_parameters['model_checkpoint_location']+os.sep+self.model_name+"_step_"+str(training_step), - encoder_parameters=self.encoder_parameters, - decoder_parameters=self.decoder_parameters, - training_parameters=training_parameters) - + with open(filename, "a+") as logs: + logs.write(progress + "\n") + + if training_step % training_parameters['save_model_params_freq'] == 0: + self.save(model_checkpoint=training_parameters[ + 'model_checkpoint_location'] + os.sep + self.model_name + "_step_" + str( + training_step), + encoder_parameters=self.encoder_parameters, + decoder_parameters=self.decoder_parameters, + training_parameters=training_parameters) + if training_parameters['use_validation_set'] and training_step % training_parameters['validation_freq'] == 0: - x_val = torch.tensor(x_val, dtype=self.dtype).to(self.device) - val_neg_ELBO, val_BCE, val_KLD_latent, val_KLD_global_parameters = self.test_model(x_val, weights_val, training_parameters['batch_size']) + x_val = x_val.to(self.device, dtype=self.dtype) + val_neg_ELBO, val_BCE, val_KLD_latent, val_KLD_global_parameters = self.test_model(x_val, weights_val, + training_parameters[ + 'batch_size']) - progress_val = "\t\t\t|Val : Update {0}. Negative ELBO : {1:.3f}, BCE: {2:.3f}, KLD_latent: {3:.3f}, KLD_decoder_params_norm: {4:.3f}, Time: {5:.2f} |".format(training_step, val_neg_ELBO, val_BCE, val_KLD_latent, val_KLD_global_parameters, time.time() - start) + progress_val = "\t\t\t|Val : Update {0}. Negative ELBO : {1:.3f}, BCE: {2:.3f}, KLD_latent: {3:.3f}, KLD_decoder_params_norm: {4:.3f}, Time: {5:.2f} |".format( + training_step, val_neg_ELBO, val_BCE, val_KLD_latent, val_KLD_global_parameters, + time.time() - start) print(progress_val) if training_parameters['log_training_info']: - with open(filename, "a") as logs: - logs.write(progress_val+"\n") + with open(filename, "a+") as logs: + logs.write(progress_val + "\n") if val_neg_ELBO < best_val_loss: best_val_loss = val_neg_ELBO best_model_step_index = training_step - self.save(model_checkpoint=training_parameters['model_checkpoint_location']+os.sep+self.model_name+"_best", - encoder_parameters=self.encoder_parameters, - decoder_parameters=self.decoder_parameters, - training_parameters=training_parameters) + self.save(model_checkpoint=training_parameters[ + 'model_checkpoint_location'] + os.sep + self.model_name + "_best", + encoder_parameters=self.encoder_parameters, + decoder_parameters=self.decoder_parameters, + training_parameters=training_parameters) self.train() - + + def test_model(self, x_val, weights_val, batch_size): self.eval() - + with torch.no_grad(): val_batch_order = np.arange(x_val.shape[0]) val_seq_sample_probs = weights_val / np.sum(weights_val) val_batch_index = np.random.choice(val_batch_order, batch_size, p=val_seq_sample_probs).tolist() - x = torch.tensor(x_val[val_batch_index], dtype=self.dtype).to(self.device) + x = x_val[val_batch_index].to(self.device, dtype=self.dtype) mu, log_var = self.encoder(x) z = self.sample_latent(mu, log_var) recon_x_log = self.decoder(z) - - neg_ELBO, BCE, KLD_latent, KLD_global_parameters = self.loss_function(recon_x_log, x, mu, log_var, kl_latent_scale=1.0, kl_global_params_scale=1.0, annealing_warm_up=0, training_step=1, Neff = self.Neff_training) #set annealing factor to 1 - + + neg_ELBO, BCE, KLD_latent, KLD_global_parameters = self.loss_function(recon_x_log, x, mu, log_var, + kl_latent_scale=1.0, + kl_global_params_scale=1.0, + annealing_warm_up=0, training_step=1, + Neff=self.Neff_training) # set annealing factor to 1 + return neg_ELBO.item(), BCE.item(), KLD_latent.item(), KLD_global_parameters.item() - def save(self, model_checkpoint, encoder_parameters, decoder_parameters, training_parameters, batch_size=256): + # Create intermediate dirs above this + os.makedirs(os.path.dirname(model_checkpoint), exist_ok=True) torch.save({ - 'model_state_dict':self.state_dict(), - 'encoder_parameters':encoder_parameters, - 'decoder_parameters':decoder_parameters, - 'training_parameters':training_parameters, - }, model_checkpoint) - - def compute_evol_indices(self, msa_data, list_mutations_location, num_samples, batch_size=256): + 'model_state_dict': self.state_dict(), + 'encoder_parameters': encoder_parameters, + 'decoder_parameters': decoder_parameters, + 'training_parameters': training_parameters, + }, model_checkpoint) + + def compute_evol_indices(self, msa_data, list_mutations_location, num_samples, batch_size=256, + mutant_column="mutations"): """ - The column in the list_mutations dataframe that contains the mutant(s) for a given variant should be called "mutations" + The column in the list_mutations dataframe that contains the mutant(s) for a given variant should be called "mutations" """ - #Multiple mutations are to be passed colon-separated - list_mutations=pd.read_csv(list_mutations_location, header=0) - - #Remove (multiple) mutations that are invalid - list_valid_mutations = ['wt'] + + # Note: wt is added inside this function, so no need to add a row in csv/dataframe input with wt + list_mutations = pd.read_csv(list_mutations_location, header=0) + + # Multiple mutations are to be passed colon-separated + # Remove (multiple) mutations that are invalid + list_valid_mutations, list_valid_mutated_sequences = self.validate_mutants(msa_data=msa_data, mutations=list_mutations[mutant_column]) + + # first sequence in the list is the wild_type + list_valid_mutations = ['wt'] + list_valid_mutations + list_valid_mutated_sequences['wt'] = msa_data.focus_seq_trimmed + + dataloader = get_one_hot_dataloader(seq_keys=list_valid_mutations, + seq_name_to_sequence=list_valid_mutated_sequences, + alphabet=msa_data.alphabet, + seq_len=len(msa_data.focus_cols), + batch_size=batch_size) + + # Store wt_mean_predictions + with torch.no_grad(): + mean_predictions = torch.zeros(len(list_valid_mutations)) + std_predictions = torch.zeros(len(list_valid_mutations)) + for i, batch in enumerate(tqdm(dataloader, 'Looping through mutation batches')): + batch_samples = torch.zeros(len(batch), num_samples, dtype=self.dtype, device=self.device) # Keep this on GPU + x = batch.type(self.dtype).to(self.device) + mu, log_var = self.encoder(x) + for j in tqdm(range(num_samples), 'Looping through number of samples for batch #: ' + str(i + 1), mininterval=5): + # seq_predictions, _, _ = self.all_likelihood_components(x) + seq_predictions, _, _ = self.all_likelihood_components_z(x, mu, log_var) + batch_samples[:, j] = seq_predictions # Note: We could move this straight to CPU to save GPU space + + mean_predictions[i * batch_size:i * batch_size + len(x)] = batch_samples.mean(dim=1, keepdim=False) + std_predictions[i * batch_size:i * batch_size + len(x)] = batch_samples.std(dim=1, keepdim=False) + tqdm.write('\n') + + delta_elbos = mean_predictions - mean_predictions[0] + evol_indices = - delta_elbos.detach().cpu().numpy() + + return list_valid_mutations, evol_indices, mean_predictions[0].detach().cpu().numpy(), std_predictions.detach().cpu().numpy() + + def validate_mutants(self, msa_data, mutations): + list_valid_mutations = [] list_valid_mutated_sequences = {} - list_valid_mutated_sequences['wt'] = msa_data.focus_seq_trimmed # first sequence in the list is the wild_type - for mutation in list_mutations['mutations']: - individual_substitutions = mutation.split(':') + + for mutation in mutations: + try: + individual_substitutions = str(mutation).split(':') + except Exception as e: + print("Error with mutant {}".format(str(mutation))) + print("Specific error: " + str(e)) + continue mutated_sequence = list(msa_data.focus_seq_trimmed)[:] fully_valid_mutation = True for mut in individual_substitutions: - wt_aa, pos, mut_aa = mut[0], int(mut[1:-1]), mut[-1] - if pos not in msa_data.uniprot_focus_col_to_wt_aa_dict or msa_data.uniprot_focus_col_to_wt_aa_dict[pos] != wt_aa or mut not in msa_data.mutant_to_letter_pos_idx_focus_list: - print ("Not a valid mutant: "+mutation) + try: + wt_aa, pos, mut_aa = mut[0], int(mut[1:-1]), mut[-1] + if wt_aa == mut_aa: # Skip synonymous + continue + # Log specific invalid mutants + if pos not in msa_data.uniprot_focus_col_to_wt_aa_dict: + print("pos {} not in uniprot_focus_col_to_wt_aa_dict".format(pos)) + fully_valid_mutation = False + # Given it's in the dict, check if it's a valid mutation + elif msa_data.uniprot_focus_col_to_wt_aa_dict[pos] != wt_aa: + print("wt_aa {} != uniprot_focus_col_to_wt_aa_dict[{}] {}".format( + wt_aa, pos, msa_data.uniprot_focus_col_to_wt_aa_dict[pos])) + fully_valid_mutation = False + if mut not in msa_data.mutant_to_letter_pos_idx_focus_list: + print("mut {} not in mutant_to_letter_pos_idx_focus_list".format(mut)) + fully_valid_mutation = False + + if fully_valid_mutation: + wt_aa, pos, idx_focus = msa_data.mutant_to_letter_pos_idx_focus_list[mut] + mutated_sequence[idx_focus] = mut_aa # perform the corresponding AA substitution + else: + print("Not a valid mutant: " + mutation) + break + + except Exception as e: + print("Error processing mutation {} in mutant {}".format(str(mut), str(mutation))) + print("Specific error: " + str(e)) fully_valid_mutation = False break - else: - wt_aa,pos,idx_focus = msa_data.mutant_to_letter_pos_idx_focus_list[mut] - mutated_sequence[idx_focus] = mut_aa #perform the corresponding AA substitution - + if fully_valid_mutation: list_valid_mutations.append(mutation) list_valid_mutated_sequences[mutation] = ''.join(mutated_sequence) - - #One-hot encoding of mutated sequences - mutated_sequences_one_hot = np.zeros((len(list_valid_mutations),len(msa_data.focus_cols),len(msa_data.alphabet))) - for i,mutation in enumerate(list_valid_mutations): - sequence = list_valid_mutated_sequences[mutation] - for j,letter in enumerate(sequence): - if letter in msa_data.aa_dict: - k = msa_data.aa_dict[letter] - mutated_sequences_one_hot[i,j,k] = 1.0 - - mutated_sequences_one_hot = torch.tensor(mutated_sequences_one_hot) - dataloader = torch.utils.data.DataLoader(mutated_sequences_one_hot, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True) - prediction_matrix = torch.zeros((len(list_valid_mutations),num_samples)) - - with torch.no_grad(): - for i, batch in enumerate(tqdm.tqdm(dataloader, 'Looping through mutation batches')): - x = batch.type(self.dtype).to(self.device) - for j in tqdm.tqdm(range(num_samples), 'Looping through number of samples for batch #: '+str(i+1)): - seq_predictions, _, _ = self.all_likelihood_components(x) - prediction_matrix[i*batch_size:i*batch_size+len(x),j] = seq_predictions - tqdm.tqdm.write('\n') - mean_predictions = prediction_matrix.mean(dim=1, keepdim=False) - std_predictions = prediction_matrix.std(dim=1, keepdim=False) - delta_elbos = mean_predictions - mean_predictions[0] - evol_indices = - delta_elbos.detach().cpu().numpy() - return list_valid_mutations, evol_indices, mean_predictions[0].detach().cpu().numpy(), std_predictions.detach().cpu().numpy() \ No newline at end of file + return list_valid_mutations, list_valid_mutated_sequences diff --git a/README.md b/README.md index 7499a28..6f2914e 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,7 @@ EVE is a set of protein-specific models providing for any single amino acid muta The end to end process to compute EVE scores consists of three consecutive steps: 1. Train the Bayesian VAE on a re-weighted multiple sequence alignment (MSA) for the protein of interest => train_VAE.py 2. Compute the evolutionary indices for all single amino acid mutations => compute_evol_indices.py -3. Train a GMM to cluster variants on the basis of the evol indices then output scores and uncertainties on the class assignments => train_GMM_and_compute_EVE_scores.py +3. Train a GMM to cluster variants on the basis of the qevol indices then output scores and uncertainties on the class assignments => train_GMM_and_compute_EVE_scores.py We also provide all EVE scores for all single amino acid mutations for thousands of proteins at the following address: http://evemodel.org/. ## Example scripts @@ -47,6 +47,7 @@ The entire codebase is written in python. Package requirements are as follows: - tqdm - matplotlib - seaborn + - numba The corresponding environment may be created via conda and the provided protein_env.yml file as follows: ``` diff --git a/calc_weights.py b/calc_weights.py new file mode 100644 index 0000000..23c9dc3 --- /dev/null +++ b/calc_weights.py @@ -0,0 +1,123 @@ +# Basically train_VAE.py but just calculating the weights +import argparse +import os +import time + +import numpy as np +import pandas as pd + +from utils import data_utils + + +def create_argparser(): + parser = argparse.ArgumentParser(description='VAE') + + # If we don't have a mapping file, just use a single MSA path + parser.add_argument("--MSA_filepath", type=str, help="Full path to MSA") + + # If we have a mapping file with one MSA path per line + parser.add_argument('--MSA_data_folder', type=str, help='Folder where MSAs are stored', required=True) + parser.add_argument('--MSA_list', type=str, help='List of proteins and corresponding MSA file name', required=True) + parser.add_argument('--protein_index', type=int, help='Row index of protein in input mapping file', required=True) + parser.add_argument('--MSA_weights_location', type=str, + help='Location where weights for each sequence in the MSA will be stored', required=True) + parser.add_argument('--theta_reweighting', type=float, help='Parameters for MSA sequence re-weighting') + parser.add_argument("--num_cpus", type=int, help="Number of CPUs to use", default=1) + parser.add_argument("--skip_existing", help="Will quit gracefully if weights file already exists", action="store_true", default=False) + parser.add_argument("--overwrite", help="Will overwrite existing weights file", action="store_true", default=False) + parser.add_argument("--calc_method", choices=["evcouplings", "eve", "both", "identity"], help="Method to use for calculating weights. Note: Both produce the same results as we modified the evcouplings numba code to mirror the eve calculation", default="evcouplings") + parser.add_argument("--threshold_focus_cols_frac_gaps", type=float, + help="Maximum fraction of gaps allowed in focus columns - see data_utils.MSA_processing") + return parser + + +def main(args): + print("Arguments:", args) + + weights_file = None + + if args.MSA_filepath is not None: + assert os.path.isfile(args.MSA_filepath), f"MSA filepath {args.MSA_filepath} doesn't exist" + msa_location = args.MSA_filepath + else: + # Use mapping file + assert os.path.isfile(args.MSA_list), f"MSA file list {args.MSA_list} doesn't seem to exist" + mapping_file = pd.read_csv(args.MSA_list) + protein_name = mapping_file['protein_name'][args.protein_index] + msa_location = args.MSA_data_folder + os.sep + mapping_file['msa_location'][args.protein_index] + print("Protein name: " + str(protein_name)) + # If weights_file is in the df_mapping, use that instead + if "weight_file_name" in mapping_file.columns: + weights_file = args.MSA_weights_location + os.sep + mapping_file["weight_file_name"][args.protein_index] + print("Using weights filename from mapping file:", weights_file) + + print("MSA file: " + str(msa_location)) + + if args.theta_reweighting is not None: + theta = args.theta_reweighting + print(f"Using custom theta value {theta} instead of loading from mapping file.") + else: + try: + theta = float(mapping_file['theta'][args.protein_index]) + except KeyError as e: + # Overriding previous errors is bad, but we're being nice to the user + raise KeyError("Couldn't load theta from mapping file. " + "NOT using default value of theta=0.2; please specify theta manually. Specific line:", + mapping_file[args.protein_index], + "Previous error:", e) + assert not np.isnan(theta), "Theta is NaN, please provide a custom theta value" + + print("Theta MSA re-weighting: " + str(theta)) + + # Using data_kwargs so that if options aren't set, they'll be set to default values + data_kwargs = {} + if args.threshold_focus_cols_frac_gaps is not None: + print("Using custom threshold_focus_cols_frac_gaps: ", args.threshold_focus_cols_frac_gaps) + data_kwargs['threshold_focus_cols_frac_gaps'] = args.threshold_focus_cols_frac_gaps + + if not os.path.isdir(args.MSA_weights_location): + # exist_ok=True: Otherwise we'll get some race conditions between concurrent jobs + os.makedirs(args.MSA_weights_location, exist_ok=True) + # print(f"{args.MSA_weights_location} is not a directory. " + # f"Being nice and creating it for you, but this might be a mistake.") + raise NotADirectoryError(f"{args.MSA_weights_location} is not a directory." + f"Could create it automatically, but at the moment raising an error.") + else: + print(f"MSA weights directory: {args.MSA_weights_location}") + + if weights_file is None: + print("Weights filename not found - writing to new file") + weights_file = args.MSA_weights_location + os.sep + protein_name + '_theta_' + str(theta) + '.npy' + + print(f"Writing to {weights_file}") + # First check that the weights file doesn't exist + if os.path.isfile(weights_file) and not args.overwrite: + if args.skip_existing: + print("Weights file already exists, skipping, since --skip_existing was specified") + exit(0) + else: + raise FileExistsError(f"File {weights_file} already exists. " + f"Please delete it if you want to re-calculate it. " + f"If you want to skip existing files, use --skip_existing.") + + # The msa_data processing has a side effect of saving a weights file + _ = data_utils.MSA_processing( + MSA_location=msa_location, + theta=theta, + use_weights=True, + weights_location=weights_file, + num_cpus=args.num_cpus, + weights_calc_method=args.calc_method, + overwrite_weights=args.overwrite, + skip_one_hot_encodings=True, + **data_kwargs, + ) + + +if __name__ == '__main__': + start = time.perf_counter() + parser = create_argparser() + args = parser.parse_args() + main(args) + end = time.perf_counter() + print(f"calc_weights.py took {end-start:.2f} seconds in total.") diff --git a/compute_evol_indices.py b/compute_evol_indices.py index 9e87034..76aed13 100644 --- a/compute_evol_indices.py +++ b/compute_evol_indices.py @@ -1,20 +1,20 @@ -import os,sys -import json import argparse +import os + import pandas as pd import torch from EVE import VAE_model from utils import data_utils -if __name__=='__main__': +if __name__ == '__main__': parser = argparse.ArgumentParser(description='Evol indices') parser.add_argument('--MSA_data_folder', type=str, help='Folder where MSAs are stored') parser.add_argument('--MSA_list', type=str, help='List of proteins and corresponding MSA file name') parser.add_argument('--protein_index', type=int, help='Row index of protein in input mapping file') - parser.add_argument('--MSA_weights_location', type=str, help='Location where weights for each sequence in the MSA will be stored') - parser.add_argument('--theta_reweighting', type=float, help='Parameters for MSA sequence re-weighting') + # parser.add_argument('--MSA_weights_location', type=str, help='Location where weights for each sequence in the MSA will be stored') + # parser.add_argument('--theta_reweighting', type=float, help='Parameters for MSA sequence re-weighting') parser.add_argument('--VAE_checkpoint_location', type=str, help='Location where VAE model checkpoints will be stored') parser.add_argument('--model_name_suffix', default='Jan1', type=str, help='model checkpoint name is the protein name followed by this suffix') parser.add_argument('--model_parameters_location', type=str, help='Location of VAE model parameters') @@ -27,29 +27,33 @@ parser.add_argument('--batch_size', default=256, type=int, help='Batch size when computing evol indices') args = parser.parse_args() + print("Arguments=", args) + mapping_file = pd.read_csv(args.MSA_list) protein_name = mapping_file['protein_name'][args.protein_index] msa_location = args.MSA_data_folder + os.sep + mapping_file['msa_location'][args.protein_index] print("Protein name: "+str(protein_name)) print("MSA file: "+str(msa_location)) - if args.theta_reweighting is not None: - theta = args.theta_reweighting - else: - try: - theta = float(mapping_file['theta'][args.protein_index]) - except: - theta = 0.2 - print("Theta MSA re-weighting: "+str(theta)) + # Theta reweighting not necessary for computing evol indices + # if args.theta_reweighting is not None: + # theta = args.theta_reweighting + # else: + # try: + # theta = float(mapping_file['theta'][args.protein_index]) + # except: + # print("Theta not found in mapping file. Using default value of 0.2") + # theta = 0.2 + # print("Theta MSA re-weighting: "+str(theta)) data = data_utils.MSA_processing( MSA_location=msa_location, - theta=theta, - use_weights=True, - weights_location=args.MSA_weights_location + os.sep + protein_name + '_theta_' + str(theta) + '.npy' + # theta=theta, + use_weights=False, + # weights_location=args.MSA_weights_location + os.sep + protein_name + '_theta_' + str(theta) + '.npy' ) - if args.computation_mode=="all_singles": + if args.computation_mode == "all_singles": data.save_all_singles(output_filename=args.all_singles_mutations_folder + os.sep + protein_name + "_all_singles.csv") args.mutations_location = args.all_singles_mutations_folder + os.sep + protein_name + "_all_singles.csv" else: @@ -58,25 +62,24 @@ model_name = protein_name + "_" + args.model_name_suffix print("Model name: "+str(model_name)) - model_params = json.load(open(args.model_parameters_location)) + # model_params = json.load(open(args.model_parameters_location)) + + checkpoint_name = str(args.VAE_checkpoint_location) + os.sep + model_name + "_final" + assert os.path.isdir(args.VAE_checkpoint_location), "Cannot find dir"+args.VAE_checkpoint_location + assert os.path.isfile(checkpoint_name), "Cannot find "+checkpoint_name+".\nOther options: "+str([f for f in os.listdir('.') if os.path.isfile(f)]) + checkpoint = torch.load(checkpoint_name) model = VAE_model.VAE_model( model_name=model_name, data=data, - encoder_parameters=model_params["encoder_parameters"], - decoder_parameters=model_params["decoder_parameters"], + encoder_parameters=checkpoint["encoder_parameters"], + decoder_parameters=checkpoint["decoder_parameters"], random_seed=42 ) model = model.to(model.device) - try: - checkpoint_name = str(args.VAE_checkpoint_location) + os.sep + model_name + "_final" - checkpoint = torch.load(checkpoint_name) - model.load_state_dict(checkpoint['model_state_dict']) - print("Initialized VAE with checkpoint '{}' ".format(checkpoint_name)) - except: - print("Unable to locate VAE model checkpoint") - sys.exit(0) + model.load_state_dict(checkpoint['model_state_dict']) + print("Initialized VAE with checkpoint '{}' ".format(checkpoint_name)) list_valid_mutations, evol_indices, _, _ = model.compute_evol_indices(msa_data=data, list_mutations_location=args.mutations_location, @@ -93,5 +96,6 @@ try: keep_header = os.stat(evol_indices_output_filename).st_size == 0 except: - keep_header=True - df.to_csv(path_or_buf=evol_indices_output_filename, index=False, mode='a', header=keep_header) \ No newline at end of file + keep_header = True + df.to_csv(path_or_buf=evol_indices_output_filename, index=False, mode='a', header=keep_header) + print("Script completed successfully.") diff --git a/compute_evol_indices_DMS.py b/compute_evol_indices_DMS.py new file mode 100644 index 0000000..0ed3e41 --- /dev/null +++ b/compute_evol_indices_DMS.py @@ -0,0 +1,147 @@ +import datetime +import os,sys +import json +import argparse +from resource import getrusage, RUSAGE_SELF + +import pandas as pd +import torch + +from EVE import VAE_model +from utils import data_utils + + +def parse_args(): + parser = argparse.ArgumentParser(description='Evol indices') + parser.add_argument('--MSA_data_folder', type=str, help='Folder where MSAs are stored') + parser.add_argument('--DMS_reference_file_path', type=str, help='List of proteins and corresponding MSA file name') + parser.add_argument('--protein_index', type=int, help='Row index of protein in input mapping file') + parser.add_argument('--theta_reweighting', type=float, help='Parameters for MSA sequence re-weighting') + parser.add_argument('--seeds', type=int, nargs="+", + help='Random seeds of VAE checkpoints to load (can specify one or several)') + parser.add_argument('--VAE_checkpoint_location', type=str, + help='Location where VAE model checkpoints will be stored') + parser.add_argument('--model_parameters_location', type=str, help='Location of VAE model parameters') + parser.add_argument('--DMS_data_folder', type=str, help='Location of all mutations to compute the evol indices for') + parser.add_argument('--output_scores_folder', type=str, help='Output location of computed evol indices') + parser.add_argument('--output_evol_indices_filename_suffix', default='', type=str, + help='(Optional) Suffix to be added to output filename') + parser.add_argument('--num_samples_compute_evol_indices', type=int, + help='Num of samples to approximate delta elbo when computing evol indices') + parser.add_argument('--batch_size', default=256, type=int, help='Batch size when computing evol indices') + parser.add_argument("--skip_existing", action="store_true", help="Skip scoring if output file already exists") + parser.add_argument("--threshold_focus_cols_frac_gaps", type=float, + help="Maximum fraction of gaps allowed in focus columns - see data_utils.MSA_processing") + args = parser.parse_args() + return args + + +if __name__=='__main__': + args = parse_args() + print("Arguments:", args) + + assert os.path.isfile(args.DMS_reference_file_path), 'MSA list file does not exist: {}'.format(args.DMS_reference_file_path) + mapping_file = pd.read_csv(args.DMS_reference_file_path) + DMS_id = mapping_file['DMS_id'][args.protein_index] + # Use MSA_filename, because UniProt_ID is not unique + protein_name = mapping_file['MSA_filename'][args.protein_index].split(".a2m")[0] + DMS_filename = mapping_file['DMS_filename'][args.protein_index] + mutant = mapping_file['DMS_filename'][args.protein_index] + msa_location = args.MSA_data_folder + os.sep + mapping_file['MSA_filename'][args.protein_index] + DMS_mutant_column = "mutant" + print("Protein name: "+str(protein_name)) + print("MSA file: "+str(msa_location)) + print("DMS id: "+str(DMS_id)) + if not DMS_filename.startswith(DMS_id): + print(f"Warning: DMS id does not match DMS filename: {DMS_id} vs {DMS_filename}. Continuing for now.") + + # Check filepaths are valid + evol_indices_output_filename = os.path.join(args.output_scores_folder, f'{DMS_id}.csv') + + if os.path.isfile(evol_indices_output_filename): + print("Output file already exists: " + str(evol_indices_output_filename)) + + if args.skip_existing: + print("Skipping scoring since args.skip_existing is True") + sys.exit(0) + else: + print("Overwriting existing file: " + str(evol_indices_output_filename)) + print("To skip scoring for existing files, use --skip_existing") + # Check if surrounding directory exists + else: + print("Output file: " + str(evol_indices_output_filename)) + assert os.path.isdir(os.path.dirname(evol_indices_output_filename)), \ + 'Output directory does not exist: {}. Please create directory before running script.\nOutput filename given: {}.\nDebugging curdir={}'\ + .format(os.path.dirname(evol_indices_output_filename), evol_indices_output_filename, os.getcwd()) + + if args.theta_reweighting is not None: + theta = args.theta_reweighting + else: + try: + theta = float(mapping_file['MSA_theta'][args.protein_index]) + except: + theta = 0.2 + print("Theta MSA re-weighting: "+str(theta)) + + # Using data_kwargs so that if options aren't set, they'll be set to default values + data_kwargs = {} + if args.threshold_focus_cols_frac_gaps is not None: + print("Using custom threshold_focus_cols_frac_gaps: ", args.threshold_focus_cols_frac_gaps) + data_kwargs['threshold_focus_cols_frac_gaps'] = args.threshold_focus_cols_frac_gaps + + data = data_utils.MSA_processing( + MSA_location=msa_location, + theta=theta, + use_weights=False, # Don't need weights for evol indices + skip_one_hot_encodings=True, # One-hot encodings computed on the fly + **data_kwargs, + ) + + args.mutations_location = os.path.join(args.DMS_data_folder, DMS_filename) + for seed in args.seeds: + model_name = protein_name + f"_seed_{seed}" + print("Model name: "+str(model_name)) + + model_params = json.load(open(args.model_parameters_location)) + + model = VAE_model.VAE_model( + model_name=model_name, + data=data, + encoder_parameters=model_params["encoder_parameters"], + decoder_parameters=model_params["decoder_parameters"], + random_seed=42 + ) + model = model.to(model.device) + checkpoint_name = str(args.VAE_checkpoint_location) + os.sep + model_name + assert os.path.isfile(checkpoint_name), 'Checkpoint file does not exist: {}'.format(checkpoint_name) + + try: + checkpoint = torch.load(checkpoint_name, map_location=model.device) # Added map_location so that this works with CPU too + model.load_state_dict(checkpoint['model_state_dict']) + print("Initialized VAE with checkpoint '{}' ".format(checkpoint_name)) + except Exception as e: + print("Unable to load VAE model checkpoint {}".format(checkpoint_name)) + raise e + + list_valid_mutations, evol_indices, _, _ = model.compute_evol_indices( + msa_data=data, + list_mutations_location=args.mutations_location, + mutant_column=DMS_mutant_column, + num_samples=args.num_samples_compute_evol_indices, + batch_size=args.batch_size, + ) + + df = {} + df['mutant'] = list_valid_mutations + df[f'evol_indices_seed_{seed}'] = evol_indices + df = pd.DataFrame(df) + + if os.path.exists(evol_indices_output_filename) and seed != args.random_seeds[0]: + prev_df = pd.read_csv(evol_indices_output_filename) + prev_len = len(prev_df) + df = pd.merge(prev_df, df, on='mutant', how='inner') + # checking that the mutants match after the first seed (first seed will overwrite original score file) + assert len(df) == len(prev_df), "Length of merged dataframe doesn't match previous length, mutants must not match across seeds" + df.to_csv(evol_indices_output_filename, index=False) + else: + df.to_csv(evol_indices_output_filename, index=False) diff --git a/data/mappings/example_mapping.csv b/data/mappings/example_mapping.csv index 410662c..def4996 100644 --- a/data/mappings/example_mapping.csv +++ b/data/mappings/example_mapping.csv @@ -1,2 +1,4 @@ -protein_name,msa_location,theta -PTEN_HUMAN,PTEN_HUMAN_b1.0.a2m,0.2 \ No newline at end of file +protein_name,MSA_filename,MSA_theta +YAP1_HUMAN,YAP1_HUMAN_full_11-26-2021_b02.a2m,0.2 +PTEN_HUMAN,PTEN_HUMAN_b1.0.a2m,0.2 +DLG4_RAT,DLG4_RAT_full_11-26-2021_b03.a2m,0.2 diff --git a/examples/Step0_optional_calc_weights.sh b/examples/Step0_optional_calc_weights.sh new file mode 100644 index 0000000..89e1a2e --- /dev/null +++ b/examples/Step0_optional_calc_weights.sh @@ -0,0 +1,21 @@ +#!/bin/bash +set -e # fail fully on first line failure (from Joost slurm_for_ml) + +# Send python outputs (like print) directly to terminal/log without buffering +export PYTHONUNBUFFERED=1 + +export MSA_data_folder='./data/MSA' +export MSA_list='./data/mappings/example_mapping.csv' +export MSA_weights_location='./data/weights' +export protein_index=0 +export num_cpus=1 +export calc_method='evcouplings' + +python3 calc_weights.py \ + --MSA_data_folder ${MSA_data_folder} \ + --MSA_list ${MSA_list} \ + --protein_index "${protein_index}" \ + --MSA_weights_location "${MSA_weights_location}" \ + --num_cpus "$num_cpus" \ + --calc_method ${calc_method} +# --skip_existing \ No newline at end of file diff --git a/examples/Step1_train_VAE.sh b/examples/Step1_train_VAE.sh index 4ead15b..570d68b 100644 --- a/examples/Step1_train_VAE.sh +++ b/examples/Step1_train_VAE.sh @@ -1,3 +1,4 @@ +#! /bin/bash export MSA_data_folder='./data/MSA' export MSA_list='./data/mappings/example_mapping.csv' export MSA_weights_location='./data/weights' @@ -15,4 +16,6 @@ python train_VAE.py \ --VAE_checkpoint_location ${VAE_checkpoint_location} \ --model_name_suffix ${model_name_suffix} \ --model_parameters_location ${model_parameters_location} \ - --training_logs_location ${training_logs_location} \ No newline at end of file + --training_logs_location ${training_logs_location} \ + --batch_size 256 \ + --experimental_stream_data \ No newline at end of file diff --git a/protein_env.yml b/protein_env.yml index 07914cf..19ba34e 100644 --- a/protein_env.yml +++ b/protein_env.yml @@ -13,4 +13,8 @@ dependencies: - scipy=1.6.2 - tqdm - matplotlib - - seaborn \ No newline at end of file + - seaborn + - numba==0.54.1 + - pip + - pip: + - numba-progress \ No newline at end of file diff --git a/train_VAE.py b/train_VAE.py index 22cc4df..0823d9a 100644 --- a/train_VAE.py +++ b/train_VAE.py @@ -1,58 +1,141 @@ -import os, sys import argparse -import pandas as pd import json +import time +import os + + +import pandas as pd from EVE import VAE_model from utils import data_utils -if __name__=='__main__': +if __name__ == '__main__': parser = argparse.ArgumentParser(description='VAE') - parser.add_argument('--MSA_data_folder', type=str, help='Folder where MSAs are stored') - parser.add_argument('--MSA_list', type=str, help='List of proteins and corresponding MSA file name') - parser.add_argument('--protein_index', type=int, help='Row index of protein in input mapping file') - parser.add_argument('--MSA_weights_location', type=str, help='Location where weights for each sequence in the MSA will be stored') - parser.add_argument('--theta_reweighting', type=float, help='Parameters for MSA sequence re-weighting') - parser.add_argument('--VAE_checkpoint_location', type=str, help='Location where VAE model checkpoints will be stored') - parser.add_argument('--model_name_suffix', default='Jan1', type=str, help='model checkpoint name will be the protein name followed by this suffix') - parser.add_argument('--model_parameters_location', type=str, help='Location of VAE model parameters') - parser.add_argument('--training_logs_location', type=str, help='Location of VAE model parameters') - parser.add_argument('--seed', type=int, default=42, help='Random seed') + parser.add_argument('--MSA_data_folder', type=str, + help='Folder where MSAs are stored', required=True) + parser.add_argument('--MSA_list', type=str, + help='List of proteins and corresponding MSA file name', required=True) + parser.add_argument('--protein_index', type=int, + help='Row index of protein in input mapping file', required=True) + parser.add_argument('--MSA_weights_location', type=str, + help='Location where weights for each sequence in the MSA will be stored', required=True) + parser.add_argument('--theta_reweighting', type=float, + help='Parameters for MSA sequence re-weighting') + parser.add_argument('--VAE_checkpoint_location', type=str, help='Location where VAE model checkpoints will be stored', required=True) + parser.add_argument('--model_name_suffix', help='Model checkpoint name will be the protein name followed by this suffix') + parser.add_argument('--model_parameters_location', type=str, + help='Location of VAE model parameters', required=True) + parser.add_argument('--training_logs_location', type=str, + help='Location of VAE model parameters') + parser.add_argument("--seed", type=int, help="Random seed", default=42) + parser.add_argument('--z_dim', type=int, help='Specify a different latent dim than in the params file') + parser.add_argument("--threshold_focus_cols_frac_gaps", type=float, + help="Maximum fraction of gaps allowed in focus columns - see data_utils.MSA_processing") + parser.add_argument('--force_load_weights', action='store_true', + help="Force loading of weights from MSA_weights_location (useful if you want to make sure you're using precalculated weights). Will fail if weight file doesn't exist.", + default=False) + parser.add_argument("--overwrite_weights", + help="Will overwrite weights file if it already exists", action="store_true", default=False) + parser.add_argument("--skip_existing", help="Will quit gracefully if model checkpoint file already exists", + action="store_true", default=False) + parser.add_argument("--batch_size", type=int, + help="Batch size for training", default=None) + parser.add_argument("--experimental_stream_data", + help="Load one-hot-encodings on the fly. Saves a lot of memory by not storing the whole one-hot matrix (sometimes >300GB)", action="store_true", default=False) + args = parser.parse_args() + print("Arguments:", args) + + assert os.path.isfile(args.MSA_list), f"MSA file list {args.MSA_list} doesn't seem to exist" mapping_file = pd.read_csv(args.MSA_list) - protein_name = mapping_file['protein_name'][args.protein_index] - msa_location = args.MSA_data_folder + os.sep + mapping_file['msa_location'][args.protein_index] - print("Protein name: "+str(protein_name)) - print("MSA file: "+str(msa_location)) + + if mapping_file["MSA_filename"].duplicated().any(): + print("Note: Duplicate MSA_filename detected in the mapping file. Deduplicating to only have one EVE model per alignment.") + mapping_file = mapping_file.drop_duplicates(subset=["MSA_filename"]) + protein_name = mapping_file['MSA_filename'][args.protein_index].split(".a2m")[0] + if args.model_name_suffix is not None: + protein_name = f"{protein_name}_{args.model_name_suffix}" + msa_location = args.MSA_data_folder + os.sep + mapping_file['MSA_filename'][args.protein_index] + print("Protein name: " + str(protein_name)) + print("MSA file: " + str(msa_location)) if args.theta_reweighting is not None: theta = args.theta_reweighting else: try: - theta = float(mapping_file['theta'][args.protein_index]) + theta = float(mapping_file['MSA_theta'][args.protein_index]) except: + print("Couldn't load theta from mapping file. Using default value of 0.2") theta = 0.2 - print("Theta MSA re-weighting: "+str(theta)) + + model_name = protein_name + f"_seed_{args.seed}" + print("Model name: " + str(model_name)) + model_checkpoint_final_path = args.VAE_checkpoint_location + os.sep + model_name + if os.path.isfile(model_checkpoint_final_path): + if args.skip_existing: + print("Model checkpoint already exists, skipping, since --skip_existing was specified") + exit(0) + else: + raise FileExistsError(f"Model checkpoint {model_checkpoint_final_path} already exists. \ + Use --skip_existing to skip without raising an error, or delete the destination file if you want to rerun.") + + # Using data_kwargs so that if options aren't set, they'll be set to default values + data_kwargs = {} + if args.threshold_focus_cols_frac_gaps is not None: + print("Using custom threshold_focus_cols_frac_gaps: ", + args.threshold_focus_cols_frac_gaps) + data_kwargs['threshold_focus_cols_frac_gaps'] = args.threshold_focus_cols_frac_gaps + + if args.overwrite_weights: + print("Overwriting weights file") + data_kwargs['overwrite_weights'] = True + + print("Theta MSA re-weighting: " + str(theta)) + + # Load weights file if it's in the mapping file + if "weight_file_name" in mapping_file.columns: + weights_file = args.MSA_weights_location + os.sep + \ + mapping_file["weight_file_name"][args.protein_index] + print("Using weights filename from mapping file") + else: + print(f"weight_file_name not provided in mapping file. Using default weights filename of {protein_name}_theta_{theta}.npy") + weights_file = args.MSA_weights_location + os.sep + protein_name + '_theta_' + str(theta) + '.npy' + + print(f"Weights location: {weights_file}") + + if args.force_load_weights: + print("Flag force_load_weights enabled - Forcing that we use weights from file:", weights_file) + if not os.path.isfile(weights_file): + raise FileNotFoundError(f"Weights file {weights_file} doesn't exist." + f"To recompute weights, remove the flag --force_load_weights.") data = data_utils.MSA_processing( - MSA_location=msa_location, - theta=theta, - use_weights=True, - weights_location=args.MSA_weights_location + os.sep + protein_name + '_theta_' + str(theta) + '.npy' + MSA_location=msa_location, + theta=theta, + use_weights=True, + weights_location=weights_file, + skip_one_hot_encodings=args.experimental_stream_data, # If we're only streaming in the one-hot encodings, we can set this to True to disable one-hot encoding calculations. + **data_kwargs, ) - model_name = protein_name + "_" + args.model_name_suffix - print("Model name: "+str(model_name)) - + assert os.path.isfile(args.model_parameters_location), args.model_parameters_location model_params = json.load(open(args.model_parameters_location)) + # Overwrite params if necessary + if args.z_dim: + model_params["encoder_parameters"]["z_dim"] = args.z_dim + model_params["decoder_parameters"]["z_dim"] = args.z_dim + if args.batch_size is not None: + print("Using batch_size from command line: ", args.batch_size) + model_params["training_parameters"]["batch_size"] = args.batch_size + model = VAE_model.VAE_model( - model_name=model_name, - data=data, - encoder_parameters=model_params["encoder_parameters"], - decoder_parameters=model_params["decoder_parameters"], - random_seed=args.seed + model_name=model_name, + data=data, + encoder_parameters=model_params["encoder_parameters"], + decoder_parameters=model_params["decoder_parameters"], + random_seed=args.seed ) model = model.to(model.device) @@ -60,11 +143,16 @@ model_params["training_parameters"]['model_checkpoint_location'] = args.VAE_checkpoint_location print("Starting to train model: " + model_name) - model.train_model(data=data, training_parameters=model_params["training_parameters"]) + start = time.perf_counter() + model.train_model( + data=data, training_parameters=model_params["training_parameters"], use_dataloader=args.experimental_stream_data) + end = time.perf_counter() + # Show time in hours,minutes,seconds + print(f"Finished in {(end - start)//60//60}hours {(end - start)//60%60} minutes and {(end - start)%60} seconds") print("Saving model: " + model_name) - model.save(model_checkpoint=model_params["training_parameters"]['model_checkpoint_location']+os.sep+model_name+"_final", - encoder_parameters=model_params["encoder_parameters"], - decoder_parameters=model_params["decoder_parameters"], - training_parameters=model_params["training_parameters"] - ) \ No newline at end of file + model.save(model_checkpoint=model_checkpoint_final_path, + encoder_parameters=model_params["encoder_parameters"], + decoder_parameters=model_params["decoder_parameters"], + training_parameters=model_params["training_parameters"] + ) diff --git a/utils/constants.py b/utils/constants.py new file mode 100644 index 0000000..b6de0a9 --- /dev/null +++ b/utils/constants.py @@ -0,0 +1,6 @@ +# Copied from EVCouplings +GAP = "-" +MATCH_GAP = GAP +INSERT_GAP = "." +ALPHABET_PROTEIN_NOGAP = "ACDEFGHIKLMNPQRSTVWY" +ALPHABET_PROTEIN_GAP = GAP + ALPHABET_PROTEIN_NOGAP diff --git a/utils/data_utils.py b/utils/data_utils.py index 485c178..acf21b8 100644 --- a/utils/data_utils.py +++ b/utils/data_utils.py @@ -1,33 +1,52 @@ +import multiprocessing +import os +import time +from collections import defaultdict + import numpy as np import pandas as pd -from collections import defaultdict -import os +from tqdm import tqdm import torch -import tqdm +from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler + +from utils.weights import map_from_alphabet, map_matrix, calc_weights_fast + +# constants +GAP = "-" +MATCH_GAP = GAP +INSERT_GAP = "." + +ALPHABET_PROTEIN_NOGAP = "ACDEFGHIKLMNPQRSTVWY" +ALPHABET_PROTEIN_GAP = GAP + ALPHABET_PROTEIN_NOGAP + class MSA_processing: def __init__(self, - MSA_location="", - theta=0.2, - use_weights=True, - weights_location="./data/weights", - preprocess_MSA=True, - threshold_sequence_frac_gaps=0.5, - threshold_focus_cols_frac_gaps=0.3, - remove_sequences_with_indeterminate_AA_in_focus_cols=True - ): - + MSA_location="", + theta=0.2, + use_weights=True, + weights_location="./data/weights", + preprocess_MSA=True, + threshold_sequence_frac_gaps=0.5, + threshold_focus_cols_frac_gaps=0.3, + remove_sequences_with_indeterminate_AA_in_focus_cols=True, + num_cpus=-1, + weights_calc_method="eve", + overwrite_weights=False, + skip_one_hot_encodings=False, + ): + """ Parameters: - - msa_location: (path) Location of the MSA data. Constraints on input MSA format: + - msa_location: (path) Location of the MSA data. Constraints on input MSA format: - focus_sequence is the first one in the MSA data - first line is structured as follows: ">focus_seq_name/start_pos-end_pos" (e.g., >SPIKE_SARS2/310-550) - - corespondding sequence data located on following line(s) + - corresponding sequence data located on following line(s) - then all other sequences follow with ">name" on first line, corresponding data on subsequent lines - theta: (float) Sequence weighting hyperparameter. Generally: Prokaryotic and eukaryotic families = 0.2; Viruses = 0.01 - - use_weights: (bool) If False, sets all sequence weights to 1. If True, checks weights_location -- if non empty uses that; + - use_weights: (bool) If False, sets all sequence weights to 1. If True, checks weights_location -- if non empty uses that; otherwise compute weights from scratch and store them at weights_location - - weights_location: (path) Location to load from/save to the sequence weights + - weights_location: (path) File to load from/save to the sequence weights - preprocess_MSA: (bool) performs pre-processing of MSA to remove short fragments and positions that are not well covered. - threshold_sequence_frac_gaps: (float, between 0 and 1) Threshold value to define fragments - sequences with a fraction of gap characters above threshold_sequence_frac_gaps are removed @@ -36,25 +55,66 @@ def __init__(self, - positions with a fraction of gap characters above threshold_focus_cols_pct_gaps will be set to lower case (and not included in the focus_cols) - default is set to 0.3 (i.e., focus positions are the ones with 30% of gaps or less, i.e., 70% or more residue occupancy) - remove_sequences_with_indeterminate_AA_in_focus_cols: (bool) Remove all sequences that have indeterminate AA (e.g., B, J, X, Z) at focus positions of the wild type + - num_cpus: (int) Number of CPUs to use for parallel weights calculation processing. If set to -1 (default), all available CPUs are used. If set to 1, weights are computed in serial. + - weights_calc_method: (str) Method to use for calculating sequence weights. Options: "eve" or "identity". (default "eve") + - overwrite_weights: (bool) If True, calculate weights and overwrite weights file. If False, load weights from weights_location if it exists. + Ideally, these weights options should be more like calc_weights=[True/False], and the weights_location should be a location to load from/save to. + - skip_one_hot_encodings: (bool) If True, only use this class to calculate weights. Skip the one-hot encodings (which can be very memory/compute intensive) + and don't calculate all singles. """ np.random.seed(2021) self.MSA_location = MSA_location self.weights_location = weights_location self.theta = theta - self.alphabet = "ACDEFGHIKLMNPQRSTVWY" + self.alphabet = ALPHABET_PROTEIN_NOGAP self.use_weights = use_weights + self.overwrite_weights = overwrite_weights self.preprocess_MSA = preprocess_MSA self.threshold_sequence_frac_gaps = threshold_sequence_frac_gaps self.threshold_focus_cols_frac_gaps = threshold_focus_cols_frac_gaps self.remove_sequences_with_indeterminate_AA_in_focus_cols = remove_sequences_with_indeterminate_AA_in_focus_cols + self.skip_one_hot_encodings = skip_one_hot_encodings + self.weights_calc_method = weights_calc_method + + # Defined by gen_alignment + self.aa_dict = {} + self.focus_seq_name = "" + self.seq_name_to_sequence = defaultdict(str) + self.focus_seq, self.focus_cols, self.focus_seq_trimmed, self.seq_len, self.alphabet_size = [None] * 5 + self.focus_start_loc, self.focus_stop_loc = None, None + self.uniprot_focus_col_to_wt_aa_dict, self.uniprot_focus_col_to_focus_idx = None, None + self.one_hot_encoding, self.weights, self.Neff, self.num_sequences = [None] * 4 + + # Defined by create_all_singles + self.mutant_to_letter_pos_idx_focus_list = None + self.all_single_mutations = None + # Fill in the instance variables self.gen_alignment() - self.create_all_singles() + + if not self.skip_one_hot_encodings: + # Encode the sequences + print("One-hot encoding sequences") + self.one_hot_encoding = one_hot_3D( + list(self.seq_name_to_sequence.values()), # Note: Dicts are unordered for python < 3.6 + alphabet=self.alphabet, + seq_length=self.seq_len, + progress=True, + ) + print("Data Shape =", self.one_hot_encoding.shape) + else: + print("Sequence length =", self.seq_len) + + self.calc_weights(num_cpus=num_cpus, method=weights_calc_method) + + if not self.skip_one_hot_encodings: + print("Creating all single mutations") + self.create_all_singles() def gen_alignment(self): """ Read training alignment and store basics in class instance """ self.aa_dict = {} - for i,aa in enumerate(self.alphabet): + for i, aa in enumerate(self.alphabet): self.aa_dict[aa] = i self.seq_name_to_sequence = defaultdict(str) @@ -64,66 +124,49 @@ def gen_alignment(self): line = line.rstrip() if line.startswith(">"): name = line - if i==0: + if i == 0: self.focus_seq_name = name else: self.seq_name_to_sequence[name] += line + print("Number of sequences in MSA (before preprocessing):", len(self.seq_name_to_sequence)) - ## MSA pre-processing to remove inadequate columns and sequences if self.preprocess_MSA: - msa_df = pd.DataFrame.from_dict(self.seq_name_to_sequence, orient='index', columns=['sequence']) - # Data clean up - msa_df.sequence = msa_df.sequence.apply(lambda x: x.replace(".","-")).apply(lambda x: ''.join([aa.upper() for aa in x])) - # Remove columns that would be gaps in the wild type - non_gap_wt_cols = [aa!='-' for aa in msa_df.sequence[self.focus_seq_name]] - msa_df['sequence'] = msa_df['sequence'].apply(lambda x: ''.join([aa for aa,non_gap_ind in zip(x, non_gap_wt_cols) if non_gap_ind])) - assert 0.0 <= self.threshold_sequence_frac_gaps <= 1.0,"Invalid fragment filtering parameter" - assert 0.0 <= self.threshold_focus_cols_frac_gaps <= 1.0,"Invalid focus position filtering parameter" - msa_array = np.array([list(seq) for seq in msa_df.sequence]) - gaps_array = np.array(list(map(lambda seq: [aa=='-' for aa in seq], msa_array))) - # Identify fragments with too many gaps - seq_gaps_frac = gaps_array.mean(axis=1) - seq_below_threshold = seq_gaps_frac <= self.threshold_sequence_frac_gaps - print("Proportion of sequences dropped due to fraction of gaps: "+str(round(float(1 - seq_below_threshold.sum()/seq_below_threshold.shape)*100,2))+"%") - # Identify focus columns - columns_gaps_frac = gaps_array[seq_below_threshold].mean(axis=0) - index_cols_below_threshold = columns_gaps_frac <= self.threshold_focus_cols_frac_gaps - print("Proportion of non-focus columns removed: "+str(round(float(1 - index_cols_below_threshold.sum()/index_cols_below_threshold.shape)*100,2))+"%") - # Lower case non focus cols and filter fragment sequences - msa_df['sequence'] = msa_df['sequence'].apply(lambda x: ''.join([aa.upper() if upper_case_ind else aa.lower() for aa, upper_case_ind in zip(x, index_cols_below_threshold)])) - msa_df = msa_df[seq_below_threshold] - # Overwrite seq_name_to_sequence with clean version - self.seq_name_to_sequence = defaultdict(str) - for seq_idx in range(len(msa_df['sequence'])): - self.seq_name_to_sequence[msa_df.index[seq_idx]] = msa_df.sequence[seq_idx] + # Overwrite self.seq_name_to_sequence + self.seq_name_to_sequence = self.preprocess_msa( + seq_name_to_sequence=self.seq_name_to_sequence, + focus_seq_name=self.focus_seq_name, + threshold_sequence_frac_gaps=self.threshold_sequence_frac_gaps, + threshold_focus_cols_frac_gaps=self.threshold_focus_cols_frac_gaps + ) self.focus_seq = self.seq_name_to_sequence[self.focus_seq_name] - self.focus_cols = [ix for ix, s in enumerate(self.focus_seq) if s == s.upper() and s!='-'] - self.focus_seq_trimmed = [self.focus_seq[ix] for ix in self.focus_cols] + self.focus_cols = [ix for ix, s in enumerate(self.focus_seq) if s == s.upper() and s != '-'] + self.focus_seq_trimmed = "".join([self.focus_seq[ix] for ix in self.focus_cols]) self.seq_len = len(self.focus_cols) self.alphabet_size = len(self.alphabet) # Connect local sequence index with uniprot index (index shift inferred from 1st row of MSA) focus_loc = self.focus_seq_name.split("/")[-1] - start,stop = focus_loc.split("-") + start, stop = focus_loc.split("-") self.focus_start_loc = int(start) self.focus_stop_loc = int(stop) self.uniprot_focus_col_to_wt_aa_dict \ - = {idx_col+int(start):self.focus_seq[idx_col] for idx_col in self.focus_cols} + = {idx_col + int(start): self.focus_seq[idx_col] for idx_col in self.focus_cols} self.uniprot_focus_col_to_focus_idx \ - = {idx_col+int(start):idx_col for idx_col in self.focus_cols} + = {idx_col + int(start): idx_col for idx_col in self.focus_cols} # Move all letters to CAPS; keeps focus columns only - for seq_name,sequence in self.seq_name_to_sequence.items(): - sequence = sequence.replace(".","-") - self.seq_name_to_sequence[seq_name] = [sequence[ix].upper() for ix in self.focus_cols] + for seq_name, sequence in self.seq_name_to_sequence.items(): + sequence = sequence.replace(".", "-") + self.seq_name_to_sequence[seq_name] = "".join( + [sequence[ix].upper() for ix in self.focus_cols]) # Makes a List[str] instead of str # Remove sequences that have indeterminate AA (e.g., B, J, X, Z) in the focus columns if self.remove_sequences_with_indeterminate_AA_in_focus_cols: alphabet_set = set(list(self.alphabet)) seq_names_to_remove = [] - for seq_name,sequence in self.seq_name_to_sequence.items(): + for seq_name, sequence in self.seq_name_to_sequence.items(): for letter in sequence: if letter not in alphabet_set and letter != "-": seq_names_to_remove.append(seq_name) @@ -131,46 +174,100 @@ def gen_alignment(self): seq_names_to_remove = list(set(seq_names_to_remove)) for seq_name in seq_names_to_remove: del self.seq_name_to_sequence[seq_name] + + print("Number of sequences after preprocessing:", len(self.seq_name_to_sequence)) + + self.num_sequences = len(self.seq_name_to_sequence.keys()) - # Encode the sequences - print ("Encoding sequences") - self.one_hot_encoding = np.zeros((len(self.seq_name_to_sequence.keys()),len(self.focus_cols),len(self.alphabet))) - for i,seq_name in enumerate(self.seq_name_to_sequence.keys()): - sequence = self.seq_name_to_sequence[seq_name] - for j,letter in enumerate(sequence): - if letter in self.aa_dict: - k = self.aa_dict[letter] - self.one_hot_encoding[i,j,k] = 1.0 + # Using staticmethod to keep this under the MSAProcessing namespace, but this is apparently not best practice + @staticmethod + def preprocess_msa(seq_name_to_sequence, focus_seq_name, threshold_sequence_frac_gaps, threshold_focus_cols_frac_gaps): + """Remove inadequate columns and sequences from MSA, overwrite self.seq_name_to_sequence.""" + print("Pre-processing MSA to remove inadequate columns and sequences...") + msa_df = pd.DataFrame.from_dict(seq_name_to_sequence, orient='index', columns=['sequence']) + # Data clean up + msa_df.sequence = msa_df.sequence.apply(lambda x: x.replace(".", "-")).apply( + lambda x: ''.join([aa.upper() for aa in x])) + # Remove columns that would be gaps in the wild type + non_gap_wt_cols = [aa != '-' for aa in msa_df.sequence[focus_seq_name]] + msa_df['sequence'] = msa_df['sequence'].apply( + lambda x: ''.join([aa for aa, non_gap_ind in zip(x, non_gap_wt_cols) if non_gap_ind])) + assert 0.0 <= threshold_sequence_frac_gaps <= 1.0, "Invalid fragment filtering parameter" + assert 0.0 <= threshold_focus_cols_frac_gaps <= 1.0, "Invalid focus position filtering parameter" + print("Calculating proportion of gaps") + msa_array = np.array([list(seq) for seq in msa_df.sequence]) + gaps_array = np.array(list(map(lambda seq: [aa == '-' for aa in seq], msa_array))) + # Identify fragments with too many gaps + seq_gaps_frac = gaps_array.mean(axis=1) + seq_below_threshold = seq_gaps_frac <= threshold_sequence_frac_gaps + print("Proportion of sequences dropped due to fraction of gaps: " + str( + round(float(1 - seq_below_threshold.sum() / seq_below_threshold.shape) * 100, 2)) + "%") + # Identify focus columns + columns_gaps_frac = gaps_array[seq_below_threshold].mean(axis=0) + index_cols_below_threshold = columns_gaps_frac <= threshold_focus_cols_frac_gaps + print("Proportion of non-focus columns removed: " + str( + round(float(1 - index_cols_below_threshold.sum() / index_cols_below_threshold.shape) * 100, 2)) + "%") + # Lower case non focus cols and filter fragment sequences + def _lower_case_and_filter_fragments(seq): + return ''.join([aa.lower() if aa_ix in index_cols_below_threshold else aa for aa_ix, aa in enumerate(seq)]) + msa_df['sequence'] = msa_df['sequence'].apply( + lambda seq: ''.join([aa.upper() if upper_case_ind else aa.lower() for aa, upper_case_ind in + zip(seq, index_cols_below_threshold)])) + msa_df = msa_df[seq_below_threshold] + # Overwrite seq_name_to_sequence with clean version + seq_name_to_sequence = defaultdict(str) + for seq_idx in range(len(msa_df['sequence'])): + seq_name_to_sequence[msa_df.index[seq_idx]] = msa_df.sequence[seq_idx] + return seq_name_to_sequence + + def calc_weights(self, num_cpus=1, method="eve"): + """ + If num_cpus == 1, weights are computed in serial. + If num_cpus == -1, weights are computed in parallel using all available cores. + Note: This will use multiprocessing.cpu_count() to get the number of available cores, which on clusters may + return all cores, not just the number of cores available to the user. + """ + # Refactored into its own function so that we can call it separately if self.use_weights: - try: + if os.path.isfile(self.weights_location) and not self.overwrite_weights: + print("Loading sequence weights from disk") self.weights = np.load(file=self.weights_location) - print("Loaded sequence weights from disk") - except: - print ("Computing sequence weights") - list_seq = self.one_hot_encoding - list_seq = list_seq.reshape((list_seq.shape[0], list_seq.shape[1] * list_seq.shape[2])) - def compute_weight(seq): - number_non_empty_positions = np.dot(seq,seq) - if number_non_empty_positions>0: - denom = np.dot(list_seq,seq) / np.dot(seq,seq) - denom = np.sum(denom > 1 - self.theta) - return 1/denom - else: - return 0.0 #return 0 weight if sequence is fully empty - self.weights = np.array(list(map(compute_weight,list_seq))) + else: + print("Computing sequence weights") + if num_cpus == -1: + num_cpus = get_num_cpus() + + if method == "eve": + alphabet_mapper = map_from_alphabet(ALPHABET_PROTEIN_GAP, default=GAP) + arrays = [] + for seq in self.seq_name_to_sequence.values(): + arrays.append(np.array(list(seq))) + sequences = np.vstack(arrays) + sequences_mapped = map_matrix(sequences, alphabet_mapper) + start = time.perf_counter() + self.weights = calc_weights_fast(sequences_mapped, identity_threshold=1 - self.theta, + empty_value=0, num_cpus=num_cpus) # GAP = 0 + end = time.perf_counter() + print(f"Weights calculation took {end - start:.2f} seconds") + elif method == "identity": + self.weights = np.ones(self.num_sequences) + else: + raise ValueError(f"Unknown method: {method}. Must be either 'eve' or 'identity'.") + print("Saving sequence weights to disk") np.save(file=self.weights_location, arr=self.weights) else: # If not using weights, use an isotropic weight matrix print("Not weighting sequence data") - self.weights = np.ones(self.one_hot_encoding.shape[0]) + self.weights = np.ones(self.num_sequences) self.Neff = np.sum(self.weights) - self.num_sequences = self.one_hot_encoding.shape[0] + print("Neff =", str(self.Neff)) + print("Number of sequences: ", self.num_sequences) + assert self.weights.shape[0] == self.num_sequences # == self.one_hot_encoding.shape[0] + + return self.weights - print ("Neff =",str(self.Neff)) - print ("Data Shape =",self.one_hot_encoding.shape) - def create_all_singles(self): start_idx = self.focus_start_loc focus_seq_index = 0 @@ -178,15 +275,15 @@ def create_all_singles(self): list_valid_mutations = [] # find all possible valid mutations that can be run with this alignment alphabet_set = set(list(self.alphabet)) - for i,letter in enumerate(self.focus_seq): + for i, letter in enumerate(self.focus_seq): if letter in alphabet_set and letter != "-": for mut in self.alphabet: - pos = start_idx+i + pos = start_idx + i if mut != letter: - mutant = letter+str(pos)+mut + mutant = letter + str(pos) + mut self.mutant_to_letter_pos_idx_focus_list[mutant] = [letter, pos, focus_seq_index] list_valid_mutations.append(mutant) - focus_seq_index += 1 + focus_seq_index += 1 self.all_single_mutations = list_valid_mutations def save_all_singles(self, output_filename): @@ -194,4 +291,182 @@ def save_all_singles(self, output_filename): output.write('mutations') for mutation in self.all_single_mutations: output.write('\n') - output.write(mutation) \ No newline at end of file + output.write(mutation) + + +def generate_mutated_sequences(msa_data, list_mutations): + """ + Copied from VAE_model.compute_evol_indices. + + Generate mutated sequences using a MSAProcessing data object and list of mutations of the form "A42T" where position + 42 on the wild type is changed from A to T. + Multiple mutations are separated by colons e.g. "A42T:C9A" + + Returns a tuple (list_valid_mutations, valid_mutated_sequences), + e.g. (['wt', 'A3T'], {'wt': 'AGAKLI', 'A3T': 'AGTKLI'}) + """ + list_valid_mutations = ['wt'] + valid_mutated_sequences = {} + valid_mutated_sequences['wt'] = msa_data.focus_seq_trimmed # first sequence in the list is the wild_type + + # Remove (multiple) mutations that are invalid + for mutation in list_mutations: + individual_substitutions = mutation.split(':') + mutated_sequence = list(msa_data.focus_seq_trimmed)[:] + fully_valid_mutation = True + for mut in individual_substitutions: + wt_aa, pos, mut_aa = mut[0], int(mut[1:-1]), mut[-1] + if pos not in msa_data.uniprot_focus_col_to_wt_aa_dict \ + or msa_data.uniprot_focus_col_to_wt_aa_dict[pos] != wt_aa \ + or mut not in msa_data.mutant_to_letter_pos_idx_focus_list: + print("Not a valid mutant: " + mutation) + fully_valid_mutation = False + break + else: + wt_aa, pos, idx_focus = msa_data.mutant_to_letter_pos_idx_focus_list[mut] + mutated_sequence[idx_focus] = mut_aa # perform the corresponding AA substitution + + if fully_valid_mutation: + list_valid_mutations.append(mutation) + valid_mutated_sequences[mutation] = ''.join(mutated_sequence) + + return list_valid_mutations, valid_mutated_sequences + + +# Copied from VAE_model.compute_evol_indices +# One-hot encoding of sequences +def one_hot_3D(sequences, alphabet, seq_length, progress=False): + """ + Take in a list of sequence names/keys and corresponding sequences, and generate a one-hot array according to an alphabet. + """ + aa_dict = {letter: i for (i, letter) in enumerate(alphabet)} + + one_hot_out = np.zeros((len(sequences), seq_length, len(alphabet))) + for i, sequence in enumerate(tqdm(sequences, desc="One-hot encoding sequences", mininterval=1, disable=not progress)): + for j, letter in enumerate(sequence): + if letter in aa_dict: + k = aa_dict[letter] + one_hot_out[i, j, k] = 1.0 + one_hot_out = torch.tensor(one_hot_out) + return one_hot_out + + +def gen_one_hot_to_sequence(one_hot_tensor, alphabet): + """Reverse of one_hot_3D. Need the msa_data again. Returns a list of sequences.""" + for seq_tensor in one_hot_tensor: # iterate through outer dimension + seq = "" + letters_idx = seq_tensor.argmax(-1) + + for idx in letters_idx.tolist(): # Could also do map(di.get, letters_idx) + letter = alphabet[idx] + seq += letter + yield seq + + +def one_hot_to_sequence_list(one_hot_tensor, alphabet): + return list(gen_one_hot_to_sequence(one_hot_tensor, alphabet)) + +def get_one_hot_3D_fn(alphabet, seq_len): + aa_dict = {letter: i for (i, letter) in enumerate(alphabet)} + + def fn(batch_seqs): + one_hot_out = np.zeros((len(batch_seqs), seq_len, len(alphabet))) + for i, sequence in enumerate(batch_seqs): + for j, letter in enumerate(sequence): + if letter in aa_dict: + k = aa_dict[letter] + one_hot_out[i, j, k] = 1.0 + one_hot_out = torch.tensor(one_hot_out) + return one_hot_out + return fn + +def get_num_cpus(): + if 'SLURM_CPUS_PER_TASK' in os.environ: + num_cpus = int(os.environ['SLURM_CPUS_PER_TASK']) + print("SLURM_CPUS_PER_TASK:", os.environ['SLURM_CPUS_PER_TASK']) + print("Using all available cores (calculated using SLURM_CPUS_PER_TASK):", num_cpus) + else: + num_cpus = len(os.sched_getaffinity(0)) + print("Using all available cores (calculated using len(os.sched_getaffinity(0))):", num_cpus) + return num_cpus + + +class SequenceDataset(Dataset): + def __init__(self, sequences): + self.sequences = sequences + + def __len__(self): + return len(self.sequences) + + def __getitem__(self, idx): + return self.sequences[idx] + + +class InfiniteDataLoader(DataLoader): + """Dataloader that reloads its dataset every epoch""" + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.iter_loader = super().__iter__() + + def __iter__(self): + return self + + def __next__(self): + try: + batch = next(self.iter_loader) + except StopIteration: + # If the inner DataLoader has exhausted the dataset, reset it + self.iter_loader = super().__iter__() + batch = next(self.iter_loader) + return batch + + +def get_one_hot_dataloader(sequences, alphabet, seq_len, batch_size): + """ + To avoid issues with storing all one-hot encodings in memory, we can calculate them on the fly with a small performance overhead. + Crucially, this only runs through the dataset once (unlike in training, where we need to iterate through each epoch) + """ + dataset = SequenceDataset(sequences) + + # num_cpus = get_num_cpus() + + one_hot_fn = get_one_hot_3D_fn(alphabet, seq_len) + + def collate_fn(batch_seqs): + # Construct a batch of one-hot-encodings + batch_seq_tensor = one_hot_fn(batch_seqs) + return batch_seq_tensor + + dataloader = DataLoader( + dataset=dataset, + batch_size=batch_size, + shuffle=False, + num_workers=1, # collate_fn is not parallelized, so no speedup with multiple CPUs + collate_fn=collate_fn) # pin_memory=True + + return dataloader + + +def get_training_dataloader(sequences, weights, alphabet, seq_len, batch_size, num_training_steps): + """Similar to get_one_hot_dataloader, but based on the MSA_processing object, using weighted sampling and reloading the dataset each epoch.""" + dataset = SequenceDataset(sequences) + # This can take a ton of memory if the weights or num_training_steps*batch_size are large + sampler = WeightedRandomSampler(weights=weights, num_samples=num_training_steps*batch_size, replacement=True) + num_cpus = get_num_cpus() + + one_hot_fn = get_one_hot_3D_fn(alphabet, seq_len) + + def collate_fn(batch_seqs): + # Construct a batch of one-hot-encodings + batch_seq_tensor = one_hot_fn(batch_seqs) + return batch_seq_tensor + + # Avoiding the problem of the dataset running out: Wrap it with an iterable that refreshes it every time + dataloader = InfiniteDataLoader( + dataset=dataset, + batch_size=batch_size, + num_workers=num_cpus, # collate_fn is not parallelized, so no speedup with multiple CPUs + sampler=sampler, + collate_fn=collate_fn,) #pin_memory=True + + return dataloader diff --git a/utils/weights.py b/utils/weights.py new file mode 100644 index 0000000..792f80d --- /dev/null +++ b/utils/weights.py @@ -0,0 +1,262 @@ +import multiprocessing +import time +from collections import defaultdict + +import numba +from numba import prange +from numba_progress import ProgressBar + +import numpy as np +from tqdm import tqdm + +def calc_weights_fast(matrix_mapped, identity_threshold, empty_value, num_cpus=1, print_progress=True): + """ + Modified from EVCouplings: https://github.com/debbiemarkslab/EVcouplings + + Note: Numba by default uses `multiprocessing.cpu_count()` threads. + On a cluster where a process might only have access to a subset of CPUs, this may be less than the number of CPUs available. + The caller should ideally use len(os.sched_getaffinity(0)) to get the number of CPUs available to the process. + + Calculate weights for sequences in alignment by + clustering all sequences with sequence identity + greater or equal to the given threshold. + Parameters + ---------- + identity_threshold : float + Sequence identity threshold + """ + empty_idx = is_empty_sequence_matrix(matrix_mapped, empty_value=empty_value) # e.g. sequences with just gaps or lowercase, no valid AAs + N = matrix_mapped.shape[0] + + # Original EVCouplings code structure, plus gap handling + if num_cpus != 1: + # print("Calculating weights using Numba parallel (experimental) since num_cpus > 1. If you want to disable multiprocessing set num_cpus=1.") + # print("Default number of threads for Numba:", numba.config.NUMBA_NUM_THREADS) + + # num_cpus > numba.config.NUMBA_NUM_THREADS will give an error. + # But we'll leave it so that the user has to be explicit. + numba.set_num_threads(num_cpus) + print("Set number of threads to:", numba.get_num_threads()) # Sometimes Numba uses all the CPUs anyway + + if print_progress: + update_frequency=1000 + with ProgressBar(total=N, update_interval=30, miniters=update_frequency) as progress: # can also use tqdm mininterval, maxinterval etc + num_cluster_members = calc_num_cluster_members_nogaps_parallel_print(matrix_mapped[~empty_idx], identity_threshold, + invalid_value=empty_value, progress_proxy=progress, update_frequency=update_frequency) + else: + num_cluster_members = calc_num_cluster_members_nogaps_parallel(matrix_mapped[~empty_idx], identity_threshold, + invalid_value=empty_value) + + else: + # Use the serial version + num_cluster_members = calc_num_cluster_members_nogaps(matrix_mapped[~empty_idx], identity_threshold, + invalid_value=empty_value) + + # Empty sequences: weight 0 + weights = np.zeros((N)) + weights[~empty_idx] = 1.0 / num_cluster_members + return weights + +# Below are util functions copied from EVCouplings +def is_empty_sequence_matrix(matrix, empty_value): + assert len(matrix.shape) == 2, f"Matrix must be 2D; shape={matrix.shape}" + assert isinstance(empty_value, (int, float)), f"empty_value must be a number; type={type(empty_value)}" + # Check for each sequence if all positions are equal to empty_value + empty_idx = np.all((matrix == empty_value), axis=1) + return empty_idx + + +def map_from_alphabet(alphabet, default): + """ + Creates a mapping dictionary from a given alphabet. + Parameters + ---------- + alphabet : str + Alphabet for remapping. Elements will + be remapped according to alphabet starting + from 0 + default : Elements in matrix that are not + contained in alphabet will be treated as + this character + Raises + ------ + ValueError + For invalid default character + """ + map_ = { + c: i for i, c in enumerate(alphabet) + } + + try: + default = map_[default] + except KeyError: + raise ValueError( + "Default {} is not in alphabet {}".format(default, alphabet) + ) + + return defaultdict(lambda: default, map_) + + + +def map_matrix(matrix, map_): + """ + Map elements in a numpy array using alphabet + Parameters + ---------- + matrix : np.array + Matrix that should be remapped + map_ : defaultdict + Map that will be applied to matrix elements + Returns + ------- + np.array + Remapped matrix + """ + return np.vectorize(map_.__getitem__)(matrix) + + +# Fastmath should be safe here, as we can assume that there are no NaNs in the input etc. +@numba.jit(nopython=True, fastmath=True) #parallel=True +def calc_num_cluster_members_nogaps(matrix, identity_threshold, invalid_value): + """ + From EVCouplings: https://github.com/debbiemarkslab/EVcouplings/blob/develop/evcouplings/align/alignment.py#L1172. + Modified to use non-gapped length and not counting gaps as sequence similarity matches. + + Calculate number of sequences in alignment + within given identity_threshold of each other + Parameters + ---------- + matrix : np.array + N x L matrix containing N sequences of length L. + Matrix must be mapped to range(0, num_symbols) using + map_matrix function + identity_threshold : float + Sequences with at least this pairwise identity will be + grouped in the same cluster. + Returns + ------- + np.array + Vector of length N containing number of cluster + members for each sequence (inverse of sequence + weight) + """ + N, L = matrix.shape + L = 1.0 * L + + # Empty sequences are filtered out before this function and are ignored + # minimal cluster size is 1 (self) + num_neighbors = np.ones((N)) + L_non_gaps = L - np.sum(matrix == invalid_value, axis=1) # Edit: From EVE, use the non-gapped length + # compare all pairs of sequences + for i in range(N - 1): + for j in range(i + 1, N): + pair_matches = 0 + for k in range(L): + # Edit(Lood): Don't count gaps as matches + if matrix[i, k] == matrix[j, k] and matrix[i, k] != invalid_value: + pair_matches += 1 + + # Edit(Lood): Calculate identity as fraction of non-gapped positions (so asymmetric) + # Note: Changed >= to > to match EVE / DeepSequence code + if pair_matches / L_non_gaps[i] > identity_threshold: + num_neighbors[i] += 1 + if pair_matches / L_non_gaps[j] > identity_threshold: + num_neighbors[j] += 1 + + return num_neighbors + + +@numba.jit(nopython=True, fastmath=True, parallel=True) +def calc_num_cluster_members_nogaps_parallel(matrix, identity_threshold, invalid_value): + """ + Parallel implementation of calc_num_cluster_members_nogaps above. + + Calculate number of sequences in alignment + within given identity_threshold of each other + Parameters + ---------- + matrix : np.array + N x L matrix containing N sequences of length L. + Matrix must be mapped to range(0, num_symbols) using + map_matrix function + identity_threshold : float + Sequences with at least this pairwise identity will be + grouped in the same cluster. + invalid_value : int + Value in matrix that is considered invalid, e.g. gap or lowercase character. + Returns + ------- + np.array + Vector of length N containing number of cluster + members for each sequence (inverse of sequence + weight) + """ + N, L = matrix.shape + L = 1.0 * L + + # Empty sequences are filtered out before this function and are ignored + # minimal cluster size is 1 (self) + num_neighbors = np.ones((N)) + L_non_gaps = L - np.sum(matrix == invalid_value, axis=1) # Edit: From EVE, use the non-gapped length + # compare all pairs of sequences + # Edit: Rewrote loop without any dependencies between inner and outer loops, so that it can be parallelized + for i in prange(N): + num_neighbors_i = 1 + for j in range(N): + if i == j: + continue + pair_matches = 0 + for k in range(L): # This should hopefully be vectorised by numba + # Edit(Lood): Don't count gaps as matches + if matrix[i, k] == matrix[j, k] and matrix[i, k] != invalid_value: + pair_matches += 1 + + # Edit(Lood): Calculate identity as fraction of non-gapped positions (so this similarity is asymmetric) + # Note: Changed >= to > to match EVE / DeepSequence code + if pair_matches / L_non_gaps[i] > identity_threshold: + num_neighbors_i += 1 + + num_neighbors[i] = num_neighbors_i + + return num_neighbors + +@numba.jit(nopython=True, fastmath=True, parallel=True) +def calc_num_cluster_members_nogaps_parallel_print(matrix, identity_threshold, invalid_value, progress_proxy=None, update_frequency=1000): + """ + Modified calc_num_cluster_members_nogaps_parallel to add tqdm progress bar - useful for multi-hour weights calc. + + progress_proxy : numba_progress.ProgressBar + A handle on the progress bar to update + update_frequency : int + Similar to miniters in tqdm, how many iterations between updating the progress bar (which then will only print every `update_interval` seconds) + """ + + N, L = matrix.shape + L = 1.0 * L + + # Empty sequences are filtered out before this function and are ignored + # minimal cluster size is 1 (self) + num_neighbors = np.ones((N)) + L_non_gaps = L - np.sum(matrix == invalid_value, axis=1) # Edit: From EVE, use the non-gapped length + # compare all pairs of sequences + # Edit: Rewrote loop without any dependencies between inner and outer loops, so that it can be parallelized + for i in prange(N): + num_neighbors_i = 1 + for j in range(N): + if i == j: + continue + pair_matches = 0 + for k in range(L): # This should hopefully be vectorised by numba + # Edit(Lood): Don't count gaps as matches + if matrix[i, k] == matrix[j, k] and matrix[i, k] != invalid_value: + pair_matches += 1 + # Edit(Lood): Calculate identity as fraction of non-gapped positions (so this similarity is asymmetric) + # Note: Changed >= to > to match EVE / DeepSequence code + if pair_matches / L_non_gaps[i] > identity_threshold: + num_neighbors_i += 1 + + num_neighbors[i] = num_neighbors_i + if progress_proxy is not None and i % update_frequency == 0: + progress_proxy.update(update_frequency) + + return num_neighbors