Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
trlx
algorithm_distillation/tasks/lm/sentiment/ray_results/
algorithm_distillation/tasks/lm/sentiment/rollouts/run-*
algorithm_distillation/tasks/lm/sentiment/decoded_rollouts
algorithm_distillation/tasks/lm/sentiment/wandb/
wandb/


# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
59 changes: 58 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1 +1,58 @@
# Algorithm-Distillation-RLHF
# Algorithm-Distillation-RLHF

A reinforcement learning algorithm is characterised by the trajectories it generates during training.

We are interested in "algorithm distillation" - whether trajectories can be modelled by transformers, as studied in the original deepmind algorithm distillation paper.

A particular focus of this repo is to extend prior work to the case where:
1. the trajectories have been generated by the TRLx library during RLHF training of language models
2. the transformer modelling the trajectories is itself a standard language model


## On data formats

A trajectory is typically defined as a list of `(state, action, reward)` triples. For training purposes, it is sometimes useful to augment this to include `logprobs`, which is, for each triple `(s, a, r)`, the probability of taking action $a$ at state $s$ as determined the policy generating the trajectory.

We therefore define an **RL Format Trajectory** as a sequence of `(state, action, reward, logprobs)` tuples.

The typical way to learn to model these trajectories with a transformer is to seperately map the final hidden state using 3 different heads. That is, for a given triple `(s,a,r,l)` a transformer $f$ maps to $(\hat{s}, \hat{a}, \hat{r}, \hat{l})$.

In this repo, this is done via the models in `/models/rl`.

We are also interested in the ability of standard language models (with language modeling heads) to learn trajectories. To this end we define a **Language Format Trajectory** as a trajectory serialised into a string. There are many possible ways to do this, and the optimal one requires investigation. For example, for trajectories generated using TRLx when finetuning a language model on positive sentiment, we can format the trajectory as the string:

```
prompt: Dan went down to the shops.
completion: He smiled as he walked - the sun was shining.
reward: 0.9975
###
```

It's less obvious how to do this when the task is not a language task, such as moonlander. Enumerating the states as coordinates might work, but requires experimentation.

Trajectories in *Language format* are learnt by models in `/models/lm`.

## To summarise:

`/models` contains the "algorithm distillation models", transformers that are trained in a supervised fashion to learn RL trajectories. We distinguish between models that operate on *RL Format* trajectories and *Language format* trajectories.

`/tasks` contains code to produce the RL trajectories that the models learn. It can store this data however it likes, but each task should expose a `torch.utils.data.Dataset` that can return trajectory data in either *RL Format* or *Language format*.

## Generating trajectory data
I am using my own fork of TRLx that has rollout logging.

## ToDo:

Today:
[X] Set up repo structure (just for your language stuff, @H can add in his)
[X] Add train script for models/lm/casuallm
[ ] Clone H's stuff and merge with @H stuff (/models/rl) and (/tasks/rl) (25 mins)
[ ] Proper PR for TRLx (25 mins)
[ ] Post guide and project tasks on discord

Future:
[ ] Add more elegant meta class switching between ...LanguageTrajectories and ...RlTrajectories
[ ] Add online evaluation script for models/lm/casuallm
[ ] Improve train script to include reward accuracy
[ ] Run some preliminary experiments
[ ] Add __main__ file with click CLI interface for running experiments
Empty file.
83 changes: 83 additions & 0 deletions algorithm_distillation/models/casual_lm /legacy_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import torch
import wandb
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from accelerate import Accelerator
from rollouts_as_lm_task import RolloutsAsLanguageModellingTask
from utils import ShuffledIterableDataset
from tqdm.auto import tqdm

wandb.init(project="algorithm-distillation")
logging_table = wandb.Table(columns=["Step", "Generation"])

accelerator = Accelerator()

model = AutoModelForCausalLM.from_pretrained('gpt2')
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)

tokenizer = AutoTokenizer.from_pretrained('gpt2')
train_dataset = RolloutsAsLanguageModellingTask(tokenizer, './decoded_rollouts/train')
train_dataset = ShuffledIterableDataset(train_dataset, buffer_size=10_000)
eval_dataset = RolloutsAsLanguageModellingTask(tokenizer, './decoded_rollouts/eval', verbose=False)
generate_dataset = RolloutsAsLanguageModellingTask(tokenizer, './decoded_rollouts/eval', for_generation=True, verbose=False)

train_dataloader = torch.utils.data.DataLoader(train_dataset, shuffle=False)

model, optimizer, train_dataloader = accelerator.prepare(model, optimizer, train_dataloader)

