From 42e682ac8ebe1d5bb9da74ebffb63968eafcf270 Mon Sep 17 00:00:00 2001 From: Thomas Foster Date: Mon, 5 Dec 2022 13:48:49 +0000 Subject: [PATCH 01/21] collect rollouts for hyperparameter sweep on roc_story --- .gitignore | 4 ++ .../sentiment-data/ppo_config.yml | 49 ++++++++++++++++ .../ppo_roc_story_sentiments.py | 57 +++++++++++++++++++ .../sentiment-data/ppo_sweep.yml | 17 ++++++ .../sentiment-data/rollouts/.gitkeep | 0 5 files changed, 127 insertions(+) create mode 100644 algorithm_distillation/sentiment-data/ppo_config.yml create mode 100644 algorithm_distillation/sentiment-data/ppo_roc_story_sentiments.py create mode 100644 algorithm_distillation/sentiment-data/ppo_sweep.yml create mode 100644 algorithm_distillation/sentiment-data/rollouts/.gitkeep diff --git a/.gitignore b/.gitignore index b6e4761..440260f 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,7 @@ +trlx +algorithm_distillation/sentiment-data/ray_results/ +algorithm_distillation/sentiment-data/rollouts/run-* + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/algorithm_distillation/sentiment-data/ppo_config.yml b/algorithm_distillation/sentiment-data/ppo_config.yml new file mode 100644 index 0000000..70ff782 --- /dev/null +++ b/algorithm_distillation/sentiment-data/ppo_config.yml @@ -0,0 +1,49 @@ +model: + model_path: "lvwerra/gpt2-imdb" # Name of hf model to load + tokenizer_path: "gpt2" # Name of hf tokenizer to load + model_type: "AcceleratePPOModel" # Name of accelerate model type to load + num_layers_unfrozen: 2 # Number of bottom layers to freeze during training + +train: + seq_length: 48 # Size of LM context + epochs: 100 # Train for max(epochs, total_steps) + total_steps: 10000 # Train for max(epochs, total_steps) + batch_size: 128 # batch size + + lr_init: 1.0e-4 # init learning rate + lr_target: 1.0e-4 # target final learning rate + opt_betas: [0.9, 0.95] # adam betas + opt_eps: 1.0e-8 # adam eps + weight_decay: 1.0e-6 # weight decay param + + checkpoint_interval: 10000 # checkpoint interval + eval_interval: 100 # eval interval + + pipeline: "PromptPipeline" # prompt pipeline to load + orchestrator: "PPOOrchestrator" # orchestrator to load + + rollout_logging_dir: "/home/ubuntu/Algorithm-Distillation-RLHF/sentiment-data/rollouts" + +method: + name: 'ppoconfig' # Name of RL method config + num_rollouts: 128 # Number of rollouts to collect per epoch + chunk_size: 128 # Number of rollouts to collect in one loop of orchestrator + ppo_epochs: 4 # Number of ppo epochs + init_kl_coef: 0.05 # init kl coefficient + target: 6 # target kl coefficient, set None for fixed kl coef + horizon: 10000 # PPO horizon + gamma: 1 # PPO discount + lam: 0.95 # PPO lambda + cliprange: 0.2 # clip range + cliprange_value: 0.2 # clip range + vf_coef: 1 # value term weight + scale_reward: False # False | "ref" | "running" estimate against which to scale rewards + ref_mean: null + ref_std: null # rescale rewards with this deviation + cliprange_reward: 10 + gen_kwargs: + max_length: 48 # LM max sample gen length + min_length: 48 # LM min sample gen length + top_k: 0.0 # top k + top_p: 1.0 # top p + do_sample: True # sample diff --git a/algorithm_distillation/sentiment-data/ppo_roc_story_sentiments.py b/algorithm_distillation/sentiment-data/ppo_roc_story_sentiments.py new file mode 100644 index 0000000..c2c38a1 --- /dev/null +++ b/algorithm_distillation/sentiment-data/ppo_roc_story_sentiments.py @@ -0,0 +1,57 @@ +from posixpath import dirname +from datasets import load_dataset +from transformers import pipeline +import os +import yaml + +import trlx +import torch +from typing import List +from trlx.data.configs import TRLConfig + +from trlx.utils.loading import get_model, get_orchestrator, get_pipeline + +def get_positive_score(scores): + "Extract value associated with a positive sentiment from pipeline's output" + return dict(map(lambda x: tuple(x.values()), scores))["POSITIVE"] + +default_config = yaml.safe_load(open(os.path.join(dirname(__file__), "ppo_config.yml"))) + +def main(hparams={}): + config = TRLConfig.update(default_config, hparams) + + if torch.cuda.is_available(): + device = int(os.environ.get("LOCAL_RANK", 0)) + else: + device = -1 + + + sentiment_fn = pipeline( + "sentiment-analysis", + "lvwerra/distilbert-imdb", + top_k=2, + truncation=True, + batch_size=256, + device=device, + ) + + def reward_fn(samples: List[str]) -> List[float]: + sentiments = list(map(get_positive_score, sentiment_fn(samples))) + return sentiments + + # Take few words off of movies reviews as prompts + stories = load_dataset("adamlin/roc_story") + prompts = [d['sentence1'] for d in stories['train']] + eval_prompts = [d for d in stories['validation'][:64]['sentence1']] + + model = trlx.train( + model_path="gpt2", + reward_fn=reward_fn, + prompts=prompts, + eval_prompts=eval_prompts, + config=config, + ) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/algorithm_distillation/sentiment-data/ppo_sweep.yml b/algorithm_distillation/sentiment-data/ppo_sweep.yml new file mode 100644 index 0000000..95469ef --- /dev/null +++ b/algorithm_distillation/sentiment-data/ppo_sweep.yml @@ -0,0 +1,17 @@ +tune_config: + mode: "max" + metric: "mean_reward" + search_alg: "random" + scheduler: "fifo" + num_samples: 32 + +# https://docs.ray.io/en/latest/tune/api_docs/search_space.html#tune-sample-docs +lr_init: + strategy: "loguniform" + values: [0.00001, 0.01] +init_kl_coef: + strategy: "uniform" + values: [0, 0.2] +vf_coef: + strategy: "uniform" + values: [0.5, 2] diff --git a/algorithm_distillation/sentiment-data/rollouts/.gitkeep b/algorithm_distillation/sentiment-data/rollouts/.gitkeep new file mode 100644 index 0000000..e69de29 From 5e4157846abd1e8d19e656585a4e64a8b20a7a96 Mon Sep 17 00:00:00 2001 From: Thomas Foster Date: Mon, 5 Dec 2022 14:49:04 +0000 Subject: [PATCH 02/21] update reward model to one with wider range of sentiment options --- .gitignore | 1 + .../sentiment-data/ppo_config.yml | 2 +- .../sentiment-data/ppo_roc_story_sentiments.py | 13 ++++++++----- 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/.gitignore b/.gitignore index 440260f..2fe9190 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,7 @@ trlx algorithm_distillation/sentiment-data/ray_results/ algorithm_distillation/sentiment-data/rollouts/run-* +algorithm_distillation/sentiment-data/wandb/ # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/algorithm_distillation/sentiment-data/ppo_config.yml b/algorithm_distillation/sentiment-data/ppo_config.yml index 70ff782..fed94fe 100644 --- a/algorithm_distillation/sentiment-data/ppo_config.yml +++ b/algorithm_distillation/sentiment-data/ppo_config.yml @@ -22,7 +22,7 @@ train: pipeline: "PromptPipeline" # prompt pipeline to load orchestrator: "PPOOrchestrator" # orchestrator to load - rollout_logging_dir: "/home/ubuntu/Algorithm-Distillation-RLHF/sentiment-data/rollouts" + rollout_logging_dir: "~/Algorithm-Distillation-RLHF/algorithm_distillation/sentiment-data/rollouts" method: name: 'ppoconfig' # Name of RL method config diff --git a/algorithm_distillation/sentiment-data/ppo_roc_story_sentiments.py b/algorithm_distillation/sentiment-data/ppo_roc_story_sentiments.py index c2c38a1..e06b555 100644 --- a/algorithm_distillation/sentiment-data/ppo_roc_story_sentiments.py +++ b/algorithm_distillation/sentiment-data/ppo_roc_story_sentiments.py @@ -3,6 +3,7 @@ from transformers import pipeline import os import yaml +from functools import partial import trlx import torch @@ -11,9 +12,10 @@ from trlx.utils.loading import get_model, get_orchestrator, get_pipeline -def get_positive_score(scores): +def get_score_for_label(label, scores): "Extract value associated with a positive sentiment from pipeline's output" - return dict(map(lambda x: tuple(x.values()), scores))["POSITIVE"] + label_to_score = {d['label'] : d['score'] for d in scores} + return label_to_score[label] default_config = yaml.safe_load(open(os.path.join(dirname(__file__), "ppo_config.yml"))) @@ -28,15 +30,16 @@ def main(hparams={}): sentiment_fn = pipeline( "sentiment-analysis", - "lvwerra/distilbert-imdb", - top_k=2, + "bhadresh-savani/distilbert-base-uncased-emotion", truncation=True, batch_size=256, device=device, + return_all_scores=True, ) def reward_fn(samples: List[str]) -> List[float]: - sentiments = list(map(get_positive_score, sentiment_fn(samples))) + output_batch = sentiment_fn(samples) + sentiments = list(map(partial(get_score_for_label, 'joy'), output_batch)) return sentiments # Take few words off of movies reviews as prompts From 88cfc22d31e16998ac0b655c0c309585e652af59 Mon Sep 17 00:00:00 2001 From: Thomas Foster Date: Thu, 5 Jan 2023 14:18:45 +0000 Subject: [PATCH 03/21] update for relative logging paths --- .gitignore | 1 + algorithm_distillation/sentiment-data/ppo_config.yml | 2 +- .../sentiment-data/ppo_roc_story_sentiments.py | 1 - 3 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.gitignore b/.gitignore index 2fe9190..976f19e 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ trlx algorithm_distillation/sentiment-data/ray_results/ algorithm_distillation/sentiment-data/rollouts/run-* algorithm_distillation/sentiment-data/wandb/ +wandb/ # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/algorithm_distillation/sentiment-data/ppo_config.yml b/algorithm_distillation/sentiment-data/ppo_config.yml index fed94fe..7afc44a 100644 --- a/algorithm_distillation/sentiment-data/ppo_config.yml +++ b/algorithm_distillation/sentiment-data/ppo_config.yml @@ -22,7 +22,7 @@ train: pipeline: "PromptPipeline" # prompt pipeline to load orchestrator: "PPOOrchestrator" # orchestrator to load - rollout_logging_dir: "~/Algorithm-Distillation-RLHF/algorithm_distillation/sentiment-data/rollouts" + rollout_logging_dir: "../algorithm_distillation/sentiment-data/rollouts" method: name: 'ppoconfig' # Name of RL method config diff --git a/algorithm_distillation/sentiment-data/ppo_roc_story_sentiments.py b/algorithm_distillation/sentiment-data/ppo_roc_story_sentiments.py index e06b555..3679b95 100644 --- a/algorithm_distillation/sentiment-data/ppo_roc_story_sentiments.py +++ b/algorithm_distillation/sentiment-data/ppo_roc_story_sentiments.py @@ -42,7 +42,6 @@ def reward_fn(samples: List[str]) -> List[float]: sentiments = list(map(partial(get_score_for_label, 'joy'), output_batch)) return sentiments - # Take few words off of movies reviews as prompts stories = load_dataset("adamlin/roc_story") prompts = [d['sentence1'] for d in stories['train']] eval_prompts = [d for d in stories['validation'][:64]['sentence1']] From 3c12ceb2bcf298b0581e0f7f7fe00b37b0ac4f1a Mon Sep 17 00:00:00 2001 From: Thomas Foster Date: Thu, 5 Jan 2023 14:29:51 +0000 Subject: [PATCH 04/21] add requirements.txt --- requirements.txt | 86 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 86 insertions(+) create mode 100644 requirements.txt diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..17bbea6 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,86 @@ +accelerate==0.14.0 +aiohttp==3.8.3 +aiosignal==1.3.1 +async-timeout==4.0.2 +attrs==22.1.0 +certifi==2022.9.24 +cfgv==3.3.1 +charset-normalizer==2.1.1 +click==8.0.4 +datasets==2.7.1 +deepspeed==0.7.5 +dill==0.3.6 +distlib==0.3.6 +docker-pycreds==0.4.0 +einops==0.6.0 +exceptiongroup==1.0.4 +filelock==3.8.0 +frozenlist==1.3.3 +fsspec==2022.11.0 +gitdb==4.0.9 +GitPython==3.1.29 +grpcio==1.50.0 +hjson==3.1.0 +huggingface-hub==0.11.0 +identify==2.5.9 +idna==3.4 +iniconfig==1.1.1 +joblib==1.2.0 +jsonschema==4.17.1 +msgpack==1.0.4 +multidict==6.0.2 +multiprocess==0.70.14 +networkx==2.8.8 +ninja==1.11.1 +nodeenv==1.7.0 +numpy==1.23.5 +nvidia-cublas-cu11==11.10.3.66 +nvidia-cuda-nvrtc-cu11==11.7.99 +nvidia-cuda-runtime-cu11==11.7.99 +nvidia-cudnn-cu11==8.5.0.96 +packaging==21.3 +pandas==1.5.2 +pathtools==0.1.2 +platformdirs==2.5.4 +pluggy==1.0.0 +pre-commit==2.20.0 +promise==2.3 +protobuf==4.21.9 +psutil==5.9.4 +py-cpuinfo==9.0.0 +pyarrow==10.0.1 +pydantic==1.10.2 +pyparsing==3.0.9 +pyrsistent==0.19.2 +pytest==7.2.0 +python-dateutil==2.8.2 +pytz==2022.6 +PyYAML==6.0 +ray==2.1.0 +regex==2022.10.31 +requests==2.28.1 +responses==0.18.0 +scikit-learn==1.1.3 +scipy==1.9.3 +sentry-sdk==1.11.1 +setproctitle==1.3.2 +shortuuid==1.0.11 +six==1.16.0 +smmap==5.0.0 +tabulate==0.9.0 +threadpoolctl==3.1.0 +tokenizers==0.13.2 +toml==0.10.2 +tomli==2.0.1 +torch==1.13.0 +torchtyping==0.1.4 +tqdm==4.64.1 +transformers==4.24.0 +-e git+https://github.com/thomfoster/trlx.git@e76b34338eb1551e8eb5f0eff7f694a4711d8b83#egg=trlx +typeguard==2.13.3 +typing_extensions==4.4.0 +urllib3==1.26.12 +virtualenv==20.16.7 +wandb==0.13.5 +xxhash==3.1.0 +yarl==1.8.1 From 6ff57bae6b4f0c6c1202ef5d50e353dd90e6784b Mon Sep 17 00:00:00 2001 From: Thomas Foster Date: Thu, 5 Jan 2023 16:18:45 +0000 Subject: [PATCH 05/21] add incredibly overengineered script for decoding rollouts --- .gitignore | 2 + .../sentiment-data/decode_rollout.py | 101 ++++++++++++++++++ 2 files changed, 103 insertions(+) create mode 100644 algorithm_distillation/sentiment-data/decode_rollout.py diff --git a/.gitignore b/.gitignore index 976f19e..b672c9e 100644 --- a/.gitignore +++ b/.gitignore @@ -1,9 +1,11 @@ trlx algorithm_distillation/sentiment-data/ray_results/ algorithm_distillation/sentiment-data/rollouts/run-* +algorithm_distillation/sentiment-data/decoded_rollouts algorithm_distillation/sentiment-data/wandb/ wandb/ + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/algorithm_distillation/sentiment-data/decode_rollout.py b/algorithm_distillation/sentiment-data/decode_rollout.py new file mode 100644 index 0000000..13b2b65 --- /dev/null +++ b/algorithm_distillation/sentiment-data/decode_rollout.py @@ -0,0 +1,101 @@ +from pathlib import Path +import json +from transformers import AutoTokenizer +import click +from typing import Union +from shutil import copy2 + +@click.group() +def main(): + """ + CLI for formatting the rollouts into training data. + + \b + 1. decode-epoch: Decode a single epoch rollout .json file + 2. decode-run: Decode an entire PPO run (multiple epoch files) + 3. decode-rollouts: Deocde an entire directory (multiple runs) + """ + +def get_tokenizer(tokenizer: Union[str, AutoTokenizer]) -> AutoTokenizer: + if isinstance(tokenizer, str): + return AutoTokenizer.from_pretrained(tokenizer) + else: + return tokenizer + + +@main.command() +@click.option('--tokenizer', type=str, default='gpt2', help='tokenizer to decode with') +@click.option('--input-fpath', type=click.Path(path_type=Path), help='the input JSON file') +@click.option('--output-fpath', type=click.Path(path_type=Path), help='the path of the JSON file to be created. Will overwrite if necessary.') +def decode_epoch(tokenizer: Union[str, AutoTokenizer], input_fpath: Path, output_fpath: Path): + _decode_epoch(tokenizer, input_fpath, output_fpath) + +def _decode_epoch(tokenizer: Union[str, AutoTokenizer], input_fpath: Path, output_fpath: Path): + + assert input_fpath.exists() + assert input_fpath.is_file() + assert input_fpath.name.endswith('.json') + assert output_fpath.name.endswith('.json') + if not output_fpath.parent.exists(): + output_fpath.parent.mkdir() + + tokenizer = get_tokenizer(tokenizer) + + with open(input_fpath, 'r') as f: + rollouts = json.loads(f.read()) + + for rollout in rollouts: + rollout['query_text'] = tokenizer.decode(rollout['query_tensor'], skip_special_tokens=True) + rollout['response_text'] = tokenizer.decode(rollout['response_tensor'], skip_special_tokens=True) + + with open(output_fpath, 'w') as f: + f.write(json.dumps(rollouts, indent=2)) + + +@main.command() +@click.option('--tokenizer', type=str, default='gpt2', help='tokenizer to decode with') +@click.option('--input-fpath', type=click.Path(path_type=Path), help='the directory containing the JSON epoch files') +@click.option('--output-fpath', type=click.Path(path_type=Path), help='the path of the folder to be created.') +def decode_run(tokenizer: Union[str, AutoTokenizer], input_fpath: Path, output_fpath: Path): + _decode_run(tokenizer, input_fpath, output_fpath) + +def _decode_run(tokenizer: Union[str, AutoTokenizer], input_fpath: Path, output_fpath: Path): + + assert input_fpath.exists() + assert input_fpath.is_dir() + if not output_fpath.exists(): + output_fpath.mkdir() + + # Copy over the config file + assert (input_fpath / 'config.json').exists() + copy2(input_fpath / 'config.json', output_fpath / 'config.json') + + # Decode the rest of the files + tokenizer = get_tokenizer(tokenizer) + + epochs = [fpath for fpath in input_fpath.iterdir() if fpath.name != 'config.json'] + for epoch in epochs: + _decode_epoch(tokenizer, epoch, output_fpath / epoch.name) + + +@main.command() +@click.option('--tokenizer', type=str, default='gpt2', help='tokenizer to decode with') +@click.option('--input-fpath', type=click.Path(path_type=Path), help='the input directory') +@click.option('--output-fpath', type=click.Path(path_type=Path), help='the output directory') +def decode_rollouts(tokenizer: str, input_fpath: Path, output_fpath: Path): + _decode_rollouts(tokenizer, input_fpath, output_fpath) + +def _decode_rollouts(tokenizer: str, input_fpath: Path, output_fpath: Path): + + assert input_fpath.exists() + assert input_fpath.is_dir() + if not output_fpath.exists(): + output_fpath.mkdir() + + tokenizer = get_tokenizer(tokenizer) + runs = [fpath for fpath in input_fpath.iterdir() if fpath.name.startswith('run-')] + for run in runs: + _decode_run(tokenizer, run, output_fpath / run.name) + +if __name__ == '__main__': + main() \ No newline at end of file From 315b00fbd4af11fea7640beabe6ac80b50e9b25a Mon Sep 17 00:00:00 2001 From: Thomas Foster Date: Thu, 5 Jan 2023 18:20:49 +0000 Subject: [PATCH 06/21] rename decode_rollout > decode_rollouts --- .../sentiment-data/decode_rollouts.py | 101 ++++++++++++++++++ 1 file changed, 101 insertions(+) create mode 100644 algorithm_distillation/sentiment-data/decode_rollouts.py diff --git a/algorithm_distillation/sentiment-data/decode_rollouts.py b/algorithm_distillation/sentiment-data/decode_rollouts.py new file mode 100644 index 0000000..13b2b65 --- /dev/null +++ b/algorithm_distillation/sentiment-data/decode_rollouts.py @@ -0,0 +1,101 @@ +from pathlib import Path +import json +from transformers import AutoTokenizer +import click +from typing import Union +from shutil import copy2 + +@click.group() +def main(): + """ + CLI for formatting the rollouts into training data. + + \b + 1. decode-epoch: Decode a single epoch rollout .json file + 2. decode-run: Decode an entire PPO run (multiple epoch files) + 3. decode-rollouts: Deocde an entire directory (multiple runs) + """ + +def get_tokenizer(tokenizer: Union[str, AutoTokenizer]) -> AutoTokenizer: + if isinstance(tokenizer, str): + return AutoTokenizer.from_pretrained(tokenizer) + else: + return tokenizer + + +@main.command() +@click.option('--tokenizer', type=str, default='gpt2', help='tokenizer to decode with') +@click.option('--input-fpath', type=click.Path(path_type=Path), help='the input JSON file') +@click.option('--output-fpath', type=click.Path(path_type=Path), help='the path of the JSON file to be created. Will overwrite if necessary.') +def decode_epoch(tokenizer: Union[str, AutoTokenizer], input_fpath: Path, output_fpath: Path): + _decode_epoch(tokenizer, input_fpath, output_fpath) + +def _decode_epoch(tokenizer: Union[str, AutoTokenizer], input_fpath: Path, output_fpath: Path): + + assert input_fpath.exists() + assert input_fpath.is_file() + assert input_fpath.name.endswith('.json') + assert output_fpath.name.endswith('.json') + if not output_fpath.parent.exists(): + output_fpath.parent.mkdir() + + tokenizer = get_tokenizer(tokenizer) + + with open(input_fpath, 'r') as f: + rollouts = json.loads(f.read()) + + for rollout in rollouts: + rollout['query_text'] = tokenizer.decode(rollout['query_tensor'], skip_special_tokens=True) + rollout['response_text'] = tokenizer.decode(rollout['response_tensor'], skip_special_tokens=True) + + with open(output_fpath, 'w') as f: + f.write(json.dumps(rollouts, indent=2)) + + +@main.command() +@click.option('--tokenizer', type=str, default='gpt2', help='tokenizer to decode with') +@click.option('--input-fpath', type=click.Path(path_type=Path), help='the directory containing the JSON epoch files') +@click.option('--output-fpath', type=click.Path(path_type=Path), help='the path of the folder to be created.') +def decode_run(tokenizer: Union[str, AutoTokenizer], input_fpath: Path, output_fpath: Path): + _decode_run(tokenizer, input_fpath, output_fpath) + +def _decode_run(tokenizer: Union[str, AutoTokenizer], input_fpath: Path, output_fpath: Path): + + assert input_fpath.exists() + assert input_fpath.is_dir() + if not output_fpath.exists(): + output_fpath.mkdir() + + # Copy over the config file + assert (input_fpath / 'config.json').exists() + copy2(input_fpath / 'config.json', output_fpath / 'config.json') + + # Decode the rest of the files + tokenizer = get_tokenizer(tokenizer) + + epochs = [fpath for fpath in input_fpath.iterdir() if fpath.name != 'config.json'] + for epoch in epochs: + _decode_epoch(tokenizer, epoch, output_fpath / epoch.name) + + +@main.command() +@click.option('--tokenizer', type=str, default='gpt2', help='tokenizer to decode with') +@click.option('--input-fpath', type=click.Path(path_type=Path), help='the input directory') +@click.option('--output-fpath', type=click.Path(path_type=Path), help='the output directory') +def decode_rollouts(tokenizer: str, input_fpath: Path, output_fpath: Path): + _decode_rollouts(tokenizer, input_fpath, output_fpath) + +def _decode_rollouts(tokenizer: str, input_fpath: Path, output_fpath: Path): + + assert input_fpath.exists() + assert input_fpath.is_dir() + if not output_fpath.exists(): + output_fpath.mkdir() + + tokenizer = get_tokenizer(tokenizer) + runs = [fpath for fpath in input_fpath.iterdir() if fpath.name.startswith('run-')] + for run in runs: + _decode_run(tokenizer, run, output_fpath / run.name) + +if __name__ == '__main__': + main() \ No newline at end of file From 8c48e165a4b94b78a663c1b352254bb5ada35142 Mon Sep 17 00:00:00 2001 From: Thomas Foster Date: Thu, 5 Jan 2023 18:21:51 +0000 Subject: [PATCH 07/21] rename ppo_roc_story_sentiment_rollouts.py >> generate_ppo_roc_story_sentiment_rollouts.py to reflect that this is the script to generate data, not the class that uses it --- ...nerate_ppo_roc_story_sentiment_rollouts.py | 60 +++++++++++++++++++ 1 file changed, 60 insertions(+) create mode 100644 algorithm_distillation/sentiment-data/generate_ppo_roc_story_sentiment_rollouts.py diff --git a/algorithm_distillation/sentiment-data/generate_ppo_roc_story_sentiment_rollouts.py b/algorithm_distillation/sentiment-data/generate_ppo_roc_story_sentiment_rollouts.py new file mode 100644 index 0000000..e2f0582 --- /dev/null +++ b/algorithm_distillation/sentiment-data/generate_ppo_roc_story_sentiment_rollouts.py @@ -0,0 +1,60 @@ +from posixpath import dirname +from datasets import load_dataset +from transformers import pipeline +import os +import yaml +from functools import partial + +import trlx +import torch +from typing import List +from trlx.data.configs import TRLConfig + +from trlx.utils.loading import get_model, get_orchestrator, get_pipeline + +def get_score_for_label(label, scores): + "Extract value associated with a positive sentiment from pipeline's output" + label_to_score = {d['label'] : d['score'] for d in scores} + return label_to_score[label] + +default_config = yaml.safe_load(open(os.path.join(dirname(__file__), "ppo_config.yml"))) + +def main(hparams={}): + config = TRLConfig.update(default_config, hparams) + + if torch.cuda.is_available(): + device = int(os.environ.get("LOCAL_RANK", 0)) + else: + device = -1 + + + sentiment_fn = pipeline( + "sentiment-analysis", + "bhadresh-savani/distilbert-base-uncased-emotion", + truncation=True, + batch_size=256, + device=device, + return_all_scores=True, + ) + + def reward_fn(samples: List[str]) -> List[float]: + output_batch = sentiment_fn(samples) + sentiments = list(map(partial(get_score_for_label, 'sadness'), output_batch)) + return sentiments + + stories = load_dataset("adamlin/roc_story") + format_prompt = lambda d : d['sentence1'] + ' ' + d['sentence2'] + prompts = [format_prompt(d) for d in stories['train']] + eval_prompts = [format_prompt(d) for d in stories['validation'].select(range(64))] + + model = trlx.train( + model_path="gpt2", + reward_fn=reward_fn, + prompts=prompts, + eval_prompts=eval_prompts, + config=config, + ) + + +if __name__ == "__main__": + main() \ No newline at end of file From 35fa94d9880df8224d093ea9dd8a1bdb7c73a6bc Mon Sep 17 00:00:00 2001 From: Thomas Foster Date: Thu, 5 Jan 2023 18:22:19 +0000 Subject: [PATCH 08/21] add class to use rollout data as language modelling task --- .../ppo_roc_story_sentiment_rollouts.py | 55 +++++++++++++++++++ 1 file changed, 55 insertions(+) create mode 100644 algorithm_distillation/sentiment-data/ppo_roc_story_sentiment_rollouts.py diff --git a/algorithm_distillation/sentiment-data/ppo_roc_story_sentiment_rollouts.py b/algorithm_distillation/sentiment-data/ppo_roc_story_sentiment_rollouts.py new file mode 100644 index 0000000..424dcf5 --- /dev/null +++ b/algorithm_distillation/sentiment-data/ppo_roc_story_sentiment_rollouts.py @@ -0,0 +1,55 @@ +import torch +from pathlib import Path +import json +from transformers import AutoTokenizer +from typing import Dict, Any + +class RolloutsAsLanguageModellingTask(torch.utils.data.IterableDataset): + def __init__(self, tokenizer: AutoTokenizer, rollouts_folder_fpath: str): + self.tokenizer = tokenizer + self.rollouts_folder = Path(rollouts_folder_fpath) + + def format_rollout(self, d: Dict[Any, Any]) -> str: + return f"Prompt:{d['query_text']}\nCompletion:{d['response_text']}\nReward:{d['rewards'][-1]}\n\n" + + def __iter__(self): + + runs = [run for run in self.rollouts_folder.iterdir() if run.name.startswith('run-e')] + print(f'Found {len(runs)} runs...') + for run in runs: + + config = json.loads(open(run / 'config.json', 'r').read()) + epochs = [epoch for epoch in run.iterdir() if epoch.name != 'config.json'] + print(f'Run {run.name} has {len(epochs)} epochs...') + epochs = sorted(epochs) + for epoch in epochs: + # print(f'Yielding from epoch {epoch.name}') + + rollouts = json.loads(open(epoch, 'r').read()) + + rollout_idx = 0 + prompt = "" + while rollout_idx < len(rollouts): + rollout = self.format_rollout(rollouts[rollout_idx]) + rollout_idx += 1 + + new_prompt = prompt + rollout + new_ids = self.tokenizer.tokenize(new_prompt) + + if len(new_ids) > tokenizer.model_max_length: # self.tokenizer.model_max_length: + yield self.tokenizer(prompt, return_tensors='pt') + prompt = "" + else: + prompt = new_prompt + + + yield self.tokenizer(prompt, return_tensors='pt') + + raise StopIteration + +if __name__ == '__main__': + tokenizer = AutoTokenizer.from_pretrained('gpt2') + dataset = RolloutsAsLanguageModellingTask(tokenizer, './decoded_rollouts') + for ex in dataset: + print(tokenizer.decode(ex['input_ids'][0])) + print('\n---------\n') \ No newline at end of file From 2f2ac70776eb2ab1bc742ae3fbf2ce3f19f60933 Mon Sep 17 00:00:00 2001 From: Thomas Foster Date: Thu, 5 Jan 2023 19:09:58 +0000 Subject: [PATCH 09/21] rename to reflect that this is a generic wrapper for any rollout, not just roc stories --- .../sentiment-data/rollouts_as_lm_task.py | 63 +++++++++++++++++++ 1 file changed, 63 insertions(+) create mode 100644 algorithm_distillation/sentiment-data/rollouts_as_lm_task.py diff --git a/algorithm_distillation/sentiment-data/rollouts_as_lm_task.py b/algorithm_distillation/sentiment-data/rollouts_as_lm_task.py new file mode 100644 index 0000000..4858262 --- /dev/null +++ b/algorithm_distillation/sentiment-data/rollouts_as_lm_task.py @@ -0,0 +1,63 @@ +import torch +from pathlib import Path +import json +from transformers import AutoTokenizer +from typing import Dict, Any + +class RolloutsAsLanguageModellingTask(torch.utils.data.IterableDataset): + def __init__(self, tokenizer: AutoTokenizer, rollouts_folder_fpath: str): + self.tokenizer = tokenizer + self.rollouts_folder = Path(rollouts_folder_fpath) + + def format_rollout(self, d: Dict[Any, Any]) -> str: + return f"Prompt:{d['query_text']}\nCompletion:{d['response_text']}\nReward:{d['rewards'][-1]}\n\n" + + def tokenize_for_training(self, x: str): + inputs = self.tokenizer(x, truncation=True, return_tensors='pt') + return { + 'input_ids': inputs.input_ids, + 'attention_mask': inputs.attention_mask, + 'labels': inputs.input_ids + } + + def __iter__(self): + + runs = [run for run in self.rollouts_folder.iterdir() if run.name.startswith('run-e')] + print(f'Found {len(runs)} runs...') + for run in runs: + + config = json.loads(open(run / 'config.json', 'r').read()) + epochs = [epoch for epoch in run.iterdir() if epoch.name != 'config.json'] + print(f'Run {run.name} has {len(epochs)} epochs...') + epochs = sorted(epochs) + for epoch in epochs: + # print(f'Yielding from epoch {epoch.name}') + + rollouts = json.loads(open(epoch, 'r').read()) + + rollout_idx = 0 + prompt = "" + while rollout_idx < len(rollouts): + rollout = self.format_rollout(rollouts[rollout_idx]) + rollout_idx += 1 + + new_prompt = prompt + rollout + new_ids = self.tokenizer.tokenize(new_prompt) + + if len(new_ids) > self.tokenizer.model_max_length: # self.tokenizer.model_max_length: + yield self.tokenize_for_training(prompt) + prompt = "" + else: + prompt = new_prompt + + if len(prompt) > 0: + yield self.tokenize_for_training(prompt) + + raise StopIteration + +if __name__ == '__main__': + tokenizer = AutoTokenizer.from_pretrained('gpt2') + dataset = RolloutsAsLanguageModellingTask(tokenizer, './decoded_rollouts') + for ex in dataset: + print(tokenizer.decode(ex['input_ids'][0])) + print('\n---------\n') \ No newline at end of file From f368d936e9b2b367fdcfa95e6900879a87f10a2b Mon Sep 17 00:00:00 2001 From: Thomas Foster Date: Thu, 5 Jan 2023 19:10:19 +0000 Subject: [PATCH 10/21] super simple script to train lm with accelerate on the rollout data --- .../sentiment-data/train_lm.py | 31 +++++++++++++++++++ 1 file changed, 31 insertions(+) create mode 100644 algorithm_distillation/sentiment-data/train_lm.py diff --git a/algorithm_distillation/sentiment-data/train_lm.py b/algorithm_distillation/sentiment-data/train_lm.py new file mode 100644 index 0000000..455c279 --- /dev/null +++ b/algorithm_distillation/sentiment-data/train_lm.py @@ -0,0 +1,31 @@ +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer +from datasets import load_dataset +from accelerate import Accelerator +from rollouts_as_lm_task import RolloutsAsLanguageModellingTask +from tqdm.auto import tqdm + +accelerator = Accelerator() + +model = AutoModelForCausalLM.from_pretrained('gpt2') +optimizer = torch.optim.Adam(model.parameters()) + +tokenizer = AutoTokenizer.from_pretrained('gpt2') +dataset = RolloutsAsLanguageModellingTask(tokenizer, './decoded_rollouts') +data = torch.utils.data.DataLoader(dataset, shuffle=False) + +model, optimizer, data = accelerator.prepare(model, optimizer, data) + +model.train() +for epoch in range(10): + for batch in tqdm(data, desc=f'Epoch {epoch}'): + + optimizer.zero_grad() + + output = model(**batch) + + loss = output.loss + + accelerator.backward(loss) + + optimizer.step() \ No newline at end of file From b983670f857b6fe8c452f05daab70f127d4c7da0 Mon Sep 17 00:00:00 2001 From: Thomas Foster Date: Fri, 6 Jan 2023 15:34:52 +0000 Subject: [PATCH 11/21] add eval loop --- .../sentiment-data/train_lm.py | 41 ++++++++++++++----- 1 file changed, 30 insertions(+), 11 deletions(-) diff --git a/algorithm_distillation/sentiment-data/train_lm.py b/algorithm_distillation/sentiment-data/train_lm.py index 455c279..2b32e3d 100644 --- a/algorithm_distillation/sentiment-data/train_lm.py +++ b/algorithm_distillation/sentiment-data/train_lm.py @@ -8,24 +8,43 @@ accelerator = Accelerator() model = AutoModelForCausalLM.from_pretrained('gpt2') -optimizer = torch.optim.Adam(model.parameters()) +optimizer = torch.optim.Adam(model.parameters(), lr=5e-5) tokenizer = AutoTokenizer.from_pretrained('gpt2') -dataset = RolloutsAsLanguageModellingTask(tokenizer, './decoded_rollouts') -data = torch.utils.data.DataLoader(dataset, shuffle=False) +train_dataset = RolloutsAsLanguageModellingTask(tokenizer, './decoded_rollouts/train') +eval_dataset = RolloutsAsLanguageModellingTask(tokenizer, './decoded_rollouts/eval') -model, optimizer, data = accelerator.prepare(model, optimizer, data) +train_dataloader = torch.utils.data.DataLoader(train_dataset, shuffle=False) + +model, optimizer, train_dataloader = accelerator.prepare(model, optimizer, train_dataloader) model.train() +total_steps = 0 for epoch in range(10): - for batch in tqdm(data, desc=f'Epoch {epoch}'): - + for batch in tqdm(train_dataloader, desc=f'Training epoch {epoch}'): + # train optimizer.zero_grad() - output = model(**batch) - loss = output.loss - accelerator.backward(loss) - - optimizer.step() \ No newline at end of file + optimizer.step() + total_steps += 1 + + # eval + eval_steps = 25 + eval_size = 20 + if total_steps % eval_steps == 0: + model.eval() + eval_loss = 0 + eval_dataloader = torch.utils.data.DataLoader(eval_dataset, shuffle=False) + eval_dataloader = accelerator.prepare(eval_dataloader) + for idx, batch in tqdm(enumerate(eval_dataloader), desc=f'Eval after {total_steps} steps.'): + if idx >= eval_size: + break + output = model(**batch) + loss = output.loss + eval_loss += loss.item() + eval_loss /= idx + print(f'Avg loss after {total_steps}: {eval_loss}') + model.train() + \ No newline at end of file From 23410dda1bb4b5bf3a793ebfc690177c51ccc24c Mon Sep 17 00:00:00 2001 From: Thomas Foster Date: Fri, 6 Jan 2023 15:36:49 +0000 Subject: [PATCH 12/21] add verbosity flag for dataset --- .../sentiment-data/rollouts_as_lm_task.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/algorithm_distillation/sentiment-data/rollouts_as_lm_task.py b/algorithm_distillation/sentiment-data/rollouts_as_lm_task.py index 4858262..23a3f01 100644 --- a/algorithm_distillation/sentiment-data/rollouts_as_lm_task.py +++ b/algorithm_distillation/sentiment-data/rollouts_as_lm_task.py @@ -5,9 +5,10 @@ from typing import Dict, Any class RolloutsAsLanguageModellingTask(torch.utils.data.IterableDataset): - def __init__(self, tokenizer: AutoTokenizer, rollouts_folder_fpath: str): + def __init__(self, tokenizer: AutoTokenizer, rollouts_folder_fpath: str, verbose: bool = True): self.tokenizer = tokenizer self.rollouts_folder = Path(rollouts_folder_fpath) + self.verbose = verbose def format_rollout(self, d: Dict[Any, Any]) -> str: return f"Prompt:{d['query_text']}\nCompletion:{d['response_text']}\nReward:{d['rewards'][-1]}\n\n" @@ -22,13 +23,18 @@ def tokenize_for_training(self, x: str): def __iter__(self): - runs = [run for run in self.rollouts_folder.iterdir() if run.name.startswith('run-e')] - print(f'Found {len(runs)} runs...') + if self.verbose: + runs = [run for run in self.rollouts_folder.iterdir() if run.name.startswith('run-e')] + + print(f'Iterating over {len(runs)} runs...') for run in runs: config = json.loads(open(run / 'config.json', 'r').read()) epochs = [epoch for epoch in run.iterdir() if epoch.name != 'config.json'] - print(f'Run {run.name} has {len(epochs)} epochs...') + + if self.verbose: + print(f'...and {run.name} has {len(epochs)} epochs.') + epochs = sorted(epochs) for epoch in epochs: # print(f'Yielding from epoch {epoch.name}') From ca5c5e270ddd00bd0b26b045437b17c613a9779d Mon Sep 17 00:00:00 2001 From: Thomas Foster Date: Fri, 6 Jan 2023 15:38:23 +0000 Subject: [PATCH 13/21] use verbosity flag --- algorithm_distillation/sentiment-data/train_lm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithm_distillation/sentiment-data/train_lm.py b/algorithm_distillation/sentiment-data/train_lm.py index 2b32e3d..32c2cfb 100644 --- a/algorithm_distillation/sentiment-data/train_lm.py +++ b/algorithm_distillation/sentiment-data/train_lm.py @@ -12,7 +12,7 @@ tokenizer = AutoTokenizer.from_pretrained('gpt2') train_dataset = RolloutsAsLanguageModellingTask(tokenizer, './decoded_rollouts/train') -eval_dataset = RolloutsAsLanguageModellingTask(tokenizer, './decoded_rollouts/eval') +eval_dataset = RolloutsAsLanguageModellingTask(tokenizer, './decoded_rollouts/eval', verbose=False) train_dataloader = torch.utils.data.DataLoader(train_dataset, shuffle=False) From 3935c55ffc0dc37ef324aa9b7c4223bbf04c3450 Mon Sep 17 00:00:00 2001 From: Thomas Foster Date: Fri, 6 Jan 2023 16:33:33 +0000 Subject: [PATCH 14/21] add flag for yielding prompt only for generation during evaluation --- .../sentiment-data/rollouts_as_lm_task.py | 41 +++++++++++-------- .../sentiment-data/train_lm.py | 28 +++++++++++++ 2 files changed, 53 insertions(+), 16 deletions(-) diff --git a/algorithm_distillation/sentiment-data/rollouts_as_lm_task.py b/algorithm_distillation/sentiment-data/rollouts_as_lm_task.py index 23a3f01..9e436ab 100644 --- a/algorithm_distillation/sentiment-data/rollouts_as_lm_task.py +++ b/algorithm_distillation/sentiment-data/rollouts_as_lm_task.py @@ -5,13 +5,14 @@ from typing import Dict, Any class RolloutsAsLanguageModellingTask(torch.utils.data.IterableDataset): - def __init__(self, tokenizer: AutoTokenizer, rollouts_folder_fpath: str, verbose: bool = True): + def __init__(self, tokenizer: AutoTokenizer, rollouts_folder_fpath: str, for_generation: bool = False, verbose: bool = True): self.tokenizer = tokenizer self.rollouts_folder = Path(rollouts_folder_fpath) self.verbose = verbose + self.for_generation = for_generation def format_rollout(self, d: Dict[Any, Any]) -> str: - return f"Prompt:{d['query_text']}\nCompletion:{d['response_text']}\nReward:{d['rewards'][-1]}\n\n" + return f"Prompt: {d['query_text']}\nCompletion: {d['response_text']}\nReward: {d['rewards'][-1]}\n\n" def tokenize_for_training(self, x: str): inputs = self.tokenizer(x, truncation=True, return_tensors='pt') @@ -23,10 +24,10 @@ def tokenize_for_training(self, x: str): def __iter__(self): - if self.verbose: - runs = [run for run in self.rollouts_folder.iterdir() if run.name.startswith('run-e')] - - print(f'Iterating over {len(runs)} runs...') + runs = [run for run in self.rollouts_folder.iterdir() if run.name.startswith('run-e')] + if self.verbose: + print(f'Iterating over {len(runs)} runs...') + for run in runs: config = json.loads(open(run / 'config.json', 'r').read()) @@ -44,17 +45,25 @@ def __iter__(self): rollout_idx = 0 prompt = "" while rollout_idx < len(rollouts): - rollout = self.format_rollout(rollouts[rollout_idx]) - rollout_idx += 1 - - new_prompt = prompt + rollout - new_ids = self.tokenizer.tokenize(new_prompt) - if len(new_ids) > self.tokenizer.model_max_length: # self.tokenizer.model_max_length: - yield self.tokenize_for_training(prompt) - prompt = "" + if self.for_generation: + d = rollouts[rollout_idx] + rollout_idx += 1 + generation_prompt = f"Prompt: {d['query_text']}\nCompletion:" + yield self.tokenize_for_training(generation_prompt) + else: - prompt = new_prompt + rollout = self.format_rollout(rollouts[rollout_idx]) + rollout_idx += 1 + + new_prompt = prompt + rollout + new_ids = self.tokenizer.tokenize(new_prompt) + + if len(new_ids) > self.tokenizer.model_max_length: # self.tokenizer.model_max_length: + yield self.tokenize_for_training(prompt) + prompt = "" + else: + prompt = new_prompt if len(prompt) > 0: yield self.tokenize_for_training(prompt) @@ -63,7 +72,7 @@ def __iter__(self): if __name__ == '__main__': tokenizer = AutoTokenizer.from_pretrained('gpt2') - dataset = RolloutsAsLanguageModellingTask(tokenizer, './decoded_rollouts') + dataset = RolloutsAsLanguageModellingTask(tokenizer, './decoded_rollouts', for_generation=True) for ex in dataset: print(tokenizer.decode(ex['input_ids'][0])) print('\n---------\n') \ No newline at end of file diff --git a/algorithm_distillation/sentiment-data/train_lm.py b/algorithm_distillation/sentiment-data/train_lm.py index 32c2cfb..d6579d1 100644 --- a/algorithm_distillation/sentiment-data/train_lm.py +++ b/algorithm_distillation/sentiment-data/train_lm.py @@ -13,6 +13,7 @@ tokenizer = AutoTokenizer.from_pretrained('gpt2') train_dataset = RolloutsAsLanguageModellingTask(tokenizer, './decoded_rollouts/train') eval_dataset = RolloutsAsLanguageModellingTask(tokenizer, './decoded_rollouts/eval', verbose=False) +generate_dataset = RolloutsAsLanguageModellingTask(tokenizer, './decoded_rollouts/eval', for_generation=True, verbose=False) train_dataloader = torch.utils.data.DataLoader(train_dataset, shuffle=False) @@ -47,4 +48,31 @@ eval_loss /= idx print(f'Avg loss after {total_steps}: {eval_loss}') model.train() + + # generate examples + generate_steps = 10 + generate_size = 3 + if total_steps % generate_steps == 0: + model.eval() + + generate_dataloader = torch.utils.data.DataLoader(generate_dataset, shuffle=False) + generate_dataloader = accelerator.prepare(generate_dataloader) + + print('-------') + print('Generating examples') + print('--------') + + for idx, batch in enumerate(generate_dataloader): + if idx >= generate_size: + break + batch = {k:v.squeeze(0) for k,v in batch.items()} + outputs = model.generate(**batch, max_length=100, do_sample=True, pad_token_id=tokenizer.eos_token_id) + text = tokenizer.decode(outputs[0]) + + + print(text) + print('------------------') + print() + + model.train() \ No newline at end of file From eb6708163bcde0925c992a0ffd705553c578196c Mon Sep 17 00:00:00 2001 From: Thomas Foster Date: Fri, 6 Jan 2023 17:24:48 +0000 Subject: [PATCH 15/21] fix iterator --- algorithm_distillation/sentiment-data/rollouts_as_lm_task.py | 1 - 1 file changed, 1 deletion(-) diff --git a/algorithm_distillation/sentiment-data/rollouts_as_lm_task.py b/algorithm_distillation/sentiment-data/rollouts_as_lm_task.py index 9e436ab..e2c4c3d 100644 --- a/algorithm_distillation/sentiment-data/rollouts_as_lm_task.py +++ b/algorithm_distillation/sentiment-data/rollouts_as_lm_task.py @@ -68,7 +68,6 @@ def __iter__(self): if len(prompt) > 0: yield self.tokenize_for_training(prompt) - raise StopIteration if __name__ == '__main__': tokenizer = AutoTokenizer.from_pretrained('gpt2') From e949b19430829674db90cef534d8678cbe07be9a Mon Sep 17 00:00:00 2001 From: Thomas Foster Date: Fri, 6 Jan 2023 17:26:28 +0000 Subject: [PATCH 16/21] add wandb logging --- .../sentiment-data/train_lm.py | 20 +++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/algorithm_distillation/sentiment-data/train_lm.py b/algorithm_distillation/sentiment-data/train_lm.py index d6579d1..e806f64 100644 --- a/algorithm_distillation/sentiment-data/train_lm.py +++ b/algorithm_distillation/sentiment-data/train_lm.py @@ -1,10 +1,15 @@ import torch +import wandb from transformers import AutoModelForCausalLM, AutoTokenizer from datasets import load_dataset from accelerate import Accelerator from rollouts_as_lm_task import RolloutsAsLanguageModellingTask +from utils import ShuffledIterableDataset from tqdm.auto import tqdm +wandb.init(project="algorithm-distillation") +logging_table = wandb.Table(columns=["Step", "Generation"]) + accelerator = Accelerator() model = AutoModelForCausalLM.from_pretrained('gpt2') @@ -27,6 +32,7 @@ optimizer.zero_grad() output = model(**batch) loss = output.loss + wandb.log({'loss': loss.item(), 'step': total_steps}) accelerator.backward(loss) optimizer.step() total_steps += 1 @@ -46,6 +52,7 @@ loss = output.loss eval_loss += loss.item() eval_loss /= idx + wandb.log({'eval_loss': eval_loss, 'step': total_steps}) print(f'Avg loss after {total_steps}: {eval_loss}') model.train() @@ -58,21 +65,18 @@ generate_dataloader = torch.utils.data.DataLoader(generate_dataset, shuffle=False) generate_dataloader = accelerator.prepare(generate_dataloader) - print('-------') - print('Generating examples') - print('--------') - for idx, batch in enumerate(generate_dataloader): if idx >= generate_size: break batch = {k:v.squeeze(0) for k,v in batch.items()} outputs = model.generate(**batch, max_length=100, do_sample=True, pad_token_id=tokenizer.eos_token_id) text = tokenizer.decode(outputs[0]) - - print(text) - print('------------------') - print() + logging_table.add_data(total_steps, text) + + logging_table = wandb.Table(columns=logging_table.columns, data=logging_table.data) + wandb.log({'Generations Table': logging_table}) + model.train() \ No newline at end of file From 5f67a14c02dca327b758c275fcfc3699d4c6ee67 Mon Sep 17 00:00:00 2001 From: Thomas Foster Date: Fri, 6 Jan 2023 17:26:51 +0000 Subject: [PATCH 17/21] add dataset shuffling --- .../sentiment-data/train_lm.py | 1 + .../sentiment-data/utils.py | 31 +++++++++++++++++++ 2 files changed, 32 insertions(+) create mode 100644 algorithm_distillation/sentiment-data/utils.py diff --git a/algorithm_distillation/sentiment-data/train_lm.py b/algorithm_distillation/sentiment-data/train_lm.py index e806f64..13ad30e 100644 --- a/algorithm_distillation/sentiment-data/train_lm.py +++ b/algorithm_distillation/sentiment-data/train_lm.py @@ -17,6 +17,7 @@ tokenizer = AutoTokenizer.from_pretrained('gpt2') train_dataset = RolloutsAsLanguageModellingTask(tokenizer, './decoded_rollouts/train') +train_dataset = ShuffledIterableDataset(train_dataset, buffer_size=10_000) eval_dataset = RolloutsAsLanguageModellingTask(tokenizer, './decoded_rollouts/eval', verbose=False) generate_dataset = RolloutsAsLanguageModellingTask(tokenizer, './decoded_rollouts/eval', for_generation=True, verbose=False) diff --git a/algorithm_distillation/sentiment-data/utils.py b/algorithm_distillation/sentiment-data/utils.py new file mode 100644 index 0000000..98abbf2 --- /dev/null +++ b/algorithm_distillation/sentiment-data/utils.py @@ -0,0 +1,31 @@ +import torch +import random + +class ShuffledIterableDataset(torch.utils.data.IterableDataset): + + def __init__(self, original_dataset: torch.utils.data.IterableDataset, buffer_size: int = 10_000): + self.original_dataset = original_dataset + self.buffer_size = buffer_size + + def __iter__(self): + original_iterator = iter(self.original_dataset) + + # fill buffer (or until runs out) + buffer = [] + while len(buffer) < self.buffer_size: + try: + x = next(original_iterator) + except StopIteration: + break + buffer.append(x) + + # shuffle, yield and replace until original runs out + for x in original_iterator: + random.shuffle(buffer) + yield buffer[-1] + buffer[-1] = x + + # empty the remaining + for x in buffer: + yield x + \ No newline at end of file From b336e171546dd1f000f3ec1b039eb9443ad94661 Mon Sep 17 00:00:00 2001 From: Thomas Foster Date: Sun, 29 Jan 2023 10:51:29 +0000 Subject: [PATCH 18/21] folder level refactor --- .gitignore | 8 +- README.md | 46 +++++++- .../casual_lm /train.py} | 0 .../sentiment-data/decode_rollout.py | 101 ------------------ .../ppo_roc_story_sentiment_rollouts.py | 55 ---------- .../ppo_roc_story_sentiments.py | 59 ---------- .../.gitkeep => tasks/lm/__init__.py} | 0 .../tasks/lm/sentiment/__init__.py | 1 + .../lm/sentiment/dataset.py} | 19 +++- .../lm/sentiment}/decode_rollouts.py | 0 ...nerate_ppo_roc_story_sentiment_rollouts.py | 0 .../lm/sentiment}/ppo_config.yml | 4 +- .../lm/sentiment}/ppo_sweep.yml | 2 +- .../tasks/lm/sentiment/rollouts/.gitkeep | 0 .../lm/sentiment}/utils.py | 0 algorithm_distillation/tasks/rl/__init__.py | 0 16 files changed, 69 insertions(+), 226 deletions(-) rename algorithm_distillation/{sentiment-data/train_lm.py => models/casual_lm /train.py} (100%) delete mode 100644 algorithm_distillation/sentiment-data/decode_rollout.py delete mode 100644 algorithm_distillation/sentiment-data/ppo_roc_story_sentiment_rollouts.py delete mode 100644 algorithm_distillation/sentiment-data/ppo_roc_story_sentiments.py rename algorithm_distillation/{sentiment-data/rollouts/.gitkeep => tasks/lm/__init__.py} (100%) create mode 100644 algorithm_distillation/tasks/lm/sentiment/__init__.py rename algorithm_distillation/{sentiment-data/rollouts_as_lm_task.py => tasks/lm/sentiment/dataset.py} (80%) rename algorithm_distillation/{sentiment-data => tasks/lm/sentiment}/decode_rollouts.py (100%) rename algorithm_distillation/{sentiment-data => tasks/lm/sentiment}/generate_ppo_roc_story_sentiment_rollouts.py (100%) rename algorithm_distillation/{sentiment-data => tasks/lm/sentiment}/ppo_config.yml (93%) rename algorithm_distillation/{sentiment-data => tasks/lm/sentiment}/ppo_sweep.yml (94%) create mode 100644 algorithm_distillation/tasks/lm/sentiment/rollouts/.gitkeep rename algorithm_distillation/{sentiment-data => tasks/lm/sentiment}/utils.py (100%) create mode 100644 algorithm_distillation/tasks/rl/__init__.py diff --git a/.gitignore b/.gitignore index b672c9e..6b91dd2 100644 --- a/.gitignore +++ b/.gitignore @@ -1,8 +1,8 @@ trlx -algorithm_distillation/sentiment-data/ray_results/ -algorithm_distillation/sentiment-data/rollouts/run-* -algorithm_distillation/sentiment-data/decoded_rollouts -algorithm_distillation/sentiment-data/wandb/ +algorithm_distillation/tasks/lm/sentiment/ray_results/ +algorithm_distillation/tasks/lm/sentiment/rollouts/run-* +algorithm_distillation/tasks/lm/sentiment/decoded_rollouts +algorithm_distillation/tasks/lm/sentiment/wandb/ wandb/ diff --git a/README.md b/README.md index fea9271..9d4e03b 100644 --- a/README.md +++ b/README.md @@ -1 +1,45 @@ -# Algorithm-Distillation-RLHF \ No newline at end of file +# Algorithm-Distillation-RLHF + +A reinforcement learning algorithm is characterised by the trajectories it generates during training. + +We are interested in "algorithm distillation" - whether trajectories can be modelled by transformers, as studied in the original deepmind algorithm distillation paper. + +A particular focus of this repo is to extend prior work to the case where: +1. the trajectories have been generated by the TRLx library during RLHF training of language models +2. the transformer modelling the trajectories is itself a standard language model + + +## On data formats + +A trajectory is typically defined as a list of `(state, action, reward)` triples. For training purposes, it is sometimes useful to augment this to include `logprobs`, which is, for each triple `(s, a, r)`, the probability of taking action $a$ at state $s$ as determined the policy generating the trajectory. + +We therefore define an **RL Format Trajectory** as a sequence of `(state, action, reward, logprobs)` tuples. + +The typical way to learn to model these trajectories with a transformer is to seperately map the final hidden state using 3 different heads. That is, for a given triple `(s,a,r,l)` a transformer $f$ maps to $(\hat{s}, \hat{a}, \hat{r}, \hat{l})$. + +In this repo, this is done via the models in `/models/rl`. + +We are also interested in the ability of standard language models (with language modeling heads) to learn trajectories. To this end we define a **Language Format Trajectory** as a trajectory serialised into a string. There are many possible ways to do this, and the optimal one requires investigation. For example, for trajectories generated using TRLx when finetuning a language model on positive sentiment, we can format the trajectory as the string: + +``` +prompt: Dan went down to the shops. +completion: He smiled as he walked - the sun was shining. +reward: 0.9975 +### +``` + +It's less obvious how to do this when the task is not a language task, such as moonlander. Enumerating the states as coordinates might work, but requires experimentation. + +Trajectories in *Language format* are learn by models in `/models/lm`. + +## To summarise: + +`/models` contains the "algorithm distillation models", transformers that are trained in a supervised fashion to learn RL trajectories. We distinguish between models that operate on *RL Format* trajectories and *Language format* trajectories. + +`/tasks` contains code to produce the RL trajectories that the models learn. It can store this data however it likes, but each task should expose a `torch.utils.data.Dataset` that can return trajectory data in either *RL Format* or *Language format*. + +## ToDo: + +[ ] Set up repo structure (just for your language stuff, @H can add in his) +[ ] Post guide and project tasks on discord +[ ] Run some preliminary experiments \ No newline at end of file diff --git a/algorithm_distillation/sentiment-data/train_lm.py b/algorithm_distillation/models/casual_lm /train.py similarity index 100% rename from algorithm_distillation/sentiment-data/train_lm.py rename to algorithm_distillation/models/casual_lm /train.py diff --git a/algorithm_distillation/sentiment-data/decode_rollout.py b/algorithm_distillation/sentiment-data/decode_rollout.py deleted file mode 100644 index 13b2b65..0000000 --- a/algorithm_distillation/sentiment-data/decode_rollout.py +++ /dev/null @@ -1,101 +0,0 @@ -from pathlib import Path -import json -from transformers import AutoTokenizer -import click -from typing import Union -from shutil import copy2 - -@click.group() -def main(): - """ - CLI for formatting the rollouts into training data. - - \b - 1. decode-epoch: Decode a single epoch rollout .json file - 2. decode-run: Decode an entire PPO run (multiple epoch files) - 3. decode-rollouts: Deocde an entire directory (multiple runs) - """ - -def get_tokenizer(tokenizer: Union[str, AutoTokenizer]) -> AutoTokenizer: - if isinstance(tokenizer, str): - return AutoTokenizer.from_pretrained(tokenizer) - else: - return tokenizer - - -@main.command() -@click.option('--tokenizer', type=str, default='gpt2', help='tokenizer to decode with') -@click.option('--input-fpath', type=click.Path(path_type=Path), help='the input JSON file') -@click.option('--output-fpath', type=click.Path(path_type=Path), help='the path of the JSON file to be created. Will overwrite if necessary.') -def decode_epoch(tokenizer: Union[str, AutoTokenizer], input_fpath: Path, output_fpath: Path): - _decode_epoch(tokenizer, input_fpath, output_fpath) - -def _decode_epoch(tokenizer: Union[str, AutoTokenizer], input_fpath: Path, output_fpath: Path): - - assert input_fpath.exists() - assert input_fpath.is_file() - assert input_fpath.name.endswith('.json') - assert output_fpath.name.endswith('.json') - if not output_fpath.parent.exists(): - output_fpath.parent.mkdir() - - tokenizer = get_tokenizer(tokenizer) - - with open(input_fpath, 'r') as f: - rollouts = json.loads(f.read()) - - for rollout in rollouts: - rollout['query_text'] = tokenizer.decode(rollout['query_tensor'], skip_special_tokens=True) - rollout['response_text'] = tokenizer.decode(rollout['response_tensor'], skip_special_tokens=True) - - with open(output_fpath, 'w') as f: - f.write(json.dumps(rollouts, indent=2)) - - -@main.command() -@click.option('--tokenizer', type=str, default='gpt2', help='tokenizer to decode with') -@click.option('--input-fpath', type=click.Path(path_type=Path), help='the directory containing the JSON epoch files') -@click.option('--output-fpath', type=click.Path(path_type=Path), help='the path of the folder to be created.') -def decode_run(tokenizer: Union[str, AutoTokenizer], input_fpath: Path, output_fpath: Path): - _decode_run(tokenizer, input_fpath, output_fpath) - -def _decode_run(tokenizer: Union[str, AutoTokenizer], input_fpath: Path, output_fpath: Path): - - assert input_fpath.exists() - assert input_fpath.is_dir() - if not output_fpath.exists(): - output_fpath.mkdir() - - # Copy over the config file - assert (input_fpath / 'config.json').exists() - copy2(input_fpath / 'config.json', output_fpath / 'config.json') - - # Decode the rest of the files - tokenizer = get_tokenizer(tokenizer) - - epochs = [fpath for fpath in input_fpath.iterdir() if fpath.name != 'config.json'] - for epoch in epochs: - _decode_epoch(tokenizer, epoch, output_fpath / epoch.name) - - -@main.command() -@click.option('--tokenizer', type=str, default='gpt2', help='tokenizer to decode with') -@click.option('--input-fpath', type=click.Path(path_type=Path), help='the input directory') -@click.option('--output-fpath', type=click.Path(path_type=Path), help='the output directory') -def decode_rollouts(tokenizer: str, input_fpath: Path, output_fpath: Path): - _decode_rollouts(tokenizer, input_fpath, output_fpath) - -def _decode_rollouts(tokenizer: str, input_fpath: Path, output_fpath: Path): - - assert input_fpath.exists() - assert input_fpath.is_dir() - if not output_fpath.exists(): - output_fpath.mkdir() - - tokenizer = get_tokenizer(tokenizer) - runs = [fpath for fpath in input_fpath.iterdir() if fpath.name.startswith('run-')] - for run in runs: - _decode_run(tokenizer, run, output_fpath / run.name) - -if __name__ == '__main__': - main() \ No newline at end of file diff --git a/algorithm_distillation/sentiment-data/ppo_roc_story_sentiment_rollouts.py b/algorithm_distillation/sentiment-data/ppo_roc_story_sentiment_rollouts.py deleted file mode 100644 index 424dcf5..0000000 --- a/algorithm_distillation/sentiment-data/ppo_roc_story_sentiment_rollouts.py +++ /dev/null @@ -1,55 +0,0 @@ -import torch -from pathlib import Path -import json -from transformers import AutoTokenizer -from typing import Dict, Any - -class RolloutsAsLanguageModellingTask(torch.utils.data.IterableDataset): - def __init__(self, tokenizer: AutoTokenizer, rollouts_folder_fpath: str): - self.tokenizer = tokenizer - self.rollouts_folder = Path(rollouts_folder_fpath) - - def format_rollout(self, d: Dict[Any, Any]) -> str: - return f"Prompt:{d['query_text']}\nCompletion:{d['response_text']}\nReward:{d['rewards'][-1]}\n\n" - - def __iter__(self): - - runs = [run for run in self.rollouts_folder.iterdir() if run.name.startswith('run-e')] - print(f'Found {len(runs)} runs...') - for run in runs: - - config = json.loads(open(run / 'config.json', 'r').read()) - epochs = [epoch for epoch in run.iterdir() if epoch.name != 'config.json'] - print(f'Run {run.name} has {len(epochs)} epochs...') - epochs = sorted(epochs) - for epoch in epochs: - # print(f'Yielding from epoch {epoch.name}') - - rollouts = json.loads(open(epoch, 'r').read()) - - rollout_idx = 0 - prompt = "" - while rollout_idx < len(rollouts): - rollout = self.format_rollout(rollouts[rollout_idx]) - rollout_idx += 1 - - new_prompt = prompt + rollout - new_ids = self.tokenizer.tokenize(new_prompt) - - if len(new_ids) > tokenizer.model_max_length: # self.tokenizer.model_max_length: - yield self.tokenizer(prompt, return_tensors='pt') - prompt = "" - else: - prompt = new_prompt - - - yield self.tokenizer(prompt, return_tensors='pt') - - raise StopIteration - -if __name__ == '__main__': - tokenizer = AutoTokenizer.from_pretrained('gpt2') - dataset = RolloutsAsLanguageModellingTask(tokenizer, './decoded_rollouts') - for ex in dataset: - print(tokenizer.decode(ex['input_ids'][0])) - print('\n---------\n') \ No newline at end of file diff --git a/algorithm_distillation/sentiment-data/ppo_roc_story_sentiments.py b/algorithm_distillation/sentiment-data/ppo_roc_story_sentiments.py deleted file mode 100644 index 3679b95..0000000 --- a/algorithm_distillation/sentiment-data/ppo_roc_story_sentiments.py +++ /dev/null @@ -1,59 +0,0 @@ -from posixpath import dirname -from datasets import load_dataset -from transformers import pipeline -import os -import yaml -from functools import partial - -import trlx -import torch -from typing import List -from trlx.data.configs import TRLConfig - -from trlx.utils.loading import get_model, get_orchestrator, get_pipeline - -def get_score_for_label(label, scores): - "Extract value associated with a positive sentiment from pipeline's output" - label_to_score = {d['label'] : d['score'] for d in scores} - return label_to_score[label] - -default_config = yaml.safe_load(open(os.path.join(dirname(__file__), "ppo_config.yml"))) - -def main(hparams={}): - config = TRLConfig.update(default_config, hparams) - - if torch.cuda.is_available(): - device = int(os.environ.get("LOCAL_RANK", 0)) - else: - device = -1 - - - sentiment_fn = pipeline( - "sentiment-analysis", - "bhadresh-savani/distilbert-base-uncased-emotion", - truncation=True, - batch_size=256, - device=device, - return_all_scores=True, - ) - - def reward_fn(samples: List[str]) -> List[float]: - output_batch = sentiment_fn(samples) - sentiments = list(map(partial(get_score_for_label, 'joy'), output_batch)) - return sentiments - - stories = load_dataset("adamlin/roc_story") - prompts = [d['sentence1'] for d in stories['train']] - eval_prompts = [d for d in stories['validation'][:64]['sentence1']] - - model = trlx.train( - model_path="gpt2", - reward_fn=reward_fn, - prompts=prompts, - eval_prompts=eval_prompts, - config=config, - ) - - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/algorithm_distillation/sentiment-data/rollouts/.gitkeep b/algorithm_distillation/tasks/lm/__init__.py similarity index 100% rename from algorithm_distillation/sentiment-data/rollouts/.gitkeep rename to algorithm_distillation/tasks/lm/__init__.py diff --git a/algorithm_distillation/tasks/lm/sentiment/__init__.py b/algorithm_distillation/tasks/lm/sentiment/__init__.py new file mode 100644 index 0000000..e067b54 --- /dev/null +++ b/algorithm_distillation/tasks/lm/sentiment/__init__.py @@ -0,0 +1 @@ +from .dataset import SentimentTrajectories \ No newline at end of file diff --git a/algorithm_distillation/sentiment-data/rollouts_as_lm_task.py b/algorithm_distillation/tasks/lm/sentiment/dataset.py similarity index 80% rename from algorithm_distillation/sentiment-data/rollouts_as_lm_task.py rename to algorithm_distillation/tasks/lm/sentiment/dataset.py index e2c4c3d..69a0a3c 100644 --- a/algorithm_distillation/sentiment-data/rollouts_as_lm_task.py +++ b/algorithm_distillation/tasks/lm/sentiment/dataset.py @@ -2,9 +2,22 @@ from pathlib import Path import json from transformers import AutoTokenizer -from typing import Dict, Any +from typing import Dict, Any, Union -class RolloutsAsLanguageModellingTask(torch.utils.data.IterableDataset): +Dataset = Union[torch.utils.data.Dataset, torch.utils.data.IterableDataset] + +class SentimentTrajectories(Dataset): + def __init__(self, format:str, *args, **kwargs): + if format == "language": + self = SentimentAsLanguageTrajectories(*args, **kwargs) + elif format == "rl": + raise NotImplementedError() + # self = SentimentAsRlTrajectories(*stargs, **kwargs) + else: + raise RuntimeError(f"format must be either 'language' or 'rl', got: {format}") + + +class SentimentAsLanguageTrajectories(torch.utils.data.IterableDataset): def __init__(self, tokenizer: AutoTokenizer, rollouts_folder_fpath: str, for_generation: bool = False, verbose: bool = True): self.tokenizer = tokenizer self.rollouts_folder = Path(rollouts_folder_fpath) @@ -71,7 +84,7 @@ def __iter__(self): if __name__ == '__main__': tokenizer = AutoTokenizer.from_pretrained('gpt2') - dataset = RolloutsAsLanguageModellingTask(tokenizer, './decoded_rollouts', for_generation=True) + dataset = SentimentTrajectories("language", tokenizer, './decoded_rollouts', for_generation=False) for ex in dataset: print(tokenizer.decode(ex['input_ids'][0])) print('\n---------\n') \ No newline at end of file diff --git a/algorithm_distillation/sentiment-data/decode_rollouts.py b/algorithm_distillation/tasks/lm/sentiment/decode_rollouts.py similarity index 100% rename from algorithm_distillation/sentiment-data/decode_rollouts.py rename to algorithm_distillation/tasks/lm/sentiment/decode_rollouts.py diff --git a/algorithm_distillation/sentiment-data/generate_ppo_roc_story_sentiment_rollouts.py b/algorithm_distillation/tasks/lm/sentiment/generate_ppo_roc_story_sentiment_rollouts.py similarity index 100% rename from algorithm_distillation/sentiment-data/generate_ppo_roc_story_sentiment_rollouts.py rename to algorithm_distillation/tasks/lm/sentiment/generate_ppo_roc_story_sentiment_rollouts.py diff --git a/algorithm_distillation/sentiment-data/ppo_config.yml b/algorithm_distillation/tasks/lm/sentiment/ppo_config.yml similarity index 93% rename from algorithm_distillation/sentiment-data/ppo_config.yml rename to algorithm_distillation/tasks/lm/sentiment/ppo_config.yml index 7afc44a..d09339b 100644 --- a/algorithm_distillation/sentiment-data/ppo_config.yml +++ b/algorithm_distillation/tasks/lm/sentiment/ppo_config.yml @@ -1,5 +1,5 @@ model: - model_path: "lvwerra/gpt2-imdb" # Name of hf model to load + model_path: "lvwerra/gpt2-xl" # Name of hf model to load tokenizer_path: "gpt2" # Name of hf tokenizer to load model_type: "AcceleratePPOModel" # Name of accelerate model type to load num_layers_unfrozen: 2 # Number of bottom layers to freeze during training @@ -7,7 +7,7 @@ model: train: seq_length: 48 # Size of LM context epochs: 100 # Train for max(epochs, total_steps) - total_steps: 10000 # Train for max(epochs, total_steps) + total_steps: 1000 # Train for max(epochs, total_steps) batch_size: 128 # batch size lr_init: 1.0e-4 # init learning rate diff --git a/algorithm_distillation/sentiment-data/ppo_sweep.yml b/algorithm_distillation/tasks/lm/sentiment/ppo_sweep.yml similarity index 94% rename from algorithm_distillation/sentiment-data/ppo_sweep.yml rename to algorithm_distillation/tasks/lm/sentiment/ppo_sweep.yml index 95469ef..dc5faf0 100644 --- a/algorithm_distillation/sentiment-data/ppo_sweep.yml +++ b/algorithm_distillation/tasks/lm/sentiment/ppo_sweep.yml @@ -3,7 +3,7 @@ tune_config: metric: "mean_reward" search_alg: "random" scheduler: "fifo" - num_samples: 32 + num_samples: 4 # https://docs.ray.io/en/latest/tune/api_docs/search_space.html#tune-sample-docs lr_init: diff --git a/algorithm_distillation/tasks/lm/sentiment/rollouts/.gitkeep b/algorithm_distillation/tasks/lm/sentiment/rollouts/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/algorithm_distillation/sentiment-data/utils.py b/algorithm_distillation/tasks/lm/sentiment/utils.py similarity index 100% rename from algorithm_distillation/sentiment-data/utils.py rename to algorithm_distillation/tasks/lm/sentiment/utils.py diff --git a/algorithm_distillation/tasks/rl/__init__.py b/algorithm_distillation/tasks/rl/__init__.py new file mode 100644 index 0000000..e69de29 From 5cc1d91e0c459ac5d6dad49ce9d34926b121fa03 Mon Sep 17 00:00:00 2001 From: Thomas Foster Date: Sun, 29 Jan 2023 12:48:01 +0000 Subject: [PATCH 19/21] add mega simple example script --- README.md | 16 ++++++- algorithm_distillation/__init__.py | 0 .../casual_lm /{train.py => legacy_train.py} | 0 .../tasks/{lm/sentiment => }/utils.py | 0 algorithm_distillation/train.py | 42 +++++++++++++++++++ 5 files changed, 56 insertions(+), 2 deletions(-) create mode 100644 algorithm_distillation/__init__.py rename algorithm_distillation/models/casual_lm /{train.py => legacy_train.py} (100%) rename algorithm_distillation/tasks/{lm/sentiment => }/utils.py (100%) create mode 100644 algorithm_distillation/train.py diff --git a/README.md b/README.md index 9d4e03b..97d33db 100644 --- a/README.md +++ b/README.md @@ -38,8 +38,20 @@ Trajectories in *Language format* are learn by models in `/models/lm`. `/tasks` contains code to produce the RL trajectories that the models learn. It can store this data however it likes, but each task should expose a `torch.utils.data.Dataset` that can return trajectory data in either *RL Format* or *Language format*. +## Generating trajectory data +I am using my own fork of TRLx that has rollout logging. + ## ToDo: -[ ] Set up repo structure (just for your language stuff, @H can add in his) +Today: +[X] Set up repo structure (just for your language stuff, @H can add in his) +[ ] Add train script for models/lm/casuallm (25 mins) +[ ] Clone H's stuff and merge with @H stuff (/models/rl) and (/tasks/rl) (25 mins) +[ ] Proper PR for TRLx (25 mins) [ ] Post guide and project tasks on discord -[ ] Run some preliminary experiments \ No newline at end of file + +Future: +[ ] Add online evaluation script for models/lm/casuallm +[ ] Improve train script to include reward accuracy +[ ] Run some preliminary experiments +[ ] Add __main__ file with click CLI interface for running experiments \ No newline at end of file diff --git a/algorithm_distillation/__init__.py b/algorithm_distillation/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/algorithm_distillation/models/casual_lm /train.py b/algorithm_distillation/models/casual_lm /legacy_train.py similarity index 100% rename from algorithm_distillation/models/casual_lm /train.py rename to algorithm_distillation/models/casual_lm /legacy_train.py diff --git a/algorithm_distillation/tasks/lm/sentiment/utils.py b/algorithm_distillation/tasks/utils.py similarity index 100% rename from algorithm_distillation/tasks/lm/sentiment/utils.py rename to algorithm_distillation/tasks/utils.py diff --git a/algorithm_distillation/train.py b/algorithm_distillation/train.py new file mode 100644 index 0000000..d82e808 --- /dev/null +++ b/algorithm_distillation/train.py @@ -0,0 +1,42 @@ + +from .tasks.lm.sentiment import SentimentTrajectories +from .tasks.utils import ShuffledIterableDataset + +from torch.utils.data import DataLoader +from torch.optim import Adam +from tqdm.auto import tqdm +import wandb + +accelerator = Accelerator() + +# Logging inits +wandb.init(project="algorithm-distillation") +logging_table = wandb.Table(columns=['step', 'generation']) + +# Data +tokenizer = AutoTokenizer.from_pretrained('gpt2') +train_dataset = SentimentTrajectories(format="language", tokenizer=tokenizer) +train_dataset = ShuffledIterableDataset(train_dataset, buffer_size=10_000) +# eval_dataset = ... +# generate_dataset = ... +train_dataloader = DataLoader(train_dataset, shuffle=False) + +# Setup parameters for training with accelerate +model = AutoModelForCausalLM.from_pretrained('gpt2') +optimizer = Adam(model.parameters(), lr=5e-5) +model, optimizer, train_dataloader = accelerator.prepare(model, optimizer, train_dataloader) + +# Train +model.train() +total_steps = 0 +for epoch in range(10): + for batch in tqdm(train_dataloader, desc=f'Training epoch {epoch}'): + + optimizer.zero_grad() + output = model(**batch) + loss = output.loss + wandb.log({'loss': loss.item(), 'step': total_steps}) + accelerator.backward(loss) + optimizer.step() + total_steps +=1 + \ No newline at end of file From 6b422f8a34392f28f5bbaaae886756450dd9870d Mon Sep 17 00:00:00 2001 From: Thomas Foster Date: Sun, 29 Jan 2023 13:11:05 +0000 Subject: [PATCH 20/21] fix train script --- README.md | 3 +- .../tasks/lm/sentiment/__init__.py | 2 +- .../tasks/lm/sentiment/dataset.py | 34 +++++++++---------- algorithm_distillation/train.py | 8 +++-- 4 files changed, 25 insertions(+), 22 deletions(-) diff --git a/README.md b/README.md index 97d33db..783ac49 100644 --- a/README.md +++ b/README.md @@ -45,12 +45,13 @@ I am using my own fork of TRLx that has rollout logging. Today: [X] Set up repo structure (just for your language stuff, @H can add in his) -[ ] Add train script for models/lm/casuallm (25 mins) +[X] Add train script for models/lm/casuallm [ ] Clone H's stuff and merge with @H stuff (/models/rl) and (/tasks/rl) (25 mins) [ ] Proper PR for TRLx (25 mins) [ ] Post guide and project tasks on discord Future: +[ ] Add more elegant meta class switching between ...LanguageTrajectories and ...RlTrajectories [ ] Add online evaluation script for models/lm/casuallm [ ] Improve train script to include reward accuracy [ ] Run some preliminary experiments diff --git a/algorithm_distillation/tasks/lm/sentiment/__init__.py b/algorithm_distillation/tasks/lm/sentiment/__init__.py index e067b54..f562f25 100644 --- a/algorithm_distillation/tasks/lm/sentiment/__init__.py +++ b/algorithm_distillation/tasks/lm/sentiment/__init__.py @@ -1 +1 @@ -from .dataset import SentimentTrajectories \ No newline at end of file +from .dataset import SentimentLanguageTrajectories \ No newline at end of file diff --git a/algorithm_distillation/tasks/lm/sentiment/dataset.py b/algorithm_distillation/tasks/lm/sentiment/dataset.py index 69a0a3c..d453fca 100644 --- a/algorithm_distillation/tasks/lm/sentiment/dataset.py +++ b/algorithm_distillation/tasks/lm/sentiment/dataset.py @@ -2,28 +2,28 @@ from pathlib import Path import json from transformers import AutoTokenizer -from typing import Dict, Any, Union - -Dataset = Union[torch.utils.data.Dataset, torch.utils.data.IterableDataset] - -class SentimentTrajectories(Dataset): - def __init__(self, format:str, *args, **kwargs): - if format == "language": - self = SentimentAsLanguageTrajectories(*args, **kwargs) - elif format == "rl": - raise NotImplementedError() - # self = SentimentAsRlTrajectories(*stargs, **kwargs) - else: - raise RuntimeError(f"format must be either 'language' or 'rl', got: {format}") +from typing import Dict, Any +from pathlib import Path +class SentimentRlTrajectories(torch.utils.data.Dataset): + def __init__(self): + raise NotImplementedError() + -class SentimentAsLanguageTrajectories(torch.utils.data.IterableDataset): - def __init__(self, tokenizer: AutoTokenizer, rollouts_folder_fpath: str, for_generation: bool = False, verbose: bool = True): +class SentimentLanguageTrajectories(torch.utils.data.IterableDataset): + def __init__(self, tokenizer: AutoTokenizer, split: str, for_generation: bool = False, verbose: bool = True): self.tokenizer = tokenizer - self.rollouts_folder = Path(rollouts_folder_fpath) self.verbose = verbose self.for_generation = for_generation + if split == 'train': + self.rollouts_folder = Path(__file__).parent / "decoded_rollouts" / "train" + elif split == 'eval': + self.rollouts_folder = Path(__file__).parent / "decoded_rollouts" / "eval" + else: + raise RuntimeError(f"split must be either 'train' or 'eval', got: {split}") + + def format_rollout(self, d: Dict[Any, Any]) -> str: return f"Prompt: {d['query_text']}\nCompletion: {d['response_text']}\nReward: {d['rewards'][-1]}\n\n" @@ -84,7 +84,7 @@ def __iter__(self): if __name__ == '__main__': tokenizer = AutoTokenizer.from_pretrained('gpt2') - dataset = SentimentTrajectories("language", tokenizer, './decoded_rollouts', for_generation=False) + dataset = SentimentLanguageTrajectories(tokenizer, split='train', for_generation=False) for ex in dataset: print(tokenizer.decode(ex['input_ids'][0])) print('\n---------\n') \ No newline at end of file diff --git a/algorithm_distillation/train.py b/algorithm_distillation/train.py index d82e808..7579dc1 100644 --- a/algorithm_distillation/train.py +++ b/algorithm_distillation/train.py @@ -1,7 +1,9 @@ -from .tasks.lm.sentiment import SentimentTrajectories -from .tasks.utils import ShuffledIterableDataset +from tasks.lm.sentiment import SentimentLanguageTrajectories +from tasks.utils import ShuffledIterableDataset +from transformers import AutoModelForCausalLM, AutoTokenizer +from accelerate import Accelerator from torch.utils.data import DataLoader from torch.optim import Adam from tqdm.auto import tqdm @@ -15,7 +17,7 @@ # Data tokenizer = AutoTokenizer.from_pretrained('gpt2') -train_dataset = SentimentTrajectories(format="language", tokenizer=tokenizer) +train_dataset = SentimentLanguageTrajectories(split="train", tokenizer=tokenizer) train_dataset = ShuffledIterableDataset(train_dataset, buffer_size=10_000) # eval_dataset = ... # generate_dataset = ... From 6f8c8f95aefd6885b422610741d1f4acf2342645 Mon Sep 17 00:00:00 2001 From: Thomas Foster Date: Sun, 29 Jan 2023 13:13:24 +0000 Subject: [PATCH 21/21] ipdate readme --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 783ac49..61441ae 100644 --- a/README.md +++ b/README.md @@ -30,7 +30,7 @@ reward: 0.9975 It's less obvious how to do this when the task is not a language task, such as moonlander. Enumerating the states as coordinates might work, but requires experimentation. -Trajectories in *Language format* are learn by models in `/models/lm`. +Trajectories in *Language format* are learnt by models in `/models/lm`. ## To summarise: