Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
trlx
algorithm_distillation/tasks/lm/sentiment/ray_results/
algorithm_distillation/tasks/lm/sentiment/rollouts/run-*
algorithm_distillation/tasks/lm/sentiment/decoded_rollouts
algorithm_distillation/tasks/lm/sentiment/wandb/
wandb/


# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
59 changes: 58 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1 +1,58 @@
# Algorithm-Distillation-RLHF
# Algorithm-Distillation-RLHF

A reinforcement learning algorithm is characterised by the trajectories it generates during training.

We are interested in "algorithm distillation" - whether trajectories can be modelled by transformers, as studied in the original deepmind algorithm distillation paper.

A particular focus of this repo is to extend prior work to the case where:
1. the trajectories have been generated by the TRLx library during RLHF training of language models
2. the transformer modelling the trajectories is itself a standard language model


## On data formats

A trajectory is typically defined as a list of `(state, action, reward)` triples. For training purposes, it is sometimes useful to augment this to include `logprobs`, which is, for each triple `(s, a, r)`, the probability of taking action $a$ at state $s$ as determined the policy generating the trajectory.

We therefore define an **RL Format Trajectory** as a sequence of `(state, action, reward, logprobs)` tuples.

The typical way to learn to model these trajectories with a transformer is to seperately map the final hidden state using 3 different heads. That is, for a given triple `(s,a,r,l)` a transformer $f$ maps to $(\hat{s}, \hat{a}, \hat{r}, \hat{l})$.

In this repo, this is done via the models in `/models/rl`.

We are also interested in the ability of standard language models (with language modeling heads) to learn trajectories. To this end we define a **Language Format Trajectory** as a trajectory serialised into a string. There are many possible ways to do this, and the optimal one requires investigation. For example, for trajectories generated using TRLx when finetuning a language model on positive sentiment, we can format the trajectory as the string:

```
prompt: Dan went down to the shops.
completion: He smiled as he walked - the sun was shining.
reward: 0.9975
###
```

It's less obvious how to do this when the task is not a language task, such as moonlander. Enumerating the states as coordinates might work, but requires experimentation.

Trajectories in *Language format* are learnt by models in `/models/lm`.

## To summarise:

`/models` contains the "algorithm distillation models", transformers that are trained in a supervised fashion to learn RL trajectories. We distinguish between models that operate on *RL Format* trajectories and *Language format* trajectories.

`/tasks` contains code to produce the RL trajectories that the models learn. It can store this data however it likes, but each task should expose a `torch.utils.data.Dataset` that can return trajectory data in either *RL Format* or *Language format*.

## Generating trajectory data
I am using my own fork of TRLx that has rollout logging.

## ToDo:

Today:
[X] Set up repo structure (just for your language stuff, @H can add in his)
[X] Add train script for models/lm/casuallm
[ ] Clone H's stuff and merge with @H stuff (/models/rl) and (/tasks/rl) (25 mins)
[ ] Proper PR for TRLx (25 mins)
[ ] Post guide and project tasks on discord

Future:
[ ] Add more elegant meta class switching between ...LanguageTrajectories and ...RlTrajectories
[ ] Add online evaluation script for models/lm/casuallm
[ ] Improve train script to include reward accuracy
[ ] Run some preliminary experiments
[ ] Add __main__ file with click CLI interface for running experiments
5 changes: 0 additions & 5 deletions algorithm_distillation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +0,0 @@
from .ad import AlgorithmDistillation, GymAD
from .task import GymTask, Task
from .task_manager import TaskManager

__all__ = ["AlgorithmDistillation", "GymAD", "Task", "GymTask", "TaskManager"]
Empty file.
83 changes: 83 additions & 0 deletions algorithm_distillation/models/lm/legacy_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import torch
import wandb
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from accelerate import Accelerator
from rollouts_as_lm_task import RolloutsAsLanguageModellingTask
from utils import ShuffledIterableDataset
from tqdm.auto import tqdm

wandb.init(project="algorithm-distillation")
logging_table = wandb.Table(columns=["Step", "Generation"])

accelerator = Accelerator()

model = AutoModelForCausalLM.from_pretrained('gpt2')
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)

tokenizer = AutoTokenizer.from_pretrained('gpt2')
train_dataset = RolloutsAsLanguageModellingTask(tokenizer, './decoded_rollouts/train')
train_dataset = ShuffledIterableDataset(train_dataset, buffer_size=10_000)
eval_dataset = RolloutsAsLanguageModellingTask(tokenizer, './decoded_rollouts/eval', verbose=False)
generate_dataset = RolloutsAsLanguageModellingTask(tokenizer, './decoded_rollouts/eval', for_generation=True, verbose=False)

