Abstract Downstream tasks on protein language models (like ESM-2) utilize high-dimensional projection layers that are often redundant. This project demonstrates Distil-Rank, a pipeline that compresses these downstream projection layers by 10x. It recovers near-perfect functional fidelity (99.8%+) through knowledge distillation, exploiting the low intrinsic dimensionality of biological embeddings.
Tested on projection layers operating on embeddings from the ESM-2 (650M) model across three teacher initialization strategies.
| Teacher Init | SVD Cosine | Student Cosine | Improvement | MSE Reduction |
|---|---|---|---|---|
| Random | 0.467 | 0.998 | +53.1% | 99.5% |
| Task-Trained | 0.999 | 1.000 | +0.1% | 98.8% |
| Spectral Decay | 0.847 | 0.999 | +15.2% | 99.2% |
| Metric | Value |
|---|---|
| Speedup | ~7.5x |
| Compression | 9.9x (1.64M → 0.16M params) |
| Target Rank | 64 |
- Task-Trained Teacher: Most realistic scenario. Teacher learns a denoising proxy task on real embeddings, creating weight structure aligned with the data manifold. SVD already achieves 0.999 fidelity, showing learned layers have natural low-rank structure.
- Random Teacher: Demonstrates distillation's power—even arbitrary projections achieve near-perfect recovery.
- Spectral Decay: Simulates realistic singular value decay patterns found in trained neural networks.
The singular value spectrum plot (right panel) reveals why: task-trained weights have sharp spectral decay (energy concentrated in top components), while random weights have flat spectra requiring distillation to recover fidelity.
This approach combines linear algebra with neural knowledge distillation.
A standard projection layer is a dense matrix
In biological data, inputs
By forcing the rank constraint
-
Teacher:
$O(m \cdot n)$ operations. -
Student:
$O(r(m + n))$ operations. - For
$m=n=1280$ and$r=64$ , this yields a theoretical ~10x reduction in FLOPs.
When a neural network layer is trained on structured data (like protein embeddings), the weight matrix naturally develops spectral concentration. The learned transformation
The benchmark automatically detects if real data is available.
-
Install dependencies:
pip install -r requirements.txt
-
Run Benchmark:
python main.py
- Real Data Mode: If
embeddings_train.ptandembeddings_val.ptare found, it runs the full benchmark on ESM-2 embeddings with three teacher initialization strategies. - Synthetic Mode: If files are missing, it falls back to structured synthetic data for demonstration.
Results are saved to results/:
distil_rank_summary.json- All metrics in JSON formatdistil_rank_final_report.png- Visualization comparing strategies
├── main.py # Main benchmark script
├── embeddings_train.pt # ESM-2 training embeddings (optional)
├── embeddings_val.pt # ESM-2 validation embeddings (optional)
├── requirements.txt # Dependencies
├── README.md
└── results/
├── distil_rank_summary.json
└── distil_rank_final_report.png
MIT
