Unofficial PyTorch version of ProteInfer (https://github.com/google-research/proteinfer), originally implemented in TensorFlow 1.X.
ProteInfer is a model for protein function prediction that is trained to predict the functional properties of protein sequences using Deep Learning. Authors provide pre-trained models for two tasks: Gene Ontology (GO) and Enzyme Commission (EC) number prediction, as well as two data splits two data splits: random and clustered. Additionally, for every task and data split combination, authors trained multiple models using different random seeds.
This repo contains PyTorch code to run inference, train, or extract embeddings for four ProteInfer models - one for each task/data split combination. All model weights are hosted in Hugging Face 🤗.
The table below summarizes ProteInferTorch's F1 Micro performance on the original ProteInfer test sets using the Pytorch converted weights:
| Data Split | Task | ProteInfer | ProteInferTorch | Weights 🤗 |
|---|---|---|---|---|
| random | GO | 0.885 | 0.886 | Link |
| clustered | GO | Not Reported | 0.784 | Link |
| random | EC | 0.977 | 0.979 | Link |
| clustered | EC | 0.914 | 0.914 | Link |
TODO: ProteInferTorch's performance when training from scratch (i.e., random weights) Using the training script and an 8-V100 GPU cluster, ProteinferTorch achieves XX F1 Micro on the GO task for the random split test set after XX epochs.
TODO: Instert train/val plot
- Installation
- Config
- Data
- Input data format
- Inference
- Extract Embeddings
- Train
- Citation
- Additional scripts
git clone https://github.com/samirchar/proteinfertorch
cd proteinfertorch
conda env create -f environment.yml
conda activate proteinfertorch
pip install -e ./ # make sure ./ is the dir including setup.py
All default hyperparameters and default arugments for the scripts are stored in config/config.yaml.
All tge data to train and run inference with ProteInferTorch is available in the data.zip file (945 MB) hosted in Zenodo using the following command from the ProteInferTorch root folder
sudo apt-get install unzip
curl -O https://zenodo.org/records/14514368/files/data.zip?download=1
unzip data.zip
The data folder has the following structure:
- data/
- random_split/: contains the train, dev, test fasta files for all tasks using the random split method
- clustered_split/: contains the train, dev, test fasta files for all tasks using the clustered split method
- parenthood/: holds a JSON with the EC and GO graphs, used by ProteInfer to normalize output probabilities.
This package uses the standard FASTA format, which is a standard in bioinformatics. All scripts have an optional arugment called --fasta-separator that defaults to " " and represents how elements of header (i.e., sequence id and labels, if any) are separated.
To run inference simply run and evaluate model performance run:
python bin/inference.py --data-path data/random_split/test_GO.fasta --vocabulary-path data/random_split/vocabularies/full_GO.json --weights-dir samirchar/proteinfertorch-go-random-13731645
To save the prediction logits, probabilities and labels, add the flag --save-prediction-results. For EC numers use the full_EC.json vocabulary.
Users can extract and save ProteInferTorch embeddings using the get_embeddings.py script. The embeddings will be stored in one or more .pt files inside of --output-dir, depending on the number of --num-embedding-partitions specified.
python bin/get_embeddings.py --data-path data/random_split/test_GO.fasta --weights-dir samirchar/proteinfertorch-go-random-13731645
By default, --num-embedding-partitions=10.
The model can be trained from scratch or from pretrained weights depending on the value of the --weights-dir argument.
To train from scratch run:
python bin/train.py --train-data-path data/random_split/train_GO.fasta --validation-data-path data/random_split/dev_GO.fasta --test-data-path data/random_split/test_GO.fasta --vocabulary-path data/random_split/vocabularies/full_GO.json
To start from pretrained weights:
python bin/train.py --train-data-path data/random_split/train_GO.fasta --validation-data-path data/random_split/dev_GO.fasta --test-data-path data/random_split/test_GO.fasta --vocabulary-path data/random_split/vocabularies/full_GO.json --weights-dir samirchar/proteinfertorch-go-random-13731645
For EC numers use the full_EC.json vocabulary.
If you use this model in your work, I would greatly appreciate it if you could cite it as follows:
@misc{char2024pytorchmodel,
title={ProteInferTorch: a PyTorch implementation of ProteInfer},
version={v1.0.0},
author={Samir Char},
year={2024},
month={12},
day={08},
doi={10.5281/zenodo.14514368},
url={https://github.com/samirchar/proteinfertorch}
}This section describes additional scripts available in the bin folder
The following code create train, dev and test FASTA files for both tasks and data splits from the original datasets in tfrecord format.
conda env create -f proteinfer_conda_requirements.yml
conda activate proteinfer
python bin/make_proteinfer_dataset.py --data-input-dir data/clustered_split/tfrecords/ --data-output-dir data/clustered_split/ --vocab-output-dir data/clustered_split/vocabularies/ --annotation-types GO
python bin/make_proteinfer_dataset.py --data-input-dir data/clustered_split/tfrecords/ --data-output-dir data/clustered_split/ --vocab-output-dir data/clustered_split/vocabularies/ --annotation-types EC
python bin/make_proteinfer_dataset.py --data-input-dir data/random_split/tfrecords/ --data-output-dir data/random_split/ --vocab-output-dir data/random_split/vocabularies/ --annotation-types GO
python bin/make_proteinfer_dataset.py --data-input-dir data/random_split/tfrecords/ --data-output-dir data/random_split/ --vocab-output-dir data/random_split/vocabularies/ --annotation-types EC
conda activate proteinfertorch
Use the following code to download the original tensorflow weights for the two tasks and data splits, and convert them pkl format:
python bin/download_proteinfer_weights.py --task go --data-split clustered --ids 13703731 --output-dir data/model_weights/tf_weights/
python bin/download_proteinfer_weights.py --task go --data-split random --ids 13731645 --output-dir data/model_weights/tf_weights/
python bin/download_proteinfer_weights.py --task ec --data-split clustered --ids 13704042 --output-dir data/model_weights/tf_weights/
python bin/download_proteinfer_weights.py --task ec --data-split random --ids 13685140 --output-dir data/model_weights/tf_weights/