Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions examples/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
*/dataset/
*/train_output/
24 changes: 11 additions & 13 deletions examples/chemprot/README.md
Original file line number Diff line number Diff line change
@@ -1,24 +1,22 @@
# Fine-tuning for tagging: End-to-end example

1. Preprocess the data with `uv run examples/chemprot/prepare_chemprot_dataset.py data/chemprot`
1. Preprocess the data with `uv run examples/chemprot/prepare_chemprot_dataset.py`

2. Fine-tune with something like:
2. Fine-tune for NER with something like:

```bash
cnlpt train \
--task_name chemical_ner gene_ner \
--data_dir data/chemprot \
--encoder_name allenai/scibert_scivocab_uncased \
--do_train \
--do_eval \
--cache_dir cache/ \
--output_dir temp/ \
uv run cnlpt train \
--model_type proj \
--encoder allenai/scibert_scivocab_uncased \
--data_dir ./dataset \
--task chemical_ner --task gene_ner \
--output_dir ./train_output \
--overwrite_output_dir \
--num_train_epochs 50 \
--do_train --do_eval \
--num_train_epochs 3 \
--learning_rate 2e-5 \
--lr_scheduler_type constant \
--report_to none \
--save_strategy no \
--save_strategy best \
--gradient_accumulation_steps 1 \
--eval_accumulation_steps 10 \
--weight_decay 0.2
Expand Down
28 changes: 13 additions & 15 deletions examples/chemprot/preprocess_chemprot.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,24 @@
import bisect
import itertools
import os
import re
from dataclasses import dataclass
from sys import argv
from typing import Any, Union
from pathlib import Path
from typing import Any

import polars as pl
from datasets import load_dataset
from datasets.dataset_dict import Dataset, DatasetDict
from datasets.utils import disable_progress_bars, enable_progress_bars
from rich.console import Console


def load_chemprot_dataset(cache_dir="./cache") -> DatasetDict:
return load_dataset("bigbio/chemprot", "chemprot_full_source", cache_dir=cache_dir)
disable_progress_bars()
dataset = load_dataset(
"bigbio/chemprot", "chemprot_full_source", cache_dir=cache_dir
)
enable_progress_bars()
return dataset


def clean_text(text: str):
Expand Down Expand Up @@ -156,25 +161,18 @@ def preprocess_data(split: Dataset):
)


def main(out_dir: Union[str, os.PathLike]):
if __name__ == "__main__":
console = Console()

if not os.path.isdir(out_dir):
os.mkdir(out_dir)
out_dir = Path(__file__).parent / "dataset"
out_dir.mkdir(exist_ok=True)

with console.status("Loading dataset...") as st:
dataset = load_chemprot_dataset()
for split in ("train", "test", "validation"):
st.update(f"Preprocessing {split} data...")
preprocessed = preprocess_data(dataset[split])
preprocessed.write_csv(
os.path.join(out_dir, f"{split}.tsv"), separator="\t"
)
preprocessed.write_csv(out_dir / f"{split}.tsv", separator="\t")

console.print(
f"[green i]Preprocessed chemprot data saved to [repr.filename]{out_dir}[/]."
)


if __name__ == "__main__":
main(argv[1])
118 changes: 55 additions & 63 deletions examples/uci_drug/README.md
Original file line number Diff line number Diff line change
@@ -1,74 +1,66 @@
### Fine-tuning for classification: End-to-end example
# Drug Review Sentiment Classification

1. Download data from [Drug Reviews (Druglib.com) Data Set](https://archive.ics.uci.edu/dataset/461/drug+review+dataset+druglib+com) to `data` folder and extract. Pay attention to their terms:
1. only use the data for research purposes
2. don't use the data for any commerical purposes
3. don't distribute the data to anyone else
4. cite us
## Jupyter notebook example

2. Run ```python examples/uci_drug/transform_uci_drug.py <raw dir> <processed dir>``` to preprocess the data from the extract directory into a new directory. This will create {train,dev,test}.tsv in the processed directory specified, where the sentiment ratings have been collapsed into 3 categories.
See the [example notebook](./uci_drug.ipynb) for a step-by-step walkthrough of
how to use CNLPT to train a model for sentiment classification of drug reviews.

3. Fine-tune with something like:
## CLI example

```bash
cnlpt train \
--data_dir <processed dir> \
--task_name sentiment \
--encoder_name roberta-base \
--do_train \
--do_eval \
--cache_dir cache/ \
--output_dir temp/ \
--overwrite_output_dir \
--evals_per_epoch 5 \
--num_train_epochs 1 \
--learning_rate 1e-5 \
--report_to none \
--metric_for_best_model eval_sentiment.avg_micro_f1 \
--load_best_model_at_end \
--save_strategy best
```

On our hardware, that command results in eval performance like the following:
```sentiment = {'acc': 0.7041800643086816, 'f1': [0.7916666666666666, 0.7228915662650603, 0.19444444444444442], 'acc_and_f1': [0.7479233654876741, 0.7135358152868709, 0.449312254376563], 'recall': [0.8216216216216217, 0.8695652173913043, 0.12280701754385964], 'precision': [0.7638190954773869, 0.6185567010309279, 0.4666666666666667]}```

#### Error Analysis for Classification

If you run the above command with the `--error_analysis` flag, you can obtain the `dev` instances for which the model made an erroneous
prediction, organized by their original index in `dev` split, in the `eval_predictions...tsv` file in the `--output_dir` argument.
For us the first line of this file (after the header) is:

```
text sentiment
2 Benefits: <cr> helped aleviate whip lash symptoms <cr> Side effects: <cr> none that i noticed <cr> Overall comments: <cr> i took the medications for the prescribed time and symptoms improved, however, I still have some symptoms which are being treated through physical therapy since the accident was only in December Ground: Medium Predicted: High