model.train()
total_steps = 0
for epoch in range(10):
for batch in tqdm(train_dataloader, desc=f'Training epoch {epoch}'):
# train
optimizer.zero_grad()
output = model(**batch)
loss = output.loss
wandb.log({'loss': loss.item(), 'step': total_steps})
accelerator.backward(loss)
optimizer.step()
total_steps += 1

# eval
eval_steps = 25
eval_size = 20
if total_steps % eval_steps == 0:
model.eval()
eval_loss = 0
eval_dataloader = torch.utils.data.DataLoader(eval_dataset, shuffle=False)
eval_dataloader = accelerator.prepare(eval_dataloader)
for idx, batch in tqdm(enumerate(eval_dataloader), desc=f'Eval after {total_steps} steps.'):
if idx >= eval_size:
break
output = model(**batch)
loss = output.loss
eval_loss += loss.item()
eval_loss /= idx
wandb.log({'eval_loss': eval_loss, 'step': total_steps})
print(f'Avg loss after {total_steps}: {eval_loss}')
model.train()

# generate examples
generate_steps = 10
generate_size = 3
if total_steps % generate_steps == 0:
model.eval()

generate_dataloader = torch.utils.data.DataLoader(generate_dataset, shuffle=False)
generate_dataloader = accelerator.prepare(generate_dataloader)

for idx, batch in enumerate(generate_dataloader):
if idx >= generate_size:
break
batch = {k:v.squeeze(0) for k,v in batch.items()}
outputs = model.generate(**batch, max_length=100, do_sample=True, pad_token_id=tokenizer.eos_token_id)
text = tokenizer.decode(outputs[0])

logging_table.add_data(total_steps, text)

logging_table = wandb.Table(columns=logging_table.columns, data=logging_table.data)
wandb.log({'Generations Table': logging_table})


model.train()

Empty file.
1 change: 1 addition & 0 deletions algorithm_distillation/tasks/lm/sentiment/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .dataset import SentimentLanguageTrajectories
90 changes: 90 additions & 0 deletions algorithm_distillation/tasks/lm/sentiment/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import torch
from pathlib import Path
import json
from transformers import AutoTokenizer
from typing import Dict, Any
from pathlib import Path

class SentimentRlTrajectories(torch.utils.data.Dataset):
def __init__(self):
raise NotImplementedError()


class SentimentLanguageTrajectories(torch.utils.data.IterableDataset):
def __init__(self, tokenizer: AutoTokenizer, split: str, for_generation: bool = False, verbose: bool = True):
self.tokenizer = tokenizer
self.verbose = verbose
self.for_generation = for_generation

if split == 'train':
self.rollouts_folder = Path(__file__).parent / "decoded_rollouts" / "train"
elif split == 'eval':
self.rollouts_folder = Path(__file__).parent / "decoded_rollouts" / "eval"
else:
raise RuntimeError(f"split must be either 'train' or 'eval', got: {split}")


def format_rollout(self, d: Dict[Any, Any]) -> str:
return f"Prompt: {d['query_text']}\nCompletion: {d['response_text']}\nReward: {d['rewards'][-1]}\n\n"

def tokenize_for_training(self, x: str):
inputs = self.tokenizer(x, truncation=True, return_tensors='pt')
return {
'input_ids': inputs.input_ids,
'attention_mask': inputs.attention_mask,
'labels': inputs.input_ids
}

def __iter__(self):

runs = [run for run in self.rollouts_folder.iterdir() if run.name.startswith('run-e')]
if self.verbose:
print(f'Iterating over {len(runs)} runs...')

for run in runs:

config = json.loads(open(run / 'config.json', 'r').read())
epochs = [epoch for epoch in run.iterdir() if epoch.name != 'config.json']

if self.verbose:
print(f'...and {run.name} has {len(epochs)} epochs.')

epochs = sorted(epochs)
for epoch in epochs:
# print(f'Yielding from epoch {epoch.name}')

rollouts = json.loads(open(epoch, 'r').read())

rollout_idx = 0
prompt = ""
while rollout_idx < len(rollouts):

if self.for_generation:
d = rollouts[rollout_idx]
rollout_idx += 1
generation_prompt = f"Prompt: {d['query_text']}\nCompletion:"
yield self.tokenize_for_training(generation_prompt)

else:
rollout = self.format_rollout(rollouts[rollout_idx])
rollout_idx += 1

new_prompt = prompt + rollout
new_ids = self.tokenizer.tokenize(new_prompt)

if len(new_ids) > self.tokenizer.model_max_length: # self.tokenizer.model_max_length:
yield self.tokenize_for_training(prompt)
prompt = ""
else:
prompt = new_prompt

