Custom transformer model
Overview
- This project includes a pytorch implementation of a customer transformer model that takes as input contextual sequence for genes of interest, predict relative gene expression and produce saliency for maps for the input sequence.
- Implemented in: Tross, M. C., Duggan, G., Shrestha, N., & Schnable, J. C. (2024). Models trained to predict differential expression across plant organs identify distal and proximal regulatory regions. https://doi.org/10.1101/2024.06.04.597477
Getting Started Prerequisites
- Python 3.x
- Pip package manager
Installation
-
Clone the Repository
- Run the following commands:
git clone https://github.com/mtross2/transformer_regulatory_sequence.gitcd transformer_regulatory_sequence -
Set Up a Virtual Environment (Optional but recommended)
- For Windows:
python -m venv venv.\venv\Scripts\activate- For Unix or MacOS:
python3 -m venv venv
```{bash} source venv/bin/activate -
Install Required Packages
- Execute the command:
pip install -r requirements.txt -
Install Your Package (Optional if you want to use it as a package)
- Use this command:
python setup.py install
This script (train.py) is used to train a deep learning model for predicting gene expression levels based on gene sequences.
To train the model, simply run the script with the necessary arguments:
python train.py --data_dir /path/to/data --max_epochs 2000 --seq_max_len 90000 --batch_size 1 --num_gpus 1 --learning_rate 0.000001 --patience 100 --num_genes 28200 --num_val_genes 2000
Arguments
data_dir: Path to the data directory containing training data.
max_epochs: Maximum number of epochs for training.
seq_max_len: Maximum length of gene sequences.
batch_size: Batch size for training.
num_gpus: Number of GPUs to use.
learning_rate: Learning rate for training the model.
patience: Number of epochs with no loss improvement before stopping training.
num_genes: Number of genes for training.
num_val_genes: Number of genes for validation.
This script (predict.py) is used to predict gene expression levels and generate saliency maps based on provided gene sequences.
python predict.py --model /path/to/model.pth --sequence_file /path/to/sequence.txt --expression_file /path/to/expression.txt
Replace /path/to/model.pth, /path/to/sequence.txt, and /path/to/expression.txt with the paths to your trained model, gene sequence file, and expression file, respectively.
model: Path to the saved model file.
sequence_file: Path to the file containing gene sequence.
expression_file: Path to the file containing gene expression data.
This project is licensed under the CC-BY-NC License. See the LICENSE file for details.