```

The number at the beginning of the line, 2, is the index of the instance in the `dev` split. The `text` column contains the text of the erroneous instances and the following columns are the tasks provided to the model, in this case, just `sentiment`. `Ground: Medium Predicted: High` indicates that the provided ground truth label for the instance sentiment is `Medium` but the model predicted `High`.

#### Human Readable Predictions for Classification
If you prefer, you can instead use the CLI to train the model:

Similarly if you run the above command with `--do_predict` you can obtain human readable predictions for the `test` split, in the `test_predictions...tsv` file. For us the first line of this file (after the header) is:

```
0 Benefits: <cr> The antibiotic may have destroyed bacteria causing my sinus infection. But it may also have been caused by a virus, so its hard to say. <cr> Side effects: <cr> Some back pain, some nauseau. <cr> Overall comments: <cr> Took the antibiotics for 14 days. Sinus infection was gone after the 6th day. Low

```

##### Prediction Probability Outputs for Classification

(Currently only supported for classification tasks), if you run the above command with the `--output_prob` flag, you can see the model's softmax-obtained probability for the predicted classification label. The first error analysis sample from `dev` would now looks like:

```
text sentiment
2 Benefits: <cr> helped aleviate whip lash symptoms <cr> Side effects: <cr> none that i noticed <cr> Overall comments: <cr> i took the medications for the prescribed time and symptoms improved, however, I still have some symptoms which are being treated through physical therapy since the accident was only in December Ground: Medium Predicted: High , Probability 0.613825
### Download and preprocess the data

Use the [`prepare_data.py`](./prepare_data.py) script to download the data and convert it to CNLPT's data format:

```bash
uv run prepare_data.py
```

And the first prediction sample from `test` now looks like:
> [!TIP] About the dataset:
> This script downloads the
> [*Drug Reviews (Druglib.com)* dataset](https://archive.ics.uci.edu/dataset/461/drug+review+dataset+druglib+com).
> Please be aware of the terms of use:
>
> > Important Notes:
> >
> > When using this dataset, you agree that you
> >
> > 1) only use the data for research purposes
> > 2) don't use the data for any commerical purposes
> > 3) don't distribute the data to anyone else
> > 4) cite UCI data lab and the source
>
> Here is the dataset's BibTeX citation:
>
> ```bibtex
> @misc{drug_reviews_(druglib.com)_461,
> author = {Kallumadi, Surya and Grer, Felix},
> title = {{Drug Reviews (Druglib.com)}},
> year = {2018},
> howpublished = {UCI Machine Learning Repository},
> note = {{DOI}: https://doi.org/10.24432/C55G6J}
> }
> ```

### Train a model

The following example fine-tunes
[the RoBERTa base model](https://huggingface.co/FacebookAI/roberta-base)
with an added projection layer for classification:

```
text sentiment
0 Benefits: <cr> The antibiotic may have destroyed bacteria causing my sinus infection. But it may also have been caused by a virus, so its hard to say. <cr> Side effects: <cr> Some back pain, some nauseau. <cr> Overall comments: <cr> Took the antibiotics for 14 days. Sinus infection was gone after the 6th day. Low , Probability 0.370522
```bash
uv run cnlpt train \
--model_type proj \
--encoder roberta-base \
--data_dir ./dataset \
--task sentiment \
--output_dir ./train_output \
--overwrite_output_dir \
--do_train --do_eval --do_predict \
--evals_per_epoch 2 \
--learning_rate 1e-5 \
--metric_for_best_model 'sentiment.macro_f1' \
--load_best_model_at_end \
--save_strategy best
```
67 changes: 67 additions & 0 deletions examples/uci_drug/prepare_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import io
import zipfile
from pathlib import Path

import polars as pl
import requests

DATASET_ZIP_URL = (
"https://archive.ics.uci.edu/static/public/461/drug+review+dataset+druglib+com.zip"
)
DATA_DIR = Path(__file__).parent / "dataset"


def preprocess_raw_data(unprocessed_path: str):
return pl.read_csv(unprocessed_path, separator="\t").select(
id="",
sentiment=pl.col("rating").map_elements(
lambda rating: "Negative"
if rating < 5
else "Neutral"
if rating < 8
else "Positive",
return_dtype=pl.String,
),
text=(
pl.concat_str(
"benefitsReview",
"sideEffectsReview",
"commentsReview",
separator=" ",
)
.str.replace_all("\n", " <cr> ")
.str.replace_all("\r", " <cr> ")
.str.replace_all("\t", " ")
),
)


if __name__ == "__main__":
DATA_DIR.mkdir(exist_ok=True)

# Download dataset
response = requests.get(DATASET_ZIP_URL)
response.raise_for_status()
zip = zipfile.ZipFile(io.BytesIO(response.content))
zip.extractall(DATA_DIR)

raw_train_file = DATA_DIR / "drugLibTrain_raw.tsv"
raw_test_file = DATA_DIR / "drugLibTest_raw.tsv"

# Preprocess raw data
preprocessed_train_data = preprocess_raw_data(raw_train_file)
preprocessed_test_data = preprocess_raw_data(raw_test_file)

# 90/10 split for train and dev
preprocessed_train_data, preprocessed_dev_data = (
preprocessed_train_data.iter_slices(int(preprocessed_train_data.shape[0] * 0.9))
)

# Write to tsv files
preprocessed_train_data.write_csv(DATA_DIR / "train.tsv", separator="\t")
preprocessed_dev_data.write_csv(DATA_DIR / "dev.tsv", separator="\t")
preprocessed_test_data.write_csv(DATA_DIR / "test.tsv", separator="\t")

# Delete raw data files
raw_train_file.unlink()
raw_test_file.unlink()
Loading
Loading