train_dataloader = torch.utils.data.DataLoader(train_dataset, shuffle=False)

model, optimizer, train_dataloader = accelerator.prepare(model, optimizer, train_dataloader)

model.train()
total_steps = 0
for epoch in range(10):
for batch in tqdm(train_dataloader, desc=f'Training epoch {epoch}'):
# train
optimizer.zero_grad()
output = model(**batch)
loss = output.loss
wandb.log({'loss': loss.item(), 'step': total_steps})
accelerator.backward(loss)
optimizer.step()
total_steps += 1

# eval
eval_steps = 25
eval_size = 20
if total_steps % eval_steps == 0:
model.eval()
eval_loss = 0
eval_dataloader = torch.utils.data.DataLoader(eval_dataset, shuffle=False)
eval_dataloader = accelerator.prepare(eval_dataloader)
for idx, batch in tqdm(enumerate(eval_dataloader), desc=f'Eval after {total_steps} steps.'):
if idx >= eval_size:
break
output = model(**batch)
loss = output.loss
eval_loss += loss.item()
eval_loss /= idx
wandb.log({'eval_loss': eval_loss, 'step': total_steps})
print(f'Avg loss after {total_steps}: {eval_loss}')
model.train()

# generate examples
generate_steps = 10
generate_size = 3
if total_steps % generate_steps == 0:
model.eval()

generate_dataloader = torch.utils.data.DataLoader(generate_dataset, shuffle=False)
generate_dataloader = accelerator.prepare(generate_dataloader)

for idx, batch in enumerate(generate_dataloader):
if idx >= generate_size:
break
batch = {k:v.squeeze(0) for k,v in batch.items()}
outputs = model.generate(**batch, max_length=100, do_sample=True, pad_token_id=tokenizer.eos_token_id)
text = tokenizer.decode(outputs[0])

logging_table.add_data(total_steps, text)

logging_table = wandb.Table(columns=logging_table.columns, data=logging_table.data)
wandb.log({'Generations Table': logging_table})


model.train()

Empty file.
Empty file.
1 change: 1 addition & 0 deletions algorithm_distillation/tasks/lm/sentiment/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .dataset import SentimentLanguageTrajectories
90 changes: 90 additions & 0 deletions algorithm_distillation/tasks/lm/sentiment/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import torch
from pathlib import Path
import json
from transformers import AutoTokenizer
from typing import Dict, Any
from pathlib import Path

class SentimentRlTrajectories(torch.utils.data.Dataset):
def __init__(self):
raise NotImplementedError()


class SentimentLanguageTrajectories(torch.utils.data.IterableDataset):
def __init__(self, tokenizer: AutoTokenizer, split: str, for_generation: bool = False, verbose: bool = True):
self.tokenizer = tokenizer
self.verbose = verbose
self.for_generation = for_generation

if split == 'train':
self.rollouts_folder = Path(__file__).parent / "decoded_rollouts" / "train"
elif split == 'eval':
self.rollouts_folder = Path(__file__).parent / "decoded_rollouts" / "eval"
else:
raise RuntimeError(f"split must be either 'train' or 'eval', got: {split}")


def format_rollout(self, d: Dict[Any, Any]) -> str:
return f"Prompt: {d['query_text']}\nCompletion: {d['response_text']}\nReward: {d['rewards'][-1]}\n\n"

def tokenize_for_training(self, x: str):
inputs = self.tokenizer(x, truncation=True, return_tensors='pt')
return {
'input_ids': inputs.input_ids,
'attention_mask': inputs.attention_mask,
'labels': inputs.input_ids
}

def __iter__(self):

runs = [run for run in self.rollouts_folder.iterdir() if run.name.startswith('run-e')]
if self.verbose:
print(f'Iterating over {len(runs)} runs...')

for run in runs:

config = json.loads(open(run / 'config.json', 'r').read())
epochs = [epoch for epoch in run.iterdir() if epoch.name != 'config.json']

if self.verbose:
print(f'...and {run.name} has {len(epochs)} epochs.')

epochs = sorted(epochs)
for epoch in epochs:
# print(f'Yielding from epoch {epoch.name}')

rollouts = json.loads(open(epoch, 'r').read())

rollout_idx = 0
prompt = ""
while rollout_idx < len(rollouts):

if self.for_generation:
d = rollouts[rollout_idx]
rollout_idx += 1
generation_prompt = f"Prompt: {d['query_text']}\nCompletion:"
yield self.tokenize_for_training(generation_prompt)

else:
rollout = self.format_rollout(rollouts[rollout_idx])
rollout_idx += 1

new_prompt = prompt + rollout
new_ids = self.tokenizer.tokenize(new_prompt)

