Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
64bd6ea
Fix sampling behavior. Reduce duplicated codes.
honglu2875 Dec 18, 2022
0316768
Minor fix on requirements.txt.
honglu2875 Dec 18, 2022
c98720e
Added customized logger and included its usage in the test.
honglu2875 Dec 18, 2022
23eb20b
Added rollout method for AD.
honglu2875 Dec 19, 2022
88faa77
Fix type.
honglu2875 Dec 19, 2022
dcc1350
Added verbose and return more logging info. Added a couple paragraphs…
honglu2875 Dec 19, 2022
3caa347
Fix tqdm.
honglu2875 Dec 19, 2022
1bb95fc
Fix sampling behavior.
honglu2875 Dec 22, 2022
de320ee
Fix rollout behavior when skip is more than 0.
honglu2875 Dec 24, 2022
5f90cc5
Merge remote-tracking branch 'honglu2875/more_testing' into more_testing
honglu2875 Dec 24, 2022
305de88
Enforce types for rewards/terminals in ad rollout (some sloppy envs m…
honglu2875 Dec 26, 2022
2ca2eb8
Supports PPO and A2C. Expands environment to continuous observation l…
honglu2875 Dec 26, 2022
fd58ec7
Fix format.
honglu2875 Dec 26, 2022
e2ae059
README and docstrings.
honglu2875 Dec 26, 2022
9e07368
Fix various bugs (length size issue, callback collection length issue…
honglu2875 Dec 29, 2022
89d91c1
A2C steps seem undeterministic... leave it alone for now.
honglu2875 Dec 29, 2022
2e0d95a
Refactored folder structures to pave the way for a merge.
honglu2875 Jan 12, 2023
bf519ea
Merge remote-tracking branch 'honglu2875/more_testing' into more_testing
honglu2875 Jan 12, 2023
75029b2
Refactored folder structures to pave the way for a merge.
honglu2875 Jan 12, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1 +1,10 @@
# Algorithm-Distillation-RLHF
# Algorithm-Distillation-RLHF

The current state of this repo is a preliminary version of a replication of the Algorithmic Distillation algorithm
described in the paper [In-context Reinforcement Learning with Algorithm Distillation](https://arxiv.org/abs/2210.14215).
We also aim to make the codes general enough to try ideas beyond the paper.

# Quick start
A demo script/notebook is not provided yet, but the unit test `tests/test_ad.py` provides a complete routine of applying
the transformer to the histories of toy tasks "FrozenLake-v1", "CartPole-v1". Please [take a look](tests/test_ad.py)
and feel free to plug in your own gym env.
7 changes: 3 additions & 4 deletions algorithm_distillation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from .ad import AlgorithmDistillation, GymAD
from .task import GymTask, Task
from .task_manager import TaskManager
from algorithm_distillation import models, tasks
from algorithm_distillation.ad import AlgorithmDistillation, GymAD

__all__ = ["AlgorithmDistillation", "GymAD", "Task", "GymTask", "TaskManager"]
__all__ = ["models", "tasks", "AlgorithmDistillation", "GymAD"]
165 changes: 147 additions & 18 deletions algorithm_distillation/ad.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import abc
import logging

import numpy as np
import torch
from tqdm import tqdm

from algorithm_distillation.models.ad_transformer import ADTransformer

from .task_manager import TaskManager
from algorithm_distillation.models.util import get_sequence
from algorithm_distillation.tasks.rl.task import GymTask
from algorithm_distillation.tasks.rl.task_manager import TaskManager


class AlgorithmDistillation(abc.ABC):
Expand All @@ -14,6 +18,7 @@ class AlgorithmDistillation(abc.ABC):

def __init__(self, model: ADTransformer):
self.model = model
self.logger = logging.getLogger(__name__)

@abc.abstractmethod
def train(
Expand All @@ -23,8 +28,12 @@ def train(
length: int,
skip: int,
batch_size: int,
**config
):
**config,
) -> list:
pass

@abc.abstractmethod
def rollout(self, task, steps: int, skip: int) -> tuple:
pass


Expand All @@ -35,34 +44,47 @@ def train(
steps: int,
length: int,
skip: int,
batch_size: int,
**config
):
batch_size: int = 32,
lr: float = 1e-4,
verbose: int = 0,
**config,
) -> list:
"""
Collect samples and train `steps` amount of gradient steps.

:param task_manager: the controller that controls a collection of tasks.
:param steps: the amount of gradient steps to train.
:param length: the step-length of sampled sequences (not the sequence length which is 3x).
:param skip: the amount of states to skip between two consecutive ones.
:param batch_size: the batch size.
:param batch_size: (Optional) the batch size.
:param lr: (Optional) the learning rate.
:param verbose: (Optional) verbose level. Nonzero => showing progress bar and certain logs.
:param config: the extra config that goes into transformer training.
:return: None
:return: a list of losses
"""
# Combine the config and the direct args.
# Note: direct args `batch_size` and `lr` override the config dict!
cfg = {**config, "batch_size": batch_size, "lr": lr}

# We implement a PyTorch training loop.
# Use GPU if exists.
device = (
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
)
if verbose:
self.logger.info(f"Device: {device.type}")

data_iter = self._get_data_iter(
steps, batch_size, task_manager, length, skip, device=device
steps, cfg["batch_size"], task_manager, length, skip, device=device
)
self.model.to(device)
optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-2)
optimizer = torch.optim.Adam(self.model.parameters(), lr=cfg["lr"])

self.model.train() # Set to train mode so that dropout and batch norm would update.
losses = []
for step, sample in enumerate(data_iter):

_tqdm_iter = tqdm(enumerate(data_iter), total=steps, disable=(verbose == 0))
for step, sample in _tqdm_iter:
optimizer.zero_grad()
obs, actions, rewards = sample
one_hot_actions = torch.nn.functional.one_hot(
Expand All @@ -75,9 +97,110 @@ def train(
losses.append(loss.item())
optimizer.step()

self.model.eval() # By default, set to eval mode outside of training.
if verbose: # Update loss if verbose is on
_tqdm_iter.set_postfix(ordered_dict={"loss": losses[-1]})

self.model.eval() # By default, set to eval mode outside training.
return losses

def rollout(
self,
task: GymTask,
steps: int,
skip: int,
verbose: int = 0,
) -> tuple:
"""
Roll out for `steps` amount of steps (ignore the policy embedded in `task` and only uses its _env).

:param task: the task to perform rollout on.
:param steps: the amount of steps to roll out.
:param skip: the amount of steps to skip (normally should be the same as `skip` during training).
:param verbose: (Optional) verbose level. Nonzero => showing progress bar and certain logs.
:return: the full sequences (observations, actions, rewards), each of length `steps`.
"""
device = (
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
)
self.model.to(device)

for st in ["obs", "act"]:
if getattr(task, f"{st}_dim") != getattr(self.model, f"{st}_dim"):
raise ValueError(
f"The task must have observation dimension {self.model.obs_dim}"
)
env = task.env

# The max_len of history should be 1 less than max_step_len (leave room for current_obs)
# (recall that max sequence length for the transformer is `model.max_step_len * 3`)
max_len = self.model.max_step_len - 1

# Prepare sequential inputs/outputs for the transformer
observations = torch.zeros(
(steps, task.obs_dim), device=device, dtype=torch.float
)
# Predicted action logits
action_logits = torch.zeros(
(steps, task.act_dim), device=device, dtype=torch.float
)
# The actual actions taken (argmax of action_logits)
actions = torch.zeros((steps,), device=device, dtype=torch.long)
# The actual one-hot encoded actions (nn.one_hot of actions)
actions_one_hot = torch.zeros(
(steps, task.act_dim), device=device, dtype=torch.float
)
rewards = torch.zeros((steps, 1), device=device, dtype=torch.float)
terminals = torch.zeros((steps,), device=device, dtype=torch.bool)

obs, done = None, True
cum_reward = 0.0

_tqdm_iter = tqdm(range(steps), disable=(verbose == 0))
for step in _tqdm_iter:
if done: # Last step was terminal. Reset.
obs, done = (
torch.tensor(
task.obs_post_process(np.array([env.reset()])),
device=device,
dtype=torch.float,
),
False,
)

# TODO: can probably be optimized using kv cache
with torch.inference_mode():
# The input of the model is collected from the index `step` (exclusive) going backwards
# with interval `skip + 1`. It goes as far back as possible until either the beginning or
# `max_len` number of steps (the maximal number of steps that can fit into the transformer)
# We then take the argmax of the prediction of the next action and perform the
# rollout.
action_logits[step] = self.model(
get_sequence(observations, max_len, step, skip + 1)[None, :],
get_sequence(actions_one_hot, max_len, step, skip + 1)[None, :],
get_sequence(rewards, max_len, step, skip + 1)[None, :],
current_obs=obs[None, 0],
action_only=True,
)[0, min(step // (skip + 1), max_len - 1)]

actions[step] = torch.argmax(action_logits[step]).type(torch.long)
actions_one_hot[step] = torch.nn.functional.one_hot(
actions[step], num_classes=task.act_dim
).type(torch.float)

observations[step] = obs[None, 0]
obs, rew, done, _ = env.step(actions[step].item()) # are still np.ndarray
obs = torch.tensor(
task.obs_post_process(np.array([obs])), device=device, dtype=torch.float
)

rewards[step] = float(rew)
terminals[step] = bool(done)
cum_reward += float(rew)

if verbose: # Update loss if verbose is on
_tqdm_iter.set_postfix(ordered_dict={"cum_reward": cum_reward})

print(losses)
return observations, actions, rewards, terminals

@staticmethod
def _get_data_iter(
Expand All @@ -90,13 +213,19 @@ def _get_data_iter(

yield (
torch.tensor(
[sample[0] for sample in samples], dtype=torch.float, device=device
np.array([sample[0] for sample in samples]),
dtype=torch.float,
device=device,
), # observations
torch.tensor(
[sample[1] for sample in samples], dtype=torch.long, device=device
np.array([sample[1] for sample in samples]),
dtype=torch.long,
device=device,
), # actions
torch.tensor(
[sample[2] for sample in samples], dtype=torch.float, device=device
np.array([sample[2] for sample in samples]),
dtype=torch.float,
device=device,
), # rewards
)

Expand All @@ -109,5 +238,5 @@ def _compute_loss(x, y) -> torch.Tensor:
"""
assert y.dtype == torch.long
assert x.shape[:-1] + (1,) == y.shape
x = torch.nn.functional.log_softmax(x) # (b, length, action_num)
x = torch.nn.functional.log_softmax(x, dim=-1) # (b, length, action_num)
return -torch.take_along_dim(x, y, dim=len(y.shape) - 1).sum(-1).mean()
2 changes: 1 addition & 1 deletion algorithm_distillation/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .gpt2 import GPT2AD

__all__ = ["util", "GPT2AD"]
__all__ = ["sb3_util", "util", "GPT2AD"]
3 changes: 3 additions & 0 deletions algorithm_distillation/models/ad_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
class ADTransformer(abc.ABC, torch.nn.Module):
obs_dim: int
act_dim: int
# The maximal amount of steps in the input.
# Note: the corresponding max sequence length is `max_step_len * 3`.
max_step_len: int

@abc.abstractmethod
def __init__(self, **kwargs):
Expand Down
Loading