diff --git a/README.md b/README.md index fea9271..6255ade 100644 --- a/README.md +++ b/README.md @@ -1 +1,10 @@ -# Algorithm-Distillation-RLHF \ No newline at end of file +# Algorithm-Distillation-RLHF + +The current state of this repo is a preliminary version of a replication of the Algorithmic Distillation algorithm +described in the paper [In-context Reinforcement Learning with Algorithm Distillation](https://arxiv.org/abs/2210.14215). +We also aim to make the codes general enough to try ideas beyond the paper. + +# Quick start +A demo script/notebook is not provided yet, but the unit test `tests/test_ad.py` provides a complete routine of applying +the transformer to the histories of toy tasks "FrozenLake-v1", "CartPole-v1". Please [take a look](tests/test_ad.py) +and feel free to plug in your own gym env. diff --git a/algorithm_distillation/__init__.py b/algorithm_distillation/__init__.py index 0a31263..cfde3d4 100644 --- a/algorithm_distillation/__init__.py +++ b/algorithm_distillation/__init__.py @@ -1,5 +1,4 @@ -from .ad import AlgorithmDistillation, GymAD -from .task import GymTask, Task -from .task_manager import TaskManager +from algorithm_distillation import models, tasks +from algorithm_distillation.ad import AlgorithmDistillation, GymAD -__all__ = ["AlgorithmDistillation", "GymAD", "Task", "GymTask", "TaskManager"] +__all__ = ["models", "tasks", "AlgorithmDistillation", "GymAD"] diff --git a/algorithm_distillation/ad.py b/algorithm_distillation/ad.py index f42a70f..fa88090 100644 --- a/algorithm_distillation/ad.py +++ b/algorithm_distillation/ad.py @@ -1,10 +1,14 @@ import abc +import logging +import numpy as np import torch +from tqdm import tqdm from algorithm_distillation.models.ad_transformer import ADTransformer - -from .task_manager import TaskManager +from algorithm_distillation.models.util import get_sequence +from algorithm_distillation.tasks.rl.task import GymTask +from algorithm_distillation.tasks.rl.task_manager import TaskManager class AlgorithmDistillation(abc.ABC): @@ -14,6 +18,7 @@ class AlgorithmDistillation(abc.ABC): def __init__(self, model: ADTransformer): self.model = model + self.logger = logging.getLogger(__name__) @abc.abstractmethod def train( @@ -23,8 +28,12 @@ def train( length: int, skip: int, batch_size: int, - **config - ): + **config, + ) -> list: + pass + + @abc.abstractmethod + def rollout(self, task, steps: int, skip: int) -> tuple: pass @@ -35,9 +44,11 @@ def train( steps: int, length: int, skip: int, - batch_size: int, - **config - ): + batch_size: int = 32, + lr: float = 1e-4, + verbose: int = 0, + **config, + ) -> list: """ Collect samples and train `steps` amount of gradient steps. @@ -45,24 +56,35 @@ def train( :param steps: the amount of gradient steps to train. :param length: the step-length of sampled sequences (not the sequence length which is 3x). :param skip: the amount of states to skip between two consecutive ones. - :param batch_size: the batch size. + :param batch_size: (Optional) the batch size. + :param lr: (Optional) the learning rate. + :param verbose: (Optional) verbose level. Nonzero => showing progress bar and certain logs. :param config: the extra config that goes into transformer training. - :return: None + :return: a list of losses """ + # Combine the config and the direct args. + # Note: direct args `batch_size` and `lr` override the config dict! + cfg = {**config, "batch_size": batch_size, "lr": lr} + # We implement a PyTorch training loop. # Use GPU if exists. device = ( torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") ) + if verbose: + self.logger.info(f"Device: {device.type}") + data_iter = self._get_data_iter( - steps, batch_size, task_manager, length, skip, device=device + steps, cfg["batch_size"], task_manager, length, skip, device=device ) self.model.to(device) - optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-2) + optimizer = torch.optim.Adam(self.model.parameters(), lr=cfg["lr"]) self.model.train() # Set to train mode so that dropout and batch norm would update. losses = [] - for step, sample in enumerate(data_iter): + + _tqdm_iter = tqdm(enumerate(data_iter), total=steps, disable=(verbose == 0)) + for step, sample in _tqdm_iter: optimizer.zero_grad() obs, actions, rewards = sample one_hot_actions = torch.nn.functional.one_hot( @@ -75,9 +97,110 @@ def train( losses.append(loss.item()) optimizer.step() - self.model.eval() # By default, set to eval mode outside of training. + if verbose: # Update loss if verbose is on + _tqdm_iter.set_postfix(ordered_dict={"loss": losses[-1]}) + + self.model.eval() # By default, set to eval mode outside training. + return losses + + def rollout( + self, + task: GymTask, + steps: int, + skip: int, + verbose: int = 0, + ) -> tuple: + """ + Roll out for `steps` amount of steps (ignore the policy embedded in `task` and only uses its _env). + + :param task: the task to perform rollout on. + :param steps: the amount of steps to roll out. + :param skip: the amount of steps to skip (normally should be the same as `skip` during training). + :param verbose: (Optional) verbose level. Nonzero => showing progress bar and certain logs. + :return: the full sequences (observations, actions, rewards), each of length `steps`. + """ + device = ( + torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + ) + self.model.to(device) + + for st in ["obs", "act"]: + if getattr(task, f"{st}_dim") != getattr(self.model, f"{st}_dim"): + raise ValueError( + f"The task must have observation dimension {self.model.obs_dim}" + ) + env = task.env + + # The max_len of history should be 1 less than max_step_len (leave room for current_obs) + # (recall that max sequence length for the transformer is `model.max_step_len * 3`) + max_len = self.model.max_step_len - 1 + + # Prepare sequential inputs/outputs for the transformer + observations = torch.zeros( + (steps, task.obs_dim), device=device, dtype=torch.float + ) + # Predicted action logits + action_logits = torch.zeros( + (steps, task.act_dim), device=device, dtype=torch.float + ) + # The actual actions taken (argmax of action_logits) + actions = torch.zeros((steps,), device=device, dtype=torch.long) + # The actual one-hot encoded actions (nn.one_hot of actions) + actions_one_hot = torch.zeros( + (steps, task.act_dim), device=device, dtype=torch.float + ) + rewards = torch.zeros((steps, 1), device=device, dtype=torch.float) + terminals = torch.zeros((steps,), device=device, dtype=torch.bool) + + obs, done = None, True + cum_reward = 0.0 + + _tqdm_iter = tqdm(range(steps), disable=(verbose == 0)) + for step in _tqdm_iter: + if done: # Last step was terminal. Reset. + obs, done = ( + torch.tensor( + task.obs_post_process(np.array([env.reset()])), + device=device, + dtype=torch.float, + ), + False, + ) + + # TODO: can probably be optimized using kv cache + with torch.inference_mode(): + # The input of the model is collected from the index `step` (exclusive) going backwards + # with interval `skip + 1`. It goes as far back as possible until either the beginning or + # `max_len` number of steps (the maximal number of steps that can fit into the transformer) + # We then take the argmax of the prediction of the next action and perform the + # rollout. + action_logits[step] = self.model( + get_sequence(observations, max_len, step, skip + 1)[None, :], + get_sequence(actions_one_hot, max_len, step, skip + 1)[None, :], + get_sequence(rewards, max_len, step, skip + 1)[None, :], + current_obs=obs[None, 0], + action_only=True, + )[0, min(step // (skip + 1), max_len - 1)] + + actions[step] = torch.argmax(action_logits[step]).type(torch.long) + actions_one_hot[step] = torch.nn.functional.one_hot( + actions[step], num_classes=task.act_dim + ).type(torch.float) + + observations[step] = obs[None, 0] + obs, rew, done, _ = env.step(actions[step].item()) # are still np.ndarray + obs = torch.tensor( + task.obs_post_process(np.array([obs])), device=device, dtype=torch.float + ) + + rewards[step] = float(rew) + terminals[step] = bool(done) + cum_reward += float(rew) + + if verbose: # Update loss if verbose is on + _tqdm_iter.set_postfix(ordered_dict={"cum_reward": cum_reward}) - print(losses) + return observations, actions, rewards, terminals @staticmethod def _get_data_iter( @@ -90,13 +213,19 @@ def _get_data_iter( yield ( torch.tensor( - [sample[0] for sample in samples], dtype=torch.float, device=device + np.array([sample[0] for sample in samples]), + dtype=torch.float, + device=device, ), # observations torch.tensor( - [sample[1] for sample in samples], dtype=torch.long, device=device + np.array([sample[1] for sample in samples]), + dtype=torch.long, + device=device, ), # actions torch.tensor( - [sample[2] for sample in samples], dtype=torch.float, device=device + np.array([sample[2] for sample in samples]), + dtype=torch.float, + device=device, ), # rewards ) @@ -109,5 +238,5 @@ def _compute_loss(x, y) -> torch.Tensor: """ assert y.dtype == torch.long assert x.shape[:-1] + (1,) == y.shape - x = torch.nn.functional.log_softmax(x) # (b, length, action_num) + x = torch.nn.functional.log_softmax(x, dim=-1) # (b, length, action_num) return -torch.take_along_dim(x, y, dim=len(y.shape) - 1).sum(-1).mean() diff --git a/algorithm_distillation/models/__init__.py b/algorithm_distillation/models/__init__.py index 6a82b3e..c6300e9 100644 --- a/algorithm_distillation/models/__init__.py +++ b/algorithm_distillation/models/__init__.py @@ -1,3 +1,3 @@ from .gpt2 import GPT2AD -__all__ = ["util", "GPT2AD"] +__all__ = ["sb3_util", "util", "GPT2AD"] diff --git a/algorithm_distillation/models/ad_transformer.py b/algorithm_distillation/models/ad_transformer.py index 334342b..bdb398b 100644 --- a/algorithm_distillation/models/ad_transformer.py +++ b/algorithm_distillation/models/ad_transformer.py @@ -6,6 +6,9 @@ class ADTransformer(abc.ABC, torch.nn.Module): obs_dim: int act_dim: int + # The maximal amount of steps in the input. + # Note: the corresponding max sequence length is `max_step_len * 3`. + max_step_len: int @abc.abstractmethod def __init__(self, **kwargs): diff --git a/algorithm_distillation/models/gpt2.py b/algorithm_distillation/models/gpt2.py index 281bc60..256d11e 100644 --- a/algorithm_distillation/models/gpt2.py +++ b/algorithm_distillation/models/gpt2.py @@ -35,7 +35,7 @@ def __init__( obs_dim, act_dim, hidden_size, - max_ep_len=4096, + max_step_len=1024, action_tanh=True, obs_emb_cls=torch.nn.Linear, act_emb_cls=torch.nn.Linear, @@ -48,7 +48,7 @@ def __init__( :param obs_dim: observation dimension (as a flattened tensor) :param act_dim: action dimension (as a flattened tensor) :param hidden_size: the dimension of the embedding space - :param max_ep_len: (Optional) maximal episode length + :param max_step_len: (Optional) maximal episode length :param action_tanh: (Optional) apply tanh activation function on the action :param obs_emb_cls: (Optional) the nn.Module class for observation embedding :param act_emb_cls: (Optional) the nn.Module class for action embedding @@ -61,13 +61,16 @@ def __init__( self.act_dim = act_dim self.hidden_size = hidden_size self.action_tanh = action_tanh + self.max_step_len = max_step_len # Generate the most basic GPT2 config - config = transformers.GPT2Config(vocab_size=1, n_embd=hidden_size, **kwargs) + config = transformers.GPT2Config( + vocab_size=1, n_embd=hidden_size, n_positions=max_step_len * 3, **kwargs + ) self.transformers = transformers.GPT2Model(config) # Remove the position embedding by replacing it with a dummy. self.transformers.wpe = ZeroDummy((hidden_size,)) # This is our position embedding based on steps. - self.step_embedding = torch.nn.Embedding(max_ep_len, self.hidden_size) + self.step_embedding = torch.nn.Embedding(max_step_len, self.hidden_size) # The embedding layers self.obs_embedding = obs_emb_cls(self.obs_dim, self.hidden_size) @@ -110,9 +113,9 @@ def forward( But other cases are allowed, e.g., `..., latest_obs, latest_act, None` -> `next_reward` - :param obs: (b, t, obs_dim) - :param actions: (b, t, act_dim) - :param rewards: (b, t, 1) + :param obs: (b, t, obs_dim) or None if b==0 + :param actions: (b, t, act_dim) or None if b==0 + :param rewards: (b, t, 1) or None if b==0 :param current_obs: (Optional) (b, obs_dim) :param current_action: (Optional) shape (b, act_dim) :param current_reward: (Optional) shape (b, 1) @@ -122,32 +125,45 @@ def forward( :param action_only: (Optional) return predicted actions only. :return: predicted action logits (if action_only) or predicted action logits, rewards, and obs. """ - device = obs.device + if ( + obs is None or obs.numel() == 0 + ): # None or empty tensor are regarded as no history + assert current_obs is not None, "Empty input." + device = ( + torch.device("cuda") + if torch.cuda.is_available() + else torch.device("cpu") + ) + batch_size, history_length = current_obs.shape[0], 0 + else: + device = obs.device + batch_size, history_length, _ = obs.shape - batch_size, timestep, _ = obs.shape if current_step_id is None: if step_ids is not None: logger.warning( - "'current_step_id' defaults to the number of steps. But it may conflict with the given 'step_ids'." + "'current_step_id' will default to the number of history steps. " + "But it may conflict with the given 'step_ids'." ) - current_step_id = timestep + current_step_id = history_length - if attention_mask is None: - attention_mask = torch.ones( - (batch_size, timestep), dtype=torch.float, device=device + if step_ids is None and history_length > 0: + step_ids = ( + torch.arange(0, history_length, dtype=torch.long, device=device) + .view(1, history_length) + .repeat((batch_size, 1)) ) - if step_ids is None: - step_ids = torch.arange(0, timestep, dtype=torch.long, device=device).view( - 1, timestep + if history_length == 0: + embedded_obs, embedded_act, embedded_rew = None, None, None + else: + embedded_steps = self.step_embedding(step_ids).view( + batch_size, history_length, self.hidden_size ) - embedded_steps = self.step_embedding(step_ids).view( - 1, timestep, self.hidden_size - ) + embedded_obs = self.obs_embedding(obs) + embedded_steps + embedded_act = self.act_embedding(actions) + embedded_steps + embedded_rew = self.rew_embedding(rewards) + embedded_steps - embedded_obs = self.obs_embedding(obs) + embedded_steps - embedded_act = self.act_embedding(actions) + embedded_steps - embedded_rew = self.rew_embedding(rewards) + embedded_steps embedded_latest_step = self.step_embedding( torch.tensor([current_step_id], dtype=torch.long, device=device) ) @@ -168,16 +184,28 @@ def forward( # Note: only affects axis 1. Axis 0 (batch) and axis 2 (embedding) are preserved. input_seq = stack_seq(embedded_obs, embedded_act, embedded_rew, extra) input_seq = self.layer_norm_in_embedding(input_seq) - attention_mask = ( - attention_mask.unsqueeze(-1).repeat((1, 1, 3)).view(batch_size, -1) - ) - attention_mask = torch.concat( - [ - attention_mask, - torch.ones((batch_size, num_extra), dtype=torch.float, device=device), - ], - dim=1, - ) + + if history_length == 0: # Empty history -> no need to concatenate with history. + attention_mask = torch.ones( + (batch_size, num_extra), dtype=torch.float, device=device + ) + else: # Nonempty history -> concatenate history sequence with current state info. + if attention_mask is None: + attention_mask = torch.ones( + (batch_size, history_length), dtype=torch.float, device=device + ) + attention_mask = ( + attention_mask.unsqueeze(-1).repeat((1, 1, 3)).view(batch_size, -1) + ) + attention_mask = torch.concat( + [ + attention_mask, + torch.ones( + (batch_size, num_extra), dtype=torch.float, device=device + ), + ], + dim=1, + ) # Do inference using the underlying transformer. output = self.transformers( diff --git a/algorithm_distillation/models/sb3_util/__init__.py b/algorithm_distillation/models/sb3_util/__init__.py new file mode 100644 index 0000000..03cd976 --- /dev/null +++ b/algorithm_distillation/models/sb3_util/__init__.py @@ -0,0 +1,3 @@ +from .logger import CustomLogger, configure + +__all__ = ["CustomLogger", "configure"] diff --git a/algorithm_distillation/models/sb3_util/callback.py b/algorithm_distillation/models/sb3_util/callback.py new file mode 100644 index 0000000..f1de03f --- /dev/null +++ b/algorithm_distillation/models/sb3_util/callback.py @@ -0,0 +1,64 @@ +from stable_baselines3.common.buffers import ReplayBuffer +from stable_baselines3.common.callbacks import BaseCallback +from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm + + +class RolloutCallback(BaseCallback): + """ + This is a custom callback that collects rollouts from an on-policy algorithm. + + :param buffer: The external replay buffer to save rollouts. + :param verbose: Verbosity level: 0 for no output, 1 for info messages, 2 for debug messages. + """ + + def __init__( + self, agent: OnPolicyAlgorithm, buffer: ReplayBuffer, verbose: int = 0 + ): + super().__init__(verbose) + self.agent = agent + self.buffer = buffer + assert ( + self.buffer.buffer_size >= self.agent.rollout_buffer.buffer_size + ), "External replay buffer must be larger than the agent rollout buffer." + self._enforce_type() + + def _on_step(self) -> bool: + return True + + def _on_rollout_end(self) -> None: + start = self.buffer.pos + end = start + self.agent.rollout_buffer.buffer_size + + buffer = self.buffer + tgt_buffer = self.agent.rollout_buffer + + ranges = [(start, min(buffer.buffer_size, end))] + tgt_start = [0] + if buffer.buffer_size < end: # If overflown, wrap around to the beginning + ranges.append((0, end - buffer.buffer_size)) + tgt_start.append(buffer.buffer_size - start) + buffer.full = True + + for (st, ed), tst in zip(ranges, tgt_start): + buffer.observations[st:ed] = tgt_buffer.observations[ + tst : tst + ed - st + ].copy() + buffer.actions[st:ed] = tgt_buffer.actions[tst : tst + ed - st].copy() + buffer.rewards[st:ed] = tgt_buffer.rewards[tst : tst + ed - st].copy() + self.buffer.pos = ed # Update the pointer to the last ending range + + def _enforce_type(self) -> None: + # Rollout buffer's observation and action are float and it can cause inconsistency. + # So we force the types of buffer to be the same (should we emit a warning?). + buffer = self.buffer + tgt_buffer = self.agent.rollout_buffer + if tgt_buffer.observations is None: + raise RuntimeError("The rollout buffer is not initialized.") + + obs_type = tgt_buffer.observations.dtype + action_type = tgt_buffer.actions.dtype + reward_type = tgt_buffer.rewards.dtype + + buffer.observations = buffer.observations.astype(obs_type) + buffer.actions = buffer.actions.astype(action_type) + buffer.rewards = buffer.rewards.astype(reward_type) diff --git a/algorithm_distillation/models/sb3_util/logger.py b/algorithm_distillation/models/sb3_util/logger.py new file mode 100644 index 0000000..fb1150c --- /dev/null +++ b/algorithm_distillation/models/sb3_util/logger.py @@ -0,0 +1,78 @@ +import datetime +import os +import tempfile +from collections import defaultdict +from typing import Any, List, Optional, Tuple, Union + +from stable_baselines3.common.logger import Logger, make_output_format + + +class CustomLogger(Logger): + """ + A logger object can be plugged into an SB3 agent to record the metrics. Here we customize it to save metric + histories. One can further customize it and implement, for example, the connection with wandb. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.history_value = defaultdict(list) + self.history_mean_value = defaultdict(list) + + def record( + self, + key: str, + value: Any, + exclude: Optional[Union[str, Tuple[str, ...]]] = None, + ) -> None: + super().record(key, value, exclude) + self.history_value[key].append(value) + + def record_mean( + self, + key: str, + value: Any, + exclude: Optional[Union[str, Tuple[str, ...]]] = None, + ) -> None: + super().record_mean(key, value, exclude) + self.history_mean_value[key].append(self.name_to_value[key]) + + +def configure( + folder: Optional[str] = None, + format_strings: Optional[List[str]] = None, + logger_class=Logger, +) -> Logger: + """ + Configure the current logger. + (This is almost the same as SB3's logger configuration helper function except one line in the parameter and + another line towards the end to allow for customized logger classes.) + + :param folder: the save location + (if None, $SB3_LOGDIR, if still None, tempdir/SB3-[date & time]) + :param format_strings: the output logging format + (if None, $SB3_LOG_FORMAT, if still None, ['stdout', 'log', 'csv']) + :param logger_class: (Optional) the custom logger class. + :return: The logger object. + """ + if folder is None: + folder = os.getenv("SB3_LOGDIR") + if folder is None: + folder = os.path.join( + tempfile.gettempdir(), + datetime.datetime.now().strftime("SB3-%Y-%m-%d-%H-%M-%S-%f"), + ) + assert isinstance(folder, str) + os.makedirs(folder, exist_ok=True) + + log_suffix = "" + if format_strings is None: + format_strings = os.getenv("SB3_LOG_FORMAT", "stdout,log,csv").split(",") + + format_strings = list(filter(None, format_strings)) + output_formats = [make_output_format(f, folder, log_suffix) for f in format_strings] + + logger = logger_class(folder=folder, output_formats=output_formats) + # Only print when some files will be saved + if len(format_strings) > 0 and format_strings != ["stdout"]: + logger.log(f"Logging to {folder}") + return logger diff --git a/algorithm_distillation/models/util.py b/algorithm_distillation/models/util.py index a631ced..ee5c8d1 100644 --- a/algorithm_distillation/models/util.py +++ b/algorithm_distillation/models/util.py @@ -5,15 +5,59 @@ def stack_seq(obs, act, rew, extra=None) -> torch.Tensor: """ Stack up into a sequence (obs, act, rew, obs, act, rew, ...) in axis 1, and append extra in the end. + :param obs: shape (b, t, hidden_size) :param act: shape (b, t, hidden_size) :param rew: shape (b, t, hidden_size) :param extra: (Optional) shape (b, i, hidden_size) where i can be 1, 2 or 3 :return: shape (b, 3*t+i, hidden_size) """ + if obs is None: + # batch size is 0. We return extra. + return extra + batch_size, timestep, _ = obs.shape stacked = torch.stack((obs, act, rew), dim=2).view(batch_size, 3 * timestep, -1) if extra is None: return stacked else: return torch.concat([stacked, extra], dim=1) + + +def get_sequence(arr: torch.Tensor, num_items: int, end_idx: int, interval: int): + """ + Get the subsequence of indices 'end_idx - num_items * interval, ..., end_idx - interval' from + the PyTorch tensor `arr`. + Note: + - While `end_idx` is excluded, the intervals start backwards from it. + - Negative indices are ignored (meaning the return could be shorter). + - If the length of `arr` is less than `end_idx`, treat `arr` as a circular buffer. + Example: + arr = [7, 8, 9], length is 3 + num_items = 3, end_idx = 4, interval = 2 + -> [7, 9] + num_items = 1, end_idx = 3, interval = 2 + -> [8] + num_items = 1, end_idx = 2, interval = 2 + -> [7] + + :param arr: a tensor whose first dimension is the index we are concerned with. + :param num_items: the max number of items. + :param end_idx: the end index. + :param interval: the interval length. + :return: a subsequence of `arr` according to the description. + """ + length = arr.size(0) + # The size of the circular buffer determines max how many items it can return + num_items = min(num_items, (length + interval - 1) // interval) + # Get the actual start index (inclusive) + start_idx = max(end_idx - num_items * interval, end_idx % interval) + + if end_idx >= length: + # The subseq cuts in the middle. Update `end_idx` to the actual end index on the second half. + end_idx = max(0, end_idx % length - interval + 1) + return torch.concat( + [arr[start_idx % length :: interval], arr[:end_idx:interval]], dim=0 + ) + else: + return arr[start_idx:end_idx:interval] diff --git a/algorithm_distillation/task.py b/algorithm_distillation/task.py deleted file mode 100644 index 58e407e..0000000 --- a/algorithm_distillation/task.py +++ /dev/null @@ -1,240 +0,0 @@ -import abc -import math -import random -from typing import Optional - -import gym -import numpy as np -import stable_baselines3 -from stable_baselines3.common.buffers import ReplayBuffer - - -class Task(abc.ABC): - """ - This class controls the training of one single task. - Each object controls one single trainable model inside. - """ - - # `obs_dim` marks the total dimension of the observations - obs_dim: int - - @abc.abstractmethod - def __init__(self, **kwargs): - pass - - @abc.abstractmethod - def train(self, steps: int): - """ - Train a number of steps. - - :param steps: the number of steps to train - :return: None - """ - pass - - @abc.abstractmethod - def sample_history(self, length: int, skip: int = 0) -> tuple: - """ - Sample a history of specific length. - - :param length: the length of the history. - (note: length of training steps instead of length of sequence! Every step - includes obs, act, rew. Thus, the final sequence is 3x as long.) - :param skip: (Optional) skip certain amount of steps between two states. - :return: a tuple of (observations, actions, rewards). Each is a tensor. - """ - pass - - -class GymTask(Task): - # TODO: Still need some work to set up the on-policy algorithms due to problems in their rollout buffers. - _algorithms = {"DQN": stable_baselines3.DQN} - _on_policy = ("PPO",) - - # If the environment has a discrete obs space of n classes, return an (n,)-shape array for each obs. - # If the environment has a continuous obs space of certain shape, return a flattened array for each obs. - obs_dim: int - obs_cls: str - - def __init__(self, env, algorithm: str, config: Optional[dict] = None): - """ - GymTask takes in a gym environment and set up a stable-baselines3 algorithm to train a policy network. - - :param env: the gym environment - :param algorithm: the stable-baselines3 algorithm - :param config: (Optional) the stable-baselines3 algorithm config (except for the gym environment) - """ - self.env = env - if not isinstance(env.action_space, gym.spaces.Discrete): - raise ValueError("Only support discrete action spaces.") - - self.algorithm = algorithm - if algorithm not in self._algorithms: - raise ValueError( - f"Input must be one of {self._algorithms.keys()}. Got {algorithm} instead." - ) - - self.config = {} if config is None else config.copy() - self.config["env"] = self.env - # Default to MultiInputPolicy which is applicable most of the time. - if "policy" not in self.config: - self.config["policy"] = "MultiInputPolicy" - - self.agent = self._algorithms[algorithm](**self.config) - self.obs_dim, self.obs_cls = self._get_obs_specs() - - def train(self, steps: int): - self.agent.learn(total_timesteps=steps) - - def sample_history( - self, length: int, skip: int = 0, most_recent: bool = False - ) -> tuple: - """ - This implementation will sample their most recent histories. - - :param length: the length of the history . - (note: length of training steps instead of length of sequence! Every step - includes obs, act, rew. Thus, the final sequence is 3x as long.) - :param skip: (Optional) skip certain amount of steps between two states - :param most_recent: (Optional) sample from the most recent histories. False to sample randomly. - :return: a tuple of (observations, actions, rewards). Each is a tensor. - """ - if self.algorithm in self._on_policy: - raise NotImplementedError("Not supporting on-policy algorithms yet.") - - buffer = self.agent.replay_buffer - - if buffer.n_envs != 1: - raise NotImplementedError("Not supporting parallel environments yet.") - - if most_recent: - return self._get_most_recent_history(buffer, length, skip) - else: - return self._randomly_sample_buffer(buffer, length, skip) - - def _get_obs_specs(self) -> tuple: - if not hasattr(self, "env"): - raise ValueError('Need to assign "env" attribute first.') - obs_space = self.env.observation_space - - if isinstance(obs_space, gym.spaces.Discrete): - return obs_space.n, "discrete" - elif isinstance(obs_space, gym.spaces.Box): - return math.prod([n for n in obs_space.shape]), "box" - else: - raise NotImplementedError( - f"The observation space does not support {type(obs_space)}." - ) - - def _obs_post_process(self, obs: np.ndarray) -> np.ndarray: - """ - Post-process the observations according to its type and shape. - - :param obs: the batched observation array. - :return: the processed observation array of shape (length, obs_dim). - """ - length = obs.shape[0] - if self.obs_cls == "discrete": - obs = obs.reshape((length, 1)) - # Return arrays according to one-hot encoding - return (obs == np.tile(np.arange(self.obs_dim), (length, 1))).astype(float) - elif self.obs_cls == "box": - # Flatten all the other - return obs.reshape((length, -1)) - else: - raise RuntimeError("Impossible code path.") - - @staticmethod - def _act_post_process(act: np.ndarray) -> np.ndarray: - """ - Post-process the actions. Assume actions are discrete with shape (length, 1) or (length). - - :param act: the batched action array. - :return: the processed action array of shape (length, 1). - """ - return act.reshape((-1, 1)).astype(int) - - @staticmethod - def _rew_post_process(rew: np.ndarray) -> np.ndarray: - """ - Post-process the rewards. Rewards are scalars with shape (length,). - - :param rew: the batched reward array. - :return: the processed reward array of shape (length, 1). - """ - return rew.reshape((-1, 1)).astype(float) - - def _get_most_recent_history( - self, buffer: ReplayBuffer, length: int, skip: int - ) -> tuple: - """ - Get the most recent history from the buffer. - - :param buffer: ReplayBuffer object from stable-baselines3. - :param length: the length of steps to sample. - :param skip: the amount to skip between states. - :return: an (observations, actions, rewards) tuple. - """ - pos = buffer.pos - total_length = length * (skip + 1) - if pos >= total_length: - return ( - self._obs_post_process( - buffer.observations[pos - total_length : pos : skip + 1] - ), - self._act_post_process( - buffer.actions[pos - total_length : pos : skip + 1] - ), - self._rew_post_process( - buffer.rewards[pos - total_length : pos : skip + 1] - ), - ) - else: - latest = ( - self._obs_post_process(buffer.observations[: pos : skip + 1]), - self._act_post_process(buffer.actions[: pos : skip + 1]), - self._rew_post_process(buffer.rewards[: pos : skip + 1]), - ) - if buffer.full: - extra = ( - self._obs_post_process( - buffer.observations[pos - length :: skip + 1] - ), - self._act_post_process(buffer.actions[pos - length :: skip + 1]), - self._rew_post_process(buffer.rewards[pos - length :: skip + 1]), - ) - latest = tuple( - [np.concatenate([extra[i], latest[i]], axis=0) for i in range(3)] - ) - - return latest - - def _randomly_sample_buffer( - self, buffer: ReplayBuffer, length: int, skip: int - ) -> tuple: - """ - Randomly sample a sequence from the buffer (requires that there is enough to sample from). - - :param buffer: ReplayBuffer object from stable-baselines3. - :param length: the length of steps to sample. - :param skip: the amount to skip between states. - :return: an (observations, actions, rewards) tuple. - """ - upper_bound = buffer.size if buffer.full else buffer.pos - total_length = length * (skip + 1) - - if total_length > upper_bound: - raise IndexError("Buffer contains fewer samples than necessary.") - - start = random.randint(0, upper_bound - total_length) - return ( - self._obs_post_process( - buffer.observations[start : start + total_length : skip + 1] - ), - self._act_post_process( - buffer.actions[start : start + total_length : skip + 1] - ), - self._rew_post_process( - buffer.rewards[start : start + total_length : skip + 1] - ), - ) diff --git a/algorithm_distillation/tasks/__init__.py b/algorithm_distillation/tasks/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/algorithm_distillation/tasks/rl/__init__.py b/algorithm_distillation/tasks/rl/__init__.py new file mode 100644 index 0000000..7baca5f --- /dev/null +++ b/algorithm_distillation/tasks/rl/__init__.py @@ -0,0 +1,4 @@ +from algorithm_distillation.tasks.rl.task import GymTask, Task +from algorithm_distillation.tasks.rl.task_manager import TaskManager + +__all__ = ["Task", "GymTask", "TaskManager"] diff --git a/algorithm_distillation/tasks/rl/task.py b/algorithm_distillation/tasks/rl/task.py new file mode 100644 index 0000000..801df0c --- /dev/null +++ b/algorithm_distillation/tasks/rl/task.py @@ -0,0 +1,317 @@ +import abc +import math +import random +from typing import Optional + +import gym +import numpy as np +import stable_baselines3 +from stable_baselines3.common.buffers import ReplayBuffer + +from algorithm_distillation.models.sb3_util.callback import RolloutCallback + + +class Task(abc.ABC): + """ + This class controls the training of one single task. + Each object controls one single trainable model inside. + """ + + # `obs_dim` marks the total dimension of the observations + obs_dim: int + act_dim: int + env: object + + @abc.abstractmethod + def __init__(self, **kwargs): + pass + + @abc.abstractmethod + def train(self, steps: int, **kwargs): + """ + Train a number of steps. + + :param steps: the number of steps to train. + :param kwargs: (Optional) extra parameters for SB3 `learn` method. + :return: None + """ + pass + + @abc.abstractmethod + def sample_history(self, length: int, skip: int = 0) -> tuple: + """ + Sample a history of specific length. + + :param length: the length of the history. + (note: length of training steps instead of length of sequence! Every step + includes obs, act, rew. Thus, the final sequence is 3x as long.) + :param skip: (Optional) skip certain amount of steps between two states. + :return: a tuple of (observations, actions, rewards). Each is a tensor. + """ + pass + + +class GymTask(Task): + # TODO: support TD3 for continuous action space? + _algorithms = { + "DQN": stable_baselines3.DQN, + "PPO": stable_baselines3.PPO, + "A2C": stable_baselines3.A2C, + } + _on_policy = ("PPO", "A2C") + + # If the environment has a discrete obs space of n classes, return an (n,)-shape array for each obs. + # If the environment has a continuous obs space of certain shape, return a flattened array for each obs. + obs_dim: int + act_dim: int + + obs_cls: str + act_cls: str + + def __init__( + self, + env, + algorithm: str, + buffer_size: int = 10_000, + config: Optional[dict] = None, + ): + """ + GymTask takes in a gym environment and set up a stable-baselines3 algorithm to train a policy network. + + :param env: the gym environment + :param algorithm: the stable-baselines3 algorithm + :param buffer_size: (Optional) the buffer size. This refers to the stored history length where the + transformer samples from, and applies to both on-policy and off-policy + :param config: (Optional) the stable-baselines3 algorithm config (except for the gym environment) + """ + self._env = env + if not isinstance(env.action_space, gym.spaces.Discrete): + raise NotImplementedError("Only supports discrete action spaces for now.") + + self.algorithm = algorithm + if algorithm not in self._algorithms: + raise ValueError( + f"Input must be one of {self._algorithms.keys()}. Got {algorithm} instead." + ) + self.buffer_size = buffer_size + + # Set up the agent + self.config = {} if config is None else config.copy() + self.config["env"] = self.env + if "buffer_size" in self.config or algorithm not in self._on_policy: + # Overwrite policy buffer size if exists/doing off-policy + self.config["buffer_size"] = self.buffer_size + + # Set up the defaults + defaults = { + "policy": "MultiInputPolicy", # a common policy applicable most of the time + } + # Set up the defaults specifically for either on-policy or off-policy algos + if self.algorithm in self._on_policy: + policy_level_defaults = { + "n_steps": 2048, + } + else: + policy_level_defaults = {} + self.config = {**defaults, **policy_level_defaults, **self.config} + # Set up the agent + self.agent = self._algorithms[algorithm](**self.config) + + # If on-policy, set up the the callback to collect rollout + if self.algorithm in self._on_policy: + buffer_config = { + "buffer_size": self.buffer_size, + "observation_space": self._env.observation_space, + "action_space": self._env.action_space, + } + self.buffer = ReplayBuffer(**buffer_config) + self.callback = RolloutCallback(self.agent, self.buffer) + else: + self.buffer = self.agent.replay_buffer + self.callback = None + + self.obs_dim, self.obs_cls = self._get_obs_specs() + self.act_dim, self.act_cls = self._get_act_specs() + + def train(self, steps: int, **kwargs): + self.agent.learn(total_timesteps=steps, callback=self.callback, **kwargs) + + def sample_history( + self, length: int, skip: int = 0, most_recent: bool = False + ) -> tuple: + """ + This implementation will sample their most recent histories. + + :param length: the length of the history . + (note: length of training steps instead of length of sequence! Every step + includes obs, act, rew. Thus, the final sequence is 3x as long.) + :param skip: (Optional) skip certain amount of steps between two states + :param most_recent: (Optional) get the most recent histories. False to sample randomly. + :return: a tuple of (observations, actions, rewards). Each is a tensor. + """ + if self.buffer.n_envs != 1: + raise NotImplementedError("Not supporting parallel environments yet.") + + if most_recent: + return self._get_most_recent_history(self.buffer, length, skip) + else: + return self._randomly_sample_buffer(self.buffer, length, skip) + + def _get_obs_specs(self) -> tuple: + obs_space = self.env.observation_space + + if isinstance(obs_space, gym.spaces.Discrete): + return obs_space.n, "discrete" + elif isinstance(obs_space, gym.spaces.Box): + return math.prod([n for n in obs_space.shape]), "box" + else: + raise NotImplementedError( + f"The observation space does not support {type(obs_space)}." + ) + + def _get_act_specs(self) -> tuple: + act_space = self.env.action_space + + if isinstance(act_space, gym.spaces.Discrete): + return act_space.n, "discrete" + else: + raise NotImplementedError( + f"The observation space does not support {type(act_space)}." + ) + + @property + def env(self) -> gym.Env: + if not hasattr(self, "_env"): + raise ValueError('Need to assign "_env" attribute first.') + return self._env + + def obs_post_process(self, obs: np.ndarray) -> np.ndarray: + """ + Post-process the observations according to its type and shape. + + :param obs: the batched observation array. + :return: the processed observation array of shape (length, obs_dim). + """ + length = obs.shape[0] + if self.obs_cls == "discrete": + obs = obs.reshape((length, 1)) + # Return arrays according to one-hot encoding + return (obs == np.tile(np.arange(self.obs_dim), (length, 1))).astype(float) + elif self.obs_cls == "box": + # Flatten all the other + return obs.reshape((length, -1)) + else: + raise RuntimeError("Impossible code path.") + + @staticmethod + def act_post_process(act: np.ndarray) -> np.ndarray: + """ + Post-process the actions. Assume actions are discrete with shape (length, 1) or (length). + + :param act: the batched action array. + :return: the processed action array of shape (length, 1). + """ + return act.reshape((-1, 1)).astype(int) + + @staticmethod + def rew_post_process(rew: np.ndarray) -> np.ndarray: + """ + Post-process the rewards. Rewards are scalars with shape (length,). + + :param rew: the batched reward array. + :return: the processed reward array of shape (length, 1). + """ + return rew.reshape((-1, 1)).astype(float) + + def _get_most_recent_history( + self, buffer: ReplayBuffer, length: int, skip: int + ) -> tuple: + """ + Get the most recent history from the buffer. + + :param buffer: ReplayBuffer object from stable-baselines3. + :param length: the length of steps to sample. + :param skip: the amount to skip between states. + :return: an (observations, actions, rewards) tuple. + """ + pos = buffer.pos + total_length = (length - 1) * (skip + 1) + 1 + + assert ( + buffer.buffer_size > total_length + ), "Replay buffer size must be larger than the sequence length." + + start = (pos - total_length + buffer.buffer_size) % buffer.buffer_size + end = pos + + return self._get_obs_act_rew(buffer, start, end, skip) + + def _randomly_sample_buffer( + self, buffer: ReplayBuffer, length: int, skip: int + ) -> tuple: + """ + Randomly sample a sequence from the buffer (requires that there is enough to sample from). + + :param buffer: ReplayBuffer object from stable-baselines3. + :param length: the length of steps to sample. + :param skip: the amount to skip between states. + :return: an (observations, actions, rewards) tuple. + """ + total_length = (length - 1) * (skip + 1) + 1 + assert ( + buffer.buffer_size > total_length + ), "Replay buffer size must be larger than the sequence length." + + if not buffer.full: + start = random.randint(0, buffer.pos - total_length) + end = start + total_length + else: + start = ( + random.randint(0, buffer.buffer_size - total_length) + buffer.pos + ) % buffer.buffer_size + end = (start + total_length) % buffer.buffer_size + + return self._get_obs_act_rew(buffer, start, end, skip) + + @staticmethod + def _get_range( + array: np.ndarray, start: int, end: int, interval: int + ) -> np.ndarray: + """ + A helper function to either slice array[start:end:interval] or combine array[start::interval] and + array[:end:interval] depending on whether start < end. + + :param array: the sliced array. + :param start: the starting index. + :param end: the ending index (exclusive). + :param interval: the interval. + :return: the sliced sub-array. + """ + if start < end: + return array[start:end:interval] + else: + return np.concatenate( + [array[start::interval], array[:end:interval]], axis=0 + ) + + def _get_obs_act_rew(self, buffer: ReplayBuffer, start: int, end: int, skip: int): + """ + Return a tuple (obs, act, rew) sampled according to the buffer and the parameters. + + :param buffer: the replay buffer. + :param start: the starting index. + :param end: the ending index. + :param skip: the amount of states to skip. + :return: the tuple (obs, act, rew) + """ + return ( + self.obs_post_process( + self._get_range(buffer.observations, start, end, skip + 1) + ), + self.act_post_process( + self._get_range(buffer.actions, start, end, skip + 1) + ), + self.rew_post_process( + self._get_range(buffer.rewards, start, end, skip + 1) + ), + ) diff --git a/algorithm_distillation/task_manager.py b/algorithm_distillation/tasks/rl/task_manager.py similarity index 80% rename from algorithm_distillation/task_manager.py rename to algorithm_distillation/tasks/rl/task_manager.py index d268fbd..523b9e3 100644 --- a/algorithm_distillation/task_manager.py +++ b/algorithm_distillation/tasks/rl/task_manager.py @@ -16,20 +16,22 @@ def __init__(self, tasks: List[Task]): if not tasks: raise ValueError("The task list cannot be empty.") for task in tasks[1:]: - if task.obs_dim != tasks[0].obs_dim: - raise ValueError("All tasks must have the same obs_dim.") + for st in ["obs", "act"]: + if getattr(task, f"{st}_dim") != getattr(tasks[0], f"{st}_dim"): + raise ValueError(f"All tasks must have the same {st}_dim.") self.tasks = tasks - def train(self, steps: int): + def train(self, steps: int, **kwargs): """ Train `steps` amount of steps for all the tasks. :param steps: the amount of gradient steps to train + :param kwargs: (Optional) extra parameters for SB3 `learn` method. :return: None """ for task in self.tasks: - task.train(steps) + task.train(steps, **kwargs) def sample_history(self, length: int, skip: int = 0) -> tuple: """ diff --git a/requirements.txt b/requirements.txt index 5887b2c..8303bfe 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,7 @@ stable-baselines3 -transformers~=4.24.0 -torch~=1.12.1 -pytest~=7.2.0 -gym~=0.21.0 -numpy~=1.23.5 \ No newline at end of file +transformers>=4.24.0 +torch>=1.12.1 +pytest>=7.2.0 +gym==0.21.0 +numpy +tqdm \ No newline at end of file diff --git a/tests/test_ad.py b/tests/test_ad.py index 746ec02..f7b9c60 100644 --- a/tests/test_ad.py +++ b/tests/test_ad.py @@ -1,20 +1,57 @@ import gym +import pytest -from algorithm_distillation import GymTask, TaskManager, GymAD +from algorithm_distillation import GymAD +from algorithm_distillation.tasks.rl import GymTask, TaskManager from algorithm_distillation.models import GPT2AD +from algorithm_distillation.models.sb3_util import CustomLogger, configure -def test_ad(): - env = gym.make('FrozenLake-v1') - config = {'learning_starts': 10, - 'buffer_size': 1000, - 'policy': 'MlpPolicy' - } - model = GPT2AD(env.observation_space.n, env.action_space.n, 12, max_ep_len=16) +@pytest.mark.parametrize("policy", ['DQN', 'PPO', 'A2C']) +@pytest.mark.parametrize("env_name", ['FrozenLake-v1', 'CartPole-v1']) +def test_ad(policy: str, env_name: str): + env = gym.make(env_name) + + if policy in ['DQN', 'TD3']: + config = {'learning_starts': 10, + 'policy': 'MlpPolicy' + } + else: + config = {'n_steps': 30, + 'batch_size': 10, + 'policy': 'MlpPolicy' + } + if policy == 'A2C': + config.pop('batch_size') + + task = GymTask(env, policy, buffer_size=100, config=config) + model = GPT2AD(task.obs_dim, task.act_dim, 12, max_step_len=16) + + # Inject a customized logger + logger = configure(None, None, CustomLogger) + task.agent.set_logger(logger) - task = GymTask(env, 'DQN', config) task_manager = TaskManager([task]) - task_manager.train(100) + task_manager.train(100, log_interval=1) + + assert 'history_value' in task.agent.logger.__dir__() + # 100 total time-steps, but training only happens upon the finish of an episode. We don't know how many gradient + # steps are trained, but we are sure it is nonzero. + loss_key = 'train/policy_loss' if policy == 'A2C' else 'train/loss' + assert len(task.agent.logger.history_value[loss_key]) != 0 + # But we are sure that rollout happens 100 times. + if policy in ['DQN']: + assert len(task.agent.logger.history_value['rollout/exploration_rate']) == 100 + else: + # Only logged once, because training step (100) equals the rollout length. + # SB3 off-policy algorithms first collect rollouts accumulating `n_steps` until rollout buffer is full, + # and then check if training continues. + assert len(task.agent.logger.history_value['rollout/ep_rew_mean']) > 0 # A2C steps not deterministic?? ad = GymAD(model) - ad.train(task_manager, 100, 10, skip=0, batch_size=8) + ad.train(task_manager, 100, 10, skip=0, batch_size=8, verbose=1) + + obs, act, rew, term = ad.rollout(task, 100, 0, verbose=1) + assert obs.size(0) == 100 + obs, act, rew, term = ad.rollout(task, 100, 2, verbose=1) + assert obs.size(0) == 100 diff --git a/tests/test_env.py b/tests/test_env.py index 34fc082..3a160e1 100644 --- a/tests/test_env.py +++ b/tests/test_env.py @@ -1,43 +1,55 @@ import gym -from algorithm_distillation import GymTask +import pytest +from algorithm_distillation.tasks.rl import GymTask -def test_gym_task(): + +@pytest.mark.parametrize("policy", ['DQN', 'PPO', 'A2C']) +def test_gym_task(policy: str): env = gym.make('FrozenLake-v1') - config = {'learning_starts': 10, - 'buffer_size': 1000, - 'policy': 'MlpPolicy' - } - task = GymTask(env, 'DQN', config) + if policy in ['DQN', 'TD3']: + config = {'learning_starts': 10, + 'policy': 'MlpPolicy' + } + else: + config = {'n_steps': 3, + 'batch_size': 10, + 'policy': 'MlpPolicy' + } + if policy == 'A2C': + config.pop('batch_size') + task = GymTask(env, policy, buffer_size=100, config=config) assert task.obs_cls == 'discrete' assert task.obs_dim == 16 # The default observation space of FrozenLake is discrete 16 labels - task.train(100) + task.train(150) for most_recent in [True, False]: - sample = task.sample_history(10, most_recent=most_recent) - assert sample[0].shape == (10, 16) # Observations are discrete classes - assert sample[1].shape == (10, 1) # Actions are discrete classes - assert sample[2].shape == (10, 1) # Rewards. + for _ in range(100): # There is a bit of randomness. Try many times. + sample = task.sample_history(10, most_recent=most_recent) + assert sample[0].shape == (10, 16) # Observations are discrete classes + assert sample[1].shape == (10, 1) # Actions are discrete classes + assert sample[2].shape == (10, 1) # Rewards. - sample = task.sample_history(10, skip=2, most_recent=most_recent) - assert sample[0].shape == (10, 16) # Observations are discrete classes - assert sample[1].shape == (10, 1) # Actions are discrete classes - assert sample[2].shape == (10, 1) # Rewards. + sample = task.sample_history(10, skip=3, most_recent=most_recent) + assert sample[0].shape == (10, 16) # Observations are discrete classes + assert sample[1].shape == (10, 1) # Actions are discrete classes + assert sample[2].shape == (10, 1) # Rewards. # More than buffer - sample = task.sample_history(1000, skip=0, most_recent=True) - assert sample[0].shape == (100, 16) # Observations are discrete classes - assert sample[1].shape == (100, 1) # Actions are discrete classes - assert sample[2].shape == (100, 1) # Rewards. + try: + sample = task.sample_history(1000, skip=0, most_recent=True) + assert False, "Error should have been raised." + except AssertionError: + assert True """# Please install gym[atari, accept-rom-license] manually if you want to run Atari. - env = gym.make('Alien-v4') + _env = gym.make('Alien-v4') config = {'learning_starts': 10, 'buffer_size': 1000, 'policy': 'MlpPolicy' } - task = GymTask(env, 'DQN', config) + task = GymTask(_env, 'DQN', config) task.train(100) """ diff --git a/tests/test_module.py b/tests/test_module.py index 2671c7e..abd6ef2 100644 --- a/tests/test_module.py +++ b/tests/test_module.py @@ -5,7 +5,7 @@ @pytest.mark.parametrize("n", [1, 2, 3]) def test_GPT2AD(n): - model = GPT2AD(2, n, 12, max_ep_len=16) + model = GPT2AD(2, n, 12, max_step_len=16) model.eval() # Disable dropout sample_obs = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=torch.float) # two samples, two steps in each trajectory: diff --git a/tests/test_sb3_util.py b/tests/test_sb3_util.py new file mode 100644 index 0000000..dfa0604 --- /dev/null +++ b/tests/test_sb3_util.py @@ -0,0 +1,34 @@ +import gym +import numpy as np +import stable_baselines3 +from stable_baselines3.common.buffers import ReplayBuffer + +from algorithm_distillation.models.sb3_util.callback import RolloutCallback + + +def test_callback(): + env = gym.make('FrozenLake-v1') + config = {'batch_size': 10, + 'n_steps': 100, # determines the rollout size + 'policy': 'MlpPolicy' + } + buffer_config = {'buffer_size': 1000, + 'observation_space': env.observation_space, + 'action_space': env.action_space} + buffer = ReplayBuffer(**buffer_config) + agent = stable_baselines3.PPO(env=env, **config) + cb = RolloutCallback(agent, buffer) + + assert buffer.pos == 0 + agent.learn(200, callback=cb) + assert buffer.pos == 200 + # Note: There is a subtlety in comparing shapes of two buffers after the learning is finished. + # RolloutBuffer.get changes the shapes of everything. It will call swap_and_flatten which flattens + # the first two dimensions: (buffer_size, n_env, ...) -> (buffer_size * n_env, ...) + assert np.all(np.isclose(buffer.observations[100:200].flatten(), + agent.rollout_buffer.observations.flatten())) + assert np.all(np.isclose(buffer.actions[100:200].flatten(), + agent.rollout_buffer.actions.flatten())) + assert np.all(np.isclose(buffer.rewards[100:200].flatten(), + agent.rollout_buffer.rewards.flatten())) + diff --git a/tests/test_util.py b/tests/test_util.py index e888e59..e70fbc3 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -1,5 +1,5 @@ import torch -from algorithm_distillation.models.util import stack_seq +from algorithm_distillation.models.util import stack_seq, get_sequence def test_stack_seq(): @@ -30,3 +30,29 @@ def test_stack_seq(): result = stack_seq(obs, act, rew, extra) assert torch.all(result.isclose(r)) + + # when batch size is 0, all of obs, act, rew should be None. It should just return the extra. + result = stack_seq(None, None, None, extra) + assert torch.all(result.isclose(extra)) + + +def test_get_sequence(): + arr = torch.tensor([7, 8, 9]) + n = 3 + end_idx = 4 + interval = 2 + assert all(get_sequence(arr, n, end_idx, interval) == torch.tensor([7, 9])) + n = 1 + end_idx = 3 + interval = 2 + assert all(get_sequence(arr, n, end_idx, interval) == torch.tensor([8])) + n = 1 + end_idx = 2 + interval = 2 + assert all(get_sequence(arr, n, end_idx, interval) == torch.tensor([7])) + + n = 3 + end_idx = 5 + interval = 2 + # can only fetch a max elem that the circular buffer allows (i.e., 2 in this case) + assert all(get_sequence(arr, n, end_idx, interval) == torch.tensor([8, 7]))