if len(new_ids) > self.tokenizer.model_max_length: # self.tokenizer.model_max_length:
yield self.tokenize_for_training(prompt)
prompt = ""
else:
prompt = new_prompt

if len(prompt) > 0:
yield self.tokenize_for_training(prompt)


if __name__ == '__main__':
tokenizer = AutoTokenizer.from_pretrained('gpt2')
dataset = SentimentLanguageTrajectories(tokenizer, split='train', for_generation=False)
for ex in dataset:
print(tokenizer.decode(ex['input_ids'][0]))
print('\n---------\n')
101 changes: 101 additions & 0 deletions algorithm_distillation/tasks/lm/sentiment/decode_rollouts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
from pathlib import Path
import json
from transformers import AutoTokenizer
import click
from typing import Union
from shutil import copy2

@click.group()
def main():
"""
CLI for formatting the rollouts into training data.

\b
1. decode-epoch: Decode a single epoch rollout .json file
2. decode-run: Decode an entire PPO run (multiple epoch files)
3. decode-rollouts: Deocde an entire directory (multiple runs)
"""

def get_tokenizer(tokenizer: Union[str, AutoTokenizer]) -> AutoTokenizer:
if isinstance(tokenizer, str):
return AutoTokenizer.from_pretrained(tokenizer)
else:
return tokenizer


@main.command()
@click.option('--tokenizer', type=str, default='gpt2', help='tokenizer to decode with')
@click.option('--input-fpath', type=click.Path(path_type=Path), help='the input JSON file')
@click.option('--output-fpath', type=click.Path(path_type=Path), help='the path of the JSON file to be created. Will overwrite if necessary.')
def decode_epoch(tokenizer: Union[str, AutoTokenizer], input_fpath: Path, output_fpath: Path):
_decode_epoch(tokenizer, input_fpath, output_fpath)

def _decode_epoch(tokenizer: Union[str, AutoTokenizer], input_fpath: Path, output_fpath: Path):

assert input_fpath.exists()
assert input_fpath.is_file()
assert input_fpath.name.endswith('.json')
assert output_fpath.name.endswith('.json')
if not output_fpath.parent.exists():
output_fpath.parent.mkdir()

tokenizer = get_tokenizer(tokenizer)

with open(input_fpath, 'r') as f:
rollouts = json.loads(f.read())

for rollout in rollouts:
rollout['query_text'] = tokenizer.decode(rollout['query_tensor'], skip_special_tokens=True)
rollout['response_text'] = tokenizer.decode(rollout['response_tensor'], skip_special_tokens=True)

with open(output_fpath, 'w') as f:
f.write(json.dumps(rollouts, indent=2))


@main.command()
@click.option('--tokenizer', type=str, default='gpt2', help='tokenizer to decode with')
@click.option('--input-fpath', type=click.Path(path_type=Path), help='the directory containing the JSON epoch files')
@click.option('--output-fpath', type=click.Path(path_type=Path), help='the path of the folder to be created.')
def decode_run(tokenizer: Union[str, AutoTokenizer], input_fpath: Path, output_fpath: Path):
_decode_run(tokenizer, input_fpath, output_fpath)

def _decode_run(tokenizer: Union[str, AutoTokenizer], input_fpath: Path, output_fpath: Path):

assert input_fpath.exists()
assert input_fpath.is_dir()
if not output_fpath.exists():
output_fpath.mkdir()

# Copy over the config file
assert (input_fpath / 'config.json').exists()
copy2(input_fpath / 'config.json', output_fpath / 'config.json')

# Decode the rest of the files
tokenizer = get_tokenizer(tokenizer)

epochs = [fpath for fpath in input_fpath.iterdir() if fpath.name != 'config.json']
for epoch in epochs:
_decode_epoch(tokenizer, epoch, output_fpath / epoch.name)


@main.command()
@click.option('--tokenizer', type=str, default='gpt2', help='tokenizer to decode with')
@click.option('--input-fpath', type=click.Path(path_type=Path), help='the input directory')
@click.option('--output-fpath', type=click.Path(path_type=Path), help='the output directory')
def decode_rollouts(tokenizer: str, input_fpath: Path, output_fpath: Path):
_decode_rollouts(tokenizer, input_fpath, output_fpath)

def _decode_rollouts(tokenizer: str, input_fpath: Path, output_fpath: Path):

assert input_fpath.exists()
assert input_fpath.is_dir()
if not output_fpath.exists():
output_fpath.mkdir()

tokenizer = get_tokenizer(tokenizer)
runs = [fpath for fpath in input_fpath.iterdir() if fpath.name.startswith('run-')]
for run in runs:
_decode_run(tokenizer, run, output_fpath / run.name)

if __name__ == '__main__':
main()
Loading