The repository contains the code for the reproduction of the results from "Supervised Learning of Universal Sentence Representations from Natural Language Inference Data" paper by Conneau et al. (2018).
- Create a virtual environment:
python -m venv venv - Activate the virtual environment:
. ./venv/bin/activate - Install the dependencies:
pip install -r requirements.txt - Install
SentEvalframework:./scripts/install_senteval.sh(might not download the STS dataset on MacOS) - SentEval troublshooting: SentEval uses old code so the Python versions =>3.10 can lead to errors:
ValueError: Function has keyword-only parameters or annotations, use inspect.signature() API which can support them: A fix/workaround can be found here: facebookresearch/SentEval#89.ValueError: setting an array element with a sequence. The requested array has an inhomogeneous shape after 1 dimensions. The detected shape was (750,) + inhomogeneous part.on STS14 benchmark. STS14 benchmark code in SentEval doesn't work with newer numpy versions. A fix/workaround can be found here: facebookresearch/SentEval#94.
Example usage:
python -m src.train --model lstm
Acceptable model values are lstm, bi-lstm, bi-lstm-pool.
All arguments with their description can be viewed with python -m src.train -h.
Example usage:
python -m src.eval --model lstm --checkpoint_path models/lstm_2024_04_17_13_47.pt
Acceptable model values are lstm, bi-lstm, bi-lstm-pool, mean.
All arguments with their description can be viewed with python -m src.eval -h.
Training logs can be found in the training report.
The pre-trained models checkpoints can be found here.
The results on the SNLI task. The accuracy is rounded to two decimal points.
| Model | Validation accuracy | Test accuracy |
|---|---|---|
| Mean | 64.19 | 64.92 |
| LSTM | 81.07 | 80.53 |
| BiLSTM | 80.55 | 80.32 |
| BiLSTM with max pooling | 84.87 | 84.47 |
Following the methodology of Conneau et al. (2018), ”micro” and ”macro” averages of development set (dev) results on transfer tasks whose metrics is accuracy. The accuracy is rounded to two decimal points.
| Model | Micro | Macro |
|---|---|---|
| Mean | 82.03 | 79.36 |
| LSTM | 79.75 | 78.94 |
| BiLSTM | 82.63 | 81.7 |
| BiLSTM with max pooling | 84.25 | 83.27 |
| Model | MR | CR | SUBJ | MPQA | SST | TREC | MRPC | SICK-R | SICK-E | STS14 |
|---|---|---|---|---|---|---|---|---|---|---|
| Mean | 76.79 | 78.01 | 90.90 | 87.41 | 80.45 | 81.4 | 71.83/80.65 | 0.7740 | 77.13 | 0.54/0.56 |
| LSTM | 73.86 | 77.69 | 86.38 | 87.69 | 77.98 | 75.4 | 73.04/81.39 | 0.8627 | 84.33 | 0.14/0.32 |
| BiLSTM | 74.6 | 79.08 | 89.33 | 88.06 | 79.41 | 87.8 | 73.57/82.12 | 0.8719 | 84.96 | 0.30/0.30 |
| BiLSTM with max pooling | 77.89 | 81.22 | 91.87 | 88.15 | 83.03 | 87.4 | 75.07/83.28 | 0.8824 | 85.06 | 0.69/0.67 |