diff --git a/.gitignore b/.gitignore index e343b2c..07ed892 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,11 @@ +trlx +algorithm_distillation/tasks/lm/sentiment/ray_results/ +algorithm_distillation/tasks/lm/sentiment/rollouts/run-* +algorithm_distillation/tasks/lm/sentiment/decoded_rollouts +algorithm_distillation/tasks/lm/sentiment/wandb/ +wandb/ + + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/README.md b/README.md index fea9271..61441ae 100644 --- a/README.md +++ b/README.md @@ -1 +1,58 @@ -# Algorithm-Distillation-RLHF \ No newline at end of file +# Algorithm-Distillation-RLHF + +A reinforcement learning algorithm is characterised by the trajectories it generates during training. + +We are interested in "algorithm distillation" - whether trajectories can be modelled by transformers, as studied in the original deepmind algorithm distillation paper. + +A particular focus of this repo is to extend prior work to the case where: +1. the trajectories have been generated by the TRLx library during RLHF training of language models +2. the transformer modelling the trajectories is itself a standard language model + + +## On data formats + +A trajectory is typically defined as a list of `(state, action, reward)` triples. For training purposes, it is sometimes useful to augment this to include `logprobs`, which is, for each triple `(s, a, r)`, the probability of taking action $a$ at state $s$ as determined the policy generating the trajectory. + +We therefore define an **RL Format Trajectory** as a sequence of `(state, action, reward, logprobs)` tuples. + +The typical way to learn to model these trajectories with a transformer is to seperately map the final hidden state using 3 different heads. That is, for a given triple `(s,a,r,l)` a transformer $f$ maps to $(\hat{s}, \hat{a}, \hat{r}, \hat{l})$. + +In this repo, this is done via the models in `/models/rl`. + +We are also interested in the ability of standard language models (with language modeling heads) to learn trajectories. To this end we define a **Language Format Trajectory** as a trajectory serialised into a string. There are many possible ways to do this, and the optimal one requires investigation. For example, for trajectories generated using TRLx when finetuning a language model on positive sentiment, we can format the trajectory as the string: + +``` +prompt: Dan went down to the shops. +completion: He smiled as he walked - the sun was shining. +reward: 0.9975 +### +``` + +It's less obvious how to do this when the task is not a language task, such as moonlander. Enumerating the states as coordinates might work, but requires experimentation. + +Trajectories in *Language format* are learnt by models in `/models/lm`. + +## To summarise: + +`/models` contains the "algorithm distillation models", transformers that are trained in a supervised fashion to learn RL trajectories. We distinguish between models that operate on *RL Format* trajectories and *Language format* trajectories. + +`/tasks` contains code to produce the RL trajectories that the models learn. It can store this data however it likes, but each task should expose a `torch.utils.data.Dataset` that can return trajectory data in either *RL Format* or *Language format*. + +## Generating trajectory data +I am using my own fork of TRLx that has rollout logging. + +## ToDo: + +Today: +[X] Set up repo structure (just for your language stuff, @H can add in his) +[X] Add train script for models/lm/casuallm +[ ] Clone H's stuff and merge with @H stuff (/models/rl) and (/tasks/rl) (25 mins) +[ ] Proper PR for TRLx (25 mins) +[ ] Post guide and project tasks on discord + +Future: +[ ] Add more elegant meta class switching between ...LanguageTrajectories and ...RlTrajectories +[ ] Add online evaluation script for models/lm/casuallm +[ ] Improve train script to include reward accuracy +[ ] Run some preliminary experiments +[ ] Add __main__ file with click CLI interface for running experiments \ No newline at end of file diff --git a/algorithm_distillation/__init__.py b/algorithm_distillation/__init__.py index 0a31263..e69de29 100644 --- a/algorithm_distillation/__init__.py +++ b/algorithm_distillation/__init__.py @@ -1,5 +0,0 @@ -from .ad import AlgorithmDistillation, GymAD -from .task import GymTask, Task -from .task_manager import TaskManager - -__all__ = ["AlgorithmDistillation", "GymAD", "Task", "GymTask", "TaskManager"] diff --git a/algorithm_distillation/models/lm/__init__.py b/algorithm_distillation/models/lm/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/algorithm_distillation/models/lm/legacy_train.py b/algorithm_distillation/models/lm/legacy_train.py new file mode 100644 index 0000000..13ad30e --- /dev/null +++ b/algorithm_distillation/models/lm/legacy_train.py @@ -0,0 +1,83 @@ +import torch +import wandb +from transformers import AutoModelForCausalLM, AutoTokenizer +from datasets import load_dataset +from accelerate import Accelerator +from rollouts_as_lm_task import RolloutsAsLanguageModellingTask +from utils import ShuffledIterableDataset +from tqdm.auto import tqdm + +wandb.init(project="algorithm-distillation") +logging_table = wandb.Table(columns=["Step", "Generation"]) + +accelerator = Accelerator() + +model = AutoModelForCausalLM.from_pretrained('gpt2') +optimizer = torch.optim.Adam(model.parameters(), lr=5e-5) + +tokenizer = AutoTokenizer.from_pretrained('gpt2') +train_dataset = RolloutsAsLanguageModellingTask(tokenizer, './decoded_rollouts/train') +train_dataset = ShuffledIterableDataset(train_dataset, buffer_size=10_000) +eval_dataset = RolloutsAsLanguageModellingTask(tokenizer, './decoded_rollouts/eval', verbose=False) +generate_dataset = RolloutsAsLanguageModellingTask(tokenizer, './decoded_rollouts/eval', for_generation=True, verbose=False) + +train_dataloader = torch.utils.data.DataLoader(train_dataset, shuffle=False) + +model, optimizer, train_dataloader = accelerator.prepare(model, optimizer, train_dataloader) + +model.train() +total_steps = 0 +for epoch in range(10): + for batch in tqdm(train_dataloader, desc=f'Training epoch {epoch}'): + # train + optimizer.zero_grad() + output = model(**batch) + loss = output.loss + wandb.log({'loss': loss.item(), 'step': total_steps}) + accelerator.backward(loss) + optimizer.step() + total_steps += 1 + + # eval + eval_steps = 25 + eval_size = 20 + if total_steps % eval_steps == 0: + model.eval() + eval_loss = 0 + eval_dataloader = torch.utils.data.DataLoader(eval_dataset, shuffle=False) + eval_dataloader = accelerator.prepare(eval_dataloader) + for idx, batch in tqdm(enumerate(eval_dataloader), desc=f'Eval after {total_steps} steps.'): + if idx >= eval_size: + break + output = model(**batch) + loss = output.loss + eval_loss += loss.item() + eval_loss /= idx + wandb.log({'eval_loss': eval_loss, 'step': total_steps}) + print(f'Avg loss after {total_steps}: {eval_loss}') + model.train() + + # generate examples + generate_steps = 10 + generate_size = 3 + if total_steps % generate_steps == 0: + model.eval() + + generate_dataloader = torch.utils.data.DataLoader(generate_dataset, shuffle=False) + generate_dataloader = accelerator.prepare(generate_dataloader) + + for idx, batch in enumerate(generate_dataloader): + if idx >= generate_size: + break + batch = {k:v.squeeze(0) for k,v in batch.items()} + outputs = model.generate(**batch, max_length=100, do_sample=True, pad_token_id=tokenizer.eos_token_id) + text = tokenizer.decode(outputs[0]) + + logging_table.add_data(total_steps, text) + + logging_table = wandb.Table(columns=logging_table.columns, data=logging_table.data) + wandb.log({'Generations Table': logging_table}) + + + model.train() + \ No newline at end of file diff --git a/algorithm_distillation/models/rl/__init__.py b/algorithm_distillation/models/rl/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/algorithm_distillation/models/ad_transformer.py b/algorithm_distillation/models/rl/ad_transformer.py similarity index 100% rename from algorithm_distillation/models/ad_transformer.py rename to algorithm_distillation/models/rl/ad_transformer.py diff --git a/algorithm_distillation/models/gpt2.py b/algorithm_distillation/models/rl/gpt2.py similarity index 100% rename from algorithm_distillation/models/gpt2.py rename to algorithm_distillation/models/rl/gpt2.py diff --git a/algorithm_distillation/models/util.py b/algorithm_distillation/models/rl/util.py similarity index 100% rename from algorithm_distillation/models/util.py rename to algorithm_distillation/models/rl/util.py diff --git a/algorithm_distillation/tasks/lm/__init__.py b/algorithm_distillation/tasks/lm/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/algorithm_distillation/tasks/lm/sentiment/__init__.py b/algorithm_distillation/tasks/lm/sentiment/__init__.py new file mode 100644 index 0000000..f562f25 --- /dev/null +++ b/algorithm_distillation/tasks/lm/sentiment/__init__.py @@ -0,0 +1 @@ +from .dataset import SentimentLanguageTrajectories \ No newline at end of file diff --git a/algorithm_distillation/tasks/lm/sentiment/dataset.py b/algorithm_distillation/tasks/lm/sentiment/dataset.py new file mode 100644 index 0000000..d453fca --- /dev/null +++ b/algorithm_distillation/tasks/lm/sentiment/dataset.py @@ -0,0 +1,90 @@ +import torch +from pathlib import Path +import json +from transformers import AutoTokenizer +from typing import Dict, Any +from pathlib import Path + +class SentimentRlTrajectories(torch.utils.data.Dataset): + def __init__(self): + raise NotImplementedError() + + +class SentimentLanguageTrajectories(torch.utils.data.IterableDataset): + def __init__(self, tokenizer: AutoTokenizer, split: str, for_generation: bool = False, verbose: bool = True): + self.tokenizer = tokenizer + self.verbose = verbose + self.for_generation = for_generation + + if split == 'train': + self.rollouts_folder = Path(__file__).parent / "decoded_rollouts" / "train" + elif split == 'eval': + self.rollouts_folder = Path(__file__).parent / "decoded_rollouts" / "eval" + else: + raise RuntimeError(f"split must be either 'train' or 'eval', got: {split}") + + + def format_rollout(self, d: Dict[Any, Any]) -> str: + return f"Prompt: {d['query_text']}\nCompletion: {d['response_text']}\nReward: {d['rewards'][-1]}\n\n" + + def tokenize_for_training(self, x: str): + inputs = self.tokenizer(x, truncation=True, return_tensors='pt') + return { + 'input_ids': inputs.input_ids, + 'attention_mask': inputs.attention_mask, + 'labels': inputs.input_ids + } + + def __iter__(self): + + runs = [run for run in self.rollouts_folder.iterdir() if run.name.startswith('run-e')] + if self.verbose: + print(f'Iterating over {len(runs)} runs...') + + for run in runs: + + config = json.loads(open(run / 'config.json', 'r').read()) + epochs = [epoch for epoch in run.iterdir() if epoch.name != 'config.json'] + + if self.verbose: + print(f'...and {run.name} has {len(epochs)} epochs.') + + epochs = sorted(epochs) + for epoch in epochs: + # print(f'Yielding from epoch {epoch.name}') + + rollouts = json.loads(open(epoch, 'r').read()) + + rollout_idx = 0 + prompt = "" + while rollout_idx < len(rollouts): + + if self.for_generation: + d = rollouts[rollout_idx] + rollout_idx += 1 + generation_prompt = f"Prompt: {d['query_text']}\nCompletion:" + yield self.tokenize_for_training(generation_prompt) + + else: + rollout = self.format_rollout(rollouts[rollout_idx]) + rollout_idx += 1 + + new_prompt = prompt + rollout + new_ids = self.tokenizer.tokenize(new_prompt) + + if len(new_ids) > self.tokenizer.model_max_length: # self.tokenizer.model_max_length: + yield self.tokenize_for_training(prompt) + prompt = "" + else: + prompt = new_prompt + + if len(prompt) > 0: + yield self.tokenize_for_training(prompt) + + +if __name__ == '__main__': + tokenizer = AutoTokenizer.from_pretrained('gpt2') + dataset = SentimentLanguageTrajectories(tokenizer, split='train', for_generation=False) + for ex in dataset: + print(tokenizer.decode(ex['input_ids'][0])) + print('\n---------\n') \ No newline at end of file diff --git a/algorithm_distillation/tasks/lm/sentiment/decode_rollouts.py b/algorithm_distillation/tasks/lm/sentiment/decode_rollouts.py new file mode 100644 index 0000000..13b2b65 --- /dev/null +++ b/algorithm_distillation/tasks/lm/sentiment/decode_rollouts.py @@ -0,0 +1,101 @@ +from pathlib import Path +import json +from transformers import AutoTokenizer +import click +from typing import Union +from shutil import copy2 + +@click.group() +def main(): + """ + CLI for formatting the rollouts into training data. + + \b + 1. decode-epoch: Decode a single epoch rollout .json file + 2. decode-run: Decode an entire PPO run (multiple epoch files) + 3. decode-rollouts: Deocde an entire directory (multiple runs) + """ + +def get_tokenizer(tokenizer: Union[str, AutoTokenizer]) -> AutoTokenizer: + if isinstance(tokenizer, str): + return AutoTokenizer.from_pretrained(tokenizer) + else: + return tokenizer + + +@main.command() +@click.option('--tokenizer', type=str, default='gpt2', help='tokenizer to decode with') +@click.option('--input-fpath', type=click.Path(path_type=Path), help='the input JSON file') +@click.option('--output-fpath', type=click.Path(path_type=Path), help='the path of the JSON file to be created. Will overwrite if necessary.') +def decode_epoch(tokenizer: Union[str, AutoTokenizer], input_fpath: Path, output_fpath: Path): + _decode_epoch(tokenizer, input_fpath, output_fpath) + +def _decode_epoch(tokenizer: Union[str, AutoTokenizer], input_fpath: Path, output_fpath: Path): + + assert input_fpath.exists() + assert input_fpath.is_file() + assert input_fpath.name.endswith('.json') + assert output_fpath.name.endswith('.json') + if not output_fpath.parent.exists(): + output_fpath.parent.mkdir() + + tokenizer = get_tokenizer(tokenizer) + + with open(input_fpath, 'r') as f: + rollouts = json.loads(f.read()) + + for rollout in rollouts: + rollout['query_text'] = tokenizer.decode(rollout['query_tensor'], skip_special_tokens=True) + rollout['response_text'] = tokenizer.decode(rollout['response_tensor'], skip_special_tokens=True) + + with open(output_fpath, 'w') as f: + f.write(json.dumps(rollouts, indent=2)) + + +@main.command() +@click.option('--tokenizer', type=str, default='gpt2', help='tokenizer to decode with') +@click.option('--input-fpath', type=click.Path(path_type=Path), help='the directory containing the JSON epoch files') +@click.option('--output-fpath', type=click.Path(path_type=Path), help='the path of the folder to be created.') +def decode_run(tokenizer: Union[str, AutoTokenizer], input_fpath: Path, output_fpath: Path): + _decode_run(tokenizer, input_fpath, output_fpath) + +def _decode_run(tokenizer: Union[str, AutoTokenizer], input_fpath: Path, output_fpath: Path): + + assert input_fpath.exists() + assert input_fpath.is_dir() + if not output_fpath.exists(): + output_fpath.mkdir() + + # Copy over the config file + assert (input_fpath / 'config.json').exists() + copy2(input_fpath / 'config.json', output_fpath / 'config.json') + + # Decode the rest of the files + tokenizer = get_tokenizer(tokenizer) + + epochs = [fpath for fpath in input_fpath.iterdir() if fpath.name != 'config.json'] + for epoch in epochs: + _decode_epoch(tokenizer, epoch, output_fpath / epoch.name) + + +@main.command() +@click.option('--tokenizer', type=str, default='gpt2', help='tokenizer to decode with') +@click.option('--input-fpath', type=click.Path(path_type=Path), help='the input directory') +@click.option('--output-fpath', type=click.Path(path_type=Path), help='the output directory') +def decode_rollouts(tokenizer: str, input_fpath: Path, output_fpath: Path): + _decode_rollouts(tokenizer, input_fpath, output_fpath) + +def _decode_rollouts(tokenizer: str, input_fpath: Path, output_fpath: Path): + + assert input_fpath.exists() + assert input_fpath.is_dir() + if not output_fpath.exists(): + output_fpath.mkdir() + + tokenizer = get_tokenizer(tokenizer) + runs = [fpath for fpath in input_fpath.iterdir() if fpath.name.startswith('run-')] + for run in runs: + _decode_run(tokenizer, run, output_fpath / run.name) + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/algorithm_distillation/tasks/lm/sentiment/generate_ppo_roc_story_sentiment_rollouts.py b/algorithm_distillation/tasks/lm/sentiment/generate_ppo_roc_story_sentiment_rollouts.py new file mode 100644 index 0000000..e2f0582 --- /dev/null +++ b/algorithm_distillation/tasks/lm/sentiment/generate_ppo_roc_story_sentiment_rollouts.py @@ -0,0 +1,60 @@ +from posixpath import dirname +from datasets import load_dataset +from transformers import pipeline +import os +import yaml +from functools import partial + +import trlx +import torch +from typing import List +from trlx.data.configs import TRLConfig + +from trlx.utils.loading import get_model, get_orchestrator, get_pipeline + +def get_score_for_label(label, scores): + "Extract value associated with a positive sentiment from pipeline's output" + label_to_score = {d['label'] : d['score'] for d in scores} + return label_to_score[label] + +default_config = yaml.safe_load(open(os.path.join(dirname(__file__), "ppo_config.yml"))) + +def main(hparams={}): + config = TRLConfig.update(default_config, hparams) + + if torch.cuda.is_available(): + device = int(os.environ.get("LOCAL_RANK", 0)) + else: + device = -1 + + + sentiment_fn = pipeline( + "sentiment-analysis", + "bhadresh-savani/distilbert-base-uncased-emotion", + truncation=True, + batch_size=256, + device=device, + return_all_scores=True, + ) + + def reward_fn(samples: List[str]) -> List[float]: + output_batch = sentiment_fn(samples) + sentiments = list(map(partial(get_score_for_label, 'sadness'), output_batch)) + return sentiments + + stories = load_dataset("adamlin/roc_story") + format_prompt = lambda d : d['sentence1'] + ' ' + d['sentence2'] + prompts = [format_prompt(d) for d in stories['train']] + eval_prompts = [format_prompt(d) for d in stories['validation'].select(range(64))] + + model = trlx.train( + model_path="gpt2", + reward_fn=reward_fn, + prompts=prompts, + eval_prompts=eval_prompts, + config=config, + ) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/algorithm_distillation/tasks/lm/sentiment/ppo_config.yml b/algorithm_distillation/tasks/lm/sentiment/ppo_config.yml new file mode 100644 index 0000000..d09339b --- /dev/null +++ b/algorithm_distillation/tasks/lm/sentiment/ppo_config.yml @@ -0,0 +1,49 @@ +model: + model_path: "lvwerra/gpt2-xl" # Name of hf model to load + tokenizer_path: "gpt2" # Name of hf tokenizer to load + model_type: "AcceleratePPOModel" # Name of accelerate model type to load + num_layers_unfrozen: 2 # Number of bottom layers to freeze during training + +train: + seq_length: 48 # Size of LM context + epochs: 100 # Train for max(epochs, total_steps) + total_steps: 1000 # Train for max(epochs, total_steps) + batch_size: 128 # batch size + + lr_init: 1.0e-4 # init learning rate + lr_target: 1.0e-4 # target final learning rate + opt_betas: [0.9, 0.95] # adam betas + opt_eps: 1.0e-8 # adam eps + weight_decay: 1.0e-6 # weight decay param + + checkpoint_interval: 10000 # checkpoint interval + eval_interval: 100 # eval interval + + pipeline: "PromptPipeline" # prompt pipeline to load + orchestrator: "PPOOrchestrator" # orchestrator to load + + rollout_logging_dir: "../algorithm_distillation/sentiment-data/rollouts" + +method: + name: 'ppoconfig' # Name of RL method config + num_rollouts: 128 # Number of rollouts to collect per epoch + chunk_size: 128 # Number of rollouts to collect in one loop of orchestrator + ppo_epochs: 4 # Number of ppo epochs + init_kl_coef: 0.05 # init kl coefficient + target: 6 # target kl coefficient, set None for fixed kl coef + horizon: 10000 # PPO horizon + gamma: 1 # PPO discount + lam: 0.95 # PPO lambda + cliprange: 0.2 # clip range + cliprange_value: 0.2 # clip range + vf_coef: 1 # value term weight + scale_reward: False # False | "ref" | "running" estimate against which to scale rewards + ref_mean: null + ref_std: null # rescale rewards with this deviation + cliprange_reward: 10 + gen_kwargs: + max_length: 48 # LM max sample gen length + min_length: 48 # LM min sample gen length + top_k: 0.0 # top k + top_p: 1.0 # top p + do_sample: True # sample diff --git a/algorithm_distillation/tasks/lm/sentiment/ppo_sweep.yml b/algorithm_distillation/tasks/lm/sentiment/ppo_sweep.yml new file mode 100644 index 0000000..dc5faf0 --- /dev/null +++ b/algorithm_distillation/tasks/lm/sentiment/ppo_sweep.yml @@ -0,0 +1,17 @@ +tune_config: + mode: "max" + metric: "mean_reward" + search_alg: "random" + scheduler: "fifo" + num_samples: 4 + +# https://docs.ray.io/en/latest/tune/api_docs/search_space.html#tune-sample-docs +lr_init: + strategy: "loguniform" + values: [0.00001, 0.01] +init_kl_coef: + strategy: "uniform" + values: [0, 0.2] +vf_coef: + strategy: "uniform" + values: [0.5, 2] diff --git a/algorithm_distillation/tasks/lm/sentiment/rollouts/.gitkeep b/algorithm_distillation/tasks/lm/sentiment/rollouts/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/algorithm_distillation/tasks/rl/__init__.py b/algorithm_distillation/tasks/rl/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/algorithm_distillation/ad.py b/algorithm_distillation/tasks/rl/ad.py similarity index 100% rename from algorithm_distillation/ad.py rename to algorithm_distillation/tasks/rl/ad.py diff --git a/algorithm_distillation/task.py b/algorithm_distillation/tasks/rl/task.py similarity index 100% rename from algorithm_distillation/task.py rename to algorithm_distillation/tasks/rl/task.py diff --git a/algorithm_distillation/task_manager.py b/algorithm_distillation/tasks/rl/task_manager.py similarity index 100% rename from algorithm_distillation/task_manager.py rename to algorithm_distillation/tasks/rl/task_manager.py diff --git a/algorithm_distillation/tasks/utils.py b/algorithm_distillation/tasks/utils.py new file mode 100644 index 0000000..98abbf2 --- /dev/null +++ b/algorithm_distillation/tasks/utils.py @@ -0,0 +1,31 @@ +import torch +import random + +class ShuffledIterableDataset(torch.utils.data.IterableDataset): + + def __init__(self, original_dataset: torch.utils.data.IterableDataset, buffer_size: int = 10_000): + self.original_dataset = original_dataset + self.buffer_size = buffer_size + + def __iter__(self): + original_iterator = iter(self.original_dataset) + + # fill buffer (or until runs out) + buffer = [] + while len(buffer) < self.buffer_size: + try: + x = next(original_iterator) + except StopIteration: + break + buffer.append(x) + + # shuffle, yield and replace until original runs out + for x in original_iterator: + random.shuffle(buffer) + yield buffer[-1] + buffer[-1] = x + + # empty the remaining + for x in buffer: + yield x + \ No newline at end of file diff --git a/algorithm_distillation/train.py b/algorithm_distillation/train.py new file mode 100644 index 0000000..7579dc1 --- /dev/null +++ b/algorithm_distillation/train.py @@ -0,0 +1,44 @@ + +from tasks.lm.sentiment import SentimentLanguageTrajectories +from tasks.utils import ShuffledIterableDataset + +from transformers import AutoModelForCausalLM, AutoTokenizer +from accelerate import Accelerator +from torch.utils.data import DataLoader +from torch.optim import Adam +from tqdm.auto import tqdm +import wandb + +accelerator = Accelerator() + +# Logging inits +wandb.init(project="algorithm-distillation") +logging_table = wandb.Table(columns=['step', 'generation']) + +# Data +tokenizer = AutoTokenizer.from_pretrained('gpt2') +train_dataset = SentimentLanguageTrajectories(split="train", tokenizer=tokenizer) +train_dataset = ShuffledIterableDataset(train_dataset, buffer_size=10_000) +# eval_dataset = ... +# generate_dataset = ... +train_dataloader = DataLoader(train_dataset, shuffle=False) + +# Setup parameters for training with accelerate +model = AutoModelForCausalLM.from_pretrained('gpt2') +optimizer = Adam(model.parameters(), lr=5e-5) +model, optimizer, train_dataloader = accelerator.prepare(model, optimizer, train_dataloader) + +# Train +model.train() +total_steps = 0 +for epoch in range(10): + for batch in tqdm(train_dataloader, desc=f'Training epoch {epoch}'): + + optimizer.zero_grad() + output = model(**batch) + loss = output.loss + wandb.log({'loss': loss.item(), 'step': total_steps}) + accelerator.backward(loss) + optimizer.step() + total_steps +=1 + \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 5887b2c..f726850 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,90 @@ +accelerate==0.14.0 +aiohttp==3.8.3 +aiosignal==1.3.1 +async-timeout==4.0.2 +attrs==22.1.0 +certifi==2022.9.24 +cfgv==3.3.1 +charset-normalizer==2.1.1 +click==8.0.4 +datasets==2.7.1 +deepspeed==0.7.5 +dill==0.3.6 +distlib==0.3.6 +docker-pycreds==0.4.0 +einops==0.6.0 +exceptiongroup==1.0.4 +filelock==3.8.0 +frozenlist==1.3.3 +fsspec==2022.11.0 +gitdb==4.0.9 +GitPython==3.1.29 +grpcio==1.50.0 +hjson==3.1.0 +huggingface-hub==0.11.0 +identify==2.5.9 +idna==3.4 +iniconfig==1.1.1 +joblib==1.2.0 +jsonschema==4.17.1 +msgpack==1.0.4 +multidict==6.0.2 +multiprocess==0.70.14 +networkx==2.8.8 +ninja==1.11.1 +nodeenv==1.7.0 +numpy==1.23.5 +nvidia-cublas-cu11==11.10.3.66 +nvidia-cuda-nvrtc-cu11==11.7.99 +nvidia-cuda-runtime-cu11==11.7.99 +nvidia-cudnn-cu11==8.5.0.96 +packaging==21.3 +pandas==1.5.2 +pathtools==0.1.2 +platformdirs==2.5.4 +pluggy==1.0.0 +pre-commit==2.20.0 +promise==2.3 +protobuf==4.21.9 +psutil==5.9.4 +py-cpuinfo==9.0.0 +pyarrow==10.0.1 +pydantic==1.10.2 +pyparsing==3.0.9 +pyrsistent==0.19.2 +pytest==7.2.0 +python-dateutil==2.8.2 +pytz==2022.6 +PyYAML==6.0 +ray==2.1.0 +regex==2022.10.31 +requests==2.28.1 +responses==0.18.0 +scikit-learn==1.1.3 +scipy==1.9.3 +sentry-sdk==1.11.1 +setproctitle==1.3.2 +shortuuid==1.0.11 +six==1.16.0 +smmap==5.0.0 +tabulate==0.9.0 +threadpoolctl==3.1.0 +tokenizers==0.13.2 +toml==0.10.2 +tomli==2.0.1 +torch==1.13.0 +torchtyping==0.1.4 +tqdm==4.64.1 +transformers==4.24.0 +-e git+https://github.com/thomfoster/trlx.git@e76b34338eb1551e8eb5f0eff7f694a4711d8b83#egg=trlx +typeguard==2.13.3 +typing_extensions==4.4.0 +urllib3==1.26.12 +virtualenv==20.16.7 +wandb==0.13.5 +xxhash==3.1.0 +yarl==1.8.1 +======= stable-baselines3 transformers~=4.24.0 torch~=1.12.1