if len(prompt) > 0:
yield self.tokenize_for_training(prompt)


if __name__ == '__main__':
tokenizer = AutoTokenizer.from_pretrained('gpt2')
dataset = SentimentLanguageTrajectories(tokenizer, split='train', for_generation=False)
for ex in dataset:
print(tokenizer.decode(ex['input_ids'][0]))
print('\n---------\n')
101 changes: 101 additions & 0 deletions algorithm_distillation/tasks/lm/sentiment/decode_rollouts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
from pathlib import Path
import json
from transformers import AutoTokenizer
import click
from typing import Union
from shutil import copy2

@click.group()
def main():
"""
CLI for formatting the rollouts into training data.

\b
1. decode-epoch: Decode a single epoch rollout .json file
2. decode-run: Decode an entire PPO run (multiple epoch files)
3. decode-rollouts: Deocde an entire directory (multiple runs)
"""

def get_tokenizer(tokenizer: Union[str, AutoTokenizer]) -> AutoTokenizer:
if isinstance(tokenizer, str):
return AutoTokenizer.from_pretrained(tokenizer)
else:
return tokenizer


@main.command()
@click.option('--tokenizer', type=str, default='gpt2', help='tokenizer to decode with')
@click.option('--input-fpath', type=click.Path(path_type=Path), help='the input JSON file')
@click.option('--output-fpath', type=click.Path(path_type=Path), help='the path of the JSON file to be created. Will overwrite if necessary.')
def decode_epoch(tokenizer: Union[str, AutoTokenizer], input_fpath: Path, output_fpath: Path):
_decode_epoch(tokenizer, input_fpath, output_fpath)

def _decode_epoch(tokenizer: Union[str, AutoTokenizer], input_fpath: Path, output_fpath: Path):

assert input_fpath.exists()
assert input_fpath.is_file()
assert input_fpath.name.endswith('.json')
assert output_fpath.name.endswith('.json')
if not output_fpath.parent.exists():
output_fpath.parent.mkdir()

tokenizer = get_tokenizer(tokenizer)

with open(input_fpath, 'r') as f:
rollouts = json.loads(f.read())

for rollout in rollouts:
rollout['query_text'] = tokenizer.decode(rollout['query_tensor'], skip_special_tokens=True)
rollout['response_text'] = tokenizer.decode(rollout['response_tensor'], skip_special_tokens=True)

with open(output_fpath, 'w') as f:
f.write(json.dumps(rollouts, indent=2))


@main.command()
@click.option('--tokenizer', type=str, default='gpt2', help='tokenizer to decode with')
@click.option('--input-fpath', type=click.Path(path_type=Path), help='the directory containing the JSON epoch files')
@click.option('--output-fpath', type=click.Path(path_type=Path), help='the path of the folder to be created.')
def decode_run(tokenizer: Union[str, AutoTokenizer], input_fpath: Path, output_fpath: Path):
_decode_run(tokenizer, input_fpath, output_fpath)

def _decode_run(tokenizer: Union[str, AutoTokenizer], input_fpath: Path, output_fpath: Path):

assert input_fpath.exists()
assert input_fpath.is_dir()
if not output_fpath.exists():
output_fpath.mkdir()

# Copy over the config file
assert (input_fpath / 'config.json').exists()
copy2(input_fpath / 'config.json', output_fpath / 'config.json')

# Decode the rest of the files
tokenizer = get_tokenizer(tokenizer)

epochs = [fpath for fpath in input_fpath.iterdir() if fpath.name != 'config.json']
for epoch in epochs:
_decode_epoch(tokenizer, epoch, output_fpath / epoch.name)


@main.command()
@click.option('--tokenizer', type=str, default='gpt2', help='tokenizer to decode with')
@click.option('--input-fpath', type=click.Path(path_type=Path), help='the input directory')
@click.option('--output-fpath', type=click.Path(path_type=Path), help='the output directory')
def decode_rollouts(tokenizer: str, input_fpath: Path, output_fpath: Path):
_decode_rollouts(tokenizer, input_fpath, output_fpath)

def _decode_rollouts(tokenizer: str, input_fpath: Path, output_fpath: Path):

assert input_fpath.exists()
assert input_fpath.is_dir()
if not output_fpath.exists():
output_fpath.mkdir()

tokenizer = get_tokenizer(tokenizer)
runs = [fpath for fpath in input_fpath.iterdir() if fpath.name.startswith('run-')]
for run in runs:
_decode_run(tokenizer, run, output_fpath / run.name)

if __name__ == '__main__':
main()
Loading