diff --git a/envs/bp3.yaml b/envs/bp3.yaml index c6deaf6..385cc55 100644 --- a/envs/bp3.yaml +++ b/envs/bp3.yaml @@ -5,9 +5,6 @@ dependencies: - pip - python==3.8.8 - - - - pip: - bp3==0.0.12.7 - fair-esm==1.0.3 diff --git a/envs/env.yaml b/envs/env.yaml index 5a3a278..f7351c9 100644 --- a/envs/env.yaml +++ b/envs/env.yaml @@ -9,6 +9,10 @@ dependencies: - matplotlib - polars - biopython + - plotnine + - pyarrow + - matplotlib + - scikit-learn - pip: - torch diff --git a/notebooks/example_code.ipynb b/notebooks/example_code.ipynb deleted file mode 100644 index 903ad99..0000000 --- a/notebooks/example_code.ipynb +++ /dev/null @@ -1,284 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "bae1dc08", - "metadata": {}, - "source": [ - "## Environment:\n", - "\n", - "This notebook will run with the 'envs/env.yaml` environment (epident-experiments)" - ] - }, - { - "cell_type": "markdown", - "id": "257bb096", - "metadata": {}, - "source": [ - "## Bepipred 3 dataset\n", - "\n", - "- job_name: unique identifier for protein, comes from hash of seq\n", - "- seq: amino acid sequence of protein\n", - "- test: boolean indicating if seq is part of test set\n", - "- epitope_boolmask: boolean array the same length as seq indiciating if the AA at that position is an epitope residue\n", - "- raw_protein_id: original ID assigned to protein in BP3C50ID set\n", - "- RSA: relative solvent accessiblity of the protein at each AA, calculated by FreeSASA\n", - "- SA: absolute solvent accessibility of the protein at each AA, calculated by FreeSASA" - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "id": "777f08b2", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "shape: (358, 7)
job_nameseqtestepitope_boolmaskraw_protein_idRSASA
strstrboollist[bool]strlist[f64]list[f64]
"bf2a62534941cf895971e1daa33a46…"LIQTPSSLLVQTNHTAKMSCEVKSISKLTS…true[false, false, … false]"3b9k_B"[0.205823, 0.471213, … 1.001547][36.957627, 82.806331, … 152.205138]
"d4febd28417e8a4bf6266337c7a2de…"GNVDLVFLFDGSMSLQPDEFQKILDFMKDV…true[false, false, … false]"3hi6_A"[0.840245, 0.294451, … 0.605393][68.13546, 42.698276, … 129.669178]
"17d233a2b305a3544cf6c164f8ad67…"DERETWSGKVDFLLSVIGFAVDLANVWRFP…true[false, false, … false]"4xp9_C"[1.100846, 1.039373, … 1.03129][157.156786, 181.038055, … 185.178502]
"34e0c5de18ccd222f24d4bc9d0f0e4…"KAMHVAQPAVVLASSRGIASFVCEYASPGK…true[true, true, … false]"5ggv_Y"[0.731129, 0.872878, … 1.227995][149.866882, 94.934212, … 168.493256]
"f4c930a3f1b5fb78cef62c5021adc0…"GSHHHHHHGSGTDITNQLTNVTVGIDSGTT…true[false, false, … false]"5jq6_A"[1.462795, 0.796439, … 0.90302][118.618012, 94.250586, … 119.379304]
"2c282aeeb88596bf1f1f99be1bb7f0…"LDKIDLSYETTESGDTAVSEDSYDKYASQN…false[false, false, … false]"7jum_A"[0.474081, 0.908022, … 0.986187][85.126053, 129.629226, … 143.007018]
"5196520df0000bf1b3fafa8c0e9ecc…"TDRQLAEEYLYRYGYTRVASLGPALLLLQK…false[false, false, … false]"5th9_A"[0.871364, 0.432033, … 0.256134][122.513787, 61.677101, … 54.861323]
"96836e4358c57e3f571a4f2bb8a8f8…"LPWLNVSADGDNVHLVLNVSEEQHFGLSLY…false[false, false, … true]"6hga_B"[0.93068, 0.110532, … 1.220648][167.112849, 15.166135, … 223.341912]
"9d838eec0c24655e9902a3ac128a34…"CSSPPCECHQEEDFRVTCKDIQRIPSLPPS…false[false, false, … false]"2xwt_C"[0.479508, 1.059907, … 0.528113][63.390948, 125.429402, … 74.252646]
"cb56653d3f7b5272b7874963549242…"CSVVVGENYSIKCDATKCTIEDKNRGIIKT…false[false, false, … false]"6vtw_A"[0.628252, 0.682526, … 0.637367][83.054881, 80.770091, … 113.980313]
" - ], - "text/plain": [ - "shape: (358, 7)\n", - "┌───────────────┬──────────────┬───────┬──────────────┬──────────────┬──────────────┬──────────────┐\n", - "│ job_name ┆ seq ┆ test ┆ epitope_bool ┆ raw_protein_ ┆ RSA ┆ SA │\n", - "│ --- ┆ --- ┆ --- ┆ mask ┆ id ┆ --- ┆ --- │\n", - "│ str ┆ str ┆ bool ┆ --- ┆ --- ┆ list[f64] ┆ list[f64] │\n", - "│ ┆ ┆ ┆ list[bool] ┆ str ┆ ┆ │\n", - "╞═══════════════╪══════════════╪═══════╪══════════════╪══════════════╪══════════════╪══════════════╡\n", - "│ bf2a62534941c ┆ LIQTPSSLLVQT ┆ true ┆ [false, ┆ 3b9k_B ┆ [0.205823, ┆ [36.957627, │\n", - "│ f895971e1daa3 ┆ NHTAKMSCEVKS ┆ ┆ false, … ┆ ┆ 0.471213, … ┆ 82.806331, … │\n", - "│ 3a46… ┆ ISKLTS… ┆ ┆ false] ┆ ┆ 1.00154… ┆ 152.2… │\n", - "│ d4febd28417e8 ┆ GNVDLVFLFDGS ┆ true ┆ [false, ┆ 3hi6_A ┆ [0.840245, ┆ [68.13546, │\n", - "│ a4bf6266337c7 ┆ MSLQPDEFQKIL ┆ ┆ false, … ┆ ┆ 0.294451, … ┆ 42.698276, … │\n", - "│ a2de… ┆ DFMKDV… ┆ ┆ false] ┆ ┆ 0.60539… ┆ 129.66… │\n", - "│ 17d233a2b305a ┆ DERETWSGKVDF ┆ true ┆ [false, ┆ 4xp9_C ┆ [1.100846, ┆ [157.156786, │\n", - "│ 3544cf6c164f8 ┆ LLSVIGFAVDLA ┆ ┆ false, … ┆ ┆ 1.039373, … ┆ 181.038055, │\n", - "│ ad67… ┆ NVWRFP… ┆ ┆ false] ┆ ┆ 1.03129… ┆ … 185… │\n", - "│ 34e0c5de18ccd ┆ KAMHVAQPAVVL ┆ true ┆ [true, true, ┆ 5ggv_Y ┆ [0.731129, ┆ [149.866882, │\n", - "│ 222f24d4bc9d0 ┆ ASSRGIASFVCE ┆ ┆ … false] ┆ ┆ 0.872878, … ┆ 94.934212, … │\n", - "│ f0e4… ┆ YASPGK… ┆ ┆ ┆ ┆ 1.22799… ┆ 168.… │\n", - "│ f4c930a3f1b5f ┆ GSHHHHHHGSGT ┆ true ┆ [false, ┆ 5jq6_A ┆ [1.462795, ┆ [118.618012, │\n", - "│ b78cef62c5021 ┆ DITNQLTNVTVG ┆ ┆ false, … ┆ ┆ 0.796439, … ┆ 94.250586, … │\n", - "│ adc0… ┆ IDSGTT… ┆ ┆ false] ┆ ┆ 0.90302… ┆ 119.… │\n", - "│ … ┆ … ┆ … ┆ … ┆ … ┆ … ┆ … │\n", - "│ 2c282aeeb8859 ┆ LDKIDLSYETTE ┆ false ┆ [false, ┆ 7jum_A ┆ [0.474081, ┆ [85.126053, │\n", - "│ 6bf1f1f99be1b ┆ SGDTAVSEDSYD ┆ ┆ false, … ┆ ┆ 0.908022, … ┆ 129.629226, │\n", - "│ b7f0… ┆ KYASQN… ┆ ┆ false] ┆ ┆ 0.98618… ┆ … 143.… │\n", - "│ 5196520df0000 ┆ TDRQLAEEYLYR ┆ false ┆ [false, ┆ 5th9_A ┆ [0.871364, ┆ [122.513787, │\n", - "│ bf1b3fafa8c0e ┆ YGYTRVASLGPA ┆ ┆ false, … ┆ ┆ 0.432033, … ┆ 61.677101, … │\n", - "│ 9ecc… ┆ LLLLQK… ┆ ┆ false] ┆ ┆ 0.25613… ┆ 54.8… │\n", - "│ 96836e4358c57 ┆ LPWLNVSADGDN ┆ false ┆ [false, ┆ 6hga_B ┆ [0.93068, ┆ [167.112849, │\n", - "│ e3f571a4f2bb8 ┆ VHLVLNVSEEQH ┆ ┆ false, … ┆ ┆ 0.110532, … ┆ 15.166135, … │\n", - "│ a8f8… ┆ FGLSLY… ┆ ┆ true] ┆ ┆ 1.220648… ┆ 223.… │\n", - "│ 9d838eec0c246 ┆ CSSPPCECHQEE ┆ false ┆ [false, ┆ 2xwt_C ┆ [0.479508, ┆ [63.390948, │\n", - "│ 55e9902a3ac12 ┆ DFRVTCKDIQRI ┆ ┆ false, … ┆ ┆ 1.059907, … ┆ 125.429402, │\n", - "│ 8a34… ┆ PSLPPS… ┆ ┆ false] ┆ ┆ 0.52811… ┆ … 74.2… │\n", - "│ cb56653d3f7b5 ┆ CSVVVGENYSIK ┆ false ┆ [false, ┆ 6vtw_A ┆ [0.628252, ┆ [83.054881, │\n", - "│ 272b787496354 ┆ CDATKCTIEDKN ┆ ┆ false, … ┆ ┆ 0.682526, … ┆ 80.770091, … │\n", - "│ 9242… ┆ RGIIKT… ┆ ┆ false] ┆ ┆ 0.63736… ┆ 113.9… │\n", - "└───────────────┴──────────────┴───────┴──────────────┴──────────────┴──────────────┴──────────────┘" - ] - }, - "execution_count": 22, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "import polars as pl\n", - "\n", - "bp3 = pl.read_parquet(\"../data/bp3c50id/bp3c50id.rsa.parquet\")\n", - "\n", - "bp3" - ] - }, - { - "cell_type": "markdown", - "id": "49c40ec6", - "metadata": {}, - "source": [ - "## Getting structural embeddings for a protein\n" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "id": "d1cd1126", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "array([[-316. , -916. , -148. , ..., 114.5 , -175. , 56.75],\n", - " [-498. , -616. , -120.5 , ..., 198. , -235. , 105.5 ],\n", - " [-508. , -608. , -107. , ..., 288. , -241. , -6.03],\n", - " ...,\n", - " [-370. , -528. , -242. , ..., 163. , -460. , -227. ],\n", - " [-370. , -528. , -242. , ..., 163. , -460. , -227. ],\n", - " [-370. , -528. , -242. , ..., 163. , -460. , -227. ]],\n", - " shape=(256, 384), dtype=float16)" - ] - }, - "execution_count": 17, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "from mdaf3.AF3OutputParser import AF3Output\n", - "from pathlib import Path\n", - "\n", - "\n", - "INF_DIR = Path(\"../data/bp3c50id/inference\")\n", - "sample_job_name = bp3.select(\"job_name\")[0].item()\n", - "\n", - "af3_output = AF3Output(INF_DIR / sample_job_name)\n", - "\n", - "af3_single_embed = af3_output.get_single_embeddings()\n", - "af3_pairwise_embed = af3_output.get_pair_embeddings()\n", - "\n", - "af3_single_embed" - ] - }, - { - "cell_type": "markdown", - "id": "7f1cda64", - "metadata": {}, - "source": [ - "### The af3_output object can do a lot more:\n" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "id": "8e84ca95", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Alpha carbon pLDDT: [89.05000305 90.90000153 93.55000305 92.66999817 95.01999664 96.\n", - " 97.02999878 96.05000305 97.12000275 96.63999939 96.61000061 95.05999756\n", - " 93.01999664 93.45999908 92.69000244 93.5 93.80999756 93.55000305\n", - " 92.87999725 91.15000153 88.44000244 82.69999695 76.98000336 72.58000183\n", - " 68.90000153 72.23999786 69.37999725 71.36000061 77.56999969 84.30999756\n", - " 87.66999817 91.04000092 92.34999847 93.26999664 93. 92.87000275\n", - " 86.84999847 79.66000366 76.95999908 76.51000214 80.98999786 81.33999634\n", - " 84.69999695 89.47000122 90.79000092 90.59999847 90.86000061 90.16999817\n", - " 89.73999786 88.19000244 84.12999725 81.91999817 78.51000214 80.23999786\n", - " 81.58000183 82.37999725 84.91000366 85.86000061 87. 85.43000031\n", - " 80.48000336 74.48000336 74.86000061 60.11000061 59.43999863 62.54999924\n", - " 57.09000015 70.15000153 77.37999725 75.16000366 75.26999664 73.08999634\n", - " 64.12999725 57.47999954 55.90000153 61.22000122 75.45999908 83.11000061\n", - " 84.58000183 87.79000092 90.01999664 90.48999786 90.80000305 88.38999939\n", - " 89.47000122 91.26999664 91.04000092 91.87999725 89.41000366 90.81999969\n", - " 93.27999878 93.80999756 95.09999847 95.44000244 95.37999725 93.48999786\n", - " 92.26000214 88.79000092 80.73000336 69.51999664 64.80000305 66.72000122\n", - " 77.19999695 84.18000031 89.05000305 91.90000153 93.55000305 94.54000092\n", - " 96.01000214 96.37999725 96.51000214 95.43000031 96.23999786 95.08999634\n", - " 96.22000122 91.58999634 80.5 ]\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/lwoods/miniconda3/envs/epident-experiments/lib/python3.13/site-packages/MDAnalysis/coordinates/MMCIF.py:139: UserWarning: 1 A^3 CRYST1 record, this is usually a placeholder. Unit cell dimensions will be set to None.\n", - " warnings.warn(\n" - ] - } - ], - "source": [ - "u = af3_output.get_mda_universe()\n", - "\n", - "# select all alpha carbons in topology\n", - "calphas = u.select_atoms(\"name CA\")\n", - "\n", - "# get their plddt\n", - "print(f\"Alpha carbon pLDDT: {calphas.tempfactors}\")\n", - "\n", - "# get the contact probability array\n", - "contact_probs = af3_output.get_contact_prob_ndarr()" - ] - }, - { - "cell_type": "markdown", - "id": "91ae6e70", - "metadata": {}, - "source": [ - "## Getting LM embeddings for a protein\n" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "id": "ae7add4b", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[-7.9217e-02, -8.2230e-02, 5.8380e-02, ..., 2.4681e-01,\n", - " 9.6495e-02, 1.1700e+02],\n", - " [ 2.7191e-01, 1.3160e-01, -1.2749e-01, ..., 8.3813e-02,\n", - " 2.6999e-02, 1.1700e+02],\n", - " [ 7.5211e-02, -1.2474e-01, -3.1285e-01, ..., -7.0912e-02,\n", - " -1.3021e-01, 1.1700e+02],\n", - " ...,\n", - " [-9.3008e-02, 1.5062e-01, 3.5336e-01, ..., -3.2767e-01,\n", - " -1.1053e-01, 1.1700e+02],\n", - " [ 6.3777e-02, 1.2429e-01, 2.3989e-01, ..., -2.6909e-01,\n", - " 8.8695e-02, 1.1700e+02],\n", - " [ 7.8697e-02, -1.0143e-02, 3.3305e-01, ..., -1.6285e-01,\n", - " 1.1192e-01, 1.1700e+02]])" - ] - }, - "execution_count": 21, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "from pathlib import Path\n", - "import torch\n", - "\n", - "ESM_ENCODING_DIR = Path(\"/tgen_labs/altin/esm_encodings\")\n", - "\n", - "sample_job_name = bp3.select(\"job_name\")[0].item()\n", - "\n", - "esm_2_embed = torch.load(ESM_ENCODING_DIR / (sample_job_name + \".pt\"))\n", - "\n", - "esm_2_embed" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "epident-experiments", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.13.5" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/notebooks/regression.ipynb b/notebooks/regression.ipynb new file mode 100644 index 0000000..6c95435 --- /dev/null +++ b/notebooks/regression.ipynb @@ -0,0 +1,715 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "97313f50", + "metadata": {}, + "source": [ + "# Machine Learning on Embeddings for Epitope Prediction\n", + "\n", + "The goal of this notebook is to put together a basic machine learning pipeline that can make epitope predictions using embeddings from AF3 and ESM" + ] + }, + { + "cell_type": "markdown", + "id": "0fa5456c", + "metadata": {}, + "source": [ + "## Environment:\n", + "\n", + "This notebook will run with the 'envs/env.yaml` environment (epident-experiments)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "47295d86", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/jsesate/miniconda3/envs/epident-experiments/lib/python3.13/site-packages/MDAnalysis/coordinates/MMCIF.py:139: UserWarning: 1 A^3 CRYST1 record, this is usually a placeholder. Unit cell dimensions will be set to None.\n" + ] + } + ], + "source": [ + "###################\n", + "# --- Imports --- #\n", + "###################\n", + "\n", + "import pickle\n", + "import sys\n", + "import os\n", + "from pathlib import Path\n", + "\n", + "import polars as pl\n", + "import polars.selectors as cs\n", + "import pandas as pd\n", + "import torch\n", + "import numpy as np\n", + "\n", + "from mdaf3.AF3OutputParser import AF3Output\n", + "from MDAnalysis.lib import distances\n", + "import networkx as nx\n", + "\n", + "from sklearn.linear_model import LogisticRegression\n", + "from sklearn.ensemble import RandomForestClassifier\n", + "from sklearn.neural_network import MLPClassifier\n", + "from sklearn.model_selection import StratifiedKFold, KFold, GridSearchCV\n", + "from sklearn.preprocessing import StandardScaler \n", + "from sklearn.decomposition import PCA\n", + "from sklearn.metrics import roc_auc_score\n", + "\n", + "from plotnine import *\n", + "import matplotlib.pyplot as plt\n", + "theme_set(theme_classic())\n", + "\n", + "ESM_ENCODING_DIR = Path(\"/tgen_labs/altin/esm_encodings\")\n", + "INF_DIR = Path(\"../data/bp3c50id/inference\")\n", + "\n", + "NUM_ESM_EMB_VARS = 1280\n", + "NUM_AF3_EMB_VARS = 348\n", + "\n", + "#################################\n", + "# --- Import Bepipred3 Data --- #\n", + "#################################\n", + "# - job_name: unique identifier for protein, comes from hash of seq\n", + "# - seq: amino acid sequence of protein\n", + "# - train: boolean indicating if seq is part of train set\n", + "# - epitope_boolmask: boolean array the same length as seq indiciating if the AA at that position is an epitope residue\n", + "# - raw_protein_id: original ID assigned to protein in BP3C50ID set\n", + "# - RSA: relative solvent accessiblity of the protein at each AA, calculated by FreeSASA\n", + "# - SA: absolute solvent accessibility of the protein at each AA, calculated by FreeSASA\n", + "\n", + "bp3 = pl.read_parquet(\"../data/bp3c50id/bp3c50id.rsa.parquet\")\n", + "\n", + "# train and test labels were swapped\n", + "bp3 = bp3.rename({\"test\" : \"train\"})\n", + "bp3 = bp3.with_row_index()\n", + "\n", + "########################\n", + "# --- Num Residues --- #\n", + "########################\n", + "\n", + "if \"seq_len\" not in bp3.columns:\n", + " seq_lens = []\n", + " for cols in bp3.iter_rows(named=True):\n", + " seq_len = len(cols['seq'])\n", + " seq_lens.append(seq_len)\n", + " seq_lens = pl.Series(\"seq_len\", seq_lens)\n", + " bp3.insert_column(3, seq_lens)\n", + "\n", + "#####################\n", + "# --- AF3 PTM --- #\n", + "#####################\n", + "\n", + "if \"ptms\" not in bp3.columns:\n", + " ptms = []\n", + " for cols in bp3.iter_rows(named=True):\n", + " job_name = cols['job_name']\n", + " af3_output = AF3Output(INF_DIR / job_name)\n", + " ptm = af3_output.get_summary_metrics()['ptm']\n", + " ptms.append(ptm)\n", + " ptms = pl.Series(\"ptm\", ptms)\n", + " bp3.insert_column(1, ptms)\n", + "\n", + "#####################\n", + "# --- AF3 pLDDT --- #\n", + "#####################\n", + "\n", + "if \"ptms\" not in bp3.columns:\n", + " pLDDTs = []\n", + " for cols in bp3.iter_rows(named=True):\n", + " job_name = cols['job_name']\n", + " af3_output = AF3Output(INF_DIR / job_name)\n", + " u = af3_output.get_mda_universe()\n", + " ca_atoms = u.select_atoms(\"protein and name CA\")\n", + " cur_protein_pLDDTs = []\n", + " for residue in ca_atoms:\n", + " pLDDT = residue.tempfactor\n", + " cur_protein_pLDDTs.append(pLDDT)\n", + " pLDDTs.append(cur_protein_pLDDTs)\n", + " pLDDTs = pl.Series(\"pLDDT\", pLDDTs)\n", + " bp3.insert_column(1, pLDDTs)\n", + "\n", + "##########################\n", + "# --- ESM Embeddings --- #\n", + "##########################\n", + "\n", + "if \"esm_emb\" not in bp3.columns:\n", + " esm_embeddings = []\n", + " for cols in bp3.iter_rows(named=True):\n", + " job_name = cols['job_name']\n", + " # remove last column of embedding (sequence lengths)\n", + " embedding = torch.load(ESM_ENCODING_DIR / (job_name + \".pt\"))[:,:-1]\n", + " esm_embeddings.append(embedding)\n", + " esm_emb = pl.Series(\"esm_emb\", esm_embeddings)\n", + " bp3.insert_column(1, esm_emb)\n", + "\n", + "##########################\n", + "# --- AF3 Embeddings --- #\n", + "##########################\n", + "\n", + "if \"af_emb\" not in bp3.columns:\n", + " af_embeddings = []\n", + " for cols in bp3.iter_rows(named=True):\n", + " job_name = cols['job_name']\n", + " af3_output = AF3Output(INF_DIR / job_name)\n", + " af3_single_embed = af3_output.get_single_embeddings()\n", + " # remove alphafold embedding padding \n", + " num_tokens = cols['seq_len']\n", + " if af3_single_embed.shape[0] != num_tokens:\n", + " diff = int(af3_single_embed.shape[0] - num_tokens)\n", + " af_emb = af3_single_embed[:-diff,:]\n", + " af3_tensor = torch.from_numpy(af_emb)\n", + " af_embeddings.append(af3_tensor)\n", + " af_emb = pl.Series(\"af_emb\", af_embeddings)\n", + " bp3.insert_column(1, af_emb)\n", + " # TEMP: for some reason, row 192 has an esm embedding smaller than af3 embedding... \n", + " bp3 = bp3.filter(pl.col(\"index\") != 192).drop(\"index\")\n", + " bp3 = bp3.with_row_index()\n", + "\n", + "#########################################\n", + "# --- AF3 Protein Structure Network --- #\n", + "#########################################\n", + "# For each AF3 structure, we generalize a graph whose nodes are residues\n", + "# and whose edges are the distance between residues. Edges are only drawn when\n", + "# the probability of contact exceeds the threshold specified (0.5) AND the\n", + "# predicted alignment error falls below the cutoff (5).\n", + "\n", + "#cutoff_dist = 10 # Angstroms\n", + "cutoff_prob = 0.825 # Probability of Contact\n", + "cutoff_pae = 10 # Predicted Alignment Error\n", + "\n", + "if \"af_graph\" not in bp3.columns:\n", + " af_graphs = []\n", + " for cols in bp3.iter_rows(named=True):\n", + " job_name = cols['job_name']\n", + " af3_output = AF3Output(INF_DIR / job_name)\n", + " u = af3_output.get_mda_universe()\n", + "\n", + " ca_atoms = u.select_atoms('protein and name CA')\n", + " ca_positions = ca_atoms.positions\n", + "\n", + " dist_array_flat = distances.self_distance_array(ca_positions)\n", + " n_residues = len(ca_atoms)\n", + " distance_matrix = np.zeros((n_residues, n_residues))\n", + " triu_indices = np.triu_indices(n_residues, k=1)\n", + " distance_matrix[triu_indices] = dist_array_flat\n", + " distance_matrix.T[triu_indices] = dist_array_flat\n", + "\n", + " contact_probability_matrix = af3_output.get_contact_prob_ndarr()\n", + " pae_mtx = af3_output.get_pae_ndarr()\n", + "\n", + " cont_adj_mtx = (contact_probability_matrix >= cutoff_prob).astype(int)\n", + " pae_adj_mtx = (pae_mtx <= cutoff_pae).astype(int)\n", + " adj_mtx = distance_matrix * pae_adj_mtx * cont_adj_mtx\n", + " np.fill_diagonal(adj_mtx, 0)\n", + "\n", + " resids = ca_atoms.resids\n", + " G = nx.Graph()\n", + " for i, resid in enumerate(resids):\n", + " G.add_node(resid)\n", + " \n", + " af_graphs.append(G)\n", + "\n", + " # Iterate over the upper triangle of the adjacency matrix to find contacts (edges)\n", + " for i in range(n_residues):\n", + " for j in range(i + 1, n_residues):\n", + " if adj_mtx[i, j] > 1:\n", + " distance = distance_matrix[i, j]\n", + " G.add_edge(resids[i], resids[j], weight=distance)\n", + " af_graphs = pl.Series(\"af_graph\", af_graphs)\n", + " bp3.insert_column(1, af_graphs)\n", + "\n", + " \n", + "###############################################\n", + "# --- Analysis of AF3 Network's Structure --- #\n", + "###############################################\n", + "\n", + "closeness_centrality = []\n", + "betweenness_centrality = []\n", + "load_centrality = []\n", + "eigenvector_centrality = []\n", + "degree_centrality = []\n", + "clustering = []\n", + "coreness = []\n", + "triangles = []\n", + "density = []\n", + "lapl_n1 = []\n", + "lapl_f = []\n", + "induced_subgraphs = []\n", + "SUBGRAPH_DISTANCE_CUTOFF = 15\n", + "for cols in bp3.iter_rows(named=True):\n", + " af_graph = cols['af_graph']\n", + " closeness_centrality.append(pl.Series(list(nx.closeness_centrality(af_graph).values())))\n", + " betweenness_centrality.append(pl.Series('betweenness_centrality', list(nx.betweenness_centrality(af_graph).values())))\n", + " load_centrality.append(pl.Series('load_centrality', list(nx.load_centrality(af_graph).values())))\n", + " eigenvector_centrality.append(pl.Series('eigenvector_centrality', list(nx.eigenvector_centrality(af_graph, max_iter=10000).values())))\n", + " degree_centrality.append(pl.Series('degree_centrality', list(nx.degree_centrality(af_graph).values())))\n", + " clustering.append(pl.Series('clustering', list(nx.clustering(af_graph).values()), dtype=pl.Float64))\n", + " coreness.append(pl.Series('coreness', list(nx.core_number(af_graph).values()), dtype=pl.Float64))\n", + " triangles.append(pl.Series('triangles', list(nx.triangles(af_graph).values()), dtype=pl.Float64))\n", + " density.append(nx.density(af_graph))\n", + "\n", + " lapl_mtx = nx.laplacian_matrix(af_graph).toarray()\n", + " eigvals, _ = np.linalg.eig(lapl_mtx)\n", + " # remove near 0 eigvals to pull fielder val\n", + " eigvals = np.real(eigvals[np.abs(eigvals) >= 1e-3])\n", + " lapl_f.append(np.min(eigvals))\n", + " lapl_n1.append(np.max(eigvals))\n", + "closeness_centrality = pl.Series('closeness_centrality', closeness_centrality)\n", + "betweenness_centrality = pl.Series('betweenness_centrality', betweenness_centrality)\n", + "load_centrality = pl.Series('load_centrality', load_centrality)\n", + "eigenvector_centrality = pl.Series('eigenvector_centrality', eigenvector_centrality)\n", + "degree_centrality = pl.Series('degree_centrality', degree_centrality)\n", + "clustering = pl.Series('clustering', clustering)\n", + "coreness = pl.Series('coreness', coreness)\n", + "triangles = pl.Series('triangles', triangles)\n", + "density = pl.Series('density', density)\n", + "lapl_n1 = pl.Series(\"lapl_n1\", lapl_n1)\n", + "lapl_f = pl.Series(\"lapl_f\", lapl_f)\n", + "\n", + "graph_features = pl.DataFrame([\n", + " closeness_centrality, betweenness_centrality, load_centrality, eigenvector_centrality, degree_centrality, \n", + " clustering, coreness, triangles, density, lapl_f, lapl_n1]).with_row_index()\n", + "\n", + "if 'clustering' not in bp3.columns:\n", + " bp3 = bp3.join(graph_features, on='index', how='full').drop(['index_right'])\n", + "\n", + "##########################################\n", + "# --- Transform to Per-Residue Basis --- #\n", + "##########################################\n", + "\n", + "af_emb = []\n", + "esm_emb = []\n", + "af_graph = []\n", + "epitope = []\n", + "rsa = []\n", + "sa = []\n", + "closeness_centrality = []\n", + "betweenness_centrality = []\n", + "load_centrality = []\n", + "eigenvector_centrality = []\n", + "degree_centrality = []\n", + "clustering = []\n", + "coreness = []\n", + "triangles = []\n", + "pLDDT = []\n", + "ptm = []\n", + "job_name = []\n", + "seq = []\n", + "seq_len = []\n", + "train = []\n", + "raw_protein_id = []\n", + "density = []\n", + "lapl_f = []\n", + "lapl_n1 = []\n", + "bp3_res = bp3.drop(\"index\")\n", + "for cols in bp3.iter_rows(named=True):\n", + " # Residue Features\n", + " af_emb.extend(cols['af_emb'])\n", + " esm_emb.extend(cols['esm_emb'])\n", + " epitope.extend(cols['epitope_boolmask'])\n", + " rsa.extend(cols['RSA'])\n", + " sa.extend(cols['SA'])\n", + " closeness_centrality.extend(cols['closeness_centrality'])\n", + " betweenness_centrality.extend(cols['betweenness_centrality'])\n", + " load_centrality.extend(cols['load_centrality'])\n", + " eigenvector_centrality.extend(cols['eigenvector_centrality'])\n", + " degree_centrality.extend(cols['degree_centrality'])\n", + " clustering.extend(cols['clustering'])\n", + " coreness.extend(cols['coreness'])\n", + " triangles.extend(cols['triangles'])\n", + " pLDDT.extend(cols['pLDDT'])\n", + "\n", + " # Global Features\n", + " for repeats in range(cols['seq_len']):\n", + " af_graph.append(cols['af_graph'])\n", + " ptm.append(cols['ptm'])\n", + " job_name.append(cols['job_name'])\n", + " seq.append(cols['seq'])\n", + " seq_len.append(cols['seq_len'])\n", + " train.append(cols['train'])\n", + " raw_protein_id.append(cols['raw_protein_id'])\n", + " density.append(cols['density'])\n", + " lapl_f.append(cols['lapl_f'])\n", + " lapl_n1.append(cols['lapl_n1'])\n", + "\n", + "af_emb = pl.Series('af_emb', af_emb)\n", + "esm_emb = pl.Series('esm_emb', esm_emb)\n", + "af_graph = pl.Series('af_graph', af_graph)\n", + "epitope = pl.Series('epitope', epitope)\n", + "rsa = pl.Series('rsa', rsa)\n", + "sa = pl.Series('sa', sa)\n", + "closeness_centrality = pl.Series('closeness_centrality', closeness_centrality)\n", + "betweenness_centrality = pl.Series('betweenness_centrality', betweenness_centrality)\n", + "load_centrality = pl.Series('load_centrality', load_centrality)\n", + "eigenvector_centrality = pl.Series('eigenvector_centrality', eigenvector_centrality)\n", + "degree_centrality = pl.Series('degree_centrality', degree_centrality)\n", + "clustering = pl.Series('clustering', clustering)\n", + "coreness = pl.Series('coreness', coreness)\n", + "triangles = pl.Series('triangles', triangles)\n", + "pLDDT = pl.Series('pLDDT', pLDDT)\n", + "ptm = pl.Series('ptm', ptm)\n", + "job_name = pl.Series('job_name', job_name)\n", + "seq = pl.Series('seq', seq)\n", + "seq_len = pl.Series('seq_len', seq_len)\n", + "train = pl.Series('train', train)\n", + "raw_protein_id = pl.Series('raw_protein_id', raw_protein_id)\n", + "density = pl.Series('density', density)\n", + "lapl_f = pl.Series('lapl_f', lapl_f)\n", + "lapl_n1 = pl.Series('lapl_n1', lapl_n1)\n", + "\n", + "bp3_res = pl.DataFrame([\n", + " job_name, raw_protein_id, seq, seq_len, esm_emb, af_emb, af_graph,\n", + " closeness_centrality, betweenness_centrality, load_centrality, eigenvector_centrality,\n", + " degree_centrality, clustering, coreness, triangles, density, lapl_f, lapl_n1, \n", + " ptm, pLDDT, rsa, sa, epitope, train\n", + " ]).with_row_index()\n", + "\n", + "#####################################\n", + "# --- Explode Embedding Columns --- #\n", + "#####################################\n", + "\n", + "bp3_esm_res = bp3_res.select(\n", + " pl.col('index'),\n", + " pl.col(\"esm_emb\").map_batches(\n", + " lambda s: pl.Series(\n", + " np.stack([t.cpu().numpy() for t in s.to_list()]),\n", + " dtype=pl.List(pl.Float64)\n", + " ),\n", + " return_dtype=pl.List(pl.Float64)\n", + " )\n", + ").with_columns(\n", + " pl.col(\"esm_emb\").list.to_struct(\n", + " fields=[f\"esm_{i}\" for i in range(NUM_ESM_EMB_VARS)]\n", + " )\n", + ").unnest(\"esm_emb\")\n", + "\n", + "bp3_af3_res = bp3_res.select(\n", + " pl.col(\"index\"),\n", + " pl.col(\"af_emb\").map_batches(\n", + " lambda s: pl.Series(\n", + " np.stack([t.cpu().numpy() for t in s.to_list()]),\n", + " dtype=pl.List(pl.Float64)\n", + " ),\n", + " return_dtype=pl.List(pl.Float64)\n", + " )\n", + ").with_columns(\n", + " pl.col(\"af_emb\").list.to_struct(\n", + " fields=[f\"af3_{i}\" for i in range(NUM_AF3_EMB_VARS)]\n", + " )\n", + ").unnest(\"af_emb\")\n", + "\n", + "bp3_no_emb_res = bp3_res.drop(['af_emb', 'esm_emb'])\n", + "bp3_res = bp3_no_emb_res.join(bp3_esm_res, on='index', how='full').drop(['index_right'])\n", + "bp3_res = bp3_res.join(bp3_af3_res, on='index', how='full').drop(['index_right'])\n", + "bp3_df = bp3_res.drop('af_graph')\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "68686bd6", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Num Epitope Residues: 10811\n", + "Num Non-Epitope Residues: 70702\n", + "Column Names: ['index', 'job_name', 'raw_protein_id', 'seq', 'seq_len', 'closeness_centrality', 'betweenness_centrality', 'load_centrality', 'eigenvector_centrality', 'degree_centrality', 'clustering', 'coreness', 'triangles', 'density', 'lapl_f', 'lapl_n1', 'ptm', 'pLDDT', 'rsa', 'sa', 'epitope', 'train', 'esm_0', 'esm_1', 'esm_2', 'esm_3', 'esm_4', 'esm_5', 'esm_6', 'esm_7', 'esm_8', 'esm_9', 'esm_10', 'esm_11', 'esm_12', 'esm_13', 'esm_14', 'esm_15', 'esm_16', 'esm_17', 'esm_18', 'esm_19', 'esm_20', 'esm_21', 'esm_22', 'esm_23', 'esm_24', 'esm_25', 'esm_26', 'esm_27', 'esm_28', 'esm_29', 'esm_30', 'esm_31', 'esm_32', 'esm_33', 'esm_34', 'esm_35', 'esm_36', 'esm_37', 'esm_38', 'esm_39', 'esm_40', 'esm_41', 'esm_42', 'esm_43', 'esm_44', 'esm_45', 'esm_46', 'esm_47', 'esm_48', 'esm_49', 'esm_50', 'esm_51', 'esm_52', 'esm_53', 'esm_54', 'esm_55', 'esm_56', 'esm_57', 'esm_58', 'esm_59', 'esm_60', 'esm_61', 'esm_62', 'esm_63', 'esm_64', 'esm_65', 'esm_66', 'esm_67', 'esm_68', 'esm_69', 'esm_70', 'esm_71', 'esm_72', 'esm_73', 'esm_74', 'esm_75', 'esm_76', 'esm_77', 'esm_78', 'esm_79', 'esm_80', 'esm_81', 'esm_82', 'esm_83', 'esm_84', 'esm_85', 'esm_86', 'esm_87', 'esm_88', 'esm_89', 'esm_90', 'esm_91', 'esm_92', 'esm_93', 'esm_94', 'esm_95', 'esm_96', 'esm_97', 'esm_98', 'esm_99', 'esm_100', 'esm_101', 'esm_102', 'esm_103', 'esm_104', 'esm_105', 'esm_106', 'esm_107', 'esm_108', 'esm_109', 'esm_110', 'esm_111', 'esm_112', 'esm_113', 'esm_114', 'esm_115', 'esm_116', 'esm_117', 'esm_118', 'esm_119', 'esm_120', 'esm_121', 'esm_122', 'esm_123', 'esm_124', 'esm_125', 'esm_126', 'esm_127', 'esm_128', 'esm_129', 'esm_130', 'esm_131', 'esm_132', 'esm_133', 'esm_134', 'esm_135', 'esm_136', 'esm_137', 'esm_138', 'esm_139', 'esm_140', 'esm_141', 'esm_142', 'esm_143', 'esm_144', 'esm_145', 'esm_146', 'esm_147', 'esm_148', 'esm_149', 'esm_150', 'esm_151', 'esm_152', 'esm_153', 'esm_154', 'esm_155', 'esm_156', 'esm_157', 'esm_158', 'esm_159', 'esm_160', 'esm_161', 'esm_162', 'esm_163', 'esm_164', 'esm_165', 'esm_166', 'esm_167', 'esm_168', 'esm_169', 'esm_170', 'esm_171', 'esm_172', 'esm_173', 'esm_174', 'esm_175', 'esm_176', 'esm_177', 'esm_178', 'esm_179', 'esm_180', 'esm_181', 'esm_182', 'esm_183', 'esm_184', 'esm_185', 'esm_186', 'esm_187', 'esm_188', 'esm_189', 'esm_190', 'esm_191', 'esm_192', 'esm_193', 'esm_194', 'esm_195', 'esm_196', 'esm_197', 'esm_198', 'esm_199', 'esm_200', 'esm_201', 'esm_202', 'esm_203', 'esm_204', 'esm_205', 'esm_206', 'esm_207', 'esm_208', 'esm_209', 'esm_210', 'esm_211', 'esm_212', 'esm_213', 'esm_214', 'esm_215', 'esm_216', 'esm_217', 'esm_218', 'esm_219', 'esm_220', 'esm_221', 'esm_222', 'esm_223', 'esm_224', 'esm_225', 'esm_226', 'esm_227', 'esm_228', 'esm_229', 'esm_230', 'esm_231', 'esm_232', 'esm_233', 'esm_234', 'esm_235', 'esm_236', 'esm_237', 'esm_238', 'esm_239', 'esm_240', 'esm_241', 'esm_242', 'esm_243', 'esm_244', 'esm_245', 'esm_246', 'esm_247', 'esm_248', 'esm_249', 'esm_250', 'esm_251', 'esm_252', 'esm_253', 'esm_254', 'esm_255', 'esm_256', 'esm_257', 'esm_258', 'esm_259', 'esm_260', 'esm_261', 'esm_262', 'esm_263', 'esm_264', 'esm_265', 'esm_266', 'esm_267', 'esm_268', 'esm_269', 'esm_270', 'esm_271', 'esm_272', 'esm_273', 'esm_274', 'esm_275', 'esm_276', 'esm_277', 'esm_278', 'esm_279', 'esm_280', 'esm_281', 'esm_282', 'esm_283', 'esm_284', 'esm_285', 'esm_286', 'esm_287', 'esm_288', 'esm_289', 'esm_290', 'esm_291', 'esm_292', 'esm_293', 'esm_294', 'esm_295', 'esm_296', 'esm_297', 'esm_298', 'esm_299', 'esm_300', 'esm_301', 'esm_302', 'esm_303', 'esm_304', 'esm_305', 'esm_306', 'esm_307', 'esm_308', 'esm_309', 'esm_310', 'esm_311', 'esm_312', 'esm_313', 'esm_314', 'esm_315', 'esm_316', 'esm_317', 'esm_318', 'esm_319', 'esm_320', 'esm_321', 'esm_322', 'esm_323', 'esm_324', 'esm_325', 'esm_326', 'esm_327', 'esm_328', 'esm_329', 'esm_330', 'esm_331', 'esm_332', 'esm_333', 'esm_334', 'esm_335', 'esm_336', 'esm_337', 'esm_338', 'esm_339', 'esm_340', 'esm_341', 'esm_342', 'esm_343', 'esm_344', 'esm_345', 'esm_346', 'esm_347', 'esm_348', 'esm_349', 'esm_350', 'esm_351', 'esm_352', 'esm_353', 'esm_354', 'esm_355', 'esm_356', 'esm_357', 'esm_358', 'esm_359', 'esm_360', 'esm_361', 'esm_362', 'esm_363', 'esm_364', 'esm_365', 'esm_366', 'esm_367', 'esm_368', 'esm_369', 'esm_370', 'esm_371', 'esm_372', 'esm_373', 'esm_374', 'esm_375', 'esm_376', 'esm_377', 'esm_378', 'esm_379', 'esm_380', 'esm_381', 'esm_382', 'esm_383', 'esm_384', 'esm_385', 'esm_386', 'esm_387', 'esm_388', 'esm_389', 'esm_390', 'esm_391', 'esm_392', 'esm_393', 'esm_394', 'esm_395', 'esm_396', 'esm_397', 'esm_398', 'esm_399', 'esm_400', 'esm_401', 'esm_402', 'esm_403', 'esm_404', 'esm_405', 'esm_406', 'esm_407', 'esm_408', 'esm_409', 'esm_410', 'esm_411', 'esm_412', 'esm_413', 'esm_414', 'esm_415', 'esm_416', 'esm_417', 'esm_418', 'esm_419', 'esm_420', 'esm_421', 'esm_422', 'esm_423', 'esm_424', 'esm_425', 'esm_426', 'esm_427', 'esm_428', 'esm_429', 'esm_430', 'esm_431', 'esm_432', 'esm_433', 'esm_434', 'esm_435', 'esm_436', 'esm_437', 'esm_438', 'esm_439', 'esm_440', 'esm_441', 'esm_442', 'esm_443', 'esm_444', 'esm_445', 'esm_446', 'esm_447', 'esm_448', 'esm_449', 'esm_450', 'esm_451', 'esm_452', 'esm_453', 'esm_454', 'esm_455', 'esm_456', 'esm_457', 'esm_458', 'esm_459', 'esm_460', 'esm_461', 'esm_462', 'esm_463', 'esm_464', 'esm_465', 'esm_466', 'esm_467', 'esm_468', 'esm_469', 'esm_470', 'esm_471', 'esm_472', 'esm_473', 'esm_474', 'esm_475', 'esm_476', 'esm_477', 'esm_478', 'esm_479', 'esm_480', 'esm_481', 'esm_482', 'esm_483', 'esm_484', 'esm_485', 'esm_486', 'esm_487', 'esm_488', 'esm_489', 'esm_490', 'esm_491', 'esm_492', 'esm_493', 'esm_494', 'esm_495', 'esm_496', 'esm_497', 'esm_498', 'esm_499', 'esm_500', 'esm_501', 'esm_502', 'esm_503', 'esm_504', 'esm_505', 'esm_506', 'esm_507', 'esm_508', 'esm_509', 'esm_510', 'esm_511', 'esm_512', 'esm_513', 'esm_514', 'esm_515', 'esm_516', 'esm_517', 'esm_518', 'esm_519', 'esm_520', 'esm_521', 'esm_522', 'esm_523', 'esm_524', 'esm_525', 'esm_526', 'esm_527', 'esm_528', 'esm_529', 'esm_530', 'esm_531', 'esm_532', 'esm_533', 'esm_534', 'esm_535', 'esm_536', 'esm_537', 'esm_538', 'esm_539', 'esm_540', 'esm_541', 'esm_542', 'esm_543', 'esm_544', 'esm_545', 'esm_546', 'esm_547', 'esm_548', 'esm_549', 'esm_550', 'esm_551', 'esm_552', 'esm_553', 'esm_554', 'esm_555', 'esm_556', 'esm_557', 'esm_558', 'esm_559', 'esm_560', 'esm_561', 'esm_562', 'esm_563', 'esm_564', 'esm_565', 'esm_566', 'esm_567', 'esm_568', 'esm_569', 'esm_570', 'esm_571', 'esm_572', 'esm_573', 'esm_574', 'esm_575', 'esm_576', 'esm_577', 'esm_578', 'esm_579', 'esm_580', 'esm_581', 'esm_582', 'esm_583', 'esm_584', 'esm_585', 'esm_586', 'esm_587', 'esm_588', 'esm_589', 'esm_590', 'esm_591', 'esm_592', 'esm_593', 'esm_594', 'esm_595', 'esm_596', 'esm_597', 'esm_598', 'esm_599', 'esm_600', 'esm_601', 'esm_602', 'esm_603', 'esm_604', 'esm_605', 'esm_606', 'esm_607', 'esm_608', 'esm_609', 'esm_610', 'esm_611', 'esm_612', 'esm_613', 'esm_614', 'esm_615', 'esm_616', 'esm_617', 'esm_618', 'esm_619', 'esm_620', 'esm_621', 'esm_622', 'esm_623', 'esm_624', 'esm_625', 'esm_626', 'esm_627', 'esm_628', 'esm_629', 'esm_630', 'esm_631', 'esm_632', 'esm_633', 'esm_634', 'esm_635', 'esm_636', 'esm_637', 'esm_638', 'esm_639', 'esm_640', 'esm_641', 'esm_642', 'esm_643', 'esm_644', 'esm_645', 'esm_646', 'esm_647', 'esm_648', 'esm_649', 'esm_650', 'esm_651', 'esm_652', 'esm_653', 'esm_654', 'esm_655', 'esm_656', 'esm_657', 'esm_658', 'esm_659', 'esm_660', 'esm_661', 'esm_662', 'esm_663', 'esm_664', 'esm_665', 'esm_666', 'esm_667', 'esm_668', 'esm_669', 'esm_670', 'esm_671', 'esm_672', 'esm_673', 'esm_674', 'esm_675', 'esm_676', 'esm_677', 'esm_678', 'esm_679', 'esm_680', 'esm_681', 'esm_682', 'esm_683', 'esm_684', 'esm_685', 'esm_686', 'esm_687', 'esm_688', 'esm_689', 'esm_690', 'esm_691', 'esm_692', 'esm_693', 'esm_694', 'esm_695', 'esm_696', 'esm_697', 'esm_698', 'esm_699', 'esm_700', 'esm_701', 'esm_702', 'esm_703', 'esm_704', 'esm_705', 'esm_706', 'esm_707', 'esm_708', 'esm_709', 'esm_710', 'esm_711', 'esm_712', 'esm_713', 'esm_714', 'esm_715', 'esm_716', 'esm_717', 'esm_718', 'esm_719', 'esm_720', 'esm_721', 'esm_722', 'esm_723', 'esm_724', 'esm_725', 'esm_726', 'esm_727', 'esm_728', 'esm_729', 'esm_730', 'esm_731', 'esm_732', 'esm_733', 'esm_734', 'esm_735', 'esm_736', 'esm_737', 'esm_738', 'esm_739', 'esm_740', 'esm_741', 'esm_742', 'esm_743', 'esm_744', 'esm_745', 'esm_746', 'esm_747', 'esm_748', 'esm_749', 'esm_750', 'esm_751', 'esm_752', 'esm_753', 'esm_754', 'esm_755', 'esm_756', 'esm_757', 'esm_758', 'esm_759', 'esm_760', 'esm_761', 'esm_762', 'esm_763', 'esm_764', 'esm_765', 'esm_766', 'esm_767', 'esm_768', 'esm_769', 'esm_770', 'esm_771', 'esm_772', 'esm_773', 'esm_774', 'esm_775', 'esm_776', 'esm_777', 'esm_778', 'esm_779', 'esm_780', 'esm_781', 'esm_782', 'esm_783', 'esm_784', 'esm_785', 'esm_786', 'esm_787', 'esm_788', 'esm_789', 'esm_790', 'esm_791', 'esm_792', 'esm_793', 'esm_794', 'esm_795', 'esm_796', 'esm_797', 'esm_798', 'esm_799', 'esm_800', 'esm_801', 'esm_802', 'esm_803', 'esm_804', 'esm_805', 'esm_806', 'esm_807', 'esm_808', 'esm_809', 'esm_810', 'esm_811', 'esm_812', 'esm_813', 'esm_814', 'esm_815', 'esm_816', 'esm_817', 'esm_818', 'esm_819', 'esm_820', 'esm_821', 'esm_822', 'esm_823', 'esm_824', 'esm_825', 'esm_826', 'esm_827', 'esm_828', 'esm_829', 'esm_830', 'esm_831', 'esm_832', 'esm_833', 'esm_834', 'esm_835', 'esm_836', 'esm_837', 'esm_838', 'esm_839', 'esm_840', 'esm_841', 'esm_842', 'esm_843', 'esm_844', 'esm_845', 'esm_846', 'esm_847', 'esm_848', 'esm_849', 'esm_850', 'esm_851', 'esm_852', 'esm_853', 'esm_854', 'esm_855', 'esm_856', 'esm_857', 'esm_858', 'esm_859', 'esm_860', 'esm_861', 'esm_862', 'esm_863', 'esm_864', 'esm_865', 'esm_866', 'esm_867', 'esm_868', 'esm_869', 'esm_870', 'esm_871', 'esm_872', 'esm_873', 'esm_874', 'esm_875', 'esm_876', 'esm_877', 'esm_878', 'esm_879', 'esm_880', 'esm_881', 'esm_882', 'esm_883', 'esm_884', 'esm_885', 'esm_886', 'esm_887', 'esm_888', 'esm_889', 'esm_890', 'esm_891', 'esm_892', 'esm_893', 'esm_894', 'esm_895', 'esm_896', 'esm_897', 'esm_898', 'esm_899', 'esm_900', 'esm_901', 'esm_902', 'esm_903', 'esm_904', 'esm_905', 'esm_906', 'esm_907', 'esm_908', 'esm_909', 'esm_910', 'esm_911', 'esm_912', 'esm_913', 'esm_914', 'esm_915', 'esm_916', 'esm_917', 'esm_918', 'esm_919', 'esm_920', 'esm_921', 'esm_922', 'esm_923', 'esm_924', 'esm_925', 'esm_926', 'esm_927', 'esm_928', 'esm_929', 'esm_930', 'esm_931', 'esm_932', 'esm_933', 'esm_934', 'esm_935', 'esm_936', 'esm_937', 'esm_938', 'esm_939', 'esm_940', 'esm_941', 'esm_942', 'esm_943', 'esm_944', 'esm_945', 'esm_946', 'esm_947', 'esm_948', 'esm_949', 'esm_950', 'esm_951', 'esm_952', 'esm_953', 'esm_954', 'esm_955', 'esm_956', 'esm_957', 'esm_958', 'esm_959', 'esm_960', 'esm_961', 'esm_962', 'esm_963', 'esm_964', 'esm_965', 'esm_966', 'esm_967', 'esm_968', 'esm_969', 'esm_970', 'esm_971', 'esm_972', 'esm_973', 'esm_974', 'esm_975', 'esm_976', 'esm_977', 'esm_978', 'esm_979', 'esm_980', 'esm_981', 'esm_982', 'esm_983', 'esm_984', 'esm_985', 'esm_986', 'esm_987', 'esm_988', 'esm_989', 'esm_990', 'esm_991', 'esm_992', 'esm_993', 'esm_994', 'esm_995', 'esm_996', 'esm_997', 'esm_998', 'esm_999', 'esm_1000', 'esm_1001', 'esm_1002', 'esm_1003', 'esm_1004', 'esm_1005', 'esm_1006', 'esm_1007', 'esm_1008', 'esm_1009', 'esm_1010', 'esm_1011', 'esm_1012', 'esm_1013', 'esm_1014', 'esm_1015', 'esm_1016', 'esm_1017', 'esm_1018', 'esm_1019', 'esm_1020', 'esm_1021', 'esm_1022', 'esm_1023', 'esm_1024', 'esm_1025', 'esm_1026', 'esm_1027', 'esm_1028', 'esm_1029', 'esm_1030', 'esm_1031', 'esm_1032', 'esm_1033', 'esm_1034', 'esm_1035', 'esm_1036', 'esm_1037', 'esm_1038', 'esm_1039', 'esm_1040', 'esm_1041', 'esm_1042', 'esm_1043', 'esm_1044', 'esm_1045', 'esm_1046', 'esm_1047', 'esm_1048', 'esm_1049', 'esm_1050', 'esm_1051', 'esm_1052', 'esm_1053', 'esm_1054', 'esm_1055', 'esm_1056', 'esm_1057', 'esm_1058', 'esm_1059', 'esm_1060', 'esm_1061', 'esm_1062', 'esm_1063', 'esm_1064', 'esm_1065', 'esm_1066', 'esm_1067', 'esm_1068', 'esm_1069', 'esm_1070', 'esm_1071', 'esm_1072', 'esm_1073', 'esm_1074', 'esm_1075', 'esm_1076', 'esm_1077', 'esm_1078', 'esm_1079', 'esm_1080', 'esm_1081', 'esm_1082', 'esm_1083', 'esm_1084', 'esm_1085', 'esm_1086', 'esm_1087', 'esm_1088', 'esm_1089', 'esm_1090', 'esm_1091', 'esm_1092', 'esm_1093', 'esm_1094', 'esm_1095', 'esm_1096', 'esm_1097', 'esm_1098', 'esm_1099', 'esm_1100', 'esm_1101', 'esm_1102', 'esm_1103', 'esm_1104', 'esm_1105', 'esm_1106', 'esm_1107', 'esm_1108', 'esm_1109', 'esm_1110', 'esm_1111', 'esm_1112', 'esm_1113', 'esm_1114', 'esm_1115', 'esm_1116', 'esm_1117', 'esm_1118', 'esm_1119', 'esm_1120', 'esm_1121', 'esm_1122', 'esm_1123', 'esm_1124', 'esm_1125', 'esm_1126', 'esm_1127', 'esm_1128', 'esm_1129', 'esm_1130', 'esm_1131', 'esm_1132', 'esm_1133', 'esm_1134', 'esm_1135', 'esm_1136', 'esm_1137', 'esm_1138', 'esm_1139', 'esm_1140', 'esm_1141', 'esm_1142', 'esm_1143', 'esm_1144', 'esm_1145', 'esm_1146', 'esm_1147', 'esm_1148', 'esm_1149', 'esm_1150', 'esm_1151', 'esm_1152', 'esm_1153', 'esm_1154', 'esm_1155', 'esm_1156', 'esm_1157', 'esm_1158', 'esm_1159', 'esm_1160', 'esm_1161', 'esm_1162', 'esm_1163', 'esm_1164', 'esm_1165', 'esm_1166', 'esm_1167', 'esm_1168', 'esm_1169', 'esm_1170', 'esm_1171', 'esm_1172', 'esm_1173', 'esm_1174', 'esm_1175', 'esm_1176', 'esm_1177', 'esm_1178', 'esm_1179', 'esm_1180', 'esm_1181', 'esm_1182', 'esm_1183', 'esm_1184', 'esm_1185', 'esm_1186', 'esm_1187', 'esm_1188', 'esm_1189', 'esm_1190', 'esm_1191', 'esm_1192', 'esm_1193', 'esm_1194', 'esm_1195', 'esm_1196', 'esm_1197', 'esm_1198', 'esm_1199', 'esm_1200', 'esm_1201', 'esm_1202', 'esm_1203', 'esm_1204', 'esm_1205', 'esm_1206', 'esm_1207', 'esm_1208', 'esm_1209', 'esm_1210', 'esm_1211', 'esm_1212', 'esm_1213', 'esm_1214', 'esm_1215', 'esm_1216', 'esm_1217', 'esm_1218', 'esm_1219', 'esm_1220', 'esm_1221', 'esm_1222', 'esm_1223', 'esm_1224', 'esm_1225', 'esm_1226', 'esm_1227', 'esm_1228', 'esm_1229', 'esm_1230', 'esm_1231', 'esm_1232', 'esm_1233', 'esm_1234', 'esm_1235', 'esm_1236', 'esm_1237', 'esm_1238', 'esm_1239', 'esm_1240', 'esm_1241', 'esm_1242', 'esm_1243', 'esm_1244', 'esm_1245', 'esm_1246', 'esm_1247', 'esm_1248', 'esm_1249', 'esm_1250', 'esm_1251', 'esm_1252', 'esm_1253', 'esm_1254', 'esm_1255', 'esm_1256', 'esm_1257', 'esm_1258', 'esm_1259', 'esm_1260', 'esm_1261', 'esm_1262', 'esm_1263', 'esm_1264', 'esm_1265', 'esm_1266', 'esm_1267', 'esm_1268', 'esm_1269', 'esm_1270', 'esm_1271', 'esm_1272', 'esm_1273', 'esm_1274', 'esm_1275', 'esm_1276', 'esm_1277', 'esm_1278', 'esm_1279', 'af3_0', 'af3_1', 'af3_2', 'af3_3', 'af3_4', 'af3_5', 'af3_6', 'af3_7', 'af3_8', 'af3_9', 'af3_10', 'af3_11', 'af3_12', 'af3_13', 'af3_14', 'af3_15', 'af3_16', 'af3_17', 'af3_18', 'af3_19', 'af3_20', 'af3_21', 'af3_22', 'af3_23', 'af3_24', 'af3_25', 'af3_26', 'af3_27', 'af3_28', 'af3_29', 'af3_30', 'af3_31', 'af3_32', 'af3_33', 'af3_34', 'af3_35', 'af3_36', 'af3_37', 'af3_38', 'af3_39', 'af3_40', 'af3_41', 'af3_42', 'af3_43', 'af3_44', 'af3_45', 'af3_46', 'af3_47', 'af3_48', 'af3_49', 'af3_50', 'af3_51', 'af3_52', 'af3_53', 'af3_54', 'af3_55', 'af3_56', 'af3_57', 'af3_58', 'af3_59', 'af3_60', 'af3_61', 'af3_62', 'af3_63', 'af3_64', 'af3_65', 'af3_66', 'af3_67', 'af3_68', 'af3_69', 'af3_70', 'af3_71', 'af3_72', 'af3_73', 'af3_74', 'af3_75', 'af3_76', 'af3_77', 'af3_78', 'af3_79', 'af3_80', 'af3_81', 'af3_82', 'af3_83', 'af3_84', 'af3_85', 'af3_86', 'af3_87', 'af3_88', 'af3_89', 'af3_90', 'af3_91', 'af3_92', 'af3_93', 'af3_94', 'af3_95', 'af3_96', 'af3_97', 'af3_98', 'af3_99', 'af3_100', 'af3_101', 'af3_102', 'af3_103', 'af3_104', 'af3_105', 'af3_106', 'af3_107', 'af3_108', 'af3_109', 'af3_110', 'af3_111', 'af3_112', 'af3_113', 'af3_114', 'af3_115', 'af3_116', 'af3_117', 'af3_118', 'af3_119', 'af3_120', 'af3_121', 'af3_122', 'af3_123', 'af3_124', 'af3_125', 'af3_126', 'af3_127', 'af3_128', 'af3_129', 'af3_130', 'af3_131', 'af3_132', 'af3_133', 'af3_134', 'af3_135', 'af3_136', 'af3_137', 'af3_138', 'af3_139', 'af3_140', 'af3_141', 'af3_142', 'af3_143', 'af3_144', 'af3_145', 'af3_146', 'af3_147', 'af3_148', 'af3_149', 'af3_150', 'af3_151', 'af3_152', 'af3_153', 'af3_154', 'af3_155', 'af3_156', 'af3_157', 'af3_158', 'af3_159', 'af3_160', 'af3_161', 'af3_162', 'af3_163', 'af3_164', 'af3_165', 'af3_166', 'af3_167', 'af3_168', 'af3_169', 'af3_170', 'af3_171', 'af3_172', 'af3_173', 'af3_174', 'af3_175', 'af3_176', 'af3_177', 'af3_178', 'af3_179', 'af3_180', 'af3_181', 'af3_182', 'af3_183', 'af3_184', 'af3_185', 'af3_186', 'af3_187', 'af3_188', 'af3_189', 'af3_190', 'af3_191', 'af3_192', 'af3_193', 'af3_194', 'af3_195', 'af3_196', 'af3_197', 'af3_198', 'af3_199', 'af3_200', 'af3_201', 'af3_202', 'af3_203', 'af3_204', 'af3_205', 'af3_206', 'af3_207', 'af3_208', 'af3_209', 'af3_210', 'af3_211', 'af3_212', 'af3_213', 'af3_214', 'af3_215', 'af3_216', 'af3_217', 'af3_218', 'af3_219', 'af3_220', 'af3_221', 'af3_222', 'af3_223', 'af3_224', 'af3_225', 'af3_226', 'af3_227', 'af3_228', 'af3_229', 'af3_230', 'af3_231', 'af3_232', 'af3_233', 'af3_234', 'af3_235', 'af3_236', 'af3_237', 'af3_238', 'af3_239', 'af3_240', 'af3_241', 'af3_242', 'af3_243', 'af3_244', 'af3_245', 'af3_246', 'af3_247', 'af3_248', 'af3_249', 'af3_250', 'af3_251', 'af3_252', 'af3_253', 'af3_254', 'af3_255', 'af3_256', 'af3_257', 'af3_258', 'af3_259', 'af3_260', 'af3_261', 'af3_262', 'af3_263', 'af3_264', 'af3_265', 'af3_266', 'af3_267', 'af3_268', 'af3_269', 'af3_270', 'af3_271', 'af3_272', 'af3_273', 'af3_274', 'af3_275', 'af3_276', 'af3_277', 'af3_278', 'af3_279', 'af3_280', 'af3_281', 'af3_282', 'af3_283', 'af3_284', 'af3_285', 'af3_286', 'af3_287', 'af3_288', 'af3_289', 'af3_290', 'af3_291', 'af3_292', 'af3_293', 'af3_294', 'af3_295', 'af3_296', 'af3_297', 'af3_298', 'af3_299', 'af3_300', 'af3_301', 'af3_302', 'af3_303', 'af3_304', 'af3_305', 'af3_306', 'af3_307', 'af3_308', 'af3_309', 'af3_310', 'af3_311', 'af3_312', 'af3_313', 'af3_314', 'af3_315', 'af3_316', 'af3_317', 'af3_318', 'af3_319', 'af3_320', 'af3_321', 'af3_322', 'af3_323', 'af3_324', 'af3_325', 'af3_326', 'af3_327', 'af3_328', 'af3_329', 'af3_330', 'af3_331', 'af3_332', 'af3_333', 'af3_334', 'af3_335', 'af3_336', 'af3_337', 'af3_338', 'af3_339', 'af3_340', 'af3_341', 'af3_342', 'af3_343', 'af3_344', 'af3_345', 'af3_346', 'af3_347']\n" + ] + } + ], + "source": [ + "##########################\n", + "# --- Save DataFrame --- #\n", + "##########################\n", + "\n", + "#with open('bp3_pae10.pkl', 'wb') as f:\n", + "# pickle.dump(bp3_df, f)\n", + "\n", + "with open(\"bp3_pae10.pkl\", 'rb') as f:\n", + " bp3_all_df = pickle.load(f)\n", + "bp3_train_df = bp3_all_df.filter(pl.col(\"train\") == True)\n", + "print(f\"Num Epitope Residues: {len(bp3_train_df.filter(pl.col(\"epitope\") == True))}\")\n", + "print(f\"Num Non-Epitope Residues: {len(bp3_train_df.filter(pl.col(\"epitope\") == False))}\")\n", + "\n", + "print(f\"Column Names: {bp3_all_df.columns}\")\n", + "bp3_test_df = bp3_all_df.filter(pl.col(\"train\") == False)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d3becb0b", + "metadata": {}, + "outputs": [], + "source": [ + "#######################################################################\n", + "# --- Visualize Variable Distributions for Epitope vs. NonEpitope --- #\n", + "#######################################################################\n", + "\n", + "bp3_plot = bp3_train_df.to_pandas()\n", + "plot_var = \"pLDDT\"\n", + "( #\n", + "ggplot(aes(x = bp3_plot[\"epitope\"], y = bp3_plot[plot_var]))\n", + "+ geom_boxplot() \n", + "+ labs(\n", + " x = \"Epitope Status\",\n", + " y = f\"{plot_var}\"\n", + ")\n", + "#+ geom_jitter()\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0a766107", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Num Features: 1632\n" + ] + } + ], + "source": [ + "#############################\n", + "# --- Feature Selection --- #\n", + "#############################\n", + "\n", + "agg_features = []\n", + "\n", + "for emb in range(NUM_ESM_EMB_VARS):\n", + " esm_vars = f\"esm_{emb}\"\n", + " agg_features.append(esm_vars)\n", + "\n", + "for emb in range(NUM_AF3_EMB_VARS):\n", + " af3_vars = f\"af3_{emb}\"\n", + " agg_features.append(af3_vars)\n", + "\n", + "#agg_features.extend([\n", + "# 'closeness_centrality', 'betweenness_centrality', 'load_centrality', \n", + "# 'eigenvector_centrality', 'degree_centrality', 'clustering', \n", + "# 'coreness', 'triangles', 'density', 'lapl_n1', 'lapl_f'])\n", + "agg_features.append(\"seq_len\") \n", + "#agg_features.append('pLDDT')\n", + "agg_features.append(\"ptm\") \n", + "agg_features.append(\"rsa\")\n", + "agg_features.append(\"sa\")\n", + "\n", + "print(f\"Num Features: {len(agg_features)}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 187, + "id": "54f52a38", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "--- Cross-Validation Fold Details ---\n", + "Fold 1: Train AUC = 0.7720, Test AUC = 0.7450\n", + "Fold 2: Train AUC = 0.7781, Test AUC = 0.7380\n", + "Fold 3: Train AUC = 0.7792, Test AUC = 0.7384\n", + "Fold 4: Train AUC = 0.7748, Test AUC = 0.7583\n", + "Fold 5: Train AUC = 0.7717, Test AUC = 0.7499\n", + "\n", + "--- Overfitting Check ---\n", + "PCA Variance Kept: 95.0% with 1103.6000 (+/- 3.3823) components\n", + "Average Training AUC across folds: 0.7752 (+/- 0.0031)\n", + "Average Test (Validation) AUC across folds: 0.7459 (+/- 0.0076)\n" + ] + } + ], + "source": [ + "###################################\n", + "# --- 5-Fold Cross Validation --- #\n", + "###################################\n", + "\n", + "train_df = bp3_train_df.to_pandas()\n", + "X_df = train_df[agg_features]\n", + "y_df = train_df[\"epitope\"]\n", + "\n", + "X = X_df.values\n", + "y = y_df.values\n", + "\n", + "n_splits = 5\n", + "cv = KFold(\n", + " n_splits=n_splits, \n", + " shuffle=False,\n", + " #random_state=11\n", + " )\n", + "\n", + "train_auc_scores = []\n", + "test_auc_scores = []\n", + "components = []\n", + "\n", + "print(\"--- Cross-Validation Fold Details ---\")\n", + "for fold, (train_index, test_index) in enumerate(cv.split(X, y)):\n", + "\n", + " # --- Cross Validation ---\n", + " X_train, X_test = X[train_index], X[test_index]\n", + " y_train, y_test = y[train_index], y[test_index]\n", + "\n", + " # --- Scale Features (Required for PCA) ---\n", + " scaler = StandardScaler() \n", + " scaler.fit(X_train) \n", + " X_train = scaler.transform(X_train) \n", + " X_test = scaler.transform(X_test) \n", + "\n", + " # --- Calibrate PCA ---\n", + " pca = PCA(n_components=None, random_state=11) # n_components=None keeps all 1280 components\n", + " pca.fit(X_train)\n", + "\n", + " # Calculate the cumulative explained variance\n", + " cumulative_variance = np.cumsum(pca.explained_variance_ratio_)\n", + " variance_threshold = 0.95\n", + " optimal_k = np.where(cumulative_variance >= variance_threshold)[0][0] + 1\n", + " components.append(optimal_k)\n", + " pca_final = PCA(n_components=optimal_k, random_state=11)\n", + "\n", + " # --- Enable PCA ---\n", + " #pca_final = PCA(n_components=1111, random_state=11)\n", + " X_train = pca_final.fit_transform(X_train)\n", + " X_test = pca_final.transform(X_test)\n", + "\n", + " # Lasso Regularization (best so far)\n", + " clf = LogisticRegression(solver=\"saga\", class_weight=\"balanced\", penalty='l1', C=0.0025, max_iter=500, n_jobs=-1, random_state=11)\n", + "\n", + " # Ridge Regularization\n", + " #clf = LogisticRegression(solver=\"saga\", class_weight=\"balanced\", penalty='l2', C=0.000025, max_iter=500, n_jobs=-1, random_state=11)\n", + "\n", + " clf.fit(X_train, y_train)\n", + "\n", + " # --- Training AUC Calculation ---\n", + " y_train_proba = clf.predict_proba(X_train)[:, 1]\n", + " train_auc = roc_auc_score(y_train, y_train_proba)\n", + " train_auc_scores.append(train_auc)\n", + "\n", + " # --- Test AUC Calculation ---\n", + " y_test_proba = clf.predict_proba(X_test)[:, 1]\n", + " test_auc = roc_auc_score(y_test, y_test_proba)\n", + " test_auc_scores.append(test_auc)\n", + "\n", + " print(f\"Fold {fold+1}: Train AUC = {train_auc:.4f}, Test AUC = {test_auc:.4f}\")\n", + "\n", + "# Mean ROC data\n", + "mean_auc_test = np.mean(test_auc_scores)\n", + "std_auc_test = np.std(test_auc_scores)\n", + "\n", + "# Mean PCA Components\n", + "mean_components = np.mean(components)\n", + "std_components = np.std(components)\n", + "\n", + "# --- Overfitting Check Section ---\n", + "print(\"\\n--- Overfitting Check ---\")\n", + "mean_train_auc = np.mean(train_auc_scores)\n", + "std_train_auc = np.std(train_auc_scores)\n", + "\n", + "print(f\"PCA Variance Kept: {variance_threshold*100}% with {mean_components:.4f} (+/- {std_components:.4f}) components\")\n", + "\n", + "print(f\"Average Training AUC across folds: {mean_train_auc:.4f} (+/- {std_train_auc:.4f})\")\n", + "print(\n", + " f\"Average Test (Validation) AUC across folds: {mean_auc_test:.4f} (+/- {std_auc_test:.4f})\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 97, + "id": "9ec60245", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "PCA Variance Kept: 95.0% with 1024 components\n", + "Train AUC = 0.7421, Test AUC = 0.7703\n" + ] + } + ], + "source": [ + "################################\n", + "# --- BP3 Final Evaluation --- #\n", + "################################\n", + "\n", + "train_df = bp3_train_df.to_pandas()\n", + "X_train = train_df[agg_features]\n", + "y_train = train_df[\"epitope\"]\n", + "\n", + "test_df = bp3_test_df.to_pandas()\n", + "X_test = test_df[agg_features]\n", + "y_test = test_df[\"epitope\"]\n", + "\n", + "# --- Scale Features (Required for PCA) ---\n", + "scaler = StandardScaler() \n", + "scaler.fit(X_train) \n", + "X_train = scaler.transform(X_train) \n", + "X_test = scaler.transform(X_test) \n", + "\n", + "# --- PCA ---\n", + "pca = PCA(n_components=None, random_state=11) # n_components=None keeps all 1280 components\n", + "pca.fit(X_train)\n", + "cumulative_variance = np.cumsum(pca.explained_variance_ratio_)\n", + "variance_threshold = 0.95\n", + "optimal_k = np.where(cumulative_variance >= variance_threshold)[0][0] + 1\n", + "components.append(optimal_k)\n", + "pca_final = PCA(n_components=optimal_k, random_state=11)\n", + "X_train = pca_final.fit_transform(X_train)\n", + "X_test = pca_final.transform(X_test)\n", + "\n", + "# --- Fit Model ---\n", + "clf = LogisticRegression(solver=\"saga\", class_weight=\"balanced\", penalty='l1', C=0.0025, max_iter=500, n_jobs=-1, random_state=11)\n", + "clf.fit(X_train, y_train)\n", + "\n", + "# --- Training AUC Calculation ---\n", + "y_train_proba = clf.predict_proba(X_train)[:, 1]\n", + "train_auc = roc_auc_score(y_train, y_train_proba)\n", + "train_auc_scores.append(train_auc)\n", + "\n", + "# --- Test AUC Calculation ---\n", + "y_test_proba = clf.predict_proba(X_test)[:, 1]\n", + "test_auc = roc_auc_score(y_test, y_test_proba)\n", + "test_auc_scores.append(test_auc)\n", + "\n", + "print(f\"PCA Variance Kept: {variance_threshold*100}% with {optimal_k} components\")\n", + "print(f\"Train AUC = {train_auc:.4f}, Test AUC = {test_auc:.4f}\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "epident-experiments", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.7" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}