diff --git a/.gitignore b/.gitignore index d495c4d..32c1873 100644 --- a/.gitignore +++ b/.gitignore @@ -3,4 +3,5 @@ .bash_history lp_solution* .vscode -.coverage \ No newline at end of file +.coverage +combined_scores.csv \ No newline at end of file diff --git a/README.md b/README.md index ddcd193..de6c14d 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,6 @@ - ![Status](https://img.shields.io/badge/Status-Active-green.svg) ![Python](https://img.shields.io/badge/Python-3.9-blue.svg) [![Paper](https://img.shields.io/badge/Paper-Download-green.svg)](https://www.biorxiv.org/content/10.1101/2024.11.03.621763v1) @@ -16,30 +15,27 @@ ## Introduction -Welcome to the `protlib-designer` repository! This repository contains a Python package that designs diverse protein libraries by seeding linear programming with deep mutational scanning data (or any other data that can be represented as a matrix of scores per single-point mutation). The software takes as input the score matrix, where each row corresponds to a mutation and each column corresponds to a different source of scores, and outputs a subset of mutations that maximize the diversity of the library while Pareto-optimizing the scores from the different sources. +Welcome to the `protlib-designer` repository! This repository contains a lightweight python library for designing diverse protein libraries by seeding linear programming with deep mutational scanning data (or any other data that can be represented as a matrix of scores per single-point mutation). The software takes as input the score matrix, where each row corresponds to a mutation and each column corresponds to a different source of scores, and outputs a subset of mutations that maximize the diversity of the library while Pareto-optimizing the scores from the different sources. The paper [Antibody Library Design by Seeding Linear Programming with Inverse Folding and Protein Language Models](https://www.biorxiv.org/content/10.1101/2024.11.03.621763v1) uses this software to design diverse antibody libraries by seeding linear programming with scores computed by Protein Language Models (PLMs) and Inverse Folding models.
- -
-

- - protlib-designer designs diverse protein libraries by seeding linear programming with deep mutational scanning data. - (a) The input to the method is an antibody-antigen complex and a target antibody sequence. (b) We generate in silico deep mutational scanning data using protein language and inverse folding models. (c) The result is fed into a multi-objective linear programming solver. (d) The solver generates a library of antibodies that are co-optimized for the in silico scores while satisfying diversity constraints. - -

-
+ +
+

+ protlib-designer designs diverse protein libraries by seeding linear programming with deep mutational scanning data. (a) The input to the method is target protein sequence and, if available, a structure of the protein or protein complex (in this case, the antibody trastuzumab in complex with the HER2 receptor). (b) We generate in silico deep mutational scanning data using protein language and inverse folding models. (c) The result is fed into a multi-objective linear programming solver. (d) The solver generates a library of antibodies that are co-optimized for the in silico scores while satisfying diversity constraints. + +

+
- ## Getting Started In this section, we provide instructions on how to install the software and run the code. ### Installation -Create an environment with Python >=3.9 and install the dependencies: +Create an environment with Python >=3.7,<3.11 and install the dependencies: ```bash python -m venv .venv source .venv/bin/activate @@ -52,12 +48,10 @@ pip install -e .[dev] ``` which will allow you to run the tests and the linter. You can run the linting with: ```bash -black -S -t py39 protlib_designer scripts +black -S -t py39 protlib_designer scripts && \ flake8 --ignore=E501,E203,W503 protlib_designer scripts ``` - - ### Run the code To run the code to create a diverse protein library of size 10 from the example data, run the following command: @@ -66,26 +60,31 @@ To run the code to create a diverse protein library of size 10 from the example protlib-designer ./example_data/trastuzumab_spm.csv 10 ``` -We provide a rich set of command-line arguments to customize the behavior of `protlib-designer`. For example, the following command runs `protlib-designer` with a range of 3 to 5 mutations per sequence, enforcing the interleaving of the mutant order and balancing the mutant order, and using a weighted multi-objective optimization: +We provide a rich set of command-line arguments to customize the behavior of `protlib-designer`. For example, the following command runs `protlib-designer` with a range of 3 to 5 mutations per sequence, enforcing the interleaving of the mutant order and balancing the mutant order, allowing for each mutation to appear at most `1` time and a position to be mutated at most `4` times, +and using a weighted multi-objective optimization: ```bash protlib-designer ./example_data/trastuzumab_spm.csv 10 \ ---min-mut 3 --max-mut 5 --interleave-mutant-order True --force-mutant-order-balance True \ ---weighted-multi-objective True + --min-mut 3 \ + --max-mut 5 \ + --interleave-mutant-order True \ + --force-mutant-order-balance True \ + --schedule 2 \ + --schedule-param '1,4' \ + --weighted-multi-objective True ``` - For more information on the command-line arguments, run: ```bash protlib-designer --help ``` -### Input data +### Input data : In silico deep mutational scanning data The input to the software is a matrix of per-mutation scores (the csv file `trastuzumab_spm.csv` in the example above). Typically, the score matrix is defined by *in silico* deep mutational scanning data, where each row corresponds to a mutation and each column corresponds to the score computed by a deep learning model. See the example data in the `example_data` directory for an example of the input data format. The structure of the input data is shown below: -| MutationHL | score-1 | score-2 | ... | score-N | +| Mutation | score-1 | score-2 | ... | score-N | |------------|--------|--------|-----|--------| | AH106C | -0.1 | 0.2 | ... | 0.3 | | AH106D | 0.2 | -0.3 | ... | -0.4 | @@ -95,7 +94,7 @@ The input to the software is a matrix of per-mutation scores (the csv file `tras Important notes about the input data: -• The `MutationHL` column contains the mutation in the format : `WT_residue` + `chain` + `position_index` + `mutant_residue`. For example, `A+H+106+C = AH106C` represents the mutation of the residue at position 106 in chain H from alanine to cysteine. +• The `Mutation` column contains the mutation in the format : `WT_residue` + `chain` + `position_index` + `mutant_residue`. For example, `A+H+106+C = AH106C` represents the mutation of the residue at position 106 in chain H from alanine to cysteine. • The `score-1`, `score-2`, ..., `score-N` columns contain the scores computed by the deep learning models for each mutation. Typically, the scores are the negative log-likelihoods ratios of the mutant residue and the wild-type residue, computed by the deep learning model: @@ -105,27 +104,35 @@ s_{ij}^{\text{PLM}} = -\log \left( \frac{p(x_i = a_j | w)}{p(x_i = w_i | w)} \r where $w$ is the wild-type sequence, and $p(x_i = a_j | w)$ is the probability of the mutant residue $a_j$ at position $i$ given the wild-type sequence $w$ as estimated by a Protein Language Model (PLM) or an Inverse Folding model (or any other deep learning model). For example, in [Antibody Library Design by Seeding Linear Programming with Inverse Folding and Protein Language Models](https://www.biorxiv.org/content/10.1101/2024.11.03.621763v1), we used the scores computed by the [ProtBert](https://pubmed.ncbi.nlm.nih.gov/34232869/) and [AntiFold](https://arxiv.org/abs/2405.03370) models. +### Scoring functions + +We provide a set of scoring functions that can be used to compute the scores for the input data. The scoring functions are defined in the `protlib_designer/scorer` module. To use this functionality, you need to install additional dependencies: + +```bash +pip install -e .[plm] +``` + +After installing the dependencies, you can use the scoring functions to compute the scores for the input data. For example, we can compute the scores using `Rostlab/prot_bert` and `facebook/esm2_t6_8M_UR50D` models, and then, call `protlib-designer` to design a diverse protein library of size 10: + +```bash +protlib-plm-scorer EVQLVESGGGLVQPGGSLRLSCAASGFNIKDTYIHWVRQAPGKGLEWVARIYPTNGYTRYADSVKGRFTISADTSKNTAYLQMNSLRAEDTAVYYCSRWGGDGFYAMDYWGQGTLVTVSS WH99 GH100 GH101 DH102 GH103 FH104 YH105 AH106 MH107 DH108 \ +--models Rostlab/prot_bert --models facebook/esm2_t6_8M_UR50D \ +--chain-type heavy \ +--score-type minus_llr \ +--mask \ +--output-file combined_scores.csv \ +&& protlib-designer combined_scores.csv 10 --weighted-multi-objective True +``` + ## Contributing Please read [CONTRIBUTING.md](./CONTRIBUTING.md) for details on our code of conduct, and the process for submitting pull requests to us. -## Citing This Work +## Citation If you use this software in your research, please cite the following paper: -```latex -@article {Hayes2024.11.03.621763, - author = {Hayes, Conor F. and Magana-Zook, Steven A. and Gon{\c c}alves, Andre and Solak, Ahmet Can and Faissol, Daniel and Landajuela, Mikel}, - title = {Antibody Library Design by Seeding Linear Programming with Inverse Folding and Protein Language Models}, - elocation-id = {2024.11.03.621763}, - year = {2024}, - doi = {10.1101/2024.11.03.621763}, - publisher = {Cold Spring Harbor Laboratory}, - URL = {https://www.biorxiv.org/content/early/2024/11/03/2024.11.03.621763}, - eprint = {https://www.biorxiv.org/content/early/2024/11/03/2024.11.03.621763.full.pdf}, - journal = {bioRxiv} -} -``` +Hayes, C. F., Magana-Zook, S. A., Gonçalves, A., Solak, A. C., Faissol, D., & Landajuela, M. (2024). *Antibody Library Design by Seeding Linear Programming with Inverse Folding and Protein Language Models*. **bioRxiv**. [https://doi.org/10.1101/2024.11.03.621763](https://doi.org/10.1101/2024.11.03.621763) ## License diff --git a/example_data/trastuzumab_spm.csv b/example_data/trastuzumab_spm.csv index 14b5084..50ae655 100644 --- a/example_data/trastuzumab_spm.csv +++ b/example_data/trastuzumab_spm.csv @@ -1,4 +1,4 @@ -MutationHL,antifold_antigen_neg_llr,protbert_neg_llr +Mutation,antifold_antigen_neg_llr,protbert_neg_llr AH106C,5.8979,5.614669799804688 AH106D,5.1441,3.7915658950805664 AH106E,6.1135,4.867663383483887 diff --git a/protlib_designer/dataloader.py b/protlib_designer/dataloader.py index 23596cb..9382272 100644 --- a/protlib_designer/dataloader.py +++ b/protlib_designer/dataloader.py @@ -18,9 +18,9 @@ def extract_positions_and_wildtype_amino_from_data(df: pd.DataFrame): df : pd.DataFrame The dataframe containing the data. """ - mutation_full = df["MutationHL"].values.tolist() - positions = [] # Positions that have mutations - wildtype_position_amino = {} # Position to wild type amino acid mapping + mutation_full = df["Mutation"].values.tolist() + positions = [] # Positions that have mutations. + wildtype_position_amino = {} # Position to wild type amino acid mapping. for mutation in mutation_full: wildtype_amino, position, _ = parse_mutation(mutation) @@ -36,14 +36,14 @@ def extract_positions_and_wildtype_amino_from_data(df: pd.DataFrame): ) exit() - # Save the wild type amino acid at this position + # Save the wild type amino acid at this position. wildtype_position_amino[position] = wildtype_amino - # Get distinct positions + # Get distinct positions. positions = list(set(positions)) - # Order the positions in ascending order - # Consider positions like H28 < H100A + # Order the positions in ascending order. + # Consider positions like H28 < H100A. positions_df = pd.DataFrame.from_dict( { i: { @@ -61,7 +61,7 @@ def extract_positions_and_wildtype_amino_from_data(df: pd.DataFrame): ascending=[True, True, True], ) - # Get the order by merging the strings + # Get the order by merging the strings. positions = [ f"{row['chain']}{row['pos']}{row['pos_extra']}" for _, row in positions_df.iterrows() @@ -92,7 +92,7 @@ def load_data(self): logger.info(f"Detected wild type amino acid: {self.wildtype_position_amino}") def update_config_with_data(self, config: Dict[str, Any]): - # Check that max_mut is less than the number of positions + # Check that max_mut is less than the number of positions. if ( config["max_mut"] > len(self.positions) and config["interleave_mutant_order"] diff --git a/protlib_designer/filter/filter.py b/protlib_designer/filter/filter.py index 53fc191..b07b747 100644 --- a/protlib_designer/filter/filter.py +++ b/protlib_designer/filter/filter.py @@ -3,7 +3,7 @@ class Filter(ABC): @abstractmethod - def filter(self): + def filter(self, solution): pass @abstractmethod diff --git a/protlib_designer/generator/generator.py b/protlib_designer/generator/generator.py index e5ccfe1..2c85c0a 100644 --- a/protlib_designer/generator/generator.py +++ b/protlib_designer/generator/generator.py @@ -1,5 +1,3 @@ -# write a generator abstract class - from abc import ABC, abstractmethod diff --git a/protlib_designer/generator/ilp_generator.py b/protlib_designer/generator/ilp_generator.py index e13e5b7..bda4b77 100644 --- a/protlib_designer/generator/ilp_generator.py +++ b/protlib_designer/generator/ilp_generator.py @@ -1,5 +1,6 @@ import time from pathlib import Path +import warnings import numpy as np import pandas as pd @@ -10,6 +11,9 @@ from protlib_designer.generator.generator import Generator from protlib_designer.utils import amino_acids, aromatic_amino_acids, parse_mutation +# Ignore UserWarnings from pulp +warnings.filterwarnings("ignore", category=UserWarning, module="pulp") + class ILPGenerator(Generator): def __init__(self, data_loader, config): @@ -100,15 +104,15 @@ def _prepare_variables_and_zero_pad_matrix(self): self.forbidden_vars.append(x_var) self.forbidden_vars_dict[mutation_name] = x_var # Check if row exists in the input dataframe. - if mutation_name in data_df["MutationHL"].values: + if mutation_name in data_df["Mutation"].values: # Extract the row from the dataframe in a dictionary format. - row = data_df[data_df["MutationHL"] == mutation_name].to_dict( + row = data_df[data_df["Mutation"] == mutation_name].to_dict( "records" )[0] data_df_padded.append(row) else: # The row does not exist in the input dataframe. # Add 0-vector row for the new mutation. - new_row = {"MutationHL": mutation_name} + new_row = {"Mutation": mutation_name} # Save the position and aa to add X_pos_a = 0 constraint later in the script. zero_enforced_mutations.append((wt, position, aa)) self.missing_vars.append(x_var) @@ -141,13 +145,13 @@ def _check_data_and_variables_consistency(self): ) exit() - # Check that data_df["MutationHL"].values is equivalent (ordered in the same way) as x_vars. + # Check that data_df["Mutation"].values is equivalent (ordered in the same way) as x_vars. for index, x_var in enumerate(self.x_vars): mutation_name = x_var.getName().split("_")[1] - if mutation_name != self.data_df["MutationHL"].values[index]: + if mutation_name != self.data_df["Mutation"].values[index]: logger.error( f"Error adding missing position-amino acid pairs. Expected {mutation_name}. \ - Got {self.data_df['MutationHL'].values[index]}" + Got {self.data_df['Mutation'].values[index]}" ) exit() @@ -339,7 +343,9 @@ def generate_one_solution(self, iteration: int): status = self.problem.solve(self.solver) if status != 1: - logger.error(f"Error Status: {pulp.LpStatus[status]}") + logger.error( + f"Error in ILPGenerator when solving the problem. Status: {pulp.LpStatus[status]}" + ) return None cpu_time = time.time() - cpu_time_start diff --git a/protlib_designer/scorer/plm_scorer.py b/protlib_designer/scorer/plm_scorer.py new file mode 100644 index 0000000..bb9f569 --- /dev/null +++ b/protlib_designer/scorer/plm_scorer.py @@ -0,0 +1,341 @@ +import torch +import pandas as pd +from typing import List + +from transformers import AutoModelForMaskedLM, AutoTokenizer + +from protlib_designer.utils import amino_acids +from protlib_designer.scorer.scorer import ( + score_function, + from_user_input_to_scorer_input, + from_scorer_output_to_user_output, + Scorer, +) + + +def load_huggingface_model(model_reference: str, device: torch.device): + """ + Load the Hugging Face model and tokenizer. + + Parameters + ---------- + model_reference : str + Reference to the model, it can be a model name (e.g., 'Rostlab/prot_bert') or a path to an existing model. + device : torch.device + Device to load the model. + """ + tokenizer = AutoTokenizer.from_pretrained(model_reference) + model = AutoModelForMaskedLM.from_pretrained( + model_reference, output_hidden_states=True + ) + model.resize_token_embeddings(len(tokenizer)) + model = model.to(device) + model = model.eval() + return model, tokenizer + + +class PLMScorer(Scorer): + def __init__( + self, + model_name: str = "Rostlab/prot_bert", + model_path: str = None, + score_type: str = "minus_llr", + mask: bool = True, + mapping: dict = None, + ): + """Initialize the PLM Scorer. + + Parameters + ---------- + model_name : str + Name of the model to use. Defaults to 'Rostlab/prot_bert'. + model_path : str + Path to the model to use. If None, the model is loaded from the Hugging Face model hub using the model name. + score_type : str + Type of score to use. Options are 'll', 'llr', 'minus_llr', 'probs'. + mask : bool + If True: mask positions before passing to PLM, if False: do not mask positions before passing to PLM. + mapping : dict + A dictionary to map the positions from the user input to the numbering used by the scorer. Example: + {'H': {'H1': 1, 'H5': 2, 'H6': 3}, 'L': {'L1': 1, 'L2': 2, 'L3': 3}} + """ + + if model_name is None: + raise ValueError("Please provide a model name or a model path.") + + self.model_name = model_name + self.model_path = model_path + self.score_type = score_type + self.mask = mask + self.mapping = mapping + + # Set the device to GPU if available, otherwise use CPU. + self.device = ( + torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") + ) + + # Load the PLM model and tokenizer. + self.model, self.tokenizer = self.load_model() + + # If the model is not ESM, use token type ids. + self.use_token_type_ids = "esm" not in model_name.lower() + + # Get indices of the PLM token library from the tokenizer with respect to all amino acids. + self.aa_token_indices = self.tokenizer.convert_tokens_to_ids(amino_acids) + + def load_model(self): + """Load the PLM model and tokenizer. + + Parameters + ---------- + model_name : str + Name of the model. + model_path : str + Path to the model. + device : torch.device + Device to load the model. + + Returns + ------- + model : torch.nn.Module + The loaded model. + tokenizer : transformers.AutoTokenizer + The loaded tokenizer. + """ + if self.model_path is not None: + return load_huggingface_model(self.model_path, self.device) + try: + return load_huggingface_model(self.model_name, self.device) + except Exception as e: + raise ValueError( + f"Model {self.model_name} not found in Hugging Face model hub. Please provide a valid model name." + ) from e + + def prepare_input(self, sequence: str, positions: List[str], chain_type: str): + """Prepare the input for the model. + + Parameters + ---------- + sequence : str + Sequence used to generated the scores. This a string of amino acids. + positions : list + Positions on the sequence to be used to generate the score. + Positions must be in the following format: {WT}{CHAIN}{PDBINDEX}. + Note: PDBINDEX is 1-indexed, that is, the first position is 1. For example, the first positions in + the list of positions are [EH1, VH2, QH3, ...]. + chain_type : str + Optional parameter to specify the chain type. For example, heavy or light chain for antibodies. + + Returns + ------- + batch : list + If mask is True, the batch is a list of copies of the sequence with the required position masked. + If mask is False, the batch is a single copy of the sequence. + chain_token : int + Chain token to pass to the PLM model. + wildtype : list + List of wildtype amino acids for each position. + """ + + if self.mapping is not None: + positions = from_user_input_to_scorer_input(positions, self.mapping) + + # Check that all positions have the same chain letter (2nd character). + chain_letter = {position[1] for position in positions} + if len(chain_letter) > 1: + raise ValueError( + "All positions must have the same chain letter. Please provide positions with the same chain type." + ) + chain_letter = chain_letter.pop() + + # Get the positions indices. + position_indices = [int(position[2:]) for position in positions] + + # Get wildtype dict: {position: wildtype} + wildtype_dict = {int(position[2:]): position[0] for position in positions} + + # Create batch to generate score in one pass. + batch = [] + # Iterate over positions to apply mask token at each required position. + if self.mask: + # Create mask token. + mask_token = self.tokenizer.mask_token + for position_index in position_indices: + # Convert sequence to list. + sequence_list = list(sequence) + # Mask the required position. + # Subtract 1 from the position index to convert to 0-indexed. + position_index = position_index - 1 + sequence_list[position_index] = mask_token + # Append sequence to batch, and add whitespace to sequence for protbert tokenizer. + batch.append(" ".join(sequence_list)) + else: + sequence_list = list(sequence) + batch.append(" ".join(sequence_list)) + + chain_token = 1 if chain_type == "light" else 0 + + return batch, wildtype_dict, position_indices, chain_letter, chain_token + + def forward_pass(self, batch, chain_token): + """Perform the forward pass. + + Parameters + ---------- + batch : list + List of sequences to pass to the model. + chain_token : int + Chain token to pass to the PLM model. + """ + input_ids = self.tokenizer(batch, padding=True)["input_ids"] + input_ids = torch.tensor(input_ids, device=self.device) + bz, seq_length = input_ids.shape + token_type_ids = ( + torch.zeros(bz, seq_length).fill_(chain_token).to(self.device).long() + ) + + with torch.no_grad(): + if self.use_token_type_ids: + logits = self.model(input_ids=input_ids, token_type_ids=token_type_ids)[ + "logits" + ] + else: + logits = self.model(input_ids=input_ids)["logits"] + + return logits + + def get_scores(self, sequence: str, positions: List[str], chain_type: str): + """Compute the scores (in silico deep mutational scanning) for a given sequence and positions. + + Parameters + ---------- + sequence : str + Sequence used to generated the scores. This a string of amino acids. + positions : list + Positions on the sequence to be used to generate the score. + Positions must be in the following format: {WT}{CHAIN}{PDBINDEX}. + Note: PDBINDEX is 1-indexed, that is, the first position is 1. For example, the first positions in + the list of positions are [EH1, VH2, QH3, ...]. + chain_type : str + Type of antibody chain (heavy or light). This is used to determine the chain token to pass to the LLM model. + """ + # Prepare the input for the model. + ( + batch, + wildtype_dict, + position_indices, + chain_letter, + chain_token, + ) = self.prepare_input(sequence, positions, chain_type) + # Get the logits from the forward pass (shape: (batch_size, sequence_length, num_tokens)). + logits = self.forward_pass(batch, chain_token) + # Get AA tokens (shape: (batch_size, sequence_length, num_amino_acids)). + logits = logits[:, :, self.aa_token_indices] + # Apply softmax and take log over the logits (shape: (batch_size, sequence_length, num_amino_acids)). + logps = torch.log_softmax(logits, dim=-1) + # Create list to store elements for dataframe output + mutation2score = {} + # Iterate over each position and corresponding wildtype to get the scores. + for batch_idx, position_index in enumerate(position_indices): + batch_idx = batch_idx if self.mask else 0 + # Get the wildtype amino acid at the current position. + wildtype_aa = wildtype_dict[position_index] + # Index 0 of logits corresponds to CLS token, index 1 is where the sequence begins. + sequence_index = position_index + # Per position logp and probs (shape: (num_amino_acids)). + position_logps = logps[batch_idx][sequence_index].cpu().numpy() + # Get the wildtype logp. + wildtype_aa_id = self.tokenizer.convert_tokens_to_ids(wildtype_aa) + wildtype_aa_logp = position_logps[ + self.aa_token_indices.index(wildtype_aa_id) + ] + # Compute the scores. + position_scores = list( + score_function( + position_logps, wildtype_aa_logp, score_type=self.score_type + ) + ) + # Update the mutation2score dictionary. + for i, amino_acid in enumerate(amino_acids): + mutation = f"{wildtype_aa}{chain_letter}{position_index}{amino_acid}" + if self.mapping is not None: + mutation = from_scorer_output_to_user_output(mutation, self.mapping) + mutation2score[mutation] = position_scores[i] + + # create a dataframe from the mutation2score dictionary. columns: Mutation, score + return pd.DataFrame( + mutation2score.items(), + columns=["Mutation", f"{self.model_name}_{self.score_type}"], + ) + + def __str__(self): + return super().__str__() + ": PLM Scorer" + + +if __name__ == "__main__": + + sequence = "EVQLVESGGGLVQPGGSLRLSCAASGFNIKDTYIHWVRQAPGKGLEWVARIYPTNGYTRYADSVKGRFTISADTSKNTAYLQMNSLRAEDTAVYYCSRWGGDGFYAMDYWGQGTLVTVSS" + chain_type = "heavy" + positions = ["WH99", "GH100", "GH101"] + score_type = "minus_llr" + model_name = "Rostlab/prot_bert" + model_path = None + mask = True + mapping = None + + # test the class + plm_scorer = PLMScorer( + model_name=model_name, + model_path=model_path, + score_type=score_type, + mask=mask, + mapping=mapping, + ) + + df = plm_scorer.get_scores(sequence, positions, chain_type) + print(df) + + model_name = "facebook/esm2_t6_8M_UR50D" + plm_scorer_2 = PLMScorer( + model_name=model_name, + model_path=model_path, + score_type=score_type, + mask=mask, + mapping=mapping, + ) + + df = plm_scorer_2.get_scores(sequence, positions, chain_type) + print(df) + + model_name = "abbert" + model_path = "/Users/landajuelala1/Repositories/abbert/abbert_legacy" + plm_scorer_3 = PLMScorer( + model_name=model_name, + model_path=model_path, + score_type=score_type, + mask=mask, + mapping=mapping, + ) + + df = plm_scorer_3.get_scores(sequence, positions, chain_type) + print(df) + + sequence = "EVQLVESGGGLVQPGGSLRLSCAASGFNIKDTYIHWVRQAPGKGLEWVARIYPTNGYTRYADSVKGRFTISADTSKNTAYLQMNSLRAEDTAVYYCSRWGGDGFYAMDYWGQGTLVTVSS" + chain_type = "heavy" + positions = ["WB99", "GB100", "GB101"] + scoring_func = "minus_llr" + model_name = "abbert" + model_path = "/Users/landajuelala1/Repositories/abbert/abbert_legacy" + mask = True + plm_scorer_3 = PLMScorer( + model_name=model_name, + model_path=model_path, + score_type=score_type, + mask=mask, + mapping=mapping, + ) + # print the dataframe + df = plm_scorer_3.get_scores(sequence, positions, chain_type) + print(df) + # save the dataframe to a csv + df.to_csv("plm_scorer_abbert.csv", index=False) diff --git a/protlib_designer/scorer/scorer.py b/protlib_designer/scorer/scorer.py new file mode 100644 index 0000000..edfbeb4 --- /dev/null +++ b/protlib_designer/scorer/scorer.py @@ -0,0 +1,88 @@ +from abc import ABC, abstractmethod +import numpy as np +from numpy.typing import NDArray +from typing import List + + +def score_function( + position_logps: NDArray, wt_logps: NDArray, score_type: str = "llr" +) -> NDArray: + """Calculate the score for the positions.""" + if score_type == "ll": + position_score = position_logps + elif score_type == "minus_ll": + position_score = -position_logps + elif score_type == "llr": + position_score = position_logps - wt_logps + elif score_type == "minus_llr": + position_score = -position_logps + wt_logps + elif score_type == "probs": + position_score = np.exp(position_logps) + return position_score + + +def from_user_input_to_scorer_input(positions: List[str], mapping: dict) -> List[str]: + """Use the mapping to convert the positions to the numbering used by the scorer. + + Parameters + ---------- + positions : list + List of positions. + mapping : dict + mapping : dict + A dictionary to map the positions from the user input to the numbering used by the scorer. Example: + {'H': {'H1': 1, 'H5': 2, 'H6': 3}, 'L': {'L1': 1, 'L2': 2, 'L3': 3}} + """ + for i, position in enumerate(positions): + wt, chain, pos = position[0], position[1], position[2:] + mapped_pos = mapping[chain][str(chain) + str(pos)] + positions[i] = wt + chain + str(mapped_pos) + return positions + + +def from_scorer_output_to_user_output(mutation: str, mapping: dict) -> str: + """Use the mapping to convert the positions from the numbering used by the scorer to the user input. + + Parameters + ---------- + mutations : list + List of mutations. + mapping : dict + Mapping of chains and positions. + """ + + wt, chain, pos, aa = mutation[0], mutation[1], mutation[2:-1], mutation[-1] + chain_pos = list(mapping[chain].keys())[ + list(mapping[chain].values()).index(int(pos)) + ] + return wt + chain_pos + aa + + +class Scorer(ABC): + @abstractmethod + def __init__(self): + pass + + @abstractmethod + def load_model(self): + """Load the model.""" + pass + + @abstractmethod + def prepare_input(self, **kwargs): + """Prepare the input for the model.""" + pass + + @abstractmethod + def forward_pass(self): + """Perform the forward pass.""" + pass + + @abstractmethod + def get_scores(self): + """Get the scores.""" + pass + + @abstractmethod + def __str__(self): + return "Generator" diff --git a/protlib_designer/solution_manager.py b/protlib_designer/solution_manager.py index 149b162..08491c3 100644 --- a/protlib_designer/solution_manager.py +++ b/protlib_designer/solution_manager.py @@ -27,8 +27,8 @@ def output_results(self): output_path = Path(output_folder) if not output_path.exists(): output_path.mkdir(parents=True, exist_ok=True) - # Rename `solution` to `MutationHL` for consistency with other outputs. - self.solutions_df.rename(columns={"solution": "MutationHL"}, inplace=True) + # Rename `solution` to `Mutation` for consistency with other outputs. + self.solutions_df.rename(columns={"solution": "Mutation"}, inplace=True) self.solutions_df.to_csv(output_path / "solutions.csv", index=False) logger.info(f"Solutions saved to {output_path / 'solutions.csv'}") diff --git a/protlib_designer/solver/generate_and_remove_solver.py b/protlib_designer/solver/generate_and_remove_solver.py index 83fd9ba..bcbf3f3 100644 --- a/protlib_designer/solver/generate_and_remove_solver.py +++ b/protlib_designer/solver/generate_and_remove_solver.py @@ -1,5 +1,7 @@ import time +from protlib_designer import logger + class GenerateAndRemoveSolver: def __init__( @@ -24,7 +26,11 @@ def run(self): ): self.generator.update_generator_before_generation(iteration=iteration) solution_dict = self.generator.generate_one_solution(iteration=iteration) - solution = solution_dict.get("solution") + try: + solution = solution_dict.get("solution") + except AttributeError: + logger.info(f"No solution found for iteration {iteration}. Exiting.") + break if self.filter.filter(solution): self.list_of_solution_dicts.append(solution_dict) number_of_solutions += 1 diff --git a/protlib_designer/utils.py b/protlib_designer/utils.py index 539dafa..0291f50 100644 --- a/protlib_designer/utils.py +++ b/protlib_designer/utils.py @@ -135,20 +135,20 @@ def format_and_validate_parameters( def validate_data(df: pd.DataFrame): """Validate the data file. The data file must have the following columns: - MutationHL, Target1, Target2, ..., TargetN + Mutation, Target1, Target2, ..., TargetN Parameters ---------- df : pd.DataFrame The dataframe containing the data. """ - if "MutationHL" not in df.columns: - logger.error("Data file must have a MutationHL column.") + if "Mutation" not in df.columns: + logger.error("Data file must have a Mutation column.") sys.exit(2) if len(df.columns) < 2: logger.error( - "Data file must have at minimum the MutationHL column and at least one Objective/Target." + "Data file must have at minimum the Mutation column and at least one Objective/Target." ) sys.exit(3) @@ -228,7 +228,7 @@ def extract_mutation_key(mutationbreak: str): mutationbreak : str The mutationbreak string. """ - chars = list(mutationbreak) # Convert string to list of chars + chars = list(mutationbreak) # Convert string to list of chars. return f"{chars[1]}_{chars[2]}" diff --git a/scripts/_run_protlib_designer.py b/scripts/_run_protlib_designer.py index 9e510a5..9e01691 100644 --- a/scripts/_run_protlib_designer.py +++ b/scripts/_run_protlib_designer.py @@ -180,7 +180,7 @@ def ilp( write_config(config, given_path) - problem = pulp.LpProblem("GUIDE_Antibody_Optimization", pulp.LpMinimize) + problem = pulp.LpProblem("Protein_Library_Optimization", pulp.LpMinimize) logger.info("Linear programming problem initialized") solver_msg = debug > 2 @@ -227,16 +227,16 @@ def ilp( ) # Check if row exists in the input dataframe - if mutation_name in data_df["MutationHL"].values: + if mutation_name in data_df["Mutation"].values: # Extract the row from the dataframe in a dictionary format - row = data_df[data_df["MutationHL"] == mutation_name].to_dict( - "records" - )[0] + row = data_df[data_df["Mutation"] == mutation_name].to_dict("records")[ + 0 + ] # Append the row to the padded dataframe data_df_padded.append(row) else: # The row does not exist in the input dataframe # Add 0-vector row for the new mutation - new_row = {"MutationHL": mutation_name} + new_row = {"Mutation": mutation_name} # Save the position and aa to add X_pos_a = 0 constraint later in the script. zero_enforced_mutations.append((wt, position, aa)) for curr_target in targets: @@ -265,13 +265,13 @@ def ilp( ) exit() - # Check that data_df["MutationHL"].values is equivalent (ordered in the same way) as x_vars + # Check that data_df["Mutation"].values is equivalent (ordered in the same way) as x_vars for index, x_var in enumerate(x_vars): mutation_name = x_var.getName().split("_")[1] - if mutation_name != data_df["MutationHL"].values[index]: + if mutation_name != data_df["Mutation"].values[index]: logger.error( f"Error adding missing position-amino acid pairs. Expected {mutation_name}. \ - Got {data_df['MutationHL'].values[index]}" + Got {data_df['Mutation'].values[index]}" ) exit() @@ -620,7 +620,7 @@ def ilp( if list_of_solution_dicts: df_solutions = pd.DataFrame(list_of_solution_dicts) - df_solutions.rename(columns={"solution": "MutationHL"}).to_csv( + df_solutions.rename(columns={"solution": "Mutation"}).to_csv( Path(given_path) / "solutions.csv", index=False ) diff --git a/scripts/run_plm_scorer.py b/scripts/run_plm_scorer.py new file mode 100644 index 0000000..e53d83a --- /dev/null +++ b/scripts/run_plm_scorer.py @@ -0,0 +1,97 @@ +import click +import pandas as pd + +from protlib_designer import logger +from protlib_designer.scorer.plm_scorer import PLMScorer + +CONTEXT_SETTINGS = dict(help_option_names=["-h", "--help"]) + + +@click.command(context_settings=CONTEXT_SETTINGS) +@click.argument('sequence', type=str, required=True) +@click.argument('positions', type=str, required=True, nargs=-1) +@click.option('--models', 'model_names', type=str, multiple=True, required=True) +@click.option('--chain-type', type=str, required=True) +@click.option('--model-paths', 'model_paths', type=str, multiple=True) +@click.option( + '--score-type', + type=click.Choice(['minus_ll', 'minus_llr']), + default='minus_llr', +) +@click.option('--mask/--no-mask', default=True) +@click.option('--mapping', type=str, default=None) +@click.option('--output-file', type=str, default='combined_scores.csv') +def run_plm_scorer( + sequence, + positions, + model_names, + chain_type, + model_paths, + score_type, + mask, + mapping, + output_file, +): + """ + Compute in silico mutagenesis scores using Protein Language Models (PLM). + + \b + Parameters + ---------- + sequence : str + The protein sequence. + positions : str + The positions to mutate. + model_names : str + The model names. + chain_type : str + The chain type. + model_paths : str + The model paths. + score_type : str + The score type. + mask : bool + Whether to mask the wild type amino acid. + mapping : str + The mapping. + """ + + dataframes = [] + + for model_name in model_names: + plm_scorer = PLMScorer( + model_name=model_name, + model_path=None, + score_type=score_type, + mask=mask, + mapping=mapping, + ) + df = plm_scorer.get_scores(sequence, list(positions), chain_type) + dataframes.append(df) + + for model_path in model_paths: + plm_scorer = PLMScorer( + model_name=model_path, + model_path=model_path, + score_type=score_type, + mask=mask, + mapping=mapping, + ) + df = plm_scorer.get_scores(sequence, list(positions), chain_type) + dataframes.append(df) + + if not dataframes: + logger.error("No dataframes to combine.") + return + + # Merge the dataframes over the column "Mutation" + combined_df = None + for i, df in enumerate(dataframes): + combined_df = df if i == 0 else pd.merge(combined_df, df, on="Mutation") + + combined_df.to_csv(output_file, index=False) + logger.info(f"Combined scores saved to {output_file}") + + +if __name__ == "__main__": + run_plm_scorer() diff --git a/setup.cfg b/setup.cfg index da473ea..17aaa4e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -3,7 +3,7 @@ name = protlib_designer version = 0.0.1 author = Mikel Landajuela author_email = -description = Linear programming for protein design +description = Integer Linear Programming for Protein Library Design long_description = file: README.md long_description_content_type = text/markdown @@ -12,7 +12,7 @@ packages = find: python_requires = >=3.9 install_requires = click - numpy + numpy==1.26.4 pandas pulp @@ -21,4 +21,7 @@ dev = black==22.8.0 flake8==5.0.4 pytest - coverage \ No newline at end of file + coverage +plm = + torch + transformers \ No newline at end of file diff --git a/setup.py b/setup.py index 31dbe6c..581b465 100644 --- a/setup.py +++ b/setup.py @@ -1,3 +1,10 @@ from setuptools import setup -setup(entry_points={"console_scripts": ["protlib-designer=scripts.run_protlib_designer:run_protlib_designer"]}) +setup( + entry_points={ + "console_scripts": [ + "protlib-designer=scripts.run_protlib_designer:run_protlib_designer", + "protlib-plm-scorer=scripts.run_plm_scorer:run_plm_scorer", + ] + } +)