From 716d0c67cbae0b8a4b57aaeeaed08870e6050e69 Mon Sep 17 00:00:00 2001 From: yivlad Date: Wed, 17 May 2023 01:13:17 +0200 Subject: [PATCH 01/18] Test protbert model --- analysis.sh | 2 +- bertrand/pretraining/train_mlm.py | 4 ++-- bertrand/training/train.py | 5 +++-- 3 files changed, 6 insertions(+), 5 deletions(-) mode change 100644 => 100755 analysis.sh diff --git a/analysis.sh b/analysis.sh old mode 100644 new mode 100755 index 8cc2da0..a011da8 --- a/analysis.sh +++ b/analysis.sh @@ -4,7 +4,7 @@ CPU=$2 # First run MLM pre-training # This step is faster with a GPU -bash pretraining.sh "$DIR"/pretraining +# bash pretraining.sh "$DIR"/pretraining # Then generate negative decoys # This step is very CPU and RAM intensive diff --git a/bertrand/pretraining/train_mlm.py b/bertrand/pretraining/train_mlm.py index 8181e40..911a70d 100644 --- a/bertrand/pretraining/train_mlm.py +++ b/bertrand/pretraining/train_mlm.py @@ -11,7 +11,7 @@ DataCollatorForLanguageModeling, ) -from bertrand.training.config import BERT_CONFIG, MLM_TRAINING_ARGS +from bertrand.training.config import MLM_TRAINING_ARGS from bertrand.pretraining.dataset_mlm import PeptideTCRMLMDataset from bertrand.model.tokenization import tokenizer @@ -69,7 +69,7 @@ def get_training_args(output_dir: str) -> TrainingArguments: train_dataset = PeptideTCRMLMDataset(train) val_dataset = PeptideTCRMLMDataset(val) - model = BertForMaskedLM(BERT_CONFIG) + model = BertForMaskedLM.from_pretrained("Rostlab/prot_bert") data_collator = DataCollatorForLanguageModeling( tokenizer=tokenizer, mlm=True, mlm_probability=args.mlm_frac ) diff --git a/bertrand/training/train.py b/bertrand/training/train.py index a5944be..81d76bf 100644 --- a/bertrand/training/train.py +++ b/bertrand/training/train.py @@ -4,7 +4,7 @@ import os import pandas as pd -from transformers import TrainingArguments, DataCollatorWithPadding, Trainer +from transformers import TrainingArguments, DataCollatorWithPadding, Trainer, BertForMaskedLM from bertrand.training.dataset import PeptideTCRDataset from bertrand.training.metrics import mean_auroc_per_peptide_cluster @@ -132,6 +132,7 @@ def compute_metrics_and_save_predictions(p): train_dataset = PeptideTCRDataset(dataset, cv_seed=cv_seed, subset="train") val_dataset = PeptideTCRDataset(dataset, cv_seed=cv_seed, subset="val+test") logging.info("Training started") + model = BertForMaskedLM.from_pretrained("Rostlab/prot_bert") train_and_evaluate( - train_dataset, val_dataset, BERTrand, args.model_ckpt, dataset_out_dir, + train_dataset, val_dataset, model, args.model_ckpt, dataset_out_dir, ) From 2a784d0127b58261fda2e30b148058e738df8301 Mon Sep 17 00:00:00 2001 From: yivlad Date: Thu, 18 May 2023 01:22:39 +0200 Subject: [PATCH 02/18] Use BertForSequenceClassification --- bertrand/training/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bertrand/training/train.py b/bertrand/training/train.py index 81d76bf..5b79136 100644 --- a/bertrand/training/train.py +++ b/bertrand/training/train.py @@ -4,7 +4,7 @@ import os import pandas as pd -from transformers import TrainingArguments, DataCollatorWithPadding, Trainer, BertForMaskedLM +from transformers import TrainingArguments, DataCollatorWithPadding, Trainer, BertForSequenceClassification from bertrand.training.dataset import PeptideTCRDataset from bertrand.training.metrics import mean_auroc_per_peptide_cluster @@ -132,7 +132,7 @@ def compute_metrics_and_save_predictions(p): train_dataset = PeptideTCRDataset(dataset, cv_seed=cv_seed, subset="train") val_dataset = PeptideTCRDataset(dataset, cv_seed=cv_seed, subset="val+test") logging.info("Training started") - model = BertForMaskedLM.from_pretrained("Rostlab/prot_bert") + model = BertForSequenceClassification.from_pretrained("Rostlab/prot_bert") train_and_evaluate( train_dataset, val_dataset, model, args.model_ckpt, dataset_out_dir, ) From c50b2b7b35edb93cbe1171f00268170d27a6d19e Mon Sep 17 00:00:00 2001 From: yivlad Date: Tue, 23 May 2023 23:26:57 +0200 Subject: [PATCH 03/18] Fix pretrained model downloading --- bertrand/training/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bertrand/training/train.py b/bertrand/training/train.py index 5b79136..508650c 100644 --- a/bertrand/training/train.py +++ b/bertrand/training/train.py @@ -132,7 +132,7 @@ def compute_metrics_and_save_predictions(p): train_dataset = PeptideTCRDataset(dataset, cv_seed=cv_seed, subset="train") val_dataset = PeptideTCRDataset(dataset, cv_seed=cv_seed, subset="val+test") logging.info("Training started") - model = BertForSequenceClassification.from_pretrained("Rostlab/prot_bert") + model = BertForSequenceClassification train_and_evaluate( - train_dataset, val_dataset, model, args.model_ckpt, dataset_out_dir, + train_dataset, val_dataset, model, "Rostlab/prot_bert", dataset_out_dir, ) From d7eae414490c2ef77398a089ab544d3364205a07 Mon Sep 17 00:00:00 2001 From: yivlad Date: Thu, 25 May 2023 00:49:59 +0200 Subject: [PATCH 04/18] Add prot_bert wrapper --- bertrand/training/prot_bert.py | 15 +++++++++++++++ bertrand/training/train.py | 22 +++++++--------------- 2 files changed, 22 insertions(+), 15 deletions(-) create mode 100644 bertrand/training/prot_bert.py diff --git a/bertrand/training/prot_bert.py b/bertrand/training/prot_bert.py new file mode 100644 index 0000000..19100de --- /dev/null +++ b/bertrand/training/prot_bert.py @@ -0,0 +1,15 @@ +from transformers import BertForSequenceClassification +import torch.nn as nn + +PRE_TRAINED_MODEL_NAME = 'Rostlab/prot_bert' +class ProteinClassifier(nn.Module): + def __init__(self): + super(ProteinClassifier, self).__init__() + self.bert = BertForSequenceClassification.from_pretrained(PRE_TRAINED_MODEL_NAME) + self.classifier = nn.Sequential(nn.Dropout(p=0.2), + nn.Linear(self.bert.config.hidden_size, 1), + nn.Tanh()) + + def forward(self, *args, **kwargs): + output = self.bert(*args, **kwargs) + return self.classifier(output.pooler_output) diff --git a/bertrand/training/train.py b/bertrand/training/train.py index 508650c..216eee6 100644 --- a/bertrand/training/train.py +++ b/bertrand/training/train.py @@ -4,13 +4,13 @@ import os import pandas as pd -from transformers import TrainingArguments, DataCollatorWithPadding, Trainer, BertForSequenceClassification +from transformers import TrainingArguments, DataCollatorWithPadding, Trainer from bertrand.training.dataset import PeptideTCRDataset from bertrand.training.metrics import mean_auroc_per_peptide_cluster -from bertrand.training.config import BERT_CONFIG, SUPERVISED_TRAINING_ARGS -from bertrand.model.model import BERTrand +from bertrand.training.config import SUPERVISED_TRAINING_ARGS from bertrand.model.tokenization import tokenizer +from bertrand.training.prot_bert import ProteinClassifier def parse_args() -> argparse.Namespace: @@ -63,8 +63,6 @@ def get_training_args(output_dir: str) -> TrainingArguments: def train_and_evaluate( train_dataset: PeptideTCRDataset, val_dataset: PeptideTCRDataset, - model_class, - model_ckpt: str, output_dir: str, ) -> None: """ @@ -77,7 +75,6 @@ def train_and_evaluate( :param output_dir: folder to save model checkpoints and predictions for `val_dataset` for every epoch """ predictions = [] - logging.info(f"Model class: {model_class}") def compute_metrics_and_save_predictions(p): predictions.append(p) @@ -88,15 +85,11 @@ def compute_metrics_and_save_predictions(p): True, ) - if model_ckpt: - logging.info(f"Loading model from {model_ckpt}") - model = model_class.from_pretrained(model_ckpt) - else: - logging.info("Initializing model from scratch") - model = model_class(BERT_CONFIG) - training_args = get_training_args(output_dir) data_collator = DataCollatorWithPadding(tokenizer=tokenizer) + + model = ProteinClassifier() + trainer = Trainer( model=model, args=training_args, @@ -132,7 +125,6 @@ def compute_metrics_and_save_predictions(p): train_dataset = PeptideTCRDataset(dataset, cv_seed=cv_seed, subset="train") val_dataset = PeptideTCRDataset(dataset, cv_seed=cv_seed, subset="val+test") logging.info("Training started") - model = BertForSequenceClassification train_and_evaluate( - train_dataset, val_dataset, model, "Rostlab/prot_bert", dataset_out_dir, + train_dataset, val_dataset, dataset_out_dir, ) From 64406a5f86c594af28908bec678953c55ffc14f2 Mon Sep 17 00:00:00 2001 From: yivlad Date: Thu, 25 May 2023 00:59:52 +0200 Subject: [PATCH 05/18] Add flatten layer to prot_bert wrapper --- bertrand/training/prot_bert.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/bertrand/training/prot_bert.py b/bertrand/training/prot_bert.py index 19100de..7fc0b22 100644 --- a/bertrand/training/prot_bert.py +++ b/bertrand/training/prot_bert.py @@ -6,10 +6,11 @@ class ProteinClassifier(nn.Module): def __init__(self): super(ProteinClassifier, self).__init__() self.bert = BertForSequenceClassification.from_pretrained(PRE_TRAINED_MODEL_NAME) - self.classifier = nn.Sequential(nn.Dropout(p=0.2), - nn.Linear(self.bert.config.hidden_size, 1), + self.classifier = nn.Sequential(nn.Flatten(), + nn.Dropout(p=0.2), + nn.Linear(64, 1), nn.Tanh()) def forward(self, *args, **kwargs): output = self.bert(*args, **kwargs) - return self.classifier(output.pooler_output) + return self.classifier(output.logits) From 3081640b6b08bd458483dcaca9a24ac3dc517648 Mon Sep 17 00:00:00 2001 From: yivlad Date: Thu, 25 May 2023 01:05:00 +0200 Subject: [PATCH 06/18] Flatten tensor --- bertrand/training/prot_bert.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/bertrand/training/prot_bert.py b/bertrand/training/prot_bert.py index 7fc0b22..c857a66 100644 --- a/bertrand/training/prot_bert.py +++ b/bertrand/training/prot_bert.py @@ -1,4 +1,5 @@ from transformers import BertForSequenceClassification +import torch import torch.nn as nn PRE_TRAINED_MODEL_NAME = 'Rostlab/prot_bert' @@ -6,11 +7,10 @@ class ProteinClassifier(nn.Module): def __init__(self): super(ProteinClassifier, self).__init__() self.bert = BertForSequenceClassification.from_pretrained(PRE_TRAINED_MODEL_NAME) - self.classifier = nn.Sequential(nn.Flatten(), - nn.Dropout(p=0.2), + self.classifier = nn.Sequential(nn.Dropout(p=0.2), nn.Linear(64, 1), nn.Tanh()) def forward(self, *args, **kwargs): output = self.bert(*args, **kwargs) - return self.classifier(output.logits) + return self.classifier(torch.flatten(output.logits)) From 998f36fb3839e529ac6230a3ffca8eb734942d5c Mon Sep 17 00:00:00 2001 From: yivlad Date: Thu, 25 May 2023 01:17:39 +0200 Subject: [PATCH 07/18] Bump pytorch --- env.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/env.yml b/env.yml index 52be727..a1ed750 100644 --- a/env.yml +++ b/env.yml @@ -12,7 +12,7 @@ dependencies: - pandas=1.4.1=py38h295c915_0 - pip=21.2.4=py38h06a4308_0 - python=3.8.0=h0371630_2 - - pytorch=1.11.0=py3.8_cuda10.2_cudnn7.6.5_0 + - pytorch=1.11.0=py3.8_cuda11.5_cudnn8.3.2_0 - scikit-learn=0.24.2=py38ha9443f7_0 - scipy=1.7.3=py38hc147768_0 - seaborn=0.11.1=pyhd3eb1b0_0 From e62421b6ab594beee31dfbcc1b0922fca3243db2 Mon Sep 17 00:00:00 2001 From: yivlad Date: Sun, 28 May 2023 21:28:13 +0200 Subject: [PATCH 08/18] Update protbert --- bertrand/training/prot_bert.py | 55 +++++++++++++++++++++++++++++----- 1 file changed, 48 insertions(+), 7 deletions(-) diff --git a/bertrand/training/prot_bert.py b/bertrand/training/prot_bert.py index c857a66..ca2d1aa 100644 --- a/bertrand/training/prot_bert.py +++ b/bertrand/training/prot_bert.py @@ -1,16 +1,57 @@ from transformers import BertForSequenceClassification import torch -import torch.nn as nn +from torch import nn +from transformers.modeling_outputs import SequenceClassifierOutput + +from bertrand.model.focal_loss import FocalLoss PRE_TRAINED_MODEL_NAME = 'Rostlab/prot_bert' class ProteinClassifier(nn.Module): def __init__(self): super(ProteinClassifier, self).__init__() + self.num_labels = 2 self.bert = BertForSequenceClassification.from_pretrained(PRE_TRAINED_MODEL_NAME) - self.classifier = nn.Sequential(nn.Dropout(p=0.2), - nn.Linear(64, 1), - nn.Tanh()) - def forward(self, *args, **kwargs): - output = self.bert(*args, **kwargs) - return self.classifier(torch.flatten(output.logits)) + def forward( + self, + input_ids: torch.Tensor = None, + attention_mask: torch.Tensor = None, + token_type_ids: torch.Tensor = None, + position_ids: torch.Tensor = None, + head_mask: torch.Tensor = None, + inputs_embeds: torch.Tensor = None, + labels: torch.Tensor = None, + weights: torch.Tensor = None, + output_attentions: bool = None, + output_hidden_states: bool = None, + return_dict: bool = None, + ): + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + logits = outputs.logits + + loss = None + if labels is not None: + loss_fct = FocalLoss(gamma=3, alpha=0.25, no_agg=True) + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + loss = torch.mean(loss * weights) + + if not return_dict: + output = (logits, outputs.hidden_states, outputs.attentions,) + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) From 629b4b385f913bee25ed7dc32e76453058d63f33 Mon Sep 17 00:00:00 2001 From: yivlad Date: Mon, 29 May 2023 01:12:42 +0200 Subject: [PATCH 09/18] Update env.yml --- env.yml | 143 +++++++++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 125 insertions(+), 18 deletions(-) diff --git a/env.yml b/env.yml index a1ed750..b0aa302 100644 --- a/env.yml +++ b/env.yml @@ -1,27 +1,134 @@ name: bertrand channels: + - nvidia - pytorch - defaults dependencies: - - biopython=1.78=py38h7b6447c_0 - - h5py=2.10.0=py38hd6299e0_1 - - hdf5=1.10.6=hb1b8bf9_0 - - joblib=1.1.0=pyhd3eb1b0_0 - - matplotlib=3.3.4=py38h06a4308_0 - - numpy=1.21.2=py38h20f2e39_0 - - pandas=1.4.1=py38h295c915_0 - - pip=21.2.4=py38h06a4308_0 + - _libgcc_mutex=0.1=main + - _openmp_mutex=5.1=1_gnu + - biopython=1.78=py38h7f8727e_0 + - blas=1.0=mkl + - bottleneck=1.3.5=py38h7deecbd_0 + - brotli=1.0.9=h5eee18b_7 + - brotli-bin=1.0.9=h5eee18b_7 + - ca-certificates=2023.01.10=h06a4308_0 + - contourpy=1.0.5=py38hdb19cb5_0 + - cudatoolkit=11.5.1=hcf5317a_9 + - cycler=0.11.0=pyhd3eb1b0_0 + - dbus=1.13.18=hb2f20db_0 + - expat=2.4.9=h6a678d5_0 + - fftw=3.3.9=h27cfd23_1 + - fontconfig=2.14.1=h52c9d5c_1 + - fonttools=4.25.0=pyhd3eb1b0_0 + - freetype=2.12.1=h4a9f257_0 + - giflib=5.2.1=h5eee18b_3 + - glib=2.63.1=h5a9c865_0 + - gst-plugins-base=1.14.0=hbbd80ab_1 + - gstreamer=1.14.0=hb453b48_1 + - h5py=3.7.0=py38h737f45e_0 + - hdf5=1.10.6=h3ffc7dd_1 + - icu=58.2=he6710b0_3 + - importlib_resources=5.2.0=pyhd3eb1b0_1 + - intel-openmp=2021.4.0=h06a4308_3561 + - joblib=1.2.0=py38h06a4308_0 + - jpeg=9e=h5eee18b_1 + - kiwisolver=1.4.4=py38h6a678d5_0 + - lcms2=2.12=h3be6417_0 + - lerc=3.0=h295c915_0 + - libbrotlicommon=1.0.9=h5eee18b_7 + - libbrotlidec=1.0.9=h5eee18b_7 + - libbrotlienc=1.0.9=h5eee18b_7 + - libdeflate=1.17=h5eee18b_0 + - libedit=3.1.20221030=h5eee18b_0 + - libffi=3.2.1=hf484d3e_1007 + - libgcc-ng=11.2.0=h1234567_1 + - libgfortran-ng=11.2.0=h00389a5_1 + - libgfortran5=11.2.0=h1234567_1 + - libgomp=11.2.0=h1234567_1 + - libpng=1.6.39=h5eee18b_0 + - libstdcxx-ng=11.2.0=h1234567_1 + - libtiff=4.5.0=h6a678d5_2 + - libuuid=1.41.5=h5eee18b_0 + - libuv=1.44.2=h5eee18b_0 + - libwebp=1.2.4=h11a3e52_1 + - libwebp-base=1.2.4=h5eee18b_1 + - libxcb=1.15=h7f8727e_0 + - libxml2=2.9.14=h74e7548_0 + - lz4-c=1.9.4=h6a678d5_0 + - matplotlib=3.7.1=py38h06a4308_1 + - matplotlib-base=3.7.1=py38h417a72b_1 + - mkl=2021.4.0=h06a4308_640 + - mkl-service=2.4.0=py38h7f8727e_0 + - mkl_fft=1.3.1=py38hd3c417c_0 + - mkl_random=1.2.2=py38h51133e4_0 + - munkres=1.1.4=py_0 + - ncurses=6.4=h6a678d5_0 + - numexpr=2.8.4=py38he184ba9_0 + - numpy=1.22.3=py38he7a7128_0 + - numpy-base=1.22.3=py38hf524024_0 + - openssl=1.1.1t=h7f8727e_0 + - packaging=23.0=py38h06a4308_0 + - pandas=1.5.3=py38h417a72b_0 + - pcre=8.45=h295c915_0 + - pillow=9.4.0=py38h6a678d5_0 + - pip=23.0.1=py38h06a4308_0 + - pyparsing=3.0.9=py38h06a4308_0 + - pyqt=5.9.2=py38h05f1152_4 - python=3.8.0=h0371630_2 + - python-dateutil=2.8.2=pyhd3eb1b0_0 - pytorch=1.11.0=py3.8_cuda11.5_cudnn8.3.2_0 + - pytorch-mutex=1.0=cuda + - pytz=2022.7=py38h06a4308_0 + - qt=5.9.7=h5867ecd_1 + - readline=7.0=h7b6447c_5 - scikit-learn=0.24.2=py38ha9443f7_0 - - scipy=1.7.3=py38hc147768_0 - - seaborn=0.11.1=pyhd3eb1b0_0 - - tokenizers=0.10.3=py38hb317417_1 - - tqdm=4.62.3=pyhd3eb1b0_1 + - scipy=1.7.3=py38h6c91a56_2 + - seaborn=0.12.2=py38h06a4308_0 + - setuptools=66.0.0=py38h06a4308_0 + - sip=4.19.13=py38h295c915_0 + - six=1.16.0=pyhd3eb1b0_1 + - sqlite=3.33.0=h62c20be_0 + - threadpoolctl=2.2.0=pyh0d69192_0 + - tk=8.6.12=h1ccaba5_0 + - tokenizers=0.11.4=py38h3dcd8bd_1 + - tornado=6.2=py38h5eee18b_0 + - tqdm=4.65.0=py38hb070fc8_0 + - typing_extensions=4.5.0=py38h06a4308_0 + - wheel=0.38.4=py38h06a4308_0 + - xz=5.4.2=h5eee18b_0 + - zipp=3.11.0=py38h06a4308_0 + - zlib=1.2.13=h5eee18b_0 + - zstd=1.5.5=hc292b87_0 - pip: - - datasets==1.18.3 - - fastcluster==1.2.4 - - leven==1.0.4 - - pytorch-lightning==0.7.1 - - transformers==4.16.2 -prefix: /home/ardigen/miniconda3/envs/bertrand + - aiohttp==3.8.4 + - aiosignal==1.3.1 + - async-timeout==4.0.2 + - attrs==23.1.0 + - certifi==2023.5.7 + - charset-normalizer==3.1.0 + - click==8.1.3 + - datasets==2.12.0 + - dill==0.3.6 + - fastcluster==1.2.6 + - filelock==3.12.0 + - frozenlist==1.3.3 + - fsspec==2023.5.0 + - huggingface-hub==0.14.1 + - idna==3.4 + - leven==1.0.4 + - lightning-utilities==0.8.0 + - multidict==6.0.4 + - multiprocess==0.70.14 + - nose==1.3.7 + - pyarrow==12.0.0 + - pytorch-lightning==2.0.2 + - pyyaml==6.0 + - regex==2023.5.5 + - requests==2.31.0 + - responses==0.18.0 + - sacremoses==0.0.53 + - torchmetrics==0.11.4 + - transformers==4.16.2 + - urllib3==2.0.2 + - xxhash==3.2.0 + - yarl==1.9.2 From 88942bccffd1eec1e36899c3b1c0cbc1fe71d68b Mon Sep 17 00:00:00 2001 From: yivlad Date: Wed, 31 May 2023 00:25:28 +0200 Subject: [PATCH 10/18] Update SUPERVISED_TRAINING_ARGS --- bertrand/training/config.py | 1 + 1 file changed, 1 insertion(+) diff --git a/bertrand/training/config.py b/bertrand/training/config.py index 63cee8b..1a1ee79 100644 --- a/bertrand/training/config.py +++ b/bertrand/training/config.py @@ -28,6 +28,7 @@ logging_steps=10, # logging every 10 steps evaluation_strategy="epoch", # model is evaluated every epoch save_strategy="epoch", # model is saved every epoch + metric_for_best_model="roc", ) # Training args for MLM From 454c622d30412f63ac21c672ee0fcbad60b7e804 Mon Sep 17 00:00:00 2001 From: yivlad Date: Sun, 4 Jun 2023 22:25:31 +0200 Subject: [PATCH 11/18] Simplify protbert --- bertrand/training/prot_bert.py | 11 +++++------ bertrand/training/train.py | 4 ++-- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/bertrand/training/prot_bert.py b/bertrand/training/prot_bert.py index ca2d1aa..2cb9099 100644 --- a/bertrand/training/prot_bert.py +++ b/bertrand/training/prot_bert.py @@ -1,16 +1,15 @@ from transformers import BertForSequenceClassification import torch -from torch import nn from transformers.modeling_outputs import SequenceClassifierOutput from bertrand.model.focal_loss import FocalLoss PRE_TRAINED_MODEL_NAME = 'Rostlab/prot_bert' -class ProteinClassifier(nn.Module): - def __init__(self): - super(ProteinClassifier, self).__init__() +class ProteinClassifier(BertForSequenceClassification): + def __init__(self, config): + super().__init__(config) self.num_labels = 2 - self.bert = BertForSequenceClassification.from_pretrained(PRE_TRAINED_MODEL_NAME) + self.config = config def forward( self, @@ -26,7 +25,7 @@ def forward( output_hidden_states: bool = None, return_dict: bool = None, ): - outputs = self.bert( + outputs = super().forward( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, diff --git a/bertrand/training/train.py b/bertrand/training/train.py index 216eee6..14b2627 100644 --- a/bertrand/training/train.py +++ b/bertrand/training/train.py @@ -10,7 +10,7 @@ from bertrand.training.metrics import mean_auroc_per_peptide_cluster from bertrand.training.config import SUPERVISED_TRAINING_ARGS from bertrand.model.tokenization import tokenizer -from bertrand.training.prot_bert import ProteinClassifier +from bertrand.training.prot_bert import PRE_TRAINED_MODEL_NAME, ProteinClassifier def parse_args() -> argparse.Namespace: @@ -88,7 +88,7 @@ def compute_metrics_and_save_predictions(p): training_args = get_training_args(output_dir) data_collator = DataCollatorWithPadding(tokenizer=tokenizer) - model = ProteinClassifier() + model = ProteinClassifier.from_pretrained(PRE_TRAINED_MODEL_NAME) trainer = Trainer( model=model, From 5be693901c0a3699139b42128595d5c55c37b65e Mon Sep 17 00:00:00 2001 From: yivlad Date: Sun, 4 Jun 2023 22:27:33 +0200 Subject: [PATCH 12/18] Update --- bertrand/training/prot_bert.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bertrand/training/prot_bert.py b/bertrand/training/prot_bert.py index 2cb9099..d1bc604 100644 --- a/bertrand/training/prot_bert.py +++ b/bertrand/training/prot_bert.py @@ -45,7 +45,7 @@ def forward( loss = torch.mean(loss * weights) if not return_dict: - output = (logits, outputs.hidden_states, outputs.attentions,) + output = (logits,) return ((loss,) + output) if loss is not None else output return SequenceClassifierOutput( From 4cfcf052cfc4d56f230f169b4c476a8d64295523 Mon Sep 17 00:00:00 2001 From: yivlad Date: Wed, 7 Jun 2023 00:15:08 +0200 Subject: [PATCH 13/18] Update evaluate function --- bertrand/training/evaluate.py | 9 ++++----- bertrand/training/prot_bert.py | 2 +- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/bertrand/training/evaluate.py b/bertrand/training/evaluate.py index 48c3791..f9828bf 100644 --- a/bertrand/training/evaluate.py +++ b/bertrand/training/evaluate.py @@ -1,7 +1,6 @@ import os import shutil from copy import deepcopy -from functools import partial from glob import glob from typing import Union, List, Tuple, Dict @@ -14,8 +13,8 @@ from bertrand.training.config import SUPERVISED_TRAINING_ARGS from bertrand.training.dataset import PeptideTCRDataset from bertrand.training.metrics import mean_auroc_per_peptide_cluster -from bertrand.model.model import BERTrand from bertrand.model.tokenization import tokenizer +from bertrand.training.prot_bert import ProteinClassifier from bertrand.training.utils import get_last_ckpt, load_metrics_df import argparse @@ -47,7 +46,7 @@ def parse_args() -> argparse.Namespace: def get_trainer( - model: BERTrand, test_dataset: PeptideTCRDataset, batch_size: int = 512 + model: ProteinClassifier, test_dataset: PeptideTCRDataset, batch_size: int = 512 ) -> Trainer: """ Creates a Trainer for the model to do inference on a dataset @@ -70,7 +69,7 @@ def get_trainer( def get_predictions( - model: BERTrand, test_dataset: PeptideTCRDataset + model: ProteinClassifier, test_dataset: PeptideTCRDataset ) -> PredictionOutput: """ Performs inference on a dataset @@ -90,7 +89,7 @@ def evaluate_cancer(cancer_dataset: PeptideTCRDataset, ckpt: str) -> pd.DataFram :param ckpt: model checkpoint folder :return: dataframe with AUROC for every peptide """ - model = BERTrand.from_pretrained(ckpt) + model = ProteinClassifier.from_pretrained(ckpt) model.eval() predictions = get_predictions(model, cancer_dataset) diff --git a/bertrand/training/prot_bert.py b/bertrand/training/prot_bert.py index d1bc604..b22ad58 100644 --- a/bertrand/training/prot_bert.py +++ b/bertrand/training/prot_bert.py @@ -4,7 +4,7 @@ from bertrand.model.focal_loss import FocalLoss -PRE_TRAINED_MODEL_NAME = 'Rostlab/prot_bert' +PRE_TRAINED_MODEL_NAME = 'yarongef/DistilProtBert' class ProteinClassifier(BertForSequenceClassification): def __init__(self, config): super().__init__(config) From 9829f3168cb15a29f22151c05d48ad134c339eb2 Mon Sep 17 00:00:00 2001 From: Yatsenko Vladyslav Date: Sat, 9 Sep 2023 16:44:50 +0200 Subject: [PATCH 14/18] Add job script --- analysis.sh | 6 +++--- bertrand-job.sh | 17 +++++++++++++++++ env.yml | 1 + train_and_evaluate.sh | 4 ++-- 4 files changed, 23 insertions(+), 5 deletions(-) create mode 100644 bertrand-job.sh diff --git a/analysis.sh b/analysis.sh index a011da8..c49f7f3 100755 --- a/analysis.sh +++ b/analysis.sh @@ -4,12 +4,12 @@ CPU=$2 # First run MLM pre-training # This step is faster with a GPU -# bash pretraining.sh "$DIR"/pretraining +bash pretraining.sh "$DIR"/pretraining # Then generate negative decoys # This step is very CPU and RAM intensive -bash negative_decoys.sh "$DIR"/negative_decoys "$CPU" +# bash negative_decoys.sh "$DIR"/negative_decoys "$CPU" # Finally perform supervised training and evaluate the model # This step is faster with a GPU -bash train_and_evaluate.sh "$DIR"/negative_decoys/datasets "$DIR"/pretraining/model "$DIR"/training \ No newline at end of file +# bash train_and_evaluate.sh "$DIR"/negative_decoys/datasets "$DIR"/pretraining/model "$DIR"/training diff --git a/bertrand-job.sh b/bertrand-job.sh new file mode 100644 index 0000000..cdbff19 --- /dev/null +++ b/bertrand-job.sh @@ -0,0 +1,17 @@ +#!/bin/bash +#SBATCH --mail-type=ALL # Powiadomienia mailowe. Opcje: NONE, BEGIN, END, FAIL, ALL +#SBATCH --mail-user=jacenko.vlad@gmail.com # adres e-mail +#SBATCH --ntasks=4 # Uruchomienie na jednym procesorze +#SBATCH --mem=32gb +#SBATCH --gpus=a100:1 +#SBATCH --time=72:00:00 # maksymalny limit czasu DD-HH:MM:SS +#SBATCH --partition=long + +pwd; hostname; date + +source /home2/sfglab/yvladyslav/anaconda3/etc/profile.d/conda.sh +cd bertrand +conda activate bertrand +./analysis.sh "/home2/sfglab/yvladyslav/bertrand_results" 16 + +date diff --git a/env.yml b/env.yml index b0aa302..d15a02a 100644 --- a/env.yml +++ b/env.yml @@ -132,3 +132,4 @@ dependencies: - urllib3==2.0.2 - xxhash==3.2.0 - yarl==1.9.2 +prefix: /home2/sfglab/yvladyslav/anaconda3/envs/bertrand diff --git a/train_and_evaluate.sh b/train_and_evaluate.sh index a350cd5..647dd45 100644 --- a/train_and_evaluate.sh +++ b/train_and_evaluate.sh @@ -1,4 +1,4 @@ -set -x +set -ex DATA_DIR=$1 MODEL_DIR=$2 OUT_DIR=$3 @@ -14,4 +14,4 @@ python -m bertrand.training.train \ python -m bertrand.training.evaluate \ --datasets-dir=$DATA_DIR \ --results-dir=$OUT_DIR \ - --out=$OUT_DIR/results.csv \ No newline at end of file + --out=$OUT_DIR/results.csv From 7bb3450ade19c39f3cd0ca4a74009e8380c08b63 Mon Sep 17 00:00:00 2001 From: yivlad Date: Sat, 9 Sep 2023 17:22:57 +0200 Subject: [PATCH 15/18] Set splits amount to 1 --- analysis.sh | 2 +- bertrand-job.sh | 2 +- train_and_evaluate.sh | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/analysis.sh b/analysis.sh index c49f7f3..dcccd12 100755 --- a/analysis.sh +++ b/analysis.sh @@ -12,4 +12,4 @@ bash pretraining.sh "$DIR"/pretraining # Finally perform supervised training and evaluate the model # This step is faster with a GPU -# bash train_and_evaluate.sh "$DIR"/negative_decoys/datasets "$DIR"/pretraining/model "$DIR"/training +bash train_and_evaluate.sh "$DIR"/negative_decoys/datasets "$DIR"/pretraining/model "$DIR"/training diff --git a/bertrand-job.sh b/bertrand-job.sh index cdbff19..ec88676 100644 --- a/bertrand-job.sh +++ b/bertrand-job.sh @@ -12,6 +12,6 @@ pwd; hostname; date source /home2/sfglab/yvladyslav/anaconda3/etc/profile.d/conda.sh cd bertrand conda activate bertrand -./analysis.sh "/home2/sfglab/yvladyslav/bertrand_results" 16 +./analysis.sh "/home2/sfglab/yvladyslav/distilbert/bertrand_results" 4 date diff --git a/train_and_evaluate.sh b/train_and_evaluate.sh index 647dd45..4df5aee 100644 --- a/train_and_evaluate.sh +++ b/train_and_evaluate.sh @@ -9,7 +9,7 @@ python -m bertrand.training.train \ --input-dir=$DATA_DIR \ --model-ckpt=$MODEL_DIR \ --output-dir=$OUT_DIR \ - --n-splits=21 + --n-splits=1 python -m bertrand.training.evaluate \ --datasets-dir=$DATA_DIR \ From dac4f30eead4bdb3245e82faa14a8a64352f7dd4 Mon Sep 17 00:00:00 2001 From: yivlad Date: Sat, 9 Sep 2023 18:15:08 +0200 Subject: [PATCH 16/18] Keep only pretrain step --- analysis.sh | 2 +- bertrand-job.sh | 2 +- bertrand/pretraining/train_mlm.py | 4 +-- bertrand/training/evaluate.py | 8 ++--- bertrand/training/prot_bert.py | 56 ------------------------------- bertrand/training/train.py | 19 +++++++---- 6 files changed, 21 insertions(+), 70 deletions(-) delete mode 100644 bertrand/training/prot_bert.py diff --git a/analysis.sh b/analysis.sh index dcccd12..c49f7f3 100755 --- a/analysis.sh +++ b/analysis.sh @@ -12,4 +12,4 @@ bash pretraining.sh "$DIR"/pretraining # Finally perform supervised training and evaluate the model # This step is faster with a GPU -bash train_and_evaluate.sh "$DIR"/negative_decoys/datasets "$DIR"/pretraining/model "$DIR"/training +# bash train_and_evaluate.sh "$DIR"/negative_decoys/datasets "$DIR"/pretraining/model "$DIR"/training diff --git a/bertrand-job.sh b/bertrand-job.sh index ec88676..44f6e43 100644 --- a/bertrand-job.sh +++ b/bertrand-job.sh @@ -12,6 +12,6 @@ pwd; hostname; date source /home2/sfglab/yvladyslav/anaconda3/etc/profile.d/conda.sh cd bertrand conda activate bertrand -./analysis.sh "/home2/sfglab/yvladyslav/distilbert/bertrand_results" 4 +./analysis.sh "/home2/sfglab/yvladyslav/pretrain-mlm/bertrand_results" 4 date diff --git a/bertrand/pretraining/train_mlm.py b/bertrand/pretraining/train_mlm.py index 911a70d..8181e40 100644 --- a/bertrand/pretraining/train_mlm.py +++ b/bertrand/pretraining/train_mlm.py @@ -11,7 +11,7 @@ DataCollatorForLanguageModeling, ) -from bertrand.training.config import MLM_TRAINING_ARGS +from bertrand.training.config import BERT_CONFIG, MLM_TRAINING_ARGS from bertrand.pretraining.dataset_mlm import PeptideTCRMLMDataset from bertrand.model.tokenization import tokenizer @@ -69,7 +69,7 @@ def get_training_args(output_dir: str) -> TrainingArguments: train_dataset = PeptideTCRMLMDataset(train) val_dataset = PeptideTCRMLMDataset(val) - model = BertForMaskedLM.from_pretrained("Rostlab/prot_bert") + model = BertForMaskedLM(BERT_CONFIG) data_collator = DataCollatorForLanguageModeling( tokenizer=tokenizer, mlm=True, mlm_probability=args.mlm_frac ) diff --git a/bertrand/training/evaluate.py b/bertrand/training/evaluate.py index f9828bf..d1b8b0f 100644 --- a/bertrand/training/evaluate.py +++ b/bertrand/training/evaluate.py @@ -13,8 +13,8 @@ from bertrand.training.config import SUPERVISED_TRAINING_ARGS from bertrand.training.dataset import PeptideTCRDataset from bertrand.training.metrics import mean_auroc_per_peptide_cluster +from bertrand.model.model import BERTrand from bertrand.model.tokenization import tokenizer -from bertrand.training.prot_bert import ProteinClassifier from bertrand.training.utils import get_last_ckpt, load_metrics_df import argparse @@ -46,7 +46,7 @@ def parse_args() -> argparse.Namespace: def get_trainer( - model: ProteinClassifier, test_dataset: PeptideTCRDataset, batch_size: int = 512 + model: BERTrand, test_dataset: PeptideTCRDataset, batch_size: int = 512 ) -> Trainer: """ Creates a Trainer for the model to do inference on a dataset @@ -69,7 +69,7 @@ def get_trainer( def get_predictions( - model: ProteinClassifier, test_dataset: PeptideTCRDataset + model: BERTrand, test_dataset: PeptideTCRDataset ) -> PredictionOutput: """ Performs inference on a dataset @@ -89,7 +89,7 @@ def evaluate_cancer(cancer_dataset: PeptideTCRDataset, ckpt: str) -> pd.DataFram :param ckpt: model checkpoint folder :return: dataframe with AUROC for every peptide """ - model = ProteinClassifier.from_pretrained(ckpt) + model = BERTrand.from_pretrained(ckpt) model.eval() predictions = get_predictions(model, cancer_dataset) diff --git a/bertrand/training/prot_bert.py b/bertrand/training/prot_bert.py deleted file mode 100644 index b22ad58..0000000 --- a/bertrand/training/prot_bert.py +++ /dev/null @@ -1,56 +0,0 @@ -from transformers import BertForSequenceClassification -import torch -from transformers.modeling_outputs import SequenceClassifierOutput - -from bertrand.model.focal_loss import FocalLoss - -PRE_TRAINED_MODEL_NAME = 'yarongef/DistilProtBert' -class ProteinClassifier(BertForSequenceClassification): - def __init__(self, config): - super().__init__(config) - self.num_labels = 2 - self.config = config - - def forward( - self, - input_ids: torch.Tensor = None, - attention_mask: torch.Tensor = None, - token_type_ids: torch.Tensor = None, - position_ids: torch.Tensor = None, - head_mask: torch.Tensor = None, - inputs_embeds: torch.Tensor = None, - labels: torch.Tensor = None, - weights: torch.Tensor = None, - output_attentions: bool = None, - output_hidden_states: bool = None, - return_dict: bool = None, - ): - outputs = super().forward( - input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - ) - - logits = outputs.logits - - loss = None - if labels is not None: - loss_fct = FocalLoss(gamma=3, alpha=0.25, no_agg=True) - loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) - loss = torch.mean(loss * weights) - - if not return_dict: - output = (logits,) - return ((loss,) + output) if loss is not None else output - - return SequenceClassifierOutput( - loss=loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) diff --git a/bertrand/training/train.py b/bertrand/training/train.py index 14b2627..a5944be 100644 --- a/bertrand/training/train.py +++ b/bertrand/training/train.py @@ -8,9 +8,9 @@ from bertrand.training.dataset import PeptideTCRDataset from bertrand.training.metrics import mean_auroc_per_peptide_cluster -from bertrand.training.config import SUPERVISED_TRAINING_ARGS +from bertrand.training.config import BERT_CONFIG, SUPERVISED_TRAINING_ARGS +from bertrand.model.model import BERTrand from bertrand.model.tokenization import tokenizer -from bertrand.training.prot_bert import PRE_TRAINED_MODEL_NAME, ProteinClassifier def parse_args() -> argparse.Namespace: @@ -63,6 +63,8 @@ def get_training_args(output_dir: str) -> TrainingArguments: def train_and_evaluate( train_dataset: PeptideTCRDataset, val_dataset: PeptideTCRDataset, + model_class, + model_ckpt: str, output_dir: str, ) -> None: """ @@ -75,6 +77,7 @@ def train_and_evaluate( :param output_dir: folder to save model checkpoints and predictions for `val_dataset` for every epoch """ predictions = [] + logging.info(f"Model class: {model_class}") def compute_metrics_and_save_predictions(p): predictions.append(p) @@ -85,11 +88,15 @@ def compute_metrics_and_save_predictions(p): True, ) + if model_ckpt: + logging.info(f"Loading model from {model_ckpt}") + model = model_class.from_pretrained(model_ckpt) + else: + logging.info("Initializing model from scratch") + model = model_class(BERT_CONFIG) + training_args = get_training_args(output_dir) data_collator = DataCollatorWithPadding(tokenizer=tokenizer) - - model = ProteinClassifier.from_pretrained(PRE_TRAINED_MODEL_NAME) - trainer = Trainer( model=model, args=training_args, @@ -126,5 +133,5 @@ def compute_metrics_and_save_predictions(p): val_dataset = PeptideTCRDataset(dataset, cv_seed=cv_seed, subset="val+test") logging.info("Training started") train_and_evaluate( - train_dataset, val_dataset, dataset_out_dir, + train_dataset, val_dataset, BERTrand, args.model_ckpt, dataset_out_dir, ) From f7557194ed758b3f7a28bd4952929a89c555549a Mon Sep 17 00:00:00 2001 From: yivlad Date: Sat, 9 Sep 2023 18:16:00 +0200 Subject: [PATCH 17/18] Remove roc metric --- bertrand/training/config.py | 1 - 1 file changed, 1 deletion(-) diff --git a/bertrand/training/config.py b/bertrand/training/config.py index 1a1ee79..63cee8b 100644 --- a/bertrand/training/config.py +++ b/bertrand/training/config.py @@ -28,7 +28,6 @@ logging_steps=10, # logging every 10 steps evaluation_strategy="epoch", # model is evaluated every epoch save_strategy="epoch", # model is saved every epoch - metric_for_best_model="roc", ) # Training args for MLM From 2ee4e1f477838619aec5f5abf514a8ff65648033 Mon Sep 17 00:00:00 2001 From: yivlad Date: Sat, 9 Sep 2023 19:02:43 +0200 Subject: [PATCH 18/18] Fix --- bertrand/pretraining/peptide_tcr_repertoire.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/bertrand/pretraining/peptide_tcr_repertoire.py b/bertrand/pretraining/peptide_tcr_repertoire.py index 0ea5894..b792b11 100644 --- a/bertrand/pretraining/peptide_tcr_repertoire.py +++ b/bertrand/pretraining/peptide_tcr_repertoire.py @@ -80,11 +80,10 @@ def read_peptides(fn: str) -> pd.DataFrame: logging.info(f"{len(presented_peptides)} peptides read") presented_unique = ( presented_peptides.reset_index() - .groupby("Peptide2") + .groupby("peptide_seq") .agg( { "HLA_type": lambda x: "|".join(sorted(x)), - "index": lambda x: "|".join(sorted(x)), } ) .reset_index() @@ -114,7 +113,7 @@ def sample_peptide_tcr_repertoire( ) peptides_sampled.loc[:, "CDR3b"] = synthetic_tcrs.values - peptide_tcr_repertoire = peptides_sampled.rename(columns={"Peptide2": "Peptide"}) + peptide_tcr_repertoire = peptides_sampled.rename(columns={"peptide_seq": "Peptide"}) return peptide_tcr_repertoire